SFI.integrate.api module¶
Integration runtime: vmapped over time indices with dataset-owned producers.
Contract¶
collection.peek_row(require=…) returns a single-row sample mapping for memory sizing.
collection.iter_slices(require=…, bytes_hint=…, chunk_target_bytes=…, subsampling=…, context=…) yields dictionaries with:
“producer”: Callable[[t], row] — JAX-traceable single-t builder,
“t_idx”: jax.Array[int32] — indices for this chunk,
“weight”: float — dataset-level weight in [0,1],
“dataset_index”: int — for bookkeeping.
- program implements:
require() -> set[str] of streams (plus “extras” if needed),
estimate_bytes_per_sample(sample_row) -> Optional[int],
__call__(**streams) for one time slice; for the parametric route it additionally supports a keyword-only argument params.
This module provides:
integrate(…): one-off integration using an Integrand program (backwards compatible front-end).
make_parametric_integrator(…): build a reusable, jittable integrator for a parameterised Integrand, with a clear separation between host-side planning and JAX-side runtime.
- class SFI.integrate.api.ChunkSpec(dataset_index, weight, t_block, valid_block)[source]¶
Bases:
objectOne time-chunk (possibly padded) for a given dataset.
- Parameters:
dataset_index (int)
weight (float)
t_block (Array)
valid_block (Array)
- dataset_index: int¶
- t_block: Array¶
- valid_block: Array¶
- weight: float¶
- class SFI.integrate.api.IntegrationPlan(producers, batch_producers, chunks, reduce, reduce_over_particles, weight_by_dt, bytes_hint, K_fixed, context)[source]¶
Bases:
objectHost-side integration plan.
- Contains:
producers: per-dataset single-t row builders,
batch_producers: per-dataset batch-t row builders,
chunks: padded time blocks with validity masks and weights,
reduction semantics and memory hints.
- Parameters:
producers (Dict[int, Callable[[int], Mapping[str, Any]]])
batch_producers (Dict[int, Callable[[Array], Mapping[str, Any]]])
chunks (Tuple[ChunkSpec, ...])
reduce (str)
reduce_over_particles (bool)
weight_by_dt (bool)
bytes_hint (int | None)
K_fixed (int | None)
context (str | None)
- K_fixed: int | None¶
- batch_producers: Dict[int, Callable[[Array], Mapping[str, Any]]]¶
- bytes_hint: int | None¶
- context: str | None¶
- producers: Dict[int, Callable[[int], Mapping[str, Any]]]¶
- reduce: str¶
- reduce_over_particles: bool¶
- weight_by_dt: bool¶
- SFI.integrate.api.integrate(collection, program, *, reduce='sum', reduce_over_particles=True, weight_by_dt=True, subsampling=1, chunk_target_bytes=536870912, context=None, batch=True)[source]¶
Integrate an instantaneous program over time and datasets.
- Parameters:
collection – TrajectoryCollection exposing producers and time-index chunks.
program – Integrand object with require, estimate_bytes_per_sample, and __call__.
reduce ({'sum','mean'}) – Dataset-and-time reduction. ‘mean’ divides by the accumulated effective exposure computed from the same dt used in the numerator.
reduce_over_particles (bool) – If the program output has a leading particle axis, apply mask_out, then sum that axis before the time reduction.
weight_by_dt (bool) – If True (default), multiply each program output by
dtbefore accumulation. Set to False for programs whose output should be summed without dt weighting (e.g. parametric Gram matrices).subsampling (int) – Keep indices with t % subsampling == 0.
chunk_target_bytes (int) – Target working-set size for the vmapped kernel.
context (str, optional) – Forwarded to dataset extras via producers.
batch (bool)
- Returns:
Reduced value with particle axis removed if requested. Shapes match the program’s output after optional particle reduction.
- Return type:
jax.Array
- SFI.integrate.api.make_minibatch_parametric_integrator(collection, program, *, batch_size, reduce='sum', reduce_over_particles=True, weight_by_dt=True, subsampling=1, chunk_target_bytes=536870912, context=None, bytes_per_sample=None, batch=True)[source]¶
Build a parametric integrator with both full and mini-batch runners.
- Parameters:
collection – Same as
make_parametric_integrator().program – Same as
make_parametric_integrator().reduce (str) – Same as
make_parametric_integrator().reduce_over_particles (bool) – Same as
make_parametric_integrator().weight_by_dt (bool) – Same as
make_parametric_integrator().subsampling (int) – Same as
make_parametric_integrator().chunk_target_bytes (int) – Same as
make_parametric_integrator().context (str | None) – Same as
make_parametric_integrator().bytes_per_sample (int | None) – Same as
make_parametric_integrator().batch (bool) – Same as
make_parametric_integrator().batch_size (int) – Number of time indices to sample per mini-batch evaluation.
- Returns:
plan (IntegrationPlan)
run_full (callable) –
run_full(theta) -> scalar— full-data evaluator.run_batch (callable) –
run_batch(theta, rng_key) -> scalar— stochastic mini-batch evaluator. Unbiased estimator of the full-data value.
- Return type:
Tuple[IntegrationPlan, Callable, Callable]
- SFI.integrate.api.make_parametric_integrator(collection, program, *, reduce='sum', reduce_over_particles=True, weight_by_dt=True, subsampling=1, chunk_target_bytes=536870912, context=None, bytes_per_sample=None, batch=True)[source]¶
Build a reusable, jittable integrator for a parametric Integrand.
- Parameters:
collection – TrajectoryCollection exposing producers and time-index chunks.
program – Integrand object with require, estimate_bytes_per_sample, and a call signature
program(**streams, params=theta)where theta is a PyTree of parameters.reduce (str) – Same meaning as in
integrate().reduce_over_particles (bool) – Same meaning as in
integrate().weight_by_dt (bool) – Same meaning as in
integrate().subsampling (int) – Same meaning as in
integrate().chunk_target_bytes (int) – Same meaning as in
integrate().context (str | None) – Same meaning as in
integrate().bytes_per_sample (int, optional) – Optional override for the per-sample memory estimate. If None, the program’s estimate_bytes_per_sample is used.
batch (bool) – If True, use the batched integration path (see
integrate()).
- Returns:
plan (IntegrationPlan) – Host-side plan describing the chunks and producers.
run (callable) – JAX-jitted function
run(theta) -> valuethat evaluates the integration for a given set of parameters.
- Return type:
Tuple[IntegrationPlan, Callable[[Any], Array]]