SFI.integrate.api module

Integration runtime: vmapped over time indices with dataset-owned producers.

Contract

  • collection.peek_row(require=…) returns a single-row sample mapping for memory sizing.

  • collection.iter_slices(require=…, bytes_hint=…, chunk_target_bytes=…, subsampling=…, context=…) yields dictionaries with:

    • “producer”: Callable[[t], row] — JAX-traceable single-t builder,

    • “t_idx”: jax.Array[int32] — indices for this chunk,

    • “weight”: float — dataset-level weight in [0,1],

    • “dataset_index”: int — for bookkeeping.

  • program implements:
    • require() -> set[str] of streams (plus “extras” if needed),

    • estimate_bytes_per_sample(sample_row) -> Optional[int],

    • __call__(**streams) for one time slice; for the parametric route it additionally supports a keyword-only argument params.

This module provides:

  • integrate(…): one-off integration using an Integrand program (backwards compatible front-end).

  • make_parametric_integrator(…): build a reusable, jittable integrator for a parameterised Integrand, with a clear separation between host-side planning and JAX-side runtime.

class SFI.integrate.api.ChunkSpec(dataset_index, weight, t_block, valid_block)[source]

Bases: object

One time-chunk (possibly padded) for a given dataset.

Parameters:
  • dataset_index (int)

  • weight (float)

  • t_block (Array)

  • valid_block (Array)

dataset_index: int
t_block: Array
valid_block: Array
weight: float
class SFI.integrate.api.IntegrationPlan(producers, batch_producers, chunks, reduce, reduce_over_particles, weight_by_dt, bytes_hint, K_fixed, context)[source]

Bases: object

Host-side integration plan.

Contains:
  • producers: per-dataset single-t row builders,

  • batch_producers: per-dataset batch-t row builders,

  • chunks: padded time blocks with validity masks and weights,

  • reduction semantics and memory hints.

Parameters:
  • producers (Dict[int, Callable[[int], Mapping[str, Any]]])

  • batch_producers (Dict[int, Callable[[Array], Mapping[str, Any]]])

  • chunks (Tuple[ChunkSpec, ...])

  • reduce (str)

  • reduce_over_particles (bool)

  • weight_by_dt (bool)

  • bytes_hint (int | None)

  • K_fixed (int | None)

  • context (str | None)

K_fixed: int | None
batch_producers: Dict[int, Callable[[Array], Mapping[str, Any]]]
bytes_hint: int | None
chunks: Tuple[ChunkSpec, ...]
context: str | None
producers: Dict[int, Callable[[int], Mapping[str, Any]]]
reduce: str
reduce_over_particles: bool
weight_by_dt: bool
SFI.integrate.api.integrate(collection, program, *, reduce='sum', reduce_over_particles=True, weight_by_dt=True, subsampling=1, chunk_target_bytes=536870912, context=None, batch=True)[source]

Integrate an instantaneous program over time and datasets.

Parameters:
  • collection – TrajectoryCollection exposing producers and time-index chunks.

  • program – Integrand object with require, estimate_bytes_per_sample, and __call__.

  • reduce ({'sum','mean'}) – Dataset-and-time reduction. ‘mean’ divides by the accumulated effective exposure computed from the same dt used in the numerator.

  • reduce_over_particles (bool) – If the program output has a leading particle axis, apply mask_out, then sum that axis before the time reduction.

  • weight_by_dt (bool) – If True (default), multiply each program output by dt before accumulation. Set to False for programs whose output should be summed without dt weighting (e.g. parametric Gram matrices).

  • subsampling (int) – Keep indices with t % subsampling == 0.

  • chunk_target_bytes (int) – Target working-set size for the vmapped kernel.

  • context (str, optional) – Forwarded to dataset extras via producers.

  • batch (bool)

Returns:

Reduced value with particle axis removed if requested. Shapes match the program’s output after optional particle reduction.

Return type:

jax.Array

SFI.integrate.api.make_minibatch_parametric_integrator(collection, program, *, batch_size, reduce='sum', reduce_over_particles=True, weight_by_dt=True, subsampling=1, chunk_target_bytes=536870912, context=None, bytes_per_sample=None, batch=True)[source]

Build a parametric integrator with both full and mini-batch runners.

Parameters:
Returns:

  • plan (IntegrationPlan)

  • run_full (callable) – run_full(theta) -> scalar — full-data evaluator.

  • run_batch (callable) – run_batch(theta, rng_key) -> scalar — stochastic mini-batch evaluator. Unbiased estimator of the full-data value.

Return type:

Tuple[IntegrationPlan, Callable, Callable]

SFI.integrate.api.make_parametric_integrator(collection, program, *, reduce='sum', reduce_over_particles=True, weight_by_dt=True, subsampling=1, chunk_target_bytes=536870912, context=None, bytes_per_sample=None, batch=True)[source]

Build a reusable, jittable integrator for a parametric Integrand.

Parameters:
  • collection – TrajectoryCollection exposing producers and time-index chunks.

  • program – Integrand object with require, estimate_bytes_per_sample, and a call signature program(**streams, params=theta) where theta is a PyTree of parameters.

  • reduce (str) – Same meaning as in integrate().

  • reduce_over_particles (bool) – Same meaning as in integrate().

  • weight_by_dt (bool) – Same meaning as in integrate().

  • subsampling (int) – Same meaning as in integrate().

  • chunk_target_bytes (int) – Same meaning as in integrate().

  • context (str | None) – Same meaning as in integrate().

  • bytes_per_sample (int, optional) – Optional override for the per-sample memory estimate. If None, the program’s estimate_bytes_per_sample is used.

  • batch (bool) – If True, use the batched integration path (see integrate()).

Returns:

  • plan (IntegrationPlan) – Host-side plan describing the chunks and producers.

  • run (callable) – JAX-jitted function run(theta) -> value that evaluates the integration for a given set of parameters.

Return type:

Tuple[IntegrationPlan, Callable[[Any], Array]]