Source code for SFI
"""
SFI – Stochastic Force Inference - main entry point
"""
import logging as _logging
import os as _os
from importlib.metadata import version as _v
# ---------------------------------------------------------------------------
# JAX persistent compilation cache (opt-in)
# ---------------------------------------------------------------------------
# Loading cached XLA executables triggers a per-hit C++ warning from the
# PjRt-IFRT layer ("Assume version compatibility …"), so the cache is OFF
# by default. Enable it by setting SFI_JAX_CACHE_DIR to a directory path,
# e.g. export SFI_JAX_CACHE_DIR=~/.cache/sfi/jax_cache
# XLA compilations dominate small-data runtime (~5 s for lorenz_demo T=100);
# caching cuts repeat runs to ~1 s.
_cache_dir = _os.environ.get("SFI_JAX_CACHE_DIR", "")
if _cache_dir:
import jax
jax.config.update("jax_compilation_cache_dir", _cache_dir)
# Many SFI compilations are 50–500 ms each; the default 1 s threshold
# would skip them, defeating the purpose of the cache.
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
del jax
del _cache_dir
# Library-level null handler (silences output unless the user configures logging)
_logging.getLogger(__name__).addHandler(_logging.NullHandler())
[docs]
def enable_logging(level: str = "INFO") -> None:
"""Quick helper to turn on SFI log output.
>>> import SFI
>>> SFI.enable_logging() # INFO-level messages
>>> SFI.enable_logging("DEBUG") # everything
"""
_logger = _logging.getLogger(__name__)
_logger.setLevel(getattr(_logging, level.upper(), _logging.INFO))
if not any(
isinstance(h, _logging.StreamHandler) for h in _logger.handlers if not isinstance(h, _logging.NullHandler)
):
_handler = _logging.StreamHandler()
_handler.setFormatter(_logging.Formatter("[SFI %(levelname)s] %(message)s"))
_logger.addHandler(_handler)
# Public API --------------------------------------------------------------
from . import bases, diagnostics, inference, integrate, langevin, trajectory, utils
# Convenience re-exports so users can write ``from SFI import ...``
from .diagnostics import DiagnosticsReport, DynamicsOrderReport, assess, classify_dynamics
from .inference import (
InferenceResultSF,
OverdampedLangevinInference,
UnderdampedLangevinInference,
)
from .statefunc import PSF, SF, Basis, make_sf
from .trajectory import TrajectoryCollection, TrajectoryDataset
__version__ = _v("StochasticForceInference")
del _v