SFI.utils.maths module¶
- SFI.utils.maths.as_default_float(x)[source]¶
Return
xcast to the JAX default float dtype.Accepts any array-like; passes through
Noneunchanged so callers can use it on optional inputs (e.g.v0).
- SFI.utils.maths.default_float_dtype()[source]¶
Return JAX’s currently-active default float dtype.
Returns
float64whenJAX_ENABLE_X64is set,float32otherwise. Use this everywhere user-supplied arrays enter the simulation pipeline so that all internal arithmetic shares a single dtype andlax.scancarry-in / carry-out types always agree.
- SFI.utils.maths.fd_velocity(x, dt, *, scheme='central')[source]¶
Finite-difference velocity along the leading (time) axis.
Reconstructs
v(t) ≈ dx/dtfrom a position array using a finite-difference stencil — the same secant-velocity convention the underdamped inference engine uses internally. The output keeps the same leading time dimension asx(boundary frames fall back to a one-sided stencil).- Parameters:
x (array_like, shape
(T, ...)) – Positions sampled in time along axis 0 (e.g.(T, d)or(T, N, d)).dt (float or array_like, shape
(T - 1,)) – Time step. Scalar for uniform sampling, or per-interval spacings.scheme ({"central", "forward", "backward"}) – Interior stencil.
"central"is second-order accurate and the default;"forward"/"backward"are first-order.
- Returns:
v – Velocity estimate, same shape as
x.- Return type:
jax.Array, shape
(T, ...)
- SFI.utils.maths.solve_or_pinv(A, b, tol=1e-15)[source]¶
Solve A ⋅ x = b for x, with a fallback to the Moore–Penrose pseudo-inverse if A is singular or not square. To improve numerical stability, we first normalize A by its diagonal: A_norm = D^{-1} A D^{-1}, b_norm = D^{-1} b, solve A_norm ⋅ x_norm = b_norm, and then recover x = D^{-1} x_norm.
This ensures that the diagonal entries of A_norm are 1 (assuming A has positive diagonal), which often makes the linear solve or pseudo-inverse more robust when A has widely varying scales on its diagonal.
- Parameters:
A (jax.Array, shape (k, k)) – The matrix to solve against. We assume that A has nonnegative diagonal entries; if any diagonal entry is zero, we clip it to a small floor to avoid division by zero.
b (jax.Array, shape (k,)) – The right-hand side vector.
tol (float, default=1e-15) – The tolerance for the pseudo-inverse. If A_norm is effectively singular, we compute x_norm = pinv(A_norm, rcond=tol) @ b_norm.
- Returns:
x –
- The solution vector to A ⋅ x = b, computed as follows:
d_i = sqrt(max(A_{ii}, tol)) (we floor each diagonal entry to tol > 0 to avoid zero divides)
A_norm = D_inv @ A @ D_inv where D_inv = diag(1 / d_i) b_norm = b / d
Solve A_norm ⋅ x_norm = b_norm: - if A_norm is non-singular, use a direct solver - otherwise, fall back to x_norm = pinv(A_norm) @ b_norm
Recover x = x_norm / d
- Return type:
jax.Array, shape (k,)
- SFI.utils.maths.sqrtm_psd(A, eps=0.0)[source]¶
Symmetric PSD matrix square root, applied to the last two (matrix) axes. - Symmetrizes input for stability. - Clips negative eigenvalues to 0 (PSD) then takes sqrt. - Re-symmetrizes the result to kill small asymmetries from numerics.
- Parameters:
A (Array)
eps (float)
- Return type:
Array
- SFI.utils.maths.stable_pinv(G)[source]¶
Numerically-stable pseudo-inverse of a Gram matrix.
Normalizes G by its diagonal before inversion so that all diagonal entries of the rescaled matrix are 1. This avoids ill-conditioning when basis functions have very different scales.
Algorithm: let d_i = sqrt(G_{ii}) (clamped to 1 when zero). Compute pinv(G / outer(d, d)) and rescale by outer(1/d, 1/d) to recover the pseudo-inverse of the original G.