from __future__ import annotations
import logging
import jax.numpy as jnp
from SFI.bases.constants import constant_array
from SFI.inference.base import BaseLangevinInference
from SFI.integrate.api import integrate
from SFI.integrate.integrand import (
ConstOperand,
ExprOperand,
Integrand,
Term,
TimeOperand,
)
from SFI.integrate.timeops import stream, timeop, velocity
from SFI.statefunc import PSF, SF, Basis
from SFI.utils.maths import sqrtm_psd
from .result import InferenceResultSF # fitted SF with param_cov/meta
from .sparse import SparseScorer
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------- #
# OLI linear force-inference presets
# ---------------------------------------------------------------------------- #
# A single ``preset`` keyword bundles the (M_mode, G_mode) convention surface,
# mirroring the underdamped engine. ``'auto'`` keys off measurement noise; the
# ``legacy-*`` presets reproduce published SFI v1.0 conventions.
_FORCE_LINEAR_PRESETS: dict[str, tuple[str, str]] = {
# noise-robust default: Stratonovich moments + noise-decorrelating Gram
"robust": ("Strato", "shift"),
# sharper for verified clean / fine-sampled data
"clean": ("Ito", "trapeze"),
# Kramers–Moyal: plain Itô finite-difference moments, rectangle Gram
"KM": ("Ito", "rectangle"),
# SFI v1.0 (2020 PRX): Stratonovich moments, rectangle Gram
"legacy-sfi-v1.0": ("Strato", "rectangle"),
}
def _resolve_force_linear_preset(
preset: str,
*,
M_mode: str | None = None,
G_mode: str | None = None,
noise_detected: bool | None = None,
) -> tuple[str, str, str]:
"""Resolve an overdamped force ``preset`` to a concrete ``(M_mode, G_mode)``.
``'auto'`` selects ``'robust'`` when measurement noise is detected and
``'clean'`` otherwise. Any non-``None`` ``M_mode`` / ``G_mode`` (other than
the legacy ``M_mode='auto'`` synonym) overrides the preset on that axis.
Returns ``(M_mode, G_mode, resolved_preset)``.
"""
if preset == "auto":
if noise_detected is None:
raise RuntimeError(
"preset='auto' needs the measurement-noise estimate; call "
"compute_diffusion_constant() first, or pass an explicit preset."
)
base = "robust" if noise_detected else "clean"
elif preset in _FORCE_LINEAR_PRESETS:
base = preset
else:
choices = ", ".join(repr(p) for p in ("auto", *_FORCE_LINEAR_PRESETS))
raise ValueError(f"unknown preset {preset!r}; choose from {choices}")
M0, G0 = _FORCE_LINEAR_PRESETS[base]
M = M0 if M_mode in (None, "auto") else M_mode
G = G0 if G_mode is None else G_mode
return M, G, base
[docs]
class OverdampedLangevinInference(BaseLangevinInference):
r"""Stochastic Force Inference concrete class for overdamped systems
This class provides tools for inferring force (drift) and
diffusion tensors from stochastic trajectory data based on overdamped
Langevin dynamics. It supports both linear and nonlinear basis
function methods.
Core Equation
~~~~~~~~~~~~~
The dynamics are described by the 1st order autonomous stochastic
differential equation (SDE)::
dx/dt = F(x) + sqrt(2D(x)) dxi(t)
where:
- ``F(x)`` is the Ito drift (force) term.
- ``D(x)`` is the diffusion tensor, evaluated in the Ito convention.
- ``dxi(t)`` is Gaussian white noise.
.. physics:: Overdamped Langevin SDE
:label: overdamped-langevin-sde
:category: Dynamical equations
.. math::
\frac{\mathrm{d}x}{\mathrm{d}t}
= F(x) + \sqrt{2\,D(x)}\;\mathrm{d}\xi(t)
:math:`F(x)` is the Itô drift, :math:`D(x)` the diffusion tensor
(Itô convention), and :math:`\mathrm{d}\xi` is Gaussian white noise.
Here x is a 2D array of shape Nparticles x dimension. All particles
are assumed to have identical properties.
This class provides tools to approximate F(x) and D(x) from a time series x(t) formatted as TrajectoryCollection.
Note that the `Ito` and `Strato` variants of the force inference
routines do NOT refer to the convention in which the SDE is
expressed (which is always Ito), but to the way stochastic
integrals are performed to compute parameters.
Key Features
~~~~~~~~~~~~
- Force Inference:
- Linear combination of basis functions (`infer_force_linear`).
- Parametric families (`infer_force` with a `Basis` or `PSF`).
- Diffusion Inference:
- Constant diffusion estimation (`compute_diffusion_constant`).
- State-dependent diffusion with basis functions (`infer_diffusion_linear` and `infer_diffusion`).
- Sparsification:
- Force sparsification for linear inference `sparsify_force`, implementing PASTIS and other information criteria.
- Error Estimation:
- Normalized mean-squared error (MSE) prediction for both force and diffusion.
- Comparison Tools for method benchmarking:
- Evaluate inferred fields against known exact models (`compare_to_exact`).
- Simulation:
- Generate trajectories using inferred fields (`simulate_bootstrapped_trajectory`).
Workflow
~~~~~~~~
1. Initialize with `TrajectoryCollection` containing the trajectory.
2. Use the `infer_*` methods to infer force and diffusion fields.
3. Optionally compute error estimates and/or compare with exact data for validation.
Indices Convention
~~~~~~~~~~~~~~~~~~
The code uses jnp.einsum for array manipulation, with a consistent index naming scheme for clarity:
- `t` : Time index, = 0..Ntimesteps-1
- `a, b, c...` : Basis function indices, = 0..Nfunctions - 1.
- `m, n, o...` : Spatial indices, = 0..dim-1.
- `i, j...` : Particle indices.
We also use these indices as shortcuts for array shapes. For
instance `basis_linear : im -> iam` reads `basis_linear has input
a jnp.array of shape (Nparticles,dim) and outputs a jnp.array of
shape (Nparticles,Nfunctions,dim)`.
Logging
~~~~~~~
Progress messages use Python ``logging``. Enable with
``SFI.enable_logging()`` or ``logging.getLogger('SFI').setLevel(logging.INFO)``.
Example
~~~~~~~
Fully documented examples in the "examples" folder: Lorenz model, ActiveBrownianParticles, Ornstein-Uhlenbeck...
"""
[docs]
def compute_diffusion_constant(self, method: str = "auto") -> None:
"""Estimate a constant (spatially uniform) diffusion matrix.
Parameters
----------
method : {"auto", "noisy", "WeakNoise", "MSD"}
Estimator to use. ``"noisy"`` is the noise-robust
Vestergaard–Blainey–Flyvbjerg estimator. ``"auto"`` selects
``"noisy"`` when the measurement-noise trace Tr(Λ) > 0
(localization noise detected), and ``"WeakNoise"`` otherwise.
Updates
-------
Sets ``diffusion_average``, ``diffusion_inferred``, ``A``,
``A_inv``, ``sqrtA``, ``sqrtA_inv``, ``Lambda``,
``Lambda_trace``, and ``metadata["diffusion_constant_method"]``.
"""
if hasattr(self, "diffusion_average"):
raise RuntimeError("Diffusion already computed; create a new inference object to recompute.")
# 1) measurement noise Λ
L_op = self.get_diffusion_timeop("Lambda")
L_prog = Integrand(times=[L_op], terms=[Term(eq="imn->imn", ops=(L_op.alias,))])
# Diffusion/noise are per-point quantities: each increment's estimate has
# constant relative variance regardless of dt, so combine them per-point
# (weight_by_dt=False), not dt-weighted (which is the force convention).
self.Lambda = integrate(self.data, L_prog, reduce="mean", weight_by_dt=False,
chunk_target_bytes=self._chunk_target_bytes)
self.Lambda_trace = float(jnp.trace(self.Lambda))
logger.info("Measurement noise trace: %s", self.Lambda_trace)
# 2) select instantaneous estimator
if method == "auto":
method = "noisy" if self.Lambda_trace > 0 else "WeakNoise"
self.metadata["diffusion_constant_method"] = method
# 3) time-average instantaneous diffusion
D_op = self.get_diffusion_timeop(method)
D_prog = Integrand(times=[D_op], terms=[Term(eq="imn->imn", ops=(D_op.alias,))])
self.diffusion_average = integrate(
self.data, D_prog, reduce="mean", weight_by_dt=False,
chunk_target_bytes=self._chunk_target_bytes,
)
self.diffusion_inferred = constant_array(self.diffusion_average)
# 4) normalization matrices
self.A = 2.0 * self.diffusion_average
self.A_inv = jnp.linalg.inv(self.A)
self.sqrtA = sqrtm_psd(self.A)
self.sqrtA_inv = jnp.linalg.inv(self.sqrtA)
[docs]
def infer_force_linear(
self,
basis: Basis,
*,
preset: str = "auto",
M_mode: str | None = None,
G_mode: str | None = None,
):
r"""Infer the force field as a linear combination of basis functions (linear regression).
.. physics:: Linear force regression (overdamped)
:label: linear-force-regression-overdamped
:category: Inference
.. math::
\hat F(x) = \sum_a C_a\, b_a(x)
\qquad\text{where}\qquad
G\,C = M
:math:`G_{ab} = \langle b_a(x_t)\, b_b(x_t) \rangle` is the Gram
matrix, :math:`M_a = \langle v_t \cdot A^{-1} \cdot b_a \rangle` are
the force moments, and :math:`A = 2\bar D`.
This method computes the force field coefficients (`self.force_coefficients`) using
the provided Basis object. The force field is represented as:
inferred_force(x) = sum_a basis_linear(x)[:,a] * force_coefficients[a]
These coefficients are computed by solving a linear system::
G . force_coefficients = force_moments
and the different options account for the manner to compute G
and force_moments. In its simplest form::
G_ab = < b_a(xt) b_b(xt) > [G_mode = 'rectangle']
force_moments[a] = < dX[t]/dt b_a(xt) > [mode = 'Ito' ]
but this is rarely the best choice of parameters.
Args:
basis: Basis
The fitting functions, encoded as a single callable Basis
object (see SFI.statefunc submodule for doc and SFI.bases for examples).
preset (str):
Single-keyword convention bundle (mirrors the underdamped
engine), default ``'auto'``. ``'auto'`` picks ``'robust'``
when measurement noise is detected (Tr(Λ) > 0) and ``'clean'``
otherwise. Presets: ``'robust'`` (Stratonovich moments +
``'shift'`` Gram, noise-robust); ``'clean'`` (Itô +
``'trapeze'``, clean / fine-sampled data); ``'KM'``
(Kramers–Moyal: Itô + ``'rectangle'``); ``'legacy-sfi-v1.0'``
(Stratonovich + ``'rectangle'``, the published SFI v1.0
convention).
M_mode (str, optional):
Override the preset's moment convention: ``'Ito'``,
``'Ito-shift'``, or ``'Strato'``. ``None`` (default) uses the
preset.
G_mode (str, optional):
Override the preset's Gram normalization: ``'rectangle'`` (2020
PRX), ``'trapeze'`` (2024 Amiri et al. PRR), or ``'shift'``
(``<b_a(xt) b_b(x_t+dt)>``, decorrelates measurement noise).
``None`` (default) uses the preset.
Outputs:
Updates the following attributes:
- self.force_scorer: SparseScorer for model selection.
- self.force_coefficients: The inferred coefficients for the basis functions.
- self.force_inferred: Callable function representing the inferred force field.
- self.force_G: The normalization matrix used in the inference process.
"""
noise = self.Lambda_trace > 0.0 if hasattr(self, "Lambda_trace") else None
M_mode, G_mode, _resolved_preset = _resolve_force_linear_preset(
preset, M_mode=M_mode, G_mode=G_mode, noise_detected=noise,
)
logger.info(
"Force inference: preset=%s -> M_mode=%s, G_mode=%s (Lambda trace: %s)",
_resolved_preset, M_mode, G_mode, getattr(self, "Lambda_trace", None),
)
self._validate_basis(basis, expected_rank=1, label="force basis")
self.__force_M_mode__ = M_mode
self.__force_G_mode__ = G_mode
self.force_basis = basis
self.metadata["force_preset"] = _resolved_preset
self.metadata["force_M_mode"] = M_mode
self.metadata["force_G_mode"] = G_mode
if hasattr(self, "force_G_full"):
raise RuntimeError("Force has already been inferred on this object - create a new instance to re-infer.")
# Structural (CSR/stencil) arrays are exposed only for this evaluation and
# never persisted on the dataset (see ``_structural_scope``).
with self._structural_scope(self.force_basis):
self.force_G_full = self._force_G_matrix()
self.force_moments = self._force_moments()
self.force_scorer = SparseScorer(M=self.force_moments, G=self.force_G_full)
self._update_force_coefficients(self.force_scorer.total_C)
self.metadata["force_method"] = "linear"
[docs]
def infer_diffusion_linear(
self,
basis: Basis = None,
*,
M_mode: str = "auto",
G_mode: str = "rectangle",
) -> None:
"""
Fit the diffusion field as a linear combination of basis functions.
This method computes the coefficients of the diffusion tensor field (`self.diffusion_coefficients`) using
the provided basis functions. The diffusion tensor is represented as:
diffusion_inferred(x, mask) = sum_a basis_linear(x, mask)[:,a] * diffusion_coefficients[a]
Args:
basis (Basis with rank = 2 or None): the fitting functions. When ``None``
(default), ``symmetric_matrix_basis(d)`` is used, spanning all constant
symmetric ``d×d`` diffusion matrices. Requires a prior
``compute_diffusion_constant()`` call to determine ``d``.
M_mode (str):
The method used for local diffusion tensor estimation and moments computation.
See _diffusion_estimator documentation for additional information.
G_mode (str):
The method used to compute the normalization matrix `G`.
Not investigated extensively yet for diffusion inference.
Updates:
self.diffusion_coefficients: The inferred coefficients for the diffusion basis functions.
self.diffusion_inferred: Callable representing the inferred diffusion tensor field.
self.diffusion_G: The normalization matrix used in the inference process.
Note:
This inferred tensor field is not guaranteed to be nonnegative.
"""
if basis is None:
if not hasattr(self, "diffusion_average"):
raise RuntimeError(
"infer_diffusion_linear() with no basis requires a prior "
"compute_diffusion_constant() call to determine spatial dimension."
)
from SFI.bases import symmetric_matrix_basis
basis = symmetric_matrix_basis(self.diffusion_average.shape[0])
if M_mode == "auto":
if self.Lambda_trace > 0.0:
M_mode = "noisy"
G_mode = "rectangle"
else:
M_mode = "WeakNoise"
G_mode = "rectangle"
logger.info(
"Auto-selecting diffusion inference: M_mode %s, G_mode %s (Lambda trace: %s)",
M_mode,
G_mode,
self.Lambda_trace,
)
self._validate_basis(basis, expected_rank=2, label="diffusion basis")
self.diffusion_basis = basis
self.__diffusion_M_mode__ = M_mode
self.__diffusion_G_mode__ = G_mode
self.metadata["diffusion_M_mode"] = M_mode
self.metadata["diffusion_G_mode"] = G_mode
if hasattr(self, "diffusion_moments"):
raise RuntimeError(
"Diffusion has already been inferred on this object - create a new instance to re-infer."
)
with self._structural_scope(self.diffusion_basis):
self.diffusion_G_full = self._diffusion_G_matrix()
self.diffusion_moments = self._diffusion_moments()
self.diffusion_scorer = SparseScorer(M=self.diffusion_moments, G=self.diffusion_G_full)
self._update_diffusion_coefficients(self.diffusion_scorer.total_C)
[docs]
def simulate_bootstrapped_trajectory(self, key, oversampling=1, simulate=True, dataset=0):
"""
Simulate an overdamped Langevin trajectory with the inferred force and diffusion fields.
This function generates a trajectory using the inferred force field and diffusion tensor inferred
from the input data, matching the original time series and initial conditions.
Args:
key: JAX random key for generating noise in the simulation.
oversampling (int, optional): Factor for oversampling (i.e. number of intermediate simulated
points between two recorded points). Defaults to 1.
simulate: if True, performs the simulation with the first data point as initial position;
if False, returns an uninitialized object which can be simulated with flexible
initial position and parameters.
dataset (int, optional): Which experiment of a pooled fit to reproduce. The inferred
model is collapsed to this condition via
:meth:`~SFI.statefunc.StateExpr.specialize` (folding ``per_dataset_scalar`` /
``dataset_indicator`` at ``dataset``); the resulting process is a standalone
single-condition model that does not read ``dataset_index``. Defaults to 0.
Returns:
OverdampedProcess: Simulated Langevin process object.
"""
from SFI.langevin import OverdampedProcess
# Collapse a (possibly pooled) fit to the chosen experiment: per-dataset
# parameters are folded at ``dataset`` and the reserved ``dataset_index``
# disappears, so the simulated model is standalone (no dataset concept).
force_k = self.force_inferred.specialize(dataset=int(dataset))
diff_k = self.diffusion_inferred.specialize(dataset=int(dataset))
bootstrapped_process = OverdampedProcess(force_k._psf, diff_k._psf)
bootstrapped_process.set_params(theta_F=force_k.params, theta_D=diff_k.params)
ds = self.data.datasets[int(dataset)]
bootstrapped_process.set_extras(extras_global=ds.extras_global, extras_local=ds.extras_local)
# Only the particle count is framework-supplied; the specialized
# force does not read ``dataset_index``.
bootstrapped_process._reserved_overrides = {
"particle_index": jnp.arange(int(ds.N)),
}
if simulate:
from SFI.trajectory import TrajectoryCollection
start_config = TrajectoryCollection.from_dataset(ds).peek_row(require={"X", "dt"})
x0 = jnp.asarray(start_config["X"])
if not jnp.all(jnp.isfinite(x0)):
x0 = self._find_finite_x0()
bootstrapped_process.initialize(x0)
data_bootstrap = bootstrapped_process.simulate(
dt=start_config["dt"],
Nsteps=ds.T,
key=key,
prerun=0,
oversampling=oversampling,
)
return data_bootstrap, bootstrapped_process
return bootstrapped_process
# ── Parametric (windowed) force inference ──
[docs]
def infer_force(
self,
F,
theta0=None,
*,
D=None,
Lambda=None,
integrator: str = "rk4",
n_substeps: int = 1,
inner: str = "auto",
eiv="auto",
max_outer: int = 5,
inner_maxiter: int = 80,
extra_radius: int = 1,
) -> None:
r"""Infer the force field with the minimal parametric estimator.
Built on ``SFI.inference.parametric_core``: a single RK4 flow step
per observation interval defines the residual, the residual
covariance gives a windowed-precision NLL, and the parameters are
found by direct Gauss–Newton (linear-in-θ ``Basis``, with the
skip-trick errors-in-variables instrument) or frozen-precision
L-BFGS (nonlinear-in-θ ``PSF``). ``(D, Λ)`` are profiled
natively: moment-estimator init, then one windowed
conditional-NLL refinement at the fitted θ.
.. physics:: Parametric windowed force inference (overdamped)
:label: parametric-force-overdamped
:category: Inference
The observed positions follow
:math:`y_t = x_t + \eta_t` where
:math:`\eta \sim \mathcal{N}(0, \Lambda)`. The
deterministic flow
:math:`\Phi(x;\theta) = z(\Delta t) - x`
(one RK4 step by default) defines the residual
:math:`r_t = y_{t+1} - y_t - \Phi(y_t;\theta)`.
Residuals follow a banded Gaussian whose local precision
weights the Gauss–Newton normal equations; under
measurement noise the left factor is replaced by the
η-clean *skip* instrument
:math:`\psi_{\rm inst} = \partial\Phi/\partial\theta`
evaluated at the lagged clean point (``eiv=True``),
giving a consistent estimating equation.
Parameters
----------
F : PSF or Basis
Parametric drift model. A ``Basis`` is converted to a PSF
internally (coefficients initialised to zero) and PASTIS
sparsification is enabled; it runs the fast direct-GN path.
A ``PSF`` (possibly nonlinear in θ) runs the L-BFGS path.
theta0 : dict, array, or None
Initial drift parameters (default: zeros).
D : array (d, d), optional
Fixed diffusion matrix. If both ``D`` and ``Lambda``
are given, noise profiling is skipped entirely (fast path).
Lambda : array (d, d), optional
Fixed measurement-noise covariance.
integrator : {"rk4", "euler"}
Flow predictor (default ``"rk4"``, a single 4th-order step).
n_substeps : int
Integrator micro-steps per observation interval (default 1 —
the single-step minimal estimator).
inner : {"auto", "gn", "lbfgs"}
Inner solver. ``"auto"`` → direct Gauss–Newton for a linear
``Basis``, L-BFGS for a ``PSF``.
eiv : {"auto", True, False, float}
Measurement-noise errors-in-variables instrument. ``"auto"``
(default) → ``True`` for all models (interacting models use
the same N-body flow for the instrument as for the residual);
``True`` forces the η-clean skip instrument (consistent under
noise); ``False`` is the plain MLE; a float in ``[0, 1]``
blends. Active on the GN path only.
max_outer : int
Outer Gauss–Newton / IRLS iterations (default 5).
inner_maxiter : int
Inner L-BFGS iterations per outer step on the PSF path
(default 80; raise for large nonlinear families, e.g. NNs).
extra_radius : int
Precision-window padding beyond the covariance bandwidth
(default 1). Raise to 2–3 in the noise-dominated regime
β = Tr(Λ)/Tr(2DΔt) ≫ 1, where the windowed precision
decays slowly and the default window under-resolves it.
Updates
-------
Sets standard ``force_*`` attributes:
``force_inferred``, ``force_psf``, ``force_G``,
``force_G_pinv``, ``force_coefficients_full``,
``force_coefficients``, ``force_support``,
``force_moments``.
Also sets ``diffusion_average``, ``A``, ``A_inv``, ``Lambda``
from the profiled ``(D, Λ)``.
When ``F`` is a ``Basis``, additionally sets
``force_basis``, ``force_G_full``, and ``force_scorer``
so that ``sparsify_force()`` can be called afterwards.
See Also
--------
:ref:`parametric-concept` : Mathematical foundations.
:ref:`parametric-algorithm` : Detailed algorithm description.
"""
import time as _time
from SFI.inference.parametric_core.solve import _as_psf, solve_force_od
if hasattr(self, "force_inferred"):
raise RuntimeError("Force inference has already been run on this object.")
if integrator not in ("rk4", "euler"):
raise ValueError(f"integrator must be 'rk4' or 'euler', got {integrator!r}")
# ── Accept Basis → convert to PSF, enable PASTIS ──
basis_mode = isinstance(F, Basis)
if basis_mode:
self.force_basis = F
elif not (hasattr(F, "flatten_params") and hasattr(F, "unflatten_params")):
raise TypeError("F must be a PSF or a Basis. Got %s." % type(F).__name__)
F_psf = _as_psf(F)
with self._structural_scope(F_psf):
dt_val = float(self.data.peek_row(require={"dt"})["dt"])
logger.info(
"[infer_force] OD minimal parametric solve (n_params=%d, dt=%.4g, %s·n%d, basis_mode=%s)...",
int(F_psf.template.size), dt_val, integrator, n_substeps, basis_mode,
)
t0 = _time.perf_counter()
res = solve_force_od(
self.data, F, theta0=theta0, D=D, Lambda=Lambda,
integrator=integrator, n_substeps=n_substeps,
inner=inner, eiv=eiv, max_outer=max_outer,
inner_maxiter=inner_maxiter, extra_radius=extra_radius,
)
t_elapsed = _time.perf_counter() - t0
# ── Store results (standard force_* attributes) ──
theta_flat = res.theta
G = res.G
self.force_psf = F_psf
self.force_G = G
# Parameter covariance from the solver: the sandwich G⁻¹HG⁻ᵀ on the
# IV (eiv) path, the inverse information on the symmetric path.
G_inv = None
if bool(jnp.all(jnp.isfinite(res.theta_cov))):
G_inv = res.theta_cov
self.force_G_pinv = G_inv
self.force_coefficients_full = theta_flat
self.force_coefficients = theta_flat
self.force_support = jnp.arange(theta_flat.size)
self.force_moments = G @ theta_flat
# ── PASTIS plumbing (Basis mode only) ──
if basis_mode:
from SFI.inference.sparse import SparseScorer
self.force_G_full = G
self.force_scorer = SparseScorer(
M=self.force_moments,
G=self.force_G_full,
)
self.diffusion_average = res.D
# Callable constant-D field, so the full result surface (incl.
# simulate_bootstrapped_trajectory) works without a separate
# compute_diffusion_constant() call; a later infer_diffusion()
# overwrites it with the state-dependent fit.
self.diffusion_inferred = constant_array(res.D)
self.A = 2.0 * res.D
self.A_inv = jnp.linalg.inv(self.A)
self.Lambda = res.Lambda
self.Lambda_trace = float(jnp.trace(res.Lambda))
beta = float(jnp.trace(res.Lambda) / (jnp.trace(res.D) * dt_val + 1e-30))
sf_F = SF(F_psf, F_psf.unflatten_params(theta_flat))
meta_F = dict(
kind="force",
inference="parametric",
n_params=int(F_psf.template.size),
beta=beta,
A_inv=jnp.asarray(self.A_inv),
)
self.force_inferred = InferenceResultSF(
sf_F,
param_cov=G_inv,
meta=meta_F,
)
self.metadata["force_method"] = "parametric"
self.metadata["force_parametric_info"] = {
**res.info,
"integrator": integrator,
"beta": beta,
"D_matrix": res.D,
"Lambda": res.Lambda,
}
logger.info(
"[infer_force] Done in %.1f s. Λ=%s, D=%s, β=%.3f",
t_elapsed,
jnp.diag(res.Lambda),
jnp.diag(res.D),
beta,
)
# ── State-dependent diffusion from parametric residuals ──
[docs]
def infer_diffusion(
self,
basis=None,
*,
theta_D0=None,
integrator: str = "rk4",
n_substeps: int = 1,
maxiter: int = 100,
) -> None:
r"""Infer state-dependent diffusion D(x) from parametric residuals.
Requires a prior parametric :meth:`infer_force` call. Holds the
fitted force fixed and minimises the windowed conditional NLL over
the diffusion parameters (the log-det term makes the diffusion
level identifiable), reusing the same flow residuals and integrate
engine as the force solve.
.. physics:: State-dependent diffusion inference (overdamped)
:label: parametric-diffusion-overdamped
:category: Inference
With the force :math:`\hat F` held fixed, the state-dependent
diffusion :math:`D(x;\theta_D)` is optimised by minimising the
windowed conditional negative log-likelihood; :math:`\Lambda`
from the force inference is held fixed. A rank-2 basis gives
:math:`D(x) = \sum_j (\theta_D)_j\, d_j(x)`; a PSF is evaluated
directly.
Parameters
----------
basis : Basis (rank 2), PSF, or None
Diffusion model. A rank-2 ``Basis`` gives the linear
parameterisation; a ``PSF`` is used directly
(``D(x) = PSF(x; θ_D)``). When ``None`` (default),
``symmetric_matrix_basis(d)`` — all constant symmetric
``d×d`` diffusion matrices.
theta_D0 : dict, array, or None
Initial diffusion parameters (default zeros).
integrator : {"rk4", "euler"}
Flow predictor (default ``"rk4"``, matching :meth:`infer_force`).
n_substeps : int
Integrator micro-steps per Δt (default 1).
maxiter : int
L-BFGS maximum iterations (default 100).
Updates
-------
Sets ``diffusion_inferred``, ``diffusion_coefficients`` (and
``diffusion_basis`` when ``basis`` is a ``Basis``), plus metadata.
See Also
--------
:ref:`parametric-algorithm` : Full algorithm description.
"""
from SFI.bases import symmetric_matrix_basis
from SFI.inference.parametric_core.solve import _as_psf, solve_diffusion_od
if not hasattr(self, "force_inferred"):
raise RuntimeError("infer_diffusion() requires a prior infer_force() call.")
if self.metadata.get("force_method") not in ("parametric", "parametric_core"):
raise RuntimeError(
"infer_diffusion() requires parametric force inference (call infer_force(), not infer_force_linear)."
)
d = self.data.datasets[0].X.shape[-1]
if basis is None:
basis = symmetric_matrix_basis(d)
if isinstance(basis, Basis):
self._validate_basis(basis, expected_rank=2, label="diffusion basis")
self.diffusion_basis = basis
D_psf = _as_psf(basis)
Lambda = getattr(self, "Lambda", None)
if Lambda is None:
Lambda = jnp.zeros((d, d))
res = solve_diffusion_od(
self.data, self.force_psf, self.force_coefficients_full, D_psf,
Lambda=Lambda, theta_D0=theta_D0,
n_substeps=n_substeps, integrator=integrator, maxiter=maxiter,
)
theta_D = res.theta_D
self.diffusion_coefficients = theta_D
self.diffusion_coefficients_full = theta_D
sf_D = SF(D_psf, D_psf.unflatten_params(theta_D))
meta_D = dict(
kind="diffusion",
inference="parametric",
n_params=int(D_psf.template.size),
nll=res.info["nll"],
)
self.diffusion_inferred = InferenceResultSF(
sf_D,
param_cov=None,
meta=meta_D,
)
self.metadata["diffusion_method"] = "parametric"
self.metadata["diffusion_parametric_info"] = dict(res.info)
#################################################################
################ BACKEND ############################
#################################################################
# Hooks required by BaseLangevinInference:
def _force_G_matrix(self) -> jnp.ndarray:
b_left = self.force_basis
b_right = self.force_basis @ self.A_inv
return self.__G_matrix__(b_left, b_right, self.__force_G_mode__, "ima,imb->iab")
def _force_moments(self):
r"""
Compute force moments ⟨ v · A_inv · b ⟩ with Ito / Ito-shift / Strato flavors.
.. physics:: Overdamped force moments (linear regression)
:label: force-moments-overdamped
:category: Inference
**Itô moments:**
.. math::
M_a = \bigl\langle v_t \cdot A^{-1} \cdot b_a(X_t) \bigr\rangle
**Stratonovich moments** (trapezoid + gradient correction):
.. math::
M_a^{\text{S}} = \tfrac{1}{2}\bigl\langle v_t \cdot A^{-1}
\cdot \bigl[b_a(X_t) + b_a(X_{t+1})\bigr] \bigr\rangle
\;-\; \bigl\langle D_{\text{inst}} : (A^{-1} \cdot \nabla_x b_a) \bigr\rangle
The force coefficients solve :math:`G \cdot C = M`
where :math:`G_{ab} = \langle b_a \cdot b_b \rangle` (with mode variants).
Contractions
------------
- RHS (Ito or Ito-shift):
eq='im,mn,ina->ia' with ops (V, A, B)
shapes:
V: (N, m) from velocity(dX, dt)
A: (m, n) constant A_inv
B: (i, n, a) basis at X or X_minus, features last
- Stratonovich v∘b:
trapezoid average over X and X_plus, same rhs contraction on each, then 0.5*(...+...)
- Stratonovich gradient correction:
eq='imn,ioma,no->ia' with ops (D, G, A)
shapes:
D: (i, m, n) instantaneous diffusion (N, d, d) e.g. noisy
G: (i, o, m, a) basis.d_x()(X) shape (N, d_deriv, d_force, F)
A: (m, n) A_inv (d, d)
Masking
-------
Mask is applied by the integrator on the leading particle axis i,
and forwarded to state-expression calls via `mask_out`. The leaf
fill policy (`zerostop` by default) replaces masked entries with 0
via `jnp.where`, which naturally gives zero tangents for masked
entries without blocking the Jacobian for active entries.
"""
if not hasattr(self, "A_inv"):
raise RuntimeError("A_inv not available. Compute diffusion first.")
# Common pieces
A = ConstOperand(self.A_inv, alias="A")
V = TimeOperand(velocity("dX", "dt"), alias="V")
mode = getattr(self, "__force_M_mode__", "Ito")
if mode not in ("Ito", "Ito-shift", "Strato"):
raise KeyError(f"Unknown __force_M_mode__: {mode}")
if mode in ("Ito", "Ito-shift"):
x_key = "X_minus" if mode == "Ito-shift" else "X"
B = ExprOperand(expr=self.force_basis, x=stream(x_key), alias="B")
prog = Integrand(
exprs=[B],
times=[V],
consts=[A],
terms=[Term(eq="im,mn,ina->ia", ops=("V", "A", "B"))],
)
logger.debug(
"Computing Ito-shift force coefficients."
if mode == "Ito-shift"
else "Computing Ito force coefficients."
)
return integrate(self.data, prog, reduce="sum", chunk_target_bytes=self._chunk_target_bytes)
# Stratonovich
logger.debug("Computing Strato force coefficients.")
# v ∘ b via trapezoid on basis
B0 = ExprOperand(expr=self.force_basis, x=stream("X"), alias="B0")
Bp = ExprOperand(expr=self.force_basis, x=stream("X_plus"), alias="Bp")
prog_B0 = Integrand(
exprs=[B0],
times=[V],
consts=[A],
terms=[Term(eq="im,mn,ina->ia", ops=("V", "A", "B0"))],
)
prog_Bp = Integrand(
exprs=[Bp],
times=[V],
consts=[A],
terms=[Term(eq="im,mn,ina->ia", ops=("V", "A", "Bp"))],
)
prog_strato_vb = 0.5 * (prog_B0 + prog_Bp)
self.force_v_moments = integrate(
self.data, prog_strato_vb, reduce="sum", chunk_target_bytes=self._chunk_target_bytes
)
logger.debug("Computing Strato gradient term.")
# Gradient correction: G = basis.d_x() with mask forwarded normally.
# The leaf's fill_policy='zerostop' uses jnp.where to zero masked
# entries, which preserves the Jacobian for active particles.
#
# For interacting bases (pdepth >= 1), the full cross-particle
# Jacobian d_x() has shape (N_out, N_in, d, d, F) which doesn't
# match the Strato einsum 'ioma'. We use same_particle=True to
# get the diagonal block (N, d, d, F) = (i, o, m, a) instead.
# This is physically correct: the Strato correction only needs
# ∂b_a(x_i)/∂x_i (same-particle gradient) IF the diffusion is
# diagonal on particle level (no noise correlations between particles).
_interacting = getattr(self.force_basis, "particles_input", False)
G = ExprOperand(
expr=self.force_basis.d_x(same_particle=_interacting),
x=stream("X"),
alias="G",
)
D = self.get_diffusion_timeop("noisy") # already a TimeOperand with alias
prog_grad = Integrand(
exprs=[G],
times=[D],
consts=[A],
terms=[Term(eq="imn,ioma,no->ia", ops=(D.alias, "G", "A"))],
)
self._force_D_grad_b_average = integrate(
self.data, prog_grad, reduce="sum", chunk_target_bytes=self._chunk_target_bytes
)
return self.force_v_moments - self._force_D_grad_b_average
def _diffusion_G_matrix(self) -> jnp.ndarray:
# Diffusion is a per-point quantity: weight the projection per-point
# (weight_by_dt=False), consistently with _diffusion_moments.
return self.__G_matrix__(
self.diffusion_basis,
self.diffusion_basis,
self.__diffusion_G_mode__,
"imna,imnb->iab",
weight_by_dt=False,
)
def _diffusion_moments(self) -> jnp.ndarray:
logger.debug("Computing diffusion linear moments.")
D_op = self.get_diffusion_timeop(self.__diffusion_M_mode__)
B = ExprOperand(expr=self.diffusion_basis, x=stream("X"), alias="B")
prog = Integrand(
exprs=[B],
times=[D_op],
terms=[Term(eq="imna,imn->ia", ops=("B", D_op.alias))],
)
# Per-point (weight_by_dt=False): each increment's diffusion estimate counts equally.
return integrate(self.data, prog, reduce="sum", weight_by_dt=False,
chunk_target_bytes=self._chunk_target_bytes)
def _update_force_inferred(self) -> None:
"""
Materialize the fitted force as an SF and wrap it with param covariance.
Produces: self.force_inferred : InferenceResultSF
"""
if hasattr(self, "force_basis"):
# Basis -> PSF is the supported path; tests already use .to_psf()
P = self.force_basis.to_psf()
theta = {"coeff": jnp.asarray(self.force_coefficients_full)}
elif hasattr(self, "force_psf"):
P = self.force_psf
theta = self.force_params_nonlinear
sf = SF(P, theta) # fixed-θ state function
meta = dict(
kind="force",
modes=dict(
M=getattr(self, "__force_M_mode__", None),
G=getattr(self, "__force_G_mode__", None),
),
A_inv=jnp.asarray(getattr(self, "A_inv", None)) if hasattr(self, "A_inv") else None,
basis_features=int(getattr(self.force_basis, "n_features", 0)),
basis_labels=getattr(self.force_basis, "labels", None),
)
cov = getattr(self, "force_coefficients_cov", None) # may be absent
self.force_inferred = InferenceResultSF(sf, param_cov=cov, meta=meta)
def _update_diffusion_inferred(self) -> None:
"""
Materialize the fitted diffusion tensor as an SF and wrap it.
Produces: self.diffusion_inferred : InferenceResultSF
"""
if self.diffusion_basis is None:
raise RuntimeError("_update_diffusion_inferred called before diffusion was fitted.")
P: "PSF" = self.diffusion_basis.to_psf() # rank-2 PSF
theta = {"coeff": jnp.asarray(self.diffusion_coefficients_full)}
sf = SF(P, theta) # callable (x[, mask/extras]) -> (i, m, n)
meta = dict(
kind="diffusion",
mode=getattr(self, "__diffusion_M_mode__", None),
A_inv=jnp.asarray(getattr(self, "A_inv", None)) if hasattr(self, "A_inv") else None,
basis_features=int(getattr(self.diffusion_basis, "n_features", 0)),
basis_labels=getattr(self.diffusion_basis, "labels", None),
)
cov = getattr(self, "diffusion_coefficients_cov", None) # may be absent
self.diffusion_inferred = InferenceResultSF(sf, param_cov=cov, meta=meta)
def __G_matrix__(
self,
b_left: callable,
b_right: callable,
G_mode: str,
einsum_string: str,
subsampling: int = 1,
weight_by_dt: bool = True,
) -> jnp.ndarray:
"""
Compute Gram/normalization matrix G = < b_left ⊗ b_right > with chosen mode.
Arguments
---------
b_left, b_right : stateexpr callables
Each must follow the contract expr(x, v=..., mask=..., extras=..., params=...),
producing arrays with features on the last axis. Particle axis is leading i.
G_mode : {'rectangle','trapeze','shift'}
rectangle: b_left(X_t) ⊗ b_right(X_t)
trapeze: b_left(X_t) ⊗ 0.5 [ b_right(X_t) + b_right(X_{t+}) ]
shift: b_left(X_t) ⊗ b_right(X_{t+})
einsum_string : str
Einstein string including the particle index in inputs (e.g. 'iam,ibn->iabmn').
subsampling : int
Use every `subsampling`-th time row.
Returns
-------
jnp.ndarray
Time-averaged Gram matrix with particle axis reduced by the integrator.
"""
logger.debug("Computing G matrix with einsum: %s", einsum_string)
BL = ExprOperand(expr=b_left, x=stream("X"), alias="BL")
BR0 = ExprOperand(expr=b_right, x=stream("X"), alias="BR0")
BRp = ExprOperand(expr=b_right, x=stream("X_plus"), alias="BRP")
if G_mode == "rectangle":
prog = Integrand(exprs=[BL, BR0], terms=[Term(eq=einsum_string, ops=("BL", "BR0"))])
elif G_mode == "trapeze":
rect = Integrand(exprs=[BL, BR0], terms=[Term(eq=einsum_string, ops=("BL", "BR0"))])
shift = Integrand(exprs=[BL, BRp], terms=[Term(eq=einsum_string, ops=("BL", "BRP"))])
prog = 0.5 * (rect + shift)
elif G_mode == "shift":
prog = Integrand(exprs=[BL, BRp], terms=[Term(eq=einsum_string, ops=("BL", "BRP"))])
else:
raise KeyError("Wrong G_mode argument")
# Mask-aware reduction over particles; Teff mean over time
return integrate(
self.data,
prog,
reduce="sum",
reduce_over_particles=True,
subsampling=subsampling,
weight_by_dt=weight_by_dt,
chunk_target_bytes=self._chunk_target_bytes,
)
def _build_diffusion_timeoperands(self):
"""
Construct and cache TimeOperands for diffusion estimators.
Safe to call multiple times; idempotent.
"""
if hasattr(self, "_diff_ops"):
return
self._diff_ops = {
"MSD": TimeOperand(_D_msd, alias="D_msd"),
"noisy": TimeOperand(_D_noisy, alias="D_noisy"),
"WeakNoise": TimeOperand(_D_weaknoise, alias="D_weaknoise"),
"Lambda": TimeOperand(_Lambda, alias="Lambda"),
}
[docs]
def get_diffusion_timeop(self, method: str) -> TimeOperand:
"""
Return the requested overdamped diffusion estimator as a TimeOperand.
Output is (N, d, d). Required streams are declared on the wrapped TimeOp.
"""
self._build_diffusion_timeoperands()
try:
return self._diff_ops[method] # type: ignore[attr-defined]
except KeyError as e:
raise KeyError(f"Unknown diffusion estimator method: {method}") from e
# -------------------- Overdamped diffusion estimators as TimeOps --------------------
@timeop(name="D_msd", batch_safe=True)
def _D_msd(**streams):
r"""
MSD estimator (per particle): 0.5 * dX ⊗ (dX / dt)
Returns (..., N, d, d) — batch-safe.
.. physics:: MSD diffusion estimator (overdamped)
:label: D-msd-overdamped
:category: Estimator
.. math::
\hat D_{\text{MSD}}(t)
= \tfrac{1}{2}\,\mathrm{d}X_t \otimes
\frac{\mathrm{d}X_t}{\mathrm{d}t}
Simplest estimator; biased by measurement noise.
"""
dX = streams["dX"]
dt = streams["dt"]
while dt.ndim < dX.ndim:
dt = dt[..., jnp.newaxis]
v = dX / dt
return 0.5 * jnp.einsum("...m,...n->...mn", dX, v)
_D_msd._requires = frozenset({"dX", "dt"}) # type: ignore[attr-defined]
@timeop(name="D_noisy", batch_safe=True)
def _D_noisy(**streams):
r"""
Noisy diffusion estimator — Vestergaard–Blainey–Flyvbjerg (per particle):
1/4 [ dX⊗(dX/dt) + 2 dX⊗(dX^-/dt) + 2 dX^-⊗(dX/dt) + dX^-⊗(dX^-/dt) ].
Returns (N, d, d).
.. physics:: Noisy (Vestergaard–Blainey–Flyvbjerg) diffusion estimator
:label: D-noisy
:category: Estimator
.. math::
\hat D_{\text{noisy}}(t) = \tfrac{1}{4}\bigl[
\mathrm{d}X_t \otimes v_t
+ 2\,\mathrm{d}X_t \otimes v_{t-1}
+ 2\,\mathrm{d}X_{t-1} \otimes v_t
+ \mathrm{d}X_{t-1} \otimes v_{t-1}
\bigr]
Two-point estimator robust to measurement noise
(CL Vestergaard, PC Blainey, H Flyvbjerg - Physical Review E, 2014).
"""
dX = streams["dX"]
dXm = streams["dX_minus"]
dt = streams["dt"]
while dt.ndim < dX.ndim:
dt = dt[..., jnp.newaxis]
invdt = 1.0 / dt
a = jnp.einsum("...m,...n->...mn", dX, dX * invdt)
b = jnp.einsum("...m,...n->...mn", dX, dXm * invdt)
c = jnp.einsum("...m,...n->...mn", dXm, dX * invdt)
d = jnp.einsum("...m,...n->...mn", dXm, dXm * invdt)
return 0.25 * (a + 2 * b + 2 * c + d)
_D_noisy._requires = frozenset({"dX", "dX_minus", "dt"}) # type: ignore[attr-defined]
@timeop(name="D_weaknoise", batch_safe=True)
def _D_weaknoise(**streams):
r"""
Weak-noise estimator (per particle):
1/4 * ( (dX - dX^-) ⊗ (dX/dt - dX^-/dt) ).
Returns (N, d, d).
.. physics:: Weak-noise diffusion estimator (overdamped)
:label: D-weaknoise-overdamped
:category: Estimator
.. math::
\hat D_{\text{WN}}(t)
= \tfrac{1}{4}\bigl(\mathrm{d}X_t - \mathrm{d}X_{t-1}\bigr)
\otimes \bigl(v_t - v_{t-1}\bigr)
Uses successive-displacement differences; suitable when localization
noise is negligible.
"""
dX = streams["dX"]
dXm = streams["dX_minus"]
dt = streams["dt"]
while dt.ndim < dX.ndim:
dt = dt[..., jnp.newaxis]
invdt = 1.0 / dt
ddx = dX - dXm
dv = dX * invdt - dXm * invdt
return 0.25 * jnp.einsum("...m,...n->...mn", ddx, dv)
_D_weaknoise._requires = frozenset({"dX", "dX_minus", "dt"}) # type: ignore[attr-defined]
@timeop(name="Lambda_meas_noise", batch_safe=True)
def _Lambda(**streams):
r"""
Measurement-noise cross term (per particle):
Λ_i = -0.5 [ dX_i ⊗ dX^-_i + dX^-_i ⊗ dX_i ].
Returns (N, d, d). No 1/dt factor inside.
.. physics:: Measurement noise estimator (overdamped)
:label: Lambda-overdamped
:category: Estimator
.. math::
\hat\Lambda_i
= -\,\tfrac{1}{2}\bigl[
\mathrm{d}X_i \otimes \mathrm{d}X_{i-1}
+ \mathrm{d}X_{i-1} \otimes \mathrm{d}X_i
\bigr]
Estimates localization / measurement noise from anti-correlation
of successive increments.
"""
dX = streams["dX"]
dXm = streams["dX_minus"]
return -0.5 * (jnp.einsum("...m,...n->...mn", dX, dXm) + jnp.einsum("...m,...n->...mn", dXm, dX))
_Lambda._requires = frozenset({"dX", "dX_minus"}) # type: ignore[attr-defined]