SFI.inference.result module

class SFI.inference.result.InferenceResultSF(sf, *, param_cov=None, meta=None)[source]

Bases: SF

A fitted, callable state function that is an SF and carries parameter covariance + metadata for downstream uncertainty handling.

Notes

  • param_cov is the covariance of the flattened parameter vector defined by the underlying PSF template order (see PSF.flatten_params).

  • Covariance estimation is handled upstream (in the inferer).

  • Call predict_var() / predict_ci() for pointwise uncertainty.

Parameters:
  • sf (SF)

  • param_cov (Array | None)

  • meta (Dict[str, Any])

d_v(*, same_particle=False, mode='auto')

Build an expression for the velocity Jacobian ∂F/∂v.

Same rules as .d_x(). Requires needs_v=True on the underlying expression.

Parameters:
  • same_particle (bool)

  • mode (str)

d_x(*, same_particle=False, mode='auto')

Build an expression for the spatial Jacobian dF/dx.

Axis effects

  • Adds one derivative-dim immediately before the rank block.

  • If particles_input=True:

    • when same_particle=True: if pdepth=1, compute df_i/dx_i (no extra P axis); the particle dimension behaves like a broadcasted index. Otherwise, raises an error.

    • when same_particle=False (default): compute the full cross-particle Jacobian df_i/dx_j; an extra particle axis appears (from JAX). We never create P axes ourselves; we only permute to canonical order.

param same_particle:

See axis effects above.

type same_particle:

bool

param mode:

Backend differentiation mode; ‘auto’ selects a sane default.

type mode:

{‘auto’, …}

returns:

A new expression representing the Jacobian.

rtype:

StateExpr

Notes

This method triggers no evaluation; it returns a new graph.

Parameters:
  • same_particle (bool)

  • mode (str)

dense(n_out, *, weight='W', bias='b')

Apply a learnable affine map on the feature axis.

y[..., j] = sum_i x[..., i] * W[i, j] + b[j]

Spatial (rank) axes are untouched: the same W, b are shared across every spatial component. The result is always a PSF (since the dense layer introduces learnable parameters).

Parameters:
  • n_out (int) – Number of output features.

  • weight (str) – Name for the weight parameter (default "W"). Use distinct names ("W1", "W2", …) when stacking multiple layers.

  • bias (str | None) – Name for the bias parameter (default "b"; None to omit). Use distinct names ("b1", "b2", …) when stacking layers.

Returns:

A parametric state function wrapping the dense layer.

Return type:

PSF

Examples

Build the hidden layers of an MLP force field:

>>> from SFI.bases import X
>>> import jax.numpy as jnp
>>> mlp = (
...     X(dim=2).vectorize(2)
...     .dense(32, weight="W1", bias="b1")
...     .elementwisemap(jnp.tanh)
...     .dense(1, weight="W2", bias="b2")
... )
property dim
dot(other, axes=None)

Spatial tensordot via einsum.

Semantics:
  • axes=None: contract last axis of self with first axis of other.

  • axes=int:
    • if self.rank == other.rank: contract all axes (Frobenius/trace for rank-2).

    • else: contract axes trailing axes of self with axes leading axes of other.

  • axes=(a_axes, b_axes): NumPy-style explicit lists.

Arrays are accepted and coerced to spatial constants.

drop_features: bool = True
classmethod einsum(spec, *operands)

General contraction on spatial axes (like jnp.einsum).

Important

  • Use only lowercase letters.

  • spec refers only to spatial axes (not the feature axis).

  • Features take a Cartesian product across operands (no implicit feature reduction or alignment). If you need feature concatenation, use &/stack. For per-feature ops, use element-wise maps or binary ops where features must match.

Arrays in operands are accepted and coerced to spatial-constant expressions with a single feature. Only spatial letters in spec are interpreted. If no StateExpr is present, a TypeError is raised because dim cannot be inferred.

Examples

Vector inner product (per-feature), two rank-1 inputs: >>> # a, b: i × F >>> c = StateExpr.einsum(“i,i->”, a, b) # result: × F

Matrix–vector product (per-feature), rank-2 with rank-1: >>> # M: ij × F1, v: j × F2 → i × (F1×F2) >>> y = StateExpr.einsum(“ij,j->i”, M, v)

Outer product (per-feature Cartesian product): >>> # u: i × F1, v: j × F2 → ij × (F1×F2) >>> O = StateExpr.einsum(“i,j->ij”, u, v)

Parameters:
  • spec (str) – An einsum string over spatial indices, e.g. “ij,j->i”.

  • operands (mix[StateExpr, array-like]) – Any mix of StateExpr and arrays.

elementwisemap(func, *, label_fn=None)

Apply func element-wise to every feature (spatial axes untouched).

