Source code for proteinshake.tasks.protein_protein_interface

import numpy as np
from sklearn import metrics

from proteinshake.datasets import ProteinProteinInterfaceDataset
from proteinshake.tasks import Task

class ProteinProteinInterfaceTask(Task):
    """ Identify the binding interface of a protein-protein complex. Protein function is driven in large part by binding events between different protein chains to form 'complexes'. Understanding how proteins interact with each other has implications in unraveling complex biological mechanisms, and designing proteins with desirable interactions. The underlying data is taken from the PDBBind database. All pairs of residues belonging to different chains and coming from different protein chains within 6A of each other (Townshend et al., 2019)  are labeled as positive examples.

    .. admonition:: Task Summary 

        * **Input:** two protein chains
        * **Output:** binary label for each residue in both chains (1 if residue belongs to interface 0 otherwise)
        * **Evaluation:** AUROC (*Fout, Alex, et al. "Protein interface prediction using graph convolutional networks." Advances in neural information processing systems 30 (2017)*)


    """

    DatasetClass = ProteinProteinInterfaceDataset
    
    type = 'Binary Classification'
    input = 'Protein and Protein'
    output = 'Protein Binding Interface Residues'
    
    @property
    def num_classes(self):
        return 2

    @property
    def task_in(self):
        return ('residue', 'residue')

    @property
    def task_type(self):
        return ('residue_pair', 'binary')

    @property
    def task_out(self):
        return ('binary')

    @property
    def out_dim(self):
        return (1)

    def dummy_output(self):
        import random
        return [np.where(np.random.randint(0, 2, p.shape) == 0, 0, 1) for p in self.test_targets]

    def update_index(self):
        """ Transform to pairwise indexing """
        self.train_index = self.compute_pairs(self.train_index)
        self.val_index = self.compute_pairs(self.val_index)
        self.test_index = self.compute_pairs(self.test_index)
def compute_targets(self): self.train_targets = [self.target(self.proteins[i], self.proteins[j]) for i,j in self.train_index] self.val_targets = [self.target(self.proteins[i], self.proteins[j]) for i,j in self.val_index] self.test_targets = [self.target(self.proteins[i], self.proteins[j]) for i,j in self.test_index] def compute_pairs(self, index): """ Grab all pairs of chains that share an interface""" protein_to_index = {p['protein']['ID']: i for i, p in enumerate(self.dataset.proteins())} def find_index(pdbid, chain): return protein_to_index[f'{pdbid}_{chain}'] proteins = self.dataset.proteins() chain_pairs = [] for i, protein in enumerate(proteins): if i not in index: continue #chain = protein['residue']['chain_id'][0] pdbid, chain = protein['protein']['ID'].split('_') try: chain_pairs.extend([(i, find_index(pdbid, partner)) for partner in self.dataset._interfaces[pdbid][chain]]) # if chain is not in any interface, we skip except (KeyError, IndexError): continue chain_pairs = [(i,j) for i,j in chain_pairs if i in index and j in index] # @carlos please check return np.array(chain_pairs, dtype=int)