"""Per-backend residual builders.
Each builder takes a fitted inference object and returns a
:class:`ResidualBundle` containing pooled standardized residuals
:math:`z = \\Sigma^{-1/2} r` ready to feed into the statistical tests.
Measurement-noise-aware, banded whitening
-----------------------------------------
Both residuals carry two correlation sources that a single-residual
whitening ignores:
* **Measurement noise** :math:`\\Sigma_\\eta`. The diagnostic residual
covariance is :math:`C = \\text{(thermal)} + c\\,\\Sigma_\\eta`, not the
thermal part alone. The estimator's profiled :math:`\\Sigma_\\eta`
(``inferer.Lambda``) is folded into ``C`` so that a *well-recovered but
noisy* fit still whitens to unit variance instead of tripping every
flag. On clean data :math:`\\Sigma_\\eta\\approx 0` and this reduces to
the thermal whitening.
* **Serial correlation.** Localisation error is shared between
neighbouring residuals, so the residual series is a moving-average
process (overdamped increment → MA(1) with lag-1 block
:math:`-\\Sigma_\\eta`; the kept underdamped acceleration series → MA(1)
with lag-1 block :math:`\\Sigma_\\eta/\\Delta t^4`). A *banded*
whitening — the sequential block-Cholesky innovations of the
tridiagonal residual covariance (:func:`_sequential_innovations`) —
decorrelates the stream, exactly paralleling the parametric core's
banded precision. On clean data the off-diagonal block vanishes and
the innovations coincide with the marginal whitening.
The whitened stream ``z`` (moments / normality / autocorrelation) uses
the banded innovations; the per-row Mahalanobis norms ``z_squared_norms``
(the chi-square / MSE-consistency *bias* check) keep the **marginal**
noise-aware form, which faithfully preserves a slowly-varying force bias
that the innovations would partly difference out.
Residual conventions
--------------------
**Overdamped**:
.. math::
r_{t,n} = X_{t+1,n} - X_{t,n} - F(X_{t,n})\\,\\Delta t,
\\qquad C_{t} = 2\\,\\bar D\\,\\Delta t + 2\\,\\Sigma_\\eta,
with lag-1 covariance :math:`-\\Sigma_\\eta`. For the linear path the
thermal part is the exact ML residual; for the parametric path it is an
approximation that is nevertheless consistent (whitened residuals should
have unit variance and no autocorrelation if the model is well
specified).
**Underdamped**: symmetric acceleration
:math:`\\hat a_t = (X_{t+1} - 2X_t + X_{t-1})/\\Delta t^2`,
.. math::
r_t = \\hat a_t - F(\\hat x_t, \\hat v_t),
\\qquad C_t = \\tfrac23\\,\\frac{2\\bar D}{\\Delta t}
+ \\frac{6\\,\\Sigma_\\eta}{\\Delta t^4}.
For both regimes residuals are pooled across time, particles, and
spatial components, applying the dataset's ``dynamic_mask`` (for
overdamped) or its 1-step erosion (for underdamped, which needs three
consecutive valid observations).
"""
from __future__ import annotations
from dataclasses import dataclass, field
import jax
import jax.numpy as jnp
import numpy as np
from jax import lax
[docs]
@dataclass
class ResidualBundle:
"""Standardised residuals + metadata.
Attributes
----------
z : np.ndarray
Whitened residuals, shape ``(K,)``. Pooled across time,
particles and spatial components after masking.
z_components : np.ndarray
Whitened residuals organised by spatial component, shape
``(K_per_component, d)``. Used for per-axis statistics.
z_squared_norms : np.ndarray
Per-row squared Mahalanobis norm
:math:`r_t^\\top \\Sigma_t^{-1} r_t`, shape ``(K_per_row,)``.
Used for the diffusion / "chi-square" check.
force_quadratic_form : np.ndarray
Per-row quadratic form
:math:`F^\\top A^{-1} F` evaluated on the same valid samples
used to build ``z``. Pre-computing it here avoids a second
evaluation of ``F`` in the MSE-consistency check downstream.
mean_dt : float
Average step size used in the residual construction.
n_obs : int
Number of valid (un-masked) observations used to build ``z``.
d : int
Spatial dimension.
regime : str
``"OD"`` or ``"UD"``.
backend : str
Coarse tag of the inference path (``"linear"``, ``"parametric"``,
``"nonlinear"``). For diagnostic display only.
n_particles : int
Maximum number of particles in any dataset.
nmse_excess_factor : float
Conversion factor from the chi-square excess to the force NMSE in
:func:`mse_consistency`. ``1.0`` for the overdamped increment
residual; :data:`KAPPA_UD` for the underdamped acceleration
residual (see that constant for the derivation).
whitened : list of (np.ndarray, np.ndarray)
Per-dataset ``(z_full, mask)`` pairs with ``z_full`` of shape
``(K, N, d)`` (time-major) and ``mask`` of shape ``(K, N)``.
Kept so that autocorrelation can be measured strictly along
time, per particle and per component — pooling the flattened
``z`` stream would mix particles and components at short lags.
"""
z: np.ndarray
z_components: np.ndarray
z_squared_norms: np.ndarray
force_quadratic_form: np.ndarray
mean_dt: float
n_obs: int
d: int
regime: str
backend: str
n_particles: int
nmse_excess_factor: float = 1.0
whitened: list = field(default_factory=list)
# --------------------------------------------------------------------- #
# Helpers
# --------------------------------------------------------------------- #
def _inv_sqrt_psd(S: jnp.ndarray) -> jnp.ndarray:
"""Batched symmetric inverse square root ``S^{-1/2}`` (eigen-clamped).
``S`` has shape ``(..., d, d)`` and is symmetrised and floored before
inversion so a marginally-indefinite conditional covariance (high
measurement noise) stays well-posed.
"""
S = 0.5 * (S + jnp.swapaxes(S, -1, -2))
w, U = jnp.linalg.eigh(S)
floor = jnp.maximum(jnp.max(w, axis=-1, keepdims=True) * 1e-12, 1e-30)
w = jnp.clip(w, floor, None)
return jnp.einsum("...am,...m,...bm->...ab", U, 1.0 / jnp.sqrt(w), U)
@jax.jit
def _sequential_innovations(r, mask, contiguous, A_blocks, Lambda, offdiag_coef):
r"""Whiten a tridiagonal (MA(1)) residual stream by block-Cholesky innovations.
The residual series has diagonal covariance blocks ``A_blocks[k]`` and
lag-1 blocks ``C_k = Cov(r_{k-1}, r_k) = offdiag_coef[k]·Λ``. The
LDL\\ :sup:`T` factorisation of the resulting block-tridiagonal
covariance gives the innovations recursion (the exact whitening, the
diagnostic twin of the parametric core's banded precision):
.. math::
M_k = C_k\\,S_{k-1}^{-1},\\quad
w_k = r_k - M_k w_{k-1},\\quad
S_k = A_k - M_k C_k,\\quad
z_k = S_k^{-1/2} w_k,
so ``z`` has unit covariance *and no serial correlation* — unlike the
marginal whitening ``A_k^{-1/2} r_k``, which leaves the measurement-noise
off-diagonal in place. The recursion **resets** (drops to the marginal
form ``z_k = A_k^{-1/2} r_k``) at the start of each contiguous run: where
``contiguous[k]`` is ``False`` (the kept index is not the immediate
successor of the previous one) or either endpoint is masked, so gaps in
the trajectory do not couple unrelated residuals.
Parameters
----------
r : ``(K, N, d)`` time-major residuals.
mask : ``(K, N)`` bool validity.
contiguous : ``(K,)`` bool — ``True`` where index ``k`` is the sampling
successor of ``k-1`` (so the lag-1 block applies).
A_blocks : ``(K, d, d)`` diagonal covariance blocks (= the marginal,
noise-aware ``C``).
Lambda : ``(d, d)`` measurement-noise covariance ``Λ``.
offdiag_coef : ``(K,)`` lag-1 scalar coefficient (overdamped ``-1``;
kept underdamped ``1/Δt^4``).
Returns
-------
z : ``(K, N, d)`` whitened innovations (zero on masked rows).
"""
K, N, d = r.shape
I_d = jnp.eye(d, dtype=r.dtype)
def body(carry, x):
w_prev, S_prev, valid_prev = carry
r_k, A_k, coef_k, contig_k, valid_k = x
C = coef_k * Lambda # (d, d) symmetric lag-1 block
couple = contig_k & valid_prev & valid_k # (N,)
S_safe = jnp.where(valid_prev[:, None, None], S_prev, I_d[None])
# M_k = C S_prev^{-1} = (S_prev^{-1} C)^T (C, S_prev symmetric)
M = jnp.swapaxes(jnp.linalg.solve(S_safe, jnp.broadcast_to(C, (N, d, d))), -1, -2)
M = jnp.where(couple[:, None, None], M, 0.0)
w_k = r_k - jnp.einsum("nij,nj->ni", M, w_prev)
S_k = A_k[None] - jnp.einsum("nij,jk->nik", M, C) # A_k − M C
z_k = jnp.einsum("nij,nj->ni", _inv_sqrt_psd(S_k), w_k)
# Carry the innovation covariance forward; reset masked rows so the
# next step decouples from them.
w_out = jnp.where(valid_k[:, None], w_k, 0.0)
S_out = jnp.where(valid_k[:, None, None], S_k, I_d[None])
z_out = jnp.where(valid_k[:, None], z_k, 0.0)
return (w_out, S_out, valid_k), z_out
init = (
jnp.zeros((N, d), r.dtype),
jnp.broadcast_to(I_d, (N, d, d)),
jnp.zeros((N,), bool),
)
_, z = lax.scan(body, init, (r, A_blocks, offdiag_coef, contiguous, mask))
return z
def _coerce_F_value(value, K: int, N: int, d: int) -> jnp.ndarray:
"""Reshape an SF / Basis / callable output to ``(K, N, d)``.
The OD / UD ``force_inferred`` callable accepts batched inputs
of shape ``(M, d)`` and returns ``(M, d)``; we always feed it the
flattened ``(K * N, d)`` form and reshape back.
"""
arr = jnp.asarray(value)
if arr.shape == (K * N, d):
return arr.reshape(K, N, d)
if arr.shape == (K, N, d):
return arr
raise ValueError(f"Force callable returned shape {arr.shape}; expected ({K * N}, {d}) or ({K}, {N}, {d}).")
def _measurement_noise(inferer, d: int) -> jnp.ndarray:
"""PSD measurement-noise covariance Λ (``inferer.Lambda``), or zero.
The parametric estimator profiles Λ natively and the diffusion
estimators expose it as ``Lambda``; on clean data Λ ≈ 0 so the
noise-aware whitening reduces to the thermal case. The estimate can
be marginally non-PSD on clean data, which the ``6Λ/Δt⁴`` weighting
would amplify — so we clamp to the PSD cone.
"""
Lam = getattr(inferer, "Lambda", None)
if Lam is None:
return jnp.zeros((d, d))
Lam = jnp.asarray(Lam)
if Lam.shape != (d, d):
return jnp.zeros((d, d))
w, U = jnp.linalg.eigh(0.5 * (Lam + Lam.T))
return (U * jnp.maximum(w, 0.0)) @ U.T
def _backend_tag(inferer) -> str:
if hasattr(inferer, "metadata") and isinstance(inferer.metadata, dict):
return str(inferer.metadata.get("force_method", "linear"))
return "linear"
# Continuous-limit noise factor for the underdamped acceleration residual.
#
# The underdamped diagnostic residual is the symmetric finite-difference
# acceleration â(t) = (x_{t+1} - 2 x_t + x_{t-1}) / dt² minus the fitted
# force F(x̂, v̂) — the same quantity the underdamped force estimator fits
# (the symmetric ULI kinematics in SFI.inference.underdamped: _A_sym_uli /
# _V_sym_uli / _X_sym_uli).
#
# For dx = v dt, dv = F dt + sqrt(2D) dW the position is C¹, so the noise
# part of â is the second difference of the integrated velocity noise.
# Writing N_t = sqrt(2D) ∫ B over one sampling cell (B the velocity
# Brownian motion), the adjacent-cell autocovariance integral gives
#
# Var(x_{t+1} - 2 x_t + x_{t-1}) = (4/3) D dt³ = (2/3) (2D) dt³,
#
# hence Var(â_noise) = (2/3) (2D) / dt = KAPPA_UD · A / dt with A = 2D.
#
# This factor is exact for continuously sampled data (the physical case
# for experimental trajectories) and is the limit that finely oversampled
# simulations converge to. The thermal residual is a clean MA(1) process in
# time (lag-1 ≈ 1/4, lag ≥ 2 ≈ 0), so the builder keeps every second valid
# time index to remove that thermal lag-1. The leftover measurement-noise
# correlation (a lag-1 block Λ/Δt⁴ in the kept series) is removed by the
# banded innovations whitening (_sequential_innovations).
KAPPA_UD = 2.0 / 3.0
def _process_chunk(
*,
F_at: jnp.ndarray, # (K, N, d)
r: jnp.ndarray, # (K, N, d) raw residual
dt: jnp.ndarray, # (K,) physical step (pooled into mean_dt)
mask: jnp.ndarray, # (K, N) bool
A: jnp.ndarray, # (d, d) = 2 D
A_inv: jnp.ndarray, # (d, d)
contiguous: np.ndarray, # (K,) bool sampling-successor flag
offdiag_coef: jnp.ndarray, # (K,) lag-1 coefficient on Λ
var_scale: jnp.ndarray | None = None, # (K,) thermal coefficient on A
Lambda: jnp.ndarray | None = None, # (d, d) measurement-noise covariance Λ
noise_scale: jnp.ndarray | None = None, # (K,) coefficient on Λ
):
"""Whiten residuals, compute Mahalanobis norms and ``F^T A^{-1} F``,
and return only the masked-valid rows pooled along ``(K, N)``.
Two whitenings are produced from the same residual covariance
``C = var_scale·A + noise_scale·Λ`` (diagonal blocks) with lag-1 block
``offdiag_coef·Λ``:
* **Banded innovations** ``z`` (returned for the moments / normality /
autocorrelation tests) — the sequential block-Cholesky whitening
(:func:`_sequential_innovations`) of the tridiagonal covariance. It
has unit variance *and no serial correlation*, so it does not trip the
Ljung--Box test on measurement-noise-correlated residuals.
* **Marginal Mahalanobis norms** ``sqn = rᵀ C⁻¹ r`` (returned for the
MSE-consistency / chi-square *bias* check) — the single-residual form,
kept because the banded innovations would partly difference out a
slowly-varying force bias that this check is meant to detect.
The **measurement-noise term** ``noise_scale · Λ`` makes both
noise-aware: for the underdamped acceleration residual it scales as
``6 Λ / Δt⁴`` and otherwise overwhelms the thermal term
``(2/3)(2D)/Δt`` even when the force is well recovered. When ``Λ`` is
zero (clean data, where the estimator profiles ``Λ ≈ 0``) the
off-diagonal block vanishes and both reduce to the thermal whitening
``(var_scale·A)^{-1/2} r``. ``var_scale`` defaults to ``dt`` (the
overdamped Euler scale).
"""
scale = dt if var_scale is None else var_scale
Lam = jnp.zeros_like(A) if Lambda is None else jnp.asarray(Lambda)
if Lambda is not None and noise_scale is not None:
# Drop a spuriously-small Λ: if its residual contribution is a tiny
# fraction of the thermal scale the data is *effectively clean*, and
# folding it in would deflate the chi-square / MSE-consistency bias
# signal — masking genuine misspecification. Real underdamped noise is
# amplified as 6Λ/Δt⁴, far above this floor, so it is never dropped.
thermal = jnp.mean(scale) * jnp.trace(A)
noise = jnp.mean(noise_scale) * jnp.trace(Lam)
Lam = jnp.where(noise < 0.2 * thermal, jnp.zeros_like(Lam), Lam)
C = scale[:, None, None] * A[None] # (K, d, d) diagonal blocks
if Lambda is not None and noise_scale is not None:
C = C + noise_scale[:, None, None] * Lam[None]
# Marginal Mahalanobis norms (bias / chi-square channel): rᵀ C⁻¹ r.
w, U = jnp.linalg.eigh(C) # (K, d), (K, d, d)
C_inv = jnp.einsum("kam,km,kbm->kab", U, 1.0 / w, U)
sqn = jnp.einsum("kni,kij,knj->kn", r, C_inv, r) # (K, N)
F2 = jnp.einsum("kjm,mp,kjp->kj", F_at, A_inv, F_at) # (K, N)
# Banded innovations (serial-decorrelation channel) used for z / z_full.
z = _sequential_innovations(
jnp.asarray(r),
jnp.asarray(mask),
jnp.asarray(contiguous),
C,
jnp.asarray(Lam),
jnp.asarray(offdiag_coef),
) # (K, N, d)
m_np = np.asarray(mask)
z_full = np.asarray(z) # (K, N, d) time-major, pre-masking
z_np = z_full[m_np] # (Kv, d)
sqn_np = np.asarray(sqn)[m_np] # (Kv,)
F2_np = np.asarray(F2)[m_np] # (Kv,)
# Per-valid-row dt (broadcast scalar dt over particles)
K, N = m_np.shape
dt_kn = np.broadcast_to(np.asarray(dt)[:, None], (K, N))[m_np]
return z_np, sqn_np, F2_np, dt_kn, z_full, m_np
# --------------------------------------------------------------------- #
# Overdamped builder
# --------------------------------------------------------------------- #
[docs]
def build_overdamped_residuals(inferer, data=None) -> ResidualBundle:
"""Build standardised Euler--Maruyama residuals for an OD inferer.
Routes data access through ``TrajectoryDataset.make_batch_producer`` —
the same low-level streaming layer used by ``SFI.integrate`` — so
multi-particle, masked, and multi-dataset trajectories are handled
transparently.
Works for any overdamped inference path (linear, parametric, nonlinear)
as long as ``inferer.force_inferred`` is callable and
``inferer.A_inv`` is available.
"""
if not hasattr(inferer, "force_inferred") or inferer.force_inferred is None:
raise RuntimeError("inferer.force_inferred is missing; run a force-inference method first.")
if not hasattr(inferer, "A_inv"):
raise RuntimeError("inferer.A_inv is missing; run compute_diffusion_constant() first.")
F = inferer.force_inferred
A = jnp.asarray(getattr(inferer, "A", 2.0 * inferer.diffusion_average))
A_inv = jnp.asarray(inferer.A_inv)
d = int(A.shape[0])
Lambda = _measurement_noise(inferer, d) # Λ; increment noise = 2 Λ
require = {"X", "X_plus"}
z_chunks: list[np.ndarray] = []
sqnorm_chunks: list[np.ndarray] = []
F2_chunks: list[np.ndarray] = []
dt_chunks: list[np.ndarray] = []
whitened_chunks: list = []
n_particles_max = 0
backend = _backend_tag(inferer)
collection = data if data is not None else inferer.data
for ds_idx, ds in enumerate(collection.datasets):
t_idx = ds.valid_indices(require)
if t_idx.size == 0:
continue
d_ds = int(ds.d)
if d_ds != d:
raise ValueError(f"Dataset dimension {d_ds} does not match A_inv dimension {d}.")
n_particles_max = max(n_particles_max, int(ds.N))
producer = ds.make_batch_producer(
require,
include_mask=True,
include_dt=True,
force_dt_keys={"dt"},
)
row = producer(t_idx)
X = row["X"] # (K, N, d)
Xp = row["X_plus"] # (K, N, d)
dt = row["dt"] # (K,)
mask = row["mask_out"] # (K, N) bool
K, N, _ = X.shape
extras = ds.build_extras(t_idx, dataset_index=ds_idx)
# Sampling adjacency: the lag-1 measurement-noise block (−Λ) only
# couples residuals at consecutive frames; reset across any gap.
t_arr = np.asarray(t_idx)
contiguous = np.zeros(K, dtype=bool)
contiguous[1:] = np.diff(t_arr) == 1
offdiag_coef = -np.ones(K) # Cov(r_{k-1}, r_k) = −Λ
if N > 1:
# Multiparticle: each frame has N interacting particles — compute
# the force on each frame separately to preserve particle structure.
# Per-particle extras (N, ...) reach each particle on the (N,) batch.
F_frames = [np.asarray(F(np.asarray(X[t]), extras=extras)) for t in range(K)]
F_at = np.stack(F_frames, axis=0) # (K, N, d)
else:
F_at = _coerce_F_value(F(X.reshape(K * N, d), extras=extras), K, N, d)
r = (Xp - X) - F_at * dt[:, None, None]
z_v, sqn_v, F2_v, dt_v, z_full, mask_full = _process_chunk(
F_at=F_at,
r=r,
dt=dt,
mask=mask,
A=A,
A_inv=A_inv,
contiguous=contiguous,
offdiag_coef=jnp.asarray(offdiag_coef),
var_scale=jnp.asarray(dt),
Lambda=Lambda,
noise_scale=2.0 * jnp.ones_like(jnp.asarray(dt)),
)
z_chunks.append(z_v)
sqnorm_chunks.append(sqn_v)
F2_chunks.append(F2_v)
dt_chunks.append(dt_v)
whitened_chunks.append((z_full, mask_full))
return _assemble_bundle(
z_chunks,
sqnorm_chunks,
F2_chunks,
dt_chunks,
whitened_chunks,
d=d,
regime="OD",
backend=backend,
n_particles=n_particles_max,
)
# --------------------------------------------------------------------- #
# Underdamped builder
# --------------------------------------------------------------------- #
[docs]
def build_underdamped_residuals(inferer, data=None) -> ResidualBundle:
"""Build standardised innovations for a UD inferer from the symmetric
acceleration residual.
Uses the symmetric ULI kinematics that the underdamped force estimator
itself fits (see ``SFI.inference.underdamped``):
.. math::
\\hat x = \\tfrac13(X_{t-1}+X_t+X_{t+1}), \\quad
\\hat v = \\frac{X_{t+1}-X_{t-1}}{2\\Delta t}, \\quad
\\hat a = \\frac{X_{t+1}-2X_t+X_{t-1}}{\\Delta t^2},
and forms the residual :math:`r_t = \\hat a - F(\\hat x, \\hat v)`.
Its thermal noise covariance is :math:`\\tfrac23 A/\\Delta t` (see
:data:`KAPPA_UD`); with measurement noise the diagonal block gains
:math:`6\\Sigma_\\eta/\\Delta t^4`. The thermal residual is MA(1), so
only every second valid index is kept (removing the thermal lag-1);
the residual measurement-noise correlation (lag-1 block
:math:`\\Sigma_\\eta/\\Delta t^4`) is removed by the banded innovations
whitening, leaving a serially independent stream.
Like :func:`build_overdamped_residuals`, all data access uses
``TrajectoryDataset.make_batch_producer`` so masking and
multi-dataset / multi-particle pooling are handled by the same
streaming layer that powers ``SFI.integrate``.
"""
if not hasattr(inferer, "force_inferred") or inferer.force_inferred is None:
raise RuntimeError("inferer.force_inferred is missing; run a force-inference method first.")
if not hasattr(inferer, "A_inv"):
raise RuntimeError("inferer.A_inv is missing; run compute_diffusion_constant() first.")
F = inferer.force_inferred
A = jnp.asarray(getattr(inferer, "A", 2.0 * inferer.diffusion_average))
A_inv = jnp.asarray(inferer.A_inv)
d = int(A.shape[0])
Lambda = _measurement_noise(inferer, d) # Λ; acceleration noise = 6 Λ / Δt⁴
# Symmetric 3-point stencil: X_{t-1}, X_t, X_{t+1}.
require = {"X_minus", "X", "X_plus"}
z_chunks: list[np.ndarray] = []
sqnorm_chunks: list[np.ndarray] = []
F2_chunks: list[np.ndarray] = []
dt_chunks: list[np.ndarray] = []
whitened_chunks: list = []
n_particles_max = 0
backend = _backend_tag(inferer)
collection = data if data is not None else inferer.data
for ds_idx, ds in enumerate(collection.datasets):
t_idx = ds.valid_indices(require)
# Adjacent acceleration residuals share two of three positions, so the
# *thermal* series is MA(1). Keeping every second valid index removes
# that thermal lag-1 correlation (lag ≥ 2 thermal ≈ 0). Under
# measurement noise the kept series still carries a lag-1 block
# Λ/Δt⁴ (the original lag-2), which the banded whitening below
# decorrelates; on clean data that block vanishes.
t_idx = t_idx[::2]
if t_idx.size == 0:
continue
d_ds = int(ds.d)
if d_ds != d:
raise ValueError(f"Dataset dimension {d_ds} does not match A_inv dimension {d}.")
n_particles_max = max(n_particles_max, int(ds.N))
producer = ds.make_batch_producer(
require,
include_mask=True,
include_dt=True,
force_dt_keys={"dt"},
)
row = producer(t_idx)
Xm = row["X_minus"] # (K, N, d) at t-1
X = row["X"] # (K, N, d) at t
Xp = row["X_plus"] # (K, N, d) at t+1
dt0 = row["dt"] # (K,) step t -> t+1
mask = row["mask_out"] # (K, N) bool — already AND'd
K, N, _ = X.shape
# Strided sampling adjacency: the lag-1 block (Λ/Δt⁴) only couples
# kept residuals two original frames apart; reset across any gap.
t_arr = np.asarray(t_idx)
contiguous = np.zeros(K, dtype=bool)
contiguous[1:] = np.diff(t_arr) == 2
# Symmetric ULI kinematics (match SFI.inference.underdamped):
# x̂ = (X₋ + X + X₊)/3, v̂ = (X₊ − X₋)/(2 dt), â = (X₊ − 2X + X₋)/dt²
dt_b = dt0[:, None, None]
x_hat = (Xm + X + Xp) / 3.0
v_hat = (Xp - Xm) / (2.0 * dt_b)
a_hat = (Xp - 2.0 * X + Xm) / (dt_b * dt_b)
extras = ds.build_extras(t_idx, dataset_index=ds_idx)
if N > 1:
# Preserve the particle axis so bases with per-particle extras
# (e.g. per-agent home ranges) see each particle's own value.
F_at = _coerce_F_value(F(x_hat, v=v_hat, extras=extras), K, N, d)
else:
F_at = _coerce_F_value(
F(x_hat.reshape(K * N, d), v=v_hat.reshape(K * N, d), extras=extras),
K,
N,
d,
)
# Residual in acceleration units; thermal Cov = (2/3) A / dt
# (KAPPA_UD), measurement-noise Cov = 6 Λ / dt⁴ (second difference of
# i.i.d. localisation errors: variance (1+4+1) Λ / dt⁴).
r = a_hat - F_at
var_scale = KAPPA_UD / dt0 # (K,)
noise_scale = 6.0 / (dt0 ** 4) # (K,)
offdiag_coef = 1.0 / (dt0 ** 4) # Cov(r_{k-1}, r_k) = Λ/Δt⁴
z_v, sqn_v, F2_v, dt_pool, z_full, mask_full = _process_chunk(
F_at=F_at,
r=r,
dt=dt0,
mask=mask,
A=A,
A_inv=A_inv,
contiguous=contiguous,
offdiag_coef=offdiag_coef,
var_scale=var_scale,
Lambda=Lambda,
noise_scale=noise_scale,
)
z_chunks.append(z_v)
sqnorm_chunks.append(sqn_v)
F2_chunks.append(F2_v)
dt_chunks.append(dt_pool)
whitened_chunks.append((z_full, mask_full))
return _assemble_bundle(
z_chunks,
sqnorm_chunks,
F2_chunks,
dt_chunks,
whitened_chunks,
d=d,
regime="UD",
backend=backend,
n_particles=n_particles_max,
nmse_excess_factor=KAPPA_UD,
)
def _assemble_bundle(
z_chunks,
sqnorm_chunks,
F2_chunks,
dt_chunks,
whitened_chunks,
*,
d: int,
regime: str,
backend: str,
n_particles: int,
nmse_excess_factor: float = 1.0,
) -> ResidualBundle:
if z_chunks:
z_components = np.concatenate(z_chunks, axis=0)
sqn_pooled = np.concatenate(sqnorm_chunks, axis=0)
F2_pooled = np.concatenate(F2_chunks, axis=0)
dt_pooled = np.concatenate(dt_chunks, axis=0)
mean_dt = float(np.mean(dt_pooled)) if dt_pooled.size else float("nan")
else:
z_components = np.zeros((0, d))
sqn_pooled = np.zeros((0,))
F2_pooled = np.zeros((0,))
mean_dt = float("nan")
z_pooled = z_components.reshape(-1)
return ResidualBundle(
z=z_pooled,
z_components=z_components,
z_squared_norms=sqn_pooled,
force_quadratic_form=F2_pooled,
mean_dt=mean_dt,
n_obs=int(z_components.shape[0]),
d=d,
regime=regime,
backend=backend,
n_particles=n_particles,
nmse_excess_factor=nmse_excess_factor,
whitened=whitened_chunks,
)
# --------------------------------------------------------------------- #
# Dispatch
# --------------------------------------------------------------------- #
[docs]
def build_residuals(inferer, data=None) -> ResidualBundle:
"""Dispatch to the OD / UD residual builder based on the engine class.
``data`` (optional) evaluates the residuals on an independent
:class:`~SFI.trajectory.TrajectoryCollection` instead of the
training data — the held-out path used by ``holdout_score``.
"""
cls = type(inferer).__name__
if "Underdamped" in cls:
return build_underdamped_residuals(inferer, data=data)
if "Overdamped" in cls:
return build_overdamped_residuals(inferer, data=data)
# Fallback: try by attribute
if hasattr(inferer, "_force_inference_underdamped"):
return build_underdamped_residuals(inferer, data=data)
return build_overdamped_residuals(inferer, data=data)