Source code for prody.kdtree.kdtree

# -*- coding: utf-8 -*-
"""This module defines :class:`KDTree` class for dealing with atomic coordinate
sets and handling periodic boundary conditions."""

from numpy import array, ndarray, concatenate, empty

from prody import LOGGER

def createKDTreeByDim(KDTclass, coords, bucketsize):
    kdt = KDTclass(3, bucketsize)
    return kdt

def createKDTreeByCoords(KDTclass, coords, bucketsize):
    kdt = KDTclass(coords, bucketsize)
    return kdt

    from ._CKDTree import KDTree as _KDTree
    CKDTree = lambda coords, bz : createKDTreeByDim(_KDTree, coords, bz) 
except ImportError:
        from Bio.PDB.kdtrees import KDTree as _KDTree
        CKDTree = lambda coords, bz : createKDTreeByCoords(_KDTree, coords, bz) 
    except ImportError:
            from Bio.KDTree._CKDTree import KDTree as _KDTree
            CKDTree = lambda coords, bz : createKDTreeByDim(_KDTree, coords, bz) 
        except ImportError:
            raise ImportError('CKDTree module could not be imported. '
                            'Reinstall ProDy or install Biopython '
                            'to solve the problem.')

__all__ = ['KDTree']

_ = array([-1., 0., 1.])
REPLICATE = array([[x, y, z] for x in _ for y in _ for z in _])

[docs]class KDTree(object): """An interface to Thomas Hamelryck's C KDTree module that can handle periodic boundary conditions. Both point and pair search are performed using the single :meth:`search` method and results are retrieved using :meth:`getIndices` and :meth:`getDistances`. **Periodic Boundary Conditions** *Point search* A point search around a *center*, indicated with a question mark (``?``) below, involves making images of the point in cells sharing a wall or an edge with the unitcell that contains the system. The search is performed for all images of the *center* (27 in 3-dimensional space) and unique indices with the minimum distance from them to the *center* are returned. :: _____________________________ | 1| 2| 3| | ? | ? | ? | |_________|_________|_________| | 4|o h h 5| 6| ? and H interact in periodic image 4 | ?H| h o ? | ? | but not in the original unitcell (5) |_________|_________|_________| | 7| 8| 9| | ? | ? | ? | |_________|_________|_________| There are two requirements for this approach to work: (i) the *center* must be in the original unitcell, and (ii) the system must be in the original unitcell with parts in its immediate periodic images. *Pair search* A pair search involves making 26 (or 8 in 2-d) replicas of the system coordinates. A KDTree is built for the system (``O`` and ``H``) and all its replicas (``o`` and ``h``). After pair search is performed, unique pairs of indices and minimum distance between them are returned. :: _____________________________ |o h h 1|o h h 2|o h h 3| h| h o h| h o h| h o | |_________|_________|_________| |o h h 4|O H H 5|o h h 6| h| h o H| H O h| h o | |_________|_________|_________| |o h h 7|o h h 8|o h h 9| h| h o h| h o h| h o | |_________|_________|_________| Only requirement for this approach to work is that the system must be in the original unitcell with parts in its immediate periodic images. .. seealso:: :func:`.wrapAtoms` can be used for wrapping atoms into the single periodic image of the system.""" def __init__(self, coords, **kwargs): """ :arg coords: coordinate array with shape ``(N, 3)``, where N is number of atoms :type coords: :class:`numpy.ndarray`, :class:`.Atomic`, :class:`.Frame` :arg unitcell: orthorhombic unitcell dimension array with shape ``(3,)`` :type unitcell: :class:`numpy.ndarray` :arg bucketsize: number of points per tree node, default is 10 :type bucketsize: int""" unitcell = kwargs.get('unitcell') if not isinstance(coords, ndarray): if unitcell is None: try: unitcell = coords.getUnitcell() except AttributeError: pass else: if unitcell is not None:'Unitcell information from {0} will be ' 'used.'.format(str(coords))) try: # using getCoords() because coords will be stored internally # and reused when needed, this will avoid unexpected results # due to changes made to coordinates externally coords = coords.getCoords() except AttributeError: raise TypeError('coords must be a Numpy array or must have ' 'getCoords attribute') else: coords = coords.copy() if coords.ndim != 2: raise Exception('coords.ndim must be 2') if coords.shape[-1] != 3: raise Exception('coords.shape must be (N,3)') if coords.min() <= -1e6 or coords.max() >= 1e6: raise Exception('coords must be between -1e6 and 1e6') self._bucketsize = kwargs.get('bucketsize', 10) if not isinstance(self._bucketsize, int): raise TypeError('bucketsize must be an integer') if self._bucketsize < 1: raise ValueError('bucketsize must be a positive integer') self._coords = None self._unitcell = None self._neighbors = None if unitcell is None: self._kdtree = CKDTree(coords, self._bucketsize) else: if not isinstance(unitcell, ndarray): raise TypeError('unitcell must be a Numpy array') if unitcell.shape != (3,): raise ValueError('unitcell.shape must be (3,)') self._kdtree = CKDTree(coords, self._bucketsize) self._coords = coords self._unitcell = unitcell self._replicate = REPLICATE * unitcell self._kdtree2 = None self._pbcdict = {} self._pbckeys = [] self._n_atoms = coords.shape[0] self._none = kwargs.pop('none', lambda: None) try: self._none() except TypeError: raise TypeError('none argument must be callable') self._oncall = kwargs.pop('oncall', 'both') assert self._oncall in ('both', 'dist'), 'oncall must be both or dist' def __call__(self, radius, center=None): """Shorthand method for searching and retrieving results.""", center) if self._oncall == 'both': return self.getIndices(), self.getDistances() elif self._oncall == 'dist': return self.getDistances()
[docs] def search(self, radius, center=None): """Search pairs within *radius* of each other or points within *radius* of *center*. :arg radius: distance (Å) :type radius: float :arg center: a point in Cartesian coordinate system :type center: :class:`numpy.ndarray`""" if not isinstance(radius, (float, int)): raise TypeError('radius must be a number') if radius <= 0: raise TypeError('radius must be a positive number') if center is not None: if not isinstance(center, ndarray): raise TypeError('center must be a Numpy array instance') if center.shape != (3,): raise ValueError('center.shape must be (3,)') if self._unitcell is None: self._kdtree.search_center_radius(center, radius) self._neighbors = None else: kdtree = self._kdtree search = kdtree.search_center_radius get_radii = lambda : get_KDTree_radii(kdtree) get_indices = lambda : get_KDTree_indices(kdtree) get_count = kdtree.get_count _dict = {} _dict_get = _dict.get _dict_set = _dict.__setitem__ for center in center + self._replicate: search(center, radius) if get_count(): [_dict_set(i, min(r, _dict_get(i, 1e6))) for i, r in zip(get_indices(), get_radii())] self._pbcdict = _dict self._pdbkeys = list(_dict) else: if self._unitcell is None: self._neighbors = self._kdtree.neighbor_search(radius) else: kdtree = self._kdtree2 if kdtree is None: coords = self._coords coords = concatenate([coords + rep for rep in self._replicate]) kdtree = CKDTree(coords, self._bucketsize) self._kdtree2 = kdtree n_atoms = len(self._coords) _dict = {} neighbors = kdtree.neighbor_search(radius) if kdtree.neighbor_get_count(): _get = _dict.get _set = _dict.__setitem__ for nb in neighbors: i = nb.index1 % n_atoms j = nb.index2 % n_atoms if i < j: _set((i, j), min(nb.radius, _get((i, j), 1e6))) elif j < i: _set((j, i), min(nb.radius, _get((j, i), 1e6))) self._pbcdict = _dict self._pdbkeys = list(_dict)
[docs] def getIndices(self): """Returns array of indices for points or pairs, depending on the type of the most recent search.""" if self.getCount(): if self._unitcell is None: if self._neighbors is None: return get_KDTree_indices(self._kdtree) else: return array([(n.index1, n.index2) for n in self._neighbors], int) else: return array(self._pdbkeys) return self._none()
[docs] def getDistances(self): """Returns array of distances.""" if self.getCount(): if self._unitcell is None: if self._neighbors is None: return get_KDTree_radii(self._kdtree) else: return array([n.radius for n in self._neighbors]) else: _dict = self._pbcdict return array([_dict[i] for i in self._pdbkeys]) return self._none()
[docs] def getCount(self): """Returns number of points or pairs.""" if self._unitcell is None: if self._neighbors is None: return self._kdtree.get_count() else: return self._kdtree.neighbor_get_count() else: return len(self._pbcdict)
def get_KDTree_indices(kdtree): indices = None try: indices = kdtree.get_indices() except: n = kdtree.get_count() if n: indices = empty(n, int) kdtree.get_indices(indices) return indices def get_KDTree_radii(kdtree): radii = None try: radii = kdtree.get_radii() except: n = kdtree.get_count() if n: radii = empty(n, 'f') kdtree.get_radii(radii) return radii