"""
SFI.bases.pairs
===============
Generic building blocks for multi-particle (pair-interaction) systems.
This module provides:
- **PBC utilities**: minimum-image displacement in arbitrary dimension.
- **Kernel families**: pre-built 1-D radial kernel functions.
- **Radial / scalar pair bases**: build Interactor objects from kernel
families, ready for dispatch over neighbor lists.
- **Angular coupling bases**: weighted orientation-coupling interactors.
- **Heading vector**: single-particle heading vector from an angle coordinate.
- **Tensor pair features**: rank-2 dyadic basis for diffusion tensors, nematic order, etc.
"""
from __future__ import annotations
from typing import Any, Callable, Sequence
import jax.numpy as jnp
from ..statefunc import Basis, Interactor, make_basis, make_interactor
# ═══════════════════════════════════════════════════════════════════════
# PBC UTILITIES
# ═══════════════════════════════════════════════════════════════════════
[docs]
def pbc_displacement(xj, xi, box):
"""Minimum-image displacement ``xj - xi`` under periodic boundaries.
Works in any dimension. All inputs are plain JAX arrays of shape
``(d,)`` (or broadcastable).
Parameters
----------
xj, xi : array, shape ``(d,)``
Positions (or sub-positions) of two particles.
box : array, shape ``(d,)``
Box lengths along each axis.
Returns
-------
dx : array, shape ``(d,)``
``xj - xi`` folded via minimum-image convention.
"""
dx = xj - xi
return dx - box * jnp.round(dx / box)
[docs]
def wrap_angle(a):
"""Wrap angle(s) to ``(-π, π]``."""
return (a + jnp.pi) % (2.0 * jnp.pi) - jnp.pi
def _pairwise_dr(XK, *, box=None, spatial_dims=None, eps=1e-12):
"""Compute displacement, distance and unit vector for a K=2 pair.
Parameters
----------
XK : array, shape ``(2, dim)``
Stacked pair ``[xi, xj]``.
box : array or None
PBC box lengths (applied to ``spatial_dims`` only).
spatial_dims : slice or index array, optional
Which dimensions of state vector are spatial coordinates.
Default: all.
eps : float
Regularisation to avoid division by zero.
Returns
-------
dx : array, shape ``(d_spatial,)``
r : scalar
rhat : array, shape ``(d_spatial,)``
"""
xi, xj = XK[0], XK[1]
if spatial_dims is not None:
xi_s, xj_s = xi[spatial_dims], xj[spatial_dims]
else:
xi_s, xj_s = xi, xj
if box is not None:
dx = pbc_displacement(xj_s, xi_s, box)
else:
dx = xj_s - xi_s
r = jnp.sqrt(jnp.sum(dx**2) + eps)
rhat = dx / r
return dx, r, rhat
# ═══════════════════════════════════════════════════════════════════════
# KERNEL FAMILIES
# ═══════════════════════════════════════════════════════════════════════
[docs]
def exp_poly_kernels(degrees, lengths):
r"""Radial kernels :math:`\phi_{k,L}(r) = r^k \exp(-r/L)`.
Parameters
----------
degrees : sequence of int
Polynomial degrees *k*.
lengths : sequence of float
Exponential decay lengths *L*.
Returns
-------
list of (callable, str)
Each entry is ``(phi, label)`` where ``phi(r) -> scalar``.
"""
out = []
for k in degrees:
for L in lengths:
k_, L_ = int(k), float(L)
def _phi(r, _k=k_, _L=L_):
return (r**_k) * jnp.exp(-r / _L)
out.append((_phi, f"r^{k_}·exp(-r/{L_:g})"))
return out
[docs]
def gaussian_kernels(sigmas):
r"""Radial kernels :math:`\phi_\sigma(r) = \exp(-r^2 / 2\sigma^2)`.
Parameters
----------
sigmas : sequence of float
Gaussian widths.
Returns
-------
list of (callable, str)
"""
out = []
for s in sigmas:
s_ = float(s)
def _phi(r, _s=s_):
return jnp.exp(-(r**2) / (2.0 * _s**2))
out.append((_phi, f"r^0·gauss(σ={s_:g})"))
return out
[docs]
def power_kernels(degrees):
r"""Radial kernels :math:`\phi_k(r) = r^k`.
Parameters
----------
degrees : sequence of int
Returns
-------
list of (callable, str)
"""
out = []
for k in degrees:
k_ = int(k)
def _phi(r, _k=k_):
return r**_k
out.append((_phi, f"r^{k_}"))
return out
[docs]
def compact_kernels(degrees, cutoff):
r"""Compactly-supported kernels :math:`r^k (1 - r/r_c)^2` for :math:`r < r_c`.
Parameters
----------
degrees : sequence of int
cutoff : float
Support radius *r_c*.
Returns
-------
list of (callable, str)
"""
rc = float(cutoff)
out = []
for k in degrees:
k_ = int(k)
def _phi(r, _k=k_, _rc=rc):
w = jnp.where(r < _rc, (1.0 - r / _rc) ** 2, 0.0)
return (r**_k) * w
out.append((_phi, f"r^{k_}·comp(rc={rc:g})"))
return out
# ═══════════════════════════════════════════════════════════════════════
# RADIAL / SCALAR PAIR BASES
# ═══════════════════════════════════════════════════════════════════════
[docs]
def radial_pair_basis(
kernels: Sequence[tuple[Callable, str]],
*,
dim: int,
box: Any = None,
spatial_dims: slice | Sequence[int] | None = None,
embed_dim: int | None = None,
embed_axes: Sequence[int] | None = None,
labels: Sequence[str] | None = None,
) -> Interactor:
r"""Build a rank-1 pair Interactor with radial-kernel features.
Each feature is :math:`\phi_\alpha(r_{ij})\,\hat{\mathbf{r}}_{ij}`
where :math:`r_{ij}` is the pairwise distance (optionally with PBC).
.. physics:: Radial pair interaction basis
:label: radial-pair-basis
:category: Basis functions
.. math::
f_\alpha(\mathbf{r}_{ij})
= \phi_\alpha(r_{ij})\;\hat{\mathbf{r}}_{ij}
Scalar radial kernel :math:`\phi_\alpha` times the unit displacement
vector. Available kernel families: exponential-polynomial,
Gaussian, power-law, and compactly supported.
Parameters
----------
kernels : list of (callable, str)
1-D kernel functions and their labels, as returned by
:func:`exp_poly_kernels`, :func:`gaussian_kernels`, etc.
dim : int
Full state-vector dimension per particle.
box : array, ``"extras"``, or None
PBC box lengths. ``None`` (default) = free-space, no periodic
boundaries. Pass an array for a static box captured in the
closure, or ``"extras"`` to read the box from
``extras["box"]`` at evaluation time. The box is applied over
``spatial_dims`` only.
spatial_dims : slice or index array, optional
Which axes of the state vector are spatial coordinates (the ones
over which distances are computed and the output vector is
defined). Default: ``slice(None)`` (all axes are spatial).
embed_dim : int, optional
If the output should be embedded into a larger vector (e.g., the
displacement lives in 2D but the state vector is 3D with an
angle), set ``embed_dim`` to the output vector size. Spatial
components are placed at ``embed_axes`` indices; remaining indices
are zero.
embed_axes : sequence of int, optional
Indices into the ``embed_dim``-length output where spatial
components are placed. Required when ``embed_dim is not None``.
labels : sequence of str, optional
Override labels (one per kernel).
Returns
-------
Interactor
Rank-1 (vector) interactor with ``K=2``, ``n_features=len(kernels)``.
Call ``.dispatch_pairs(...)`` to stream over neighbour lists.
"""
n_feat = len(kernels)
if labels is None:
labels = [lab for _, lab in kernels]
fns = [fn for fn, _ in kernels]
# Determine output vector size
if embed_dim is not None:
out_d = embed_dim
if embed_axes is None:
raise ValueError("embed_axes required when embed_dim is set")
_embed_axes = jnp.array(embed_axes, dtype=jnp.int32)
if len(embed_axes) != (
dim if spatial_dims is None
else len(range(*spatial_dims.indices(dim))) if isinstance(spatial_dims, slice)
else len(spatial_dims)
):
raise ValueError(
f"len(embed_axes)={len(embed_axes)} must equal the number of spatial "
f"dimensions selected by spatial_dims."
)
else:
out_d = (
dim
if spatial_dims is None
else len(range(*spatial_dims.indices(dim)))
if isinstance(spatial_dims, slice)
else len(spatial_dims)
)
_embed_axes = None
# Resolve box mode
_use_extras_box = box == "extras"
_box = None if (box is None or _use_extras_box) else jnp.asarray(box)
def _pair_local(XK, *, extras=None):
# Resolve box
b = _box
if b is None and _use_extras_box and extras is not None:
b = extras["box"]
if spatial_dims is not None:
b = b[spatial_dims] if b.ndim > 0 else b
dx, r, rhat = _pairwise_dr(XK, box=b, spatial_dims=spatial_dims)
# Evaluate kernels → (n_feat,)
vals = jnp.stack([fn(r) for fn in fns]) # (F,)
# Build output: (out_d, F) = rhat[:, None] * vals[None, :]
if _embed_axes is not None:
# Embed into larger vector
if len(embed_axes) != rhat.shape[0]:
raise ValueError(
f"embed_axes length ({len(embed_axes)}) must equal the number of "
f"selected spatial dims ({rhat.shape[0]}) from spatial_dims."
)
full = jnp.zeros((out_d, n_feat), dtype=XK.dtype)
for k, ax in enumerate(embed_axes):
full = full.at[ax, :].set(rhat[k] * vals)
return full
else:
return rhat[:, None] * vals[None, :] # (d_spatial, F)
extras_keys = ("box",) if _use_extras_box else ()
return make_interactor(
_pair_local,
dim=dim,
rank=1,
K=2,
n_features=n_feat,
extras_keys=extras_keys,
labels=list(labels),
descriptor="radial-pair-basis",
)
[docs]
def scalar_pair_basis(
kernels: Sequence[tuple[Callable, str]],
*,
dim: int,
box: Any = None,
spatial_dims: slice | Sequence[int] | None = None,
labels: Sequence[str] | None = None,
) -> Interactor:
r"""Build a rank-0 pair Interactor with scalar radial-kernel features.
Each feature is :math:`\phi_\alpha(r_{ij})` — the raw kernel value
without the directional :math:`\hat{r}` factor.
Use this for energy-like quantities, as radial weights for angular
coupling, or as building blocks for tensor pair features composed
via the ``*`` operator.
Parameters
----------
kernels, dim, box, spatial_dims, labels
Same as :func:`radial_pair_basis`.
Returns
-------
Interactor
Rank-0 (scalar) interactor with ``K=2``.
"""
n_feat = len(kernels)
if labels is None:
labels = [lab for _, lab in kernels]
fns = [fn for fn, _ in kernels]
_use_extras_box = box == "extras"
_box = None if (box is None or _use_extras_box) else jnp.asarray(box)
def _pair_local(XK, *, extras=None):
b = _box
if b is None and _use_extras_box and extras is not None:
b = extras["box"]
if spatial_dims is not None:
b = b[spatial_dims] if b.ndim > 0 else b
_, r, _ = _pairwise_dr(XK, box=b, spatial_dims=spatial_dims)
return jnp.stack([fn(r) for fn in fns]) # (F,)
extras_keys = ("box",) if _use_extras_box else ()
return make_interactor(
_pair_local,
dim=dim,
rank=0,
K=2,
n_features=n_feat,
extras_keys=extras_keys,
labels=list(labels),
descriptor="scalar-pair-basis",
)
# ═══════════════════════════════════════════════════════════════════════
# ANGULAR / ORIENTATION COUPLING
# ═══════════════════════════════════════════════════════════════════════
[docs]
def angular_pair_basis(
kernels: Sequence[tuple[Callable, str]],
coupling_fn: Callable,
*,
dim: int,
angle_index: int,
output_index: int | None = None,
box: Any = None,
spatial_dims: slice | Sequence[int] | None = None,
coupling_label: str = "g",
labels: Sequence[str] | None = None,
) -> Interactor:
r"""Build a rank-1 pair Interactor for orientation coupling.
Each feature computes
:math:`\phi_\alpha(r_{ij})\,g(\theta_j - \theta_i)` and embeds the
result along ``output_index`` in a ``dim``-d output vector.
Parameters
----------
kernels : list of (callable, str)
Radial weight functions (same format as other kernel factories).
coupling_fn : callable
Scalar function of the angle difference, e.g. ``jnp.sin`` for
alignment, ``lambda a: jnp.cos(2*a)`` for nematic coupling.
dim : int
Full state-vector dimension.
angle_index : int
Index of the angle coordinate in the state vector.
output_index : int, optional
Index along which the coupled output is placed. Defaults to
``angle_index``.
box, spatial_dims
PBC and spatial-dimension controls (same as :func:`radial_pair_basis`).
coupling_label : str
Short label for the coupling (appears in feature labels).
labels : list of str, optional
Override labels.
Returns
-------
Interactor
Rank-1 (vector) interactor with ``K=2``.
"""
n_feat = len(kernels)
if output_index is None:
output_index = angle_index
if labels is None:
labels = [f"{coupling_label}·{lab}" for _, lab in kernels]
fns = [fn for fn, _ in kernels]
_use_extras_box = box == "extras"
_box = None if (box is None or _use_extras_box) else jnp.asarray(box)
def _pair_local(XK, *, extras=None):
xi, xj = XK[0], XK[1]
# distance
b = _box
if b is None and _use_extras_box and extras is not None:
b = extras["box"]
if spatial_dims is not None:
b = b[spatial_dims] if b.ndim > 0 else b
_, r, _ = _pairwise_dr(XK, box=b, spatial_dims=spatial_dims)
# angle coupling
dth = wrap_angle(xj[angle_index] - xi[angle_index])
g = coupling_fn(dth)
vals = jnp.stack([fn(r) * g for fn in fns]) # (F,)
# embed along output_index
out = jnp.zeros((dim, n_feat), dtype=XK.dtype)
out = out.at[output_index, :].set(vals)
return out # (dim, F)
extras_keys = ("box",) if _use_extras_box else ()
return make_interactor(
_pair_local,
dim=dim,
rank=1,
K=2,
n_features=n_feat,
extras_keys=extras_keys,
labels=list(labels),
descriptor="angular-pair-basis",
)
# ═══════════════════════════════════════════════════════════════════════
# SINGLE-PARTICLE: HEADING VECTOR
# ═══════════════════════════════════════════════════════════════════════
[docs]
def heading_vector(dim: int, angle_index: int, *, spatial_axes: tuple[int, ...] | None = None) -> Basis:
r"""Single-particle heading vector from an angle coordinate.
Returns a rank-1 Basis whose single feature is the unit vector
:math:`(\cos\theta, \sin\theta)` embedded in a ``dim``-d vector,
with the cosine and sine placed at ``spatial_axes[0]`` and
``spatial_axes[1]`` respectively.
Parameters
----------
dim : int
State-vector dimension.
angle_index : int
Index of the angle coordinate :math:`\theta`.
spatial_axes : (int, int), optional
Indices for (cos θ, sin θ) in the output.
Default: ``(0, 1)`` — i.e. the first two axes.
Returns
-------
Basis
Rank-1, 1-feature heading-vector basis.
"""
if spatial_axes is None:
spatial_axes = (0, 1)
ax_cos, ax_sin = spatial_axes[0], spatial_axes[1]
def _f(x, *, mask=None):
th = x[angle_index]
out = jnp.zeros(dim, dtype=x.dtype)
out = out.at[ax_cos].set(jnp.cos(th))
out = out.at[ax_sin].set(jnp.sin(th))
return out[:, None] # (dim, 1) — rank-1, 1 feature
return make_basis(
_f,
dim=dim,
rank=1,
n_features=1,
labels=("e_heading",),
descriptor="heading-vector",
)
# ═══════════════════════════════════════════════════════════════════════
# TENSOR PAIR FEATURES
# ═══════════════════════════════════════════════════════════════════════
[docs]
def dyadic_pair_basis(
kernels: Sequence[tuple[Callable, str]],
*,
dim: int,
box: Any = None,
spatial_dims: slice | Sequence[int] | None = None,
labels: Sequence[str] | None = None,
) -> Interactor:
r"""Build a rank-2 (tensor) pair Interactor: :math:`\phi(r)\,\hat{r}\otimes\hat{r}`.
Each feature is the outer product of the unit displacement vector
with itself, weighted by a radial kernel. Useful for directional
diffusion tensors, nematic order parameters, etc.
Parameters
----------
kernels, dim, box, spatial_dims, labels
Same as :func:`radial_pair_basis`.
Returns
-------
Interactor
Rank-2 (matrix) interactor with ``K=2``.
"""
n_feat = len(kernels)
if labels is None:
labels = [f"rr·{lab}" for _, lab in kernels]
fns = [fn for fn, _ in kernels]
_use_extras_box = box == "extras"
_box = None if (box is None or _use_extras_box) else jnp.asarray(box)
def _pair_local(XK, *, extras=None):
b = _box
if b is None and _use_extras_box and extras is not None:
b = extras["box"]
if spatial_dims is not None:
b = b[spatial_dims] if b.ndim > 0 else b
_, r, rhat = _pairwise_dr(XK, box=b, spatial_dims=spatial_dims)
vals = jnp.stack([fn(r) for fn in fns]) # (F,)
# rhat ⊗ rhat: (d, d) then weight by each kernel → (d, d, F)
rr = jnp.outer(rhat, rhat) # (d, d)
return rr[:, :, None] * vals[None, None, :] # (d, d, F)
extras_keys = ("box",) if _use_extras_box else ()
return make_interactor(
_pair_local,
dim=dim,
rank=2,
K=2,
n_features=n_feat,
extras_keys=extras_keys,
labels=list(labels),
descriptor="dyadic-pair-basis",
)
# ═══════════════════════════════════════════════════════════════════════
# COMPOSABLE GEOMETRIC PRIMITIVES
# ═══════════════════════════════════════════════════════════════════════
#
# Single-feature Interactors designed for composition via ``*``
# (element-wise spatial multiplication with feature Cartesian product).
# Combine a direction (rank-1), a scalar gate (rank-0), and a
# parametric radial kernel (rank-0) to build rich pair forces.
# ═══════════════════════════════════════════════════════════════════════
[docs]
def pair_direction(
*,
dim: int,
box: Any = None,
spatial_dims: slice | Sequence[int] | None = None,
embed_dim: int | None = None,
embed_axes: Sequence[int] | None = None,
) -> Interactor:
r"""Unit displacement vector :math:`\hat{r}_{ij}` as a rank-1 Interactor.
Returns a single-feature, rank-1 pair Interactor whose output is the
unit vector pointing from particle *i* to particle *j*.
Parameters
----------
dim : int
Full state-vector dimension per particle.
box : array, ``"extras"``, or None
PBC box lengths (same semantics as :func:`radial_pair_basis`).
spatial_dims : slice or index array, optional
Which axes are spatial coordinates.
embed_dim : int, optional
Embed into a larger output vector (e.g. 2-D displacement in 3-D state).
embed_axes : sequence of int, optional
Indices for embedding (required when ``embed_dim`` is set).
Returns
-------
Interactor
Rank-1, 1-feature interactor with ``K=2``.
"""
if embed_dim is not None:
out_d = embed_dim
if embed_axes is None:
raise ValueError("embed_axes required when embed_dim is set")
else:
out_d = (
dim
if spatial_dims is None
else (len(range(*spatial_dims.indices(dim))) if isinstance(spatial_dims, slice) else len(spatial_dims))
)
_use_extras_box = box == "extras"
_box = None if (box is None or _use_extras_box) else jnp.asarray(box)
def _pair_local(XK, *, extras=None):
b = _box
if b is None and _use_extras_box and extras is not None:
b = extras["box"]
if spatial_dims is not None:
b = b[spatial_dims] if b.ndim > 0 else b
_, _, rhat = _pairwise_dr(XK, box=b, spatial_dims=spatial_dims)
if embed_dim is not None:
full = jnp.zeros(out_d, dtype=XK.dtype)
for k, ax in enumerate(embed_axes):
full = full.at[ax].set(rhat[k])
return full[:, None] # (out_d, 1)
return rhat[:, None] # (d_spatial, 1)
extras_keys = ("box",) if _use_extras_box else ()
return make_interactor(
_pair_local,
dim=dim,
rank=1,
K=2,
n_features=1,
extras_keys=extras_keys,
labels=("r̂_ij",),
descriptor="pair-direction",
)
[docs]
def angle_coupling(
coupling_fn: Callable,
*,
dim: int,
angle_index: int,
output_index: int | None = None,
label: str = "g",
) -> Interactor:
r"""Scalar orientation coupling embedded as a rank-1 Interactor.
Computes ``coupling_fn(θ_j − θ_i)`` and places the result along
``output_index`` in a ``dim``-d output vector.
Parameters
----------
coupling_fn : callable
Scalar function of the wrapped angle difference, e.g. ``jnp.sin``.
dim : int
Full state-vector dimension.
angle_index : int
Index of the angle coordinate in the state vector.
output_index : int, optional
Output axis for the coupling value. Defaults to ``angle_index``.
label : str
Short label used in feature names.
Returns
-------
Interactor
Rank-1, 1-feature interactor with ``K=2``.
"""
if output_index is None:
output_index = angle_index
def _pair_local(XK, *, extras=None):
xi, xj = XK[0], XK[1]
dth = wrap_angle(xj[angle_index] - xi[angle_index])
g = coupling_fn(dth)
out = jnp.zeros(dim, dtype=XK.dtype)
out = out.at[output_index].set(g)
return out[:, None] # (dim, 1)
return make_interactor(
_pair_local,
dim=dim,
rank=1,
K=2,
n_features=1,
labels=(label,),
descriptor="angle-coupling",
)
[docs]
def particle_heading(
which: int,
*,
dim: int,
angle_index: int,
spatial_axes: tuple[int, ...] | None = None,
) -> Interactor:
r"""Heading vector of one particle in a pair, as a rank-1 Interactor.
Returns :math:`(\cos\theta, \sin\theta)` of the selected particle
(``which=0`` for the focal particle, ``which=1`` for the neighbor),
embedded in a ``dim``-d output vector.
Parameters
----------
which : int
``0`` for the focal particle, ``1`` for the neighbor.
dim : int
Full state-vector dimension.
angle_index : int
Index of the angle coordinate.
spatial_axes : (int, int), optional
Indices for (cos θ, sin θ) in the output. Default: ``(0, 1)``.
Returns
-------
Interactor
Rank-1, 1-feature interactor with ``K=2``.
"""
if spatial_axes is None:
spatial_axes = (0, 1)
ax_cos, ax_sin = spatial_axes[0], spatial_axes[1]
def _pair_local(XK, *, extras=None):
th = XK[which][angle_index]
out = jnp.zeros(dim, dtype=XK.dtype)
out = out.at[ax_cos].set(jnp.cos(th))
out = out.at[ax_sin].set(jnp.sin(th))
return out[:, None] # (dim, 1)
return make_interactor(
_pair_local,
dim=dim,
rank=1,
K=2,
n_features=1,
labels=(f"ê_θ[{which}]",),
descriptor="particle-heading",
)
[docs]
def vision_gate(
gate_fn: Callable,
*,
dim: int,
angle_index: int,
box: Any = None,
spatial_dims: slice | Sequence[int] | None = None,
) -> Interactor:
r"""Scalar vision-cone gate as a rank-0 Interactor.
Computes the bearing angle :math:`\delta` from the focal particle's
heading to the displacement toward the neighbor, then returns
``gate_fn(δ)``. This makes interactions *nonreciprocal*: the gate
value depends on whether the neighbor is "in front" or "behind" the
focal particle.
Parameters
----------
gate_fn : callable
Scalar function of the bearing angle, e.g.
``lambda d: (1 + jnp.cos(d)) / 2`` for a cosine vision cone.
dim : int
Full state-vector dimension.
angle_index : int
Index of the angle coordinate.
box : array, ``"extras"``, or None
PBC box (same semantics as :func:`radial_pair_basis`).
spatial_dims : slice or index array, optional
Which axes are spatial coordinates (for the displacement).
Must select exactly 2 dimensions — bearing angle is 2-D only.
Returns
-------
Interactor
Rank-0, 1-feature interactor with ``K=2``.
"""
_use_extras_box = box == "extras"
_box = None if (box is None or _use_extras_box) else jnp.asarray(box)
def _pair_local(XK, *, extras=None):
xi, _ = XK[0], XK[1]
# Spatial displacement
b = _box
if b is None and _use_extras_box and extras is not None:
b = extras["box"]
if spatial_dims is not None:
b = b[spatial_dims] if b.ndim > 0 else b
dx, _, _ = _pairwise_dr(XK, box=b, spatial_dims=spatial_dims)
if dx.shape[0] != 2:
raise ValueError(
f"vision_gate requires exactly 2 spatial dimensions, got {dx.shape[0]}. "
"Use spatial_dims to select 2 axes from a higher-dimensional state vector."
)
# Bearing angle: direction of neighbour relative to heading
phi_ij = jnp.arctan2(dx[1], dx[0])
theta_i = xi[angle_index]
delta = wrap_angle(phi_ij - theta_i)
return gate_fn(delta)[None] # (1,) — rank-0, 1 feature
extras_keys = ("box",) if _use_extras_box else ()
return make_interactor(
_pair_local,
dim=dim,
rank=0,
K=2,
n_features=1,
extras_keys=extras_keys,
labels=("vision",),
descriptor="vision-gate",
)
[docs]
def parametric_radial_kernel(
kernel_fn: Callable,
*,
params: dict,
dim: int,
box: Any = None,
spatial_dims: slice | Sequence[int] | None = None,
) -> Interactor:
r"""Parametric scalar radial kernel as a rank-0 Interactor.
Wraps a user-supplied function ``kernel_fn(r, params)`` into a
rank-0 pair Interactor with learnable parameters.
Parameters
----------
kernel_fn : callable
``kernel_fn(r, params) -> scalar`` where *r* is the inter-particle
distance and *params* is a dict of JAX arrays.
params : dict
Parameter specification passed to :func:`make_interactor`, e.g.
``{"eps": (), "R0": ()}`` for two scalar parameters.
dim : int
Full state-vector dimension.
box : array, ``"extras"``, or None
PBC box (same semantics as :func:`radial_pair_basis`).
spatial_dims : slice or index array, optional
Which axes are spatial coordinates.
Returns
-------
Interactor
Rank-0, 1-feature parametric interactor with ``K=2``.
"""
_use_extras_box = box == "extras"
_box = None if (box is None or _use_extras_box) else jnp.asarray(box)
def _pair_local(XK, *, params=None, extras=None):
b = _box
if b is None and _use_extras_box and extras is not None:
b = extras["box"]
if spatial_dims is not None:
b = b[spatial_dims] if b.ndim > 0 else b
_, r, _ = _pairwise_dr(XK, box=b, spatial_dims=spatial_dims)
return kernel_fn(r, params)[None] # (1,) — rank-0, 1 feature
extras_keys = ("box",) if _use_extras_box else ()
return make_interactor(
_pair_local,
dim=dim,
rank=0,
K=2,
n_features=1,
params=params,
extras_keys=extras_keys,
labels=("k(r)",),
descriptor="parametric-radial-kernel",
)
# ═══════════════════════════════════════════════════════════════════════
# VELOCITY-DEPENDENT PAIR BASES
# ═══════════════════════════════════════════════════════════════════════
[docs]
def pair_velocity_difference(
*,
dim: int,
) -> Interactor:
r"""Velocity difference :math:`\mathbf{v}_j - \mathbf{v}_i` as a rank-1 Interactor.
Returns a single-feature, rank-1 pair Interactor whose output is the
velocity difference between neighbor and focal particle. Designed
for composition with scalar pair Interactors via the ``*`` operator,
e.g.::
scalar_pair_basis(kernels, dim=d) * pair_velocity_difference(dim=d)
Parameters
----------
dim : int
State-vector dimension per particle.
Returns
-------
Interactor
Rank-1, 1-feature interactor with ``K=2``, ``needs_v=True``.
"""
def _pair_local(XK, *, v, extras=None):
dv = v[1] - v[0] # (dim,)
return dv[:, None] # (dim, 1)
return make_interactor(
_pair_local,
dim=dim,
rank=1,
K=2,
n_features=1,
needs_v=True,
labels=("Δv_ij",),
descriptor="pair-velocity-difference",
)
[docs]
def particle_velocity(
which: int,
*,
dim: int,
) -> Interactor:
r"""Velocity of one particle in a pair, as a rank-1 Interactor.
Returns the velocity of either the focal particle (``which=0``) or
the neighbor (``which=1``) as a rank-1 Interactor. Designed for
composition with scalar pair Interactors via the ``*`` operator, e.g.::
scalar_pair_basis(kernels, dim=d) * particle_velocity(which=1, dim=d)
Parameters
----------
which : int
``0`` for the focal particle's velocity, ``1`` for the neighbor's.
dim : int
State-vector dimension per particle.
Returns
-------
Interactor
Rank-1, 1-feature interactor with ``K=2``, ``needs_v=True``.
"""
def _pair_local(XK, *, v, extras=None):
return v[which][:, None] # (dim, 1)
return make_interactor(
_pair_local,
dim=dim,
rank=1,
K=2,
n_features=1,
needs_v=True,
labels=(f"v[{which}]",),
descriptor="particle-velocity",
)