Source code for SFI.langevin.overdamped

"""
Overdamped Langevin simulator (Euler–Maruyama / Heun) with post-run observables.
"""

from dataclasses import dataclass
from typing import Any, Dict, Optional

import jax.numpy as jnp
from jax import jit, vmap

from SFI.statefunc import PSF, SF
from SFI.utils.maths import as_default_float

from .base import Array, DiffusionKind, LangevinBase


[docs] @dataclass class OverdampedProcess(LangevinBase): """Overdamped Langevin simulator (Euler–Maruyama or stochastic Heun). Parameters ---------- F : PSF | SF Force model with `rank=vector`, `needs_v=False`, and `pdepth∈{0,1}`. If a PSF is provided, bind parameters via :meth:`set_params` prior to simulation. D : float | Array | PSF | SF Diffusion model: scalar σ (interpreted as σ·I), constant (d×d) matrix, or a PSF/SF with `rank=matrix`, `pdepth∈{0,1}` compatible with `F`. theta_F, theta_D : Optional[Array], optional Parameter vectors for binding PSF → SF. extras_global, extras_local : Optional[dict], optional Frozen, time-independent extras passed to both ``F`` and ``D`` at every call. Users should classify extras explicitly as: - extras_global: system-wide objects (geometry, external field ...) - extras_local: per-particle objects (species labels, radii, ...) At runtime these are merged into a single ``extras`` mapping, with local keys overriding global ones, and passed identically to both models. Notes ----- This class does **not** insert particle axes. The shapes must match the model contract: - If ``F.pdepth == 0``, ``x0.shape == (d,)``. - If ``F.pdepth == 1``, ``x0.shape == (P, d)``. Observables ----------- After the run (on recorded steps only), we compute: - information ``I`` approx ``0.25 * sum_t <dx_t, D_inv(x_t) . F(x_t)>`` - entropy ``S`` approx ``sum_t <dx_t, D_inv(x_mid) . (F(x_t)+F(x_{t+dt}))/2>`` where ``dx_t = x_{t+dt} - x_t`` and ``x_mid = (x_{t+dt}+x_t)/2``. We evaluate F(x) exactly once per recorded step and reuse it for both terms. """ # cached diffusion inverse for constant-D case _Dinv_const: Optional[Array] = None # bound diffusion callable for state-dependent case (so we can evaluate D at x_t / x_mid) _D_sf: Optional[SF] = None # ------------------------------- API --------------------------------
[docs] def initialize(self, x0: Array) -> None: """Initialize the process state. Parameters ---------- x0 : Array Initial position. Must satisfy: - If ``F.pdepth == 0``: shape ``(d,)`` - If ``F.pdepth == 1``: shape ``(P, d)`` Side effects ------------ Binds PSF parameters (if any), validates model contracts, and prepares diffusion shortcuts (constant vs state-dependent). """ # Basic contract inspection from the statefunc objects self._normalize_basis_to_psf() self._assert_force_contract(self.F) # Canonicalize dtype: any user-provided array (float32, float64, # numpy int, plain list, …) is cast to JAX's currently-active # default float dtype. This is the single normalization point # that makes lax.scan carry-in / carry-out dtypes always agree. x0 = as_default_float(x0) # Deduce (P, d) from x0 if x0.ndim == 1: d = int(x0.shape[0]) P = None elif x0.ndim == 2: P = int(x0.shape[0]) d = int(x0.shape[1]) else: raise ValueError("x0 must have shape (d,) or (P, d).") # Check against F.pdepth f_pdepth = getattr(self.F, "pdepth", None) if f_pdepth not in (0, 1): raise ValueError("F.pdepth must be 0 or 1 for overdamped simulations.") if f_pdepth == 0 and x0.ndim != 1: if x0.shape[0] == 1 and x0.ndim == 2: # Silently drop the first axis x0 = x0[0] else: raise ValueError("F expects no particle axis (pdepth=0): x0 must be (d,).") if f_pdepth == 1 and x0.ndim != 2: raise ValueError("F expects a particle axis (pdepth=1): x0 must be (P, d).") # Bind F now, keeping the unbound object intact self._F_sf = self._bind_force() # Bind D early *if* it's a PSF/SF, so structural extras can be prepared once for both. D_sf_for_extras: Optional[SF] = None if isinstance(self.D, SF): D_sf_for_extras = self.D elif isinstance(self.D, PSF): if self.theta_D is None: raise ValueError("Diffusion PSF not bound: call set_params(theta_D=...).") D_sf_for_extras = self.D.bind(self.theta_D) # type: ignore[attr-defined] # Eagerly materialize any structural extras needed by bound expressions. # This MUST happen before the first JIT-triggering evaluation. self._invalidate_prepared_extras() exprs = [self._F_sf] + ([D_sf_for_extras] if D_sf_for_extras is not None else []) extras = self._prepare_model_extras(x_probe=x0, exprs=exprs) # Validate output dimension of F on a cheap probe (shape-only). # Use the merged process-level extras seen by both F and D. f_out = self._F_sf(x0, extras=extras) if f_out.shape != x0.shape: raise ValueError(f"Force output shape {f_out.shape} does not match input shape {x0.shape}.") # Prepare diffusion shortcuts (constant vs state-dependent) for the integrator self._check_diffusion_contract(self.D, d=d, f_pdepth=f_pdepth) self._setup_diffusion(d=d, with_v=False) # Also prepare D or D^{-1} for post-run observables from SFI.langevin.noise import NoiseModel if isinstance(self.D, NoiseModel): # Use the effective per-site D for observables approximation extras = self._model_extras() D_eff = self.D.effective_D_per_site(extras) self._Dinv_const = jnp.linalg.pinv(D_eff) self._D_sf = None elif isinstance(self.D, (int, float)): sigma = float(self.D) self._Dinv_const = (1.0 / sigma) * jnp.eye(d) # pinv(σ I) = (1/σ) I self._D_sf = None elif isinstance(self.D, jnp.ndarray) and self.D.ndim == 2: self._Dinv_const = jnp.linalg.pinv(self.D) self._D_sf = None elif isinstance(self.D, (PSF, SF)): # Reuse the already-bound callable (if any) if D_sf_for_extras is None: # Defensive fallback; should not happen if the above binding logic ran if isinstance(self.D, SF): D_sf_for_extras = self.D else: if self.theta_D is None: raise ValueError("Diffusion PSF not bound: call set_params(theta_D=...).") D_sf_for_extras = self.D.bind(self.theta_D) # type: ignore[attr-defined] self._D_sf = D_sf_for_extras self._Dinv_const = None else: raise TypeError("D must be a float, (d×d) array, or a PSF/SF.") # Persist runtime state/metadata basics self._x = x0 num_particles = int(P) if P is not None else 1 self.metadata.clear() self.metadata.update( dict( kind="overdamped", dimension=d, pdepth=int(f_pdepth), num_particles=num_particles, x0=x0, ) )
[docs] def simulate( self, dt: float, Nsteps: int, key: Array, *, oversampling: int = 4, prerun: int = 0, jit_compile: bool = True, compute_observables: bool = True, method: str = "heun", ): r""" Integrate overdamped Langevin dynamics and return a :class:`TrajectoryCollection` of positions. Parameters ---------- dt Time step between recorded frames. Nsteps Number of recorded time steps. key PRNG key for the simulation. oversampling Number of integration substeps between recorded frames. The effective substep size is ``dt / oversampling``. Although all integrators have a consistent continuous limit, they introduce short-range, algorithm-specific temporal correlations at the scale of a single step. Downsampling by recording only every ``oversampling``-th substep ensures these artefacts never reach the inference layer. The default of 4 is a safe minimum for typical use; increase it when ``dt`` is large or the process varies rapidly. prerun Number of recorded steps to discard as burn-in, using the same ``dt`` and ``oversampling``. jit_compile If True, JIT-compile the single-step integrator before scanning. method Integration scheme. ``"heun"`` (default) selects the stochastic Heun predictor-corrector scheme, which achieves **weak order 2** for constant (additive) diffusion — the dominant use case — at the cost of two force evaluations per substep. For state-dependent diffusion the Heun scheme still uses the Itô-correct left-point noise evaluation, giving weak order 1 but with better error constants than Euler–Maruyama. ``"euler"`` selects the classical Euler–Maruyama integrator (weak order 1). compute_observables If True, compute post-run information and entropy production estimates on the recorded trajectory and store them in the dataset metadata under the ``"observables"`` key. .. physics:: Information functional & entropy production (overdamped) :label: info-entropy-overdamped :category: Observable .. math:: I \approx \tfrac{1}{4}\sum_t \mathrm{d}X_t^\top\, D^{-1}(x_t)\, F(x_t) .. math:: S \approx \sum_t \mathrm{d}X_t^\top\, D^{-1}(x_{\text{mid}})\, \tfrac{1}{2}\bigl[F(x_t)+F(x_{t+1})\bigr] :math:`I` estimates the information content; :math:`S` the entropy production (time-reversal asymmetry). Returns ------- TrajectoryCollection A collection with a single dataset containing the positions. The underlying dataset has: - ``X`` of shape ``(Nsteps, d)`` or ``(Nsteps, P, d)``, - metadata combining model info (kind, dimension, pdepth, etc.), run info (dt, Nsteps, oversampling, prerun), and optional observables. """ if self._F_sf is None: raise RuntimeError("Call initialize(x0) before simulate().") _valid_methods = ("euler", "heun") if method not in _valid_methods: raise ValueError(f"Unknown method {method!r}; expected one of {_valid_methods}.") self._method = method # Split extras into static values and per-frame schedules (the # time-dependent ones: TimeSeriesExtra of length Nsteps, or f(t) # callables materialized at the frame times). static_extras, schedules, eg_out, el_out = self._materialize_step_extras(dt=dt, Nsteps=Nsteps) step = self._make_step(static_extras) if jit_compile: step = jit(step) traj, info = self._scan( self._x, step_fn=step, dt=dt, Nsteps=Nsteps, oversampling=oversampling, prerun=prerun, key=key, schedules=schedules or None, ) # Update final state for continuation self._x = traj[-1] # Run-level metadata snapshot run_meta: Dict[str, Any] = dict(self.metadata) run_meta.update(dict(dt=float(dt), integrator=method, **info)) if schedules: run_meta["time_dependent_extras"] = sorted(schedules) # Optional post-run observables (information/entropy) if compute_observables: X = traj dX = X[1:] - X[:-1] X_mid = 0.5 * (X[1:] + X[:-1]) extras = static_extras # Frame-aligned extras: a[k] governs [X[k] -> X[k+1]] (zeroth- # order hold), so the evaluation at X[k] pairs with the frame-k # schedule slice. if schedules: def _ex_at(s): return {**(extras or {}), **s} F_all = vmap( lambda x, s: self._F_sf(x, extras=_ex_at(s)) # type: ignore[misc] )(X, schedules) else: F_all = vmap( lambda x, _extras=extras: self._F_sf(x, extras=_extras) # type: ignore[misc] )(X) F_t, F_tp = F_all[:-1], F_all[1:] F_avg = 0.5 * (F_t + F_tp) if self._Dinv_const is not None: I_terms = jnp.einsum("...m,...n,mn->", dX, F_t, self._Dinv_const) S_terms = jnp.einsum("...m,...n,mn->", dX, F_avg, self._Dinv_const) else: if self._D_sf is None: raise RuntimeError("State-dependent diffusion SF not initialized.") if schedules: sched_t = {k: v[:-1] for k, v in schedules.items()} Dinv_t = vmap( lambda x, s: jnp.linalg.pinv(self._D_sf(x, extras=_ex_at(s))) # type: ignore[misc] )(X[:-1], sched_t) Dinv_mid = vmap( lambda x, s: jnp.linalg.pinv(self._D_sf(x, extras=_ex_at(s))) # type: ignore[misc] )(X_mid, sched_t) else: Dinv_t = vmap( lambda x, _extras=extras: jnp.linalg.pinv(self._D_sf(x, extras=_extras)) # type: ignore[misc] )(X[:-1]) Dinv_mid = vmap( lambda x, _extras=extras: jnp.linalg.pinv(self._D_sf(x, extras=_extras)) # type: ignore[misc] )(X_mid) I_terms = jnp.einsum("...m,...n,...mn->", dX, F_t, Dinv_t) S_terms = jnp.einsum("...m,...n,...mn->", dX, F_avg, Dinv_mid) observables = { "information": float(0.25 * I_terms), "entropy": float(S_terms), } run_meta["observables"] = observables # Hand off to the base helper: positions → TrajectoryCollection coll = self._traj_to_collection( traj, dt=dt, meta=run_meta, extras_global_out=eg_out, extras_local_out=el_out ) return coll
# ----------------------------- Internals ----------------------------- def _make_step(self, extras=None): """Dispatch to the selected integration scheme. ``extras`` is the merged *static* extras dict; per-frame scheduled values arrive through the step's optional ``sched`` argument and override it. """ method = getattr(self, "_method", "euler") if extras is None: extras = self._model_extras() if method == "heun": return self._make_heun_step(extras) return self._make_euler_step(extras) def _make_euler_step(self, extras=None): r"""Create the Euler–Maruyama substep function (no observables inside). .. physics:: Euler–Maruyama integrator (overdamped) :label: euler-maruyama-overdamped :category: Simulation .. math:: x_{t+\mathrm{d}t} = x_t + \mathrm{d}t\, F(x_t) + \sqrt{\mathrm{d}t}\, B(x_t)\,\xi_t where :math:`B = \sqrt{2D}` and :math:`\xi_t \sim \mathcal{N}(0, I)`. """ F = self._F_sf kind = self._D_kind Bc = self._B_const Bf = self._B_fn noise_model = self._noise_model static_extras = extras if F is None: raise RuntimeError("Force not bound. Did you call initialize() after set_params()?") def step(x: Array, ddt: float, key: Array, sched=None) -> Array: """Single Euler–Maruyama substep.""" ex = static_extras if sched is None else {**(static_extras or {}), **sched} # Drift term drift = F(x, extras=ex) # Noise increment if kind is DiffusionKind.NOISE_MODEL: assert noise_model is not None inc = noise_model.sample(key, x, ex) elif kind is DiffusionKind.STATE_FUNC: assert Bf is not None xi = self._noise(key, x) Bx = Bf(x, extras=ex) inc = self._apply_B(Bx, xi, state_dependent=True) else: xi = self._noise(key, x) inc = self._apply_B(Bc, xi, state_dependent=False) return x + ddt * drift + jnp.sqrt(ddt) * inc return step def _make_heun_step(self, extras=None): r"""Create the stochastic Heun (predictor-corrector) substep function. .. physics:: Stochastic Heun integrator (overdamped) :label: heun-overdamped :category: Simulation Predictor (Euler): .. math:: \hat x = x_t + \mathrm{d}t\, F(x_t) + \sqrt{\mathrm{d}t}\, B(x_t)\,\xi_t Corrector (trapezoidal drift): .. math:: x_{t+\mathrm{d}t} = x_t + \tfrac{1}{2}\mathrm{d}t\,[F(x_t) + F(\hat x)] + \sqrt{\mathrm{d}t}\, B(x_t)\,\xi_t For constant (additive) diffusion this achieves **weak order 2**. For state-dependent diffusion it uses the left-point noise evaluation to preserve the Itô convention, giving weak order 1 but with better error constants than Euler–Maruyama. Costs two force evaluations per substep. """ F = self._F_sf kind = self._D_kind Bc = self._B_const Bf = self._B_fn noise_model = self._noise_model static_extras = extras if F is None: raise RuntimeError("Force not bound. Did you call initialize() after set_params()?") def step(x: Array, ddt: float, key: Array, sched=None) -> Array: """Single stochastic Heun substep.""" ex = static_extras if sched is None else {**(static_extras or {}), **sched} drift = F(x, extras=ex) # Noise increment (evaluated at left point for Itô correctness) if kind is DiffusionKind.NOISE_MODEL: assert noise_model is not None noise_inc = jnp.sqrt(ddt) * noise_model.sample(key, x, ex) elif kind is DiffusionKind.STATE_FUNC: assert Bf is not None xi = self._noise(key, x) Bx = Bf(x, extras=ex) noise_inc = jnp.sqrt(ddt) * self._apply_B(Bx, xi, state_dependent=True) else: xi = self._noise(key, x) noise_inc = jnp.sqrt(ddt) * self._apply_B(Bc, xi, state_dependent=False) # Euler predictor x_pred = x + ddt * drift + noise_inc # Corrector: trapezoidal average of drift, same noise realization # (predictor and corrector share the frame's scheduled extras — # the protocol is piecewise constant by definition). drift_pred = F(x_pred, extras=ex) return x + 0.5 * ddt * (drift + drift_pred) + noise_inc return step # --------------------------- Validations ----------------------------- @staticmethod def _assert_force_contract(F: PSF | SF) -> None: """Validate force has the right rank/dim flags for overdamped use.""" needs_v = getattr(F, "needs_v", False) if needs_v: raise ValueError("Overdamped force must not require velocity (needs_v=False).") rank = getattr(F, "rank", None) if rank != 1: raise ValueError("Force must have rank=vector.") dim = getattr(F, "dim", None) if dim is None or not isinstance(dim, int): raise ValueError("Force must declare an integer `dim` attribute.") pdepth = getattr(F, "pdepth", None) if pdepth not in (0, 1): raise ValueError("Force `pdepth` must be 0 or 1 for overdamped simulations.") @staticmethod def _check_diffusion_contract(D, *, d: int, f_pdepth: int) -> None: """Validate diffusion against the force contract, allowing benign broadcast.""" # NoiseModel instances from SFI.langevin.noise import NoiseModel if isinstance(D, NoiseModel): if D.dim != d: raise ValueError(f"NoiseModel n_fields={D.dim} must match force dim={d}.") return # Scalars and constant matrices are fine if isinstance(D, (int, float)): return if isinstance(D, jnp.ndarray) and D.ndim == 2: if D.shape != (d, d): raise ValueError(f"Constant diffusion must be shape (d,d)={(d, d)}, got {D.shape}.") return # PSF/SF path if not isinstance(D, (PSF, SF)): raise TypeError("Diffusion must be a float, (d×d) array, PSF/SF, or NoiseModel.") rank = getattr(D, "rank", None) if rank != 2: raise ValueError("Diffusion PSF/SF must have rank=matrix.") dim = getattr(D, "dim", None) if dim != d: raise ValueError(f"Diffusion dim={dim} must match force dim={d}.") d_pdepth = getattr(D, "pdepth", None) if d_pdepth not in (0, 1): raise ValueError("Diffusion `pdepth` must be 0 or 1.") if not (d_pdepth == f_pdepth or d_pdepth == 0): raise ValueError( f"Incompatible pdepth: force pdepth={f_pdepth}, diffusion pdepth={d_pdepth}. " "Only equal depths or diffusion depth=0 (broadcast) are allowed." )