Source code for SFI.trajectory.io

"""SFI.trajectory.io
====================

CSV I/O utilities and columnar ↔ tensor conversion for trajectory data.

File format
-----------
We support a single CSV with optional YAML header. The numerical columns include:

- ``particle_id`` (optional): if absent, it is a *single-trajectory* file.
- ``time_step``           : integer time index t (0-based after relabel).
- ``x0, x1, ..., x{d-1}`` : state vector components.

Extras are stored either in the header (YAML) or as extra numeric columns:

Prefixes (numeric columns)
~~~~~~~~~~~~~~~~~~~~~~~~~~
- ``TG_`` : *time-dependent globals* — values depend on ``t`` only.
- ``P_``  : *per-particle constants* — values depend on particle only.
- ``TP_`` : *time-dependent per-particle* — values depend on ``(t, n)``.
- ``G_``  : *global scalars* — constants stored in the header via averaging.

Note: ``TG_``/``TP_`` columns are parsed as time series and wrapped into
:class:`TimeSeriesExtra`. Header `extras_global` entries are treated as
static unless explicitly wrapped when building the dataset.

Header
~~~~~~
The YAML header may include a mapping ``extras_global`` of arbitrary scalars or
arrays. A special key ``"t"`` (vector of length ``T``) is recognized as the time
axis; when present, it replaces scalar ``dt`` in downstream builders.

Round-trip helpers
------------------
- :func:`flatten_X_to_columns` / :func:`assemble_X_from_columns` convert between
  structured tensors ``(T,N,d)`` and flat columns.
- :func:`save_trajectory_csv_with_extras` / :func:`load_trajectory_csv_with_extras`
  handle extras and header metadata.
- :func:`columns_and_extras_to_dataset` builds a :class:`TrajectoryDataset`
  ready for inference.

All functions are NumPy-based; JAX is optional for basic dtype detection only.

"""

from __future__ import annotations

import warnings
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union

import numpy as np

from SFI.trajectory.dataset import FunctionExtra, TimeSeriesExtra  # local import to avoid cycles

__all__ = [
    "save_trajectory",
    "load_trajectory",
    "columns_and_extras_to_dataset",
]
# ---------------- utilities ----------------


def _sanitize_metadata(obj: Any) -> Any:
    """Convert arrays/scalars into plain Python types recursively for YAML/JSON."""
    # Catch any array-like (numpy, JAX, PyTorch, …) by duck-typing.
    # The old check ("ndarray" in type(obj).__name__) missed JAX ArrayImpl.
    if hasattr(obj, "shape") and hasattr(obj, "dtype"):
        try:
            return np.asarray(obj).tolist()
        except Exception:
            pass
    if hasattr(obj, "item") and not isinstance(obj, (dict, list, tuple)):
        try:
            return obj.item()
        except Exception:
            pass
    if isinstance(obj, dict):
        return {k: _sanitize_metadata(v) for k, v in obj.items()}
    if isinstance(obj, (list, tuple)):
        return [_sanitize_metadata(v) for v in obj]
    return obj


# ---------------- core converters ----------------


