Skip to content

Commit 69395ab

Browse files
committed
Migrate to newer nnx.Pytree api, away from flax.struct.dataclass.
1 parent 211cc5e commit 69395ab

8 files changed

Lines changed: 114 additions & 107 deletions

File tree

docs/source/intro/cg.ipynb

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,20 @@
100100
"output_type": "stream",
101101
"text": [
102102
"Euler Tableau (CG1):\n",
103-
" ButcherTableau(stages=1, a=((0,),), b=(1,), c=(0,))\n",
103+
" \u001b[38;2;79;201;177mButcherTableau\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n",
104+
" \u001b[38;2;156;220;254mstages\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;182;207;169m1\u001b[0m,\n",
105+
" \u001b[38;2;156;220;254ma\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;182;207;169m0.0\u001b[0m,\u001b[38;2;255;213;3m)\u001b[0m,\u001b[38;2;255;213;3m)\u001b[0m,\n",
106+
" \u001b[38;2;156;220;254mb\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;182;207;169m1.0\u001b[0m,\u001b[38;2;255;213;3m)\u001b[0m,\n",
107+
" \u001b[38;2;156;220;254mc\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;182;207;169m0.0\u001b[0m,\u001b[38;2;255;213;3m)\u001b[0m\n",
108+
"\u001b[38;2;255;213;3m)\u001b[0m\n",
104109
"\n",
105110
"CG2 Tableau:\n",
106-
" ButcherTableau(stages=2, a=((0, 0), (0.5, 0)), b=(0, 1), c=(0, 0.5))\n"
111+
" \u001b[38;2;79;201;177mButcherTableau\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n",
112+
" \u001b[38;2;156;220;254mstages\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;182;207;169m2\u001b[0m,\n",
113+
" \u001b[38;2;156;220;254ma\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;182;207;169m0.0\u001b[0m, \u001b[38;2;182;207;169m0.0\u001b[0m\u001b[38;2;255;213;3m)\u001b[0m, \u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;182;207;169m0.5\u001b[0m, \u001b[38;2;182;207;169m0.0\u001b[0m\u001b[38;2;255;213;3m)\u001b[0m\u001b[38;2;255;213;3m)\u001b[0m,\n",
114+
" \u001b[38;2;156;220;254mb\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;182;207;169m0.0\u001b[0m, \u001b[38;2;182;207;169m1.0\u001b[0m\u001b[38;2;255;213;3m)\u001b[0m,\n",
115+
" \u001b[38;2;156;220;254mc\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;182;207;169m0.0\u001b[0m, \u001b[38;2;182;207;169m0.5\u001b[0m\u001b[38;2;255;213;3m)\u001b[0m\n",
116+
"\u001b[38;2;255;213;3m)\u001b[0m\n"
107117
]
108118
}
109119
],

docs/source/tutorials/scalar-theory.ipynb

Lines changed: 4 additions & 3 deletions
Large diffs are not rendered by default.

src/bijx/bijections/coupling.py

Lines changed: 39 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,14 @@
2222

2323
import functools
2424
import inspect
25+
from dataclasses import dataclass
2526

26-
import flax
2727
import jax
2828
import jax.numpy as jnp
2929
import numpy as np
3030
from flax import nnx
3131
from jax_autovmap import autovmap
3232

33-
from ..utils import Const
3433
from .base import Bijection
3534

3635

@@ -73,11 +72,6 @@ class BinaryMask(Bijection):
7372
>>> # reconstructed == x
7473
"""
7574

76-
masks: Const
77-
primary_indices: Const
78-
secondary_indices: Const
79-
event_shape: tuple[int, ...]
80-
8175
def __init__(
8276
self,
8377
primary_indices: tuple[np.ndarray, ...],
@@ -90,20 +84,20 @@ def __init__(
9084
masks = (mask, ~mask)
9185
if secondary_indices is None:
9286
secondary_indices = np.where(masks[1])
93-
self.masks = Const(masks)
94-
self.primary_indices = Const(primary_indices)
95-
self.secondary_indices = Const(secondary_indices)
96-
self.event_shape = event_shape
87+
self.primary_indices = nnx.data(primary_indices)
88+
self.event_shape = nnx.static(event_shape)
89+
self.masks = nnx.data(masks)
90+
self.secondary_indices = nnx.data(secondary_indices)
9791

9892
@property
9993
def count_primary(self):
10094
"""Number of elements in the primary (True) mask region."""
101-
return self.primary_indices.value[0].size
95+
return self.primary_indices[0].size
10296

10397
@property
10498
def count_secondary(self):
10599
"""Number of elements in the secondary (False) mask region."""
106-
return self.secondary_indices.value[0].size
100+
return self.secondary_indices[0].size
107101

108102
@property
109103
def counts(self):
@@ -145,7 +139,7 @@ def from_boolean_mask(cls, mask: jax.Array):
145139
@property
146140
def boolean_mask(self):
147141
"""Primary boolean mask array."""
148-
return self.masks.value[0]
142+
return self.masks[0]
149143

150144
def indices(
151145
self, extra_channel_dims: int = 0, batch_safe: bool = True, primary: bool = True
@@ -161,7 +155,7 @@ def indices(
161155
Indexing tuple suitable for array subscripting.
162156
"""
163157
ind = (...,) if batch_safe else ()
164-
ind += self.primary_indices.value if primary else self.secondary_indices.value
158+
ind += self.primary_indices if primary else self.secondary_indices
165159
ind += (np.s_[:],) * extra_channel_dims
166160
return ind
167161

