Skip to content

Commit 32c55ec

Browse files
committed
(ENH) sketch JIT-able masker (untested)
1 parent b89e114 commit 32c55ec

1 file changed

Lines changed: 45 additions & 0 deletions

File tree

src/nitrix/_internal/util.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,10 @@ def apply_vmap_over_outer(
329329
) -> PyTree:
330330
"""
331331
Apply a function across the outer dimensions of a tensor.
332+
333+
This is intended to be a QoL feature for handling the common case of
334+
applying a function across the outer dimensions of a tensor. The goal is
335+
to eliminate the need for nested calls to ``jax.vmap``.
332336
"""
333337
if isinstance(f_dim, int):
334338
f_dim = tree_map(lambda _: f_dim, x)
@@ -658,3 +662,44 @@ def mask_tensor(
658662
):
659663
mask = conform_mask(tensor=tensor, mask=mask, axis=axis)
660664
return jnp.where(mask, tensor, fill_value)
665+
666+
667+
def masker(
668+
mask: Tensor,
669+
axis: int | Sequence[int],
670+
output_axis: int | None = None,
671+
) -> Callable[[Tensor], Tensor]:
672+
"""
673+
Create a JIT-compatible function that applies a mask to a tensor.
674+
675+
.. warning::
676+
677+
This function comes with some memory overhead. Specifically, it
678+
closes over an integer array of the same size as the number of
679+
``True`` elements in the mask. When applying a very large mask to
680+
a tensor, it is important to consider the trade-off between memory
681+
and potential performance gains of JIT compilation.
682+
683+
.. warning::
684+
685+
Just like any JIT-compiled function, the resulting function must be
686+
recompiled when the shape of the input tensor changes.
687+
"""
688+
if isinstance(axis, int):
689+
axis = (axis,)
690+
mask_loc = jnp.where(mask)
691+
if mask_loc[0].size == 0:
692+
raise ValueError('Mask is empty')
693+
assert len(mask_loc) == len(axis)
694+
if output_axis is None:
695+
output_axis = -1
696+
697+
@jax.jit
698+
def apply_mask(tensor: Tensor) -> Tensor:
699+
_axis = tuple(standard_axis_number(ax, tensor.ndim) for ax in axis)
700+
indexer = [slice(None)] * tensor.ndim
701+
for ax, loc in zip(_axis, mask_loc):
702+
indexer[ax] = loc
703+
return tensor[tuple(indexer)].swapaxes(0, output_axis)
704+
705+
return apply_mask

0 commit comments

Comments
 (0)