"""Cell-list neighbor builder for truncated-range pair interactions.
Provides :func:`build_neighbor_csr` which constructs a sparse CSR
neighbor list from particle positions and a cutoff radius, using
``scipy.spatial.cKDTree``. The returned ``(indptr, indices)`` arrays
plug directly into ``dispatch_pairs_from_extras``.
All routines run on the host (pure NumPy) and are meant to be called
*between* JIT-compiled simulation chunks, not inside ``jax.lax.scan``.
"""
from __future__ import annotations
from typing import Optional, Tuple
import jax.numpy as jnp
import numpy as np
from scipy.spatial import cKDTree
[docs]
def build_neighbor_csr(
positions: np.ndarray,
cutoff: float,
box: Optional[np.ndarray] = None,
*,
exclude_self: bool = True,
) -> Tuple[np.ndarray, np.ndarray]:
"""Build a CSR neighbor list using ``scipy.spatial.cKDTree``.
Parameters
----------
positions : ndarray, shape ``(N, d)``
Particle positions (spatial coordinates only).
cutoff : float
Cutoff radius. Pairs with ``r_ij > cutoff`` are excluded.
box : ndarray, shape ``(d,)``, optional
Periodic box lengths. If *None*, open (non-periodic) boundaries.
exclude_self : bool
If *True* (default), self-pairs ``(i, i)`` are never included.
Returns
-------
indptr : ndarray, shape ``(N + 1,)``, dtype int32
CSR row pointers.
indices : ndarray, shape ``(nnz,)``, dtype int32
CSR column indices (neighbour particle indices).
"""
positions = np.asarray(positions, dtype=np.float64)
N, d = positions.shape
if N == 0:
return (
np.zeros(1, dtype=np.int32),
np.empty(0, dtype=np.int32),
)
# --- wrap positions into the primary box ---
if box is not None:
box = np.asarray(box, dtype=np.float64)
positions = positions % box
# --- build KD-tree and query pairs ---
boxsize = box if box is not None else None
tree = cKDTree(positions, boxsize=boxsize)
csr = tree.sparse_distance_matrix(tree, cutoff, output_type="coo_matrix")
csr = csr.tocsr()
if exclude_self:
csr.setdiag(0)
csr.eliminate_zeros()
indptr = csr.indptr.astype(np.int32)
indices = csr.indices.astype(np.int32)
return indptr, indices
[docs]
def pad_neighbor_csr(
indptr: np.ndarray,
indices: np.ndarray,
max_nnz: int,
*,
fill_index: int = 0,
) -> Tuple[np.ndarray, np.ndarray]:
"""Pad a CSR neighbor list to a fixed ``max_nnz``.
JAX JIT recompiles when array shapes change. Padding the indices
array to a fixed length avoids recompilation across simulation
chunks with fluctuating neighbor counts.
Excess entries are filled with ``fill_index`` (default 0). Because
``indptr`` keeps the true lengths, the dispatcher will only iterate
over the real neighbours — the padded entries are never evaluated.
.. note::
This only pads ``indices``. ``indptr`` is left unchanged (always
``N + 1`` long). If the actual nnz exceeds ``max_nnz``, a
``ValueError`` is raised.
Parameters
----------
indptr, indices
As returned by :func:`build_neighbor_csr`.
max_nnz : int
Target length for ``indices``.
fill_index : int
Index used to fill padded entries.
Returns
-------
indptr, indices_padded
Same ``indptr``, padded ``indices`` of length ``max_nnz``.
"""
nnz = len(indices)
if nnz > max_nnz:
raise ValueError(
f"Actual nnz ({nnz}) exceeds max_nnz ({max_nnz}). Increase max_nnz or enlarge the cutoff safety margin."
)
if nnz == max_nnz:
return indptr, indices
padded = np.full(max_nnz, fill_index, dtype=indices.dtype)
padded[:nnz] = indices
return indptr, padded