"""
Noise models for Langevin simulators.
This module provides a hierarchy of noise models that go beyond the
simple diagonal diffusion ``D = σ I``. In particular, it supports
**conserved noise** needed by SPDE models such as Active Model B+, where
the noise takes the form ``∇·(σ η)`` and preserves the spatial integral
of the field.
Class hierarchy
---------------
- :class:`NoiseModel` — abstract base;
- :class:`WhiteNoise` — i.i.d. per-site Gaussian (recovers current behaviour);
- :class:`ConservedNoise` — ``sqrt(-σ² ∇²) ξ`` via FFT on a periodic grid;
- :class:`CompositeNoise` — different noise models on different field components.
Usage
-----
Pass a ``NoiseModel`` instance as ``D=`` to :class:`~SFI.langevin.OverdampedProcess`
(or any ``LangevinBase`` subclass). The simulator detects the noise model and
delegates noise generation to it instead of the traditional ``sqrt(2D)·ξ`` path.
.. code-block:: python
from SFI.langevin.noise import ConservedNoise
noise = ConservedNoise(sigma=0.3, grid_shape=(64, 64), dx=1.0)
proc = OverdampedProcess(BASIS, D=noise)
proc.set_params(theta_F=theta)
proc.set_extras(extras_global=box_extras)
proc.initialize(X0)
coll = proc.simulate(dt=0.02, Nsteps=3000, key=key, oversampling=4)
"""
from __future__ import annotations
import abc
from typing import List, Sequence, Tuple, Union
import jax.numpy as jnp
from jax import random
Array = jnp.ndarray
# ============================================================================
# Abstract base
# ============================================================================
[docs]
class NoiseModel(abc.ABC):
"""Abstract base for noise models used by Langevin simulators.
Subclasses must implement :meth:`sample` and :meth:`effective_D_per_site`.
The simulator calls ``sample(key, x, extras)`` once per Euler–Maruyama
substep to obtain the noise increment (already scaled by ``sqrt(2)`` so
that the step becomes ``x += dt*F + sqrt(dt)*sample(key, x, extras)``).
Parameters
----------
n_fields : int
Number of field components per grid site (= ``dim`` in the force
contract). E.g. 1 for a scalar field, 2 for a 2-component system.
"""
def __init__(self, *, n_fields: int) -> None:
self._n_fields = n_fields
@property
def n_fields(self) -> int:
"""Number of field components per site."""
return self._n_fields
# Alias for the force-contract convention
@property
def dim(self) -> int:
return self._n_fields
[docs]
@abc.abstractmethod
def sample(self, key: Array, x: Array, extras: dict) -> Array:
r"""Draw one noise increment.
Parameters
----------
key : PRNG key
x : Array, shape ``(P, d)`` or ``(d,)``
Current state (used only for shape; not accessed by white/conserved
noise, but may be needed by multiplicative noise subclasses).
extras : dict
Process extras (contains grid geometry, etc.).
Returns
-------
Array, same shape as *x*
Noise increment already multiplied by ``sqrt(2)`` so the
integrator applies ``x += dt*F + sqrt(dt) * sample(...)``.
"""
...
[docs]
@abc.abstractmethod
def effective_D_per_site(self, extras: dict) -> Array:
r"""Return an approximate per-site diffusion matrix ``(d, d)``.
This is used by the inference pipeline as a *pragmatic approximation*
when the noise is not white-in-space. For ``WhiteNoise(σ)`` this
returns exactly ``σ·I``. For ``ConservedNoise`` it returns the
spatially-averaged effective variance per site.
Returns
-------
Array, shape ``(d, d)``
"""
...
@property
def noise_kind(self) -> str:
"""Short string tag for the noise type."""
return self.__class__.__name__
# ============================================================================
# White noise (recovers current constant-scalar behaviour)
# ============================================================================
[docs]
class WhiteNoise(NoiseModel):
r"""Spatially uncorrelated Gaussian noise: ``B = sqrt(2σ) I``.
Each grid site receives i.i.d. ``N(0, 2σ dt)`` noise per component.
This recovers the current ``D = σ`` (scalar constant) behaviour.
Parameters
----------
sigma : float
Scalar diffusion coefficient (the *D* in ``dx = F dt + sqrt(2D) dW``).
n_fields : int
Number of field components per site.
"""
def __init__(self, sigma: float, *, n_fields: int = 1) -> None:
super().__init__(n_fields=n_fields)
if sigma < 0:
raise ValueError(f"sigma must be non-negative, got {sigma}")
self._sigma = float(sigma)
# Precompute sqrt(2σ) for efficiency
self._sqrt2sigma = float(jnp.sqrt(2.0 * sigma))
@property
def sigma(self) -> float:
return self._sigma
[docs]
def sample(self, key: Array, x: Array, extras: dict) -> Array:
xi = random.normal(key, shape=x.shape)
return self._sqrt2sigma * xi
[docs]
def effective_D_per_site(self, extras: dict) -> Array:
return self._sigma * jnp.eye(self._n_fields)
def __repr__(self) -> str:
return f"WhiteNoise(sigma={self._sigma}, n_fields={self._n_fields})"
# ============================================================================
# Conserved noise (sqrt(-σ² ∇²) via FFT on periodic grids)
# ============================================================================
def _build_freq_amplitudes(
grid_shape: Sequence[int],
dx: Union[float, Sequence[float]],
ndim: int,
) -> Array:
r"""Build the Fourier-space multiplier ``|k|`` for conserved noise.
For a periodic grid with spacing *dx*, the wavenumbers along axis α are
.. math::
k_\alpha = \frac{2\pi\,n_\alpha}{N_\alpha\,\Delta x_\alpha}
and the multiplier is ``|k| = sqrt(sum_α k_α²)``, which corresponds to
the operator ``sqrt(-∇²)`` in Fourier space.
We use ``rfft`` along the last spatial axis, so the returned array has
shape ``(N_0, N_1, ..., N_{d-2}, N_{d-1}//2+1)`` for an *ndim*-D grid.
The ``k = 0`` mode is set to zero (conserved noise has zero mean).
Returns
-------
Array, real, shape matching rfft output
``|k|`` on the half-complex grid.
"""
grid_shape = tuple(int(n) for n in grid_shape)
if isinstance(dx, (int, float)):
dx_arr = [float(dx)] * ndim
else:
dx_arr = [float(d) for d in dx]
# Build k² = sum_α k_α²
k_sq = jnp.zeros(grid_shape[:-1] + (grid_shape[-1] // 2 + 1,))
for axis in range(ndim):
N = grid_shape[axis]
if axis < ndim - 1:
# Full-size axis: use fftfreq
freq = jnp.fft.fftfreq(N, d=dx_arr[axis]) # shape (N,)
else:
# Last axis: rfft convention → only non-negative frequencies
freq = jnp.fft.rfftfreq(N, d=dx_arr[axis]) # shape (N//2+1,)
k_alpha = 2 * jnp.pi * freq # angular wavenumber
# Broadcast to the full rfft grid shape
shape = [1] * ndim
if axis < ndim - 1:
shape[axis] = N
else:
shape[axis] = grid_shape[-1] // 2 + 1
k_alpha = k_alpha.reshape(shape)
k_sq = k_sq + k_alpha**2
# |k| = sqrt(k²), with k=0 mode zeroed out
k_abs = jnp.sqrt(k_sq)
return k_abs
[docs]
class ConservedNoise(NoiseModel):
r"""Conserved (divergence-form) noise on a periodic square grid.
Implements noise of the form
.. math::
\eta(x, t) = \nabla \cdot \bigl(\sigma\, \vec{\Lambda}(x,t)\bigr)
where :math:`\vec{\Lambda}` is spatiotemporal white vector noise.
In Fourier space this is equivalent to multiplying each mode by
:math:`|k|`:
.. math::
\hat{\eta}_k = \sigma\,|k|\,\hat{\xi}_k
This noise **conserves the spatial average** of the field
(:math:`\sum_i \phi_i` is a constant of the noise process), as
required by Model B / Active Model B+ dynamics.
Parameters
----------
sigma : float
Noise amplitude (the :math:`\sigma` in the equations above).
This is the *continuum* amplitude; the grid discretisation is
handled internally.
grid_shape : sequence of int
Grid dimensions ``(Nx, Ny, ...)`` — must match the simulation grid.
dx : float or sequence of float
Grid spacing (uniform or per-axis).
n_fields : int
Number of field components per site.
Notes
-----
The ``sample`` method uses real FFT (``rfftn`` / ``irfftn``) for
efficiency. It draws white noise in real space, transforms to
Fourier space, multiplies by :math:`\sigma\,|k|\,\sqrt{2/\Delta V}`
(where :math:`\Delta V = \prod \Delta x_\alpha` is the cell volume),
and transforms back.
The factor :math:`1/\sqrt{\Delta V}` provides the correct continuum
limit: the noise covariance
:math:`\langle\eta_i\,\eta_j\rangle = -\sigma^2 \nabla^2 \delta_{ij} / \Delta V`
is independent of grid resolution when *sigma* is held fixed.
"""
def __init__(
self,
sigma: float,
*,
grid_shape: Sequence[int],
dx: Union[float, Sequence[float]] = 1.0,
n_fields: int = 1,
) -> None:
super().__init__(n_fields=n_fields)
if sigma < 0:
raise ValueError(f"sigma must be non-negative, got {sigma}")
self._sigma = float(sigma)
self._grid_shape = tuple(int(n) for n in grid_shape)
self._ndim = len(self._grid_shape)
if isinstance(dx, (int, float)):
self._dx = tuple([float(dx)] * self._ndim)
else:
self._dx = tuple(float(d) for d in dx)
self._P = 1
for n in self._grid_shape:
self._P *= n
# Cell volume for continuum-limit normalisation
dV = 1.0
for d in self._dx:
dV *= d
self._dV = dV
# Precompute |k| array for rfft
self._k_abs = _build_freq_amplitudes(self._grid_shape, self._dx, self._ndim)
# Combined multiplier: sigma * |k| * sqrt(2 / dV)
# The sqrt(2) enters because the integrator step is
# x += sqrt(dt) * sample(...)
# and we need <sample_i sample_j> = 2 * D_eff * delta_{ij}
# where D_eff is the noise operator.
self._multiplier = self._sigma * self._k_abs * jnp.sqrt(2.0 / self._dV)
# Effective per-site D for inference approximation:
# Var(noise_i) = sigma² * <|k|²> / dV
# where <|k|²> = (1/N) sum_k |k|² (excluding k=0)
k_sq_mean = float(jnp.mean(self._k_abs**2))
self._D_eff = float(self._sigma**2 * k_sq_mean / self._dV)
@property
def sigma(self) -> float:
return self._sigma
@property
def grid_shape(self) -> Tuple[int, ...]:
return self._grid_shape
[docs]
def sample(self, key: Array, x: Array, extras: dict) -> Array:
r"""Draw one conserved-noise increment.
Steps:
1. Draw white noise ξ ~ N(0,1) on the grid, shape (Nx, Ny, ..., d)
2. rFFT along spatial axes
3. Multiply by :math:`\sigma\,|k|\,\sqrt{2/\Delta V}` (the ``_multiplier``)
4. iFFT back to real space
5. Flatten back to (P, d)
The k=0 mode of ``_multiplier`` is zero, so sum_i η_i = 0 exactly.
"""
d = self._n_fields
grid_d = self._grid_shape + (d,)
# Draw white noise on the grid
xi = random.normal(key, shape=grid_d)
# FFT axes = spatial only (not the field axis)
fft_axes = tuple(range(self._ndim))
# Forward real FFT along spatial axes
xi_hat = jnp.fft.rfftn(xi, axes=fft_axes)
# Multiply by the precomputed |k|*sigma*sqrt(2/dV)
# multiplier shape: (Nx, ..., Ny//2+1) — broadcast over field axis d
eta_hat = xi_hat * self._multiplier[..., None]
# Inverse real FFT
eta = jnp.fft.irfftn(eta_hat, s=self._grid_shape, axes=fft_axes)
# Flatten spatial axes: (Nx, Ny, ..., d) → (P, d)
return eta.reshape(self._P, d)
[docs]
def effective_D_per_site(self, extras: dict) -> Array:
return self._D_eff * jnp.eye(self._n_fields)
def __repr__(self) -> str:
return (
f"ConservedNoise(sigma={self._sigma}, "
f"grid_shape={self._grid_shape}, dx={self._dx}, "
f"n_fields={self._n_fields})"
)
# ============================================================================
# Composite noise (different models on different field components)
# ============================================================================
[docs]
class CompositeNoise(NoiseModel):
r"""Apply different noise models to different field components.
Useful when some fields have conserved dynamics (e.g. concentration)
and others have non-conserved dynamics (e.g. velocity).
Parameters
----------
components : list of ``(NoiseModel, field_indices)`` pairs
Each element specifies a noise model and the field indices it
applies to. ``field_indices`` is a list of ints. Together the
indices must cover ``range(n_fields)`` exactly once.
Example
-------
>>> conserved = ConservedNoise(sigma=0.3, grid_shape=(64, 64), n_fields=1)
>>> white = WhiteNoise(sigma=0.1, n_fields=2)
>>> composite = CompositeNoise(
... components=[(conserved, [0]), (white, [1, 2])],
... n_fields=3,
... )
"""
def __init__(
self,
*,
components: List[Tuple[NoiseModel, List[int]]],
n_fields: int,
) -> None:
super().__init__(n_fields=n_fields)
# Validate coverage
all_indices: set[int] = set()
for model, indices in components:
idx_set = set(indices)
if idx_set & all_indices:
raise ValueError(f"Overlapping field indices: {idx_set & all_indices}")
all_indices |= idx_set
if all_indices != set(range(n_fields)):
raise ValueError(f"Field indices must cover range({n_fields}), got {sorted(all_indices)}")
self._components = components
[docs]
def sample(self, key: Array, x: Array, extras: dict) -> Array:
result = jnp.zeros_like(x)
for i, (model, indices) in enumerate(self._components):
key, sub = random.split(key)
# Extract the sub-state for this component's fields
idx = jnp.array(indices)
x_sub = x[..., idx]
noise_sub = model.sample(sub, x_sub, extras)
# Place back into the full field
result = result.at[..., idx].set(noise_sub)
return result
[docs]
def effective_D_per_site(self, extras: dict) -> Array:
D = jnp.zeros((self._n_fields, self._n_fields))
for model, indices in self._components:
D_sub = model.effective_D_per_site(extras)
for i_local, i_global in enumerate(indices):
for j_local, j_global in enumerate(indices):
D = D.at[i_global, j_global].set(D_sub[i_local, j_local])
return D
def __repr__(self) -> str:
parts = [f"({m!r}, {idx})" for m, idx in self._components]
return f"CompositeNoise(components=[{', '.join(parts)}], n_fields={self._n_fields})"