-
Notifications
You must be signed in to change notification settings - Fork 26
[DataOriented] Fix ndarrays on data oriented #704
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
06b7c6a
d4350ef
97afa6d
49a723b
9bdeca5
dc7997b
906ce19
a0db648
c9598ad
93893e5
ce769a7
dd4de40
46825ab
ee5fbbb
b132b81
6d1c820
1de65b9
a648c3f
a55d360
6667ba6
e9c50b4
abf242b
4c27e2e
1f539e6
730cbcb
57e1b95
d4ca211
b72a7a7
18ff7bd
33f4744
883243e
3504250
cc01339
df3113e
7f5fd12
e7fafeb
f9a35df
d336dcd
4c5f622
56a4399
ef5f8a6
06580f1
8899357
8fef507
92f5fe1
9ea8e5b
6ff0848
fd8cd0a
004cd9a
bf85e4e
ccaae54
820c01a
06d2e86
f7dd090
8c0377c
71a53da
46fef24
18f995b
3ce0ab0
35be370
94e455a
31b27d7
36dc933
07dc486
3aa4fe1
89bb005
4923d68
16d50f5
e510a53
6d339a5
6d6b25d
f788395
786223a
a7adc97
d396bb5
eaf16fc
15da770
39a1df7
9f3b4e6
3b9f17a
0607d55
8d7e473
d0eda52
c2196b0
31c291a
0a7506a
fa353e4
4cf3b8c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,6 +25,7 @@ | |
| a consequence of inlining 'is_dataclass' and 'fields'. | ||
| """ | ||
|
|
||
| import dataclasses | ||
| import weakref | ||
| from dataclasses import _FIELD, _FIELDS | ||
| from typing import Any, Union | ||
|
|
@@ -71,6 +72,88 @@ | |
| _primitive_types = {int, float, bool} | ||
|
|
||
|
|
||
| # Per-class cache: ``type(arg) -> list[tuple[str, ...]]`` of attribute paths whose values are ``Ndarray`` instances at | ||
| # first observation. Populated lazily by ``_struct_nd_paths_for`` on the first call with each new data_oriented (or | ||
| # nested dataclass) class. Empty list means "this class holds no ndarrays anywhere", in which case subsequent calls | ||
| # pay only a dict-lookup per arg. Non-empty list short-circuits the full ``vars()`` recursion and just resolves each | ||
| # cached path via ``getattr`` chains. Critical for the genesis field-backend hot path: the ``@qd.data_oriented`` | ||
| # Solver is passed as ``self`` to most kernels and holds dozens of attributes, so a full per-call ``vars()`` walk | ||
| # costs >100ns per kernel and trashed FPS until this cache was added. | ||
| _struct_nd_paths_cache: dict[type, list[tuple]] = {} | ||
|
|
||
|
|
||
| def _build_struct_nd_paths(obj: Any, prefix: tuple, out: list) -> None: | ||
| if dataclasses.is_dataclass(obj) and not isinstance(obj, type): | ||
| children = ((f.name, getattr(obj, f.name)) for f in dataclasses.fields(obj)) | ||
| else: | ||
| # ``NamedTuple`` (decorated as ``@qd.data_oriented``) has no instance ``__dict__`` — fall back to ``_asdict()`` | ||
| # which materialises a dict view of the named fields. Mirrors the same fallback in | ||
| # ``args_hasher.stringify_obj_type`` so the per-class path cache here picks up ndarray members on NamedTuples | ||
| # too (regression covered by ``test_args_hasher_named_tuple``). | ||
| try: | ||
| children = obj._asdict().items() | ||
| except AttributeError: | ||
| children = obj.__dict__.items() | ||
| for k, v in children: | ||
| chain = prefix + (k,) | ||
| if type(v) in _TENSOR_WRAPPER_TYPES: | ||
| v = v._unwrap() | ||
| v_type = type(v) | ||
| if issubclass(v_type, Ndarray): | ||
| out.append(chain) | ||
| elif is_data_oriented(v) or (dataclasses.is_dataclass(v) and not isinstance(v, type)): | ||
| _build_struct_nd_paths(v, chain, out) | ||
|
|
||
|
|
||
| def _struct_nd_paths_for(arg: Any) -> list[tuple]: | ||
| """Return the cached attribute paths (each a tuple of attr-name strings) at which ``Ndarray`` instances are | ||
| reachable from ``arg`` of type ``type(arg)``. First call for a class walks ``arg`` once via | ||
| ``_build_struct_nd_paths``; subsequent calls are dict-lookups. | ||
|
|
||
| Trades freshness for speed: assumes the *set* of ndarray-holding attribute paths is stable across instances of | ||
| the same class. The genesis Solver and similar ``@qd.data_oriented`` containers satisfy this — their ndarray | ||
| members are declared in ``__init__`` and not added later. If you need to add an ndarray attribute after the first | ||
| kernel launch on an instance of a given class, the new attribute won't be tracked. Call ``invalidate_struct_nd_ | ||
| paths_for`` (below) or restart the program. | ||
|
|
||
| FIXME (Codex #3 on PR #704, https://github.com/Genesis-Embodied-AI/quadrants/pull/704#discussion_r3253281957): | ||
| the cache is keyed by ``type(arg)`` only. If two instances of the same class have *polymorphic attribute | ||
| structure* — e.g. instance A has ``.x`` as a ``qd.ndarray``-backed ``qd.Tensor`` while instance B has the same | ||
| ``.x`` as a field-backed ``qd.Tensor`` — the paths discovered from the first-walked instance are reused for the | ||
| second. ``_collect_struct_nd_descriptors`` then unconditionally reads ndarray-only attrs (``element_type``, | ||
| ``grad``, ``_qd_layout``) on what is now a ``ScalarField``, raising before the kernel can run. The fix is the | ||
| per-instance walk implemented on top of this branch in PR #705; this branch ships the class-level cache as-is. | ||
| """ | ||
| cls = type(arg) | ||
| paths = _struct_nd_paths_cache.get(cls) | ||
| if paths is None: | ||
| paths = [] | ||
| _build_struct_nd_paths(arg, (), paths) | ||
| _struct_nd_paths_cache[cls] = paths | ||
| return paths | ||
|
|
||
|
|
||
| def _collect_struct_nd_descriptors(arg: Any, out: list) -> None: | ||
| """Emit per-ndarray shape descriptors ``(joined-path, element_type, ndim, needs_grad, layout)`` for every ndarray | ||
| reachable from ``arg``. Used by the template-mapper to refine the spec key for ``@qd.data_oriented`` args holding | ||
| ndarrays — see the data_oriented branch in ``_extract_arg``. | ||
|
|
||
| FIXME (Codex #3 on PR #704): when a polymorphic instance reuses a cached path that pointed to an ``Ndarray`` on | ||
| the first-walked instance, ``v`` here can be a ``ScalarField`` and the ``v.element_type`` / ``v.grad`` / | ||
| ``v._qd_layout`` reads will raise. See ``_struct_nd_paths_for`` above for details. Fixed in PR #705 via the | ||
| per-instance walk redesign. | ||
| """ | ||
| for chain in _struct_nd_paths_for(arg): | ||
| v = arg | ||
| for a in chain: | ||
| v = getattr(v, a) | ||
| if type(v) in _TENSOR_WRAPPER_TYPES: | ||
| v = v._unwrap() | ||
| type_id = id(v.element_type) | ||
|
Comment on lines
+150
to
+152
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
If the first Useful? React with 👍 / 👎. |
||
| element_type = type_id if type_id in primitive_types.type_ids else v.element_type | ||
| out.append((".".join(chain), element_type, len(v.shape), v.grad is not None, v._qd_layout)) | ||
|
|
||
|
|
||
| def _extract_arg(raise_on_templated_floats: bool, arg: Any, annotation: AnnotationType, arg_name: str) -> Any: | ||
| # ``qd.Tensor`` wrappers passed as struct fields. Top-level kernel-arg unwrap in ``Kernel.__call__`` covers direct | ||
| # args, but the dataclass-field recursion at the bottom of this function walks struct attributes via raw | ||
|
|
@@ -124,7 +207,7 @@ def _extract_arg(raise_on_templated_floats: bool, arg: Any, annotation: Annotati | |
| raise QuadrantsRuntimeTypeError( | ||
| "Ndarray shouldn't be passed in via `qd.template()`, please annotate your kernel using `qd.types.ndarray(...)` instead" | ||
| ) | ||
| if arg_type in _composite_mutable_types or is_data_oriented(arg): | ||
| if arg_type in _composite_mutable_types: | ||
| # [Composite arguments] Return weak reference to the object | ||
| # Quadrants kernel will cache the extracted arguments, thus we can't simply return the original argument. | ||
| # Instead, a weak reference to the original value is returned to avoid memory leak. | ||
|
|
@@ -134,6 +217,21 @@ def _extract_arg(raise_on_templated_floats: bool, arg: Any, annotation: Annotati | |
| # 1. Invalid weak-ref will leave a dead(dangling) entry in both caches: "self.mapping" and "self.compiled_functions" | ||
| # 2. Different argument instances with same type and same value, will get templatized into separate kernels. | ||
| return weakref.ref(arg) | ||
| if is_data_oriented(arg): | ||
| # Same memory-leak avoidance as above — keep ``weakref.ref(arg)`` so the spec key never holds a strong | ||
| # reference to user state. But for data_oriented containers that hold ``Ndarray`` members, the live | ||
| # ``weakref`` alone is too coarse: same instance with ``state.x = other_ndarray`` of a different dtype/ndim | ||
| # would re-use the previously-compiled kernel, which was specialised for the old shape. Walk the reachable | ||
| # ndarrays and prepend their shape descriptors so dtype/ndim changes trigger re-specialisation. Mirrors what | ||
| # the dataclass branch below does via ``annotation_fields``. | ||
| # | ||
| # Containers with no ndarrays keep the original short-path (one spec per instance via weakref) so this is | ||
| # a no-op for the existing data_oriented + qd.field workloads (genesis field-backend). | ||
| nd_descriptors: list = [] | ||
| _collect_struct_nd_descriptors(arg, nd_descriptors) | ||
| if nd_descriptors: | ||
| return (id(type(arg)), tuple(nd_descriptors), weakref.ref(arg)) | ||
| return weakref.ref(arg) | ||
|
|
||
| # Return value directly for other types, i.e. primitive types and all qd.Field-derived classes | ||
| if raise_on_templated_floats and arg_type is float: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This section is great, but should probably be moved in Appendix, because it is not supposed to be relevant for users.