Source code for SFI.integrate.integrand

from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Sequence, Set, Tuple

if TYPE_CHECKING:
    from .timeops import TimeOp

import jax
import jax.numpy as jnp
import opt_einsum as oe  # required

__all__ = ["ExprOperand", "TimeOperand", "ConstOperand", "Term", "Integrand"]


[docs] @dataclass(frozen=True) class ExprOperand: """ Wrap a state expression evaluated at x (and optionally v). Contract -------- expr(x, v=..., mask=..., extras=..., params=...) -> array - features sit on the last axis - 'mask' is forwarded from the 'mask_out' stream when present """ expr: object x: "TimeOp" v: Optional["TimeOp"] = None # Template / default parameters. Call-time params may override this. params_template: Optional[Mapping] = None alias: str = "E" @property def required_extras(self) -> Tuple[str, ...]: """Keys this expression demands in the extras mapping.""" re = getattr(self.expr, "required_extras", ()) or () return tuple(re) @property def requires(self) -> Set[str]: req = set(self.x.requires) if self.v is not None: req |= set(self.v.requires) if self.required_extras: req.add("extras") # ensure collection streams extras return req
[docs] @dataclass(frozen=True) class TimeOperand: """Time operand (scalar/vector/tensor) built from streams via a TimeOp.""" op: "TimeOp" alias: str = "T" @property def requires(self) -> Set[str]: return set(self.op.requires)
[docs] @dataclass(frozen=True) class ConstOperand: """Constant array captured by the program (e.g. A_inv).""" value: jnp.ndarray alias: str = "C"
[docs] @dataclass(frozen=True) class Term: """ One einsum contraction over the named operands. Axis conventions ---------------- i : particle axis (kept until the integrator reduces over i) m,n,... : spatial indices a,b,... : feature axes (always last on any ExprOperand output) """ eq: str ops: Tuple[str, ...] scale: float = 1.0
[docs] class Integrand: """ Compose state expressions and time operands via einsum per time slice. ``require()`` -> ``Set[str]`` Stream keys needed to evaluate this program on one time slice. ``__call__(**streams)`` -> ``jnp.ndarray`` Evaluate on one time slice. The collection handles mask and reduction. ``estimate_bytes_per_sample(sample_streams)`` -> ``Optional[int]`` Upper bound on per-particle bytes using a real sample (via opt_einsum). """ def __init__( self, exprs: Sequence[ExprOperand] = (), times: Sequence[TimeOperand] = (), consts: Sequence[ConstOperand] = (), terms: Sequence[Term] = (), ): self._check_aliases(exprs, times, consts) self._exprs: Dict[str, ExprOperand] = {e.alias: e for e in exprs} self._times: Dict[str, TimeOperand] = {t.alias: t for t in times} self._consts: Dict[str, ConstOperand] = {c.alias: c for c in consts} self._terms: Tuple[Term, ...] = tuple(terms) @staticmethod def _check_aliases( exprs: Sequence[ExprOperand], times: Sequence[TimeOperand], consts: Sequence[ConstOperand], ) -> None: seen: Dict[str, str] = {} groups = [("expr", exprs), ("time", times), ("const", consts)] for kind, ops in groups: for op in ops: alias = op.alias # type: ignore[union-attr] if alias in seen: raise ValueError( f"Duplicate alias {alias!r}: already registered as " f"{seen[alias]}, now also as {kind}." ) seen[alias] = kind @staticmethod def _merge_operand_dicts( d1: Dict[str, Any], d2: Dict[str, Any], kind: str, ) -> Dict[str, Any]: """Merge two alias→operand dicts; allow same-object sharing, raise on conflict.""" merged = dict(d1) for alias, op in d2.items(): if alias in merged: if merged[alias] is not op: raise ValueError( f"Cannot add Integrands: alias {alias!r} maps to " f"different {kind} operands in the two summands." ) else: merged[alias] = op return merged # ---------- integration API ----------
[docs] def require(self) -> Set[str]: req = set() for e in self._exprs.values(): req |= e.requires for t in self._times.values(): req |= t.requires return req
[docs] def required_extras(self) -> Set[str]: keys: Set[str] = set() for E in self._exprs.values(): keys.update(E.required_extras) return keys
def __call__(self, *, params: Optional[Any] = None, **streams): bufs: Dict[str, jnp.ndarray] = {} # --- early extras validation (fail fast, clear errors) extras = streams.get("extras", None) for a, E in self._exprs.items(): if E.required_extras: try: # StateExpr provides the validator E.expr._validate_extras_presence(extras) # type: ignore[attr-defined] except KeyError: avail = sorted(list(extras.keys())) if isinstance(extras, Mapping) else None msg = f"[{a}] missing extras {list(E.required_extras)}" if avail is not None: msg += f"; available: {avail}" raise KeyError(msg) from None x = E.x(**streams) v = None if E.v is None else E.v(**streams) kwargs: Dict[str, Any] = {} if v is not None: kwargs["v"] = v if "mask_out" in streams and streams["mask_out"] is not None: kwargs["mask"] = streams["mask_out"] if "extras" in streams: kwargs["extras"] = streams["extras"] # Parameter routing: # - if params is not None: use it as the actual parameter object # - else: fall back to the operand's params_template par = E.params_template if params is None else params if par is not None: kwargs["params"] = par bufs[a] = E.expr(x, **kwargs) # time operands for a, T in self._times.items(): bufs[a] = T.op(**streams) # constants for a, C in self._consts.items(): bufs[a] = C.value # contractions out = None for term in self._terms: args = [bufs[a] for a in term.ops] val = jnp.einsum(term.eq, *args) out = term.scale * val if out is None else out + term.scale * val return out
[docs] def batch_call(self, *, params: Optional[Any] = None, **streams): """Evaluate with streams that carry a leading batch (K) axis. This is the batched counterpart of :meth:`__call__`. Streams such as ``X`` have shape ``(K, N, d)`` instead of ``(N, d)``, and ``dt`` has shape ``(K,)`` instead of being a scalar. State-expression operands receive the full ``(K, N, d)`` tensor and handle arbitrary leading batch dimensions internally (the leaf's ``_apply_user_func`` flattens the batch prefix and uses a single ``jax.vmap``). Time operands (e.g. velocity) are evaluated on the batch streams directly; the batch-safe ``velocity`` TimeOp handles dt broadcasting. Einsum contractions are vmapped over the leading K axis so that existing subscript strings (which reference particle/spatial/feature axes only) work unchanged. Returns ------- jax.Array Result with a leading ``K`` axis. Shape is ``(K, <per-row output shape>)``. """ bufs: Dict[str, jnp.ndarray] = {} # --- extras validation (same as __call__) --- extras = streams.get("extras", None) for a, E in self._exprs.items(): if E.required_extras: try: E.expr._validate_extras_presence(extras) except KeyError: avail = sorted(list(extras.keys())) if isinstance(extras, Mapping) else None msg = f"[{a}] missing extras {list(E.required_extras)}" if avail is not None: msg += f"; available: {avail}" raise KeyError(msg) from None x = E.x(**streams) # (K, N, d) v = None if E.v is None else E.v(**streams) # (K, N, d) or None kwargs: Dict[str, Any] = {} if v is not None: kwargs["v"] = v if "mask_out" in streams and streams["mask_out"] is not None: kwargs["mask"] = streams["mask_out"] # (K, N) if "extras" in streams: kwargs["extras"] = streams["extras"] par = E.params_template if params is None else params if par is not None: kwargs["params"] = par bufs[a] = E.expr(x, **kwargs) # Output: (K, N, ..., F) or (K, ..., F) depending on pdepth/rank # --- time operands: call directly if batch_safe, else vmap over K --- for a, T in self._times.items(): if getattr(T.op, "batch_safe", False): # TimeOp already handles leading batch dims — call directly bufs[a] = T.op(**streams) else: req_keys = T.op.requires req_streams = {k: streams[k] for k in req_keys if k in streams} # Separate batched (ndim≥1) vs scalar entries for in_axes in_axes_t = {k: 0 if v.ndim >= 1 else None for k, v in req_streams.items()} bufs[a] = jax.vmap(lambda s: T.op(**s), in_axes=(in_axes_t,))(req_streams) # --- constants --- for a, C in self._consts.items(): bufs[a] = C.value # --- contractions: vmap each einsum over the leading K axis --- const_aliases = frozenset(self._consts.keys()) out = None for term in self._terms: operands = [bufs[a] for a in term.ops] # Constants have no K axis → in_axes=None; others → 0 in_axes = tuple(None if a in const_aliases else 0 for a in term.ops) val = jax.vmap( lambda *args, _eq=term.eq: jnp.einsum(_eq, *args), in_axes=in_axes, )(*operands) out = term.scale * val if out is None else out + term.scale * val return out
# ---------- sugar for linear combos ---------- def __add__(self, other: "Integrand") -> "Integrand": me = self._merge_operand_dicts(self._exprs, other._exprs, "expr") mt = self._merge_operand_dicts(self._times, other._times, "time") mc = self._merge_operand_dicts(self._consts, other._consts, "const") return Integrand( exprs=list(me.values()), times=list(mt.values()), consts=list(mc.values()), terms=[*self._terms, *other._terms], ) def __radd__(self, other): # allow sum([...], start=0) return self if other == 0 else NotImplemented def __mul__(self, alpha: float) -> "Integrand": return Integrand( exprs=self._exprs.values(), times=self._times.values(), consts=self._consts.values(), terms=[Term(eq=t.eq, ops=t.ops, scale=t.scale * alpha) for t in self._terms], ) __rmul__ = __mul__ # ---------- memory hint on a real sample ----------
[docs] def estimate_bytes_per_sample( self, sample_streams: Mapping[str, jnp.ndarray], *, dtypesize: int = 4, ) -> Optional[int]: """ Conservative upper bound in bytes per **time-sample** (one k-row). Uses StateExpr static hints + shapes-only einsum path. No evaluation. """ _X = sample_streams.get("X", None) _dtype = None if _X is None else _X.dtype P = int(sample_streams["N_total"]) # use provided, conservative # ---- shapes without evaluation def expr_shape(E) -> tuple[int, ...]: p = int(getattr(E.expr, "pdepth", 0)) r = int(getattr(E.expr, "rank", 0)) d = int(getattr(E.expr, "dim", 0)) F = int(getattr(E.expr, "n_features", 1)) paxes = () if p == 0 else (P,) * p # full particle axes (per-sample) return (*paxes, *(d,) * r, F) def time_shape(T) -> tuple[int, ...]: arr = T.op(**sample_streams) # shape-only probe return tuple(arr.shape) def const_shape(C) -> tuple[int, ...]: return tuple(C.value.shape) shapes = {} for a, E in self._exprs.items(): shapes[a] = expr_shape(E) for a, T in self._times.items(): shapes[a] = time_shape(T) for a, C in self._consts.items(): shapes[a] = const_shape(C) def match_len(shape: tuple[int, ...], subscript: str) -> tuple[int, ...]: L = len(subscript) return (*shape, *([1] * (L - len(shape)))) if len(shape) < L else shape def size_for_term(t: Term) -> int: lhs = t.eq.split("->", 1)[0] subs = [s.strip() for s in lhs.split(",")] ops_shapes = [match_len(shapes[a], subs[i]) for i, a in enumerate(t.ops)] try: _, info = oe.contract_path(t.eq, *ops_shapes, shapes=True, optimize="greedy") if hasattr(info, "largest_intermediate"): return int(info.largest_intermediate) * int(dtypesize) if hasattr(info, "intermediate_shapes") and info.intermediate_shapes: mx = max(int(jnp.prod(jnp.array(sh))) for sh in info.intermediate_shapes) # type: ignore return mx * int(dtypesize) except Exception: pass total_el = sum(int(jnp.prod(jnp.array(s))) for s in ops_shapes) return total_el * int(dtypesize) peak_terms = max((size_for_term(t) for t in self._terms), default=0) state_bytes = sum(int(jnp.prod(jnp.array(shapes[a]))) * int(dtypesize) for a in self._exprs.keys()) expr_transient = 0 for E in self._exprs.values(): estimator = getattr(E.expr, "estimate_bytes_per_sample", None) if callable(estimator): expr_transient += int(estimator(dtype=_dtype, particle_size=P)) total = expr_transient + state_bytes + peak_terms return int(total) if total > 0 else None