Simulation

The SFI.langevin subpackage provides Langevin simulators that use the same state-function objects (PSF / SF) as the inference engines. This closes the inference → simulation → validation loop: you infer a model from data, simulate from it, and compare the synthetic trajectory back to the original.

Two simulators are available:

  • OverdampedProcess — stochastic Heun (default, Euler–Maruyama optional) for \(\mathrm{d}\mathbf{x} = \mathbf{F}(\mathbf{x})\,\mathrm{d}t + \sqrt{2\mathbf{D}(\mathbf{x})}\,\mathrm{d}W_t\)

  • UnderdampedProcess — velocity-Verlet-like for \(\mathrm{d}\mathbf{x} = \mathbf{v}\,\mathrm{d}t,\;\mathrm{d}\mathbf{v} = \mathbf{F}(\mathbf{x},\mathbf{v})\,\mathrm{d}t + \sqrt{2\mathbf{D}}\,\mathrm{d}W_t\)

Quick example

import jax.numpy as jnp
from jax import random
from SFI.langevin import OverdampedProcess
from SFI import make_sf

# Define force as a simple function, wrap as SF
F_sf = make_sf(lambda x, *, mask=None: -x, dim=2, rank=1)
proc = OverdampedProcess(F_sf, D=jnp.eye(2) * 0.5)
proc.initialize(jnp.zeros(2))

coll = proc.simulate(
    dt=0.01,
    Nsteps=10_000,
    key=random.PRNGKey(0),
    prerun=100,
    oversampling=10,
)

Workflow

  1. Construct the process with a force model F (PSF or SF) and a diffusion D (scalar, matrix, or PSF/SF).

  2. Bind parameters via set_params() if F or D are unbound PSF objects.

  3. Initialize the state with initialize() (position for overdamped; position + velocity for underdamped).

  4. Simulate via simulate(), which returns a TrajectoryCollection.

Diffusion specification

The D argument accepts multiple forms:

Input

Interpretation

float (scalar)

Isotropic: \(D = \sigma \cdot I_d\)

(d, d) array

Constant diffusion matrix

PSF / SF with rank=2

State-dependent \(D(x)\) (or \(D(x, v)\))

Particle systems

Both simulators respect the pdepth contract of the state functions:

  • pdepth=0: single particle, x0 has shape (d,)

  • pdepth=1: interacting particles, x0 has shape (P, d)

For interacting-particle systems, attach extras_global and/or extras_local with set_extras() before simulating, to pass system metadata (box size, species labels, neighbor lists, …) through to the models at every time step.

For a comprehensive guide to setting up particle systems, see Particle systems.

Choosing dt and oversampling

The time step dt is the interval between recorded frames. The oversampling parameter controls how many internal sub-steps are taken per recorded frame:

  • Larger oversampling → more accurate integration, but slower.

  • A rule of thumb: oversampling should be large enough that \(\mathrm{d}t_{\text{internal}} = \mathrm{d}t / \text{oversampling}\) is small compared to the fastest timescale in the dynamics.

  • For stiff systems (strong gradients, large forces), increase oversampling to 10–100.

  • For diffusion-dominated systems with gentle forces, oversampling=1 may suffice.

Time-dependent extras (protocols)

Drives, ramps, and switching protocols enter a simulation as time-dependent extras through the unchanged set_extras API:

from SFI.trajectory import time_series_extra

k_t = (np.arange(Nsteps) // 1000 % 2).astype(float)   # square wave
proc.set_extras(extras_global={"k_drive": time_series_extra(k_t)})
coll = proc.simulate(dt=dt, Nsteps=Nsteps, key=key)

Conventions:

  • A TimeSeriesExtra must carry one value per recorded frame (leading axis == Nsteps); the value is held constant across the oversampling substeps of its frame (zeroth-order hold), and the prerun uses the frame-0 value.

  • A plain callable is interpreted as f(t) of physical time and materialized at the frame times t = k\,\mathrm{d}t before the run.

  • The schedule is attached to the returned collection (as a TimeSeriesExtra), aligned so that the increment X[k+1] - X[k] was generated under the frame-k value — exactly the pairing the inference layer assumes. The round-trip idiom is therefore one line on each side: simulate with the protocol, then infer_force_linear() on a basis containing extra_scalar() terms (Time-dependent forcing — protocols as extras).

  • simulate_chunked does not support time-dependent extras.

Simulation parameters

Parameter

Description

dt

Time step between recorded frames

Nsteps

Number of recorded frames

key

JAX PRNG key

prerun

Warm-up steps (discarded)

oversampling

Internal sub-steps per recorded step (improves accuracy)

Observables (overdamped)

By default the overdamped process also computes entropy- and information-production estimates on the recorded trajectory (pass compute_observables=False to skip them). They are attached to the returned collection’s dataset metadata:

obs = coll.datasets[0].meta["observables"]
I = obs["information"]   # I ≈ (1/4) Σ_t ⟨Δx_t, D⁻¹ F(x_t)⟩
S = obs["entropy"]       # S ≈ Σ_t ⟨Δx_t, D⁻¹(x_mid) · F̄_t⟩