Source code for SFI.integrate.api

"""
Integration runtime: vmapped over time indices with dataset-owned producers.

Contract
--------
- `collection.peek_row(require=...)` returns a single-row sample mapping for
  memory sizing.
- `collection.iter_slices(require=..., bytes_hint=..., chunk_target_bytes=..., subsampling=..., context=...)`
  yields dictionaries with:

    - "producer": Callable[[t], row]  — JAX-traceable single-t builder,
    - "t_idx": jax.Array[int32]       — indices for this chunk,
    - "weight": float                 — dataset-level weight in [0,1],
    - "dataset_index": int            — for bookkeeping.
- `program` implements:
    - `require() -> set[str]` of streams (plus "extras" if needed),
    - `estimate_bytes_per_sample(sample_row) -> Optional[int]`,
    - `__call__(**streams)` for one time slice; for the parametric route
      it additionally supports a keyword-only argument `params`.

This module provides:

- `integrate(...)`: one-off integration using an `Integrand` `program`
  (backwards compatible front-end).
- `make_parametric_integrator(...)`: build a reusable, jittable integrator
  for a parameterised `Integrand`, with a clear separation between host-side
  planning and JAX-side runtime.
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Callable, Dict, Mapping, Optional, Tuple

import jax
import jax.numpy as jnp

Row = Mapping[str, Any]
Producer = Callable[[int], Row]
BatchProducer = Callable[[jnp.ndarray], Row]
RowEval = Callable[[Row, Any], jnp.ndarray]


[docs] @dataclass(frozen=True) class ChunkSpec: """One time-chunk (possibly padded) for a given dataset.""" dataset_index: int weight: float t_block: jnp.ndarray # shape (K_chunk,) valid_block: jnp.ndarray # shape (K_chunk,), bool
[docs] @dataclass(frozen=True) class IntegrationPlan: """ Host-side integration plan. Contains: - producers: per-dataset single-t row builders, - batch_producers: per-dataset batch-t row builders, - chunks: padded time blocks with validity masks and weights, - reduction semantics and memory hints. """ producers: Dict[int, Producer] batch_producers: Dict[int, BatchProducer] chunks: Tuple[ChunkSpec, ...] reduce: str reduce_over_particles: bool weight_by_dt: bool bytes_hint: Optional[int] K_fixed: Optional[int] context: Optional[str]
# --------------------------------------------------------------------------- # Planning # --------------------------------------------------------------------------- def _build_plan( collection, program, *, reduce: str, reduce_over_particles: bool, weight_by_dt: bool = True, subsampling: int, chunk_target_bytes: int, context: Optional[str], bytes_per_sample: Optional[int] = None, ) -> IntegrationPlan: """ Plan an integration once on the host side. Computes: - required streams from `program.require()`, - one real sample to estimate bytes per sample, - a fixed chunk size K for stable vmapped kernels, - the list of ChunkSpec and per-dataset producers. """ # Required streams; include sentinel for dt window if datasets support it. require = set(program.require()) require.add("__dt__") # dataset.valid_indices should treat as offsets (0,+1) # Size hint from a real row try: sample_row = collection.peek_row(require=require, context=context) except ValueError: # ``peek_row`` raises ValueError both when no dataset has a usable time # window (the legitimate empty-plan case) and when materialising a row # genuinely fails (e.g. a malformed extra). Only the former should be # swallowed into an empty plan — masking the latter turns an informative # error into a cryptic downstream crash (a scalar Gram fed to swapaxes). if any(ds.valid_indices(require).size > 0 for ds in collection.datasets): raise # No dataset has valid rows for these requirements return IntegrationPlan( producers={}, batch_producers={}, chunks=tuple(), reduce=reduce, reduce_over_particles=reduce_over_particles, weight_by_dt=weight_by_dt, bytes_hint=None, K_fixed=None, context=context, ) if bytes_per_sample is not None: bytes_hint = int(bytes_per_sample) else: estimator = getattr(program, "estimate_bytes_per_sample", None) if callable(estimator): bytes_hint = estimator(sample_row) else: bytes_hint = None # Derive fixed K per chunk from hint if not bytes_hint or bytes_hint <= 0: K_fixed: Optional[int] = None else: # Conservative: at least one row K_fixed = max(1, int(chunk_target_bytes // int(bytes_hint))) producers: Dict[int, Producer] = {} batch_producers: Dict[int, BatchProducer] = {} chunks: list[ChunkSpec] = [] # Collect all payloads first so we can cap K_fixed at the actual data size. # Without this cap, tiny datasets get absurdly padded (e.g. 19 valid indices # padded to 6.4M when bytes_hint is small and chunk_target_bytes is large). payloads = list( collection.iter_slices( require=require, bytes_hint=bytes_hint, chunk_target_bytes=chunk_target_bytes, subsampling=subsampling, context=context, ) ) if K_fixed is not None and payloads: max_payload = max(int(p["t_idx"].shape[0]) for p in payloads) K_fixed = min(K_fixed, max_payload) for payload in payloads: ds_idx: int = int(payload["dataset_index"]) producer = payload["producer"] t_idx = payload["t_idx"] weight = float(payload.get("weight", 1.0)) if ds_idx not in producers: producers[ds_idx] = producer # Build the matching batch producer from the underlying dataset ds = collection.datasets[ds_idx] batch_producers[ds_idx] = ds.make_batch_producer( require, include_mask=True, include_dt=True, context=context, force_dt_keys={"dt"}, dataset_index=collection.dataset_index(ds_idx), ) if K_fixed is None: # No padding, one chunk per payload t_block = t_idx valid_block = jnp.ones_like(t_idx, dtype=bool) chunks.append( ChunkSpec( dataset_index=ds_idx, weight=weight, t_block=t_block, valid_block=valid_block, ) ) else: # Split t_idx into blocks of size K_fixed and pad the last one K_total = int(t_idx.shape[0]) for start in range(0, K_total, K_fixed): stop = min(start + K_fixed, K_total) cur = t_idx[start:stop] K = int(cur.shape[0]) pad = K_fixed - K if pad > 0: pad_idx = jnp.pad(cur, (0, pad), mode="edge") valid = jnp.concatenate([jnp.ones((K,), dtype=bool), jnp.zeros((pad,), dtype=bool)]) else: pad_idx = cur valid = jnp.ones((K_fixed,), dtype=bool) chunks.append( ChunkSpec( dataset_index=ds_idx, weight=weight, t_block=pad_idx, valid_block=valid, ) ) return IntegrationPlan( producers=producers, batch_producers=batch_producers, chunks=tuple(chunks), reduce=reduce, reduce_over_particles=reduce_over_particles, weight_by_dt=weight_by_dt, bytes_hint=bytes_hint, K_fixed=K_fixed, context=context, ) # --------------------------------------------------------------------------- # Core row kernel and runner (shared) # --------------------------------------------------------------------------- def _row_kernel( row_eval: RowEval, row: Row, theta: Any, *, reduce_over_particles: bool, weight_by_dt: bool = True, ) -> Tuple[jnp.ndarray, jnp.ndarray]: """ Single-row kernel. ``row_eval(row, theta) -> y``, same semantics as ``program(**row)`` in the non-parametric case. Returns ``(weighted_value, dt_eff)`` where ``weighted_value`` already includes dt when *weight_by_dt* is True. """ y = row_eval(row, theta) # Optional particle reduction if reduce_over_particles: if y.ndim == 0: raise ValueError("reduce_over_particles=True but row_eval returned a scalar.") m = row.get("mask_out", None) if m is not None: if y.ndim == 0 or y.shape[0] != m.shape[0]: raise ValueError( "mask_out mismatch: row_eval must return an array with " f"leading particle axis of size {m.shape[0]} when mask_out is present." ) mexp = m.reshape((m.shape[0],) + (1,) * (y.ndim - 1)) y = jnp.where(mexp, y, 0.0) y = jnp.sum(y, axis=0) dt = row["dt"] if weight_by_dt: dteff = dt * row["N_active"] y_w = y * dt else: dteff = row["N_active"] y_w = y return y_w, dteff def _run_plan_core( plan: IntegrationPlan, row_eval: RowEval, theta: Any, ) -> Tuple[jnp.ndarray, jnp.ndarray]: """ Core runtime: integrates according to a pre-built IntegrationPlan and a given row_eval. Returns ``(acc, Teff_total)`` where: - ``acc`` is the sum over chunks of ``y_w``, - ``Teff_total`` is the sum over chunks of effective exposure. Suitable for jitting when ``row_eval`` and ``theta`` are JAX-traceable. """ if not plan.chunks: zero = jnp.asarray(0.0, dtype=jnp.float32) return zero, zero reduce_over_particles = plan.reduce_over_particles weight_by_dt = plan.weight_by_dt producer_by_idx: Dict[int, Producer] = plan.producers # Per-dataset weights are applied in every reduction (sum and mean) so the # force Gram, diffusion average, and parametric Gram pool datasets the same # way. Within-dataset weighting (per-dt vs per-point) is the caller's # ``weight_by_dt``; the per-dataset multiplier is orthogonal. use_weight = True # per-dataset weights applied in all reductions (unit weights => no-op) acc = None Teff_total = jnp.asarray(0.0, dtype=jnp.float32) for chunk in plan.chunks: producer = producer_by_idx[chunk.dataset_index] t_block = chunk.t_block # (K,) valid_block = chunk.valid_block # (K,) weight = chunk.weight def row_masked(t, is_valid, theta_): row = producer(t) y_w, dteff = _row_kernel( row_eval, row, theta_, reduce_over_particles=reduce_over_particles, weight_by_dt=weight_by_dt, ) assert y_w is not None maskf = is_valid.astype(y_w.dtype) if use_weight: return ( y_w * maskf * weight, dteff * is_valid.astype(dteff.dtype) * weight, ) else: return ( y_w * maskf, dteff * is_valid.astype(dteff.dtype), ) Ys, Dteffs = jax.vmap(row_masked, in_axes=(0, 0, None))(t_block, valid_block, theta) y_sum = jnp.sum(Ys, axis=0) dteff_sum = jnp.sum(Dteffs, axis=0) Teff_total = Teff_total + dteff_sum acc = y_sum if acc is None else (acc + y_sum) return acc, Teff_total # --------------------------------------------------------------------------- # Batched runner: batch gather + batch statefunc + vmapped einsum # --------------------------------------------------------------------------- BatchRowEval = Callable[[Row, Any], jnp.ndarray] def _run_plan_batched( plan: IntegrationPlan, batch_row_eval: BatchRowEval, theta: Any, ) -> Tuple[jnp.ndarray, jnp.ndarray]: """ Batched runtime: gathers K rows at once and evaluates in batch. Instead of vmapping over individual time indices (which nests vmaps inside the state-expression leaves), this path: 1. Gathers K rows with a single batch producer (one XLA gather), 2. Passes ``(K, N, d)`` tensors to state expressions, which handle the leading batch dimensions in a single fused vmap, 3. vmaps the einsum contractions over the K axis. The result shapes are ``(K, ...)`` where ``...`` is the per-row result shape. Particle and time reductions happen outside the batch evaluator. Returns ``(acc, Teff_total)`` with the same semantics as :func:`_run_plan_core`. """ if not plan.chunks: zero = jnp.asarray(0.0, dtype=jnp.float32) return zero, zero reduce_over_particles = plan.reduce_over_particles batch_producer_by_idx: Dict[int, BatchProducer] = plan.batch_producers use_weight = True # per-dataset weights applied in all reductions (unit weights => no-op) acc = None Teff_total = jnp.asarray(0.0, dtype=jnp.float32) for chunk in plan.chunks: batch_producer = batch_producer_by_idx[chunk.dataset_index] t_block = chunk.t_block # (K,) valid_block = chunk.valid_block # (K,), bool weight = chunk.weight # 1) Batch gather — one XLA gather per stream batch_row = batch_producer(t_block) # {X: (K,N,d), dX: (K,N,d), dt: (K,), mask_out: (K,N), ...} # 2) Batch evaluate — statefunc sees (K,N,d), einsum vmapped over K y = batch_row_eval(batch_row, theta) # y shape: (K, N, ..., F) if particle axis is present, # (K, ..., F) if the einsum already contracted particles. # 3) Particle reduction over axis 1 if reduce_over_particles: if y.ndim < 2: raise ValueError( "reduce_over_particles=True but batch_row_eval returned an array with fewer than 2 dimensions." ) m = batch_row.get("mask_out", None) # (K, N) if m is not None: # Expand mask to broadcast: (K, N, 1, ..., 1) n_trail = y.ndim - 2 # dims after particle axis mexp = m.reshape(m.shape + (1,) * n_trail) y = jnp.where(mexp, y, 0.0) y = jnp.sum(y, axis=1) # (K, ..., F) # 4) dt weighting dt = batch_row["dt"] # (K,) if plan.weight_by_dt: dteff = dt * batch_row["N_active"] # (K,) dt_exp = dt.reshape((dt.shape[0],) + (1,) * (y.ndim - 1)) y_w = y * dt_exp # (K, ..., F) else: dteff = batch_row["N_active"] # (K,) y_w = y # 5) Valid masking valid_f = valid_block.astype(y_w.dtype) valid_exp = valid_f.reshape((valid_f.shape[0],) + (1,) * (y_w.ndim - 1)) y_w = y_w * valid_exp dteff = dteff * valid_block.astype(dteff.dtype) if use_weight: y_w = y_w * weight dteff = dteff * weight # 6) Sum over K y_sum = jnp.sum(y_w, axis=0) dteff_sum = jnp.sum(dteff, axis=0) Teff_total = Teff_total + dteff_sum acc = y_sum if acc is None else (acc + y_sum) return acc, Teff_total # --------------------------------------------------------------------------- # Public API: integrate (non-parametric) # --------------------------------------------------------------------------- def _has_time_varying_required_extras(collection, program) -> bool: """True when the program reads an extras key that is time-varying. The batched runtime gathers extras once per chunk (``make_batch_producer`` collects them at the chunk's first index), which is only correct for static extras. Programs that read :class:`TimeSeriesExtra` values or plain time-generator callables must run on the per-``t`` core runtime, where ``build_extras(t)`` slices them per frame. """ req = getattr(program, "required_extras", None) keys: tuple = tuple(req() or ()) if callable(req) else () if not keys: return False # The reserved ``time`` extra (auto-injected per frame by # ``build_extras``) is inherently time-varying, so any program that # reads it — e.g. a :func:`~SFI.bases.time_fourier` dictionary — must # run on the per-``t`` core runtime even though it is not stored as a # TimeSeriesExtra on the dataset. if "time" in keys: return True from SFI.trajectory.dataset import FunctionExtra, TimeSeriesExtra for ds in collection.datasets: for src in (ds.extras_global, ds.extras_local): for k in keys: v = (src or {}).get(k) if isinstance(v, TimeSeriesExtra) or (callable(v) and not isinstance(v, FunctionExtra)): return True return False
[docs] def integrate( collection, program, *, reduce: str = "sum", # {'sum','mean'} reduce_over_particles: bool = True, # sum over leading i if present weight_by_dt: bool = True, subsampling: int = 1, chunk_target_bytes: int = 512 * 1024**2, context: Optional[str] = None, batch: bool = True, ) -> jnp.ndarray: """ Integrate an instantaneous program over time and datasets. Parameters ---------- collection TrajectoryCollection exposing producers and time-index chunks. program Integrand object with `require`, `estimate_bytes_per_sample`, and `__call__`. reduce : {'sum','mean'} Dataset-and-time reduction. `'mean'` divides by the accumulated effective exposure computed from the same `dt` used in the numerator. reduce_over_particles : bool If the program output has a leading particle axis, apply `mask_out`, then sum that axis before the time reduction. weight_by_dt : bool If True (default), multiply each program output by ``dt`` before accumulation. Set to False for programs whose output should be summed without dt weighting (e.g. parametric Gram matrices). subsampling : int Keep indices with `t % subsampling == 0`. chunk_target_bytes : int Target working-set size for the vmapped kernel. context : str, optional Forwarded to dataset extras via producers. Returns ------- jax.Array Reduced value with particle axis removed if requested. Shapes match the program’s output after optional particle reduction. """ if reduce not in {"sum", "mean"}: raise ValueError("reduce must be 'sum' or 'mean'") # Time-varying extras must be sliced per frame: only the per-t core # runtime does that (the batch producer gathers extras once per chunk). if batch and _has_time_varying_required_extras(collection, program): batch = False plan = _build_plan( collection, program, reduce=reduce, reduce_over_particles=reduce_over_particles, weight_by_dt=weight_by_dt, subsampling=subsampling, chunk_target_bytes=chunk_target_bytes, context=context, ) if batch: batch_row_eval = _build_batch_row_eval(program, context=context, parametric=False) acc, Teff_total = _run_plan_batched(plan, batch_row_eval, theta=None) else: # Original vmap-over-t path def row_eval(row: Row, _theta: Any) -> jnp.ndarray: return program(**row) acc, Teff_total = _run_plan_core(plan, row_eval, theta=None) if reduce == "sum": return acc # mean: check Teff_total on host for backwards-compatible error behaviour Teff_val = float(Teff_total) if Teff_val <= 0.0: raise ValueError("Mean reduction requested but total exposure is non-positive.") return acc / jnp.asarray(Teff_total, dtype=acc.dtype)
# --------------------------------------------------------------------------- # Public API: parametric integrator # ---------------------------------------------------------------------------
[docs] def make_parametric_integrator( collection, program, *, reduce: str = "sum", reduce_over_particles: bool = True, weight_by_dt: bool = True, subsampling: int = 1, chunk_target_bytes: int = 512 * 1024**2, context: Optional[str] = None, bytes_per_sample: Optional[int] = None, batch: bool = True, ) -> Tuple[IntegrationPlan, Callable[[Any], jnp.ndarray]]: """ Build a reusable, jittable integrator for a parametric Integrand. Parameters ---------- collection TrajectoryCollection exposing producers and time-index chunks. program Integrand object with `require`, `estimate_bytes_per_sample`, and a call signature ``program(**streams, params=theta)`` where `theta` is a PyTree of parameters. reduce, reduce_over_particles, weight_by_dt, subsampling, chunk_target_bytes, context Same meaning as in :func:`integrate`. bytes_per_sample : int, optional Optional override for the per-sample memory estimate. If None, the program's `estimate_bytes_per_sample` is used. batch : bool If True, use the batched integration path (see :func:`integrate`). Returns ------- plan : IntegrationPlan Host-side plan describing the chunks and producers. run : callable JAX-jitted function ``run(theta) -> value`` that evaluates the integration for a given set of parameters. """ if reduce not in {"sum", "mean"}: raise ValueError("reduce must be 'sum' or 'mean'") # Same correctness rule as `integrate`: time-varying extras require the # per-t core runtime. if batch and _has_time_varying_required_extras(collection, program): batch = False plan = _build_plan( collection, program, reduce=reduce, reduce_over_particles=reduce_over_particles, weight_by_dt=weight_by_dt, subsampling=subsampling, chunk_target_bytes=chunk_target_bytes, context=context, bytes_per_sample=bytes_per_sample, ) if not plan.chunks: # Empty-plan edge case: always return 0.0 @jax.jit def run_empty(theta): del theta return jnp.asarray(0.0, dtype=jnp.float32) return plan, run_empty if batch: batch_row_eval = _build_batch_row_eval(program, context=context) def run(theta): acc, Teff_total = _run_plan_batched(plan, batch_row_eval, theta) if reduce == "sum": return acc Teff_safe = jnp.where(Teff_total > 0, Teff_total, jnp.ones_like(Teff_total)) return acc / Teff_safe.astype(acc.dtype) else: # Original vmap-over-t path def row_eval(row: Row, theta: Any) -> jnp.ndarray: return program(params=theta, **row) def run(theta): acc, Teff_total = _run_plan_core(plan, row_eval, theta) if reduce == "sum": return acc Teff_safe = jnp.where(Teff_total > 0, Teff_total, jnp.ones_like(Teff_total)) return acc / Teff_safe.astype(acc.dtype) return plan, run
# --------------------------------------------------------------------------- # Mini-batch parametric integrator # --------------------------------------------------------------------------- def _build_batch_row_eval(program, context=None, *, parametric: bool = True): """Build a JIT-compiled batch row evaluator. Parameters ---------- program Integrand or duck-typed program with ``require`` / ``__call__``. context : str, optional Forwarded to extras as a static constant. parametric : bool If True (default), forward ``theta`` as ``params=theta`` on every call. Set False for non-parametric programs whose ``__call__`` does not accept a ``params`` keyword. Returns ------- batch_row_eval : callable ``batch_row_eval(row, theta) -> jnp.ndarray``. """ _base_static: Dict[str, Any] = {} if context is not None: _base_static["context"] = context def _split_extras(row: Row): ext = row.get("extras") if ext is None or not isinstance(ext, dict): return row, {}, () arr_ext: Dict[str, Any] = {} fn_ext: Dict[str, Any] = {} for k, v in ext.items(): if hasattr(v, "shape") and hasattr(v, "dtype"): arr_ext[k] = v elif callable(v): fn_ext[k] = v elif isinstance(v, (str, type(None), bool, int, float)): fn_ext[k] = v else: arr_ext[k] = v if not fn_ext: return row, {}, () row = dict(row) row["extras"] = arr_ext if arr_ext else {} cache_key = tuple(sorted((k, id(v)) for k, v in fn_ext.items())) return row, fn_ext, cache_key _jit_cache: Dict[tuple, Any] = {} def _make_jit_fn(fn_extras: Dict[str, Any]): static = {**_base_static, **fn_extras} if hasattr(program, "batch_call"): @jax.jit def jit_fn(row: Row, theta: Any) -> jnp.ndarray: if static: ext = dict(row.get("extras", {})) ext.update(static) row = {**row, "extras": ext} if parametric: return program.batch_call(params=theta, **row) return program.batch_call(**row) else: @jax.jit def jit_fn(row: Row, theta: Any) -> jnp.ndarray: if static: ext = dict(row.get("extras", {})) ext.update(static) row = {**row, "extras": ext} batched = {k: v for k, v in row.items() if hasattr(v, "ndim") and v.ndim >= 1} scalars = {k: v for k, v in row.items() if not hasattr(v, "ndim") or v.ndim < 1} def _call(b): if parametric: return program(params=theta, **{**b, **scalars}) return program(**{**b, **scalars}) return jax.vmap(_call)(batched) return jit_fn _default_jit = _make_jit_fn({}) def batch_row_eval(row: Row, theta: Any) -> jnp.ndarray: stripped, fn_extras, cache_key = _split_extras(row) if not cache_key: return _default_jit(stripped, theta) if cache_key not in _jit_cache: _jit_cache[cache_key] = _make_jit_fn(fn_extras) return _jit_cache[cache_key](stripped, theta) return batch_row_eval # [new code — parametric update] minibatch infrastructure below def _build_minibatch_runner( plan: IntegrationPlan, program, *, batch_size: int, context: Optional[str] = None, ): """Build a stochastic mini-batch evaluator from an existing plan. Parameters ---------- plan : IntegrationPlan A plan already built by ``_build_plan`` / ``make_parametric_integrator``. program : Integrand The same program used to build *plan*. batch_size : int Number of time indices to sample per evaluation. context : str, optional Forwarded to extras. Returns ------- run_batch : callable ``run_batch(theta, rng_key) -> scalar``. An unbiased estimator of the full-data loss (with ``reduce="sum"`` semantics). """ if not plan.chunks: def run_batch_empty(theta, rng_key): del theta, rng_key return jnp.asarray(0.0, dtype=jnp.float32) return run_batch_empty reduce_over_particles = plan.reduce_over_particles batch_row_eval = _build_batch_row_eval(program, context=context) # Pool valid indices per dataset from the plan. ds_indices: Dict[int, jnp.ndarray] = {} for chunk in plan.chunks: ds_idx = chunk.dataset_index valid = chunk.t_block[chunk.valid_block] if ds_idx in ds_indices: ds_indices[ds_idx] = jnp.concatenate([ds_indices[ds_idx], valid]) else: ds_indices[ds_idx] = valid # Pre-compute per-dataset batch sizes (proportional allocation). total_valid = sum(int(idx.shape[0]) for idx in ds_indices.values()) ds_batch_info = [] # list of (ds_idx, all_idx, n_batch) for ds_idx, all_idx in ds_indices.items(): n_ds = int(all_idx.shape[0]) n_batch = max(1, min(n_ds, round(batch_size * n_ds / total_valid))) ds_batch_info.append((ds_idx, all_idx, n_ds, n_batch)) def run_batch(theta, rng_key): acc = jnp.asarray(0.0, dtype=jnp.float32) for ds_idx, all_idx, n_ds, n_batch in ds_batch_info: rng_key, subkey = jax.random.split(rng_key) # Sample without replacement perm = jax.random.permutation(subkey, n_ds)[:n_batch] sampled = all_idx[perm] # (n_batch,) batch_producer = plan.batch_producers[ds_idx] batch_row = batch_producer(sampled) y = batch_row_eval(batch_row, theta) # y shape: (n_batch, N, ...) or (n_batch, ...) if reduce_over_particles: if y.ndim < 2: raise ValueError("reduce_over_particles=True but batch_row_eval returned < 2 dimensions.") m = batch_row.get("mask_out", None) if m is not None: n_trail = y.ndim - 2 mexp = m.reshape(m.shape + (1,) * n_trail) y = jnp.where(mexp, y, 0.0) y = jnp.sum(y, axis=1) # dt weighting if plan.weight_by_dt: dt = batch_row["dt"] dt_exp = dt.reshape((dt.shape[0],) + (1,) * (y.ndim - 1)) y_w = y * dt_exp else: y_w = y # Sum over batch and scale to be unbiased y_sum = jnp.sum(y_w, axis=0) scale = jnp.asarray(n_ds / n_batch, dtype=y_sum.dtype) acc = acc + y_sum * scale return acc return run_batch
[docs] def make_minibatch_parametric_integrator( collection, program, *, batch_size: int, reduce: str = "sum", reduce_over_particles: bool = True, weight_by_dt: bool = True, subsampling: int = 1, chunk_target_bytes: int = 512 * 1024**2, context: Optional[str] = None, bytes_per_sample: Optional[int] = None, batch: bool = True, ) -> Tuple[IntegrationPlan, Callable, Callable]: """Build a parametric integrator with both full and mini-batch runners. Parameters ---------- collection, program, reduce, reduce_over_particles, weight_by_dt, subsampling, chunk_target_bytes, context, bytes_per_sample, batch Same as :func:`make_parametric_integrator`. batch_size : int Number of time indices to sample per mini-batch evaluation. Returns ------- plan : IntegrationPlan run_full : callable ``run_full(theta) -> scalar`` — full-data evaluator. run_batch : callable ``run_batch(theta, rng_key) -> scalar`` — stochastic mini-batch evaluator. Unbiased estimator of the full-data value. """ plan, run_full = make_parametric_integrator( collection, program, reduce=reduce, reduce_over_particles=reduce_over_particles, weight_by_dt=weight_by_dt, subsampling=subsampling, chunk_target_bytes=chunk_target_bytes, context=context, bytes_per_sample=bytes_per_sample, batch=batch, ) run_batch = _build_minibatch_runner( plan, program, batch_size=batch_size, context=context, ) return plan, run_full, run_batch