SFI.statefunc.basis module¶
Basis façade: dictionary of deterministic functions.
- class SFI.statefunc.basis.Basis(root)[source]¶
Bases:
StateExprDeterministic dictionary façade (no parameters).
- Parameters:
root (BaseNode)
- d_v(*, same_particle=False, mode='auto')¶
Build an expression for the velocity Jacobian ∂F/∂v.
Same rules as .d_x(). Requires needs_v=True on the underlying expression.
- Parameters:
same_particle (bool)
mode (str)
- d_x(*, same_particle=False, mode='auto')¶
Build an expression for the spatial Jacobian dF/dx.
Axis effects¶
Adds one derivative-dim immediately before the rank block.
If
particles_input=True:when
same_particle=True: if pdepth=1, compute df_i/dx_i (no extra P axis); the particle dimension behaves like a broadcasted index. Otherwise, raises an error.when
same_particle=False(default): compute the full cross-particle Jacobian df_i/dx_j; an extra particle axis appears (from JAX). We never create P axes ourselves; we only permute to canonical order.
- param same_particle:
See axis effects above.
- type same_particle:
bool
- param mode:
Backend differentiation mode; ‘auto’ selects a sane default.
- type mode:
{‘auto’, …}
- returns:
A new expression representing the Jacobian.
- rtype:
StateExpr
Notes
This method triggers no evaluation; it returns a new graph.
- Parameters:
same_particle (bool)
mode (str)
- dense(n_out, *, weight='W', bias='b')¶
Apply a learnable affine map on the feature axis.
y[..., j] = sum_i x[..., i] * W[i, j] + b[j]Spatial (rank) axes are untouched: the same
W, bare shared across every spatial component. The result is always a PSF (since the dense layer introduces learnable parameters).- Parameters:
n_out (int) – Number of output features.
weight (str) – Name for the weight parameter (default
"W"). Use distinct names ("W1","W2", …) when stacking multiple layers.bias (str | None) – Name for the bias parameter (default
"b";Noneto omit). Use distinct names ("b1","b2", …) when stacking layers.
- Returns:
A parametric state function wrapping the dense layer.
- Return type:
Examples
Build the hidden layers of an MLP force field:
>>> from SFI.bases import X >>> import jax.numpy as jnp >>> mlp = ( ... X(dim=2).vectorize(2) ... .dense(32, weight="W1", bias="b1") ... .elementwisemap(jnp.tanh) ... .dense(1, weight="W2", bias="b2") ... )
- property dim¶
- dot(other, axes=None)¶
Spatial tensordot via einsum.
- Semantics:
axes=None: contract last axis of self with first axis of other.
- axes=int:
if self.rank == other.rank: contract all axes (Frobenius/trace for rank-2).
else: contract axes trailing axes of self with axes leading axes of other.
axes=(a_axes, b_axes): NumPy-style explicit lists.
Arrays are accepted and coerced to spatial constants.
- classmethod einsum(spec, *operands)¶
General contraction on spatial axes (like jnp.einsum).
Important
Use only lowercase letters.
spec refers only to spatial axes (not the feature axis).
Features take a Cartesian product across operands (no implicit feature reduction or alignment). If you need feature concatenation, use &/stack. For per-feature ops, use element-wise maps or binary ops where features must match.
Arrays in operands are accepted and coerced to spatial-constant expressions with a single feature. Only spatial letters in spec are interpreted. If no StateExpr is present, a TypeError is raised because dim cannot be inferred.
Examples
Vector inner product (per-feature), two rank-1 inputs: >>> # a, b: i × F >>> c = StateExpr.einsum(“i,i->”, a, b) # result: × F
Matrix–vector product (per-feature), rank-2 with rank-1: >>> # M: ij × F1, v: j × F2 → i × (F1×F2) >>> y = StateExpr.einsum(“ij,j->i”, M, v)
Outer product (per-feature Cartesian product): >>> # u: i × F1, v: j × F2 → ij × (F1×F2) >>> O = StateExpr.einsum(“i,j->ij”, u, v)
- Parameters:
spec (str) – An einsum string over spatial indices, e.g. “ij,j->i”.
operands (mix[StateExpr, array-like]) – Any mix of StateExpr and arrays.
- elementwisemap(func, *, label_fn=None)¶
Apply func element-wise to every feature (spatial axes untouched).
func must be a pure JAX function from scalar→scalar (rank-0 arrays OK). If the expression carries feature labels (e.g., a Basis or an SF bound from a Basis), label_fn (if provided) is applied to each feature label.
Example
>>> B = ... # Basis with 4 features >>> C = B.elementwisemap(jnp.tanh, label_fn=lambda s: f"tanh({s})")
- Parameters:
func (Callable[[Array], Array])
label_fn (Callable[[str], str] | None)
- estimate_bytes_per_sample(*, dtype=None, particle_size=None, sample=None, mode='forward')¶
Small convenience wrapper returning only the transient bytes/sample.
- Parameters:
particle_size (int | None)
sample (SampleMeta | None)
mode (str)
- Return type:
int
- features_to_rank(rank)¶
Unfold features into spatial axes → given rank.
The output layout changes from the current:
batch · (dim,)^self.rank · n_features
to:
batch · (dim,)^rank · (n_features / dim^(rank − self.rank),)
where the new innermost spatial axes are carved out of the feature axis. This is a pure reshape and is the exact inverse of
rank_to_features()when restoring the original rank.- Parameters:
rank (int) – Target tensor rank (must be greater than the current rank).
- Returns:
Expression at the requested rank with fewer features.
- Return type:
StateExpr (same subclass)
- Raises:
ValueError – If
n_featuresis not divisible bydim^Δrank.TypeError – If
rank ≤ self.rank(userank_to_featuresto go down).
Examples
Turn a dense layer’s output back into a vector field:
>>> scalar_expr.features_to_rank(1) # rank-1, F/dim features
Build a 2→H→H→2 MLP force field:
>>> mlp = ( ... X(dim=2) ... .rank_to_features() # rank-0, 2 features ... .dense(32, weight="W1", bias="b1") ... .elementwisemap(jnp.tanh) ... .dense(2, weight="W2", bias="b2") # rank-0, 2 features ... .features_to_rank(1) # rank-1, 1 feature ... )
- property labels¶
- memory_hint(*, dtype=None, particle_size=None, sample=None, mode='forward')¶
Conservative per-sample memory footprint for the WHOLE expression tree. Delegates to the root node, which sums children + own output along the way.
- Parameters:
particle_size (int | None)
sample (SampleMeta | None)
mode (str)
- property n_features¶
- property needs_v¶
- property particle_extras: tuple[str, ...]¶
Pure metadata, forwarded from the root node.
Names of extras declared as per-particle somewhere in the underlying node tree (typically by interaction leaves). The dispatcher reads this to know which keys to gather from (P, …) into (E, K, …) per edge before calling locals.
- property particles_input¶
- property pdepth¶
- property rank¶
- rank_to_features()¶
Fold all spatial (rank) axes into the feature axis → rank-0.
The output layout changes from:
batch · (dim,)^rank · n_features
to:
batch · (n_features × dim^rank,)
with rank = 0. This is a pure reshape (no copy, no learnable parameters) and is the exact inverse of
features_to_rank(original_rank).- Returns:
Scalar expression whose feature count is
self.n_features × self.dim ** self.rank.- Return type:
StateExpr (same subclass)
- Raises:
TypeError – If the expression is already rank‑0 (no-op would be confusing).
Examples
Prepare a rank-1 position vector for dense layers:
>>> X(dim=2).rank_to_features() # rank-0, 2 features
The round-trip is the identity:
>>> expr.rank_to_features().features_to_rank(expr.rank) # same as expr
- property required_extras: tuple[str, ...]¶
Presence-only extras required by the expression, forwarded from the root node. No shape/broadcast semantics here.
- root: BaseNode¶
- property sdims¶
- specialize(*, dataset)¶
Collapse a pooled model to its single-condition specialization.
Returns a new expression in which every
dataset_index-reading primitive (e.g.per_dataset_scalar(),dataset_indicator()) is folded at conditiondataset: per-condition parameter arrays collapse to that condition’s slice and the reserveddataset_indexextra drops out ofrequired_extras. The pooled-ness is an inference-time concern; once a condition is chosen the model stands alone (no dataset concept).On a bound
SFthe stored parameter values are projected to match the shrunken template; on an unboundPSFthe template’s per-condition specs become scalars.- Parameters:
dataset (int)
- Return type:
- sqrtm()¶
- classmethod stack(exprs)¶
Concatenate along the feature axis.
Static contracts must match (rank/dim, compatible pdepth).
- Parameters:
exprs (Sequence[StateExpr])
- tensordot(other, axes=1)¶
Alias of .dot with NumPy-compatible axes.
- tensorize(dim=None, mode='symmetric')¶
Lift a scalar expression to rank-2 (matrix).
- Parameters:
dim (int, optional) – Spatial dimension. Inferred when possible.
mode (str) –
'symmetric'(default) usessymmetric_matrix_basis()(d(d+1)/2 features per scalar feature, spans all symmetric matrices).'identity'usesidentity_matrix_basis()(1 feature per scalar feature, isotropic).
- Returns:
Matrix expression.
- Return type:
- to_psf(coeff_key='coeff', drop_features=True)[source]¶
Return a parametric state function whose value is a linear combination of this Basis’ features:
F(x; θ) = Σ_j θ_j · f_j(x)
Note that use cases are rare within SFI, since the PSF’s features axis is typically used for nonlinearities and/or vector/tensor components. But this can be useful for quick prototyping of linear models, benchmark comparisons of linear vs nonlinear solvers, or as a building block for more complex PSFs.
- Parameters:
coeff_key (str) – Key name for the coefficient vector in the parameter dict.
drop_features (bool) – Whether to remove the trailing size-1 feature axis (default True).
Notes
The resulting PSF shares the same spatial contract (rank/dim/pdepth, particles_input) as this Basis, and does not have a features axis.
- vectorize(dim=None, axes=None)¶
Lift a scalar expression to rank-1 (vector).
Equivalent to
self * unit_vector_basis(dim, axes=axes), i.e. a Cartesian product of the feature axis with unit vectors.- Parameters:
dim (int, optional) – Spatial dimension. Inferred from the expression’s contract when possible.
axes (sequence of int, optional) – Subset of spatial axes to include (default: all
dimaxes).
- Returns:
Vector expression with
n_features = self.n_features × len(axes).- Return type: