from __future__ import annotations
import contextlib
import logging
from abc import ABC, abstractmethod
from pathlib import Path
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from SFI.inference.sparse import SparsityResult
import jax
import jax.numpy as jnp
from SFI.integrate.api import integrate
from SFI.integrate.integrand import (
ConstOperand,
ExprOperand,
Integrand,
Term,
TimeOperand,
)
from SFI.integrate.timeops import stream, timeop, velocity
from SFI.statefunc.nodes.interactions.prepare import (
prepare_collection_for_expr,
)
from SFI.utils.maths import stable_pinv
logger = logging.getLogger(__name__)
[docs]
class BaseLangevinInference(ABC):
"""Stochastic Force Inference main class
This class provides tools for inferring force (drift) and
diffusion tensors from stochastic trajectory data based on
Langevin dynamics. It contains the shared logic for Overdamped and
Underdamped Langevin inference.
These subclasses must implement a handful of hooks that depend on
the physics (e.g. whether velocities are observed). The details of
the physics assumptions and definitions, as well as extensive doc
strings, are given in the headers of these classes.
Key Features
------------
- **Force inference** — linear combination of basis functions
(``infer_force_linear`` with a ``Basis``, the canonical path) or
parametric families (``infer_force`` with a ``Basis`` or ``PSF``, the
single-step flow estimator with native (D, Λ) profiling).
- **Diffusion inference** — constant (``compute_diffusion_constant``) or
state-dependent via a linear basis (``infer_diffusion_linear``).
- **Sparsification** — pluggable strategies (beam search, greedy, STLSQ,
LASSO) with information-criterion selection (AIC, BIC, PASTIS).
- **Error estimation** — normalized mean-squared-error (NMSE) prediction
for force and diffusion.
- **Comparison** — evaluate inferred fields against known exact models
(``compare_to_exact``).
- **Simulation** — generate trajectories from inferred fields
(``simulate_bootstrapped_trajectory``).
Workflow
--------
1. Initialize with a ``TrajectoryCollection`` holding the trajectory.
2. Use the ``infer_*`` methods to infer force and diffusion fields.
3. Optionally sparsify the results to mitigate overfitting.
4. Optionally compute error estimates and/or compare with exact data.
Indices Convention
------------------
The code uses ``jnp.einsum`` with a consistent index naming scheme:
- ``t`` — time index, 0..Ntimesteps-1.
- ``a, b, c...`` — basis-function indices, 0..Nfunctions-1.
- ``m, n, o...`` — state / spatial indices, 0..dim-1.
- ``i, j...`` — particle indices (size Nparticles, or 1 if there is no
particle structure).
These also serve as array-shape shorthands: e.g.
``basis_linear : im -> iam`` means ``basis_linear`` takes an array of
shape (Nparticles, dim) and returns one of shape
(Nparticles, Nfunctions, dim).
Logging levels control output (configure via ``logging`` or
``SFI.enable_logging()``):
- INFO -> inference steps, key results.
- DEBUG -> detailed computation progress.
"""
def __init__(self, data, *, max_memory_gb=1.0, **kwargs):
"""Initialize the inference object.
Parameters
----------
data : TrajectoryCollection
The trajectory data to infer from.
max_memory_gb : float
Approximate memory budget for integration batches.
Raises
------
TypeError
If ``data`` is not a ``TrajectoryCollection``.
"""
from SFI.trajectory.collection import TrajectoryCollection
if not isinstance(data, TrajectoryCollection):
raise TypeError(
f"Expected a TrajectoryCollection, got {type(data).__name__}. "
f"Use TrajectoryCollection.from_dataset(ds) or "
f"TrajectoryCollection.from_arrays(X=..., dt=...) first."
)
self.data = data
self.max_memory_gb = max_memory_gb
self.metadata = {}
@property
def _chunk_target_bytes(self) -> int:
"""Convert ``max_memory_gb`` to bytes for the integration engine."""
return max(1, int(self.max_memory_gb * 1024**3))
@contextlib.contextmanager
def _structural_scope(self, *exprs):
"""Expose the structural (CSR/stencil) arrays required by ``exprs`` for the
duration of the block, without ever persisting them on the dataset.
On entry ``self.data`` is swapped for a transient collection carrying the
freshly-built structural arrays; on exit the original collection object is
restored. Nothing is left on the dataset, so a stale structural table can
never survive a later transform, and teardown is exception-safe (there is
no purge step to forget).
"""
original = self.data
try:
exprs = tuple(e for e in exprs if e is not None)
if exprs:
self.data = prepare_collection_for_expr(original, *exprs)
yield
finally:
self.data = original
# ---- shared validation helpers ---------------------------------------- #
@staticmethod
def _validate_basis(basis, *, expected_rank: int, label: str = "basis"):
"""Check that ``basis`` is a Basis with the correct rank.
Parameters
----------
basis : object
Should be a ``Basis`` instance.
expected_rank : int
Required spatial rank (1 for force, 2 for diffusion).
label : str
Human-readable name for error messages.
Raises
------
TypeError
If ``basis`` is not a ``Basis``.
ValueError
If the rank does not match ``expected_rank``.
"""
from SFI.statefunc import Basis as _Basis
if not isinstance(basis, _Basis):
from SFI.statefunc import PSF as _PSF
from SFI.statefunc import SF as _SF
if isinstance(basis, (_PSF, _SF)):
raise TypeError(
f"`{label}` must be a Basis (deterministic dictionary), "
f"not a {type(basis).__name__}. "
f"For parametric inference with a PSF, use infer_force()."
)
raise TypeError(f"`{label}` must be a Basis object, got {type(basis).__name__}.")
if basis.rank != expected_rank:
raise ValueError(
f"`{label}` must have rank={expected_rank} "
f"(got rank={basis.rank}). "
+ (
"Force bases should produce vectors (rank 1)."
if expected_rank == 1
else "Diffusion bases should produce matrices (rank 2)."
)
)
[docs]
def sparsify_force(
self,
*,
criterion: str = "PASTIS",
p: float = 0.05,
method: str = "beam",
max_k: int | None = None,
**strategy_kwargs,
) -> "SparsityResult":
"""Sparsify the inferred force by selecting a subset of basis functions.
Builds a Pareto front of sparse models using the chosen
``method``, then selects the model that maximises the given
information ``criterion``.
Parameters
----------
criterion : ``"PASTIS"`` | ``"AIC"`` | ``"BIC"`` | ``"EBIC"`` | ``"SIC"``, default ``"PASTIS"``
Information criterion for model selection.
p : float, default 0.05
Prior-scale parameter :math:`p_0` for the PASTIS penalty.
method : str, default ``"beam"``
Search strategy. One of:
- ``"beam"`` — bidirectional beam search (PASTIS original).
Extra kwargs: ``beam_width`` (int, default 3),
``aic_patience`` (int, default 2).
- ``"greedy"`` — forward stepwise selection.
Extra kwargs: ``direction`` (``"forward"`` | ``"backward"``
| ``"both"``, default ``"forward"``).
- ``"stlsq"`` — Sequential Thresholded Least Squares
(SINDy-style). Extra kwargs: ``threshold`` (float or None),
``mode`` (``"relative"`` | ``"absolute"``),
``n_thresholds`` (int).
- ``"lasso"`` — :math:`\\ell_1`-penalised coordinate descent.
Extra kwargs: ``alpha`` (float or None),
``n_alphas`` (int).
- ``"hillclimb"`` — stochastic hill-climbing
(Gerardos & Ronceray, 2025). Extra kwargs: ``ic``,
``patience`` (int), ``seed`` (int or None).
max_k : int or None
Maximum model size. Defaults to the full basis size.
**strategy_kwargs
Passed to the strategy constructor.
Returns
-------
SparsityResult
The full Pareto-front result, also stored as
``self.force_sparsity_result``.
"""
from SFI.inference.sparse import (
BeamSearchStrategy,
GreedyStepwiseStrategy,
HillClimbStrategy,
LassoStrategy,
STLSQStrategy,
)
scorer = self.force_scorer
if max_k is None:
max_k = scorer.p
# Build the strategy object
_strategies = {
"beam": BeamSearchStrategy,
"greedy": GreedyStepwiseStrategy,
"stlsq": STLSQStrategy,
"lasso": LassoStrategy,
"hillclimb": HillClimbStrategy,
}
key = method.lower()
if key not in _strategies:
raise ValueError(f"Unknown sparsity method {method!r}. Choose from {list(_strategies)}.")
# Default beam_width for beam search
if key == "beam":
strategy_kwargs.setdefault("beam_width", 3)
strategy_kwargs.setdefault("aic_patience", 2)
strategy_kwargs.setdefault("report_time", True)
elif key == "hillclimb":
# Hill-climb uses the selection criterion as its acceptance
# objective by default, so it stays consistent with the IC
# used at the final `select_by_ic` step.
strategy_kwargs.setdefault("ic", criterion)
strategy_kwargs.setdefault("p_param", p)
strategy = _strategies[key](**strategy_kwargs)
result: SparsityResult = strategy.run(scorer, max_k=max_k)
# Select the best model according to the information criterion.
# BIC/EBIC penalise by the total trajectory time tau (== Teff);
# supply it from the data so those criteria work out of the box.
tau = float(self.data.Teff({"dt"}))
k, support, score, coeffs = result.select_by_ic(criterion, p_param=p, tau=tau)
self._update_force_coefficients(coeffs, support)
self.force_sparsity_result = result
return result
########################## ERROR ANALYSIS ###########################
[docs]
def compute_force_error(self):
r"""
Estimate sampling error for force inference.
.. physics:: Force coefficient covariance & predicted error
:label: force-error-covariance
:category: Error analysis
.. math::
\operatorname{Cov}(C) = G^{-1},
\qquad
\mathbb{E}\!\left[\langle \delta F^\top A^{-1} \delta F \rangle\right]
= \operatorname{Tr}\!\left(G\,\operatorname{Cov}(C)\right),
\qquad
I_F = \tfrac{1}{2}\,C^\top M,
\qquad
\text{NMSE}_{F,\text{pred}}
= \frac{\operatorname{Tr}(G\cdot\operatorname{Cov}(C))}{C^\top M}
= \frac{\operatorname{Tr}(G\cdot\operatorname{Cov}(C))}{2 I_F}
Assumes the sampling error dominates; measurement noise and
discretization biases are not addressed.
This method evaluates the covariance of the inferred force coefficients, the standard error,
and computes the predicted normalized mean squared error (MSE) of the inferred force field. This analysis
assumes that the sampling error dominates, and measurement noise or discretization biases are
not explicitly addressed. It is common to OLI and ULI (by construction of the normal matrix G).
Updates:
self.force_coefficients_covariance (jnp.ndarray): Covariance matrix of the force coefficients.
self.force_coefficients_stderr (jnp.ndarray): Standard error for each force coefficient.
self.force_information (float): Estimated information content of the inferred force field.
self.force_predicted_MSE (float): Predicted normalized mean squared error of the inferred force field.
"""
# Estimate the covariance of the force coefficients.
# force_G is the NLL Hessian for all force methods:
# - linear: G_ab = Σ_t dt b_a⊤ A⁻¹ b_b (GLS Gram = Onsager–Machlup Hessian)
# - parametric: G = Σ_t ψ_t⊤ Σ_t⁻¹ ψ_t (Gauss–Newton NLL Hessian)
# - nonlinear: G = H(NLL) from L-BFGS-B inverse Hessian
# In all cases Cov(C) = G⁻¹ exactly (Fisher information bound).
self.force_coefficients_covariance = self.force_G_pinv
# Calculate the standard error for each force coefficient
self.force_coefficients_stderr = jnp.einsum("aa->a", self.force_coefficients_covariance) ** 0.5
# Propagate covariance into the existing InferenceResultSF (if present)
if hasattr(self, "force_inferred") and self.force_inferred is not None:
object.__setattr__(self.force_inferred, "param_cov", self.force_coefficients_covariance)
# Compute time-integrated squared error of the force field
force_SSE = float(jnp.einsum("ab,ba", self.force_G, self.force_coefficients_covariance))
# Compute normalized MSE
if hasattr(self, "force_moments"):
self.force_information = float(0.5 * self.force_coefficients_full @ self.force_moments)
force_energy = 2.0 * self.force_information
elif hasattr(self, "force_optimization_results_nonlinear"):
self.force_information = -self.force_optimization_results_nonlinear["fun"]
# For parametric/nonlinear, force_information = -NLL_min; energy = 2*I
force_energy = 2.0 * self.force_information
else:
raise RuntimeError("Force information is unavailable. Run force inference before computing the error.")
# Guard against empty support (zero inferred force energy → null model, MSE undefined)
if force_energy == 0.0:
self.force_predicted_MSE = float("nan")
else:
self.force_predicted_MSE = float(force_SSE / force_energy)
[docs]
def compute_diffusion_error(self):
r"""
Estimate sampling error for diffusion inference.
Mirrors :meth:`compute_force_error` for the diffusion field.
Uses the diffusion Gram matrix (normal matrix) and its inverse.
For linear diffusion inference the moments covariance is
proportional to the Gram matrix, giving
``Cov(θ_D) = cov_factor * G_D⁻¹``.
.. note::
This error estimate is approximate. The diffusion inference is more
complex than the force inference: diffusion coefficients are inferred
from force residuals, a positive-definiteness constraint applies, and
the simple covariance formula ``Cov(θ_D) = cov_factor * G_D⁻¹`` may
not capture all sources of uncertainty. Treat the result as a rough
guide rather than a rigorous confidence interval.
Updates
-------
self.diffusion_coefficients_covariance : jnp.ndarray
Covariance matrix of the diffusion coefficients.
self.diffusion_coefficients_stderr : jnp.ndarray
Standard error for each diffusion coefficient.
self.diffusion_information : float
Estimated information content of the inferred diffusion field.
self.diffusion_predicted_MSE : float
Predicted normalized mean squared error.
"""
if not hasattr(self, "diffusion_G_pinv") or self.diffusion_G_pinv is None:
raise RuntimeError(
"Diffusion Gram inverse is unavailable. Run infer_diffusion_linear() before computing the error."
)
diffusion_method = self.metadata.get("diffusion_method", "linear")
if diffusion_method in ("parametric_nll",):
# G is the NLL Hessian → Cov(θ_D) = G⁻¹
cov_factor = 1.0
else:
# Linear diffusion is OLS (not GLS): the local estimator D̂_t has
# chi-squared fluctuations with Var(D̂_t) ≈ 2D².
# Propagating through the MoM formula gives
# Cov(Ĉ_D) ≈ 2D²_eff · dt · G_D⁻¹
# where D_eff = Tr(D)/d is the average eigenvalue of the diffusion tensor
# and dt is the observation interval. This is inherently approximate;
# any fixed factor misses the state-dependence of D.
d = int(self.diffusion_average.shape[-1])
D_eff = float(jnp.trace(self.diffusion_average)) / d
dt_val = float(self.data.peek_row(require={"dt"})["dt"])
cov_factor = 2.0 * D_eff**2 * dt_val
self.diffusion_coefficients_covariance = cov_factor * self.diffusion_G_pinv
self.diffusion_coefficients_cov = self.diffusion_coefficients_covariance
self.diffusion_coefficients_stderr = jnp.einsum("aa->a", self.diffusion_coefficients_covariance) ** 0.5
# Propagate into the existing InferenceResultSF
if hasattr(self, "diffusion_inferred") and self.diffusion_inferred is not None:
object.__setattr__(
self.diffusion_inferred,
"param_cov",
self.diffusion_coefficients_covariance,
)
diffusion_SSE = float(jnp.einsum("ab,ba", self.diffusion_G, self.diffusion_coefficients_covariance))
if hasattr(self, "diffusion_moments"):
self.diffusion_information = float(0.5 * self.diffusion_coefficients_full @ self.diffusion_moments)
diffusion_energy = 2.0 * self.diffusion_information
else:
self.diffusion_information = float("nan")
diffusion_energy = 0.0
if diffusion_energy == 0.0:
self.diffusion_predicted_MSE = float("nan")
else:
self.diffusion_predicted_MSE = float(diffusion_SSE / diffusion_energy)
[docs]
def diagnose(self, *, level: str = "standard", **kwargs):
"""Run the consistency-check suite from :mod:`SFI.diagnostics`.
Convenience wrapper for :func:`SFI.diagnostics.assess`. See its
docstring for the available ``level`` presets.
"""
from SFI.diagnostics import assess
return assess(self, level=level, **kwargs)
[docs]
def holdout_score(self, data, *, require_error: bool = False) -> dict:
"""Held-out NMSE of the fitted force on an independent collection.
A *side feature for data-abundant scenarios*: SFI estimates its
own accuracy from the training data (``force_predicted_MSE``)
and validates fits through the diagnostics suite, neither of
which costs any data. Reach for an explicit train/test split
(:meth:`TrajectoryCollection.split_time
<SFI.trajectory.TrajectoryCollection.split_time>`) only when
data is plentiful, or to confirm a suspected bias floor: a
``ratio`` near 1 means the fit is sampling-limited, a ratio
``≫ 1`` means a bias floor (often measurement noise — see the
noise-and-sampling guide).
The score is the residual-based normalised mean-square error of
``force_inferred`` on ``data``, with the diffusion noise floor
subtracted (a bias detector, not a precision instrument: its
resolution is set by the χ² fluctuations of the residuals).
Parameters
----------
data : TrajectoryCollection
Independent test data (e.g. the second half of
``coll.split_time(0.8)``).
require_error : bool
If True, run :meth:`compute_force_error` first when the
predicted error is missing, so ``ratio`` is always defined.
Returns
-------
dict
``{"holdout_NMSE", "predicted_NMSE", "ratio", "excess_z",
"n_obs"}``. Also stored as ``self.force_holdout_NMSE``.
Notes
-----
Bases that read *time-dependent* extras are not supported on the
held-out path (the residual builders pass extras unsliced).
"""
from SFI.diagnostics.residual_tests import mse_consistency
from SFI.diagnostics.residuals import build_residuals
if require_error and getattr(self, "force_predicted_MSE", None) is None:
self.compute_force_error()
bundle = build_residuals(self, data=data)
out = mse_consistency(self, bundle)
result = {
"holdout_NMSE": out.get("realised_NMSE"),
"predicted_NMSE": out.get("predicted_NMSE"),
"ratio": out.get("ratio"),
"excess_z": out.get("excess_z"),
"n_obs": bundle.n_obs,
}
self.force_holdout_NMSE = result["holdout_NMSE"]
return result
[docs]
def print_report(self):
"""
Print a summary report of the inference results.
Provides insights into the inferred diffusion and force fields, along with error metrics
such as sampling error, trajectory length, discretization bias, and measurement noise.
"""
print("\n --- StochasticForceInference Report --- ")
# Average diffusion tensor
print("Average diffusion tensor:\n", self.diffusion_average)
# Measurement noise tensor
print("Measurement noise tensor:\n", self.Lambda)
# Entropy production
if hasattr(self, "DeltaS"):
print(
"Entropy production: inferred/bootstrapped error:",
self.DeltaS,
self.error_DeltaS,
)
# Force inference metrics
if hasattr(self, "force_predicted_MSE"):
print("Force estimated information:", self.force_information)
print(
"Force: estimated normalized mean squared error (sampling only):",
self.force_predicted_MSE,
)
# To add: bias estimates
# Diffusion error metrics
if hasattr(self, "diffusion_predicted_MSE"):
print("Diffusion estimated information:", self.diffusion_information)
print(
"Diffusion: estimated normalized mean squared error (sampling only):",
self.diffusion_predicted_MSE,
)
# To add: bias estimates
# Exact-comparison NMSE (set by compare_to_exact)
if hasattr(self, "NMSE_force"):
print(f"Normalized MSE (force): {self.NMSE_force:.4f}")
if hasattr(self, "NMSE_diffusion"):
print(f"Normalized MSE (diffusion): {self.NMSE_diffusion:.4f}")
# Force coefficient table
if hasattr(self, "force_coefficients_full"):
print()
print(self.summary("force"))
[docs]
def summary(self, field: str = "force") -> str:
"""
Return a formatted coefficient table for the inferred model.
Parameters
----------
field : ``"force"`` or ``"diffusion"``
Which inferred field to summarize.
Returns
-------
str
Multi-line table ready for ``print()``.
"""
import numpy as np
from SFI.utils.formatting import model_summary
if field == "force":
labels = getattr(self, "force_basis_labels", None)
coeffs = getattr(self, "force_coefficients_full", None)
stderr = getattr(self, "force_coefficients_stderr", None)
support = getattr(self, "force_support", None)
title = "Force Coefficient Table"
elif field == "diffusion":
labels = getattr(self, "diffusion_basis_labels", None)
coeffs = getattr(self, "diffusion_coefficients_full", None)
stderr = getattr(self, "diffusion_coefficients_stderr", None)
support = getattr(self, "diffusion_support", None)
title = "Diffusion Coefficient Table"
else:
raise ValueError(f"Unknown field {field!r}; expected 'force' or 'diffusion'.")
if coeffs is None:
return f" (No {field} coefficients available.)"
coeffs = np.asarray(coeffs)
n = coeffs.shape[0]
auto_labels = labels is None
if auto_labels:
labels = [f"b{j}" for j in range(n)]
if stderr is not None:
stderr = np.asarray(stderr)
# support may be full or sparse
support_arr = None
if support is not None:
support_arr = np.asarray(support)
if support_arr.shape[0] == n:
support_arr = None # full support, do not highlight
return model_summary(
labels,
coeffs,
stderr=stderr,
support=support_arr,
title=title,
auto_labels=auto_labels,
)
[docs]
def report_dict(self) -> dict:
"""Return a structured summary of inference results as a dictionary.
This is the machine-readable counterpart of :meth:`print_report`.
All values are plain Python scalars or numpy arrays (no JAX arrays).
Returns
-------
dict
Keys include ``"diffusion_average"``, ``"Lambda"``,
``"force_information"``, ``"force_predicted_MSE"``,
``"NMSE_force"``, ``"NMSE_diffusion"``, and others
when available. Missing quantities are omitted.
"""
import numpy as np
d: dict = {}
d["metadata"] = dict(self.metadata)
if hasattr(self, "diffusion_average"):
d["diffusion_average"] = np.asarray(self.diffusion_average)
if hasattr(self, "Lambda"):
d["Lambda"] = np.asarray(self.Lambda)
if hasattr(self, "DeltaS"):
d["DeltaS"] = float(self.DeltaS)
d["error_DeltaS"] = float(self.error_DeltaS)
if hasattr(self, "force_coefficients"):
d["force_coefficients"] = np.asarray(self.force_coefficients)
if hasattr(self, "force_coefficients_full"):
d["force_coefficients_full"] = np.asarray(self.force_coefficients_full)
if hasattr(self, "force_support"):
d["force_support"] = np.asarray(self.force_support)
if hasattr(self, "force_coefficients_stderr"):
d["force_coefficients_stderr"] = np.asarray(self.force_coefficients_stderr)
if hasattr(self, "force_information"):
d["force_information"] = float(self.force_information)
if hasattr(self, "force_predicted_MSE"):
d["force_predicted_MSE"] = float(self.force_predicted_MSE)
if hasattr(self, "NMSE_force"):
d["NMSE_force"] = float(self.NMSE_force)
if hasattr(self, "MSE_force"):
d["MSE_force"] = float(self.MSE_force)
if hasattr(self, "diffusion_coefficients"):
d["diffusion_coefficients"] = np.asarray(self.diffusion_coefficients)
if hasattr(self, "diffusion_information"):
d["diffusion_information"] = float(self.diffusion_information)
if hasattr(self, "diffusion_predicted_MSE"):
d["diffusion_predicted_MSE"] = float(self.diffusion_predicted_MSE)
if hasattr(self, "NMSE_diffusion"):
d["NMSE_diffusion"] = float(self.NMSE_diffusion)
if hasattr(self, "MSE_diffusion"):
d["MSE_diffusion"] = float(self.MSE_diffusion)
return d
[docs]
def compare_to_exact(
self,
*,
model_exact=None,
data_exact=None,
force_exact=None,
diffusion_exact=None, # callable | float | (d,d) array
maxpoints: int = 1000,
) -> None:
r"""
Compare inferred vs exact using dt-weighted time means via the integrate() API.
This function evaluates the inferred force/diffusion against "exact" references
on a (possibly exact/synthetic) dataset. It updates:
self.MSE_force / self.NMSE_force
self.MSE_diffusion / self.NMSE_diffusion
Inputs: exact references
~~~~~~~~~~~~~~~~~~~~~~~~
You can provide exact references in two ways:
1) Preferred: `model_exact`
A model object (from SFI.langevin submodule) exposing:
- model_exact.force_sf : exact force/drift (SF/StateExpr-like)
- model_exact.diffusion_sf : exact diffusion (SF/StateExpr-like)
OR a constant (float or (d,d) matrix) via ``model_exact.D``
2) Explicit: `force_exact`, `diffusion_exact`
- `force_exact`: SF/StateExpr-like callable returning (N, d)
- `diffusion_exact`:
* callable returning (N, d, d), OR
* float meaning σ·I, OR
* (d,d) matrix constant diffusion.
These are used if `model_exact` is not provided. If `model_exact` is provided,
its members take precedence unless they are missing, in which case the explicit
arguments can be used as fallback.
Velocity provisioning (underdamped)
-----------------------------------
If an evaluated expression advertises `needs_v=True`, this routine supplies:
v := dX/dt (secant velocity from the data stream)
i.e. it uses `velocity("dX", "dt")` as the `v=...` keyword argument. This works for both
exact and inferred expressions and keeps underdamped comparisons possible even when
the dataset only stores positions.
Metrics
-------
Force:
e = Fe - Fh
MSE_force = < e^T A_inv e >
NMSE_force = MSE_force / < Fh^T A_inv Fh >
Diffusion:
E = De - Dh
MSE_diffusion = < tr(A_inv E A_inv E) >
.. physics:: Normalized MSE metrics (force & diffusion)
:label: nmse-metrics
:category: Error analysis
.. math::
\text{NMSE}_F
= \frac{\langle (F_{\text{exact}} - \hat F)^\top A^{-1}
(F_{\text{exact}} - \hat F) \rangle}
{\langle \hat F^\top A^{-1} \hat F \rangle}
.. math::
\text{NMSE}_D
= \frac{\langle \operatorname{tr}(A^{-1} E\, A^{-1} E) \rangle}
{\langle \operatorname{tr}(A^{-1} \hat D\, A^{-1} \hat D) \rangle}
where :math:`E = D_{\text{exact}} - \hat D`.
NMSE_diffusion = MSE_diffusion / < tr(A_inv Dh A_inv Dh) >
where A_inv is `self.A_inv` (typically (2 D̄)^{-1} from the inferred constant diffusion normalization).
Subsampling
-----------
Uses a simple subsampling heuristic so that the total number of evaluated points is ~<= `maxpoints`,
accounting for the maximum number of particles.
Requirements
------------
- `self.A_inv` must exist (run compute_diffusion_constant() or otherwise set A_inv).
- The dataset must provide streams `X`, `dt`, and if any evaluated expr needs v: `dX` as well.
"""
data_exact = data_exact or self.data
# ---------- resolve exact references ----------
if model_exact is not None:
F_exact = (
getattr(model_exact, "force_sf", None)
if force_exact is None
else getattr(model_exact, "force_sf", force_exact)
)
D_exact = (
getattr(model_exact, "diffusion_sf", None)
if diffusion_exact is None
else getattr(model_exact, "diffusion_sf", diffusion_exact)
)
# Fall back to model_exact.D when the bound SF is absent (e.g. constant
# diffusion in OverdampedProcess sets _D_sf=None but stores the raw value
# in the .D field).
if D_exact is None:
D_exact = diffusion_exact if diffusion_exact is not None else getattr(model_exact, "D", None)
else:
F_exact = force_exact
D_exact = diffusion_exact
if not hasattr(self, "A_inv"):
raise RuntimeError("A_inv not available. Run compute_diffusion_constant() (or equivalent) first.")
nsteps = int(getattr(data_exact, "Nsteps", 0) or 0)
nmaxp = int(getattr(data_exact, "Nmaxparticles", 1) or 1)
subsampling = max(1, nsteps // max(1, maxpoints // max(1, nmaxp))) if nsteps else 1
logger.info("Comparing to exact data...")
A = ConstOperand(jnp.asarray(self.A_inv), alias="A")
d = int(jnp.asarray(self.A_inv).shape[0])
# Helper: provide v only if the expression wants it.
def _maybe_v(expr):
return velocity("dX", "dt") if bool(getattr(expr, "needs_v", False)) else None
# Helper: wrap constant diffusion into a TimeOperand returning (N,d,d)
def _const_diffusion_operand(Dconst, *, alias: str):
Dconst = jnp.asarray(Dconst)
if Dconst.ndim == 0:
# scalar σ -> σ I
sigma = float(Dconst)
Dmat = sigma * jnp.eye(d, dtype=jnp.asarray(self.A_inv).dtype)
elif Dconst.ndim == 2:
Dmat = Dconst
else:
raise TypeError("Constant diffusion must be a float (σ) or a (d,d) matrix.")
@timeop(name=f"D_const_{alias}")
def _Dconst(**streams):
N = streams["X"].shape[0]
return jnp.broadcast_to(Dmat[None, :, :], (N, Dmat.shape[0], Dmat.shape[1]))
_Dconst._requires = frozenset({"X"}) # type: ignore[attr-defined]
return TimeOperand(_Dconst, alias=alias)
# Helper: build diffusion operand (ExprOperand or TimeOperand) with alias control.
def _diffusion_operand(Dobj, *, alias: str):
# NoiseModel instances (e.g. ConservedNoise) are not callable exprs
# and cannot be converted to JAX arrays — return None so the caller
# skips the (point-wise) diffusion comparison for them.
from SFI.langevin.noise import NoiseModel
if isinstance(Dobj, NoiseModel):
return None
if callable(Dobj):
vD = _maybe_v(Dobj)
return ExprOperand(expr=Dobj, x=stream("X"), v=vD, alias=alias)
return _const_diffusion_operand(Dobj, alias=alias)
# ------------------------------ FORCE ------------------------------ #
if hasattr(self, "force_inferred") and (F_exact is not None):
if not callable(F_exact):
raise TypeError("Exact force must be callable (SF/StateExpr-like).")
Fh = getattr(self.force_inferred, "sf", self.force_inferred)
# Structural extras needed by expr graphs are built into a transient
# collection (never persisted). Only StateExpr-like objects have them;
# plain callables are skipped (the `root` attribute marks StateExprs).
force_exprs = [e for e in (F_exact, Fh) if hasattr(e, "root")]
if force_exprs:
data_exact = prepare_collection_for_expr(data_exact, *force_exprs)
Fe_op = ExprOperand(expr=F_exact, x=stream("X"), v=_maybe_v(F_exact), alias="Fe")
Fh_op = ExprOperand(expr=Fh, x=stream("X"), v=_maybe_v(Fh), alias="Fh")
# Numerator: ⟨(Fe−Fh)^T A (Fe−Fh)⟩
num_prog = Integrand(
exprs=[Fe_op, Fh_op],
consts=[A],
terms=[
Term(eq="im,mn,in->i", ops=("Fe", "A", "Fe"), scale=+1.0),
Term(eq="im,mn,in->i", ops=("Fe", "A", "Fh"), scale=-2.0),
Term(eq="im,mn,in->i", ops=("Fh", "A", "Fh"), scale=+1.0),
],
)
# Denominator: ⟨Fh^T A Fh⟩
den_prog = Integrand(
exprs=[Fh_op],
consts=[A],
terms=[Term(eq="im,mn,in->i", ops=("Fh", "A", "Fh"))],
)
num = integrate(
data_exact,
num_prog,
reduce="mean",
reduce_over_particles=True,
subsampling=subsampling,
chunk_target_bytes=self._chunk_target_bytes,
)
den = integrate(
data_exact,
den_prog,
reduce="mean",
reduce_over_particles=True,
subsampling=subsampling,
chunk_target_bytes=self._chunk_target_bytes,
)
self.MSE_force = num
self.NMSE_force = num / (den + 1e-12)
logger.info("Normalized MSE (force): %s", self.NMSE_force)
# ------------------------------ DIFFUSION ------------------------------ #
if D_exact is not None:
# Inferred diffusion: prefer callable diffusion_inferred if available, else constant diffusion_average.
Dh_obj = None
if hasattr(self, "diffusion_inferred"):
Dh_candidate = getattr(self.diffusion_inferred, "sf", self.diffusion_inferred)
if callable(Dh_candidate):
Dh_obj = Dh_candidate
if Dh_obj is None:
if hasattr(self, "diffusion_average"):
Dh_obj = jnp.asarray(self.diffusion_average)
else:
raise RuntimeError("No inferred diffusion callable and no diffusion_average available.")
# Build operands (supports callable OR constant for both exact and inferred).
De_op = _diffusion_operand(D_exact, alias="De")
Dh_op = _diffusion_operand(Dh_obj, alias="Dh")
if De_op is None or Dh_op is None:
logger.info(
"Skipping diffusion NMSE: exact diffusion is a NoiseModel "
"(e.g. ConservedNoise) that cannot be compared point-wise."
)
else:
# Prepare structural extras only for callable exprs with node trees.
diff_exprs = []
if isinstance(De_op, ExprOperand) and hasattr(D_exact, "root"):
diff_exprs.append(D_exact)
if isinstance(Dh_op, ExprOperand) and hasattr(Dh_obj, "root"):
diff_exprs.append(Dh_obj)
if diff_exprs:
data_exact = prepare_collection_for_expr(data_exact, *diff_exprs)
exprs_num = [op for op in (De_op, Dh_op) if isinstance(op, ExprOperand)]
times_num = [op for op in (De_op, Dh_op) if isinstance(op, TimeOperand)]
exprs_den = [Dh_op] if isinstance(Dh_op, ExprOperand) else []
times_den = [Dh_op] if isinstance(Dh_op, TimeOperand) else []
# Numerator: ⟨ tr(A (De−Dh) A (De−Dh)) ⟩
# Expanded form with contractions using eq="imn,iop,no,pm->i".
num_prog = Integrand(
exprs=exprs_num,
times=times_num,
consts=[A],
terms=[
Term(eq="imn,iop,no,pm->i", ops=("De", "De", "A", "A"), scale=+1.0),
Term(eq="imn,iop,no,pm->i", ops=("De", "Dh", "A", "A"), scale=-2.0),
Term(eq="imn,iop,no,pm->i", ops=("Dh", "Dh", "A", "A"), scale=+1.0),
],
)
den_prog = Integrand(
exprs=exprs_den,
times=times_den,
consts=[A],
terms=[Term(eq="imn,iop,no,pm->i", ops=("Dh", "Dh", "A", "A"))],
)
num = integrate(
data_exact,
num_prog,
reduce="mean",
reduce_over_particles=True,
subsampling=subsampling,
chunk_target_bytes=self._chunk_target_bytes,
)
den = integrate(
data_exact,
den_prog,
reduce="mean",
reduce_over_particles=True,
subsampling=subsampling,
chunk_target_bytes=self._chunk_target_bytes,
)
self.MSE_diffusion = num
self.NMSE_diffusion = num / (den + 1e-12)
logger.info("Normalized MSE (diffusion): %s", self.NMSE_diffusion)
# ------------------------------------------------------------------
# Exact-vs-inferred sample arrays + scatter (graphical comparison)
# ------------------------------------------------------------------
def _comparison_points(self, data=None, *, maxpoints: int = 2000, need_v: bool = False):
"""Subsample ``(X_flat, V_flat|None, mask_flat)`` to ~``maxpoints`` points."""
import numpy as np
data = data if data is not None else self.data
t, X, M = data.to_arrays(dataset=0)
X = np.asarray(X)
M = np.asarray(M)
T, _, d = X.shape
stride = max(1, T // max(1, maxpoints // max(1, X.shape[1])))
Xs, Ms = X[::stride], M[::stride]
Vf = None
if need_v:
from SFI.utils.maths import fd_velocity
dt = np.diff(np.asarray(t, dtype=float))
Vf = np.asarray(fd_velocity(X, dt))[::stride].reshape(-1, d)
return Xs.reshape(-1, d), Vf, Ms.reshape(-1).astype(bool)
def _eval_on_points(self, field, Xf, Vf):
"""Evaluate a callable force/field on flat points, supplying ``v`` when needed."""
import jax.numpy as jnp
import numpy as np
fn = getattr(field, "sf", field)
needs_v = bool(getattr(field, "needs_v", False))
if needs_v and Vf is not None:
out = fn(jnp.asarray(Xf), v=jnp.asarray(Vf))
else:
out = fn(jnp.asarray(Xf))
return np.asarray(out)
def _eval_diffusion_on_points(self, Dobj, Xf, Vf):
"""Evaluate a diffusion field (callable) or broadcast a constant to ``(M, d, d)``."""
import numpy as np
if callable(getattr(Dobj, "sf", Dobj)):
return self._eval_on_points(Dobj, Xf, Vf)
Dc = np.asarray(Dobj)
d = Xf.shape[1]
if Dc.ndim == 0:
Dmat = float(Dc) * np.eye(d)
elif Dc.ndim == 2:
Dmat = Dc
else:
raise TypeError("Constant diffusion must be a scalar or a (d, d) matrix.")
return np.broadcast_to(Dmat[None, :, :], (Xf.shape[0], d, d))
[docs]
def force_comparison_arrays(self, *, model_exact=None, force_exact=None, data=None, maxpoints: int = 2000):
"""Return ``(F_exact, F_inferred)`` evaluated along the trajectory.
Evaluates the exact and inferred force on the (subsampled, masked)
trajectory points, supplying finite-difference velocities for
underdamped fields. Feeds :meth:`comparison_scatter`; also handy
for custom diagnostics.
Parameters
----------
model_exact :
Object exposing ``force_sf`` (e.g. an ``OverdampedProcess``).
force_exact :
Explicit callable exact force (overrides ``model_exact``).
data :
Collection to evaluate on (default: the training data).
maxpoints :
Approximate number of points to evaluate.
Returns
-------
(F_exact, F_inferred) : tuple of ndarray, shape ``(n_points, d)``
"""
F_exact = force_exact
if F_exact is None and model_exact is not None:
F_exact = getattr(model_exact, "force_sf", None)
if F_exact is None or not callable(F_exact):
raise ValueError(
"force_comparison_arrays needs a callable exact force "
"(model_exact.force_sf or force_exact=)."
)
if not hasattr(self, "force_inferred"):
raise RuntimeError("No inferred force; run infer_force_linear / infer_force first.")
need_v = bool(getattr(F_exact, "needs_v", False)) or bool(
getattr(self.force_inferred, "needs_v", False)
)
Xf, Vf, mask = self._comparison_points(data, maxpoints=maxpoints, need_v=need_v)
Fe = self._eval_on_points(F_exact, Xf, Vf)[mask]
Fi = self._eval_on_points(self.force_inferred, Xf, Vf)[mask]
return Fe, Fi
[docs]
def diffusion_comparison_arrays(self, *, model_exact=None, diffusion_exact=None, data=None, maxpoints: int = 2000):
"""Return ``(D_exact, D_inferred)`` evaluated along the trajectory.
Like :meth:`force_comparison_arrays` but for the diffusion field;
a constant exact/inferred diffusion is broadcast to ``(n_points,
d, d)``.
"""
import jax.numpy as jnp
D_exact = diffusion_exact
if D_exact is None and model_exact is not None:
D_exact = getattr(model_exact, "diffusion_sf", None)
if D_exact is None:
D_exact = getattr(model_exact, "D", None)
if D_exact is None:
raise ValueError(
"diffusion_comparison_arrays needs an exact diffusion "
"(model_exact or diffusion_exact=)."
)
D_inf = getattr(self, "diffusion_inferred", None)
if D_inf is None or not callable(getattr(D_inf, "sf", D_inf)):
if not hasattr(self, "diffusion_average"):
raise RuntimeError("No inferred diffusion callable and no diffusion_average available.")
D_inf = jnp.asarray(self.diffusion_average)
need_v = bool(getattr(D_exact, "needs_v", False)) or bool(getattr(D_inf, "needs_v", False))
Xf, Vf, mask = self._comparison_points(data, maxpoints=maxpoints, need_v=need_v)
De = self._eval_diffusion_on_points(D_exact, Xf, Vf)[mask]
Di = self._eval_diffusion_on_points(D_inf, Xf, Vf)[mask]
return De, Di
[docs]
def comparison_scatter(self, *, model_exact=None, field: str = "force", data=None, ax=None, maxpoints: int = 2000, **plot_kw):
"""Scatter inferred-vs-exact force (or diffusion) along the trajectory.
Evaluates both fields on the data with
:meth:`force_comparison_arrays` / :meth:`diffusion_comparison_arrays`
and renders them with
:func:`SFI.utils.plotting.comparison_scatter` (identity line +
Pearson ``r`` + MSE). Replaces hand-rolled exact-vs-inferred
scatters in demos.
Parameters
----------
model_exact :
Object exposing ``force_sf`` / ``diffusion_sf`` / ``D``.
field : {"force", "diffusion"}
Which field to compare.
data :
Collection to evaluate on (default: training data).
ax :
Target axes (default: current axes).
maxpoints :
Approximate number of points to evaluate.
**plot_kw :
Forwarded to :func:`SFI.utils.plotting.comparison_scatter`.
Returns
-------
matplotlib.axes.Axes
"""
import matplotlib.pyplot as plt
from SFI.utils import plotting
if field == "force":
exact, inferred = self.force_comparison_arrays(model_exact=model_exact, data=data, maxpoints=maxpoints)
elif field == "diffusion":
exact, inferred = self.diffusion_comparison_arrays(model_exact=model_exact, data=data, maxpoints=maxpoints)
else:
raise ValueError(f"Unknown field {field!r}; expected 'force' or 'diffusion'.")
if ax is not None:
plt.sca(ax)
plotting.comparison_scatter(exact, inferred, **plot_kw)
return plt.gca()
[docs]
def compare_params_to_exact(self, theta_true, *, psf=None) -> dict:
"""Compare inferred parametric coefficients to known ground truth.
For a model fitted with a parametric family, returns a per-parameter
dict of absolute and relative error. ``theta_true`` may be a flat
array (compared elementwise to ``force_coefficients_full``) or a
``{name: value}`` dict (unflattened from the fitted coefficients via
``psf.unflatten_params``, falling back to ``self.force_psf``).
Returns
-------
dict
``{name: {"true", "inferred", "abs_error", "rel_error"}}``; also
stored as ``self.parameter_comparison``.
"""
import numpy as np
theta_inf = np.asarray(self.force_coefficients_full).ravel()
out: dict = {}
if isinstance(theta_true, dict):
psf = psf if psf is not None else getattr(self, "force_psf", None)
inf_dict: dict = {}
if psf is not None and hasattr(psf, "unflatten_params"):
try:
inf_dict = {k: np.asarray(v) for k, v in dict(psf.unflatten_params(theta_inf)).items()}
except Exception:
inf_dict = {}
for name, tv in theta_true.items():
tv = np.asarray(tv, dtype=float)
iv = np.asarray(inf_dict.get(name, np.full(tv.shape, np.nan)), dtype=float)
abs_err = float(np.sqrt(np.mean((iv - tv) ** 2))) if iv.shape == tv.shape else float("nan")
rel = abs_err / (float(np.sqrt(np.mean(tv**2))) + 1e-12)
out[name] = {"true": tv, "inferred": iv, "abs_error": abs_err, "rel_error": rel}
else:
tv = np.asarray(theta_true, dtype=float).ravel()
iv = theta_inf[: tv.shape[0]]
err = float(np.linalg.norm(iv - tv))
out["theta"] = {
"true": tv,
"inferred": iv,
"abs_error": err,
"rel_error": err / (float(np.linalg.norm(tv)) + 1e-12),
}
self.parameter_comparison = out
return out
[docs]
def coeff_block(self, block, *, field: str = "force"):
"""Return the coefficient (and covariance) slice for a basis sub-block.
Compound bases (e.g. a multi-kernel or time-Fourier library) pack
several conceptual blocks into one flat coefficient vector. This
returns the slice for one block without hand-computed offsets.
Parameters
----------
block :
``(start, stop)`` indices, a ``slice``, an ``int``, or a
``Basis`` (located by matching its labels as a contiguous run of
the fitted basis labels).
field : {"force", "diffusion"}
Returns
-------
(coeffs, cov) : tuple
``coeffs`` is the 1-D slice; ``cov`` is the matching covariance
block (or ``None`` if no covariance is available).
"""
import numpy as np
coeffs = np.asarray(getattr(self, f"{field}_coefficients_full"))
cov = getattr(self, f"{field}_coefficients_covariance", None)
cov = np.asarray(cov) if cov is not None else None
if isinstance(block, slice):
i0 = block.start or 0
i1 = block.stop if block.stop is not None else len(coeffs)
elif isinstance(block, (tuple, list)) and len(block) == 2 and all(
isinstance(v, (int, np.integer)) for v in block
):
i0, i1 = int(block[0]), int(block[1])
elif isinstance(block, (int, np.integer)):
i0, i1 = int(block), int(block) + 1
elif hasattr(block, "labels") and hasattr(block, "n_features"):
full_labels = list(getattr(self, f"{field}_basis_labels", []) or [])
block_labels = list(block.labels)
nf = int(block.n_features)
i0 = None
for start in range(0, len(full_labels) - nf + 1):
if full_labels[start : start + nf] == block_labels:
i0 = start
break
if i0 is None:
raise ValueError("Could not locate the basis block within the fitted basis labels.")
i1 = i0 + nf
else:
raise TypeError("block must be (start, stop), a slice, an int, or a Basis.")
cslice = coeffs[i0:i1]
covslice = cov[i0:i1, i0:i1] if cov is not None else None
return cslice, covslice
[docs]
def predict_time_profile(self, basis_block, t, *, field: str = "force", x=None):
"""Evaluate a (time-dependent) basis block's coefficient profile at ``t``.
Contracts the fitted coefficients of ``basis_block`` with the basis's
own evaluation at times ``t`` (via the reserved ``time`` extra),
returning the time profile — e.g. ``-k(t)`` for the ``x`` block of a
time-Fourier trap. Avoids re-deriving the design matrix by hand.
"""
import jax.numpy as jnp
import numpy as np
coeffs, _ = self.coeff_block(basis_block, field=field)
t = np.asarray(t)
dim = int(getattr(basis_block, "dim", 1))
x0 = np.zeros((t.shape[0], dim)) if x is None else np.broadcast_to(np.asarray(x), (t.shape[0], dim))
duration = float(t.max() - t.min()) if t.size else 1.0
D = np.asarray(
basis_block(jnp.asarray(x0), extras={"time": jnp.asarray(t)[:, None], "duration": duration})
).reshape(t.shape[0], -1)
return D @ np.asarray(coeffs)
#################################################################
################ BACKEND ############################
#################################################################
def _update_force_coefficients(self, coeffs, support=None, jit_inferred=True):
"""Write or update force coefficients and rebuild ``force_inferred``.
Called after the initial linear solve and again after sparsification
to set the active support and (re-)build the callable force field.
Parameters
----------
coeffs : Array
Coefficient vector for the active basis functions.
support : array-like or None
Indices of the active basis functions. ``None`` means all.
jit_inferred : bool
Whether to JIT-compile the resulting ``force_inferred``.
"""
self.force_coefficients = coeffs
if support is None:
self.force_support = jnp.arange(self.force_scorer.p)
else:
self.force_support = jnp.array(support)
# Persist basis labels for downstream reporting (best-effort)
if hasattr(self, "force_basis"):
self.force_basis_labels = getattr(self.force_basis, "labels", None)
elif not hasattr(self, "force_basis_labels"):
self.force_basis_labels = None
self.force_G = self.force_G_full[jnp.ix_(self.force_support, self.force_support)]
self.force_G_pinv = stable_pinv(self.force_G)
# Sparse coeffs on the complete basis:
self.force_coefficients_full = jnp.zeros_like(self.force_moments)
if len(self.force_support) > 0:
self.force_coefficients_full = self.force_coefficients_full.at[self.force_support].set(coeffs)
# Call the inferred-constructing subclass-specific hook:
self._update_force_inferred()
def _update_diffusion_coefficients(self, coeffs, support=None, jit_inferred=True):
"""Write or update diffusion coefficients and rebuild ``diffusion_inferred``."""
self.diffusion_coefficients = coeffs
if support is None:
self.diffusion_support = jnp.arange(self.diffusion_scorer.p)
else:
self.diffusion_support = jnp.array(support)
self.diffusion_G = self.diffusion_G_full[jnp.ix_(self.diffusion_support, self.diffusion_support)]
self.diffusion_G_pinv = stable_pinv(self.diffusion_G)
# Sparse coeffs on the complete basis:
self.diffusion_coefficients_full = jnp.zeros_like(self.diffusion_moments)
self.diffusion_coefficients_full = self.diffusion_coefficients_full.at[self.diffusion_support].set(coeffs)
# Call the inferred-constructing subclass-specific hook:
self._update_diffusion_inferred()
def _detach_from_jax(self):
"""Convert all JAX arrays inside this object to NumPy arrays
to prevent memory leaks. Use this before deleting this object,
as the Jax traces might persist otherwise. Important when
performing a large number of inference runs in the same run
(e.g. for benchmarking the method over many
parameters/trajectories).
"""
import gc
for attr_name in vars(self): # Loop through all attributes
attr_value = getattr(self, attr_name)
if isinstance(attr_value, jnp.ndarray): # If it's a JAX array
setattr(self, attr_name, jax.device_get(attr_value)) # Convert to NumPy
elif isinstance(attr_value, dict): # If it's a dictionary, check inside
for key, value in attr_value.items():
if isinstance(value, jnp.ndarray):
attr_value[key] = jax.device_get(value)
elif isinstance(attr_value, list): # If it's a list, check each item
for i in range(len(attr_value)):
if isinstance(attr_value[i], jnp.ndarray):
attr_value[i] = jax.device_get(attr_value[i])
# Clear any lingering references in JAX's cache
jax.clear_caches()
jax.device_get(jax.numpy.zeros(1)) # Forces JAX to clear buffers
gc.collect()
# ---- simulation helpers ----------------------------------------------- #
def _find_finite_x0(self, also_dx: bool = False):
"""Return the first X row (and optionally dX) with no NaN fill values.
Masked datasets store NaN as a fill value at missing positions.
Using such a row as a simulation initial condition propagates NaN
through the whole trajectory. This helper scans the datasets to
find the first time step where every element of X is finite.
Parameters
----------
also_dx : bool
If True, also return ``dX = X[t+1] - X[t]`` at the chosen step
(both rows must be finite).
Returns
-------
x0 : Array
First fully-finite position row.
dX : Array, only when ``also_dx=True``
Increment at that time step.
"""
import numpy as np
for ds in self.data.datasets:
X_np = np.asarray(ds.X)
T = X_np.shape[0]
flat = X_np.reshape(T, -1)
rows_finite = np.all(np.isfinite(flat), axis=1)
if also_dx:
valid = np.where(rows_finite[:-1] & rows_finite[1:])[0]
else:
valid = np.where(rows_finite)[0]
if len(valid) > 0:
t0 = int(valid[0])
x0 = jnp.asarray(X_np[t0])
if also_dx:
return x0, jnp.asarray(X_np[t0 + 1] - X_np[t0])
return x0
# Fallback: peek_row + replace any remaining NaN with zeros.
import warnings
x0 = jnp.asarray(self.data.peek_row(require={"X"})["X"])
x0_clean = jnp.where(jnp.isfinite(x0), x0, jnp.zeros_like(x0))
warnings.warn(
"simulate_bootstrapped_trajectory: no fully-finite time step found in "
"the dataset. NaN fill values in the initial condition have been replaced "
"with 0.0.",
UserWarning,
stacklevel=3,
)
if also_dx:
return x0_clean, jnp.zeros_like(x0_clean)
return x0_clean
### Subclass hooks ###
@abstractmethod
def _force_G_matrix(self) -> jnp.ndarray: ...
@abstractmethod
def _force_moments(self) -> jnp.ndarray: ...
@abstractmethod
def _diffusion_G_matrix(self) -> jnp.ndarray: ...
@abstractmethod
def _diffusion_moments(self) -> jnp.ndarray: ...
[docs]
@abstractmethod
def get_diffusion_timeop(self, method: str) -> TimeOperand: ...
@abstractmethod
def _update_force_inferred(self) -> None: ...
@abstractmethod
def _update_diffusion_inferred(self) -> None: ...
# ---- persistence ------------------------------------------------------ #
[docs]
def save_results(self, path) -> "Path":
"""Save ``report_dict()`` to ``<path>.npz`` + ``<path>.json``.
See :func:`SFI.inference.serialization.save_results` for details.
"""
from SFI.inference.serialization import save_results
return save_results(self, path)
[docs]
@staticmethod
def load_results(path) -> dict:
"""Reload a dict previously saved by :meth:`save_results`.
See :func:`SFI.inference.serialization.load_results` for details.
"""
from SFI.inference.serialization import load_results
return load_results(path)