def flatten_X_to_columns(X: np.ndarray, mask: Optional[np.ndarray] = None) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Tensor trajectory to columnar representation.

    Parameters
    ----------
    X : ndarray, shape (T, N, d)
        State tensor.
    mask : ndarray, optional
        Boolean mask with shape ``(T,N)`` or ``(T,)``; invalid rows are filtered out.

    Returns
    -------
    particle_idx : ndarray, shape (L,)
    time_idx : ndarray, shape (L,)
    state_vectors : ndarray, shape (L, d)

    Notes
    -----
    Rows with NaNs in ``state_vectors`` are dropped. If ``mask`` is provided,
    rows where mask is False are also dropped.
    """
    X = np.asarray(X)
    if X.ndim != 3:
        raise ValueError(f"X must be (T,N,d); got {X.shape}")
    T, N, d = X.shape
    time_idx = np.repeat(np.arange(T, dtype=int), N)
    particle_idx = np.tile(np.arange(N, dtype=int), T)
    state_vectors = X.reshape(T * N, d)
    valid = ~np.isnan(state_vectors).any(axis=1)
    if mask is not None:
        m = np.asarray(mask, dtype=bool)
        if m.shape == (T,):
            m = np.broadcast_to(m[:, None], (T, N))
        if m.shape != (T, N):
            raise ValueError(f"mask must be (T,) or (T,N); got {m.shape}")
        valid &= m.reshape(T * N)
    return (
        particle_idx[valid].astype(int, copy=False),
        time_idx[valid].astype(int, copy=False),
        state_vectors[valid],
    )


def assemble_X_from_columns(
    particle_idx: np.ndarray,
    time_idx: np.ndarray,
    state_vectors: np.ndarray,
    *,
    fill_value: float = np.nan,
    relabel: bool = True,
    compress_particles: bool = False,
) -> Tuple[np.ndarray, np.ndarray, Optional[Any]]:
    """Columnar to tensor trajectory (and mask).

    Parameters
    ----------
    particle_idx, time_idx : ndarray
        Integer columns.
    state_vectors : ndarray, shape (L, d)
        Flat state vectors.
    fill_value : float
        Value to fill missing entries in ``X``.
    relabel : bool
        If True, compress particle IDs to contiguous ``0..N-1`` and shift
        ``time_idx`` to start at 0.
    compress_particles : bool
        If True, apply greedy interval packing to further reduce the number of
        columns by merging particles whose time supports do not overlap (with a
        2-frame buffer).  Implies relabeling of particle IDs first.
        See :func:`_greedy_compress_particles`.

    Returns
    -------
    X : ndarray, shape (T, N, d)
    mask : ndarray, shape (T, N)
        True where entries are present in the columns.
    id_map : ndarray of shape ``(N,)`` or dict or None
        * ``relabel=True``, ``compress_particles=False``: shape ``(N,)`` array
          of original particle IDs in column order.
        * ``compress_particles=True``: dict with keys ``'column_origins'``
          (list of lists of compact IDs per column), ``'t_first'``, ``'t_last'``
          (ndarrays of shape ``(N_orig,)`` giving the time span of each compact
          particle before compression).
        * Otherwise: ``None``.
    """
    particle_idx = np.asarray(particle_idx, dtype=int)
    time_idx = np.asarray(time_idx, dtype=int)
    state_vectors = np.asarray(state_vectors)
    if state_vectors.ndim != 2:
        raise ValueError("state_vectors must be 2D (L, d)")
    time_idx = time_idx - time_idx.min()
    if (time_idx < 0).any():
        raise ValueError("time_idx normalization failed (negative after shift).")
    id_map: Optional[Any] = None
    if relabel or compress_particles:
        uniq = np.unique(particle_idx)
        remap = {old: new for new, old in enumerate(uniq)}
        particle_idx = np.vectorize(remap.__getitem__, otypes=[int])(particle_idx)
        N = len(uniq)
        if not compress_particles:
            # Only record the map when IDs were not already 0..N-1
            if not np.array_equal(uniq, np.arange(len(uniq))):
                id_map = uniq  # original IDs in column order
    else:
        N = int(particle_idx.max()) + 1 if len(particle_idx) > 0 else 0
    if compress_particles:
        particle_idx, column_origins, t_first, t_last = _greedy_compress_particles(particle_idx, time_idx)
        N = len(column_origins)
        id_map = {
            "column_origins": column_origins,
            "t_first": t_first,
            "t_last": t_last,
        }
    T = int(time_idx.max()) + 1 if len(time_idx) > 0 else 0
    d = int(state_vectors.shape[1])
    X = np.full((T, N, d), fill_value, dtype=state_vectors.dtype)
    mask = np.zeros((T, N), dtype=bool)
    X[time_idx, particle_idx] = state_vectors
    mask[time_idx, particle_idx] = True
    return X, mask, id_map


def _greedy_compress_particles(
    particle_idx: np.ndarray,
    time_idx: np.ndarray,
) -> Tuple[np.ndarray, List[List[int]], np.ndarray, np.ndarray]:
    """Greedy interval packing: merge particle columns with non-overlapping time windows.

    Two particles assigned to the same column are always separated by at least
    one masked frame (gap ≥ 2 time steps between time supports), preventing
    spurious increments across identity changes.

    Parameters
    ----------
    particle_idx : ndarray, shape (L,)
        Compact particle indices in ``0..N_orig-1``.
    time_idx : ndarray, shape (L,)
        Compact time indices in ``0..T-1``.

    Returns
    -------
    new_particle_idx : ndarray, shape (L,)
        Updated column index for each observation row.
    column_origins : list[list[int]]
        ``column_origins[c]`` lists the compact IDs packed into column ``c``,
        in temporal order.
    t_first : ndarray, shape (N_orig,)
        First time index for each compact particle.
    t_last : ndarray, shape (N_orig,)
        Last time index for each compact particle.
    """
    if len(particle_idx) == 0:
        return (
            particle_idx.copy(),
            [],
            np.array([], dtype=np.intp),
            np.array([], dtype=np.intp),
        )
    N_orig = int(particle_idx.max()) + 1
    t_first = np.full(N_orig, np.iinfo(np.intp).max, dtype=np.intp)
    t_last = np.full(N_orig, -1, dtype=np.intp)
    np.minimum.at(t_first, particle_idx, time_idx)
    np.maximum.at(t_last, particle_idx, time_idx)
    # Sort particles by first appearance (stable to make assignment deterministic)
    order = np.argsort(t_first, kind="stable")
    column_last: List[int] = []  # last time index of each open column
    column_origins: List[List[int]] = []
    assignment = np.empty(N_orig, dtype=np.intp)
    for p in order:
        tf = int(t_first[p])
        tl = int(t_last[p])
        assigned = -1
        for c, clast in enumerate(column_last):
            if clast <= tf - 2:  # ≥ 2-frame gap ensures at least one masked frame
                assigned = c
                break
        if assigned < 0:
            assigned = len(column_last)
            column_last.append(tl)
            column_origins.append([int(p)])
        else:
            column_last[assigned] = tl
            column_origins[assigned].append(int(p))
        assignment[p] = assigned
    return assignment[particle_idx], column_origins, t_first, t_last


def _reindex_extras_local_for_compression(
    extras_local: Dict[str, Any],
    column_origins: List[List[int]],
    t_first: np.ndarray,
    t_last: np.ndarray,
    T: int,
    N_comp: int,
) -> Dict[str, Any]:
    """Reindex per-particle extras from ``N_orig`` to ``N_comp`` columns.

    After :func:`_greedy_compress_particles` has merged particles into fewer
    columns, the extras stored in ``extras_local`` still refer to the original
    compact particle indices.  This function rebuilds them so they match the
    compressed shape ``(T, N_comp, …)``.

    * ``TimeSeriesExtra`` with shape ``(T, N_orig, …)`` → ``(T, N_comp, …)``.
    * ndarray with leading axis ``N_orig`` → promoted to ``(T, N_comp, …)``
      wrapped in a ``TimeSeriesExtra`` so each ``(t, c)`` slot returns the
      value of whichever compact particle is active in column ``c`` at time ``t``.
    * Callables / ``FunctionExtra`` and other shapes are kept unchanged.
    """
    from SFI.trajectory.dataset import TimeSeriesExtra, time_series_extra

    N_orig = len(t_first)
    out: Dict[str, Any] = {}
    for key, val in extras_local.items():
        if callable(val):
            out[key] = val
            continue
        if isinstance(val, TimeSeriesExtra):
            arr = np.asarray(val.data)
            if arr.ndim >= 2 and arr.shape[0] == T and arr.shape[1] == N_orig:
                tail = arr.shape[2:]
                new_arr = np.full((T, N_comp) + tail, np.nan, dtype=float)
                for c, orig_list in enumerate(column_origins):
                    for p in orig_list:
                        tf, tl = int(t_first[p]), int(t_last[p])
                        new_arr[tf : tl + 1, c] = arr[tf : tl + 1, p]
                out[key] = time_series_extra(new_arr)
            else:
                out[key] = val
        else:
            arr = np.asarray(val)
            if arr.ndim >= 1 and arr.shape[0] == N_orig:
                # Per-particle constant (N_orig, …) → time-varying (T, N_comp, …)
                tail = arr.shape[1:]
                new_arr = np.full((T, N_comp) + tail, np.nan, dtype=float)
                for c, orig_list in enumerate(column_origins):
                    for p in orig_list:
                        tf, tl = int(t_first[p]), int(t_last[p])
                        new_arr[tf : tl + 1, c] = arr[p]
                out[key] = time_series_extra(new_arr)
            else:
                out[key] = val
    return out


# -------- builder that wires t from extras_global if present --------


[docs] def columns_and_extras_to_dataset( particle_idx: np.ndarray, time_idx: np.ndarray, state_vectors: np.ndarray, *, extras_global: Mapping[str, Any] | None = None, extras_local: Mapping[str, Any] | None = None, dt: Optional[float] = None, t: Optional[np.ndarray] = None, mask_fill_value: float = np.nan, relabel: bool = True, compress_particles: bool = False, meta: Optional[Dict[str, Any]] = None, ): """Build a :class:`TrajectoryDataset` from columns and parsed extras. Preference order for the time axis: 1) explicit ``t`` argument, 2) ``extras_global['t']`` (from header or ``TG_t``), 3) fallback to scalar ``dt``. Parameters ---------- compress_particles : bool If True, apply greedy interval packing so that particles with non-overlapping time supports share the same column index. This can dramatically reduce ``N`` for open-boundary systems where particles enter and leave the field of view over time. Per-particle extras are automatically reindexed to the compressed column layout. The mapping is stored as ``meta['particle_column_map']``. When False (default) and ``relabel=True``, the original particle IDs are recorded as ``extras_local['original_particle_id']``. Returns ------- TrajectoryDataset """ if isinstance(t, TimeSeriesExtra): t = np.asarray(t.data) from SFI.trajectory.dataset import TrajectoryDataset X, mask, id_map = assemble_X_from_columns( particle_idx=particle_idx, time_idx=time_idx, state_vectors=state_vectors, fill_value=mask_fill_value, relabel=relabel, compress_particles=compress_particles, ) eg = dict(extras_global or {}) el = dict(extras_local or {}) meta_out = dict(meta or {}) if id_map is not None: if isinstance(id_map, dict): # compress_particles=True: store column map and reindex per-particle extras meta_out["particle_column_map"] = id_map["column_origins"] el = _reindex_extras_local_for_compression( el, id_map["column_origins"], id_map["t_first"], id_map["t_last"], T=X.shape[0], N_comp=X.shape[1], ) else: # relabel=True: record the original→compact ID mapping el["original_particle_id"] = id_map t_vec = t if t_vec is None and "t" in eg: tv = eg["t"] raw = tv.data if isinstance(tv, TimeSeriesExtra) else tv arr = np.asarray(raw) if arr.ndim == 1: t_vec = arr return TrajectoryDataset.from_arrays( X=X, dt=None if t_vec is not None else dt, t=None if t_vec is None else t_vec, mask=mask, extras_global=eg, extras_local=el, meta=meta_out, )
# ---------------- unified save/load with extras (csv/parquet/h5) ---------------- def _build_tabular_with_extras( *, particle_idx: Optional[np.ndarray], time_idx: np.ndarray, state_vectors: np.ndarray, extras_global: Optional[Dict[str, Any]] = None, extras_local: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None, prefix_G: str = "G_", prefix_TG: str = "TG_", prefix_P: str = "P_", prefix_TP: str = "TP_", ): """Create a pandas DataFrame with all columns (x*, IDs, prefixed extras) and a YAML-able metadata dict for header/schema. Returns ------- df : pandas.DataFrame yaml_meta : dict (already sanitized to plain Python types) base_cols : list[str] # column ordering hint """ import pandas as pd time_idx = np.asarray(time_idx, dtype=int) particle_idx = None if particle_idx is None else np.asarray(particle_idx, dtype=int) state_vectors = np.asarray(state_vectors) L, d = state_vectors.shape cols: Dict[str, Any] = {} if particle_idx is not None: cols["particle_id"] = particle_idx cols["time_step"] = time_idx for j in range(d): cols[f"x{j}"] = state_vectors[:, j] # Shapes T = int(time_idx.max()) + 1 if L else 0 N = (int(particle_idx.max()) + 1) if (L and particle_idx is not None) else (1 if L else 0) part = particle_idx if particle_idx is not None else np.zeros_like(time_idx, dtype=int) # Global extras: TimeSeriesExtra → TG_* columns; others → header header_globals: Dict[str, Any] = {} for key, val in (extras_global or {}).items(): if isinstance(val, TimeSeriesExtra): arr = np.asarray(val.data) if arr.shape[0] != T and L: raise ValueError( f"extras_global['{key}'] TimeSeriesExtra expects leading axis T={T}, got {arr.shape[0]}" ) vals = arr[time_idx] if L else arr.reshape(0, *arr.shape[1:]) flat = vals.reshape(L, -1) if L else vals.reshape(0, -1) for kcol in range(flat.shape[1] if L else (arr.reshape(-1).shape[0] or 1)): name = ( f"{prefix_TG}{key}" if (flat.shape[1] if L else arr.reshape(-1).shape[0]) == 1 else f"{prefix_TG}{key}_{kcol}" ) cols[name] = flat[:, kcol] if L else np.asarray([], dtype=float) else: if callable(val) or isinstance(val, FunctionExtra): warnings.warn( f"extras_global['{key}'] is a callable and cannot be serialized to disk. " "It will be omitted from the saved file. Re-attach it to the dataset after loading.", UserWarning, stacklevel=2, ) continue header_globals[key] = _sanitize_metadata(val) # Local extras: # - TimeSeriesExtra → TP_* columns # - array-like with first axis N → P_* columns # - otherwise → header for key, val in (extras_local or {}).items(): if isinstance(val, TimeSeriesExtra): arr = np.asarray(val.data) # expected (T, N, …) or (T, 1, …) for single-trajectory if arr.ndim < 2: raise ValueError( f"extras_local['{key}'] TimeSeriesExtra must have at least 2 dims (T,N,...), got {arr.shape}" ) if L and arr.shape[0] != T: raise ValueError(f"extras_local['{key}'] TimeSeriesExtra expects T={T}, got {arr.shape[0]}") pid = part if (particle_idx is None) and arr.shape[1] == 1: pid = np.zeros_like(time_idx) if L: vals = arr[time_idx, pid] flat = vals.reshape(L, -1) for kcol in range(flat.shape[1]): name = f"{prefix_TP}{key}" if flat.shape[1] == 1 else f"{prefix_TP}{key}_{kcol}" cols[name] = flat[:, kcol] else: # empty table, still create no data columns; metadata only header_globals[key] = _sanitize_metadata(arr) else: if callable(val) or isinstance(val, FunctionExtra): warnings.warn( f"extras_local['{key}'] is a callable and cannot be serialized to disk. " "It will be omitted from the saved file. Re-attach it to the dataset after loading.", UserWarning, stacklevel=2, ) continue arr = np.asarray(val) if arr.ndim >= 1 and (N == 0 or arr.shape[0] == N): # per-particle constants: (N, …) vals = arr[part if particle_idx is not None else np.zeros_like(time_idx)] flat = vals.reshape(L, -1) for kcol in range(flat.shape[1]): name = f"{prefix_P}{key}" if flat.shape[1] == 1 else f"{prefix_P}{key}_{kcol}" cols[name] = flat[:, kcol] else: header_globals[key] = _sanitize_metadata(arr) df = pd.DataFrame(cols) yaml_meta = dict(_sanitize_metadata(metadata or {})) if header_globals: yaml_meta.setdefault("extras_global", {}).update(header_globals) base_cols = ([] if particle_idx is None else ["particle_id"]) + ["time_step"] + [f"x{j}" for j in range(d)] return df, yaml_meta, base_cols def _parse_tabular_with_extras( df, yaml_meta: Dict[str, Any], *, particle_column: Optional[str], time_column: str, relabel: bool, prefix_G: str = "G_", prefix_TG: str = "TG_", prefix_P: str = "P_", prefix_TP: str = "TP_", ): """Inverse of _build_tabular_with_extras for a pandas DataFrame + header/metadata dict.""" import numpy as np from SFI.trajectory.dataset import time_series_extra colnames = list(df.columns) has_particles = particle_column is not None and (particle_column in df.columns) if has_particles: particle_indices = df[particle_column].to_numpy(dtype=int) else: particle_indices = np.zeros((len(df),), dtype=int) time_indices = df[time_column].to_numpy(dtype=int) # identify state columns: not ids and not extras prefixes def is_extra(name: str) -> bool: return any(name.startswith(px) for px in (prefix_G, prefix_TG, prefix_P, prefix_TP)) or name in { particle_column, time_column, } state_cols = [c for c in colnames if not is_extra(c)] state_vectors = df[state_cols].to_numpy(dtype=float) # relabel like CSV loader if relabel and len(state_vectors) > 0: if has_particles: _, inv = np.unique(particle_indices, return_inverse=True) particle_indices = inv.astype(int, copy=False) time_indices = time_indices - int(time_indices.min()) L = len(df) T = int(time_indices.max()) + 1 if L else 0 N = int(particle_indices.max()) + 1 if L else 0 # Start from header-provided extras (may contain globals) extras_global: Dict[str, Any] = dict(yaml_meta.get("extras_global", {}) or {}) extras_local: Dict[str, Any] = {} # Collect prefixed numeric columns def collect(prefix: str) -> Dict[str, np.ndarray]: return {name: df[name].to_numpy() for name in df.columns if name.startswith(prefix)} TG_cols = collect(prefix_TG) P_cols = collect(prefix_P) TP_cols = collect(prefix_TP) G_cols = collect(prefix_G) # TG_: time-dependent globals → TimeSeriesExtra(T, …) tg_grouped: Dict[str, List[Tuple[str, np.ndarray]]] = {} for name, vals in TG_cols.items(): base = name[len(prefix_TG) :].split("_")[0] tg_grouped.setdefault(base, []).append((name, vals)) for key, items in tg_grouped.items(): items_sorted = sorted(items, key=lambda kv: kv[0]) mat = np.column_stack([v for _, v in items_sorted]) # (L, k) tg_matrix = np.full((T, mat.shape[1]), np.nan, dtype=float) for t in range(T): sel = time_indices == t if np.any(sel): tg_matrix[t] = np.nanmean(mat[sel], axis=0) tg_matrix = tg_matrix.squeeze(-1) if tg_matrix.shape[1] == 1 else tg_matrix extras_global[key] = time_series_extra(tg_matrix) # TP_: time-dependent per-particle → TimeSeriesExtra(T, N, …) tp_grouped: Dict[str, List[Tuple[str, np.ndarray]]] = {} for name, vals in TP_cols.items(): base = name[len(prefix_TP) :].split("_")[0] tp_grouped.setdefault(base, []).append((name, vals)) for key, items in tp_grouped.items(): items_sorted = sorted(items, key=lambda kv: kv[0]) mat = np.column_stack([v for _, v in items_sorted]) # (L, k) out = np.full((T, N, mat.shape[1]), np.nan, dtype=float) for t in range(T): sel_t = time_indices == t if not np.any(sel_t): continue for n in range(N): sel = sel_t & ((particle_indices == n) if has_particles else (particle_indices == 0)) if np.any(sel): out[t, n] = np.nanmean(mat[sel], axis=0) out = out.squeeze(-1) if out.shape[2] == 1 else out extras_local[key] = time_series_extra(out) # P_: per-particle constants p_grouped: Dict[str, List[Tuple[str, np.ndarray]]] = {} for name, vals in P_cols.items(): base = name[len(prefix_P) :].split("_")[0] p_grouped.setdefault(base, []).append((name, vals)) for key, items in p_grouped.items(): items_sorted = sorted(items, key=lambda kv: kv[0]) mat = np.column_stack([v for _, v in items_sorted]) out = np.full((N, mat.shape[1]), np.nan, dtype=float) for n in range(N): sel = particle_indices == n if np.any(sel): out[n] = np.nanmean(mat[sel], axis=0) extras_local[key] = out.squeeze(-1) if out.shape[1] == 1 else out # G_: global scalars → average for name, vals in G_cols.items(): key = name[len(prefix_G) :] extras_global[key] = float(np.nanmean(vals)) return ( yaml_meta, colnames, particle_indices, time_indices, state_vectors, extras_global, extras_local, ) def _resolve_column_name(spec: Union[int, str], colnames: List[str], *, what: str) -> str: """Resolve an ``int`` index or ``str`` name to a column name. Raises a ``ValueError`` naming the available columns on a miss. """ if isinstance(spec, str): if spec not in colnames: raise ValueError(f"{what} column {spec!r} not found; available columns: {colnames}") return spec i = int(spec) if not (-len(colnames) <= i < len(colnames)): raise ValueError(f"{what} column index {i} out of range for {len(colnames)} columns: {colnames}") return colnames[i] def _subselect_state_columns( df, *, particle_name: Optional[str], time_name: str, state_names: List[str], prefixes: Tuple[str, ...], ): """Keep only id/time, the selected state columns (in order), and extras columns.""" keep = ([particle_name] if particle_name else []) + [time_name] + list(state_names) keep += [c for c in df.columns if any(c.startswith(px) for px in prefixes) and c not in keep] return df[keep] def _apply_named_knobs( df, *, fmt: str, particle_column: Optional[Union[int, str]], time_column: Union[int, str], state_columns: Optional[Sequence[Union[int, str]]], prefixes: Tuple[str, ...], ): """Resolve column knobs for the name-addressed formats (parquet / h5). ``str`` values are honored (and validated); ``int`` values keep the historical behavior — they cannot be distinguished from the defaults, so the canonical names ``particle_id`` / ``time_step`` are used. Returns ``(particle_name_or_None, time_name, df_possibly_subselected)``. """ colnames = list(df.columns) if isinstance(particle_column, str): p_name: Optional[str] = _resolve_column_name(particle_column, colnames, what="particle") elif particle_column is None: p_name = None else: p_name = "particle_id" if "particle_id" in colnames else None t_name = ( _resolve_column_name(time_column, colnames, what="time") if isinstance(time_column, str) else "time_step" ) if state_columns is not None: bad = [c for c in state_columns if not isinstance(c, str)] if bad: raise ValueError(f"state_columns must be column names (str) for {fmt} files; got {bad!r}") state_names = [_resolve_column_name(c, colnames, what="state") for c in state_columns] df = _subselect_state_columns( df, particle_name=p_name, time_name=t_name, state_names=state_names, prefixes=prefixes ) return p_name, t_name, df def _infer_format_from_suffix(path: str) -> str: lower = path.lower() if lower.endswith(".parquet") or lower.endswith(".pq"): return "parquet" if lower.endswith(".h5") or lower.endswith(".hdf5"): return "h5" return "csv"
[docs] def save_trajectory( filename: str, *, particle_idx: Optional[np.ndarray], time_idx: np.ndarray, state_vectors: np.ndarray, extras_global: Optional[Dict[str, Any]] = None, extras_local: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None, format: Optional[str] = None, # CSV-only knobs: float_fmt: str = "%.8f", # Parquet-only knobs: compression: str = "snappy", # Prefixes (shared semantics across formats): prefix_G: str = "G_", prefix_TG: str = "TG_", prefix_P: str = "P_", prefix_TP: str = "TP_", ) -> None: """Unified saver for {'csv','parquet','h5'} (inferred from filename if format=None).""" fmt = (format or _infer_format_from_suffix(filename)).lower() # Dispatcher-owned structural arrays (``_cache/`` keys) are derivable and must # never be serialized; strip them here so the saver upholds the invariant on # its own, mirroring ``simulate`` output and ``degrade`` (no normal path # persists them onto a dataset, so this is a defense-in-depth guard). from SFI.statefunc.nodes.interactions.prepare import purge_cache_extras extras_global = purge_cache_extras(extras_global) extras_local = purge_cache_extras(extras_local) df, yaml_meta, base_cols = _build_tabular_with_extras( particle_idx=particle_idx, time_idx=time_idx, state_vectors=state_vectors, extras_global=extras_global, extras_local=extras_local, metadata=metadata, prefix_G=prefix_G, prefix_TG=prefix_TG, prefix_P=prefix_P, prefix_TP=prefix_TP, ) if fmt == "csv": import yaml # YAML header as comment lines + CSV body yaml_str = yaml.safe_dump(yaml_meta, sort_keys=False) yaml_header = "# ---\n" + "\n".join(f"# {line}" for line in yaml_str.strip().splitlines()) ordered = base_cols + [c for c in df.columns if c not in base_cols] with open(filename, "w") as f: f.write(yaml_header + "\n") df.to_csv(f, index=False, columns=ordered, float_format=float_fmt) return if fmt == "parquet": try: import pyarrow as pa import pyarrow.parquet as pq import yaml except Exception as e: # pragma: no cover raise ImportError("Saving Parquet requires pyarrow and yaml.") from e table = pa.Table.from_pandas(df, preserve_index=False) md = dict(table.schema.metadata or {}) md[b"sfi_yaml_header"] = yaml.safe_dump(yaml_meta, sort_keys=False).encode("utf-8") table = table.replace_schema_metadata(md) pq.write_table(table, filename, compression=compression) return if fmt == "h5": try: import h5py import yaml except Exception as e: # pragma: no cover raise ImportError("Saving HDF5 requires h5py and yaml.") from e with h5py.File(filename, "w") as h5: grp = h5.create_group("table") for c in df.columns: data = np.asarray(df[c].to_numpy()) grp.create_dataset(c, data=data, compression="gzip", shuffle=True, fletcher32=True) # store YAML in root attr h5.attrs["sfi_yaml_header"] = yaml.safe_dump(yaml_meta, sort_keys=False) return raise ValueError("format must be one of {'csv','parquet','h5'}")
[docs] def load_trajectory( filename: str, *, format: Optional[str] = None, # Column-selection knobs (int index or str name): particle_column: Optional[Union[int, str]] = 0, # None => single-trajectory file time_column: Union[int, str] = 1, state_columns: Optional[Sequence[Union[int, str]]] = None, relabel: bool = True, # Prefixes: prefix_G: str = "G_", prefix_TG: str = "TG_", prefix_P: str = "P_", prefix_TP: str = "TP_", ): """Unified loader for {'csv','parquet','h5'} (inferred from filename if format=None). Parameters ---------- particle_column, time_column Which columns hold the particle ID and the time index. Accept a column *name* (``str``) for any format, or a positional *index* (``int``) for CSV files only. CSV defaults are positional (column 0 = particle, column 1 = time); parquet and HDF5 default to the canonical names ``"particle_id"`` and ``"time_step"`` and ignore ``int`` values (their defaults cannot be distinguished from "unspecified"). ``particle_column=None`` marks a single-trajectory file. state_columns Optional explicit selection of the state-vector columns (names, or indices for CSV), in order. When given, every other non-extras column is dropped. Default: every column that is not an ID and does not carry an extras prefix is a state component. Returns the standard tuple: (metadata, column_headers, particle_indices, time_indices, state_vectors, extras_global, extras_local) """ fmt = (format or _infer_format_from_suffix(filename)).lower() extra_prefixes = (prefix_G, prefix_TG, prefix_P, prefix_TP) if fmt == "csv": # Parse YAML header then read CSV import pandas as pd import yaml metadata: Dict[str, Any] = {} yaml_lines: List[str] = [] with open(filename, "r") as f: for line in f: if line.startswith("# "): yaml_lines.append(line[2:]) elif not line.startswith("#"): break if yaml_lines: try: metadata = yaml.safe_load("".join(yaml_lines)) or {} except Exception: metadata = {} df = pd.read_csv(filename, comment="#") colnames = list(df.columns) # Resolve int/str knobs to column names. particle_name = ( _resolve_column_name(particle_column, colnames, what="particle") if particle_column is not None else None ) time_name = _resolve_column_name(time_column, colnames, what="time") # Optional explicit state-column selection (drops everything else). if state_columns is not None: state_names = [_resolve_column_name(c, colnames, what="state") for c in state_columns] df = _subselect_state_columns( df, particle_name=particle_name, time_name=time_name, state_names=state_names, prefixes=extra_prefixes, ) # Canonical names for the uniform parser. rename: Dict[str, str] = {} if particle_name is not None and particle_name != "particle_id": if "particle_id" in df.columns: raise ValueError( f"particle column {particle_name!r} selected, but a distinct " "'particle_id' column also exists — drop or rename one." ) rename[particle_name] = "particle_id" if time_name != "time_step": if "time_step" in df.columns: raise ValueError( f"time column {time_name!r} selected, but a distinct " "'time_step' column also exists — drop or rename one." ) rename[time_name] = "time_step" if rename: # pandas-stubs rename overloads reject a plain dict mapper (false positive). df = df.rename(columns=rename) # type: ignore[call-overload] # parse return _parse_tabular_with_extras( df, metadata, particle_column="particle_id" if particle_name is not None else None, time_column="time_step", relabel=relabel, prefix_G=prefix_G, prefix_TG=prefix_TG, prefix_P=prefix_P, prefix_TP=prefix_TP, ) if fmt == "parquet": import pandas as pd import pyarrow.parquet as pq import yaml table = pq.read_table(filename) md = dict(table.schema.metadata or {}) yaml_bytes = md.get(b"sfi_yaml_header", None) metadata = yaml.safe_load(yaml_bytes.decode("utf-8")) if yaml_bytes else {} df = table.to_pandas(types_mapper=None) p_name, t_name, df = _apply_named_knobs( df, fmt=fmt, particle_column=particle_column, time_column=time_column, state_columns=state_columns, prefixes=extra_prefixes, ) return _parse_tabular_with_extras( df, metadata, particle_column=p_name, time_column=t_name, relabel=relabel, prefix_G=prefix_G, prefix_TG=prefix_TG, prefix_P=prefix_P, prefix_TP=prefix_TP, ) if fmt == "h5": import h5py import pandas as pd import yaml with h5py.File(filename, "r") as h5: metadata = {} if "sfi_yaml_header" in h5.attrs: try: # h5py stubs type attrs values as array-like; the header is str/bytes. metadata = yaml.safe_load(h5.attrs["sfi_yaml_header"]) or {} # type: ignore[arg-type] except Exception: metadata = {} if "table" not in h5: raise ValueError("HDF5 file missing 'table' group.") grp = h5["table"] if not isinstance(grp, h5py.Group): raise ValueError("HDF5 'table' entry is not a group.") # reconstruct DataFrame from datasets (the group holds only Datasets; # h5py stubs widen __getitem__ to include Datatype) cols = {name: grp[name][...] for name in grp.keys()} # type: ignore[index] df = pd.DataFrame(cols) p_name, t_name, df = _apply_named_knobs( df, fmt=fmt, particle_column=particle_column, time_column=time_column, state_columns=state_columns, prefixes=extra_prefixes, ) return _parse_tabular_with_extras( df, metadata, particle_column=p_name, time_column=t_name, relabel=relabel, prefix_G=prefix_G, prefix_TG=prefix_TG, prefix_P=prefix_P, prefix_TP=prefix_TP, ) raise ValueError("format must be one of {'csv','parquet','h5'}")