SFI.utils.maths module

SFI.utils.maths.as_default_float(x)[source]

Return x cast to the JAX default float dtype.

Accepts any array-like; passes through None unchanged so callers can use it on optional inputs (e.g. v0).

SFI.utils.maths.default_float_dtype()[source]

Return JAX’s currently-active default float dtype.

Returns float64 when JAX_ENABLE_X64 is set, float32 otherwise. Use this everywhere user-supplied arrays enter the simulation pipeline so that all internal arithmetic shares a single dtype and lax.scan carry-in / carry-out types always agree.

SFI.utils.maths.fd_velocity(x, dt, *, scheme='central')[source]

Finite-difference velocity along the leading (time) axis.

Reconstructs v(t) dx/dt from a position array using a finite-difference stencil — the same secant-velocity convention the underdamped inference engine uses internally. The output keeps the same leading time dimension as x (boundary frames fall back to a one-sided stencil).

Parameters:
  • x (array_like, shape (T, ...)) – Positions sampled in time along axis 0 (e.g. (T, d) or (T, N, d)).

  • dt (float or array_like, shape (T - 1,)) – Time step. Scalar for uniform sampling, or per-interval spacings.

  • scheme ({"central", "forward", "backward"}) – Interior stencil. "central" is second-order accurate and the default; "forward" / "backward" are first-order.

Returns:

v – Velocity estimate, same shape as x.

Return type:

jax.Array, shape (T, ...)

SFI.utils.maths.solve_or_pinv(A, b, tol=1e-15)[source]

Solve A ⋅ x = b for x, with a fallback to the Moore–Penrose pseudo-inverse if A is singular or not square. To improve numerical stability, we first normalize A by its diagonal: A_norm = D^{-1} A D^{-1}, b_norm = D^{-1} b, solve A_norm ⋅ x_norm = b_norm, and then recover x = D^{-1} x_norm.

This ensures that the diagonal entries of A_norm are 1 (assuming A has positive diagonal), which often makes the linear solve or pseudo-inverse more robust when A has widely varying scales on its diagonal.

Parameters:
  • A (jax.Array, shape (k, k)) – The matrix to solve against. We assume that A has nonnegative diagonal entries; if any diagonal entry is zero, we clip it to a small floor to avoid division by zero.

  • b (jax.Array, shape (k,)) – The right-hand side vector.

  • tol (float, default=1e-15) – The tolerance for the pseudo-inverse. If A_norm is effectively singular, we compute x_norm = pinv(A_norm, rcond=tol) @ b_norm.

Returns:

x

The solution vector to A ⋅ x = b, computed as follows:
  1. d_i = sqrt(max(A_{ii}, tol)) (we floor each diagonal entry to tol > 0 to avoid zero divides)

  2. A_norm = D_inv @ A @ D_inv where D_inv = diag(1 / d_i) b_norm = b / d

  3. Solve A_norm ⋅ x_norm = b_norm: - if A_norm is non-singular, use a direct solver - otherwise, fall back to x_norm = pinv(A_norm) @ b_norm

  4. Recover x = x_norm / d

Return type:

jax.Array, shape (k,)

SFI.utils.maths.sqrtm_psd(A, eps=0.0)[source]

Symmetric PSD matrix square root, applied to the last two (matrix) axes. - Symmetrizes input for stability. - Clips negative eigenvalues to 0 (PSD) then takes sqrt. - Re-symmetrizes the result to kill small asymmetries from numerics.

Parameters:
  • A (Array)

  • eps (float)

Return type:

Array

SFI.utils.maths.stable_pinv(G)[source]

Numerically-stable pseudo-inverse of a Gram matrix.

Normalizes G by its diagonal before inversion so that all diagonal entries of the rescaled matrix are 1. This avoids ill-conditioning when basis functions have very different scales.

Algorithm: let d_i = sqrt(G_{ii}) (clamped to 1 when zero). Compute pinv(G / outer(d, d)) and rescale by outer(1/d, 1/d) to recover the pseudo-inverse of the original G.