Source code for SFI.trajectory.degrade

# SFI/trajectory/degrade.py
"""
SFI.trajectory.degrade
======================

Degrade synthetic trajectories to mimic real data:
- motion blur (temporal window average)
- downsampling
- additive measurement noise
- ROI filtering (mask points outside a region)
- random data loss

Two front doors:
----------------
1) Dataset/Collection API (recommended for internal use)
   - degrade_dataset(ds, ...)
   - degrade_collection(coll, ...)

2) Columns API (back-compat for I/O scripts)
   - degrade_columns(meta, particle_idx, time_idx, state_vectors, ...)

Why two? Column flow is convenient for simple scripts and file round-trips;
dataset/collection flow keeps everything rectangular so we can blur/downsample
time-dependent extras cleanly without flatten/unflatten gymnastics.

Extras semantics
----------------
- extras_global:
    - arrays with leading shape (T, ...) are blurred/downsampled along time
      like X; other entries are passed through unchanged.
- extras_local:
    - arrays with shape (N, ...): per-particle constants → unchanged
    - arrays with shape (T, N, ...): blurred/downsampled along time like X

Noise/ROI/data-loss are applied on the **mask** (not by deleting rows), so
tensor shapes remain intact. Flattening to columns (if needed) happens last.

Cache-only extras (auto-generated structural tables)
----------------------------------------------------
Keys starting with ``_cache/`` are considered auto-generated structural extras
(e.g. CSR neighbor lists, stencil hyper tables). They are **not** degraded and
are **dropped** from outputs, because any degradation/context change invalidates
such cached structural objects. They can be regenerated on demand by calling
the appropriate host-side preparation routine.

"""

from __future__ import annotations

from typing import Any, Callable, Dict, List, Literal, Mapping, Optional, Tuple, Union

import numpy as np

from SFI.statefunc.nodes.interactions.prepare import (
    CACHE_PREFIX,
    is_cache_key,
    purge_cache_extras,
)
from SFI.trajectory.collection import TrajectoryCollection
from SFI.trajectory.dataset import TrajectoryDataset

# -------------------------- public: dataset/collection -------------------------- #