func must be a pure JAX function from scalar→scalar (rank-0 arrays OK). If the expression carries feature labels (e.g., a Basis or an SF bound from a Basis), label_fn (if provided) is applied to each feature label.

Example

>>> B = ...   # Basis with 4 features
>>> C = B.elementwisemap(jnp.tanh, label_fn=lambda s: f"tanh({s})")
Parameters:
  • func (Callable[[Array], Array])

  • label_fn (Callable[[str], str] | None)

estimate_bytes_per_sample(*, dtype=None, particle_size=None, sample=None, mode='forward')

Small convenience wrapper returning only the transient bytes/sample.

Parameters:
  • particle_size (int | None)

  • sample (SampleMeta | None)

  • mode (str)

Return type:

int

features_to_rank(rank)

Unfold features into spatial axes → given rank.

The output layout changes from the current:

batch · (dim,)^self.rank · n_features

to:

batch · (dim,)^rank · (n_features / dim^(rank − self.rank),)

where the new innermost spatial axes are carved out of the feature axis. This is a pure reshape and is the exact inverse of rank_to_features() when restoring the original rank.

Parameters:

rank (int) – Target tensor rank (must be greater than the current rank).

Returns:

Expression at the requested rank with fewer features.

Return type:

StateExpr (same subclass)

Raises:
  • ValueError – If n_features is not divisible by dim^Δrank.

  • TypeError – If rank self.rank (use rank_to_features to go down).

Examples

Turn a dense layer’s output back into a vector field:

>>> scalar_expr.features_to_rank(1)  # rank-1, F/dim features

Build a 2→H→H→2 MLP force field:

>>> mlp = (
...     X(dim=2)
...     .rank_to_features()                     # rank-0, 2 features
...     .dense(32, weight="W1", bias="b1")
...     .elementwisemap(jnp.tanh)
...     .dense(2, weight="W2", bias="b2")       # rank-0, 2 features
...     .features_to_rank(1)                     # rank-1, 1 feature
... )
flatten_params()[source]

Return θ̂ as a 1D vector using the PSF template order.

Return type:

Array

property labels

Basis labels propagated from the parent PSF.

classmethod load(path, template)[source]

Reload a model saved by save().

Parameters:
  • path (str or Path) – Base path (without extension).

  • template (InferenceResultSF) – Skeleton with the same tree structure (same PSF + dummy params).

Return type:

InferenceResultSF

materialize_params(vec)[source]

Inverse of flatten_params: make a param dict from a vector.

Parameters:

vec (Array)

Return type:

dict[str, Array]

memory_hint(*, dtype=None, particle_size=None, sample=None, mode='forward')

Conservative per-sample memory footprint for the WHOLE expression tree. Delegates to the root node, which sums children + own output along the way.

Parameters:
  • particle_size (int | None)

  • sample (SampleMeta | None)

  • mode (str)

meta: Dict[str, Any]
property n_features
property needs_v
param_cov: Array | None
params: dict[str, Array]
property particle_extras: tuple[str, ...]

Pure metadata, forwarded from the root node.

Names of extras declared as per-particle somewhere in the underlying node tree (typically by interaction leaves). The dispatcher reads this to know which keys to gather from (P, …) into (E, K, …) per edge before calling locals.

property particles_input
property pdepth
predict_ci(x, *, alpha=0.95, extras=None, mask=None)[source]

Pointwise confidence intervals via the delta method.

Parameters:
  • x (array, shape (N, dim))

  • alpha (float) – Confidence level (default 0.95 for 95 % CI).

  • extras (forwarded to the underlying state function.)

  • mask (forwarded to the underlying state function.)

Returns:

  • dict with keys

    • mean — model prediction F̂(x).

    • std — pointwise standard deviation √Var[F(x)].

  • - **lower*, **upper** — symmetric CI bounds* – ± z_{α/2} · std.

Return type:

Dict[str, Array]

predict_cov(x, *, extras=None, mask=None)[source]

Full pointwise covariance matrix via the delta method.

\[\Sigma_F(x) = J_\theta(x)\,\Sigma_\theta\,J_\theta(x)^\top\]
Parameters:
  • x (array, shape (N, dim))

  • extras (forwarded to the underlying state function.)

  • mask (forwarded to the underlying state function.)

Returns:

cov – For rank-1 models: shape (N, d, d). For rank-2 models: shape (N, d, d, d, d) (rarely needed).

Return type:

jnp.ndarray

predict_var(x, *, extras=None, mask=None)[source]

Pointwise predictive variance via the delta method.

\[\operatorname{Var}\!\bigl[F_i(x)\bigr] \approx \bigl(J_\theta(x)\,\Sigma_\theta\,J_\theta(x)^\top\bigr)_{ii}\]

For linear models (basis expansion) this is exact, not an approximation.

