"""Chunked simulation with periodic neighbor-list rebuilds.
Provides :func:`simulate_chunked`, a thin wrapper around
:meth:`OverdampedProcess.simulate` that breaks a long simulation into
shorter chunks and rebuilds the CSR neighbor list (via
:func:`~SFI.utils.neighbors.build_neighbor_csr`) between chunks.
This avoids the O(N²) cost of ``AutoPairs`` for large particle systems
while keeping the extras interface exactly as-is.
"""
from __future__ import annotations
from typing import Optional
import jax
import jax.numpy as jnp
import numpy as np
from SFI.trajectory.collection import TrajectoryCollection
from SFI.utils.neighbors import build_neighbor_csr, pad_neighbor_csr
[docs]
def simulate_chunked(
proc,
dt: float,
Nsteps: int,
key,
*,
cutoff: float,
box: np.ndarray,
skin: float = 0.0,
rebuild_every: int = 50,
save_every: Optional[int] = None,
spatial_dims: slice = slice(None, 2),
indptr_key: str = "indptr",
indices_key: str = "indices",
nnz_safety: float = 1.25,
oversampling: int = 1,
prerun: int = 0,
compute_observables: bool = False,
jit_compile: bool = True,
verbose: bool = False,
) -> TrajectoryCollection:
"""Run a chunked overdamped simulation with periodic neighbor rebuilds.
Parameters
----------
proc : OverdampedProcess
An initialized process whose force uses
``dispatch_pairs_from_extras(indptr_key, indices_key)``.
dt : float
Time step.
Nsteps : int
Total number of steps.
key : jax PRNG key
Random key for the simulation.
cutoff : float
Cutoff radius for neighbor list construction.
box : array-like, shape ``(d,)``
Periodic box lengths.
skin : float
Verlet skin width. The neighbor list is built with radius
``cutoff + skin`` so that particles drifting into range between
rebuilds are already included. After each chunk the maximum
particle displacement is checked; a warning is printed if it
exceeds ``skin / 2`` (the Verlet safety threshold).
rebuild_every : int
Number of simulation steps between neighbor-list rebuilds.
save_every : int, optional
Number of simulation steps per output dataset. If *None*,
defaults to ``rebuild_every``. Must be a multiple of
``rebuild_every``.
spatial_dims : slice
Slice into the state vector that selects spatial coordinates
(default: first two components ``[:2]``).
indptr_key, indices_key : str
Extras keys for the CSR neighbor list.
nnz_safety : float
Fraction by which ``max_nnz`` is enlarged beyond the initial
neighbor count to absorb fluctuations (default 1.25 = 25%).
oversampling, prerun, compute_observables, jit_compile
Forwarded to ``proc.simulate()``.
verbose : bool
Print progress info.
Returns
-------
TrajectoryCollection
Concatenated trajectory from all chunks.
"""
# Time-dependent extras are not supported here: simulate() is re-invoked
# per rebuild chunk, which would mis-slice frame-aligned schedules.
from SFI.trajectory.dataset import TimeSeriesExtra
for src in (proc.extras_global, proc.extras_local):
for k, v in (src or {}).items():
if isinstance(v, TimeSeriesExtra) or (callable(v) and not hasattr(v, "func")):
raise NotImplementedError(
f"simulate_chunked does not support time-dependent extras (got {k!r}); "
"use proc.simulate() directly."
)
box = np.asarray(box, dtype=np.float64)
cutoff_list = cutoff + skin # Verlet list radius
if save_every is None:
save_every = rebuild_every
if save_every % rebuild_every != 0:
raise ValueError(f"save_every ({save_every}) must be a multiple of rebuild_every ({rebuild_every})")
rebuilds_per_save = save_every // rebuild_every
# --- initial neighbor list ---
positions = np.asarray(proc._x) # (P, d) or (d,)
# Assumes layout (P, d) for ndim==2 or (d,) for ndim==1; (d, P) would silently mislabel.
pos_spatial = positions[:, spatial_dims] if positions.ndim == 2 else positions[spatial_dims]
indptr, indices = build_neighbor_csr(pos_spatial, cutoff_list, box)
pos_at_rebuild = pos_spatial.copy() # track displacements
# Fixed max_nnz with safety margin
max_nnz = max(int(len(indices) * nnz_safety), len(indices) + 1)
indptr_pad, indices_pad = pad_neighbor_csr(indptr, indices, max_nnz)
# Merge CSR into existing extras (don't clobber other keys)
base_extras = dict(proc.extras_global or {})
base_extras[indptr_key] = jnp.array(indptr_pad)
base_extras[indices_key] = jnp.array(indices_pad)
proc.set_extras(extras_global=base_extras)
# --- chunk the simulation ---
n_rebuilds = max(1, (Nsteps + rebuild_every - 1) // rebuild_every)
remaining = Nsteps
collections = [] # final output datasets
sub_collections = [] # accumulate rebuild-chunks within a save-chunk
step_done = 0
max_disp_in_save = 0.0 # track across rebuilds within a save-chunk
for rebuild_i in range(n_rebuilds):
chunk_steps = min(rebuild_every, remaining)
if chunk_steps <= 0:
break
key, sub_key = jax.random.split(key)
coll = proc.simulate(
dt,
Nsteps=chunk_steps,
key=sub_key,
oversampling=oversampling,
prerun=prerun if rebuild_i == 0 else 0,
compute_observables=compute_observables,
jit_compile=jit_compile,
)
sub_collections.append(coll)
remaining -= chunk_steps
step_done += chunk_steps
# Emit a save-chunk when we've accumulated enough rebuilds
if (rebuild_i + 1) % rebuilds_per_save == 0 or remaining <= 0:
# Merge sub-collection X arrays into one contiguous dataset.
# Each sub-collection has one dataset with X shape (T_sub, N, d)
# where T_sub = rebuild_every (no duplicate frames).
X_parts = [np.asarray(sc.datasets[0].X) for sc in sub_collections]
X_merged = np.concatenate(X_parts, axis=0)
if verbose:
print(f" merged {len(sub_collections)} sub-chunks → X shape {X_merged.shape}")
merged = TrajectoryCollection.from_arrays(
X=X_merged,
dt=dt,
)
collections.append(merged)
sub_collections = []
if verbose:
print(f" chunk {len(collections)}: {step_done}/{Nsteps} steps")
# Rebuild neighbors from updated positions
if remaining > 0:
positions = np.asarray(proc._x) # (P, d) or (d,)
pos_spatial = positions[:, spatial_dims] if positions.ndim == 2 else positions[spatial_dims]
# --- Verlet displacement check ---
disp = pos_spatial - pos_at_rebuild
if box is not None:
disp = disp - box * np.round(disp / box)
max_disp = float(np.sqrt((disp * disp).sum(axis=-1)).max())
max_disp_in_save = max(max_disp_in_save, max_disp)
if skin > 0 and max_disp > skin / 2:
import warnings
warnings.warn(
f"Rebuild {rebuild_i}: max displacement {max_disp:.3f} "
f"> skin/2 = {skin / 2:.3f}. Neighbor list may have "
f"missed interactions. Increase skin or decrease "
f"rebuild_every.",
stacklevel=2,
)
if verbose and (rebuild_i + 1) % rebuilds_per_save == 0:
print(f" max displacement = {max_disp_in_save:.3f} (skin/2 = {skin / 2:.3f})")
max_disp_in_save = 0.0
indptr, indices = build_neighbor_csr(pos_spatial, cutoff_list, box)
pos_at_rebuild = pos_spatial.copy()
# Grow max_nnz if needed
if len(indices) > max_nnz:
max_nnz = int(len(indices) * nnz_safety)
if verbose:
print(f" max_nnz grew to {max_nnz}")
indptr_pad, indices_pad = pad_neighbor_csr(indptr, indices, max_nnz)
# Update extras, preserving all non-CSR keys
updated_extras = dict(proc.extras_global or {})
updated_extras[indptr_key] = jnp.array(indptr_pad)
updated_extras[indices_key] = jnp.array(indices_pad)
proc.set_extras(extras_global=updated_extras)
# Free old JIT caches to prevent GPU memory buildup
jax.clear_caches()
# --- concatenate chunks ---
if len(collections) == 1:
return collections[0]
return collections[0].concat(collections[1:], weights="pool")