Source code for SFI.inference.sparse.metrics
"""
SFI.inference.sparse.metrics — Benchmark helpers
=================================================
Standalone functions for comparing inferred supports / coefficients
against ground truth. Useful for benchmarking and papers but not
required for normal inference.
"""
from __future__ import annotations
from typing import Dict, List
import jax.numpy as jnp
Array = jnp.ndarray
[docs]
def overlap_metrics(true_support: List[int], pred_support: List[int]) -> Dict:
"""Compare predicted support to the ground truth.
Parameters
----------
true_support, pred_support : list[int]
Indices of the true and predicted active basis functions.
Returns
-------
dict
Keys: ``TP``, ``FP``, ``FN``, ``prec``, ``rec``, ``exact``.
"""
true_set, pred_set = set(true_support), set(pred_support)
tp = len(true_set & pred_set)
fp = len(pred_set - true_set)
fn = len(true_set - pred_set)
return dict(
TP=tp,
FP=fp,
FN=fn,
prec=tp / (tp + fp) if tp + fp else 0.0,
rec=tp / (tp + fn) if tp + fn else 0.0,
exact=(fp == 0 and fn == 0),
)
[docs]
def predictive_nmse(
Phi_test: Array,
true_support: List[int],
true_coeffs,
inferred_support: List[int],
inferred_coeffs,
) -> float:
"""Normalised mean-squared error on a held-out design matrix.
Parameters
----------
Phi_test : (n_test, p) Array
Design matrix evaluated on test data.
true_support : list[int]
Ground-truth active indices.
true_coeffs : array-like
Ground-truth coefficient vector (length ``len(true_support)``).
inferred_support : list[int]
Inferred active indices.
inferred_coeffs : array-like
Inferred coefficient vector (length ``len(inferred_support)``).
Returns
-------
float
:math:`\\|\\hat y - y\\|^2 / \\|y\\|^2`.
"""
if len(inferred_support) == 0:
return 1.0
true_signal = Phi_test[:, jnp.array(true_support)] @ jnp.array(true_coeffs)
pred_signal = Phi_test[:, jnp.array(inferred_support)] @ jnp.array(inferred_coeffs)
residual = true_signal - pred_signal
return float(jnp.sum(residual**2) / jnp.sum(true_signal**2))