SFI.integrate package¶
SFI.integrate — Time-averaging integration engine.
Public API¶
- integrate
Run an Integrand over a TrajectoryCollection and reduce.
- make_parametric_integrator
Build a reusable, jittable integrator for a parametric Integrand.
- make_minibatch_parametric_integrator
Like make_parametric_integrator but also returns a stochastic mini-batch runner.
- Integrand
Compose state expressions and time operands via einsum per time slice.
- Term, ExprOperand, TimeOperand, ConstOperand
Building blocks for Integrand programs.
- stream, timeop, velocity, scale, add
TimeOp constructors for stream access and linear combinations.
- class SFI.integrate.ConstOperand(value, alias='C')[source]¶
Bases:
objectConstant array captured by the program (e.g. A_inv).
- Parameters:
value (Array)
alias (str)
- alias: str = 'C'¶
- value: Array¶
- class SFI.integrate.ExprOperand(expr, x, v=None, params_template=None, alias='E')[source]¶
Bases:
objectWrap a state expression evaluated at x (and optionally v).
Contract¶
- expr(x, v=…, mask=…, extras=…, params=…) -> array
features sit on the last axis
‘mask’ is forwarded from the ‘mask_out’ stream when present
- alias: str = 'E'¶
- expr: object¶
- params_template: Mapping | None = None¶
- property required_extras: Tuple[str, ...]¶
Keys this expression demands in the extras mapping.
- property requires: Set[str]¶
- v: 'TimeOp' | None = None¶
- Parameters:
expr (object)
x (TimeOp)
v (Optional['TimeOp'])
params_template (Optional[Mapping])
alias (str)
- class SFI.integrate.Integrand(exprs=(), times=(), consts=(), terms=())[source]¶
Bases:
objectCompose state expressions and time operands via einsum per time slice.
require()->Set[str]Stream keys needed to evaluate this program on one time slice.
__call__(**streams)->jnp.ndarrayEvaluate on one time slice. The collection handles mask and reduction.
estimate_bytes_per_sample(sample_streams)->Optional[int]Upper bound on per-particle bytes using a real sample (via opt_einsum).
- Parameters:
exprs (Sequence[ExprOperand])
times (Sequence[TimeOperand])
consts (Sequence[ConstOperand])
terms (Sequence[Term])
- batch_call(*, params=None, **streams)[source]¶
Evaluate with streams that carry a leading batch (K) axis.
This is the batched counterpart of
__call__(). Streams such asXhave shape(K, N, d)instead of(N, d), anddthas shape(K,)instead of being a scalar.State-expression operands receive the full
(K, N, d)tensor and handle arbitrary leading batch dimensions internally (the leaf’s_apply_user_funcflattens the batch prefix and uses a singlejax.vmap).Time operands (e.g. velocity) are evaluated on the batch streams directly; the batch-safe
velocityTimeOp handles dt broadcasting.Einsum contractions are vmapped over the leading K axis so that existing subscript strings (which reference particle/spatial/feature axes only) work unchanged.
- Returns:
Result with a leading
Kaxis. Shape is(K, <per-row output shape>).- Return type:
jax.Array
- Parameters:
params (Any | None)
- class SFI.integrate.Term(eq, ops, scale=1.0)[source]¶
Bases:
objectOne einsum contraction over the named operands.
Axis conventions¶
i : particle axis (kept until the integrator reduces over i) m,n,… : spatial indices a,b,… : feature axes (always last on any ExprOperand output)
- eq: str¶
- ops: Tuple[str, ...]¶
- scale: float = 1.0¶
- Parameters:
eq (str)
ops (Tuple[str, ...])
scale (float)
- class SFI.integrate.TimeOperand(op, alias='T')[source]¶
Bases:
objectTime operand (scalar/vector/tensor) built from streams via a TimeOp.
- Parameters:
op (TimeOp)
alias (str)
- alias: str = 'T'¶
- property requires: Set[str]¶
- SFI.integrate.add(ops, *, name=None)[source]¶
Return a TimeOp that sums the outputs of all ops element-wise.
- SFI.integrate.integrate(collection, program, *, reduce='sum', reduce_over_particles=True, weight_by_dt=True, subsampling=1, chunk_target_bytes=536870912, context=None, batch=True)[source]¶
Integrate an instantaneous program over time and datasets.
- Parameters:
collection – TrajectoryCollection exposing producers and time-index chunks.
program – Integrand object with require, estimate_bytes_per_sample, and __call__.
reduce ({'sum','mean'}) – Dataset-and-time reduction. ‘mean’ divides by the accumulated effective exposure computed from the same dt used in the numerator.
reduce_over_particles (bool) – If the program output has a leading particle axis, apply mask_out, then sum that axis before the time reduction.
weight_by_dt (bool) – If True (default), multiply each program output by
dtbefore accumulation. Set to False for programs whose output should be summed without dt weighting (e.g. parametric Gram matrices).subsampling (int) – Keep indices with t % subsampling == 0.
chunk_target_bytes (int) – Target working-set size for the vmapped kernel.
context (str, optional) – Forwarded to dataset extras via producers.
batch (bool)
- Returns:
Reduced value with particle axis removed if requested. Shapes match the program’s output after optional particle reduction.
- Return type:
jax.Array
- SFI.integrate.make_minibatch_parametric_integrator(collection, program, *, batch_size, reduce='sum', reduce_over_particles=True, weight_by_dt=True, subsampling=1, chunk_target_bytes=536870912, context=None, bytes_per_sample=None, batch=True)[source]¶
Build a parametric integrator with both full and mini-batch runners.
- Parameters:
collection – Same as
make_parametric_integrator().program – Same as
make_parametric_integrator().reduce (str) – Same as
make_parametric_integrator().reduce_over_particles (bool) – Same as
make_parametric_integrator().weight_by_dt (bool) – Same as
make_parametric_integrator().subsampling (int) – Same as
make_parametric_integrator().chunk_target_bytes (int) – Same as
make_parametric_integrator().context (str | None) – Same as
make_parametric_integrator().bytes_per_sample (int | None) – Same as
make_parametric_integrator().batch (bool) – Same as
make_parametric_integrator().batch_size (int) – Number of time indices to sample per mini-batch evaluation.
- Returns:
plan (IntegrationPlan)
run_full (callable) –
run_full(theta) -> scalar— full-data evaluator.run_batch (callable) –
run_batch(theta, rng_key) -> scalar— stochastic mini-batch evaluator. Unbiased estimator of the full-data value.
- Return type:
Tuple[IntegrationPlan, Callable, Callable]
- SFI.integrate.make_parametric_integrator(collection, program, *, reduce='sum', reduce_over_particles=True, weight_by_dt=True, subsampling=1, chunk_target_bytes=536870912, context=None, bytes_per_sample=None, batch=True)[source]¶
Build a reusable, jittable integrator for a parametric Integrand.
- Parameters:
collection – TrajectoryCollection exposing producers and time-index chunks.
program – Integrand object with require, estimate_bytes_per_sample, and a call signature
program(**streams, params=theta)where theta is a PyTree of parameters.reduce (str) – Same meaning as in
integrate().reduce_over_particles (bool) – Same meaning as in
integrate().weight_by_dt (bool) – Same meaning as in
integrate().subsampling (int) – Same meaning as in
integrate().chunk_target_bytes (int) – Same meaning as in
integrate().context (str | None) – Same meaning as in
integrate().bytes_per_sample (int, optional) – Optional override for the per-sample memory estimate. If None, the program’s estimate_bytes_per_sample is used.
batch (bool) – If True, use the batched integration path (see
integrate()).
- Returns:
plan (IntegrationPlan) – Host-side plan describing the chunks and producers.
run (callable) – JAX-jitted function
run(theta) -> valuethat evaluates the integration for a given set of parameters.
- Return type:
Tuple[IntegrationPlan, Callable[[Any], Array]]
- SFI.integrate.scale(op, alpha, *, name=None)[source]¶
Return a TimeOp that multiplies op by scalar alpha.
- SFI.integrate.stream(key, *, name=None)[source]¶
Return a TimeOp that passes stream key through unchanged.
- Parameters:
key (str)
name (str | None)
- Return type:
- SFI.integrate.timeop(fn=None, *, name=None, batch_safe=False, requires=None)[source]¶
Decorator: convert a function into a TimeOp.
- Parameters:
batch_safe (bool) – If True, the function already handles a leading batch (K) axis in its inputs without requiring an additional
jax.vmap.requires (frozenset, optional) – Explicit set of required stream keys. When given, the function signature is not inspected for parameter names. Use this when the function accepts
**streamsand the required keys cannot be inferred from the signature.fn (Callable[[...], Array] | None)
name (str | None)
Notes
Without requires, the function’s parameter names define the required stream keys (
**kwargs/*argsparameters are ignored).
- SFI.integrate.velocity(dx_key, dt_key, *, name=None)[source]¶
Return a TimeOp that computes dx/dt, broadcasting dt over arbitrary leading dims.
- Parameters:
dx_key (str)
dt_key (str)
name (str | None)
- Return type: