SFI.integrate.integrand module

class SFI.integrate.integrand.ConstOperand(value, alias='C')[source]

Bases: object

Constant array captured by the program (e.g. A_inv).

Parameters:
  • value (Array)

  • alias (str)

alias: str = 'C'
value: Array
class SFI.integrate.integrand.ExprOperand(expr, x, v=None, params_template=None, alias='E')[source]

Bases: object

Wrap a state expression evaluated at x (and optionally v).

Contract

expr(x, v=…, mask=…, extras=…, params=…) -> array
  • features sit on the last axis

  • ‘mask’ is forwarded from the ‘mask_out’ stream when present

alias: str = 'E'
expr: object
params_template: Mapping | None = None
property required_extras: Tuple[str, ...]

Keys this expression demands in the extras mapping.

property requires: Set[str]
v: 'TimeOp' | None = None
x: TimeOp
Parameters:
  • expr (object)

  • x (TimeOp)

  • v (Optional['TimeOp'])

  • params_template (Optional[Mapping])

  • alias (str)

class SFI.integrate.integrand.Integrand(exprs=(), times=(), consts=(), terms=())[source]

Bases: object

Compose state expressions and time operands via einsum per time slice.

require() -> Set[str]

Stream keys needed to evaluate this program on one time slice.

__call__(**streams) -> jnp.ndarray

Evaluate on one time slice. The collection handles mask and reduction.

estimate_bytes_per_sample(sample_streams) -> Optional[int]

Upper bound on per-particle bytes using a real sample (via opt_einsum).

Parameters:
batch_call(*, params=None, **streams)[source]

Evaluate with streams that carry a leading batch (K) axis.

This is the batched counterpart of __call__(). Streams such as X have shape (K, N, d) instead of (N, d), and dt has shape (K,) instead of being a scalar.

State-expression operands receive the full (K, N, d) tensor and handle arbitrary leading batch dimensions internally (the leaf’s _apply_user_func flattens the batch prefix and uses a single jax.vmap).

Time operands (e.g. velocity) are evaluated on the batch streams directly; the batch-safe velocity TimeOp handles dt broadcasting.

Einsum contractions are vmapped over the leading K axis so that existing subscript strings (which reference particle/spatial/feature axes only) work unchanged.

Returns:

Result with a leading K axis. Shape is (K, <per-row output shape>).

Return type:

jax.Array

Parameters:

params (Any | None)

estimate_bytes_per_sample(sample_streams, *, dtypesize=4)[source]

Conservative upper bound in bytes per time-sample (one k-row). Uses StateExpr static hints + shapes-only einsum path. No evaluation.

Parameters:
  • sample_streams (Mapping[str, Array])

  • dtypesize (int)

Return type:

int | None

require()[source]
Return type:

Set[str]

required_extras()[source]
Return type:

Set[str]

class SFI.integrate.integrand.Term(eq, ops, scale=1.0)[source]

Bases: object

One einsum contraction over the named operands.

Axis conventions

i : particle axis (kept until the integrator reduces over i) m,n,… : spatial indices a,b,… : feature axes (always last on any ExprOperand output)

eq: str
ops: Tuple[str, ...]
scale: float = 1.0
Parameters:
  • eq (str)

  • ops (Tuple[str, ...])

  • scale (float)

class SFI.integrate.integrand.TimeOperand(op, alias='T')[source]

Bases: object

Time operand (scalar/vector/tensor) built from streams via a TimeOp.

Parameters:
alias: str = 'T'
op: TimeOp
property requires: Set[str]