Source code for SFI.statefunc.core.runtime
from typing import Any, Callable, Dict
import equinox as eqx
# === JIT control & harness ====================================================
_JIT_ENABLED = True
# Per-root compiled function cache. Keyed by object identity of `root`.
# We intentionally use identity (id(root)) rather than a structural hash to
# keep lookups O(1) and avoid expensive tree-walk hashing or serialisation.
_COMPILED_CACHE: Dict[int, Callable[..., Any]] = {}
[docs]
def set_jit(enabled: bool = True):
"""Globally enable/disable JIT for Basis/PSF/SF __call__."""
global _JIT_ENABLED
_JIT_ENABLED = bool(enabled)
def _eager_eval(root, x, v, mask, extras, params):
# Plain Python call; used when JIT is disabled (set_jit(False)).
return root(x, params=params, v=v, mask=mask, extras=extras)
def _compiled_for_root(root) -> Callable[..., Any]:
"""
Return a compiled callable specialised to the given `root`.
The compiled function *closes over* `root`, so `root` is static
without being an argument (reduces tracing/dispatch overhead).
"""
key = id(root)
fn = _COMPILED_CACHE.get(key)
if fn is not None:
return fn
# Compile once for this root: dynamic args are (x, v, mask, extras, params)
@eqx.filter_jit
def _call(x, v, mask, extras, params):
return root(x, params=params, v=v, mask=mask, extras=extras)
_COMPILED_CACHE[key] = _call
return _call
def _jitted_eval(root, x, v, mask, extras, params):
"""Dispatch to the per-root compiled callable (JIT-enabled path)."""
return _compiled_for_root(root)(x, v, mask, extras, params)