Source code for SFI.inference.sparse.hillclimb

"""
SFI.inference.sparse.hillclimb — Stochastic hill-climbing selection
====================================================================

Implements the search strategy described in Gerardos & Ronceray (2025):
starting from an initial model, accept random single-parameter
additions or removals if they increase the chosen information
criterion, until convergence.  Multiple independent chains are run in
parallel, one per cardinality *k* ∈ {0, …, max_k} (including the null
and full models), each starting from a uniformly random support of
size *k*.

This is complementary to the deterministic strategies (greedy, beam):

* **Greedy / beam** explore a systematic path and may get trapped in a
  single monotonic trajectory through the lattice.
* **Hill-climbing** performs a stochastic local search at every
  complexity level, which can escape greedy traps and explore broader
  neighbourhoods of the combinatorial lattice.

References
----------
* Gerardos, A. & Ronceray, P. (2025).  "Principled model selection for
  stochastic dynamics."  (describes the algorithm in the main text.)
"""

from __future__ import annotations

import logging
import math
import time
from typing import Optional

import jax.numpy as jnp
import numpy as np

from .base import SparsityStrategy
from .result import SparsityResult
from .scorer import SparseScorer

logger = logging.getLogger(__name__)


# ------------------------------------------------------------------
# Inline IC penalties (no SparsityResult needed during the search)
# ------------------------------------------------------------------


def _ic_penalty(
    name: str,
    k: int,
    *,
    n0: int,
    p_param: float,
    tau: Optional[float],
    gamma: float,
    total_info: float,
) -> float:
    """Return the penalty term for a given information criterion.

    The score is ``info - penalty``, so a *larger* penalty means more
    pruning.
    """
    name = name.upper()
    if name == "AIC":
        return float(k)
    if name == "BIC":
        if tau is None:
            raise ValueError("BIC requires 'tau' (total trajectory time).")
        return 0.5 * k * math.log(tau)
    if name == "EBIC":
        if tau is None:
            raise ValueError("EBIC requires 'tau' (total trajectory time).")
        lc = math.lgamma(n0 + 1) - math.lgamma(k + 1) - math.lgamma(n0 - k + 1)
        return 0.5 * k * math.log(tau) + 2.0 * gamma * lc
    if name == "PASTIS":
        return k * math.log(n0 / p_param)
    if name == "SIC":
        if total_info <= 0:
            return float("inf")
        return k * math.log(total_info)
    raise ValueError(f"Unknown criterion {name!r}")


# ------------------------------------------------------------------
# Single chain
# ------------------------------------------------------------------


def _run_chain(
    scorer: SparseScorer,
    init_support: np.ndarray,
    *,
    ic_name: str,
    n0: int,
    max_k: int,
    p_param: float,
    tau: Optional[float],
    gamma: float,
    total_info: float,
    patience: int,
    rng: np.random.Generator,
) -> dict:
    """Run one stochastic hill-climbing chain.

    Returns
    -------
    trace : dict[int, (info, support, coeffs)]
        Best raw information gain seen per cardinality across every
        proposal evaluated by the chain (accepted *or* rejected).  This
        is the slice of the chain's work that contributes to the global
        Pareto front.
    """
    current = set(int(j) for j in init_support)
    all_indices = set(range(n0))
    trace: dict = {}

    def _evaluate(support_set: set) -> tuple[float, float, Optional[jnp.ndarray]]:
        """Return ``(ic_score, raw_info, coeffs)``."""
        k = len(support_set)
        penalty = _ic_penalty(
            ic_name, k,
            n0=n0, p_param=p_param, tau=tau, gamma=gamma, total_info=total_info,
        )
        if k == 0:
            return -penalty, 0.0, None
        B = jnp.array(sorted(support_set), dtype=jnp.int32)
        info, coeffs = scorer.info_and_coeffs(B)
        info_f = float(info)
        return info_f - penalty, info_f, coeffs

    def _record_trace(support_set: set, info: float, coeffs) -> None:
        k = len(support_set)
        entry = trace.get(k)
        if entry is None or info > entry[0]:
            trace[k] = (info, sorted(support_set), coeffs)

    best_ic, raw_info, coeffs = _evaluate(current)
    _record_trace(current, raw_info, coeffs)
    fails = 0

    while fails < patience:
        # Decide on a random move: add or remove
        can_add = len(current) < max_k and len(current) < n0
        can_remove = len(current) > 0
        if can_add and can_remove:
            do_add = bool(rng.integers(2))
        elif can_add:
            do_add = True
        elif can_remove:
            do_add = False
        else:
            break  # stuck at empty with max_k=0

        if do_add:
            candidates = list(all_indices - current)
            j = int(rng.choice(candidates))
            proposal = current | {j}
        else:
            j = int(rng.choice(list(current)))
            proposal = current - {j}

        ic_proposal, info_proposal, coeffs_proposal = _evaluate(proposal)
        _record_trace(proposal, info_proposal, coeffs_proposal)
        if ic_proposal > best_ic:
            current = proposal
            best_ic = ic_proposal
            fails = 0
        else:
            fails += 1

    return trace


