Source code for proteinshake.tasks.enzyme_class
from sklearn import metrics
from functools import cached_property
import numpy as np
from proteinshake.datasets import EnzymeCommissionDataset
from proteinshake.tasks import Task
class EnzymeClassTask(Task):
""" Predict the type of reaction catalyzed by the given protein as given by the Enzyme Commission databse. The Enzyme Commission
classification is hierarchically organized giving rise to one prediction task per level in the hierarchy. We default to the top-most
level which specifies the generic class of the enzyme, but this can be changed by setting ``ec_level`` when instantiating the task.
This is a protein-level multi-class prediction.
.. admonition:: Task Summary
* **Input:** one protein
* **Output:** enzyme class label (7 classes)
* **Evaluation:** Accuracy (Ryu, Jae Yong, Hyun Uk Kim, and Sang Yup Lee. "Deep learning enables high-quality and high-throughput prediction of enzyme commission numbers." Proceedings of the National Academy of Sciences 116.28 (2019): 13996-14001.)
"""
DatasetClass = EnzymeCommissionDataset
type = 'Multiclass Classification'
input = 'Protein'
output = 'Enzyme Commission Level 1'
def __init__(self, ec_level=0, *args, **kwargs):
self.ec_level = ec_level
super().__init__(*args, **kwargs)
@cached_property
def token_map(self):
labels = {p['protein']['EC'].split(".")[self.ec_level] for p in self.proteins}
return {label: i for i, label in enumerate(sorted(list(labels)))}
def dummy_output(self):
import random
tokens = list(self.token_map.values())
return [random.choice(tokens) for _ in range(len(self.test_targets))]
@property
def num_classes(self):
return len(self.token_map)
@property
def task_type(self):
return ('protein', 'multi_class')
@property
def task_in(self):
return ('protein')
@property
def task_out(self):
return ('multi_class')
@property
def target_dim(self):
return (len(self.token_map.values()))
@property
def num_features(self):
return 20
def target(self, protein):
return self.token_map[protein['protein']['EC'].split(".")[0]]
@property
def default_metric(self):
return 'accuracy'
def evaluate(self, y_true, y_pred):
""" Using metrics from https://doi.org/10.1073/pnas.1821905116 """
y_true = np.array(y_true, dtype=int)
y_pred = np.array(y_pred, dtype=int)
return {
'precision': metrics.precision_score(y_true, y_pred, average='macro', zero_division=0),
'recall': metrics.recall_score(y_true, y_pred, average='macro', zero_division=0),
'accuracy': metrics.accuracy_score(y_true, y_pred),
#'AUROC': metrics.roc_auc_score(y_true, y_pred, average='macro', multi_class='ovo'),
}