"""StructuredExpr — declarative structured tensor expressions (inner world).
Defines :class:`StructuredExpr`, the user-facing expression type for the
inner physics world, and the lightweight ``_StructNode`` IR that encodes
the computation graph. Expressions are symbolic and cannot be evaluated
directly; use ``Layout.embed()`` to compile them into outer-world
``StateExpr`` objects.
"""
from __future__ import annotations
import itertools
import math
from dataclasses import dataclass
from typing import Any, Callable, Sequence
import jax.numpy as jnp
from .params import ParamSpec, ParamSuite
# =====================================================================
# Constants
# =====================================================================
_FREE_LAYOUT: int = -1 # layout-agnostic sentinel (compatible with any layout)
_EINSUM_LETTERS: str = "ijklmnopqrstuvwxyz"
# Sentinel for "not provided" (distinct from None)
_MISSING: object = object()
# =====================================================================
# _StructNode: lightweight IR (frozen dataclasses, no eval logic)
# =====================================================================
# Users never see these directly. The embed compiler (Phase 3) walks
# them to produce outer-world StateExpr nodes.
@dataclass(frozen=True, slots=True)
class _StructNode:
"""Base marker for all IR nodes."""
# --- Leaves -----------------------------------------------------------
@dataclass(frozen=True, slots=True)
class _SectorLeaf(_StructNode):
"""Field extracted from a Layout sector."""
sector_name: str
indices: tuple[int, ...]
sdims: tuple[int, ...]
@dataclass(frozen=True, slots=True)
class _ConstNode(_StructNode):
"""Compile-time constant (scalar, eye, etc.)."""
value: Any # int | float | jax.Array
sdims: tuple[int, ...]
@dataclass(frozen=True, slots=True)
class _ParamLeaf(_StructNode):
"""Learnable parameter block."""
param_spec: ParamSpec
sdims: tuple[int, ...]
# --- Layout-engine operators ------------------------------------------
@dataclass(frozen=True, slots=True)
class _DiffOpNode(_StructNode):
"""Differential operator (grad, lap, div, …). Engine-specific."""
op_name: str
child: _StructNode
engine_meta: Any # opaque to algebra
@dataclass(frozen=True, slots=True)
class _InteractionNode(_StructNode):
"""K-body interaction. Engine-specific."""
fn: Callable
reads: tuple
writes: Any
arity: int
spec_factory: Any
engine_meta: Any
# --- Pure algebra (engine-agnostic) -----------------------------------
@dataclass(frozen=True, slots=True)
class _BinaryOp(_StructNode):
"""Binary arithmetic: ``+ - * / **``."""
op: str
left: _StructNode
right: _StructNode
@dataclass(frozen=True, slots=True)
class _EinsumOp(_StructNode):
"""Einstein summation over rank axes."""
spec: str # e.g. 'ij,jk->ik'
children: tuple[_StructNode, ...]
@dataclass(frozen=True, slots=True)
class _ConcatOp(_StructNode):
"""Feature concatenation (``&``)."""
children: tuple[_StructNode, ...]
@dataclass(frozen=True, slots=True)
class _StackOp(_StructNode):
"""Stack expressions along a new leading rank axis."""
children: tuple[_StructNode, ...]
sdim: int
@dataclass(frozen=True, slots=True)
class _SliceOp(_StructNode):
"""Feature selection (``expr[idx]``)."""
child: _StructNode
idx: Any # int | slice | tuple[int, ...]
@dataclass(frozen=True, slots=True)
class _UnaryOp(_StructNode):
"""Unary: neg, transpose, elementwisemap."""
op: str # "neg" | "T" | "ew"
child: _StructNode
fn: Callable | None = None # only for "ew"
@dataclass(frozen=True, slots=True)
class _ReshapeOp(_StructNode):
"""Reshape between rank axes and features."""
child: _StructNode
target_sdims: tuple[int, ...]
@dataclass(frozen=True, slots=True)
class _DenseOp(_StructNode):
"""Learnable affine layer on features."""
child: _StructNode
n_out: int
param_spec: ParamSpec
# =====================================================================
# Helpers
# =====================================================================
def _merge_params(
a: ParamSuite | None,
b: ParamSuite | None,
) -> ParamSuite | None:
"""Merge two optional ParamSuites (sharing-by-name)."""
if a is None:
return b
if b is None:
return a
return a.merge(b)
def _resolve_layout(a: int, b: int) -> int:
"""Resolve layout IDs. ``_FREE_LAYOUT`` is compatible with anything."""
if a == _FREE_LAYOUT:
return b
if b == _FREE_LAYOUT:
return a
if a != b:
raise TypeError(f"Cannot combine expressions from different Layouts (layout ids {a} vs {b})")
return a
def _parse_einsum(spec: str) -> tuple[list[str], str]:
"""Parse ``'ij,jk->ik'`` → ``(['ij', 'jk'], 'ik')``."""
if "->" not in spec:
raise ValueError(f"Einsum spec must contain '->': {spec!r}")
lhs, rhs = spec.split("->", 1)
return lhs.split(","), rhs
def _validate_einsum(
operand_specs: list[str],
rhs: str,
operands: Sequence[StructuredExpr],
) -> tuple[int, ...]:
"""Validate einsum letter→size mapping; return output sdims."""
if len(operand_specs) != len(operands):
raise ValueError(f"Einsum spec has {len(operand_specs)} operands, got {len(operands)} expressions")
letter_size: dict[str, int] = {}
for spec_str, expr in zip(operand_specs, operands):
if len(spec_str) != expr.srank:
raise ValueError(
f"Einsum operand spec '{spec_str}' has {len(spec_str)} axes, "
f"but expression has srank={expr.srank} (sdims={expr.sdims})"
)
for ch, sz in zip(spec_str, expr.sdims):
if ch in letter_size:
if letter_size[ch] != sz:
raise ValueError(f"Einsum letter '{ch}' has conflicting sizes: {letter_size[ch]} vs {sz}")
else:
letter_size[ch] = sz
for ch in rhs:
if ch not in letter_size:
raise ValueError(f"Einsum output letter '{ch}' not found in any input operand")
return tuple(letter_size[ch] for ch in rhs)
def _coerce_scalar(value: Any) -> StructuredExpr | None:
"""Wrap a Python/JAX scalar as a constant ``StructuredExpr``, or None."""
if isinstance(value, (int, float)):
return StructuredExpr(
sdims=(),
n_features=1,
param_suite=None,
labels=(),
_layout_id=_FREE_LAYOUT,
_node=_ConstNode(value=value, sdims=()),
)
if hasattr(value, "shape") and getattr(value, "shape") == ():
return StructuredExpr(
sdims=(),
n_features=1,
param_suite=None,
labels=(),
_layout_id=_FREE_LAYOUT,
_node=_ConstNode(value=value, sdims=()),
)
return None
# =====================================================================
# Auto-label helpers
# =====================================================================
_SUPERSCRIPT_DIGITS = str.maketrans("0123456789-", "\u2070\u00b9\u00b2\u00b3\u2074\u2075\u2076\u2077\u2078\u2079\u207b")
# Characters in a label that signal it is "compound" and needs
# parenthesisation when used as a factor in a product.
_MUL_PAREN_CHARS = frozenset("+-/\u00b7")
def _int_superscript(n: int) -> str:
"""Convert an integer to a unicode superscript string."""
return str(n).translate(_SUPERSCRIPT_DIGITS)
def _complete_labels(labels: tuple[str, ...], n_features: int) -> bool:
"""True when *labels* has exactly one non-empty entry per feature."""
return len(labels) == n_features > 0 and all(labels)
def _is_const_value(node: _StructNode, value: float | int) -> bool:
"""True when *node* is a ``_ConstNode`` with the given Python scalar."""
if not isinstance(node, _ConstNode):
return False
v = node.value
return isinstance(v, (int, float)) and v == value
def _const_int(node: _StructNode) -> int | None:
"""Return the integer value of a ``_ConstNode``, or ``None``."""
if not isinstance(node, _ConstNode):
return None
v = node.value
if isinstance(v, int):
return v
if isinstance(v, float) and v == int(v) and abs(v) < 1000:
return int(v)
return None
def _paren_for_pow(label: str) -> str:
"""Wrap in parentheses when raising to a power (multi-char labels)."""
return f"({label})" if len(label) > 1 else label
def _paren_for_mul(label: str) -> str:
"""Wrap in parentheses when used as a factor in a product."""
check = label[1:] if label.startswith("-") else label
return f"({label})" if any(c in _MUL_PAREN_CHARS for c in check) else label
def _auto_mul_labels(
a_labels: tuple[str, ...],
a_nf: int,
a_node: _StructNode,
b_labels: tuple[str, ...],
b_nf: int,
b_node: _StructNode,
) -> tuple[str, ...]:
"""Auto-generate labels for element-wise multiplication ``a * b``."""
# multiply by 1 -> identity
if _is_const_value(b_node, 1) and _complete_labels(a_labels, a_nf):
return a_labels
if _is_const_value(a_node, 1) and _complete_labels(b_labels, b_nf):
return b_labels
# both fully labelled -> Cartesian product (juxtaposition)
if _complete_labels(a_labels, a_nf) and _complete_labels(b_labels, b_nf):
return tuple(_paren_for_mul(la) + _paren_for_mul(lb) for la in a_labels for lb in b_labels)
return ()
def _auto_pow_labels(
base_labels: tuple[str, ...],
base_nf: int,
exp_node: _StructNode,
) -> tuple[str, ...]:
"""Auto-generate labels for ``base ** exp``."""
n = _const_int(exp_node)
if n is None or not _complete_labels(base_labels, base_nf):
return ()
if n == 0:
return tuple("1" for _ in base_labels)
if n == 1:
return base_labels
sup = _int_superscript(n)
return tuple(f"{_paren_for_pow(lab)}{sup}" for lab in base_labels)
def _auto_sub_labels(
a_labels: tuple[str, ...],
a_nf: int,
b_labels: tuple[str, ...],
b_nf: int,
) -> tuple[str, ...]:
"""Auto-generate labels for ``a - b``."""
if not _complete_labels(a_labels, a_nf) or not _complete_labels(b_labels, b_nf):
return ()
if a_nf != b_nf:
return ()
return tuple(f"{la}-{_paren_for_mul(lb)}" for la, lb in zip(a_labels, b_labels))
def _auto_div_labels(
a_labels: tuple[str, ...],
a_nf: int,
b_labels: tuple[str, ...],
b_nf: int,
) -> tuple[str, ...]:
"""Auto-generate labels for ``a / b``."""
if not _complete_labels(a_labels, a_nf) or not _complete_labels(b_labels, b_nf):
return ()
if b_nf == 1:
d = _paren_for_mul(b_labels[0])
return tuple(f"{_paren_for_mul(la)}/{d}" for la in a_labels)
if a_nf == b_nf:
return tuple(f"{_paren_for_mul(la)}/{_paren_for_mul(lb)}" for la, lb in zip(a_labels, b_labels))
return ()
def _auto_add_labels(
a_labels: tuple[str, ...],
a_nf: int,
b_labels: tuple[str, ...],
b_nf: int,
) -> tuple[str, ...]:
"""Auto-generate labels for ``a + b``.
When both sides are fully labelled, produces ``"a+b"`` per feature.
Otherwise falls back to first-non-empty-wins (the previous default).
"""
if _complete_labels(a_labels, a_nf) and _complete_labels(b_labels, b_nf) and a_nf == b_nf:
return tuple(f"{la}+{lb}" for la, lb in zip(a_labels, b_labels))
return a_labels or b_labels
def _auto_einsum_labels(
operand_info: Sequence[tuple[tuple[str, ...], int]],
) -> tuple[str, ...]:
"""Auto-generate labels for ``einsum`` (Cartesian product with ``\u00b7``)."""
for labels, nf in operand_info:
if not _complete_labels(labels, nf):
return ()
label_lists = [labels for labels, _nf in operand_info]
return tuple("\u00b7".join(combo) for combo in itertools.product(*label_lists))
def _auto_ew_labels(
fn: Callable,
labels: tuple[str, ...],
n_features: int,
name: str | None = None,
) -> tuple[str, ...]:
"""Auto-generate labels for ``elementwisemap(fn, \u2026)``."""
if not _complete_labels(labels, n_features):
return ()
nm = name or getattr(fn, "__name__", "")
if not nm or nm == "<lambda>":
return ()
return tuple(f"{nm}({lab})" for lab in labels)
def _feature_count(idx: Any, n_features: int) -> int:
"""Compute output feature count for a ``__getitem__`` index."""
if isinstance(idx, int):
if idx < -n_features or idx >= n_features:
raise IndexError(f"Feature index {idx} out of range for n_features={n_features}")
return 1
if isinstance(idx, slice):
return len(range(*idx.indices(n_features)))
if isinstance(idx, (list, tuple)):
for i in idx:
if not isinstance(i, int):
raise TypeError(f"Feature indices must be ints, got {type(i).__name__}")
if i < -n_features or i >= n_features:
raise IndexError(f"Feature index {i} out of range for n_features={n_features}")
return len(idx)
raise TypeError(f"Unsupported index type: {type(idx).__name__}")
# =====================================================================
# StructuredExpr
# =====================================================================
[docs]
@dataclass(frozen=True, slots=True, eq=False)
class StructuredExpr:
"""Declarative structured tensor expression (inner world).
Users build these via Layout methods and algebra. They are symbolic
and cannot be evaluated directly — use ``Layout.embed()`` to compile
to an outer-world ``StateExpr``.
Attributes
----------
sdims : tuple[int, ...]
Per-axis sizes. ``()`` = scalar, ``(2,)`` = 2-vector, etc.
n_features : int
Number of independent regression channels (last axis).
param_suite : ParamSuite | None
Learnable parameters (``None`` = pure basis).
labels : tuple[str, ...]
Human-readable feature labels.
"""
sdims: tuple[int, ...]
n_features: int
param_suite: ParamSuite | None
labels: tuple[str, ...]
_layout_id: int
_node: _StructNode
# --- properties ---------------------------------------------------
@property
def srank(self) -> int:
"""Number of structured dimensions (rank axes)."""
return len(self.sdims)
# --- private helpers ----------------------------------------------
def _new(
self,
*,
sdims: tuple[int, ...] | None = None,
n_features: int | None = None,
param_suite: Any = _MISSING,
labels: tuple[str, ...] = (),
layout_id: int | None = None,
node: _StructNode,
) -> StructuredExpr:
"""Shortcut: create a new expr, inheriting metadata from *self*."""
return StructuredExpr(
sdims=sdims if sdims is not None else self.sdims,
n_features=(n_features if n_features is not None else self.n_features),
param_suite=(param_suite if param_suite is not _MISSING else self.param_suite),
labels=labels,
_layout_id=(layout_id if layout_id is not None else self._layout_id),
_node=node,
)
def _check_layout(self, other: StructuredExpr) -> None:
_resolve_layout(self._layout_id, other._layout_id)
# =================================================================
# Arithmetic: + - * / ** (unary -)
# =================================================================
def __add__(self, other: Any) -> StructuredExpr:
if not isinstance(other, StructuredExpr):
other = _coerce_scalar(other)
if other is None:
return NotImplemented # type: ignore[return-value]
self._check_layout(other)
if self.sdims != other.sdims:
raise ValueError(f"Cannot add: sdims mismatch {self.sdims} vs {other.sdims}")
if self.n_features != other.n_features:
raise ValueError(f"Cannot add: n_features mismatch {self.n_features} vs {other.n_features}")
return StructuredExpr(
sdims=self.sdims,
n_features=self.n_features,
param_suite=_merge_params(self.param_suite, other.param_suite),
labels=_auto_add_labels(
self.labels,
self.n_features,
other.labels,
other.n_features,
),
_layout_id=_resolve_layout(self._layout_id, other._layout_id),
_node=_BinaryOp("+", self._node, other._node),
)
def __radd__(self, other: Any) -> StructuredExpr:
return self.__add__(other) # addition is commutative
def __sub__(self, other: Any) -> StructuredExpr:
if not isinstance(other, StructuredExpr):
other = _coerce_scalar(other)
if other is None:
return NotImplemented # type: ignore[return-value]
self._check_layout(other)
if self.sdims != other.sdims:
raise ValueError(f"Cannot subtract: sdims mismatch {self.sdims} vs {other.sdims}")
if self.n_features != other.n_features:
raise ValueError(f"Cannot subtract: n_features mismatch {self.n_features} vs {other.n_features}")
return StructuredExpr(
sdims=self.sdims,
n_features=self.n_features,
param_suite=_merge_params(self.param_suite, other.param_suite),
labels=_auto_sub_labels(
self.labels,
self.n_features,
other.labels,
other.n_features,
),
_layout_id=_resolve_layout(self._layout_id, other._layout_id),
_node=_BinaryOp("-", self._node, other._node),
)
def __rsub__(self, other: Any) -> StructuredExpr:
other_expr = _coerce_scalar(other)
if other_expr is None:
return NotImplemented # type: ignore[return-value]
return other_expr.__sub__(self)
def __mul__(self, other: Any) -> StructuredExpr:
if not isinstance(other, StructuredExpr):
other = _coerce_scalar(other)
if other is None:
return NotImplemented # type: ignore[return-value]
self._check_layout(other)
a, b = self, other
if a.srank == 0:
sdims = b.sdims
elif b.srank == 0:
sdims = a.sdims
elif a.sdims == b.sdims:
sdims = a.sdims
else:
raise TypeError(
f"Cannot multiply expressions with incompatible sdims "
f"{a.sdims} and {b.sdims}. Use einsum() for mixed-rank "
f"contractions."
)
return StructuredExpr(
sdims=sdims,
n_features=a.n_features * b.n_features,
param_suite=_merge_params(a.param_suite, b.param_suite),
labels=_auto_mul_labels(
a.labels,
a.n_features,
a._node,
b.labels,
b.n_features,
b._node,
),
_layout_id=_resolve_layout(a._layout_id, b._layout_id),
_node=_BinaryOp("*", a._node, b._node),
)
def __rmul__(self, other: Any) -> StructuredExpr:
return self.__mul__(other) # multiplication is commutative
def __truediv__(self, other: Any) -> StructuredExpr:
if not isinstance(other, StructuredExpr):
other = _coerce_scalar(other)
if other is None:
return NotImplemented # type: ignore[return-value]
self._check_layout(other)
if other.n_features not in (1, self.n_features):
raise ValueError(
f"Division requires divisor n_features=1 or matching "
f"n_features={self.n_features}, got {other.n_features}"
)
if other.srank > 0 and other.sdims != self.sdims:
raise ValueError(f"Cannot divide: sdims mismatch {self.sdims} vs {other.sdims}")
return StructuredExpr(
sdims=self.sdims,
n_features=self.n_features,
param_suite=_merge_params(self.param_suite, other.param_suite),
labels=_auto_div_labels(
self.labels,
self.n_features,
other.labels,
other.n_features,
),
_layout_id=_resolve_layout(self._layout_id, other._layout_id),
_node=_BinaryOp("/", self._node, other._node),
)
def __rtruediv__(self, other: Any) -> StructuredExpr:
other_expr = _coerce_scalar(other)
if other_expr is None:
return NotImplemented # type: ignore[return-value]
return other_expr.__truediv__(self)
def __pow__(self, other: Any) -> StructuredExpr:
if not isinstance(other, StructuredExpr):
other = _coerce_scalar(other)
if other is None:
return NotImplemented # type: ignore[return-value]
self._check_layout(other)
if other.srank != 0 or other.n_features != 1:
raise TypeError(
f"Exponent must be a scalar with n_features=1, got sdims={other.sdims}, n_features={other.n_features}"
)
return StructuredExpr(
sdims=self.sdims,
n_features=self.n_features,
param_suite=_merge_params(self.param_suite, other.param_suite),
labels=_auto_pow_labels(
self.labels,
self.n_features,
other._node,
),
_layout_id=_resolve_layout(self._layout_id, other._layout_id),
_node=_BinaryOp("**", self._node, other._node),
)
def __neg__(self) -> StructuredExpr:
return self._new(labels=self.labels, node=_UnaryOp("neg", self._node))
def __pos__(self) -> StructuredExpr:
return self
# =================================================================
# Human-readable label
# =================================================================
[docs]
def with_label(self, label: str) -> StructuredExpr:
"""Return a copy of this expression with a single human-readable label.
Useful for annotating derived quantities (arithmetic, einsum, …)
so that ``print_report`` shows a meaningful term name instead of
a generic fallback.
Parameters
----------
label : str
Human-readable name for this term, e.g. ``"|Q|²Q"``.
Returns
-------
StructuredExpr
Identical expression with ``labels=(label,)``.
"""
if self.n_features != 1:
raise ValueError(
f"with_label() requires n_features=1, "
f"got n_features={self.n_features}. "
"Use the & operator to concatenate labelled single-feature "
"terms instead of labelling a multi-feature block."
)
return StructuredExpr(
sdims=self.sdims,
n_features=self.n_features,
param_suite=self.param_suite,
labels=(label,),
_layout_id=self._layout_id,
_node=self._node,
)
# =================================================================
# Feature concatenation (&)
# =================================================================
def __and__(self, other: Any) -> StructuredExpr:
if not isinstance(other, StructuredExpr):
return NotImplemented # type: ignore[return-value]
self._check_layout(other)
if self.sdims != other.sdims:
raise ValueError(f"Cannot concatenate (&): sdims mismatch {self.sdims} vs {other.sdims}")
# Flatten left-side concats so labels[i] aligns with children[i]
# (avoids misalignment when _compile_sector iterates node.children)
if isinstance(self._node, _ConcatOp):
new_children = self._node.children + (other._node,)
else:
new_children = (self._node, other._node)
return StructuredExpr(
sdims=self.sdims,
n_features=self.n_features + other.n_features,
param_suite=_merge_params(self.param_suite, other.param_suite),
labels=self.labels + other.labels,
_layout_id=_resolve_layout(self._layout_id, other._layout_id),
_node=_ConcatOp(new_children),
)
# =================================================================
# Feature selection (expr[idx])
# =================================================================
def __getitem__(self, idx: Any) -> StructuredExpr:
n = _feature_count(idx, self.n_features)
if isinstance(idx, int):
lbls = (self.labels[idx],) if idx < len(self.labels) else ()
elif isinstance(idx, slice):
lbls = tuple(self.labels[idx]) if self.labels else ()
else:
lbls = ()
return self._new(
n_features=n,
labels=lbls,
node=_SliceOp(self._node, idx),
)
# =================================================================
# Transpose
# =================================================================
@property
def T(self) -> StructuredExpr:
"""Swap last two rank axes. Requires ``srank >= 2``."""
if self.srank < 2:
raise TypeError(f"Transpose requires srank >= 2, got srank={self.srank} (sdims={self.sdims})")
new_sdims = self.sdims[:-2] + (self.sdims[-1], self.sdims[-2])
return self._new(
sdims=new_sdims,
node=_UnaryOp("T", self._node),
)
# =================================================================
# Matmul (@)
# =================================================================
def __matmul__(self, other: Any) -> StructuredExpr:
"""Contract last axis of *self* with first axis of *other*."""
if not isinstance(other, StructuredExpr):
return NotImplemented # type: ignore[return-value]
if self.srank < 1 or other.srank < 1:
raise TypeError(f"Matmul requires srank >= 1 on both sides, got {self.srank} and {other.srank}")
if self.sdims[-1] != other.sdims[0]:
raise ValueError(
f"Matmul contraction mismatch: self.sdims[-1]={self.sdims[-1]} vs other.sdims[0]={other.sdims[0]}"
)
pool = iter(_EINSUM_LETTERS)
a_letters = [next(pool) for _ in range(self.srank)]
shared = a_letters[-1]
b_letters = [shared] + [next(pool) for _ in range(other.srank - 1)]
spec = "".join(a_letters) + "," + "".join(b_letters) + "->" + "".join(a_letters[:-1]) + "".join(b_letters[1:])
return type(self).einsum(spec, self, other)
# =================================================================
# Dot (contract last axes of both)
# =================================================================
[docs]
def dot(self, other: StructuredExpr) -> StructuredExpr:
"""Contract last axis of *self* with last axis of *other*."""
if not isinstance(other, StructuredExpr):
raise TypeError(f"Expected StructuredExpr, got {type(other).__name__}")
if self.srank < 1 or other.srank < 1:
raise TypeError(f"dot requires srank >= 1 on both sides, got {self.srank} and {other.srank}")
if self.sdims[-1] != other.sdims[-1]:
raise ValueError(
f"dot contraction mismatch: self.sdims[-1]={self.sdims[-1]} vs other.sdims[-1]={other.sdims[-1]}"
)
pool = iter(_EINSUM_LETTERS)
a_rest = [next(pool) for _ in range(self.srank - 1)]
b_rest = [next(pool) for _ in range(other.srank - 1)]
shared = next(pool)
spec = "".join(a_rest + [shared]) + "," + "".join(b_rest + [shared]) + "->" + "".join(a_rest + b_rest)
return type(self).einsum(spec, self, other)
# =================================================================
# Einsum (static method)
# =================================================================
[docs]
@staticmethod
def einsum(spec: str, *operands: StructuredExpr) -> StructuredExpr:
"""Einstein summation over rank axes.
Example::
Q = StructuredExpr.einsum('i,j->ij', n, n)
"""
operand_specs, rhs = _parse_einsum(spec)
output_sdims = _validate_einsum(operand_specs, rhs, operands)
layout_id = _FREE_LAYOUT
params: ParamSuite | None = None
for op in operands:
layout_id = _resolve_layout(layout_id, op._layout_id)
params = _merge_params(params, op.param_suite)
n_features = math.prod(op.n_features for op in operands)
return StructuredExpr(
sdims=output_sdims,
n_features=n_features,
param_suite=params,
labels=_auto_einsum_labels([(op.labels, op.n_features) for op in operands]),
_layout_id=layout_id,
_node=_EinsumOp(spec, tuple(op._node for op in operands)),
)
# =================================================================
# Stack (classmethod — build vector from scalars)
# =================================================================
[docs]
@classmethod
def stack(
cls,
exprs: Sequence[StructuredExpr],
*,
sdim: int | None = None,
) -> StructuredExpr:
"""Stack expressions along a new leading rank axis.
All inputs must share the same ``sdims`` and ``n_features``.
``sdim`` defaults to ``len(exprs)``.
"""
exprs = list(exprs)
if not exprs:
raise ValueError("stack requires at least one expression")
if sdim is None:
sdim = len(exprs)
if sdim != len(exprs):
raise ValueError(f"sdim={sdim} does not match number of expressions ({len(exprs)})")
ref = exprs[0]
layout_id = ref._layout_id
params = ref.param_suite
for e in exprs[1:]:
layout_id = _resolve_layout(layout_id, e._layout_id)
if e.sdims != ref.sdims:
raise ValueError(f"stack: all expressions must have same sdims, got {ref.sdims} and {e.sdims}")
if e.n_features != ref.n_features:
raise ValueError(
f"stack: all expressions must have same n_features, got {ref.n_features} and {e.n_features}"
)
params = _merge_params(params, e.param_suite)
return StructuredExpr(
sdims=(sdim,) + ref.sdims,
n_features=ref.n_features,
param_suite=params,
labels=(),
_layout_id=layout_id,
_node=_StackOp(tuple(e._node for e in exprs), sdim),
)
# =================================================================
# Eye (classmethod — identity matrix)
# =================================================================
[docs]
@classmethod
def eye(
cls,
sdim: int,
*,
layout_id: int = _FREE_LAYOUT,
) -> StructuredExpr:
"""Identity matrix with ``sdims=(sdim, sdim)``, ``n_features=1``."""
return StructuredExpr(
sdims=(sdim, sdim),
n_features=1,
param_suite=None,
labels=("I",),
_layout_id=layout_id,
_node=_ConstNode(value=jnp.eye(sdim), sdims=(sdim, sdim)),
)
# =================================================================
# Math convenience methods
# =================================================================
[docs]
def sin(self) -> StructuredExpr:
return self.elementwisemap(jnp.sin, name="sin")
[docs]
def cos(self) -> StructuredExpr:
return self.elementwisemap(jnp.cos, name="cos")
[docs]
def exp(self) -> StructuredExpr:
return self.elementwisemap(jnp.exp, name="exp")
[docs]
def log(self) -> StructuredExpr:
return self.elementwisemap(jnp.log, name="log")
[docs]
def tanh(self) -> StructuredExpr:
return self.elementwisemap(jnp.tanh, name="tanh")
[docs]
def abs(self) -> StructuredExpr:
return self.elementwisemap(jnp.abs, name="abs")
[docs]
def sqrt(self) -> StructuredExpr:
return self.elementwisemap(jnp.sqrt, name="sqrt")
[docs]
def elementwisemap(
self,
fn: Callable,
*,
name: str | None = None,
) -> StructuredExpr:
"""Apply a JAX-traceable function elementwise.
Records a ``_UnaryOp("ew", \u2026)`` node in the IR tree. At embed
time this compiles to ``StateExpr.elementwisemap(fn)``.
Parameters
----------
name : str, optional
Override the function name used in auto-generated labels.
Defaults to ``fn.__name__``.
"""
return self._new(
labels=_auto_ew_labels(fn, self.labels, self.n_features, name),
node=_UnaryOp("ew", self._node, fn),
)
# =================================================================
# Reshape operations
# =================================================================
[docs]
def rank_to_features(self) -> StructuredExpr:
"""Flatten all rank axes into the features axis.
``sdims=(2,3), n_features=5`` → ``sdims=(), n_features=30``.
"""
new_nf = self.n_features * math.prod(self.sdims) if self.sdims else self.n_features
return self._new(
sdims=(),
n_features=new_nf,
node=_ReshapeOp(self._node, target_sdims=()),
)
[docs]
def features_to_rank(self, target_sdims: tuple[int, ...]) -> StructuredExpr:
"""Promote features into rank axes. Requires ``srank == 0``.
``sdims=(), n_features=12`` → ``features_to_rank((3,4))``
→ ``sdims=(3,4), n_features=1``.
"""
if self.srank != 0:
raise TypeError(
f"features_to_rank requires srank == 0 (scalar input), "
f"got srank={self.srank}. Use rank_to_features() first."
)
p = math.prod(target_sdims)
if p == 0:
raise ValueError("target_sdims must have all positive sizes")
if self.n_features % p != 0:
raise ValueError(f"n_features={self.n_features} is not divisible by prod(target_sdims)={p}")
return self._new(
sdims=target_sdims,
n_features=self.n_features // p,
node=_ReshapeOp(self._node, target_sdims=target_sdims),
)
# =================================================================
# Dense (learnable affine layer on features)
# =================================================================
[docs]
def dense(self, n_out: int, *, name: str = "dense") -> StructuredExpr:
"""Affine projection of features. Adds a learnable weight matrix."""
ps = ParamSpec(name=name, shape=(n_out, self.n_features))
return StructuredExpr(
sdims=self.sdims,
n_features=n_out,
param_suite=_merge_params(self.param_suite, ParamSuite([ps])),
labels=(),
_layout_id=self._layout_id,
_node=_DenseOp(self._node, n_out, ps),
)
# =================================================================
# repr
# =================================================================
def __repr__(self) -> str:
parts = [
f"sdims={self.sdims}",
f"n_features={self.n_features}",
]
if self.labels:
parts.append(f"labels={self.labels}")
if self.param_suite is not None:
parts.append("has_params=True")
return f"StructuredExpr({', '.join(parts)})"