Source code for proteinshake.tasks.pfam_task

from sklearn import metrics
from functools import cached_property
import numpy as np

from proteinshake.datasets import ProteinFamilyDataset
from proteinshake.tasks import Task

class ProteinFamilyTask(Task):
    """ Predict the protein family classification of a protein structure which groups proteins into evolutionarily-related families. This is a protein-level multi-class prediction.

    .. admonition:: Task Summary 

        * **Input:** one protein
        * **Output:** protein family class (5163 classes) 
        * **Evaluation:** Accuracy (custom task)

    """

    DatasetClass = ProteinFamilyDataset
    
    type = 'Multiclass Classification'
    input = 'Protein'
    output = 'Protein Family (Pfam)'
    
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
    @property
    def num_classes(self):
        return len(self.token_map)

    @cached_property
    def token_map(self):
        # Pfam': ['Fis1 N-terminal tetratricopeptide repeat (Fis1_TPR_N)', 'Fis1 C-terminal tetratricopeptide repeat (Fis1_TPR_C)'], 
        labels = {p['protein']['Pfam'][0] for p in self.proteins}
        return {label: i for i, label in enumerate(sorted(list(labels)))}

    def dummy_output(self):
        import random
        tokens = list(self.token_map.values())
        return [random.choice(tokens) for _ in range(len(self.test_targets))]

    @property
    def task_in(self):
        return ('protein')

    @property
    def task_type(self):
        return ('protein', 'multi_class')

    @property
    def task_out(self):
        return ('multi_class')

    @property
    def out_dim(self):
        return len(self.token_map)

    @property
    def num_features(self):
        return 20

    def target(self, protein):
        return self.token_map[protein['protein']['Pfam'][0]]

    @property
    def default_metric(self):
        return 'accuracy'

    def evaluate(self, y_true, y_pred):
        """ Using metrics from https://doi.org/10.1073/pnas.1821905116 """
        y_true = np.array(y_true, dtype=int)
        y_pred = np.array(y_pred, dtype=int)
        return {
            'precision': metrics.precision_score(y_true, y_pred, average='macro', zero_division=0),
            'recall': metrics.recall_score(y_true, y_pred, average='macro', zero_division=0),
            'accuracy': metrics.accuracy_score(y_true, y_pred),
            #'AUROC': metrics.roc_auc_score(y_true, y_pred, average='macro', multi_class='ovo'),
        }