@@ -329,6 +329,10 @@ def apply_vmap_over_outer(
329329) -> PyTree :
330330 """
331331 Apply a function across the outer dimensions of a tensor.
332+
333+ This is intended to be a QoL feature for handling the common case of
334+ applying a function across the outer dimensions of a tensor. The goal is
335+ to eliminate the need for nested calls to ``jax.vmap``.
332336 """
333337 if isinstance (f_dim , int ):
334338 f_dim = tree_map (lambda _ : f_dim , x )
@@ -658,3 +662,44 @@ def mask_tensor(
658662):
659663 mask = conform_mask (tensor = tensor , mask = mask , axis = axis )
660664 return jnp .where (mask , tensor , fill_value )
665+
666+
667+ def masker (
668+ mask : Tensor ,
669+ axis : int | Sequence [int ],
670+ output_axis : int | None = None ,
671+ ) -> Callable [[Tensor ], Tensor ]:
672+ """
673+ Create a JIT-compatible function that applies a mask to a tensor.
674+
675+ .. warning::
676+
677+ This function comes with some memory overhead. Specifically, it
678+ closes over an integer array of the same size as the number of
679+ ``True`` elements in the mask. When applying a very large mask to
680+ a tensor, it is important to consider the trade-off between memory
681+ and potential performance gains of JIT compilation.
682+
683+ .. warning::
684+
685+ Just like any JIT-compiled function, the resulting function must be
686+ recompiled when the shape of the input tensor changes.
687+ """
688+ if isinstance (axis , int ):
689+ axis = (axis ,)
690+ mask_loc = jnp .where (mask )
691+ if mask_loc [0 ].size == 0 :
692+ raise ValueError ('Mask is empty' )
693+ assert len (mask_loc ) == len (axis )
694+ if output_axis is None :
695+ output_axis = - 1
696+
697+ @jax .jit
698+ def apply_mask (tensor : Tensor ) -> Tensor :
699+ _axis = tuple (standard_axis_number (ax , tensor .ndim ) for ax in axis )
700+ indexer = [slice (None )] * tensor .ndim
701+ for ax , loc in zip (_axis , mask_loc ):
702+ indexer [ax ] = loc
703+ return tensor [tuple (indexer )].swapaxes (0 , output_axis )
704+
705+ return apply_mask
0 commit comments