SFI.inference.sparse.scorer module¶
SFI.inference.sparse.scorer — Normal-equations scorer¶
The SparseScorer owns the pre-computed moment vector M and
Gram matrix G and provides efficient (JIT / vmap) evaluation of the
log-likelihood gain for any candidate support \(B\).
It is stateless with respect to the search: no Pareto-front data lives here. Every strategy receives a scorer and calls its methods.
Performance notes¶
Symmetry of G is detected at construction time. When symmetric PSD, the restricted solve uses Cholesky (
assume_a="pos"), which is ~2× faster and more stable than the general LU path.The solve is fully JIT-compatible (no Python-level
try/except). Singular/rank-deficient cases are handled viajnp.linalg.lstsq.vmap_infoavoids double-JIT: the vmapped kernel is compiled once per support size k as a standalone pure function.
- class SFI.inference.sparse.scorer.SparseScorer(*, M, G, norm_X2=0.0, n=1, pinv_tol=1e-08, use_residuals=False)[source]¶
Bases:
objectScore candidate supports by solving the restricted normal equations.
- Parameters:
M ((p,) Array) – Pre-computed moment vector (cross-moments between data and basis functions).
G ((p, p) Array) – Normal-equations matrix. May be non-symmetric (e.g. when using Itô-shift moment estimators). Symmetry is detected automatically and a Cholesky fast-path is used when possible.
norm_X2 (float, default 0.0) – Sum of squared observations. Only used when
use_residuals=True.n (int, default 1) – Sample count prefactor used in the residual-based information gain \(\tfrac{1}{2} n\,\log(\lVert X\rVert^2 / \mathrm{RSS})\). Only used when
use_residuals=True.pinv_tol (float, default 1e-8) – Tolerance for the diagonal preconditioning floor.
use_residuals (bool, default False) – If True, the information gain is computed via the residual sum-of-squares expression instead of the explicit quadratic form.
- info_and_coeffs(B)[source]¶
Solve \(G_{BB}\,C_B = M_B\) and return the information gain.
- Parameters:
B ((k,) int Array) – Indices of the active basis functions (the support).
- Returns:
info (scalar Array) – \(\tfrac{1}{2}\,C_B^\top M_B\) (or the RSS variant when
use_residuals=True).C_B ((k,) Array) – Maximum-likelihood coefficients for the restricted support.
- Return type:
Tuple[Array, Array]
- vmap_info(batch)[source]¶
Score a batch of supports of the same cardinality.
The batch is padded to the next power-of-2 length so that JAX’s compilation cache stays bounded (≤ ~12 unique shapes per support size k instead of one per distinct batch count).
- Parameters:
batch ((n_supports, k) int Array) – Each row is a sorted support of length k.
- Returns:
infos ((n_supports,) Array)
coeffs ((n_supports, k) Array)
- Return type:
Tuple[Array, Array]