from typing import Sequence
import jax.numpy as jnp
from SFI.statefunc import Basis, Rank
from SFI.statefunc.nodes import SimpleLeaf
[docs]
def ones_basis(dim: int, pdepth: int = 0) -> Basis:
def _ones(x):
return jnp.ones((*x.shape[:-1], 1))
return Basis(
SimpleLeaf(
func=_ones,
n_features=1,
labels=("1",),
descriptor="scalar-one",
dim=dim,
rank=Rank.SCALAR,
pdepth=pdepth,
)
)
[docs]
def constant_array(A, *, label: str = "const", descriptor: str = "constant-array", as_sf=True):
"""
Constant basis/sf with a single feature whose value is a fixed tensor A of shape
``(dim,)*rank`` (rank inferred from ``A.ndim``, dim from ``A.shape[0]``).
- Errors if A is not a hypercube tensor (all axes same length).
- Broadcasts over batch/particles: output shape is
``(*x.shape[:-1], (dim,)*rank, 1)``.
"""
A = jnp.asarray(A)
if A.ndim == 0:
raise ValueError("Use ones_basis()/scalar constants for rank=SCALAR.")
if len(set(A.shape)) != 1:
raise ValueError(f"A must have shape (dim,)*rank (hypercube). Got {A.shape}.")
dim = int(A.shape[0])
rank = int(A.ndim)
def _const(x):
# x: (..., dim). We broadcast A over all leading axes of x (particles/time/...)
lead = x.shape[:-1]
return jnp.broadcast_to(A[..., None], (*lead, *A.shape, 1))
leaf = SimpleLeaf(
func=_const,
n_features=1,
labels=(label,),
descriptor=descriptor,
dim=dim,
rank=rank,
pdepth=0,
)
B = Basis(leaf)
if as_sf:
return B.to_psf().bind(params={"coeff": 1.0})
return B
[docs]
def unit_vector_basis(dim: int, axes: Sequence[int] | None = None) -> Basis:
if axes is None:
axes = list(range(dim))
labels = [f"e{i}" for i in axes]
def _f(x):
out = jnp.zeros((dim, len(axes)))
for i, a in enumerate(axes):
out = out.at[a, i].set(1)
return out
return Basis(
SimpleLeaf(
func=_f,
n_features=len(axes),
labels=tuple(labels),
descriptor="unit-vector-set",
dim=dim,
rank=Rank.VECTOR,
pdepth=0,
particles_input=False,
needs_v=False,
)
)
[docs]
def identity_matrix_basis(dim: int, pdepth: int = 0) -> Basis:
def _f(x):
eye = jnp.eye(dim)
return jnp.broadcast_to(eye, (*x.shape[:-1], dim, dim))
return Basis(
SimpleLeaf(
func=_f,
n_features=1,
labels=("I",),
descriptor="identity",
dim=dim,
rank=Rank.MATRIX,
pdepth=pdepth,
)
)
[docs]
def symmetric_matrix_basis(dim: int, pdepth: int = 0) -> Basis:
"""Constant symmetric-matrix templates spanning the space of real symmetric
``dim × dim`` matrices.
For ``dim=d`` there are ``d(d+1)/2`` features: one per upper-triangle
entry ``(i,j)`` with ``i <= j``. Each feature is a ``(dim, dim)``
matrix: ``S_{(i,j)} = δ_{ia}δ_{jb} + δ_{ib}δ_{ja}`` (so the
off-diagonal templates equal 1 in both symmetric slots, diagonal
templates equal 1 on the diagonal).
Rank is ``MATRIX`` (2), and the output shape is ``(dim, dim, F)``.
"""
pairs = [(i, j) for i in range(dim) for j in range(i, dim)]
labels = [f"S{i}{j}" for (i, j) in pairs]
# Pre-build the (dim, dim, F) template at module level
import numpy as _np # deferred: only needed at build time, not at JAX eval time
tpl = _np.zeros((dim, dim, len(pairs)), dtype="float32")
for k, (i, j) in enumerate(pairs):
tpl[i, j, k] = 1.0
tpl[j, i, k] = 1.0 # symmetric: both slots
tpl_jnp = jnp.array(tpl)
def _f(x):
return tpl_jnp
return Basis(
SimpleLeaf(
func=_f,
n_features=len(pairs),
labels=tuple(labels),
descriptor="symmetric-matrix-templates",
dim=dim,
rank=Rank.MATRIX,
pdepth=pdepth,
particles_input=False,
needs_v=False,
)
)
# ---------------------------------------------------------------------------
# Named scalar parameters (rank-0 PSFs whose value is a single named param)
# ---------------------------------------------------------------------------
[docs]
def named_scalar(name: str, default=None, *, dim: int | None = None, label: str | None = None):
"""Rank-0, 1-feature PSF whose value is a single named scalar parameter.
The returned PSF carries a single :class:`ParamSpec` with shape ``()``,
optional ``default``, and label ``label or name``.
Parameters
----------
name : str
Parameter name; also the default feature label.
default : scalar or None
Optional default value. When set, the PSF can be evaluated, bound,
or passed to a simulator without explicit ``params``.
dim : int or None
Spatial dimensionality; ``None`` (default) lets it be inferred at
first call (the value is independent of ``x``).
label : str or None
Optional human-readable feature label (defaults to ``name``).
Examples
--------
>>> sigma = named_scalar("sigma", default=20.0)
>>> sigma() # uses default
Array(20., dtype=float32)
>>> sigma(params={"sigma": 30.}) # explicit override
Array(30., dtype=float32)
"""
from ..statefunc import make_psf
from ..statefunc.params import ParamSpec
spec = ParamSpec(name, (), default=default)
def _f(x, *, params):
return params[name]
return make_psf(
_f,
dim=dim,
rank=0,
n_features=1,
labels=[label or name],
descriptor=f"named-scalar({name})",
params=[spec],
)
[docs]
def time_fourier(
n_modes: int,
period: float | None = None,
*,
dim: int | None = None,
label: str | None = None,
):
r"""Rank-0 time-Fourier dictionary read from the reserved ``time`` extra.
Emits ``1 + 2 * n_modes`` parameter-free features
.. math::
\bigl\{\,1,\; \cos(k\omega t),\; \sin(k\omega t)\,\bigr\}_{k=1}^{n_\text{modes}},
\qquad \omega = 2\pi / P,
evaluated at each frame's absolute time ``t`` — the auto-injected
``time`` extra (see :meth:`TrajectoryDataset.build_extras`), so no
bookkeeping is required. Tensor it with a spatial basis to learn an
*unknown* time-dependent force field by expansion: ``time_fourier(4) *
X(dim=1)`` recovers a time-varying stiffness :math:`k(t)`, and
``time_fourier(4) * unit_vector_basis(1)`` a moving trap centre.
Sparse selection (:term:`PASTIS`) keeps only the harmonics the data
support.
Parameters
----------
n_modes : int
Number of harmonics; produces ``1 + 2 * n_modes`` features.
period : float or None
Fundamental period :math:`P`. If ``None`` (default), it defaults
to the **full trajectory duration** (read from the auto-injected
``duration`` extra), i.e. the fundamental frequency is the inverse
of the total observation time.
dim : int or None
Spatial dimensionality; ``None`` lets it be inferred (the value is
independent of ``x``).
label : str or None
Optional label prefix for the features.
Examples
--------
>>> from SFI.bases import X, time_fourier
>>> B = time_fourier(4) * X(dim=1) # learn k(t) over the trajectory
>>> inf.infer_force_linear(B)
"""
from ..statefunc import make_basis
if n_modes < 1:
raise ValueError("time_fourier requires n_modes >= 1")
keys = ("time",) if period is not None else ("time", "duration")
n_feat = 1 + 2 * n_modes
pre = (label + " ") if label else ""
labels = [f"{pre}1"]
for k in range(1, n_modes + 1):
labels += [f"{pre}cos({k}wt)", f"{pre}sin({k}wt)"]
ks = jnp.arange(1, n_modes + 1, dtype=float)
def _f(x, *, extras):
t = jnp.asarray(extras["time"], dtype=float)
P = period if period is not None else jnp.asarray(extras["duration"], dtype=float)
ang = ks * (2.0 * jnp.pi / P) * t # (n_modes,)
cs = jnp.stack([jnp.cos(ang), jnp.sin(ang)], axis=-1).reshape(-1) # (2 n_modes,)
return jnp.concatenate([jnp.ones((1,), dtype=cs.dtype), cs]) # (1 + 2 n_modes,)
return make_basis(
_f,
dim=dim,
rank=0,
n_features=n_feat,
labels=labels,
descriptor=f"time-fourier(n={n_modes})",
extras_keys=keys,
)
[docs]
def per_dataset_scalar(name: str, n_datasets: int, default=None, *, dim: int | None = None):
"""Rank-0, 1-feature PSF whose value is dataset-specific.
Carries a parameter array of shape ``(n_datasets,)`` and reads the
entry of the current dataset through the reserved
``extras["dataset_index"]`` (injected automatically by
:class:`~SFI.trajectory.TrajectoryCollection`). Compose it with
shared :func:`named_scalar` terms to fit pooled multi-experiment
models where part of the parameters is experiment-specific and the
rest is shared.
Parameters
----------
name : str
Parameter name (one entry per dataset).
n_datasets : int
Number of datasets in the collection the model will be fit on.
default : scalar or array of shape (n_datasets,), optional
Optional default value(s); a scalar is broadcast.
dim : int or None
Spatial dimensionality (``None`` → inferred).
Notes
-----
Indexed parameter access is nonlinear in the bookkeeping sense, so
the parametric estimators fit models containing this primitive on
the L-BFGS path. For the **linear estimators**, use the one-hot
route instead: :func:`dataset_indicator`.
"""
from ..statefunc import make_psf
from ..statefunc.params import ParamSpec
if default is not None:
default = jnp.broadcast_to(jnp.asarray(default, dtype=float), (n_datasets,))
spec = ParamSpec(name, (int(n_datasets),), default=default)
def _f(x, *, params, extras):
return params[name][extras["dataset_index"]]
def _specialize_at(k, _name=name, _dim=dim, _default=default):
# Condition-k slice: the (n_datasets,) param collapses to a scalar and
# the dataset_index read disappears (see StateExpr.specialize).
d_k = None if _default is None else jnp.asarray(_default)[int(k)]
spec_k = ParamSpec(_name, (), default=d_k)
def _fk(x, *, params):
return params[_name]
return make_psf(
_fk,
dim=_dim,
rank=0,
n_features=1,
labels=[_name],
descriptor=f"per-dataset-scalar({_name})@{int(k)}",
params=[spec_k],
)
return make_psf(
_f,
dim=dim,
rank=0,
n_features=1,
labels=[name],
descriptor=f"per-dataset-scalar({name})",
params=[spec],
extras_keys=("dataset_index",),
specialize_at=_specialize_at,
)
[docs]
def dataset_indicator(n_datasets: int, *, dim: int | None = None):
"""Rank-0 Basis of ``n_datasets`` one-hot features ``1{dataset == d}``.
The **linear-estimator** route to per-dataset coefficients: multiply
a feature by the indicator and concatenate, and each dataset gets an
independent linear coefficient for that feature (the Gram is
block-diagonal across datasets), PASTIS-prunable like any feature.
.. code-block:: python
B = B_shared & (dataset_indicator(n) * X(dim))
Reads the reserved ``extras["dataset_index"]`` injected by
:class:`~SFI.trajectory.TrajectoryCollection`.
"""
from ..statefunc import make_basis
n = int(n_datasets)
def _f(x, *, extras):
return (jnp.arange(n) == extras["dataset_index"]).astype(float)
def _specialize_at(k, _n=n, _dim=dim):
# Condition-k one-hot becomes a constant vector; no dataset_index read.
onehot = (jnp.arange(_n) == int(k)).astype(float)
def _fk(x):
return onehot
return make_basis(
_fk,
dim=_dim,
rank=0,
n_features=_n,
labels=[f"ds{i}" for i in range(_n)],
descriptor=f"dataset-indicator({_n})@{int(k)}",
)
return make_basis(
_f,
dim=dim,
rank=0,
n_features=n,
labels=[f"ds{i}" for i in range(n)],
descriptor=f"dataset-indicator({n})",
extras_keys=("dataset_index",),
specialize_at=_specialize_at,
)
# Per-particle inferred parameters ("params_local paralleling extras_local")
# are expressed inside interactor kernels: declare the reserved
# ``particle_index`` extra as particle-aligned and index a (P,)-shaped
# parameter with it —
#
# def local(Xk, *, params, extras):
# mob = params["mob"][extras["particle_index"]] # (K,) per edge
# ...
# inter = make_interactor(local, ..., params={"mob": (P,)},
# extras_keys=("particle_index",),
# particle_extras=("particle_index",))
#
# See :func:`SFI.statefunc.make_interactor`.
[docs]
def named_scalars(*args, **kwargs):
"""Unpack named scalar parameters, one PSF per name.
Two equivalent call styles, both with deterministic ordering:
Positional names (no defaults)::
sigma, rho, beta = named_scalars("sigma", "rho", "beta")
Keyword names with defaults (Python preserves call-site order)::
sigma, rho, beta = named_scalars(sigma=20.0, rho=8.0, beta=2.0)
Returns
-------
tuple[PSF, ...]
One :class:`PSF` per name, in the order given.
"""
if args and kwargs:
raise TypeError("named_scalars: pass either positional names OR keyword name=default, not both.")
if kwargs:
return tuple(named_scalar(name, default=val) for name, val in kwargs.items())
if not args:
raise TypeError("named_scalars: at least one name is required.")
if not all(isinstance(a, str) for a in args):
raise TypeError("named_scalars: positional arguments must be parameter names (str).")
return tuple(named_scalar(name) for name in args)