SFI.inference.sparse.scorer module

SFI.inference.sparse.scorer — Normal-equations scorer

The SparseScorer owns the pre-computed moment vector M and Gram matrix G and provides efficient (JIT / vmap) evaluation of the log-likelihood gain for any candidate support \(B\).

It is stateless with respect to the search: no Pareto-front data lives here. Every strategy receives a scorer and calls its methods.

Performance notes

  • Symmetry of G is detected at construction time. When symmetric PSD, the restricted solve uses Cholesky (assume_a="pos"), which is ~2× faster and more stable than the general LU path.

  • The solve is fully JIT-compatible (no Python-level try/except). Singular/rank-deficient cases are handled via jnp.linalg.lstsq.

  • vmap_info avoids double-JIT: the vmapped kernel is compiled once per support size k as a standalone pure function.

class SFI.inference.sparse.scorer.SparseScorer(*, M, G, norm_X2=0.0, n=1, pinv_tol=1e-08, use_residuals=False)[source]

Bases: object

Score candidate supports by solving the restricted normal equations.

Parameters:
  • M ((p,) Array) – Pre-computed moment vector (cross-moments between data and basis functions).

  • G ((p, p) Array) – Normal-equations matrix. May be non-symmetric (e.g. when using Itô-shift moment estimators). Symmetry is detected automatically and a Cholesky fast-path is used when possible.

  • norm_X2 (float, default 0.0) – Sum of squared observations. Only used when use_residuals=True.

  • n (int, default 1) – Sample count prefactor used in the residual-based information gain \(\tfrac{1}{2} n\,\log(\lVert X\rVert^2 / \mathrm{RSS})\). Only used when use_residuals=True.

  • pinv_tol (float, default 1e-8) – Tolerance for the diagonal preconditioning floor.

  • use_residuals (bool, default False) – If True, the information gain is computed via the residual sum-of-squares expression instead of the explicit quadratic form.

info_and_coeffs(B)[source]

Solve \(G_{BB}\,C_B = M_B\) and return the information gain.

Parameters:

B ((k,) int Array) – Indices of the active basis functions (the support).

Returns:

  • info (scalar Array) – \(\tfrac{1}{2}\,C_B^\top M_B\) (or the RSS variant when use_residuals=True).

  • C_B ((k,) Array) – Maximum-likelihood coefficients for the restricted support.

Return type:

Tuple[Array, Array]

vmap_info(batch)[source]

Score a batch of supports of the same cardinality.

The batch is padded to the next power-of-2 length so that JAX’s compilation cache stays bounded (≤ ~12 unique shapes per support size k instead of one per distinct batch count).

Parameters:

batch ((n_supports, k) int Array) – Each row is a sorted support of length k.

Returns:

  • infos ((n_supports,) Array)

  • coeffs ((n_supports, k) Array)

Return type:

Tuple[Array, Array]