TMalign needs to be in your $PATH. Follow the instructions at
import glob
import requests
import os
import itertools
import re
import subprocess
import tempfile
import shutil
import numpy as np
from biopandas.pdb import PandasPdb
from collections import defaultdict
from joblib import Parallel, delayed
from functools import cached_property

from proteinshake.datasets import RCSBDataset
from proteinshake.utils import (extract_tar,

class TMAlignDataset(RCSBDataset):
    """ Proteins that were aligned with TMalign annotated with distance/similarity metrics.
    The dataset provides the TM-score, RMSD, Global Distance Test (GDT), and Local Distance Difference Test (LDDT).

    .. admonition:: Please cite

      Zhang, Yang, and Jeffrey Skolnick. "TM-align: a protein structure alignment algorithm based on the TM-score." Nucleic acids research 33.7 (2005): 2302-2309.

      Berman, H M et al. “The Protein Data Bank.” Nucleic acids research vol. 28,1 (2000): 235-42. doi:10.1093/nar/28.1.235

    .. admonition:: Source

      Raw data was obtained and modified from `RCSB Protein Data Bank <>`_, originally licensed under `CC0 1.0 <>`_.

    .. list-table:: Dataset stats
       :widths: 100
       :header-rows: 1

       * - # proteins
       * - 994

    .. code-block:: python

        from proteinshake.datasets import TMAlignDataset

        dataset = TMAlignDataset()
        proteins = dataset.proteins()
        protein_1, protein_2 = next(proteins)['protein']['ID'], next(proteins)['protein']['ID']

        dataset.tm_score(protein_1, protein_2)
        >>> 0.03
        dataset.rmsd(protein_1, protein_2)
        >>> 3.64
        dataset.gdt(protein_1, protein_2)
        >>> 0.61
        dataset.lddt(protein_1, protein_2)
        >>> 0.65


    additional_files = [

    def __init__(self, **kwargs):

        if not self.use_precomputed: self.align_structures()
        self.protein_ids = [p['protein']['ID'] for p in self.proteins()]

        def download_file(filename):
            if not os.path.exists(f'{self.root}/{filename}'):
                download_url(f'{self.repository_url}/{filename}.gz', f'{self.root}', verbosity=0)
            return load(f'{self.root}/{filename}')

        self._tm_score = download_file(f'{}.tmscore.npy')
        self._rmsd = download_file(f'{}.rmsd.npy')
        self._gdt = download_file(f'{}.gdt.npy')
        self._lddt = download_file(f'{}.lddt.npy')
    def limit(self):
        return 1000
    def align_structures(self):
        """ Calls TMalign on all pairs of structures and saves the output"""
        if os.path.exists(f'{self.root}/{}.tmscore.npy'):
        pdbids = [p['protein']['ID'] for p in self.proteins()]
        path_dict = {self.get_id_from_filename(os.path.basename(f)):f for f in self.get_raw_files()}
        paths = [path_dict[id] for id in pdbids]
        num_proteins = len(paths)
        combinations = np.array(list(itertools.combinations(range(num_proteins), 2)))
        TM, RMSD, GDT, LDDT = [np.ones((num_proteins,num_proteins), dtype=np.float16) * np.nan for _ in ['tm','rmsd','gdt','lddt']]
        np.fill_diagonal(TM, 1.0), np.fill_diagonal(RMSD, 0.0), np.fill_diagonal(GDT, 1.0), np.fill_diagonal(LDDT, 1.0)
        d = Parallel(n_jobs=self.n_jobs)(delayed(tmalign_wrapper)(paths[i], paths[j]) for i,j in progressbar(combinations, desc='Aligning', verbosity=self.verbosity))
        x,y = tuple(combinations[:,0]), tuple(combinations[:,1])
        TM[x,y] = [x['TM1'] for x in d]
        TM[y,x] = [x['TM2'] for x in d]
        RMSD[x,y] = [x['RMSD'] for x in d]
        RMSD[y,x] = [x['RMSD'] for x in d]
        GDT[x,y] = [x['GDT'] for x in d]
        GDT[y,x] = [x['GDT'] for x in d]
        LDDT[x,y] = [x['LDDT'] for x in d]
        LDDT[y,x] = [x['LDDT'] for x in d]
        # save'{self.root}/{}.tmscore.npy', TM)'{self.root}/{}.rmsd.npy', RMSD)'{self.root}/{}.gdt.npy', GDT)'{self.root}/{}.lddt.npy', LDDT)
def tm_score(self, protein_1, protein_2): return self._tm_score[self.protein_ids.index(protein_1)][self.protein_ids.index(protein_2)] def rmsd(self, protein_1, protein_2): return self._rmsd[self.protein_ids.index(protein_1)][self.protein_ids.index(protein_2)] def gdt(self, protein_1, protein_2): return self._gdt[self.protein_ids.index(protein_1)][self.protein_ids.index(protein_2)] def lddt(self, protein_1, protein_2): return self._lddt[self.protein_ids.index(protein_1)][self.protein_ids.index(protein_2)]