SFI.integrate.rk4 module

rk4.py

Classical fourth-order Runge-Kutta integrator in JAX, suitable for differentiating through the ODE flow via jax.jacobian / jax.hessian.

SFI.integrate.rk4.euler_flow(f, x0, dt, n_substeps)[source]

Integrate dx/dt = f(x) from x0 over total time dt using Euler.

Uses n_substeps forward-Euler micro-steps each of size h = dt / n_substeps. The loop is implemented via jax.lax.scan so the full computation is JAX-traceable and differentiable.

Parameters:
  • f (callable (d,) (d,)) – Drift vector field (parameters should already be closed over).

  • x0 (array (d,)) – Initial state.

  • dt (scalar) – Total integration interval.

  • n_substeps (int (static Python int, not a JAX tracer)) – Number of Euler micro-steps. Must be a compile-time constant.

Returns:

x_final – State at time dt.

Return type:

array (d,)

SFI.integrate.rk4.euler_step(f, x, h)[source]

One forward-Euler step.

Parameters:
  • f (callable (d,) (d,)) – Autonomous vector field, JAX-traceable.

  • x (array (d,)) – Current state.

  • h (scalar) – Step size.

Returns:

x_new – State after one Euler step.

Return type:

array (d,)

SFI.integrate.rk4.ode_flow(f, x0, dt, n_substeps)[source]

Integrate dx/dt = f(x) from x0 over total time dt.

Uses n_substeps RK4 micro-steps each of size h = dt / n_substeps. The loop is implemented via jax.lax.scan so the full computation is JAX-traceable and differentiable.

Parameters:
  • f (callable (d,) (d,)) – Drift vector field (parameters should already be closed over).

  • x0 (array (d,)) – Initial state.

  • dt (scalar) – Total integration interval.

  • n_substeps (int (static Python int, not a JAX tracer)) – Number of RK4 micro-steps. Must be a compile-time constant. Must be >= 1; use integrator='euler' for the Euler path.

Returns:

x_final – State at time dt.

Return type:

array (d,)

SFI.integrate.rk4.rk4_step(f, x, h)[source]

One classical fourth-order Runge-Kutta step.

Parameters:
  • f (callable (d,) (d,)) – Autonomous vector field, JAX-traceable.

  • x (array (d,)) – Current state.

  • h (scalar) – Step size.

Returns:

x_new – State after one RK4 step.

Return type:

array (d,)

SFI.integrate.rk4.select_flow(integrator)[source]

Return the ODE flow function for the given integrator name.

Parameters:

integrator ({'rk4', 'euler'})

Returns:

flow – Either ode_flow() or euler_flow().

Return type:

callable