Source code for SFI.statefunc.psf

"""PSF façade: parametric family of state functions."""

import equinox as eqx
import jax
import jax.numpy as jnp

from .nodes import BaseNode, DerivativeNode
from .params import ParamSuite
from .stateexpr import StateExpr


# ============================================================================
#  PSF  – parametric state-function family
# ============================================================================
[docs] class PSF(StateExpr): """Parametric State-Function family `F(x; θ)`. By default `drop_features=True`: when `n_features==1`, outputs **do not** carry a trailing feature axis. `.d_theta()` forces `drop_features=False`. Holds a `ParamSuite` template describing names, shapes, and dtypes of θ. `__call__` evaluates `F` given a parameter dict matching the template. Supports `.d_theta()` in addition to `.d_x()`/`.d_v()`. """ template: ParamSuite = eqx.field(static=True) drop_features: bool = eqx.field(static=True, default=True) _validated_stamp: tuple | None = eqx.field(static=True, default=None, repr=False) def __init__(self, root: BaseNode, *, drop_features: bool = True): if root.param_suite is None: raise ValueError("PSF root must carry parameters") super().__init__(root) object.__setattr__(self, "template", root.param_suite) object.__setattr__(self, "drop_features", bool(drop_features)) def _coerce_and_check(self, params, extras): # Normalize/validate via the suite pnorm = self.template.coerce(params) stamp = tuple(sorted((k, pnorm[k].shape, pnorm[k].dtype) for k in self.template._lookup)) if stamp != self._validated_stamp: object.__setattr__(self, "_validated_stamp", stamp) # Only presence for extras here; shape/broadcast is handled in node contracts self._validate_extras_presence(extras) return pnorm def __call__(self, x, *, v=None, mask=None, extras=None, params=None): if params is None: if self.template.size == 0: params = {} # parameter-free PSF: auto-supply else: defaults = self.template.defaults() if defaults is None: raise ValueError("PSF.__call__: params are required (template has parameters without defaults).") params = defaults params = self._coerce_and_check(params, extras) y = self._caller(x, v, mask, extras, params) if self.drop_features and self.n_features == 1: return jnp.squeeze(y, axis=-1) return y
[docs] def bind(self, params: dict[str, jax.Array] | None = None): """Freeze parameter dict into an SF with normalized arrays. If ``params is None``, fall back to spec defaults (``ParamSuite.defaults()``). Raises if the template has parameters without defaults. """ from .sf import SF if params is None: defaults = self.template.defaults() if defaults is None: raise ValueError("PSF.bind(): params are required (template has parameters without defaults).") params = defaults return SF(self, params, drop_features=self.drop_features)
[docs] def d_theta(self, *, mode: str = "auto"): """Build an expression for the Jacobian w.r.t. parameters θ. Shape effect ------------ The final axis becomes `features × n_params_total`. Batch/pdepth/rank prefixes are preserved exactly. Notes ----- The parameter PyTree is handled leafwise; each grad leaf is flattened over its param part, then all leaves are concatenated along the final axis. """ if self.root.param_suite is None: # unreachable: enforced by __init__ raise AttributeError("Expression has no parameters to differentiate") node = DerivativeNode(self.root, var="theta", mode=mode) return PSF(node, drop_features=False)
# ------------- SciPy helpers --------------------------------- @property def labels(self): """Basis labels from the underlying CoeffNode (if present).""" from .nodes.ops.linear import CoeffNode if isinstance(self.root, CoeffNode): return self.root._basis_labels _, labs, _ = self.root.flatten() return labs
[docs] def flatten_params(self, params: dict[str, jax.Array]): """Vectorize a parameter dict according to the template order.""" return self.template.vectorize(params)
[docs] def unflatten_params(self, vec: jax.Array): """Materialize a parameter dict from a flat vector (inverse of `flatten_params`).""" return self.template.materialize(vec)