Simulation¶
The SFI.langevin subpackage provides Langevin simulators that
use the same state-function objects (PSF / SF) as the inference
engines. This closes the inference → simulation → validation loop:
you infer a model from data, simulate from it, and compare the synthetic
trajectory back to the original.
Two simulators are available:
OverdampedProcess— stochastic Heun (default, Euler–Maruyama optional) for \(\mathrm{d}\mathbf{x} = \mathbf{F}(\mathbf{x})\,\mathrm{d}t + \sqrt{2\mathbf{D}(\mathbf{x})}\,\mathrm{d}W_t\)UnderdampedProcess— velocity-Verlet-like for \(\mathrm{d}\mathbf{x} = \mathbf{v}\,\mathrm{d}t,\;\mathrm{d}\mathbf{v} = \mathbf{F}(\mathbf{x},\mathbf{v})\,\mathrm{d}t + \sqrt{2\mathbf{D}}\,\mathrm{d}W_t\)
Quick example¶
import jax.numpy as jnp
from jax import random
from SFI.langevin import OverdampedProcess
from SFI import make_sf
# Define force as a simple function, wrap as SF
F_sf = make_sf(lambda x, *, mask=None: -x, dim=2, rank=1)
proc = OverdampedProcess(F_sf, D=jnp.eye(2) * 0.5)
proc.initialize(jnp.zeros(2))
coll = proc.simulate(
dt=0.01,
Nsteps=10_000,
key=random.PRNGKey(0),
prerun=100,
oversampling=10,
)
Workflow¶
Construct the process with a force model
F(PSForSF) and a diffusionD(scalar, matrix, orPSF/SF).Bind parameters via
set_params()ifForDare unboundPSFobjects.Initialize the state with
initialize()(position for overdamped; position + velocity for underdamped).Simulate via
simulate(), which returns aTrajectoryCollection.
Diffusion specification¶
The D argument accepts multiple forms:
Particle systems¶
Both simulators respect the pdepth contract of the state
functions:
pdepth=0: single particle,x0has shape(d,)pdepth=1: interacting particles,x0has shape(P, d)
For interacting-particle systems, attach extras_global and/or
extras_local with set_extras() before simulating, to pass
system metadata (box size, species labels, neighbor lists, …) through to
the models at every time step.
For a comprehensive guide to setting up particle systems, see Particle systems.
Choosing dt and oversampling¶
The time step dt is the interval between recorded frames.
The oversampling parameter controls how many internal sub-steps
are taken per recorded frame:
Larger
oversampling→ more accurate integration, but slower.A rule of thumb:
oversamplingshould be large enough that \(\mathrm{d}t_{\text{internal}} = \mathrm{d}t / \text{oversampling}\) is small compared to the fastest timescale in the dynamics.For stiff systems (strong gradients, large forces), increase
oversamplingto 10–100.For diffusion-dominated systems with gentle forces,
oversampling=1may suffice.
Time-dependent extras (protocols)¶
Drives, ramps, and switching protocols enter a simulation as
time-dependent extras through the unchanged set_extras API:
from SFI.trajectory import time_series_extra
k_t = (np.arange(Nsteps) // 1000 % 2).astype(float) # square wave
proc.set_extras(extras_global={"k_drive": time_series_extra(k_t)})
coll = proc.simulate(dt=dt, Nsteps=Nsteps, key=key)
Conventions:
A
TimeSeriesExtramust carry one value per recorded frame (leading axis== Nsteps); the value is held constant across theoversamplingsubsteps of its frame (zeroth-order hold), and the prerun uses the frame-0 value.A plain callable is interpreted as
f(t)of physical time and materialized at the frame timest = k\,\mathrm{d}tbefore the run.The schedule is attached to the returned collection (as a
TimeSeriesExtra), aligned so that the incrementX[k+1] - X[k]was generated under the frame-kvalue — exactly the pairing the inference layer assumes. The round-trip idiom is therefore one line on each side: simulate with the protocol, theninfer_force_linear()on a basis containingextra_scalar()terms (Time-dependent forcing — protocols as extras).simulate_chunkeddoes not support time-dependent extras.
Simulation parameters¶
Parameter |
Description |
|---|---|
|
Time step between recorded frames |
|
Number of recorded frames |
|
JAX PRNG key |
|
Warm-up steps (discarded) |
|
Internal sub-steps per recorded step (improves accuracy) |
Observables (overdamped)¶
By default the overdamped process also computes entropy- and
information-production estimates on the recorded trajectory (pass
compute_observables=False to skip them). They are attached to the
returned collection’s dataset metadata:
obs = coll.datasets[0].meta["observables"]
I = obs["information"] # I ≈ (1/4) Σ_t ⟨Δx_t, D⁻¹ F(x_t)⟩
S = obs["entropy"] # S ≈ Σ_t ⟨Δx_t, D⁻¹(x_mid) · F̄_t⟩