State function design

What is a state function?

A state function in SFI is any object that maps a dynamical state (positions, optional velocities, internal coordinates, …) plus optional extras (masks, neighbor lists, experimental tags) to tensor-valued outputs with a well-defined rank and feature axis.

Formally, a state function has a static contract, implemented by a common mixin _ContractMixin:

  • rank – tensor rank as a Rank enum (SCALAR, VECTOR, MATRIX, …);

  • dim – number of spatial coordinates per particle (or None to remain agnostic);

  • needs_v – whether the function requires velocity input;

  • pdepth – number of particle axes in the output (0, 1 or 2);

  • n_features – number of feature channels on the last axis;

  • param_suite – optional ParamSuite describing parameters, or None for deterministic nodes;

  • extras_required – a tuple of extras keys that must be present;

  • particles_input – whether the input x is expected to carry a particle axis.

Velocities are optional throughout: for overdamped systems, state functions typically depend only on positions; for underdamped systems, objects with needs_v=True additionally take velocities and can expose derivatives with respect to both.

The contract is enforced at node level by two main guards:

  • _assert_inputs checks the shape and consistency of x, v, mask and extras;

  • _assert_outputs checks that the output y matches the declared feature, rank, and particle-depth structure.

Masking and missing data

Inference in SFI must deal with complex patterns of missing data:

  • particles appear and disappear over time;

  • different coordinates of the same particle may be masked;

  • estimators often evaluate derivatives at shifted time points.

The contract enforces a simple rule for masks:

  • a mask, if provided, must broadcast to the prefix of x excluding the spatial dimension, i.e. to x.shape[:-1];

  • with particles_input=True, this prefix is (…, N_particles) and the mask has the natural shape “batch · particles”.

Masks can be boolean or numeric; broadcasting follows NumPy rules. In typical use, masks are constructed automatically from missing data in TrajectoryCollection (trajectory module). User-defined single-sample functions almost never manipulate masks explicitly:

  • factories accept a keyword-only argument mask=None, but most user functions do not need to mention it;

  • the surrounding SFI code constructs and applies masks as needed.

Differentiation and custom rules

Inference modules need derivatives with respect to:

  • positions and velocities (for drift and diffusion estimators);

  • parameters (for likelihood-based inference and gradient-based optimisation).

Naive autodiff through raw code is not enough: it must interact correctly with masking, particle indexing, and multi-particle interactions. The StateExpr graph carries explicit derivative nodes, and the API provides consistent methods such as d_x(), d_v() and d_theta() on PSF and SF. Custom differentiation rules for multi-particle systems are implemented at node level while still leveraging JAX under the hood.

Modularity and algebra of operators

Realistic models are built from reusable bricks: monomials, radial functions, local tensors, external fields, etc. The state function abstraction provides an algebra of operators on StateExpr:

  • concatenation of feature dictionaries;

  • products with scalar fields;

  • Einstein summations (einsum) to contract features with unit vectors, projectors, and tensors;

  • feature slicing and reshaping;

  • mapping over extra axes (e.g. species or feature groups).

The static contract for composite nodes is obtained by merging children via merge_contract with mode-specific rules:

  • mode="concat" – concatenate features; requires equal rank, sums n_features;

  • mode="map" – elementwise maps; requires equal rank and n_features;

  • mode="einsum" – spatial contraction; the Einstein spec determines the output rank and the features are the Cartesian product of the inputs. Particle depth and particles_input are merged using a strict broadcast rule (either identical, or scalars combined with particle-wise nodes).

Because each operator preserves the static contract, complex bases can be built safely by combining simpler ones. In particular, it is natural to:

  • start from scalar features such as monomials or radial kernels;

  • contract them with unit vectors to obtain vector forces;

  • assemble higher-rank tensors (diffusion matrices, projectors, stress tensors) with prescribed symmetry.

Identity, naming and model selection

Inference modules need to keep track of which feature is which:

  • coefficients must be associated with meaningful labels;

  • model selection routines need stable feature identities for pruning;

  • pretty-printing and diagnostics require readable names.

Each state function carries labels and descriptors for its feature axis, and these are preserved under composition (with appropriate concatenation or prefixing rules). This traceability is what allows higher-level code to talk about “turning off a specific kernel” or “selecting the best radial terms” without losing track of which contribution is which.

Extras: pass-through pytrees

Many models depend on extras that are not part of the core state:

  • boundary conditions (e.g. periodic box size);

  • neighbor or adjacency lists;

  • experiment-specific parameters or tags;

  • particle identities and types.

State functions declare which extras keys they require, via extras_required. At runtime:

  • _assert_inputs guarantees that required keys are present if any extras mapping is provided;

  • extras themselves can be any JAX-compatible pytree (arrays, nested dicts, tuples, …) and are passed through to leaves and dispatchers without shape validation at this level.

This is essential for interacting-particle and spatial models, where local kernels need access to box sizes, neighbor lists, or other contextual data.

Named parameters and parametric trees

Parametric state functions (PSF) organise parameters into named blocks described by ParamSpec and grouped into a ParamSuite:

  • each ParamSpec has a name, shape, dtype, and an initialisation rule;

  • ParamSuite is an immutable container of specs; it can map between flat vectors and parameter trees and can coerce user-provided parameter dicts to the expected shapes and dtypes.

Crucially, parameter names encode sharing:

  • merging parameter suites (e.g. when composing PSFs) uses name-based union;

  • if a name appears in multiple suites, the corresponding specs must be compatible (same shape and dtype) and are merged into a single shared parameter; otherwise an error is raised.

Thus, reusing a parameter name in different parts of a PSF implies that the parameter is shared and must have the same array structure. This is how weight tying and shared kernels are expressed in the architecture.

