Source code for SFI.inference.optimizers

# SFI/inference/optimizers.py
"""
Optimizer back-ends for parametric inference methods.

Provides L-BFGS-B (via SciPy) and Adam (via optax) wrappers with
logging, best-parameter tracking, and a unified result interface.
"""

from __future__ import annotations

import logging

import jax
import jax.numpy as jnp

logger = logging.getLogger(__name__)


[docs] def optimize_lbfgsb(loss, loss_grad, theta0_flat, *, maxiter, tol): """L-BFGS-B via SciPy (returns a SciPy OptimizeResult).""" import time from scipy.optimize import minimize def fun(theta_np): return float(loss(jnp.asarray(theta_np))) def jac(theta_np): return jnp.asarray(loss_grad(jnp.asarray(theta_np)), dtype=float) logger.info("[force_nonlinear] Starting L-BFGS-B optimization...") t0 = time.perf_counter() res = minimize( fun, jnp.asarray(theta0_flat), jac=jac, method="L-BFGS-B", tol=tol, options={"maxiter": maxiter}, ) t1 = time.perf_counter() logger.info( "[force_nonlinear] L-BFGS-B finished in %.2f s | nit=%d, nfev=%d, final f=%.6g, status=%d (%s)", t1 - t0, res.nit, res.nfev, res.fun, res.status, res.message, ) return res
[docs] def optimize_adam( loss, loss_grad, theta0_flat, *, maxiter, learning_rate, lr_schedule, loss_grad_batch=None, batch_rng_seed=0, batch_schedule=None, ): """Adam via optax (returns a namespace mimicking SciPy OptimizeResult). When *loss_grad_batch* is provided, mini-batch stochastic gradients are used for parameter updates while the full-data *loss* is still used for tracking / best-parameter selection. Parameters ---------- loss_grad_batch : callable, optional ``loss_grad_batch(theta, rng_key) -> grad``. When given, each Adam step uses a stochastic gradient instead of the full-data gradient. Ignored when *batch_schedule* is set. batch_rng_seed : int Seed for the mini-batch PRNG stream. batch_schedule : list of (float, callable), optional Batch-size annealing schedule. Each entry is ``(step_fraction, grad_fn)`` where *grad_fn* has the signature ``grad_fn(theta, rng_key) -> grad``. The list must be sorted by ascending *step_fraction*; the last entry should have fraction 1.0. At each step the active gradient function is the first whose fraction exceeds ``step / maxiter``. """ import time import optax # Build learning-rate schedule if lr_schedule == "cosine": schedule = optax.cosine_decay_schedule( init_value=float(learning_rate), decay_steps=maxiter, ) elif lr_schedule == "constant" or lr_schedule is None: schedule = float(learning_rate) else: raise ValueError(f"Unknown lr_schedule {lr_schedule!r}; choose 'cosine' or 'constant'.") opt = optax.adam(schedule) theta = jnp.asarray(theta0_flat) opt_state = opt.init(theta) best_f = float("inf") best_theta = theta use_schedule = batch_schedule is not None use_minibatch = loss_grad_batch is not None or use_schedule if use_minibatch: rng_key = jax.random.PRNGKey(batch_rng_seed) if use_schedule: logger.info( "[force_nonlinear] Starting Adam with batch-size annealing (%d phases, maxiter=%d, seed=%d)...", len(batch_schedule), maxiter, batch_rng_seed, ) else: logger.info( "[force_nonlinear] Starting Adam optimization with mini-batch (maxiter=%d, seed=%d)...", maxiter, batch_rng_seed, ) else: logger.info("[force_nonlinear] Starting Adam optimization (maxiter=%d)...", maxiter) prev_phase = -1 t0 = time.perf_counter() for step in range(maxiter): if use_schedule: frac = step / maxiter phase = 0 active_grad = batch_schedule[-1][1] for i, (sf, gfn) in enumerate(batch_schedule): if frac < sf: active_grad = gfn phase = i break if phase != prev_phase: prev_phase = phase logger.info( "[force_nonlinear] phase %d/%d starts at step %d", phase, len(batch_schedule), step, ) rng_key, subkey = jax.random.split(rng_key) g = active_grad(theta, subkey) elif use_minibatch: rng_key, subkey = jax.random.split(rng_key) assert loss_grad_batch is not None g = loss_grad_batch(theta, subkey) else: g = loss_grad(theta) updates, opt_state = opt.update(g, opt_state, theta) theta = optax.apply_updates(theta, updates) if step % max(1, maxiter // 20) == 0 or step == maxiter - 1: f_val = float(loss(theta)) if f_val < best_f: best_f = f_val best_theta = theta logger.info( "[force_nonlinear] step %5d / %d loss=%.6g |grad|=%.3e", step, maxiter, f_val, float(jnp.linalg.norm(g)), ) t1 = time.perf_counter() # Final eval f_final = float(loss(theta)) if f_final < best_f: best_f = f_final best_theta = theta logger.info( "[force_nonlinear] Adam finished in %.2f s | %d steps, final loss=%.6g, best loss=%.6g", t1 - t0, maxiter, f_final, best_f, ) # Return an object with the same .x / .fun interface as SciPy from types import SimpleNamespace return SimpleNamespace( x=best_theta, fun=best_f, nit=maxiter, nfev=maxiter, success=True, message="Adam optimization completed.", status=0, )