SFI.trajectory package

Trajectory submodule public API.

Exports: - TrajectoryDataset / TrajectoryCollection — main user-facing containers - FunctionExtra / function_extra — pass JAX-traceable callables as extras - TimeSeriesExtra / time_series_extra — wrap time-varying array extras

I/O (save_trajectory, load_trajectory, columns_and_extras_to_dataset) is available via SFI.trajectory.io but is not re-exported here; most users should use TrajectoryCollection.save / TrajectoryCollection.load.

class SFI.trajectory.FunctionExtra(func)[source]

Bases: object

Wrapper for a JAX-compatible callable passed through extras.

Unlike plain callables (which are invoked eagerly by TrajectoryDataset.build_extras() as time-dependent generators), a FunctionExtra is passed through unchanged so the user’s basis function can call it inside JIT.

Parameters:

func (callable) – A JAX-traceable function, e.g. func(x) -> Array. It will be captured as a compile-time constant inside @jax.jit.

Examples

>>> adhesion = FunctionExtra(lambda x: jnp.exp(-jnp.sum(x**2)))
>>> coll = TrajectoryCollection.from_arrays(
...     X=X, dt=0.01,
...     extras_global={"adhesion": adhesion},
... )
func: Callable
class SFI.trajectory.TimeSeriesExtra(data)[source]

Bases: object

Wrapper for time-dependent extras with an explicit leading time axis.

Parameters:

data (jax.Array) – Array with shape (T, ...) for globals or (T, N, ...) for per-particle extras. The dataset will slice data[t].

data: Array
class SFI.trajectory.TrajectoryCollection(datasets, weights)[source]

Bases: object

Container for one or more trajectories plus per-dataset weights.

This is the main user-facing trajectory object. It wraps a list of TrajectoryDataset instances and exposes an index-based streaming interface used by the integration runtime.

Most users should construct collections via from_arrays(), from_dataset() or load() rather than instantiating this dataclass directly.

Parameters:
  • datasets (List[SFI.trajectory.dataset.TrajectoryDataset]) – List of underlying TrajectoryDataset objects. The order is preserved in iteration and determines the ordering of the weights vector.

  • weights (jax.Array) – 1D JAX array of shape (D,) with non-negative entries, where D = len(datasets). The vector is normalized to sum to 1 by with_weights().

Notes

The collection itself does not impose any chunking heuristic. It only coordinates datasets and their weights; the integrator decides how to vmap over the indices returned by iter_slices().

Teff(required, *, subsampling=1)[source]

Total effective exposure time across all datasets.

This is simply the sum of per-dataset Teff values:

sum_d datasets[d].Teff(required, subsampling=subsampling).

Parameters:
  • required (Set[str])

  • subsampling (int)

Return type:

float

concat(items, *, weights='pool')[source]

Concatenate this collection with other collections or datasets.

Parameters:
Returns:

New collection containing all datasets from self followed by all datasets from items.

Return type:

TrajectoryCollection

dataset_index(position)[source]

Dense index of dataset position, keyed on its stable identity.

Datasets are numbered by first appearance of their uuid, so the index a force sees (e.g. via per_dataset_scalar() or dataset_indicator()) is tied to the dataset itself, not its slot — stable under concatenation and reordering.

Parameters:

position (int)

Return type:

int

datasets: List[TrajectoryDataset]
degrade(*, downsample=1, motion_blur=0, data_loss_fraction=0.0, noise=None, ROI=None, seed=None, reweight='pool')[source]

Return a new degraded collection; the original is not modified.

This is the preferred user-facing API for degrading synthetic trajectories to mimic experimental noise, blur, and data loss.

Parameters:
Returns:

New degraded collection. The original collection is not modified.

Return type:

TrajectoryCollection

classmethod from_arrays(*, X, dt=None, t=None, mask=None, extras_global=None, extras_local=None, meta=None, weights='pool')[source]

Build a single-dataset collection from array-likes.

This is the recommended entry point when you already have tensors in memory.

