Working with tasks

Note

Make sure you understand how to work with datasets first.

Of course, at some point, we would like to use the datasets to train and evaluate a model. ProteinShake provides the Task classes, which extend the datasets with data splits and metrics. They work very similar to a Dataset in that they store a set of proteins with annotations, only with some additional functionality such as splits and evaluation methods:

from proteinshake.tasks import EnzymeClassTask
task = EnzymeClassTask(split='sequence').to_voxel().torch()

Hint

You can change the split argument to retrieve either 'random', 'sequence', or 'structure'-based splits. The latter two are based on sequence/structure similarity which we pre-compute for you. The split type influences how hard the generalization to the test set is for the model.

The split_similarity_threshold argument controls the maximum similarity between train and test. It can be any of 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9 for split="sequence", and 0.5, 0.6, 0.7, 0.8, 0.9 for split="structure".

If you want more control over the similarity threshold you can pre-process the dataset yourself. Have a look at the Release Repository.

The task has a few attributes and methods that are specific to model training and evaluation. Let’s look at our prediction targets.

print(task.test_targets)

We can retrieve the train, test and validation splits to put them into a dataloader.

Note

ProteinShake is directly compatible with any dataloader from the supported frameworks. The usage may differ slightly. Check the Quickstart to see the differences.

from torch.utils.data import DataLoader
train, test = DataLoader(task.train), DataLoader(task.test)

The task classes also implement appropriate metrics and function as an evaluator.

Tip

Every task implements a dummy_output() method you can use for testing if you don’t have model predictions at hand. This method will return random values with the correct shape and type for the task.

my_model_predictions = task.dummy_output()
metrics = task.evaluate(task.test_targets, my_model_predictions)

This will return a dictionary of various relevant metrics. Each task has a default metric which we use to rank models in the Leaderboard. Feel free to submit your scores!

print(metrics[task.default_metric])