Source code for proteinshake.frameworks.dgl

import dgl
import torch
from dgl.data import DGLDataset
from proteinshake.frameworks.dataset import FrameworkDataset


class DGLGraphDataset(FrameworkDataset, DGLDataset):
    """ Graph dataset for Deep Graph Library (DGL).
    """

    def convert_to_framework(self, data_item):
        nodes, adj = data_item.data
        data = dgl.from_scipy(adj, eweight_name='edge_weight')
        if data_item.weighted_edges:
            data.ndata[f'{data_item.resolution}'] = torch.tensor(nodes).long()
        return data