Parameters:
  • X (Any) – State array of shape (T, N, d) or (T, d). If (T, d), a single particle is assumed.

  • dt (float | None) – Either a scalar step, an array of shape (T,) (per-step), or None. If None and t is provided, effective steps are derived from t on demand.

  • t (Any | None) – Optional absolute time vector of shape (T,). If provided, it defines time steps when needed.

  • mask (Any | None) – Optional boolean mask of shape (T, N) or (T,) marking valid observations. If None, all entries are considered valid.

  • extras_global (Dict[str, Any] | None) – Mapping of global extras. Values can be static objects, TimeSeriesExtra, or JAX-traceable callables f(t_idx, context=None) -> Array with a leading time axis.

  • extras_local (Dict[str, Any] | None) – Mapping of per-particle extras, with the same typing as extras_global. Time-series entries typically have shape (T, N, ...).

  • meta (Dict[str, Any] | None) – Free-form metadata dictionary attached to the underlying dataset.

  • weights (str | Sequence[float]) – Initial weight specification for the resulting collection. See with_weights().

Returns:

A collection with one dataset built from the provided arrays.

Return type:

TrajectoryCollection

classmethod from_columns(particle_idx, time_idx, state_vectors, *, extras_global=None, extras_local=None, dt=None, t=None, relabel=True, compress_particles=False, meta=None, weights='pool')[source]

Build a single-dataset collection from flat (particle, time) columns.

This constructor is convenient when reading trajectories from a tabular format or a custom pipeline.

Parameters:
  • particle_idx (ndarray) – Integer array of shape (L,) with particle IDs for each row.

  • time_idx (ndarray) – Integer array of shape (L,) with time indices t for each row.

  • state_vectors (ndarray) – Array of shape (L, d) with state vectors.

  • extras_global (Mapping[str, Any] | None) – Parsed global extras (e.g. from YAML header), as described in SFI.trajectory.io.

  • extras_local (Mapping[str, Any] | None) – Parsed local extras, including time-series extras, as described in SFI.trajectory.io.

  • dt (float | None) – Optional scalar step; used only if no absolute time axis is provided via t or extras_global['t'].

  • t (ndarray | None) – Optional time vector of shape (T,) overriding any time axis inferred from extras.

  • relabel (bool) – If True, compress particle IDs to 0..N-1 and shift time to start at 0.

  • compress_particles (bool) – If True, apply greedy interval packing to reduce the column count by merging particles whose time supports do not overlap (with a 2-frame buffer). Per-particle extras are reindexed automatically. The mapping is stored in dataset.meta['particle_column_map'].

  • meta (Dict[str, Any] | None) – Metadata dictionary to attach to the dataset.

  • weights (str | Sequence[float]) – Initial weight specification for the resulting collection.

Returns:

A collection with one dataset assembled from the columns.

Return type:

TrajectoryCollection

classmethod from_dataframe(df, *, particle=None, time=None, coords=None, dt=None, t=None, extras_global=None, extras_local=None, relabel=True, compress_particles=False, meta=None, weights='pool')[source]

Build a single-dataset collection from a pandas DataFrame.

The natural entry point for raw tracking tables (trackpy, TrackMate, custom pipelines): columns are addressed by name, in any order, and junk columns are dropped.

Parameters:
  • df – A pandas DataFrame with one row per (particle, time) observation.

  • particle (str | None) – Name of the particle/track-ID column. Default: case-insensitive auto-detection among particle_id, particle, track_id, track, traj_id; if none is present the table is treated as a single trajectory, and if several are present a ValueError asks for an explicit choice.

  • time (str | None) – Name of the time column. Default: auto-detection among time_step, frame, time, t (same ambiguity rule). Integer columns are used as frame indices; float columns are factorized into frame indices and, unless t or dt is given, their sorted unique values become the absolute time axis.

  • coords (Sequence[str] | None) – State-vector column names, in order. Default: every remaining column without an extras prefix (G_, TG_, P_, TP_), in dataframe order. Columns not selected are silently dropped.

  • dt (float | None) – Time-axis specification, as in from_columns().

  • t (Any | None) – Time-axis specification, as in from_columns().

  • extras_global (Mapping[str, Any] | None) – Extra fields merged over any extras parsed from prefixed columns (user values win).

  • extras_local (Mapping[str, Any] | None) – Extra fields merged over any extras parsed from prefixed columns (user values win).

  • relabel (bool) – As in from_columns().

  • compress_particles (bool) – As in from_columns().

  • meta (Dict[str, Any] | None) – As in from_columns().

  • weights (str | Sequence[float]) – As in from_columns().

Return type:

TrajectoryCollection

Examples

>>> coll = TrajectoryCollection.from_dataframe(
...     tracks, particle="track_id", time="frame",
...     coords=("x", "y"), dt=0.05,
... )
classmethod from_dataset(ds, *, weights='pool')[source]

Wrap a single TrajectoryDataset in a collection.

Parameters:
Returns:

A single-dataset collection with weights computed from ds.

