SFI.statefunc.params module

class SFI.statefunc.params.ParamSpec(name: 'str', shape: 'Tuple[int, ...]', dtype: 'Any' = <class 'jax.numpy.float32'>, init: 'Callable | str' = 'zeros', default: 'Any' = None)[source]

Bases: object

Parameters:
  • name (str)

  • shape (Tuple[int, ...])

  • dtype (Any)

  • init (Callable | str)

  • default (Any)

compatible_with(other)[source]

Shareable iff shape and dtype match exactly.

Parameters:

other (ParamSpec)

Return type:

bool

default: Any
dtype: Any
init: Callable | str
merged_with(other)[source]

Return a single spec representing the shared parameter. Requires compatibility; keeps self.init by default.

Parameters:

other (ParamSpec)

Return type:

ParamSpec

name: str
shape: Tuple[int, ...]
property size: int
class SFI.statefunc.params.ParamSuite(specs)[source]

Bases: Module

Immutable container holding a set of ParamSpec objects.

Parameters:

specs (tuple[ParamSpec, ...])

coerce(params, *, allow_scalar_for_scalar=True, allow_scalar_to_len1=True, cast_dtype=True)[source]

Normalize a user param dict into JAX arrays matching this suite.

Notes

  • If spec.shape == (), accept Python scalars / 0-d arrays (if allowed).

  • If spec.shape == (1,), optionally accept a scalar and expand to (1,).

  • Otherwise, require exact shape; dtype is cast if cast_dtype is True.

Returns a NEW dict with normalized arrays.

Parameters:
  • params (dict)

  • allow_scalar_for_scalar (bool)

  • allow_scalar_to_len1 (bool)

  • cast_dtype (bool)

Return type:

dict[str, Array]

defaults()[source]

Return a parameter dict from spec default values, or None if any spec has no default. Values are broadcast to the declared shape.

Return type:

dict[str, Array] | None

classmethod from_specs(*specs)[source]
Parameters:

specs (ParamSpec)

Return type:

ParamSuite

property has_defaults: bool

True iff every spec in this suite carries a concrete default.

materialize(vector, *, dtype_overrides=None)[source]
Parameters:
  • vector (Array)

  • dtype_overrides (dict[str, dtype] | None)

merge(other)[source]
Union with sharing-by-name:
  • If a name appears in both suites and specs are compatible (shape/dtype), they are tied (kept once).

  • If incompatible → error.

Parameters:

other (ParamSuite | None)

Return type:

ParamSuite

classmethod merge_many(*suites)[source]

Merge any number of suites, sharing parameters by name (shape/dtype must match).

Parameters:

suites (ParamSuite | None)

Return type:

ParamSuite | None

classmethod parse(params)[source]

Normalize various user-facing descriptions into a ParamSuite.

Accepts:

  • None – returns None

  • ParamSuite – returned as-is

  • dict[name -> array | shape] – infer shape/dtype

  • iterable[ParamSpec] – from_specs

  • iterable[str] – scalar specs for each name

Shapes may be (), (k,), (m, n, ...) or an integer k (interpreted as (k,)).

Return type:

ParamSuite | None

property size: int
specs: tuple[ParamSpec, ...]
stamp(params_dict)[source]

Build a stable stamp for (shape, dtype) of each param in template order. Caller is expected to pass an already-coerced dict.

Parameters:

params_dict (dict)

Return type:

tuple

tree_flatten()[source]
classmethod tree_unflatten(aux, children)[source]
vectorize(tree)[source]
Parameters:

tree (dict[str, Array])

Return type:

Array

zeros()[source]

Return a parameter dict with all values initialized to zero.

Return type:

dict[str, Array]