SFI.trajectory.collection module¶
Trajectory collection: index-driven streaming over datasets.
This module defines TrajectoryCollection, a thin coordinator that:
- stores multiple TrajectoryDataset objects,
- computes per-dataset weights,
- yields (producer, t_idx_chunk) pairs for vmapped integration.
No chunk heuristics live here. The dataset owns valid windows and single-t row production. The integrator vmaps over integer indices and reduces.
Typical loop¶
>>> coll = TrajectoryCollection.from_dataset(ds).with_weights("pool")
>>> for payload in coll.iter_slices(require=req, bytes_hint=bh, chunk_target_bytes=64<<20):
... producer = payload["producer"] # Callable[[t], row]
... t_idx = payload["t_idx"] # (K_chunk,)
... w_ds = payload["weight"] # dataset scalar weight
... # integrator: vmap(lambda t: program(**producer(t)))(t_idx)
Weights¶
Per-dataset weights are unnormalised multipliers applied to every estimator (force, diffusion, parametric). Within-dataset weighting is intrinsic to each estimator: the force is per-dt, the diffusion per-point.
“pool” (default): multiplier 1 for every dataset — pool all increments on equal footing (each dataset then contributes by its effective time for the force, by its point count for the diffusion).
“per_dataset”: each dataset contributes equally (multiplier mean(Teff)/Teff_d).
a sequence of floats: explicit unnormalised multipliers.
Notes
No cross-dataset vectorization. A small Python loop over datasets is intended.
bytes_hint is the per-row memory estimate supplied by the integrator.
- class SFI.trajectory.collection.TrajectoryCollection(datasets, weights)[source]¶
Bases:
objectContainer for one or more trajectories plus per-dataset weights.
This is the main user-facing trajectory object. It wraps a list of
TrajectoryDatasetinstances and exposes an index-based streaming interface used by the integration runtime.Most users should construct collections via
from_arrays(),from_dataset()orload()rather than instantiating this dataclass directly.- Parameters:
datasets (List[SFI.trajectory.dataset.TrajectoryDataset]) – List of underlying
TrajectoryDatasetobjects. The order is preserved in iteration and determines the ordering of theweightsvector.weights (jax.Array) – 1D JAX array of shape
(D,)with non-negative entries, whereD = len(datasets). The vector is normalized to sum to 1 bywith_weights().
Notes
The collection itself does not impose any chunking heuristic. It only coordinates datasets and their weights; the integrator decides how to vmap over the indices returned by
iter_slices().- Teff(required, *, subsampling=1)[source]¶
Total effective exposure time across all datasets.
This is simply the sum of per-dataset Teff values:
sum_d datasets[d].Teff(required, subsampling=subsampling).
- Parameters:
required (Set[str])
subsampling (int)
- Return type:
float
- concat(items, *, weights='pool')[source]¶
Concatenate this collection with other collections or datasets.
- Parameters:
items (Sequence[TrajectoryCollection | TrajectoryDataset]) – Sequence of
TrajectoryCollectionorTrajectoryDatasetinstances. Collections are flattened into their constituent datasets.weights (str | Sequence[float]) – Weight specification for the concatenated collection. See
with_weights()for accepted values.
- Returns:
New collection containing all datasets from
selffollowed by all datasets fromitems.- Return type:
- dataset_index(position)[source]¶
Dense index of dataset
position, keyed on its stable identity.Datasets are numbered by first appearance of their
uuid, so the index a force sees (e.g. viaper_dataset_scalar()ordataset_indicator()) is tied to the dataset itself, not its slot — stable under concatenation and reordering.- Parameters:
position (int)
- Return type:
int
- datasets: List[TrajectoryDataset]¶
- degrade(*, downsample=1, motion_blur=0, data_loss_fraction=0.0, noise=None, ROI=None, seed=None, reweight='pool')[source]¶
Return a new degraded collection; the original is not modified.
This is the preferred user-facing API for degrading synthetic trajectories to mimic experimental noise, blur, and data loss.
- Parameters:
downsample (int) – See
SFI.trajectory.degrade.degrade_collection()for a full description of each parameter.motion_blur (int) – See
SFI.trajectory.degrade.degrade_collection()for a full description of each parameter.data_loss_fraction (float) – See
SFI.trajectory.degrade.degrade_collection()for a full description of each parameter.noise (None | float | ndarray) – See
SFI.trajectory.degrade.degrade_collection()for a full description of each parameter.ROI (None | float | ndarray | Callable[[ndarray], bool]) – See
SFI.trajectory.degrade.degrade_collection()for a full description of each parameter.seed (int | None) – See
SFI.trajectory.degrade.degrade_collection()for a full description of each parameter.reweight (Literal['pool', 'keep']) – See
SFI.trajectory.degrade.degrade_collection()for a full description of each parameter.
- Returns:
New degraded collection. The original collection is not modified.
- Return type:
- classmethod from_arrays(*, X, dt=None, t=None, mask=None, extras_global=None, extras_local=None, meta=None, weights='pool')[source]¶
Build a single-dataset collection from array-likes.
This is the recommended entry point when you already have tensors in memory.
- Parameters:
X (Any) – State array of shape
(T, N, d)or(T, d). If(T, d), a single particle is assumed.dt (float | None) – Either a scalar step, an array of shape
(T,)(per-step), orNone. IfNoneandtis provided, effective steps are derived fromton demand.t (Any | None) – Optional absolute time vector of shape
(T,). If provided, it defines time steps when needed.mask (Any | None) – Optional boolean mask of shape
(T, N)or(T,)marking valid observations. IfNone, all entries are considered valid.extras_global (Dict[str, Any] | None) – Mapping of global extras. Values can be static objects,
TimeSeriesExtra, or JAX-traceable callablesf(t_idx, context=None) -> Arraywith a leading time axis.extras_local (Dict[str, Any] | None) – Mapping of per-particle extras, with the same typing as
extras_global. Time-series entries typically have shape(T, N, ...).meta (Dict[str, Any] | None) – Free-form metadata dictionary attached to the underlying dataset.
weights (str | Sequence[float]) – Initial weight specification for the resulting collection. See
with_weights().
- Returns:
A collection with one dataset built from the provided arrays.
- Return type:
- classmethod from_columns(particle_idx, time_idx, state_vectors, *, extras_global=None, extras_local=None, dt=None, t=None, relabel=True, compress_particles=False, meta=None, weights='pool')[source]¶
Build a single-dataset collection from flat (particle, time) columns.
This constructor is convenient when reading trajectories from a tabular format or a custom pipeline.
- Parameters:
particle_idx (ndarray) – Integer array of shape
(L,)with particle IDs for each row.time_idx (ndarray) – Integer array of shape
(L,)with time indicestfor each row.state_vectors (ndarray) – Array of shape
(L, d)with state vectors.extras_global (Mapping[str, Any] | None) – Parsed global extras (e.g. from YAML header), as described in
SFI.trajectory.io.extras_local (Mapping[str, Any] | None) – Parsed local extras, including time-series extras, as described in
SFI.trajectory.io.dt (float | None) – Optional scalar step; used only if no absolute time axis is provided via
torextras_global['t'].t (ndarray | None) – Optional time vector of shape
(T,)overriding any time axis inferred from extras.relabel (bool) – If True, compress particle IDs to
0..N-1and shift time to start at 0.compress_particles (bool) – If True, apply greedy interval packing to reduce the column count by merging particles whose time supports do not overlap (with a 2-frame buffer). Per-particle extras are reindexed automatically. The mapping is stored in
dataset.meta['particle_column_map'].meta (Dict[str, Any] | None) – Metadata dictionary to attach to the dataset.
weights (str | Sequence[float]) – Initial weight specification for the resulting collection.
- Returns:
A collection with one dataset assembled from the columns.
- Return type:
- classmethod from_dataframe(df, *, particle=None, time=None, coords=None, dt=None, t=None, extras_global=None, extras_local=None, relabel=True, compress_particles=False, meta=None, weights='pool')[source]¶
Build a single-dataset collection from a pandas DataFrame.
The natural entry point for raw tracking tables (trackpy, TrackMate, custom pipelines): columns are addressed by name, in any order, and junk columns are dropped.
- Parameters:
df – A pandas DataFrame with one row per
(particle, time)observation.particle (str | None) – Name of the particle/track-ID column. Default: case-insensitive auto-detection among
particle_id, particle, track_id, track, traj_id; if none is present the table is treated as a single trajectory, and if several are present aValueErrorasks for an explicit choice.time (str | None) – Name of the time column. Default: auto-detection among
time_step, frame, time, t(same ambiguity rule). Integer columns are used as frame indices; float columns are factorized into frame indices and, unlesstordtis given, their sorted unique values become the absolute time axis.coords (Sequence[str] | None) – State-vector column names, in order. Default: every remaining column without an extras prefix (
G_,TG_,P_,TP_), in dataframe order. Columns not selected are silently dropped.dt (float | None) – Time-axis specification, as in
from_columns().t (Any | None) – Time-axis specification, as in
from_columns().extras_global (Mapping[str, Any] | None) – Extra fields merged over any extras parsed from prefixed columns (user values win).
extras_local (Mapping[str, Any] | None) – Extra fields merged over any extras parsed from prefixed columns (user values win).
relabel (bool) – As in
from_columns().compress_particles (bool) – As in
from_columns().meta (Dict[str, Any] | None) – As in
from_columns().weights (str | Sequence[float]) – As in
from_columns().
- Return type:
Examples
>>> coll = TrajectoryCollection.from_dataframe( ... tracks, particle="track_id", time="frame", ... coords=("x", "y"), dt=0.05, ... )
- classmethod from_dataset(ds, *, weights='pool')[source]¶
Wrap a single
TrajectoryDatasetin a collection.- Parameters:
ds (TrajectoryDataset) – The dataset to wrap.
weights (str | Sequence[float]) – Initial weight specification; default
"Teff". Seewith_weights().
- Returns:
A single-dataset collection with weights computed from
ds.- Return type:
- iter_slices(*, require, bytes_hint, chunk_target_bytes=67108864, subsampling=1, context=None)[source]¶
Yield chunks as (producer, t_idx) pairs for vmapped integration.
- Parameters:
require (Set[str]) – Set of stream names required by the integrator (e.g.
{"X","dX","mask"}). Passed toTrajectoryDataset.valid_indices()andTrajectoryDataset.make_producer().bytes_hint (int | None) – Approximate per-row memory footprint (in bytes) of the values produced by the program. If
Noneor<= 0, no chunking is performed and all valid indices are yielded at once.chunk_target_bytes (int) – Target chunk size in bytes. Combined with
bytes_hintto determine how many rows to include in each chunk.subsampling (int) – Optional subsampling factor applied to the time indices before chunking.
context (str | None) – Optional context string passed through to the dataset producer, typically used to switch extra fields.
- Yields:
dict – Mapping with keys:
- Return type:
Iterator[Mapping[str, Any]]
- classmethod load(path, *, relabel=True, compress_particles=False, particle_column='auto', time_column='auto', state_columns=None)[source]¶
Load a collection from a single file or a directory.
- Parameters:
relabel (bool) – If True, compress particle IDs to 0..N-1 and shift time to start at 0.
compress_particles (bool) – If True, further reduce the column count by merging particles whose time supports do not overlap (greedy interval packing with a 2-frame buffer). Useful for open-boundary systems where particles enter and leave the field of view, causing the naive N to grow as the total number of unique particle IDs rather than the concurrent count. Per-particle extras are reindexed automatically; the mapping is stored in
dataset.meta['particle_column_map'].particle_column (int | str | None) – Which columns hold the particle ID and the time index, as a column name (any format) or a positional index (CSV only).
"auto"(default) keeps the loader defaults: CSV positional (column 0 = particle, column 1 = time), parquet/HDF5 the canonical names"particle_id"/"time_step". Passparticle_column=Nonefor single-trajectory files.time_column (int | str) – Which columns hold the particle ID and the time index, as a column name (any format) or a positional index (CSV only).
"auto"(default) keeps the loader defaults: CSV positional (column 0 = particle, column 1 = time), parquet/HDF5 the canonical names"particle_id"/"time_step". Passparticle_column=Nonefor single-trajectory files.state_columns (Sequence[int | str] | None) – Optional explicit selection of the state-vector columns (names, or indices for CSV), in order; every other non-extras column is dropped. Default: all non-ID, non-extras columns.
path (str | Path)
- Return type:
Notes
The default weight policy differs by path: a single-file load uses
"Teff"(viafrom_dataset()); a directory load uses"equal". Callwith_weights()after loading if a consistent policy is needed.
- merge(items, *, weights='pool')[source]¶
Combine this collection with others into one collection.
Convenience alias for
concat()— useful for assembling an ensemble from several single-trajectory collections (base.merge([c1, c2, ...])). Seeconcat()for theweightspolicy.- Parameters:
items (Sequence[TrajectoryCollection | TrajectoryDataset])
weights (str | Sequence[float])
- Return type:
- peek_X()[source]¶
Convenience helper: peek at the “X” stream. Shape-aligned with the first valid row of “X” from peek_row.
- peek_dX()[source]¶
Convenience helper: peek at the “dX” stream. Shape-aligned with the first valid row of “dX” from peek_row.
- peek_dt()[source]¶
Convenience helper: peek at the “dt” stream. Shape-aligned with the first valid row of “dt” from peek_row.
- peek_mask()[source]¶
Convenience helper: peek at the “mask” stream.
Shape-aligned with the first valid row of “mask” from peek_row.
- peek_row(*, require=frozenset({'X', 'dX'}), context=None)[source]¶
Return a single-t sample row from the first dataset with valid indices.
- Parameters:
require (Set[str]) – Set of stream names required for the sample (as in
iter_slices()).context (str | None) – Optional context string forwarded to the producer.
- Returns:
Structure matching
producer(t)for the chosen dataset.- Return type:
dict
Notes
Useful for memory estimation and debugging program outputs.
- save(path, *, format=None, **format_kw)[source]¶
Save the collection.
Notes
Single file path (.csv/.parquet/.h5): collection must have exactly one dataset.
Directory path: write one file per dataset + manifest.yaml.
Masked samples are dropped (no masked rows written).
No relabeling at save-time; relabeling is handled at load-time.
dynamic_maskis not persisted; after a save/load round-trip it will beNone(equivalent to the static mask).
- Parameters:
path (str | Path)
format (str | None)
format_kw (Any)
- Return type:
Path
- split_time(fraction=0.8, *, gap=0, reweight='pool')[source]¶
Split every dataset along time into
(train, test)collections.A side feature for data-abundant scenarios: SFI estimates its own accuracy from the training data (
force_predicted_MSE) and validates fits through the diagnostics suite, neither of which costs any data. Hold out a test fraction only when data is plentiful, or to confirm a suspected bias floor withholdout_score().- Parameters:
fraction (float) – Fraction of frames per dataset assigned to the train half.
gap (int) – Frames dropped between the halves (decorrelation;
0is safe for increment-based estimators).reweight ({"Teff", "keep"}) –
"Teff"(default) recomputes per-dataset weights on each half;"keep"carries over the current relative weights.
- Return type:
Examples
>>> train, test = coll.split_time(0.8) >>> inf = OverdampedLangevinInference(train) >>> # ... fit ... >>> inf.holdout_score(test)
- to_array(*, axis='time', as_numpy=True)[source]¶
Materialize the whole collection as one dense
(T, N, d)array.Concatenates every dataset along the time axis into a single array of positions. Use this for the legitimate non-plotting reach-ins (disk caching, ensemble bootstrap initial conditions, neighbour lists); for plotting, prefer the toolkit functions in
SFI.utils.plotting, and for(t, X, mask)of a single dataset useto_arrays().- Parameters:
axis (Literal['time']) – Only
"time"is supported (axis-0 concatenation).as_numpy (bool) – If True (default), return a NumPy array; else a JAX array.
- Return type:
ndarray, shape
(sum_T, N, d)
- to_arrays(*, dataset=0, as_numpy=True, include_mask=True)[source]¶
Convenience helper: materialize one dataset as dense arrays.
- velocity_array(*, dataset=0, scheme='central', as_numpy=True)[source]¶
Finite-difference velocity
v(t)for one dataset.Reconstructs velocities from stored positions with
SFI.utils.maths.fd_velocity(), matching the secant-velocity convention of the underdamped engine. Handy for building(x, v)phase portraits or held-out evaluation grids from position-only recordings.- Parameters:
dataset (int) – Dataset index inside the collection (default 0).
scheme (Literal['central', 'forward', 'backward']) – Finite-difference stencil; see
SFI.utils.maths.fd_velocity().as_numpy (bool) – If True (default), return a NumPy array; else a JAX array.
- Returns:
v
- Return type:
ndarray, shape
(T, N, d)
- weights: Array¶
- with_weights(spec='pool', *, required=frozenset({'X', 'dX'}), subsampling=1)[source]¶
Set the per-dataset weights (an unnormalised multiplier).
- Parameters:
spec (str | Sequence[float]) –
Inter-dataset weight policy — a per-dataset multiplier applied to every estimator (force, diffusion, parametric). Accepted values:
"pool"(default): multiplier1for all datasets, i.e. pool every increment on equal footing. Combined with each estimator’s intrinsic within-dataset weighting (force is per-dt, diffusion per-point), this weights each dataset by its effective time (force) or point count (diffusion) — the natural maximum-likelihood pooling."per_dataset": each dataset contributes equally regardless of length (multipliermean(Teff)/Teff_d). Exact for the force; for the diffusion it is exact whendtis uniform.a sequence of floats: explicit unnormalised multipliers.
required (Set[str]) – Streams used to compute
Teffin the"per_dataset"policy. SeeTrajectoryDataset.Teff().subsampling (int) – Optional subsampling factor used when counting valid indices.
- Returns:
The same collection with its
weightsfield updated.- Return type:
Notes
Weights are exposed to the integrator via the
"weight"entry in the payloads yielded byiter_slices()and applied in every reduction (sum and mean). They are deliberately unnormalised: the absolute scale cancels in the mean-reduced estimates, while for the force Gram / covariance it sets the information scale (a single dataset carries unit weight).