SFI.trajectory.dataset module¶
Trajectory dataset: single-index producer with explicit valid index window.
This module defines TrajectoryDataset, an immutable container for one
trajectory and a JAX-traceable single-t row producer used by the integration
runtime. No shape heuristics are used. Time bounds are enforced by an explicit
valid_indices(require) window; all gathers assume indices are in-bounds.
Key APIs¶
valid_indices(require, subsampling=None)()Returns the time indices
tfor which all requested streams are in-bounds.
make_producer(require, include_mask=True, include_dt=True, ...)()Returns
producer(t)that builds a fixed-structure single-row mapping.
build_extras(t, dataset_index=0, context=None)()Assembles user + reserved extras at a single time index.
Streams¶
- States:
"X","X_minus","X_plus", "X_minusminus","X_plusplus".
- States:
- Increments:
"dX_minus" = X[t]-X[t-1],"dX" = X[t+1]-X[t], "dX_plus" = X[t+2]-X[t+1].
- Increments:
Windows:
"X_window:<W>"returns(W, N, d)containingX[t - (W-1)//2], ..., X[t + W//2](W positions, any positive int). Odd W → symmetric; even W → one extra to the right.Mask:
"mask"(per-particle validity att).Time steps: if
include_dt=True, provides"dt"and, when required,"dt_minus","dt_plus"from either scalar/arraydtor absolute timest.Extras: include by adding
"extras"torequire. Values are static objects orTimeSeriesExtrathat are sliced at timet. Callables are allowed if JAX-traceable and accept(t_idx, context=None).
Boundary policy¶
valid_indicescomputes the exact interior window from stream offsets.All gathers assume in-bounds indices. No clamping. Passing an invalid index is a logic error and leads to undefined behavior under XLA.
Edge effects must be handled by downstream masking via
"mask_out".
Example
>>> ds = TrajectoryDataset(X, dt=0.01, extras_global={"box": jnp.array([Lx, Ly])})
>>> req = {"X", "dX", "mask", "extras"}
>>> t_idx = ds.valid_indices(req) # e.g. arange(0, T-1)
>>> producer = ds.make_producer(req, include_dt=True)
>>> # Integrator does: Ys = jax.vmap(lambda tt: program(**producer(tt)))(t_idx)
- class SFI.trajectory.dataset.FunctionExtra(func)[source]¶
Bases:
objectWrapper for a JAX-compatible callable passed through extras.
Unlike plain callables (which are invoked eagerly by
TrajectoryDataset.build_extras()as time-dependent generators), aFunctionExtrais passed through unchanged so the user’s basis function can call it inside JIT.- Parameters:
func (callable) – A JAX-traceable function, e.g.
func(x) -> Array. It will be captured as a compile-time constant inside@jax.jit.
Examples
>>> adhesion = FunctionExtra(lambda x: jnp.exp(-jnp.sum(x**2))) >>> coll = TrajectoryCollection.from_arrays( ... X=X, dt=0.01, ... extras_global={"adhesion": adhesion}, ... )
- func: Callable¶
- class SFI.trajectory.dataset.TimeSeriesExtra(data)[source]¶
Bases:
objectWrapper for time-dependent extras with an explicit leading time axis.
- Parameters:
data (jax.Array) – Array with shape
(T, ...)for globals or(T, N, ...)for per-particle extras. The dataset will slicedata[t].
- data: Array¶
- class SFI.trajectory.dataset.TrajectoryDataset(X, dt=None, t=None, mask=None, dynamic_mask=None, extras_global=<factory>, extras_local=<factory>, meta=<factory>)[source]¶
Bases:
objectImmutable dataset for a single trajectory.
- Parameters:
X (jax.Array) – State array of shape
(T, N, d)or(T, d). If(T, d), N is 1.dt (jax.Array | float | None) – Either a scalar step, an array of shape
(T,)(per-step), orNone. IfNoneandtis provided, steps are derived fromt.t (jax.Array | None) – Optional absolute time vector of shape
(T,). If provided, it defines dt via finite differences when requested.mask (jax.Array | None) – Optional boolean mask of shape
(T, N)or(T,)marking valid observations at time t and particle n (“static mask”). IfNone, all ones. A True entry means the particle’s position is known and can be used for state evaluation (e.g. neighbor forces).dynamic_mask (jax.Array | None) – Optional boolean mask of shape
(T, N)or(T,)marking entries whose increments are reliable and should contribute to parameter fitting (“dynamic mask”). Must be a subset ofmask(dynamic_mask ⊆ mask). IfNone, defaults tomask. Typical use: particles near open boundaries are statically valid (their positions are known) but dynamically masked (their neighborhoods are incomplete, biasing their increments).extras_global (Dict[str, Any] | None) – Dict of global extras. Values are static objects,
TimeSeriesExtra, or JAX-traceable callablesf(t_idx, context=None) -> Arraywith a leading time axis.extras_local (Dict[str, Any] | None) – Dict of per-particle extras. Same typing as
extras_global. Time-series entries typically have shape(T, N, ...).metadata – Free-form metadata.
meta (Dict[str, Any] | None)
- property N: int¶
- property T: int¶
- Teff(required, *, subsampling=1)[source]¶
Effective exposure time over valid indices for weighting.
Defined as
Teff = sum_t N_active[t] * dt[t],
where N_active[t] is the number of active (unmasked) particles at time index t under the same stream requirements used by the integration runtime.
This reuses the same dt logic as
_dt_fields_single()and the same masking logic as_output_mask_single(), so that weighting matches exactly what the runtime sees.- Parameters:
required (Set[str])
subsampling (int)
- Return type:
float
- X: Array¶
- build_extras(t_idx, *, dataset_index=0, context=None)[source]¶
Full model-facing extras at
t_idx: user values plus reserved keys.User extras are sliced at the frame(s) — a
TimeSeriesExtrais indexed, a callable invoked, anything else forwarded — and the reservedtime/duration/dataset_index/particle_indexare resolved for this dataset. This single mapping is what every consumer (inference, simulation, diagnostics) feeds the force/diffusion expression.extras_localoverridesextras_globalon key conflicts.- Parameters:
t_idx (Array)
dataset_index (int)
context (str | None)
- Return type:
Dict[str, Any]
- property d: int¶
- dt: Array | float | None = None¶
- dynamic_mask: Array | None = None¶
- extras_global: Dict[str, Any] | None¶
- extras_local: Dict[str, Any] | None¶
- classmethod from_arrays(*, X, dt=None, t=None, mask=None, dynamic_mask=None, extras_global=None, extras_local=None, meta=None)[source]¶
Construct a dataset from array-likes.
All inputs are converted to JAX arrays where relevant; extras and meta are stored as-is (no deep conversion).
- Return type:
- Raises:
ValueError – If X contains NaN/Inf, has wrong dimensionality, dt <= 0, or the trajectory is too short for any useful computation.
- Parameters:
X (Any)
dt (float | None)
t (Any | None)
mask (Any | None)
dynamic_mask (Any | None)
extras_global (Dict[str, Any] | None)
extras_local (Dict[str, Any] | None)
meta (Dict[str, Any] | None)
- make_batch_producer(require, *, include_mask=True, include_dt=True, context=None, force_dt_keys=None, dataset_index=0)[source]¶
Return a function that gathers a batch of rows in one vectorised pass.
This is the batch counterpart of
make_producer(). Instead of building one row at a time (designed for use insidejax.vmap), this function gathers K rows at once using array indexing, producing arrays with a leadingKaxis.- Parameters:
require (Set[str]) – Same meaning as in
make_producer().include_mask (bool) – Same meaning as in
make_producer().include_dt (bool) – Same meaning as in
make_producer().context (str | None) – Same meaning as in
make_producer().force_dt_keys (Set[str] | None) – Same meaning as in
make_producer().dataset_index (int)
- Returns:
batch_producer –
batch_producer(t_block)witht_blockof shape(K,)returns a dict whose arrays have a leadingKaxis.- Return type:
Callable[[Array], Dict[str, Any]]
Notes
Extras limitation: time-varying extras (
TimeSeriesExtraand callables) are evaluated att_block[0]only, not at each index in the block. This batch producer is designed for use cases where extras are global constants across the chunk (e.g. static boundary tensors). For per-step time-varying extras, usemake_producer()withjax.vmapinstead.
- make_producer(require, *, include_mask=True, include_dt=True, context=None, force_dt_keys=None, dataset_index=0)[source]¶
Return a JAX-traceable function that builds a single-t row.
- Parameters:
require (Set[str]) – Set of stream names and the special key
"extras"if extras are needed by downstream expressions.include_mask (bool) – If True, include
"mask_out"computed fromrequire.include_dt (bool) – If True, include
"dt"and neighbors when needed.context (str | None) – Optional string to pass through to extras callables.
force_dt_keys (Set[str] | None) – Extra dt fields to force, e.g.
{"dt_plus"}.dataset_index (int) – Position of this dataset within its collection; resolves the reserved
dataset_indexextra on every row.
- Returns:
producer – A function such that
producer(t)returns a dict whose leaves are single-row arrays. Structure is fixed across calls.- Return type:
Callable[[Array], Dict[str, Any]]
Notes
Use
valid_indices()to generate in-bounds indices. The producer assumes its input is valid and does not clamp.
- mask: Array | None = None¶
- materialize_time(*, as_numpy=True)[source]¶
Return a dense absolute time vector
tof shape(T,).Notes
If
self.tis not None, it is returned as-is.Else, if
self.dtis a scalar, uset[k] = k * dt.Else, if
self.dtis an array of shape(T,), interpretdt[k]as the step betweenX[k]andX[k+1]and buildt[0] = 0 t[k+1] = t[k] + dt[k] for k = 0..T-2
The last entry
dt[T-1](if any) is ignored.If both
tanddtare None, a ValueError is raised.
- meta: Dict[str, Any] | None¶
- split_time(fraction=0.8, *, gap=0)[source]¶
Split into
(train, test)datasets along the time axis.A side feature for data-abundant scenarios: SFI estimates its own accuracy from the training data (
force_predicted_MSE) and validates fits through the diagnostics suite, neither of which costs any data. Hold out a test fraction only when data is plentiful, or to confirm a suspected bias floor.- Parameters:
fraction (float) – Fraction of frames assigned to the train half: train is
[0, round(fraction*T)), test is the remainder (after the optionalgap).gap (int) – Number of frames dropped between the halves.
0is safe for increment-based estimators (the boundary increment belongs to neither half by construction); use a few correlation times for slowly mixing systems.
- Return type:
Tuple[TrajectoryDataset, TrajectoryDataset]
- t: Array | None = None¶
- to_arrays(*, as_numpy=True, include_mask=True)[source]¶
Materialize dense trajectory arrays for this dataset.
This is intended for plotting and quick inspection, not for JAX integration (use
make_producer()for that).- Parameters:
as_numpy (bool) – If True (default), return NumPy arrays.
include_mask (bool) – If True (default), return the per-particle validity mask.
- Returns:
t – Absolute time vector of shape
(T,); seematerialize_time().X – State tensor of shape
(T, N, d).mask – Boolean mask of shape
(T, N)ifinclude_maskis True, otherwiseNone.
- Return type:
tuple[ndarray, ndarray, ndarray] | tuple[Array, Array, Array | None]
- property uuid: str¶
Stable identity of this dataset.
- valid_indices(required, subsampling=None)[source]¶
Return valid time indices given required streams.
A time index
tis valid ifft + amin >= 0andt + amax <= T-1, where(amin, amax)aggregates all offsets required by streams inrequired. Extras do not affect the window.- Parameters:
required (Set[str]) – Set of stream names and possibly
"extras".subsampling (int | None) – Optional positive integer. If provided, keep only indices where
t % subsampling == 0(grid-aligned to multiples ofsubsampling). This may exclude the first valid index if it is not a multiple ofsubsampling.
- Returns:
1-D array of valid time indices (dtype=int32), possibly empty.
- Return type:
jax.Array
- SFI.trajectory.dataset.function_extra(func)[source]¶
Build a
FunctionExtrafrom a callable.- Parameters:
func (Callable)
- Return type:
- SFI.trajectory.dataset.time_series_extra(x)[source]¶
Build a
TimeSeriesExtrafrom an array-like.- Parameters:
x (Any)
- Return type: