State function design¶
What is a state function?¶
A state function in SFI is any object that maps a dynamical state (positions, optional velocities, internal coordinates, …) plus optional extras (masks, neighbor lists, experimental tags) to tensor-valued outputs with a well-defined rank and feature axis.
Formally, a state function has a static contract, implemented by a common
mixin _ContractMixin:
rank– tensor rank as aRankenum (SCALAR, VECTOR, MATRIX, …);dim– number of spatial coordinates per particle (orNoneto remain agnostic);needs_v– whether the function requires velocity input;pdepth– number of particle axes in the output (0, 1 or 2);n_features– number of feature channels on the last axis;param_suite– optionalParamSuitedescribing parameters, orNonefor deterministic nodes;extras_required– a tuple of extras keys that must be present;particles_input– whether the inputxis expected to carry a particle axis.
Velocities are optional throughout: for overdamped systems, state functions
typically depend only on positions; for underdamped systems, objects with
needs_v=True additionally take velocities and can expose derivatives with
respect to both.
The contract is enforced at node level by two main guards:
_assert_inputschecks the shape and consistency ofx,v,maskandextras;_assert_outputschecks that the outputymatches the declared feature, rank, and particle-depth structure.
Masking and missing data¶
Inference in SFI must deal with complex patterns of missing data:
particles appear and disappear over time;
different coordinates of the same particle may be masked;
estimators often evaluate derivatives at shifted time points.
The contract enforces a simple rule for masks:
a mask, if provided, must broadcast to the prefix of
xexcluding the spatial dimension, i.e. tox.shape[:-1];with
particles_input=True, this prefix is(…, N_particles)and the mask has the natural shape “batch · particles”.
Masks can be boolean or numeric; broadcasting follows NumPy rules. In typical
use, masks are constructed automatically from missing data in
TrajectoryCollection (trajectory module). User-defined single-sample
functions almost never manipulate masks explicitly:
factories accept a keyword-only argument
mask=None, but most user functions do not need to mention it;the surrounding SFI code constructs and applies masks as needed.
Differentiation and custom rules¶
Inference modules need derivatives with respect to:
positions and velocities (for drift and diffusion estimators);
parameters (for likelihood-based inference and gradient-based optimisation).
Naive autodiff through raw code is not enough: it must interact correctly with
masking, particle indexing, and multi-particle interactions. The
StateExpr graph carries explicit derivative nodes, and the API
provides consistent methods such as d_x(), d_v() and d_theta() on
PSF and SF. Custom differentiation rules for multi-particle
systems are implemented at node level while still leveraging JAX under the
hood.
Modularity and algebra of operators¶
Realistic models are built from reusable bricks: monomials, radial functions,
local tensors, external fields, etc. The state function abstraction provides
an algebra of operators on StateExpr:
concatenation of feature dictionaries;
products with scalar fields;
Einstein summations (
einsum) to contract features with unit vectors, projectors, and tensors;feature slicing and reshaping;
mapping over extra axes (e.g. species or feature groups).
The static contract for composite nodes is obtained by merging children via
merge_contract with mode-specific rules:
mode="concat"– concatenate features; requires equal rank, sumsn_features;mode="map"– elementwise maps; requires equal rank andn_features;mode="einsum"– spatial contraction; the Einstein spec determines the output rank and the features are the Cartesian product of the inputs. Particle depth andparticles_inputare merged using a strict broadcast rule (either identical, or scalars combined with particle-wise nodes).
Because each operator preserves the static contract, complex bases can be built safely by combining simpler ones. In particular, it is natural to:
start from scalar features such as monomials or radial kernels;
contract them with unit vectors to obtain vector forces;
assemble higher-rank tensors (diffusion matrices, projectors, stress tensors) with prescribed symmetry.
Identity, naming and model selection¶
Inference modules need to keep track of which feature is which:
coefficients must be associated with meaningful labels;
model selection routines need stable feature identities for pruning;
pretty-printing and diagnostics require readable names.
Each state function carries labels and descriptors for its feature axis, and these are preserved under composition (with appropriate concatenation or prefixing rules). This traceability is what allows higher-level code to talk about “turning off a specific kernel” or “selecting the best radial terms” without losing track of which contribution is which.
Extras: pass-through pytrees¶
Many models depend on extras that are not part of the core state:
boundary conditions (e.g. periodic box size);
neighbor or adjacency lists;
experiment-specific parameters or tags;
particle identities and types.
State functions declare which extras keys they require, via
extras_required. At runtime:
_assert_inputsguarantees that required keys are present if any extras mapping is provided;extras themselves can be any JAX-compatible pytree (arrays, nested dicts, tuples, …) and are passed through to leaves and dispatchers without shape validation at this level.
This is essential for interacting-particle and spatial models, where local kernels need access to box sizes, neighbor lists, or other contextual data.
Named parameters and parametric trees¶
Parametric state functions (PSF) organise parameters into named
blocks described by ParamSpec and grouped into a ParamSuite:
each
ParamSpechas aname,shape,dtype, and an initialisation rule;ParamSuiteis an immutable container of specs; it can map between flat vectors and parameter trees and can coerce user-provided parameter dicts to the expected shapes and dtypes.
Crucially, parameter names encode sharing:
merging parameter suites (e.g. when composing PSFs) uses name-based union;
if a name appears in multiple suites, the corresponding specs must be compatible (same shape and dtype) and are merged into a single shared parameter; otherwise an error is raised.
Thus, reusing a parameter name in different parts of a PSF implies that the parameter is shared and must have the same array structure. This is how weight tying and shared kernels are expressed in the architecture.
Factory entry points: make_basis and make_psf¶
The factories make_basis() and make_psf() are the primary
user-facing entry points. They take single-sample functions written in
natural Python / jax.numpy and turn them into full Basis or
PSF objects with explicit metadata.
Both take a single-sample function plus declared metadata (dim,
rank, n_features, labels, parameter specs) and return the
corresponding rich object; worked examples live in
Models and state functions. In both cases:
the function is written for a single sample; SFI handles batching over time and particles;
maskis an optional keyword that most users can ignore;metadata (dimension, rank, particle depth, feature labels, extras keys, parameter spec) is declared once at construction and then enforced.
Single-feature PSFs are common and natural: they represent a single parametric field or force. Bases, in contrast, are typically multi-feature dictionaries used for linear combinations; if you conceptually have a single feature with a coefficient, a PSF is usually a better fit than a single-feature Basis.
Linear and nonlinear inference: Basis vs PSF¶
The distinction between Basis and PSF reflects the two main
inference regimes in SFI:
Linear inference uses
Basisobjects. Given a basis matrix and data, the inference module solves the linear regression problem exactly (or via sparse / regularised variants) to obtain optimal coefficients.Nonlinear inference uses
PSFobjects. Parameters are updated through nonlinear optimisers (gradient descent, quasi-Newton, etc.), using derivatives with respect to parameters exposed by the PSF.
Both are important:
linear inference is fast, robust, and often sufficient when a rich basis is available;
nonlinear PSFs are needed for models with genuinely nonlinear parameter dependencies or shared parameters across many features.
Once parameters are inferred, a PSF can be frozen into an SF for use
in simulation and downstream analysis.
Interacting particle systems (optional)¶
For systems of many interacting particles (or lattice sites), SFI takes a “local rule + dispatcher” approach:
A local
Interactordescribes how a small group of particles (typically a pair) contributes to the quantity of interest (force, torque, current, …). It operates on a local configurationXkwith shape(K, dim)and optional extras and parameters.Calling
Interactor.dispatch(...)turns this local rule into a global state function (Basis/PSF/SF), specifying: * how to select and combine the K-tuples (all pairs, finite range, given neighbor list, …); * how to reduce them (sum, average, owner conventions).
The underlying dispatcher and interaction-node machinery lives in backend
modules; Interactor and its .dispatch method are the main
front-end objects. Conceptually, this design is close to graph neural
networks: the interactor plays the role of a message function on edges, and
the dispatcher specifies the graph and aggregation on nodes.
Users who do not work with interacting-particle or spatial systems can skip
this part; the rest of statefunc is sufficient for single-particle or
well-mixed models.
Particles input vs particle depth¶
Two related but distinct concepts control how particle axes are handled:
particles_inputindicates whether the state function expects a particle axis in its input. Many single-particle bricks operate on a single state vector and are vmapped over particles by the surrounding code, while interactors expect a fixed numberKof particles as input.pdepthcounts how many particle axes the output carries (0, 1 or 2). A single-particle force typically haspdepth = 1; a purely local scalar kernel might havepdepth = 0; a pairwise matrix could havepdepth = 2.
The contract for outputs enforces:
if
particles_input=True, an inputxwith prefixbatch · P_in · dimmust produce outputs whose prefix isbatch · P_in^pdepth(i.e. one particle axis per depth level);if
particles_input=False, a node withpdepth=0preserves the batch prefix ofx, and other combinations are rejected at runtime.
This ensures that the same operators (concatenation, derivatives, etc.) work uniformly across single-particle functions, multi-particle interactions, and spatial fields.
Symmetries and tensor structure¶
A key motivation for this architecture is precise control over symmetries:
scalar features can encode radial dependence, invariants, or other symmetry constraints;
unit vectors and projectors can be composed into tensors with prescribed symmetries (e.g. longitudinal vs transverse components, radial vs tangential directions);
dimension-dependent behaviour is explicit: basis functions can differ between components, not just be copies of the same scalar applied to each dimension.
This contrasts with earlier SFI architectures, where basis functions were often implicitly “the same for all dimensions”. In the new design, the structure of forces and diffusion tensors is expressed directly at the level of state functions, making symmetries transparent and easier to enforce.
Feature axis and drop_features¶
All state functions place the feature axis last. Some operations can “drop”
features and return a featureless tensor (for example, turning a single-feature
PSF into an SF that represents a single field rather than a
dictionary).
To avoid ambiguity:
operations that would need to drop more than one feature (
n_features > 1) are forbidden and raise an error;single-feature PSFs are the intended target for such operations;
Bases are typically multi-feature; users are expected to select individual features explicitly before applying operations that drop the feature axis.
Backend architecture: leaves, nodes, operations¶
Internally, statefunc organises computations as a tree of nodes:
leaf nodes implement primitive bricks (single-sample functions, interactors, parameter blocks);
operation nodes implement composition (concat, map, einsum, derivatives, slicing, interaction dispatch, …);
every node inherits the same static contract and participates in memory accounting.
This node-level backend (including the interaction dispatchers) is considered
an implementation detail. The front-end abstractions (Basis,
PSF, SF, Interactor, and the factory functions) are
the stable public surface.
Performance and memory-aware evaluation¶
Each node also contributes to a global estimate of memory footprint. This
allows the integration layer (SFI.integrate) to:
plan vectorised evaluations over time and particles;
adaptively chunk computations to stay within a target memory budget.
This memory-aware design is what lets large systems (many particles, long trajectories, rich bases) be evaluated and differentiated efficiently without manual chunking logic in user code.
See also
The Gallery demo gallery and the Running inference section showcase the evaluation engine on concrete inference problems.