Return type:

TrajectoryCollection

iter_slices(*, require, bytes_hint, chunk_target_bytes=67108864, subsampling=1, context=None)[source]

Yield chunks as (producer, t_idx) pairs for vmapped integration.

Parameters:
  • require (Set[str]) – Set of stream names required by the integrator (e.g. {"X","dX","mask"}). Passed to TrajectoryDataset.valid_indices() and TrajectoryDataset.make_producer().

  • bytes_hint (int | None) – Approximate per-row memory footprint (in bytes) of the values produced by the program. If None or <= 0, no chunking is performed and all valid indices are yielded at once.

  • chunk_target_bytes (int) – Target chunk size in bytes. Combined with bytes_hint to determine how many rows to include in each chunk.

  • subsampling (int) – Optional subsampling factor applied to the time indices before chunking.

  • context (str | None) – Optional context string passed through to the dataset producer, typically used to switch extra fields.

Yields:

dict – Mapping with keys:

  • "producer": Callable[[jax.Array], dict], single-t row builder.

  • "t_idx": 1D JAX array of integer time indices.

  • "dataset_index": index of the underlying dataset in datasets.

  • "weight": float dataset weight, taken from weights.

Return type:

Iterator[Mapping[str, Any]]

classmethod load(path, *, relabel=True, compress_particles=False, particle_column='auto', time_column='auto', state_columns=None)[source]

Load a collection from a single file or a directory.

Parameters:
  • relabel (bool) – If True, compress particle IDs to 0..N-1 and shift time to start at 0.

  • compress_particles (bool) – If True, further reduce the column count by merging particles whose time supports do not overlap (greedy interval packing with a 2-frame buffer). Useful for open-boundary systems where particles enter and leave the field of view, causing the naive N to grow as the total number of unique particle IDs rather than the concurrent count. Per-particle extras are reindexed automatically; the mapping is stored in dataset.meta['particle_column_map'].

  • particle_column (int | str | None) – Which columns hold the particle ID and the time index, as a column name (any format) or a positional index (CSV only). "auto" (default) keeps the loader defaults: CSV positional (column 0 = particle, column 1 = time), parquet/HDF5 the canonical names "particle_id" / "time_step". Pass particle_column=None for single-trajectory files.

  • time_column (int | str) – Which columns hold the particle ID and the time index, as a column name (any format) or a positional index (CSV only). "auto" (default) keeps the loader defaults: CSV positional (column 0 = particle, column 1 = time), parquet/HDF5 the canonical names "particle_id" / "time_step". Pass particle_column=None for single-trajectory files.

  • state_columns (Sequence[int | str] | None) – Optional explicit selection of the state-vector columns (names, or indices for CSV), in order; every other non-extras column is dropped. Default: all non-ID, non-extras columns.

  • path (str | Path)

Return type:

TrajectoryCollection

Notes

The default weight policy differs by path: a single-file load uses "Teff" (via from_dataset()); a directory load uses "equal". Call with_weights() after loading if a consistent policy is needed.

merge(items, *, weights='pool')[source]

Combine this collection with others into one collection.

Convenience alias for concat() — useful for assembling an ensemble from several single-trajectory collections (base.merge([c1, c2, ...])). See concat() for the weights policy.

Parameters:
Return type:

TrajectoryCollection

peek_X()[source]

Convenience helper: peek at the “X” stream. Shape-aligned with the first valid row of “X” from peek_row.

peek_dX()[source]

Convenience helper: peek at the “dX” stream. Shape-aligned with the first valid row of “dX” from peek_row.

peek_dt()[source]

Convenience helper: peek at the “dt” stream. Shape-aligned with the first valid row of “dt” from peek_row.

peek_mask()[source]

Convenience helper: peek at the “mask” stream.

Shape-aligned with the first valid row of “mask” from peek_row.

peek_row(*, require=frozenset({'X', 'dX'}), context=None)[source]

Return a single-t sample row from the first dataset with valid indices.

Parameters:
  • require (Set[str]) – Set of stream names required for the sample (as in iter_slices()).

  • context (str | None) – Optional context string forwarded to the producer.

Returns:

Structure matching producer(t) for the chosen dataset.

Return type:

dict

Notes

Useful for memory estimation and debugging program outputs.

save(path, *, format=None, **format_kw)[source]

Save the collection.

