Source code for proteinshake.tasks.structural_class

from sklearn import metrics
from functools import cached_property
import numpy as np

from proteinshake.datasets import SCOPDataset
from proteinshake.tasks import Task

class StructuralClassTask(Task):
    """ Predict the SCOP class of a protein structure. SCOP labels proteins according to a hierarchy of structural and evolutionary information. The top level of the hierarchy ``SCOP_FA``, you can customize the task to use a different level setting ``scop_level`` to ``SCOP_{level}``, where level is any of TP=protein type, CL=protein class, CF=fold, SF=superfamily, FA=family. This is a protein-level multi-class prediction.

    .. admonition:: Task Summary 

        * **Input:** one protein
        * **Output:** SCOP class (3042 classes)
        * **Evaluation:** Accuracy (custom task)

    """

    DatasetClass = SCOPDataset
    
    type = 'Multiclass Classification'
    input = 'Protein'
    output = 'SCOP Class'
    
    def __init__(self, scop_level='SCOP-FA', *args, **kwargs):
        self.scop_level = scop_level
        super().__init__(*args, **kwargs)
        
    @property
    def num_classes(self):
        return len(self.token_map)

    @cached_property
    def token_map(self):
        labels = {p['protein'][self.scop_level] for p in self.proteins}
        return {label: i for i, label in enumerate(sorted(list(labels)))}

    @property
    def target_dim(self):
        return len(self.token_map)

    def dummy_output(self):
        import random
        tokens = list(self.token_map.values())
        return [random.choice(tokens) for _ in range(len(self.test_targets))]

    @property
    def task_in(self):
        return ('protein')

    @property
    def task_type(self):
        return ('protein', 'multi-class')

    @property
    def task_out(self):
        return ('multi_class')

    @property
    def num_features(self):
        return 20

    def target(self, protein):
        return self.token_map[protein['protein'][self.scop_level]]

    @property
    def default_metric(self):
        return 'accuracy'

    def evaluate(self, y_true, y_pred):
        """ Using metrics from https://doi.org/10.1073/pnas.1821905116 """
        y_true = np.array(y_true, dtype=int)
        y_pred = np.array(y_pred, dtype=int)
        return {
            'precision': metrics.precision_score(y_true, y_pred, average='macro', zero_division=0),
            'recall': metrics.recall_score(y_true, y_pred, average='macro', zero_division=0),
            'accuracy': metrics.accuracy_score(y_true, y_pred),
            #'AUROC': metrics.roc_auc_score(self.test_targets, y_pred, average='macro', multi_class='ovo'),
        }