Source code for proteinshake.frameworks.tf
import tensorflow as tf
from proteinshake.frameworks.dataset import FrameworkDataset
class TensorflowVoxelDataset(FrameworkDataset):
""" Voxel dataset for TensorFlow.
"""
def __init__(self, *args, transform=lambda x:x[0], **kwargs):
super().__init__(*args, transform=transform, **kwargs)
def convert_to_framework(self, data_item):
return tf.sparse.from_dense(tf.convert_to_tensor(data_item.data, dtype=tf.float32))
def load_transform(self, data, protein_dict):
return tf.sparse.to_dense(data), protein_dict