Source code for SFI.utils.maths

import jax
import jax.numpy as jnp
import jax.scipy.linalg as jsp_la
from jax import dtypes as _jdtypes
from jax import jit

# ---------------------------------------------------------------------
#  Mathematical utilities
# ---------------------------------------------------------------------


[docs] def default_float_dtype(): """Return JAX's currently-active default float dtype. Returns ``float64`` when ``JAX_ENABLE_X64`` is set, ``float32`` otherwise. Use this everywhere user-supplied arrays enter the simulation pipeline so that all internal arithmetic shares a single dtype and ``lax.scan`` carry-in / carry-out types always agree. """ return _jdtypes.canonicalize_dtype(jnp.float64)
[docs] def as_default_float(x): """Return ``x`` cast to the JAX default float dtype. Accepts any array-like; passes through ``None`` unchanged so callers can use it on optional inputs (e.g. ``v0``). """ if x is None: return None return jnp.asarray(x, dtype=default_float_dtype())
[docs] def fd_velocity(x, dt, *, scheme: str = "central"): """Finite-difference velocity along the leading (time) axis. Reconstructs ``v(t) ≈ dx/dt`` from a position array using a finite-difference stencil — the same secant-velocity convention the underdamped inference engine uses internally. The output keeps the same leading time dimension as ``x`` (boundary frames fall back to a one-sided stencil). Parameters ---------- x : array_like, shape ``(T, ...)`` Positions sampled in time along axis 0 (e.g. ``(T, d)`` or ``(T, N, d)``). dt : float or array_like, shape ``(T - 1,)`` Time step. Scalar for uniform sampling, or per-interval spacings. scheme : {"central", "forward", "backward"} Interior stencil. ``"central"`` is second-order accurate and the default; ``"forward"`` / ``"backward"`` are first-order. Returns ------- v : jax.Array, shape ``(T, ...)`` Velocity estimate, same shape as ``x``. """ x = jnp.asarray(x) if x.shape[0] < 2: raise ValueError("fd_velocity needs at least 2 time points.") dt = jnp.asarray(dt) n_extra = x.ndim - 1 def _bcast(a): # reshape a per-step (T-1,) array to broadcast over the trailing axes return a.reshape(a.shape + (1,) * n_extra) # forward differences between consecutive frames, shape (T-1, ...) fwd = (x[1:] - x[:-1]) / (dt if dt.ndim == 0 else _bcast(dt)) if scheme == "forward": return jnp.concatenate([fwd, fwd[-1:]], axis=0) if scheme == "backward": return jnp.concatenate([fwd[:1], fwd], axis=0) if scheme == "central": denom = (2.0 * dt) if dt.ndim == 0 else _bcast(dt[:-1] + dt[1:]) central = (x[2:] - x[:-2]) / denom return jnp.concatenate([fwd[:1], central, fwd[-1:]], axis=0) raise ValueError(f"Unknown scheme {scheme!r}; expected central/forward/backward.")
[docs] def stable_pinv(G): """Numerically-stable pseudo-inverse of a Gram matrix. Normalizes G by its diagonal before inversion so that all diagonal entries of the rescaled matrix are 1. This avoids ill-conditioning when basis functions have very different scales. Algorithm: let d_i = sqrt(G_{ii}) (clamped to 1 when zero). Compute pinv(G / outer(d, d)) and rescale by outer(1/d, 1/d) to recover the pseudo-inverse of the original G. """ # Normalize the matrix before computing the pseudo-inverse, for numerical stability G_norm = jnp.sqrt(jnp.diag(G)) # Extract diagonal elements as normalization factors # Avoid division by zero: if G_norm[i] is zero, keep it zero in scaling safe_G_norm = jnp.where(G_norm > 0, G_norm, 1.0) # Replace 0 with 1 to avoid NaN in division return jnp.linalg.pinv(G / jnp.outer(safe_G_norm, safe_G_norm)) * jnp.outer(1 / safe_G_norm, 1 / safe_G_norm)
[docs] @jit def sqrtm_psd(A: jax.Array, eps: float = 0.0) -> jax.Array: """ Symmetric PSD matrix square root, applied to the last two (matrix) axes. - Symmetrizes input for stability. - Clips negative eigenvalues to 0 (PSD) then takes sqrt. - Re-symmetrizes the result to kill small asymmetries from numerics. """ A = 0.5 * (A + jnp.swapaxes(A, -1, -2)) w, V = jnp.linalg.eigh(A) w = jnp.clip(w, min=0.0) # PSD safeguard if eps: w = w + eps # optional jitter if you like S = (V * jnp.sqrt(w)[..., None, :]) @ jnp.swapaxes(V, -1, -2) return 0.5 * (S + jnp.swapaxes(S, -1, -2))
[docs] def solve_or_pinv(A: jax.Array, b: jax.Array, tol: float = 1e-15) -> jax.Array: """ Solve A ⋅ x = b for x, with a fallback to the Moore–Penrose pseudo-inverse if A is singular or not square. To improve numerical stability, we first normalize A by its diagonal: A_norm = D^{-1} A D^{-1}, b_norm = D^{-1} b, solve A_norm ⋅ x_norm = b_norm, and then recover x = D^{-1} x_norm. This ensures that the diagonal entries of A_norm are 1 (assuming A has positive diagonal), which often makes the linear solve or pseudo-inverse more robust when A has widely varying scales on its diagonal. Parameters ---------- A : jax.Array, shape (k, k) The matrix to solve against. We assume that A has nonnegative diagonal entries; if any diagonal entry is zero, we clip it to a small floor to avoid division by zero. b : jax.Array, shape (k,) The right-hand side vector. tol : float, default=1e-15 The tolerance for the pseudo-inverse. If A_norm is effectively singular, we compute x_norm = pinv(A_norm, rcond=tol) @ b_norm. Returns ------- x : jax.Array, shape (k,) The solution vector to A ⋅ x = b, computed as follows: 1) d_i = sqrt(max(A_{ii}, tol)) (we floor each diagonal entry to tol > 0 to avoid zero divides) 2) A_norm = D_inv @ A @ D_inv where D_inv = diag(1 / d_i) b_norm = b / d 3) Solve A_norm ⋅ x_norm = b_norm: - if A_norm is non-singular, use a direct solver - otherwise, fall back to x_norm = pinv(A_norm) @ b_norm 4) Recover x = x_norm / d """ # 1) Extract and “floor” the diagonal of A to avoid zeros or negatives. # If A_{ii} is very small or negative (due to numerical noise), we floor it. diag_A = jnp.diag(A) # shape (k,) # Clip to tol to avoid sqrt of zero or negative diag_clipped = jnp.clip(diag_A, min=tol) # shape (k,) d = jnp.sqrt(diag_clipped) # shape (k,) # 2) Build the inverse scaling matrix D_inv = diag(1 / d_i) # We do this by dividing each row of A by d_i and each column by d_j. # More precisely: A_norm[i,j] = A[i,j] / (d[i] * d[j]). D_inv = 1.0 / d # shape (k,) # Use broadcasting to normalize A: first divide each row by d, # then each column by d. A_norm = (A * D_inv[:, None]) * D_inv[None, :] # 3) Normalize the RHS vector b as well: b_norm = b / d. b_norm = b / d # shape (k,) # 4) Attempt a direct solve of A_norm ⋅ x_norm = b_norm. # If A_norm is non-singular, this will succeed. try: # We do not assume positive-definite here, so use 'gen' unless we detect # symmetry and positive definiteness. For simplicity, assume 'gen'. x_norm = jsp_la.solve(A_norm, b_norm, assume_a="gen") except Exception: # broad catch: JAX/XLA doesn't expose a stable LinAlgError type # 4a) Fallback: if A_norm is singular or not square, use pseudo-inverse. x_norm = jnp.linalg.pinv(A_norm, rtol=tol) @ b_norm # 5) Scale back to obtain the final solution: x_i = x_norm[i] / d[i]. x = x_norm / d # shape (k,) return x