Source code for SFI.diagnostics.report

"""DiagnosticsReport dataclass — structured output of `assess()`."""

from __future__ import annotations

import json
from dataclasses import asdict, dataclass, field
from typing import Any, Mapping

import numpy as np

#: One-line action hint per flag, appended to `flag_issues` messages.
#: Keys are ``(section, test)``; ``(section, "")`` is the section fallback.
#: Wording mirrors docs/source/diagnostics.rst ("Interpreting flags" and
#: "When a flag points beyond the linear estimators") — keep in lockstep.
_FLAG_HINTS: dict[tuple[str, str], str] = {
    ("autocorr", "ljung_box"): (
        "missing time-correlated feature — widen the basis; if it persists, "
        "suspect coarse sampling: the parametric estimator (infer_force) "
        "extends the usable Δt"
    ),
    ("autocorr", "ljung_box_squared"): (
        "diffusion mis-estimated or state-dependent — try the other "
        "compute_diffusion_constant method or a state-dependent diffusion "
        "basis; the parametric estimators profile (D, Λ)"
    ),
    ("normality", ""): (
        "non-Gaussian residuals — rare events not captured by the basis, "
        "or a non-Gaussian noise structure"
    ),
    ("moments", "mean"): (
        "non-zero residual mean — systematic drift bias; widen the basis"
    ),
    ("moments", "std"): (
        "whitened std far from 1 — D̄ likely wrong: try both "
        "compute_diffusion_constant methods, then the parametric estimator"
    ),
    ("mse_consistency", ""): (
        "realised error above predicted — model bias; on experimental data "
        "usually measurement noise: consider the parametric estimator "
        "(infer_force)"
    ),
}


def _hint(section: str, test: str = "") -> str:
    return _FLAG_HINTS.get((section, test)) or _FLAG_HINTS.get((section, ""), "")


def _to_jsonable(obj: Any) -> Any:
    """Recursively convert numpy / jax arrays to plain Python objects."""
    if isinstance(obj, Mapping):
        return {k: _to_jsonable(v) for k, v in obj.items()}
    if isinstance(obj, (list, tuple)):
        return [_to_jsonable(v) for v in obj]
    if hasattr(obj, "tolist"):
        try:
            return obj.tolist()
        except Exception:
            pass
    if isinstance(obj, (np.floating, np.integer)):
        return obj.item()
    return obj


