Skip to content
2 changes: 2 additions & 0 deletions docs/source/user_guide/compound_types.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,5 @@ class Particle:
vel: qd.types.vector(3, qd.f32)
mass: qd.f32
```

For larger statically-indexed groups (e.g. a row or tile of scalars that you want SROA to register-promote independently), see {doc}`register_array`.
1 change: 1 addition & 0 deletions docs/source/user_guide/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ matrix_vector_per_thread
linalg_per_thread
tensor
compound_types
register_array
buffer_view
static
sub_functions
Expand Down
133 changes: 133 additions & 0 deletions docs/source/user_guide/register_array.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Register array

`qd.register_array(N, dtype)` is a `@qd.dataclass` field annotation that declares a group of `N` independently-allocated scalar fields exposed under a single name via indexed syntax.

It is a *layout hint*, not a new container. At source level you write `t.r[i]`; the compiler lowers that to a direct reference to a synthetic scalar field `t._r{i}`. The generated LLVM IR / PTX is byte-identical to a struct that was declared with `N` individually-named scalar fields.

## What problem does it solve?

The intuitive way to group `N` scalars on a per-thread struct is to declare a vector member:

```python
@qd.dataclass
class Tile:
r: qd.types.vector(32, qd.f32)
```

`qd.types.vector(N, dtype)` lays the group out as one packed `alloca`. LLVM's scalar-replacement-of-aggregates pass (SROA + `mem2reg`) tries to break that `alloca` apart into per-slot SSA values so each one can live in a register — but once a kernel's register pressure crosses a threshold (e.g. two concurrent `32x32` tiles in a Cholesky + triangular solve), SROA bails out and the whole packed `alloca` spills to local memory. Each access then turns into a `ld.local` / `st.local` and the kernel slows down dramatically.

The alternative is to declare `N` named scalar fields by hand:

```python
@qd.dataclass
class Tile:
r0: qd.f32
r1: qd.f32
# ... 30 more lines ...
r31: qd.f32
```

Now each slot has its own `alloca`, and SROA + `mem2reg` can promote each one into a register independently. The optimiser is also free to spill only the slots it has to, instead of the whole group as a unit. The cost is that every index becomes a cascade in source:

```python
def get_r(t, k):
if k == 0:
return t.r0
elif k == 1:
return t.r1
# ... 30 more branches ...
```

…which is duplicated at every call site that wants to read or write the group.

`qd.register_array` is the named-field layout with the ergonomic indexed syntax restored:

```python
@qd.dataclass
class Tile:
r: qd.register_array(32, qd.f32)
```

The annotation expands at struct-definition time into the `N` synthetic scalar fields. The AST transformer rewrites `obj.r[i]` (for any python-int / `qd.static`-resolved `i`) into a direct reference to the synthetic field `obj._r{i}`. The IR / PTX matches the hand-rolled named-field version exactly.

## How to use it

Declare the group as a `register_array(count, dtype)` annotation on a `@qd.dataclass`:

```python
import quadrants as qd

qd.init(arch=qd.gpu)


@qd.dataclass
class Tile:
r: qd.register_array(32, qd.f32)


