Source code for proteinshake.frameworks.torch
import torch
from torch.utils.data import Dataset as TorchDataset
from proteinshake.frameworks.dataset import FrameworkDataset
class TorchVoxelDataset(FrameworkDataset, TorchDataset):
""" Voxel dataset for PyTorch.
"""
def convert_to_framework(self, data_item):
return torch.tensor(data_item.data).float().to_sparse()
def load_transform(self, data, protein_dict):
return data.to_dense(), protein_dict