frameworks

Datasets can natively be converted to various deep learning frameworks. The base FrameworkDataset is an iterable dataset with a __getitem__ method. Each framework sub-class implements convert_to_framework and optionally load_transform to provide the framework specific conversion.

 from proteinshake.datasets import RCSBDataset
 dataset = RCSBDataset().to_graph(eps=8).pyg()
class FrameworkDataset(data_list, size, path, transform=None, pre_transform=None, pre_filter=None, verbosity=2)[source]

Bases: object

Dataset base class for different frameworks.

Parameters:
  • data_list (generator) – A generator of objects from a representation.

  • size (int) – The size of the dataset.

  • path (str) – Path to save the processed dataset.

  • transform (function) – A transform function to be applied in the __getitem__ method. Signature: transform(data, protein_dict) -> (data, protein_dict)

  • pre_transform (function) – A transform function to be applied before writing the data. Signature: transform((data, protein_dict)) -> (data, protein_dict)

  • pre_filter (function) – A filter function to be applied before writing the data. Signature: transform(data, protein_dict) -> bool

convert_to_framework(data_item)[source]

Converts data_item to a data object of the framework.

load_transform(data, protein_dict)[source]

Applies a transform after loading, for example if the data has been stored in sparse format and needs to be converted to dense.

len()[source]
get()[source]
class TorchVoxelDataset(data_list, size, path, transform=None, pre_transform=None, pre_filter=None, verbosity=2)[source]

Bases: FrameworkDataset, Dataset

Voxel dataset for PyTorch.

load_transform(data, protein_dict)[source]

Applies a transform after loading, for example if the data has been stored in sparse format and needs to be converted to dense.

class TorchPointDataset(data_list, size, path, transform=None, pre_transform=None, pre_filter=None, verbosity=2)[source]

Bases: FrameworkDataset, Dataset

Point dataset for PyTorch.

class TensorflowVoxelDataset(*args, transform=<function TensorflowVoxelDataset.<lambda>>, **kwargs)[source]

Bases: FrameworkDataset

Voxel dataset for TensorFlow.

load_transform(data, protein_dict)[source]

Applies a transform after loading, for example if the data has been stored in sparse format and needs to be converted to dense.

class TensorflowPointDataset(*args, transform=<function TensorflowPointDataset.<lambda>>, **kwargs)[source]

Bases: FrameworkDataset

Point dataset for TensorFlow.

class NumpyVoxelDataset(*args, transform=<function NumpyVoxelDataset.<lambda>>, **kwargs)[source]

Bases: FrameworkDataset

Voxel dataset for NumPy.

class NumpyPointDataset(*args, transform=<function NumpyPointDataset.<lambda>>, **kwargs)[source]

Bases: FrameworkDataset

Point dataset for NumPy.

class PygGraphDataset(data_list, size, path, transform=None, pre_transform=None, pre_filter=None, verbosity=2)[source]

Bases: FrameworkDataset, Dataset

Graph dataset for PyG.

class DGLGraphDataset(data_list, size, path, transform=None, pre_transform=None, pre_filter=None, verbosity=2)[source]

Bases: FrameworkDataset, DGLDataset

Graph dataset for Deep Graph Library (DGL).

class NetworkxGraphDataset(data_list, size, path, transform=None, pre_transform=None, pre_filter=None, verbosity=2)[source]

Bases: FrameworkDataset

Graph dataset for NetworkX.