Source code for SFI.bases.linear

# SFI.bases.linear
# =================
# Lightweight linear utilities for building first-order bases in x and v.
#
# Contract assumed here:
# - User functions are called on a SINGLE SAMPLE (no leading batch axes).
# - For rank=VECTOR leaves, x and v arrive as shape (dim,).
# - For rank=SCALAR leaves, the function may return () or (k,) where k=n_features.
# - Feature axis must be present in the final leaf output; for n_features=1 on
#   a scalar leaf, BasisLeaf will auto-insert it from a scalar return.
#
# pdepth=0 (non-interacting) everywhere in this helper module.

from __future__ import annotations

from typing import Optional, Sequence

import jax.numpy as jnp

from ..statefunc import Basis, Rank, make_basis
from .monomials import monomials_degree  # degree-specific builder

__all__ = [
    "linear_basis",
    "X",
    "V",
    "x_coordinate",
    "x_coordinates",
    "v_coordinate",
    "v_coordinates",
    "x_components",
    "v_components",
    "unit_axes",
    "frame",
]


[docs] def linear_basis(dim: int, *, include_x: bool = True, include_v: bool = False): """ Degree-1 monomial basis in (x, v). Parameters ---------- dim : int Spatial dimension. include_x : bool Include linear x terms. include_v : bool Include linear v terms. Returns ------- Basis Rank-1 (vector) basis concatenating requested degree-1 monomials. """ return monomials_degree(1, dim=dim, include_x=include_x, include_v=include_v)
[docs] def X(dim: int, *, label: Optional[str] = None) -> Basis: """ Identity in x with an explicit feature axis. Input : x ∈ R^dim Output: Y ∈ R^{dim×1} """ def _eval(x): return x[:, None] # (dim, 1) label = "x" if label is None else label return make_basis( func=_eval, dim=dim, rank=Rank.VECTOR, n_features=1, needs_v=False, labels=[label], )
[docs] def V(dim: int, *, label: Optional[str] = None) -> Basis: """ Identity in v with an explicit feature axis. Input : v ∈ R^dim (provided via keyword v=...) Output: Y ∈ R^{dim×1} """ def _eval(x, v): return v[:, None] # (dim, 1) label = "v" if label is None else label return make_basis( func=_eval, dim=dim, rank=Rank.VECTOR, n_features=1, needs_v=True, labels=[label], )
[docs] def x_coordinate(index: int, *, dim: int, label: Optional[str] = None) -> Basis: """ Single x-coordinate as a scalar feature. Input : x ∈ R^dim Return: scalar (); BasisLeaf will auto-insert feature axis → (1,) """ def _eval(x): return x[index] # () label = f"x{index}" if label is None else label return make_basis( func=_eval, dim=dim, rank=Rank.SCALAR, n_features=1, needs_v=False, labels=[label], )
[docs] def field_component(index: int, *, n_fields: int, label: Optional[str] = None) -> Basis: """Extract a single field component from an SPDE state vector. Alias for :func:`x_coordinate` with SPDE-oriented naming. Parameters ---------- index : int Zero-based index of the field component to extract. n_fields : int Total number of field components per grid site (= ``dim``). label : str, optional Human-readable label; defaults to ``"field[{index}]"``. """ if label is None: label = f"field[{index}]" return x_coordinate(index, dim=n_fields, label=label)
[docs] def x_coordinates(indices: Sequence[int], *, dim: int, labels: Optional[Sequence[str]] = None) -> Basis: """ Multiple x-coordinates as scalar features. Input : x ∈ R^dim Output: y ∈ R^{k} with k=len(indices) """ indices = jnp.array(indices) def _eval(x): return x[indices] # (k,) if labels is None: labels = [f"x{i}" for i in indices] elif len(labels) != len(indices): raise ValueError("x_coordinates: labels length must match number of indices") return make_basis( func=_eval, dim=dim, rank=Rank.SCALAR, n_features=len(indices), needs_v=False, labels=list(labels), )
[docs] def v_coordinate(index: int, *, dim: int, label: Optional[str] = None) -> Basis: """ Single v-coordinate as a scalar feature. Input : v ∈ R^dim (provided via keyword v=...) Return: scalar (); BasisLeaf will auto-insert feature axis → (1,) """ def _eval(x, v): return v[index] # () label = f"v{index}" if label is None else label return make_basis( func=_eval, dim=dim, rank=Rank.SCALAR, n_features=1, needs_v=True, labels=[label], )
[docs] def v_coordinates(indices: Sequence[int], *, dim: int, labels: Optional[Sequence[str]] = None) -> Basis: """ Multiple v-coordinates as scalar features. Input : v ∈ R^dim (provided via keyword v=...) Output: y ∈ R^{k} with k=len(indices) """ indices = jnp.array(indices) def _eval(x, v): return v[indices] # (k,) if labels is None: labels = [f"v{i}" for i in indices] elif len(labels) != len(indices): raise ValueError("v_coordinates: labels length must match number of indices") return make_basis( func=_eval, dim=dim, rank=Rank.SCALAR, n_features=len(indices), needs_v=True, labels=list(labels), )
# --------------------------------------------------------------------------- # Component / axis unpackers # --------------------------------------------------------------------------- _DEFAULT_X_LABELS = ("x", "y", "z", "w") _DEFAULT_V_LABELS = ("vx", "vy", "vz", "vw") _DEFAULT_E_LABELS = ("ex", "ey", "ez", "ew") def _auto_labels(dim: int, defaults: Sequence[str], prefix: str) -> list[str]: if dim <= len(defaults): return [defaults[i] for i in range(dim)] return [f"{prefix}{i}" for i in range(dim)]
[docs] def x_components(dim: int, *, labels: Optional[Sequence[str]] = None) -> tuple[Basis, ...]: """Unpack scalar x-coordinate bases, one per axis. >>> x, y, z = x_components(3) Each returned basis is rank-0 with one feature. Labels default to ``("x", "y", "z", "w")`` for ``dim <= 4`` and ``("x0", "x1", ...)`` otherwise. """ if labels is None: labels = _auto_labels(dim, _DEFAULT_X_LABELS, "x") elif len(labels) != dim: raise ValueError(f"x_components: labels length ({len(labels)}) must equal dim ({dim})") return tuple(x_coordinate(i, dim=dim, label=labels[i]) for i in range(dim))
[docs] def v_components(dim: int, *, labels: Optional[Sequence[str]] = None) -> tuple[Basis, ...]: """Unpack scalar v-coordinate bases, one per axis. >>> vx, vy, vz = v_components(3) """ if labels is None: labels = _auto_labels(dim, _DEFAULT_V_LABELS, "v") elif len(labels) != dim: raise ValueError(f"v_components: labels length ({len(labels)}) must equal dim ({dim})") return tuple(v_coordinate(i, dim=dim, label=labels[i]) for i in range(dim))
[docs] def unit_axes(dim: int, *, labels: Optional[Sequence[str]] = None) -> tuple[Basis, ...]: """Unpack unit-vector bases (one per spatial axis). >>> ex, ey, ez = unit_axes(3) Each returned basis is rank-1 with a single feature carrying the unit vector along that axis. Labels default to ``("ex", "ey", "ez", "ew")`` for ``dim <= 4`` and ``("e0", "e1", ...)`` otherwise. """ # local import to avoid a circular import at module load from .constants import unit_vector_basis if labels is None: labels = _auto_labels(dim, _DEFAULT_E_LABELS, "e") elif len(labels) != dim: raise ValueError(f"unit_axes: labels length ({len(labels)}) must equal dim ({dim})") out = [] for i in range(dim): e = unit_vector_basis(dim, axes=[i]) # Override the leaf's label with the friendly default. Bypass the # Equinox freeze to patch labels after construction; fragile if # Equinox changes its freeze mechanism — prefer constructing with # labels set. leaf = e.root object.__setattr__(leaf, "labels", (labels[i],)) out.append(e) return tuple(out)
[docs] def frame( dim: int, *, velocity: bool = False, x_labels: Optional[Sequence[str]] = None, v_labels: Optional[Sequence[str]] = None, e_labels: Optional[Sequence[str]] = None, ) -> tuple[Basis, ...]: """Default compositional frame: constant ``1`` + coordinate scalars + unit axes. Overdamped (``velocity=False``):: one, *x_components(dim), *unit_axes(dim) Underdamped (``velocity=True``):: one, *x_components(dim), *v_components(dim), *unit_axes(dim) Examples -------- >>> one, x, y, z, ex, ey, ez = frame(3) >>> one, x, y, z, vx, vy, vz, ex, ey, ez = frame(3, velocity=True) Custom labels (useful for ``dim > 4``): >>> bundle = frame(5, x_labels=["q0","q1","q2","q3","q4"]) """ from .constants import ones_basis pieces: list[Basis] = [ones_basis(dim)] pieces.extend(x_components(dim, labels=x_labels)) if velocity: pieces.extend(v_components(dim, labels=v_labels)) pieces.extend(unit_axes(dim, labels=e_labels)) return tuple(pieces)