Source code for proteinshake.tasks.gene_ontology

import itertools
import numpy as np
from sklearn import metrics
from functools import cached_property

from proteinshake.datasets import GeneOntologyDataset
from proteinshake.tasks import Task

class GeneOntologyTask(Task):
    """ Predict the Gene Ontology terms describing the functional roles of a given protein in the cell. This is a protein-level multi-label prediction.

    The prediction should be a n_samples x n_classes matrix, where the columns are ordered according to `self.classes`.
    If your model does not predict or handle a certain class, assign a zero value.

    .. admonition:: Task Summary

        * **Input:** one protein
        * **Output:** n_classes gene ontology terms 
        * **Evaluation:** Fmax (Radivojac, Predrag, et al. "A large-scale evaluation of computational protein function prediction." Nature methods 10.3 (2013): 221-227.)


    """

    DatasetClass = GeneOntologyDataset
    
    type = 'Multilabel Classification'
    input = 'Protein'
    output = 'Gene Ontology Terms'
    
    def __init__(self, branch='molecular_function', *args, **kwargs):
        self.branch = branch
        super().__init__(*args, **kwargs)

    @cached_property
    def token_map(self):
        labels = set(itertools.chain(*[p['protein'][self.branch] for p in self.proteins]))
        return {label: i for i, label in enumerate(sorted(list(labels)))}

    @property
    def num_classes(self):
        return len(self.token_map)
    
    @property
    def classes(self):
        return list(self.token_map.keys())

    @property
    def task_in(self):
        return ('protein')

    @property
    def task_type(self):
        return ('protein', 'multi_label')

    @property
    def task_out(self):
        return ('multi_label')

    @property
    def target_dim(self):
        return (len(self.token_map.values()))

    @property
    def num_features(self):
        return 20

    def target(self, protein):
        tokens = [self.token_map[i] for i in protein['protein'][self.branch]]
        target = np.zeros_like(self.classes, dtype=bool)
        target[tokens] = True
        return target

    def precision(self, y_true, y_pred, threshold):
        mt = (y_pred.max(axis=1) >= threshold).sum()
        if mt == 0: return 0.0
        y_pred = y_pred >= threshold
        nom = np.logical_and(y_true, y_pred).sum(axis=1).astype(np.float32)
        denom = y_pred.sum(axis=1).astype(np.float32)
        return 1/mt * np.divide(nom, denom, out=np.zeros_like(nom), where=denom!=0).sum()

    def recall(self, y_true, y_pred, threshold):
        ne = y_true.shape[0]
        if ne == 0: return 0.0
        y_pred = y_pred >= threshold
        nom = np.logical_and(y_true, y_pred).sum(axis=1).astype(np.float32)
        denom = y_true.sum(axis=1).astype(np.float32)
        return 1/ne * np.divide(nom, denom, out=np.zeros_like(nom), where=denom!=0).sum()
    
    def remaining_uncertainty(self, y_true, y_pred, threshold):
        pass

    def missing_information(self, y_true, y_pred, threshold):
        pass

    def fmax(self, y_true, y_pred):
        fmax = 0
        for t in np.linspace(0,1,21):
            prec, rec = self.precision(y_true, y_pred, t), self.recall(y_true, y_pred, t)
            if prec+rec == 0: continue
            f1 = (2 * prec * rec) / (prec + rec)
            fmax = max(fmax, f1)
        return fmax
    
    def smin(self, y_pred):
        return min([
            np.sqrt(
                self.remaining_uncertainty(y_pred, t) ** 2
                + self.missing_information(y_pred, t) ** 2
            )
            for t in np.linspace(0,1,21)
        ])

    def dummy_output(self):
        return np.random.rand(len(self.test_index), len(self.token_map.keys()))

    @property
    def default_metric(self):
        return 'Fmax'

    def evaluate(self, y_true, y_pred):
        y_true, y_pred = np.array(y_true), np.array(y_pred)
        return {
            'Fmax': self.fmax(y_true, y_pred),
            #'Smin': self.smin(y_true, y_pred),
        }