Source code for prody.dynamics.compare

# -*- coding: utf-8 -*-
"""This module defines functions for comparing normal modes from different
models."""

import numpy as np
from numbers import Integral
from prody import LOGGER, SETTINGS
from prody.utilities import openFile, isListLike

from .nma import NMA
from .modeset import ModeSet
from .mode import Mode, Vector
from .gnm import ZERO
from .analysis import calcFractVariance, calcSqFlucts

__all__ = ['calcOverlap', 'calcCumulOverlap', 'calcSubspaceOverlap', 'calcSpectralOverlap', 
           'calcCovOverlap', 'printOverlapTable', 'writeOverlapTable', 
           'calcSquareInnerProduct','pairModes', 'matchModes']

SO_CACHE = {}
WO_CACHE = {}

[docs]def calcOverlap(rows, cols, diag=False): """Returns overlap (or correlation) between two sets of modes (*rows* and *cols*). Returns a matrix whose rows correspond to modes passed as *rows* argument, and columns correspond to those passed as *cols* argument. Both rows and columns are normalized prior to calculating overlap. This function can now return the diagonal of the overlap matrix if *diag* is set to **True**.""" if not isinstance(rows, (NMA, ModeSet, Mode, Vector, np.ndarray)): raise TypeError('rows must be NMA, ModeSet, Mode, Vector, or array, not {0}' .format(type(rows))) if not isinstance(cols, (NMA, ModeSet, Mode, Vector, np.ndarray)): raise TypeError('cols must be NMA, ModeSet, Mode, or Vector, or array, not {0}' .format(type(cols))) if isinstance(rows, np.ndarray): num_rows = rows.shape[0] else: num_rows = rows.numEntries() if isinstance(cols, np.ndarray): num_cols = cols.shape[0] else: num_cols = cols.numEntries() if num_rows != num_cols: raise ValueError('the length of vectors in rows and ' 'cols must be the same') if not isinstance(rows, np.ndarray): rows = rows.getArray() rows *= 1 / (rows ** 2).sum(0) ** 0.5 if not isinstance(cols, np.ndarray): cols = cols.getArray() cols *= 1 / (cols ** 2).sum(0) ** 0.5 if diag: overlaps = np.einsum('ij,ji->i', rows.T, cols) else: overlaps = np.dot(rows.T, cols) return overlaps
[docs]def printOverlapTable(rows, cols): """Print table of overlaps (correlations) between two sets of modes. *rows* and *cols* are sets of normal modes, and correspond to rows and columns of the printed table. This function may be used to take a quick look into mode correspondences between two models. >>> # Compare top 3 PCs and slowest 3 ANM modes >>> printOverlapTable(p38_pca[:3], p38_anm[:3]) # doctest: +SKIP Overlap Table ANM 1p38 #1 #2 #3 PCA p38 xray #1 -0.39 +0.04 -0.71 PCA p38 xray #2 -0.78 -0.20 +0.22 PCA p38 xray #3 +0.05 -0.57 +0.06""" print(getOverlapTable(rows, cols))
[docs]def writeOverlapTable(filename, rows, cols): """Write table of overlaps (correlations) between two sets of modes to a file. *rows* and *cols* are sets of normal modes, and correspond to rows and columns of the overlap table. See also :func:`.printOverlapTable`.""" assert isinstance(filename, str), 'filename must be a string' out = openFile(filename, 'w') out.write(getOverlapTable(rows, cols)) out.close() return filename
def getOverlapTable(rows, cols): """Make a formatted string of overlaps between modes in *rows* and *cols*. """ overlap = calcOverlap(rows, cols) if isinstance(rows, Mode): rids = [rows.getIndex()] rname = str(rows.getModel()) elif isinstance(rows, NMA): rids = np.arange(len(rows)) rname = str(rows) elif isinstance(rows, ModeSet): rids = rows.getIndices() rname = str(rows.getModel()) else: rids = [0] rname = str(rows) rlen = len(rids) if isinstance(cols, Mode): cids = [cols.getIndex()] cname = str(cols.getModel()) elif isinstance(cols, NMA): cids = np.arange(len(cols)) cname = str(cols) elif isinstance(cols, ModeSet): cids = cols.getIndices() cname = str(cols.getModel()) else: cids = [0] cname = str(cols) clen = len(cids) overlap = overlap.reshape((rlen, clen)) table = 'Overlap Table\n' table += (' '*(len(rname)+5) + cname.center(clen*7)).rstrip() + '\n' line = ' '*(len(rname)+5) for j in range(clen): line += ('#{0}'.format(cids[j]+1)).center(7) table += line.rstrip() + '\n' for i in range(rlen): line = rname + (' #{0}'.format(rids[i]+1)).ljust(5) for j in range(clen): if abs(overlap[i, j]).round(2) == 0.00: minplus = ' ' elif overlap[i, j] < 0: minplus = '-' else: minplus = '+' line += (minplus+'{0:-.2f}').format(abs(overlap[i, j])).center(7) table += line.rstrip() + '\n' return table
[docs]def calcCumulOverlap(modes1, modes2, array=False): """Returns cumulative overlap of modes in *modes2* with those in *modes1*. Returns a number of *modes1* contains a single :class:`.Mode` or a :class:`.Vector` instance. If *modes1* contains multiple modes, returns an array. Elements of the array correspond to cumulative overlaps for modes in *modes1* with those in *modes2*. If *array* is **True**, returns an array of cumulative overlaps. Returned array has the shape ``(len(modes1), len(modes2))``. Each row corresponds to cumulative overlaps calculated for modes in *modes1* with those in *modes2*. Each value in a row corresponds to cumulative overlap calculated using upto that many number of modes from *modes2*.""" overlap = calcOverlap(modes1, modes2) if array: return np.sqrt(np.power(overlap, 2).sum(axis=overlap.ndim-1)) else: return np.sqrt(np.power(overlap, 2).cumsum(axis=overlap.ndim-1))
[docs]def calcSubspaceOverlap(modes1, modes2): """Returns subspace overlap between two sets of modes (*modes1* and *modes2*). Also known as the root mean square inner product (RMSIP) of essential subspaces [AA99]_. This function returns a single number. .. [AA99] Amadei A, Ceruso MA, Di Nola A. On the convergence of the conformational coordinates basis set obtained by the essential dynamics analysis of proteins' molecular dynamics simulations. *Proteins* **1999** 36(4):419-424.""" overlap = calcOverlap(modes1, modes2) if isinstance(modes1, Mode): length = 1 else: length = len(modes1) rmsip = np.sqrt(np.power(overlap, 2).sum() / length) return rmsip
[docs]def calcSquareInnerProduct(modes1, modes2): """Returns the square inner product (SIP) of fluctuations [SK02]_. This function returns a single number. .. [SK02] Kundu S, Melton JS, Sorensen DC, Phillips GN: Dynamics of proteins in crystals: comparison of experiment with simple models. Biophys J. 2002, 83: 723-732. """ if isinstance(modes1, (NMA, ModeSet)): w1 = calcSqFlucts(modes1) elif isListLike(modes1): w1 = modes1 else: raise TypeError('modes1 should be a profile or an NMA or ModeSet object') if isinstance(modes2, (NMA, ModeSet)): w2 = calcSqFlucts(modes2) elif isListLike(modes2): w2 = modes2 else: raise TypeError('modes2 should be a profile or an NMA or ModeSet object') return np.dot(w1, w2)**2 / (np.dot(w1, w1) * np.dot(w2, w2))
[docs]def calcSpectralOverlap(modes1, modes2, weighted=False, turbo=False): """Returns overlap between covariances of *modes1* and *modes2*. Overlap between covariances are calculated using normal modes (eigenvectors), hence modes in both models must have been calculated. This function implements equation 11 in [BH02]_. .. [BH02] Hess B. Convergence of sampling in protein simulations. *Phys Rev E* **2002** 65(3):031910. :arg weighted: if **True** then covariances are weighted by the trace. :type weighted: bool """ if modes1.is3d() ^ modes2.is3d(): raise TypeError('models must be either both 1-dimensional or 3-dimensional') if modes1.numAtoms() != modes2.numAtoms(): raise ValueError('modes1 and modes2 must have same number of atoms') if isinstance(modes1, Mode): if weighted: varA = np.array([calcFractVariance(modes1)]) else: varA = np.array([modes1.getVariance()]) I = np.array([modes1.getIndex()]) else: if weighted: varA = calcFractVariance(modes1) else: varA = modes1.getVariances() try: I = modes1.getIndices() except: try: modes1 = modes1[:] I = modes1.getIndices() except: raise TypeError('modes1 should be ModeSet or an object from which a ModeSet can be obtained') if isinstance(modes2, Mode): if weighted: varB = np.array([calcFractVariance(modes2)]) else: varB = np.array([modes2.getVariance()]) J = np.array([modes2.getIndex()]) else: if weighted: varB = calcFractVariance(modes2) else: varB = modes2.getVariances() try: J = modes2.getIndices() except: try: modes2 = modes2[:] J = modes2.getIndices() except: raise TypeError('modes2 should be ModeSet or an object from which a ModeSet can be obtained') if turbo: model1 = modes1.getModel() model2 = modes2.getModel() if weighted: CACHE = WO_CACHE else: CACHE = SO_CACHE if (model1, model2) in CACHE: weights = CACHE[(model1, model2)] elif (model2, model1) in CACHE: weights = CACHE[(model2, model1)] else: farrayA = model1._getArray() farrayB = model2._getArray() fvarA = model1.getVariances() fvarB = model2.getVariances() dotAB = np.dot(farrayA.T, farrayB)**2 outerAB = np.outer(fvarA**0.5, fvarB**0.5) CACHE[(model1, model2)] = weights = outerAB * dotAB weights = weights[I, :][:, J] else: arrayA = modes1._getArray() arrayB = modes2._getArray() dotAB = np.dot(arrayA.T, arrayB)**2 outerAB = np.outer(varA**0.5, varB**0.5) weights = outerAB * dotAB diff = (np.sum(varA.sum() + varB.sum()) - 2 * np.sum(weights)) if diff < ZERO: diff = 0 else: diff = diff ** 0.5 return 1 - diff / np.sqrt(varA.sum() + varB.sum())
[docs]def calcCovOverlap(modes1, modes2, turbo=False): """Returns overlap between covariances of *modes1* and *modes2*. Overlap between covariances are calculated using normal modes (eigenvectors), hence modes in both models must have been calculated. This function implements equation 11 in [BH02]_.""" return calcSpectralOverlap(modes1, modes2, turbo=turbo)
[docs]def pairModes(modes1, modes2, **kwargs): """Returns the optimal matches between *modes1* and *modes2*. *modes1* and *modes2* should have equal number of modes, and the function will return a nested list where each item is a list containing a pair of modes. :arg index: if **True** then indices of modes will be returned instead of :class:`Mode` instances. :type index: bool """ index = kwargs.pop('index', False) method = kwargs.pop('method', None) if method is None: from scipy.optimize import linear_sum_assignment method = linear_sum_assignment if not (isinstance(modes1, (ModeSet, NMA)) \ and isinstance(modes2, (ModeSet, NMA))): raise TypeError('modes1 and modes2 should be ModeSet or NMA instances') if len(modes1) != len(modes2): raise ValueError('the same number of modes should be provided') overlaps = calcOverlap(modes1, modes2) costs = 1 - abs(overlaps) row_ind, col_ind = method(costs) if index: return row_ind, col_ind if isinstance(modes1, ModeSet): row_ind = modes1._indices[row_ind] if isinstance(modes2, ModeSet): col_ind = modes2._indices[col_ind] outmodes1 = ModeSet(modes1.getModel(), row_ind) outmodes2 = ModeSet(modes2.getModel(), col_ind) return outmodes1, outmodes2
def _pairModes_wrapper(args): modeset0, modesets, index = args ret = [] for modeset in modesets: _, reordered_modeset = pairModes(modeset0, modeset, index=index) ret.append(reordered_modeset) return ret
[docs]def matchModes(*modesets, **kwargs): """Returns the matches of modes among *modesets*. Note that the first modeset will be treated as the reference so that only the matching of each modeset to the first modeset is guaranteed to be optimal. :arg index: if **True** then indices of modes will be returned instead of :class:`Mode` instances :type index: bool :arg turbo: if **True** then the computation will be performed in parallel. The number of threads is set to be the same as the number of CPUs. Assigning a number to specify the number of threads to be used. Note that if writing a script, ``if __name__ == '__main__'`` is necessary to protect your code when multi-tasking. See https://docs.python.org/2/library/multiprocessing.html for details. Default is **False** :type turbo: bool, int """ index = kwargs.pop('index', False) turbo = kwargs.pop('turbo', False) n_worker = None if not isinstance(turbo, bool): n_worker = int(turbo) modeset0 = modesets[0] if index: ret = [modeset0.getIndices()] else: ret = [modeset0] n_modes = len(modeset0) n_sets = len(modesets) if n_sets == 1: return ret elif n_sets == 0: raise ValueError('at least one modeset should be given') if turbo: from multiprocessing import Pool, cpu_count from math import ceil if not n_worker: n_worker = cpu_count() LOGGER.info('Matching {0} modes across {1} modesets with {2} threads...' .format(n_modes, n_sets, n_worker)) pool = Pool(n_worker) n_sets_per_worker = ceil((n_sets - 1) / n_worker) args = [] for i in range(n_worker): start = i*n_sets_per_worker + 1 end = (i+1)*n_sets_per_worker + 1 subset = modesets[start:end] args.append((modeset0, subset, index)) nested_ret = pool.map(_pairModes_wrapper, args) for entry in nested_ret: ret.extend(entry) pool.close() pool.join() else: LOGGER.progress('Matching {0} modes across {1} modesets...' .format(n_modes, n_sets), n_sets, '_prody_matchModes') for i, modeset in enumerate(modesets): LOGGER.update(i, label='_prody_matchModes') if i > 0: _, reordered_modeset = pairModes(modeset0, modeset, index=index, **kwargs) ret.append(reordered_modeset) LOGGER.finish() return ret