frameworks
Datasets can natively be converted to various deep learning frameworks.
The base FrameworkDataset
is an iterable dataset with a __getitem__
method.
Each framework sub-class implements convert_to_framework
and optionally load_transform
to provide the framework specific conversion.
from proteinshake.datasets import RCSBDataset
dataset = RCSBDataset().to_graph(eps=8).pyg()
- class FrameworkDataset(data_list, size, path, transform=None, pre_transform=None, pre_filter=None, verbosity=2)[source]
Bases:
object
Dataset base class for different frameworks.
- Parameters:
data_list (generator) – A generator of objects from a representation.
size (int) – The size of the dataset.
path (str) – Path to save the processed dataset.
transform (function) – A transform function to be applied in the __getitem__ method. Signature: transform(data, protein_dict) -> (data, protein_dict)
pre_transform (function) – A transform function to be applied before writing the data. Signature: transform((data, protein_dict)) -> (data, protein_dict)
pre_filter (function) – A filter function to be applied before writing the data. Signature: transform(data, protein_dict) -> bool
- convert_to_framework(data_item)[source]
Converts data_item to a data object of the framework.
- load_transform(data, protein_dict)[source]
Applies a transform after loading, for example if the data has been stored in sparse format and needs to be converted to dense.
- len()[source]
- get()[source]
- class TorchVoxelDataset(data_list, size, path, transform=None, pre_transform=None, pre_filter=None, verbosity=2)[source]
Bases:
FrameworkDataset
,Dataset
Voxel dataset for PyTorch.
- load_transform(data, protein_dict)[source]
Applies a transform after loading, for example if the data has been stored in sparse format and needs to be converted to dense.
- class TorchPointDataset(data_list, size, path, transform=None, pre_transform=None, pre_filter=None, verbosity=2)[source]
Bases:
FrameworkDataset
,Dataset
Point dataset for PyTorch.
- class TensorflowVoxelDataset(*args, transform=<function TensorflowVoxelDataset.<lambda>>, **kwargs)[source]
Bases:
FrameworkDataset
Voxel dataset for TensorFlow.
- load_transform(data, protein_dict)[source]
Applies a transform after loading, for example if the data has been stored in sparse format and needs to be converted to dense.
- class TensorflowPointDataset(*args, transform=<function TensorflowPointDataset.<lambda>>, **kwargs)[source]
Bases:
FrameworkDataset
Point dataset for TensorFlow.
- class NumpyVoxelDataset(*args, transform=<function NumpyVoxelDataset.<lambda>>, **kwargs)[source]
Bases:
FrameworkDataset
Voxel dataset for NumPy.
- class NumpyPointDataset(*args, transform=<function NumpyPointDataset.<lambda>>, **kwargs)[source]
Bases:
FrameworkDataset
Point dataset for NumPy.
- class PygGraphDataset(data_list, size, path, transform=None, pre_transform=None, pre_filter=None, verbosity=2)[source]
Bases:
FrameworkDataset
,Dataset
Graph dataset for PyG.
- class DGLGraphDataset(data_list, size, path, transform=None, pre_transform=None, pre_filter=None, verbosity=2)[source]
Bases:
FrameworkDataset
,DGLDataset
Graph dataset for Deep Graph Library (DGL).
- class NetworkxGraphDataset(data_list, size, path, transform=None, pre_transform=None, pre_filter=None, verbosity=2)[source]
Bases:
FrameworkDataset
Graph dataset for NetworkX.