SFI.integrate.rk4 module¶
rk4.py¶
Classical fourth-order Runge-Kutta integrator in JAX, suitable for
differentiating through the ODE flow via jax.jacobian / jax.hessian.
- SFI.integrate.rk4.euler_flow(f, x0, dt, n_substeps)[source]¶
Integrate dx/dt = f(x) from x0 over total time dt using Euler.
Uses n_substeps forward-Euler micro-steps each of size
h = dt / n_substeps. The loop is implemented viajax.lax.scanso the full computation is JAX-traceable and differentiable.- Parameters:
- Returns:
x_final – State at time dt.
- Return type:
array (d,)
- SFI.integrate.rk4.ode_flow(f, x0, dt, n_substeps)[source]¶
Integrate dx/dt = f(x) from x0 over total time dt.
Uses n_substeps RK4 micro-steps each of size
h = dt / n_substeps. The loop is implemented viajax.lax.scanso the full computation is JAX-traceable and differentiable.- Parameters:
f (callable (d,) → (d,)) – Drift vector field (parameters should already be closed over).
x0 (array (d,)) – Initial state.
dt (scalar) – Total integration interval.
n_substeps (int (static Python int, not a JAX tracer)) – Number of RK4 micro-steps. Must be a compile-time constant. Must be >= 1; use
integrator='euler'for the Euler path.
- Returns:
x_final – State at time dt.
- Return type:
array (d,)
- SFI.integrate.rk4.select_flow(integrator)[source]¶
Return the ODE flow function for the given integrator name.
- Parameters:
integrator ({'rk4', 'euler'})
- Returns:
flow – Either
ode_flow()oreuler_flow().- Return type:
callable