# ------------------------------------------------------------------
# Strategy class
# ------------------------------------------------------------------


[docs] class HillClimbStrategy(SparsityStrategy): """Stochastic hill-climbing model selection. For each cardinality *k* ∈ {0, …, max_k}, a chain starts from a random support of size *k* and accepts random add/remove moves that improve the chosen information criterion, stopping after *patience* consecutive failures. The best-per-*k* results form the Pareto front returned as a :class:`SparsityResult`. This is the search algorithm described in Gerardos & Ronceray (2025), §"Model selection", which recommends parallel searches from null, full, and random starting points. Parameters ---------- ic : str, default ``"PASTIS"`` Information criterion used as the acceptance objective. One of ``"AIC"``, ``"BIC"``, ``"EBIC"``, ``"PASTIS"``, ``"SIC"``. p_param : float, default 1e-3 PASTIS significance level :math:`p_0`. tau : float or None Total trajectory time (required for BIC / EBIC). gamma : float, default 0.5 EBIC tuning parameter (:math:`\\gamma \\in [0,1]`). patience : int, default 200 Stop a chain after this many consecutive rejected moves. seed : int or None Random seed for reproducibility. report_time : bool, default False Log elapsed wall-clock time when done. """ name = "hillclimb" def __init__( self, *, ic: str = "PASTIS", p_param: float = 1e-3, tau: Optional[float] = None, gamma: float = 0.5, patience: int = 200, seed: Optional[int] = None, report_time: bool = False, ): self.ic = ic.upper() self.p_param = p_param self.tau = tau self.gamma = gamma self.patience = patience self.seed = seed self.report_time = report_time # -----------------------------------------------------------------
[docs] def run(self, scorer: SparseScorer, *, max_k: int, **_kwargs) -> SparsityResult: t0 = time.perf_counter() n0 = scorer.p max_k = min(max_k, n0) rng = np.random.default_rng(self.seed) total_info = float(scorer.total_info) best_info = [-np.inf] * (max_k + 1) best_support = [[] for _ in range(max_k + 1)] best_coeffs = [None] * (max_k + 1) # Null model (k=0) always has info=0 best_info[0] = 0.0 def _record(k: int, info: float, support: list, coeffs): if 0 <= k <= max_k and info > best_info[k]: best_info[k] = info best_support[k] = list(support) best_coeffs[k] = coeffs # ------ Launch one chain per cardinality k -------------------- for k in range(0, max_k + 1): # Random starting support of size k if k == 0: init = np.array([], dtype=np.int32) elif k >= n0: init = np.arange(n0, dtype=np.int32) else: init = np.sort(rng.choice(n0, size=k, replace=False)) trace = _run_chain( scorer, init, ic_name=self.ic, n0=n0, max_k=max_k, p_param=self.p_param, tau=self.tau, gamma=self.gamma, total_info=total_info, patience=self.patience, rng=rng, ) # Fold every cardinality the chain explored into the global # Pareto front, not only its final landing point. for k_state, (info, support, coeffs) in trace.items(): _record(k_state, info, support, coeffs) logger.debug( "Hill-climb chain k₀=%d%d cardinalities recorded.", k, len(trace), ) if self.report_time: dt = time.perf_counter() - t0 logger.info( "Hill-climb (%s, patience=%d) done in %.2fs.", self.ic, self.patience, dt, ) return SparsityResult( p=n0, total_info=total_info, method="hillclimb", best_info_by_k=best_info, best_support_by_k=best_support, best_coeffs_by_k=best_coeffs, second_info_by_k=[-np.inf] * (max_k + 1), second_support_by_k=[[] for _ in range(max_k + 1)], )