[docs] @dataclass class DiagnosticsReport: """Container for residual-consistency test results. Attributes ---------- residuals : dict Test results. Always holds ``"moments"``; at ``level="standard"`` also ``"autocorr"``, ``"normality"`` and ``"mse_consistency"``. meta : dict Backend tag, regime, ``n_obs``, ``n_particles``, ``d``, level. """ residuals: dict = field(default_factory=dict) meta: dict = field(default_factory=dict) # ------------------------------------------------------------------ # # Serialisation # ------------------------------------------------------------------ #
[docs] def to_dict(self) -> dict: """Return a JSON-serialisable representation of the report.""" return _to_jsonable(asdict(self))
[docs] def to_json(self, indent: int = 2) -> str: """Serialise the report to a JSON string.""" return json.dumps(self.to_dict(), indent=indent, default=str)
# ------------------------------------------------------------------ # # Issue flagging # ------------------------------------------------------------------ #
[docs] def flag_issues(self, alpha: float = 0.01, *, hints: bool = True) -> list[str]: """List human-readable warnings. Returns one line per test whose p-value is below ``alpha`` or whose statistic crosses a sane threshold (residual mean off zero, std far from one, MSE-consistency ``|z| > 5``). Parameters ---------- alpha : float Significance level for the p-value tests. hints : bool When True (default), each message carries a one-line action hint (" — <what to do>"); set False for bare statistics (machine parsing). """ msgs: list[str] = [] def _emit(base: str, section: str, test: str = "") -> None: h = _hint(section, test) if hints else "" msgs.append(f"{base}{h}" if h else base) # Normality / autocorrelation p-values for section_name, section in ( ("normality", self.residuals.get("normality", {})), ("autocorr", self.residuals.get("autocorr", {})), ): for test_name, payload in section.items(): if not isinstance(payload, Mapping): continue p = payload.get("pvalue") if p is not None and p < alpha: _emit( f"[{section_name}/{test_name}] p={p:.2e} < {alpha}", section_name, test_name, ) # Residual moments moments = self.residuals.get("moments", {}) m = moments.get("mean") s = moments.get("std") n = float(moments.get("n", 1)) band = 4.0 / max(n**0.5, 1.0) if m is not None and s is not None and abs(m) > band: _emit( f"[moments/mean] |mean|={abs(m):.3g} (expected ~0, 4σ band={band:.3g})", "moments", "mean", ) if s is not None and (s < 0.5 or s > 2.0): _emit(f"[moments/std] std={s:.3g}", "moments", "std") # Predicted vs realised MSE — flag on the chi^2 z-score of the # residual excess, which is sampling-noise-aware (the raw ratio is # too noisy at modest n_obs). A decisive z (|z|>5) flags on its # own; a moderately significant z (|z|>2) flags when the realised/ # predicted ratio is also an order of magnitude off — a large, # consistently-signed excess at modest n_obs (e.g. a structurally # misspecified force family) would otherwise stay silent. consistency = self.residuals.get("mse_consistency", {}) excess_z = consistency.get("excess_z") ratio = consistency.get("ratio") if excess_z is not None and ( abs(excess_z) > 5.0 or (abs(excess_z) > 2.0 and ratio is not None and ratio > 10.0) ): _emit( f"[mse_consistency] residual chi^2 z-score = {excess_z:+.2f}, " f"realised/predicted NMSE = {ratio:.2g}", "mse_consistency", ) return msgs
# ------------------------------------------------------------------ # # Pretty-printing # ------------------------------------------------------------------ #
[docs] def print_summary(self, alpha: float = 0.01, *, hints: bool = True) -> None: """Print a human-readable summary of the diagnostic report. Each flagged issue carries a one-line action hint unless ``hints=False``. """ print("\n=== SFI diagnostics report ===") meta = self.meta or {} if meta: print(f"backend : {meta.get('backend', '?')}") print(f"regime : {meta.get('regime', '?')}") print( f"n_obs : {meta.get('n_obs', '?')} " f"n_particles: {meta.get('n_particles', '?')} " f"d: {meta.get('d', '?')}" ) print(f"level : {meta.get('level', '?')}") res = self.residuals or {} if res: print("\n-- Residuals --") mom = res.get("moments", {}) if mom: print( f" mean = {mom.get('mean', float('nan')):+.4f} " f"std = {mom.get('std', float('nan')):.4f} " f"skew = {mom.get('skew', float('nan')):+.3f} " f"kurt-3 = {mom.get('excess_kurt', float('nan')):+.3f} " f"(n={mom.get('n', 0)})" ) for sect_name, sect in ( ("normality", res.get("normality", {})), ("autocorr", res.get("autocorr", {})), ): if not sect: continue for test_name, payload in sect.items(): if not isinstance(payload, Mapping): continue stat = payload.get("statistic") p = payload.get("pvalue") if stat is None and p is None: continue flag = "✗" if (p is not None and p < alpha) else "✓" s_str = f"stat={stat:.3g}" if stat is not None else "" p_str = f"p={p:.3g}" if p is not None else "" print(f" {flag} {sect_name:9s} {test_name:18s} {s_str:14s} {p_str}") mc = res.get("mse_consistency", {}) if mc: pred = mc.get("predicted_NMSE") real = mc.get("realised_NMSE") excess_z = mc.get("excess_z") pred_str = f"{pred:.3g}" if isinstance(pred, (int, float)) else str(pred) real_str = f"{real:.3g}" if isinstance(real, (int, float)) else str(real) z_str = f"{excess_z:+.2f}" if isinstance(excess_z, (int, float)) else str(excess_z) print(f" predicted NMSE = {pred_str} realised NMSE = {real_str} χ² z = {z_str} (|z|>5 ⇒ bias)") issues = self.flag_issues(alpha=alpha, hints=hints) print("\n-- Flags --") if not issues: print(" (no issues at α = {:.2g})".format(alpha)) else: for msg in issues: print(f" ! {msg}") print()