@qd.kernel
def k(out: qd.types.NDArray[qd.f32, 1]) -> None:
t = Tile()
# python-int index: lowers to a direct write of t._r5
t.r[5] = 1.0
# qd.static loop variable: each iter is one AST node, fully unrolled,
# no per-iter cascade.
for i in qd.static(range(32)):
t.r[i] = qd.f32(i)
out[0] = t.r[3]
```

Read access works the same way:

```python
v = t.r[5] # python-int index
v = t.r[i] # i bound by `for i in qd.static(range(N)):`
```

You can mix `register_array` groups with regular scalar / vector fields on the same dataclass; they are independent. You can also have several `register_array` groups in one struct:

```python
@qd.dataclass
class TwoTiles:
a: qd.register_array(32, qd.f32)
b: qd.register_array(32, qd.f32)
scale: qd.f32
```

The generated struct has 65 scalar members (`_a0..._a31`, `_b0..._b31`, `scale`).

## When to reach for it

Use `register_array` when:

- the group is *small and statically-sized*, and
- the kernel body accesses it with python-int / `qd.static`-resolved indices (typically unrolled inner loops), and
- you have measured (or strongly suspect) that an equivalent `qd.types.vector(N, dtype)` is leaving slots in local memory under register pressure.

A good signal is `ptxas` reporting non-zero "bytes spill stores / loads" for the kernel, or `ld.local` / `st.local` instructions in the generated PTX that don't correspond to a deliberate shared-memory access.

Prefer `qd.types.vector(N, dtype)` for small groups where register pressure is low and runtime indexing is needed — vectors keep all the usual arithmetic conveniences (element-wise ops, dot products, etc.) that `register_array` does not.

## Constraints and limitations

- **Static indices only.** `t.r[k]` must resolve at AST-build time, i.e. `k` is a python-int literal or a `for k in qd.static(range(N)):` loop variable. A runtime-int index raises a `QuadrantsSyntaxError` at trace time with a message pointing at the `qd.static` requirement. If you need a runtime index over the group, spell out the cascade explicitly (`if k == 0: ...`).
- **Static out-of-bounds is rejected at trace time.** `t.r[7]` on a `register_array(4, ...)` group raises `QuadrantsSyntaxError: register_array index out of bounds: r[7] (count=4)`.
- **Storage only.** A `register_array` group has no vector arithmetic. There is no `t.r + other`, no `t.r.dot(...)`, no broadcast operations. If you want those, use `qd.types.vector(N, dtype)` instead.
- **`count` is fixed at struct-definition time.** It must be a positive python-int literal.
- **Naming.** The synthetic fields use the convention `_{group_name}{i}` (e.g. `_r0`, `_r1`, ..., `_r31`). Avoid declaring your own field with a name that collides with one of those, or `StructType` will report a duplicate member.

## Relationship to other annotations

| annotation | storage layout | runtime indexing | best for |
|-------------------------------------|---------------------------------|:----------------:|---------------------------------------|
| `qd.f32` (per-field) | one `alloca` per field | n/a | individually-named scalars |
| `qd.types.vector(N, dtype)` | one packed `alloca` | yes | small groups with vector arithmetic |
| `qd.register_array(N, dtype)` | `N` independent `alloca`s | no | register-resident groups under load |

Under low register pressure the three options generate similar code. Under high register pressure `register_array` is the one most likely to stay in registers because the optimiser can promote each slot independently.

## See also

- {doc}`compound_types` — `@qd.dataclass` overview
- {doc}`matrix_vector_per_thread` — `qd.types.vector` and per-thread matrices
- {doc}`linalg_per_thread` — examples of tile-resident linear algebra where register residency matters
22 changes: 22 additions & 0 deletions python/quadrants/lang/ast/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from quadrants.lang.expr import Expr, make_expr_group
from quadrants.lang.field import Field
from quadrants.lang.matrix import Matrix, MatrixType
from quadrants.lang.register_array import _RegisterArrayRef
from quadrants.lang.snode import append, deactivate, length
from quadrants.lang.struct import Struct, StructType
from quadrants.lang.util import (
Expand Down Expand Up @@ -295,6 +296,17 @@ def _unpack_layout_vector_index(ast_builder, index, layout_len):
def build_Subscript(ctx: ASTTransformerFuncContext, node: ast.Subscript):
build_stmt(ctx, node.value)
build_stmt(ctx, node.slice)
# ``register_array`` group subscript: rewrite ``obj.{group}[k]`` to a direct reference to the synthetic scalar
# field ``_{group}{k}`` when ``k`` is a python-int. The resolution lives entirely at AST-build time so the
# runtime IR/PTX is byte-identical to the named-field form. Runtime indices are not supported here -- the user
# must spell the cascade explicitly.
if isinstance(node.value.ptr, _RegisterArrayRef):
slice_val = node.slice.ptr
node.ptr = node.value.ptr._qd_field_for(slice_val)
node.violates_pure = node.value.violates_pure
if node.violates_pure:
node.violates_pure_reason = node.value.violates_pure_reason
return node.ptr
if not ASTTransformer.is_tuple(node.slice):
node.slice.ptr = [node.slice.ptr]
# Tensors layout: a layout-tagged Ndarray or Field with a non-identity ``_qd_layout`` has its canonical indices
Expand Down Expand Up @@ -741,6 +753,16 @@ def build_Attribute(ctx: ASTTransformerFuncContext, node: ast.Attribute):
node.ptr = node.ptr._unwrap()
node.ptr = ASTTransformer._promote_ndarray_if_declared(ctx, node.ptr)
else:
# ``register_array`` group access on a ``@qd.dataclass`` Struct expression. Returns a transient
# ``_RegisterArrayRef`` that ``build_Subscript`` (or its assignment-LHS sibling) resolves to a direct field
# reference. The lookup is by-name on ``_qd_register_groups``, which ``StructType.__call__`` attaches to
# every Struct instance whose type declared at least one ``register_array`` annotation. Tested in
# ``test_register_array.py``.
groups = getattr(node.value.ptr, "_qd_register_groups", None)
if groups and node.attr in groups:
count, dtype, naming_fn = groups[node.attr]
node.ptr = _RegisterArrayRef(node.value.ptr, node.attr, count, dtype, naming_fn)
return node.ptr
node.ptr = getattr(node.value.ptr, node.attr)
# ``qd.Tensor`` wrappers reached via attribute access on a ``@qd.data_oriented`` struct field at AST-build
# time. The IR layer downstream (``build_Subscript`` -> ``impl.subscript``) only knows about ``Ndarray`` /
Expand Down
8 changes: 7 additions & 1 deletion python/quadrants/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,13 @@ def expr_init(rhs):
if isinstance(rhs, BufferView):
return rhs
if isinstance(rhs, Struct):
return Struct(rhs.to_dict(include_methods=True, include_ndim=True))
new_struct = Struct(rhs.to_dict(include_methods=True, include_ndim=True))
# Preserve ``register_array`` group metadata across the rewrap so the AST transformer can still resolve
# ``obj.{group}[k]`` on the re-emitted Struct.
groups = getattr(rhs, "_qd_register_groups", None)
if groups is not None:
new_struct._qd_register_groups = groups
return new_struct
if isinstance(rhs, list):
return [expr_init(e) for e in rhs]
if isinstance(rhs, tuple):
Expand Down
125 changes: 125 additions & 0 deletions python/quadrants/lang/register_array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# type: ignore
"""``qd.register_array`` -- indexed groups of independently-allocated scalar fields on ``@qd.dataclass``.

A ``register_array(N, dtype)`` annotation expands at struct-definition time into N individually-named synthetic scalar
members (``_{group}0`` .. ``_{group}{N-1}``). The AST transformer rewrites ``obj.{group}[i]`` into a direct reference to
``obj._{group}{i}`` for python-int / ``qd.static``-resolved indices, so generated LLVM IR / PTX is byte-identical to a
hand-rolled named-field struct.

The motivation is register residency under pressure. A packed ``qd.types.vector(N, dtype)`` collapses into one
``alloca`` that LLVM SROA cannot decompose once register pressure crosses a threshold (e.g. two concurrent tiles in a
Cholesky + TRSM kernel), causing spills to local memory. ``register_array`` pre-decomposes the storage so SROA +
``mem2reg`` can register-promote each slot independently, while keeping the ergonomic indexed-access syntax at the
source level.

Public:
- ``RegisterArray`` - type wrapper used as the annotation value
- ``register_array`` - factory: ``r: qd.register_array(N, dtype)``

Internal (used by ``StructType`` and the AST transformer):
- ``_expand_register_array_naming(group, i)`` - synthetic-field naming convention
- ``_RegisterArrayRef`` - transient proxy yielded by attribute access

This module has no dependency on ``struct.py``; ``struct.py`` imports from here.
"""

import numpy as np

from quadrants.lang.exception import QuadrantsSyntaxError


class RegisterArray:
"""Type wrapper for a group of N scalar fields exposed via indexed syntax on a ``@qd.dataclass``.

See :func:`register_array` for the user-facing constructor and the motivation writeup. Holding only ``count`` and
``dtype``, this object is consumed at struct-definition time by ``StructType.__init__`` to lay out the N synthetic
scalar fields.
"""

def __init__(self, count, dtype):
if not isinstance(count, int) or count <= 0:
raise QuadrantsSyntaxError(f"register_array count must be a positive int, got {count!r}")
self.count = count
self.dtype = dtype

def __repr__(self):
return f"register_array(count={self.count}, dtype={self.dtype})"


def register_array(count, dtype):
"""Declare a group of ``count`` independently-allocated fields of ``dtype`` on a ``@qd.dataclass``.

The annotation expands at struct-definition time into ``count`` individually-named scalar members (``_{group}0`` ..
``_{group}{count-1}``). Each member gets its own LLVM ``alloca``, which lets SROA + ``mem2reg`` promote each slot
into its own SSA value independently. The motivation is register residency under pressure:
``qd.types.vector(N, dtype)`` collapses into one packed ``alloca`` that the optimiser often spills as a unit when
register pressure crosses a threshold; ``register_array`` decomposes the storage up-front so the compiler can keep
individual slots in registers and only spill the ones it has to.

Example::

@qd.dataclass
class Tile:
r: qd.register_array(32, qd.f32) # 32 scalar fields exposed as t.r[0..31]

t = Tile()
t.r[5] = 1.0 # lowers to direct write of synthetic field _r5
v = t.r[5] # same as v = t._r5
for k in qd.static(range(32)):
t.r[k] = 0.0 # each iter is one AST node, not a 32-way cascade

For python-int / ``qd.static``-resolved indices, ``t.r[k]`` is rewritten by the AST transformer to the named-field
access ``t._r{k}``, producing identical LLVM IR / PTX to a struct declared with N individually-named scalar fields.

Runtime-int indexing is currently unsupported; use an explicit cascade helper for that case.
"""
return RegisterArray(count, dtype)


def _expand_register_array_naming(group_name, index):
"""Naming convention for the synthetic scalar fields of a ``register_array`` group.

Public-ish so the AST transformer can mirror this without a circular import on ``struct``.
"""
return f"_{group_name}{index}"


class _RegisterArrayRef:
"""Transient proxy returned by the AST transformer for ``obj.{group}`` where ``group`` is a registered group on the
struct type.

Only valid as the value of a Subscript node: ``obj.{group}[i]``. Resolved by ``ASTTransformer.build_Subscript``
to a direct reference to the synthetic scalar field ``_{group}{i}`` when ``i`` is a python-int /
``qd.static``-resolved integer.

Used as a not-an-Expr marker; any attempt to use it as a value raises.
"""

_qd_is_register_array_ref = True

def __init__(self, struct, group_name: str, count: int, dtype, naming_fn):
self._qd_struct = struct
self._qd_group_name = group_name
self._qd_count = count
self._qd_dtype = dtype
self._qd_naming_fn = naming_fn

def _qd_field_for(self, index: int):
if not isinstance(index, (int, np.integer)):
raise QuadrantsSyntaxError(
f"register_array {self._qd_group_name}[i] requires a python-int index "
f"(possibly via qd.static); got runtime index of type {type(index).__name__}"
)
i = int(index)
if i < 0 or i >= self._qd_count:
raise QuadrantsSyntaxError(
f"register_array index out of bounds: {self._qd_group_name}[{i}] " f"(count={self._qd_count})"
)
field_name = self._qd_naming_fn(self._qd_group_name, i)
return getattr(self._qd_struct, field_name)

def __repr__(self) -> str: # pragma: no cover - debug only
return f"<register_array_ref group={self._qd_group_name!r} count={self._qd_count} " f"dtype={self._qd_dtype}>"


__all__ = ["RegisterArray", "register_array"]
22 changes: 21 additions & 1 deletion python/quadrants/lang/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@
from quadrants.lang.expr import Expr
from quadrants.lang.field import Field, ScalarField, SNodeHostAccess
from quadrants.lang.matrix import Matrix, MatrixType
from quadrants.lang.register_array import (
RegisterArray,
_expand_register_array_naming,
_RegisterArrayRef,
register_array,
)
from quadrants.lang.util import (
cook_dtype,
in_python_scope,
Expand Down Expand Up @@ -601,10 +607,20 @@ class StructType(CompoundType):
def __init__(self, **kwargs):
self.members = {}
self.methods = {}
# Maps group name -> (count, dtype, naming_fn). Populated when a member annotation is a ``RegisterArray``;
# consumed by the AST transformer to rewrite ``obj.{group}[i]`` into a direct synthetic-field reference.
self._register_groups: dict = {}
elements = []
for k, dtype in kwargs.items():
if k == "__struct_methods":
self.methods = dtype
elif isinstance(dtype, RegisterArray):
cooked = cook_dtype(dtype.dtype)
self._register_groups[k] = (dtype.count, cooked, _expand_register_array_naming)
for i in range(dtype.count):
sub = _expand_register_array_naming(k, i)
self.members[sub] = cooked
elements.append([cooked, sub])
elif isinstance(dtype, StructType):
self.members[k] = dtype
elements.append([dtype.dtype, k])
Expand Down Expand Up @@ -640,6 +656,10 @@ def __call__(self, *args, **kwargs):
entries._Struct__dtype = self.dtype
struct = self.cast(entries)
struct._Struct__dtype = self.dtype
# Propagate the register-array group metadata onto the Struct instance so the AST transformer can detect
# indexed-group access (``obj.r``) on the per-trace expression.
if self._register_groups:
struct._qd_register_groups = self._register_groups
return struct

def __instancecheck__(self, instance):
Expand Down Expand Up @@ -832,4 +852,4 @@ def dataclass(cls):
return StructType(**fields)


__all__ = ["Struct", "StructField", "dataclass"]
__all__ = ["Struct", "StructField", "dataclass", "RegisterArray", "register_array", "_RegisterArrayRef"]
2 changes: 2 additions & 0 deletions tests/python/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def _get_expected_matrix_apis():
"Mesh",
"MeshInstance",
"Ndarray",
"RegisterArray",
"SNode",
"ScalarField",
"ScalarNdarray",
Expand Down Expand Up @@ -214,6 +215,7 @@ def _get_expected_matrix_apis():
"raw_mod",
"real_func",
"ref",
"register_array",
"rescale_index",
"reset",
"root",
Expand Down
Loading
Loading