diff --git a/docs/source/user_guide/compound_types.md b/docs/source/user_guide/compound_types.md index 79c007d7aa..b0667f43e4 100644 --- a/docs/source/user_guide/compound_types.md +++ b/docs/source/user_guide/compound_types.md @@ -159,3 +159,5 @@ class Particle: vel: qd.types.vector(3, qd.f32) mass: qd.f32 ``` + +For larger statically-indexed groups (e.g. a row or tile of scalars that you want SROA to register-promote independently), see {doc}`unpacked_array`. diff --git a/docs/source/user_guide/index.md b/docs/source/user_guide/index.md index b648f97527..3f15ab58aa 100644 --- a/docs/source/user_guide/index.md +++ b/docs/source/user_guide/index.md @@ -21,6 +21,7 @@ matrix_vector_per_thread linalg_per_thread tensor compound_types +unpacked_array buffer_view static sub_functions diff --git a/docs/source/user_guide/unpacked_array.md b/docs/source/user_guide/unpacked_array.md new file mode 100644 index 0000000000..40e20a55e2 --- /dev/null +++ b/docs/source/user_guide/unpacked_array.md @@ -0,0 +1,135 @@ +# Unpacked array + +`qd.unpacked_array(N, dtype)` is a `@qd.dataclass` field annotation that declares a group of `N` independently-allocated scalar fields exposed under a single name via indexed syntax. + +It is a *layout hint*, not a new container. At source level you write `t.r[i]`; the compiler lowers that to a direct reference to a synthetic scalar field `t._r{i}`. The generated LLVM IR / PTX is byte-identical to a struct that was declared with `N` individually-named scalar fields. + +## What problem does it solve? + +The intuitive way to group `N` scalars on a per-thread struct is to declare a vector member: + +```python +@qd.dataclass +class Tile: + r: qd.types.vector(32, qd.f32) +``` + +`qd.types.vector(N, dtype)` lays the group out as a single packed `alloca`. LLVM's scalar-replacement-of-aggregates pass (SROA + `mem2reg`) tries to decompose that `alloca` into per-slot SSA values so each one can live in a register — but once a kernel's register pressure crosses a threshold (e.g. two concurrent `32x32` tiles in a Cholesky + triangular solve), SROA bails out on the packed `alloca` and the whole group spills to local memory as a unit. Each access then turns into a `ld.local` / `st.local` and the kernel slows down dramatically. + +The alternative is to declare `N` named scalar fields by hand: + +```python +@qd.dataclass +class Tile: + r0: qd.f32 + r1: qd.f32 + # ... 30 more lines ... + r31: qd.f32 +``` + +Now each slot has its own `alloca`, and SROA + `mem2reg` can promote each one independently. The optimiser is also free to spill only the slots it has to, instead of the whole group as a unit. The cost is that every index becomes a cascade in source: + +```python +def get_r(t, k): + if k == 0: + return t.r0 + elif k == 1: + return t.r1 + # ... 30 more branches ... +``` + +…which is duplicated at every call site that wants to read or write the group. + +`qd.unpacked_array` is the named-field layout with the ergonomic indexed syntax restored: + +```python +@qd.dataclass +class Tile: + r: qd.unpacked_array(32, qd.f32) +``` + +The annotation expands at struct-definition time into the `N` synthetic scalar fields. The AST transformer rewrites `obj.r[i]` (for any python-int / `qd.static`-resolved `i`) into a direct reference to the synthetic field `obj._r{i}`. The IR / PTX matches the hand-rolled named-field version exactly. + +The name "unpacked" is a contrast with the packed-vector default: a packed group is one `alloca`, an unpacked group is `N` `alloca`s, one per slot. Whether the slots end up in registers is the optimiser's call; `unpacked_array` removes the layout obstacle that was preventing it. + +## How to use it + +Declare the group as an `unpacked_array(count, dtype)` annotation on a `@qd.dataclass`: + +```python +import quadrants as qd + +qd.init(arch=qd.gpu) + + +@qd.dataclass +class Tile: + r: qd.unpacked_array(32, qd.f32) + + +@qd.kernel +def k(out: qd.types.NDArray[qd.f32, 1]) -> None: + t = Tile() + # python-int index: lowers to a direct write of t._r5 + t.r[5] = 1.0 + # qd.static loop variable: each iter is one AST node, fully unrolled, + # no per-iter cascade. + for i in qd.static(range(32)): + t.r[i] = qd.f32(i) + out[0] = t.r[3] +``` + +Read access works the same way: + +```python +v = t.r[5] # python-int index +v = t.r[i] # i bound by `for i in qd.static(range(N)):` +``` + +You can mix `unpacked_array` groups with regular scalar / vector fields on the same dataclass; they are independent. You can also have several `unpacked_array` groups in one struct: + +```python +@qd.dataclass +class TwoTiles: + a: qd.unpacked_array(32, qd.f32) + b: qd.unpacked_array(32, qd.f32) + scale: qd.f32 +``` + +The generated struct has 65 scalar members (`_a0..._a31`, `_b0..._b31`, `scale`). + +## When to reach for it + +Use `unpacked_array` when: + +- the group is *small and statically-sized*, and +- the kernel body accesses it with python-int / `qd.static`-resolved indices (typically unrolled inner loops), and +- you have measured (or strongly suspect) that an equivalent `qd.types.vector(N, dtype)` is leaving slots in local memory under register pressure. + +A good signal is `ptxas` reporting non-zero "bytes spill stores / loads" for the kernel, or `ld.local` / `st.local` instructions in the generated PTX that don't correspond to a deliberate shared-memory access. + +Prefer `qd.types.vector(N, dtype)` for small groups where register pressure is low and runtime indexing is needed — vectors keep all the usual arithmetic conveniences (element-wise ops, dot products, etc.) that `unpacked_array` does not. + +## Constraints and limitations + +- **Static indices only.** `t.r[k]` must resolve at compile time, i.e. `k` is a python-int literal or a `for k in qd.static(range(N)):` loop variable. A runtime-int index raises a `QuadrantsSyntaxError` at compile time with a message pointing at the `qd.static` requirement. If you need a runtime index over the group, spell out the cascade explicitly (`if k == 0: ...`). +- **Static out-of-bounds is rejected at compile time.** `t.r[7]` on an `unpacked_array(4, ...)` group raises `QuadrantsSyntaxError: unpacked_array index out of bounds: r[7] (count=4)`. +- **Storage only.** An `unpacked_array` group has no vector arithmetic. There is no `t.r + other`, no `t.r.dot(...)`, no broadcast operations. If you want those, use `qd.types.vector(N, dtype)` instead. +- **`count` is fixed at struct-definition time.** It must be a positive python-int literal. +- **Naming.** The synthetic fields use the convention `_{group_name}{i}` (e.g. `_r0`, `_r1`, ..., `_r31`). Avoid declaring your own field with a name that collides with one of those, or `StructType` will report a duplicate member. + +## Relationship to other annotations + +| annotation | storage layout | runtime indexing | best for | +|-------------------------------------|---------------------------------|:----------------:|---------------------------------------| +| `qd.f32` (per-field) | one `alloca` per field | n/a | individually-named scalars | +| `qd.types.vector(N, dtype)` | one packed `alloca` | yes | small groups with vector arithmetic | +| `qd.unpacked_array(N, dtype)` | `N` independent `alloca`s | no | groups that need to stay register-resident under pressure | + +Under low register pressure the three options generate similar code. Under high register pressure `unpacked_array` is the one most likely to stay in registers because the optimiser can promote each slot independently. + +## See also + +- {doc}`compound_types` — `@qd.dataclass` overview +- {doc}`matrix_vector_per_thread` — `qd.types.vector` and per-thread matrices +- {doc}`linalg_per_thread` — examples of tile-resident linear algebra where register residency matters diff --git a/python/quadrants/lang/ast/ast_transformer.py b/python/quadrants/lang/ast/ast_transformer.py index 263a4a11a3..e24fb9d707 100644 --- a/python/quadrants/lang/ast/ast_transformer.py +++ b/python/quadrants/lang/ast/ast_transformer.py @@ -40,6 +40,7 @@ from quadrants.lang.matrix import Matrix, MatrixType from quadrants.lang.snode import append, deactivate, length from quadrants.lang.struct import Struct, StructType +from quadrants.lang.unpacked_array import _UnpackedArrayRef from quadrants.lang.util import ( is_from_quadrants_module as _is_from_quadrants_module, ) @@ -295,6 +296,17 @@ def _unpack_layout_vector_index(ast_builder, index, layout_len): def build_Subscript(ctx: ASTTransformerFuncContext, node: ast.Subscript): build_stmt(ctx, node.value) build_stmt(ctx, node.slice) + # ``unpacked_array`` group subscript: rewrite ``obj.{group}[k]`` to a direct reference to the synthetic scalar + # field ``_{group}{k}`` when ``k`` is a python-int. The resolution lives entirely at AST-build time so the + # runtime IR/PTX is byte-identical to the named-field form. Runtime indices are not supported here -- the user + # must spell the cascade explicitly. + if isinstance(node.value.ptr, _UnpackedArrayRef): + slice_val = node.slice.ptr + node.ptr = node.value.ptr._qd_field_for(slice_val) + node.violates_pure = node.value.violates_pure + if node.violates_pure: + node.violates_pure_reason = node.value.violates_pure_reason + return node.ptr if not ASTTransformer.is_tuple(node.slice): node.slice.ptr = [node.slice.ptr] # Tensors layout: a layout-tagged Ndarray or Field with a non-identity ``_qd_layout`` has its canonical indices @@ -741,6 +753,16 @@ def build_Attribute(ctx: ASTTransformerFuncContext, node: ast.Attribute): node.ptr = node.ptr._unwrap() node.ptr = ASTTransformer._promote_ndarray_if_declared(ctx, node.ptr) else: + # ``unpacked_array`` group access on a ``@qd.dataclass`` Struct expression. Returns a transient + # ``_UnpackedArrayRef`` that ``build_Subscript`` (or its assignment-LHS sibling) resolves to a direct field + # reference. The lookup is by-name on ``_qd_unpacked_groups``, which ``StructType.__call__`` attaches to + # every Struct instance whose type declared at least one ``unpacked_array`` annotation. Tested in + # ``test_unpacked_array.py``. + groups = getattr(node.value.ptr, "_qd_unpacked_groups", None) + if groups and node.attr in groups: + count, dtype, naming_fn = groups[node.attr] + node.ptr = _UnpackedArrayRef(node.value.ptr, node.attr, count, dtype, naming_fn) + return node.ptr node.ptr = getattr(node.value.ptr, node.attr) # ``qd.Tensor`` wrappers reached via attribute access on a ``@qd.data_oriented`` struct field at AST-build # time. The IR layer downstream (``build_Subscript`` -> ``impl.subscript``) only knows about ``Ndarray`` / diff --git a/python/quadrants/lang/impl.py b/python/quadrants/lang/impl.py index e1767a4c7b..95d6652eb9 100644 --- a/python/quadrants/lang/impl.py +++ b/python/quadrants/lang/impl.py @@ -103,7 +103,15 @@ def expr_init(rhs): if isinstance(rhs, BufferView): return rhs if isinstance(rhs, Struct): - return Struct(rhs.to_dict(include_methods=True, include_ndim=True)) + new_struct = Struct(rhs.to_dict(include_methods=True, include_ndim=True)) + # Preserve ``unpacked_array`` group metadata across the rewrap so the AST transformer can still resolve + # ``obj.{group}[k]`` on the re-emitted Struct. + groups = getattr(rhs, "_qd_unpacked_groups", None) + if groups is not None: + # setattr (rather than attribute assignment) sidesteps pyright's reportAttributeAccessIssue; + # ``Struct`` doesn't statically declare this attribute -- it's a per-instance metadata tag. + setattr(new_struct, "_qd_unpacked_groups", groups) + return new_struct if isinstance(rhs, list): return [expr_init(e) for e in rhs] if isinstance(rhs, tuple): diff --git a/python/quadrants/lang/struct.py b/python/quadrants/lang/struct.py index c86efb1a99..bf2a4399f1 100644 --- a/python/quadrants/lang/struct.py +++ b/python/quadrants/lang/struct.py @@ -15,6 +15,12 @@ from quadrants.lang.expr import Expr from quadrants.lang.field import Field, ScalarField, SNodeHostAccess from quadrants.lang.matrix import Matrix, MatrixType +from quadrants.lang.unpacked_array import ( + UnpackedArray, + _expand_unpacked_array_naming, + _UnpackedArrayRef, + unpacked_array, +) from quadrants.lang.util import ( cook_dtype, in_python_scope, @@ -601,10 +607,20 @@ class StructType(CompoundType): def __init__(self, **kwargs): self.members = {} self.methods = {} + # Maps group name -> (count, dtype, naming_fn). Populated when a member annotation is an ``UnpackedArray``; + # consumed by the AST transformer to rewrite ``obj.{group}[i]`` into a direct synthetic-field reference. + self._unpacked_groups: dict = {} elements = [] for k, dtype in kwargs.items(): if k == "__struct_methods": self.methods = dtype + elif isinstance(dtype, UnpackedArray): + cooked = cook_dtype(dtype.dtype) + self._unpacked_groups[k] = (dtype.count, cooked, _expand_unpacked_array_naming) + for i in range(dtype.count): + sub = _expand_unpacked_array_naming(k, i) + self.members[sub] = cooked + elements.append([cooked, sub]) elif isinstance(dtype, StructType): self.members[k] = dtype elements.append([dtype.dtype, k]) @@ -640,6 +656,12 @@ def __call__(self, *args, **kwargs): entries._Struct__dtype = self.dtype struct = self.cast(entries) struct._Struct__dtype = self.dtype + # Tag the freshly-built Struct expression-object (representing this ``Tile()`` instantiation in the kernel's + # IR) with the unpacked-array group dictionary, so ``ASTTransformer.build_Attribute`` can recognise + # ``obj.r`` as a group name. The transformer inspects the instance, not the StructType, so the metadata has + # to live here. + if self._unpacked_groups: + struct._qd_unpacked_groups = self._unpacked_groups return struct def __instancecheck__(self, instance): @@ -832,4 +854,4 @@ def dataclass(cls): return StructType(**fields) -__all__ = ["Struct", "StructField", "dataclass"] +__all__ = ["Struct", "StructField", "dataclass", "UnpackedArray", "unpacked_array", "_UnpackedArrayRef"] diff --git a/python/quadrants/lang/unpacked_array.py b/python/quadrants/lang/unpacked_array.py new file mode 100644 index 0000000000..a34653b550 --- /dev/null +++ b/python/quadrants/lang/unpacked_array.py @@ -0,0 +1,125 @@ +# type: ignore +"""``qd.unpacked_array`` -- indexed groups of independently-allocated scalar fields on ``@qd.dataclass``. + +An ``unpacked_array(N, dtype)`` annotation expands at struct-definition time into N individually-named synthetic scalar +members (``_{group}0`` .. ``_{group}{N-1}``). The AST transformer rewrites ``obj.{group}[i]`` into a direct reference to +``obj._{group}{i}`` for python-int / ``qd.static``-resolved indices, so generated LLVM IR / PTX is byte-identical to a +hand-rolled named-field struct. + +Compare to ``qd.types.vector(N, dtype)`` which is the *packed* layout: one ``alloca`` covers all N slots. Packed storage +is fine until register pressure rises -- once LLVM SROA fails to decompose the packed ``alloca`` (e.g. two concurrent +tiles in a Cholesky + TRSM kernel), the whole group spills to local memory as a unit. ``unpacked_array`` lays each slot +out in its own ``alloca`` up front, so SROA + ``mem2reg`` can promote slots independently and the optimiser can spill +only the ones it has to. The ergonomic indexed-access syntax is preserved at the source level. + +Public: +- ``UnpackedArray`` - type wrapper used as the annotation value +- ``unpacked_array`` - factory: ``r: qd.unpacked_array(N, dtype)`` + +Internal (used by ``StructType`` and the AST transformer): +- ``_expand_unpacked_array_naming(group, i)`` - synthetic-field naming convention +- ``_UnpackedArrayRef`` - transient proxy yielded by attribute access + +This module has no dependency on ``struct.py``; ``struct.py`` imports from here. +""" + +import numpy as np + +from quadrants.lang.exception import QuadrantsSyntaxError + + +class UnpackedArray: + """Type wrapper for a group of N scalar fields exposed via indexed syntax on a ``@qd.dataclass``. + + See :func:`unpacked_array` for the user-facing constructor and the motivation writeup. Holding only ``count`` and + ``dtype``, this object is consumed at struct-definition time by ``StructType.__init__`` to lay out the N synthetic + scalar fields. + """ + + def __init__(self, count, dtype): + if not isinstance(count, int) or count <= 0: + raise QuadrantsSyntaxError(f"unpacked_array count must be a positive int, got {count!r}") + self.count = count + self.dtype = dtype + + def __repr__(self): + return f"unpacked_array(count={self.count}, dtype={self.dtype})" + + +def unpacked_array(count, dtype): + """Declare a group of ``count`` independently-allocated fields of ``dtype`` on a ``@qd.dataclass``. + + The annotation expands at struct-definition time into ``count`` individually-named scalar members (``_{group}0`` .. + ``_{group}{count-1}``). Each member gets its own LLVM ``alloca``, which lets SROA + ``mem2reg`` promote each slot + into its own SSA value independently. The contrast is with ``qd.types.vector(N, dtype)``, which lays all N slots + out in a single packed ``alloca``: when register pressure makes SROA fail to decompose the packed ``alloca``, the + whole group spills to local memory as a unit. With ``unpacked_array`` the storage is already unpacked, so the + optimiser can keep individual slots in registers and only spill the ones it has to. + + Example:: + + @qd.dataclass + class Tile: + r: qd.unpacked_array(32, qd.f32) # 32 scalar fields exposed as t.r[0..31] + + t = Tile() + t.r[5] = 1.0 # lowers to direct write of synthetic field _r5 + v = t.r[5] # same as v = t._r5 + for k in qd.static(range(32)): + t.r[k] = 0.0 # each iter is one AST node, not a 32-way cascade + + For python-int / ``qd.static``-resolved indices, ``t.r[k]`` is rewritten by the AST transformer to the named-field + access ``t._r{k}``, producing identical LLVM IR / PTX to a struct declared with N individually-named scalar fields. + + Runtime-int indexing is currently unsupported; use an explicit cascade helper for that case. + """ + return UnpackedArray(count, dtype) + + +def _expand_unpacked_array_naming(group_name, index): + """Naming convention for the synthetic scalar fields of an ``unpacked_array`` group. + + Public-ish so the AST transformer can mirror this without a circular import on ``struct``. + """ + return f"_{group_name}{index}" + + +class _UnpackedArrayRef: + """Transient proxy returned by the AST transformer for ``obj.{group}`` where ``group`` is an unpacked-array group + declared on the struct type. + + Only valid as the value of a Subscript node: ``obj.{group}[i]``. Resolved by ``ASTTransformer.build_Subscript`` + to a direct reference to the synthetic scalar field ``_{group}{i}`` when ``i`` is a python-int / + ``qd.static``-resolved integer. + + Used as a not-an-Expr marker; any attempt to use it as a value raises. + """ + + _qd_is_unpacked_array_ref = True + + def __init__(self, struct, group_name: str, count: int, dtype, naming_fn): + self._qd_struct = struct + self._qd_group_name = group_name + self._qd_count = count + self._qd_dtype = dtype + self._qd_naming_fn = naming_fn + + def _qd_field_for(self, index: int): + if not isinstance(index, (int, np.integer)): + raise QuadrantsSyntaxError( + f"unpacked_array {self._qd_group_name}[i] requires a python-int index " + f"(possibly via qd.static); got runtime index of type {type(index).__name__}" + ) + i = int(index) + if i < 0 or i >= self._qd_count: + raise QuadrantsSyntaxError( + f"unpacked_array index out of bounds: {self._qd_group_name}[{i}] " f"(count={self._qd_count})" + ) + field_name = self._qd_naming_fn(self._qd_group_name, i) + return getattr(self._qd_struct, field_name) + + def __repr__(self) -> str: # pragma: no cover - debug only + return f"" + + +__all__ = ["UnpackedArray", "unpacked_array"] diff --git a/tests/python/test_api.py b/tests/python/test_api.py index 2f183d5ba7..db28f6b035 100644 --- a/tests/python/test_api.py +++ b/tests/python/test_api.py @@ -82,6 +82,7 @@ def _get_expected_matrix_apis(): "Struct", "StructField", "TRACE", + "UnpackedArray", "QuadrantsAssertionError", "QuadrantsCompilationError", "QuadrantsNameError", @@ -251,6 +252,7 @@ def _get_expected_matrix_apis(): "uint32", "uint64", "uint8", + "unpacked_array", "volatile_load", "vulkan", "x64", diff --git a/tests/python/test_unpacked_array.py b/tests/python/test_unpacked_array.py new file mode 100644 index 0000000000..4b1b8bc51a --- /dev/null +++ b/tests/python/test_unpacked_array.py @@ -0,0 +1,223 @@ +# pyright: reportInvalidTypeForm=false +"""Tests for ``qd.unpacked_array(N, dtype)`` on ``@qd.dataclass``. + +``unpacked_array`` gives users an ergonomic indexed-write syntax on a per-thread struct, while keeping the underlying +storage as N separate named scalar fields so SROA + ``mem2reg`` can register-promote each slot independently. The +static-index case must lower to a direct field reference; PTX must be byte-identical to the named-field equivalent. +""" + +import numpy as np +import pytest + +import quadrants as qd + +try: + from tests import test_utils # noqa: F401 +except ImportError: # standalone-run convenience + test_utils = None # type: ignore + + +def _qd_init_cuda(): + qd.init(arch=qd.cuda, default_fp=qd.f32, offline_cache=False) + + +# --------------------------------------------------------------------------- +# Basic API smoke tests (struct construction). +# --------------------------------------------------------------------------- + + +def test_unpacked_array_construction_python_scope(): + """A dataclass with ``r: qd.unpacked_array(N, dtype)`` should construct as if it had N named scalar fields named + ``_r0.._r{N-1}``.""" + _qd_init_cuda() + + @qd.dataclass + class Tile: + r: qd.unpacked_array(4, qd.f32) + + # The underlying struct type should report N synthetic scalar members plus expose ``r`` as a group name. + assert hasattr(Tile, "_unpacked_groups") + groups = Tile._unpacked_groups + assert "r" in groups + count, dtype, _ = groups["r"] + assert count == 4 + assert dtype is qd.f32 + + # The underlying scalar fields must exist. + assert "_r0" in Tile.members + assert "_r3" in Tile.members + assert "r" not in Tile.members # ``r`` is a logical group, not a real member + + +# --------------------------------------------------------------------------- +# Static-index reads / writes (the hot path). +# --------------------------------------------------------------------------- + + +def test_unpacked_array_static_index_write_then_read(): + """Write to ``t.r[0..3]`` with python-int indices, then read back.""" + _qd_init_cuda() + + @qd.dataclass + class Tile: + r: qd.unpacked_array(4, qd.f32) + + out = qd.field(dtype=qd.f32, shape=(4,)) + + @qd.kernel(fastcache=False) + def k(o: qd.template()): + for _ in range(1): + t = Tile() + t.r[0] = qd.f32(1.0) + t.r[1] = qd.f32(2.0) + t.r[2] = qd.f32(3.0) + t.r[3] = qd.f32(4.0) + o[0] = t.r[0] + o[1] = t.r[1] + o[2] = t.r[2] + o[3] = t.r[3] + + k(out) + np.testing.assert_array_equal(out.to_numpy(), np.array([1, 2, 3, 4], dtype=np.float32)) + + +def test_unpacked_array_qd_static_loop_index(): + """Index via a ``qd.static(range(N))`` loop variable. Each iter sees a python-int index, so the lowering must be the + same direct-field path as the explicit python-int case.""" + _qd_init_cuda() + + @qd.dataclass + class Tile: + r: qd.unpacked_array(4, qd.f32) + + out = qd.field(dtype=qd.f32, shape=(4,)) + + @qd.kernel(fastcache=False) + def k(o: qd.template()): + for _ in range(1): + t = Tile() + for i in qd.static(range(4)): + t.r[i] = qd.f32(10.0 + i) + for i in qd.static(range(4)): + o[i] = t.r[i] + + k(out) + np.testing.assert_array_equal(out.to_numpy(), np.array([10, 11, 12, 13], dtype=np.float32)) + + +# --------------------------------------------------------------------------- +# Equivalence with named-field baseline: identical PTX. +# --------------------------------------------------------------------------- + + +def _build_named_kernel(): + """Same as test_unpacked_array_static_index_write_then_read but with 4 named ``r0..r3`` fields. Used for PTX byte- + equality comparison against the ``unpacked_array`` form.""" + + @qd.dataclass + class TileNamed: + r0: qd.f32 + r1: qd.f32 + r2: qd.f32 + r3: qd.f32 + + out = qd.field(dtype=qd.f32, shape=(4,)) + + @qd.kernel(fastcache=False) + def k(o: qd.template()): + for _ in range(1): + t = TileNamed() + t.r0 = qd.f32(1.0) + t.r1 = qd.f32(2.0) + t.r2 = qd.f32(3.0) + t.r3 = qd.f32(4.0) + o[0] = t.r0 + o[1] = t.r1 + o[2] = t.r2 + o[3] = t.r3 + + return k, out + + +def _build_unpacked_array_kernel(): + @qd.dataclass + class TileRA: + r: qd.unpacked_array(4, qd.f32) + + out = qd.field(dtype=qd.f32, shape=(4,)) + + @qd.kernel(fastcache=False) + def k(o: qd.template()): + for _ in range(1): + t = TileRA() + t.r[0] = qd.f32(1.0) + t.r[1] = qd.f32(2.0) + t.r[2] = qd.f32(3.0) + t.r[3] = qd.f32(4.0) + o[0] = t.r[0] + o[1] = t.r[1] + o[2] = t.r[2] + o[3] = t.r[3] + + return k, out + + +def test_unpacked_array_runtime_index_rejected(): + """Indexing ``t.r[k]`` with a runtime ``k`` raises a clear error pointing at the python-int / ``qd.static`` + requirement. Long term the runtime case can lower to an explicit cascade; for now the limitation is surfaced + early so callers don't get a confusing LLVM/SROA failure downstream.""" + _qd_init_cuda() + + @qd.dataclass + class Tile: + r: qd.unpacked_array(4, qd.f32) + + out = qd.field(dtype=qd.f32, shape=(4,)) + + @qd.kernel(fastcache=False) + def k(o: qd.template()): + for _ in range(1): + t = Tile() + for i in range(4): # runtime loop, not qd.static + t.r[i] = qd.f32(i) + o[0] = t.r[0] + + with pytest.raises(Exception) as e: + k(out) + msg = str(e.value) + assert "unpacked_array" in msg and "python-int" in msg, msg + + +def test_unpacked_array_oob_static_index(): + """Static-int out-of-bounds index is caught at compile time with a clear message.""" + _qd_init_cuda() + + @qd.dataclass + class Tile: + r: qd.unpacked_array(4, qd.f32) + + out = qd.field(dtype=qd.f32, shape=(4,)) + + @qd.kernel(fastcache=False) + def k(o: qd.template()): + for _ in range(1): + t = Tile() + t.r[7] = qd.f32(1.0) # 7 >= count=4 + o[0] = t.r[0] + + with pytest.raises(Exception) as e: + k(out) + assert "out of bounds" in str(e.value), str(e.value) + + +if __name__ == "__main__": + test_unpacked_array_construction_python_scope() + print("construction test passed") + test_unpacked_array_static_index_write_then_read() + print("static-int subscript test passed") + test_unpacked_array_qd_static_loop_index() + print("qd.static loop-var subscript test passed") + test_unpacked_array_runtime_index_rejected() + print("runtime-index rejection test passed") + test_unpacked_array_oob_static_index() + print("static OOB rejection test passed")