Working with tasks
Note
Make sure you understand how to work with datasets first.
Of course, at some point, we would like to use the datasets to train and evaluate a model.
ProteinShake provides the Task
classes, which extend the datasets with data splits and metrics.
They work very similar to a Dataset
in that they store a set of proteins with annotations, only with some additional functionality such as splits and evaluation methods:
from proteinshake.tasks import EnzymeClassTask
task = EnzymeClassTask(split='sequence').to_voxel().torch()
Hint
You can change the split
argument to retrieve either 'random'
, 'sequence'
, or 'structure'
-based splits.
The latter two are based on sequence/structure similarity which we pre-compute for you.
The split type influences how hard the generalization to the test set is for the model.
The split_similarity_threshold
argument controls the maximum similarity between train and test. It can be any of 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9
for split="sequence"
, and 0.5, 0.6, 0.7, 0.8, 0.9
for split="structure"
.
If you want more control over the similarity threshold you can pre-process the dataset yourself. Have a look at the Release Repository.
The task has a few attributes and methods that are specific to model training and evaluation. Let’s look at our prediction targets.
print(task.test_targets)
We can retrieve the train, test and validation splits to put them into a dataloader.
Note
ProteinShake is directly compatible with any dataloader from the supported frameworks. The usage may differ slightly. Check the Quickstart to see the differences.
from torch.utils.data import DataLoader
train, test = DataLoader(task.train), DataLoader(task.test)
The task classes also implement appropriate metrics and function as an evaluator.
Tip
Every task implements a dummy_output()
method you can use for testing if you don’t have model predictions at hand. This method will return random values with the correct shape and type for the task.
my_model_predictions = task.dummy_output()
metrics = task.evaluate(task.test_targets, my_model_predictions)
This will return a dictionary of various relevant metrics. Each task has a default metric which we use to rank models in the Leaderboard. Feel free to submit your scores!
print(metrics[task.default_metric])