Notes

  • Single file path (.csv/.parquet/.h5): collection must have exactly one dataset.

  • Directory path: write one file per dataset + manifest.yaml.

  • Masked samples are dropped (no masked rows written).

  • No relabeling at save-time; relabeling is handled at load-time.

  • dynamic_mask is not persisted; after a save/load round-trip it will be None (equivalent to the static mask).

Parameters:
  • path (str | Path)

  • format (str | None)

  • format_kw (Any)

Return type:

Path

split_time(fraction=0.8, *, gap=0, reweight='pool')[source]

Split every dataset along time into (train, test) collections.

A side feature for data-abundant scenarios: SFI estimates its own accuracy from the training data (force_predicted_MSE) and validates fits through the diagnostics suite, neither of which costs any data. Hold out a test fraction only when data is plentiful, or to confirm a suspected bias floor with holdout_score().

Parameters:
  • fraction (float) – Fraction of frames per dataset assigned to the train half.

  • gap (int) – Frames dropped between the halves (decorrelation; 0 is safe for increment-based estimators).

  • reweight ({"Teff", "keep"}) – "Teff" (default) recomputes per-dataset weights on each half; "keep" carries over the current relative weights.

Return type:

tuple[TrajectoryCollection, TrajectoryCollection]

Examples

>>> train, test = coll.split_time(0.8)
>>> inf = OverdampedLangevinInference(train)
>>> # ... fit ...
>>> inf.holdout_score(test)
to_array(*, axis='time', as_numpy=True)[source]

Materialize the whole collection as one dense (T, N, d) array.

Concatenates every dataset along the time axis into a single array of positions. Use this for the legitimate non-plotting reach-ins (disk caching, ensemble bootstrap initial conditions, neighbour lists); for plotting, prefer the toolkit functions in SFI.utils.plotting, and for (t, X, mask) of a single dataset use to_arrays().

Parameters:
  • axis (Literal['time']) – Only "time" is supported (axis-0 concatenation).

  • as_numpy (bool) – If True (default), return a NumPy array; else a JAX array.

Return type:

ndarray, shape (sum_T, N, d)

to_arrays(*, dataset=0, as_numpy=True, include_mask=True)[source]

Convenience helper: materialize one dataset as dense arrays.

Parameters:
  • dataset (int) – Index of the dataset inside the collection (default 0).

  • as_numpy (bool) – If True, return NumPy arrays.

  • include_mask (bool) – If True, also return the per-particle mask.

Returns:

See TrajectoryDataset.to_arrays().

Return type:

t, X, mask

velocity_array(*, dataset=0, scheme='central', as_numpy=True)[source]

Finite-difference velocity v(t) for one dataset.

Reconstructs velocities from stored positions with SFI.utils.maths.fd_velocity(), matching the secant-velocity convention of the underdamped engine. Handy for building (x, v) phase portraits or held-out evaluation grids from position-only recordings.

Parameters:
  • dataset (int) – Dataset index inside the collection (default 0).

  • scheme (Literal['central', 'forward', 'backward']) – Finite-difference stencil; see SFI.utils.maths.fd_velocity().

  • as_numpy (bool) – If True (default), return a NumPy array; else a JAX array.

Returns:

v

Return type:

ndarray, shape (T, N, d)

weights: Array
with_weights(spec='pool', *, required=frozenset({'X', 'dX'}), subsampling=1)[source]

Set the per-dataset weights (an unnormalised multiplier).

Parameters:
  • spec (str | Sequence[float]) –

    Inter-dataset weight policy — a per-dataset multiplier applied to every estimator (force, diffusion, parametric). Accepted values:

    • "pool" (default): multiplier 1 for all datasets, i.e. pool every increment on equal footing. Combined with each estimator’s intrinsic within-dataset weighting (force is per-dt, diffusion per-point), this weights each dataset by its effective time (force) or point count (diffusion) — the natural maximum-likelihood pooling.

    • "per_dataset": each dataset contributes equally regardless of length (multiplier mean(Teff)/Teff_d). Exact for the force; for the diffusion it is exact when dt is uniform.

    • a sequence of floats: explicit unnormalised multipliers.

  • required (Set[str]) – Streams used to compute Teff in the "per_dataset" policy. See TrajectoryDataset.Teff().

  • subsampling (int) – Optional subsampling factor used when counting valid indices.

Returns:

The same collection with its weights field updated.

Return type:

TrajectoryCollection

Notes

Weights are exposed to the integrator via the "weight" entry in the payloads yielded by iter_slices() and applied in every reduction (sum and mean). They are deliberately unnormalised: the absolute scale cancels in the mean-reduced estimates, while for the force Gram / covariance it sets the information scale (a single dataset carries unit weight).

