SFI.statefunc.stateexpr module

High-level immutable façade over backend node trees.

This module exposes three public classes:

  • Basis – deterministic dictionary of features (no parameters),

  • PSF – parametric state-function family F(x; theta),

  • SF – bound state function with theta fixed.

They all share the StateExpr algebra (broadcasting, linear ops, feature product/concatenation rules, differentiation builders).

Shape conventions (runtime evaluation)

Users call Basis/PSF/SF on single inputs or batched arrays; the library handles batching and vectorization internally. Let x be the runtime input:

  • If particles_input=False: x.shape == batch · dim

  • If particles_input=True: x.shape == batch · P · dim

Single inputs have batch=(,). Outputs always end with the feature axis (length n_features):

y.shape == batch · [P]^pdepth · (dim)^rank · n_features

pdepth is strict: outputs have exactly pdepth particle axes. If particles_input=False, then pdepth must be 0 (no particle axes can be created without a particle input axis). Any particle axis is simply treated as batch in that case.

User function contract (single–sample)

Factories (make_basis, make_psf) accept single-sample callables. Your function never sees batch axes; it receives:

  • x: (dim,) or (P, dim) if particles_input=True

  • Optional keywords: any subset of {v, mask, extras, params} that you declare in the signature. We only pass those you declare.

Return shape (no batch axis):

(P,)*pdepth + (dim,)*rank + (n_features?,)

If n_features==1 you may omit the last axis; we insert a singleton feature axis automatically.

mask semantics: must broadcast to the prefix of x including the particle axis (i.e. batch * P when present). Numeric or boolean masks are accepted.

extras semantics: Extras are pass-through data for user functions. The expression enforces presence only:

  • If a leaf declares extras_keys=("a","b",...), those keys must be present in the extras mapping at call time.

  • If a leaf declares extras but no keys, only the presence of a mapping is required; no keys are enforced.

No shape/broadcasting of extras is performed by the expression. Three kinds exist by declaration (never by shape):

  • global – default; forwarded unchanged to user functions;

  • particle – declared only by interaction leaves; gathered by the dispatcher per edge and then forwarded downstream as globals;

  • structural – rule/dispatcher-owned (e.g., CSR arrays); never forwarded.

JAX use & autodiff

All computations are JAX-friendly. Write user functions with jax.numpy. Expressions compose under jit/vmap, and support automatic differentiation:

  • .d_x() adds a spatial derivative axis to the rank block, and adds a particle axis when particles_input=True (coming from JAX; we only permute/reshape), unless pdepth=1 and same_particle is True.

  • .d_v() similarly (if needs_v=True).

  • .d_theta() (PSF only) returns a Jacobian with the feature axis fused with parameters.

Internally, derivative axis ordering is canonicalized by permutation only; we never create particle axes ourselves—if particles_input=True, extra particle axes in a Jacobian come from JAX itself.

class SFI.statefunc.stateexpr.StateExpr(root)[source]

Bases: Module

Immutable state expression backed by a static node tree.

Think: a read-only NumPy array whose last axis is features. Every algebraic operation returns a new StateExpr (functional style), and static contract metadata (rank, dim, pdepth, n_features, needs_v, particles_input) is validated at graph-construction time.

Runtime shapes

Inputs are batched at call time; the library handles batching.

  • If particles_input=False: x.shape == batch · dim

  • If particles_input=True: x.shape == batch · P · dim

Outputs always end with the feature axis (length n_features):

y.shape == batch · [P]^pdepth · (dim)^rank · n_features

If particles_input=False, pdepth must be 0.

User function contract (single-sample)

Factories accept single-sample callables; user code never sees batch axes. Your function gets x of shape (dim,) (or (P, dim) if particles_input) and any subset of keyword-only args it declares: {v, mask, extras, params}. Return shape (no batch axis): (P,)*pdepth + (dim,)*rank + (n_features?,). If n_features==1, you may omit the last axis; a singleton is inserted.

  • mask must broadcast to the prefix of x including the particle axis.

  • extras presence: if a leaf declares extras with no explicit keys, presence is required (any dict). If extras_keys is given, those keys are required. Values may be scalars or arrays that broadcast over batch only.

Operators

Element-wise arithmetic

  • +  -  *  / – element-wise on spatial axes; features must match. Scalars and 1-D vectors (length n_features) broadcast along features.

  • Unary: +expr, -expr.

  • NumPy/JAX ufuncs: sin, exp, etc. forward to element-wise maps with the same broadcasting rules; binary ufuncs accept StateExpr const and StateExpr StateExpr (features must match for the latter).

