Source code for SFI.inference.sparse.scorer

"""
SFI.inference.sparse.scorer — Normal-equations scorer
=====================================================

The :class:`SparseScorer` owns the pre-computed moment vector **M** and
Gram matrix **G** and provides efficient (JIT / vmap) evaluation of the
log-likelihood gain for any candidate support :math:`B`.

It is *stateless* with respect to the search: no Pareto-front data lives
here.  Every strategy receives a scorer and calls its methods.

Performance notes
~~~~~~~~~~~~~~~~~
* Symmetry of **G** is detected at construction time.  When symmetric
  PSD, the restricted solve uses Cholesky (``assume_a="pos"``),
  which is ~2× faster and more stable than the general LU path.
* The solve is fully JIT-compatible (no Python-level ``try``/``except``).
  Singular/rank-deficient cases are handled via ``jnp.linalg.lstsq``.
* ``vmap_info`` avoids double-JIT: the vmapped kernel is compiled
  once per support size *k* as a standalone pure function.
"""

from __future__ import annotations

import logging
from functools import partial
from typing import Tuple

import jax
import jax.numpy as jnp
import jax.scipy.linalg as jsp_la
import numpy as np

logger = logging.getLogger(__name__)

Array = jnp.ndarray


# =====================================================================
# Pure-function kernels (no class reference ⇒ clean JIT, clean vmap)
# =====================================================================


def _solve_sym(A_norm: Array, b_norm: Array) -> Array:
    """Cholesky-based solve for a symmetric PSD system (normalised)."""
    return jsp_la.solve(A_norm, b_norm, assume_a="pos")


def _solve_gen(A_norm: Array, b_norm: Array) -> Array:
    """General LU solve for a non-symmetric system (normalised)."""
    return jsp_la.solve(A_norm, b_norm, assume_a="gen")


def _solve_lstsq(A_norm: Array, b_norm: Array) -> Array:
    """Least-squares fallback (always JIT-safe)."""
    x, _, _, _ = jnp.linalg.lstsq(A_norm, b_norm, rcond=None)
    return x


def _info_kernel_sym(M: Array, G: Array, B: Array, tol: float) -> Tuple[Array, Array]:
    """Score one support — symmetric-PSD fast path."""
    B_idx = jnp.asarray(B, dtype=jnp.int32)
    M_B = M[B_idx]
    G_BB = G[jnp.ix_(B_idx, B_idx)]

    # Diagonal preconditioning
    diag_clipped = jnp.clip(jnp.diag(G_BB), min=tol)
    d = jnp.sqrt(diag_clipped)
    D_inv = 1.0 / d
    A_norm = (G_BB * D_inv[:, None]) * D_inv[None, :]
    b_norm = M_B / d

    # Cholesky-based solve; lstsq fallback via lax.cond on det
    # We use a simple heuristic: if reciprocal condition number is
    # reasonable, use Cholesky; otherwise lstsq.
    # For JIT compatibility we always compute both paths and select.
    x_chol = _solve_sym(A_norm, b_norm)
    # If Cholesky produced NaN/Inf (silent failure on non-PSD input),
    # fall back to least-squares.
    chol_ok = jnp.all(jnp.isfinite(x_chol))
    x_norm = jax.lax.cond(
        chol_ok,
        lambda _: x_chol,
        lambda _: _solve_lstsq(A_norm, b_norm),
        None,
    )

    C_B = x_norm / d
    Q_B = C_B @ M_B
    info = 0.5 * Q_B
    return info, C_B


