Source code for SFI.statefunc.factory

import logging
from typing import Any, Callable, Iterable, Optional, Sequence

import jax.numpy as jnp

from .basis import Basis
from .interactor import Interactor
from .nodes import SimpleLeaf
from .nodes.contract import Rank
from .nodes.leaf import InteractionLeaf
from .params import ParamSpec, ParamSuite
from .psf import PSF
from .sf import SF

logger = logging.getLogger(__name__)


def _probe_call(obj, *, dim: int, rank: int, n_features: int, label: str):
    """Run a zero-valued test call to catch shape/signature errors early.

    Called when ``dim`` is known at factory construction time.
    Logs a warning on failure instead of raising, so construction still
    succeeds (the user may intend a shape that only works with real data).
    """
    x_probe = jnp.zeros((1, dim))  # (N=1, dim)
    try:
        out = obj(x_probe)
    except Exception as exc:
        logger.warning(
            "%s: probe call with x of shape %s failed — the function may error at runtime.  Original error: %s",
            label,
            x_probe.shape,
            exc,
        )
        return

    out = jnp.asarray(out)
    if out.shape[0] != 1:
        logger.warning(
            "%s: probe call returned shape %s but expected leading axis 1 (matching N=1 input).  Check output shape.",
            label,
            out.shape,
        )
        return
    logger.debug("%s: probe call OK, output shape %s", label, out.shape)


