2222
2323import functools
2424import inspect
25+ from dataclasses import dataclass
2526
26- import flax
2727import jax
2828import jax .numpy as jnp
2929import numpy as np
3030from flax import nnx
3131from jax_autovmap import autovmap
3232
33- from ..utils import Const
3433from .base import Bijection
3534
3635
@@ -73,11 +72,6 @@ class BinaryMask(Bijection):
7372 >>> # reconstructed == x
7473 """
7574
76- masks : Const
77- primary_indices : Const
78- secondary_indices : Const
79- event_shape : tuple [int , ...]
80-
8175 def __init__ (
8276 self ,
8377 primary_indices : tuple [np .ndarray , ...],
@@ -90,20 +84,20 @@ def __init__(
9084 masks = (mask , ~ mask )
9185 if secondary_indices is None :
9286 secondary_indices = np .where (masks [1 ])
93- self .masks = Const ( masks )
94- self .primary_indices = Const ( primary_indices )
95- self .secondary_indices = Const ( secondary_indices )
96- self .event_shape = event_shape
87+ self .primary_indices = nnx . data ( primary_indices )
88+ self .event_shape = nnx . static ( event_shape )
89+ self .masks = nnx . data ( masks )
90+ self .secondary_indices = nnx . data ( secondary_indices )
9791
9892 @property
9993 def count_primary (self ):
10094 """Number of elements in the primary (True) mask region."""
101- return self .primary_indices . value [0 ].size
95+ return self .primary_indices [0 ].size
10296
10397 @property
10498 def count_secondary (self ):
10599 """Number of elements in the secondary (False) mask region."""
106- return self .secondary_indices . value [0 ].size
100+ return self .secondary_indices [0 ].size
107101
108102 @property
109103 def counts (self ):
@@ -145,7 +139,7 @@ def from_boolean_mask(cls, mask: jax.Array):
145139 @property
146140 def boolean_mask (self ):
147141 """Primary boolean mask array."""
148- return self .masks . value [0 ]
142+ return self .masks [0 ]
149143
150144 def indices (
151145 self , extra_channel_dims : int = 0 , batch_safe : bool = True , primary : bool = True
@@ -161,7 +155,7 @@ def indices(
161155 Indexing tuple suitable for array subscripting.
162156 """
163157 ind = (...,) if batch_safe else ()
164- ind += self .primary_indices . value if primary else self .secondary_indices . value
158+ ind += self .primary_indices if primary else self .secondary_indices
165159 ind += (np .s_ [:],) * extra_channel_dims
166160 return ind
167161
@@ -172,10 +166,10 @@ def flip(self):
172166 New BinaryMask with primary and secondary regions swapped.
173167 """
174168 return self .__class__ (
175- self .secondary_indices . value ,
169+ self .secondary_indices ,
176170 self .event_shape ,
177- masks = self .masks . value [::- 1 ],
178- secondary_indices = self .primary_indices . value ,
171+ masks = self .masks [::- 1 ],
172+ secondary_indices = self .primary_indices ,
179173 )
180174
181175 def split (self , array , extra_channel_dims : int = 0 , batch_safe : bool = True ):
@@ -323,7 +317,7 @@ def checker_mask(shape, parity: bool):
323317 return BinaryMask .from_boolean_mask (mask .astype (bool ))
324318
325319
326- class ModuleReconstructor :
320+ class ModuleReconstructor ( nnx . Pytree ) :
327321 """
328322 Parameter management utility for dynamically parameterizing modules.
329323
@@ -341,42 +335,39 @@ class ModuleReconstructor:
341335 - Full nnx state, use `from_params`
342336 """
343337
344- # params_treedef: Any # static
345- # params_leaves: list[jax.core.ShapedArray] # static
346- # unconditional: nnx.State # array leaf
347- # graph: Any | None = None # static
348-
349338 def __init__ (
350- self , module_or_state : nnx .State | nnx .Module , filter : nnx .Param = nnx .Param
339+ self ,
340+ module_or_state : nnx .State | nnx .Module ,
341+ filter : nnx .Param = nnx .Param ,
351342 ):
352343 if isinstance (module_or_state , nnx .State ):
353344 self .graph = None
354345 state = module_or_state
355346 else :
356347 graph , state = nnx .split (module_or_state )
357- self .graph = graph
348+ self .graph = nnx . static ( graph )
358349
359350 params , unconditional = nnx .split_state (state , filter , ...)
360351
361352 params = jax .tree .map (lambda x : jax .core .ShapedArray (x .shape , x .dtype ), params )
362353
363354 params_leaves , params_treedef = jax .tree .flatten (params )
364355
365- self .params_treedef = params_treedef
366- self .params_leaves = params_leaves
367- self .unconditional = unconditional
356+ self .params_treedef = nnx . static ( params_treedef )
357+ self .params_leaves = nnx . static ( params_leaves )
358+ self .unconditional = nnx . data ( unconditional )
368359
369- def _tree_flatten (self ):
370- children = (self .unconditional ,)
371- aux_data = (self .params_treedef , self .params_leaves , self .graph )
372- return children , aux_data
360+ # def _tree_flatten(self):
361+ # children = (self.unconditional,)
362+ # aux_data = (self.params_treedef, self.params_leaves, self.graph)
363+ # return children, aux_data
373364
374- @classmethod
375- def _tree_unflatten (cls , aux_data , children ):
376- self = object .__new__ (cls )
377- self .params_treedef , self .params_leaves , self .graph = aux_data
378- (self .unconditional ,) = children
379- return self
365+ # @classmethod
366+ # def _tree_unflatten(cls, aux_data, children):
367+ # self = object.__new__(cls)
368+ # self.params_treedef, self.params_leaves, self.graph = aux_data
369+ # (self.unconditional,) = children
370+ # return self
380371
381372 @property
382373 def params (self ):
@@ -490,15 +481,15 @@ def __repr__(self):
490481 return f"ModuleReconstructor:{ state_or_module } "
491482
492483
493- jax .tree_util .register_pytree_node (
494- ModuleReconstructor ,
495- ModuleReconstructor ._tree_flatten ,
496- ModuleReconstructor ._tree_unflatten ,
497- )
484+ # jax.tree_util.register_pytree_node(
485+ # ModuleReconstructor,
486+ # ModuleReconstructor._tree_flatten,
487+ # ModuleReconstructor._tree_unflatten,
488+ # )
498489
499490
500- @flax . struct . dataclass
501- class AutoVmapReconstructor :
491+ @dataclass ( frozen = True )
492+ class AutoVmapReconstructor ( nnx . Pytree ) :
502493 r"""Automatic vectorization for module reconstruction with batched parameters.
503494
504495 This class provides a solution for bijections that do not natively support
@@ -539,8 +530,8 @@ class AutoVmapReconstructor:
539530 """
540531
541532 reconstructor : ModuleReconstructor
542- params : nnx .State | dict | list [jax .Array ] | jax .Array
543- params_rank : int | dict = 1
533+ params : nnx .Data [ nnx . State | dict | list [jax .Array ] | jax .Array ]
534+ params_rank : nnx . Data [ int | dict ] = 1
544535
545536 def __call__ (self , fn_name , * args , input_ranks : tuple [int , ...] = (0 , 0 ), ** kwargs ):
546537
0 commit comments