Factory entry points: make_basis and make_psf

The factories make_basis() and make_psf() are the primary user-facing entry points. They take single-sample functions written in natural Python / jax.numpy and turn them into full Basis or PSF objects with explicit metadata.

Both take a single-sample function plus declared metadata (dim, rank, n_features, labels, parameter specs) and return the corresponding rich object; worked examples live in Models and state functions. In both cases:

  • the function is written for a single sample; SFI handles batching over time and particles;

  • mask is an optional keyword that most users can ignore;

  • metadata (dimension, rank, particle depth, feature labels, extras keys, parameter spec) is declared once at construction and then enforced.

Single-feature PSFs are common and natural: they represent a single parametric field or force. Bases, in contrast, are typically multi-feature dictionaries used for linear combinations; if you conceptually have a single feature with a coefficient, a PSF is usually a better fit than a single-feature Basis.

Linear and nonlinear inference: Basis vs PSF

The distinction between Basis and PSF reflects the two main inference regimes in SFI:

  • Linear inference uses Basis objects. Given a basis matrix and data, the inference module solves the linear regression problem exactly (or via sparse / regularised variants) to obtain optimal coefficients.

  • Nonlinear inference uses PSF objects. Parameters are updated through nonlinear optimisers (gradient descent, quasi-Newton, etc.), using derivatives with respect to parameters exposed by the PSF.

Both are important:

  • linear inference is fast, robust, and often sufficient when a rich basis is available;

  • nonlinear PSFs are needed for models with genuinely nonlinear parameter dependencies or shared parameters across many features.

Once parameters are inferred, a PSF can be frozen into an SF for use in simulation and downstream analysis.

Interacting particle systems (optional)

For systems of many interacting particles (or lattice sites), SFI takes a “local rule + dispatcher” approach:

  1. A local Interactor describes how a small group of particles (typically a pair) contributes to the quantity of interest (force, torque, current, …). It operates on a local configuration Xk with shape (K, dim) and optional extras and parameters.

  2. Calling Interactor.dispatch(...) turns this local rule into a global state function (Basis/PSF/SF), specifying: * how to select and combine the K-tuples (all pairs, finite range, given neighbor list, …); * how to reduce them (sum, average, owner conventions).

The underlying dispatcher and interaction-node machinery lives in backend modules; Interactor and its .dispatch method are the main front-end objects. Conceptually, this design is close to graph neural networks: the interactor plays the role of a message function on edges, and the dispatcher specifies the graph and aggregation on nodes.

Users who do not work with interacting-particle or spatial systems can skip this part; the rest of statefunc is sufficient for single-particle or well-mixed models.

Particles input vs particle depth

Two related but distinct concepts control how particle axes are handled:

  • particles_input indicates whether the state function expects a particle axis in its input. Many single-particle bricks operate on a single state vector and are vmapped over particles by the surrounding code, while interactors expect a fixed number K of particles as input.

  • pdepth counts how many particle axes the output carries (0, 1 or 2). A single-particle force typically has pdepth = 1; a purely local scalar kernel might have pdepth = 0; a pairwise matrix could have pdepth = 2.

The contract for outputs enforces:

  • if particles_input=True, an input x with prefix batch · P_in · dim must produce outputs whose prefix is batch · P_in^pdepth (i.e. one particle axis per depth level);

  • if particles_input=False, a node with pdepth=0 preserves the batch prefix of x, and other combinations are rejected at runtime.

This ensures that the same operators (concatenation, derivatives, etc.) work uniformly across single-particle functions, multi-particle interactions, and spatial fields.

Symmetries and tensor structure

A key motivation for this architecture is precise control over symmetries:

  • scalar features can encode radial dependence, invariants, or other symmetry constraints;

  • unit vectors and projectors can be composed into tensors with prescribed symmetries (e.g. longitudinal vs transverse components, radial vs tangential directions);

  • dimension-dependent behaviour is explicit: basis functions can differ between components, not just be copies of the same scalar applied to each dimension.

This contrasts with earlier SFI architectures, where basis functions were often implicitly “the same for all dimensions”. In the new design, the structure of forces and diffusion tensors is expressed directly at the level of state functions, making symmetries transparent and easier to enforce.

Feature axis and drop_features

All state functions place the feature axis last. Some operations can “drop” features and return a featureless tensor (for example, turning a single-feature PSF into an SF that represents a single field rather than a dictionary).

To avoid ambiguity:

  • operations that would need to drop more than one feature (n_features > 1) are forbidden and raise an error;

  • single-feature PSFs are the intended target for such operations;

  • Bases are typically multi-feature; users are expected to select individual features explicitly before applying operations that drop the feature axis.

Backend architecture: leaves, nodes, operations

Internally, statefunc organises computations as a tree of nodes:

  • leaf nodes implement primitive bricks (single-sample functions, interactors, parameter blocks);

  • operation nodes implement composition (concat, map, einsum, derivatives, slicing, interaction dispatch, …);

  • every node inherits the same static contract and participates in memory accounting.

This node-level backend (including the interaction dispatchers) is considered an implementation detail. The front-end abstractions (Basis, PSF, SF, Interactor, and the factory functions) are the stable public surface.

Performance and memory-aware evaluation

Each node also contributes to a global estimate of memory footprint. This allows the integration layer (SFI.integrate) to:

  • plan vectorised evaluations over time and particles;

  • adaptively chunk computations to stay within a target memory budget.

This memory-aware design is what lets large systems (many particles, long trajectories, rich bases) be evaluated and differentiated efficiently without manual chunking logic in user code.

See also

The Gallery demo gallery and the Running inference section showcase the evaluation engine on concrete inference problems.