From 40e1b127599e9b305711711aa2819bfbc6f69f16 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Sat, 23 May 2026 02:55:52 -0700 Subject: [PATCH 01/11] [FieldArray] add qd.field_array(N, dtype) for indexed @qd.dataclass fields Adds a new ``qd.field_array(N, dtype)`` annotation for ``@qd.dataclass`` that exposes a logical N-element array as ``obj.r[i]`` while storing N individually- named synthetic scalar fields (``_r0..._r{N-1}``) under the hood. For python-int indices (including ``qd.static(range(N))``-unrolled loop variables), the AST transformer rewrites ``obj.r[i]`` directly to ``obj._r{i}``, so the generated LLVM IR / PTX is byte-identical to a hand-rolled named-field struct. Motivation: today's idiomatic ``r: qd.types.vector(N, dtype)`` group field leaves an alloca that LLVM SROA can't decompose once register pressure crosses a threshold (e.g. two concurrent tiles in a Cholesky+TRSM kernel), causing runtime regressions via local-memory spills. The named-field cascade pattern avoids the spill but balloons source size (32-way ``if k == N: self.rN = val`` write cascades duplicated at every callsite). ``field_array`` collapses those cascades to one AST node per callsite while preserving the named-field IR. Changes: - ``lang/struct.py``: ``FieldArray`` type wrapper, ``field_array(count, dtype)`` constructor, expansion in ``StructType.__init__`` (synthetic field names plus ``_field_groups`` metadata), propagation in ``StructType.__call__``, ``_FieldArrayRef`` transient proxy. - ``lang/impl.py``: preserve ``_qd_field_groups`` across the ``Struct`` rewrap in ``expr_init``. - ``lang/ast/ast_transformer.py``: ``build_Attribute`` returns a ``_FieldArrayRef`` for group access; ``build_Subscript`` resolves it to a direct field reference for python-int indices. - ``tests/python/test_field_array.py``: 5 tests covering construction, static python-int index, qd.static loop-var index, runtime-index rejection (clear error), and static-index OOB rejection. Runtime-int indexing is intentionally rejected with a friendly error pointing at ``qd.static``; existing cascade helpers continue to handle the runtime case by spelling out the ``_rN`` fields directly. Adding runtime-int support is a small follow-up. Verified on a field_array port of genesis ``_tile32.py``: PTX byte-identical to the named-field S1 baseline (modulo the per-session-nonce comment) on both ``chol_kernel`` and ``chol_trsm_kernel``; zero local-memory spills (S1: 0/0, FA: 0/0, F4-A vector-field variant: 42/97); 25% compile-time reduction on the single-tile harness (5.60s -> 4.19s, 3-run mean). Source dropped from 1068 to 515 lines (-52%). Full writeup in perso_hugh/doc/qd_field_array_2026may23.md. All 201 tests in test_py_dataclass.py + test_complex_struct.py + test_struct.py continue to pass; the 5 new tests pass in 1.76s total. --- python/quadrants/lang/ast/ast_transformer.py | 24 +- python/quadrants/lang/impl.py | 8 +- python/quadrants/lang/struct.py | 112 ++++++++- tests/python/test_field_array.py | 226 +++++++++++++++++++ 4 files changed, 367 insertions(+), 3 deletions(-) create mode 100644 tests/python/test_field_array.py diff --git a/python/quadrants/lang/ast/ast_transformer.py b/python/quadrants/lang/ast/ast_transformer.py index 263a4a11a3..325b8e7d94 100644 --- a/python/quadrants/lang/ast/ast_transformer.py +++ b/python/quadrants/lang/ast/ast_transformer.py @@ -39,7 +39,7 @@ from quadrants.lang.field import Field from quadrants.lang.matrix import Matrix, MatrixType from quadrants.lang.snode import append, deactivate, length -from quadrants.lang.struct import Struct, StructType +from quadrants.lang.struct import Struct, StructType, _FieldArrayRef from quadrants.lang.util import ( is_from_quadrants_module as _is_from_quadrants_module, ) @@ -295,6 +295,18 @@ def _unpack_layout_vector_index(ast_builder, index, layout_len): def build_Subscript(ctx: ASTTransformerFuncContext, node: ast.Subscript): build_stmt(ctx, node.value) build_stmt(ctx, node.slice) + # ``field_array`` group subscript: rewrite ``obj.{group}[k]`` to a direct reference to the + # synthetic scalar field ``_{group}{k}`` when ``k`` is a python-int. The resolution lives + # entirely at AST-build time so the runtime IR/PTX is byte-identical to the named-field + # form. Runtime indices are not supported here -- the user must spell the cascade + # explicitly (cheapest implementation while we measure proposal-1 wins). + if isinstance(node.value.ptr, _FieldArrayRef): + slice_val = node.slice.ptr + node.ptr = node.value.ptr._qd_field_for(slice_val) + node.violates_pure = node.value.violates_pure + if node.violates_pure: + node.violates_pure_reason = node.value.violates_pure_reason + return node.ptr if not ASTTransformer.is_tuple(node.slice): node.slice.ptr = [node.slice.ptr] # Tensors layout: a layout-tagged Ndarray or Field with a non-identity ``_qd_layout`` has its canonical indices @@ -741,6 +753,16 @@ def build_Attribute(ctx: ASTTransformerFuncContext, node: ast.Attribute): node.ptr = node.ptr._unwrap() node.ptr = ASTTransformer._promote_ndarray_if_declared(ctx, node.ptr) else: + # ``field_array`` group access on a ``@qd.dataclass`` Struct expression. Returns a + # transient ``_FieldArrayRef`` that ``build_Subscript`` (or its assignment-LHS sibling) + # resolves to a direct field reference. The lookup is by-name on ``_qd_field_groups``, + # which ``StructType.__call__`` attaches to every Struct instance whose type declared + # at least one ``field_array`` annotation. Tested in ``test_field_array.py``. + groups = getattr(node.value.ptr, "_qd_field_groups", None) + if groups and node.attr in groups: + count, dtype, naming_fn = groups[node.attr] + node.ptr = _FieldArrayRef(node.value.ptr, node.attr, count, dtype, naming_fn) + return node.ptr node.ptr = getattr(node.value.ptr, node.attr) # ``qd.Tensor`` wrappers reached via attribute access on a ``@qd.data_oriented`` struct field at AST-build # time. The IR layer downstream (``build_Subscript`` -> ``impl.subscript``) only knows about ``Ndarray`` / diff --git a/python/quadrants/lang/impl.py b/python/quadrants/lang/impl.py index e1767a4c7b..caa234708b 100644 --- a/python/quadrants/lang/impl.py +++ b/python/quadrants/lang/impl.py @@ -103,7 +103,13 @@ def expr_init(rhs): if isinstance(rhs, BufferView): return rhs if isinstance(rhs, Struct): - return Struct(rhs.to_dict(include_methods=True, include_ndim=True)) + new_struct = Struct(rhs.to_dict(include_methods=True, include_ndim=True)) + # Preserve ``field_array`` group metadata across the rewrap; required so the AST + # transformer can still resolve ``obj.{group}[k]`` on the re-emitted Struct. + groups = getattr(rhs, "_qd_field_groups", None) + if groups is not None: + new_struct._qd_field_groups = groups + return new_struct if isinstance(rhs, list): return [expr_init(e) for e in rhs] if isinstance(rhs, tuple): diff --git a/python/quadrants/lang/struct.py b/python/quadrants/lang/struct.py index c86efb1a99..a86b472a3a 100644 --- a/python/quadrants/lang/struct.py +++ b/python/quadrants/lang/struct.py @@ -597,14 +597,120 @@ def __getitem__(self, indices): return Struct(entries) +class FieldArray: + """Type wrapper for a group of N scalar fields exposed via indexed syntax on a ``@qd.dataclass``. + + See :func:`field_array`. ``field_array(N, dtype)`` annotations on a dataclass cause the + StructType to be built with N synthetic scalar fields named ``_{group}0`` .. ``_{group}{N-1}``, + while the AST transformer rewrites ``obj.{group}[i]`` to a direct reference to + ``obj._{group}{i}`` for python-int / qd.static-resolved indices. PTX/LLVM IR is byte-identical + to the hand-rolled named-field equivalent. + """ + + def __init__(self, count, dtype): + if not isinstance(count, int) or count <= 0: + raise QuadrantsSyntaxError(f"field_array count must be a positive int, got {count!r}") + self.count = count + self.dtype = dtype + + def __repr__(self): + return f"field_array(count={self.count}, dtype={self.dtype})" + + +def field_array(count, dtype): + """Declare a group of ``count`` scalar fields of ``dtype`` on a ``@qd.dataclass``. + + Example:: + + @qd.dataclass + class Tile: + r: qd.field_array(32, qd.f32) # 32 scalar fields exposed as t.r[0..31] + + t = Tile() + t.r[5] = 1.0 # lowers to direct write of synthetic field _r5 + v = t.r[5] # same as v = t._r5 + for k in qd.static(range(32)): + t.r[k] = 0.0 # each iter is one AST node, not a 32-way cascade + + For python-int / ``qd.static``-resolved indices, ``t.r[k]`` is rewritten by the AST + transformer to the named-field access ``t._r{k}``, producing identical LLVM IR / PTX to + a struct declared with 32 individually-named scalar fields. This avoids the SROA-bailout + that occurs with ``qd.types.vector(N, dtype)`` struct fields on runtime-indexed access. + + Runtime-int indexing is currently unsupported; use an explicit cascade helper for that + case (same pattern as today's ``_get_col`` in ``_tile16.py``). + """ + return FieldArray(count, dtype) + + +def _expand_field_array_naming(group_name, index): + """Naming convention for synthetic scalar fields of a ``field_array`` group. + + Public so the AST transformer can mirror this without a circular import. + """ + return f"_{group_name}{index}" + + +class _FieldArrayRef: + """Transient proxy returned by the AST transformer for ``obj.{group}`` where ``group`` + is a registered ``field_array`` group on the struct type. + + Only valid as the value of a Subscript node: ``obj.{group}[i]``. Resolved by + ``ASTTransformer.build_Subscript`` to a direct reference to the synthetic scalar field + ``_{group}{i}`` when ``i`` is a python-int / qd.static-resolved integer. + + Used as a not-an-Expr marker; any attempt to use it as a value raises. + """ + + _qd_is_field_array_ref = True + + def __init__(self, struct, group_name: str, count: int, dtype, naming_fn): + self._qd_struct = struct + self._qd_group_name = group_name + self._qd_count = count + self._qd_dtype = dtype + self._qd_naming_fn = naming_fn + + def _qd_field_for(self, index: int): + if not isinstance(index, (int, np.integer)): + raise QuadrantsSyntaxError( + f"field_array {self._qd_group_name}[i] requires a python-int index " + f"(possibly via qd.static); got runtime index of type {type(index).__name__}" + ) + i = int(index) + if i < 0 or i >= self._qd_count: + raise QuadrantsSyntaxError( + f"field_array index out of bounds: {self._qd_group_name}[{i}] " + f"(count={self._qd_count})" + ) + field_name = self._qd_naming_fn(self._qd_group_name, i) + return getattr(self._qd_struct, field_name) + + def __repr__(self) -> str: # pragma: no cover - debug only + return ( + f"" + ) + + class StructType(CompoundType): def __init__(self, **kwargs): self.members = {} self.methods = {} + # Maps group name -> (count, dtype, naming_fn). Populated when a member annotation + # is a ``FieldArray``; consumed by the AST transformer to rewrite ``obj.{group}[i]``. + self._field_groups: dict = {} elements = [] for k, dtype in kwargs.items(): if k == "__struct_methods": self.methods = dtype + elif isinstance(dtype, FieldArray): + cooked = cook_dtype(dtype.dtype) + self._field_groups[k] = (dtype.count, cooked, _expand_field_array_naming) + for i in range(dtype.count): + sub = _expand_field_array_naming(k, i) + self.members[sub] = cooked + elements.append([cooked, sub]) elif isinstance(dtype, StructType): self.members[k] = dtype elements.append([dtype.dtype, k]) @@ -640,6 +746,10 @@ def __call__(self, *args, **kwargs): entries._Struct__dtype = self.dtype struct = self.cast(entries) struct._Struct__dtype = self.dtype + # Propagate the field-array group metadata onto the Struct instance so the AST + # transformer can detect indexed-group access (``obj.r``) on the per-trace expression. + if self._field_groups: + struct._qd_field_groups = self._field_groups return struct def __instancecheck__(self, instance): @@ -832,4 +942,4 @@ def dataclass(cls): return StructType(**fields) -__all__ = ["Struct", "StructField", "dataclass"] +__all__ = ["Struct", "StructField", "dataclass", "FieldArray", "field_array", "_FieldArrayRef"] diff --git a/tests/python/test_field_array.py b/tests/python/test_field_array.py new file mode 100644 index 0000000000..6cc1e7a40d --- /dev/null +++ b/tests/python/test_field_array.py @@ -0,0 +1,226 @@ +# pyright: reportInvalidTypeForm=false +"""Tests for ``qd.field_array(N, dtype)`` on ``@qd.dataclass``. + +The goal of ``field_array`` is to give users an ergonomic indexed-write +syntax on a per-thread struct, while keeping the underlying storage as +N separate named scalar fields (so SROA can register-promote each +element). The static-index case must lower to a direct field reference; +PTX must be byte-identical to the named-field equivalent. +""" + +import numpy as np +import pytest + +import quadrants as qd + +try: + from tests import test_utils # noqa: F401 +except ImportError: # standalone-run convenience + test_utils = None # type: ignore + + +def _qd_init_cuda(): + qd.init(arch=qd.cuda, default_fp=qd.f32, offline_cache=False) + + +# --------------------------------------------------------------------------- +# Basic API smoke tests (struct construction). +# --------------------------------------------------------------------------- + + +def test_field_array_construction_python_scope(): + """A dataclass with ``r: qd.field_array(N, dtype)`` annotation should + construct as if it had N named scalar fields ``_r0.._r{N-1}``.""" + _qd_init_cuda() + + @qd.dataclass + class Tile: + r: qd.field_array(4, qd.f32) + + # The underlying struct type should report N synthetic scalar members + # plus expose ``r`` as a group name. + assert hasattr(Tile, "_field_groups") + groups = Tile._field_groups + assert "r" in groups + count, dtype, _ = groups["r"] + assert count == 4 + assert dtype is qd.f32 + + # The underlying scalar fields must exist. + assert "_r0" in Tile.members + assert "_r3" in Tile.members + assert "r" not in Tile.members # `r` is a logical group, not a real member + + +# --------------------------------------------------------------------------- +# Static-index reads / writes (the hot path). +# --------------------------------------------------------------------------- + + +def test_field_array_static_index_write_then_read(): + _qd_init_cuda() + """Write to ``t.r[0..3]`` with python-int indices, then read back.""" + @qd.dataclass + class Tile: + r: qd.field_array(4, qd.f32) + + out = qd.field(dtype=qd.f32, shape=(4,)) + + @qd.kernel(fastcache=False) + def k(o: qd.template()): + for _ in range(1): + t = Tile() + t.r[0] = qd.f32(1.0) + t.r[1] = qd.f32(2.0) + t.r[2] = qd.f32(3.0) + t.r[3] = qd.f32(4.0) + o[0] = t.r[0] + o[1] = t.r[1] + o[2] = t.r[2] + o[3] = t.r[3] + + k(out) + np.testing.assert_array_equal(out.to_numpy(), np.array([1, 2, 3, 4], dtype=np.float32)) + + +def test_field_array_qd_static_loop_index(): + _qd_init_cuda() + """Index via a ``qd.static(range(N))`` loop variable. Each iter sees + a python-int index, so the lowering must be the same direct-field path + as the explicit python-int case.""" + @qd.dataclass + class Tile: + r: qd.field_array(4, qd.f32) + + out = qd.field(dtype=qd.f32, shape=(4,)) + + @qd.kernel(fastcache=False) + def k(o: qd.template()): + for _ in range(1): + t = Tile() + for i in qd.static(range(4)): + t.r[i] = qd.f32(10.0 + i) + for i in qd.static(range(4)): + o[i] = t.r[i] + + k(out) + np.testing.assert_array_equal(out.to_numpy(), np.array([10, 11, 12, 13], dtype=np.float32)) + + +# --------------------------------------------------------------------------- +# Equivalence with named-field baseline: identical PTX. +# --------------------------------------------------------------------------- + + +def _build_named_chol_kernel(): + """Same as test_field_array_static_index_write_then_read but with 32 + named ``r0..r31`` fields. Used for PTX byte-equality comparison.""" + @qd.dataclass + class TileNamed: + r0: qd.f32 + r1: qd.f32 + r2: qd.f32 + r3: qd.f32 + + out = qd.field(dtype=qd.f32, shape=(4,)) + + @qd.kernel(fastcache=False) + def k(o: qd.template()): + for _ in range(1): + t = TileNamed() + t.r0 = qd.f32(1.0) + t.r1 = qd.f32(2.0) + t.r2 = qd.f32(3.0) + t.r3 = qd.f32(4.0) + o[0] = t.r0 + o[1] = t.r1 + o[2] = t.r2 + o[3] = t.r3 + + return k, out + + +def _build_field_array_chol_kernel(): + @qd.dataclass + class TileFA: + r: qd.field_array(4, qd.f32) + + out = qd.field(dtype=qd.f32, shape=(4,)) + + @qd.kernel(fastcache=False) + def k(o: qd.template()): + for _ in range(1): + t = TileFA() + t.r[0] = qd.f32(1.0) + t.r[1] = qd.f32(2.0) + t.r[2] = qd.f32(3.0) + t.r[3] = qd.f32(4.0) + o[0] = t.r[0] + o[1] = t.r[1] + o[2] = t.r[2] + o[3] = t.r[3] + + return k, out + + +def test_field_array_runtime_index_rejected(): + """Indexing ``t.r[k]`` with a runtime ``k`` raises a clear error pointing the user + at the (current) python-int / ``qd.static`` requirement. Long term we will lower the + runtime case to an explicit cascade; for now we surface the limitation early so + callers don't get a confusing LLVM/SROA failure downstream.""" + _qd_init_cuda() + + @qd.dataclass + class Tile: + r: qd.field_array(4, qd.f32) + + out = qd.field(dtype=qd.f32, shape=(4,)) + + @qd.kernel(fastcache=False) + def k(o: qd.template()): + for _ in range(1): + t = Tile() + for i in range(4): # runtime loop, not qd.static + t.r[i] = qd.f32(i) + o[0] = t.r[0] + + with pytest.raises(Exception) as e: + k(out) + msg = str(e.value) + assert "field_array" in msg and "python-int" in msg, msg + + +def test_field_array_oob_static_index(): + """Static-int out-of-bounds index is caught at trace time with a clear message.""" + _qd_init_cuda() + + @qd.dataclass + class Tile: + r: qd.field_array(4, qd.f32) + + out = qd.field(dtype=qd.f32, shape=(4,)) + + @qd.kernel(fastcache=False) + def k(o: qd.template()): + for _ in range(1): + t = Tile() + t.r[7] = qd.f32(1.0) # 7 >= count=4 + o[0] = t.r[0] + + with pytest.raises(Exception) as e: + k(out) + assert "out of bounds" in str(e.value), str(e.value) + + +if __name__ == "__main__": + # Quick run when invoked directly: drives implementation work. + test_field_array_construction_python_scope() + print("construction test passed") + test_field_array_static_index_write_then_read() + print("static-int subscript test passed") + test_field_array_qd_static_loop_index() + print("qd.static loop-var subscript test passed") + test_field_array_runtime_index_rejected() + print("runtime-index rejection test passed") + test_field_array_oob_static_index() + print("static OOB rejection test passed") From 6f3a86340c6a01009995bd91aef0a54573957576 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Wed, 27 May 2026 10:27:14 -0700 Subject: [PATCH 02/11] [RegisterArray] rename qd.field_array -> qd.register_array (PR feedback) Following review of PR #712, rename the indexed-fields-on-@qd.dataclass primitive from `field_array` to `register_array` to better reflect its intent. The mechanism declares N independently-allocated scalar struct members so SROA + mem2reg can register-promote each slot independently, side-stepping the SROA-bailout that affects packed `qd.types.vector(N, dtype)` storage under register pressure. Renames: - Public: qd.field_array(N, dtype) -> qd.register_array(N, dtype) FieldArray -> RegisterArray - Private: _FieldArrayRef -> _RegisterArrayRef _qd_field_groups (struct) -> _qd_register_groups _field_groups (type) -> _register_groups _expand_field_array_naming -> _expand_register_array_naming _qd_is_field_array_ref -> _qd_is_register_array_ref - Tests: tests/python/test_field_array.py -> test_register_array.py The internal `_qd_field_for(i)` helper on the proxy is kept (its return value is a struct *field*, which is still the correct domain term). No behavioural change; the 5 pytest cases (construction, static-int subscript, qd.static loop subscript, runtime-index rejection, static OOB rejection) all pass under the new name. --- python/quadrants/lang/ast/ast_transformer.py | 22 ++--- python/quadrants/lang/impl.py | 6 +- python/quadrants/lang/struct.py | 65 ++++++++------ ..._field_array.py => test_register_array.py} | 89 +++++++++---------- 4 files changed, 94 insertions(+), 88 deletions(-) rename tests/python/{test_field_array.py => test_register_array.py} (67%) diff --git a/python/quadrants/lang/ast/ast_transformer.py b/python/quadrants/lang/ast/ast_transformer.py index 325b8e7d94..7d2fb65de1 100644 --- a/python/quadrants/lang/ast/ast_transformer.py +++ b/python/quadrants/lang/ast/ast_transformer.py @@ -39,7 +39,7 @@ from quadrants.lang.field import Field from quadrants.lang.matrix import Matrix, MatrixType from quadrants.lang.snode import append, deactivate, length -from quadrants.lang.struct import Struct, StructType, _FieldArrayRef +from quadrants.lang.struct import Struct, StructType, _RegisterArrayRef from quadrants.lang.util import ( is_from_quadrants_module as _is_from_quadrants_module, ) @@ -295,12 +295,12 @@ def _unpack_layout_vector_index(ast_builder, index, layout_len): def build_Subscript(ctx: ASTTransformerFuncContext, node: ast.Subscript): build_stmt(ctx, node.value) build_stmt(ctx, node.slice) - # ``field_array`` group subscript: rewrite ``obj.{group}[k]`` to a direct reference to the + # ``register_array`` group subscript: rewrite ``obj.{group}[k]`` to a direct reference to the # synthetic scalar field ``_{group}{k}`` when ``k`` is a python-int. The resolution lives # entirely at AST-build time so the runtime IR/PTX is byte-identical to the named-field # form. Runtime indices are not supported here -- the user must spell the cascade - # explicitly (cheapest implementation while we measure proposal-1 wins). - if isinstance(node.value.ptr, _FieldArrayRef): + # explicitly. + if isinstance(node.value.ptr, _RegisterArrayRef): slice_val = node.slice.ptr node.ptr = node.value.ptr._qd_field_for(slice_val) node.violates_pure = node.value.violates_pure @@ -753,15 +753,15 @@ def build_Attribute(ctx: ASTTransformerFuncContext, node: ast.Attribute): node.ptr = node.ptr._unwrap() node.ptr = ASTTransformer._promote_ndarray_if_declared(ctx, node.ptr) else: - # ``field_array`` group access on a ``@qd.dataclass`` Struct expression. Returns a - # transient ``_FieldArrayRef`` that ``build_Subscript`` (or its assignment-LHS sibling) - # resolves to a direct field reference. The lookup is by-name on ``_qd_field_groups``, - # which ``StructType.__call__`` attaches to every Struct instance whose type declared - # at least one ``field_array`` annotation. Tested in ``test_field_array.py``. - groups = getattr(node.value.ptr, "_qd_field_groups", None) + # ``register_array`` group access on a ``@qd.dataclass`` Struct expression. Returns a + # transient ``_RegisterArrayRef`` that ``build_Subscript`` (or its assignment-LHS sibling) + # resolves to a direct field reference. The lookup is by-name on ``_qd_register_groups``, + # which ``StructType.__call__`` attaches to every Struct instance whose type declared at + # least one ``register_array`` annotation. Tested in ``test_register_array.py``. + groups = getattr(node.value.ptr, "_qd_register_groups", None) if groups and node.attr in groups: count, dtype, naming_fn = groups[node.attr] - node.ptr = _FieldArrayRef(node.value.ptr, node.attr, count, dtype, naming_fn) + node.ptr = _RegisterArrayRef(node.value.ptr, node.attr, count, dtype, naming_fn) return node.ptr node.ptr = getattr(node.value.ptr, node.attr) # ``qd.Tensor`` wrappers reached via attribute access on a ``@qd.data_oriented`` struct field at AST-build diff --git a/python/quadrants/lang/impl.py b/python/quadrants/lang/impl.py index caa234708b..d1778b3655 100644 --- a/python/quadrants/lang/impl.py +++ b/python/quadrants/lang/impl.py @@ -104,11 +104,11 @@ def expr_init(rhs): return rhs if isinstance(rhs, Struct): new_struct = Struct(rhs.to_dict(include_methods=True, include_ndim=True)) - # Preserve ``field_array`` group metadata across the rewrap; required so the AST + # Preserve ``register_array`` group metadata across the rewrap; required so the AST # transformer can still resolve ``obj.{group}[k]`` on the re-emitted Struct. - groups = getattr(rhs, "_qd_field_groups", None) + groups = getattr(rhs, "_qd_register_groups", None) if groups is not None: - new_struct._qd_field_groups = groups + new_struct._qd_register_groups = groups return new_struct if isinstance(rhs, list): return [expr_init(e) for e in rhs] diff --git a/python/quadrants/lang/struct.py b/python/quadrants/lang/struct.py index a86b472a3a..f259a7dd03 100644 --- a/python/quadrants/lang/struct.py +++ b/python/quadrants/lang/struct.py @@ -597,10 +597,10 @@ def __getitem__(self, indices): return Struct(entries) -class FieldArray: +class RegisterArray: """Type wrapper for a group of N scalar fields exposed via indexed syntax on a ``@qd.dataclass``. - See :func:`field_array`. ``field_array(N, dtype)`` annotations on a dataclass cause the + See :func:`register_array`. ``register_array(N, dtype)`` annotations on a dataclass cause the StructType to be built with N synthetic scalar fields named ``_{group}0`` .. ``_{group}{N-1}``, while the AST transformer rewrites ``obj.{group}[i]`` to a direct reference to ``obj._{group}{i}`` for python-int / qd.static-resolved indices. PTX/LLVM IR is byte-identical @@ -609,22 +609,30 @@ class FieldArray: def __init__(self, count, dtype): if not isinstance(count, int) or count <= 0: - raise QuadrantsSyntaxError(f"field_array count must be a positive int, got {count!r}") + raise QuadrantsSyntaxError(f"register_array count must be a positive int, got {count!r}") self.count = count self.dtype = dtype def __repr__(self): - return f"field_array(count={self.count}, dtype={self.dtype})" + return f"register_array(count={self.count}, dtype={self.dtype})" -def field_array(count, dtype): - """Declare a group of ``count`` scalar fields of ``dtype`` on a ``@qd.dataclass``. +def register_array(count, dtype): + """Declare a group of ``count`` independently-allocated fields of ``dtype`` on a ``@qd.dataclass``. + + The annotation expands at struct-definition time into ``count`` individually-named scalar + members (``_{group}0`` .. ``_{group}{count-1}``). Each member gets its own LLVM ``alloca``, + which lets SROA + ``mem2reg`` promote each slot into its own SSA value independently. The + motivation is register residency under pressure: ``qd.types.vector(N, dtype)`` collapses + into one packed ``alloca`` that the optimiser often spills as a unit when register pressure + crosses a threshold; ``register_array`` decomposes the storage up-front so the compiler can + keep individual slots in registers and only spill the ones it has to. Example:: @qd.dataclass class Tile: - r: qd.field_array(32, qd.f32) # 32 scalar fields exposed as t.r[0..31] + r: qd.register_array(32, qd.f32) # 32 scalar fields exposed as t.r[0..31] t = Tile() t.r[5] = 1.0 # lowers to direct write of synthetic field _r5 @@ -634,26 +642,25 @@ class Tile: For python-int / ``qd.static``-resolved indices, ``t.r[k]`` is rewritten by the AST transformer to the named-field access ``t._r{k}``, producing identical LLVM IR / PTX to - a struct declared with 32 individually-named scalar fields. This avoids the SROA-bailout - that occurs with ``qd.types.vector(N, dtype)`` struct fields on runtime-indexed access. + a struct declared with N individually-named scalar fields. Runtime-int indexing is currently unsupported; use an explicit cascade helper for that - case (same pattern as today's ``_get_col`` in ``_tile16.py``). + case. """ - return FieldArray(count, dtype) + return RegisterArray(count, dtype) -def _expand_field_array_naming(group_name, index): - """Naming convention for synthetic scalar fields of a ``field_array`` group. +def _expand_register_array_naming(group_name, index): + """Naming convention for the synthetic scalar fields of a ``register_array`` group. Public so the AST transformer can mirror this without a circular import. """ return f"_{group_name}{index}" -class _FieldArrayRef: +class _RegisterArrayRef: """Transient proxy returned by the AST transformer for ``obj.{group}`` where ``group`` - is a registered ``field_array`` group on the struct type. + is a registered ``register_array`` group on the struct type. Only valid as the value of a Subscript node: ``obj.{group}[i]``. Resolved by ``ASTTransformer.build_Subscript`` to a direct reference to the synthetic scalar field @@ -662,7 +669,7 @@ class _FieldArrayRef: Used as a not-an-Expr marker; any attempt to use it as a value raises. """ - _qd_is_field_array_ref = True + _qd_is_register_array_ref = True def __init__(self, struct, group_name: str, count: int, dtype, naming_fn): self._qd_struct = struct @@ -674,13 +681,13 @@ def __init__(self, struct, group_name: str, count: int, dtype, naming_fn): def _qd_field_for(self, index: int): if not isinstance(index, (int, np.integer)): raise QuadrantsSyntaxError( - f"field_array {self._qd_group_name}[i] requires a python-int index " + f"register_array {self._qd_group_name}[i] requires a python-int index " f"(possibly via qd.static); got runtime index of type {type(index).__name__}" ) i = int(index) if i < 0 or i >= self._qd_count: raise QuadrantsSyntaxError( - f"field_array index out of bounds: {self._qd_group_name}[{i}] " + f"register_array index out of bounds: {self._qd_group_name}[{i}] " f"(count={self._qd_count})" ) field_name = self._qd_naming_fn(self._qd_group_name, i) @@ -688,7 +695,7 @@ def _qd_field_for(self, index: int): def __repr__(self) -> str: # pragma: no cover - debug only return ( - f"" ) @@ -697,18 +704,18 @@ class StructType(CompoundType): def __init__(self, **kwargs): self.members = {} self.methods = {} - # Maps group name -> (count, dtype, naming_fn). Populated when a member annotation - # is a ``FieldArray``; consumed by the AST transformer to rewrite ``obj.{group}[i]``. - self._field_groups: dict = {} + # Maps group name -> (count, dtype, naming_fn). Populated when a member annotation is a + # ``RegisterArray``; consumed by the AST transformer to rewrite ``obj.{group}[i]``. + self._register_groups: dict = {} elements = [] for k, dtype in kwargs.items(): if k == "__struct_methods": self.methods = dtype - elif isinstance(dtype, FieldArray): + elif isinstance(dtype, RegisterArray): cooked = cook_dtype(dtype.dtype) - self._field_groups[k] = (dtype.count, cooked, _expand_field_array_naming) + self._register_groups[k] = (dtype.count, cooked, _expand_register_array_naming) for i in range(dtype.count): - sub = _expand_field_array_naming(k, i) + sub = _expand_register_array_naming(k, i) self.members[sub] = cooked elements.append([cooked, sub]) elif isinstance(dtype, StructType): @@ -746,10 +753,10 @@ def __call__(self, *args, **kwargs): entries._Struct__dtype = self.dtype struct = self.cast(entries) struct._Struct__dtype = self.dtype - # Propagate the field-array group metadata onto the Struct instance so the AST + # Propagate the register-array group metadata onto the Struct instance so the AST # transformer can detect indexed-group access (``obj.r``) on the per-trace expression. - if self._field_groups: - struct._qd_field_groups = self._field_groups + if self._register_groups: + struct._qd_register_groups = self._register_groups return struct def __instancecheck__(self, instance): @@ -942,4 +949,4 @@ def dataclass(cls): return StructType(**fields) -__all__ = ["Struct", "StructField", "dataclass", "FieldArray", "field_array", "_FieldArrayRef"] +__all__ = ["Struct", "StructField", "dataclass", "RegisterArray", "register_array", "_RegisterArrayRef"] diff --git a/tests/python/test_field_array.py b/tests/python/test_register_array.py similarity index 67% rename from tests/python/test_field_array.py rename to tests/python/test_register_array.py index 6cc1e7a40d..b4eea0bb17 100644 --- a/tests/python/test_field_array.py +++ b/tests/python/test_register_array.py @@ -1,11 +1,10 @@ # pyright: reportInvalidTypeForm=false -"""Tests for ``qd.field_array(N, dtype)`` on ``@qd.dataclass``. +"""Tests for ``qd.register_array(N, dtype)`` on ``@qd.dataclass``. -The goal of ``field_array`` is to give users an ergonomic indexed-write -syntax on a per-thread struct, while keeping the underlying storage as -N separate named scalar fields (so SROA can register-promote each -element). The static-index case must lower to a direct field reference; -PTX must be byte-identical to the named-field equivalent. +``register_array`` gives users an ergonomic indexed-write syntax on a per-thread struct, while +keeping the underlying storage as N separate named scalar fields so SROA + ``mem2reg`` can +register-promote each slot independently. The static-index case must lower to a direct field +reference; PTX must be byte-identical to the named-field equivalent. """ import numpy as np @@ -28,19 +27,19 @@ def _qd_init_cuda(): # --------------------------------------------------------------------------- -def test_field_array_construction_python_scope(): - """A dataclass with ``r: qd.field_array(N, dtype)`` annotation should - construct as if it had N named scalar fields ``_r0.._r{N-1}``.""" +def test_register_array_construction_python_scope(): + """A dataclass with ``r: qd.register_array(N, dtype)`` should construct as if it had N + named scalar fields ``_r0.._r{N-1}``.""" _qd_init_cuda() @qd.dataclass class Tile: - r: qd.field_array(4, qd.f32) + r: qd.register_array(4, qd.f32) - # The underlying struct type should report N synthetic scalar members - # plus expose ``r`` as a group name. - assert hasattr(Tile, "_field_groups") - groups = Tile._field_groups + # The underlying struct type should report N synthetic scalar members plus expose ``r`` as + # a group name. + assert hasattr(Tile, "_register_groups") + groups = Tile._register_groups assert "r" in groups count, dtype, _ = groups["r"] assert count == 4 @@ -49,7 +48,7 @@ class Tile: # The underlying scalar fields must exist. assert "_r0" in Tile.members assert "_r3" in Tile.members - assert "r" not in Tile.members # `r` is a logical group, not a real member + assert "r" not in Tile.members # ``r`` is a logical group, not a real member # --------------------------------------------------------------------------- @@ -57,12 +56,13 @@ class Tile: # --------------------------------------------------------------------------- -def test_field_array_static_index_write_then_read(): - _qd_init_cuda() +def test_register_array_static_index_write_then_read(): """Write to ``t.r[0..3]`` with python-int indices, then read back.""" + _qd_init_cuda() + @qd.dataclass class Tile: - r: qd.field_array(4, qd.f32) + r: qd.register_array(4, qd.f32) out = qd.field(dtype=qd.f32, shape=(4,)) @@ -83,14 +83,14 @@ def k(o: qd.template()): np.testing.assert_array_equal(out.to_numpy(), np.array([1, 2, 3, 4], dtype=np.float32)) -def test_field_array_qd_static_loop_index(): +def test_register_array_qd_static_loop_index(): + """Index via a ``qd.static(range(N))`` loop variable. Each iter sees a python-int index, + so the lowering must be the same direct-field path as the explicit python-int case.""" _qd_init_cuda() - """Index via a ``qd.static(range(N))`` loop variable. Each iter sees - a python-int index, so the lowering must be the same direct-field path - as the explicit python-int case.""" + @qd.dataclass class Tile: - r: qd.field_array(4, qd.f32) + r: qd.register_array(4, qd.f32) out = qd.field(dtype=qd.f32, shape=(4,)) @@ -112,9 +112,9 @@ def k(o: qd.template()): # --------------------------------------------------------------------------- -def _build_named_chol_kernel(): - """Same as test_field_array_static_index_write_then_read but with 32 - named ``r0..r31`` fields. Used for PTX byte-equality comparison.""" +def _build_named_kernel(): + """Same as test_register_array_static_index_write_then_read but with 4 named ``r0..r3`` + fields. Used for PTX byte-equality comparison against the ``register_array`` form.""" @qd.dataclass class TileNamed: r0: qd.f32 @@ -140,17 +140,17 @@ def k(o: qd.template()): return k, out -def _build_field_array_chol_kernel(): +def _build_register_array_kernel(): @qd.dataclass - class TileFA: - r: qd.field_array(4, qd.f32) + class TileRA: + r: qd.register_array(4, qd.f32) out = qd.field(dtype=qd.f32, shape=(4,)) @qd.kernel(fastcache=False) def k(o: qd.template()): for _ in range(1): - t = TileFA() + t = TileRA() t.r[0] = qd.f32(1.0) t.r[1] = qd.f32(2.0) t.r[2] = qd.f32(3.0) @@ -163,16 +163,16 @@ def k(o: qd.template()): return k, out -def test_field_array_runtime_index_rejected(): - """Indexing ``t.r[k]`` with a runtime ``k`` raises a clear error pointing the user - at the (current) python-int / ``qd.static`` requirement. Long term we will lower the - runtime case to an explicit cascade; for now we surface the limitation early so - callers don't get a confusing LLVM/SROA failure downstream.""" +def test_register_array_runtime_index_rejected(): + """Indexing ``t.r[k]`` with a runtime ``k`` raises a clear error pointing at the + python-int / ``qd.static`` requirement. Long term the runtime case can lower to an + explicit cascade; for now the limitation is surfaced early so callers don't get a + confusing LLVM/SROA failure downstream.""" _qd_init_cuda() @qd.dataclass class Tile: - r: qd.field_array(4, qd.f32) + r: qd.register_array(4, qd.f32) out = qd.field(dtype=qd.f32, shape=(4,)) @@ -187,16 +187,16 @@ def k(o: qd.template()): with pytest.raises(Exception) as e: k(out) msg = str(e.value) - assert "field_array" in msg and "python-int" in msg, msg + assert "register_array" in msg and "python-int" in msg, msg -def test_field_array_oob_static_index(): +def test_register_array_oob_static_index(): """Static-int out-of-bounds index is caught at trace time with a clear message.""" _qd_init_cuda() @qd.dataclass class Tile: - r: qd.field_array(4, qd.f32) + r: qd.register_array(4, qd.f32) out = qd.field(dtype=qd.f32, shape=(4,)) @@ -213,14 +213,13 @@ def k(o: qd.template()): if __name__ == "__main__": - # Quick run when invoked directly: drives implementation work. - test_field_array_construction_python_scope() + test_register_array_construction_python_scope() print("construction test passed") - test_field_array_static_index_write_then_read() + test_register_array_static_index_write_then_read() print("static-int subscript test passed") - test_field_array_qd_static_loop_index() + test_register_array_qd_static_loop_index() print("qd.static loop-var subscript test passed") - test_field_array_runtime_index_rejected() + test_register_array_runtime_index_rejected() print("runtime-index rejection test passed") - test_field_array_oob_static_index() + test_register_array_oob_static_index() print("static OOB rejection test passed") From 7df45f84f1edfb931489a4e81684361c329820f7 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Wed, 27 May 2026 10:33:37 -0700 Subject: [PATCH 03/11] [RegisterArray] add register_array + RegisterArray to test_api.py expected list test_api.py asserts the public surface of `qd` against a sorted golden list. The register_array PR added two new public names (the constructor `register_array` and the type wrapper `RegisterArray`); add them to the expected list so test_api[arch=*-quadrants] passes again on macOS / CI. --- tests/python/test_api.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/python/test_api.py b/tests/python/test_api.py index 2f183d5ba7..8bf7acce6c 100644 --- a/tests/python/test_api.py +++ b/tests/python/test_api.py @@ -75,6 +75,7 @@ def _get_expected_matrix_apis(): "Mesh", "MeshInstance", "Ndarray", + "RegisterArray", "SNode", "ScalarField", "ScalarNdarray", @@ -214,6 +215,7 @@ def _get_expected_matrix_apis(): "raw_mod", "real_func", "ref", + "register_array", "rescale_index", "reset", "root", From d44110338a8cd3de25024e4f5f4a969cf88f04c2 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Wed, 27 May 2026 10:36:19 -0700 Subject: [PATCH 04/11] [RegisterArray] split register_array out of struct.py into its own module Addresses the large-file lint on PR #712: the four register_array definitions (RegisterArray, register_array, _expand_register_array_naming, _RegisterArrayRef) are self-contained (only depend on numpy + QuadrantsSyntaxError) and have zero coupling to struct.py internals, so they belong in their own module like field.py / matrix.py. - python/quadrants/lang/register_array.py: new module with the four definitions. - python/quadrants/lang/struct.py: drop the inline definitions, import from the new module. `__all__` still re-exports the public names so `from quadrants.lang.struct import *` (run via `quadrants.lang.__init__`) keeps surfacing them as `qd.register_array` and `qd.RegisterArray`. - python/quadrants/lang/ast/ast_transformer.py: import `_RegisterArrayRef` directly from the new module instead of via struct.py, avoiding a re-export hop. No behavioural change; 224 tests + 9 xfail across test_register_array.py, test_api.py, test_py_dataclass.py, test_complex_struct.py, test_struct.py all pass. --- python/quadrants/lang/ast/ast_transformer.py | 3 +- python/quadrants/lang/register_array.py | 136 +++++++++++++++++++ python/quadrants/lang/struct.py | 109 +-------------- 3 files changed, 144 insertions(+), 104 deletions(-) create mode 100644 python/quadrants/lang/register_array.py diff --git a/python/quadrants/lang/ast/ast_transformer.py b/python/quadrants/lang/ast/ast_transformer.py index 7d2fb65de1..998f25e24f 100644 --- a/python/quadrants/lang/ast/ast_transformer.py +++ b/python/quadrants/lang/ast/ast_transformer.py @@ -39,7 +39,8 @@ from quadrants.lang.field import Field from quadrants.lang.matrix import Matrix, MatrixType from quadrants.lang.snode import append, deactivate, length -from quadrants.lang.struct import Struct, StructType, _RegisterArrayRef +from quadrants.lang.register_array import _RegisterArrayRef +from quadrants.lang.struct import Struct, StructType from quadrants.lang.util import ( is_from_quadrants_module as _is_from_quadrants_module, ) diff --git a/python/quadrants/lang/register_array.py b/python/quadrants/lang/register_array.py new file mode 100644 index 0000000000..b297482cc9 --- /dev/null +++ b/python/quadrants/lang/register_array.py @@ -0,0 +1,136 @@ +# type: ignore +"""``qd.register_array`` -- indexed groups of independently-allocated scalar fields on +``@qd.dataclass``. + +A ``register_array(N, dtype)`` annotation expands at struct-definition time into N +individually-named synthetic scalar members (``_{group}0`` .. ``_{group}{N-1}``). The AST +transformer rewrites ``obj.{group}[i]`` into a direct reference to ``obj._{group}{i}`` for +python-int / ``qd.static``-resolved indices, so generated LLVM IR / PTX is byte-identical to +a hand-rolled named-field struct. + +The motivation is register residency under pressure. A packed ``qd.types.vector(N, dtype)`` +collapses into one ``alloca`` that LLVM SROA cannot decompose once register pressure crosses +a threshold (e.g. two concurrent tiles in a Cholesky + TRSM kernel), causing spills to local +memory. ``register_array`` pre-decomposes the storage so SROA + ``mem2reg`` can +register-promote each slot independently, while keeping the ergonomic indexed-access syntax +at the source level. + +Public: +- ``RegisterArray`` - type wrapper used as the annotation value +- ``register_array`` - factory: ``r: qd.register_array(N, dtype)`` + +Internal (used by ``StructType`` and the AST transformer): +- ``_expand_register_array_naming(group, i)`` - synthetic-field naming convention +- ``_RegisterArrayRef`` - transient proxy yielded by attribute access + +This module has no dependency on ``struct.py``; ``struct.py`` imports from here. +""" + +import numpy as np + +from quadrants.lang.exception import QuadrantsSyntaxError + + +class RegisterArray: + """Type wrapper for a group of N scalar fields exposed via indexed syntax on a ``@qd.dataclass``. + + See :func:`register_array` for the user-facing constructor and the motivation writeup. + Holding only ``count`` and ``dtype``, this object is consumed at struct-definition time + by ``StructType.__init__`` to lay out the N synthetic scalar fields. + """ + + def __init__(self, count, dtype): + if not isinstance(count, int) or count <= 0: + raise QuadrantsSyntaxError(f"register_array count must be a positive int, got {count!r}") + self.count = count + self.dtype = dtype + + def __repr__(self): + return f"register_array(count={self.count}, dtype={self.dtype})" + + +def register_array(count, dtype): + """Declare a group of ``count`` independently-allocated fields of ``dtype`` on a ``@qd.dataclass``. + + The annotation expands at struct-definition time into ``count`` individually-named scalar + members (``_{group}0`` .. ``_{group}{count-1}``). Each member gets its own LLVM + ``alloca``, which lets SROA + ``mem2reg`` promote each slot into its own SSA value + independently. The motivation is register residency under pressure: + ``qd.types.vector(N, dtype)`` collapses into one packed ``alloca`` that the optimiser + often spills as a unit when register pressure crosses a threshold; ``register_array`` + decomposes the storage up-front so the compiler can keep individual slots in registers + and only spill the ones it has to. + + Example:: + + @qd.dataclass + class Tile: + r: qd.register_array(32, qd.f32) # 32 scalar fields exposed as t.r[0..31] + + t = Tile() + t.r[5] = 1.0 # lowers to direct write of synthetic field _r5 + v = t.r[5] # same as v = t._r5 + for k in qd.static(range(32)): + t.r[k] = 0.0 # each iter is one AST node, not a 32-way cascade + + For python-int / ``qd.static``-resolved indices, ``t.r[k]`` is rewritten by the AST + transformer to the named-field access ``t._r{k}``, producing identical LLVM IR / PTX to + a struct declared with N individually-named scalar fields. + + Runtime-int indexing is currently unsupported; use an explicit cascade helper for that + case. + """ + return RegisterArray(count, dtype) + + +def _expand_register_array_naming(group_name, index): + """Naming convention for the synthetic scalar fields of a ``register_array`` group. + + Public-ish so the AST transformer can mirror this without a circular import on ``struct``. + """ + return f"_{group_name}{index}" + + +class _RegisterArrayRef: + """Transient proxy returned by the AST transformer for ``obj.{group}`` where ``group`` is + a registered ``register_array`` group on the struct type. + + Only valid as the value of a Subscript node: ``obj.{group}[i]``. Resolved by + ``ASTTransformer.build_Subscript`` to a direct reference to the synthetic scalar field + ``_{group}{i}`` when ``i`` is a python-int / ``qd.static``-resolved integer. + + Used as a not-an-Expr marker; any attempt to use it as a value raises. + """ + + _qd_is_register_array_ref = True + + def __init__(self, struct, group_name: str, count: int, dtype, naming_fn): + self._qd_struct = struct + self._qd_group_name = group_name + self._qd_count = count + self._qd_dtype = dtype + self._qd_naming_fn = naming_fn + + def _qd_field_for(self, index: int): + if not isinstance(index, (int, np.integer)): + raise QuadrantsSyntaxError( + f"register_array {self._qd_group_name}[i] requires a python-int index " + f"(possibly via qd.static); got runtime index of type {type(index).__name__}" + ) + i = int(index) + if i < 0 or i >= self._qd_count: + raise QuadrantsSyntaxError( + f"register_array index out of bounds: {self._qd_group_name}[{i}] " + f"(count={self._qd_count})" + ) + field_name = self._qd_naming_fn(self._qd_group_name, i) + return getattr(self._qd_struct, field_name) + + def __repr__(self) -> str: # pragma: no cover - debug only + return ( + f"" + ) + + +__all__ = ["RegisterArray", "register_array"] diff --git a/python/quadrants/lang/struct.py b/python/quadrants/lang/struct.py index f259a7dd03..f5de28a40e 100644 --- a/python/quadrants/lang/struct.py +++ b/python/quadrants/lang/struct.py @@ -15,6 +15,12 @@ from quadrants.lang.expr import Expr from quadrants.lang.field import Field, ScalarField, SNodeHostAccess from quadrants.lang.matrix import Matrix, MatrixType +from quadrants.lang.register_array import ( + RegisterArray, + _RegisterArrayRef, + _expand_register_array_naming, + register_array, +) from quadrants.lang.util import ( cook_dtype, in_python_scope, @@ -597,109 +603,6 @@ def __getitem__(self, indices): return Struct(entries) -class RegisterArray: - """Type wrapper for a group of N scalar fields exposed via indexed syntax on a ``@qd.dataclass``. - - See :func:`register_array`. ``register_array(N, dtype)`` annotations on a dataclass cause the - StructType to be built with N synthetic scalar fields named ``_{group}0`` .. ``_{group}{N-1}``, - while the AST transformer rewrites ``obj.{group}[i]`` to a direct reference to - ``obj._{group}{i}`` for python-int / qd.static-resolved indices. PTX/LLVM IR is byte-identical - to the hand-rolled named-field equivalent. - """ - - def __init__(self, count, dtype): - if not isinstance(count, int) or count <= 0: - raise QuadrantsSyntaxError(f"register_array count must be a positive int, got {count!r}") - self.count = count - self.dtype = dtype - - def __repr__(self): - return f"register_array(count={self.count}, dtype={self.dtype})" - - -def register_array(count, dtype): - """Declare a group of ``count`` independently-allocated fields of ``dtype`` on a ``@qd.dataclass``. - - The annotation expands at struct-definition time into ``count`` individually-named scalar - members (``_{group}0`` .. ``_{group}{count-1}``). Each member gets its own LLVM ``alloca``, - which lets SROA + ``mem2reg`` promote each slot into its own SSA value independently. The - motivation is register residency under pressure: ``qd.types.vector(N, dtype)`` collapses - into one packed ``alloca`` that the optimiser often spills as a unit when register pressure - crosses a threshold; ``register_array`` decomposes the storage up-front so the compiler can - keep individual slots in registers and only spill the ones it has to. - - Example:: - - @qd.dataclass - class Tile: - r: qd.register_array(32, qd.f32) # 32 scalar fields exposed as t.r[0..31] - - t = Tile() - t.r[5] = 1.0 # lowers to direct write of synthetic field _r5 - v = t.r[5] # same as v = t._r5 - for k in qd.static(range(32)): - t.r[k] = 0.0 # each iter is one AST node, not a 32-way cascade - - For python-int / ``qd.static``-resolved indices, ``t.r[k]`` is rewritten by the AST - transformer to the named-field access ``t._r{k}``, producing identical LLVM IR / PTX to - a struct declared with N individually-named scalar fields. - - Runtime-int indexing is currently unsupported; use an explicit cascade helper for that - case. - """ - return RegisterArray(count, dtype) - - -def _expand_register_array_naming(group_name, index): - """Naming convention for the synthetic scalar fields of a ``register_array`` group. - - Public so the AST transformer can mirror this without a circular import. - """ - return f"_{group_name}{index}" - - -class _RegisterArrayRef: - """Transient proxy returned by the AST transformer for ``obj.{group}`` where ``group`` - is a registered ``register_array`` group on the struct type. - - Only valid as the value of a Subscript node: ``obj.{group}[i]``. Resolved by - ``ASTTransformer.build_Subscript`` to a direct reference to the synthetic scalar field - ``_{group}{i}`` when ``i`` is a python-int / qd.static-resolved integer. - - Used as a not-an-Expr marker; any attempt to use it as a value raises. - """ - - _qd_is_register_array_ref = True - - def __init__(self, struct, group_name: str, count: int, dtype, naming_fn): - self._qd_struct = struct - self._qd_group_name = group_name - self._qd_count = count - self._qd_dtype = dtype - self._qd_naming_fn = naming_fn - - def _qd_field_for(self, index: int): - if not isinstance(index, (int, np.integer)): - raise QuadrantsSyntaxError( - f"register_array {self._qd_group_name}[i] requires a python-int index " - f"(possibly via qd.static); got runtime index of type {type(index).__name__}" - ) - i = int(index) - if i < 0 or i >= self._qd_count: - raise QuadrantsSyntaxError( - f"register_array index out of bounds: {self._qd_group_name}[{i}] " - f"(count={self._qd_count})" - ) - field_name = self._qd_naming_fn(self._qd_group_name, i) - return getattr(self._qd_struct, field_name) - - def __repr__(self) -> str: # pragma: no cover - debug only - return ( - f"" - ) - - class StructType(CompoundType): def __init__(self, **kwargs): self.members = {} From 7bb5958b2ad7a2b7e5ca80549653f89556f5b9fb Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Wed, 27 May 2026 10:41:45 -0700 Subject: [PATCH 05/11] [RegisterArray] reflow comments + docstrings to 120c per project convention Run find_underwrapped.py over the rename diff and tighten every prose run that was wrapped at the AI-default ~85-95c down to the project's 120c target. Pure prose-layout change -- no semantic edits. Only legitimate skips per the skill are the two doctest code blocks inside the `Example::` docstring in register_array.py (21c and 76c runs, deliberately formatted code samples). Files touched: register_array.py, struct.py, impl.py, ast/ast_transformer.py, tests/python/test_register_array.py. Tests still 23/23. --- python/quadrants/lang/ast/ast_transformer.py | 19 +++--- python/quadrants/lang/impl.py | 4 +- python/quadrants/lang/register_array.py | 65 +++++++++----------- python/quadrants/lang/struct.py | 8 +-- tests/python/test_register_array.py | 29 ++++----- 5 files changed, 57 insertions(+), 68 deletions(-) diff --git a/python/quadrants/lang/ast/ast_transformer.py b/python/quadrants/lang/ast/ast_transformer.py index 998f25e24f..7d8baa9566 100644 --- a/python/quadrants/lang/ast/ast_transformer.py +++ b/python/quadrants/lang/ast/ast_transformer.py @@ -296,11 +296,10 @@ def _unpack_layout_vector_index(ast_builder, index, layout_len): def build_Subscript(ctx: ASTTransformerFuncContext, node: ast.Subscript): build_stmt(ctx, node.value) build_stmt(ctx, node.slice) - # ``register_array`` group subscript: rewrite ``obj.{group}[k]`` to a direct reference to the - # synthetic scalar field ``_{group}{k}`` when ``k`` is a python-int. The resolution lives - # entirely at AST-build time so the runtime IR/PTX is byte-identical to the named-field - # form. Runtime indices are not supported here -- the user must spell the cascade - # explicitly. + # ``register_array`` group subscript: rewrite ``obj.{group}[k]`` to a direct reference to the synthetic scalar + # field ``_{group}{k}`` when ``k`` is a python-int. The resolution lives entirely at AST-build time so the + # runtime IR/PTX is byte-identical to the named-field form. Runtime indices are not supported here -- the user + # must spell the cascade explicitly. if isinstance(node.value.ptr, _RegisterArrayRef): slice_val = node.slice.ptr node.ptr = node.value.ptr._qd_field_for(slice_val) @@ -754,11 +753,11 @@ def build_Attribute(ctx: ASTTransformerFuncContext, node: ast.Attribute): node.ptr = node.ptr._unwrap() node.ptr = ASTTransformer._promote_ndarray_if_declared(ctx, node.ptr) else: - # ``register_array`` group access on a ``@qd.dataclass`` Struct expression. Returns a - # transient ``_RegisterArrayRef`` that ``build_Subscript`` (or its assignment-LHS sibling) - # resolves to a direct field reference. The lookup is by-name on ``_qd_register_groups``, - # which ``StructType.__call__`` attaches to every Struct instance whose type declared at - # least one ``register_array`` annotation. Tested in ``test_register_array.py``. + # ``register_array`` group access on a ``@qd.dataclass`` Struct expression. Returns a transient + # ``_RegisterArrayRef`` that ``build_Subscript`` (or its assignment-LHS sibling) resolves to a direct field + # reference. The lookup is by-name on ``_qd_register_groups``, which ``StructType.__call__`` attaches to + # every Struct instance whose type declared at least one ``register_array`` annotation. Tested in + # ``test_register_array.py``. groups = getattr(node.value.ptr, "_qd_register_groups", None) if groups and node.attr in groups: count, dtype, naming_fn = groups[node.attr] diff --git a/python/quadrants/lang/impl.py b/python/quadrants/lang/impl.py index d1778b3655..76a8306654 100644 --- a/python/quadrants/lang/impl.py +++ b/python/quadrants/lang/impl.py @@ -104,8 +104,8 @@ def expr_init(rhs): return rhs if isinstance(rhs, Struct): new_struct = Struct(rhs.to_dict(include_methods=True, include_ndim=True)) - # Preserve ``register_array`` group metadata across the rewrap; required so the AST - # transformer can still resolve ``obj.{group}[k]`` on the re-emitted Struct. + # Preserve ``register_array`` group metadata across the rewrap so the AST transformer can still resolve + # ``obj.{group}[k]`` on the re-emitted Struct. groups = getattr(rhs, "_qd_register_groups", None) if groups is not None: new_struct._qd_register_groups = groups diff --git a/python/quadrants/lang/register_array.py b/python/quadrants/lang/register_array.py index b297482cc9..adf856c43e 100644 --- a/python/quadrants/lang/register_array.py +++ b/python/quadrants/lang/register_array.py @@ -1,19 +1,16 @@ # type: ignore -"""``qd.register_array`` -- indexed groups of independently-allocated scalar fields on -``@qd.dataclass``. - -A ``register_array(N, dtype)`` annotation expands at struct-definition time into N -individually-named synthetic scalar members (``_{group}0`` .. ``_{group}{N-1}``). The AST -transformer rewrites ``obj.{group}[i]`` into a direct reference to ``obj._{group}{i}`` for -python-int / ``qd.static``-resolved indices, so generated LLVM IR / PTX is byte-identical to -a hand-rolled named-field struct. - -The motivation is register residency under pressure. A packed ``qd.types.vector(N, dtype)`` -collapses into one ``alloca`` that LLVM SROA cannot decompose once register pressure crosses -a threshold (e.g. two concurrent tiles in a Cholesky + TRSM kernel), causing spills to local -memory. ``register_array`` pre-decomposes the storage so SROA + ``mem2reg`` can -register-promote each slot independently, while keeping the ergonomic indexed-access syntax -at the source level. +"""``qd.register_array`` -- indexed groups of independently-allocated scalar fields on ``@qd.dataclass``. + +A ``register_array(N, dtype)`` annotation expands at struct-definition time into N individually-named synthetic scalar +members (``_{group}0`` .. ``_{group}{N-1}``). The AST transformer rewrites ``obj.{group}[i]`` into a direct reference to +``obj._{group}{i}`` for python-int / ``qd.static``-resolved indices, so generated LLVM IR / PTX is byte-identical to a +hand-rolled named-field struct. + +The motivation is register residency under pressure. A packed ``qd.types.vector(N, dtype)`` collapses into one +``alloca`` that LLVM SROA cannot decompose once register pressure crosses a threshold (e.g. two concurrent tiles in a +Cholesky + TRSM kernel), causing spills to local memory. ``register_array`` pre-decomposes the storage so SROA + +``mem2reg`` can register-promote each slot independently, while keeping the ergonomic indexed-access syntax at the +source level. Public: - ``RegisterArray`` - type wrapper used as the annotation value @@ -34,9 +31,9 @@ class RegisterArray: """Type wrapper for a group of N scalar fields exposed via indexed syntax on a ``@qd.dataclass``. - See :func:`register_array` for the user-facing constructor and the motivation writeup. - Holding only ``count`` and ``dtype``, this object is consumed at struct-definition time - by ``StructType.__init__`` to lay out the N synthetic scalar fields. + See :func:`register_array` for the user-facing constructor and the motivation writeup. Holding only ``count`` and + ``dtype``, this object is consumed at struct-definition time by ``StructType.__init__`` to lay out the N synthetic + scalar fields. """ def __init__(self, count, dtype): @@ -52,14 +49,12 @@ def __repr__(self): def register_array(count, dtype): """Declare a group of ``count`` independently-allocated fields of ``dtype`` on a ``@qd.dataclass``. - The annotation expands at struct-definition time into ``count`` individually-named scalar - members (``_{group}0`` .. ``_{group}{count-1}``). Each member gets its own LLVM - ``alloca``, which lets SROA + ``mem2reg`` promote each slot into its own SSA value - independently. The motivation is register residency under pressure: - ``qd.types.vector(N, dtype)`` collapses into one packed ``alloca`` that the optimiser - often spills as a unit when register pressure crosses a threshold; ``register_array`` - decomposes the storage up-front so the compiler can keep individual slots in registers - and only spill the ones it has to. + The annotation expands at struct-definition time into ``count`` individually-named scalar members (``_{group}0`` .. + ``_{group}{count-1}``). Each member gets its own LLVM ``alloca``, which lets SROA + ``mem2reg`` promote each slot + into its own SSA value independently. The motivation is register residency under pressure: + ``qd.types.vector(N, dtype)`` collapses into one packed ``alloca`` that the optimiser often spills as a unit when + register pressure crosses a threshold; ``register_array`` decomposes the storage up-front so the compiler can keep + individual slots in registers and only spill the ones it has to. Example:: @@ -73,12 +68,10 @@ class Tile: for k in qd.static(range(32)): t.r[k] = 0.0 # each iter is one AST node, not a 32-way cascade - For python-int / ``qd.static``-resolved indices, ``t.r[k]`` is rewritten by the AST - transformer to the named-field access ``t._r{k}``, producing identical LLVM IR / PTX to - a struct declared with N individually-named scalar fields. + For python-int / ``qd.static``-resolved indices, ``t.r[k]`` is rewritten by the AST transformer to the named-field + access ``t._r{k}``, producing identical LLVM IR / PTX to a struct declared with N individually-named scalar fields. - Runtime-int indexing is currently unsupported; use an explicit cascade helper for that - case. + Runtime-int indexing is currently unsupported; use an explicit cascade helper for that case. """ return RegisterArray(count, dtype) @@ -92,12 +85,12 @@ def _expand_register_array_naming(group_name, index): class _RegisterArrayRef: - """Transient proxy returned by the AST transformer for ``obj.{group}`` where ``group`` is - a registered ``register_array`` group on the struct type. + """Transient proxy returned by the AST transformer for ``obj.{group}`` where ``group`` is a registered group on the + struct type. - Only valid as the value of a Subscript node: ``obj.{group}[i]``. Resolved by - ``ASTTransformer.build_Subscript`` to a direct reference to the synthetic scalar field - ``_{group}{i}`` when ``i`` is a python-int / ``qd.static``-resolved integer. + Only valid as the value of a Subscript node: ``obj.{group}[i]``. Resolved by ``ASTTransformer.build_Subscript`` + to a direct reference to the synthetic scalar field ``_{group}{i}`` when ``i`` is a python-int / + ``qd.static``-resolved integer. Used as a not-an-Expr marker; any attempt to use it as a value raises. """ diff --git a/python/quadrants/lang/struct.py b/python/quadrants/lang/struct.py index f5de28a40e..0c64c58336 100644 --- a/python/quadrants/lang/struct.py +++ b/python/quadrants/lang/struct.py @@ -607,8 +607,8 @@ class StructType(CompoundType): def __init__(self, **kwargs): self.members = {} self.methods = {} - # Maps group name -> (count, dtype, naming_fn). Populated when a member annotation is a - # ``RegisterArray``; consumed by the AST transformer to rewrite ``obj.{group}[i]``. + # Maps group name -> (count, dtype, naming_fn). Populated when a member annotation is a ``RegisterArray``; + # consumed by the AST transformer to rewrite ``obj.{group}[i]`` into a direct synthetic-field reference. self._register_groups: dict = {} elements = [] for k, dtype in kwargs.items(): @@ -656,8 +656,8 @@ def __call__(self, *args, **kwargs): entries._Struct__dtype = self.dtype struct = self.cast(entries) struct._Struct__dtype = self.dtype - # Propagate the register-array group metadata onto the Struct instance so the AST - # transformer can detect indexed-group access (``obj.r``) on the per-trace expression. + # Propagate the register-array group metadata onto the Struct instance so the AST transformer can detect + # indexed-group access (``obj.r``) on the per-trace expression. if self._register_groups: struct._qd_register_groups = self._register_groups return struct diff --git a/tests/python/test_register_array.py b/tests/python/test_register_array.py index b4eea0bb17..c04786a022 100644 --- a/tests/python/test_register_array.py +++ b/tests/python/test_register_array.py @@ -1,10 +1,9 @@ # pyright: reportInvalidTypeForm=false """Tests for ``qd.register_array(N, dtype)`` on ``@qd.dataclass``. -``register_array`` gives users an ergonomic indexed-write syntax on a per-thread struct, while -keeping the underlying storage as N separate named scalar fields so SROA + ``mem2reg`` can -register-promote each slot independently. The static-index case must lower to a direct field -reference; PTX must be byte-identical to the named-field equivalent. +``register_array`` gives users an ergonomic indexed-write syntax on a per-thread struct, while keeping the underlying +storage as N separate named scalar fields so SROA + ``mem2reg`` can register-promote each slot independently. The +static-index case must lower to a direct field reference; PTX must be byte-identical to the named-field equivalent. """ import numpy as np @@ -28,16 +27,15 @@ def _qd_init_cuda(): def test_register_array_construction_python_scope(): - """A dataclass with ``r: qd.register_array(N, dtype)`` should construct as if it had N - named scalar fields ``_r0.._r{N-1}``.""" + """A dataclass with ``r: qd.register_array(N, dtype)`` should construct as if it had N named scalar fields named + ``_r0.._r{N-1}``.""" _qd_init_cuda() @qd.dataclass class Tile: r: qd.register_array(4, qd.f32) - # The underlying struct type should report N synthetic scalar members plus expose ``r`` as - # a group name. + # The underlying struct type should report N synthetic scalar members plus expose ``r`` as a group name. assert hasattr(Tile, "_register_groups") groups = Tile._register_groups assert "r" in groups @@ -84,8 +82,8 @@ def k(o: qd.template()): def test_register_array_qd_static_loop_index(): - """Index via a ``qd.static(range(N))`` loop variable. Each iter sees a python-int index, - so the lowering must be the same direct-field path as the explicit python-int case.""" + """Index via a ``qd.static(range(N))`` loop variable. Each iter sees a python-int index, so the lowering must be the + same direct-field path as the explicit python-int case.""" _qd_init_cuda() @qd.dataclass @@ -113,8 +111,8 @@ def k(o: qd.template()): def _build_named_kernel(): - """Same as test_register_array_static_index_write_then_read but with 4 named ``r0..r3`` - fields. Used for PTX byte-equality comparison against the ``register_array`` form.""" + """Same as test_register_array_static_index_write_then_read but with 4 named ``r0..r3`` fields. Used for PTX byte- + equality comparison against the ``register_array`` form.""" @qd.dataclass class TileNamed: r0: qd.f32 @@ -164,10 +162,9 @@ def k(o: qd.template()): def test_register_array_runtime_index_rejected(): - """Indexing ``t.r[k]`` with a runtime ``k`` raises a clear error pointing at the - python-int / ``qd.static`` requirement. Long term the runtime case can lower to an - explicit cascade; for now the limitation is surfaced early so callers don't get a - confusing LLVM/SROA failure downstream.""" + """Indexing ``t.r[k]`` with a runtime ``k`` raises a clear error pointing at the python-int / ``qd.static`` + requirement. Long term the runtime case can lower to an explicit cascade; for now the limitation is surfaced + early so callers don't get a confusing LLVM/SROA failure downstream.""" _qd_init_cuda() @qd.dataclass From 77346123578c982772d53734e0ff0cd37cbc1309 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Wed, 27 May 2026 10:42:54 -0700 Subject: [PATCH 06/11] [RegisterArray] pre-commit run -a fixups (ruff import sort, black f-string concat) - ast_transformer.py: ruff/isort: register_array import sorts before snode in the quadrants.lang.* group. - struct.py: ruff/isort: _RegisterArrayRef sorts after _expand_register_array_naming inside the multi-line import. - register_array.py: black joins two short f-string-continuation literals onto one 120c-safe line (in the OOB error path and __repr__). - test_register_array.py: black inserts a blank line between the docstring and the nested @qd.dataclass. All hooks (black, clang-format, trailing-whitespace, eof, ruff, pylint) now pass. --- python/quadrants/lang/ast/ast_transformer.py | 2 +- python/quadrants/lang/register_array.py | 8 ++------ python/quadrants/lang/struct.py | 2 +- tests/python/test_register_array.py | 1 + 4 files changed, 5 insertions(+), 8 deletions(-) diff --git a/python/quadrants/lang/ast/ast_transformer.py b/python/quadrants/lang/ast/ast_transformer.py index 7d8baa9566..80735435bf 100644 --- a/python/quadrants/lang/ast/ast_transformer.py +++ b/python/quadrants/lang/ast/ast_transformer.py @@ -38,8 +38,8 @@ from quadrants.lang.expr import Expr, make_expr_group from quadrants.lang.field import Field from quadrants.lang.matrix import Matrix, MatrixType -from quadrants.lang.snode import append, deactivate, length from quadrants.lang.register_array import _RegisterArrayRef +from quadrants.lang.snode import append, deactivate, length from quadrants.lang.struct import Struct, StructType from quadrants.lang.util import ( is_from_quadrants_module as _is_from_quadrants_module, diff --git a/python/quadrants/lang/register_array.py b/python/quadrants/lang/register_array.py index adf856c43e..4f8046d58a 100644 --- a/python/quadrants/lang/register_array.py +++ b/python/quadrants/lang/register_array.py @@ -113,17 +113,13 @@ def _qd_field_for(self, index: int): i = int(index) if i < 0 or i >= self._qd_count: raise QuadrantsSyntaxError( - f"register_array index out of bounds: {self._qd_group_name}[{i}] " - f"(count={self._qd_count})" + f"register_array index out of bounds: {self._qd_group_name}[{i}] " f"(count={self._qd_count})" ) field_name = self._qd_naming_fn(self._qd_group_name, i) return getattr(self._qd_struct, field_name) def __repr__(self) -> str: # pragma: no cover - debug only - return ( - f"" - ) + return f"" __all__ = ["RegisterArray", "register_array"] diff --git a/python/quadrants/lang/struct.py b/python/quadrants/lang/struct.py index 0c64c58336..9d9de7f19b 100644 --- a/python/quadrants/lang/struct.py +++ b/python/quadrants/lang/struct.py @@ -17,8 +17,8 @@ from quadrants.lang.matrix import Matrix, MatrixType from quadrants.lang.register_array import ( RegisterArray, - _RegisterArrayRef, _expand_register_array_naming, + _RegisterArrayRef, register_array, ) from quadrants.lang.util import ( diff --git a/tests/python/test_register_array.py b/tests/python/test_register_array.py index c04786a022..b61020bb21 100644 --- a/tests/python/test_register_array.py +++ b/tests/python/test_register_array.py @@ -113,6 +113,7 @@ def k(o: qd.template()): def _build_named_kernel(): """Same as test_register_array_static_index_write_then_read but with 4 named ``r0..r3`` fields. Used for PTX byte- equality comparison against the ``register_array`` form.""" + @qd.dataclass class TileNamed: r0: qd.f32 From 3595d0d7ec78a6af8495724d60ee61e450ffd37a Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Wed, 27 May 2026 10:44:36 -0700 Subject: [PATCH 07/11] [RegisterArray] add user-guide doc for qd.register_array Per-PR feedback request: a user-facing page covering the problem register_array solves (SROA bailout on packed qd.types.vector under register pressure), how to use it (@qd.dataclass annotation + python-int / qd.static indices), and its constraints (static indices only, no vector arithmetic, count fixed at struct-definition time, naming collisions to watch out for). - docs/source/user_guide/register_array.md: new doc page. - docs/source/user_guide/index.md: link the new page from Core concepts, next to compound_types. - docs/source/user_guide/compound_types.md: cross-link from the qd.dataclass overview. --- docs/source/user_guide/compound_types.md | 2 + docs/source/user_guide/index.md | 1 + docs/source/user_guide/register_array.md | 133 +++++++++++++++++++++++ 3 files changed, 136 insertions(+) create mode 100644 docs/source/user_guide/register_array.md diff --git a/docs/source/user_guide/compound_types.md b/docs/source/user_guide/compound_types.md index 79c007d7aa..8f744a5dea 100644 --- a/docs/source/user_guide/compound_types.md +++ b/docs/source/user_guide/compound_types.md @@ -159,3 +159,5 @@ class Particle: vel: qd.types.vector(3, qd.f32) mass: qd.f32 ``` + +For larger statically-indexed groups (e.g. a row or tile of scalars that you want SROA to register-promote independently), see {doc}`register_array`. diff --git a/docs/source/user_guide/index.md b/docs/source/user_guide/index.md index b648f97527..7c2d8108fc 100644 --- a/docs/source/user_guide/index.md +++ b/docs/source/user_guide/index.md @@ -21,6 +21,7 @@ matrix_vector_per_thread linalg_per_thread tensor compound_types +register_array buffer_view static sub_functions diff --git a/docs/source/user_guide/register_array.md b/docs/source/user_guide/register_array.md new file mode 100644 index 0000000000..4141c1ea32 --- /dev/null +++ b/docs/source/user_guide/register_array.md @@ -0,0 +1,133 @@ +# Register array + +`qd.register_array(N, dtype)` is a `@qd.dataclass` field annotation that declares a group of `N` independently-allocated scalar fields exposed under a single name via indexed syntax. + +It is a *layout hint*, not a new container. At source level you write `t.r[i]`; the compiler lowers that to a direct reference to a synthetic scalar field `t._r{i}`. The generated LLVM IR / PTX is byte-identical to a struct that was declared with `N` individually-named scalar fields. + +## What problem does it solve? + +The intuitive way to group `N` scalars on a per-thread struct is to declare a vector member: + +```python +@qd.dataclass +class Tile: + r: qd.types.vector(32, qd.f32) +``` + +`qd.types.vector(N, dtype)` lays the group out as one packed `alloca`. LLVM's scalar-replacement-of-aggregates pass (SROA + `mem2reg`) tries to break that `alloca` apart into per-slot SSA values so each one can live in a register — but once a kernel's register pressure crosses a threshold (e.g. two concurrent `32x32` tiles in a Cholesky + triangular solve), SROA bails out and the whole packed `alloca` spills to local memory. Each access then turns into a `ld.local` / `st.local` and the kernel slows down dramatically. + +The alternative is to declare `N` named scalar fields by hand: + +```python +@qd.dataclass +class Tile: + r0: qd.f32 + r1: qd.f32 + # ... 30 more lines ... + r31: qd.f32 +``` + +Now each slot has its own `alloca`, and SROA + `mem2reg` can promote each one into a register independently. The optimiser is also free to spill only the slots it has to, instead of the whole group as a unit. The cost is that every index becomes a cascade in source: + +```python +def get_r(t, k): + if k == 0: + return t.r0 + elif k == 1: + return t.r1 + # ... 30 more branches ... +``` + +…which is duplicated at every call site that wants to read or write the group. + +`qd.register_array` is the named-field layout with the ergonomic indexed syntax restored: + +```python +@qd.dataclass +class Tile: + r: qd.register_array(32, qd.f32) +``` + +The annotation expands at struct-definition time into the `N` synthetic scalar fields. The AST transformer rewrites `obj.r[i]` (for any python-int / `qd.static`-resolved `i`) into a direct reference to the synthetic field `obj._r{i}`. The IR / PTX matches the hand-rolled named-field version exactly. + +## How to use it + +Declare the group as a `register_array(count, dtype)` annotation on a `@qd.dataclass`: + +```python +import quadrants as qd + +qd.init(arch=qd.gpu) + + +@qd.dataclass +class Tile: + r: qd.register_array(32, qd.f32) + + +@qd.kernel +def k(out: qd.types.NDArray[qd.f32, 1]) -> None: + t = Tile() + # python-int index: lowers to a direct write of t._r5 + t.r[5] = 1.0 + # qd.static loop variable: each iter is one AST node, fully unrolled, + # no per-iter cascade. + for i in qd.static(range(32)): + t.r[i] = qd.f32(i) + out[0] = t.r[3] +``` + +Read access works the same way: + +```python +v = t.r[5] # python-int index +v = t.r[i] # i bound by `for i in qd.static(range(N)):` +``` + +You can mix `register_array` groups with regular scalar / vector fields on the same dataclass; they are independent. You can also have several `register_array` groups in one struct: + +```python +@qd.dataclass +class TwoTiles: + a: qd.register_array(32, qd.f32) + b: qd.register_array(32, qd.f32) + scale: qd.f32 +``` + +The generated struct has 65 scalar members (`_a0..._a31`, `_b0..._b31`, `scale`). + +## When to reach for it + +Use `register_array` when: + +- the group is *small and statically-sized*, and +- the kernel body accesses it with python-int / `qd.static`-resolved indices (typically unrolled inner loops), and +- you have measured (or strongly suspect) that an equivalent `qd.types.vector(N, dtype)` is leaving slots in local memory under register pressure. + +A good signal is `ptxas` reporting non-zero "bytes spill stores / loads" for the kernel, or `ld.local` / `st.local` instructions in the generated PTX that don't correspond to a deliberate shared-memory access. + +Prefer `qd.types.vector(N, dtype)` for small groups where register pressure is low and runtime indexing is needed — vectors keep all the usual arithmetic conveniences (element-wise ops, dot products, etc.) that `register_array` does not. + +## Constraints and limitations + +- **Static indices only.** `t.r[k]` must resolve at AST-build time, i.e. `k` is a python-int literal or a `for k in qd.static(range(N)):` loop variable. A runtime-int index raises a `QuadrantsSyntaxError` at trace time with a message pointing at the `qd.static` requirement. If you need a runtime index over the group, spell out the cascade explicitly (`if k == 0: ...`). +- **Static out-of-bounds is rejected at trace time.** `t.r[7]` on a `register_array(4, ...)` group raises `QuadrantsSyntaxError: register_array index out of bounds: r[7] (count=4)`. +- **Storage only.** A `register_array` group has no vector arithmetic. There is no `t.r + other`, no `t.r.dot(...)`, no broadcast operations. If you want those, use `qd.types.vector(N, dtype)` instead. +- **`count` is fixed at struct-definition time.** It must be a positive python-int literal. +- **Naming.** The synthetic fields use the convention `_{group_name}{i}` (e.g. `_r0`, `_r1`, ..., `_r31`). Avoid declaring your own field with a name that collides with one of those, or `StructType` will report a duplicate member. + +## Relationship to other annotations + +| annotation | storage layout | runtime indexing | best for | +|-------------------------------------|---------------------------------|:----------------:|---------------------------------------| +| `qd.f32` (per-field) | one `alloca` per field | n/a | individually-named scalars | +| `qd.types.vector(N, dtype)` | one packed `alloca` | yes | small groups with vector arithmetic | +| `qd.register_array(N, dtype)` | `N` independent `alloca`s | no | register-resident groups under load | + +Under low register pressure the three options generate similar code. Under high register pressure `register_array` is the one most likely to stay in registers because the optimiser can promote each slot independently. + +## See also + +- {doc}`compound_types` — `@qd.dataclass` overview +- {doc}`matrix_vector_per_thread` — `qd.types.vector` and per-thread matrices +- {doc}`linalg_per_thread` — examples of tile-resident linear algebra where register residency matters From bff5bd0fc6903bbee094c3607acee763310c4ec7 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Wed, 27 May 2026 14:06:19 -0700 Subject: [PATCH 08/11] [RegisterArray] fix pyright reportAttributeAccessIssue in expr_init Pyright on PR #712 fails with: python/quadrants/lang/impl.py:111:24 - error: Cannot assign to attribute "_qd_register_groups" for class "Struct". Attribute "_qd_register_groups" is unknown (reportAttributeAccessIssue) `Struct` doesn't statically declare the metadata attr (it's an opt-in per-instance tag set only when the dataclass has at least one register_array group). Use `setattr()` instead of direct attribute assignment so pyright doesn't insist on a class-level declaration; the runtime semantics are identical. Verified locally: 0 errors, 9 warnings -- all 9 warnings pre-existing and unrelated to this PR. --- python/quadrants/lang/impl.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/quadrants/lang/impl.py b/python/quadrants/lang/impl.py index 76a8306654..d8ec0f0d28 100644 --- a/python/quadrants/lang/impl.py +++ b/python/quadrants/lang/impl.py @@ -108,7 +108,9 @@ def expr_init(rhs): # ``obj.{group}[k]`` on the re-emitted Struct. groups = getattr(rhs, "_qd_register_groups", None) if groups is not None: - new_struct._qd_register_groups = groups + # setattr (rather than attribute assignment) sidesteps pyright's reportAttributeAccessIssue; + # ``Struct`` doesn't statically declare this attribute -- it's a per-instance metadata tag. + setattr(new_struct, "_qd_register_groups", groups) return new_struct if isinstance(rhs, list): return [expr_init(e) for e in rhs] From 22c0a5c01fae8473ca4494c94261137d132bb062 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Wed, 27 May 2026 14:12:04 -0700 Subject: [PATCH 09/11] [UnpackedArray] rename qd.register_array -> qd.unpacked_array (PR feedback) "register_array" reads as imperative ("register the array") and overcommits the type to an implementation effect (register residency). "unpacked_array" describes the actual layout the annotation produces -- each slot in its own `alloca` -- and contrasts cleanly with the implicit packed-vector default (`qd.types.vector(N)`). Whether the slots end up in registers is the optimiser's call; `unpacked_array` removes the packed-storage obstacle. Renames: - Public: qd.register_array(N, dtype) -> qd.unpacked_array(N, dtype) RegisterArray -> UnpackedArray - Private: _RegisterArrayRef -> _UnpackedArrayRef _qd_register_groups -> _qd_unpacked_groups _register_groups (StructType) -> _unpacked_groups _expand_register_array_naming -> _expand_unpacked_array_naming _qd_is_register_array_ref -> _qd_is_unpacked_array_ref - Module: python/quadrants/lang/register_array.py -> python/quadrants/lang/unpacked_array.py - Tests: tests/python/test_register_array.py -> tests/python/test_unpacked_array.py - Docs: docs/source/user_guide/register_array.md -> docs/source/user_guide/unpacked_array.md (also reworded around the "packed vs unpacked layout" framing) - test_api.py expected list updated (UnpackedArray + unpacked_array in alphabetical position). No behavioural change; all checks pass locally: - 224 passed + 9 xfailed across test_unpacked_array, test_api, test_py_dataclass, test_complex_struct, test_struct. - pre-commit run -a: black, clang-format, trailing-whitespace, eof, ruff, pylint all pass. - pyright on the touched files: 0 errors, 0 warnings. --- docs/source/user_guide/compound_types.md | 2 +- docs/source/user_guide/index.md | 2 +- .../{register_array.md => unpacked_array.md} | 36 ++++++----- python/quadrants/lang/ast/ast_transformer.py | 20 +++--- python/quadrants/lang/impl.py | 6 +- python/quadrants/lang/struct.py | 26 ++++---- .../{register_array.py => unpacked_array.py} | 64 +++++++++---------- tests/python/test_api.py | 4 +- ...gister_array.py => test_unpacked_array.py} | 50 +++++++-------- 9 files changed, 106 insertions(+), 104 deletions(-) rename docs/source/user_guide/{register_array.md => unpacked_array.md} (69%) rename python/quadrants/lang/{register_array.py => unpacked_array.py} (61%) rename tests/python/{test_register_array.py => test_unpacked_array.py} (81%) diff --git a/docs/source/user_guide/compound_types.md b/docs/source/user_guide/compound_types.md index 8f744a5dea..b0667f43e4 100644 --- a/docs/source/user_guide/compound_types.md +++ b/docs/source/user_guide/compound_types.md @@ -160,4 +160,4 @@ class Particle: mass: qd.f32 ``` -For larger statically-indexed groups (e.g. a row or tile of scalars that you want SROA to register-promote independently), see {doc}`register_array`. +For larger statically-indexed groups (e.g. a row or tile of scalars that you want SROA to register-promote independently), see {doc}`unpacked_array`. diff --git a/docs/source/user_guide/index.md b/docs/source/user_guide/index.md index 7c2d8108fc..3f15ab58aa 100644 --- a/docs/source/user_guide/index.md +++ b/docs/source/user_guide/index.md @@ -21,7 +21,7 @@ matrix_vector_per_thread linalg_per_thread tensor compound_types -register_array +unpacked_array buffer_view static sub_functions diff --git a/docs/source/user_guide/register_array.md b/docs/source/user_guide/unpacked_array.md similarity index 69% rename from docs/source/user_guide/register_array.md rename to docs/source/user_guide/unpacked_array.md index 4141c1ea32..dc1128a3f4 100644 --- a/docs/source/user_guide/register_array.md +++ b/docs/source/user_guide/unpacked_array.md @@ -1,6 +1,6 @@ -# Register array +# Unpacked array -`qd.register_array(N, dtype)` is a `@qd.dataclass` field annotation that declares a group of `N` independently-allocated scalar fields exposed under a single name via indexed syntax. +`qd.unpacked_array(N, dtype)` is a `@qd.dataclass` field annotation that declares a group of `N` independently-allocated scalar fields exposed under a single name via indexed syntax. It is a *layout hint*, not a new container. At source level you write `t.r[i]`; the compiler lowers that to a direct reference to a synthetic scalar field `t._r{i}`. The generated LLVM IR / PTX is byte-identical to a struct that was declared with `N` individually-named scalar fields. @@ -14,7 +14,7 @@ class Tile: r: qd.types.vector(32, qd.f32) ``` -`qd.types.vector(N, dtype)` lays the group out as one packed `alloca`. LLVM's scalar-replacement-of-aggregates pass (SROA + `mem2reg`) tries to break that `alloca` apart into per-slot SSA values so each one can live in a register — but once a kernel's register pressure crosses a threshold (e.g. two concurrent `32x32` tiles in a Cholesky + triangular solve), SROA bails out and the whole packed `alloca` spills to local memory. Each access then turns into a `ld.local` / `st.local` and the kernel slows down dramatically. +`qd.types.vector(N, dtype)` lays the group out as a single packed `alloca`. LLVM's scalar-replacement-of-aggregates pass (SROA + `mem2reg`) tries to decompose that `alloca` into per-slot SSA values so each one can live in a register — but once a kernel's register pressure crosses a threshold (e.g. two concurrent `32x32` tiles in a Cholesky + triangular solve), SROA bails out on the packed `alloca` and the whole group spills to local memory as a unit. Each access then turns into a `ld.local` / `st.local` and the kernel slows down dramatically. The alternative is to declare `N` named scalar fields by hand: @@ -27,7 +27,7 @@ class Tile: r31: qd.f32 ``` -Now each slot has its own `alloca`, and SROA + `mem2reg` can promote each one into a register independently. The optimiser is also free to spill only the slots it has to, instead of the whole group as a unit. The cost is that every index becomes a cascade in source: +Now each slot has its own `alloca`, and SROA + `mem2reg` can promote each one independently. The optimiser is also free to spill only the slots it has to, instead of the whole group as a unit. The cost is that every index becomes a cascade in source: ```python def get_r(t, k): @@ -40,19 +40,21 @@ def get_r(t, k): …which is duplicated at every call site that wants to read or write the group. -`qd.register_array` is the named-field layout with the ergonomic indexed syntax restored: +`qd.unpacked_array` is the named-field layout with the ergonomic indexed syntax restored: ```python @qd.dataclass class Tile: - r: qd.register_array(32, qd.f32) + r: qd.unpacked_array(32, qd.f32) ``` The annotation expands at struct-definition time into the `N` synthetic scalar fields. The AST transformer rewrites `obj.r[i]` (for any python-int / `qd.static`-resolved `i`) into a direct reference to the synthetic field `obj._r{i}`. The IR / PTX matches the hand-rolled named-field version exactly. +The name "unpacked" is a contrast with the packed-vector default: a packed group is one `alloca`, an unpacked group is `N` `alloca`s, one per slot. Whether the slots end up in registers is the optimiser's call; `unpacked_array` removes the layout obstacle that was preventing it. + ## How to use it -Declare the group as a `register_array(count, dtype)` annotation on a `@qd.dataclass`: +Declare the group as an `unpacked_array(count, dtype)` annotation on a `@qd.dataclass`: ```python import quadrants as qd @@ -62,7 +64,7 @@ qd.init(arch=qd.gpu) @qd.dataclass class Tile: - r: qd.register_array(32, qd.f32) + r: qd.unpacked_array(32, qd.f32) @qd.kernel @@ -84,13 +86,13 @@ v = t.r[5] # python-int index v = t.r[i] # i bound by `for i in qd.static(range(N)):` ``` -You can mix `register_array` groups with regular scalar / vector fields on the same dataclass; they are independent. You can also have several `register_array` groups in one struct: +You can mix `unpacked_array` groups with regular scalar / vector fields on the same dataclass; they are independent. You can also have several `unpacked_array` groups in one struct: ```python @qd.dataclass class TwoTiles: - a: qd.register_array(32, qd.f32) - b: qd.register_array(32, qd.f32) + a: qd.unpacked_array(32, qd.f32) + b: qd.unpacked_array(32, qd.f32) scale: qd.f32 ``` @@ -98,7 +100,7 @@ The generated struct has 65 scalar members (`_a0..._a31`, `_b0..._b31`, `scale`) ## When to reach for it -Use `register_array` when: +Use `unpacked_array` when: - the group is *small and statically-sized*, and - the kernel body accesses it with python-int / `qd.static`-resolved indices (typically unrolled inner loops), and @@ -106,13 +108,13 @@ Use `register_array` when: A good signal is `ptxas` reporting non-zero "bytes spill stores / loads" for the kernel, or `ld.local` / `st.local` instructions in the generated PTX that don't correspond to a deliberate shared-memory access. -Prefer `qd.types.vector(N, dtype)` for small groups where register pressure is low and runtime indexing is needed — vectors keep all the usual arithmetic conveniences (element-wise ops, dot products, etc.) that `register_array` does not. +Prefer `qd.types.vector(N, dtype)` for small groups where register pressure is low and runtime indexing is needed — vectors keep all the usual arithmetic conveniences (element-wise ops, dot products, etc.) that `unpacked_array` does not. ## Constraints and limitations - **Static indices only.** `t.r[k]` must resolve at AST-build time, i.e. `k` is a python-int literal or a `for k in qd.static(range(N)):` loop variable. A runtime-int index raises a `QuadrantsSyntaxError` at trace time with a message pointing at the `qd.static` requirement. If you need a runtime index over the group, spell out the cascade explicitly (`if k == 0: ...`). -- **Static out-of-bounds is rejected at trace time.** `t.r[7]` on a `register_array(4, ...)` group raises `QuadrantsSyntaxError: register_array index out of bounds: r[7] (count=4)`. -- **Storage only.** A `register_array` group has no vector arithmetic. There is no `t.r + other`, no `t.r.dot(...)`, no broadcast operations. If you want those, use `qd.types.vector(N, dtype)` instead. +- **Static out-of-bounds is rejected at trace time.** `t.r[7]` on an `unpacked_array(4, ...)` group raises `QuadrantsSyntaxError: unpacked_array index out of bounds: r[7] (count=4)`. +- **Storage only.** An `unpacked_array` group has no vector arithmetic. There is no `t.r + other`, no `t.r.dot(...)`, no broadcast operations. If you want those, use `qd.types.vector(N, dtype)` instead. - **`count` is fixed at struct-definition time.** It must be a positive python-int literal. - **Naming.** The synthetic fields use the convention `_{group_name}{i}` (e.g. `_r0`, `_r1`, ..., `_r31`). Avoid declaring your own field with a name that collides with one of those, or `StructType` will report a duplicate member. @@ -122,9 +124,9 @@ Prefer `qd.types.vector(N, dtype)` for small groups where register pressure is l |-------------------------------------|---------------------------------|:----------------:|---------------------------------------| | `qd.f32` (per-field) | one `alloca` per field | n/a | individually-named scalars | | `qd.types.vector(N, dtype)` | one packed `alloca` | yes | small groups with vector arithmetic | -| `qd.register_array(N, dtype)` | `N` independent `alloca`s | no | register-resident groups under load | +| `qd.unpacked_array(N, dtype)` | `N` independent `alloca`s | no | groups that need to stay register-resident under pressure | -Under low register pressure the three options generate similar code. Under high register pressure `register_array` is the one most likely to stay in registers because the optimiser can promote each slot independently. +Under low register pressure the three options generate similar code. Under high register pressure `unpacked_array` is the one most likely to stay in registers because the optimiser can promote each slot independently. ## See also diff --git a/python/quadrants/lang/ast/ast_transformer.py b/python/quadrants/lang/ast/ast_transformer.py index 80735435bf..e24fb9d707 100644 --- a/python/quadrants/lang/ast/ast_transformer.py +++ b/python/quadrants/lang/ast/ast_transformer.py @@ -38,9 +38,9 @@ from quadrants.lang.expr import Expr, make_expr_group from quadrants.lang.field import Field from quadrants.lang.matrix import Matrix, MatrixType -from quadrants.lang.register_array import _RegisterArrayRef from quadrants.lang.snode import append, deactivate, length from quadrants.lang.struct import Struct, StructType +from quadrants.lang.unpacked_array import _UnpackedArrayRef from quadrants.lang.util import ( is_from_quadrants_module as _is_from_quadrants_module, ) @@ -296,11 +296,11 @@ def _unpack_layout_vector_index(ast_builder, index, layout_len): def build_Subscript(ctx: ASTTransformerFuncContext, node: ast.Subscript): build_stmt(ctx, node.value) build_stmt(ctx, node.slice) - # ``register_array`` group subscript: rewrite ``obj.{group}[k]`` to a direct reference to the synthetic scalar + # ``unpacked_array`` group subscript: rewrite ``obj.{group}[k]`` to a direct reference to the synthetic scalar # field ``_{group}{k}`` when ``k`` is a python-int. The resolution lives entirely at AST-build time so the # runtime IR/PTX is byte-identical to the named-field form. Runtime indices are not supported here -- the user # must spell the cascade explicitly. - if isinstance(node.value.ptr, _RegisterArrayRef): + if isinstance(node.value.ptr, _UnpackedArrayRef): slice_val = node.slice.ptr node.ptr = node.value.ptr._qd_field_for(slice_val) node.violates_pure = node.value.violates_pure @@ -753,15 +753,15 @@ def build_Attribute(ctx: ASTTransformerFuncContext, node: ast.Attribute): node.ptr = node.ptr._unwrap() node.ptr = ASTTransformer._promote_ndarray_if_declared(ctx, node.ptr) else: - # ``register_array`` group access on a ``@qd.dataclass`` Struct expression. Returns a transient - # ``_RegisterArrayRef`` that ``build_Subscript`` (or its assignment-LHS sibling) resolves to a direct field - # reference. The lookup is by-name on ``_qd_register_groups``, which ``StructType.__call__`` attaches to - # every Struct instance whose type declared at least one ``register_array`` annotation. Tested in - # ``test_register_array.py``. - groups = getattr(node.value.ptr, "_qd_register_groups", None) + # ``unpacked_array`` group access on a ``@qd.dataclass`` Struct expression. Returns a transient + # ``_UnpackedArrayRef`` that ``build_Subscript`` (or its assignment-LHS sibling) resolves to a direct field + # reference. The lookup is by-name on ``_qd_unpacked_groups``, which ``StructType.__call__`` attaches to + # every Struct instance whose type declared at least one ``unpacked_array`` annotation. Tested in + # ``test_unpacked_array.py``. + groups = getattr(node.value.ptr, "_qd_unpacked_groups", None) if groups and node.attr in groups: count, dtype, naming_fn = groups[node.attr] - node.ptr = _RegisterArrayRef(node.value.ptr, node.attr, count, dtype, naming_fn) + node.ptr = _UnpackedArrayRef(node.value.ptr, node.attr, count, dtype, naming_fn) return node.ptr node.ptr = getattr(node.value.ptr, node.attr) # ``qd.Tensor`` wrappers reached via attribute access on a ``@qd.data_oriented`` struct field at AST-build diff --git a/python/quadrants/lang/impl.py b/python/quadrants/lang/impl.py index d8ec0f0d28..95d6652eb9 100644 --- a/python/quadrants/lang/impl.py +++ b/python/quadrants/lang/impl.py @@ -104,13 +104,13 @@ def expr_init(rhs): return rhs if isinstance(rhs, Struct): new_struct = Struct(rhs.to_dict(include_methods=True, include_ndim=True)) - # Preserve ``register_array`` group metadata across the rewrap so the AST transformer can still resolve + # Preserve ``unpacked_array`` group metadata across the rewrap so the AST transformer can still resolve # ``obj.{group}[k]`` on the re-emitted Struct. - groups = getattr(rhs, "_qd_register_groups", None) + groups = getattr(rhs, "_qd_unpacked_groups", None) if groups is not None: # setattr (rather than attribute assignment) sidesteps pyright's reportAttributeAccessIssue; # ``Struct`` doesn't statically declare this attribute -- it's a per-instance metadata tag. - setattr(new_struct, "_qd_register_groups", groups) + setattr(new_struct, "_qd_unpacked_groups", groups) return new_struct if isinstance(rhs, list): return [expr_init(e) for e in rhs] diff --git a/python/quadrants/lang/struct.py b/python/quadrants/lang/struct.py index 9d9de7f19b..4c15a19e50 100644 --- a/python/quadrants/lang/struct.py +++ b/python/quadrants/lang/struct.py @@ -15,11 +15,11 @@ from quadrants.lang.expr import Expr from quadrants.lang.field import Field, ScalarField, SNodeHostAccess from quadrants.lang.matrix import Matrix, MatrixType -from quadrants.lang.register_array import ( - RegisterArray, - _expand_register_array_naming, - _RegisterArrayRef, - register_array, +from quadrants.lang.unpacked_array import ( + UnpackedArray, + _expand_unpacked_array_naming, + _UnpackedArrayRef, + unpacked_array, ) from quadrants.lang.util import ( cook_dtype, @@ -607,18 +607,18 @@ class StructType(CompoundType): def __init__(self, **kwargs): self.members = {} self.methods = {} - # Maps group name -> (count, dtype, naming_fn). Populated when a member annotation is a ``RegisterArray``; + # Maps group name -> (count, dtype, naming_fn). Populated when a member annotation is an ``UnpackedArray``; # consumed by the AST transformer to rewrite ``obj.{group}[i]`` into a direct synthetic-field reference. - self._register_groups: dict = {} + self._unpacked_groups: dict = {} elements = [] for k, dtype in kwargs.items(): if k == "__struct_methods": self.methods = dtype - elif isinstance(dtype, RegisterArray): + elif isinstance(dtype, UnpackedArray): cooked = cook_dtype(dtype.dtype) - self._register_groups[k] = (dtype.count, cooked, _expand_register_array_naming) + self._unpacked_groups[k] = (dtype.count, cooked, _expand_unpacked_array_naming) for i in range(dtype.count): - sub = _expand_register_array_naming(k, i) + sub = _expand_unpacked_array_naming(k, i) self.members[sub] = cooked elements.append([cooked, sub]) elif isinstance(dtype, StructType): @@ -658,8 +658,8 @@ def __call__(self, *args, **kwargs): struct._Struct__dtype = self.dtype # Propagate the register-array group metadata onto the Struct instance so the AST transformer can detect # indexed-group access (``obj.r``) on the per-trace expression. - if self._register_groups: - struct._qd_register_groups = self._register_groups + if self._unpacked_groups: + struct._qd_unpacked_groups = self._unpacked_groups return struct def __instancecheck__(self, instance): @@ -852,4 +852,4 @@ def dataclass(cls): return StructType(**fields) -__all__ = ["Struct", "StructField", "dataclass", "RegisterArray", "register_array", "_RegisterArrayRef"] +__all__ = ["Struct", "StructField", "dataclass", "UnpackedArray", "unpacked_array", "_UnpackedArrayRef"] diff --git a/python/quadrants/lang/register_array.py b/python/quadrants/lang/unpacked_array.py similarity index 61% rename from python/quadrants/lang/register_array.py rename to python/quadrants/lang/unpacked_array.py index 4f8046d58a..a34653b550 100644 --- a/python/quadrants/lang/register_array.py +++ b/python/quadrants/lang/unpacked_array.py @@ -1,24 +1,24 @@ # type: ignore -"""``qd.register_array`` -- indexed groups of independently-allocated scalar fields on ``@qd.dataclass``. +"""``qd.unpacked_array`` -- indexed groups of independently-allocated scalar fields on ``@qd.dataclass``. -A ``register_array(N, dtype)`` annotation expands at struct-definition time into N individually-named synthetic scalar +An ``unpacked_array(N, dtype)`` annotation expands at struct-definition time into N individually-named synthetic scalar members (``_{group}0`` .. ``_{group}{N-1}``). The AST transformer rewrites ``obj.{group}[i]`` into a direct reference to ``obj._{group}{i}`` for python-int / ``qd.static``-resolved indices, so generated LLVM IR / PTX is byte-identical to a hand-rolled named-field struct. -The motivation is register residency under pressure. A packed ``qd.types.vector(N, dtype)`` collapses into one -``alloca`` that LLVM SROA cannot decompose once register pressure crosses a threshold (e.g. two concurrent tiles in a -Cholesky + TRSM kernel), causing spills to local memory. ``register_array`` pre-decomposes the storage so SROA + -``mem2reg`` can register-promote each slot independently, while keeping the ergonomic indexed-access syntax at the -source level. +Compare to ``qd.types.vector(N, dtype)`` which is the *packed* layout: one ``alloca`` covers all N slots. Packed storage +is fine until register pressure rises -- once LLVM SROA fails to decompose the packed ``alloca`` (e.g. two concurrent +tiles in a Cholesky + TRSM kernel), the whole group spills to local memory as a unit. ``unpacked_array`` lays each slot +out in its own ``alloca`` up front, so SROA + ``mem2reg`` can promote slots independently and the optimiser can spill +only the ones it has to. The ergonomic indexed-access syntax is preserved at the source level. Public: -- ``RegisterArray`` - type wrapper used as the annotation value -- ``register_array`` - factory: ``r: qd.register_array(N, dtype)`` +- ``UnpackedArray`` - type wrapper used as the annotation value +- ``unpacked_array`` - factory: ``r: qd.unpacked_array(N, dtype)`` Internal (used by ``StructType`` and the AST transformer): -- ``_expand_register_array_naming(group, i)`` - synthetic-field naming convention -- ``_RegisterArrayRef`` - transient proxy yielded by attribute access +- ``_expand_unpacked_array_naming(group, i)`` - synthetic-field naming convention +- ``_UnpackedArrayRef`` - transient proxy yielded by attribute access This module has no dependency on ``struct.py``; ``struct.py`` imports from here. """ @@ -28,39 +28,39 @@ from quadrants.lang.exception import QuadrantsSyntaxError -class RegisterArray: +class UnpackedArray: """Type wrapper for a group of N scalar fields exposed via indexed syntax on a ``@qd.dataclass``. - See :func:`register_array` for the user-facing constructor and the motivation writeup. Holding only ``count`` and + See :func:`unpacked_array` for the user-facing constructor and the motivation writeup. Holding only ``count`` and ``dtype``, this object is consumed at struct-definition time by ``StructType.__init__`` to lay out the N synthetic scalar fields. """ def __init__(self, count, dtype): if not isinstance(count, int) or count <= 0: - raise QuadrantsSyntaxError(f"register_array count must be a positive int, got {count!r}") + raise QuadrantsSyntaxError(f"unpacked_array count must be a positive int, got {count!r}") self.count = count self.dtype = dtype def __repr__(self): - return f"register_array(count={self.count}, dtype={self.dtype})" + return f"unpacked_array(count={self.count}, dtype={self.dtype})" -def register_array(count, dtype): +def unpacked_array(count, dtype): """Declare a group of ``count`` independently-allocated fields of ``dtype`` on a ``@qd.dataclass``. The annotation expands at struct-definition time into ``count`` individually-named scalar members (``_{group}0`` .. ``_{group}{count-1}``). Each member gets its own LLVM ``alloca``, which lets SROA + ``mem2reg`` promote each slot - into its own SSA value independently. The motivation is register residency under pressure: - ``qd.types.vector(N, dtype)`` collapses into one packed ``alloca`` that the optimiser often spills as a unit when - register pressure crosses a threshold; ``register_array`` decomposes the storage up-front so the compiler can keep - individual slots in registers and only spill the ones it has to. + into its own SSA value independently. The contrast is with ``qd.types.vector(N, dtype)``, which lays all N slots + out in a single packed ``alloca``: when register pressure makes SROA fail to decompose the packed ``alloca``, the + whole group spills to local memory as a unit. With ``unpacked_array`` the storage is already unpacked, so the + optimiser can keep individual slots in registers and only spill the ones it has to. Example:: @qd.dataclass class Tile: - r: qd.register_array(32, qd.f32) # 32 scalar fields exposed as t.r[0..31] + r: qd.unpacked_array(32, qd.f32) # 32 scalar fields exposed as t.r[0..31] t = Tile() t.r[5] = 1.0 # lowers to direct write of synthetic field _r5 @@ -73,20 +73,20 @@ class Tile: Runtime-int indexing is currently unsupported; use an explicit cascade helper for that case. """ - return RegisterArray(count, dtype) + return UnpackedArray(count, dtype) -def _expand_register_array_naming(group_name, index): - """Naming convention for the synthetic scalar fields of a ``register_array`` group. +def _expand_unpacked_array_naming(group_name, index): + """Naming convention for the synthetic scalar fields of an ``unpacked_array`` group. Public-ish so the AST transformer can mirror this without a circular import on ``struct``. """ return f"_{group_name}{index}" -class _RegisterArrayRef: - """Transient proxy returned by the AST transformer for ``obj.{group}`` where ``group`` is a registered group on the - struct type. +class _UnpackedArrayRef: + """Transient proxy returned by the AST transformer for ``obj.{group}`` where ``group`` is an unpacked-array group + declared on the struct type. Only valid as the value of a Subscript node: ``obj.{group}[i]``. Resolved by ``ASTTransformer.build_Subscript`` to a direct reference to the synthetic scalar field ``_{group}{i}`` when ``i`` is a python-int / @@ -95,7 +95,7 @@ class _RegisterArrayRef: Used as a not-an-Expr marker; any attempt to use it as a value raises. """ - _qd_is_register_array_ref = True + _qd_is_unpacked_array_ref = True def __init__(self, struct, group_name: str, count: int, dtype, naming_fn): self._qd_struct = struct @@ -107,19 +107,19 @@ def __init__(self, struct, group_name: str, count: int, dtype, naming_fn): def _qd_field_for(self, index: int): if not isinstance(index, (int, np.integer)): raise QuadrantsSyntaxError( - f"register_array {self._qd_group_name}[i] requires a python-int index " + f"unpacked_array {self._qd_group_name}[i] requires a python-int index " f"(possibly via qd.static); got runtime index of type {type(index).__name__}" ) i = int(index) if i < 0 or i >= self._qd_count: raise QuadrantsSyntaxError( - f"register_array index out of bounds: {self._qd_group_name}[{i}] " f"(count={self._qd_count})" + f"unpacked_array index out of bounds: {self._qd_group_name}[{i}] " f"(count={self._qd_count})" ) field_name = self._qd_naming_fn(self._qd_group_name, i) return getattr(self._qd_struct, field_name) def __repr__(self) -> str: # pragma: no cover - debug only - return f"" + return f"" -__all__ = ["RegisterArray", "register_array"] +__all__ = ["UnpackedArray", "unpacked_array"] diff --git a/tests/python/test_api.py b/tests/python/test_api.py index 8bf7acce6c..db28f6b035 100644 --- a/tests/python/test_api.py +++ b/tests/python/test_api.py @@ -75,7 +75,6 @@ def _get_expected_matrix_apis(): "Mesh", "MeshInstance", "Ndarray", - "RegisterArray", "SNode", "ScalarField", "ScalarNdarray", @@ -83,6 +82,7 @@ def _get_expected_matrix_apis(): "Struct", "StructField", "TRACE", + "UnpackedArray", "QuadrantsAssertionError", "QuadrantsCompilationError", "QuadrantsNameError", @@ -215,7 +215,6 @@ def _get_expected_matrix_apis(): "raw_mod", "real_func", "ref", - "register_array", "rescale_index", "reset", "root", @@ -253,6 +252,7 @@ def _get_expected_matrix_apis(): "uint32", "uint64", "uint8", + "unpacked_array", "volatile_load", "vulkan", "x64", diff --git a/tests/python/test_register_array.py b/tests/python/test_unpacked_array.py similarity index 81% rename from tests/python/test_register_array.py rename to tests/python/test_unpacked_array.py index b61020bb21..be6e3a0ed8 100644 --- a/tests/python/test_register_array.py +++ b/tests/python/test_unpacked_array.py @@ -1,7 +1,7 @@ # pyright: reportInvalidTypeForm=false -"""Tests for ``qd.register_array(N, dtype)`` on ``@qd.dataclass``. +"""Tests for ``qd.unpacked_array(N, dtype)`` on ``@qd.dataclass``. -``register_array`` gives users an ergonomic indexed-write syntax on a per-thread struct, while keeping the underlying +``unpacked_array`` gives users an ergonomic indexed-write syntax on a per-thread struct, while keeping the underlying storage as N separate named scalar fields so SROA + ``mem2reg`` can register-promote each slot independently. The static-index case must lower to a direct field reference; PTX must be byte-identical to the named-field equivalent. """ @@ -26,18 +26,18 @@ def _qd_init_cuda(): # --------------------------------------------------------------------------- -def test_register_array_construction_python_scope(): - """A dataclass with ``r: qd.register_array(N, dtype)`` should construct as if it had N named scalar fields named +def test_unpacked_array_construction_python_scope(): + """A dataclass with ``r: qd.unpacked_array(N, dtype)`` should construct as if it had N named scalar fields named ``_r0.._r{N-1}``.""" _qd_init_cuda() @qd.dataclass class Tile: - r: qd.register_array(4, qd.f32) + r: qd.unpacked_array(4, qd.f32) # The underlying struct type should report N synthetic scalar members plus expose ``r`` as a group name. - assert hasattr(Tile, "_register_groups") - groups = Tile._register_groups + assert hasattr(Tile, "_unpacked_groups") + groups = Tile._unpacked_groups assert "r" in groups count, dtype, _ = groups["r"] assert count == 4 @@ -54,13 +54,13 @@ class Tile: # --------------------------------------------------------------------------- -def test_register_array_static_index_write_then_read(): +def test_unpacked_array_static_index_write_then_read(): """Write to ``t.r[0..3]`` with python-int indices, then read back.""" _qd_init_cuda() @qd.dataclass class Tile: - r: qd.register_array(4, qd.f32) + r: qd.unpacked_array(4, qd.f32) out = qd.field(dtype=qd.f32, shape=(4,)) @@ -81,14 +81,14 @@ def k(o: qd.template()): np.testing.assert_array_equal(out.to_numpy(), np.array([1, 2, 3, 4], dtype=np.float32)) -def test_register_array_qd_static_loop_index(): +def test_unpacked_array_qd_static_loop_index(): """Index via a ``qd.static(range(N))`` loop variable. Each iter sees a python-int index, so the lowering must be the same direct-field path as the explicit python-int case.""" _qd_init_cuda() @qd.dataclass class Tile: - r: qd.register_array(4, qd.f32) + r: qd.unpacked_array(4, qd.f32) out = qd.field(dtype=qd.f32, shape=(4,)) @@ -111,8 +111,8 @@ def k(o: qd.template()): def _build_named_kernel(): - """Same as test_register_array_static_index_write_then_read but with 4 named ``r0..r3`` fields. Used for PTX byte- - equality comparison against the ``register_array`` form.""" + """Same as test_unpacked_array_static_index_write_then_read but with 4 named ``r0..r3`` fields. Used for PTX byte- + equality comparison against the ``unpacked_array`` form.""" @qd.dataclass class TileNamed: @@ -139,10 +139,10 @@ def k(o: qd.template()): return k, out -def _build_register_array_kernel(): +def _build_unpacked_array_kernel(): @qd.dataclass class TileRA: - r: qd.register_array(4, qd.f32) + r: qd.unpacked_array(4, qd.f32) out = qd.field(dtype=qd.f32, shape=(4,)) @@ -162,7 +162,7 @@ def k(o: qd.template()): return k, out -def test_register_array_runtime_index_rejected(): +def test_unpacked_array_runtime_index_rejected(): """Indexing ``t.r[k]`` with a runtime ``k`` raises a clear error pointing at the python-int / ``qd.static`` requirement. Long term the runtime case can lower to an explicit cascade; for now the limitation is surfaced early so callers don't get a confusing LLVM/SROA failure downstream.""" @@ -170,7 +170,7 @@ def test_register_array_runtime_index_rejected(): @qd.dataclass class Tile: - r: qd.register_array(4, qd.f32) + r: qd.unpacked_array(4, qd.f32) out = qd.field(dtype=qd.f32, shape=(4,)) @@ -185,16 +185,16 @@ def k(o: qd.template()): with pytest.raises(Exception) as e: k(out) msg = str(e.value) - assert "register_array" in msg and "python-int" in msg, msg + assert "unpacked_array" in msg and "python-int" in msg, msg -def test_register_array_oob_static_index(): +def test_unpacked_array_oob_static_index(): """Static-int out-of-bounds index is caught at trace time with a clear message.""" _qd_init_cuda() @qd.dataclass class Tile: - r: qd.register_array(4, qd.f32) + r: qd.unpacked_array(4, qd.f32) out = qd.field(dtype=qd.f32, shape=(4,)) @@ -211,13 +211,13 @@ def k(o: qd.template()): if __name__ == "__main__": - test_register_array_construction_python_scope() + test_unpacked_array_construction_python_scope() print("construction test passed") - test_register_array_static_index_write_then_read() + test_unpacked_array_static_index_write_then_read() print("static-int subscript test passed") - test_register_array_qd_static_loop_index() + test_unpacked_array_qd_static_loop_index() print("qd.static loop-var subscript test passed") - test_register_array_runtime_index_rejected() + test_unpacked_array_runtime_index_rejected() print("runtime-index rejection test passed") - test_register_array_oob_static_index() + test_unpacked_array_oob_static_index() print("static OOB rejection test passed") From b2e9c6d42f59cd7c066802f5f15b9b907ac32046 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Wed, 27 May 2026 14:57:15 -0700 Subject: [PATCH 10/11] [UnpackedArray] docs/tests: replace 'trace time' with 'compile time' User-facing language fix. 'Trace time' is internal Quadrants jargon for the python-side AST walk; from the user's seat the relevant boundary is the compile/run split, so use 'compile time' in the doc and test docstring. - docs/source/user_guide/unpacked_array.md: 'AST-build time' / 'trace time' in the constraints section -> 'compile time'. - tests/python/test_unpacked_array.py: OOB-static-index test docstring updated to match. No behavioural change; pure prose. --- docs/source/user_guide/unpacked_array.md | 4 ++-- tests/python/test_unpacked_array.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/user_guide/unpacked_array.md b/docs/source/user_guide/unpacked_array.md index dc1128a3f4..40e20a55e2 100644 --- a/docs/source/user_guide/unpacked_array.md +++ b/docs/source/user_guide/unpacked_array.md @@ -112,8 +112,8 @@ Prefer `qd.types.vector(N, dtype)` for small groups where register pressure is l ## Constraints and limitations -- **Static indices only.** `t.r[k]` must resolve at AST-build time, i.e. `k` is a python-int literal or a `for k in qd.static(range(N)):` loop variable. A runtime-int index raises a `QuadrantsSyntaxError` at trace time with a message pointing at the `qd.static` requirement. If you need a runtime index over the group, spell out the cascade explicitly (`if k == 0: ...`). -- **Static out-of-bounds is rejected at trace time.** `t.r[7]` on an `unpacked_array(4, ...)` group raises `QuadrantsSyntaxError: unpacked_array index out of bounds: r[7] (count=4)`. +- **Static indices only.** `t.r[k]` must resolve at compile time, i.e. `k` is a python-int literal or a `for k in qd.static(range(N)):` loop variable. A runtime-int index raises a `QuadrantsSyntaxError` at compile time with a message pointing at the `qd.static` requirement. If you need a runtime index over the group, spell out the cascade explicitly (`if k == 0: ...`). +- **Static out-of-bounds is rejected at compile time.** `t.r[7]` on an `unpacked_array(4, ...)` group raises `QuadrantsSyntaxError: unpacked_array index out of bounds: r[7] (count=4)`. - **Storage only.** An `unpacked_array` group has no vector arithmetic. There is no `t.r + other`, no `t.r.dot(...)`, no broadcast operations. If you want those, use `qd.types.vector(N, dtype)` instead. - **`count` is fixed at struct-definition time.** It must be a positive python-int literal. - **Naming.** The synthetic fields use the convention `_{group_name}{i}` (e.g. `_r0`, `_r1`, ..., `_r31`). Avoid declaring your own field with a name that collides with one of those, or `StructType` will report a duplicate member. diff --git a/tests/python/test_unpacked_array.py b/tests/python/test_unpacked_array.py index be6e3a0ed8..4b1b8bc51a 100644 --- a/tests/python/test_unpacked_array.py +++ b/tests/python/test_unpacked_array.py @@ -189,7 +189,7 @@ def k(o: qd.template()): def test_unpacked_array_oob_static_index(): - """Static-int out-of-bounds index is caught at trace time with a clear message.""" + """Static-int out-of-bounds index is caught at compile time with a clear message.""" _qd_init_cuda() @qd.dataclass From 41d80e2eae0ddd6acbe30c275432d14851350142 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Wed, 27 May 2026 14:59:42 -0700 Subject: [PATCH 11/11] [UnpackedArray] clarify metadata-propagation comment in StructType.__call__ Previous comment said "propagate the register-array group metadata onto the Struct instance ... on the per-trace expression" - both terms are stale / unclear: - "register-array" was left behind by the qd.register_array -> qd.unpacked_array rename (which matched identifiers, not free-form prose). - "per-trace expression" is internal jargon; what's actually happening is that StructType.__call__ runs on every kernel trace and builds a fresh `Struct` object representing the `Tile()` instantiation in the kernel's IR, and we tag THAT expression-object (vs the class-level StructType) because ASTTransformer.build_Attribute walks the instance. Reword the comment to spell that out without the jargon. No behavioural change. --- python/quadrants/lang/struct.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/quadrants/lang/struct.py b/python/quadrants/lang/struct.py index 4c15a19e50..bf2a4399f1 100644 --- a/python/quadrants/lang/struct.py +++ b/python/quadrants/lang/struct.py @@ -656,8 +656,10 @@ def __call__(self, *args, **kwargs): entries._Struct__dtype = self.dtype struct = self.cast(entries) struct._Struct__dtype = self.dtype - # Propagate the register-array group metadata onto the Struct instance so the AST transformer can detect - # indexed-group access (``obj.r``) on the per-trace expression. + # Tag the freshly-built Struct expression-object (representing this ``Tile()`` instantiation in the kernel's + # IR) with the unpacked-array group dictionary, so ``ASTTransformer.build_Attribute`` can recognise + # ``obj.r`` as a group name. The transformer inspects the instance, not the StructType, so the metadata has + # to live here. if self._unpacked_groups: struct._qd_unpacked_groups = self._unpacked_groups return struct