Source code for proteinshake.frameworks.pyg

import torch
from torch_geometric.utils import from_scipy_sparse_matrix
from torch_geometric.data import Data, Dataset as PygDataset
from proteinshake.frameworks.dataset import FrameworkDataset


class PygGraphDataset(FrameworkDataset, PygDataset):
    """ Graph dataset for PyG.
    """

    def convert_to_framework(self, data_item):
        nodes, adj = data_item.data
        return Data(
            x = torch.from_numpy(nodes),
            edge_index = from_scipy_sparse_matrix(adj)[0].long(),
            edge_attr = from_scipy_sparse_matrix(adj)[1].unsqueeze(1).float()
        )