SFI.statefunc.factory module

SFI.statefunc.factory.make_basis(func, *, dim=None, rank, n_features=1, needs_v=False, labels=None, descriptor='custom', extras_keys=None, particle_extras=None, specialize_at=None)[source]

Construct a deterministic Basis from a single-sample user function, with no particle semantics. Particle axes (if present in x at call time) are treated purely as batch axes and vmapped over.

particle_extras names extras keys whose values are per-sample arrays aligned with the batch/particle axes (e.g. an extras_local entry of shape (N, ...)): they are vmapped alongside x, so the single-sample function sees its own particle’s value instead of the whole array — the route to per-particle terms in single-particle bases (home-range centres, individual labels, …).

User function signature — declare only the kwargs you need:

  • Simplest: f(x) -> array

  • With velocity: f(x, *, v) -> array

  • With extras: f(x, *, extras) -> array

The full signature is f(x, *, v=None, mask=None, extras=None); we introspect and pass only the kwargs you declare.

Shapes (single sample):

x: (dim,)
return: (*rank_axes, m)  # feature last; m == n_features

If n_features == 1 you may omit the last axis; a singleton feature axis is auto-inserted.

Extras

If extras is declared, you may provide extras_keys=(...) to enforce keys. Extras arrays must broadcast over the batch prefix (never over rank/feature).

Examples

>>> import jax.numpy as jnp
>>> from SFI.statefunc import make_basis
>>> B = make_basis(lambda x: x, dim=2, rank=1, n_features=1) # (equivalent to the built-in X(dim=2))

JAX

Write f with jax.numpy and keep it pure; works with jit/vmap/autodiff.

Parameters:
  • func (Callable)

  • dim (int | None)

  • rank (int)

  • n_features (int)

  • needs_v (bool)

  • labels (Sequence[str] | None)

  • descriptor (Any)

  • extras_keys (Sequence[str] | None)

  • particle_extras (Sequence[str] | None)

  • specialize_at (Callable | None)

Return type:

Basis

SFI.statefunc.factory.make_interactor(func, *, dim, rank, K=None, Kmax=None, n_features=1, needs_v=False, labels=(), descriptor=None, params=None, extras_keys=(), particle_extras=())[source]

Build a local interaction dictionary (Interactor) from a single-sample user function that consumes (K, dim) and returns feature-last.

Pass exactly one of:
  • K=int → fixed arity

  • Kmax=int → variable arity (ragged via mask)

particle_extras names the extras keys whose values are per-particle arrays (shape (P, ...)): the dispatcher gathers them per edge member, so inside func they arrive with shape (K, ...) — one entry per member of the local tuple. The reserved "particle_index" extra (injected by TrajectoryCollection) combined with a (P,)-shaped parameter gives per-particle inferred parameters:

def local(Xk, *, params, extras):
    mob = params["mob"][extras["particle_index"]]   # (K,)
    ...
Parameters:
  • func (Callable)

  • dim (int)

  • rank (Rank)

  • K (int | None)

  • Kmax (int | None)

  • n_features (int)

  • needs_v (bool)

  • labels (Iterable[str])

  • params (ParamSuite | None)

  • extras_keys (Iterable[str])

  • particle_extras (Iterable[str])

SFI.statefunc.factory.make_psf(func, *, dim=None, rank, n_features=1, drop_features=True, needs_v=False, labels=None, descriptor='parametric', params, extras_keys=None, specialize_at=None)[source]

Construct a parametric state-function family (PSF) from a single-sample user function, without particle semantics.

User function signature — declare only the kwargs you need:

  • Simplest: f(x, *, params) -> array

  • With velocity: f(x, *, v, params) -> array

  • With extras: f(x, *, params, extras) -> array

The full signature is f(x, *, params, v=None, mask=None, extras=None); we introspect and pass only the kwargs you declare.

Shapes (single sample):

x: (dim,)
return: (*rank_axes, m)  # feature last; m == n_features

If n_features == 1 you may omit the last axis; we auto-insert a singleton feature axis.

Parameters (params) may be described as:

  • a ParamSuite,

  • an iterable of ParamSpec,

  • a dict of shapes, e.g. {'W': (d,d), 'b': ()},

  • or a dict of sample arrays from which (shape, dtype) are inferred.

Extras

Same rules as make_basis (extras_keys optional; broadcast over batch prefix).

JAX

Works with jit/vmap/autodiff w.r.t. inputs and parameters.

Parameters:
  • func (Callable)

  • dim (int | None)

  • rank (int)

  • n_features (int)

  • drop_features (bool)

  • needs_v (bool)

  • labels (Sequence[str] | None)

  • descriptor (Any)

  • params (ParamSuite | Iterable[ParamSpec] | dict[str, Any])

  • extras_keys (Sequence[str] | None)

  • specialize_at (Callable | None)

Return type:

PSF

SFI.statefunc.factory.make_sf(func, *, dim=None, rank, n_features=1, drop_features=True, needs_v=False, labels=None, descriptor='custom_sf', extras_keys=None)[source]

Construct an SF (bound state function) directly from a parameter-free user function — no Basis or PSF intermediate needed.

This is the simplest entry point when you have a known, fixed function (e.g. an exact model for comparison, or a hand-coded feature) and just want a callable that participates in the SFI expression-tree ecosystem.

The resulting SF supports .d_x(), .d_v(), and can be passed to compare_to_exact, integrate, or any other API that accepts an SF / StateExpr.

User function signature — declare only the kwargs you need:

  • Simplest: f(x) -> array

  • With velocity: f(x, *, v) -> array

  • With extras: f(x, *, extras) -> array

Shapes (single sample):

x:      (dim,)
return: (*rank_axes, n_features)

If n_features == 1 you may omit the trailing feature axis; it is auto-inserted. The resulting SF squeezes it back when drop_features=True (default).

Parameters:
  • func (callable) – Pure JAX function, compatible with jit/vmap/autodiff.

  • dim (int or None) – Spatial dimensionality (None = infer at first call).

  • rank (int) – Tensor rank of the output (0=scalar, 1=vector, 2=matrix).

  • n_features (int) – Number of output features (default 1).

  • drop_features (bool) – Remove trailing size-1 feature axis (default True).

  • needs_v (bool) – Whether func requires velocity v.

  • labels (sequence of str or None) – Human-readable feature labels (auto-generated if None).

  • descriptor (any) – Metadata tag stored on the leaf node.

  • extras_keys (sequence of str or None) – Required keys in the extras mapping.

Returns:

A bound, callable state function with no free parameters.

Return type:

SF

Examples

>>> import jax.numpy as jnp
>>> from SFI.statefunc import make_sf
>>> harmonic = make_sf(lambda x: -x, rank=1, dim=2)
>>> harmonic(jnp.array([1.0, 2.0]))
Array([-1., -2.], dtype=float32)