Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
88 commits
Select commit Hold shift + click to select a range
06b7c6a
[Test] Pin behaviour for @qd.data_oriented with raw qd.ndarray members
hughperkins May 16, 2026
d4350ef
[Fix] Recurse through nested data_oriented / dataclass children when …
hughperkins May 16, 2026
97afa6d
[Fix] Launch-context stale guard fires for @qd.data_oriented containe…
hughperkins May 16, 2026
49a723b
[Test] Extend @qd.data_oriented + ndarray coverage: cross-container n…
hughperkins May 16, 2026
9bdeca5
[Doc] @qd.data_oriented can contain ndarrays
hughperkins May 16, 2026
dc7997b
[Fix] Gap A: template-mapper spec key descends into data_oriented nda…
hughperkins May 16, 2026
906ce19
[Test] Gap A: spec-key descent into data_oriented ndarray members
hughperkins May 16, 2026
a0db648
[Fix] Template-mapper args_hash invalidates when data_oriented ndarra…
hughperkins May 16, 2026
c9598ad
[Fix] Clear error for @qd.data_oriented field type inside typed-datac…
hughperkins May 16, 2026
93893e5
[Perf] Per-class cache of data_oriented ndarray attribute paths for G…
hughperkins May 16, 2026
ce769a7
[Doc] Nesting compatibility matrix for compound types + spot tests
hughperkins May 16, 2026
dd4de40
[Doc] Fix @qd.struct ghost reference in compound_types
hughperkins May 16, 2026
46825ab
[Test] Pin fastcache + @qd.data_oriented + ndarray end-to-end behavior
hughperkins May 16, 2026
ee5fbbb
[Doc] Fastcache with @qd.data_oriented: worked example, semantics, fo…
hughperkins May 16, 2026
b132b81
[Doc] Restructure fastcache.md: simple main body, Advanced subsection…
hughperkins May 16, 2026
6d1c820
[Doc] Use 'member' consistently for compound-type members; drop ambig…
hughperkins May 16, 2026
1de65b9
[Doc] Mirror qd.Template wording for @qd.data_oriented primitive memb…
hughperkins May 16, 2026
a648c3f
[Doc] @qd.data_oriented row: 'types and values' to mirror qd.Template…
hughperkins May 16, 2026
a55d360
[Doc] Tighten path-cache stability restriction: actual failure modes …
hughperkins May 16, 2026
6667ba6
[Test] Fix fastcache cross-init tests: filter captured launches by ke…
hughperkins May 16, 2026
e9c50b4
[Style] pre-commit auto-fixes: black wrap + ruff import-sort
hughperkins May 16, 2026
abf242b
[Doc] Move @qd.kernel inside @qd.data_oriented class in the ndarray-m…
hughperkins May 17, 2026
4c27e2e
[Doc] Document primitive members on @qd.data_oriented self as templat…
hughperkins May 17, 2026
1f539e6
[Doc] State ndarray-member subscript behaviour directly instead of cr…
hughperkins May 17, 2026
730cbcb
[Doc] Drop 'as with dataclasses.dataclass' cross-reference in ndarray…
hughperkins May 17, 2026
57e1b95
[Doc] Simplify fastcache cross-link in @qd.data_oriented section: dro…
hughperkins May 17, 2026
d4ca211
[Doc] Drop ndarray-reassign note and tighten fastcache cross-link in …
hughperkins May 17, 2026
b72a7a7
[Doc] Drop ndarray subscript-access description in @qd.data_oriented …
hughperkins May 17, 2026
18ff7bd
[Doc] Promote fastcache cross-link to its own ### Fastcache subsectio…
hughperkins May 17, 2026
33f4744
[Doc] Rename '### ndarray members' to '### Tensor members'; cover qd.…
hughperkins May 17, 2026
883243e
[Doc] @qd.data_oriented Fastcache subsection: spell out 'disabled for…
hughperkins May 17, 2026
3504250
[Doc] Tensor members: shorten qd.tensor description to 'or qd.Tensor'
hughperkins May 17, 2026
cc01339
[Doc] Tensor members: simplify nested-container sentence to 'Nested @…
hughperkins May 17, 2026
df3113e
[Doc] Fastcache subsection: 'methods of @qd.data_oriented classes'
hughperkins May 17, 2026
7f5fd12
[Doc] Tensor members: drop qd.Vector.ndarray / qd.Matrix.ndarray pare…
hughperkins May 17, 2026
e7fafeb
[Doc] Tensor members: drop the mixing-backends + nesting trailer sent…
hughperkins May 17, 2026
f9a35df
[Doc] Restrictions: drop redundant 'A few combinations are still unsu…
hughperkins May 17, 2026
d336dcd
[Doc] @qd.dataclass section opener: cut to the constraint
hughperkins May 17, 2026
4c5f622
[Doc] Remove top-level Recommendation section
hughperkins May 17, 2026
56a4399
[Doc] Expand @qd.dataclass section: what it does, when to use it, con…
hughperkins May 17, 2026
ef5f8a6
[Doc] @qd.dataclass section: drop use-cases / constraints / cross-ref…
hughperkins May 17, 2026
06580f1
[Doc] @qd.dataclass section opener: explain the kernel-side vs python…
hughperkins May 17, 2026
8899357
[Doc] Restore verbatim prose for the @qd.struct vs other-compound-typ…
hughperkins May 17, 2026
8fef507
[Doc] Replace @qd.struct with @qd.dataclass in opener prose (actual A…
hughperkins May 17, 2026
92f5fe1
[Doc] @qd.dataclass: 'element type of fields' not 'tensors'
hughperkins May 17, 2026
9ea8e5b
[Doc] @qd.dataclass: add sentences about @qd.func methods and qd.type…
hughperkins May 17, 2026
6ff0848
[Doc] @qd.dataclass methods sentence: 'Methods can be added to ... an…
hughperkins May 17, 2026
fd8cd0a
[Doc] @qd.dataclass section: move qd.types.struct paragraph to end wi…
hughperkins May 17, 2026
004cd9a
[Doc] qd.types.struct sentence: drop 'useful when members are compute…
hughperkins May 17, 2026
bf85e4e
[Doc] @qd.dataclass: split into bare-struct example, then methods + @…
hughperkins May 17, 2026
ccaae54
[Doc] First @qd.dataclass example uses AOS layout (the unique-to-Stru…
hughperkins May 17, 2026
820c01a
[Doc] Move 'Nesting compatibility' section to end of compound_types.md
hughperkins May 17, 2026
06d2e86
[Doc] Overview table: dataclasses.dataclass supports differentiation …
hughperkins May 17, 2026
f7dd090
[Test] AD through dataclasses.dataclass with ndarray, field, and qd.t…
hughperkins May 17, 2026
8c0377c
[Doc] compound_types: rephrase intro bullets to describe each type's …
hughperkins May 17, 2026
71a53da
[Doc] compound_types: prefix dataclasses.dataclass with @ in intro/ta…
hughperkins May 17, 2026
46fef24
[Test] AD dataclass: tensor(FIELD) member works when annotated as qd.…
hughperkins May 17, 2026
18f995b
[Doc] tensor: note qd.Tensor is also the dataclass-member annotation
hughperkins May 17, 2026
3ce0ab0
[Doc] compound_types: add 'Under the hood' subsection for each type
hughperkins May 17, 2026
35be370
[Doc] compound_types: rewrite 'Under the hood' subsections at a highe…
hughperkins May 17, 2026
94e455a
[Doc] compound_types: drop 'once' from compile-time capture phrasing
hughperkins May 17, 2026
31b27d7
[Doc] compound_types: replace overview table with differentiating one
hughperkins May 17, 2026
36dc933
[Doc] compound_types: drop 'historical reasons' line
hughperkins May 17, 2026
07dc486
[Fix] _build_struct_nd_paths: handle NamedTuple via _asdict() fallback
hughperkins May 18, 2026
3aa4fe1
[Fix] test_ad_dataclass: require data64 extension for f64 tests
hughperkins May 18, 2026
89bb005
[Style] test docstrings: reflow at 120c per repo line-width
hughperkins May 18, 2026
4923d68
[Doc] compound_types: rephrase top-table row to 'Can be used as tenso…
hughperkins May 29, 2026
16d50f5
[Doc] compound_types: clarify 'members read-only' for dataclasses.dat…
hughperkins May 29, 2026
e510a53
[Doc] compound_types: rewrite @qd.dataclass intro paragraph
hughperkins May 29, 2026
6d339a5
[Doc] compound_types: introduce 'Frozen vs non-frozen' early under da…
hughperkins May 29, 2026
6d6b25d
[Doc] compound_types: drop 'legacy' framing from @qd.dataclass restri…
hughperkins May 29, 2026
f788395
[Doc] compound_types: remove duplicate frozen=True restriction bullet
hughperkins May 29, 2026
786223a
[Doc] fastcache: move 'Compound-type cache keying' from Advanced to A…
hughperkins May 29, 2026
a7adc97
[Doc] _template_mapper_hotpath: FIXME for class-level path cache vs p…
hughperkins May 29, 2026
d396bb5
[Fix] launch_kernel: stale-cache guard OR's mutability across full at…
hughperkins May 29, 2026
eaf16fc
[Doc] compound_types: rephrase @qd.data_oriented intro per duburcqa f…
hughperkins May 29, 2026
15da770
[Doc] compound_types: clarify 'Kernel-side representation' row
hughperkins May 29, 2026
39a1df7
[Doc] compound_types: 'Members can be tensors' row + SoA caveat under…
hughperkins May 29, 2026
9f3b4e6
[Doc] compound_types: 'extrudes' instead of 'splits' for SoA member b…
hughperkins May 29, 2026
3b9f17a
[Doc] compound_types: use 'extrudes' in the canonical SoA bullet too
hughperkins May 29, 2026
0607d55
[Deprecate] @dataclasses.dataclass passed via qd.template()
hughperkins May 29, 2026
8d7e473
[Doc] Add @qd.func / kernel-arg-annotation rows to compound_types ove…
hughperkins May 29, 2026
d0eda52
[Doc] Reword "methods on self" → "instance methods" in compound_types…
hughperkins May 29, 2026
c2196b0
[Doc/Code] Use \`qd.Template\` (the class) instead of \`qd.template()…
hughperkins May 29, 2026
31c291a
[Doc] Trim the @qd.func / @dataclasses.dataclass cell to a bare \`no\`
hughperkins May 29, 2026
0a7506a
[Doc] Drop the historical-context parenthetical from Frozen-vs-non-fr…
hughperkins May 29, 2026
fa353e4
[Doc] Correct Frozen-vs-non-frozen — note that frozen=True speeds up …
hughperkins May 29, 2026
4cf3b8c
[Doc] Frozen-vs-non-frozen — frozen=True speeds up kernel launch
hughperkins May 29, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 171 additions & 15 deletions docs/source/user_guide/compound_types.md

Large diffs are not rendered by default.

66 changes: 33 additions & 33 deletions docs/source/user_guide/fastcache.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,43 +50,13 @@ qd.init(arch=qd.gpu)
# qd.init(arch=qd.gpu, print_non_pure=True)
```

## Dataclass fields with cached values

By default, for `dataclasses.dataclass` parameters, fastcache only includes the *types* of each field in the cache key, not their values. This is fine for fields like ndarrays, where the compiled kernel doesn't depend on the actual data, only the dtype and dimensionality.

However, some dataclass fields hold configuration values that get baked into the compiled kernel — typically values used with `qd.static()`, such as loop bounds or feature flags:

```python
for i in qd.static(range(config.num_layers)):
...
```

Here the value of `num_layers` is compiled into the kernel. Concretely the loop will be unrolled, at compile time. If `num_layers` changes, a different kernel must be compiled.

Mark such fields with `add_value_to_cache_key` so their values are included in the cache key:

```python
import dataclasses
from quadrants.lang._fast_caching import FIELD_METADATA_CACHE_VALUE

@dataclasses.dataclass
class SimConfig:
num_envs: int = dataclasses.field(metadata={FIELD_METADATA_CACHE_VALUE: True})
dt: float = dataclasses.field(metadata={FIELD_METADATA_CACHE_VALUE: True})
use_gravity: bool = dataclasses.field(metadata={FIELD_METADATA_CACHE_VALUE: True})
```

With this annotation, changing `num_envs` from 100 to 200 produces a different cache key so the correct compiled kernel is looked up (or compiled if not yet cached). Without it, the wrong kernel could be loaded.

Note: `@qd.data_oriented` objects and `qd.Template` parameters already include primitive values in the cache key automatically — this annotation is only needed for `dataclasses.dataclass` fields.

## Constraints

A kernel is eligible for fastcache only if all of the following hold:

### 1. All data flows through parameters

The kernel must receive every piece of data it operates on as an explicit parameter. It must **not** capture variables from the enclosing Python scope (closures over fields, ndarrays, or mutable globals). This is the core "purity" constraint — the compiled kernel's behavior must be fully determined by its arguments.
The kernel must receive every piece of data it operates on as an explicit parameter. It must **not** capture variables from the enclosing Python scope (closures over ndarrays, mutable globals, or any other external state). This is the core "purity" constraint — the compiled kernel's behavior must be fully determined by its arguments.

```python
a = qd.ndarray(qd.f32, (10,))
Expand Down Expand Up @@ -125,8 +95,8 @@ Fastcache supports the following parameter types:
| `qd.types.NDArray` (scalar, vector, matrix) | Yes | dtype, ndim, layout |
| `torch.Tensor` | Yes | dtype, ndim |
| `numpy.ndarray` | Yes | dtype, ndim |
| `dataclasses.dataclass` | Yes | field types recursively; field values if annotated with `add_value_to_cache_key` (see [above](#dataclass-fields-with-cached-values)) |
| `@qd.data_oriented` objects | Yes | member types and primitive member values recursively |
| `dataclasses.dataclass` | Yes | member types recursively; member values if annotated with `FIELD_METADATA_CACHE_VALUE` (see [Appendix — compound-type cache keying](#compound-type-cache-keying)) |
| `@qd.data_oriented` objects | Yes | member types recursively; primitive member types and values baked into kernel (see [Appendix — compound-type cache keying](#compound-type-cache-keying)) |
| `qd.Template` primitives (int, float, bool) | Yes | type and value (baked into kernel) |
| Non-template primitives (int, float, bool) | Yes | type only |
| `enum.Enum` | Yes | name and value |
Expand Down Expand Up @@ -172,3 +142,33 @@ print(obs.cache_stored) # True if the compiled kernel was stored to cach
```

On the first run you'll see `cache_stored=True` but `cache_loaded=False`. On the second run (after `qd.init`), `cache_loaded=True`.

## Appendix

### Compound-type cache keying

The args hasher walks compound-type kernel parameters recursively. For each leaf member it decides what (if anything) contributes to the cache key. The headline rules:

**`@qd.data_oriented`:** the walker descends into `vars(obj)`. For each child:

- `qd.ndarray` member — `(dtype, ndim, layout)` is included in the cache key. Element values are not.
- Primitive (`int` / `float` / `bool` / `enum.Enum`) member — value is baked into the kernel (same semantics as a `qd.Template` primitive). Two instances of the same class with different primitive member values get different cache entries.
- Nested `@qd.data_oriented` member — recurses.
- Nested `dataclasses.dataclass` member — recurses (with the dataclass rules below).
- `qd.field` member — fastcache is disabled for the entire kernel call. The kernel still runs via normal compilation; a warn-level log line is emitted.

**`dataclasses.dataclass`:** the walker descends into the declared members. For each member, only the *type* is included in the cache key by default — **not** the value. To include a member's value, annotate it:
Comment on lines +150 to +160
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This section is great, but should probably be moved in Appendix, because it is not supposed to be relevant for users.


```python
import dataclasses
from quadrants.lang._fast_caching import FIELD_METADATA_CACHE_VALUE

@dataclasses.dataclass
class SimConfig:
num_layers: int = dataclasses.field(metadata={FIELD_METADATA_CACHE_VALUE: True})
dt: float = dataclasses.field(metadata={FIELD_METADATA_CACHE_VALUE: True})
```

This is necessary whenever the compiled kernel depends on the member's *value* rather than just its type (for example, when the value is used as a loop bound that the compiler bakes into the generated code). Without the annotation, two `SimConfig` instances with different `num_layers` values would share a fastcache key, and the second instance would silently load a kernel compiled for the wrong value.

Note the asymmetry: `@qd.data_oriented` primitive members are baked into the kernel automatically (same semantics as `qd.Template`); `dataclasses.dataclass` members contribute only their *type* to the cache key unless you opt in per-member.
9 changes: 9 additions & 0 deletions docs/source/user_guide/tensor.md
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,15 @@ fill(b) # ndarray branch

The kernel argument is unwrapped to the bare impl before the template-mapper / AST sees it, so kernel bodies still write `x[i, j]` and pay no per-call cost for the wrapper.

`qd.Tensor` is also the right annotation when storing a tensor as a `dataclasses.dataclass` member:

```python
@dataclass
class State:
a: qd.Tensor
b: qd.Tensor
```

## Pickle

`qd.Tensor` objects are picklable on **both** backends, including under non-identity layouts. Round-trip (pickle then unpickle) preserves the canonical data, the dtype, the shape, and the layout:
Expand Down
32 changes: 31 additions & 1 deletion python/quadrants/lang/_template_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,29 @@
from quadrants.lang import impl
from quadrants.lang.impl import Program
from quadrants.lang.kernel_arguments import ArgMetadata
from quadrants.lang.util import is_data_oriented

from .._test_tools import warnings_helper
from ._kernel_types import ArgsHash
from ._template_mapper_hotpath import _extract_arg, _primitive_types
from ._template_mapper_hotpath import (
_extract_arg,
_primitive_types,
_struct_nd_paths_for,
)


def _collect_data_oriented_nd_ids(arg: Any, out: list) -> None:
"""Append ``id(ndarray)`` for every ndarray reachable from ``arg``, using the per-class path cache in
``_template_mapper_hotpath._struct_nd_paths_for`` so the first call walks ``vars(arg)`` once and subsequent calls
are just ``getattr`` chains. Empty path list short-circuits with zero work — critical for genesis's
``@qd.data_oriented`` Solver passed as ``self`` to every kernel.
"""
for chain in _struct_nd_paths_for(arg):
v = arg
for a in chain:
v = getattr(v, a)
out.append(id(v))


Key: TypeAlias = tuple[Any, ...]

Expand Down Expand Up @@ -71,6 +90,17 @@ def lookup(self, raise_on_templated_floats: bool, args: tuple[Any, ...]) -> tupl
# branching for primitive types dramatically improve performance of hash computation.
mapping_cache_tracker: list[ReferenceType | None] | None = None
args_hash: ArgsHash = tuple([id(arg) for arg in args])
# ``@qd.data_oriented`` containers can have their member ndarrays reassigned between calls on the same instance
# (``state.x = other_ndarray``). The id(arg) alone does not capture that, so the spec-key cache below would
# serve a stale entry and the new ndarray's dtype/ndim would be wrong. Fold the reachable ndarray ids into the
# hash. No-op for data_oriented containers that hold no ndarrays — the walker returns an empty list. See
# ``_collect_data_oriented_nd_ids``.
nd_ids: list = []
for arg in args:
if is_data_oriented(arg):
_collect_data_oriented_nd_ids(arg, nd_ids)
if nd_ids:
args_hash = args_hash + tuple(nd_ids)
try:
mapping_cache_tracker = self._mapping_cache_tracker[args_hash]
except KeyError:
Expand Down
100 changes: 99 additions & 1 deletion python/quadrants/lang/_template_mapper_hotpath.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
a consequence of inlining 'is_dataclass' and 'fields'.
"""

import dataclasses
import weakref
from dataclasses import _FIELD, _FIELDS
from typing import Any, Union
Expand Down Expand Up @@ -71,6 +72,88 @@
_primitive_types = {int, float, bool}


# Per-class cache: ``type(arg) -> list[tuple[str, ...]]`` of attribute paths whose values are ``Ndarray`` instances at
# first observation. Populated lazily by ``_struct_nd_paths_for`` on the first call with each new data_oriented (or
# nested dataclass) class. Empty list means "this class holds no ndarrays anywhere", in which case subsequent calls
# pay only a dict-lookup per arg. Non-empty list short-circuits the full ``vars()`` recursion and just resolves each
# cached path via ``getattr`` chains. Critical for the genesis field-backend hot path: the ``@qd.data_oriented``
# Solver is passed as ``self`` to most kernels and holds dozens of attributes, so a full per-call ``vars()`` walk
# costs >100ns per kernel and trashed FPS until this cache was added.
_struct_nd_paths_cache: dict[type, list[tuple]] = {}


def _build_struct_nd_paths(obj: Any, prefix: tuple, out: list) -> None:
if dataclasses.is_dataclass(obj) and not isinstance(obj, type):
children = ((f.name, getattr(obj, f.name)) for f in dataclasses.fields(obj))
else:
# ``NamedTuple`` (decorated as ``@qd.data_oriented``) has no instance ``__dict__`` — fall back to ``_asdict()``
# which materialises a dict view of the named fields. Mirrors the same fallback in
# ``args_hasher.stringify_obj_type`` so the per-class path cache here picks up ndarray members on NamedTuples
# too (regression covered by ``test_args_hasher_named_tuple``).
try:
children = obj._asdict().items()
except AttributeError:
children = obj.__dict__.items()
for k, v in children:
chain = prefix + (k,)
if type(v) in _TENSOR_WRAPPER_TYPES:
v = v._unwrap()
v_type = type(v)
if issubclass(v_type, Ndarray):
out.append(chain)
elif is_data_oriented(v) or (dataclasses.is_dataclass(v) and not isinstance(v, type)):
_build_struct_nd_paths(v, chain, out)


def _struct_nd_paths_for(arg: Any) -> list[tuple]:
"""Return the cached attribute paths (each a tuple of attr-name strings) at which ``Ndarray`` instances are
reachable from ``arg`` of type ``type(arg)``. First call for a class walks ``arg`` once via
``_build_struct_nd_paths``; subsequent calls are dict-lookups.

Trades freshness for speed: assumes the *set* of ndarray-holding attribute paths is stable across instances of
the same class. The genesis Solver and similar ``@qd.data_oriented`` containers satisfy this — their ndarray
members are declared in ``__init__`` and not added later. If you need to add an ndarray attribute after the first
kernel launch on an instance of a given class, the new attribute won't be tracked. Call ``invalidate_struct_nd_
paths_for`` (below) or restart the program.

FIXME (Codex #3 on PR #704, https://github.com/Genesis-Embodied-AI/quadrants/pull/704#discussion_r3253281957):
the cache is keyed by ``type(arg)`` only. If two instances of the same class have *polymorphic attribute
structure* — e.g. instance A has ``.x`` as a ``qd.ndarray``-backed ``qd.Tensor`` while instance B has the same
``.x`` as a field-backed ``qd.Tensor`` — the paths discovered from the first-walked instance are reused for the
second. ``_collect_struct_nd_descriptors`` then unconditionally reads ndarray-only attrs (``element_type``,
``grad``, ``_qd_layout``) on what is now a ``ScalarField``, raising before the kernel can run. The fix is the
per-instance walk implemented on top of this branch in PR #705; this branch ships the class-level cache as-is.
"""
cls = type(arg)
paths = _struct_nd_paths_cache.get(cls)
if paths is None:
paths = []
_build_struct_nd_paths(arg, (), paths)
_struct_nd_paths_cache[cls] = paths
return paths


def _collect_struct_nd_descriptors(arg: Any, out: list) -> None:
"""Emit per-ndarray shape descriptors ``(joined-path, element_type, ndim, needs_grad, layout)`` for every ndarray
reachable from ``arg``. Used by the template-mapper to refine the spec key for ``@qd.data_oriented`` args holding
ndarrays — see the data_oriented branch in ``_extract_arg``.

FIXME (Codex #3 on PR #704): when a polymorphic instance reuses a cached path that pointed to an ``Ndarray`` on
the first-walked instance, ``v`` here can be a ``ScalarField`` and the ``v.element_type`` / ``v.grad`` /
``v._qd_layout`` reads will raise. See ``_struct_nd_paths_for`` above for details. Fixed in PR #705 via the
per-instance walk redesign.
"""
for chain in _struct_nd_paths_for(arg):
v = arg
for a in chain:
v = getattr(v, a)
if type(v) in _TENSOR_WRAPPER_TYPES:
v = v._unwrap()
type_id = id(v.element_type)
Comment on lines +150 to +152
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Handle cached ndarray paths that now point at fields

If the first @qd.data_oriented instance of a class has x as a qd.Tensor/ndarray, _struct_nd_paths_cache records ('x',) for the whole class. A later instance of the same class with x as a field-backed qd.Tensor (or a field assigned at that path) will still walk that cached path, unwrap to a ScalarField, and then unconditionally read ndarray-only attributes like element_type/grad, raising before the kernel can run. This breaks backend-polymorphic data-oriented containers depending on which backend instance is seen first; revalidate that the resolved value is still an Ndarray before emitting a descriptor.

Useful? React with 👍 / 👎.

element_type = type_id if type_id in primitive_types.type_ids else v.element_type
out.append((".".join(chain), element_type, len(v.shape), v.grad is not None, v._qd_layout))


def _extract_arg(raise_on_templated_floats: bool, arg: Any, annotation: AnnotationType, arg_name: str) -> Any:
# ``qd.Tensor`` wrappers passed as struct fields. Top-level kernel-arg unwrap in ``Kernel.__call__`` covers direct
# args, but the dataclass-field recursion at the bottom of this function walks struct attributes via raw
Expand Down Expand Up @@ -124,7 +207,7 @@ def _extract_arg(raise_on_templated_floats: bool, arg: Any, annotation: Annotati
raise QuadrantsRuntimeTypeError(
"Ndarray shouldn't be passed in via `qd.template()`, please annotate your kernel using `qd.types.ndarray(...)` instead"
)
if arg_type in _composite_mutable_types or is_data_oriented(arg):
if arg_type in _composite_mutable_types:
# [Composite arguments] Return weak reference to the object
# Quadrants kernel will cache the extracted arguments, thus we can't simply return the original argument.
# Instead, a weak reference to the original value is returned to avoid memory leak.
Expand All @@ -134,6 +217,21 @@ def _extract_arg(raise_on_templated_floats: bool, arg: Any, annotation: Annotati
# 1. Invalid weak-ref will leave a dead(dangling) entry in both caches: "self.mapping" and "self.compiled_functions"
# 2. Different argument instances with same type and same value, will get templatized into separate kernels.
return weakref.ref(arg)
if is_data_oriented(arg):
# Same memory-leak avoidance as above — keep ``weakref.ref(arg)`` so the spec key never holds a strong
# reference to user state. But for data_oriented containers that hold ``Ndarray`` members, the live
# ``weakref`` alone is too coarse: same instance with ``state.x = other_ndarray`` of a different dtype/ndim
# would re-use the previously-compiled kernel, which was specialised for the old shape. Walk the reachable
# ndarrays and prepend their shape descriptors so dtype/ndim changes trigger re-specialisation. Mirrors what
# the dataclass branch below does via ``annotation_fields``.
#
# Containers with no ndarrays keep the original short-path (one spec per instance via weakref) so this is
# a no-op for the existing data_oriented + qd.field workloads (genesis field-backend).
nd_descriptors: list = []
_collect_struct_nd_descriptors(arg, nd_descriptors)
if nd_descriptors:
return (id(type(arg)), tuple(nd_descriptors), weakref.ref(arg))
return weakref.ref(arg)

# Return value directly for other types, i.e. primitive types and all qd.Field-derived classes
if raise_on_templated_floats and arg_type is float:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from quadrants.lang.matrix import MatrixType
from quadrants.lang.stream import stream_parallel
from quadrants.lang.struct import StructType
from quadrants.lang.util import to_quadrants_type
from quadrants.lang.util import is_data_oriented, to_quadrants_type
from quadrants.types import annotations, buffer_view_type, ndarray_type, primitive_types


Expand Down Expand Up @@ -149,6 +149,21 @@ def _transform_kernel_arg(
field.type,
this_arg_features[field_idx],
)
elif isinstance(field.type, type) and getattr(field.type, "_data_oriented", False):
# ``@qd.data_oriented`` field type inside a typed-dataclass kernel arg. The two patterns are
# semantically incompatible at this layer: dataclass kernel-arg recursion uses annotations to
# flatten leaf fields into per-leaf kernel args at compile time, but data_oriented containers don't
# carry per-attribute type annotations — they need a value-driven walk
# (``_predeclare_struct_ndarrays``), which only fires for ``qd.template()`` / ``qd.Tensor``
# annotations. Rather than silently miscompile, raise a clear error pointing users to the
# recommended pattern.
raise QuadrantsSyntaxError(
f"Kernel arg {argument_name!r}: field {field.name!r} has @qd.data_oriented type "
f"{field.type.__name__!r}, which cannot be flattened into a typed-dataclass kernel arg. "
f"Use ``{argument_name}: qd.template()`` for the outer kernel arg annotation instead; "
f"data_oriented contents (including nested ndarrays) are walked at kernel-compile time via "
f"the template path."
)
else:
result, obj = FunctionDefTransformer._decl_and_create_variable(
ctx,
Expand Down Expand Up @@ -226,14 +241,18 @@ def _walk_obj(obj, arg_idx, path):
child = child._unwrap()
if isinstance(child, _ndarray.Ndarray):
_register_ndarray(child, arg_idx, (*path, field.name))
elif dataclasses.is_dataclass(child) and not isinstance(child, type):
elif (dataclasses.is_dataclass(child) and not isinstance(child, type)) or is_data_oriented(child):
_walk_obj(child, arg_idx, (*path, field.name))
else:
for attr_name, attr_val in vars(obj).items():
if isinstance(attr_val, _TensorClass):
attr_val = attr_val._unwrap()
if isinstance(attr_val, _ndarray.Ndarray):
_register_ndarray(attr_val, arg_idx, (*path, attr_name))
elif (dataclasses.is_dataclass(attr_val) and not isinstance(attr_val, type)) or is_data_oriented(
attr_val
):
_walk_obj(attr_val, arg_idx, (*path, attr_name))

def _register_ndarray(nd, arg_idx, attr_chain):
key = id(nd)
Expand Down
Loading
Loading