class SFI.trajectory.TrajectoryDataset(X, dt=None, t=None, mask=None, dynamic_mask=None, extras_global=<factory>, extras_local=<factory>, meta=<factory>)[source]

Bases: object

Immutable dataset for a single trajectory.

Parameters:
  • X (jax.Array) – State array of shape (T, N, d) or (T, d). If (T, d), N is 1.

  • dt (jax.Array | float | None) – Either a scalar step, an array of shape (T,) (per-step), or None. If None and t is provided, steps are derived from t.

  • t (jax.Array | None) – Optional absolute time vector of shape (T,). If provided, it defines dt via finite differences when requested.

  • mask (jax.Array | None) – Optional boolean mask of shape (T, N) or (T,) marking valid observations at time t and particle n (“static mask”). If None, all ones. A True entry means the particle’s position is known and can be used for state evaluation (e.g. neighbor forces).

  • dynamic_mask (jax.Array | None) – Optional boolean mask of shape (T, N) or (T,) marking entries whose increments are reliable and should contribute to parameter fitting (“dynamic mask”). Must be a subset of mask (dynamic_mask mask). If None, defaults to mask. Typical use: particles near open boundaries are statically valid (their positions are known) but dynamically masked (their neighborhoods are incomplete, biasing their increments).

  • extras_global (Dict[str, Any] | None) – Dict of global extras. Values are static objects, TimeSeriesExtra, or JAX-traceable callables f(t_idx, context=None) -> Array with a leading time axis.

  • extras_local (Dict[str, Any] | None) – Dict of per-particle extras. Same typing as extras_global. Time-series entries typically have shape (T, N, ...).

  • metadata – Free-form metadata.

  • meta (Dict[str, Any] | None)

property N: int
property T: int
Teff(required, *, subsampling=1)[source]

Effective exposure time over valid indices for weighting.

Defined as

Teff = sum_t N_active[t] * dt[t],

where N_active[t] is the number of active (unmasked) particles at time index t under the same stream requirements used by the integration runtime.

This reuses the same dt logic as _dt_fields_single() and the same masking logic as _output_mask_single(), so that weighting matches exactly what the runtime sees.

Parameters:
  • required (Set[str])

  • subsampling (int)

Return type:

float

X: Array
build_extras(t_idx, *, dataset_index=0, context=None)[source]

Full model-facing extras at t_idx: user values plus reserved keys.

User extras are sliced at the frame(s) — a TimeSeriesExtra is indexed, a callable invoked, anything else forwarded — and the reserved time / duration / dataset_index / particle_index are resolved for this dataset. This single mapping is what every consumer (inference, simulation, diagnostics) feeds the force/diffusion expression. extras_local overrides extras_global on key conflicts.

Parameters:
  • t_idx (Array)

  • dataset_index (int)

  • context (str | None)

Return type:

Dict[str, Any]

property d: int
dt: Array | float | None = None
dynamic_mask: Array | None = None
extras_global: Dict[str, Any] | None
extras_local: Dict[str, Any] | None
classmethod from_arrays(*, X, dt=None, t=None, mask=None, dynamic_mask=None, extras_global=None, extras_local=None, meta=None)[source]

Construct a dataset from array-likes.

All inputs are converted to JAX arrays where relevant; extras and meta are stored as-is (no deep conversion).

Return type:

TrajectoryDataset

Raises:

ValueError – If X contains NaN/Inf, has wrong dimensionality, dt <= 0, or the trajectory is too short for any useful computation.

Parameters:
  • X (Any)

  • dt (float | None)

  • t (Any | None)

  • mask (Any | None)

  • dynamic_mask (Any | None)

  • extras_global (Dict[str, Any] | None)

  • extras_local (Dict[str, Any] | None)

  • meta (Dict[str, Any] | None)

make_batch_producer(require, *, include_mask=True, include_dt=True, context=None, force_dt_keys=None, dataset_index=0)[source]

Return a function that gathers a batch of rows in one vectorised pass.

This is the batch counterpart of make_producer(). Instead of building one row at a time (designed for use inside jax.vmap), this function gathers K rows at once using array indexing, producing arrays with a leading K axis.

Parameters:
Returns:

batch_producerbatch_producer(t_block) with t_block of shape (K,) returns a dict whose arrays have a leading K axis.

Return type:

Callable[[Array], Dict[str, Any]]

Notes

