SFI.integrate package

SFI.integrate — Time-averaging integration engine.

Public API

integrate

Run an Integrand over a TrajectoryCollection and reduce.

make_parametric_integrator

Build a reusable, jittable integrator for a parametric Integrand.

make_minibatch_parametric_integrator

Like make_parametric_integrator but also returns a stochastic mini-batch runner.

Integrand

Compose state expressions and time operands via einsum per time slice.

Term, ExprOperand, TimeOperand, ConstOperand

Building blocks for Integrand programs.

stream, timeop, velocity, scale, add

TimeOp constructors for stream access and linear combinations.

class SFI.integrate.ConstOperand(value, alias='C')[source]

Bases: object

Constant array captured by the program (e.g. A_inv).

Parameters:
  • value (Array)

  • alias (str)

alias: str = 'C'
value: Array
class SFI.integrate.ExprOperand(expr, x, v=None, params_template=None, alias='E')[source]

Bases: object

Wrap a state expression evaluated at x (and optionally v).

Contract

expr(x, v=…, mask=…, extras=…, params=…) -> array
  • features sit on the last axis

  • ‘mask’ is forwarded from the ‘mask_out’ stream when present

alias: str = 'E'
expr: object
params_template: Mapping | None = None
property required_extras: Tuple[str, ...]

Keys this expression demands in the extras mapping.

property requires: Set[str]
v: 'TimeOp' | None = None
x: TimeOp
Parameters:
  • expr (object)

  • x (TimeOp)

  • v (Optional['TimeOp'])

  • params_template (Optional[Mapping])

  • alias (str)

class SFI.integrate.Integrand(exprs=(), times=(), consts=(), terms=())[source]

Bases: object

Compose state expressions and time operands via einsum per time slice.

require() -> Set[str]

Stream keys needed to evaluate this program on one time slice.

__call__(**streams) -> jnp.ndarray

Evaluate on one time slice. The collection handles mask and reduction.

estimate_bytes_per_sample(sample_streams) -> Optional[int]

Upper bound on per-particle bytes using a real sample (via opt_einsum).

Parameters:
batch_call(*, params=None, **streams)[source]

Evaluate with streams that carry a leading batch (K) axis.

This is the batched counterpart of __call__(). Streams such as X have shape (K, N, d) instead of (N, d), and dt has shape (K,) instead of being a scalar.

State-expression operands receive the full (K, N, d) tensor and handle arbitrary leading batch dimensions internally (the leaf’s _apply_user_func flattens the batch prefix and uses a single jax.vmap).

Time operands (e.g. velocity) are evaluated on the batch streams directly; the batch-safe velocity TimeOp handles dt broadcasting.

Einsum contractions are vmapped over the leading K axis so that existing subscript strings (which reference particle/spatial/feature axes only) work unchanged.

Returns:

Result with a leading K axis. Shape is (K, <per-row output shape>).

Return type:

jax.Array

Parameters:

params (Any | None)

estimate_bytes_per_sample(sample_streams, *, dtypesize=4)[source]

Conservative upper bound in bytes per time-sample (one k-row). Uses StateExpr static hints + shapes-only einsum path. No evaluation.

Parameters:
  • sample_streams (Mapping[str, Array])

  • dtypesize (int)

Return type:

int | None

require()[source]
Return type:

Set[str]

required_extras()[source]
Return type:

Set[str]

class SFI.integrate.Term(eq, ops, scale=1.0)[source]

Bases: object

One einsum contraction over the named operands.

Axis conventions

i : particle axis (kept until the integrator reduces over i) m,n,… : spatial indices a,b,… : feature axes (always last on any ExprOperand output)

eq: str
ops: Tuple[str, ...]
scale: float = 1.0
Parameters:
  • eq (str)

  • ops (Tuple[str, ...])

  • scale (float)

class SFI.integrate.TimeOperand(op, alias='T')[source]

Bases: object

Time operand (scalar/vector/tensor) built from streams via a TimeOp.

Parameters:
alias: str = 'T'
op: TimeOp
property requires: Set[str]
SFI.integrate.add(ops, *, name=None)[source]

Return a TimeOp that sums the outputs of all ops element-wise.

Parameters:
  • ops (iterable of TimeOp) – At least one operand is required.

  • name (str | None)

Return type:

TimeOp

SFI.integrate.integrate(collection, program, *, reduce='sum', reduce_over_particles=True, weight_by_dt=True, subsampling=1, chunk_target_bytes=536870912, context=None, batch=True)[source]

