SFI.statefunc.factory module¶
- SFI.statefunc.factory.make_basis(func, *, dim=None, rank, n_features=1, needs_v=False, labels=None, descriptor='custom', extras_keys=None, particle_extras=None, specialize_at=None)[source]¶
Construct a deterministic Basis from a single-sample user function, with no particle semantics. Particle axes (if present in
xat call time) are treated purely as batch axes and vmapped over.particle_extrasnames extras keys whose values are per-sample arrays aligned with the batch/particle axes (e.g. anextras_localentry of shape(N, ...)): they are vmapped alongsidex, so the single-sample function sees its own particle’s value instead of the whole array — the route to per-particle terms in single-particle bases (home-range centres, individual labels, …).User function signature — declare only the kwargs you need:
Simplest:
f(x) -> arrayWith velocity:
f(x, *, v) -> arrayWith extras:
f(x, *, extras) -> array
The full signature is
f(x, *, v=None, mask=None, extras=None); we introspect and pass only the kwargs you declare.Shapes (single sample):
x: (dim,) return: (*rank_axes, m) # feature last; m == n_features
If
n_features == 1you may omit the last axis; a singleton feature axis is auto-inserted.Extras¶
If
extrasis declared, you may provideextras_keys=(...)to enforce keys. Extras arrays must broadcast over the batch prefix (never over rank/feature).Examples
>>> import jax.numpy as jnp >>> from SFI.statefunc import make_basis >>> B = make_basis(lambda x: x, dim=2, rank=1, n_features=1) # (equivalent to the built-in X(dim=2))
JAX¶
Write
fwithjax.numpyand keep it pure; works with jit/vmap/autodiff.- Parameters:
func (Callable)
dim (int | None)
rank (int)
n_features (int)
needs_v (bool)
labels (Sequence[str] | None)
descriptor (Any)
extras_keys (Sequence[str] | None)
particle_extras (Sequence[str] | None)
specialize_at (Callable | None)
- Return type:
- SFI.statefunc.factory.make_interactor(func, *, dim, rank, K=None, Kmax=None, n_features=1, needs_v=False, labels=(), descriptor=None, params=None, extras_keys=(), particle_extras=())[source]¶
Build a local interaction dictionary (Interactor) from a single-sample user function that consumes (K, dim) and returns feature-last.
- Pass exactly one of:
K=int → fixed arity
Kmax=int → variable arity (ragged via mask)
particle_extrasnames the extras keys whose values are per-particle arrays (shape(P, ...)): the dispatcher gathers them per edge member, so insidefuncthey arrive with shape(K, ...)— one entry per member of the local tuple. The reserved"particle_index"extra (injected byTrajectoryCollection) combined with a(P,)-shaped parameter gives per-particle inferred parameters:def local(Xk, *, params, extras): mob = params["mob"][extras["particle_index"]] # (K,) ...
- Parameters:
func (Callable)
dim (int)
rank (Rank)
K (int | None)
Kmax (int | None)
n_features (int)
needs_v (bool)
labels (Iterable[str])
params (ParamSuite | None)
extras_keys (Iterable[str])
particle_extras (Iterable[str])
- SFI.statefunc.factory.make_psf(func, *, dim=None, rank, n_features=1, drop_features=True, needs_v=False, labels=None, descriptor='parametric', params, extras_keys=None, specialize_at=None)[source]¶
Construct a parametric state-function family (PSF) from a single-sample user function, without particle semantics.
User function signature — declare only the kwargs you need:
Simplest:
f(x, *, params) -> arrayWith velocity:
f(x, *, v, params) -> arrayWith extras:
f(x, *, params, extras) -> array
The full signature is
f(x, *, params, v=None, mask=None, extras=None); we introspect and pass only the kwargs you declare.Shapes (single sample):
x: (dim,) return: (*rank_axes, m) # feature last; m == n_features
If
n_features == 1you may omit the last axis; we auto-insert a singleton feature axis.Parameters (
params) may be described as:a
ParamSuite,an iterable of
ParamSpec,a dict of shapes, e.g.
{'W': (d,d), 'b': ()},or a dict of sample arrays from which (shape, dtype) are inferred.
Extras¶
Same rules as
make_basis(extras_keysoptional; broadcast over batch prefix).JAX¶
Works with jit/vmap/autodiff w.r.t. inputs and parameters.
- Parameters:
func (Callable)
dim (int | None)
rank (int)
n_features (int)
drop_features (bool)
needs_v (bool)
labels (Sequence[str] | None)
descriptor (Any)
params (ParamSuite | Iterable[ParamSpec] | dict[str, Any])
extras_keys (Sequence[str] | None)
specialize_at (Callable | None)
- Return type:
- SFI.statefunc.factory.make_sf(func, *, dim=None, rank, n_features=1, drop_features=True, needs_v=False, labels=None, descriptor='custom_sf', extras_keys=None)[source]¶
Construct an SF (bound state function) directly from a parameter-free user function — no
BasisorPSFintermediate needed.This is the simplest entry point when you have a known, fixed function (e.g. an exact model for comparison, or a hand-coded feature) and just want a callable that participates in the SFI expression-tree ecosystem.
The resulting
SFsupports.d_x(),.d_v(), and can be passed tocompare_to_exact,integrate, or any other API that accepts anSF/StateExpr.User function signature — declare only the kwargs you need:
Simplest:
f(x) -> arrayWith velocity:
f(x, *, v) -> arrayWith extras:
f(x, *, extras) -> array
Shapes (single sample):
x: (dim,) return: (*rank_axes, n_features)
If
n_features == 1you may omit the trailing feature axis; it is auto-inserted. The resulting SF squeezes it back whendrop_features=True(default).- Parameters:
func (callable) – Pure JAX function, compatible with jit/vmap/autodiff.
dim (int or None) – Spatial dimensionality (None = infer at first call).
rank (int) – Tensor rank of the output (0=scalar, 1=vector, 2=matrix).
n_features (int) – Number of output features (default 1).
drop_features (bool) – Remove trailing size-1 feature axis (default True).
needs_v (bool) – Whether
funcrequires velocityv.labels (sequence of str or None) – Human-readable feature labels (auto-generated if None).
descriptor (any) – Metadata tag stored on the leaf node.
extras_keys (sequence of str or None) – Required keys in the
extrasmapping.
- Returns:
A bound, callable state function with no free parameters.
- Return type:
Examples
>>> import jax.numpy as jnp >>> from SFI.statefunc import make_sf >>> harmonic = make_sf(lambda x: -x, rank=1, dim=2) >>> harmonic(jnp.array([1.0, 2.0])) Array([-1., -2.], dtype=float32)