Source code for SFI.bases.monomials

import functools
import itertools

import jax.numpy as jnp

from ..statefunc import Basis, Rank, make_basis

_RANK_ALIASES = {
    "scalar": None,
    "vector": "vector",
    "matrix": "symmetric",
    "symmetric_matrix": "symmetric",
    "identity_matrix": "identity",
}


def _lift(B: Basis, rank: str, dim: int):
    """Lift a scalar Basis to higher rank via Cartesian product with a structural basis."""
    import warnings

    from .constants import identity_matrix_basis, symmetric_matrix_basis, unit_vector_basis

    if rank not in _RANK_ALIASES:
        raise ValueError(f"Unknown rank={rank!r}. Choose from: {list(_RANK_ALIASES.keys())}")
    mode = _RANK_ALIASES[rank]
    if mode is None:
        return B
    # Suppress the Cartesian-product warning — this is intentional lifting.
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", ".*Cartesian product.*")
        if mode == "vector":
            return B * unit_vector_basis(dim)
        if mode == "identity":
            return B * identity_matrix_basis(dim)
        if mode == "symmetric":
            return B * symmetric_matrix_basis(dim)
    raise AssertionError("unreachable")  # pragma: no cover


@functools.lru_cache(maxsize=64)
def _exponent_mats(dim: int, degree: int, include_x: bool, include_v: bool):
    """Return (ex_x, ex_v, labels) for all monomials of exact total degree."""
    if not include_x and not include_v:
        raise ValueError("At least one of include_x / include_v must be True.")
    L = (dim if include_x else 0) + (dim if include_v else 0)

    # stars-and-bars enumeration for total degree == degree
    if degree == 0:
        gammas = [tuple(0 for _ in range(L))]
    else:
        gammas = []
        for stars in itertools.combinations_with_replacement(range(L), degree):
            arr = [0] * L
            for s in stars:
                arr[s] += 1
            gammas.append(tuple(arr))

    ex_x, ex_v, labels = [], [], []
    for g in gammas:
        alpha = g[:dim] if include_x else (0,) * dim
        beta = g[-dim:] if include_v else (0,) * dim

        parts = []
        for k, p in enumerate(alpha):
            if p:
                parts.append(f"x{k}^{p}" if p > 1 else f"x{k}")
        for k, p in enumerate(beta):
            if p:
                parts.append(f"v{k}^{p}" if p > 1 else f"v{k}")
        lab = "1" if not parts else "·".join(parts)

        ex_x.append(alpha)
        ex_v.append(beta)
        labels.append(lab)

    ex_x = jnp.array(ex_x, dtype=jnp.int32)  # (F, dim)
    ex_v = jnp.array(ex_v, dtype=jnp.int32)  # (F, dim)
    return ex_x, ex_v, tuple(labels)


[docs] def monomials_degree( degree: int, *, dim: int, include_x: bool = True, include_v: bool = False, rank: str = "scalar", ): """ All monomials of **exact** total degree in x and/or v. Parameters ---------- degree : int Exact total polynomial degree. dim : int Spatial dimension. include_x, include_v : bool Which variables to include. rank : str Output rank. ``'scalar'`` (default) returns a scalar Basis with ``F`` features. ``'vector'`` lifts to rank-1 via Cartesian product with ``unit_vector_basis(dim)`` (F × dim features). ``'matrix'`` / ``'symmetric_matrix'`` lifts to rank-2 via ``symmetric_matrix_basis(dim)`` (F × dim(dim+1)/2 features). ``'identity_matrix'`` lifts via ``identity_matrix_basis(dim)`` (F × 1 features, isotropic). """ ex_x, ex_v, labels = _exponent_mats(dim, degree, include_x, include_v) F = int(ex_x.shape[0]) use_x, use_v = bool(include_x), bool(include_v) # Precompute static stuff outside eval (captured in the closure) Dx = int(ex_x.max()) if use_x else 0 # max degree for x Dv = int(ex_v.max()) if use_v else 0 # max degree for v Ix = ex_x.T if use_x else None # (dim, F) gather indices Iv = ex_v.T if use_v else None # (dim, F) def _power_table_fixed(z, D): """ z: (dim,) -> (dim, D+1) = [1, z, z^2, ..., z^D] Important: avoid jnp.power with float exponents because its JVP/VJP uses log(z) and yields NaNs at z=0 even for integer powers (e.g. 0^0 pathways). This iterative construction is polynomial and differentiable everywhere. """ dim = z.shape[0] one = jnp.ones((dim, 1), dtype=z.dtype) if D == 0: return one # powers[:, k] = z^k powers = [one, z[:, None]] for _ in range(2, D + 1): powers.append(powers[-1] * z[:, None]) return jnp.concatenate(powers, axis=1) if use_x and use_v: def _eval(x, *, v=None, mask=None): # x,v: (dim,) ; returns (F,) if v is None: v = jnp.zeros_like(x) Tx = _power_table_fixed(x, Dx) # (dim, Dx+1) Tv = _power_table_fixed(v, Dv) # (dim, Dv+1) vx = jnp.take_along_axis(Tx, Ix, axis=1) # (dim, F) vv = jnp.take_along_axis(Tv, Iv, axis=1) # (dim, F) return jnp.prod(vx * vv, axis=0) # (F,) elif use_x: def _eval(x, *, v=None, mask=None): Tx = _power_table_fixed(x, Dx) # (dim, Dx+1) vx = jnp.take_along_axis(Tx, Ix, axis=1) # (dim, F) return jnp.prod(vx, axis=0) # (F,) else: # use_v only def _eval(x, *, v=None, mask=None): Tv = _power_table_fixed(v if v is not None else jnp.zeros_like(x), Dv) vv = jnp.take_along_axis(Tv, Iv, axis=1) # (dim, F) return jnp.prod(vv, axis=0) # (F,) B = make_basis( func=_eval, dim=dim, rank=Rank.SCALAR, n_features=F, needs_v=use_v, labels=list(labels), ) return _lift(B, rank, dim)
[docs] def monomials_up_to( order: int, *, dim: int, include_constant: bool = True, include_x: bool = True, include_v: bool = False, rank: str = "scalar", ): r""" Concatenate degree-wise monomial bases for degrees 0..order (ascending). .. physics:: Multivariate polynomial basis :label: polynomial-basis :category: Basis functions .. math:: f_\\alpha(x) = \\prod_{k=1}^{d} x_k^{\\alpha_k}, \\qquad |\\alpha| \\le \\texttt{order} Full polynomial dictionary up to a given total degree, optionally including velocity monomials and lifted to vector or matrix rank. Parameters ---------- order : int Maximum total polynomial degree. dim : int Spatial dimension. include_constant : bool If False, skip degree-0 (constant) term. include_x, include_v : bool Which variables to include. rank : str Output tensor rank. See :func:`monomials_degree` for allowed values: ``'scalar'``, ``'vector'``, ``'matrix'`` / ``'symmetric_matrix'``, ``'identity_matrix'``. For force inference, use ``rank='vector'`` (the most common choice). """ start = 0 if include_constant else 1 if start > order: raise ValueError("No monomials to include: set include_constant=True or order>=1.") Bs = [monomials_degree(d, dim=dim, include_x=include_x, include_v=include_v) for d in range(start, order + 1)] B = Basis.stack(Bs) return _lift(B, rank, dim)