Source code for SFI.statefunc.sf

"""SF façade: state function with fixed parameters."""

import equinox as eqx
import jax
import jax.numpy as jnp

from .nodes import BaseNode
from .psf import PSF
from .stateexpr import StateExpr, _specialize_node


# ============================================================================
#  SF  – bound (θ-fixed) state-function
# ============================================================================
[docs] class SF(StateExpr): """State-Function with θ fixed (a thin wrapper over the PSF’s root). Behaves like a `Basis` for evaluation purposes (no `.d_theta()`), but you can still build `.d_x()` / `.d_v()` expressions. Feature axis handling mirrors the parent `PSF`: if `drop_features=True` and `n_features==1`, the final axis is removed. """ params: dict[str, jax.Array] = eqx.field(static=False, repr=False) _psf: PSF = eqx.field(static=True, repr=False) drop_features: bool = eqx.field(static=True, default=True) def __init__( self, psf: PSF, params: dict[str, jax.Array], *, drop_features: bool | None = None, ): super().__init__(psf.root) params = psf.template.coerce(params) object.__setattr__(self, "params", params) object.__setattr__(self, "_psf", psf) object.__setattr__( self, "drop_features", psf.drop_features if drop_features is None else bool(drop_features), ) def __call__(self, x, *, v=None, mask=None, extras=None): """Evaluate the bound function on a **batched** input.""" self._validate_extras_presence(extras) y = self._caller(x, v, mask, extras, self.params) if self.drop_features and self.n_features == 1: return jnp.squeeze(y, axis=-1) return y @property def labels(self): """Basis labels propagated from the parent PSF.""" return self._psf.labels def _with_node(self, new_root: BaseNode): new_psf = PSF(new_root, drop_features=self.drop_features) return SF(new_psf, self.params, drop_features=self.drop_features)
[docs] def specialize(self, *, dataset: int) -> "SF": """Specialize a *bound* function at condition ``dataset``. Rewrites the graph (folding ``dataset_index``-reading leaves) and projects the bound parameter values onto the shrunken template: a per-condition spec whose shape loses a leading axis is sliced at ``dataset``; shared specs are kept verbatim. """ k = int(dataset) new_root = _specialize_node(self.root, k) new_psf = PSF(new_root, drop_features=self.drop_features) new_params = _project_params(self.params, self._psf.template, new_psf.template, k) return SF(new_psf, new_params, drop_features=self.drop_features)
def _project_params(old_params, old_template, new_template, k: int) -> dict: """Map a bound parameter dict onto a specialized template. For each spec in ``new_template``: keep the old value when the shape is unchanged; slice the leading axis at ``k`` when specialization dropped it (the per-condition case, e.g. ``(K,) -> ()``). """ out: dict = {} for spec in new_template: name = spec.name if name not in old_params: continue # genuinely new param (none today); leave to template defaults old_val = old_params[name] old_shape = tuple(old_template[name].shape) new_shape = tuple(spec.shape) if old_shape == new_shape: out[name] = old_val elif len(old_shape) == len(new_shape) + 1 and old_shape[1:] == new_shape: out[name] = old_val[k] # per-condition slice else: raise ValueError( f"Cannot project param {name!r} from shape {old_shape} to " f"{new_shape} during specialize(dataset={k})." ) return out