diff --git a/benchmarks/benchmarks/neighbors.py b/benchmarks/benchmarks/neighbors.py new file mode 100644 index 00000000000..a0ec0e2c85b --- /dev/null +++ b/benchmarks/benchmarks/neighbors.py @@ -0,0 +1,51 @@ +import numpy as np +from MDAnalysis.lib.pkdtree import PeriodicKDTree +from MDAnalysis.lib.distances import capped_distance +from scipy.spatial import cKDTree + + +class NeighborsBench: + """Benchmarks for neighbor searching functions.""" + + params = ([100, 1000, 10000, 100000], [20, 30, 36, 42, 48, 50, 60]) + param_names = ["number_of_atoms", "cutoff"] + + def setup(self, number_of_atoms, cutoff): + """Setup called before each benchmark with each parameter combination.""" + self.box = np.array( + [170.0, 70.0, 120.0, 90.0, 90.0, 90.0], dtype=np.float32 + ) + self.positions = ( + np.random.rand(number_of_atoms, 3) * self.box[:3] + ).astype(np.float32) + self.centre = (self.box[:3] / 2.0).reshape(1, 3) + self.cutoff = cutoff + + self.scipy_tree = cKDTree(self.positions, boxsize=self.box[:3]) + self.mda_tree = PeriodicKDTree(box=self.box) + self.mda_tree.set_coords(self.positions, cutoff=self.cutoff) + + def time_mda_tree_search(self, number_of_atoms, cutoff): + """Benchmark just the search operation on pre-built tree.""" + self.mda_tree.search(self.centre, self.cutoff) + + def time_scipy_tree_query(self, number_of_atoms, cutoff): + """Benchmark just the query operation on pre-built tree.""" + self.scipy_tree.query_ball_point(self.centre, self.cutoff) + + def time_mda_PKDtree_with_setup(self, number_of_atoms, cutoff): + """Benchmark tree construction + search.""" + tree = PeriodicKDTree(box=self.box) + tree.set_coords(self.positions, cutoff=self.cutoff) + tree.search(self.centre, self.cutoff) + + def time_scipy_cKDTree_with_setup(self, number_of_atoms, cutoff): + """Benchmark tree construction + query.""" + tree = cKDTree(self.positions, boxsize=self.box[:3]) + tree.query_ball_point(self.centre, self.cutoff) + + def time_capped_distance_array(self, number_of_atoms, cutoff): + """Benchmark capped distance calculation.""" + capped_distance( + self.centre, self.positions, max_cutoff=self.cutoff, box=self.box + ) diff --git a/package/MDAnalysis/lib/pkdtree.py b/package/MDAnalysis/lib/pkdtree.py index 952b4672e32..4699acf57b7 100644 --- a/package/MDAnalysis/lib/pkdtree.py +++ b/package/MDAnalysis/lib/pkdtree.py @@ -38,12 +38,12 @@ from MDAnalysis.lib.distances import apply_PBC import numpy.typing as npt -from typing import Optional, ClassVar +from typing import Optional, ClassVar, Union, Any __all__ = ["PeriodicKDTree"] -class PeriodicKDTree(object): +class AugmentedPKDTree(object): """Wrapper around :class:`scipy.spatial.cKDTree` Creates an object which can handle periodic as well as @@ -336,3 +336,243 @@ class initialization if pairs.size > 0: pairs = unique_rows(pairs) return pairs + + +class PeriodicKDTree(object): + + def __init__( + self, box: Optional[npt.ArrayLike] = None, leafsize: int = 10 + ) -> None: + self.leafsize = leafsize + self.dim = 3 + self.box = box + self._built = False + + self.cutoff: Optional[float] = None + self.mapping: Optional[npt.NDArray] = None + self._tree: Optional[Union[AugmentedPKDTree, cKDTree]] = None + + _use_augmented = False + if box is not None: + box_array = np.asarray(box, dtype=np.float32) + if box_array.shape == (6,): + if not np.allclose(box_array[3:], 90.0): + _use_augmented = True + else: + _use_augmented = True + + self._use_augmented = _use_augmented + + if self._use_augmented: + self._tree = AugmentedPKDTree(box=self.box, leafsize=leafsize) + else: + self._tree = None + if box is not None: + self.box = np.asarray(box, dtype=np.float32) + self._is_ortho = True + + @property + def pbc(self): + """Flag to indicate the presence of periodic boundaries. + + - ``True`` if PBC are taken into account + - ``False`` if no unitcell dimension is available. + + This is a managed attribute and can only be read. + """ + return self.box is not None + + def set_coords( + self, coords: npt.ArrayLike, cutoff: Optional[float] = None + ) -> None: + """Constructs KDTree from the coordinates + + Parameters + ---------- + coords: array_like + Coordinate array of shape ``(N, 3)`` for N atoms. + cutoff: float + Specified cutoff distance for searches. + Required for periodic calculations. + """ + if self._use_augmented: + assert self._tree is not None + self._tree.set_coords(coords, cutoff) + self._built = True + self.cutoff = cutoff + else: + coords = np.asarray(coords, dtype=np.float32) + self.cutoff = cutoff + + if self.box is None: + if cutoff is not None: + raise RuntimeError( + "Donot provide cutoff distance for non PBC aware calculations" + ) + self.coords = coords + self._tree = cKDTree(self.coords, leafsize=self.leafsize) + else: + if cutoff is None: + raise RuntimeError( + "Provide a cutoff distance with tree.set_coords(...)" + ) + self.coords = apply_PBC(coords, self.box) + box_array = np.asarray(self.box, dtype=np.float32) + self._tree = cKDTree( + self.coords, leafsize=self.leafsize, boxsize=box_array[:3] + ) + + self._built = True + + def search(self, centers: npt.ArrayLike, radius: float) -> npt.NDArray: + """Search all points within radius from centers and their periodic images. + + Parameters + ---------- + centers: array_like (N,3) + coordinate array to search for neighbors + radius: float + maximum distance to search for neighbors. + """ + if not self._built: + raise RuntimeError("Unbuilt tree. Run tree.set_coords(...)") + + if self._use_augmented: + assert self._tree is not None + return self._tree.search(centers, radius) + + centers = np.asarray(centers, dtype=np.float32) + if centers.shape == (self.dim,): + centers = centers.reshape((1, self.dim)) + + if self.pbc: + if self.cutoff is None: + raise ValueError( + "Cutoff needs to be provided when working with PBC." + ) + if self.cutoff < radius: + raise RuntimeError( + "Set cutoff greater or equal to the radius." + ) + wrapped_centers = apply_PBC(centers, self.box) + assert isinstance(self._tree, cKDTree) + indices = list( + self._tree.query_ball_point(wrapped_centers, radius) + ) + else: + assert isinstance(self._tree, cKDTree) + indices = list(self._tree.query_ball_point(centers, radius)) + + self._indices = np.array( + list(itertools.chain.from_iterable(indices)), dtype=np.intp + ) + + if self._indices.size > 0: + self._indices = np.asarray(unique_int_1d(self._indices)) + return self._indices + + def get_indices(self) -> npt.NDArray: + """Return the neighbors from the last query. + + Returns + ------ + indices : NDArray + neighbors for the last query points and search radius + """ + return self._indices + + def search_pairs(self, radius: float) -> npt.NDArray: + """Search all the pairs within a specified radius + + Parameters + ---------- + radius : float + Maximum distance between pairs of coordinates + + Returns + ------- + pairs : array + Indices of all the pairs which are within the specified radius + """ + if not self._built: + raise RuntimeError("Unbuilt Tree. Run tree.set_coords(...)") + + if self._use_augmented: + assert self._tree is not None + return self._tree.search_pairs(radius) + + if self.pbc: + if self.cutoff is None: + raise ValueError( + "Cutoff needs to be provided when working with PBC." + ) + if self.cutoff < radius: + raise RuntimeError( + "Set cutoff greater or equal to the radius." + ) + + assert isinstance(self._tree, cKDTree) + pairs = np.array(list(self._tree.query_pairs(radius)), dtype=np.intp) + + if pairs.size > 0: + pairs = np.sort(pairs, axis=1) + pairs = unique_rows(pairs) + return pairs + + def search_tree(self, centers: npt.ArrayLike, radius: float) -> np.ndarray: + """ + Searches all the pairs within `radius` between `centers` + and ``coords`` + + ``coords`` are the already initialized coordinates in the tree + during :meth:`set_coords`. + + Parameters + ---------- + centers: array_like (N,3) + coordinate array to search for neighbors + radius: float + maximum distance to search for neighbors. + + Returns + ------- + pairs : array + all the pairs between ``coords`` and ``centers`` + """ + if not self._built: + raise RuntimeError("Unbuilt tree. Run tree.set_coords(...)") + + if self._use_augmented: + assert self._tree is not None + return self._tree.search_tree(centers, radius) + + centers = np.asarray(centers, dtype=np.float32) + if centers.shape == (self.dim,): + centers = centers.reshape((1, self.dim)) + + if self.pbc: + if self.cutoff is None: + raise ValueError( + "Cutoff needs to be provided when working with PBC." + ) + if self.cutoff < radius: + raise RuntimeError( + "Set cutoff greater or equal to the radius." + ) + wrapped_centers = apply_PBC(centers, self.box) + box_array = np.asarray(self.box, dtype=np.float32) + other_tree = cKDTree( + wrapped_centers, leafsize=self.leafsize, boxsize=box_array[:3] + ) + else: + other_tree = cKDTree(centers, leafsize=self.leafsize) + + pairs_list = other_tree.query_ball_tree(self._tree, radius) + pairs = np.array( + [[i, j] for i, lst in enumerate(pairs_list) for j in lst], + dtype=np.intp, + ) + + if pairs.size > 0: + pairs = unique_rows(pairs) + return pairs diff --git a/testsuite/MDAnalysisTests/lib/test_pkdtree.py b/testsuite/MDAnalysisTests/lib/test_pkdtree.py index ec4e586b380..1b5e0a6b42f 100644 --- a/testsuite/MDAnalysisTests/lib/test_pkdtree.py +++ b/testsuite/MDAnalysisTests/lib/test_pkdtree.py @@ -26,10 +26,16 @@ from numpy.testing import assert_equal -from MDAnalysis.lib.pkdtree import PeriodicKDTree +from MDAnalysis.lib.pkdtree import PeriodicKDTree, AugmentedPKDTree from MDAnalysis.lib.distances import transform_StoR +@pytest.fixture(params=[PeriodicKDTree, AugmentedPKDTree]) +def tree_class(request): + """Fixture to run tests on both PeriodicKDTree and AugmentedPKDTree.""" + return request.param + + # fractional coordinates for data points f_dataset = np.array( [ @@ -58,23 +64,23 @@ ), ), ) -def test_setcoords(b, cut, result): +def test_setcoords(tree_class, b, cut, result): coords = np.array([[1, 1, 1], [2, 2, 2]], dtype=np.float32) if b is not None: b = np.array(b, dtype=np.float32) - tree = PeriodicKDTree(box=b) + tree = tree_class(box=b) print(b, tree.box, cut, result) with pytest.raises(RuntimeError, match=result): tree.set_coords(coords, cutoff=cut) -def test_searchfail(): +def test_searchfail(tree_class): coords = np.array([[1, 1, 1], [2, 2, 2]], dtype=np.float32) b = np.array([10, 10, 10, 90, 90, 90], dtype=np.float32) cutoff = 1.0 search_radius = 2.0 query = np.array([1, 1, 1], dtype=np.float32) - tree = PeriodicKDTree(box=b) + tree = tree_class(box=b) tree.set_coords(coords, cutoff=cutoff) match = "Set cutoff greater or equal to the radius." with pytest.raises(RuntimeError, match=match): @@ -89,22 +95,22 @@ def test_searchfail(): ([10, 10, 10, 45, 60, 90], [2.1, -3.1, 0.1], [2, 3]), ), ) -def test_search(b, q, result): +def test_search(tree_class, b, q, result): b = np.array(b, dtype=np.float32) q = transform_StoR(np.array(q, dtype=np.float32), b) cutoff = 3.0 coords = transform_StoR(f_dataset, b) - tree = PeriodicKDTree(box=b) + tree = tree_class(box=b) tree.set_coords(coords, cutoff=cutoff) indices = tree.search(q, cutoff) assert_equal(indices, result) -def test_nopbc(): +def test_nopbc(tree_class): cutoff = 0.3 q = np.array([0.2, 0.3, 0.1]) coords = f_dataset.copy() - tree = PeriodicKDTree(box=None) + tree = tree_class(box=None) tree.set_coords(coords) indices = tree.search(q, cutoff) assert_equal(indices, [0, 2]) @@ -123,11 +129,11 @@ def test_nopbc(): ([10, 10, 10, 45, 60, 90], 0.1, []), ), ) -def test_searchpairs(b, radius, result): +def test_searchpairs(tree_class, b, radius, result): b = np.array(b, dtype=np.float32) cutoff = 2.0 coords = transform_StoR(f_dataset, b) - tree = PeriodicKDTree(box=b) + tree = tree_class(box=b) tree.set_coords(coords, cutoff=cutoff) if cutoff < radius: with pytest.raises(RuntimeError, match=result): @@ -138,9 +144,9 @@ def test_searchpairs(b, radius, result): @pytest.mark.parametrize("radius, result", ((0.1, []), (0.3, [[0, 2]]))) -def test_ckd_searchpairs_nopbc(radius, result): +def test_ckd_searchpairs_nopbc(tree_class, radius, result): coords = f_dataset.copy() - tree = PeriodicKDTree() + tree = tree_class(box=None) tree.set_coords(coords) indices = tree.search_pairs(radius) assert_equal(indices, result) @@ -161,12 +167,12 @@ def test_ckd_searchpairs_nopbc(radius, result): [0, 3]]) )) # fmt: on -def test_searchtree(b, q, result): +def test_searchtree(tree_class, b, q, result): b = np.array(b, dtype=np.float32) cutoff = 3.0 coords = transform_StoR(f_dataset, b) q = transform_StoR(np.array(q, dtype=np.float32), b) - tree = PeriodicKDTree(box=b) + tree = tree_class(box=b) tree.set_coords(coords, cutoff=cutoff) pairs = tree.search_tree(q, cutoff) assert_equal(pairs, result)