Source code for SFI.trajectory.reserved_extras

"""Framework-owned extras and the single resolver that materialises them.

A force or diffusion expression reads two kinds of data through its ``extras``
mapping: **user** values attached to the trajectory (drive protocols,
per-particle properties, geometry) and **reserved** values supplied by the
framework. The reserved keys are defined once here, in a small registry, and
assembled together with the user values by :func:`resolve_extras` — the single
entry point used by simulation, inference, and diagnostics.

Reserved keys:

``time``
    Absolute time at each resolved frame — lets time-dependent bases (e.g.
    :func:`~SFI.bases.time_fourier`) read the clock.
``duration``
    Total trajectory span.
``dataset_index``
    Dense index of the dataset within its collection, for pooled
    multi-experiment models (:func:`~SFI.bases.per_dataset_scalar`,
    :func:`~SFI.bases.dataset_indicator`).
``particle_index``
    Per-particle integer ids, gathered per edge by interaction dispatchers.
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Callable, Dict, Mapping, Optional

import jax.numpy as jnp


[docs] @dataclass(frozen=True) class ExtrasContext: """Everything a reserved-key resolver needs for a set of frames. Assembled by whoever drives the evaluation (the trajectory producer, the diagnostics residual builder, or the simulator) and passed to :func:`resolve_extras`. """ n_particles: int dataset_index: int frame_times: Any duration: Any
[docs] @dataclass(frozen=True) class ReservedKey: """A framework-owned extras key and how it is materialised.""" name: str resolve: Callable[[ExtrasContext], Any]
_REGISTRY: Dict[str, ReservedKey] = {}
[docs] def register(key: ReservedKey) -> None: """Add a reserved key to the registry.""" _REGISTRY[key.name] = key
register(ReservedKey("time", lambda ctx: ctx.frame_times)) register(ReservedKey("duration", lambda ctx: ctx.duration)) register(ReservedKey("dataset_index", lambda ctx: jnp.asarray(ctx.dataset_index, dtype=jnp.int32))) register(ReservedKey("particle_index", lambda ctx: jnp.arange(int(ctx.n_particles), dtype=jnp.int32))) #: The set of reserved key names; user extras may not use these. RESERVED_NAMES = frozenset(_REGISTRY)
[docs] def is_reserved(name: str) -> bool: """True when ``name`` is a framework-owned reserved key.""" return name in _REGISTRY
[docs] def resolve_reserved(ctx: ExtrasContext) -> Dict[str, Any]: """Materialise every reserved key for ``ctx``.""" return {name: key.resolve(ctx) for name, key in _REGISTRY.items()}
[docs] def resolve_extras(user_extras: Mapping[str, Any], ctx: ExtrasContext) -> Dict[str, Any]: """Full per-frame extras: user values plus the resolved reserved keys. Reserved names are framework-owned; a user entry colliding with one is rejected so the meaning of a reserved key is never ambiguous. """ out = dict(user_extras) clash = RESERVED_NAMES.intersection(out) if clash: raise ValueError(f"extras keys {sorted(clash)} are reserved; rename the user entries.") out.update(resolve_reserved(ctx)) return out
[docs] def slice_frame_extras( extras_global: Optional[Mapping[str, Any]], extras_local: Optional[Mapping[str, Any]], *, frame_idx: Any, context: Optional[str] = None, ) -> Dict[str, Any]: """Materialise user extras at ``frame_idx``. * :class:`~SFI.trajectory.dataset.TimeSeriesExtra` → sliced ``value.data[frame_idx]``; * :class:`~SFI.trajectory.dataset.FunctionExtra` → its callable, forwarded; * plain callable → invoked as ``value(frame_idx, context=context)``; * anything else → forwarded unchanged. ``extras_local`` overrides ``extras_global`` on key conflicts. """ from SFI.trajectory.dataset import FunctionExtra, TimeSeriesExtra def _materialise(value: Any) -> Any: if isinstance(value, FunctionExtra): return value.func if isinstance(value, TimeSeriesExtra): return jnp.asarray(value.data)[frame_idx] if callable(value): return value(frame_idx, context=context) return value out: Dict[str, Any] = {} for source in (extras_global or {}, extras_local or {}): for key, value in source.items(): out[key] = _materialise(value) return out