Source code for SFI.inference.sparse.lasso

"""
SFI.inference.sparse.lasso — ℓ₁-penalised regression (LASSO)
=============================================================

Solve the :math:`\\ell_1`-regularised normal-equations problem:

.. math::

   \\hat C = \\arg\\min_C \\;
      \\tfrac{1}{2}\\,C^\\top G\\,C \\;-\\; M^\\top C
      \\;+\\; \\alpha\\,\\|C\\|_1

using proximal coordinate descent.  No access to the raw design matrix
:math:`\\Phi` is needed — only ``(M, G)`` — preserving the clean
decoupling from the data pipeline.

Sweeping over a regularisation path of :math:`\\alpha` values produces
supports at varying sparsity levels, from which a Pareto front is
assembled.
"""

from __future__ import annotations

import logging
import time

import jax.numpy as jnp
import numpy as np

from .base import SparsityStrategy
from .result import SparsityResult
from .scorer import SparseScorer

logger = logging.getLogger(__name__)


def _soft_threshold(x: float, lam: float) -> float:
    """Scalar soft-thresholding operator."""
    if x > lam:
        return x - lam
    if x < -lam:
        return x + lam
    return 0.0


[docs] class LassoStrategy(SparsityStrategy): r"""Coordinate-descent LASSO on the normal equations. Parameters ---------- alpha : float or None Fixed regularisation strength. If *None* (default), an automatic log-spaced path from :math:`\alpha_{\max}` (where the solution is entirely zero) down to :math:`10^{-4}\,\alpha_{\max}` is constructed. n_alphas : int, default 50 Number of :math:`\alpha` values in the automatic path. max_iter : int, default 1000 Maximum coordinate-descent iterations per :math:`\alpha`. tol : float, default 1e-7 Convergence tolerance (max absolute change in any coefficient). report_time : bool, default False Log elapsed wall-clock time when done. """ name = "lasso" def __init__( self, *, alpha: float | None = None, n_alphas: int = 50, max_iter: int = 1000, tol: float = 1e-7, report_time: bool = False, ): self.alpha = alpha self.n_alphas = n_alphas self.max_iter = max_iter self.tol = tol self.report_time = report_time # ----------------------------------------------------------------- def _coordinate_descent(self, G: np.ndarray, M: np.ndarray, alpha: float, C_init: np.ndarray) -> np.ndarray: """Run coordinate descent for one alpha. Pure numpy, no JAX.""" p = len(M) C = C_init.copy() for _ in range(self.max_iter): max_delta = 0.0 for j in range(p): # Partial residual for coordinate j r_j = float(M[j] - G[j] @ C + G[j, j] * C[j]) G_jj = float(G[j, j]) if G_jj < 1e-30: new_val = 0.0 else: new_val = _soft_threshold(r_j, alpha) / G_jj delta = abs(new_val - C[j]) if delta > max_delta: max_delta = delta C[j] = new_val if max_delta < self.tol: break return C # -----------------------------------------------------------------
[docs] def run(self, scorer: SparseScorer, *, max_k: int, **_kwargs) -> SparsityResult: t0 = time.perf_counter() p = scorer.p max_k = min(max_k, p) # Convert to numpy for the coordinate descent loop G_np = np.asarray(scorer.G) M_np = np.asarray(scorer.M) # alpha_max: smallest alpha that zeros everything out alpha_max = float(np.max(np.abs(M_np))) if self.alpha is not None: alphas = [self.alpha] else: alphas = np.logspace( np.log10(alpha_max), np.log10(alpha_max * 1e-4), self.n_alphas, ).tolist() best_info = [-np.inf] * (max_k + 1) best_support = [[] for _ in range(max_k + 1)] best_coeffs = [None] * (max_k + 1) # Null model best_info[0] = 0.0 # Warm-start: start from zeros for the largest alpha C_warm = np.zeros(p) for alpha_val in alphas: C_warm = self._coordinate_descent(G_np, M_np, alpha_val, C_warm) # Identify nonzero support support = [j for j in range(p) if abs(C_warm[j]) > 1e-14] k = len(support) if k > max_k or k == 0: continue # Re-solve exactly on the LASSO support (de-biased LASSO) B = jnp.array(support, dtype=jnp.int32) info, coeffs = scorer.info_and_coeffs(B) if float(info) > best_info[k]: best_info[k] = float(info) best_support[k] = support best_coeffs[k] = coeffs # Also record the full model if within max_k if p <= max_k: full_info = float(scorer.total_info) if full_info > best_info[p]: best_info[p] = full_info best_support[p] = list(range(p)) best_coeffs[p] = scorer.total_C if self.report_time: dt = time.perf_counter() - t0 logger.info("LASSO done in %.2fs (%d alphas).", dt, len(alphas)) return SparsityResult( p=scorer.p, total_info=float(scorer.total_info), method=self.name, best_info_by_k=best_info, best_support_by_k=best_support, best_coeffs_by_k=best_coeffs, second_info_by_k=[-np.inf] * (max_k + 1), second_support_by_k=[[] for _ in range(max_k + 1)], )