Source code for SFI.integrate.timeops

from __future__ import annotations

import inspect
from typing import Callable, FrozenSet, Iterable, Optional

import jax.numpy as jnp

__all__ = [
    "TimeOp",
    "timeop",
    "stream",
    "velocity",
    "scale",
    "add",
]


[docs] class TimeOp: """ A time-local operator evaluated on one time slice. A TimeOp declares which streamed fields it needs either via its function signature or the explicit *requires* parameter. Example ------- >>> @timeop ... def vel(dX_minus, dt_minus): ... return dX_minus / dt_minus[..., None] >>> vel.requires frozenset({'dX_minus', 'dt_minus'}) """ def __init__( self, fn: Callable[..., jnp.ndarray], name: str | None = None, *, batch_safe: bool = False, requires: Optional[FrozenSet[str]] = None, ): self.fn = fn self._name = name or fn.__name__ self.batch_safe = batch_safe if requires is not None: self._requires: FrozenSet[str] = frozenset(requires) else: sig = inspect.signature(fn) self._requires = frozenset( p.name for p in sig.parameters.values() if p.kind in ( inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY, ) ) def __call__(self, **streams) -> jnp.ndarray: return self.fn(**{k: streams[k] for k in self._requires}) @property def requires(self) -> FrozenSet[str]: return self._requires @property def name(self) -> str: return self._name
[docs] def timeop( fn: Callable[..., jnp.ndarray] | None = None, *, name: str | None = None, batch_safe: bool = False, requires: Optional[FrozenSet[str]] = None, ): """ Decorator: convert a function into a TimeOp. Parameters ---------- batch_safe : bool If True, the function already handles a leading batch (K) axis in its inputs without requiring an additional ``jax.vmap``. requires : frozenset, optional Explicit set of required stream keys. When given, the function signature is not inspected for parameter names. Use this when the function accepts ``**streams`` and the required keys cannot be inferred from the signature. Notes ----- Without *requires*, the function's parameter names define the required stream keys (``**kwargs``/``*args`` parameters are ignored). """ def _wrap(f: Callable[..., jnp.ndarray]) -> TimeOp: return TimeOp(f, name=name, batch_safe=batch_safe, requires=requires) return _wrap if fn is None else _wrap(fn)
# --------- Stock, composable timeops ---------
[docs] def stream(key: str, *, name: str | None = None) -> TimeOp: """Return a TimeOp that passes stream *key* through unchanged.""" def _op(**streams: jnp.ndarray) -> jnp.ndarray: return streams[key] return TimeOp(_op, name=name or f"stream[{key}]", batch_safe=True, requires=frozenset({key}))
[docs] def velocity(dx_key: str, dt_key: str, *, name: str | None = None) -> TimeOp: """Return a TimeOp that computes dx/dt, broadcasting dt over arbitrary leading dims.""" def _op(**streams: jnp.ndarray) -> jnp.ndarray: dx = streams[dx_key] dt = streams[dt_key] # Expand dt to broadcast with dx for arbitrary leading batch dims. # Single-row: dt scalar, dx (N, d) -> dt becomes (1, 1) -> OK. # Batch: dt (K,), dx (K, N, d) -> dt becomes (K, 1, 1) -> OK. while dt.ndim < dx.ndim: dt = dt[..., jnp.newaxis] return dx / dt return TimeOp( _op, name=name or f"V[{dx_key}/{dt_key}]", batch_safe=True, requires=frozenset({dx_key, dt_key}), )
[docs] def scale(op: TimeOp, alpha: float | int, *, name: str | None = None) -> TimeOp: """Return a TimeOp that multiplies *op* by scalar *alpha*.""" def _op(**streams: jnp.ndarray) -> jnp.ndarray: return alpha * op(**streams) return TimeOp( _op, name=name or f"{alpha}*{op.name}", batch_safe=op.batch_safe, requires=op.requires, )
[docs] def add(ops: Iterable[TimeOp], *, name: str | None = None) -> TimeOp: """Return a TimeOp that sums the outputs of all *ops* element-wise. Parameters ---------- ops : iterable of TimeOp At least one operand is required. """ ops = tuple(ops) if not ops: raise ValueError( "add() requires at least one TimeOp; got an empty sequence." ) nm = name or " + ".join(o.name for o in ops) req = frozenset().union(*(o.requires for o in ops)) all_safe = all(o.batch_safe for o in ops) def _op(**streams: jnp.ndarray) -> jnp.ndarray: out = None for o in ops: cur = o(**streams) out = cur if out is None else out + cur return out # type: ignore[return-value] # out is set because ops is non-empty return TimeOp(_op, name=nm, batch_safe=all_safe, requires=req)