# underdamped.py
"""
Underdamped Langevin simulator (velocity-Verlet-like, generic F(x,v) and D(x[,v])).
This mirrors :mod:`overdamped` as closely as possible, but simulates the
phase-space SDE
dx = v dt
dv = F(x, v) dt + sqrt(2 D(x, v)) dW
where diffusion acts on *velocity* increments. The returned
:class:`~SFI.trajectory.collection.TrajectoryCollection` stores **positions
only** by design.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, Optional
import jax
import jax.numpy as jnp
from jax import jit, lax, random
from SFI.statefunc import PSF, SF
from SFI.utils.maths import as_default_float
from .base import Array, DiffusionKind, LangevinBase
[docs]
@dataclass
class UnderdampedProcess(LangevinBase):
"""Underdamped Langevin simulator.
Parameters
----------
F : PSF | SF
Force model with `rank=vector`, `needs_v=True`, and `pdepth∈{0,1}`.
If a PSF is provided, bind parameters via :meth:`set_params` prior to
simulation.
D : float | Array | PSF | SF
Diffusion model acting on velocities: scalar σ (interpreted as σ·I),
constant (d×d) matrix, or a PSF/SF with `rank=matrix`.
If provided as PSF/SF, it may depend on (x) or (x, v), controlled by
its `needs_v` flag.
Notes
-----
This class does **not** insert particle axes; it follows the `pdepth`
convention of the statefunc objects, similarly to :class:`OverdampedProcess`.
"""
# Whether the (state-dependent) diffusion SF requires v.
_D_needs_v: bool = False
# Bound diffusion callable (only used for eager structural-extras preparation).
_D_sf: Optional[SF] = None
# ------------------------------- API --------------------------------
[docs]
def initialize(self, x0: Array, v0: Optional[Array] = None) -> None:
"""Initialize the process state.
Parameters
----------
x0 : Array
Initial position. Must satisfy:
- If `F.pdepth == 0`: shape (d,)
- If `F.pdepth == 1`: shape (P, d)
v0 : Array, optional
Initial velocity. Must have the same shape as `x0`. Defaults to 0.
"""
self._normalize_basis_to_psf()
self._assert_force_contract(self.F)
# Canonicalize dtype (single normalization point; see
# OverdampedProcess.initialize for rationale).
x0 = as_default_float(x0)
v0 = as_default_float(v0)
# Deduce (P, d) from x0
if x0.ndim == 1:
d = int(x0.shape[0])
P = None
elif x0.ndim == 2:
P = int(x0.shape[0])
d = int(x0.shape[1])
else:
raise ValueError("x0 must have shape (d,) or (P, d).")
f_pdepth = getattr(self.F, "pdepth", None)
if f_pdepth not in (0, 1):
raise ValueError("F.pdepth must be 0 or 1 for underdamped simulations.")
if f_pdepth == 0 and x0.ndim != 1:
if x0.shape[0] == 1 and x0.ndim == 2:
x0 = x0[0]
if v0 is not None and v0.ndim == 2 and v0.shape[0] == 1:
v0 = v0[0]
else:
raise ValueError("F expects no particle axis (pdepth=0): x0 must be (d,).")
if f_pdepth == 1 and x0.ndim != 2:
raise ValueError("F expects a particle axis (pdepth=1): x0 must be (P, d).")
if v0 is None:
v0 = jnp.zeros_like(x0)
if v0.shape != x0.shape:
raise ValueError(f"v0.shape must match x0.shape; got v0={v0.shape}, x0={x0.shape}.")
# Bind models
self._F_sf = self._bind_force()
if isinstance(self.D, (PSF, SF)):
if isinstance(self.D, SF):
self._D_sf = self.D
else:
if self.theta_D is None:
raise ValueError("Diffusion PSF not bound: call set_params(theta_D=...).")
self._D_sf = self.D.bind(self.theta_D) # type: ignore[attr-defined]
self._D_needs_v = bool(getattr(self.D, "needs_v", False))
else:
self._D_sf = None
self._D_needs_v = False
# Prepare structural extras once for all bound expressions.
# (The base helper caches preparation globally; pass the full list here.)
self._invalidate_prepared_extras()
exprs = [self._F_sf] + ([self._D_sf] if self._D_sf is not None else [])
extras = self._prepare_model_extras(x_probe=x0, v_probe=v0, exprs=exprs)
# Validate force output shape
f_out = self._F_sf(x0, v=v0, extras=extras)
if f_out.shape != x0.shape:
raise ValueError(f"Force output shape {f_out.shape} does not match input shape {x0.shape}.")
# Diffusion (constant vs state-dependent)
self._check_diffusion_contract(self.D, d=d, f_pdepth=int(f_pdepth))
self._setup_diffusion(d=d, with_v=self._D_needs_v)
# Persist runtime state/metadata basics
self._x = x0
self._v = v0
num_particles = int(P) if P is not None else 1
self.metadata.clear()
self.metadata.update(
dict(
kind="underdamped",
dimension=d,
pdepth=int(f_pdepth),
num_particles=num_particles,
x0=x0,
v0=v0,
)
)
[docs]
def simulate(
self,
dt: float,
Nsteps: int,
key: Array,
*,
oversampling: int = 4,
prerun: int = 0,
jit_compile: bool = True,
compute_observables: bool = False,
):
"""Run the integrator and return a :class:`TrajectoryCollection` of positions.
Parameters
----------
dt
Time step between recorded frames.
Nsteps
Number of recorded time steps.
key
PRNG key for the simulation.
oversampling
Number of velocity-Verlet substeps between recorded frames.
The effective substep size is ``dt / oversampling``.
Although all integrators have a consistent continuous limit, they
introduce short-range, algorithm-specific temporal correlations at
the scale of a single step. Downsampling by recording only every
``oversampling``-th substep ensures these artefacts never reach
the inference layer. The default of 4 is a safe minimum for
typical use; increase it when ``dt`` is large or the process
varies rapidly.
prerun
Number of recorded steps to discard as burn-in.
jit_compile
If True, JIT-compile the single-step integrator before scanning.
compute_observables
Not yet implemented for the underdamped case.
Returns
-------
TrajectoryCollection
A collection with a single dataset containing the positions only
(velocities are not stored by design). The underlying dataset has:
- ``X`` of shape ``(Nsteps, d)`` or ``(Nsteps, P, d)``,
- metadata combining model info (kind, dimension, pdepth, etc.)
and run info (dt, Nsteps, oversampling, prerun).
"""
if self._F_sf is None:
raise RuntimeError("Call initialize(x0, v0=...) before simulate().")
if compute_observables:
raise NotImplementedError(
"Underdamped observables are not implemented in this variant, "
"because velocities are intentionally not stored."
)
# Split extras into static values and per-frame schedules.
static_extras, schedules, eg_out, el_out = self._materialize_step_extras(dt=dt, Nsteps=Nsteps)
step = self._make_step(static_extras)
if jit_compile:
step = jit(step)
# Carry is (x, v), but we record positions only.
traj_x, final_state, info = self._scan_positions(
(self._x, self._v),
step_fn=step,
dt=dt,
Nsteps=Nsteps,
oversampling=oversampling,
prerun=prerun,
key=key,
schedules=schedules or None,
)
self._x, self._v = final_state
run_meta: Dict[str, Any] = dict(self.metadata)
run_meta.update(dict(dt=float(dt), **info))
if schedules:
run_meta["time_dependent_extras"] = sorted(schedules)
return self._traj_to_collection(
traj_x, dt=dt, meta=run_meta, extras_global_out=eg_out, extras_local_out=el_out
)
# ----------------------------- Internals -----------------------------
def _make_step(self, extras=None):
r"""Create the substep function (kick–drift–kick, velocity-Verlet-like).
.. physics:: Velocity-Verlet integrator (underdamped)
:label: velocity-verlet-underdamped
:category: Simulation
Stochastic splitting (kick–drift–kick):
.. math::
v_{1/2} &= v + \tfrac{1}{2}\mathrm{d}t\, F(x,v)
+ \sqrt{\mathrm{d}t/2}\; B(x,v)\,\xi_1 \\
x' &= x + \mathrm{d}t\, v_{1/2} \\
v' &= v_{1/2} + \tfrac{1}{2}\mathrm{d}t\, F(x', v_{1/2})
+ \sqrt{\mathrm{d}t/2}\; B(x', v_{1/2})\,\xi_2
Preserves the symplectic structure of the deterministic part.
"""
F = self._F_sf
if extras is None:
extras = self._model_extras()
static_extras = extras
kind = self._D_kind
Bc = self._B_const
Bf = self._B_fn
D_needs_v = self._D_needs_v
if F is None:
raise RuntimeError("Force not bound. Did you call initialize() after set_params()?")
def _eval_B(x: Array, v: Array, ex) -> Array:
if kind is DiffusionKind.STATE_FUNC:
if Bf is None:
raise RuntimeError("State-dependent diffusion not initialized.")
return Bf(x, v, extras=ex) if D_needs_v else Bf(x, extras=ex)
if Bc is None:
raise RuntimeError("Constant diffusion not initialized.")
return Bc
def _apply(B: Array, xi: Array) -> Array:
# Be permissive: if B happens to be (d,d), treat it as constant.
return self._apply_B(B, xi, state_dependent=(B.ndim > 2))
def step(state, ddt: float, key: Array, sched=None):
x, v = state
k1, k2 = random.split(key)
# Both half-kicks use the frame's scheduled value (zeroth-order
# hold: the protocol is piecewise constant by definition).
ex = static_extras if sched is None else {**(static_extras or {}), **sched}
# --- First half-kick
a0 = F(x, v=v, extras=ex)
B0 = _eval_B(x, v, ex)
xi1 = self._noise(k1, v)
dv1 = _apply(B0, xi1)
v_half = v + 0.5 * ddt * a0 + jnp.sqrt(ddt / 2.0) * dv1
# --- Drift
x_new = x + ddt * v_half
# --- Second half-kick
a1 = F(x_new, v=v_half, extras=ex)
B1 = _eval_B(x_new, v_half, ex)
xi2 = self._noise(k2, v)
dv2 = _apply(B1, xi2)
v_new = v_half + 0.5 * ddt * a1 + jnp.sqrt(ddt / 2.0) * dv2
return (x_new, v_new)
return step
@staticmethod
def _scan_positions(
initial_state,
*,
step_fn,
dt: float,
Nsteps: int,
oversampling: int,
prerun: int,
key: Array,
schedules=None,
):
"""Scan loop that records positions only (carry contains (x,v)).
``schedules`` follows the same per-frame contract as
:meth:`LangevinBase._scan`.
"""
if oversampling < 1:
raise ValueError("oversampling must be >= 1")
ddt = dt / float(oversampling)
scheduled = bool(schedules)
if scheduled:
sched0 = {k: v[0] for k, v in schedules.items()}
def one_substep_at(sched):
def one_substep(carry, _):
st, k = carry
k, sub = random.split(k)
st = step_fn(st, ddt, sub, sched)
return (st, k), None
return one_substep
def one_recorded_step(carry, sched_t):
st, k = carry
(st, k), _ = lax.scan(one_substep_at(sched_t), (st, k), None, length=oversampling)
return (st, k), st[0]
if prerun > 0:
(state, key), _ = lax.scan(
lambda c, _: one_recorded_step(c, sched0), (initial_state, key), None, length=prerun
)
else:
state = initial_state
if Nsteps > 0:
from SFI.langevin.base import LangevinBase
xs = LangevinBase._shift_schedules(schedules)
(state, key), traj_x = lax.scan(one_recorded_step, (state, key), xs, length=Nsteps)
else:
x_example = initial_state[0]
traj_x = jax.tree_util.tree_map(lambda a: a[None, ...][:0], x_example)
else:
def one_substep(carry, _):
st, k = carry
k, sub = random.split(k)
st = step_fn(st, ddt, sub)
return (st, k), None
def one_recorded_step_plain(carry, _):
st, k = carry
(st, k), _ = lax.scan(one_substep, (st, k), None, length=oversampling)
# record positions only
return (st, k), st[0]
if prerun > 0:
(state, key), _ = lax.scan(one_recorded_step_plain, (initial_state, key), None, length=prerun)
else:
state = initial_state
if Nsteps > 0:
(state, key), traj_x = lax.scan(one_recorded_step_plain, (state, key), None, length=Nsteps)
else:
x_example = initial_state[0]
traj_x = jax.tree_util.tree_map(lambda a: a[None, ...][:0], x_example)
info = {
"Nsteps": int(Nsteps),
"oversampling": int(oversampling),
"prerun": int(prerun),
}
return traj_x, state, info
# --------------------------- Validations -----------------------------
@staticmethod
def _assert_force_contract(F: PSF | SF) -> None:
"""Validate force has the right rank/dim flags for underdamped use."""
needs_v = getattr(F, "needs_v", False)
if not needs_v:
raise ValueError("Underdamped force must require velocity (needs_v=True).")
rank = getattr(F, "rank", None)
if rank != 1:
raise ValueError("Force must have rank=vector.")
dim = getattr(F, "dim", None)
if dim is None or not isinstance(dim, int):
raise ValueError("Force must declare an integer `dim` attribute.")
pdepth = getattr(F, "pdepth", None)
if pdepth not in (0, 1):
raise ValueError("Force `pdepth` must be 0 or 1 for underdamped simulations.")
@staticmethod
def _check_diffusion_contract(D, *, d: int, f_pdepth: int) -> None:
"""Validate diffusion against the force contract, allowing benign broadcast."""
from SFI.langevin.noise import NoiseModel
if isinstance(D, NoiseModel):
if D.dim != d:
raise ValueError(f"NoiseModel n_fields={D.dim} must match force dim={d}.")
return
if isinstance(D, (int, float)):
return
if isinstance(D, jnp.ndarray) and D.ndim == 2:
if D.shape != (d, d):
raise ValueError(f"Constant diffusion must be shape (d,d)={(d, d)}, got {D.shape}.")
return
if not isinstance(D, (PSF, SF)):
raise TypeError("Diffusion must be a float, (d×d) array, PSF/SF, or NoiseModel.")
rank = getattr(D, "rank", None)
if rank != 2:
raise ValueError("Diffusion PSF/SF must have rank=matrix.")
dim = getattr(D, "dim", None)
if dim != d:
raise ValueError(f"Diffusion dim={dim} must match force dim={d}.")
d_pdepth = getattr(D, "pdepth", None)
if d_pdepth not in (0, 1):
raise ValueError("Diffusion `pdepth` must be 0 or 1.")
if not (d_pdepth == f_pdepth or d_pdepth == 0):
raise ValueError(
f"Incompatible pdepth: force pdepth={f_pdepth}, diffusion pdepth={d_pdepth}. "
"Only equal depths or diffusion depth=0 (broadcast) are allowed."
)