Source code for proteinshake.datasets.dataset

# -*- coding: utf-8 -*-
"""
Base dataset class for protein 3D structures.
"""
import os, gzip, inspect, time, itertools, tarfile, io, requests
import copy
from collections import defaultdict, Counter
from functools import cached_property
import multiprocessing as mp

import pandas as pd
import numpy as np
import freesasa
from biopandas.pdb import PandasPdb
from joblib import Parallel, delayed
from sklearn.neighbors import kneighbors_graph, radius_neighbors_graph
from fastavro import reader as avro_reader

from proteinshake.transforms import IdentityTransform, RandomRotateTransform, CenterTransform
from proteinshake.utils import download_url, save, load, unzip_file, write_avro, Generator, progressbar, warning, error

AA_THREE_TO_ONE = {'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F', 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L', 'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R', 'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y'}
AA_ONE_TO_THREE = {v:k for k, v in AA_THREE_TO_ONE.items()}

# maps the date-format release to Zenodo identifier
RELEASES = {
    'latest': '1212262',
}

class Dataset():
    """ Base dataset class.
    Holds the logic for downloading and parsing PDB files.
    If ``use_precomputed=True``, fetched pre-processed data from Zenodo.
    Else, builds the dataset from scratch by executing: :meth:`download()` to fetch structures in PDB format, then :meth:`parse()` is applied to each to extract the relevant info and store it in a protein dictionary which has three outer keys ``'protein'``, ``'residue'``, and ``'atom'``.
    Subclassing :meth:`add_protein_attributes` lets the user include custom attributes.

    .. note::

        All child classes inherit these attributes and optionally add their own.

    .. list-table:: Annotations
      :widths: 25 35 45
      :header-rows: 1

      * - Attribute
        - Key
        - Sample value
      * - Protein identifier
        - :code:`['protein']['ID']`
        - ``'1JC8'``
      * - Sequence
        - :code:`['protein']['sequence']`
        - ``'MIWGDSGKL...'``
      * - Assigned train/val/test split
        - :code:`['protein']['sequence_split_<CUTOFF>']`, :code:`protein['protein']['structure_split_<CUTOFF>']`
        - ``'train'``
      * - Residue position on chain
        - :code:`['residue']['residue_number']`
        - ``[1, 2, 3, ...]``
      * - Amino acid type (single letter)
        - :code:`['residue']['residue_type']`
        - :code:`['M', 'I', ...]`
      * - 3D coordinates
        - :code:`[{'residue' | 'atom'}][{'x'|'y'|'z'}]`
        - ``[5.191, ...]``
      * - Solvent accessible surface area
        - :code:`[{'residue'|'atom'}]['SASA']`
        - :code:`[242.031, ...]`
      * - Relative accessible surface area
        - :code:`['residue']['RSA']`
        - :code:`[1.377, ...]`
      * - Atom position
        - :code:`['atom']['atom_number']`
        - :code:`[1, 2, 3, ...]`
      * - Atom type
        - :code:`['atom']['atom_type']`
        - :code:`['N', 'CA', ...]`


    Arguments
    -----------
    root: str, default 'data'
        The data root directory to store both raw and parsed data.
    use_precomputed: bool, default True
        If `True`, will download the processed dataset from the ProteinShake repository (recommended). If `False`, will force to download the raw data from the original sources and process them on your device. You can use this option if you wish to create a custom dataset. Using `False` is compute-intensive, consider increasing `n_jobs`.
    release: str, default '12JUL2022'
        The tag of the dataset release. See https://github.com/BorgwardtLab/proteinshake/releases for all available releases. "latest" (default) is recommended.
    only_single_chain: bool, default False
        If `True`, will only use single-chain proteins.
    check_sequence: bool, default False
        If `True`, will discard proteins whose primary sequence is not identical with the sequence of amino acids in the structure. This can happen if the structure is not complete (e.g. for parts that could not be crystallized).
    n_jobs: int, default 1
        The number of jobs for downloading and parsing files. It is recommended to increase the number of jobs with `use_precomputed=False`.
    minimum_length: int, default 10
        Proteins smaller than minimum_length residues will be skipped.
    maximum_length: int, default 2048
        Proteins larger than maximum_length residues will be skipped.
    exclude_ids: list, default []
        Exclude PDB IDs from the dataset.
    skip_signature_check: bool, default False
        If True, skips the signature check. 
    verbosity: int, default 2
        Verbosity level of output logging. 2: full output, 1: no progress bars, 0: only warnings and errors, -1: only errors, -2: no output.
    """

    additional_files = [] # indicates the additional file names that are to be included in the release
    exlude_args_from_signature = []

    def __init__(self,
            root                           = 'data',
            use_precomputed                = True,
            release                        = 'latest',
            only_single_chain              = False,
            check_sequence                 = False,
            n_jobs                         = 1,
            minimum_length                 = 10,
            maximum_length                 = 2048,
            exclude_ids                    = [],
            skip_signature_check           = False,
            verbosity                      = 2,
            # center                         = True, Put back after submission
            # random_rotate                  = True
            ):
        self.root = root
        self.repository_url = f'https://sandbox.zenodo.org/record/{RELEASES[release]}/files'
        self.n_jobs = n_jobs
        # self.random_rotate = random_rotate
        # self.center = center
        if use_precomputed and not self.precomputed_already_downloaded() and not self.precomputed_available():
            warning('Could not find precomputed file in the ProteinShake data repository. Setting use_precomputed to False. The dataset will be processed locally.', verbosity=verbosity)
            use_precomputed = False
        self.use_precomputed = use_precomputed
        self.minimum_length = minimum_length
        self.maximum_length = maximum_length
        self.only_single_chain = only_single_chain
        self.check_sequence = check_sequence
        self.release = release
        self.exclude_ids = exclude_ids
        self.skip_signature_check = skip_signature_check
        self.verbosity = verbosity
        
        os.makedirs(f'{self.root}', exist_ok=True)
        #self.check_signature()

        if not use_precomputed:
            self.start_download()
            self.parse()
        else:
            pass#self.check_signature_same_as_hosted()

    def precomputed_already_downloaded(self):
        return os.path.exists(f'{self.root}/{self.name}.residue.avro') or os.path.exists(f'{self.root}/{self.name}.atom.avro')
    
    def precomputed_available(self):
        return requests.head(f'{self.repository_url}/{self.name}.residue.avro.gz', timeout=5).status_code == 200

    def compute_signature(self, use_defaults=False):
        signature = dict(inspect.signature(self.__init__).parameters.items())
        class_object = self.__class__
        while True: # add base signatures to subclass signature
            signature = {**dict(inspect.signature(class_object.__init__).parameters.items()), **signature}
            if len(class_object.__bases__) == 0: break
            class_object = class_object.__bases__[0]
        arg_names = [n for n in signature.keys() if not n in ['self', 'args', 'kwargs', 'n_jobs', 'root', 'verbosity']+self.exlude_args_from_signature]
        if use_defaults:
            return self.name + ' | ' + ', '.join([k + '=' + str(signature[k].default) for k in arg_names])
        return self.name + ' | ' + ', '.join([k + '=' + str(getattr(self, k)) for k in arg_names])

    @cached_property
    def default_signature(self):
        return self.compute_signature(use_defaults=True)

    @cached_property
    def signature(self):
        return self.compute_signature(use_defaults=False)

    def check_signature(self):
        if self.skip_signature_check: return
        if os.path.exists(f'{self.root}/signature.txt'):
            with open(f'{self.root}/signature.txt','r') as file:
                if not file.read() == self.signature: error('The Dataset is called with different arguments than were used to create it. Delete or change the root.', verbosity=self.verbosity)
        else:
            with open(f'{self.root}/signature.txt','w') as file:
                file.write(self.signature)

    def check_signature_same_as_hosted(self):
        """ Safety check to ensure the provided dataset arguments are the same as were used to precompute the datasets. Only relevant with `use_precomputed=True`.
        """
        if not self.signature == self.default_signature: error('The dataset arguments do not match the precomputed dataset arguments (the default settings). Set use_precomputed to False if you wish to generate a new dataset.', verbosity=self.verbosity)
def proteins(self, resolution='residue'): """ Returns a generator of proteins from the avro file. Parameters ---------- resolution: str, default 'residue' The resolution of the proteins. Can be 'atom' or 'residue'. Returns ------- generator An avro reader object. .. code-block:: python >>> from proteinshake.datasets import RCSBDataset >>> protein = next(RCSBDataset().proteins()) """ self.download_precomputed(resolution=resolution) with open(f'{self.root}/{self.name}.{resolution}.avro', 'rb') as file: total = int(avro_reader(file).metadata['number_of_proteins']) def reader(): with open(f'{self.root}/{self.name}.{resolution}.avro', 'rb') as file: for x in avro_reader(file): yield x return Generator(reader(), total)