Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions recml/core/data/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from etils import epath
import numpy as np
import tensorflow as tf
import jax


Iterator = clu_data.DatasetIterator
Expand Down Expand Up @@ -67,17 +68,22 @@ def _maybe_to_numpy(
) -> np.ndarray | tf.SparseTensor | tf.RaggedTensor:
if isinstance(x, (tf.SparseTensor, tf.RaggedTensor, np.ndarray)):
return x
# FIX: Check for attribute existence to avoid crashes on non-Tensor objects
if hasattr(x, "_numpy"):
numpy = x._numpy() # pylint: disable=protected-access
else:
elif hasattr(x, "numpy"):
numpy = x.numpy()
else:
return x # Return as-is if it can't be converted

if isinstance(numpy, np.ndarray):
# `numpy` shares the same underlying buffer as the `x` Tensor.
# Tensors are expected to be immutable, so we disable writes.
numpy.setflags(write=False)
return numpy

return tf.nest.map_structure(_maybe_to_numpy, batch)
# FIX: Use jax.tree.map instead of tf.nest.map_structure
return jax.tree.map(_maybe_to_numpy, batch)

@property
def element_spec(self) -> clu_data.ElementSpec:
Expand Down Expand Up @@ -109,7 +115,8 @@ def _to_element_spec(
)
return clu_data.ArraySpec(dtype=x.dtype, shape=tuple(x.shape))

element_spec = tf.nest.map_structure(_to_element_spec, batch)
# element_spec = tf.nest.map_structure(_to_element_spec, batch)
element_spec = jax.tree.map(_to_element_spec, batch)
self._element_spec = element_spec
return element_spec

Expand Down
2 changes: 1 addition & 1 deletion recml/core/ops/embedding_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class SparsecoreParams:
"""Embedding parameters."""

feature_specs: Nested[FeatureSpec]
mesh: jax.sharding.Mesh | jax.sharding.AbstractMesh
mesh: jax.sharding.Mesh
data_axes: Sequence[str | None]
embedding_axes: Sequence[str | None]
sharding_strategy: str
Expand Down
10 changes: 10 additions & 0 deletions recml/core/training/mesh_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""Simple global context for mesh, replacing jax.experimental.maps."""

_GLOBAL_MESH = None

def set_global_mesh(mesh):
global _GLOBAL_MESH
_GLOBAL_MESH = mesh

def get_global_mesh():
return _GLOBAL_MESH
46 changes: 29 additions & 17 deletions recml/core/training/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utilities for partitioning."""

import abc
Expand All @@ -21,6 +22,8 @@
import flax.linen as nn
import jax
import numpy as np
# FIXED: Use the public experimental module available in JAX 0.4.30
from recml.core.training import mesh_context


PyTree = Any
Expand Down Expand Up @@ -67,7 +70,8 @@ class DataParallelPartitioner(Partitioner):
"""Data parallel partitioner."""