def _info_kernel_gen(M: Array, G: Array, B: Array, tol: float) -> Tuple[Array, Array]:
    """Score one support — general (non-symmetric) G."""
    B_idx = jnp.asarray(B, dtype=jnp.int32)
    M_B = M[B_idx]
    G_BB = G[jnp.ix_(B_idx, B_idx)]

    diag_clipped = jnp.clip(jnp.diag(G_BB), min=tol)
    d = jnp.sqrt(diag_clipped)
    D_inv = 1.0 / d
    A_norm = (G_BB * D_inv[:, None]) * D_inv[None, :]
    b_norm = M_B / d

    x_gen = _solve_gen(A_norm, b_norm)
    gen_ok = jnp.all(jnp.isfinite(x_gen))
    x_norm = jax.lax.cond(
        gen_ok,
        lambda _: x_gen,
        lambda _: _solve_lstsq(A_norm, b_norm),
        None,
    )

    C_B = x_norm / d
    Q_B = C_B @ M_B
    info = 0.5 * Q_B
    return info, C_B


def _info_kernel_residual_sym(M: Array, G: Array, B: Array, tol: float, norm_X2: float, n: int) -> Tuple[Array, Array]:
    """Symmetric-PSD path with residual-based information gain."""
    info_raw, C_B = _info_kernel_sym(M, G, B, tol)
    Q_B = 2.0 * info_raw  # undo the 0.5 factor
    RSS_B = norm_X2 - Q_B
    info = 0.5 * n * jnp.log(norm_X2 / RSS_B)
    return info, C_B


def _info_kernel_residual_gen(M: Array, G: Array, B: Array, tol: float, norm_X2: float, n: int) -> Tuple[Array, Array]:
    """General path with residual-based information gain."""
    info_raw, C_B = _info_kernel_gen(M, G, B, tol)
    Q_B = 2.0 * info_raw
    RSS_B = norm_X2 - Q_B
    info = 0.5 * n * jnp.log(norm_X2 / RSS_B)
    return info, C_B


# =====================================================================
# Shared JIT / vmap caches   (keyed by scorer config, NOT instance)
# =====================================================================
# By keeping M and G as call-time arguments instead of partial-bound,
# all SparseScorer instances with the same (is_sym, use_residuals, tol,
# norm_X2, n) share compiled XLA code.  This avoids per-instance
# recompilation — a major speed-up when running many scorers with the
# same p (e.g. benchmarks, cross-validation, bootstrapping).

_single_jit_cache: dict[tuple, callable] = {}
_vmap_jit_cache: dict[tuple, callable] = {}


def _get_single_jit(is_sym: bool, use_residuals: bool, tol: float, norm_X2: float = 0.0, n: int = 1):
    key = (is_sym, use_residuals, tol, norm_X2, n)
    if key not in _single_jit_cache:
        kernel = _pick_kernel(is_sym, use_residuals, tol, norm_X2, n)
        _single_jit_cache[key] = jax.jit(kernel)
    return _single_jit_cache[key]


def _get_vmap_jit(k: int, is_sym: bool, use_residuals: bool, tol: float, norm_X2: float = 0.0, n: int = 1):
    key = (k, is_sym, use_residuals, tol, norm_X2, n)
    if key not in _vmap_jit_cache:
        kernel = _pick_kernel(is_sym, use_residuals, tol, norm_X2, n)
        _vmap_jit_cache[key] = jax.jit(jax.vmap(kernel, in_axes=(None, None, 0), out_axes=(0, 0)))
    return _vmap_jit_cache[key]


def _pick_kernel(is_sym, use_residuals, tol, norm_X2, n):
    """Return a partial-bound kernel (M, G, B) -> (info, C_B)."""
    if use_residuals:
        if is_sym:
            return partial(_info_kernel_residual_sym, tol=tol, norm_X2=norm_X2, n=n)
        return partial(_info_kernel_residual_gen, tol=tol, norm_X2=norm_X2, n=n)
    if is_sym:
        return partial(_info_kernel_sym, tol=tol)
    return partial(_info_kernel_gen, tol=tol)


# =====================================================================
# SparseScorer
# =====================================================================