Linear-algebra-like

  • @ (matmul): true matrix multiplication on spatial axes, (..., m, k) @ (..., k, n) -> (..., m, n); features form a Cartesian product between operands (result features = F_left × F_right).

  • .einsum(*others, spec=...): generic spatial contraction; features take a Cartesian product across all operands (no implicit feature reduction).

  • .dot(other): Spatial inner product between last rank axis of self and first rank axis of other. Cartesian product over features.

  • .sqrtm(): matrix square root per-feature; requires rank==2.

Feature-axis manipulation

  • expr1 & expr2 / StateExpr.stack([...]): concatenate features. Static spatial contracts must match; labels (if present) are concatenated.

  • expr[idx]: feature selection (slice/list/bool/int). Spatial contract is unchanged; labels are subset when available.

  • .elementwisemap(func, label_fn=None): apply a scalar-to-scalar map to each feature independently (spatial axes untouched). Optional label_fn updates labels for Basis.

Differentiation builders

All builders return new expressions (no evaluation).

  • .d_x(same_particle=False, mode='auto') – spatial Jacobian dF/dx.

    • Adds one derivative-dim axis immediately before the rank block.

    • If particles_input=True:

      • when same_particle=False (default), builds the full cross-particle Jacobian df_i/dx_j and a second particle axis appears (from JAX);

      • when same_particle=True and pdepth=1, computes the same-particle Jacobian df_i/dx_i without adding a new particle axis; otherwise an error is raised.

  • .d_v(same_particle=False, mode='auto') – velocity Jacobian dF/dv (requires needs_v=True). Same axis rules as .d_x().

  • .d_theta(mode='auto') – Jacobian w.r.t. parameters (PSF only); the final axis becomes features × n_params_total. Batch/particle/rank prefixes are preserved.

Type mixing and broadcasting

  • Scalars and ndarrays are treated as purely spatial constants: they must be broadcastable to the spatial rank block (dim,)*rank and are then broadcast uniformly across the feature axis. Bare arrays cannot target the feature axis directly.

  • Combining two StateExpr requires matching static contracts for rank, dim, and pdepth.

  • For element-wise ops such as +, - and most binary ufuncs, n_features must match (per-feature operations).

  • For multiplicative ops (*, / and their ufuncs), as well as @ and .einsum, feature axes take a Cartesian product between operands: F_out = F_left × F_right. When both operands have more than one feature a one-off warning is emitted, as this can grow n_features quickly.

  • needs_v is OR-combined: if any operand needs v, the result does.

  • particles_input is OR-combined: if any operand uses particle input, the result does too. An operand without particle input is broadcast uniformly across the particle axis.

Array interop

Plain JAX/NumPy arrays are accepted in binary ops with StateExpr. They are treated as spatial constants with a single feature. Arrays broadcast over spatial axes and batch/particles only. Features never arise from arrays and are never contracted unless requested by explicit feature-aware APIs.

Supported operations with arrays:

  • Elementwise: +, -, *, /, **, and their reflected forms.

  • Linear algebra: A @ B, B @ A.

  • Tensor algebra: einsum(eq, ...), dot(...), tensordot(...).

JAX compatibility and autodiff

Write user functions with jax.numpy as jnp. Expressions compose under jit / vmap, and support automatic differentiation:

  • .d_x(), .d_v() add a derivative-dim axis (and a particle axis when particles_input=True).

  • .d_theta() fuses features × n_params on the last axis. Derivative axis ordering is canonicalized by permutation only.

d_v(*, same_particle=False, mode='auto')[source]

Build an expression for the velocity Jacobian ∂F/∂v.

Same rules as .d_x(). Requires needs_v=True on the underlying expression.

Parameters:
  • same_particle (bool)

  • mode (str)

d_x(*, same_particle=False, mode='auto')[source]

Build an expression for the spatial Jacobian dF/dx.

Axis effects

  • Adds one derivative-dim immediately before the rank block.

  • If particles_input=True:

    • when same_particle=True: if pdepth=1, compute df_i/dx_i (no extra P axis); the particle dimension behaves like a broadcasted index. Otherwise, raises an error.

    • when same_particle=False (default): compute the full cross-particle Jacobian df_i/dx_j; an extra particle axis appears (from JAX). We never create P axes ourselves; we only permute to canonical order.

param same_particle:

See axis effects above.

type same_particle:

bool

param mode:

Backend differentiation mode; ‘auto’ selects a sane default.

type mode:

{‘auto’, …}

returns:

A new expression representing the Jacobian.

rtype:

StateExpr

Notes

This method triggers no evaluation; it returns a new graph.

Parameters:
  • same_particle (bool)

  • mode (str)

dense(n_out, *, weight='W', bias='b')[source]

Apply a learnable affine map on the feature axis.

y[..., j] = sum_i x[..., i] * W[i, j] + b[j]

Spatial (rank) axes are untouched: the same W, b are shared across every spatial component. The result is always a PSF (since the dense layer introduces learnable parameters).

Parameters:
  • n_out (int) – Number of output features.

  • weight (str) – Name for the weight parameter (default "W"). Use distinct names ("W1", "W2", …) when stacking multiple layers.

  • bias (str | None) – Name for the bias parameter (default "b"; None to omit). Use distinct names ("b1", "b2", …) when stacking layers.

Returns:

A parametric state function wrapping the dense layer.

Return type:

PSF

Examples

Build the hidden layers of an MLP force field:

>>> from SFI.bases import X
>>> import jax.numpy as jnp
>>> mlp = (
...     X(dim=2).vectorize(2)
...     .dense(32, weight="W1", bias="b1")
...     .elementwisemap(jnp.tanh)
...     .dense(1, weight="W2", bias="b2")
... )
property dim
dot(other, axes=None)[source]

Spatial tensordot via einsum.

Semantics:
  • axes=None: contract last axis of self with first axis of other.

  • axes=int:
    • if self.rank == other.rank: contract all axes (Frobenius/trace for rank-2).

    • else: contract axes trailing axes of self with axes leading axes of other.

  • axes=(a_axes, b_axes): NumPy-style explicit lists.

Arrays are accepted and coerced to spatial constants.

classmethod einsum(spec, *operands)[source]

General contraction on spatial axes (like jnp.einsum).

Important

  • Use only lowercase letters.

  • spec refers only to spatial axes (not the feature axis).

  • Features take a Cartesian product across operands (no implicit feature reduction or alignment). If you need feature concatenation, use &/stack. For per-feature ops, use element-wise maps or binary ops where features must match.

Arrays in operands are accepted and coerced to spatial-constant expressions with a single feature. Only spatial letters in spec are interpreted. If no StateExpr is present, a TypeError is raised because dim cannot be inferred.

Examples

Vector inner product (per-feature), two rank-1 inputs: >>> # a, b: i × F >>> c = StateExpr.einsum(“i,i->”, a, b) # result: × F

Matrix–vector product (per-feature), rank-2 with rank-1: >>> # M: ij × F1, v: j × F2 → i × (F1×F2) >>> y = StateExpr.einsum(“ij,j->i”, M, v)

Outer product (per-feature Cartesian product): >>> # u: i × F1, v: j × F2 → ij × (F1×F2) >>> O = StateExpr.einsum(“i,j->ij”, u, v)

Parameters:
  • spec (str) – An einsum string over spatial indices, e.g. “ij,j->i”.

  • operands (mix[StateExpr, array-like]) – Any mix of StateExpr and arrays.

elementwisemap(func, *, label_fn=None)[source]

Apply func element-wise to every feature (spatial axes untouched).

func must be a pure JAX function from scalar→scalar (rank-0 arrays OK). If the expression carries feature labels (e.g., a Basis or an SF bound from a Basis), label_fn (if provided) is applied to each feature label.

Example

>>> B = ...   # Basis with 4 features
>>> C = B.elementwisemap(jnp.tanh, label_fn=lambda s: f"tanh({s})")
Parameters:
  • func (Callable[[Array], Array])

  • label_fn (Callable[[str], str] | None)

estimate_bytes_per_sample(*, dtype=None, particle_size=None, sample=None, mode='forward')[source]

Small convenience wrapper returning only the transient bytes/sample.

Parameters:
  • particle_size (int | None)

  • sample (SampleMeta | None)

  • mode (str)

Return type:

int

features_to_rank(rank)[source]

Unfold features into spatial axes → given rank.

The output layout changes from the current:

batch · (dim,)^self.rank · n_features

to:

batch · (dim,)^rank · (n_features / dim^(rank − self.rank),)

where the new innermost spatial axes are carved out of the feature axis. This is a pure reshape and is the exact inverse of rank_to_features() when restoring the original rank.

