SFI.integrate.integrand module¶
- class SFI.integrate.integrand.ConstOperand(value, alias='C')[source]¶
Bases:
objectConstant array captured by the program (e.g. A_inv).
- Parameters:
value (Array)
alias (str)
- alias: str = 'C'¶
- value: Array¶
- class SFI.integrate.integrand.ExprOperand(expr, x, v=None, params_template=None, alias='E')[source]¶
Bases:
objectWrap a state expression evaluated at x (and optionally v).
Contract¶
- expr(x, v=…, mask=…, extras=…, params=…) -> array
features sit on the last axis
‘mask’ is forwarded from the ‘mask_out’ stream when present
- alias: str = 'E'¶
- expr: object¶
- params_template: Mapping | None = None¶
- property required_extras: Tuple[str, ...]¶
Keys this expression demands in the extras mapping.
- property requires: Set[str]¶
- v: 'TimeOp' | None = None¶
- Parameters:
expr (object)
x (TimeOp)
v (Optional['TimeOp'])
params_template (Optional[Mapping])
alias (str)
- class SFI.integrate.integrand.Integrand(exprs=(), times=(), consts=(), terms=())[source]¶
Bases:
objectCompose state expressions and time operands via einsum per time slice.
require()->Set[str]Stream keys needed to evaluate this program on one time slice.
__call__(**streams)->jnp.ndarrayEvaluate on one time slice. The collection handles mask and reduction.
estimate_bytes_per_sample(sample_streams)->Optional[int]Upper bound on per-particle bytes using a real sample (via opt_einsum).
- Parameters:
exprs (Sequence[ExprOperand])
times (Sequence[TimeOperand])
consts (Sequence[ConstOperand])
terms (Sequence[Term])
- batch_call(*, params=None, **streams)[source]¶
Evaluate with streams that carry a leading batch (K) axis.
This is the batched counterpart of
__call__(). Streams such asXhave shape(K, N, d)instead of(N, d), anddthas shape(K,)instead of being a scalar.State-expression operands receive the full
(K, N, d)tensor and handle arbitrary leading batch dimensions internally (the leaf’s_apply_user_funcflattens the batch prefix and uses a singlejax.vmap).Time operands (e.g. velocity) are evaluated on the batch streams directly; the batch-safe
velocityTimeOp handles dt broadcasting.Einsum contractions are vmapped over the leading K axis so that existing subscript strings (which reference particle/spatial/feature axes only) work unchanged.
- Returns:
Result with a leading
Kaxis. Shape is(K, <per-row output shape>).- Return type:
jax.Array
- Parameters:
params (Any | None)
- class SFI.integrate.integrand.Term(eq, ops, scale=1.0)[source]¶
Bases:
objectOne einsum contraction over the named operands.
Axis conventions¶
i : particle axis (kept until the integrator reduces over i) m,n,… : spatial indices a,b,… : feature axes (always last on any ExprOperand output)
- eq: str¶
- ops: Tuple[str, ...]¶
- scale: float = 1.0¶
- Parameters:
eq (str)
ops (Tuple[str, ...])
scale (float)