Integrate an instantaneous program over time and datasets.

Parameters:
  • collection – TrajectoryCollection exposing producers and time-index chunks.

  • program – Integrand object with require, estimate_bytes_per_sample, and __call__.

  • reduce ({'sum','mean'}) – Dataset-and-time reduction. ‘mean’ divides by the accumulated effective exposure computed from the same dt used in the numerator.

  • reduce_over_particles (bool) – If the program output has a leading particle axis, apply mask_out, then sum that axis before the time reduction.

  • weight_by_dt (bool) – If True (default), multiply each program output by dt before accumulation. Set to False for programs whose output should be summed without dt weighting (e.g. parametric Gram matrices).

  • subsampling (int) – Keep indices with t % subsampling == 0.

  • chunk_target_bytes (int) – Target working-set size for the vmapped kernel.

  • context (str, optional) – Forwarded to dataset extras via producers.

  • batch (bool)

Returns:

Reduced value with particle axis removed if requested. Shapes match the program’s output after optional particle reduction.

Return type:

jax.Array

SFI.integrate.make_minibatch_parametric_integrator(collection, program, *, batch_size, reduce='sum', reduce_over_particles=True, weight_by_dt=True, subsampling=1, chunk_target_bytes=536870912, context=None, bytes_per_sample=None, batch=True)[source]

Build a parametric integrator with both full and mini-batch runners.

Parameters:
Returns:

  • plan (IntegrationPlan)

  • run_full (callable) – run_full(theta) -> scalar — full-data evaluator.

  • run_batch (callable) – run_batch(theta, rng_key) -> scalar — stochastic mini-batch evaluator. Unbiased estimator of the full-data value.

Return type:

Tuple[IntegrationPlan, Callable, Callable]

SFI.integrate.make_parametric_integrator(collection, program, *, reduce='sum', reduce_over_particles=True, weight_by_dt=True, subsampling=1, chunk_target_bytes=536870912, context=None, bytes_per_sample=None, batch=True)[source]

Build a reusable, jittable integrator for a parametric Integrand.

Parameters:
  • collection – TrajectoryCollection exposing producers and time-index chunks.

  • program – Integrand object with require, estimate_bytes_per_sample, and a call signature program(**streams, params=theta) where theta is a PyTree of parameters.

  • reduce (str) – Same meaning as in integrate().

  • reduce_over_particles (bool) – Same meaning as in integrate().

  • weight_by_dt (bool) – Same meaning as in integrate().

  • subsampling (int) – Same meaning as in integrate().

  • chunk_target_bytes (int) – Same meaning as in integrate().

  • context (str | None) – Same meaning as in integrate().

  • bytes_per_sample (int, optional) – Optional override for the per-sample memory estimate. If None, the program’s estimate_bytes_per_sample is used.

  • batch (bool) – If True, use the batched integration path (see integrate()).

Returns:

  • plan (IntegrationPlan) – Host-side plan describing the chunks and producers.

  • run (callable) – JAX-jitted function run(theta) -> value that evaluates the integration for a given set of parameters.

Return type:

Tuple[IntegrationPlan, Callable[[Any], Array]]

SFI.integrate.scale(op, alpha, *, name=None)[source]

Return a TimeOp that multiplies op by scalar alpha.

Parameters:
  • op (TimeOp)

  • alpha (float | int)

  • name (str | None)

Return type:

TimeOp

SFI.integrate.stream(key, *, name=None)[source]

Return a TimeOp that passes stream key through unchanged.

Parameters:
  • key (str)

  • name (str | None)

Return type:

TimeOp

SFI.integrate.timeop(fn=None, *, name=None, batch_safe=False, requires=None)[source]

Decorator: convert a function into a TimeOp.

Parameters:
  • batch_safe (bool) – If True, the function already handles a leading batch (K) axis in its inputs without requiring an additional jax.vmap.

  • requires (frozenset, optional) – Explicit set of required stream keys. When given, the function signature is not inspected for parameter names. Use this when the function accepts **streams and the required keys cannot be inferred from the signature.

  • fn (Callable[[...], Array] | None)

  • name (str | None)

Notes

Without requires, the function’s parameter names define the required stream keys (**kwargs/*args parameters are ignored).

SFI.integrate.velocity(dx_key, dt_key, *, name=None)[source]

Return a TimeOp that computes dx/dt, broadcasting dt over arbitrary leading dims.

Parameters:
  • dx_key (str)

  • dt_key (str)

  • name (str | None)

Return type:

TimeOp

Submodules