Parameters:
  • x (array, shape (N, dim)) – Query points.

  • extras (dict, optional) – Extra arguments forwarded to the underlying state function (e.g. {"box": box} for periodic boundary conditions).

  • mask (array, optional) – Boolean mask forwarded to evaluation.

Returns:

var – Per-component variance. Shape matches the model output rank: (N, d) for a rank-1 (force) model, (N, d, d) for rank-2 (diffusion tensor).

Return type:

jnp.ndarray

property rank
rank_to_features()

Fold all spatial (rank) axes into the feature axis → rank-0.

The output layout changes from:

batch · (dim,)^rank · n_features

to:

batch · (n_features × dim^rank,)

with rank = 0. This is a pure reshape (no copy, no learnable parameters) and is the exact inverse of features_to_rank(original_rank).

Returns:

Scalar expression whose feature count is self.n_features × self.dim ** self.rank.

Return type:

StateExpr (same subclass)

Raises:

TypeError – If the expression is already rank‑0 (no-op would be confusing).

Examples

Prepare a rank-1 position vector for dense layers:

>>> X(dim=2).rank_to_features()   # rank-0, 2 features

The round-trip is the identity:

>>> expr.rank_to_features().features_to_rank(expr.rank)  # same as expr
property required_extras: tuple[str, ...]

Presence-only extras required by the expression, forwarded from the root node. No shape/broadcast semantics here.

root: BaseNode
save(path)[source]

Save this fitted model to <path>.eqx + <path>.meta.json.

The saved files can be reloaded with load(), provided the user supplies a template built from the same PSF/Basis.

See SFI.inference.serialization.save_model().

Return type:

Path

property sdims
specialize(*, dataset)

Specialize a bound function at condition dataset.

Rewrites the graph (folding dataset_index-reading leaves) and projects the bound parameter values onto the shrunken template: a per-condition spec whose shape loses a leading axis is sliced at dataset; shared specs are kept verbatim.

Parameters:

dataset (int)

Return type:

SF

sqrtm()
classmethod stack(exprs)

Concatenate along the feature axis.

Static contracts must match (rank/dim, compatible pdepth).

Parameters:

exprs (Sequence[StateExpr])

summary()[source]

Formatted coefficient table (if labels and coefficients are available).

Return type:

str

tensordot(other, axes=1)

Alias of .dot with NumPy-compatible axes.

tensorize(dim=None, mode='symmetric')

Lift a scalar expression to rank-2 (matrix).

Parameters:
  • dim (int, optional) – Spatial dimension. Inferred when possible.

  • mode (str) – 'symmetric' (default) uses symmetric_matrix_basis() (d(d+1)/2 features per scalar feature, spans all symmetric matrices). 'identity' uses identity_matrix_basis() (1 feature per scalar feature, isotropic).

Returns:

Matrix expression.

Return type:

StateExpr

vectorize(dim=None, axes=None)

Lift a scalar expression to rank-1 (vector).

Equivalent to self * unit_vector_basis(dim, axes=axes), i.e. a Cartesian product of the feature axis with unit vectors.

Parameters:
  • dim (int, optional) – Spatial dimension. Inferred from the expression’s contract when possible.

  • axes (sequence of int, optional) – Subset of spatial axes to include (default: all dim axes).

Returns:

Vector expression with n_features = self.n_features × len(axes).

Return type:

StateExpr

SFI.inference.result.kernel_predict_ci(r_eval, kernels, coeffs, cov_block, *, alpha=0.95)[source]

Confidence interval for a 1D kernel profile.

For a reconstructed kernel profile

\[k(r) = \sum_{\alpha} c_\alpha \, \phi_\alpha(r),\]

the variance at each \(r\) is

\[\operatorname{Var}[k(r)] = \boldsymbol{\phi}(r)^\top \, \Sigma_c \, \boldsymbol{\phi}(r),\]

where \(\Sigma_c\) is the covariance sub-block for the coefficients of this basis group.

Parameters:
  • r_eval (array_like, shape (R,)) – Radial grid on which to evaluate the kernel.

  • kernels (list of (callable, label)) – Kernel basis functions, e.g. from exp_poly_kernels(). Each callable maps r -> phi(r).

  • coeffs (array_like, shape (K,)) – Inferred coefficients for this basis block.

  • cov_block (array_like, shape (K, K)) – Covariance sub-block for these coefficients (from inf.force_coefficients_covariance[i0:i1, i0:i1] after calling inf.compute_force_error()).

  • alpha (float) – Confidence level (default 0.95 for 95 % CI).

Returns:

  • dict with keys

  • - **r* — the input radial grid (as numpy array).*

    • mean — kernel profile coeffs @ phi(r).

  • - **std* — pointwise standard deviation.*

  • - **lower*, **upper** — symmetric CI bounds.*

    • phi — basis matrix (K, R) (useful for further analysis).

Return type:

Dict[str, ndarray]