[docs] def make_basis( func: Callable, *, dim: int | None = None, rank: int, n_features: int = 1, needs_v: bool = False, labels: Optional[Sequence[str]] = None, descriptor: Any = "custom", extras_keys: Optional[Sequence[str]] = None, particle_extras: Optional[Sequence[str]] = None, specialize_at: Optional[Callable] = None, ) -> Basis: """ Construct a **deterministic Basis** from a *single-sample* user function, with **no particle semantics**. Particle axes (if present in ``x`` at call time) are treated purely as batch axes and vmapped over. ``particle_extras`` names extras keys whose values are **per-sample** arrays aligned with the batch/particle axes (e.g. an ``extras_local`` entry of shape ``(N, ...)``): they are vmapped alongside ``x``, so the single-sample function sees *its own* particle's value instead of the whole array — the route to per-particle terms in single-particle bases (home-range centres, individual labels, ...). User function signature — declare **only** the kwargs you need: - Simplest: ``f(x) -> array`` - With velocity: ``f(x, *, v) -> array`` - With extras: ``f(x, *, extras) -> array`` The full signature is ``f(x, *, v=None, mask=None, extras=None)``; we introspect and pass only the kwargs you declare. Shapes (single sample):: x: (dim,) return: (*rank_axes, m) # feature last; m == n_features If ``n_features == 1`` you may omit the last axis; a singleton feature axis is auto-inserted. Extras ~~~~~~ If ``extras`` is declared, you may provide ``extras_keys=(...)`` to enforce keys. Extras arrays must broadcast over the **batch prefix** (never over rank/feature). Examples -------- >>> import jax.numpy as jnp >>> from SFI.statefunc import make_basis >>> B = make_basis(lambda x: x, dim=2, rank=1, n_features=1) # (equivalent to the built-in X(dim=2)) JAX ~~~ Write ``f`` with ``jax.numpy`` and keep it pure; works with jit/vmap/autodiff. """ leaf = SimpleLeaf( func=func, n_features=int(n_features), labels=tuple(labels) if labels is not None else tuple(f"f{j}" for j in range(n_features)), descriptor=descriptor, dim=dim, rank=Rank(rank), needs_v=bool(needs_v), # SimpleLeaf forbids particles & pdepth by construction (pdepth=0, particles_input=False). extras_keys=tuple(extras_keys) if extras_keys is not None else (), particle_extras=tuple(particle_extras) if particle_extras is not None else (), specialize_at=specialize_at, ) result = Basis(leaf) if dim is not None and not extras_keys and not needs_v: _probe_call(result, dim=dim, rank=rank, n_features=n_features, label="make_basis") return result
[docs] def make_psf( func: Callable, *, dim: int | None = None, rank: int, n_features: int = 1, drop_features: bool = True, needs_v: bool = False, labels: Optional[Sequence[str]] = None, descriptor: Any = "parametric", params: ParamSuite | Iterable[ParamSpec] | dict[str, Any], extras_keys: Optional[Sequence[str]] = None, specialize_at: Optional[Callable] = None, ) -> PSF: """ Construct a **parametric state-function family (PSF)** from a *single-sample* user function, **without particle semantics**. User function signature — declare **only** the kwargs you need: - Simplest: ``f(x, *, params) -> array`` - With velocity: ``f(x, *, v, params) -> array`` - With extras: ``f(x, *, params, extras) -> array`` The full signature is ``f(x, *, params, v=None, mask=None, extras=None)``; we introspect and pass only the kwargs you declare. Shapes (single sample):: x: (dim,) return: (*rank_axes, m) # feature last; m == n_features If ``n_features == 1`` you may omit the last axis; we auto-insert a singleton feature axis. Parameters (``params``) may be described as: - a ``ParamSuite``, - an iterable of ``ParamSpec``, - a dict of shapes, e.g. ``{'W': (d,d), 'b': ()}``, - or a dict of **sample arrays** from which (shape, dtype) are inferred. Extras ~~~~~~ Same rules as ``make_basis`` (``extras_keys`` optional; broadcast over batch prefix). JAX ~~~ Works with jit/vmap/autodiff w.r.t. inputs and parameters. """ suite = ParamSuite.parse(params) leaf = SimpleLeaf( func=func, n_features=int(n_features), labels=tuple(labels) if labels is not None else tuple(f"f{j}" for j in range(n_features)), descriptor=descriptor, dim=dim, rank=Rank(rank), needs_v=bool(needs_v), param_suite=suite, # Parametric path is enabled by providing a ParamSuite. extras_keys=tuple(extras_keys) if extras_keys is not None else (), specialize_at=specialize_at, ) return PSF(leaf, drop_features=bool(drop_features))
[docs] def make_sf( func: Callable, *, dim: int | None = None, rank: int, n_features: int = 1, drop_features: bool = True, needs_v: bool = False, labels: Optional[Sequence[str]] = None, descriptor: Any = "custom_sf", extras_keys: Optional[Sequence[str]] = None, ) -> SF: """ Construct an **SF** (bound state function) directly from a parameter-free user function — no ``Basis`` or ``PSF`` intermediate needed. This is the simplest entry point when you have a known, fixed function (e.g. an exact model for comparison, or a hand-coded feature) and just want a callable that participates in the SFI expression-tree ecosystem. The resulting ``SF`` supports ``.d_x()``, ``.d_v()``, and can be passed to ``compare_to_exact``, ``integrate``, or any other API that accepts an ``SF`` / ``StateExpr``. User function signature — declare **only** the kwargs you need: - Simplest: ``f(x) -> array`` - With velocity: ``f(x, *, v) -> array`` - With extras: ``f(x, *, extras) -> array`` Shapes (single sample):: x: (dim,) return: (*rank_axes, n_features) If ``n_features == 1`` you may omit the trailing feature axis; it is auto-inserted. The resulting SF squeezes it back when ``drop_features=True`` (default). Parameters ---------- func : callable Pure JAX function, compatible with jit/vmap/autodiff. dim : int or None Spatial dimensionality (None = infer at first call). rank : int Tensor rank of the output (0=scalar, 1=vector, 2=matrix). n_features : int Number of output features (default 1). drop_features : bool Remove trailing size-1 feature axis (default True). needs_v : bool Whether ``func`` requires velocity ``v``. labels : sequence of str or None Human-readable feature labels (auto-generated if None). descriptor : any Metadata tag stored on the leaf node. extras_keys : sequence of str or None Required keys in the ``extras`` mapping. Returns ------- SF A bound, callable state function with no free parameters. Examples -------- >>> import jax.numpy as jnp >>> from SFI.statefunc import make_sf >>> harmonic = make_sf(lambda x: -x, rank=1, dim=2) >>> harmonic(jnp.array([1.0, 2.0])) Array([-1., -2.], dtype=float32) """ # Build a SimpleLeaf with an *empty* (but non-None) ParamSuite so the # node is accepted by PSF. The empty suite carries zero parameters, # so PSF.__call__ auto-supplies params={} and SF.bind freezes nothing. empty_suite = ParamSuite([]) leaf = SimpleLeaf( func=func, n_features=int(n_features), labels=tuple(labels) if labels is not None else tuple(f"f{j}" for j in range(n_features)), descriptor=descriptor, dim=dim, rank=Rank(rank), needs_v=bool(needs_v), param_suite=empty_suite, extras_keys=tuple(extras_keys) if extras_keys is not None else (), ) psf = PSF(leaf, drop_features=bool(drop_features)) result = psf.bind({}) if dim is not None and not extras_keys and not needs_v: _probe_call(result, dim=dim, rank=rank, n_features=n_features, label="make_sf") return result
[docs] def make_interactor( func: Callable, *, dim: int, rank: Rank, # arity: K: int | None = None, Kmax: int | None = None, # features & plumbing: n_features: int = 1, needs_v: bool = False, labels: Iterable[str] = (), descriptor=None, params: ParamSuite | None = None, extras_keys: Iterable[str] = (), particle_extras: Iterable[str] = (), ): """ Build a local interaction dictionary (Interactor) from a single-sample user function that consumes (K, dim) and returns feature-last. Pass exactly one of: - K=int → fixed arity - Kmax=int → variable arity (ragged via mask) ``particle_extras`` names the extras keys whose values are **per-particle** arrays (shape ``(P, ...)``): the dispatcher gathers them per edge member, so inside ``func`` they arrive with shape ``(K, ...)`` — one entry per member of the local tuple. The reserved ``"particle_index"`` extra (injected by :class:`~SFI.trajectory.TrajectoryCollection`) combined with a ``(P,)``-shaped parameter gives per-particle inferred parameters:: def local(Xk, *, params, extras): mob = params["mob"][extras["particle_index"]] # (K,) ... """ suite = ParamSuite.parse(params) if (K is None) == (Kmax is None): raise ValueError("Provide exactly one of K or Kmax.") leaf = InteractionLeaf( mode="fixed" if K is not None else "variable", K=K, Kmax=Kmax, func=func, dim=dim, rank=rank, n_features=n_features, needs_v=needs_v, param_suite=suite, labels=tuple(labels), descriptor=descriptor, extras_keys=tuple(extras_keys), particle_extras=tuple(particle_extras), ) return Interactor(leaf)