Extras limitation: time-varying extras (TimeSeriesExtra and callables) are evaluated at t_block[0] only, not at each index in the block. This batch producer is designed for use cases where extras are global constants across the chunk (e.g. static boundary tensors). For per-step time-varying extras, use make_producer() with jax.vmap instead.

make_producer(require, *, include_mask=True, include_dt=True, context=None, force_dt_keys=None, dataset_index=0)[source]

Return a JAX-traceable function that builds a single-t row.

Parameters:
  • require (Set[str]) – Set of stream names and the special key "extras" if extras are needed by downstream expressions.

  • include_mask (bool) – If True, include "mask_out" computed from require.

  • include_dt (bool) – If True, include "dt" and neighbors when needed.

  • context (str | None) – Optional string to pass through to extras callables.

  • force_dt_keys (Set[str] | None) – Extra dt fields to force, e.g. {"dt_plus"}.

  • dataset_index (int) – Position of this dataset within its collection; resolves the reserved dataset_index extra on every row.

Returns:

producer – A function such that producer(t) returns a dict whose leaves are single-row arrays. Structure is fixed across calls.

Return type:

Callable[[Array], Dict[str, Any]]

Notes

  • Use valid_indices() to generate in-bounds indices. The producer assumes its input is valid and does not clamp.

mask: Array | None = None
materialize_time(*, as_numpy=True)[source]

Return a dense absolute time vector t of shape (T,).

Notes

  • If self.t is not None, it is returned as-is.

  • Else, if self.dt is a scalar, use t[k] = k * dt.

  • Else, if self.dt is an array of shape (T,), interpret dt[k] as the step between X[k] and X[k+1] and build

    t[0] = 0 t[k+1] = t[k] + dt[k] for k = 0..T-2

    The last entry dt[T-1] (if any) is ignored.

  • If both t and dt are None, a ValueError is raised.

Parameters:

as_numpy (bool) – If True, return a NumPy array; otherwise a JAX array.

Returns:

t

Return type:

array, shape (T,)

meta: Dict[str, Any] | None
split_time(fraction=0.8, *, gap=0)[source]

Split into (train, test) datasets along the time axis.

A side feature for data-abundant scenarios: SFI estimates its own accuracy from the training data (force_predicted_MSE) and validates fits through the diagnostics suite, neither of which costs any data. Hold out a test fraction only when data is plentiful, or to confirm a suspected bias floor.

Parameters:
  • fraction (float) – Fraction of frames assigned to the train half: train is [0, round(fraction*T)), test is the remainder (after the optional gap).

  • gap (int) – Number of frames dropped between the halves. 0 is safe for increment-based estimators (the boundary increment belongs to neither half by construction); use a few correlation times for slowly mixing systems.

Return type:

Tuple[TrajectoryDataset, TrajectoryDataset]

t: Array | None = None
to_arrays(*, as_numpy=True, include_mask=True)[source]

Materialize dense trajectory arrays for this dataset.

This is intended for plotting and quick inspection, not for JAX integration (use make_producer() for that).

Parameters:
  • as_numpy (bool) – If True (default), return NumPy arrays.

  • include_mask (bool) – If True (default), return the per-particle validity mask.

Returns:

  • t – Absolute time vector of shape (T,); see materialize_time().

  • X – State tensor of shape (T, N, d).

  • mask – Boolean mask of shape (T, N) if include_mask is True, otherwise None.

Return type:

tuple[ndarray, ndarray, ndarray] | tuple[Array, Array, Array | None]

property uuid: str

Stable identity of this dataset.

valid_indices(required, subsampling=None)[source]

Return valid time indices given required streams.

A time index t is valid iff t + amin >= 0 and t + amax <= T-1, where (amin, amax) aggregates all offsets required by streams in required. Extras do not affect the window.

Parameters:
  • required (Set[str]) – Set of stream names and possibly "extras".

  • subsampling (int | None) – Optional positive integer. If provided, keep only indices where t % subsampling == 0 (grid-aligned to multiples of subsampling). This may exclude the first valid index if it is not a multiple of subsampling.

Returns:

1-D array of valid time indices (dtype=int32), possibly empty.

Return type:

jax.Array

SFI.trajectory.function_extra(func)[source]

Build a FunctionExtra from a callable.

Parameters:

func (Callable)

Return type:

FunctionExtra

SFI.trajectory.time_series_extra(x)[source]

Build a TimeSeriesExtra from an array-like.

Parameters:

x (Any)

Return type:

TimeSeriesExtra

Submodules