SFI.utils package¶
SFI.utils — Mathematical, formatting, and plotting utilities.
- SFI.utils.as_default_float(x)[source]¶
Return
xcast to the JAX default float dtype.Accepts any array-like; passes through
Noneunchanged so callers can use it on optional inputs (e.g.v0).
- SFI.utils.default_float_dtype()[source]¶
Return JAX’s currently-active default float dtype.
Returns
float64whenJAX_ENABLE_X64is set,float32otherwise. Use this everywhere user-supplied arrays enter the simulation pipeline so that all internal arithmetic shares a single dtype andlax.scancarry-in / carry-out types always agree.
- SFI.utils.fd_velocity(x, dt, *, scheme='central')[source]¶
Finite-difference velocity along the leading (time) axis.
Reconstructs
v(t) ≈ dx/dtfrom a position array using a finite-difference stencil — the same secant-velocity convention the underdamped inference engine uses internally. The output keeps the same leading time dimension asx(boundary frames fall back to a one-sided stencil).- Parameters:
x (array_like, shape
(T, ...)) – Positions sampled in time along axis 0 (e.g.(T, d)or(T, N, d)).dt (float or array_like, shape
(T - 1,)) – Time step. Scalar for uniform sampling, or per-interval spacings.scheme ({"central", "forward", "backward"}) – Interior stencil.
"central"is second-order accurate and the default;"forward"/"backward"are first-order.
- Returns:
v – Velocity estimate, same shape as
x.- Return type:
jax.Array, shape
(T, ...)
- SFI.utils.model_summary(labels, coefficients, *, stderr=None, support=None, coeffs_true=None, support_true=None, title='Coefficient Table', max_rows=60, significance_thresholds=(2.0, 10.0, 100.0), auto_labels=False)[source]¶
Build a human-readable coefficient table with SNR and significance.
Only active (support) coefficients are shown in the table body. Zeroed basis functions are listed separately below, unless labels are auto-generated.
- Parameters:
labels (sequence of str) – One label per basis function.
coefficients (1-D array) – Coefficient vector (length must match labels or support).
stderr (1-D array or None) – Standard errors (same length as coefficients). These reflect sampling error only; discretization (finite-time-step) bias is not included.
support (1-D int array or None) – Indices into labels that coefficients correspond to. If
None, full support is assumed (all basis functions have non-zero coefficients).title (str) – Section header printed above the table.
max_rows (int) – If the table exceeds this many rows, truncate the middle.
significance_thresholds (tuple of float) – Three SNR thresholds (in multiples of stderr) for the
*,**, and***significance levels. Defaults(2.0, 10.0, 100.0).auto_labels (bool) – If
True, labels were auto-generated (e.g.b0,b1, …) and the list of zeroed functions is suppressed.coeffs_true (ndarray | None)
support_true (ndarray | None)
- Returns:
Ready-to-print multi-line table. Terms are marked , *, or * according to *significance_thresholds*; * terms are bold, * and ** terms are normal weight, non-significant active terms are dimmed.
- Return type:
str
- SFI.utils.print_model_comparison(inferences, labels, *, extra_cols=None, metrics=None, title='Model Comparison')[source]¶
Build a multi-model comparison table from several inference objects.
- Parameters:
inferences – Sequence of fitted inference objects.
labels – One name per inference (row labels).
metrics – Attribute names to read from each inference (default
["n_params", "NMSE_force", "force_predicted_MSE"]). The special"n_params"countsforce_coefficients_full.extra_cols – Optional
{column_name: {label: value}}of caller-supplied cells.title (str) – Header line.
- Returns:
Ready-to-print table.
- Return type:
str
- SFI.utils.solve_or_pinv(A, b, tol=1e-15)[source]¶
Solve A ⋅ x = b for x, with a fallback to the Moore–Penrose pseudo-inverse if A is singular or not square. To improve numerical stability, we first normalize A by its diagonal: A_norm = D^{-1} A D^{-1}, b_norm = D^{-1} b, solve A_norm ⋅ x_norm = b_norm, and then recover x = D^{-1} x_norm.
This ensures that the diagonal entries of A_norm are 1 (assuming A has positive diagonal), which often makes the linear solve or pseudo-inverse more robust when A has widely varying scales on its diagonal.
- Parameters:
A (jax.Array, shape (k, k)) – The matrix to solve against. We assume that A has nonnegative diagonal entries; if any diagonal entry is zero, we clip it to a small floor to avoid division by zero.
b (jax.Array, shape (k,)) – The right-hand side vector.
tol (float, default=1e-15) – The tolerance for the pseudo-inverse. If A_norm is effectively singular, we compute x_norm = pinv(A_norm, rcond=tol) @ b_norm.
- Returns:
x –
- The solution vector to A ⋅ x = b, computed as follows:
d_i = sqrt(max(A_{ii}, tol)) (we floor each diagonal entry to tol > 0 to avoid zero divides)
A_norm = D_inv @ A @ D_inv where D_inv = diag(1 / d_i) b_norm = b / d
Solve A_norm ⋅ x_norm = b_norm: - if A_norm is non-singular, use a direct solver - otherwise, fall back to x_norm = pinv(A_norm) @ b_norm
Recover x = x_norm / d
- Return type:
jax.Array, shape (k,)
- SFI.utils.sqrtm_psd(A, eps=0.0)[source]¶
Symmetric PSD matrix square root, applied to the last two (matrix) axes. - Symmetrizes input for stability. - Clips negative eigenvalues to 0 (PSD) then takes sqrt. - Re-symmetrizes the result to kill small asymmetries from numerics.
- Parameters:
A (Array)
eps (float)
- Return type:
Array
- SFI.utils.stable_pinv(G)[source]¶
Numerically-stable pseudo-inverse of a Gram matrix.
Normalizes G by its diagonal before inversion so that all diagonal entries of the rescaled matrix are 1. This avoids ill-conditioning when basis functions have very different scales.
Algorithm: let d_i = sqrt(G_{ii}) (clamped to 1 when zero). Compute pinv(G / outer(d, d)) and rescale by outer(1/d, 1/d) to recover the pseudo-inverse of the original G.
Submodules¶
- SFI.utils.formatting module
- SFI.utils.maths module
- SFI.utils.neighbors module
- SFI.utils.plotting module
SFI_COLORSanimate_particles()animate_spde_comparison()axisvector()comparison_scatter()dark_ax()dark_fig()phase2d()phase2d_scalar()phase3d()plot_field()plot_field_error()plot_nematic_director()plot_pareto_front()plot_particles()plot_particles_field()plot_profile_1d()plot_recovery_bar()plot_recovery_bar_multi()plot_recovery_matrix()plot_rods()plot_spde_snapshot()plot_tensor_field()plot_time_profile_comparison()spatial_acorr2d()stamp_fig()stamp_output()stream_field()timeseries()timeseries_colored()trajectory_scatter()wrap_positions()