Source code for SFI.statefunc.params

from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Callable, Iterable, Tuple

import equinox as eqx
import jax.numpy as jnp


# ---------------------------------------------------------------------
# ParamSpec  - static description of one parameter block
# ---------------------------------------------------------------------
[docs] @dataclass(frozen=True, slots=True) class ParamSpec: name: str # <-- KEY for sharing shape: Tuple[int, ...] dtype: Any = jnp.float32 init: Callable | str = "zeros" # PRNG->array or keyword default: Any = None # optional concrete value (scalar or array-like) @property def size(self) -> int: from functools import reduce from operator import mul return reduce(mul, self.shape, 1)
[docs] def compatible_with(self, other: "ParamSpec") -> bool: """Shareable iff shape and dtype match exactly.""" return (self.shape == other.shape) and (self.dtype == other.dtype)
[docs] def merged_with(self, other: "ParamSpec") -> "ParamSpec": """ Return a single spec representing the shared parameter. Requires compatibility; keeps `self.init` by default. """ if not self.compatible_with(other): raise ValueError( f"ParamSpec mismatch for '{self.name}': {self.shape}/{self.dtype} vs {other.shape}/{other.dtype}" ) # deterministically keep `self`, but salvage a default from `other` if self.default is None and other.default is not None: from dataclasses import replace return replace(self, default=other.default) return self
[docs] class ParamSuite(eqx.Module): """Immutable container holding a set of ``ParamSpec`` objects.""" specs: tuple[ParamSpec, ...] = eqx.field(static=True) _lookup: dict[str, int] = eqx.field(static=True) # name → idx # ---------- construction ---------- # def __init__(self, specs: Iterable[ParamSpec]): specs = tuple(specs) names = [ps.name for ps in specs] if len(set(names)) != len(names): dup = {n for n in names if names.count(n) > 1} raise ValueError(f"Duplicate parameter names: {dup}") object.__setattr__(self, "specs", specs) object.__setattr__(self, "_lookup", {ps.name: i for i, ps in enumerate(specs)})
[docs] @classmethod def from_specs(cls, *specs: ParamSpec) -> "ParamSuite": return cls(specs)
# ---------- basic info ---------- # def __iter__(self): return iter(self.specs) def __len__(self): return len(self.specs) def __getitem__(self, name: str) -> ParamSpec: return self.specs[self._lookup[name]] @property def size(self) -> int: return sum(ps.size for ps in self.specs) # ---------- convenience constructors ---------- #
[docs] def zeros(self) -> dict[str, jnp.ndarray]: """Return a parameter dict with all values initialized to zero.""" return {ps.name: jnp.zeros(ps.shape, dtype=ps.dtype) for ps in self.specs}
@property def has_defaults(self) -> bool: """True iff every spec in this suite carries a concrete ``default``.""" return len(self.specs) > 0 and all(ps.default is not None for ps in self.specs)
[docs] def defaults(self) -> dict[str, jnp.ndarray] | None: """Return a parameter dict from spec ``default`` values, or None if any spec has no default. Values are broadcast to the declared shape.""" if not self.has_defaults: return None out: dict[str, jnp.ndarray] = {} for ps in self.specs: arr = jnp.asarray(ps.default, dtype=ps.dtype) if arr.shape != ps.shape: arr = jnp.broadcast_to(arr, ps.shape) out[ps.name] = arr return out
# ---------- param ↔ vector helpers ---------- #
[docs] def materialize( self, vector: jnp.ndarray, *, dtype_overrides: dict[str, jnp.dtype] | None = None, ): if vector.ndim != 1 or vector.size != self.size: raise ValueError("Flat vector has wrong length") out, i = {}, 0 for ps in self.specs: n = ps.size arr = ( vector[i : i + n] .reshape(ps.shape) .astype(dtype_overrides.get(ps.name, ps.dtype) if dtype_overrides else ps.dtype) ) out[ps.name] = arr i += n return out
[docs] def vectorize(self, tree: dict[str, jnp.ndarray]) -> jnp.ndarray: parts = [jnp.ravel(tree[ps.name]).astype(ps.dtype) for ps in self.specs] return jnp.concatenate(parts, axis=0)
# ---------- merging (parameter sharing) ---------- #
[docs] def merge(self, other: "ParamSuite | None") -> "ParamSuite": """ Union with sharing-by-name: - If a name appears in both suites and specs are compatible (shape/dtype), they are **tied** (kept once). - If incompatible → error. """ if other is None: return self if not isinstance(other, ParamSuite): raise TypeError(f"merge expects ParamSuite or None, got {type(other).__name__}") if not self.specs: return other if not other.specs: return self left = {ps.name: ps for ps in self.specs} right = {ps.name: ps for ps in other.specs} out: dict[str, ParamSpec] = dict(left) # copy for name, ps_r in right.items(): if name in out: ps_l = out[name] out[name] = ps_l.merged_with(ps_r) # validates compatibility else: out[name] = ps_r return ParamSuite(out.values())
[docs] @classmethod def merge_many(cls, *suites: "ParamSuite | None") -> "ParamSuite | None": """Merge any number of suites, sharing parameters by name (shape/dtype must match).""" merged: ParamSuite | None = None for s in suites: if s is None: continue merged = s if merged is None else merged.merge(s) return merged
# ---------- PyTree protocol ---------- #
[docs] def tree_flatten(self): children = () # no array leaves aux = self.specs # static return children, aux
# ---------- universal parser ---------- #
[docs] @classmethod def parse(cls, params) -> "ParamSuite | None": """ Normalize various user-facing descriptions into a ParamSuite. Accepts: - ``None`` -- returns ``None`` - ``ParamSuite`` -- returned as-is - ``dict[name -> array | shape]`` -- infer shape/dtype - ``iterable[ParamSpec]`` -- from_specs - ``iterable[str]`` -- scalar specs for each name Shapes may be ``()``, ``(k,)``, ``(m, n, ...)`` or an integer k (interpreted as ``(k,)``). """ from collections.abc import Iterable as _Iterable if params is None: return None if isinstance(params, ParamSuite): return params # dict path: values can be arrays (sample), shapes, ints, ParamSpec, or None→scalar if isinstance(params, dict): specs: list[ParamSpec] = [] for name, val in params.items(): if isinstance(val, ParamSpec): specs.append(val) continue if hasattr(val, "shape"): shape = tuple(val.shape) dtype = getattr(val, "dtype", jnp.float32) specs.append(ParamSpec(name, shape, dtype=dtype)) continue if val is None: specs.append(ParamSpec(name, ())) continue if isinstance(val, int): specs.append(ParamSpec(name, (int(val),))) continue if isinstance(val, tuple) and all(isinstance(n, int) for n in val): specs.append(ParamSpec(name, tuple(int(n) for n in val))) continue raise TypeError(f"ParamSuite.parse: unsupported dict value for '{name}': {type(val).__name__}") return cls.from_specs(*specs) # iterable path: ParamSpec or str if isinstance(params, _Iterable): items = list(params) specs: list[ParamSpec] = [] for it in items: if isinstance(it, ParamSpec): specs.append(it) elif isinstance(it, str): specs.append(ParamSpec(it, ())) else: raise TypeError( f"ParamSuite.parse: iterable must contain ParamSpec or str (got {type(it).__name__})" ) return cls.from_specs(*specs) raise TypeError(f"ParamSuite.parse: unsupported params of type {type(params).__name__}")
[docs] @classmethod def tree_unflatten(cls, aux, children): return cls(aux)
[docs] def coerce( self, params: dict, *, allow_scalar_for_scalar: bool = True, allow_scalar_to_len1: bool = True, cast_dtype: bool = True, ) -> dict[str, jnp.ndarray]: """ Normalize a user param dict into JAX arrays matching this suite. Rules: - If spec.shape == (), accept Python scalars / 0-d arrays (if allowed). - If spec.shape == (1,), optionally accept a scalar and expand to (1,). - Otherwise, require exact shape; dtype is cast if `cast_dtype` is True. Returns a NEW dict with normalized arrays. """ out = {} look = self._lookup missing = [k for k in look if k not in params] if missing: raise KeyError(f"Missing params: {missing}") for spec in self: name, shape, dtype = spec.name, spec.shape, spec.dtype val = params[name] arr = jnp.asarray(val, dtype=dtype if cast_dtype else None) # Handle scalar vs length-1 convenience if shape == (): # true scalar # 0-d OK; (1,) OK if allow_scalar_for_scalar and allow squeezing if arr.shape == (): pass elif arr.shape == (1,) and allow_scalar_for_scalar: arr = jnp.reshape(arr, ()) else: raise TypeError(f"Param '{name}': expected scalar (), got {arr.shape}") elif shape == (1,): # accept scalar and expand to (1,) if asked if arr.shape == (): if allow_scalar_to_len1: arr = jnp.reshape(arr, (1,)) else: raise TypeError(f"Param '{name}': expected (1,), got scalar ()") elif arr.shape == (1,): pass else: raise TypeError(f"Param '{name}': expected {shape}, got {arr.shape}") else: if arr.shape != shape: raise TypeError(f"Param '{name}': expected {shape}, got {arr.shape}") # Final dtype enforcement if cast_dtype: arr = arr.astype(dtype) elif arr.dtype != dtype: raise TypeError(f"Param '{name}': expected dtype {dtype}, got {arr.dtype}") out[name] = arr return out
[docs] def stamp(self, params_dict: dict) -> tuple: """ Build a stable stamp for (shape, dtype) of each param in template order. Caller is expected to pass an already-coerced dict. """ return tuple((spec.name, params_dict[spec.name].shape, params_dict[spec.name].dtype) for spec in self)