Source code for SFI.utils.plotting

from datetime import datetime

import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.collections import LineCollection
from matplotlib.colors import Normalize
from scipy.stats import pearsonr

from SFI.trajectory.collection import TrajectoryCollection
from SFI.trajectory.dataset import TrajectoryDataset

""" Utility classes for plotting the results of SFI, used only to
display the results of the example files.

"""

# ---------------------------------------------------------------------------
# Gallery colour palette
# ---------------------------------------------------------------------------

#: Named colour palette for consistent plotting across gallery examples
#: and user notebooks.  Adjusted for readability on both dark (#131416)
#: and light backgrounds.
SFI_COLORS = dict(
    data="#3B9EFF",  # bright blue
    inferred="#FFC20A",  # gold
    exact="#FF7A1A",  # bright orange
    bootstrap="#5D3A9B",  # purple
    highlight="#40B0A6",  # teal
    error="#FF2D6F",  # bright pink
    secondary="#1A85FF",  # lighter blue
    tertiary="#D35FB7",  # pink/magenta
)


# ---------------------------------------------------------------------------
# Figure timestamp
# ---------------------------------------------------------------------------


[docs] def stamp_output(): """Print a discreet generation timestamp to stdout. Call once per script so the timestamp appears in terminal output blocks on gallery pages and in notebook cells. """ print(f"[Generated: {datetime.now().strftime('%Y-%m-%d %H:%M')}]")
[docs] def stamp_fig(fig=None): """Add a discreet generation timestamp to the bottom-right of a figure. Useful for tracking when a cached or gallery figure was last rendered. Idempotent: a figure is only stamped once. Parameters ---------- fig : matplotlib.figure.Figure, optional Figure to stamp. Defaults to ``plt.gcf()``. """ if fig is None: fig = plt.gcf() if getattr(fig, "_sfi_stamped", False): return fig._sfi_stamped = True now = datetime.now().strftime("%Y-%m-%d %H:%M") fig.text( 0.99, 0.005, now, fontsize=5, color="#808080", alpha=0.4, ha="right", va="bottom", transform=fig.transFigure, )
# --------------------------------------------------------------------------- # Dark-theme figure helpers (reusable across gallery demos & notebooks) # ---------------------------------------------------------------------------
[docs] def dark_ax(ax): """Style a matplotlib Axes for a dark background. Sets black face colour, white ticks and labels, and dark-grey spines. """ ax.set_facecolor("black") ax.tick_params(colors="white", which="both") for spine in ax.spines.values(): spine.set_edgecolor("0.3") ax.xaxis.label.set_color("white") ax.yaxis.label.set_color("white") ax.title.set_color("white")
[docs] def dark_fig(nrows=1, ncols=1, **kw): """Create a figure and axes with a black background. All arguments are forwarded to ``plt.subplots``. Returns ``(fig, axes)`` like ``plt.subplots``. """ fig, axes = plt.subplots(nrows, ncols, **kw) fig.patch.set_facecolor("black") for ax in np.atleast_1d(axes).flat if hasattr(axes, "flat") else [axes]: dark_ax(ax) return fig, axes
[docs] def wrap_positions(X, box): """Wrap positions into a periodic box ``[0, box_i)`` per axis. Parameters ---------- X : array_like, shape (..., d) Position array. box : array_like, shape (d,) or (2,) Box dimensions. Returns ------- X_wrapped : ndarray Copy of *X* with each column wrapped modulo the corresponding box dimension. """ X = np.array(X, copy=True, dtype=float) box = np.asarray(box) for i in range(len(box)): X[..., i] = X[..., i] % box[i] return X
def _equal_aspect(ax): """Set an equal aspect ratio, choosing an adjustable that is legal here. ``adjustable="datalim"`` (and ``axis("equal")``) raise at draw time when the axes shares its x or y with a sibling (e.g. ``sharex=True`` subplots), so fall back to ``adjustable="box"`` in that case. """ shared = ( len(ax.get_shared_x_axes().get_siblings(ax)) > 1 or len(ax.get_shared_y_axes().get_siblings(ax)) > 1 ) ax.set_aspect("equal", adjustable="box" if shared else "datalim") # --------------------------------------------------------------------------- # Sparse-model diagnostics # ---------------------------------------------------------------------------
[docs] def plot_pareto_front(result, *, criteria=("PASTIS", "BIC", "AIC"), ax=None): """Plot the Pareto front: information gain vs model size, with IC optima. Parameters ---------- result : SparsityResult As returned by ``inf.sparsify_force()``. criteria : tuple of str Information-criterion names to mark on the plot. ax : matplotlib Axes, optional If *None*, a new figure is created. Returns ------- ax : matplotlib Axes """ if ax is None: _, ax = plt.subplots() infos = result.best_info_by_k ks = list(range(len(infos))) valid = [(k, float(info)) for k, info in zip(ks, infos) if float(info) > -1e30] if valid: ks_v, infos_v = zip(*valid) else: ks_v, infos_v = [], [] ax.plot(ks_v, infos_v, ".-", color="#B0B0B0", lw=1.5, label="Pareto front") colors = [SFI_COLORS["inferred"], SFI_COLORS["exact"], SFI_COLORS["highlight"]] for ic_name, c in zip(criteria, colors): try: k_sel, _, score, _ = result.select_by_ic(ic_name, p_param=1e-3) info_at_k = float(infos[k_sel]) if k_sel < len(infos) else None if info_at_k is not None and info_at_k > -1e30: ax.axvline(k_sel, color=c, ls="--", alpha=0.7, label=f"{ic_name} (k={k_sel})") ax.plot(k_sel, info_at_k, "o", color=c, ms=8, zorder=5) except Exception: # ic_name absent or result schema mismatch pass ax.set_xlabel("Model size k") ax.set_ylabel("Information gain") ax.legend(fontsize=8) ax.set_title("Sparse model selection") return ax
[docs] def plot_recovery_bar( coeffs_inferred, support_inferred, *, coeffs_true=None, support_true=None, labels=None, stderr=None, yscale: str = "linear", sort: bool = False, show_pruned: bool = False, ax=None, ): """Bar chart of inferred sparse coefficients vs ground truth. Parameters ---------- coeffs_inferred : array_like Coefficient values for the inferred support. support_inferred : array_like Indices of the selected basis functions. coeffs_true, support_true : array_like, optional Ground-truth coefficients / support (paired comparison). labels : list of str, optional Tick labels for basis functions. stderr : array_like, optional Standard errors for the inferred coefficients (drawn as error caps). yscale : str Matplotlib y-scale (``"linear"`` or ``"log"``; use ``"log"`` for magnitude bars with widely varying scales). sort : bool If True, order bars by descending ``|coefficient|``. show_pruned : bool If True (and ``labels`` given), append faded zero-bars for basis functions outside the inferred support. ax : matplotlib Axes, optional If *None*, a new figure is created. Returns ------- ax : matplotlib Axes """ if ax is None: _, ax = plt.subplots() coeffs_inferred = np.asarray(coeffs_inferred, dtype=float) support_inferred = list(support_inferred) stderr = np.asarray(stderr, dtype=float) if stderr is not None else None if sort: order = np.argsort(-np.abs(coeffs_inferred)) coeffs_inferred = coeffs_inferred[order] support_inferred = [support_inferred[k] for k in order] if stderr is not None: stderr = stderr[order] n = len(support_inferred) x = np.arange(n) bar_width = min(0.6, 6.0 / n) if n >= 5 else 0.8 paired = coeffs_true is not None and support_true is not None offset = bar_width / 2 if paired else 0.0 ax.bar( x - offset, coeffs_inferred, width=bar_width, color=SFI_COLORS["inferred"], alpha=0.8, yerr=stderr, capsize=4 if stderr is not None else 0, ecolor="#B0B0B0", label="inferred", ) if paired: support_true_l = list(support_true) true_mapped = np.zeros(n) for i, s in enumerate(support_inferred): if s in support_true_l: true_mapped[i] = float(np.asarray(coeffs_true)[support_true_l.index(s)]) ax.bar( x + offset, true_mapped, width=bar_width, fill=False, edgecolor=SFI_COLORS["exact"], lw=2, label="exact", ) extra_ticks, extra_labels = [], [] if show_pruned and labels is not None: pruned = [k for k in range(len(labels)) if k not in set(support_inferred)] for off, s in enumerate(pruned): xp = n + off ax.bar(xp, 0.0, width=bar_width, color="#808080", alpha=0.3) extra_ticks.append(xp) extra_labels.append(labels[s]) if labels is not None: tick_labels = [labels[s] if s < len(labels) else str(s) for s in support_inferred] + extra_labels ax.set_xticks(list(x) + extra_ticks) ax.set_xticklabels(tick_labels, rotation=45, ha="right", fontsize=8 if n < 8 else 7) else: ax.set_xticks(x) ax.set_xticklabels([str(s) for s in support_inferred]) if n < 5 and not extra_ticks: ax.set_xlim(-0.8, n - 0.2) ax.set_yscale(yscale) ax.set_ylabel("Coefficient value") ax.legend() ax.set_title("Sparse model coefficients") return ax
[docs] def plot_recovery_bar_multi(coeffs_list, labels, *, coeffs_true=None, group_names=None, ax=None): """Grouped bar chart comparing coefficients across several models. ``coeffs_list`` is a list of coefficient vectors (one per regime / solver), each aligned to ``labels``. An optional ``coeffs_true`` is overlaid as a step reference. """ if ax is None: _, ax = plt.subplots() coeffs_list = [np.asarray(c, dtype=float) for c in coeffs_list] G = len(coeffs_list) m = len(labels) x = np.arange(m) group_names = group_names if group_names is not None else [f"model {i + 1}" for i in range(G)] width = 0.8 / max(G, 1) palette = list(SFI_COLORS.values()) for gi, (c, name) in enumerate(zip(coeffs_list, group_names)): ax.bar(x + (gi - (G - 1) / 2) * width, c, width=width, color=palette[gi % len(palette)], alpha=0.85, label=name) if coeffs_true is not None: ct = np.asarray(coeffs_true, dtype=float) ax.step(np.concatenate([x - 0.5, [x[-1] + 0.5]]), np.concatenate([ct, ct[-1:]]), where="post", color=SFI_COLORS["exact"], lw=1.5, label="exact") ax.set_xticks(x) ax.set_xticklabels(list(labels), rotation=45, ha="right", fontsize=8 if m < 8 else 7) ax.set_ylabel("Coefficient value") ax.axhline(0, color="#808080", lw=0.5) ax.legend(fontsize=8) return ax
[docs] def plot_recovery_matrix(true, inferred, *, row_labels=None, col_labels=None, cmap="RdBu_r", vmax=None, axes=None): """Side-by-side ``imshow`` of a true vs inferred parameter matrix.""" true = np.asarray(true, dtype=float) inferred = np.asarray(inferred, dtype=float) if axes is None: _, axes = plt.subplots(1, 2, figsize=(8, 4)) if vmax is None: vmax = float(max(np.abs(true).max(), np.abs(inferred).max())) im = None for ax, mat, title in zip(axes, [true, inferred], ["True", "Inferred"]): im = ax.imshow(mat, cmap=cmap, vmin=-vmax, vmax=vmax, aspect="auto") ax.set_title(title) if col_labels is not None: ax.set_xticks(range(mat.shape[1])) ax.set_xticklabels(col_labels, rotation=45, ha="right", fontsize=8) if row_labels is not None: ax.set_yticks(range(mat.shape[0])) ax.set_yticklabels(row_labels, fontsize=8) if im is not None: plt.colorbar(im, ax=list(axes), shrink=0.8) return axes
def _collection_to_arrays( coll: TrajectoryCollection, *, dataset: int = 0, ): """ Extract (t, X, mask) from a TrajectoryCollection as NumPy arrays. Parameters ---------- coll : TrajectoryCollection object. dataset : Dataset index inside the collection. Returns ------- t : ndarray, shape (T,) Absolute times. X : ndarray, shape (T, N, d) State array. mask : ndarray, shape (T, N) Boolean validity mask. """ if not isinstance(coll, TrajectoryCollection): raise TypeError(f"Expected TrajectoryCollection, got {type(coll)!r}") if not coll.datasets: raise ValueError("Empty TrajectoryCollection") if not (0 <= dataset < len(coll.datasets)): raise IndexError(f"dataset index {dataset} out of range for D={len(coll.datasets)}") ds: TrajectoryDataset = coll.datasets[dataset] # Positions, always as (T, N, d) X = np.asarray(ds._X3d()) # Mask try: M = np.asarray(ds._M2d()) except Exception: M = np.ones(X.shape[:2], dtype=bool) T = X.shape[0] # Time axis if ds.t is not None: t = np.asarray(ds.t) else: dt = ds.dt if dt is None: t = np.arange(T, dtype=float) else: dt_arr = np.asarray(dt) if dt_arr.ndim == 0: t = np.arange(T, dtype=float) * float(dt_arr) elif dt_arr.ndim == 1: if T == 0: t = np.zeros((0,), dtype=float) else: # t[0] = 0; t[1:] = cumsum(dt[:-1]) t = np.concatenate([[0.0], np.cumsum(dt_arr[:-1])]) else: raise ValueError("dt must be scalar or (T,) to build a time axis.") return t, X, M
[docs] def axisvector(index, dim): """d-dimensional unit vector pointing in direction `index`.""" e = np.zeros(dim, dtype=float) e[index] = 1.0 return e
[docs] def comparison_scatter( Xexact, Xinferred, error=None, maxpoints=10000, vmax=None, color=None, alpha=0.05, y=0.8, mode="both", fontsize=9, ): """This method is used to compare inferred components to the exact ones along the trajectory, in a graphical way. Xexact, Xinferred: jnp arrays (Nsteps,...); must have the same shape. error: predicted standard deviation for X_inferred. maxpoints: if Nsteps / 2 * maxpoints, data will be subsampled. """ subsample = max(1, Xexact.shape[0] // maxpoints) # Flatten the data: Xe = np.array(Xexact)[::subsample].reshape(-1) Xi = np.array(Xinferred)[::subsample].reshape(-1) MSE = sum((Xe - Xi) ** 2) / sum(Xe**2 + Xi**2) if vmax is None: vmax = max(abs(Xe).max(), abs(Xi).max()) plt.scatter(Xe, Xi, alpha=alpha, linewidth=0, c=color) if error is not None: xvals = np.array([-vmax, vmax]) confidence_interval = 2 * error**0.5 * Xi.std() plt.plot(xvals, xvals + confidence_interval, ls=":", color="#808080") plt.plot(xvals, xvals - confidence_interval, ls=":", color="#808080") (r, p) = pearsonr(Xe, Xi) plt.plot([-1e10, 1e10], [-1e10, 1e10], ls="-", color="#808080") plt.grid(True) _equal_aspect(plt.gca()) plt.xlabel("exact") plt.ylabel("inferred") titlestring = "" if mode == "r" or mode == "both": titlestring += r"$r=" + str(round(r, 2 if r < 0.98 else 3 if r < 0.999 else 4 if r < 0.9999 else 5)) + "$" if mode == "both": titlestring += "\n" if mode == "MSE" or mode == "both": titlestring += "MSE=" + str(round(MSE, 3)) plt.title(titlestring, loc="left", y=y, x=0.05, fontsize=fontsize) plt.xticks([0.0]) plt.yticks([0.0]) plt.xlim(-vmax, vmax) plt.ylim(-vmax, vmax)
[docs] def timeseries( coll: TrajectoryCollection, *, dims=None, dataset: int = 0, particles=None, transform=None, ax=None, **plot_kw, ): """ Plot x[dim](t) for one or many particles from a TrajectoryCollection. By default, plot all dimensions for the selected particles. Parameters ---------- coll : TrajectoryCollection containing the data. dims : Iterable of state dimensions to plot. If None, plot all dims. dataset : Dataset index inside the collection (default 0). particles : Iterable of particle indices to include. If None, include all particles. transform : Optional callable applied elementwise to the plotted values (e.g. ``np.exp`` to map a log-space coordinate back to population space). ax : Optional Matplotlib Axes. If None, use ``plt.gca()``. **plot_kw : Forwarded to ``ax.plot``. Returns ------- ax : Matplotlib Axes used. """ if ax is None: ax = plt.gca() t, X, M = _collection_to_arrays(coll, dataset=dataset) # X: (T, N, D) T, N, D = X.shape # Select particles if particles is None: Xp = X Mp = M Nsel = N else: idx = np.asarray(particles, dtype=int) Xp = X[:, idx, :] Mp = M[:, idx] Nsel = Xp.shape[1] # Select dims if dims is None: dims = range(D) dims = list(dims) # Mask-aware plotting: one line per particle per dimension for n in range(Nsel): mask_n = Mp[:, n].astype(bool) for d in dims: vals = Xp[mask_n, n, d] if transform is not None: vals = transform(vals) ax.plot(t[mask_n], vals, **plot_kw) ax.set_xlabel("t") if len(dims) == 1: ax.set_ylabel(f"x[{dims[0]}]") else: ax.set_ylabel("x[d]") return ax
[docs] def phase2d( coll: TrajectoryCollection, *, dataset: int = 0, dims=None, dir1=None, dir2=None, shift=(0.0, 0.0), tmin=None, tmax=None, particles=None, cmap="viridis", linewidth: float = 1.5, alpha: float = 1.0, plot_colorbar: bool = False, box=None, transform=None, color=None, ax=None, drop_masked: bool = True, ) -> LineCollection: """ 2D phase-space plot with connected line segments colored along the trajectory. Parameters ---------- coll : TrajectoryCollection containing the data. dataset : Dataset index inside the collection. dims : Pair of coordinate indices (i, j) to plot. Ignored if ``dir1``/``dir2`` are provided. If None and no directions are given, defaults to (0, 1). dir1, dir2 : Optional projection directions in R^d. If given, positions are projected onto these directions instead of using coordinate axes. shift : (xshift, yshift) added to all positions. tmin, tmax : Integer time-index bounds. If None, full range is used. Negative ``tmax`` is interpreted as from the end. particles : Sequence of particle indices to include. If None, include all. cmap : Colormap for the time-coloring. linewidth : Line width of the trajectory segments. alpha : Global alpha for the line collection. plot_colorbar : If True, add a colorbar for the time coloring. box : Optional periodic box ``(Lx, Ly)``. If given, positions are wrapped modulo the box and segments that cross a boundary are dropped, so trajectories render correctly in a periodic domain. transform : Optional callable applied elementwise to the projected ``x`` and ``y`` (e.g. ``np.exp`` to map a log-space coordinate back to population space). Applied after projection/shift/box-wrap. color : Optional single color. If given, draw the trajectory in this solid color instead of the time-colored gradient (incompatible with ``plot_colorbar``). ax : Optional Matplotlib Axes. If None, use ``plt.gca()``. drop_masked : If True, drop segments where either endpoint is masked. Returns ------- lc : The created :class:`matplotlib.collections.LineCollection`. """ if ax is None: ax = plt.gca() t, X, M = _collection_to_arrays(coll, dataset=dataset) # X: (T, N, d) T, N, d = X.shape # Time window if tmin is None: tmin = 0 if tmax is None: tmax = T if tmax < 0: tmax = T + tmax tmin = max(0, int(tmin)) tmax = min(T, int(tmax)) if tmax <= tmin: raise ValueError("Empty time window in phase2d") X = X[tmin:tmax] # (T', N, d) M = M[tmin:tmax] # (T', N) t_slice = t[tmin:tmax] # (T',) # Particle selection if particles is None: X_sel = X M_sel = M else: idx = np.asarray(particles, dtype=int) X_sel = X[:, idx, :] M_sel = M[:, idx] # Projection if dir1 is not None or dir2 is not None: if dir1 is None or dir2 is None: raise ValueError("Either provide both dir1 and dir2, or neither.") dir1 = np.asarray(dir1, dtype=float) dir2 = np.asarray(dir2, dtype=float) if dir1.shape != (d,) or dir2.shape != (d,): raise ValueError(f"dir1/dir2 must have shape ({d},)") x = np.tensordot(X_sel, dir1, axes=([-1], [0])) # (T', N_sel) y = np.tensordot(X_sel, dir2, axes=([-1], [0])) # (T', N_sel) else: if dims is None: dims = (0, 1) i, j = dims if i >= d or j >= d: raise IndexError(f"dims {dims} out of range for state dimension d={d}") x = X_sel[..., i] y = X_sel[..., j] x = x + float(shift[0]) y = y + float(shift[1]) # Periodic wrap: fold into [0, L) and flag boundary-crossing segments. jump = None if box is not None: box = np.asarray(box, dtype=float) x = x % box[0] y = y % box[1] jump = (np.abs(np.diff(x, axis=0)) > 0.5 * box[0]) | ( np.abs(np.diff(y, axis=0)) > 0.5 * box[1] ) # (T'-1, N_sel) if transform is not None: x = transform(x) y = transform(y) # Build segments XY = np.stack([x, y], axis=-1) # (T', N_sel, 2) XY_start = XY[:-1] # (T'-1, N_sel, 2) XY_end = XY[1:] # (T'-1, N_sel, 2) seg_valid = M_sel[:-1] & M_sel[1:] # (T'-1, N_sel) if jump is not None: seg_valid = seg_valid & ~jump segs = np.stack([XY_start, XY_end], axis=2) # (T'-1, N_sel, 2, 2) segs_flat = segs.reshape(-1, 2, 2) valid_flat = seg_valid.reshape(-1) if drop_masked: segs_flat = segs_flat[valid_flat] if color is not None: if plot_colorbar: raise ValueError("phase2d: `color` and `plot_colorbar` are mutually exclusive.") lc = LineCollection(segs_flat, colors=color, linewidth=linewidth, alpha=alpha) ax.add_collection(lc) else: # Color by mid-time t_mid = 0.5 * (t_slice[:-1] + t_slice[1:]) # (T'-1,) t_mid_2d = np.broadcast_to(t_mid[:, None], seg_valid.shape) # (T'-1, N_sel) c_flat = t_mid_2d.reshape(-1) if drop_masked: c_flat = c_flat[valid_flat] norm = Normalize(vmin=float(c_flat.min()), vmax=float(c_flat.max())) lc = LineCollection(segs_flat, cmap=cmap, norm=norm, linewidth=linewidth, alpha=alpha) lc.set_array(c_flat) ax.add_collection(lc) # Limits xy_valid = XY.reshape(-1, 2) if drop_masked: mask_flat = M_sel.reshape(-1) xy_valid = xy_valid[mask_flat] if xy_valid.size: ax.set_xlim(xy_valid[:, 0].min(), xy_valid[:, 0].max()) ax.set_ylim(xy_valid[:, 1].min(), xy_valid[:, 1].max()) _equal_aspect(ax) ax.set_xlabel("x") ax.set_ylabel("y") if plot_colorbar: plt.colorbar(lc, ax=ax, label="t") return lc
def _time_window(T, tmin, tmax): """Resolve integer (tmin, tmax) bounds, with negative tmax from the end.""" lo = 0 if tmin is None else max(0, int(tmin)) if tmax is None: hi = T elif tmax < 0: hi = T + int(tmax) else: hi = min(T, int(tmax)) if hi <= lo: raise ValueError("Empty time window.") return lo, hi
[docs] def phase2d_scalar( coll: TrajectoryCollection, *, color_fn, dataset: int = 0, dims=(0, 1), tmin=None, tmax=None, particles=None, cmap="plasma", linewidth: float = 1.5, alpha: float = 1.0, plot_colorbar: bool = True, colorbar_label: str = "", ax=None, drop_masked: bool = True, ) -> LineCollection: """2D phase-space plot colored by a scalar field of the coordinate. Like :func:`phase2d`, but each segment is colored by ``color_fn(midpoint)`` rather than by time — e.g. the local diffusivity ``D(x)``, the speed, or a potential. ``color_fn`` receives the **full d-dimensional** segment midpoints, shape ``(n_segments, d)``, and must return a scalar per segment, shape ``(n_segments,)``. """ if ax is None: ax = plt.gca() t, X, M = _collection_to_arrays(coll, dataset=dataset) T, N, d = X.shape lo, hi = _time_window(T, tmin, tmax) X, M = X[lo:hi], M[lo:hi] if particles is not None: idx = np.asarray(particles, dtype=int) X, M = X[:, idx], M[:, idx] i, j = dims XY = np.stack([X[..., i], X[..., j]], axis=-1) # (T', N_sel, 2) Xmid = 0.5 * (X[:-1] + X[1:]) # (T'-1, N_sel, d) seg_valid = (M[:-1] & M[1:]).reshape(-1) segs = np.stack([XY[:-1], XY[1:]], axis=2).reshape(-1, 2, 2) cvals = np.asarray(color_fn(Xmid.reshape(-1, d))).reshape(-1) if drop_masked: segs, cvals = segs[seg_valid], cvals[seg_valid] norm = Normalize(vmin=float(cvals.min()), vmax=float(cvals.max())) lc = LineCollection(segs, cmap=cmap, norm=norm, linewidth=linewidth, alpha=alpha) lc.set_array(cvals) ax.add_collection(lc) xy = XY.reshape(-1, 2) if drop_masked: xy = xy[M.reshape(-1)] if xy.size: ax.set_xlim(xy[:, 0].min(), xy[:, 0].max()) ax.set_ylim(xy[:, 1].min(), xy[:, 1].max()) _equal_aspect(ax) ax.set_xlabel("x") ax.set_ylabel("y") if plot_colorbar: plt.colorbar(lc, ax=ax, label=colorbar_label) return lc
[docs] def timeseries_colored( coll: TrajectoryCollection, *, color_fn, dataset: int = 0, dims=None, particles=None, cmap="plasma", colorbar_label: str = "", plot_colorbar: bool = True, ax=None, s: float = 4.0, alpha: float = 1.0, rasterized: bool = True, **scatter_kw, ): """Plot ``x[dim](t)`` as a scatter colored by a scalar field. Mask-aware time series in which each point is colored by ``color_fn(X)`` — e.g. the local diffusivity along the trajectory. ``color_fn`` receives points of shape ``(n_points, d)`` and returns a scalar per point. A single shared colorbar spans all series. """ if ax is None: ax = plt.gca() t, X, M = _collection_to_arrays(coll, dataset=dataset) T, N, D = X.shape if particles is not None: idx = np.asarray(particles, dtype=int) X, M = X[:, idx], M[:, idx] Nsel = X.shape[1] dims = range(D) if dims is None else list(dims) cvals = np.asarray(color_fn(X.reshape(-1, D))).reshape(T, Nsel) norm = Normalize(vmin=float(np.nanmin(cvals)), vmax=float(np.nanmax(cvals))) sm = None for n in range(Nsel): mask_n = M[:, n].astype(bool) for dd in dims: sm = ax.scatter( t[mask_n], X[mask_n, n, dd], c=cvals[mask_n, n], cmap=cmap, norm=norm, s=s, alpha=alpha, rasterized=rasterized, **scatter_kw, ) ax.set_xlabel("t") ax.set_ylabel(f"x[{list(dims)[0]}]" if len(list(dims)) == 1 else "x[d]") if plot_colorbar and sm is not None: plt.colorbar(sm, ax=ax, label=colorbar_label) return ax
[docs] def phase3d( coll: TrajectoryCollection, *, dataset: int = 0, dims=(0, 1, 2), tmin=None, tmax=None, particles=None, cmap="viridis", linewidth: float = 1.5, alpha: float = 1.0, scatter_endpoints: bool = True, scatter_size: float = 30.0, ax=None, drop_masked: bool = True, ): """3D trajectory plot with segments colored along time. The 3D analog of :func:`phase2d`: draws each particle's path as a time-colored :class:`Line3DCollection`. Pass a 3D axes via ``ax`` or a new one is created (``projection="3d"``). """ from mpl_toolkits.mplot3d.art3d import Line3DCollection # noqa: WPS433 if ax is None: ax = plt.gcf().add_subplot(111, projection="3d") t, X, M = _collection_to_arrays(coll, dataset=dataset) T, N, d = X.shape lo, hi = _time_window(T, tmin, tmax) X, M, tt = X[lo:hi], M[lo:hi], t[lo:hi] if particles is not None: idx = np.asarray(particles, dtype=int) X, M = X[:, idx], M[:, idx] i, j, k = dims norm = Normalize(vmin=float(tt.min()), vmax=float(tt.max())) allpts = [] for n in range(X.shape[1]): m = M[:, n].astype(bool) pts = np.stack([X[m, n, i], X[m, n, j], X[m, n, k]], axis=-1) # (Tm, 3) if len(pts) < 2: continue allpts.append(pts) segs = np.stack([pts[:-1], pts[1:]], axis=1) # (Tm-1, 2, 3) t_mid = 0.5 * (tt[m][:-1] + tt[m][1:]) lc = Line3DCollection(segs, cmap=cmap, norm=norm, linewidth=linewidth, alpha=alpha) lc.set_array(t_mid) ax.add_collection3d(lc) if scatter_endpoints: ax.scatter(*pts[-1], s=scatter_size, color=plt.get_cmap(cmap)(1.0)) if allpts: P = np.concatenate(allpts, axis=0) ax.set_xlim(P[:, 0].min(), P[:, 0].max()) ax.set_ylim(P[:, 1].min(), P[:, 1].max()) ax.set_zlim(P[:, 2].min(), P[:, 2].max()) ax.set_xlabel("x") ax.set_ylabel("y") ax.set_zlabel("z") return ax
[docs] def trajectory_scatter( coll: TrajectoryCollection, *, dataset: int = 0, dims=(0, 1), particles=None, tmin=None, tmax=None, cmap=None, s: float = 2.0, alpha: float = 0.1, ax=None, drop_masked: bool = True, **scatter_kw, ): """All-frames density scatter cloud of a 2D projection. Unlike :func:`phase2d` (connected lines) or :func:`plot_particles` (one frame), this scatters every valid ``(particle, frame)`` position — useful for occupancy / home-range visualisations. With ``cmap`` set, points are colored by time. """ if ax is None: ax = plt.gca() t, X, M = _collection_to_arrays(coll, dataset=dataset) T, N, d = X.shape lo, hi = _time_window(T, tmin, tmax) X, M, tt = X[lo:hi], M[lo:hi], t[lo:hi] if particles is not None: idx = np.asarray(particles, dtype=int) X, M = X[:, idx], M[:, idx] i, j = dims xf = X[..., i].reshape(-1) yf = X[..., j].reshape(-1) maskf = M.reshape(-1).astype(bool) cf = np.broadcast_to(tt[:, None], X.shape[:2]).reshape(-1) if cmap is not None else None if drop_masked: xf, yf = xf[maskf], yf[maskf] if cf is not None: cf = cf[maskf] if cf is not None: ax.scatter(xf, yf, c=cf, cmap=cmap, s=s, alpha=alpha, **scatter_kw) else: ax.scatter(xf, yf, color=SFI_COLORS["data"], s=s, alpha=alpha, **scatter_kw) _equal_aspect(ax) ax.set_xlabel("x") ax.set_ylabel("y") return ax
[docs] def plot_field( coll: TrajectoryCollection, field, *, dataset: int = 0, dir1=None, dir2=None, center=None, N: int = 10, scale: float = 1.0, autoscale: bool = False, color="g", radius=None, positions=None, powernorm: float = 0.0, mask_unvisited: bool = False, clip_magnitude=None, **kwargs, ): """Plot a 2D vector field (or a 2D slice of a higher-dimensional field). Parameters ---------- coll : TrajectoryCollection providing typical positions for scaling/centering. field : Callable ``field(X) -> F`` with X of shape (n_points, d) and F of same shape. dataset : Dataset index inside the collection used to estimate center/radius. dir1, dir2 : Projection directions in R^d. If None, use coordinate axes 0 and 1. center, N, scale, autoscale, color, radius, positions, powernorm, **kwargs : Grid / scaling controls. mask_unvisited : If True, drop arrows in grid cells with no trajectory data within one grid spacing (keeps the quiver legible in sparsely-sampled regions). clip_magnitude : If set, clip each arrow's magnitude to this value (long arrows in high-force regions no longer dominate the plot). """ t, X, M = _collection_to_arrays(coll, dataset=dataset) d = X.shape[-1] if dir1 is None: dir1 = axisvector(0, d) if dir2 is None: dir2 = axisvector(1, d) if center is None: center = X.mean(axis=(0, 1)) if radius is None: radius = 0.5 * (X.max(axis=(0, 1)) - X.min(axis=(0, 1))).max() if positions is None: positions = [] for a in np.linspace(-radius, radius, N): for b in np.linspace(-radius, radius, N): positions.append(center + a * dir1 + b * dir2) gridX, gridY = [], [] vX, vY = [], [] for pos in positions: x = dir1.dot(pos) y = dir2.dot(pos) gridX.append(x) gridY.append(y) v = field(pos.reshape((1, d))) if powernorm != 0: v /= np.linalg.norm(v) ** powernorm vX.append(dir1.dot(v[0, :])) vY.append(dir2.dot(v[0, :])) vX = np.array(vX) vY = np.array(vY) if clip_magnitude is not None: mag = np.sqrt(vX**2 + vY**2) factor = np.where(mag > clip_magnitude, clip_magnitude / np.maximum(mag, 1e-12), 1.0) vX = vX * factor vY = vY * factor if autoscale: scale /= float(np.nanmax(np.sqrt(vX**2 + vY**2))) if mask_unvisited: data = X.reshape(-1, d) if data.shape[0] > 5000: data = data[:: max(1, data.shape[0] // 5000)] pos_arr = np.asarray([np.asarray(p) for p in positions]) spacing = (2.0 * radius) / max(N - 1, 1) covered = np.zeros(len(pos_arr), dtype=bool) for ci in range(0, len(pos_arr), 256): chunk = pos_arr[ci : ci + 256] dmin = np.sqrt(((chunk[:, None, :] - data[None, :, :]) ** 2).sum(-1)).min(axis=1) covered[ci : ci + 256] = dmin <= spacing vX = np.where(covered, vX, np.nan) vY = np.where(covered, vY, np.nan) plt.quiver( gridX, gridY, scale * vX, scale * vY, scale=1.0, units="xy", color=color, minlength=0.0, **kwargs, ) plt.ylim(-radius + dir2.dot(center), radius + dir2.dot(center)) plt.xlim(-radius + dir1.dot(center), radius + dir1.dot(center)) _equal_aspect(plt.gca()) plt.xticks([]) plt.yticks([])
[docs] def plot_tensor_field( coll: TrajectoryCollection, field, *, dataset: int = 0, center=None, N: int = 10, scale: float = 1.0, autoscale: bool = False, color="g", radius=None, positions=None, mode: str = "eigencross", **kwargs, ): """Plot a tensor field for 2D processes from a TrajectoryCollection. ``mode="eigencross"`` (default) draws each tensor as a pair of eigen-axis arrows; ``mode="ellipse"`` draws an eigen-aligned ellipse glyph (axis lengths ``∝ sqrt(eigenvalue)``), a clearer rendering for anisotropic diffusion fields. """ if mode not in ("eigencross", "ellipse"): raise ValueError(f"Unknown mode {mode!r}; expected 'eigencross' or 'ellipse'.") t, X, M = _collection_to_arrays(coll, dataset=dataset) d = X.shape[-1] if d != 2: raise ValueError(f"plot_tensor_field expects d=2, got d={d}") if center is None: center = X.mean(axis=(0, 1)) if radius is None: radius = 0.5 * (X.max(axis=(0, 1)) - X.min(axis=(0, 1))).max() if positions is None: positions = [] for a in np.linspace(-radius, radius, N): for b in np.linspace(-radius, radius, N): positions.append(center + np.array([a, b])) if mode == "ellipse": from matplotlib.patches import Ellipse ax = plt.gca() for pos in positions: tensor = np.asarray(field(pos.reshape((1, d)))).reshape(2, 2) w, Vv = np.linalg.eigh(0.5 * (tensor + tensor.T)) w = np.clip(w, 1e-4, None) ang = np.degrees(np.arctan2(Vv[1, 1], Vv[0, 1])) ax.add_patch( Ellipse( (pos[0], pos[1]), width=scale * float(np.sqrt(w[1])), height=scale * float(np.sqrt(w[0])), angle=ang, fill=False, lw=1.2, color=color, **kwargs, ) ) ax.set_xlim(center[0] - radius, center[0] + radius) ax.set_ylim(center[1] - radius, center[1] + radius) _equal_aspect(plt.gca()) plt.xticks([]) plt.yticks([]) return Xp, Yp, U, V = [], [], [], [] for pos in positions: posr = pos.reshape((1, d)) tensor = field(posr) eigvals, eigvecs = np.linalg.eigh(tensor.reshape((2, 2))) for j in range(2): Xp.append(pos[0]) Yp.append(pos[1]) U.append(eigvals[j] * eigvecs[0, j]) V.append(eigvals[j] * eigvecs[1, j]) if autoscale: scale /= max(np.array(U) ** 2 + np.array(V) ** 2) ** 0.5 Xp = np.array(Xp) Yp = np.array(Yp) dX = 0.5 * scale * np.array(U) dY = 0.5 * scale * np.array(V) plt.quiver( Xp - dX, Yp - dY, 2 * dX, 2 * dY, scale=1.0, units="xy", color=color, minlength=0.0, headwidth=1.0, headlength=0.0, **kwargs, ) _equal_aspect(plt.gca()) plt.xticks([]) plt.yticks([])
[docs] def plot_profile_1d( coll: TrajectoryCollection, field, *, exact_field=None, dataset: int = 0, dim: int = 0, component=None, N: int = 200, ci=None, samples: bool = False, ax=None, margin: float = 0.05, label_exact: str = "Exact", label_inferred: str = "Inferred", ): """1D profile of an inferred field, optionally overlaid on the exact one. Evaluates ``field`` on a grid spanning the data range along ``dim`` and plots the inferred profile (gold), an optional exact overlay (orange dashes), an optional confidence band, and an optional sample histogram backdrop. The 1D analog of :func:`plot_field` for forces ``F(x)`` and scalar diffusion profiles ``D(x)`` / ``D(v)``. Parameters ---------- field, exact_field : Callables ``f(X) -> array`` (vector ``(N, d)`` or tensor ``(N, d, d)``). ``exact_field`` is optional. dim : Coordinate axis to sweep. component : Which output component to plot. Defaults to ``dim`` for a vector field and ``(dim, dim)`` for a tensor field. ci : Optional dict with ``"lower"``/``"upper"`` arrays (same length as the grid) for a confidence band. samples : If True, draw a faint histogram of the data along ``dim`` behind the curves. """ if ax is None: ax = plt.gca() t, X, M = _collection_to_arrays(coll, dataset=dataset) d = X.shape[-1] xmin = float(X[..., dim].min()) xmax = float(X[..., dim].max()) pad = margin * (xmax - xmin) grid = np.linspace(xmin - pad, xmax + pad, N) pts = np.zeros((N, d)) pts[:, dim] = grid def _val(f): out = np.asarray(f(jnp.asarray(pts))) if out.ndim == 1: return out if out.ndim == 2: return out[:, component if component is not None else dim] if out.ndim == 3: c = component if component is not None else dim return out[:, c, c] return out.reshape(N, -1)[:, 0] if samples: data1 = X[..., dim].reshape(-1)[M.reshape(-1).astype(bool)] axh = ax.twinx() axh.hist(data1, bins=60, density=True, alpha=0.18, color=SFI_COLORS["data"]) axh.set_yticks([]) if exact_field is not None: ax.plot(grid, _val(exact_field), "--", lw=2, color=SFI_COLORS["exact"], label=label_exact) ax.plot(grid, _val(field), lw=2, color=SFI_COLORS["inferred"], label=label_inferred) if ci is not None: ax.fill_between( grid, np.asarray(ci["lower"]), np.asarray(ci["upper"]), color=SFI_COLORS["inferred"], alpha=0.2, ) ax.axhline(0, color="#808080", lw=0.5) ax.set_xlabel(f"x[{dim}]") ax.legend() return ax
[docs] def plot_field_error( coll: TrajectoryCollection, field_inferred, field_exact, *, dataset: int = 0, N: int = 60, norm: str = "l2", cmap: str = "inferno", vmax=None, ax=None, ): """2D heatmap of the pointwise force-reconstruction error. Evaluates both fields on a grid spanning the (masked) data bounding box and renders ``||F_exact - F_inferred||`` as a ``pcolormesh``. The spatial complement to :func:`comparison_scatter`. """ if ax is None: ax = plt.gca() t, X, M = _collection_to_arrays(coll, dataset=dataset) d = X.shape[-1] if d != 2: raise ValueError(f"plot_field_error requires 2D data, got d={d}") flat = X.reshape(-1, d)[M.reshape(-1).astype(bool)] xmin, ymin = flat.min(axis=0) xmax, ymax = flat.max(axis=0) GX, GY = np.meshgrid(np.linspace(xmin, xmax, N), np.linspace(ymin, ymax, N)) pts = np.stack([GX.ravel(), GY.ravel()], axis=-1) Fi = np.asarray(field_inferred(jnp.asarray(pts))) Fe = np.asarray(field_exact(jnp.asarray(pts))) ordmap = {"l2": 2, "l1": 1, "linf": np.inf} err = np.linalg.norm(Fe - Fi, ord=ordmap.get(norm, 2), axis=-1).reshape(GX.shape) im = ax.pcolormesh(GX, GY, err, cmap=cmap, vmax=vmax, shading="auto") plt.colorbar(im, ax=ax, label=f"||F_exact - F_inferred|| ({norm})") ax.set_aspect("equal") ax.set_xlabel("x") ax.set_ylabel("y") return ax
[docs] def stream_field( coll: TrajectoryCollection, field, *, dataset: int = 0, dir1=None, dir2=None, center=None, radius=None, N: int = 20, density: float = 1.0, color: str = "#808080", ax=None, **streamplot_kw, ): """Streamplot of a callable vector field over the data domain. The streamline counterpart to :func:`plot_field` (quiver): integrates ``field(X) -> F`` into flow lines on a grid spanning the data bounding box. Useful for visualising the topology of an inferred or analytic drift field. """ if ax is None: ax = plt.gca() t, X, M = _collection_to_arrays(coll, dataset=dataset) d = X.shape[-1] if dir1 is None: dir1 = axisvector(0, d) if dir2 is None: dir2 = axisvector(1, d) if center is None: center = X.mean(axis=(0, 1)) if radius is None: radius = 0.5 * (X.max(axis=(0, 1)) - X.min(axis=(0, 1))).max() g = np.linspace(-radius, radius, N) GX, GY = np.meshgrid(g, g) pts = ( center[None, :] + GX.ravel()[:, None] * dir1[None, :] + GY.ravel()[:, None] * dir2[None, :] ) F = np.asarray(field(jnp.asarray(pts))) U = F.dot(dir1).reshape(N, N) V = F.dot(dir2).reshape(N, N) ax.streamplot(g + float(dir1.dot(center)), g + float(dir2.dot(center)), U, V, density=density, color=color, **streamplot_kw) ax.set_aspect("equal") ax.set_xlabel("x") ax.set_ylabel("y") return ax
[docs] def plot_particles( coll: TrajectoryCollection, a: int = 0, b: int = 1, t_index: int = -1, colored: bool = True, active: bool = False, u: float = 0.35, *, dataset: int = 0, color_dim=None, cmap=None, vmin=None, vmax=None, quiver: bool = False, heading_dim=None, box=None, s: float = 100.0, ax=None, quiver_kw=None, **kwargs, ): """Display all particles at time index ``t_index`` from a collection. Parameters ---------- a, b : State dimensions for the x/y axes of the snapshot. t_index : Frame index (negative counts from the end). colored : If True (and no ``color_dim``), color particles by index (magma). color_dim : Color particles by the value of this state dimension instead of by index — e.g. heading angle (default colormap ``"hsv"``). cmap, vmin, vmax : Colormap / normalisation for the coloring. quiver, heading_dim : If ``quiver=True``, overlay heading arrows from ``cos/sin(X[:, heading_dim])`` (``heading_dim`` defaults to 2, the orientation channel of an active particle). box : Optional periodic box ``(Lx, Ly)``; positions are wrapped into it. active, u : Legacy orientation marker (a dot at ``u·(cosθ, sinθ)``); superseded by ``quiver`` for active-matter snapshots. ax : Target axes (default: current axes). """ if ax is None: ax = plt.gca() t, X, M = _collection_to_arrays(coll, dataset=dataset) # X: (T, N, d) T = X.shape[0] if t_index < 0: t_index = T + t_index if not (0 <= t_index < T): raise IndexError(f"t_index {t_index} out of range for T={T}") X_t = X[t_index] # (N, d) xy = X_t[:, [a, b]].astype(float) if box is not None: xy = wrap_positions(xy, np.asarray(box)[:2]) x, y = xy[:, 0], xy[:, 1] if color_dim is not None: ax.scatter(x, y, c=X_t[:, color_dim], cmap=cmap or "hsv", vmin=vmin, vmax=vmax, s=s, **kwargs) elif colored: ax.scatter(x, y, cmap=cmap or "magma", s=s, c=np.linspace(0, 1, len(X_t)), vmin=vmin, vmax=vmax, **kwargs) else: ax.scatter(x, y, s=s, c="w", edgecolor="#808080", **kwargs) hd = heading_dim if heading_dim is not None else (2 if X_t.shape[1] >= 3 else None) if quiver and hd is not None: qkw = dict(color="#B0B0B0", pivot="mid", units="xy") qkw.update(quiver_kw or {}) ax.quiver(x, y, np.cos(X_t[:, hd]), np.sin(X_t[:, hd]), **qkw) elif active and X_t.shape[1] >= 3: xa = x + u * np.cos(X_t[:, 2]) ya = y + u * np.sin(X_t[:, 2]) ax.scatter(xa, ya, c="#B0B0B0", s=20) ax.set_aspect("equal") ax.set_xticks([]) ax.set_yticks([]) return ax
[docs] def plot_particles_field( coll: TrajectoryCollection, field, *, dataset: int = 0, t_index: int = -1, dir1=None, dir2=None, center=None, radius=None, scale: float = 1.0, autoscale: bool = False, color="g", **kwargs, ): """ Plot a 2D vector field evaluated at particle positions at a given time. """ t, X, M = _collection_to_arrays(coll, dataset=dataset) T, N, d = X.shape if t_index < 0: t_index = T + t_index if not (0 <= t_index < T): raise IndexError(f"t_index {t_index} out of range for T={T}") X_t = X[t_index] if dir1 is None: dir1 = axisvector(0, d) if dir2 is None: dir2 = axisvector(1, d) F = field(X_t) if center is None: center = X_t.mean(axis=0) if radius is None: radius = 0.5 * (X_t.max(axis=0) - X_t.min(axis=0)).max() gridX, gridY, vX, vY = [], [], [], [] for ind, pos in enumerate(X_t): x = dir1.dot(pos) y = dir2.dot(pos) gridX.append(x) gridY.append(y) vX.append(dir1.dot(F[ind, :])) vY.append(dir2.dot(F[ind, :])) if autoscale: scale /= max(np.array(vX) ** 2 + np.array(vY) ** 2) ** 0.5 plt.quiver( gridX, gridY, scale * np.array(vX), scale * np.array(vY), scale=1.0, units="xy", color=color, minlength=0.0, **kwargs, ) plt.ylim(-radius + dir2.dot(center), radius + dir2.dot(center)) plt.xlim(-radius + dir1.dot(center), radius + dir1.dot(center)) _equal_aspect(plt.gca()) plt.xticks([]) plt.yticks([])
[docs] def plot_nematic_director( ax, Qxx, Qxy, rho, *, skip: int = 2, scale: float = 2.5, color: str = "white", alpha: float = 0.55, linewidth: float = 0.45, **quiver_kw, ): """Overlay the nematic director field of a Q-tensor on an image axes. Given the order-parameter fields ``Qxx``, ``Qxy`` and density ``rho`` (each a 2D grid), draws a *headless* quiver of the director ``ψ = ½·atan2(Qxy, Qxx)`` on a subsampled grid — the canonical active-nematic overlay. Parameters ---------- ax : Target axes (typically holding an ``imshow`` of the density). Qxx, Qxy, rho : 2D arrays of equal shape. skip : Subsampling stride for the director glyphs. scale : Glyph length (passed through to ``quiver`` ``scale``). """ Qxx = np.asarray(Qxx) Qxy = np.asarray(Qxy) rho = np.maximum(np.asarray(rho), 1e-3) psi = 0.5 * np.arctan2(Qxy / rho, Qxx / rho) ny, nx = psi.shape iy, ix = np.meshgrid( np.arange(skip // 2, ny, skip), np.arange(skip // 2, nx, skip), indexing="ij" ) th = psi[iy, ix] qkw = dict( scale=scale, color=color, alpha=alpha, linewidth=linewidth, headwidth=0, headlength=0, headaxislength=0, pivot="mid", units="xy", ) qkw.update(quiver_kw) # Return the Quiver artist so callers can update it per animation frame # (set_offsets / set_UVC); for static use just ignore the return value. return ax.quiver(ix, iy, np.cos(th), np.sin(th), **qkw)
[docs] def plot_rods( ax, X_frame, *, angle_index: int = 2, length: float = 0.85, color: str = "#d4a96a", linewidth: float = 2.8, capstyle: str = "round", **kwargs, ): """Draw oriented rods (active-matter particles) as a ``LineCollection``. Each particle in ``X_frame`` (rows ``(x, y, …, θ, …)``) becomes a short segment of length ``length`` centred on ``(x, y)`` and oriented at ``θ = X_frame[:, angle_index]``. """ X_frame = np.asarray(X_frame) x, y = X_frame[:, 0], X_frame[:, 1] th = X_frame[:, angle_index] h = 0.5 * length dx, dy = h * np.cos(th), h * np.sin(th) starts = np.stack([x - dx, y - dy], axis=-1) ends = np.stack([x + dx, y + dy], axis=-1) segs = np.stack([starts, ends], axis=1) # (N, 2, 2) lc = LineCollection(segs, colors=color, linewidths=linewidth, capstyle=capstyle, **kwargs) ax.add_collection(lc) ax.set_aspect("equal") return lc
[docs] def plot_spde_snapshot( coll: TrajectoryCollection, t_indices, *, dataset: int = 0, scalar_channel: int = 0, vector_channels=None, grid_shape=None, dx=None, render: str = "imshow", streamplot_kw=None, quiver_kw=None, axes=None, vmin=None, vmax=None, cmap: str = "magma", ): """Render SPDE field snapshots from a gridded TrajectoryCollection. Reshapes each requested frame ``X[t]`` of shape ``(N, n_channels)`` to ``(*grid_shape, n_channels)`` and draws ``scalar_channel`` as an ``imshow`` image, optionally overlaying ``vector_channels`` as streamlines (``render="streamplot"``) or arrows (``render="quiver"``). Parameters ---------- t_indices : A single frame index, or a sequence of indices (one panel each). grid_shape : ``(nx, ny)`` of the field grid. Inferred as a square grid when omitted. dx : Physical grid spacing (default 1.0). vmin, vmax : Color limits; default to the 0.5/99.5 percentiles of the field. """ t, X, M = _collection_to_arrays(coll, dataset=dataset) T, N, C = X.shape if grid_shape is None: s = int(round(np.sqrt(N))) if s * s != N: raise ValueError("Provide grid_shape; cannot infer a non-square grid.") grid_shape = (s, s) dx = 1.0 if dx is None else float(dx) single = np.isscalar(t_indices) tis = [int(t_indices)] if single else [int(i) for i in t_indices] if axes is None: _, axes = plt.subplots(1, len(tis), figsize=(4 * len(tis), 4), squeeze=False) axes = axes[0] axes = np.atleast_1d(axes) gx = (np.arange(grid_shape[0]) + 0.5) * dx gy = (np.arange(grid_shape[1]) + 0.5) * dx for ax, ti in zip(axes, tis): field = X[ti].reshape(*grid_shape, C) scal = field[..., scalar_channel] vlo = float(np.percentile(scal, 0.5)) if vmin is None else vmin vhi = float(np.percentile(scal, 99.5)) if vmax is None else vmax ax.imshow( scal.T, origin="lower", cmap=cmap, vmin=vlo, vmax=vhi, extent=[0, grid_shape[0] * dx, 0, grid_shape[1] * dx], ) if vector_channels is not None: vx = field[..., vector_channels[0]] vy = field[..., vector_channels[1]] if render == "streamplot": ax.streamplot(gx, gy, vx.T, vy.T, **(streamplot_kw or {})) else: GX, GY = np.meshgrid(gx, gy, indexing="ij") qkw = dict(color="#B0B0B0") qkw.update(quiver_kw or {}) ax.quiver(GX, GY, vx, vy, **qkw) ax.set_xticks([]) ax.set_yticks([]) return axes[0] if single else axes
[docs] def spatial_acorr2d(field_2d, *, dx: float = 1.0, n_bins=None, normalize: bool = True): """Radially-averaged 2D spatial autocorrelation via FFT (periodic). Returns ``(r, C)`` where ``C(r)`` is the angle-averaged autocorrelation of ``field_2d`` (mean removed) at radial separation ``r``. Assumes a periodic grid. """ f = np.asarray(field_2d, dtype=float) f = f - f.mean() F = np.fft.rfft2(f) C = np.fft.irfft2(F * np.conj(F), s=f.shape) / f.size if normalize and C[0, 0] != 0: C = C / C[0, 0] nx, ny = f.shape ix = (np.arange(nx) + nx // 2) % nx - nx // 2 iy = (np.arange(ny) + ny // 2) % ny - ny // 2 GX, GY = np.meshgrid(ix, iy, indexing="ij") r = np.sqrt(GX**2 + GY**2) * dx if n_bins is None: n_bins = min(nx, ny) // 2 edges = np.linspace(0.0, float(r.max()), n_bins + 1) which = np.clip(np.digitize(r.ravel(), edges) - 1, 0, n_bins - 1) cvals = C.ravel() radial = np.array( [cvals[which == b].mean() if np.any(which == b) else np.nan for b in range(n_bins)] ) centers = 0.5 * (edges[:-1] + edges[1:]) return centers, radial
[docs] def animate_particles( coll: TrajectoryCollection, *, dataset: int = 0, dims=(0, 1), trail: int = 0, overlay_fn=None, skip: int = 1, cmap: str = "magma", s: float = 100.0, color_dim=None, vmin=None, vmax=None, quiver: bool = False, heading_dim=None, box=None, interval: int = 50, ax=None, fig=None, blit: bool = False, **anim_kw, ): """Animate particle positions over time (frames read via the collection). Returns a :class:`matplotlib.animation.FuncAnimation`. With ``trail>0`` each particle leaves a fading tail; ``overlay_fn(ax, t_index, X_t)`` is called per frame for custom overlays. For active matter, color points by a state dimension (``color_dim``, e.g. heading angle), overlay heading arrows (``quiver=True`` + ``heading_dim``), and wrap into a periodic box (``box=(Lx, Ly)``) — mirroring :func:`plot_particles`. """ from matplotlib.animation import FuncAnimation t, X, M = _collection_to_arrays(coll, dataset=dataset) T, N, d = X.shape frames = list(range(0, T, skip)) if fig is None and ax is None: fig, ax = plt.subplots() elif ax is None: ax = fig.gca() elif fig is None: fig = ax.figure i, j = dims box = np.asarray(box, dtype=float) if box is not None else None def _xy(ti): xi, yi = X[ti, :, i].copy(), X[ti, :, j].copy() if box is not None: xi, yi = xi % box[0], yi % box[1] return xi, yi if box is not None: ax.set_xlim(0.0, float(box[0])) ax.set_ylim(0.0, float(box[1])) else: ax.set_xlim(float(X[..., i].min()), float(X[..., i].max())) ax.set_ylim(float(X[..., j].min()), float(X[..., j].max())) ax.set_aspect("equal") x0, y0 = _xy(frames[0]) if color_dim is not None: cvals = X[:, :, color_dim] vmin = float(np.nanmin(cvals)) if vmin is None else vmin vmax = float(np.nanmax(cvals)) if vmax is None else vmax scat = ax.scatter(x0, y0, c=X[frames[0], :, color_dim], cmap=cmap or "hsv", vmin=vmin, vmax=vmax, s=s) else: scat = ax.scatter(x0, y0, c=np.linspace(0, 1, N), cmap=cmap, s=s) hd = heading_dim if heading_dim is not None else (2 if d >= 3 else None) quiv = None if quiver and hd is not None: th0 = X[frames[0], :, hd] quiv = ax.quiver(x0, y0, np.cos(th0), np.sin(th0), color="#B0B0B0", pivot="mid", units="xy") trails = [] if trail > 0: for _ in range(N): (ln,) = ax.plot([], [], lw=0.8, alpha=0.5, color="#B0B0B0") trails.append(ln) def _update(fr): ti = frames[fr] xi, yi = _xy(ti) off = np.stack([xi, yi], axis=-1) scat.set_offsets(off) if color_dim is not None: scat.set_array(X[ti, :, color_dim]) artists = [scat] if quiv is not None: quiv.set_offsets(off) th = X[ti, :, hd] quiv.set_UVC(np.cos(th), np.sin(th)) artists.append(quiv) if trail > 0: lo = max(0, ti - trail) for n in range(N): trails[n].set_data(X[lo : ti + 1, n, i], X[lo : ti + 1, n, j]) artists += trails if overlay_fn is not None: overlay_fn(ax, ti, X[ti]) return artists return FuncAnimation(fig, _update, frames=len(frames), interval=interval, blit=blit, **anim_kw)
[docs] def animate_spde_comparison( coll_a: TrajectoryCollection, coll_b: TrajectoryCollection, *, dataset: int = 0, field_component: int = 0, grid_shape=None, skip: int = 1, vmin=None, vmax=None, cmap: str = "magma", titles=("A", "B"), interval: int = 50, plot_colorbar: bool = True, blit: bool = False, **anim_kw, ): """Side-by-side animation of one channel of two gridded collections. Returns a :class:`matplotlib.animation.FuncAnimation` with two image panels sharing color limits (e.g. data vs bootstrap SPDE fields). """ from matplotlib.animation import FuncAnimation _, Xa, _ = _collection_to_arrays(coll_a, dataset=dataset) _, Xb, _ = _collection_to_arrays(coll_b, dataset=dataset) Tn = min(Xa.shape[0], Xb.shape[0]) N, C = Xa.shape[1], Xa.shape[2] if grid_shape is None: s = int(round(np.sqrt(N))) if s * s != N: raise ValueError("Provide grid_shape; cannot infer a non-square grid.") grid_shape = (s, s) frames = list(range(0, Tn, skip)) def _slab(Xarr, ti): return Xarr[ti].reshape(*grid_shape, C)[..., field_component].T if vmin is None: vmin = float(min(Xa[..., field_component].min(), Xb[..., field_component].min())) if vmax is None: vmax = float(max(Xa[..., field_component].max(), Xb[..., field_component].max())) fig, axes = plt.subplots(1, 2, figsize=(9, 4.5)) im_a = axes[0].imshow(_slab(Xa, 0), origin="lower", cmap=cmap, vmin=vmin, vmax=vmax) im_b = axes[1].imshow(_slab(Xb, 0), origin="lower", cmap=cmap, vmin=vmin, vmax=vmax) for ax, title in zip(axes, titles): ax.set_title(title) ax.set_xticks([]) ax.set_yticks([]) if plot_colorbar: fig.colorbar(im_b, ax=axes.tolist(), shrink=0.8) def _update(fr): ti = frames[fr] im_a.set_data(_slab(Xa, ti)) im_b.set_data(_slab(Xb, ti)) return [im_a, im_b] return FuncAnimation(fig, _update, frames=len(frames), interval=interval, blit=blit, **anim_kw)
[docs] def plot_time_profile_comparison(t, true_profiles, inferred_profiles, *, labels=None, axes=None): """Plot true vs inferred time-dependent profiles (e.g. k(t), a(t)). ``true_profiles`` and ``inferred_profiles`` are sequences of 1D arrays (one per panel); ``true_profiles`` entries may be ``None`` to skip the reference. ``labels`` gives a title per panel. """ n = len(inferred_profiles) if axes is None: _, axes = plt.subplots(1, n, figsize=(4 * n, 3.5), squeeze=False) axes = axes[0] axes = np.atleast_1d(axes) for k in range(n): ax = axes[k] if true_profiles is not None and true_profiles[k] is not None: ax.plot(t, true_profiles[k], color=SFI_COLORS["exact"], lw=2.4, label="true") ax.plot(t, inferred_profiles[k], color=SFI_COLORS["inferred"], lw=1.6, ls="--", label="inferred") if labels is not None: ax.set_title(labels[k]) ax.set_xlabel("time") ax.legend(loc="upper right") return axes