Source code for proteinshake.frameworks.dataset
import os
from proteinshake.utils import save, load, fx2str, progressbar, error
class FrameworkDataset():
""" Dataset base class for different frameworks.
Parameters
----------
data_list: generator
A generator of objects from a representation.
size: int
The size of the dataset.
path: str
Path to save the processed dataset.
transform: function
A transform function to be applied in the __getitem__ method. Signature: transform(data, protein_dict) -> (data, protein_dict)
pre_transform: function
A transform function to be applied before writing the data. Signature: transform((data, protein_dict)) -> (data, protein_dict)
pre_filter: function
A filter function to be applied before writing the data. Signature: transform(data, protein_dict) -> bool
"""
def __init__(self, data_list, size, path, transform=None, pre_transform=None, pre_filter=None, verbosity=2):
os.makedirs(path, exist_ok=True)
self.verbosity = verbosity
self.path = path
self.transform = transform
self.pre_transform = pre_transform
self.pre_filter = pre_filter
transforms_repr = fx2str(pre_transform) + fx2str(pre_filter)
if not os.path.exists(f'{path}/{size-1}.pkl'):
i = 0
for data_item in progressbar(data_list, desc='Converting', total=size, verbosity=self.verbosity):
data = self.convert_to_framework(data_item)
protein_dict = data_item.protein_dict
if not self.pre_filter is None and not self.pre_filter(data, protein_dict):
continue
if not self.pre_transform is None:
data, protein_dict = self.pre_transform(data, protein_dict)
save((data, protein_dict), f'{path}/{i}.pkl')
i += 1
save(i,f'{path}/size.pkl')
save(transforms_repr,f'{path}/transforms.pkl')
self.size = load(f'{path}/size.pkl')
original_repr = load(f'{path}/transforms.pkl')
if not original_repr == transforms_repr: error(f'The pre_transform and/or pre_filter are not the same as when the dataset was created. If you want to change them, delete the folder at {path}', verbosity=self.verbosity)
def convert_to_framework(self, data_item):
""" Converts data_item to a data object of the framework.
"""
return data_item.data
def load_transform(self, data, protein_dict):
""" Applies a transform after loading, for example if the data has been stored in sparse format and needs to be converted to dense.
"""
return data, protein_dict