# collection.py
"""
Trajectory collection: index-driven streaming over datasets.
This module defines :class:`TrajectoryCollection`, a thin coordinator that:
- stores multiple :class:`TrajectoryDataset` objects,
- computes per-dataset weights,
- yields **(producer, t_idx_chunk)** pairs for vmapped integration.
No chunk heuristics live here. The dataset owns valid windows and single-t row
production. The integrator vmaps over integer indices and reduces.
Typical loop
------------
>>> coll = TrajectoryCollection.from_dataset(ds).with_weights("pool")
>>> for payload in coll.iter_slices(require=req, bytes_hint=bh, chunk_target_bytes=64<<20):
... producer = payload["producer"] # Callable[[t], row]
... t_idx = payload["t_idx"] # (K_chunk,)
... w_ds = payload["weight"] # dataset scalar weight
... # integrator: vmap(lambda t: program(**producer(t)))(t_idx)
Weights
-------
Per-dataset weights are **unnormalised** multipliers applied to every
estimator (force, diffusion, parametric). Within-dataset weighting is
intrinsic to each estimator: the force is per-dt, the diffusion per-point.
- "pool" (default): multiplier 1 for every dataset — pool all increments on
equal footing (each dataset then contributes by its effective time for the
force, by its point count for the diffusion).
- "per_dataset": each dataset contributes equally (multiplier mean(Teff)/Teff_d).
- a sequence of floats: explicit unnormalised multipliers.
Notes
-----
- No cross-dataset vectorization. A small Python loop over datasets is intended.
- `bytes_hint` is the per-row memory estimate supplied by the integrator.
"""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import (
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
Literal,
Mapping,
Optional,
Sequence,
Set,
Union,
)
import jax.numpy as jnp
import numpy as np
from SFI.trajectory.dataset import TimeSeriesExtra, TrajectoryDataset, time_series_extra
from .io import (
_parse_tabular_with_extras,
columns_and_extras_to_dataset,
flatten_X_to_columns,
load_trajectory,
save_trajectory,
)
WeightSpec = Union[str, Sequence[float]]
def _is_single_file(path: Union[str, Path]) -> bool:
p = Path(path)
return p.suffix.lower() in {".csv", ".parquet", ".pq", ".h5", ".hdf5"}
[docs]
@dataclass
class TrajectoryCollection:
"""
Container for one or more trajectories plus per-dataset weights.
This is the main user-facing trajectory object. It wraps a list of
:class:`TrajectoryDataset` instances and exposes an index-based streaming
interface used by the integration runtime.
Most users should construct collections via :meth:`from_arrays`,
:meth:`from_dataset` or :meth:`load` rather than instantiating this
dataclass directly.
Parameters
----------
datasets
List of underlying :class:`TrajectoryDataset` objects. The order is
preserved in iteration and determines the ordering of the ``weights``
vector.
weights
1D JAX array of shape ``(D,)`` with non-negative entries, where
``D = len(datasets)``. The vector is normalized to sum to 1 by
:meth:`with_weights`.
Notes
-----
The collection itself does not impose any chunking heuristic. It only
coordinates datasets and their weights; the integrator decides how to
vmap over the indices returned by :meth:`iter_slices`.
"""
datasets: List[TrajectoryDataset]
weights: jnp.ndarray # shape (D,), normalized to sum to 1
# ---------- construction ----------
[docs]
@classmethod
def from_dataset(cls, ds: TrajectoryDataset, *, weights: WeightSpec = "pool") -> "TrajectoryCollection":
"""Wrap a single :class:`TrajectoryDataset` in a collection.
Parameters
----------
ds
The dataset to wrap.
weights
Initial weight specification; default ``"Teff"``.
See :meth:`with_weights`.
Returns
-------
TrajectoryCollection
A single-dataset collection with weights computed from ``ds``.
"""
coll = cls([ds], jnp.array([1.0], dtype=jnp.float32))
return coll.with_weights(weights)
[docs]
def concat(
self,
items: Sequence[Union["TrajectoryCollection", TrajectoryDataset]],
*,
weights: WeightSpec = "pool",
) -> "TrajectoryCollection":
"""
Concatenate this collection with other collections or datasets.
Parameters
----------
items
Sequence of :class:`TrajectoryCollection` or
:class:`TrajectoryDataset` instances. Collections are flattened
into their constituent datasets.
weights
Weight specification for the concatenated collection. See
:meth:`with_weights` for accepted values.
Returns
-------
TrajectoryCollection
New collection containing all datasets from ``self`` followed
by all datasets from ``items``.
"""
merged: List[TrajectoryDataset] = []
merged.extend(self.datasets)
for it in items:
if isinstance(it, TrajectoryCollection):
merged.extend(it.datasets)
else:
merged.append(it)
out = TrajectoryCollection(merged, jnp.ones((len(merged),), dtype=jnp.float32))
return out.with_weights(weights)
def __and__(
self, other: Union["TrajectoryCollection", TrajectoryDataset]
) -> "TrajectoryCollection":
"""Merge collections (or a collection and a dataset) with ``&``.
``c1 & c2`` appends the datasets of ``other`` to those of ``self``
with the default ``"pool"`` policy (every increment on equal
footing). It chains naturally (``c1 & c2 & c3``); call
:meth:`concat` or :meth:`with_weights` for the ``"per_dataset"``
policy or explicit weights.
"""
return self.concat([other])
# ---------- weighting ----------
[docs]
def with_weights(
self,
spec: WeightSpec = "pool",
*,
required: Set[str] = frozenset({"X", "dX"}),
subsampling: int = 1,
) -> "TrajectoryCollection":
"""
Set the per-dataset weights (an **unnormalised** multiplier).
Parameters
----------
spec
Inter-dataset weight policy — a per-dataset multiplier applied to
every estimator (force, diffusion, parametric). Accepted values:
- ``"pool"`` (default): multiplier ``1`` for all datasets, i.e.
pool every increment on equal footing. Combined with each
estimator's intrinsic within-dataset weighting (force is per-dt,
diffusion per-point), this weights each dataset by its effective
time (force) or point count (diffusion) — the natural
maximum-likelihood pooling.
- ``"per_dataset"``: each dataset contributes equally regardless of
length (multiplier ``mean(Teff)/Teff_d``). Exact for the force;
for the diffusion it is exact when ``dt`` is uniform.
- a sequence of floats: explicit unnormalised multipliers.
required
Streams used to compute ``Teff`` in the ``"per_dataset"`` policy.
See :meth:`TrajectoryDataset.Teff`.
subsampling
Optional subsampling factor used when counting valid indices.
Returns
-------
TrajectoryCollection
The same collection with its :attr:`weights` field updated.
Notes
-----
Weights are exposed to the integrator via the ``"weight"`` entry in the
payloads yielded by :meth:`iter_slices` and applied in every reduction
(sum and mean). They are deliberately **unnormalised**: the absolute
scale cancels in the mean-reduced estimates, while for the force Gram /
covariance it sets the information scale (a single dataset carries unit
weight).
"""
D = len(self.datasets)
if isinstance(spec, str):
if spec == "pool":
w = jnp.ones((D,), dtype=jnp.float32)
elif spec == "per_dataset":
teffs = [float(self.datasets[i].Teff(required, subsampling=subsampling)) for i in range(D)]
pos = [t for t in teffs if t > 0]
mean_teff = (sum(pos) / len(pos)) if pos else 1.0
w = jnp.array([(mean_teff / t) if t > 0 else 0.0 for t in teffs], dtype=jnp.float32)
else:
raise ValueError(
f"unknown weight policy {spec!r}; use 'pool', 'per_dataset', "
"or an explicit per-dataset multiplier array."
)
else:
w = jnp.array(spec, dtype=jnp.float32)
if w.shape != (D,):
raise ValueError(f"weights length mismatch: got {w.shape}, expected {(D,)}")
# Unnormalised relative multipliers: scale cancels in mean-reduced
# estimates but sets the force Gram / covariance scale.
self.weights = w
return self
# ---------- streaming ----------
[docs]
def iter_slices(
self,
*,
require: Set[str],
bytes_hint: Optional[int],
chunk_target_bytes: int = 64 * 1024**2,
subsampling: int = 1,
context: Optional[str] = None,
) -> Iterator[Mapping[str, Any]]:
"""
Yield chunks as (producer, t_idx) pairs for vmapped integration.
Parameters
----------
require
Set of stream names required by the integrator (e.g. ``{"X","dX","mask"}``).
Passed to :meth:`TrajectoryDataset.valid_indices` and
:meth:`TrajectoryDataset.make_producer`.
bytes_hint
Approximate per-row memory footprint (in bytes) of the values
produced by the program. If ``None`` or ``<= 0``, no chunking is
performed and all valid indices are yielded at once.
chunk_target_bytes
Target chunk size in bytes. Combined with ``bytes_hint`` to
determine how many rows to include in each chunk.
subsampling
Optional subsampling factor applied to the time indices before
chunking.
context
Optional context string passed through to the dataset producer,
typically used to switch extra fields.
Yields
------
dict
Mapping with keys:
- ``"producer"``: ``Callable[[jax.Array], dict]``, single-t row builder.
- ``"t_idx"``: 1D JAX array of integer time indices.
- ``"dataset_index"``: index of the underlying dataset in
:attr:`datasets`.
- ``"weight"``: float dataset weight, taken from :attr:`weights`.
"""
if subsampling <= 0:
raise ValueError("subsampling must be a positive integer")
for ds_idx, ds in enumerate(self.datasets):
base_idx = ds.valid_indices(require, subsampling=subsampling)
if base_idx.size == 0:
continue
producer = ds.make_producer(
require,
include_mask=True,
include_dt=True,
context=context,
force_dt_keys={"dt"},
dataset_index=self.dataset_index(ds_idx),
)
if not bytes_hint or bytes_hint <= 0:
yield {
"producer": producer,
"t_idx": base_idx,
"dataset_index": ds_idx,
"weight": float(self.weights[ds_idx]),
}
continue
total = int(base_idx.size)
rows_per_chunk = min(total, max(1, int(chunk_target_bytes // int(bytes_hint))))
for start in range(0, total, rows_per_chunk):
sel = base_idx[start : start + rows_per_chunk]
yield {
"producer": producer,
"t_idx": sel,
"dataset_index": ds_idx,
"weight": float(self.weights[ds_idx]),
}
[docs]
def peek_row(
self,
*,
require: Set[str] = frozenset({"X", "dX"}),
context: Optional[str] = None,
) -> Mapping[str, Any]:
"""
Return a single-t sample row from the first dataset with valid indices.
Parameters
----------
require
Set of stream names required for the sample (as in
:meth:`iter_slices`).
context
Optional context string forwarded to the producer.
Returns
-------
dict
Structure matching ``producer(t)`` for the chosen dataset.
Notes
-----
Useful for memory estimation and debugging program outputs.
"""
for ds_idx, ds in enumerate(self.datasets):
idx = ds.valid_indices(require)
if idx.size == 0:
continue
t0 = idx[:1][0]
producer = ds.make_producer(
require,
include_mask=True,
include_dt=True,
context=context,
force_dt_keys={"dt"},
dataset_index=self.dataset_index(ds_idx),
)
return producer(t0)
raise ValueError("peek_row: no dataset has valid indices for the requested streams.")
[docs]
def peek_X(self):
"""Convenience helper: peek at the "X" stream. Shape-aligned with the first valid row of "X" from peek_row."""
row = self.peek_row(require={"X"})
return row["X"]
[docs]
def peek_dX(self):
"""Convenience helper: peek at the "dX" stream. Shape-aligned with the first valid row of "dX" from peek_row."""
row = self.peek_row(require={"dX"})
return row["dX"]
[docs]
def peek_mask(self):
"""Convenience helper: peek at the "mask" stream.
Shape-aligned with the first valid row of "mask" from peek_row.
"""
row = self.peek_row(require={"mask"})
return row["mask"]
[docs]
def peek_dt(self):
"""Convenience helper: peek at the "dt" stream. Shape-aligned with the first valid row of "dt" from peek_row."""
row = self.peek_row(require={"dt"})
return row["dt"]
# ---------- aggregate Teff over datasets ---------- #
[docs]
def Teff(self, required: Set[str], *, subsampling: int = 1) -> float:
"""Total effective exposure time across all datasets.
This is simply the sum of per-dataset Teff values:
sum_d datasets[d].Teff(required, subsampling=subsampling).
"""
return float(sum(ds.Teff(required, subsampling=subsampling) for ds in self.datasets))
# ---------- persistence API ----------
[docs]
def save(
self,
path: Union[str, Path],
*,
format: Optional[str] = None,
**format_kw: Any,
) -> Path:
"""
Save the collection.
Rules
-----
- Single file path (.csv/.parquet/.h5): collection must have exactly one dataset.
- Directory path: write one file per dataset + manifest.yaml.
- Masked samples are dropped (no masked rows written).
- No relabeling at save-time; relabeling is handled at load-time.
- ``dynamic_mask`` is not persisted; after a save/load round-trip it
will be ``None`` (equivalent to the static mask).
"""
dst = Path(path)
def _write_one(ds: TrajectoryDataset, filename: Path, fmt: Optional[str]) -> None:
# flatten and drop masked rows
X = np.asarray(ds._X3d())
M = np.asarray(ds._M2d(), dtype=bool)
pid, tidx, vecs = flatten_X_to_columns(X, mask=M)
# extras: persist t or dt if present
eg = dict(ds.extras_global or {})
el = dict(ds.extras_local or {})
if ds.t is not None and "t" not in eg:
eg["t"] = time_series_extra(np.asarray(np.array(ds.t)))
elif ds.dt is not None and "t" not in eg and "dt" not in eg:
dta = np.asarray(ds.dt)
eg["dt"] = time_series_extra(dta) if dta.ndim == 1 else float(dta)
meta = dict(ds.meta or {})
save_trajectory(
str(filename),
particle_idx=pid,
time_idx=tidx,
state_vectors=vecs,
extras_global=eg,
extras_local=el,
metadata=meta,
format=fmt,
**format_kw,
)
# Case A: single file
if _is_single_file(dst):
if len(self.datasets) != 1:
raise ValueError("Saving to a single file requires exactly one dataset.")
_write_one(self.datasets[0], dst, format)
return dst
# Case B: directory of per-dataset files
dst.mkdir(parents=True, exist_ok=True)
fmt = (format or "parquet").lower()
if fmt not in {"csv", "parquet", "h5"}:
raise ValueError("format must be 'csv', 'parquet', or 'h5'.")
ext = {"csv": ".csv", "parquet": ".parquet", "h5": ".h5"}[fmt]
entries = []
for i, ds in enumerate(self.datasets):
rel = Path(f"ds_{i:03d}{ext}")
_write_one(ds, dst / rel, fmt)
entries.append({"name": getattr(ds, "name", f"dataset_{i}"), "file": rel.as_posix()})
import yaml
manifest = {"version": 1, "n_datasets": len(entries), "datasets": entries}
(dst / "manifest.yaml").write_text(yaml.safe_dump(manifest), encoding="utf-8")
return dst
[docs]
@classmethod
def load(
cls,
path: Union[str, Path],
*,
relabel: bool = True,
compress_particles: bool = False,
particle_column: Union[int, str, None] = "auto",
time_column: Union[int, str] = "auto",
state_columns: Optional[Sequence[Union[int, str]]] = None,
) -> "TrajectoryCollection":
"""
Load a collection from a single file or a directory.
Parameters
----------
relabel
If True, compress particle IDs to 0..N-1 and shift time to start at 0.
compress_particles
If True, further reduce the column count by merging particles whose
time supports do not overlap (greedy interval packing with a 2-frame
buffer). Useful for open-boundary systems where particles enter and
leave the field of view, causing the naive N to grow as the total
number of unique particle IDs rather than the concurrent count.
Per-particle extras are reindexed automatically; the mapping is
stored in ``dataset.meta['particle_column_map']``.
particle_column, time_column
Which columns hold the particle ID and the time index, as a
column *name* (any format) or a positional *index* (CSV only).
``"auto"`` (default) keeps the loader defaults: CSV positional
(column 0 = particle, column 1 = time), parquet/HDF5 the
canonical names ``"particle_id"`` / ``"time_step"``. Pass
``particle_column=None`` for single-trajectory files.
state_columns
Optional explicit selection of the state-vector columns
(names, or indices for CSV), in order; every other non-extras
column is dropped. Default: all non-ID, non-extras columns.
Notes
-----
The default weight policy differs by path: a single-file load uses
``"Teff"`` (via :meth:`from_dataset`); a directory load uses
``"equal"``. Call :meth:`with_weights` after loading if a consistent
policy is needed.
"""
src = Path(path)
column_kw: Dict[str, Any] = {}
if particle_column != "auto":
column_kw["particle_column"] = particle_column
if time_column != "auto":
column_kw["time_column"] = time_column
if state_columns is not None:
column_kw["state_columns"] = state_columns
def _read_one(filename: Path) -> TrajectoryDataset:
yaml_meta, _cols, pid, tidx, vecs, eg, el = load_trajectory(
str(filename), relabel=relabel, **column_kw
)
# Prefer t from extras; else use dt if present in extras or YAML meta
dt_pass = None
if "t" not in eg and "dt" in eg:
v = eg["dt"]
dt_pass = v.data if isinstance(v, TimeSeriesExtra) else float(v)
elif "t" not in eg and yaml_meta and "dt" in yaml_meta:
dt_pass = float(yaml_meta["dt"])
return columns_and_extras_to_dataset(
pid,
tidx,
vecs,
extras_global=eg,
extras_local=el,
dt=dt_pass,
relabel=relabel,
compress_particles=compress_particles,
meta=yaml_meta,
)
# Single file → one dataset
if _is_single_file(src):
ds = _read_one(src)
return cls.from_dataset(ds)
# Directory → many datasets
if not src.is_dir():
raise FileNotFoundError(f"No such file or directory: {src}")
files: Iterable[Path]
manifest = src / "manifest.yaml"
if manifest.exists():
import yaml
man = yaml.safe_load(manifest.read_text(encoding="utf-8")) or {}
files = [src / Path(e["file"]) for e in man.get("datasets", [])]
else:
files = sorted(p for p in src.iterdir() if p.suffix.lower() in {".csv", ".parquet", ".pq", ".h5", ".hdf5"})
datasets = [_read_one(fp) for fp in files]
return cls(datasets=datasets, weights=jnp.ones((len(datasets),), dtype=jnp.float32)).with_weights("pool")
# Constructors:
[docs]
@classmethod
def from_arrays(
cls,
*,
X: Any,
dt: Optional[float] = None,
t: Optional[Any] = None,
mask: Optional[Any] = None,
extras_global: Optional[Dict[str, Any]] = None,
extras_local: Optional[Dict[str, Any]] = None,
meta: Optional[Dict[str, Any]] = None,
weights: WeightSpec = "pool",
) -> "TrajectoryCollection":
"""
Build a single-dataset collection from array-likes.
This is the recommended entry point when you already have tensors
in memory.
Parameters
----------
X
State array of shape ``(T, N, d)`` or ``(T, d)``. If ``(T, d)``,
a single particle is assumed.
dt
Either a scalar step, an array of shape ``(T,)`` (per-step),
or ``None``. If ``None`` and ``t`` is provided, effective steps
are derived from ``t`` on demand.
t
Optional absolute time vector of shape ``(T,)``. If provided, it
defines time steps when needed.
mask
Optional boolean mask of shape ``(T, N)`` or ``(T,)`` marking valid
observations. If ``None``, all entries are considered valid.
extras_global
Mapping of global extras. Values can be static objects,
:class:`TimeSeriesExtra`, or JAX-traceable callables
``f(t_idx, context=None) -> Array`` with a leading time axis.
extras_local
Mapping of per-particle extras, with the same typing as
``extras_global``. Time-series entries typically have shape
``(T, N, ...)``.
meta
Free-form metadata dictionary attached to the underlying dataset.
weights
Initial weight specification for the resulting collection.
See :meth:`with_weights`.
Returns
-------
TrajectoryCollection
A collection with one dataset built from the provided arrays.
"""
ds = TrajectoryDataset.from_arrays(
X=X,
dt=dt,
t=t,
mask=mask,
extras_global=extras_global,
extras_local=extras_local,
meta=meta,
)
return cls.from_dataset(ds, weights=weights)
[docs]
@classmethod
def from_columns(
cls,
particle_idx: np.ndarray,
time_idx: np.ndarray,
state_vectors: np.ndarray,
*,
extras_global: Mapping[str, Any] | None = None,
extras_local: Mapping[str, Any] | None = None,
dt: Optional[float] = None,
t: Optional[np.ndarray] = None,
relabel: bool = True,
compress_particles: bool = False,
meta: Optional[Dict[str, Any]] = None,
weights: WeightSpec = "pool",
) -> "TrajectoryCollection":
"""
Build a single-dataset collection from flat (particle, time) columns.
This constructor is convenient when reading trajectories from a
tabular format or a custom pipeline.
Parameters
----------
particle_idx
Integer array of shape ``(L,)`` with particle IDs for each row.
time_idx
Integer array of shape ``(L,)`` with time indices ``t`` for each row.
state_vectors
Array of shape ``(L, d)`` with state vectors.
extras_global
Parsed global extras (e.g. from YAML header), as described in
:mod:`SFI.trajectory.io`.
extras_local
Parsed local extras, including time-series extras, as described
in :mod:`SFI.trajectory.io`.
dt
Optional scalar step; used only if no absolute time axis is
provided via ``t`` or ``extras_global['t']``.
t
Optional time vector of shape ``(T,)`` overriding any time axis
inferred from extras.
relabel
If True, compress particle IDs to ``0..N-1`` and shift time to
start at 0.
compress_particles
If True, apply greedy interval packing to reduce the column count
by merging particles whose time supports do not overlap (with a
2-frame buffer). Per-particle extras are reindexed automatically.
The mapping is stored in ``dataset.meta['particle_column_map']``.
meta
Metadata dictionary to attach to the dataset.
weights
Initial weight specification for the resulting collection.
Returns
-------
TrajectoryCollection
A collection with one dataset assembled from the columns.
"""
ds = columns_and_extras_to_dataset(
particle_idx,
time_idx,
state_vectors,
extras_global=extras_global,
extras_local=extras_local,
dt=dt,
t=t,
relabel=relabel,
compress_particles=compress_particles,
meta=meta,
)
return cls.from_dataset(ds, weights=weights)
#: Column-name candidates tried (case-insensitively) by `from_dataframe`.
_PARTICLE_COLUMN_CANDIDATES = ("particle_id", "particle", "track_id", "track", "traj_id")
_TIME_COLUMN_CANDIDATES = ("time_step", "frame", "time", "t")
[docs]
@classmethod
def from_dataframe(
cls,
df,
*,
particle: Optional[str] = None,
time: Optional[str] = None,
coords: Optional[Sequence[str]] = None,
dt: Optional[float] = None,
t: Optional[Any] = None,
extras_global: Mapping[str, Any] | None = None,
extras_local: Mapping[str, Any] | None = None,
relabel: bool = True,
compress_particles: bool = False,
meta: Optional[Dict[str, Any]] = None,
weights: WeightSpec = "pool",
) -> "TrajectoryCollection":
"""
Build a single-dataset collection from a pandas DataFrame.
The natural entry point for raw tracking tables (trackpy,
TrackMate, custom pipelines): columns are addressed by *name*, in
any order, and junk columns are dropped.
Parameters
----------
df
A pandas DataFrame with one row per ``(particle, time)``
observation.
particle
Name of the particle/track-ID column. Default: case-insensitive
auto-detection among ``particle_id, particle, track_id, track,
traj_id``; if none is present the table is treated as a single
trajectory, and if several are present a ``ValueError`` asks
for an explicit choice.
time
Name of the time column. Default: auto-detection among
``time_step, frame, time, t`` (same ambiguity rule). Integer
columns are used as frame indices; float columns are
factorized into frame indices and, unless ``t`` or ``dt`` is
given, their sorted unique values become the absolute time
axis.
coords
State-vector column names, in order. Default: every remaining
column without an extras prefix (``G_``, ``TG_``, ``P_``,
``TP_``), in dataframe order. Columns not selected are
silently dropped.
dt, t
Time-axis specification, as in :meth:`from_columns`.
extras_global, extras_local
Extra fields merged **over** any extras parsed from prefixed
columns (user values win).
relabel, compress_particles, meta, weights
As in :meth:`from_columns`.
Examples
--------
>>> coll = TrajectoryCollection.from_dataframe(
... tracks, particle="track_id", time="frame",
... coords=("x", "y"), dt=0.05,
... )
"""
try:
import pandas as pd # noqa: F401
except Exception as e: # pragma: no cover
raise ImportError("TrajectoryCollection.from_dataframe requires pandas.") from e
colnames = list(df.columns)
prefixes = ("G_", "TG_", "P_", "TP_")
def _pick(explicit: Optional[str], candidates: Sequence[str], what: str, required: bool) -> Optional[str]:
if explicit is not None:
if explicit not in colnames:
raise ValueError(f"{what} column {explicit!r} not found; available columns: {colnames}")
return explicit
lower = {c.lower(): c for c in colnames}
hits = [lower[c] for c in candidates if c in lower]
if len(hits) > 1:
raise ValueError(
f"ambiguous {what} column — found {hits}; pass {what}= explicitly"
)
if hits:
return hits[0]
if required:
raise ValueError(
f"no {what} column found (tried {tuple(candidates)}); pass {what}= explicitly"
)
return None
particle_name = _pick(particle, cls._PARTICLE_COLUMN_CANDIDATES, "particle", required=False)
time_name = _pick(time, cls._TIME_COLUMN_CANDIDATES, "time", required=True)
if coords is None:
skip = {particle_name, time_name}
coord_names = [
c for c in colnames if c not in skip and not any(c.startswith(p) for p in prefixes)
]
else:
coord_names = list(coords)
missing = [c for c in coord_names if c not in colnames]
if missing:
raise ValueError(f"coords columns not found: {missing}; available columns: {colnames}")
if not coord_names:
raise ValueError("no state (coordinate) columns selected")
keep = ([particle_name] if particle_name else []) + [time_name] + coord_names
keep += [c for c in colnames if any(c.startswith(p) for p in prefixes) and c not in keep]
sub = df[keep].copy()
# Float time column → factorize to frame indices (+ time axis).
t_resolved = t
tvals = np.asarray(sub[time_name].to_numpy())
if not np.issubdtype(tvals.dtype, np.integer):
uniq, inv = np.unique(tvals, return_inverse=True)
if t is None and dt is None:
t_resolved = uniq
sub[time_name] = inv.astype(int)
_meta, _cols, pid, tidx, vecs, eg, el = _parse_tabular_with_extras(
sub,
{},
particle_column=particle_name,
time_column=time_name,
relabel=relabel,
)
eg = {**eg, **(dict(extras_global) if extras_global else {})}
el = {**el, **(dict(extras_local) if extras_local else {})}
return cls.from_columns(
pid,
tidx,
vecs,
extras_global=eg,
extras_local=el,
dt=dt,
t=t_resolved,
relabel=relabel,
compress_particles=compress_particles,
meta=meta,
weights=weights,
)
[docs]
def split_time(
self,
fraction: float = 0.8,
*,
gap: int = 0,
reweight: Literal["pool", "keep"] = "pool",
) -> tuple["TrajectoryCollection", "TrajectoryCollection"]:
"""Split every dataset along time into ``(train, test)`` collections.
A side feature for data-abundant scenarios: SFI estimates its own
accuracy from the training data (``force_predicted_MSE``) and
validates fits through the diagnostics suite, neither of which
costs any data. Hold out a test fraction only when data is
plentiful, or to confirm a suspected bias floor with
:meth:`~SFI.inference.base.BaseLangevinInference.holdout_score`.
Parameters
----------
fraction : float
Fraction of frames per dataset assigned to the train half.
gap : int
Frames dropped between the halves (decorrelation; ``0`` is
safe for increment-based estimators).
reweight : {"Teff", "keep"}
``"Teff"`` (default) recomputes per-dataset weights on each
half; ``"keep"`` carries over the current relative weights.
Examples
--------
>>> train, test = coll.split_time(0.8)
>>> inf = OverdampedLangevinInference(train)
>>> # ... fit ...
>>> inf.holdout_score(test)
"""
pairs = [ds.split_time(fraction, gap=gap) for ds in self.datasets]
train_ds = [p[0] for p in pairs]
test_ds = [p[1] for p in pairs]
n = len(self.datasets)
spec: WeightSpec = "pool" if reweight == "pool" else np.asarray(self.weights)
train = type(self)(datasets=train_ds, weights=jnp.ones((n,), dtype=jnp.float32)).with_weights(spec)
test = type(self)(datasets=test_ds, weights=jnp.ones((n,), dtype=jnp.float32)).with_weights(spec)
return train, test
[docs]
def dataset_index(self, position: int) -> int:
"""Dense index of dataset ``position``, keyed on its stable identity.
Datasets are numbered by first appearance of their ``uuid``, so the
index a force sees (e.g. via :func:`~SFI.bases.per_dataset_scalar` or
:func:`~SFI.bases.dataset_indicator`) is tied to the dataset itself, not
its slot — stable under concatenation and reordering.
"""
order: Dict[str, int] = {}
for ds in self.datasets:
order.setdefault(ds.uuid, len(order))
return order[self.datasets[position].uuid]
[docs]
def degrade(
self,
*,
downsample: int = 1,
motion_blur: int = 0,
data_loss_fraction: float = 0.0,
noise: Union[None, float, np.ndarray] = None,
ROI: Union[None, float, np.ndarray, Callable[[np.ndarray], bool]] = None,
seed: Optional[int] = None,
reweight: Literal["pool", "keep"] = "pool",
) -> "TrajectoryCollection":
"""
Return a new degraded collection; the original is not modified.
This is the preferred user-facing API for degrading synthetic
trajectories to mimic experimental noise, blur, and data loss.
Parameters
----------
downsample, motion_blur, data_loss_fraction, noise, ROI, seed, reweight
See :func:`SFI.trajectory.degrade.degrade_collection` for a full
description of each parameter.
Returns
-------
TrajectoryCollection
New degraded collection. The original collection is not modified.
"""
from SFI.trajectory.degrade import degrade_collection
return degrade_collection(
self,
downsample=downsample,
motion_blur=motion_blur,
data_loss_fraction=data_loss_fraction,
noise=noise,
ROI=ROI,
seed=seed,
reweight=reweight,
)
[docs]
def to_arrays(
self,
*,
dataset: int = 0,
as_numpy: bool = True,
include_mask: bool = True,
):
"""
Convenience helper: materialize one dataset as dense arrays.
Parameters
----------
dataset :
Index of the dataset inside the collection (default 0).
as_numpy :
If True, return NumPy arrays.
include_mask :
If True, also return the per-particle mask.
Returns
-------
t, X, mask :
See :meth:`TrajectoryDataset.to_arrays`.
"""
if not (0 <= dataset < len(self.datasets)):
raise IndexError(f"dataset index {dataset} out of range for D={len(self.datasets)}")
return self.datasets[dataset].to_arrays(
as_numpy=as_numpy,
include_mask=include_mask,
)
[docs]
def merge(
self,
items: Sequence[Union["TrajectoryCollection", TrajectoryDataset]],
*,
weights: WeightSpec = "pool",
) -> "TrajectoryCollection":
"""Combine this collection with others into one collection.
Convenience alias for :meth:`concat` — useful for assembling an
ensemble from several single-trajectory collections
(``base.merge([c1, c2, ...])``). See :meth:`concat` for the
``weights`` policy.
"""
return self.concat(items, weights=weights)
[docs]
def to_array(self, *, axis: Literal["time"] = "time", as_numpy: bool = True):
"""Materialize the whole collection as one dense ``(T, N, d)`` array.
Concatenates every dataset along the time axis into a single array
of positions. Use this for the legitimate non-plotting reach-ins
(disk caching, ensemble bootstrap initial conditions, neighbour
lists); for plotting, prefer the toolkit functions in
:mod:`SFI.utils.plotting`, and for ``(t, X, mask)`` of a single
dataset use :meth:`to_arrays`.
Parameters
----------
axis :
Only ``"time"`` is supported (axis-0 concatenation).
as_numpy :
If True (default), return a NumPy array; else a JAX array.
Returns
-------
ndarray, shape ``(sum_T, N, d)``
"""
if axis != "time":
raise ValueError(f"to_array only supports axis='time', got {axis!r}.")
if not self.datasets:
raise ValueError("Empty TrajectoryCollection.")
out = jnp.concatenate([ds._X3d() for ds in self.datasets], axis=0)
return np.asarray(out) if as_numpy else out
[docs]
def velocity_array(
self,
*,
dataset: int = 0,
scheme: Literal["central", "forward", "backward"] = "central",
as_numpy: bool = True,
):
"""Finite-difference velocity ``v(t)`` for one dataset.
Reconstructs velocities from stored positions with
:func:`SFI.utils.maths.fd_velocity`, matching the secant-velocity
convention of the underdamped engine. Handy for building
``(x, v)`` phase portraits or held-out evaluation grids from
position-only recordings.
Parameters
----------
dataset :
Dataset index inside the collection (default 0).
scheme :
Finite-difference stencil; see :func:`SFI.utils.maths.fd_velocity`.
as_numpy :
If True (default), return a NumPy array; else a JAX array.
Returns
-------
v : ndarray, shape ``(T, N, d)``
"""
from SFI.utils.maths import fd_velocity
t, X, _ = self.to_arrays(dataset=dataset, as_numpy=True, include_mask=True)
dt = np.diff(np.asarray(t, dtype=float))
if dt.size == 0:
raise ValueError("velocity_array needs at least 2 frames.")
v = fd_velocity(X, dt, scheme=scheme)
return np.asarray(v) if as_numpy else v
# ------------------------------------------------------------------
# Attribute forwarding for the most common case: a single dataset
# ------------------------------------------------------------------
def _single_dataset(self):
"""Return the unique dataset if the collection has exactly one, else None."""
if len(self.datasets) == 1:
return self.datasets[0]
return None
def __getattr__(self, name):
"""
Forward attribute access to the sole dataset when exactly one is present.
Does not intercept existing attributes on the collection itself.
"""
# Guard: during pickle/unpickle, __dict__ may not yet contain
# 'datasets', so _single_dataset() would recurse back here.
if name == "datasets":
raise AttributeError(name)
ds = self._single_dataset()
if ds is not None and hasattr(ds, name):
return getattr(ds, name)
raise AttributeError(f"{type(self).__name__!s} has no attribute {name!r}")
def __getitem__(self, key):
"""
Optional: dict-style forwarding for convenience.
Example: coll["extras_global"].
"""
ds = self._single_dataset()
if ds is not None and hasattr(ds, key):
return getattr(ds, key)
raise KeyError(key)
def __dir__(self):
"""
Extend tab completion: if a single dataset is present,
expose its attributes as if they were part of the collection.
"""
base = super().__dir__()
ds = self._single_dataset()
if ds is not None:
return sorted(set(base) | set(dir(ds)))
return base