SFI.utils package

SFI.utils — Mathematical, formatting, and plotting utilities.

SFI.utils.as_default_float(x)[source]

Return x cast to the JAX default float dtype.

Accepts any array-like; passes through None unchanged so callers can use it on optional inputs (e.g. v0).

SFI.utils.default_float_dtype()[source]

Return JAX’s currently-active default float dtype.

Returns float64 when JAX_ENABLE_X64 is set, float32 otherwise. Use this everywhere user-supplied arrays enter the simulation pipeline so that all internal arithmetic shares a single dtype and lax.scan carry-in / carry-out types always agree.

SFI.utils.fd_velocity(x, dt, *, scheme='central')[source]

Finite-difference velocity along the leading (time) axis.

Reconstructs v(t) dx/dt from a position array using a finite-difference stencil — the same secant-velocity convention the underdamped inference engine uses internally. The output keeps the same leading time dimension as x (boundary frames fall back to a one-sided stencil).

Parameters:
  • x (array_like, shape (T, ...)) – Positions sampled in time along axis 0 (e.g. (T, d) or (T, N, d)).

  • dt (float or array_like, shape (T - 1,)) – Time step. Scalar for uniform sampling, or per-interval spacings.

  • scheme ({"central", "forward", "backward"}) – Interior stencil. "central" is second-order accurate and the default; "forward" / "backward" are first-order.

Returns:

v – Velocity estimate, same shape as x.

Return type:

jax.Array, shape (T, ...)

SFI.utils.model_summary(labels, coefficients, *, stderr=None, support=None, coeffs_true=None, support_true=None, title='Coefficient Table', max_rows=60, significance_thresholds=(2.0, 10.0, 100.0), auto_labels=False)[source]

Build a human-readable coefficient table with SNR and significance.

Only active (support) coefficients are shown in the table body. Zeroed basis functions are listed separately below, unless labels are auto-generated.

Parameters:
  • labels (sequence of str) – One label per basis function.

  • coefficients (1-D array) – Coefficient vector (length must match labels or support).

  • stderr (1-D array or None) – Standard errors (same length as coefficients). These reflect sampling error only; discretization (finite-time-step) bias is not included.

  • support (1-D int array or None) – Indices into labels that coefficients correspond to. If None, full support is assumed (all basis functions have non-zero coefficients).

  • title (str) – Section header printed above the table.

  • max_rows (int) – If the table exceeds this many rows, truncate the middle.

  • significance_thresholds (tuple of float) – Three SNR thresholds (in multiples of stderr) for the *, **, and *** significance levels. Defaults (2.0, 10.0, 100.0).

  • auto_labels (bool) – If True, labels were auto-generated (e.g. b0, b1, …) and the list of zeroed functions is suppressed.

  • coeffs_true (ndarray | None)

  • support_true (ndarray | None)

Returns:

Ready-to-print multi-line table. Terms are marked , *, or * according to *significance_thresholds*; * terms are bold, * and ** terms are normal weight, non-significant active terms are dimmed.

Return type:

str

SFI.utils.print_model_comparison(inferences, labels, *, extra_cols=None, metrics=None, title='Model Comparison')[source]

Build a multi-model comparison table from several inference objects.

Parameters:
  • inferences – Sequence of fitted inference objects.

  • labels – One name per inference (row labels).

  • metrics – Attribute names to read from each inference (default ["n_params", "NMSE_force", "force_predicted_MSE"]). The special "n_params" counts force_coefficients_full.

  • extra_cols – Optional {column_name: {label: value}} of caller-supplied cells.

  • title (str) – Header line.

Returns:

Ready-to-print table.

Return type:

str

SFI.utils.solve_or_pinv(A, b, tol=1e-15)[source]

Solve A ⋅ x = b for x, with a fallback to the Moore–Penrose pseudo-inverse if A is singular or not square. To improve numerical stability, we first normalize A by its diagonal: A_norm = D^{-1} A D^{-1}, b_norm = D^{-1} b, solve A_norm ⋅ x_norm = b_norm, and then recover x = D^{-1} x_norm.

This ensures that the diagonal entries of A_norm are 1 (assuming A has positive diagonal), which often makes the linear solve or pseudo-inverse more robust when A has widely varying scales on its diagonal.

Parameters:
  • A (jax.Array, shape (k, k)) – The matrix to solve against. We assume that A has nonnegative diagonal entries; if any diagonal entry is zero, we clip it to a small floor to avoid division by zero.

  • b (jax.Array, shape (k,)) – The right-hand side vector.

  • tol (float, default=1e-15) – The tolerance for the pseudo-inverse. If A_norm is effectively singular, we compute x_norm = pinv(A_norm, rcond=tol) @ b_norm.

Returns:

x

The solution vector to A ⋅ x = b, computed as follows:
  1. d_i = sqrt(max(A_{ii}, tol)) (we floor each diagonal entry to tol > 0 to avoid zero divides)

  2. A_norm = D_inv @ A @ D_inv where D_inv = diag(1 / d_i) b_norm = b / d

  3. Solve A_norm ⋅ x_norm = b_norm: - if A_norm is non-singular, use a direct solver - otherwise, fall back to x_norm = pinv(A_norm) @ b_norm

  4. Recover x = x_norm / d

Return type:

jax.Array, shape (k,)

SFI.utils.sqrtm_psd(A, eps=0.0)[source]

Symmetric PSD matrix square root, applied to the last two (matrix) axes. - Symmetrizes input for stability. - Clips negative eigenvalues to 0 (PSD) then takes sqrt. - Re-symmetrizes the result to kill small asymmetries from numerics.

Parameters:
  • A (Array)

  • eps (float)

Return type:

Array

SFI.utils.stable_pinv(G)[source]

Numerically-stable pseudo-inverse of a Gram matrix.

Normalizes G by its diagonal before inversion so that all diagonal entries of the rescaled matrix are 1. This avoids ill-conditioning when basis functions have very different scales.

Algorithm: let d_i = sqrt(G_{ii}) (clamped to 1 when zero). Compute pinv(G / outer(d, d)) and rescale by outer(1/d, 1/d) to recover the pseudo-inverse of the original G.

Submodules