SFI.trajectory.dataset module

Trajectory dataset: single-index producer with explicit valid index window.

This module defines TrajectoryDataset, an immutable container for one trajectory and a JAX-traceable single-t row producer used by the integration runtime. No shape heuristics are used. Time bounds are enforced by an explicit valid_indices(require) window; all gathers assume indices are in-bounds.

Key APIs

  • valid_indices(require, subsampling=None)()

    Returns the time indices t for which all requested streams are in-bounds.

  • make_producer(require, include_mask=True, include_dt=True, ...)()

    Returns producer(t) that builds a fixed-structure single-row mapping.

  • build_extras(t, dataset_index=0, context=None)()

    Assembles user + reserved extras at a single time index.

Streams

  • States: "X", "X_minus", "X_plus",

    "X_minusminus", "X_plusplus".

  • Increments: "dX_minus" = X[t]-X[t-1], "dX" = X[t+1]-X[t],

    "dX_plus" = X[t+2]-X[t+1].

  • Windows: "X_window:<W>" returns (W, N, d) containing X[t - (W-1)//2], ..., X[t + W//2] (W positions, any positive int). Odd W → symmetric; even W → one extra to the right.

  • Mask: "mask" (per-particle validity at t).

  • Time steps: if include_dt=True, provides "dt" and, when required, "dt_minus", "dt_plus" from either scalar/array dt or absolute times t.

  • Extras: include by adding "extras" to require. Values are static objects or TimeSeriesExtra that are sliced at time t. Callables are allowed if JAX-traceable and accept (t_idx, context=None).

Boundary policy

  • valid_indices computes the exact interior window from stream offsets.

  • All gathers assume in-bounds indices. No clamping. Passing an invalid index is a logic error and leads to undefined behavior under XLA.

  • Edge effects must be handled by downstream masking via "mask_out".

Example

>>> ds = TrajectoryDataset(X, dt=0.01, extras_global={"box": jnp.array([Lx, Ly])})
>>> req = {"X", "dX", "mask", "extras"}
>>> t_idx = ds.valid_indices(req)                  # e.g. arange(0, T-1)
>>> producer = ds.make_producer(req, include_dt=True)
>>> # Integrator does: Ys = jax.vmap(lambda tt: program(**producer(tt)))(t_idx)
class SFI.trajectory.dataset.FunctionExtra(func)[source]

Bases: object

Wrapper for a JAX-compatible callable passed through extras.

Unlike plain callables (which are invoked eagerly by TrajectoryDataset.build_extras() as time-dependent generators), a FunctionExtra is passed through unchanged so the user’s basis function can call it inside JIT.

Parameters:

func (callable) – A JAX-traceable function, e.g. func(x) -> Array. It will be captured as a compile-time constant inside @jax.jit.

Examples

>>> adhesion = FunctionExtra(lambda x: jnp.exp(-jnp.sum(x**2)))
>>> coll = TrajectoryCollection.from_arrays(
...     X=X, dt=0.01,
...     extras_global={"adhesion": adhesion},
... )
func: Callable
class SFI.trajectory.dataset.TimeSeriesExtra(data)[source]

Bases: object

Wrapper for time-dependent extras with an explicit leading time axis.

Parameters:

data (jax.Array) – Array with shape (T, ...) for globals or (T, N, ...) for per-particle extras. The dataset will slice data[t].

data: Array
class SFI.trajectory.dataset.TrajectoryDataset(X, dt=None, t=None, mask=None, dynamic_mask=None, extras_global=<factory>, extras_local=<factory>, meta=<factory>)[source]

Bases: object

Immutable dataset for a single trajectory.

Parameters:
  • X (jax.Array) – State array of shape (T, N, d) or (T, d). If (T, d), N is 1.

  • dt (jax.Array | float | None) – Either a scalar step, an array of shape (T,) (per-step), or None. If None and t is provided, steps are derived from t.

  • t (jax.Array | None) – Optional absolute time vector of shape (T,). If provided, it defines dt via finite differences when requested.

  • mask (jax.Array | None) – Optional boolean mask of shape (T, N) or (T,) marking valid observations at time t and particle n (“static mask”). If None, all ones. A True entry means the particle’s position is known and can be used for state evaluation (e.g. neighbor forces).

  • dynamic_mask (jax.Array | None) – Optional boolean mask of shape (T, N) or (T,) marking entries whose increments are reliable and should contribute to parameter fitting (“dynamic mask”). Must be a subset of mask (dynamic_mask mask). If None, defaults to mask. Typical use: particles near open boundaries are statically valid (their positions are known) but dynamically masked (their neighborhoods are incomplete, biasing their increments).

  • extras_global (Dict[str, Any] | None) – Dict of global extras. Values are static objects, TimeSeriesExtra, or JAX-traceable callables f(t_idx, context=None) -> Array with a leading time axis.

  • extras_local (Dict[str, Any] | None) – Dict of per-particle extras. Same typing as extras_global. Time-series entries typically have shape (T, N, ...).

  • metadata – Free-form metadata.

  • meta (Dict[str, Any] | None)

property N: int
property T: int
Teff(required, *, subsampling=1)[source]

Effective exposure time over valid indices for weighting.

Defined as

Teff = sum_t N_active[t] * dt[t],

where N_active[t] is the number of active (unmasked) particles at time index t under the same stream requirements used by the integration runtime.

This reuses the same dt logic as _dt_fields_single() and the same masking logic as _output_mask_single(), so that weighting matches exactly what the runtime sees.

Parameters:
  • required (Set[str])

  • subsampling (int)

Return type:

float

X: Array
build_extras(t_idx, *, dataset_index=0, context=None)[source]

Full model-facing extras at t_idx: user values plus reserved keys.

User extras are sliced at the frame(s) — a TimeSeriesExtra is indexed, a callable invoked, anything else forwarded — and the reserved time / duration / dataset_index / particle_index are resolved for this dataset. This single mapping is what every consumer (inference, simulation, diagnostics) feeds the force/diffusion expression. extras_local overrides extras_global on key conflicts.

Parameters:
  • t_idx (Array)

  • dataset_index (int)

  • context (str | None)

Return type:

Dict[str, Any]

property d: int
dt: Array | float | None = None
dynamic_mask: Array | None = None
extras_global: Dict[str, Any] | None
extras_local: Dict[str, Any] | None
classmethod from_arrays(*, X, dt=None, t=None, mask=None, dynamic_mask=None, extras_global=None, extras_local=None, meta=None)[source]

Construct a dataset from array-likes.

All inputs are converted to JAX arrays where relevant; extras and meta are stored as-is (no deep conversion).

Return type:

TrajectoryDataset

Raises:

ValueError – If X contains NaN/Inf, has wrong dimensionality, dt <= 0, or the trajectory is too short for any useful computation.

Parameters:
  • X (Any)

  • dt (float | None)

  • t (Any | None)

  • mask (Any | None)

  • dynamic_mask (Any | None)

  • extras_global (Dict[str, Any] | None)

  • extras_local (Dict[str, Any] | None)

  • meta (Dict[str, Any] | None)

make_batch_producer(require, *, include_mask=True, include_dt=True, context=None, force_dt_keys=None, dataset_index=0)[source]

Return a function that gathers a batch of rows in one vectorised pass.

This is the batch counterpart of make_producer(). Instead of building one row at a time (designed for use inside jax.vmap), this function gathers K rows at once using array indexing, producing arrays with a leading K axis.

Parameters:
Returns:

batch_producerbatch_producer(t_block) with t_block of shape (K,) returns a dict whose arrays have a leading K axis.

Return type:

Callable[[Array], Dict[str, Any]]

Notes

Extras limitation: time-varying extras (TimeSeriesExtra and callables) are evaluated at t_block[0] only, not at each index in the block. This batch producer is designed for use cases where extras are global constants across the chunk (e.g. static boundary tensors). For per-step time-varying extras, use make_producer() with jax.vmap instead.

make_producer(require, *, include_mask=True, include_dt=True, context=None, force_dt_keys=None, dataset_index=0)[source]

Return a JAX-traceable function that builds a single-t row.

Parameters:
  • require (Set[str]) – Set of stream names and the special key "extras" if extras are needed by downstream expressions.

  • include_mask (bool) – If True, include "mask_out" computed from require.

  • include_dt (bool) – If True, include "dt" and neighbors when needed.

  • context (str | None) – Optional string to pass through to extras callables.

  • force_dt_keys (Set[str] | None) – Extra dt fields to force, e.g. {"dt_plus"}.

  • dataset_index (int) – Position of this dataset within its collection; resolves the reserved dataset_index extra on every row.

Returns:

producer – A function such that producer(t) returns a dict whose leaves are single-row arrays. Structure is fixed across calls.

Return type:

Callable[[Array], Dict[str, Any]]

Notes

  • Use valid_indices() to generate in-bounds indices. The producer assumes its input is valid and does not clamp.

mask: Array | None = None
materialize_time(*, as_numpy=True)[source]

Return a dense absolute time vector t of shape (T,).

Notes

  • If self.t is not None, it is returned as-is.

  • Else, if self.dt is a scalar, use t[k] = k * dt.

  • Else, if self.dt is an array of shape (T,), interpret dt[k] as the step between X[k] and X[k+1] and build

    t[0] = 0 t[k+1] = t[k] + dt[k] for k = 0..T-2

    The last entry dt[T-1] (if any) is ignored.

  • If both t and dt are None, a ValueError is raised.

Parameters:

as_numpy (bool) – If True, return a NumPy array; otherwise a JAX array.

Returns:

t

Return type:

array, shape (T,)

meta: Dict[str, Any] | None
split_time(fraction=0.8, *, gap=0)[source]

Split into (train, test) datasets along the time axis.

A side feature for data-abundant scenarios: SFI estimates its own accuracy from the training data (force_predicted_MSE) and validates fits through the diagnostics suite, neither of which costs any data. Hold out a test fraction only when data is plentiful, or to confirm a suspected bias floor.

Parameters:
  • fraction (float) – Fraction of frames assigned to the train half: train is [0, round(fraction*T)), test is the remainder (after the optional gap).

  • gap (int) – Number of frames dropped between the halves. 0 is safe for increment-based estimators (the boundary increment belongs to neither half by construction); use a few correlation times for slowly mixing systems.

Return type:

Tuple[TrajectoryDataset, TrajectoryDataset]

t: Array | None = None
to_arrays(*, as_numpy=True, include_mask=True)[source]

Materialize dense trajectory arrays for this dataset.

This is intended for plotting and quick inspection, not for JAX integration (use make_producer() for that).

Parameters:
  • as_numpy (bool) – If True (default), return NumPy arrays.

  • include_mask (bool) – If True (default), return the per-particle validity mask.

Returns:

  • t – Absolute time vector of shape (T,); see materialize_time().

  • X – State tensor of shape (T, N, d).

  • mask – Boolean mask of shape (T, N) if include_mask is True, otherwise None.

Return type:

tuple[ndarray, ndarray, ndarray] | tuple[Array, Array, Array | None]

property uuid: str

Stable identity of this dataset.

valid_indices(required, subsampling=None)[source]

Return valid time indices given required streams.

A time index t is valid iff t + amin >= 0 and t + amax <= T-1, where (amin, amax) aggregates all offsets required by streams in required. Extras do not affect the window.

Parameters:
  • required (Set[str]) – Set of stream names and possibly "extras".

  • subsampling (int | None) – Optional positive integer. If provided, keep only indices where t % subsampling == 0 (grid-aligned to multiples of subsampling). This may exclude the first valid index if it is not a multiple of subsampling.

Returns:

1-D array of valid time indices (dtype=int32), possibly empty.

Return type:

jax.Array

SFI.trajectory.dataset.function_extra(func)[source]

Build a FunctionExtra from a callable.

Parameters:

func (Callable)

Return type:

FunctionExtra

SFI.trajectory.dataset.time_series_extra(x)[source]

Build a TimeSeriesExtra from an array-like.

Parameters:

x (Any)

Return type:

TimeSeriesExtra