SFI.statefunc package¶
High-level API for state functions.
- Public classes:
Basis: deterministic dictionary façade
PSF: parametric family F(x; θ)
SF: state-function with fixed θ
Interactor: local interaction expression (pre-dispatch)
StateExpr: immutable expression tree base class
- Factory helpers:
make_basis, make_psf, make_sf, make_interactor
- Control:
set_jit: enable/disable JIT on __call__
- Power-user primitives:
Rank, ParamSpec, ParamSuite
- class SFI.statefunc.Basis(root)[source]¶
Bases:
StateExprDeterministic dictionary façade (no parameters).
- Parameters:
root (BaseNode)
- d_v(*, same_particle=False, mode='auto')¶
Build an expression for the velocity Jacobian ∂F/∂v.
Same rules as .d_x(). Requires needs_v=True on the underlying expression.
- Parameters:
same_particle (bool)
mode (str)
- d_x(*, same_particle=False, mode='auto')¶
Build an expression for the spatial Jacobian dF/dx.
Axis effects¶
Adds one derivative-dim immediately before the rank block.
If
particles_input=True:when
same_particle=True: if pdepth=1, compute df_i/dx_i (no extra P axis); the particle dimension behaves like a broadcasted index. Otherwise, raises an error.when
same_particle=False(default): compute the full cross-particle Jacobian df_i/dx_j; an extra particle axis appears (from JAX). We never create P axes ourselves; we only permute to canonical order.
- param same_particle:
See axis effects above.
- type same_particle:
bool
- param mode:
Backend differentiation mode; ‘auto’ selects a sane default.
- type mode:
{‘auto’, …}
- returns:
A new expression representing the Jacobian.
- rtype:
StateExpr
Notes
This method triggers no evaluation; it returns a new graph.
- Parameters:
same_particle (bool)
mode (str)
- dense(n_out, *, weight='W', bias='b')¶
Apply a learnable affine map on the feature axis.
y[..., j] = sum_i x[..., i] * W[i, j] + b[j]Spatial (rank) axes are untouched: the same
W, bare shared across every spatial component. The result is always a PSF (since the dense layer introduces learnable parameters).- Parameters:
n_out (int) – Number of output features.
weight (str) – Name for the weight parameter (default
"W"). Use distinct names ("W1","W2", …) when stacking multiple layers.bias (str | None) – Name for the bias parameter (default
"b";Noneto omit). Use distinct names ("b1","b2", …) when stacking layers.
- Returns:
A parametric state function wrapping the dense layer.
- Return type:
Examples
Build the hidden layers of an MLP force field:
>>> from SFI.bases import X >>> import jax.numpy as jnp >>> mlp = ( ... X(dim=2).vectorize(2) ... .dense(32, weight="W1", bias="b1") ... .elementwisemap(jnp.tanh) ... .dense(1, weight="W2", bias="b2") ... )
- property dim¶
- dot(other, axes=None)¶
Spatial tensordot via einsum.
- Semantics:
axes=None: contract last axis of self with first axis of other.
- axes=int:
if self.rank == other.rank: contract all axes (Frobenius/trace for rank-2).
else: contract axes trailing axes of self with axes leading axes of other.
axes=(a_axes, b_axes): NumPy-style explicit lists.
Arrays are accepted and coerced to spatial constants.
- classmethod einsum(spec, *operands)¶
General contraction on spatial axes (like jnp.einsum).
Important
Use only lowercase letters.
spec refers only to spatial axes (not the feature axis).
Features take a Cartesian product across operands (no implicit feature reduction or alignment). If you need feature concatenation, use &/stack. For per-feature ops, use element-wise maps or binary ops where features must match.
Arrays in operands are accepted and coerced to spatial-constant expressions with a single feature. Only spatial letters in spec are interpreted. If no StateExpr is present, a TypeError is raised because dim cannot be inferred.
Examples
Vector inner product (per-feature), two rank-1 inputs: >>> # a, b: i × F >>> c = StateExpr.einsum(“i,i->”, a, b) # result: × F
Matrix–vector product (per-feature), rank-2 with rank-1: >>> # M: ij × F1, v: j × F2 → i × (F1×F2) >>> y = StateExpr.einsum(“ij,j->i”, M, v)
Outer product (per-feature Cartesian product): >>> # u: i × F1, v: j × F2 → ij × (F1×F2) >>> O = StateExpr.einsum(“i,j->ij”, u, v)
- Parameters:
spec (str) – An einsum string over spatial indices, e.g. “ij,j->i”.
operands (mix[StateExpr, array-like]) – Any mix of StateExpr and arrays.
- elementwisemap(func, *, label_fn=None)¶
Apply func element-wise to every feature (spatial axes untouched).
func must be a pure JAX function from scalar→scalar (rank-0 arrays OK). If the expression carries feature labels (e.g., a Basis or an SF bound from a Basis), label_fn (if provided) is applied to each feature label.
Example
>>> B = ... # Basis with 4 features >>> C = B.elementwisemap(jnp.tanh, label_fn=lambda s: f"tanh({s})")
- Parameters:
func (Callable[[Array], Array])
label_fn (Callable[[str], str] | None)
- estimate_bytes_per_sample(*, dtype=None, particle_size=None, sample=None, mode='forward')¶
Small convenience wrapper returning only the transient bytes/sample.
- Parameters:
particle_size (int | None)
sample (SampleMeta | None)
mode (str)
- Return type:
int
- features_to_rank(rank)¶
Unfold features into spatial axes → given rank.
The output layout changes from the current:
batch · (dim,)^self.rank · n_features
to:
batch · (dim,)^rank · (n_features / dim^(rank − self.rank),)
where the new innermost spatial axes are carved out of the feature axis. This is a pure reshape and is the exact inverse of
rank_to_features()when restoring the original rank.- Parameters:
rank (int) – Target tensor rank (must be greater than the current rank).
- Returns:
Expression at the requested rank with fewer features.
- Return type:
StateExpr (same subclass)
- Raises:
ValueError – If
n_featuresis not divisible bydim^Δrank.TypeError – If
rank ≤ self.rank(userank_to_featuresto go down).
Examples
Turn a dense layer’s output back into a vector field:
>>> scalar_expr.features_to_rank(1) # rank-1, F/dim features
Build a 2→H→H→2 MLP force field:
>>> mlp = ( ... X(dim=2) ... .rank_to_features() # rank-0, 2 features ... .dense(32, weight="W1", bias="b1") ... .elementwisemap(jnp.tanh) ... .dense(2, weight="W2", bias="b2") # rank-0, 2 features ... .features_to_rank(1) # rank-1, 1 feature ... )
- property labels¶
- memory_hint(*, dtype=None, particle_size=None, sample=None, mode='forward')¶
Conservative per-sample memory footprint for the WHOLE expression tree. Delegates to the root node, which sums children + own output along the way.
- Parameters:
particle_size (int | None)
sample (SampleMeta | None)
mode (str)
- property n_features¶
- property needs_v¶
- property particle_extras: tuple[str, ...]¶
Pure metadata, forwarded from the root node.
Names of extras declared as per-particle somewhere in the underlying node tree (typically by interaction leaves). The dispatcher reads this to know which keys to gather from (P, …) into (E, K, …) per edge before calling locals.
- property particles_input¶
- property pdepth¶
- property rank¶
- rank_to_features()¶
Fold all spatial (rank) axes into the feature axis → rank-0.
The output layout changes from:
batch · (dim,)^rank · n_features
to:
batch · (n_features × dim^rank,)
with rank = 0. This is a pure reshape (no copy, no learnable parameters) and is the exact inverse of
features_to_rank(original_rank).- Returns:
Scalar expression whose feature count is
self.n_features × self.dim ** self.rank.- Return type:
StateExpr (same subclass)
- Raises:
TypeError – If the expression is already rank‑0 (no-op would be confusing).
Examples
Prepare a rank-1 position vector for dense layers:
>>> X(dim=2).rank_to_features() # rank-0, 2 features
The round-trip is the identity:
>>> expr.rank_to_features().features_to_rank(expr.rank) # same as expr
- property required_extras: tuple[str, ...]¶
Presence-only extras required by the expression, forwarded from the root node. No shape/broadcast semantics here.
- root: BaseNode¶
- property sdims¶
- specialize(*, dataset)¶
Collapse a pooled model to its single-condition specialization.
Returns a new expression in which every
dataset_index-reading primitive (e.g.per_dataset_scalar(),dataset_indicator()) is folded at conditiondataset: per-condition parameter arrays collapse to that condition’s slice and the reserveddataset_indexextra drops out ofrequired_extras. The pooled-ness is an inference-time concern; once a condition is chosen the model stands alone (no dataset concept).On a bound
SFthe stored parameter values are projected to match the shrunken template; on an unboundPSFthe template’s per-condition specs become scalars.- Parameters:
dataset (int)
- Return type:
- sqrtm()¶
- classmethod stack(exprs)¶
Concatenate along the feature axis.
Static contracts must match (rank/dim, compatible pdepth).
- Parameters:
exprs (Sequence[StateExpr])
- tensordot(other, axes=1)¶
Alias of .dot with NumPy-compatible axes.
- tensorize(dim=None, mode='symmetric')¶
Lift a scalar expression to rank-2 (matrix).
- Parameters:
dim (int, optional) – Spatial dimension. Inferred when possible.
mode (str) –
'symmetric'(default) usessymmetric_matrix_basis()(d(d+1)/2 features per scalar feature, spans all symmetric matrices).'identity'usesidentity_matrix_basis()(1 feature per scalar feature, isotropic).
- Returns:
Matrix expression.
- Return type:
- to_psf(coeff_key='coeff', drop_features=True)[source]¶
Return a parametric state function whose value is a linear combination of this Basis’ features:
F(x; θ) = Σ_j θ_j · f_j(x)
Note that use cases are rare within SFI, since the PSF’s features axis is typically used for nonlinearities and/or vector/tensor components. But this can be useful for quick prototyping of linear models, benchmark comparisons of linear vs nonlinear solvers, or as a building block for more complex PSFs.
- Parameters:
coeff_key (str) – Key name for the coefficient vector in the parameter dict.
drop_features (bool) – Whether to remove the trailing size-1 feature axis (default True).
Notes
The resulting PSF shares the same spatial contract (rank/dim/pdepth, particles_input) as this Basis, and does not have a features axis.
- vectorize(dim=None, axes=None)¶
Lift a scalar expression to rank-1 (vector).
Equivalent to
self * unit_vector_basis(dim, axes=axes), i.e. a Cartesian product of the feature axis with unit vectors.- Parameters:
dim (int, optional) – Spatial dimension. Inferred from the expression’s contract when possible.
axes (sequence of int, optional) – Subset of spatial axes to include (default: all
dimaxes).
- Returns:
Vector expression with
n_features = self.n_features × len(axes).- Return type:
- class SFI.statefunc.Interactor(root)[source]¶
Bases:
StateExprLocal interaction expression (pre-dispatch).
- root must be a local graph built from InteractionLeaf(s):
particles_input=True, pdepth=0.
Compose as usual: inter = make_interactor(…); inter2 = (inter & inter)… Then call .dispatch(…) exactly once to obtain a Basis or a PSF.
- Parameters:
root (BaseNode)
- d_v(*, same_particle=False, mode='auto')¶
Build an expression for the velocity Jacobian ∂F/∂v.
Same rules as .d_x(). Requires needs_v=True on the underlying expression.
- Parameters:
same_particle (bool)
mode (str)
- d_x(*, same_particle=False, mode='auto')¶
Build an expression for the spatial Jacobian dF/dx.
Axis effects¶
Adds one derivative-dim immediately before the rank block.
If
particles_input=True:when
same_particle=True: if pdepth=1, compute df_i/dx_i (no extra P axis); the particle dimension behaves like a broadcasted index. Otherwise, raises an error.when
same_particle=False(default): compute the full cross-particle Jacobian df_i/dx_j; an extra particle axis appears (from JAX). We never create P axes ourselves; we only permute to canonical order.
- param same_particle:
See axis effects above.
- type same_particle:
bool
- param mode:
Backend differentiation mode; ‘auto’ selects a sane default.
- type mode:
{‘auto’, …}
- returns:
A new expression representing the Jacobian.
- rtype:
StateExpr
Notes
This method triggers no evaluation; it returns a new graph.
- Parameters:
same_particle (bool)
mode (str)
- dense(n_out, *, weight='W', bias='b')¶
Apply a learnable affine map on the feature axis.
y[..., j] = sum_i x[..., i] * W[i, j] + b[j]Spatial (rank) axes are untouched: the same
W, bare shared across every spatial component. The result is always a PSF (since the dense layer introduces learnable parameters).- Parameters:
n_out (int) – Number of output features.
weight (str) – Name for the weight parameter (default
"W"). Use distinct names ("W1","W2", …) when stacking multiple layers.bias (str | None) – Name for the bias parameter (default
"b";Noneto omit). Use distinct names ("b1","b2", …) when stacking layers.
- Returns:
A parametric state function wrapping the dense layer.
- Return type:
Examples
Build the hidden layers of an MLP force field:
>>> from SFI.bases import X >>> import jax.numpy as jnp >>> mlp = ( ... X(dim=2).vectorize(2) ... .dense(32, weight="W1", bias="b1") ... .elementwisemap(jnp.tanh) ... .dense(1, weight="W2", bias="b2") ... )
- property dim¶
- dispatch(spec, *, owners='focal', focal_index=0, owner_weights=None, reducer='sum', normalize_by_degree=False, exclude_self=True, chunk_size=None, return_as='auto', drop_features=None)[source]¶
- Parameters:
spec (PairsCSR | HyperFixed | HyperCSR | SpecRule)
owners (Literal['focal', 'all', 'custom', 'global'])
focal_index (int)
reducer (Literal['sum', 'mean', 'max'])
normalize_by_degree (bool)
exclude_self (bool)
chunk_size (int | None)
return_as (Literal['auto', 'basis', 'psf'])
drop_features (bool | None)
- dispatch_pairs_from_extras(*, indptr_key, indices_key, **kwargs)[source]¶
- Parameters:
indptr_key (str)
indices_key (str)
- dot(other, axes=None)¶
Spatial tensordot via einsum.
- Semantics:
axes=None: contract last axis of self with first axis of other.
- axes=int:
if self.rank == other.rank: contract all axes (Frobenius/trace for rank-2).
else: contract axes trailing axes of self with axes leading axes of other.
axes=(a_axes, b_axes): NumPy-style explicit lists.
Arrays are accepted and coerced to spatial constants.
- classmethod einsum(spec, *operands)¶
General contraction on spatial axes (like jnp.einsum).
Important
Use only lowercase letters.
spec refers only to spatial axes (not the feature axis).
Features take a Cartesian product across operands (no implicit feature reduction or alignment). If you need feature concatenation, use &/stack. For per-feature ops, use element-wise maps or binary ops where features must match.
Arrays in operands are accepted and coerced to spatial-constant expressions with a single feature. Only spatial letters in spec are interpreted. If no StateExpr is present, a TypeError is raised because dim cannot be inferred.
Examples
Vector inner product (per-feature), two rank-1 inputs: >>> # a, b: i × F >>> c = StateExpr.einsum(“i,i->”, a, b) # result: × F
Matrix–vector product (per-feature), rank-2 with rank-1: >>> # M: ij × F1, v: j × F2 → i × (F1×F2) >>> y = StateExpr.einsum(“ij,j->i”, M, v)
Outer product (per-feature Cartesian product): >>> # u: i × F1, v: j × F2 → ij × (F1×F2) >>> O = StateExpr.einsum(“i,j->ij”, u, v)
- Parameters:
spec (str) – An einsum string over spatial indices, e.g. “ij,j->i”.
operands (mix[StateExpr, array-like]) – Any mix of StateExpr and arrays.
- elementwisemap(func, *, label_fn=None)¶
Apply func element-wise to every feature (spatial axes untouched).
func must be a pure JAX function from scalar→scalar (rank-0 arrays OK). If the expression carries feature labels (e.g., a Basis or an SF bound from a Basis), label_fn (if provided) is applied to each feature label.
Example
>>> B = ... # Basis with 4 features >>> C = B.elementwisemap(jnp.tanh, label_fn=lambda s: f"tanh({s})")
- Parameters:
func (Callable[[Array], Array])
label_fn (Callable[[str], str] | None)
- estimate_bytes_per_sample(*, dtype=None, particle_size=None, sample=None, mode='forward')¶
Small convenience wrapper returning only the transient bytes/sample.
- Parameters:
particle_size (int | None)
sample (SampleMeta | None)
mode (str)
- Return type:
int
- features_to_rank(rank)¶
Unfold features into spatial axes → given rank.
The output layout changes from the current:
batch · (dim,)^self.rank · n_features
to:
batch · (dim,)^rank · (n_features / dim^(rank − self.rank),)
where the new innermost spatial axes are carved out of the feature axis. This is a pure reshape and is the exact inverse of
rank_to_features()when restoring the original rank.- Parameters:
rank (int) – Target tensor rank (must be greater than the current rank).
- Returns:
Expression at the requested rank with fewer features.
- Return type:
StateExpr (same subclass)
- Raises:
ValueError – If
n_featuresis not divisible bydim^Δrank.TypeError – If
rank ≤ self.rank(userank_to_featuresto go down).
Examples
Turn a dense layer’s output back into a vector field:
>>> scalar_expr.features_to_rank(1) # rank-1, F/dim features
Build a 2→H→H→2 MLP force field:
>>> mlp = ( ... X(dim=2) ... .rank_to_features() # rank-0, 2 features ... .dense(32, weight="W1", bias="b1") ... .elementwisemap(jnp.tanh) ... .dense(2, weight="W2", bias="b2") # rank-0, 2 features ... .features_to_rank(1) # rank-1, 1 feature ... )
- memory_hint(*, dtype=None, particle_size=None, sample=None, mode='forward')¶
Conservative per-sample memory footprint for the WHOLE expression tree. Delegates to the root node, which sums children + own output along the way.
- Parameters:
particle_size (int | None)
sample (SampleMeta | None)
mode (str)
- property n_features¶
- property needs_v¶
- property particle_extras: tuple[str, ...]¶
Pure metadata, forwarded from the root node.
Names of extras declared as per-particle somewhere in the underlying node tree (typically by interaction leaves). The dispatcher reads this to know which keys to gather from (P, …) into (E, K, …) per edge before calling locals.
- property particles_input¶
- property pdepth¶
- property rank¶
- rank_to_features()¶
Fold all spatial (rank) axes into the feature axis → rank-0.
The output layout changes from:
batch · (dim,)^rank · n_features
to:
batch · (n_features × dim^rank,)
with rank = 0. This is a pure reshape (no copy, no learnable parameters) and is the exact inverse of
features_to_rank(original_rank).- Returns:
Scalar expression whose feature count is
self.n_features × self.dim ** self.rank.- Return type:
StateExpr (same subclass)
- Raises:
TypeError – If the expression is already rank‑0 (no-op would be confusing).
Examples
Prepare a rank-1 position vector for dense layers:
>>> X(dim=2).rank_to_features() # rank-0, 2 features
The round-trip is the identity:
>>> expr.rank_to_features().features_to_rank(expr.rank) # same as expr
- property required_extras: tuple[str, ...]¶
Presence-only extras required by the expression, forwarded from the root node. No shape/broadcast semantics here.
- root: BaseNode¶
- property sdims¶
- specialize(*, dataset)¶
Collapse a pooled model to its single-condition specialization.
Returns a new expression in which every
dataset_index-reading primitive (e.g.per_dataset_scalar(),dataset_indicator()) is folded at conditiondataset: per-condition parameter arrays collapse to that condition’s slice and the reserveddataset_indexextra drops out ofrequired_extras. The pooled-ness is an inference-time concern; once a condition is chosen the model stands alone (no dataset concept).On a bound
SFthe stored parameter values are projected to match the shrunken template; on an unboundPSFthe template’s per-condition specs become scalars.- Parameters:
dataset (int)
- Return type:
- sqrtm()¶
- classmethod stack(exprs)¶
Concatenate along the feature axis.
Static contracts must match (rank/dim, compatible pdepth).
- Parameters:
exprs (Sequence[StateExpr])
- tensordot(other, axes=1)¶
Alias of .dot with NumPy-compatible axes.
- tensorize(dim=None, mode='symmetric')¶
Lift a scalar expression to rank-2 (matrix).
- Parameters:
dim (int, optional) – Spatial dimension. Inferred when possible.
mode (str) –
'symmetric'(default) usessymmetric_matrix_basis()(d(d+1)/2 features per scalar feature, spans all symmetric matrices).'identity'usesidentity_matrix_basis()(1 feature per scalar feature, isotropic).
- Returns:
Matrix expression.
- Return type:
- vectorize(dim=None, axes=None)¶
Lift a scalar expression to rank-1 (vector).
Equivalent to
self * unit_vector_basis(dim, axes=axes), i.e. a Cartesian product of the feature axis with unit vectors.- Parameters:
dim (int, optional) – Spatial dimension. Inferred from the expression’s contract when possible.
axes (sequence of int, optional) – Subset of spatial axes to include (default: all
dimaxes).
- Returns:
Vector expression with
n_features = self.n_features × len(axes).- Return type:
- class SFI.statefunc.PSF(root, *, drop_features=True)[source]¶
Bases:
StateExprParametric State-Function family F(x; θ).
By default drop_features=True: when n_features==1, outputs do not carry a trailing feature axis. .d_theta() forces drop_features=False.
Holds a ParamSuite template describing names, shapes, and dtypes of θ. __call__ evaluates F given a parameter dict matching the template. Supports .d_theta() in addition to .d_x()/.d_v().
- Parameters:
root (BaseNode)
drop_features (bool)
- bind(params=None)[source]¶
Freeze parameter dict into an SF with normalized arrays.
If
params is None, fall back to spec defaults (ParamSuite.defaults()). Raises if the template has parameters without defaults.- Parameters:
params (dict[str, Array] | None)
- d_theta(*, mode='auto')[source]¶
Build an expression for the Jacobian w.r.t. parameters θ.
Notes
The final axis becomes features × n_params_total. Batch/pdepth/rank prefixes are preserved exactly.
Notes
The parameter PyTree is handled leafwise; each grad leaf is flattened over its param part, then all leaves are concatenated along the final axis.
- Parameters:
mode (str)
- d_v(*, same_particle=False, mode='auto')¶
Build an expression for the velocity Jacobian ∂F/∂v.
Same rules as .d_x(). Requires needs_v=True on the underlying expression.
- Parameters:
same_particle (bool)
mode (str)
- d_x(*, same_particle=False, mode='auto')¶
Build an expression for the spatial Jacobian dF/dx.
Axis effects¶
Adds one derivative-dim immediately before the rank block.
If
particles_input=True:when
same_particle=True: if pdepth=1, compute df_i/dx_i (no extra P axis); the particle dimension behaves like a broadcasted index. Otherwise, raises an error.when
same_particle=False(default): compute the full cross-particle Jacobian df_i/dx_j; an extra particle axis appears (from JAX). We never create P axes ourselves; we only permute to canonical order.
- param same_particle:
See axis effects above.
- type same_particle:
bool
- param mode:
Backend differentiation mode; ‘auto’ selects a sane default.
- type mode:
{‘auto’, …}
- returns:
A new expression representing the Jacobian.
- rtype:
StateExpr
Notes
This method triggers no evaluation; it returns a new graph.
- Parameters:
same_particle (bool)
mode (str)
- dense(n_out, *, weight='W', bias='b')¶
Apply a learnable affine map on the feature axis.
y[..., j] = sum_i x[..., i] * W[i, j] + b[j]Spatial (rank) axes are untouched: the same
W, bare shared across every spatial component. The result is always a PSF (since the dense layer introduces learnable parameters).- Parameters:
n_out (int) – Number of output features.
weight (str) – Name for the weight parameter (default
"W"). Use distinct names ("W1","W2", …) when stacking multiple layers.bias (str | None) – Name for the bias parameter (default
"b";Noneto omit). Use distinct names ("b1","b2", …) when stacking layers.
- Returns:
A parametric state function wrapping the dense layer.
- Return type:
Examples
Build the hidden layers of an MLP force field:
>>> from SFI.bases import X >>> import jax.numpy as jnp >>> mlp = ( ... X(dim=2).vectorize(2) ... .dense(32, weight="W1", bias="b1") ... .elementwisemap(jnp.tanh) ... .dense(1, weight="W2", bias="b2") ... )
- property dim¶
- dot(other, axes=None)¶
Spatial tensordot via einsum.
- Semantics:
axes=None: contract last axis of self with first axis of other.
- axes=int:
if self.rank == other.rank: contract all axes (Frobenius/trace for rank-2).
else: contract axes trailing axes of self with axes leading axes of other.
axes=(a_axes, b_axes): NumPy-style explicit lists.
Arrays are accepted and coerced to spatial constants.
- drop_features: bool = True¶
- classmethod einsum(spec, *operands)¶
General contraction on spatial axes (like jnp.einsum).
Important
Use only lowercase letters.
spec refers only to spatial axes (not the feature axis).
Features take a Cartesian product across operands (no implicit feature reduction or alignment). If you need feature concatenation, use &/stack. For per-feature ops, use element-wise maps or binary ops where features must match.
Arrays in operands are accepted and coerced to spatial-constant expressions with a single feature. Only spatial letters in spec are interpreted. If no StateExpr is present, a TypeError is raised because dim cannot be inferred.
Examples
Vector inner product (per-feature), two rank-1 inputs: >>> # a, b: i × F >>> c = StateExpr.einsum(“i,i->”, a, b) # result: × F
Matrix–vector product (per-feature), rank-2 with rank-1: >>> # M: ij × F1, v: j × F2 → i × (F1×F2) >>> y = StateExpr.einsum(“ij,j->i”, M, v)
Outer product (per-feature Cartesian product): >>> # u: i × F1, v: j × F2 → ij × (F1×F2) >>> O = StateExpr.einsum(“i,j->ij”, u, v)
- Parameters:
spec (str) – An einsum string over spatial indices, e.g. “ij,j->i”.
operands (mix[StateExpr, array-like]) – Any mix of StateExpr and arrays.
- elementwisemap(func, *, label_fn=None)¶
Apply func element-wise to every feature (spatial axes untouched).
func must be a pure JAX function from scalar→scalar (rank-0 arrays OK). If the expression carries feature labels (e.g., a Basis or an SF bound from a Basis), label_fn (if provided) is applied to each feature label.
Example
>>> B = ... # Basis with 4 features >>> C = B.elementwisemap(jnp.tanh, label_fn=lambda s: f"tanh({s})")
- Parameters:
func (Callable[[Array], Array])
label_fn (Callable[[str], str] | None)
- estimate_bytes_per_sample(*, dtype=None, particle_size=None, sample=None, mode='forward')¶
Small convenience wrapper returning only the transient bytes/sample.
- Parameters:
particle_size (int | None)
sample (SampleMeta | None)
mode (str)
- Return type:
int
- features_to_rank(rank)¶
Unfold features into spatial axes → given rank.
The output layout changes from the current:
batch · (dim,)^self.rank · n_features
to:
batch · (dim,)^rank · (n_features / dim^(rank − self.rank),)
where the new innermost spatial axes are carved out of the feature axis. This is a pure reshape and is the exact inverse of
rank_to_features()when restoring the original rank.- Parameters:
rank (int) – Target tensor rank (must be greater than the current rank).
- Returns:
Expression at the requested rank with fewer features.
- Return type:
StateExpr (same subclass)
- Raises:
ValueError – If
n_featuresis not divisible bydim^Δrank.TypeError – If
rank ≤ self.rank(userank_to_featuresto go down).
Examples
Turn a dense layer’s output back into a vector field:
>>> scalar_expr.features_to_rank(1) # rank-1, F/dim features
Build a 2→H→H→2 MLP force field:
>>> mlp = ( ... X(dim=2) ... .rank_to_features() # rank-0, 2 features ... .dense(32, weight="W1", bias="b1") ... .elementwisemap(jnp.tanh) ... .dense(2, weight="W2", bias="b2") # rank-0, 2 features ... .features_to_rank(1) # rank-1, 1 feature ... )
- flatten_params(params)[source]¶
Vectorize a parameter dict according to the template order.
- Parameters:
params (dict[str, Array])
- property labels¶
Basis labels from the underlying CoeffNode (if present).
- memory_hint(*, dtype=None, particle_size=None, sample=None, mode='forward')¶
Conservative per-sample memory footprint for the WHOLE expression tree. Delegates to the root node, which sums children + own output along the way.
- Parameters:
particle_size (int | None)
sample (SampleMeta | None)
mode (str)
- property n_features¶
- property needs_v¶
- property particle_extras: tuple[str, ...]¶
Pure metadata, forwarded from the root node.
Names of extras declared as per-particle somewhere in the underlying node tree (typically by interaction leaves). The dispatcher reads this to know which keys to gather from (P, …) into (E, K, …) per edge before calling locals.
- property particles_input¶
- property pdepth¶
- property rank¶
- rank_to_features()¶
Fold all spatial (rank) axes into the feature axis → rank-0.
The output layout changes from:
batch · (dim,)^rank · n_features
to:
batch · (n_features × dim^rank,)
with rank = 0. This is a pure reshape (no copy, no learnable parameters) and is the exact inverse of
features_to_rank(original_rank).- Returns:
Scalar expression whose feature count is
self.n_features × self.dim ** self.rank.- Return type:
StateExpr (same subclass)
- Raises:
TypeError – If the expression is already rank‑0 (no-op would be confusing).
Examples
Prepare a rank-1 position vector for dense layers:
>>> X(dim=2).rank_to_features() # rank-0, 2 features
The round-trip is the identity:
>>> expr.rank_to_features().features_to_rank(expr.rank) # same as expr
- property required_extras: tuple[str, ...]¶
Presence-only extras required by the expression, forwarded from the root node. No shape/broadcast semantics here.
- root: BaseNode¶
- property sdims¶
- specialize(*, dataset)¶
Collapse a pooled model to its single-condition specialization.
Returns a new expression in which every
dataset_index-reading primitive (e.g.per_dataset_scalar(),dataset_indicator()) is folded at conditiondataset: per-condition parameter arrays collapse to that condition’s slice and the reserveddataset_indexextra drops out ofrequired_extras. The pooled-ness is an inference-time concern; once a condition is chosen the model stands alone (no dataset concept).On a bound
SFthe stored parameter values are projected to match the shrunken template; on an unboundPSFthe template’s per-condition specs become scalars.- Parameters:
dataset (int)
- Return type:
- sqrtm()¶
- classmethod stack(exprs)¶
Concatenate along the feature axis.
Static contracts must match (rank/dim, compatible pdepth).
- Parameters:
exprs (Sequence[StateExpr])
- template: ParamSuite¶
- tensordot(other, axes=1)¶
Alias of .dot with NumPy-compatible axes.
- tensorize(dim=None, mode='symmetric')¶
Lift a scalar expression to rank-2 (matrix).
- Parameters:
dim (int, optional) – Spatial dimension. Inferred when possible.
mode (str) –
'symmetric'(default) usessymmetric_matrix_basis()(d(d+1)/2 features per scalar feature, spans all symmetric matrices).'identity'usesidentity_matrix_basis()(1 feature per scalar feature, isotropic).
- Returns:
Matrix expression.
- Return type:
- unflatten_params(vec)[source]¶
Materialize a parameter dict from a flat vector (inverse of flatten_params).
- Parameters:
vec (Array)
- vectorize(dim=None, axes=None)¶
Lift a scalar expression to rank-1 (vector).
Equivalent to
self * unit_vector_basis(dim, axes=axes), i.e. a Cartesian product of the feature axis with unit vectors.- Parameters:
dim (int, optional) – Spatial dimension. Inferred from the expression’s contract when possible.
axes (sequence of int, optional) – Subset of spatial axes to include (default: all
dimaxes).
- Returns:
Vector expression with
n_features = self.n_features × len(axes).- Return type:
- class SFI.statefunc.ParamSpec(name: 'str', shape: 'Tuple[int, ...]', dtype: 'Any' = <class 'jax.numpy.float32'>, init: 'Callable | str' = 'zeros', default: 'Any' = None)[source]¶
Bases:
object- Parameters:
name (str)
shape (Tuple[int, ...])
dtype (Any)
init (Callable | str)
default (Any)
- compatible_with(other)[source]¶
Shareable iff shape and dtype match exactly.
- Parameters:
other (ParamSpec)
- Return type:
bool
- default: Any¶
- dtype: Any¶
- init: Callable | str¶
- merged_with(other)[source]¶
Return a single spec representing the shared parameter. Requires compatibility; keeps self.init by default.
- name: str¶
- shape: Tuple[int, ...]¶
- property size: int¶
- class SFI.statefunc.ParamSuite(specs)[source]¶
Bases:
ModuleImmutable container holding a set of
ParamSpecobjects.- Parameters:
specs (tuple[ParamSpec, ...])
- coerce(params, *, allow_scalar_for_scalar=True, allow_scalar_to_len1=True, cast_dtype=True)[source]¶
Normalize a user param dict into JAX arrays matching this suite.
Notes
If spec.shape == (), accept Python scalars / 0-d arrays (if allowed).
If spec.shape == (1,), optionally accept a scalar and expand to (1,).
Otherwise, require exact shape; dtype is cast if cast_dtype is True.
Returns a NEW dict with normalized arrays.
- Parameters:
params (dict)
allow_scalar_for_scalar (bool)
allow_scalar_to_len1 (bool)
cast_dtype (bool)
- Return type:
dict[str, Array]
- defaults()[source]¶
Return a parameter dict from spec
defaultvalues, or None if any spec has no default. Values are broadcast to the declared shape.- Return type:
dict[str, Array] | None
- property has_defaults: bool¶
True iff every spec in this suite carries a concrete
default.
- materialize(vector, *, dtype_overrides=None)[source]¶
- Parameters:
vector (Array)
dtype_overrides (dict[str, dtype] | None)
- merge(other)[source]¶
- Union with sharing-by-name:
If a name appears in both suites and specs are compatible (shape/dtype), they are tied (kept once).
If incompatible → error.
- Parameters:
other (ParamSuite | None)
- Return type:
- classmethod merge_many(*suites)[source]¶
Merge any number of suites, sharing parameters by name (shape/dtype must match).
- Parameters:
suites (ParamSuite | None)
- Return type:
ParamSuite | None
- classmethod parse(params)[source]¶
Normalize various user-facing descriptions into a ParamSuite.
Accepts:
None– returnsNoneParamSuite– returned as-isdict[name -> array | shape]– infer shape/dtypeiterable[ParamSpec]– from_specsiterable[str]– scalar specs for each name
Shapes may be
(),(k,),(m, n, ...)or an integer k (interpreted as(k,)).- Return type:
ParamSuite | None
- property size: int¶
- class SFI.statefunc.Rank(*values)[source]¶
Bases:
IntEnum- MATRIX = 2¶
- SCALAR = 0¶
- TENSOR3 = 3¶
- TENSOR4 = 4¶
- VECTOR = 1¶
- class SFI.statefunc.SF(psf, params, *, drop_features=None)[source]¶
Bases:
StateExprState-Function with θ fixed (a thin wrapper over the PSF’s root).
Behaves like a Basis for evaluation purposes (no .d_theta()), but you can still build .d_x() / .d_v() expressions.
Feature axis handling mirrors the parent PSF: if drop_features=True and n_features==1, the final axis is removed.
- Parameters:
psf (PSF)
params (dict[str, Array])
drop_features (bool)
- d_v(*, same_particle=False, mode='auto')¶
Build an expression for the velocity Jacobian ∂F/∂v.
Same rules as .d_x(). Requires needs_v=True on the underlying expression.
- Parameters:
same_particle (bool)
mode (str)
- d_x(*, same_particle=False, mode='auto')¶
Build an expression for the spatial Jacobian dF/dx.
Axis effects¶
Adds one derivative-dim immediately before the rank block.
If
particles_input=True:when
same_particle=True: if pdepth=1, compute df_i/dx_i (no extra P axis); the particle dimension behaves like a broadcasted index. Otherwise, raises an error.when
same_particle=False(default): compute the full cross-particle Jacobian df_i/dx_j; an extra particle axis appears (from JAX). We never create P axes ourselves; we only permute to canonical order.
- param same_particle:
See axis effects above.
- type same_particle:
bool
- param mode:
Backend differentiation mode; ‘auto’ selects a sane default.
- type mode:
{‘auto’, …}
- returns:
A new expression representing the Jacobian.
- rtype:
StateExpr
Notes
This method triggers no evaluation; it returns a new graph.
- Parameters:
same_particle (bool)
mode (str)
- dense(n_out, *, weight='W', bias='b')¶
Apply a learnable affine map on the feature axis.
y[..., j] = sum_i x[..., i] * W[i, j] + b[j]Spatial (rank) axes are untouched: the same
W, bare shared across every spatial component. The result is always a PSF (since the dense layer introduces learnable parameters).- Parameters:
n_out (int) – Number of output features.
weight (str) – Name for the weight parameter (default
"W"). Use distinct names ("W1","W2", …) when stacking multiple layers.bias (str | None) – Name for the bias parameter (default
"b";Noneto omit). Use distinct names ("b1","b2", …) when stacking layers.
- Returns:
A parametric state function wrapping the dense layer.
- Return type:
Examples
Build the hidden layers of an MLP force field:
>>> from SFI.bases import X >>> import jax.numpy as jnp >>> mlp = ( ... X(dim=2).vectorize(2) ... .dense(32, weight="W1", bias="b1") ... .elementwisemap(jnp.tanh) ... .dense(1, weight="W2", bias="b2") ... )
- property dim¶
- dot(other, axes=None)¶
Spatial tensordot via einsum.
- Semantics:
axes=None: contract last axis of self with first axis of other.
- axes=int:
if self.rank == other.rank: contract all axes (Frobenius/trace for rank-2).
else: contract axes trailing axes of self with axes leading axes of other.
axes=(a_axes, b_axes): NumPy-style explicit lists.
Arrays are accepted and coerced to spatial constants.
- drop_features: bool = True¶
- classmethod einsum(spec, *operands)¶
General contraction on spatial axes (like jnp.einsum).
Important
Use only lowercase letters.
spec refers only to spatial axes (not the feature axis).
Features take a Cartesian product across operands (no implicit feature reduction or alignment). If you need feature concatenation, use &/stack. For per-feature ops, use element-wise maps or binary ops where features must match.
Arrays in operands are accepted and coerced to spatial-constant expressions with a single feature. Only spatial letters in spec are interpreted. If no StateExpr is present, a TypeError is raised because dim cannot be inferred.
Examples
Vector inner product (per-feature), two rank-1 inputs: >>> # a, b: i × F >>> c = StateExpr.einsum(“i,i->”, a, b) # result: × F
Matrix–vector product (per-feature), rank-2 with rank-1: >>> # M: ij × F1, v: j × F2 → i × (F1×F2) >>> y = StateExpr.einsum(“ij,j->i”, M, v)
Outer product (per-feature Cartesian product): >>> # u: i × F1, v: j × F2 → ij × (F1×F2) >>> O = StateExpr.einsum(“i,j->ij”, u, v)
- Parameters:
spec (str) – An einsum string over spatial indices, e.g. “ij,j->i”.
operands (mix[StateExpr, array-like]) – Any mix of StateExpr and arrays.
- elementwisemap(func, *, label_fn=None)¶
Apply func element-wise to every feature (spatial axes untouched).
func must be a pure JAX function from scalar→scalar (rank-0 arrays OK). If the expression carries feature labels (e.g., a Basis or an SF bound from a Basis), label_fn (if provided) is applied to each feature label.
Example
>>> B = ... # Basis with 4 features >>> C = B.elementwisemap(jnp.tanh, label_fn=lambda s: f"tanh({s})")
- Parameters:
func (Callable[[Array], Array])
label_fn (Callable[[str], str] | None)
- estimate_bytes_per_sample(*, dtype=None, particle_size=None, sample=None, mode='forward')¶
Small convenience wrapper returning only the transient bytes/sample.
- Parameters:
particle_size (int | None)
sample (SampleMeta | None)
mode (str)
- Return type:
int
- features_to_rank(rank)¶
Unfold features into spatial axes → given rank.
The output layout changes from the current:
batch · (dim,)^self.rank · n_features
to:
batch · (dim,)^rank · (n_features / dim^(rank − self.rank),)
where the new innermost spatial axes are carved out of the feature axis. This is a pure reshape and is the exact inverse of
rank_to_features()when restoring the original rank.- Parameters:
rank (int) – Target tensor rank (must be greater than the current rank).
- Returns:
Expression at the requested rank with fewer features.
- Return type:
StateExpr (same subclass)
- Raises:
ValueError – If
n_featuresis not divisible bydim^Δrank.TypeError – If
rank ≤ self.rank(userank_to_featuresto go down).
Examples
Turn a dense layer’s output back into a vector field:
>>> scalar_expr.features_to_rank(1) # rank-1, F/dim features
Build a 2→H→H→2 MLP force field:
>>> mlp = ( ... X(dim=2) ... .rank_to_features() # rank-0, 2 features ... .dense(32, weight="W1", bias="b1") ... .elementwisemap(jnp.tanh) ... .dense(2, weight="W2", bias="b2") # rank-0, 2 features ... .features_to_rank(1) # rank-1, 1 feature ... )
- property labels¶
Basis labels propagated from the parent PSF.
- memory_hint(*, dtype=None, particle_size=None, sample=None, mode='forward')¶
Conservative per-sample memory footprint for the WHOLE expression tree. Delegates to the root node, which sums children + own output along the way.
- Parameters:
particle_size (int | None)
sample (SampleMeta | None)
mode (str)
- property n_features¶
- property needs_v¶
- params: dict[str, Array]¶
- property particle_extras: tuple[str, ...]¶
Pure metadata, forwarded from the root node.
Names of extras declared as per-particle somewhere in the underlying node tree (typically by interaction leaves). The dispatcher reads this to know which keys to gather from (P, …) into (E, K, …) per edge before calling locals.
- property particles_input¶
- property pdepth¶
- property rank¶
- rank_to_features()¶
Fold all spatial (rank) axes into the feature axis → rank-0.
The output layout changes from:
batch · (dim,)^rank · n_features
to:
batch · (n_features × dim^rank,)
with rank = 0. This is a pure reshape (no copy, no learnable parameters) and is the exact inverse of
features_to_rank(original_rank).- Returns:
Scalar expression whose feature count is
self.n_features × self.dim ** self.rank.- Return type:
StateExpr (same subclass)
- Raises:
TypeError – If the expression is already rank‑0 (no-op would be confusing).
Examples
Prepare a rank-1 position vector for dense layers:
>>> X(dim=2).rank_to_features() # rank-0, 2 features
The round-trip is the identity:
>>> expr.rank_to_features().features_to_rank(expr.rank) # same as expr
- property required_extras: tuple[str, ...]¶
Presence-only extras required by the expression, forwarded from the root node. No shape/broadcast semantics here.
- root: BaseNode¶
- property sdims¶
- specialize(*, dataset)[source]¶
Specialize a bound function at condition
dataset.Rewrites the graph (folding
dataset_index-reading leaves) and projects the bound parameter values onto the shrunken template: a per-condition spec whose shape loses a leading axis is sliced atdataset; shared specs are kept verbatim.- Parameters:
dataset (int)
- Return type:
- sqrtm()¶
- classmethod stack(exprs)¶
Concatenate along the feature axis.
Static contracts must match (rank/dim, compatible pdepth).
- Parameters:
exprs (Sequence[StateExpr])
- tensordot(other, axes=1)¶
Alias of .dot with NumPy-compatible axes.
- tensorize(dim=None, mode='symmetric')¶
Lift a scalar expression to rank-2 (matrix).
- Parameters:
dim (int, optional) – Spatial dimension. Inferred when possible.
mode (str) –
'symmetric'(default) usessymmetric_matrix_basis()(d(d+1)/2 features per scalar feature, spans all symmetric matrices).'identity'usesidentity_matrix_basis()(1 feature per scalar feature, isotropic).
- Returns:
Matrix expression.
- Return type:
- vectorize(dim=None, axes=None)¶
Lift a scalar expression to rank-1 (vector).
Equivalent to
self * unit_vector_basis(dim, axes=axes), i.e. a Cartesian product of the feature axis with unit vectors.- Parameters:
dim (int, optional) – Spatial dimension. Inferred from the expression’s contract when possible.
axes (sequence of int, optional) – Subset of spatial axes to include (default: all
dimaxes).
- Returns:
Vector expression with
n_features = self.n_features × len(axes).- Return type:
- class SFI.statefunc.StateExpr(root)[source]¶
Bases:
ModuleImmutable state expression backed by a static node tree.
Think: a read-only NumPy array whose last axis is features. Every algebraic operation returns a new
StateExpr(functional style), and static contract metadata (rank,dim,pdepth,n_features,needs_v,particles_input) is validated at graph-construction time.Runtime shapes¶
Inputs are batched at call time; the library handles batching.
If
particles_input=False:x.shape == batch · dimIf
particles_input=True:x.shape == batch · P · dim
Outputs always end with the feature axis (length
n_features):y.shape == batch · [P]^pdepth · (dim)^rank · n_features
If
particles_input=False,pdepthmust be0.User function contract (single-sample)¶
Factories accept single-sample callables; user code never sees batch axes. Your function gets
xof shape(dim,)(or(P, dim)ifparticles_input) and any subset of keyword-only args it declares:{v, mask, extras, params}. Return shape (no batch axis):(P,)*pdepth + (dim,)*rank + (n_features?,). Ifn_features==1, you may omit the last axis; a singleton is inserted.maskmust broadcast to the prefix ofxincluding the particle axis.extraspresence: if a leaf declaresextraswith no explicit keys, presence is required (any dict). Ifextras_keysis given, those keys are required. Values may be scalars or arrays that broadcast over batch only.
Operators¶
Element-wise arithmetic
+ - * /– element-wise on spatial axes; features must match. Scalars and 1-D vectors (lengthn_features) broadcast along features.Unary:
+expr,-expr.NumPy/JAX ufuncs:
sin,exp, etc. forward to element-wise maps with the same broadcasting rules; binary ufuncs acceptStateExpr ∘ constandStateExpr ∘ StateExpr(features must match for the latter).
Linear-algebra-like
@(matmul): true matrix multiplication on spatial axes,(..., m, k) @ (..., k, n) -> (..., m, n); features form a Cartesian product between operands (result features =F_left × F_right)..einsum(*others, spec=...): generic spatial contraction; features take a Cartesian product across all operands (no implicit feature reduction)..dot(other): Spatial inner product between last rank axis of self and first rank axis of other. Cartesian product over features..sqrtm(): matrix square root per-feature; requiresrank==2.
Feature-axis manipulation
expr1 & expr2/StateExpr.stack([...]): concatenate features. Static spatial contracts must match; labels (if present) are concatenated.expr[idx]: feature selection (slice/list/bool/int). Spatial contract is unchanged; labels are subset when available..elementwisemap(func, label_fn=None): apply a scalar-to-scalar map to each feature independently (spatial axes untouched). Optionallabel_fnupdates labels forBasis.
Differentiation builders¶
All builders return new expressions (no evaluation).
.d_x(same_particle=False, mode='auto')– spatial Jacobian dF/dx.Adds one derivative-dim axis immediately before the rank block.
If
particles_input=True:when
same_particle=False(default), builds the full cross-particle Jacobian df_i/dx_j and a second particle axis appears (from JAX);when
same_particle=Trueandpdepth=1, computes the same-particle Jacobian df_i/dx_i without adding a new particle axis; otherwise an error is raised.
.d_v(same_particle=False, mode='auto')– velocity Jacobian dF/dv (requiresneeds_v=True). Same axis rules as.d_x()..d_theta(mode='auto')– Jacobian w.r.t. parameters (PSF only); the final axis becomesfeatures × n_params_total. Batch/particle/rank prefixes are preserved.
Type mixing and broadcasting¶
Scalars and ndarrays are treated as purely spatial constants: they must be broadcastable to the spatial rank block
(dim,)*rankand are then broadcast uniformly across the feature axis. Bare arrays cannot target the feature axis directly.Combining two
StateExprrequires matching static contracts forrank,dim, andpdepth.For element-wise ops such as
+,-and most binary ufuncs,n_featuresmust match (per-feature operations).For multiplicative ops (
*,/and their ufuncs), as well as@and.einsum, feature axes take a Cartesian product between operands:F_out = F_left × F_right. When both operands have more than one feature a one-off warning is emitted, as this can grown_featuresquickly.needs_vis OR-combined: if any operand needsv, the result does.particles_inputis OR-combined: if any operand uses particle input, the result does too. An operand without particle input is broadcast uniformly across the particle axis.
Array interop¶
Plain JAX/NumPy arrays are accepted in binary ops with StateExpr. They are treated as spatial constants with a single feature. Arrays broadcast over spatial axes and batch/particles only. Features never arise from arrays and are never contracted unless requested by explicit feature-aware APIs.
Supported operations with arrays:
Elementwise:
+,-,*,/,**, and their reflected forms.Linear algebra:
A @ B,B @ A.Tensor algebra:
einsum(eq, ...),dot(...),tensordot(...).
JAX compatibility and autodiff¶
Write user functions with
jax.numpy as jnp. Expressions compose underjit/vmap, and support automatic differentiation:.d_x(),.d_v()add a derivative-dim axis (and a particle axis whenparticles_input=True)..d_theta()fusesfeatures × n_paramson the last axis. Derivative axis ordering is canonicalized by permutation only.
- d_v(*, same_particle=False, mode='auto')[source]¶
Build an expression for the velocity Jacobian ∂F/∂v.
Same rules as .d_x(). Requires needs_v=True on the underlying expression.
- Parameters:
same_particle (bool)
mode (str)
- d_x(*, same_particle=False, mode='auto')[source]¶
Build an expression for the spatial Jacobian dF/dx.
Axis effects¶
Adds one derivative-dim immediately before the rank block.
If
particles_input=True:when
same_particle=True: if pdepth=1, compute df_i/dx_i (no extra P axis); the particle dimension behaves like a broadcasted index. Otherwise, raises an error.when
same_particle=False(default): compute the full cross-particle Jacobian df_i/dx_j; an extra particle axis appears (from JAX). We never create P axes ourselves; we only permute to canonical order.
- param same_particle:
See axis effects above.
- type same_particle:
bool
- param mode:
Backend differentiation mode; ‘auto’ selects a sane default.
- type mode:
{‘auto’, …}
- returns:
A new expression representing the Jacobian.
- rtype:
StateExpr
Notes
This method triggers no evaluation; it returns a new graph.
- Parameters:
same_particle (bool)
mode (str)
- dense(n_out, *, weight='W', bias='b')[source]¶
Apply a learnable affine map on the feature axis.
y[..., j] = sum_i x[..., i] * W[i, j] + b[j]Spatial (rank) axes are untouched: the same
W, bare shared across every spatial component. The result is always a PSF (since the dense layer introduces learnable parameters).- Parameters:
n_out (int) – Number of output features.
weight (str) – Name for the weight parameter (default
"W"). Use distinct names ("W1","W2", …) when stacking multiple layers.bias (str | None) – Name for the bias parameter (default
"b";Noneto omit). Use distinct names ("b1","b2", …) when stacking layers.
- Returns:
A parametric state function wrapping the dense layer.
- Return type:
Examples
Build the hidden layers of an MLP force field:
>>> from SFI.bases import X >>> import jax.numpy as jnp >>> mlp = ( ... X(dim=2).vectorize(2) ... .dense(32, weight="W1", bias="b1") ... .elementwisemap(jnp.tanh) ... .dense(1, weight="W2", bias="b2") ... )
- property dim¶
- dot(other, axes=None)[source]¶
Spatial tensordot via einsum.
- Semantics:
axes=None: contract last axis of self with first axis of other.
- axes=int:
if self.rank == other.rank: contract all axes (Frobenius/trace for rank-2).
else: contract axes trailing axes of self with axes leading axes of other.
axes=(a_axes, b_axes): NumPy-style explicit lists.
Arrays are accepted and coerced to spatial constants.
- classmethod einsum(spec, *operands)[source]¶
General contraction on spatial axes (like jnp.einsum).
Important
Use only lowercase letters.
spec refers only to spatial axes (not the feature axis).
Features take a Cartesian product across operands (no implicit feature reduction or alignment). If you need feature concatenation, use &/stack. For per-feature ops, use element-wise maps or binary ops where features must match.
Arrays in operands are accepted and coerced to spatial-constant expressions with a single feature. Only spatial letters in spec are interpreted. If no StateExpr is present, a TypeError is raised because dim cannot be inferred.
Examples
Vector inner product (per-feature), two rank-1 inputs: >>> # a, b: i × F >>> c = StateExpr.einsum(“i,i->”, a, b) # result: × F
Matrix–vector product (per-feature), rank-2 with rank-1: >>> # M: ij × F1, v: j × F2 → i × (F1×F2) >>> y = StateExpr.einsum(“ij,j->i”, M, v)
Outer product (per-feature Cartesian product): >>> # u: i × F1, v: j × F2 → ij × (F1×F2) >>> O = StateExpr.einsum(“i,j->ij”, u, v)
- Parameters:
spec (str) – An einsum string over spatial indices, e.g. “ij,j->i”.
operands (mix[StateExpr, array-like]) – Any mix of StateExpr and arrays.
- elementwisemap(func, *, label_fn=None)[source]¶
Apply func element-wise to every feature (spatial axes untouched).
func must be a pure JAX function from scalar→scalar (rank-0 arrays OK). If the expression carries feature labels (e.g., a Basis or an SF bound from a Basis), label_fn (if provided) is applied to each feature label.
Example
>>> B = ... # Basis with 4 features >>> C = B.elementwisemap(jnp.tanh, label_fn=lambda s: f"tanh({s})")
- Parameters:
func (Callable[[Array], Array])
label_fn (Callable[[str], str] | None)
- estimate_bytes_per_sample(*, dtype=None, particle_size=None, sample=None, mode='forward')[source]¶
Small convenience wrapper returning only the transient bytes/sample.
- Parameters:
particle_size (int | None)
sample (SampleMeta | None)
mode (str)
- Return type:
int
- features_to_rank(rank)[source]¶
Unfold features into spatial axes → given rank.
The output layout changes from the current:
batch · (dim,)^self.rank · n_features
to:
batch · (dim,)^rank · (n_features / dim^(rank − self.rank),)
where the new innermost spatial axes are carved out of the feature axis. This is a pure reshape and is the exact inverse of
rank_to_features()when restoring the original rank.- Parameters:
rank (int) – Target tensor rank (must be greater than the current rank).
- Returns:
Expression at the requested rank with fewer features.
- Return type:
StateExpr (same subclass)
- Raises:
ValueError – If
n_featuresis not divisible bydim^Δrank.TypeError – If
rank ≤ self.rank(userank_to_featuresto go down).
Examples
Turn a dense layer’s output back into a vector field:
>>> scalar_expr.features_to_rank(1) # rank-1, F/dim features
Build a 2→H→H→2 MLP force field:
>>> mlp = ( ... X(dim=2) ... .rank_to_features() # rank-0, 2 features ... .dense(32, weight="W1", bias="b1") ... .elementwisemap(jnp.tanh) ... .dense(2, weight="W2", bias="b2") # rank-0, 2 features ... .features_to_rank(1) # rank-1, 1 feature ... )
- memory_hint(*, dtype=None, particle_size=None, sample=None, mode='forward')[source]¶
Conservative per-sample memory footprint for the WHOLE expression tree. Delegates to the root node, which sums children + own output along the way.
- Parameters:
particle_size (int | None)
sample (SampleMeta | None)
mode (str)
- property n_features¶
- property needs_v¶
- property particle_extras: tuple[str, ...]¶
Pure metadata, forwarded from the root node.
Names of extras declared as per-particle somewhere in the underlying node tree (typically by interaction leaves). The dispatcher reads this to know which keys to gather from (P, …) into (E, K, …) per edge before calling locals.
- property particles_input¶
- property pdepth¶
- property rank¶
- rank_to_features()[source]¶
Fold all spatial (rank) axes into the feature axis → rank-0.
The output layout changes from:
batch · (dim,)^rank · n_features
to:
batch · (n_features × dim^rank,)
with rank = 0. This is a pure reshape (no copy, no learnable parameters) and is the exact inverse of
features_to_rank(original_rank).- Returns:
Scalar expression whose feature count is
self.n_features × self.dim ** self.rank.- Return type:
StateExpr (same subclass)
- Raises:
TypeError – If the expression is already rank‑0 (no-op would be confusing).
Examples
Prepare a rank-1 position vector for dense layers:
>>> X(dim=2).rank_to_features() # rank-0, 2 features
The round-trip is the identity:
>>> expr.rank_to_features().features_to_rank(expr.rank) # same as expr
- property required_extras: tuple[str, ...]¶
Presence-only extras required by the expression, forwarded from the root node. No shape/broadcast semantics here.
- root: BaseNode¶
- property sdims¶
- specialize(*, dataset)[source]¶
Collapse a pooled model to its single-condition specialization.
Returns a new expression in which every
dataset_index-reading primitive (e.g.per_dataset_scalar(),dataset_indicator()) is folded at conditiondataset: per-condition parameter arrays collapse to that condition’s slice and the reserveddataset_indexextra drops out ofrequired_extras. The pooled-ness is an inference-time concern; once a condition is chosen the model stands alone (no dataset concept).On a bound
SFthe stored parameter values are projected to match the shrunken template; on an unboundPSFthe template’s per-condition specs become scalars.- Parameters:
dataset (int)
- Return type:
- classmethod stack(exprs)[source]¶
Concatenate along the feature axis.
Static contracts must match (rank/dim, compatible pdepth).
- Parameters:
exprs (Sequence[StateExpr])
- tensorize(dim=None, mode='symmetric')[source]¶
Lift a scalar expression to rank-2 (matrix).
- Parameters:
dim (int, optional) – Spatial dimension. Inferred when possible.
mode (str) –
'symmetric'(default) usessymmetric_matrix_basis()(d(d+1)/2 features per scalar feature, spans all symmetric matrices).'identity'usesidentity_matrix_basis()(1 feature per scalar feature, isotropic).
- Returns:
Matrix expression.
- Return type:
- vectorize(dim=None, axes=None)[source]¶
Lift a scalar expression to rank-1 (vector).
Equivalent to
self * unit_vector_basis(dim, axes=axes), i.e. a Cartesian product of the feature axis with unit vectors.- Parameters:
dim (int, optional) – Spatial dimension. Inferred from the expression’s contract when possible.
axes (sequence of int, optional) – Subset of spatial axes to include (default: all
dimaxes).
- Returns:
Vector expression with
n_features = self.n_features × len(axes).- Return type:
- Parameters:
root (BaseNode)
- SFI.statefunc.make_basis(func, *, dim=None, rank, n_features=1, needs_v=False, labels=None, descriptor='custom', extras_keys=None, particle_extras=None, specialize_at=None)[source]¶
Construct a deterministic Basis from a single-sample user function, with no particle semantics. Particle axes (if present in
xat call time) are treated purely as batch axes and vmapped over.particle_extrasnames extras keys whose values are per-sample arrays aligned with the batch/particle axes (e.g. anextras_localentry of shape(N, ...)): they are vmapped alongsidex, so the single-sample function sees its own particle’s value instead of the whole array — the route to per-particle terms in single-particle bases (home-range centres, individual labels, …).User function signature — declare only the kwargs you need:
Simplest:
f(x) -> arrayWith velocity:
f(x, *, v) -> arrayWith extras:
f(x, *, extras) -> array
The full signature is
f(x, *, v=None, mask=None, extras=None); we introspect and pass only the kwargs you declare.Shapes (single sample):
x: (dim,) return: (*rank_axes, m) # feature last; m == n_features
If
n_features == 1you may omit the last axis; a singleton feature axis is auto-inserted.Extras¶
If
extrasis declared, you may provideextras_keys=(...)to enforce keys. Extras arrays must broadcast over the batch prefix (never over rank/feature).Examples
>>> import jax.numpy as jnp >>> from SFI.statefunc import make_basis >>> B = make_basis(lambda x: x, dim=2, rank=1, n_features=1) # (equivalent to the built-in X(dim=2))
JAX¶
Write
fwithjax.numpyand keep it pure; works with jit/vmap/autodiff.- Parameters:
func (Callable)
dim (int | None)
rank (int)
n_features (int)
needs_v (bool)
labels (Sequence[str] | None)
descriptor (Any)
extras_keys (Sequence[str] | None)
particle_extras (Sequence[str] | None)
specialize_at (Callable | None)
- Return type:
- SFI.statefunc.make_interactor(func, *, dim, rank, K=None, Kmax=None, n_features=1, needs_v=False, labels=(), descriptor=None, params=None, extras_keys=(), particle_extras=())[source]¶
Build a local interaction dictionary (Interactor) from a single-sample user function that consumes (K, dim) and returns feature-last.
- Pass exactly one of:
K=int → fixed arity
Kmax=int → variable arity (ragged via mask)
particle_extrasnames the extras keys whose values are per-particle arrays (shape(P, ...)): the dispatcher gathers them per edge member, so insidefuncthey arrive with shape(K, ...)— one entry per member of the local tuple. The reserved"particle_index"extra (injected byTrajectoryCollection) combined with a(P,)-shaped parameter gives per-particle inferred parameters:def local(Xk, *, params, extras): mob = params["mob"][extras["particle_index"]] # (K,) ...
- Parameters:
func (Callable)
dim (int)
rank (Rank)
K (int | None)
Kmax (int | None)
n_features (int)
needs_v (bool)
labels (Iterable[str])
params (ParamSuite | None)
extras_keys (Iterable[str])
particle_extras (Iterable[str])
- SFI.statefunc.make_psf(func, *, dim=None, rank, n_features=1, drop_features=True, needs_v=False, labels=None, descriptor='parametric', params, extras_keys=None, specialize_at=None)[source]¶
Construct a parametric state-function family (PSF) from a single-sample user function, without particle semantics.
User function signature — declare only the kwargs you need:
Simplest:
f(x, *, params) -> arrayWith velocity:
f(x, *, v, params) -> arrayWith extras:
f(x, *, params, extras) -> array
The full signature is
f(x, *, params, v=None, mask=None, extras=None); we introspect and pass only the kwargs you declare.Shapes (single sample):
x: (dim,) return: (*rank_axes, m) # feature last; m == n_features
If
n_features == 1you may omit the last axis; we auto-insert a singleton feature axis.Parameters (
params) may be described as:a
ParamSuite,an iterable of
ParamSpec,a dict of shapes, e.g.
{'W': (d,d), 'b': ()},or a dict of sample arrays from which (shape, dtype) are inferred.
Extras¶
Same rules as
make_basis(extras_keysoptional; broadcast over batch prefix).JAX¶
Works with jit/vmap/autodiff w.r.t. inputs and parameters.
- Parameters:
func (Callable)
dim (int | None)
rank (int)
n_features (int)
drop_features (bool)
needs_v (bool)
labels (Sequence[str] | None)
descriptor (Any)
params (ParamSuite | Iterable[ParamSpec] | dict[str, Any])
extras_keys (Sequence[str] | None)
specialize_at (Callable | None)
- Return type:
- SFI.statefunc.make_sf(func, *, dim=None, rank, n_features=1, drop_features=True, needs_v=False, labels=None, descriptor='custom_sf', extras_keys=None)[source]¶
Construct an SF (bound state function) directly from a parameter-free user function — no
BasisorPSFintermediate needed.This is the simplest entry point when you have a known, fixed function (e.g. an exact model for comparison, or a hand-coded feature) and just want a callable that participates in the SFI expression-tree ecosystem.
The resulting
SFsupports.d_x(),.d_v(), and can be passed tocompare_to_exact,integrate, or any other API that accepts anSF/StateExpr.User function signature — declare only the kwargs you need:
Simplest:
f(x) -> arrayWith velocity:
f(x, *, v) -> arrayWith extras:
f(x, *, extras) -> array
Shapes (single sample):
x: (dim,) return: (*rank_axes, n_features)
If
n_features == 1you may omit the trailing feature axis; it is auto-inserted. The resulting SF squeezes it back whendrop_features=True(default).- Parameters:
func (callable) – Pure JAX function, compatible with jit/vmap/autodiff.
dim (int or None) – Spatial dimensionality (None = infer at first call).
rank (int) – Tensor rank of the output (0=scalar, 1=vector, 2=matrix).
n_features (int) – Number of output features (default 1).
drop_features (bool) – Remove trailing size-1 feature axis (default True).
needs_v (bool) – Whether
funcrequires velocityv.labels (sequence of str or None) – Human-readable feature labels (auto-generated if None).
descriptor (any) – Metadata tag stored on the leaf node.
extras_keys (sequence of str or None) – Required keys in the
extrasmapping.
- Returns:
A bound, callable state function with no free parameters.
- Return type:
Examples
>>> import jax.numpy as jnp >>> from SFI.statefunc import make_sf >>> harmonic = make_sf(lambda x: -x, rank=1, dim=2) >>> harmonic(jnp.array([1.0, 2.0])) Array([-1., -2.], dtype=float32)
- SFI.statefunc.set_jit(enabled=True)[source]¶
Globally enable/disable JIT for Basis/PSF/SF __call__.
- Parameters:
enabled (bool)
Subpackages¶
Submodules¶
- SFI.statefunc.basis module
BasisBasis.d_v()Basis.d_x()Basis.dense()Basis.dimBasis.dot()Basis.einsum()Basis.elementwisemap()Basis.estimate_bytes_per_sample()Basis.features_to_rank()Basis.labelsBasis.memory_hint()Basis.n_featuresBasis.needs_vBasis.particle_extrasBasis.particles_inputBasis.pdepthBasis.rankBasis.rank_to_features()Basis.required_extrasBasis.rootBasis.sdimsBasis.specialize()Basis.sqrtm()Basis.stack()Basis.tensordot()Basis.tensorize()Basis.to_psf()Basis.vectorize()
- SFI.statefunc.factory module
- SFI.statefunc.interactor module
InteractorInteractor.d_v()Interactor.d_x()Interactor.dense()Interactor.dimInteractor.dispatch()Interactor.dispatch_pairs()Interactor.dispatch_pairs_from_extras()Interactor.dot()Interactor.einsum()Interactor.elementwisemap()Interactor.estimate_bytes_per_sample()Interactor.features_to_rank()Interactor.memory_hint()Interactor.n_featuresInteractor.needs_vInteractor.particle_extrasInteractor.particles_inputInteractor.pdepthInteractor.rankInteractor.rank_to_features()Interactor.required_extrasInteractor.rootInteractor.sdimsInteractor.specialize()Interactor.sqrtm()Interactor.stack()Interactor.tensordot()Interactor.tensorize()Interactor.vectorize()
- SFI.statefunc.memhint module
- SFI.statefunc.params module
ParamSpecParamSuiteParamSuite.coerce()ParamSuite.defaults()ParamSuite.from_specs()ParamSuite.has_defaultsParamSuite.materialize()ParamSuite.merge()ParamSuite.merge_many()ParamSuite.parse()ParamSuite.sizeParamSuite.specsParamSuite.stamp()ParamSuite.tree_flatten()ParamSuite.tree_unflatten()ParamSuite.vectorize()ParamSuite.zeros()
- SFI.statefunc.psf module
PSFPSF.bind()PSF.d_theta()PSF.d_v()PSF.d_x()PSF.dense()PSF.dimPSF.dot()PSF.drop_featuresPSF.einsum()PSF.elementwisemap()PSF.estimate_bytes_per_sample()PSF.features_to_rank()PSF.flatten_params()PSF.labelsPSF.memory_hint()PSF.n_featuresPSF.needs_vPSF.particle_extrasPSF.particles_inputPSF.pdepthPSF.rankPSF.rank_to_features()PSF.required_extrasPSF.rootPSF.sdimsPSF.specialize()PSF.sqrtm()PSF.stack()PSF.templatePSF.tensordot()PSF.tensorize()PSF.unflatten_params()PSF.vectorize()
- SFI.statefunc.sf module
SFSF.d_v()SF.d_x()SF.dense()SF.dimSF.dot()SF.drop_featuresSF.einsum()SF.elementwisemap()SF.estimate_bytes_per_sample()SF.features_to_rank()SF.labelsSF.memory_hint()SF.n_featuresSF.needs_vSF.paramsSF.particle_extrasSF.particles_inputSF.pdepthSF.rankSF.rank_to_features()SF.required_extrasSF.rootSF.sdimsSF.specialize()SF.sqrtm()SF.stack()SF.tensordot()SF.tensorize()SF.vectorize()
- SFI.statefunc.stateexpr module
- Shape conventions (runtime evaluation)
- User function contract (single–sample)
- JAX use & autodiff
StateExprStateExpr.d_v()StateExpr.d_x()StateExpr.dense()StateExpr.dimStateExpr.dot()StateExpr.einsum()StateExpr.elementwisemap()StateExpr.estimate_bytes_per_sample()StateExpr.features_to_rank()StateExpr.memory_hint()StateExpr.n_featuresStateExpr.needs_vStateExpr.particle_extrasStateExpr.particles_inputStateExpr.pdepthStateExpr.rankStateExpr.rank_to_features()StateExpr.required_extrasStateExpr.rootStateExpr.sdimsStateExpr.specialize()StateExpr.sqrtm()StateExpr.stack()StateExpr.tensordot()StateExpr.tensorize()StateExpr.vectorize()
- SFI.statefunc.structexpr module
StructuredExprStructuredExpr.TStructuredExpr.abs()StructuredExpr.cos()StructuredExpr.dense()StructuredExpr.dot()StructuredExpr.einsum()StructuredExpr.elementwisemap()StructuredExpr.exp()StructuredExpr.eye()StructuredExpr.features_to_rank()StructuredExpr.labelsStructuredExpr.log()StructuredExpr.n_featuresStructuredExpr.param_suiteStructuredExpr.rank_to_features()StructuredExpr.sdimsStructuredExpr.sin()StructuredExpr.sqrt()StructuredExpr.srankStructuredExpr.stack()StructuredExpr.tanh()StructuredExpr.with_label()