SFI.trajectory package¶
Trajectory submodule public API.
Exports: - TrajectoryDataset / TrajectoryCollection — main user-facing containers - FunctionExtra / function_extra — pass JAX-traceable callables as extras - TimeSeriesExtra / time_series_extra — wrap time-varying array extras
I/O (save_trajectory, load_trajectory, columns_and_extras_to_dataset) is
available via SFI.trajectory.io but is not re-exported here; most users
should use TrajectoryCollection.save / TrajectoryCollection.load.
- class SFI.trajectory.FunctionExtra(func)[source]¶
Bases:
objectWrapper for a JAX-compatible callable passed through extras.
Unlike plain callables (which are invoked eagerly by
TrajectoryDataset.build_extras()as time-dependent generators), aFunctionExtrais passed through unchanged so the user’s basis function can call it inside JIT.- Parameters:
func (callable) – A JAX-traceable function, e.g.
func(x) -> Array. It will be captured as a compile-time constant inside@jax.jit.
Examples
>>> adhesion = FunctionExtra(lambda x: jnp.exp(-jnp.sum(x**2))) >>> coll = TrajectoryCollection.from_arrays( ... X=X, dt=0.01, ... extras_global={"adhesion": adhesion}, ... )
- func: Callable¶
- class SFI.trajectory.TimeSeriesExtra(data)[source]¶
Bases:
objectWrapper for time-dependent extras with an explicit leading time axis.
- Parameters:
data (jax.Array) – Array with shape
(T, ...)for globals or(T, N, ...)for per-particle extras. The dataset will slicedata[t].
- data: Array¶
- class SFI.trajectory.TrajectoryCollection(datasets, weights)[source]¶
Bases:
objectContainer for one or more trajectories plus per-dataset weights.
This is the main user-facing trajectory object. It wraps a list of
TrajectoryDatasetinstances and exposes an index-based streaming interface used by the integration runtime.Most users should construct collections via
from_arrays(),from_dataset()orload()rather than instantiating this dataclass directly.- Parameters:
datasets (List[SFI.trajectory.dataset.TrajectoryDataset]) – List of underlying
TrajectoryDatasetobjects. The order is preserved in iteration and determines the ordering of theweightsvector.weights (jax.Array) – 1D JAX array of shape
(D,)with non-negative entries, whereD = len(datasets). The vector is normalized to sum to 1 bywith_weights().
Notes
The collection itself does not impose any chunking heuristic. It only coordinates datasets and their weights; the integrator decides how to vmap over the indices returned by
iter_slices().- Teff(required, *, subsampling=1)[source]¶
Total effective exposure time across all datasets.
This is simply the sum of per-dataset Teff values:
sum_d datasets[d].Teff(required, subsampling=subsampling).
- Parameters:
required (Set[str])
subsampling (int)
- Return type:
float
- concat(items, *, weights='pool')[source]¶
Concatenate this collection with other collections or datasets.
- Parameters:
items (Sequence[TrajectoryCollection | TrajectoryDataset]) – Sequence of
TrajectoryCollectionorTrajectoryDatasetinstances. Collections are flattened into their constituent datasets.weights (str | Sequence[float]) – Weight specification for the concatenated collection. See
with_weights()for accepted values.
- Returns:
New collection containing all datasets from
selffollowed by all datasets fromitems.- Return type:
- dataset_index(position)[source]¶
Dense index of dataset
position, keyed on its stable identity.Datasets are numbered by first appearance of their
uuid, so the index a force sees (e.g. viaper_dataset_scalar()ordataset_indicator()) is tied to the dataset itself, not its slot — stable under concatenation and reordering.- Parameters:
position (int)
- Return type:
int
- datasets: List[TrajectoryDataset]¶
- degrade(*, downsample=1, motion_blur=0, data_loss_fraction=0.0, noise=None, ROI=None, seed=None, reweight='pool')[source]¶
Return a new degraded collection; the original is not modified.
This is the preferred user-facing API for degrading synthetic trajectories to mimic experimental noise, blur, and data loss.
- Parameters:
downsample (int) – See
SFI.trajectory.degrade.degrade_collection()for a full description of each parameter.motion_blur (int) – See
SFI.trajectory.degrade.degrade_collection()for a full description of each parameter.data_loss_fraction (float) – See
SFI.trajectory.degrade.degrade_collection()for a full description of each parameter.noise (None | float | ndarray) – See
SFI.trajectory.degrade.degrade_collection()for a full description of each parameter.ROI (None | float | ndarray | Callable[[ndarray], bool]) – See
SFI.trajectory.degrade.degrade_collection()for a full description of each parameter.seed (int | None) – See
SFI.trajectory.degrade.degrade_collection()for a full description of each parameter.reweight (Literal['pool', 'keep']) – See
SFI.trajectory.degrade.degrade_collection()for a full description of each parameter.
- Returns:
New degraded collection. The original collection is not modified.
- Return type:
- classmethod from_arrays(*, X, dt=None, t=None, mask=None, extras_global=None, extras_local=None, meta=None, weights='pool')[source]¶
Build a single-dataset collection from array-likes.
This is the recommended entry point when you already have tensors in memory.
- Parameters:
X (Any) – State array of shape
(T, N, d)or(T, d). If(T, d), a single particle is assumed.dt (float | None) – Either a scalar step, an array of shape
(T,)(per-step), orNone. IfNoneandtis provided, effective steps are derived fromton demand.t (Any | None) – Optional absolute time vector of shape
(T,). If provided, it defines time steps when needed.mask (Any | None) – Optional boolean mask of shape
(T, N)or(T,)marking valid observations. IfNone, all entries are considered valid.extras_global (Dict[str, Any] | None) – Mapping of global extras. Values can be static objects,
TimeSeriesExtra, or JAX-traceable callablesf(t_idx, context=None) -> Arraywith a leading time axis.extras_local (Dict[str, Any] | None) – Mapping of per-particle extras, with the same typing as
extras_global. Time-series entries typically have shape(T, N, ...).meta (Dict[str, Any] | None) – Free-form metadata dictionary attached to the underlying dataset.
weights (str | Sequence[float]) – Initial weight specification for the resulting collection. See
with_weights().
- Returns:
A collection with one dataset built from the provided arrays.
- Return type:
- classmethod from_columns(particle_idx, time_idx, state_vectors, *, extras_global=None, extras_local=None, dt=None, t=None, relabel=True, compress_particles=False, meta=None, weights='pool')[source]¶
Build a single-dataset collection from flat (particle, time) columns.
This constructor is convenient when reading trajectories from a tabular format or a custom pipeline.
- Parameters:
particle_idx (ndarray) – Integer array of shape
(L,)with particle IDs for each row.time_idx (ndarray) – Integer array of shape
(L,)with time indicestfor each row.state_vectors (ndarray) – Array of shape
(L, d)with state vectors.extras_global (Mapping[str, Any] | None) – Parsed global extras (e.g. from YAML header), as described in
SFI.trajectory.io.extras_local (Mapping[str, Any] | None) – Parsed local extras, including time-series extras, as described in
SFI.trajectory.io.dt (float | None) – Optional scalar step; used only if no absolute time axis is provided via
torextras_global['t'].t (ndarray | None) – Optional time vector of shape
(T,)overriding any time axis inferred from extras.relabel (bool) – If True, compress particle IDs to
0..N-1and shift time to start at 0.compress_particles (bool) – If True, apply greedy interval packing to reduce the column count by merging particles whose time supports do not overlap (with a 2-frame buffer). Per-particle extras are reindexed automatically. The mapping is stored in
dataset.meta['particle_column_map'].meta (Dict[str, Any] | None) – Metadata dictionary to attach to the dataset.
weights (str | Sequence[float]) – Initial weight specification for the resulting collection.
- Returns:
A collection with one dataset assembled from the columns.
- Return type:
- classmethod from_dataframe(df, *, particle=None, time=None, coords=None, dt=None, t=None, extras_global=None, extras_local=None, relabel=True, compress_particles=False, meta=None, weights='pool')[source]¶
Build a single-dataset collection from a pandas DataFrame.
The natural entry point for raw tracking tables (trackpy, TrackMate, custom pipelines): columns are addressed by name, in any order, and junk columns are dropped.
- Parameters:
df – A pandas DataFrame with one row per
(particle, time)observation.particle (str | None) – Name of the particle/track-ID column. Default: case-insensitive auto-detection among
particle_id, particle, track_id, track, traj_id; if none is present the table is treated as a single trajectory, and if several are present aValueErrorasks for an explicit choice.time (str | None) – Name of the time column. Default: auto-detection among
time_step, frame, time, t(same ambiguity rule). Integer columns are used as frame indices; float columns are factorized into frame indices and, unlesstordtis given, their sorted unique values become the absolute time axis.coords (Sequence[str] | None) – State-vector column names, in order. Default: every remaining column without an extras prefix (
G_,TG_,P_,TP_), in dataframe order. Columns not selected are silently dropped.dt (float | None) – Time-axis specification, as in
from_columns().t (Any | None) – Time-axis specification, as in
from_columns().extras_global (Mapping[str, Any] | None) – Extra fields merged over any extras parsed from prefixed columns (user values win).
extras_local (Mapping[str, Any] | None) – Extra fields merged over any extras parsed from prefixed columns (user values win).
relabel (bool) – As in
from_columns().compress_particles (bool) – As in
from_columns().meta (Dict[str, Any] | None) – As in
from_columns().weights (str | Sequence[float]) – As in
from_columns().
- Return type:
Examples
>>> coll = TrajectoryCollection.from_dataframe( ... tracks, particle="track_id", time="frame", ... coords=("x", "y"), dt=0.05, ... )
- classmethod from_dataset(ds, *, weights='pool')[source]¶
Wrap a single
TrajectoryDatasetin a collection.- Parameters:
ds (TrajectoryDataset) – The dataset to wrap.
weights (str | Sequence[float]) – Initial weight specification; default
"Teff". Seewith_weights().
- Returns:
A single-dataset collection with weights computed from
ds.- Return type:
- iter_slices(*, require, bytes_hint, chunk_target_bytes=67108864, subsampling=1, context=None)[source]¶
Yield chunks as (producer, t_idx) pairs for vmapped integration.
- Parameters:
require (Set[str]) – Set of stream names required by the integrator (e.g.
{"X","dX","mask"}). Passed toTrajectoryDataset.valid_indices()andTrajectoryDataset.make_producer().bytes_hint (int | None) – Approximate per-row memory footprint (in bytes) of the values produced by the program. If
Noneor<= 0, no chunking is performed and all valid indices are yielded at once.chunk_target_bytes (int) – Target chunk size in bytes. Combined with
bytes_hintto determine how many rows to include in each chunk.subsampling (int) – Optional subsampling factor applied to the time indices before chunking.
context (str | None) – Optional context string passed through to the dataset producer, typically used to switch extra fields.
- Yields:
dict – Mapping with keys:
- Return type:
Iterator[Mapping[str, Any]]
- classmethod load(path, *, relabel=True, compress_particles=False, particle_column='auto', time_column='auto', state_columns=None)[source]¶
Load a collection from a single file or a directory.
- Parameters:
relabel (bool) – If True, compress particle IDs to 0..N-1 and shift time to start at 0.
compress_particles (bool) – If True, further reduce the column count by merging particles whose time supports do not overlap (greedy interval packing with a 2-frame buffer). Useful for open-boundary systems where particles enter and leave the field of view, causing the naive N to grow as the total number of unique particle IDs rather than the concurrent count. Per-particle extras are reindexed automatically; the mapping is stored in
dataset.meta['particle_column_map'].particle_column (int | str | None) – Which columns hold the particle ID and the time index, as a column name (any format) or a positional index (CSV only).
"auto"(default) keeps the loader defaults: CSV positional (column 0 = particle, column 1 = time), parquet/HDF5 the canonical names"particle_id"/"time_step". Passparticle_column=Nonefor single-trajectory files.time_column (int | str) – Which columns hold the particle ID and the time index, as a column name (any format) or a positional index (CSV only).
"auto"(default) keeps the loader defaults: CSV positional (column 0 = particle, column 1 = time), parquet/HDF5 the canonical names"particle_id"/"time_step". Passparticle_column=Nonefor single-trajectory files.state_columns (Sequence[int | str] | None) – Optional explicit selection of the state-vector columns (names, or indices for CSV), in order; every other non-extras column is dropped. Default: all non-ID, non-extras columns.
path (str | Path)
- Return type:
Notes
The default weight policy differs by path: a single-file load uses
"Teff"(viafrom_dataset()); a directory load uses"equal". Callwith_weights()after loading if a consistent policy is needed.
- merge(items, *, weights='pool')[source]¶
Combine this collection with others into one collection.
Convenience alias for
concat()— useful for assembling an ensemble from several single-trajectory collections (base.merge([c1, c2, ...])). Seeconcat()for theweightspolicy.- Parameters:
items (Sequence[TrajectoryCollection | TrajectoryDataset])
weights (str | Sequence[float])
- Return type:
- peek_X()[source]¶
Convenience helper: peek at the “X” stream. Shape-aligned with the first valid row of “X” from peek_row.
- peek_dX()[source]¶
Convenience helper: peek at the “dX” stream. Shape-aligned with the first valid row of “dX” from peek_row.
- peek_dt()[source]¶
Convenience helper: peek at the “dt” stream. Shape-aligned with the first valid row of “dt” from peek_row.
- peek_mask()[source]¶
Convenience helper: peek at the “mask” stream.
Shape-aligned with the first valid row of “mask” from peek_row.
- peek_row(*, require=frozenset({'X', 'dX'}), context=None)[source]¶
Return a single-t sample row from the first dataset with valid indices.
- Parameters:
require (Set[str]) – Set of stream names required for the sample (as in
iter_slices()).context (str | None) – Optional context string forwarded to the producer.
- Returns:
Structure matching
producer(t)for the chosen dataset.- Return type:
dict
Notes
Useful for memory estimation and debugging program outputs.
- save(path, *, format=None, **format_kw)[source]¶
Save the collection.
Notes
Single file path (.csv/.parquet/.h5): collection must have exactly one dataset.
Directory path: write one file per dataset + manifest.yaml.
Masked samples are dropped (no masked rows written).
No relabeling at save-time; relabeling is handled at load-time.
dynamic_maskis not persisted; after a save/load round-trip it will beNone(equivalent to the static mask).
- Parameters:
path (str | Path)
format (str | None)
format_kw (Any)
- Return type:
Path
- split_time(fraction=0.8, *, gap=0, reweight='pool')[source]¶
Split every dataset along time into
(train, test)collections.A side feature for data-abundant scenarios: SFI estimates its own accuracy from the training data (
force_predicted_MSE) and validates fits through the diagnostics suite, neither of which costs any data. Hold out a test fraction only when data is plentiful, or to confirm a suspected bias floor withholdout_score().- Parameters:
fraction (float) – Fraction of frames per dataset assigned to the train half.
gap (int) – Frames dropped between the halves (decorrelation;
0is safe for increment-based estimators).reweight ({"Teff", "keep"}) –
"Teff"(default) recomputes per-dataset weights on each half;"keep"carries over the current relative weights.
- Return type:
Examples
>>> train, test = coll.split_time(0.8) >>> inf = OverdampedLangevinInference(train) >>> # ... fit ... >>> inf.holdout_score(test)
- to_array(*, axis='time', as_numpy=True)[source]¶
Materialize the whole collection as one dense
(T, N, d)array.Concatenates every dataset along the time axis into a single array of positions. Use this for the legitimate non-plotting reach-ins (disk caching, ensemble bootstrap initial conditions, neighbour lists); for plotting, prefer the toolkit functions in
SFI.utils.plotting, and for(t, X, mask)of a single dataset useto_arrays().- Parameters:
axis (Literal['time']) – Only
"time"is supported (axis-0 concatenation).as_numpy (bool) – If True (default), return a NumPy array; else a JAX array.
- Return type:
ndarray, shape
(sum_T, N, d)
- to_arrays(*, dataset=0, as_numpy=True, include_mask=True)[source]¶
Convenience helper: materialize one dataset as dense arrays.
- Parameters:
dataset (int) – Index of the dataset inside the collection (default 0).
as_numpy (bool) – If True, return NumPy arrays.
include_mask (bool) – If True, also return the per-particle mask.
- Returns:
- Return type:
- velocity_array(*, dataset=0, scheme='central', as_numpy=True)[source]¶
Finite-difference velocity
v(t)for one dataset.Reconstructs velocities from stored positions with
SFI.utils.maths.fd_velocity(), matching the secant-velocity convention of the underdamped engine. Handy for building(x, v)phase portraits or held-out evaluation grids from position-only recordings.- Parameters:
dataset (int) – Dataset index inside the collection (default 0).
scheme (Literal['central', 'forward', 'backward']) – Finite-difference stencil; see
SFI.utils.maths.fd_velocity().as_numpy (bool) – If True (default), return a NumPy array; else a JAX array.
- Returns:
v
- Return type:
ndarray, shape
(T, N, d)
- weights: Array¶
- with_weights(spec='pool', *, required=frozenset({'X', 'dX'}), subsampling=1)[source]¶
Set the per-dataset weights (an unnormalised multiplier).
- Parameters:
spec (str | Sequence[float]) –
Inter-dataset weight policy — a per-dataset multiplier applied to every estimator (force, diffusion, parametric). Accepted values:
"pool"(default): multiplier1for all datasets, i.e. pool every increment on equal footing. Combined with each estimator’s intrinsic within-dataset weighting (force is per-dt, diffusion per-point), this weights each dataset by its effective time (force) or point count (diffusion) — the natural maximum-likelihood pooling."per_dataset": each dataset contributes equally regardless of length (multipliermean(Teff)/Teff_d). Exact for the force; for the diffusion it is exact whendtis uniform.a sequence of floats: explicit unnormalised multipliers.
required (Set[str]) – Streams used to compute
Teffin the"per_dataset"policy. SeeTrajectoryDataset.Teff().subsampling (int) – Optional subsampling factor used when counting valid indices.
- Returns:
The same collection with its
weightsfield updated.- Return type:
Notes
Weights are exposed to the integrator via the
"weight"entry in the payloads yielded byiter_slices()and applied in every reduction (sum and mean). They are deliberately unnormalised: the absolute scale cancels in the mean-reduced estimates, while for the force Gram / covariance it sets the information scale (a single dataset carries unit weight).
- class SFI.trajectory.TrajectoryDataset(X, dt=None, t=None, mask=None, dynamic_mask=None, extras_global=<factory>, extras_local=<factory>, meta=<factory>)[source]¶
Bases:
objectImmutable dataset for a single trajectory.
- Parameters:
X (jax.Array) – State array of shape
(T, N, d)or(T, d). If(T, d), N is 1.dt (jax.Array | float | None) – Either a scalar step, an array of shape
(T,)(per-step), orNone. IfNoneandtis provided, steps are derived fromt.t (jax.Array | None) – Optional absolute time vector of shape
(T,). If provided, it defines dt via finite differences when requested.mask (jax.Array | None) – Optional boolean mask of shape
(T, N)or(T,)marking valid observations at time t and particle n (“static mask”). IfNone, all ones. A True entry means the particle’s position is known and can be used for state evaluation (e.g. neighbor forces).dynamic_mask (jax.Array | None) – Optional boolean mask of shape
(T, N)or(T,)marking entries whose increments are reliable and should contribute to parameter fitting (“dynamic mask”). Must be a subset ofmask(dynamic_mask ⊆ mask). IfNone, defaults tomask. Typical use: particles near open boundaries are statically valid (their positions are known) but dynamically masked (their neighborhoods are incomplete, biasing their increments).extras_global (Dict[str, Any] | None) – Dict of global extras. Values are static objects,
TimeSeriesExtra, or JAX-traceable callablesf(t_idx, context=None) -> Arraywith a leading time axis.extras_local (Dict[str, Any] | None) – Dict of per-particle extras. Same typing as
extras_global. Time-series entries typically have shape(T, N, ...).metadata – Free-form metadata.
meta (Dict[str, Any] | None)
- property N: int¶
- property T: int¶
- Teff(required, *, subsampling=1)[source]¶
Effective exposure time over valid indices for weighting.
Defined as
Teff = sum_t N_active[t] * dt[t],
where N_active[t] is the number of active (unmasked) particles at time index t under the same stream requirements used by the integration runtime.
This reuses the same dt logic as
_dt_fields_single()and the same masking logic as_output_mask_single(), so that weighting matches exactly what the runtime sees.- Parameters:
required (Set[str])
subsampling (int)
- Return type:
float
- X: Array¶
- build_extras(t_idx, *, dataset_index=0, context=None)[source]¶
Full model-facing extras at
t_idx: user values plus reserved keys.User extras are sliced at the frame(s) — a
TimeSeriesExtrais indexed, a callable invoked, anything else forwarded — and the reservedtime/duration/dataset_index/particle_indexare resolved for this dataset. This single mapping is what every consumer (inference, simulation, diagnostics) feeds the force/diffusion expression.extras_localoverridesextras_globalon key conflicts.- Parameters:
t_idx (Array)
dataset_index (int)
context (str | None)
- Return type:
Dict[str, Any]
- property d: int¶
- dt: Array | float | None = None¶
- dynamic_mask: Array | None = None¶
- extras_global: Dict[str, Any] | None¶
- extras_local: Dict[str, Any] | None¶
- classmethod from_arrays(*, X, dt=None, t=None, mask=None, dynamic_mask=None, extras_global=None, extras_local=None, meta=None)[source]¶
Construct a dataset from array-likes.
All inputs are converted to JAX arrays where relevant; extras and meta are stored as-is (no deep conversion).
- Return type:
- Raises:
ValueError – If X contains NaN/Inf, has wrong dimensionality, dt <= 0, or the trajectory is too short for any useful computation.
- Parameters:
X (Any)
dt (float | None)
t (Any | None)
mask (Any | None)
dynamic_mask (Any | None)
extras_global (Dict[str, Any] | None)
extras_local (Dict[str, Any] | None)
meta (Dict[str, Any] | None)
- make_batch_producer(require, *, include_mask=True, include_dt=True, context=None, force_dt_keys=None, dataset_index=0)[source]¶
Return a function that gathers a batch of rows in one vectorised pass.
This is the batch counterpart of
make_producer(). Instead of building one row at a time (designed for use insidejax.vmap), this function gathers K rows at once using array indexing, producing arrays with a leadingKaxis.- Parameters:
require (Set[str]) – Same meaning as in
make_producer().include_mask (bool) – Same meaning as in
make_producer().include_dt (bool) – Same meaning as in
make_producer().context (str | None) – Same meaning as in
make_producer().force_dt_keys (Set[str] | None) – Same meaning as in
make_producer().dataset_index (int)
- Returns:
batch_producer –
batch_producer(t_block)witht_blockof shape(K,)returns a dict whose arrays have a leadingKaxis.- Return type:
Callable[[Array], Dict[str, Any]]
Notes
Extras limitation: time-varying extras (
TimeSeriesExtraand callables) are evaluated att_block[0]only, not at each index in the block. This batch producer is designed for use cases where extras are global constants across the chunk (e.g. static boundary tensors). For per-step time-varying extras, usemake_producer()withjax.vmapinstead.
- make_producer(require, *, include_mask=True, include_dt=True, context=None, force_dt_keys=None, dataset_index=0)[source]¶
Return a JAX-traceable function that builds a single-t row.
- Parameters:
require (Set[str]) – Set of stream names and the special key
"extras"if extras are needed by downstream expressions.include_mask (bool) – If True, include
"mask_out"computed fromrequire.include_dt (bool) – If True, include
"dt"and neighbors when needed.context (str | None) – Optional string to pass through to extras callables.
force_dt_keys (Set[str] | None) – Extra dt fields to force, e.g.
{"dt_plus"}.dataset_index (int) – Position of this dataset within its collection; resolves the reserved
dataset_indexextra on every row.
- Returns:
producer – A function such that
producer(t)returns a dict whose leaves are single-row arrays. Structure is fixed across calls.- Return type:
Callable[[Array], Dict[str, Any]]
Notes
Use
valid_indices()to generate in-bounds indices. The producer assumes its input is valid and does not clamp.
- mask: Array | None = None¶
- materialize_time(*, as_numpy=True)[source]¶
Return a dense absolute time vector
tof shape(T,).Notes
If
self.tis not None, it is returned as-is.Else, if
self.dtis a scalar, uset[k] = k * dt.Else, if
self.dtis an array of shape(T,), interpretdt[k]as the step betweenX[k]andX[k+1]and buildt[0] = 0 t[k+1] = t[k] + dt[k] for k = 0..T-2
The last entry
dt[T-1](if any) is ignored.If both
tanddtare None, a ValueError is raised.
- meta: Dict[str, Any] | None¶
- split_time(fraction=0.8, *, gap=0)[source]¶
Split into
(train, test)datasets along the time axis.A side feature for data-abundant scenarios: SFI estimates its own accuracy from the training data (
force_predicted_MSE) and validates fits through the diagnostics suite, neither of which costs any data. Hold out a test fraction only when data is plentiful, or to confirm a suspected bias floor.- Parameters:
fraction (float) – Fraction of frames assigned to the train half: train is
[0, round(fraction*T)), test is the remainder (after the optionalgap).gap (int) – Number of frames dropped between the halves.
0is safe for increment-based estimators (the boundary increment belongs to neither half by construction); use a few correlation times for slowly mixing systems.
- Return type:
Tuple[TrajectoryDataset, TrajectoryDataset]
- t: Array | None = None¶
- to_arrays(*, as_numpy=True, include_mask=True)[source]¶
Materialize dense trajectory arrays for this dataset.
This is intended for plotting and quick inspection, not for JAX integration (use
make_producer()for that).- Parameters:
as_numpy (bool) – If True (default), return NumPy arrays.
include_mask (bool) – If True (default), return the per-particle validity mask.
- Returns:
t – Absolute time vector of shape
(T,); seematerialize_time().X – State tensor of shape
(T, N, d).mask – Boolean mask of shape
(T, N)ifinclude_maskis True, otherwiseNone.
- Return type:
tuple[ndarray, ndarray, ndarray] | tuple[Array, Array, Array | None]
- property uuid: str¶
Stable identity of this dataset.
- valid_indices(required, subsampling=None)[source]¶
Return valid time indices given required streams.
A time index
tis valid ifft + amin >= 0andt + amax <= T-1, where(amin, amax)aggregates all offsets required by streams inrequired. Extras do not affect the window.- Parameters:
required (Set[str]) – Set of stream names and possibly
"extras".subsampling (int | None) – Optional positive integer. If provided, keep only indices where
t % subsampling == 0(grid-aligned to multiples ofsubsampling). This may exclude the first valid index if it is not a multiple ofsubsampling.
- Returns:
1-D array of valid time indices (dtype=int32), possibly empty.
- Return type:
jax.Array
- SFI.trajectory.function_extra(func)[source]¶
Build a
FunctionExtrafrom a callable.- Parameters:
func (Callable)
- Return type:
- SFI.trajectory.time_series_extra(x)[source]¶
Build a
TimeSeriesExtrafrom an array-like.- Parameters:
x (Any)
- Return type:
Submodules¶
- SFI.trajectory.collection module
- Typical loop
- Weights
TrajectoryCollectionTrajectoryCollection.Teff()TrajectoryCollection.concat()TrajectoryCollection.dataset_index()TrajectoryCollection.datasetsTrajectoryCollection.degrade()TrajectoryCollection.from_arrays()TrajectoryCollection.from_columns()TrajectoryCollection.from_dataframe()TrajectoryCollection.from_dataset()TrajectoryCollection.iter_slices()TrajectoryCollection.load()TrajectoryCollection.merge()TrajectoryCollection.peek_X()TrajectoryCollection.peek_dX()TrajectoryCollection.peek_dt()TrajectoryCollection.peek_mask()TrajectoryCollection.peek_row()TrajectoryCollection.save()TrajectoryCollection.split_time()TrajectoryCollection.to_array()TrajectoryCollection.to_arrays()TrajectoryCollection.velocity_array()TrajectoryCollection.weightsTrajectoryCollection.with_weights()
- SFI.trajectory.dataset module
- Key APIs
- Streams
- Boundary policy
FunctionExtraTimeSeriesExtraTrajectoryDatasetTrajectoryDataset.NTrajectoryDataset.TTrajectoryDataset.Teff()TrajectoryDataset.XTrajectoryDataset.build_extras()TrajectoryDataset.dTrajectoryDataset.dtTrajectoryDataset.dynamic_maskTrajectoryDataset.extras_globalTrajectoryDataset.extras_localTrajectoryDataset.from_arrays()TrajectoryDataset.make_batch_producer()TrajectoryDataset.make_producer()TrajectoryDataset.maskTrajectoryDataset.materialize_time()TrajectoryDataset.metaTrajectoryDataset.split_time()TrajectoryDataset.tTrajectoryDataset.to_arrays()TrajectoryDataset.uuidTrajectoryDataset.valid_indices()
function_extra()time_series_extra()
- SFI.trajectory.degrade module
- SFI.trajectory.io module
- SFI.trajectory.reserved_extras module