[docs] class SparseScorer: r"""Score candidate supports by solving the restricted normal equations. Parameters ---------- M : (p,) Array Pre-computed moment vector (cross-moments between data and basis functions). G : (p, p) Array Normal-equations matrix. May be non-symmetric (e.g. when using Itô-shift moment estimators). Symmetry is detected automatically and a Cholesky fast-path is used when possible. norm_X2 : float, default 0.0 Sum of squared observations. Only used when ``use_residuals=True``. n : int, default 1 Sample count prefactor used in the residual-based information gain :math:`\tfrac{1}{2} n\,\log(\lVert X\rVert^2 / \mathrm{RSS})`. Only used when ``use_residuals=True``. pinv_tol : float, default 1e-8 Tolerance for the diagonal preconditioning floor. use_residuals : bool, default False If *True*, the information gain is computed via the residual sum-of-squares expression instead of the explicit quadratic form. """ def __init__( self, *, M: Array, G: Array, norm_X2: float = 0.0, n: int = 1, pinv_tol: float = 1e-8, use_residuals: bool = False, ): self.M = jnp.asarray(M) self.G = jnp.asarray(G) self.p: int = int(self.M.shape[0]) self.pinv_tol = pinv_tol self.norm_X2 = norm_X2 self.n = n self.use_residuals = use_residuals # --- detect symmetry of G (once, on host) -------------------- G_np = np.asarray(self.G) self.G_is_symmetric: bool = bool(np.allclose(G_np, G_np.T, atol=1e-12 * (np.abs(G_np).max() + 1e-30))) if self.G_is_symmetric: logger.debug("SparseScorer: G is symmetric → Cholesky fast path.") else: logger.debug("SparseScorer: G is non-symmetric → general LU path.") # --- select the right kernel --------------------------------- # Use shared module-level JIT cache so all instances with the # same config share compiled XLA code. self._jit_kernel = _get_single_jit( self.G_is_symmetric, self.use_residuals, self.pinv_tol, self.norm_X2, self.n, ) # Pre-compute the full (dense) solution. self.total_info, self.total_C = self.info_and_coeffs(jnp.arange(self.p)) # ----------------------------------------------------------------- # Score a single support B # -----------------------------------------------------------------
[docs] def info_and_coeffs(self, B: Array) -> Tuple[Array, Array]: r"""Solve :math:`G_{BB}\,C_B = M_B` and return the information gain. Parameters ---------- B : (k,) int Array Indices of the active basis functions (the *support*). Returns ------- info : scalar Array :math:`\tfrac{1}{2}\,C_B^\top M_B` (or the RSS variant when ``use_residuals=True``). C_B : (k,) Array Maximum-likelihood coefficients for the restricted support. """ if B.size == 0: return jnp.array(0.0), jnp.zeros(0, dtype=self.M.dtype) return self._jit_kernel(self.M, self.G, B)
# ----------------------------------------------------------------- # Batched evaluation # ----------------------------------------------------------------- @staticmethod def _pad_to_pow2(n: int) -> int: """Round *n* up to next power of 2 (min 1).""" if n <= 1: return 1 return 1 << (n - 1).bit_length()
[docs] def vmap_info(self, batch: Array) -> Tuple[Array, Array]: """Score a batch of supports of the *same* cardinality. The batch is padded to the next power-of-2 length so that JAX's compilation cache stays bounded (≤ ~12 unique shapes per support size *k* instead of one per distinct batch count). Parameters ---------- batch : (n_supports, k) int Array Each row is a sorted support of length *k*. Returns ------- infos : (n_supports,) Array coeffs : (n_supports, k) Array """ n_real = batch.shape[0] k = int(batch.shape[1]) n_padded = self._pad_to_pow2(n_real) if n_padded > n_real: # Pad with copies of the first row (valid support — avoids # pathological zeros). Results from padded rows are discarded. pad_rows = jnp.broadcast_to(batch[0:1], (n_padded - n_real, k)) batch = jnp.concatenate([batch, pad_rows], axis=0) fn = _get_vmap_jit( k, self.G_is_symmetric, self.use_residuals, self.pinv_tol, self.norm_X2, self.n, ) infos, coeffs = fn(self.M, self.G, batch) # Slice off padding return infos[:n_real], coeffs[:n_real]