@@ -172,10 +166,10 @@ def flip(self):
172166
New BinaryMask with primary and secondary regions swapped.
173167
"""
174168
return self.__class__(
175-
self.secondary_indices.value,
169+
self.secondary_indices,
176170
self.event_shape,
177-
masks=self.masks.value[::-1],
178-
secondary_indices=self.primary_indices.value,
171+
masks=self.masks[::-1],
172+
secondary_indices=self.primary_indices,
179173
)
180174

181175
def split(self, array, extra_channel_dims: int = 0, batch_safe: bool = True):
@@ -323,7 +317,7 @@ def checker_mask(shape, parity: bool):
323317
return BinaryMask.from_boolean_mask(mask.astype(bool))
324318

325319

326-
class ModuleReconstructor:
320+
class ModuleReconstructor(nnx.Pytree):
327321
"""
328322
Parameter management utility for dynamically parameterizing modules.
329323
@@ -341,42 +335,39 @@ class ModuleReconstructor:
341335
- Full nnx state, use `from_params`
342336
"""
343337

344-
# params_treedef: Any # static
345-
# params_leaves: list[jax.core.ShapedArray] # static
346-
# unconditional: nnx.State # array leaf
347-
# graph: Any | None = None # static
348-
349338
def __init__(
350-
self, module_or_state: nnx.State | nnx.Module, filter: nnx.Param = nnx.Param
339+
self,
340+
module_or_state: nnx.State | nnx.Module,
341+
filter: nnx.Param = nnx.Param,
351342
):
352343
if isinstance(module_or_state, nnx.State):
353344
self.graph = None
354345
state = module_or_state
355346
else:
356347
graph, state = nnx.split(module_or_state)
357-
self.graph = graph
348+
self.graph = nnx.static(graph)
358349

359350
params, unconditional = nnx.split_state(state, filter, ...)
360351

361352
params = jax.tree.map(lambda x: jax.core.ShapedArray(x.shape, x.dtype), params)
362353

363354
params_leaves, params_treedef = jax.tree.flatten(params)
364355

365-
self.params_treedef = params_treedef
366-
self.params_leaves = params_leaves
367-
self.unconditional = unconditional
356+
self.params_treedef = nnx.static(params_treedef)
357+
self.params_leaves = nnx.static(params_leaves)
358+
self.unconditional = nnx.data(unconditional)
368359

369-
def _tree_flatten(self):
370-
children = (self.unconditional,)
371-
aux_data = (self.params_treedef, self.params_leaves, self.graph)
372-
return children, aux_data
360+
# def _tree_flatten(self):
361+
# children = (self.unconditional,)
362+
# aux_data = (self.params_treedef, self.params_leaves, self.graph)
363+
# return children, aux_data
373364

374-
@classmethod
375-
def _tree_unflatten(cls, aux_data, children):
376-
self = object.__new__(cls)
377-
self.params_treedef, self.params_leaves, self.graph = aux_data
378-
(self.unconditional,) = children
379-
return self
365+
# @classmethod
366+
# def _tree_unflatten(cls, aux_data, children):
367+
# self = object.__new__(cls)
368+
# self.params_treedef, self.params_leaves, self.graph = aux_data
369+
# (self.unconditional,) = children
370+
# return self
380371

381372
@property
382373
def params(self):
@@ -490,15 +481,15 @@ def __repr__(self):
490481
return f"ModuleReconstructor:{state_or_module}"
491482

492483

493-
jax.tree_util.register_pytree_node(
494-
ModuleReconstructor,
495-
ModuleReconstructor._tree_flatten,
496-
ModuleReconstructor._tree_unflatten,
497-
)
484+
# jax.tree_util.register_pytree_node(
485+
# ModuleReconstructor,
486+
# ModuleReconstructor._tree_flatten,
487+
# ModuleReconstructor._tree_unflatten,
488+
# )
498489

499490

500-
@flax.struct.dataclass
501-
class AutoVmapReconstructor:
491+
@dataclass(frozen=True)
492+
class AutoVmapReconstructor(nnx.Pytree):
502493
r"""Automatic vectorization for module reconstruction with batched parameters.
503494
504495
This class provides a solution for bijections that do not natively support
@@ -539,8 +530,8 @@ class AutoVmapReconstructor:
539530
"""
540531

541532
reconstructor: ModuleReconstructor
542-
params: nnx.State | dict | list[jax.Array] | jax.Array
543-
params_rank: int | dict = 1
533+
params: nnx.Data[nnx.State | dict | list[jax.Array] | jax.Array]
534+
params_rank: nnx.Data[int | dict] = 1
544535

545536
def __call__(self, fn_name, *args, input_ranks: tuple[int, ...] = (0, 0), **kwargs):
546537

src/bijx/cg.py

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,14 @@
2020

2121
from functools import partial, reduce
2222

23-
import flax.struct
2423
import jax
2524
import jax.numpy as jnp
2625
import numpy as np
26+
from flax import nnx
2727
from jax import core, custom_derivatives
2828

2929

30-
@flax.struct.dataclass
31-
class ButcherTableau:
30+
class ButcherTableau(nnx.Pytree):
3231
r"""Butcher tableau defining coefficients for Runge-Kutta integration schemes.
3332
3433
Encodes the coefficient structure for explicit Runge-Kutta methods in
@@ -50,29 +49,16 @@ class ButcherTableau:
5049
Consistency requires: $c_i = \sum_j a_{ij}$ and $\sum_i b_i = 1$.
5150
"""
5251

53-
stages: int
54-
"""Number of stages $s$ in the method."""
55-
56-
a: tuple[tuple[int, ...]]
57-
"""Coefficient matrix $(a_{ij})$ as nested tuples."""
58-
59-
b: tuple[int, ...]
60-
"""Weight vector $(b_i)$ as tuple."""
61-
62-
c: tuple[int, ...]
63-
"""Node vector $(c_i)$ as tuple (computed from $a$)."""
64-
65-
@classmethod
66-
def from_ab(cls, a, b):
52+
def __init__(self, a, b):
6753
r"""Construct Butcher tableau from coefficient matrix and weights.
6854
6955
Creates a ButcherTableau instance from the $a$ matrix and $b$ vector,
7056
automatically computing the node vector $c_i = \sum_j a_{ij}$ and
7157
validating consistency conditions.
7258
7359
Args:
74-
a: Coefficient matrix as list of lists, shape $(s, s)$.
75-
b: Weight vector as list, length $s$.
60+
a: Coefficient matrix or list of lists, shape $(s, s)$.
61+
b: Weight vector or list, length $s$.
7662
7763
Returns:
7864
ButcherTableau instance with computed node vector.
@@ -85,12 +71,12 @@ def from_ab(cls, a, b):
8571
8672
Example:
8773
>>> # Second-order Crouch-Grossmann method
88-
>>> cg2 = ButcherTableau.from_ab(
74+
>>> cg2 = ButcherTableau(
8975
... a=[[0, 0], [1/2, 0]], b=[0, 1]
9076
... )
9177
"""
92-
a = tuple(tuple(ai) for ai in a)
93-
b = tuple(b)
78+
a = tuple(tuple(float(aij) for aij in ai) for ai in a)
79+
b = tuple(float(bi) for bi in b)
9480
c = tuple(sum(ai) for ai in a)
9581

9682
assert all(len(ai) == len(c) for ai in a)
@@ -101,10 +87,13 @@ def from_ab(cls, a, b):
10187
for i in range(j + 1):
10288
assert a[i][j] == 0, "only explicit methods supported"
10389

104-
return cls(stages=len(c), a=a, b=b, c=c)
90+
self.stages = len(c)
91+
self.a = nnx.static(a)
92+
self.b = nnx.static(b)
93+
self.c = nnx.static(c)
10594

10695

107-
EULER = ButcherTableau.from_ab(
96+
EULER = ButcherTableau(
10897
a=[[0]],
10998
b=[1],
11099
)
@@ -115,7 +104,7 @@ def from_ab(cls, a, b):
115104
For Lie groups: $g_{n+1} = \exp(h A(t_n, g_n)) g_n$.
116105
"""
117106

118-
CG2 = ButcherTableau.from_ab(
107+
CG2 = ButcherTableau(
119108
a=[[0, 0], [1 / 2, 0]],
120109
b=[0, 1],
121110
)
@@ -131,7 +120,7 @@ def from_ab(cls, a, b):
131120
Update: $g_{n+1} = \exp(h k_2) g_n$
132121
"""
133122

134-
CG3 = ButcherTableau.from_ab(
123+
CG3 = ButcherTableau(
135124
a=[[0, 0, 0], [3 / 4, 0, 0], [119 / 216, 17 / 108, 0]],
136125
b=[13 / 51, -2 / 3, 24 / 17],
137126
)
@@ -432,3 +421,4 @@ def augmented_ode(t, state, args):
432421

433422

434423
_crouch_grossmann.defvjp(_crouch_grossmann_fwd, _crouch_grossmann_rev)
424+
_crouch_grossmann.defvjp(_crouch_grossmann_fwd, _crouch_grossmann_rev)

0 commit comments

Comments
 (0)