def __init__(self, data_axis: str = "batch"):
self.mesh = jax.make_mesh((jax.device_count(),), (data_axis,))
devices = jax.devices()
self.mesh = jax.sharding.Mesh(devices, (data_axis,))
self.data_sharding = jax.sharding.NamedSharding(
self.mesh, jax.sharding.PartitionSpec(data_axis)
)
Expand Down Expand Up @@ -107,8 +111,10 @@ def _shard(x: np.ndarray) -> jax.Array:
def partition_init(
self, init_fn: CreateStateFn, *, abstract_batch: PyTree | None = None
) -> CreateStateFn:
with jax.sharding.use_mesh(self.mesh):
# FIXED: Use 'with self.mesh'
with self.mesh:
if abstract_batch is not None:
mesh_context.set_global_mesh(self.mesh)
abstract_state = jax.eval_shape(init_fn, abstract_batch)
specs = nn.get_partition_spec(abstract_state)
self.state_sharding = jax.tree.map(
Expand All @@ -117,7 +123,8 @@ def partition_init(
init_fn = jax.jit(init_fn, out_shardings=self.state_sharding)

def _wrapped_init(batch: PyTree) -> State:
with jax.sharding.use_mesh(self.mesh):
# FIXED: Use 'with self.mesh'
with self.mesh:
state = init_fn(batch)
state = _maybe_unbox_state(state)
return state
Expand All @@ -130,15 +137,18 @@ def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn:
jit_kws["out_shardings"] = (self.state_sharding, None)
jit_kws["donate_argnums"] = (1,)

with jax.sharding.use_mesh(self.mesh):
# FIXED: Use 'with self.mesh' and legacy bridge
with self.mesh:
mesh_context.set_global_mesh(self.mesh)
step_fn = jax.jit(
fn,
in_shardings=(self.data_sharding, self.state_sharding),
**jit_kws,
)

def _wrapped_step(batch: PyTree, state: State) -> Any:
with jax.sharding.use_mesh(self.mesh):
# FIXED: Use 'with self.mesh'
with self.mesh:
return step_fn(batch, state)

return _wrapped_step
Expand Down Expand Up @@ -190,7 +200,8 @@ def __init__(
if axis_sizes[0] == -1:
axis_sizes[0] = len(devices) // math.prod(axis_sizes[1:])

self.mesh = jax.make_mesh(axis_sizes, axis_names, devices=devices)
# self.mesh = jax.make_mesh(axis_sizes, axis_names, devices=devices)
self.mesh = jax.sharding.Mesh(devices, axis_names)
self.rules = rules
self.aot_compile = aot_compile
self.options = options
Expand All @@ -213,12 +224,6 @@ def __init__(
self.abstract_batch = None
self.abstract_state = None

@property
def mesh_context_manager(
self,
) -> Callable[[jax.sharding.Mesh], ContextManager[None]]:
return jax.sharding.use_mesh

def shard_inputs(self, inputs: PyTree) -> PyTree:
def _shard(x: np.ndarray) -> jax.Array:
return jax.make_array_from_process_local_data(self.data_sharding, x)
Expand All @@ -234,7 +239,10 @@ def partition_init(
" model parallel partitioner."
)

with self.mesh_context_manager(self.mesh):
# FIXED: Use 'with self.mesh' directly
with self.mesh:
# FIXED: Legacy bridge
mesh_context.set_global_mesh(self.mesh)
abstract_state = jax.eval_shape(init_fn, abstract_batch)
specs = nn.get_partition_spec(abstract_state)

Expand All @@ -247,7 +255,8 @@ def partition_init(
compiled_init_fn = jax.jit(init_fn, out_shardings=state_sharding)

def _init(batch: PyTree) -> State:
with self.mesh_context_manager(self.mesh):
# FIXED: Use 'with self.mesh' directly
with self.mesh:
state = compiled_init_fn(batch)
state = _maybe_unbox_state(state)
return state
Expand All @@ -265,7 +274,9 @@ def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn:
else:
jit_kws["out_shardings"] = None

with self.mesh_context_manager(self.mesh):
# FIXED: Use 'with self.mesh' directly and legacy bridge
with self.mesh:
mesh_context.set_global_mesh(self.mesh)
step_fn = jax.jit(
fn,
in_shardings=(self.data_sharding, self.state_sharding),
Expand All @@ -286,7 +297,8 @@ def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn:
)

def _step(batch: PyTree, state: State) -> Any:
with self.mesh_context_manager(self.mesh):
# FIXED: Use 'with self.mesh' directly
with self.mesh:
return step_fn(batch, state)

return _step
Expand All @@ -302,4 +314,4 @@ def _maybe_unbox(x: Any) -> Any:
_maybe_unbox,
x,
is_leaf=lambda k: isinstance(k, nn.Partitioned),
)
)
6 changes: 6 additions & 0 deletions recml/examples/dlrm_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@
import dataclasses
from typing import Generic, Literal, TypeVar

import sys
import os
# Add the RecML folder to the system path
sys.path.append(os.path.join(os.getcwd(), "../../../RecML"))
os.environ["KERAS_BACKEND"] = "jax"

from etils import epy
import fiddle as fdl
import flax.linen as nn
Expand Down
10 changes: 8 additions & 2 deletions recml/examples/dlrm_experiment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@
# limitations under the License.
"""Tests for the DLRM experiment."""

import sys
import os
# Add the RecML folder to the system path
sys.path.append(os.path.join(os.getcwd(), "../../../RecML"))
os.environ["KERAS_BACKEND"] = "jax"

from absl.testing import absltest
import fiddle as fdl
from fiddle import selectors
Expand All @@ -32,8 +38,8 @@ def test_dlrm_experiment(self):

experiment = dlrm_experiment.experiment()

experiment.task.train_data.global_batch_size = 4
experiment.task.eval_data.global_batch_size = 4
experiment.task.train_data.global_batch_size = 128
experiment.task.eval_data.global_batch_size = 128
experiment.trainer.train_steps = 12
experiment.trainer.steps_per_loop = 4
experiment.trainer.steps_per_eval = 4
Expand Down
Loading
Loading