Source code for proteinshake.tasks.structure_search

import itertools
from collections import defaultdict
from functools import cached_property

from scipy.stats import spearmanr
import numpy as np
from sklearn import metrics
from sklearn.model_selection import train_test_split

from proteinshake.datasets import TMAlignDataset
from proteinshake.tasks import Task

class StructureSearchTask(Task):
    """ Retrieve similar proteins to a query based on structural similarity.
    Evaluation is cast in the setting of recommender systems where we wish
    to retrieve 'relevant' documents from a large pool of documents.
    Here, a protein is a document and the relevant ones are all proteins
    with a minimum similarity to the query protein.


    .. admonition:: Task Summary 

        * **Input:** one protein
        * **Output:** list of similar proteins from dataset 
        * **Evaluation:** precision@k (Aung, Zeyar, and Kian-Lee Tan. "Rapid 3D protein structure database searching using information retrieval techniques." Bioinformatics 20.7 (2004): 1045-1052.)


    """

    DatasetClass = TMAlignDataset
    
    type = 'Retrieval'
    input = 'Protein'
    output = 'Similar Proteins'

    def __init__(self, min_sim=0.8, *args, **kwargs):
        self.min_sim = min_sim
        super().__init__(*args, **kwargs)

    @property
    def task_type(self):
        return ('protein', 'retrieval')


    @property
    def task_in(self):
        return ('protein')

    @property
    def task_out(self):
        return ('retrieval')

    @cached_property
    def targets(self):
        """ Precompute the set of similar proteins for each query """
        targets = {}
        for query in self.proteins:
            targets[query['protein']['ID']] = [c['protein']['ID'] for c in self.proteins if self.dataset.lddt(query['protein']['ID'], c['protein']['ID']) >= self.min_sim]
        return targets

    def target(self, protein):
        """ The target for a protein is a list of proteins deemed 'relevant'
        according to `self.min_sim`.
        """
        return self.targets[protein['protein']['ID']]

    def _precision_at_k(self, y_true, y_pred, k):
        return len(set(y_pred[:k]).intersection(set(y_true))) / len(y_pred)

    def _recall_at_k(self, y_true, y_pred, k):
        return len(set(y_pred[:k]).intersection(set(y_true))) / len(y_true)

    def dummy_output(self):
        import random
        pred = []
        ids = [p['protein']['ID'] for p in self.proteins] 
        for query in self.proteins[self.test_index]:
            targets = self.target(query)
            pred.append(random.sample(ids, len(targets)))
        return pred

    @property
    def default_metric(self):
        return 'precision_at_k'

    def evaluate(self, y_true, y_pred, k=5):
        """ Retrieval metrics.

        Arguments
        -----------
        y_pred:
            List of indices of items (hits) in the dataset for a query.
        """
        results = defaultdict(list)
        for yt, yp in zip(y_true, y_pred):
            results['precision_at_k'].append(self._precision_at_k(yt, yp, k))
            results['recall_at_k'].append(self._recall_at_k(yt, yp, k))

        return {k: np.mean(v) for k, v in results.items()}