[docs] def degrade_dataset( ds: TrajectoryDataset, *, downsample: int = 1, motion_blur: int = 0, data_loss_fraction: float = 0.0, noise: Union[None, float, np.ndarray] = None, ROI: Union[None, float, np.ndarray, Callable[[np.ndarray], bool]] = None, seed: Optional[int] = None, ) -> TrajectoryDataset: """ Degrade a single :class:`TrajectoryDataset`. The function operates in tensor space; it returns a new dataset where: - ``X`` is motion-blurred over ``motion_blur + 1`` frames and downsampled by ``downsample``, - the mask is AND-reduced over the blur window, then modified by ROI and random data loss, - ``t`` (if present) is averaged over the blur window and downsampled, otherwise scalar ``dt`` is multiplied by ``downsample``, - extras are processed consistently (see module docstring). Parameters ---------- ds Input dataset to degrade. downsample Integer downsampling factor along the time axis (must be ``>= 1``). motion_blur Temporal averaging window size minus one. The actual blur window is ``motion_blur + 1`` frames and must satisfy ``0 <= motion_blur < downsample``. data_loss_fraction Fraction of currently valid entries to drop uniformly at random after ROI filtering (in ``[0, 1)``). noise Additive Gaussian noise scale. If a float, isotropic noise with standard deviation ``noise`` is applied. If an array, broadcast to the state dimension. ROI Region-of-interest predicate or mask. Can be: - float: radial cutoff — keeps positions with ``‖x‖₂ ≤ ROI``, - ``(2, d)`` ndarray: axis-aligned box (row 0 = lower bound, row 1 = upper bound), - ``Callable[[np.ndarray], bool]``: predicate evaluated on each observed position. seed Optional RNG seed for the noise and data-loss generators. Returns ------- TrajectoryDataset Degraded dataset with the same number of particles but fewer time steps. """ if downsample < 1: raise ValueError("downsample must be >= 1") if motion_blur < 0 or motion_blur >= downsample: raise ValueError("motion_blur must satisfy 0 <= motion_blur < downsample.") if not (0.0 <= data_loss_fraction < 1.0): raise ValueError("data_loss_fraction must be in [0, 1).") rng = np.random.default_rng(seed) # Shapes X = np.asarray(ds._X3d()) # (T, N, d) M = np.asarray(ds._M2d()) # (T, N) T, N, d = X.shape has_dynamic = ds.dynamic_mask is not None if has_dynamic: M_dyn = np.asarray(ds.dynamic_mask) if M_dyn.ndim == 1: M_dyn = M_dyn[:, None] else: M_dyn = None # Downsampling schedule window = motion_blur + 1 keep_times = np.arange(0, max(T - motion_blur, 0), downsample, dtype=int) K = keep_times.size if K == 0: # Degenerate (e.g. very small T vs blur) → return an empty dataset with copied meta/extras return TrajectoryDataset.from_arrays( X=np.zeros((0, N, d), dtype=X.dtype), t=None if ds.t is None else np.zeros((0,), dtype=float), dt=None if ds.t is not None else ds.dt, mask=np.zeros((0, N), dtype=bool), extras_global=purge_cache_extras(_downsample_extras_global(ds.extras_global, keep_times, window)), extras_local=purge_cache_extras(_downsample_extras_local(ds.extras_local, keep_times, window)), meta=_update_meta( dict(ds.meta), downsample, motion_blur, data_loss_fraction, noise, ROI, seed, ), ) # 1) Blur/Downsample X and mask along time X_ds = np.empty((K, N, d), dtype=X.dtype) M_ds = np.empty((K, N), dtype=bool) M_dyn_ds = np.empty((K, N), dtype=bool) if has_dynamic else None for k, t0 in enumerate(keep_times): segX = X[t0 : t0 + window] # (window, N, d) segM = M[t0 : t0 + window] # (window, N) with np.errstate(invalid="ignore"): X_ds[k] = np.nanmean(np.where(segM[..., None], segX, np.nan), axis=0) M_ds[k] = segM.all(axis=0) # require all valid within window if has_dynamic and M_dyn_ds is not None and M_dyn is not None: M_dyn_ds[k] = M_dyn[t0 : t0 + window].all(axis=0) # 2) Add measurement noise to observed entries if noise is not None: X_ds = _add_noise_to_masked(X_ds, M_ds, rng, noise) # 3) ROI on observed positions at blurred times if ROI is not None: inside = _roi_predicate(ROI, d) # evaluate on every (k,n) that is currently valid flat = X_ds.reshape(K * N, d) M_flat = M_ds.reshape(K * N) sel = np.where(M_flat)[0] if sel.size > 0: inside_mask = np.array([inside(flat[i]) for i in sel], dtype=bool) M_flat[sel] = M_flat[sel] & inside_mask if has_dynamic: assert M_dyn_ds is not None M_dyn_flat = M_dyn_ds.reshape(K * N) M_dyn_flat[sel] = M_dyn_flat[sel] & inside_mask M_dyn_ds = M_dyn_flat.reshape(K, N) M_ds = M_flat.reshape(K, N) # 4) Random data loss on valid entries if data_loss_fraction > 0.0: M_flat = M_ds.reshape(-1) valid_idx = np.where(M_flat)[0] keep = int(round(valid_idx.size * (1.0 - data_loss_fraction))) if keep < valid_idx.size: drop = rng.choice(valid_idx, size=valid_idx.size - keep, replace=False) M_flat[drop] = False if has_dynamic: assert M_dyn_ds is not None M_dyn_flat = M_dyn_ds.reshape(-1) M_dyn_flat[drop] = False M_dyn_ds = M_dyn_flat.reshape(K, N) M_ds = M_flat.reshape(K, N) # 5) Downsample 't' or scale 'dt' if ds.t is not None: t = np.asarray(ds.t) t_out = np.array([np.mean(t[t0 : t0 + window]) for t0 in keep_times], dtype=float) dt_arg = None t_arg = t_out else: t_arg = None if ds.dt is None: dt_arg = None else: dta = np.asarray(ds.dt) if dta.ndim == 0: # Scalar dt: new step is downsample × original step. dt_arg = float(dta) * downsample else: # Variable dt array: sum over each downsampled window. dt_arg = np.array( [float(np.sum(dta[t0 : t0 + downsample])) for t0 in keep_times], dtype=float, ) # 6) Downsample/blur extras extras_g = purge_cache_extras(_downsample_extras_global(ds.extras_global, keep_times, window)) extras_l = purge_cache_extras(_downsample_extras_local(ds.extras_local, keep_times, window)) # 7) Update meta with provenance meta2 = _update_meta(dict(ds.meta), downsample, motion_blur, data_loss_fraction, noise, ROI, seed) if ds.dt is not None and dt_arg is not None: meta2.setdefault("original_dt", np.asarray(ds.dt).tolist()) meta2["dt"] = dt_arg if np.ndim(dt_arg) == 0 else np.asarray(dt_arg).tolist() return TrajectoryDataset.from_arrays( X=X_ds, t=t_arg, dt=dt_arg, mask=M_ds, dynamic_mask=M_dyn_ds, extras_global=extras_g, extras_local=extras_l, meta=meta2, )
[docs] def degrade_collection( coll: TrajectoryCollection, *, downsample: int = 1, motion_blur: int = 0, data_loss_fraction: float = 0.0, noise: Union[None, float, np.ndarray] = None, ROI: Union[None, float, np.ndarray, Callable[[np.ndarray], bool]] = None, seed: Optional[int] = None, reweight: Literal["pool", "keep"] = "pool", ) -> TrajectoryCollection: """ Degrade all datasets in a collection and optionally recompute weights. Parameters ---------- coll Input collection to degrade. downsample, motion_blur, data_loss_fraction, noise, ROI, seed Same semantics as in :func:`degrade_dataset`. reweight Policy for updating collection-level weights after degradation: - ``"pool"``: recompute weights via ``with_weights("pool")``. - ``"keep"``: preserve the relative weights from ``coll.weights``. Returns ------- TrajectoryCollection New collection whose datasets have been degraded in the same way. Notes ----- This function is purely functional: the input collection is not modified. """ ds2: List[TrajectoryDataset] = [ degrade_dataset( ds, downsample=downsample, motion_blur=motion_blur, data_loss_fraction=data_loss_fraction, noise=noise, ROI=ROI, seed=seed, ) for ds in coll.datasets ] out = TrajectoryCollection(ds2, coll.weights) if reweight == "pool": return out.with_weights("pool") if reweight == "keep": return out.with_weights(list(map(float, coll.weights))) raise ValueError(f"Unknown reweight policy: {reweight!r}")
# ------------------------------- helpers -------------------------------- # def _add_noise_to_masked( X: np.ndarray, M: np.ndarray, rng: np.random.Generator, noise: Union[float, np.ndarray], ) -> np.ndarray: """Add Gaussian noise to X **only where M is True**.""" Y = X.copy() T, N, d = Y.shape if np.isscalar(noise): # iid σ on each coordinate eps = rng.normal(scale=float(noise), size=Y.shape) Y[M] = Y[M] + eps[M] return Y noise_arr = np.asarray(noise, dtype=float) if noise_arr.ndim == 1: if noise_arr.shape[0] != d: raise ValueError("Noise vector length must match state dimension d.") eps = rng.normal(size=Y.shape) * noise_arr.reshape((1, 1, d)) Y[M] = Y[M] + eps[M] return Y if noise_arr.ndim == 2: if noise_arr.shape != (d, d): raise ValueError("Noise matrix must be (d, d).") eps = rng.normal(size=Y.shape) @ noise_arr Y[M] = Y[M] + eps[M] return Y raise ValueError("noise must be scalar, (d,), or (d,d).") def _roi_predicate( ROI: Union[None, float, np.ndarray, Callable[[np.ndarray], bool]], d: int ) -> Callable: """Build an ROI predicate in R^d. ``float`` ROI is a radial cutoff: ``‖x‖₂ ≤ ROI``. ``(2, d)`` array ROI is an axis-aligned box: ``lo ≤ x ≤ hi`` element-wise. """ if ROI is None: return lambda x: True if np.isscalar(ROI): r = float(ROI) return lambda x: bool(np.linalg.norm(x) <= r) if isinstance(ROI, np.ndarray) or hasattr(ROI, "__array__"): arr = np.asarray(ROI, dtype=float) if arr.shape == (2, d): lo, hi = arr[0].copy(), arr[1].copy() return lambda x: bool(np.all((lo <= x) & (x <= hi))) raise ValueError(f"ndarray ROI must have shape (2, {d}), got {arr.shape}.") if callable(ROI): return ROI # trusted raise ValueError("ROI must be None, scalar float, (2,d) ndarray, or callable.") def _downsample_extras_global(eg: Dict[str, Any], keep_times: np.ndarray, window: int) -> Dict[str, Any]: """ Blur/downsample extras_global entries with leading time dimension T. Cache-only extras policy ------------------------ Keys under ``_cache/`` are structural/derivable objects and are **dropped** during degradation. They must be regenerated later from the new context. """ out: Dict[str, Any] = {} K = keep_times.size for k, v in (eg or {}).items(): if is_cache_key(k, prefix=CACHE_PREFIX): continue # Pass through callables (e.g. FunctionExtra-unwrapped) unchanged. if callable(v): out[k] = v continue arr = np.asarray(v) # Guard: keep_times may be empty (degenerate trajectory). time_limit = (keep_times[-1] + window) if K > 0 else 0 if arr.ndim >= 1 and K > 0 and arr.shape[0] >= time_limit: flat = arr.reshape((arr.shape[0], -1)) buf = np.empty((K, flat.shape[1]), dtype=float) for i, t0 in enumerate(keep_times): with np.errstate(invalid="ignore"): buf[i] = np.nanmean(flat[t0 : t0 + window], axis=0) out[k] = buf.reshape((K,) + arr.shape[1:]) elif arr.ndim >= 1 and K == 0 and arr.shape[0] > 0: # Degenerate: return empty leading axis, preserve trailing shape. out[k] = arr[:0] else: out[k] = v return out def _downsample_extras_local(el: Dict[str, Any], keep_times: np.ndarray, window: int) -> Dict[str, Any]: """ Blur/downsample extras_local entries. Cache-only extras policy ------------------------ Keys under ``_cache/`` are structural/derivable objects and are **dropped** during degradation. """ out: Dict[str, Any] = {} K = keep_times.size for k, v in (el or {}).items(): if is_cache_key(k, prefix=CACHE_PREFIX): continue # Pass through callables unchanged. if callable(v): out[k] = v continue arr = np.asarray(v) time_limit = (keep_times[-1] + window) if K > 0 else 0 if arr.ndim >= 2 and K > 0 and arr.shape[0] >= time_limit: T_orig, N = arr.shape[0], arr.shape[1] tail = arr.shape[2:] or () flat = arr.reshape((T_orig, N, -1)) buf = np.empty((K, N, flat.shape[2]), dtype=float) for i, t0 in enumerate(keep_times): with np.errstate(invalid="ignore"): buf[i] = np.nanmean(flat[t0 : t0 + window], axis=0) out[k] = buf.reshape((K, N) + tail) elif arr.ndim >= 2 and K == 0 and arr.shape[0] > 0: out[k] = arr[:0] else: out[k] = v return out def _update_meta( meta: Dict[str, Any], downsample: int, motion_blur: int, data_loss_fraction: float, noise: Union[None, float, np.ndarray], ROI: Any, seed: Optional[int], ) -> Dict[str, Any]: """Record degradation provenance in meta dict (shallow copy upstream).""" meta.update( { "degrade_downsample": downsample, "degrade_motion_blur": motion_blur, "degrade_data_loss_frac": data_loss_fraction, "degrade_noise_spec": ( None if noise is None else float(noise) if np.isscalar(noise) else np.asarray(noise).tolist() ), "degrade_ROI_spec": (None if ROI is None else ("callable" if callable(ROI) else np.asarray(ROI).tolist())), "degrade_rng_seed": seed, } ) return meta # ============================================================================= # Spatial degradation for grid-based SPDE datasets # =============================================================================
[docs] def degrade_spatial_data( coll: TrajectoryCollection, *, downscale: int | Tuple[int, ...] = 2, method: Literal["mean", "subsample"] = "mean", blur_radius: int = 0, data_loss_fraction: float = 0.0, noise: Union[None, float, np.ndarray] = None, seed: Optional[int] = None, mask_threshold: float = 0.5, bc: Literal["noflux", "pbc"] = "noflux", prefix: str = "box", order: Literal["C", "F"] = "C", ) -> TrajectoryCollection: """ Degrade an SPDE-style collection in *space* (blur/coarsen/pixel-loss/noise). Assumes the standard SPDE convention: - particle axis N is a flattened grid of shape `grid_shape`, - state dim d is #fields per site. ``dx`` is read from ``extras_global['{prefix}/dx']`` and updated automatically; it does not need to be supplied here. Also updates 'box/' box parameters and erases structural outputs starting with _cache (regenerated on next use). """ rng = np.random.default_rng(seed) ds2 = [] for ds in coll.datasets: ds2.append( degrade_spatial_dataset( ds, downscale=downscale, method=method, blur_radius=blur_radius, data_loss_fraction=data_loss_fraction, noise=noise, rng=rng, mask_threshold=mask_threshold, bc=bc, prefix=prefix, order=order, ) ) out = TrajectoryCollection(ds2, coll.weights) return out.with_weights(list(map(float, coll.weights)))
[docs] def degrade_spatial_dataset( ds: TrajectoryDataset, *, downscale: int | Tuple[int, ...] = 1, method: Literal["mean", "subsample"] = "mean", blur_radius: int = 0, data_loss_fraction: float = 0.0, noise: Union[None, float, np.ndarray] = None, rng: np.random.Generator, mask_threshold: float = 0.5, bc: Literal["noflux", "pbc"] = "noflux", prefix: str = "box", order: Literal["C", "F"] = "C", ) -> TrajectoryDataset: """Spatial degradation of a single SPDE-style dataset. Key invariants ensured by this routine -------------------------------------- 1) The flattening convention is preserved (``order="C"`` or ``"F"``). 2) Box metadata (grid_shape, dx) is updated consistently after coarsening. 3) Any prepared structural stencil payload is dropped so it is rebuilt for the new grid. 4) Mask handling is conservative: a coarse cell is valid only if enough fine pixels are valid. """ # ---- materialize tensors ---- X = np.asarray(ds._X3d()) # (T, N, d) M = np.asarray(ds._M2d()) # (T, N) T, N, d = X.shape # Fetch grid shape gs = (ds.extras_global or {}).get(f"{prefix}/grid_shape", None) if gs is None: raise ValueError( "degrade_spatial_dataset: grid_shape not provided and could not be " f"inferred from extras_global['{prefix}/grid_shape']." ) grid_shape = tuple(int(v) for v in np.asarray(gs).tolist()) ndim = len(grid_shape) if int(np.prod(grid_shape, dtype=np.int64)) != N: raise ValueError(f"grid_shape={grid_shape} incompatible with N={N}.") # ---- downscale factors ---- if isinstance(downscale, int): fac = (int(downscale),) * ndim else: fac = tuple(int(v) for v in downscale) if len(fac) != ndim: raise ValueError(f"downscale must have length {ndim} (got {len(fac)})") if any(v < 1 for v in fac): raise ValueError("downscale factors must be >= 1") # ---- infer dx (needed to update metadata after coarsening) ---- dx_ex = (ds.extras_global or {}).get(f"{prefix}/dx", None) dx_in = _normalize_dx(dx_ex, ndim=ndim) # ---- reshape to grid (respect flattening convention!) ---- # IMPORTANT: if you ever use order="F" anywhere in stencil construction, you must # preserve it here, otherwise the coarse field will be indexed differently than # the (re)built neighbor tables. Xg = X.reshape((T,) + grid_shape + (d,), order=order) Mg = M.reshape((T,) + grid_shape, order=order) # ---- optional blur (masked box blur; ignores missing pixels) ---- if blur_radius > 0: Xg, Mg = _box_blur_nd(Xg, Mg, radius=int(blur_radius), bc=bc) # ---- downscale ---- Xc, Mc = _downscale_nd( Xg, Mg, factors=fac, method=method, mask_threshold=mask_threshold, ) out_shape = tuple(int(s) for s in Xc.shape[1:-1]) N2 = int(np.prod(out_shape, dtype=np.int64)) # ---- flatten back (same order) ---- Xc2 = Xc.reshape((T, N2, d), order=order) Mc2 = Mc.reshape((T, N2), order=order) # ---- optional measurement noise on observed pixels ---- if noise is not None: Xc2 = _add_noise_to_masked(Xc2, Mc2, rng, noise) # ---- random pixel loss (on the coarse grid) ---- if data_loss_fraction > 0.0: flat = Mc2.reshape(-1) valid_idx = np.where(flat)[0] keep = int(round(valid_idx.size * (1.0 - data_loss_fraction))) if keep < valid_idx.size: drop = rng.choice(valid_idx, size=valid_idx.size - keep, replace=False) flat[drop] = False Mc2 = flat.reshape((T, N2)) # ---- extras: purge invalid structural payloads, then update box metadata ---- extras_g = dict(ds.extras_global or {}) # Update dx for the coarse grid: dx_out = dx_in * factor dx_out = None if dx_in is not None: dx_out = tuple(float(dx_in[ax]) * float(fac[ax]) for ax in range(ndim)) # Keep the box extras in sync with the new grid. from SFI.bases.spde import square_grid_extras extras_g.update( square_grid_extras( grid_shape=out_shape, dx=(1.0 if dx_out is None else dx_out), prefix=prefix, ) ) # ---- extras_local: downscale per-site fields if present ---- extras_l = None if ds.extras_local: extras_l = _downscale_extras_local_grid( extras_local=dict(ds.extras_local), grid_shape=grid_shape, factors=fac, method=method, # If your helper supports it, pass Mg/mask_threshold so it can # handle invalid blocks consistently. Otherwise omit. mask=Mg, mask_threshold=mask_threshold, order=order, ) # ---- meta ---- meta2 = dict(ds.meta) meta2.update( dict( degrade_spatial_downscale=fac, degrade_spatial_method=method, degrade_spatial_blur_radius=int(blur_radius), degrade_spatial_data_loss_frac=float(data_loss_fraction), degrade_spatial_noise_spec=( None if noise is None else float(noise) if np.isscalar(noise) else np.asarray(noise).tolist() ), degrade_spatial_bc=bc, degrade_spatial_order=order, ) ) return TrajectoryDataset.from_arrays( X=Xc2, t=ds.t, dt=ds.dt, mask=Mc2, extras_global=extras_g, extras_local=extras_l, meta=meta2, )
def _box_filter_1d( X: np.ndarray, M: np.ndarray, *, radius: int, axis: int, bc: Literal["pbc", "noflux", "drop"] = "noflux", ) -> tuple[np.ndarray, np.ndarray]: """ 1D **masked** box filter along one spatial axis. This is the core primitive used by `_box_blur_nd` (separable N-D blur). Parameters ---------- X Array of shape ``(..., L, C)`` *along the chosen axis*, where: - L is the axis length being blurred, - C is the "value/channel" axis (typically state dimension, or a flattened tail). In your SPDE pipeline this will typically be ``(T, *grid, d)`` reshaped/moved. M Boolean mask of shape ``(..., L)`` aligned with `X` (same leading axes, no channel axis). Masked entries (False) are ignored in the average. radius Blur radius r. Window size is ``W = 2r + 1``. axis Axis index in `X` to blur over (same logical axis in `M`). In your grid convention: time is axis 0, grid axes are 1..ndim, channel is last. bc Boundary condition: - "pbc": periodic wrapping - "noflux": clamp to edge (replicate boundary samples) - "drop": treat out-of-bounds as missing (do not contribute) Returns ------- Xf, Mf Filtered values and updated mask: - `Mf` is True where at least one valid sample contributed to the window. - Where `Mf` is False, `Xf` is set to 0 (by convention). """ r = int(radius) if r <= 0: return X, M # We implement the blur by: # 1) moving the target axis next to the channel axis, # 2) padding according to `bc`, # 3) sliding-window sums via cumsum (O(L), not O(L*W)). # # This is *masked* averaging: # num = sum( X * M ) # den = sum( M ) # X_out = num / den where den>0 # M_out = den>0 # Move blur axis to be the second-to-last axis (right before channels). # After this: X1 has shape (..., L, C), M1 has shape (..., L). X1 = np.moveaxis(X, axis, -2) M1 = np.moveaxis(M, axis, -1) if X1.shape[-2] != M1.shape[-1]: raise ValueError("X and M incompatible along blur axis.") W = 2 * r + 1 # --- pad along L according to bc --- if bc == "pbc": # Wrap: [..., L, C] -> [..., L+2r, C] Xp = np.concatenate([X1[..., -r:, :], X1, X1[..., :r, :]], axis=-2) Mp = np.concatenate([M1[..., -r:], M1, M1[..., :r]], axis=-1) elif bc == "noflux": # Clamp: replicate edge values. padX = [(0, 0)] * X1.ndim padM = [(0, 0)] * M1.ndim padX[-2] = (r, r) padM[-1] = (r, r) Xp = np.pad(X1, pad_width=padX, mode="edge") Mp = np.pad(M1, pad_width=padM, mode="edge") elif bc == "drop": # Out-of-bounds are invalid: pad mask with False, values arbitrary (0). padX = [(0, 0)] * X1.ndim padM = [(0, 0)] * M1.ndim padX[-2] = (r, r) padM[-1] = (r, r) Xp = np.pad(X1, pad_width=padX, mode="constant", constant_values=0.0) Mp = np.pad(M1, pad_width=padM, mode="constant", constant_values=False) else: raise ValueError(f"Unknown bc={bc!r}") # --- masked sliding sums with cumsum --- # Numerator: sum(X * M) over the window; denominator: sum(M) over the window. # We prepend a zero so "sum over [i, i+W)" becomes csum[i+W] - csum[i]. Xw = Xp * Mp[..., None] # broadcast mask onto channels num_csum = np.cumsum(Xw, axis=-2) den_csum = np.cumsum(Mp.astype(np.int64), axis=-1) # Prepend zeros along the summed axis to use the classic sliding window diff. num0 = np.concatenate([np.zeros_like(num_csum[..., :1, :]), num_csum], axis=-2) den0 = np.concatenate([np.zeros_like(den_csum[..., :1]), den_csum], axis=-1) num = num0[..., W:, :] - num0[..., :-W, :] # (..., L, C) den = den0[..., W:] - den0[..., :-W] # (..., L) # Average where den>0; else output 0 and mask False. Xf = np.divide( num, den[..., None], out=np.zeros_like(num, dtype=X1.dtype), where=(den[..., None] > 0), ) Mf = den > 0 # Move axes back to original positions. X_out = np.moveaxis(Xf, -2, axis) M_out = np.moveaxis(Mf, -1, axis) return X_out, M_out def _box_blur_nd( Xg: np.ndarray, Mg: np.ndarray, *, radius: int, bc: Literal["pbc", "noflux", "drop"] = "noflux", ) -> tuple[np.ndarray, np.ndarray]: """ Separable N-D masked box blur over **all grid axes**. Parameters ---------- Xg Grid data of shape ``(T, *grid_shape, d)``. Mg Grid mask of shape ``(T, *grid_shape)``. radius Blur radius (same on every grid axis). Window size per axis: ``2r+1``. bc Boundary handling; passed to `_box_filter_1d`. Notes ----- This is implemented as consecutive 1D masked box filters along each grid axis. That yields an N-D box kernel because the box is separable. The mask is updated at each 1D pass: a voxel survives if it had **at least one** valid contributor in the window along that axis (and recursively along all axes after all passes). """ r = int(radius) if r <= 0: return Xg, Mg X = Xg M = Mg # Convention assumed by your SPDE code: time axis 0, grid axes 1..ndim, channel last. ndim = X.ndim - 2 for ax in range(1, 1 + ndim): X, M = _box_filter_1d(X, M, radius=r, axis=ax, bc=bc) return X, M def _downscale_nd( Xg: np.ndarray, Mg: np.ndarray, *, factors: tuple[int, ...], method: Literal["average", "mean", "subsample", "nearest"] = "average", mask_threshold: float = 0.0, ) -> tuple[np.ndarray, np.ndarray]: """ Downscale a masked grid by integer block factors. Parameters ---------- Xg Shape ``(T, *grid_shape, d)``. Mg Shape ``(T, *grid_shape)``. factors Integer factors per grid axis, e.g. (2,2) maps (Nx,Ny)->(Nx/2,Ny/2). method - "average"/"mean": masked block-average (recommended) - "subsample"/"nearest": pick the [0,0,...] corner of each block mask_threshold Fraction in [0,1]. A coarse cell is marked valid iff the number of valid fine pixels in the block is at least: ceil(mask_threshold * prod(factors)) Special case: if mask_threshold <= 0, we only require at least 1 valid pixel. Returns ------- Xc, Mc Coarse grid values and mask, shapes ``(T, *coarse_shape, d)``, ``(T, *coarse_shape)``. Notes ----- - If a grid axis is not divisible by its factor, we **crop** the high end. (This avoids inventing boundary conventions at the degradation level.) - For "average": values are averaged over **valid** pixels only. Blocks with no valid pixels output X=0 and mask=False. """ if Xg.ndim < 3: raise ValueError("Xg must have shape (T, *grid, d).") if Mg.shape != Xg.shape[:-1]: raise ValueError("Mg must have shape (T, *grid) aligned with Xg.") ndim = Xg.ndim - 2 if len(factors) != ndim: raise ValueError(f"factors must have length {ndim} (got {len(factors)})") fac = tuple(int(f) for f in factors) if any(f < 1 for f in fac): raise ValueError("All downscale factors must be >= 1.") T = Xg.shape[0] grid_shape = Xg.shape[1:-1] d = Xg.shape[-1] coarse_shape = tuple(int(gs // f) for gs, f in zip(grid_shape, fac)) crop_shape = tuple(cs * f for cs, f in zip(coarse_shape, fac)) # Crop fine grid so each axis is divisible by its factor. slicer = (slice(None),) + tuple(slice(0, Lc) for Lc in crop_shape) + (slice(None),) X = Xg[slicer] M = Mg[(slice(None),) + tuple(slice(0, Lc) for Lc in crop_shape)] # Reshape into interleaved (coarse_axis, block_axis) pairs: # (T, n1, f1, n2, f2, ..., d) new_shape_X = [T] new_shape_M = [T] for cs, f in zip(coarse_shape, fac): new_shape_X.extend([cs, f]) new_shape_M.extend([cs, f]) new_shape_X.append(d) Xr = X.reshape(tuple(new_shape_X)) Mr = M.reshape(tuple(new_shape_M)) # Block axes are the "f" axes: indices 2,4,6,... in Mr; in Xr they are 2,4,6,... as well. block_axes = tuple(1 + 2 * i + 1 for i in range(ndim)) # (2,4,6,... in 1-based count) -> actual indices # Example ndim=2: shape (T, n1, f1, n2, f2, d) -> block_axes=(2,4) # Count valid pixels per block. den = Mr.sum(axis=block_axes) # (T, *coarse_shape) # Decide coarse mask according to threshold. block_size = int(np.prod(fac, dtype=np.int64)) if mask_threshold <= 0.0: required = 1 else: required = int(np.ceil(float(mask_threshold) * block_size)) required = max(1, min(required, block_size)) Mc = den >= required meth = method.lower() if meth in ("average", "mean"): # Masked sum over blocks. num = (Xr * Mr[..., None]).sum(axis=block_axes) # (T, *coarse_shape, d) Xc = np.divide( num, den[..., None], out=np.zeros_like(num, dtype=Xg.dtype), where=(den[..., None] > 0), ) # Ensure blocks deemed invalid are neutral (for downstream ops expecting mask-consistency). Xc = np.where(Mc[..., None], Xc, 0.0) return Xc, Mc if meth in ("subsample", "nearest"): # Take the [0,0,...] entry of each block. # Build an index tuple selecting 0 on each block axis. idx = [slice(None)] for _ in range(ndim): idx.append(slice(None)) # coarse axis idx.append(0) # block axis idx.append(slice(None)) # channels Xc = Xr[tuple(idx)] # For mask, do the same slicing. idxM = [slice(None)] for _ in range(ndim): idxM.append(slice(None)) idxM.append(0) Mc0 = Mr[tuple(idxM)] # Still respect the threshold rule if requested. Mc = Mc & Mc0 Xc = np.where(Mc[..., None], Xc, 0.0) return Xc, Mc raise ValueError(f"Unknown downscale method: {method!r}") def _downscale_extras_local_grid( extras_local: Optional[Mapping[str, Any]], *, grid_shape: tuple[int, ...], factors: tuple[int, ...], method: Literal["average", "mean", "subsample", "nearest"] = "average", mask: Optional[np.ndarray] = None, mask_threshold: float = 0.0, ) -> Dict[str, Any]: """ Downscale **flattened per-site** extras alongside a grid downscale. This is meant for datasets where "particles" are actually grid sites (N = prod(grid_shape)), and `extras_local` may store per-site arrays like: - coordinates (x,y) per pixel, - edge vectors per pixel, - per-pixel calibration factors, etc. Heuristics (non-destructive): ----------------------------- - If an entry looks like per-site constants with shape (N, ...), it is reshaped to (*grid_shape, ...), downscaled, then re-flattened. - If an entry looks like time-dependent per-site extras with shape (T, N, ...), it is reshaped to (T, *grid_shape, ...), downscaled, then re-flattened. - Anything else is passed through unchanged. Parameters ---------- extras_local Dict-like mapping of extras. grid_shape Fine grid shape used to interpret the flattened N axis. factors Integer downscale factors per grid axis. method, mask_threshold Passed to `_downscale_nd`. mask Optional fine-grid mask to use for time-dependent extras, either: - shape (T, N) / (T, *grid_shape), or - shape (T, *grid_shape). If omitted, extras are downscaled as if fully observed. Returns ------- dict New extras_local dictionary with downscaled per-site entries. """ out: Dict[str, Any] = dict(extras_local or {}) if not extras_local: return out grid_shape = tuple(int(s) for s in grid_shape) ndim = len(grid_shape) N = int(np.prod(grid_shape, dtype=np.int64)) # Normalize mask to (T, *grid) if provided. Mg = None if mask is not None: Mm = np.asarray(mask) if Mm.ndim == 2 and Mm.shape[1] == N: # (T, N) flattened -> (T, *grid) Tm = Mm.shape[0] Mg = Mm.reshape((Tm,) + grid_shape) elif Mm.ndim == 1 and Mm.shape[0] == N: Mg = Mm.reshape((1,) + grid_shape) elif Mm.ndim == 1 + ndim: Mg = Mm else: raise ValueError(f"mask has incompatible shape {Mm.shape} for grid_shape={grid_shape} (N={N}).") for k, v in list(extras_local.items()): arr = np.asarray(v) # Case A: per-site constants (N, ...) if arr.ndim >= 1 and arr.shape[0] == N: tail = arr.shape[1:] # Represent as (T=1, *grid, C) by flattening tail into channels. C = int(np.prod(tail, dtype=np.int64)) if tail else 1 Xg = arr.reshape(grid_shape + tail).reshape((1,) + grid_shape + (C,)) Mg_use = np.ones((1,) + grid_shape, dtype=bool) Xc, Mc = _downscale_nd( Xg, Mg_use, factors=factors, method=method, mask_threshold=mask_threshold, ) # Flatten back: (*coarse, C) -> (N2, tail) coarse_shape = Xc.shape[1:-1] N2 = int(np.prod(coarse_shape, dtype=np.int64)) out[k] = Xc.reshape((1, N2, C))[0].reshape((N2,) + tail) continue # Case B: time-dependent per-site extras (T, N, ...) if arr.ndim >= 2 and arr.shape[1] == N: Tm = arr.shape[0] tail = arr.shape[2:] C = int(np.prod(tail, dtype=np.int64)) if tail else 1 Xg = arr.reshape((Tm,) + grid_shape + tail).reshape((Tm,) + grid_shape + (C,)) Mg_use = Mg if Mg_use is None: Mg_use = np.ones((Tm,) + grid_shape, dtype=bool) elif Mg_use.shape[0] == 1 and Tm != 1: # Broadcast a static mask across time if needed. Mg_use = np.broadcast_to(Mg_use, (Tm,) + grid_shape) Xc, Mc = _downscale_nd( Xg, Mg_use, factors=factors, method=method, mask_threshold=mask_threshold, ) coarse_shape = Xc.shape[1:-1] N2 = int(np.prod(coarse_shape, dtype=np.int64)) out[k] = Xc.reshape((Tm, N2, C)).reshape((Tm, N2) + tail) continue # Otherwise: keep as-is. out[k] = v return out def _normalize_dx(dx, *, ndim: int) -> tuple[float, ...] | None: """Normalize dx to a length-ndim tuple of floats (or None).""" if dx is None: return None arr = np.asarray(dx) if arr.ndim == 0: return (float(arr),) * ndim vals = tuple(float(v) for v in arr.reshape(-1).tolist()) if len(vals) == 1 and ndim > 1: return (vals[0],) * ndim if len(vals) != ndim: raise ValueError(f"dx must be scalar or length {ndim} (got {len(vals)})") return vals