Parameters:

rank (int) – Target tensor rank (must be greater than the current rank).

Returns:

Expression at the requested rank with fewer features.

Return type:

StateExpr (same subclass)

Raises:
  • ValueError – If n_features is not divisible by dim^Δrank.

  • TypeError – If rank self.rank (use rank_to_features to go down).

Examples

Turn a dense layer’s output back into a vector field:

>>> scalar_expr.features_to_rank(1)  # rank-1, F/dim features

Build a 2→H→H→2 MLP force field:

>>> mlp = (
...     X(dim=2)
...     .rank_to_features()                     # rank-0, 2 features
...     .dense(32, weight="W1", bias="b1")
...     .elementwisemap(jnp.tanh)
...     .dense(2, weight="W2", bias="b2")       # rank-0, 2 features
...     .features_to_rank(1)                     # rank-1, 1 feature
... )
memory_hint(*, dtype=None, particle_size=None, sample=None, mode='forward')[source]

Conservative per-sample memory footprint for the WHOLE expression tree. Delegates to the root node, which sums children + own output along the way.

Parameters:
  • particle_size (int | None)

  • sample (SampleMeta | None)

  • mode (str)

property n_features
property needs_v
property particle_extras: tuple[str, ...]

Pure metadata, forwarded from the root node.

Names of extras declared as per-particle somewhere in the underlying node tree (typically by interaction leaves). The dispatcher reads this to know which keys to gather from (P, …) into (E, K, …) per edge before calling locals.

property particles_input
property pdepth
property rank
rank_to_features()[source]

Fold all spatial (rank) axes into the feature axis → rank-0.

The output layout changes from:

batch · (dim,)^rank · n_features

to:

batch · (n_features × dim^rank,)

with rank = 0. This is a pure reshape (no copy, no learnable parameters) and is the exact inverse of features_to_rank(original_rank).

Returns:

Scalar expression whose feature count is self.n_features × self.dim ** self.rank.

Return type:

StateExpr (same subclass)

Raises:

TypeError – If the expression is already rank‑0 (no-op would be confusing).

Examples

Prepare a rank-1 position vector for dense layers:

>>> X(dim=2).rank_to_features()   # rank-0, 2 features

The round-trip is the identity:

>>> expr.rank_to_features().features_to_rank(expr.rank)  # same as expr
property required_extras: tuple[str, ...]

Presence-only extras required by the expression, forwarded from the root node. No shape/broadcast semantics here.

root: BaseNode
property sdims
specialize(*, dataset)[source]

Collapse a pooled model to its single-condition specialization.

Returns a new expression in which every dataset_index-reading primitive (e.g. per_dataset_scalar(), dataset_indicator()) is folded at condition dataset: per-condition parameter arrays collapse to that condition’s slice and the reserved dataset_index extra drops out of required_extras. The pooled-ness is an inference-time concern; once a condition is chosen the model stands alone (no dataset concept).

On a bound SF the stored parameter values are projected to match the shrunken template; on an unbound PSF the template’s per-condition specs become scalars.

Parameters:

dataset (int)

Return type:

StateExpr

sqrtm()[source]
classmethod stack(exprs)[source]

Concatenate along the feature axis.

Static contracts must match (rank/dim, compatible pdepth).

Parameters:

exprs (Sequence[StateExpr])

tensordot(other, axes=1)[source]

Alias of .dot with NumPy-compatible axes.

tensorize(dim=None, mode='symmetric')[source]

Lift a scalar expression to rank-2 (matrix).

Parameters:
  • dim (int, optional) – Spatial dimension. Inferred when possible.

  • mode (str) – 'symmetric' (default) uses symmetric_matrix_basis() (d(d+1)/2 features per scalar feature, spans all symmetric matrices). 'identity' uses identity_matrix_basis() (1 feature per scalar feature, isotropic).

Returns:

Matrix expression.

Return type:

StateExpr

vectorize(dim=None, axes=None)[source]

Lift a scalar expression to rank-1 (vector).

Equivalent to self * unit_vector_basis(dim, axes=axes), i.e. a Cartesian product of the feature axis with unit vectors.

Parameters:
  • dim (int, optional) – Spatial dimension. Inferred from the expression’s contract when possible.

  • axes (sequence of int, optional) – Subset of spatial axes to include (default: all dim axes).

Returns:

Vector expression with n_features = self.n_features × len(axes).

Return type:

StateExpr

Parameters:

root (BaseNode)