diff --git a/recml/core/data/iterator.py b/recml/core/data/iterator.py index f86c922..0b97066 100644 --- a/recml/core/data/iterator.py +++ b/recml/core/data/iterator.py @@ -21,6 +21,7 @@ from etils import epath import numpy as np import tensorflow as tf +import jax Iterator = clu_data.DatasetIterator @@ -67,17 +68,22 @@ def _maybe_to_numpy( ) -> np.ndarray | tf.SparseTensor | tf.RaggedTensor: if isinstance(x, (tf.SparseTensor, tf.RaggedTensor, np.ndarray)): return x + # FIX: Check for attribute existence to avoid crashes on non-Tensor objects if hasattr(x, "_numpy"): numpy = x._numpy() # pylint: disable=protected-access - else: + elif hasattr(x, "numpy"): numpy = x.numpy() + else: + return x # Return as-is if it can't be converted + if isinstance(numpy, np.ndarray): # `numpy` shares the same underlying buffer as the `x` Tensor. # Tensors are expected to be immutable, so we disable writes. numpy.setflags(write=False) return numpy - return tf.nest.map_structure(_maybe_to_numpy, batch) + # FIX: Use jax.tree.map instead of tf.nest.map_structure + return jax.tree.map(_maybe_to_numpy, batch) @property def element_spec(self) -> clu_data.ElementSpec: @@ -109,7 +115,8 @@ def _to_element_spec( ) return clu_data.ArraySpec(dtype=x.dtype, shape=tuple(x.shape)) - element_spec = tf.nest.map_structure(_to_element_spec, batch) + # element_spec = tf.nest.map_structure(_to_element_spec, batch) + element_spec = jax.tree.map(_to_element_spec, batch) self._element_spec = element_spec return element_spec diff --git a/recml/core/ops/embedding_ops.py b/recml/core/ops/embedding_ops.py index a1de4f0..f9e17bc 100644 --- a/recml/core/ops/embedding_ops.py +++ b/recml/core/ops/embedding_ops.py @@ -38,7 +38,7 @@ class SparsecoreParams: """Embedding parameters.""" feature_specs: Nested[FeatureSpec] - mesh: jax.sharding.Mesh | jax.sharding.AbstractMesh + mesh: jax.sharding.Mesh data_axes: Sequence[str | None] embedding_axes: Sequence[str | None] sharding_strategy: str diff --git a/recml/core/training/mesh_context.py b/recml/core/training/mesh_context.py new file mode 100644 index 0000000..b5491c5 --- /dev/null +++ b/recml/core/training/mesh_context.py @@ -0,0 +1,10 @@ +"""Simple global context for mesh, replacing jax.experimental.maps.""" + +_GLOBAL_MESH = None + +def set_global_mesh(mesh): + global _GLOBAL_MESH + _GLOBAL_MESH = mesh + +def get_global_mesh(): + return _GLOBAL_MESH diff --git a/recml/core/training/partitioning.py b/recml/core/training/partitioning.py index 4dc3b76..30ff444 100644 --- a/recml/core/training/partitioning.py +++ b/recml/core/training/partitioning.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + """Utilities for partitioning.""" import abc @@ -21,6 +22,8 @@ import flax.linen as nn import jax import numpy as np +# FIXED: Use the public experimental module available in JAX 0.4.30 +from recml.core.training import mesh_context PyTree = Any @@ -67,7 +70,8 @@ class DataParallelPartitioner(Partitioner): """Data parallel partitioner.""" def __init__(self, data_axis: str = "batch"): - self.mesh = jax.make_mesh((jax.device_count(),), (data_axis,)) + devices = jax.devices() + self.mesh = jax.sharding.Mesh(devices, (data_axis,)) self.data_sharding = jax.sharding.NamedSharding( self.mesh, jax.sharding.PartitionSpec(data_axis) ) @@ -107,8 +111,10 @@ def _shard(x: np.ndarray) -> jax.Array: def partition_init( self, init_fn: CreateStateFn, *, abstract_batch: PyTree | None = None ) -> CreateStateFn: - with jax.sharding.use_mesh(self.mesh): + # FIXED: Use 'with self.mesh' + with self.mesh: if abstract_batch is not None: + mesh_context.set_global_mesh(self.mesh) abstract_state = jax.eval_shape(init_fn, abstract_batch) specs = nn.get_partition_spec(abstract_state) self.state_sharding = jax.tree.map( @@ -117,7 +123,8 @@ def partition_init( init_fn = jax.jit(init_fn, out_shardings=self.state_sharding) def _wrapped_init(batch: PyTree) -> State: - with jax.sharding.use_mesh(self.mesh): + # FIXED: Use 'with self.mesh' + with self.mesh: state = init_fn(batch) state = _maybe_unbox_state(state) return state @@ -130,7 +137,9 @@ def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn: jit_kws["out_shardings"] = (self.state_sharding, None) jit_kws["donate_argnums"] = (1,) - with jax.sharding.use_mesh(self.mesh): + # FIXED: Use 'with self.mesh' and legacy bridge + with self.mesh: + mesh_context.set_global_mesh(self.mesh) step_fn = jax.jit( fn, in_shardings=(self.data_sharding, self.state_sharding), @@ -138,7 +147,8 @@ def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn: ) def _wrapped_step(batch: PyTree, state: State) -> Any: - with jax.sharding.use_mesh(self.mesh): + # FIXED: Use 'with self.mesh' + with self.mesh: return step_fn(batch, state) return _wrapped_step @@ -190,7 +200,8 @@ def __init__( if axis_sizes[0] == -1: axis_sizes[0] = len(devices) // math.prod(axis_sizes[1:]) - self.mesh = jax.make_mesh(axis_sizes, axis_names, devices=devices) + # self.mesh = jax.make_mesh(axis_sizes, axis_names, devices=devices) + self.mesh = jax.sharding.Mesh(devices, axis_names) self.rules = rules self.aot_compile = aot_compile self.options = options @@ -213,12 +224,6 @@ def __init__( self.abstract_batch = None self.abstract_state = None - @property - def mesh_context_manager( - self, - ) -> Callable[[jax.sharding.Mesh], ContextManager[None]]: - return jax.sharding.use_mesh - def shard_inputs(self, inputs: PyTree) -> PyTree: def _shard(x: np.ndarray) -> jax.Array: return jax.make_array_from_process_local_data(self.data_sharding, x) @@ -234,7 +239,10 @@ def partition_init( " model parallel partitioner." ) - with self.mesh_context_manager(self.mesh): + # FIXED: Use 'with self.mesh' directly + with self.mesh: + # FIXED: Legacy bridge + mesh_context.set_global_mesh(self.mesh) abstract_state = jax.eval_shape(init_fn, abstract_batch) specs = nn.get_partition_spec(abstract_state) @@ -247,7 +255,8 @@ def partition_init( compiled_init_fn = jax.jit(init_fn, out_shardings=state_sharding) def _init(batch: PyTree) -> State: - with self.mesh_context_manager(self.mesh): + # FIXED: Use 'with self.mesh' directly + with self.mesh: state = compiled_init_fn(batch) state = _maybe_unbox_state(state) return state @@ -265,7 +274,9 @@ def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn: else: jit_kws["out_shardings"] = None - with self.mesh_context_manager(self.mesh): + # FIXED: Use 'with self.mesh' directly and legacy bridge + with self.mesh: + mesh_context.set_global_mesh(self.mesh) step_fn = jax.jit( fn, in_shardings=(self.data_sharding, self.state_sharding), @@ -286,7 +297,8 @@ def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn: ) def _step(batch: PyTree, state: State) -> Any: - with self.mesh_context_manager(self.mesh): + # FIXED: Use 'with self.mesh' directly + with self.mesh: return step_fn(batch, state) return _step @@ -302,4 +314,4 @@ def _maybe_unbox(x: Any) -> Any: _maybe_unbox, x, is_leaf=lambda k: isinstance(k, nn.Partitioned), - ) + ) \ No newline at end of file diff --git a/recml/examples/dlrm_experiment.py b/recml/examples/dlrm_experiment.py index 36da20f..eeda133 100644 --- a/recml/examples/dlrm_experiment.py +++ b/recml/examples/dlrm_experiment.py @@ -19,6 +19,12 @@ import dataclasses from typing import Generic, Literal, TypeVar +import sys +import os +# Add the RecML folder to the system path +sys.path.append(os.path.join(os.getcwd(), "../../../RecML")) +os.environ["KERAS_BACKEND"] = "jax" + from etils import epy import fiddle as fdl import flax.linen as nn diff --git a/recml/examples/dlrm_experiment_test.py b/recml/examples/dlrm_experiment_test.py index d4b44c0..07902a4 100644 --- a/recml/examples/dlrm_experiment_test.py +++ b/recml/examples/dlrm_experiment_test.py @@ -13,6 +13,12 @@ # limitations under the License. """Tests for the DLRM experiment.""" +import sys +import os +# Add the RecML folder to the system path +sys.path.append(os.path.join(os.getcwd(), "../../../RecML")) +os.environ["KERAS_BACKEND"] = "jax" + from absl.testing import absltest import fiddle as fdl from fiddle import selectors @@ -32,8 +38,8 @@ def test_dlrm_experiment(self): experiment = dlrm_experiment.experiment() - experiment.task.train_data.global_batch_size = 4 - experiment.task.eval_data.global_batch_size = 4 + experiment.task.train_data.global_batch_size = 128 + experiment.task.eval_data.global_batch_size = 128 experiment.trainer.train_steps = 12 experiment.trainer.steps_per_loop = 4 experiment.trainer.steps_per_eval = 4 diff --git a/recml/examples/train_hstu_jax.py b/recml/examples/train_hstu_jax.py new file mode 100644 index 0000000..3a416f9 --- /dev/null +++ b/recml/examples/train_hstu_jax.py @@ -0,0 +1,271 @@ +"""HSTU Experiment Configuration using Fiddle and RecML with JaxTrainer""" + +import dataclasses +from typing import Mapping, Tuple +import sys +import os + +os.environ["KERAS_BACKEND"] = "jax" + +import fiddle as fdl +import jax +import jax.numpy as jnp +import keras +import optax +import tensorflow as tf +import clu.metrics as clu_metrics +from absl import app +from absl import flags +from absl import logging + +# Add the RecML folder to the system path +sys.path.append(os.path.join(os.getcwd(), "../../../RecML")) + +# RecML Imports +from recml.core.training import core +from recml.core.training import jax_trainer +from recml.core.training import partitioning +from recml.layers.keras import hstu +import recml + +# Define command-line flags +FLAGS = flags.FLAGS + +flags.DEFINE_string("train_path", None, "Path (or pattern) to training data") +flags.DEFINE_string("eval_path", None, "Path (or glob pattern) to evaluation data") + +flags.DEFINE_string("model_dir", "/tmp/hstu_model_jax", "Where to save the model") +flags.DEFINE_integer("vocab_size", 5_000_000, "Vocabulary size") +flags.DEFINE_integer("train_steps", 2000, "Total training steps") + +# Mark flags as required +flags.mark_flag_as_required("train_path") +flags.mark_flag_as_required("eval_path") + +@dataclasses.dataclass +class HSTUModelConfig: + """Configuration for the HSTU model architecture""" + vocab_size: int = 5_000_000 + max_sequence_length: int = 50 + model_dim: int = 64 + num_heads: int = 4 + num_layers: int = 4 + dropout: float = 0.5 + learning_rate: float = 1e-3 + +class TFRecordDataFactory(recml.Factory[tf.data.Dataset]): + """Reusable Data Factory for TFRecord datasets""" + + path: str + batch_size: int + max_sequence_length: int + feature_key: str = "sequence" + target_key: str = "target" + is_training: bool = True + + def make(self) -> tf.data.Dataset: + """Builds the tf.data.Dataset""" + if not self.path: + logging.warning("No path provided for dataset factory") + return tf.data.Dataset.empty() + + dataset = tf.data.Dataset.list_files(self.path) + dataset = tf.data.TFRecordDataset(dataset, num_parallel_reads=tf.data.AUTOTUNE) + + def _parse_fn(serialized_example): + features = { + self.feature_key: tf.io.VarLenFeature(tf.int64), + self.target_key: tf.io.FixedLenFeature([1], tf.int64), + } + parsed = tf.io.parse_single_example(serialized_example, features) + + seq = tf.sparse.to_dense(parsed[self.feature_key]) + padding_needed = self.max_sequence_length - tf.shape(seq)[0] + seq = tf.pad(seq, [[0, padding_needed]]) + seq = tf.ensure_shape(seq, [self.max_sequence_length]) + seq = tf.cast(seq, tf.int32) + + target = tf.cast(parsed[self.target_key], tf.int32) + return seq, target + + dataset = dataset.map(_parse_fn, num_parallel_calls=tf.data.AUTOTUNE) + + if self.is_training: + dataset = dataset.repeat() + + return dataset.batch(self.batch_size, drop_remainder=True).prefetch(tf.data.AUTOTUNE) + +class HSTUTask(jax_trainer.JaxTask): + """JaxTask for HSTU model""" + + def __init__( + self, + model_config: HSTUModelConfig, + train_data_factory: recml.Factory[tf.data.Dataset], + eval_data_factory: recml.Factory[tf.data.Dataset], + ): + self.config = model_config + self.train_data_factory = train_data_factory + self.eval_data_factory = eval_data_factory + + def create_datasets(self) -> Tuple[tf.data.Dataset, tf.data.Dataset]: + return self.train_data_factory.make(), self.eval_data_factory.make() + + def _create_model(self) -> keras.Model: + inputs = keras.Input( + shape=(self.config.max_sequence_length,), dtype="int32", name="input_ids" + ) + padding_mask = keras.ops.cast(keras.ops.not_equal(inputs, 0), "int32") + + hstu_layer = hstu.HSTU( + vocab_size=self.config.vocab_size, + max_positions=self.config.max_sequence_length, + model_dim=self.config.model_dim, + num_heads=self.config.num_heads, + num_layers=self.config.num_layers, + dropout=self.config.dropout, + ) + + logits = hstu_layer(inputs, padding_mask=padding_mask) + + def get_last_token_logits(args): + seq_logits, mask = args + lengths = keras.ops.sum(keras.ops.cast(mask, "int32"), axis=1) + last_indices = lengths - 1 + indices = keras.ops.expand_dims(keras.ops.expand_dims(last_indices, -1), -1) + return keras.ops.squeeze(keras.ops.take_along_axis(seq_logits, indices, axis=1), axis=1) + + output_logits = keras.layers.Lambda(get_last_token_logits)([logits, padding_mask]) + output_logits = keras.layers.Activation("linear", dtype="float32")(output_logits) + + model = keras.Model(inputs=inputs, outputs=output_logits) + return model + + def create_state(self, batch, rng) -> jax_trainer.KerasState: + inputs, _ = batch + model = self._create_model() + # Build the model to initialize variables + model.build(inputs.shape) + + optimizer = optax.adam(learning_rate=self.config.learning_rate) + return jax_trainer.KerasState.create(model=model, tx=optimizer) + + def train_step( + self, batch, state: jax_trainer.KerasState, rng: jax.Array + ) -> Tuple[jax_trainer.KerasState, Mapping[str, clu_metrics.Metric]]: + inputs, targets = batch + + def loss_fn(tvars): + logits, _ = state.model.stateless_call(tvars, state.ntvars, inputs) + loss = optax.softmax_cross_entropy_with_integer_labels( + logits, jnp.squeeze(targets) + ) + return jnp.mean(loss), logits + + grad_fn = jax.value_and_grad(loss_fn, has_aux=True) + (loss, logits), grads = grad_fn(state.tvars) + state = state.update(grads=grads) + + metrics = self._compute_metrics(loss, logits, targets) + return state, metrics + + def eval_step( + self, batch, state: jax_trainer.KerasState + ) -> Mapping[str, clu_metrics.Metric]: + inputs, targets = batch + logits, _ = state.model.stateless_call(state.tvars, state.ntvars, inputs) + loss = optax.softmax_cross_entropy_with_integer_labels( + logits, jnp.squeeze(targets) + ) + loss = jnp.mean(loss) + return self._compute_metrics(loss, logits, targets) + + def _compute_metrics(self, loss, logits, targets): + targets = jnp.squeeze(targets) + metrics = {"loss": clu_metrics.Average.from_model_output(loss)} + + # def get_acc(k): + # _, top_k_indices = jax.nn.top_k(logits, k) + # correct = jnp.sum(top_k_indices == targets[:, None], axis=-1) + # return jnp.mean(correct) + + # metrics["HR_10"] = clu_metrics.Average.from_model_output(get_acc(10)) + # metrics["HR_50"] = clu_metrics.Average.from_model_output(get_acc(50)) + # metrics["HR_200"] = clu_metrics.Average.from_model_output(get_acc(200)) + return metrics + +def experiment() -> fdl.Config[recml.Experiment]: + """Defines the experiment structure using Fiddle configs""" + + max_seq_len = 50 + batch_size = 128 + + model_cfg = fdl.Config( + HSTUModelConfig, + vocab_size=5_000_000, + max_sequence_length=max_seq_len, + model_dim=64, + num_layers=4, + dropout=0.5 + ) + + train_data = fdl.Config( + TFRecordDataFactory, + path="", # Placeholder + batch_size=batch_size, + max_sequence_length=max_seq_len, + is_training=True + ) + + eval_data = fdl.Config( + TFRecordDataFactory, + path="", # Placeholder + batch_size=batch_size, + max_sequence_length=max_seq_len, + is_training=False + ) + + task = fdl.Config( + HSTUTask, + model_config=model_cfg, + train_data_factory=train_data, + eval_data_factory=eval_data + ) + + trainer = fdl.Config( + jax_trainer.JaxTrainer, + partitioner=fdl.Config(partitioning.DataParallelPartitioner), + model_dir="/tmp/default_dir", # Placeholder + train_steps=2000, + steps_per_eval=10, + steps_per_loop=10, + ) + + return fdl.Config(recml.Experiment, task=task, trainer=trainer) + +def main(_): + # Ensure JAX uses the correct backend + logging.info(f"JAX Backend: {jax.default_backend()}") + + config = experiment() + + logging.info(f"Setting Train Path to: {FLAGS.train_path}") + config.task.train_data_factory.path = FLAGS.train_path + + logging.info(f"Setting Eval Path to: {FLAGS.eval_path}") + config.task.eval_data_factory.path = FLAGS.eval_path + + config.task.model_config.vocab_size = FLAGS.vocab_size + + logging.info(f"Setting Model Dir to: {FLAGS.model_dir}") + config.trainer.model_dir = FLAGS.model_dir + config.trainer.train_steps = FLAGS.train_steps + + expt = fdl.build(config) + + logging.info("Starting experiment execution...") + core.run_experiment(expt, core.Experiment.Mode.TRAIN_AND_EVAL) + + +if __name__ == "__main__": + app.run(main) \ No newline at end of file diff --git a/recml/examples/train_hstu_keras.py b/recml/examples/train_hstu_keras.py new file mode 100644 index 0000000..ee5aa40 --- /dev/null +++ b/recml/examples/train_hstu_keras.py @@ -0,0 +1,223 @@ +"""HSTU Experiment Configuration using Fiddle and RecML with KerasTrainer""" + +import dataclasses +from typing import Optional +import sys +import os + +import fiddle as fdl +import keras +import tensorflow as tf +from absl import app +from absl import flags +from absl import logging + +# Add the RecML folder to the system path +sys.path.append(os.path.join(os.getcwd(), "../../../RecML")) + +# RecML Imports +from recml.core.training import core +from recml.core.training import keras_trainer +from recml.layers.keras import hstu +import recml +import jax +print(jax.devices()) + +# Define command-line flags +FLAGS = flags.FLAGS + +flags.DEFINE_string("train_path", None, "Path (or pattern) to training data") +flags.DEFINE_string("eval_path", None, "Path (or glob pattern) to evaluation data") + +flags.DEFINE_string("model_dir", "/tmp/hstu_model", "Where to save the model") +flags.DEFINE_integer("vocab_size", 5_000_000, "Vocabulary size") +flags.DEFINE_integer("train_steps", 2000, "Total training steps") + +# Mark flags as required +flags.mark_flag_as_required("train_path") +flags.mark_flag_as_required("eval_path") + +@dataclasses.dataclass +class HSTUModelConfig: + """Configuration for the HSTU model architecture""" + vocab_size: int = 5_000_000 + max_sequence_length: int = 50 + model_dim: int = 64 + num_heads: int = 4 + num_layers: int = 4 + dropout: float = 0.5 + learning_rate: float = 1e-3 + +class TFRecordDataFactory(recml.Factory[tf.data.Dataset]): + """Reusable Data Factory for TFRecord datasets""" + + path: str + batch_size: int + max_sequence_length: int + feature_key: str = "sequence" + target_key: str = "target" + is_training: bool = True + + def make(self) -> tf.data.Dataset: + """Builds the tf.data.Dataset""" + if not self.path: + logging.warning("No path provided for dataset factory") + return tf.data.Dataset.empty() + + dataset = tf.data.Dataset.list_files(self.path) + dataset = tf.data.TFRecordDataset(dataset, num_parallel_reads=tf.data.AUTOTUNE) + + def _parse_fn(serialized_example): + features = { + self.feature_key: tf.io.VarLenFeature(tf.int64), + self.target_key: tf.io.FixedLenFeature([1], tf.int64), + } + parsed = tf.io.parse_single_example(serialized_example, features) + + seq = tf.sparse.to_dense(parsed[self.feature_key]) + padding_needed = self.max_sequence_length - tf.shape(seq)[0] + seq = tf.pad(seq, [[0, padding_needed]]) + seq = tf.ensure_shape(seq, [self.max_sequence_length]) + seq = tf.cast(seq, tf.int32) + + target = tf.cast(parsed[self.target_key], tf.int32) + return seq, target + + dataset = dataset.map(_parse_fn, num_parallel_calls=tf.data.AUTOTUNE) + + if self.is_training: + dataset = dataset.repeat() + + return dataset.batch(self.batch_size, drop_remainder=True).prefetch(tf.data.AUTOTUNE) + +class HSTUTask(keras_trainer.KerasTask): + """KerasTask that receives its dependencies via injection""" + + def __init__( + self, + model_config: HSTUModelConfig, + train_data_factory: recml.Factory[tf.data.Dataset], + eval_data_factory: recml.Factory[tf.data.Dataset], + ): + self.config = model_config + self.train_data_factory = train_data_factory + self.eval_data_factory = eval_data_factory + + def create_dataset(self, training: bool) -> tf.data.Dataset: + if training: + return self.train_data_factory.make() + return self.eval_data_factory.make() + + def create_model(self) -> keras.Model: + inputs = keras.Input( + shape=(self.config.max_sequence_length,), dtype="int32", name="input_ids" + ) + padding_mask = keras.ops.cast(keras.ops.not_equal(inputs, 0), "int32") + + hstu_layer = hstu.HSTU( + vocab_size=self.config.vocab_size, + max_positions=self.config.max_sequence_length, + model_dim=self.config.model_dim, + num_heads=self.config.num_heads, + num_layers=self.config.num_layers, + dropout=self.config.dropout, + ) + + logits = hstu_layer(inputs, padding_mask=padding_mask) + + def get_last_token_logits(args): + seq_logits, mask = args + lengths = tf.reduce_sum(tf.cast(mask, tf.int32), axis=1) + last_indices = lengths - 1 + return tf.gather(seq_logits, last_indices, batch_dims=1) + + output_logits = keras.layers.Lambda(get_last_token_logits)([logits, padding_mask]) + output_logits = keras.layers.Activation("linear", dtype="float32")(output_logits) + + model = keras.Model(inputs=inputs, outputs=output_logits) + + model.compile( + optimizer=keras.optimizers.Adam(learning_rate=self.config.learning_rate), + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + metrics=[ + keras.metrics.SparseTopKCategoricalAccuracy(k=10, name="HR_10"), + keras.metrics.SparseTopKCategoricalAccuracy(k=50, name="HR_50"), + keras.metrics.SparseTopKCategoricalAccuracy(k=200, name="HR_200"), + ], + ) + return model + +def experiment() -> fdl.Config[recml.Experiment]: + """Defines the experiment structure using Fiddle configs""" + + max_seq_len = 50 + batch_size = 128 + + model_cfg = fdl.Config( + HSTUModelConfig, + vocab_size=5_000_000, + max_sequence_length=max_seq_len, + model_dim=64, + num_layers=4, + dropout=0.5 + ) + + train_data = fdl.Config( + TFRecordDataFactory, + path="", # Placeholder + batch_size=batch_size, + max_sequence_length=max_seq_len, + is_training=True + ) + + eval_data = fdl.Config( + TFRecordDataFactory, + path="", # Placeholder + batch_size=batch_size, + max_sequence_length=max_seq_len, + is_training=False + ) + + task = fdl.Config( + HSTUTask, + model_config=model_cfg, + train_data_factory=train_data, + eval_data_factory=eval_data + ) + + trainer = fdl.Config( + keras_trainer.KerasTrainer, + model_dir="/tmp/default_dir", # Placeholder + train_steps=2000, + steps_per_eval=10, + steps_per_loop=10, + ) + + return fdl.Config(recml.Experiment, task=task, trainer=trainer) + +def main(_): + keras.mixed_precision.set_global_policy("mixed_bfloat16") + logging.info("Mixed precision policy set to mixed_bfloat16") + + config = experiment() + + logging.info(f"Setting Train Path to: {FLAGS.train_path}") + config.task.train_data_factory.path = FLAGS.train_path + + logging.info(f"Setting Eval Path to: {FLAGS.eval_path}") + config.task.eval_data_factory.path = FLAGS.eval_path + + config.task.model_config.vocab_size = FLAGS.vocab_size + + logging.info(f"Setting Model Dir to: {FLAGS.model_dir}") + config.trainer.model_dir = FLAGS.model_dir + config.trainer.train_steps = FLAGS.train_steps + + expt = fdl.build(config) + + logging.info("Starting experiment execution...") + core.run_experiment(expt, core.Experiment.Mode.TRAIN_AND_EVAL) + + +if __name__ == "__main__": + app.run(main) \ No newline at end of file diff --git a/recml/layers/linen/sparsecore.py b/recml/layers/linen/sparsecore.py index a908ab8..d48be9c 100644 --- a/recml/layers/linen/sparsecore.py +++ b/recml/layers/linen/sparsecore.py @@ -28,10 +28,12 @@ from recml.core.ops import embedding_ops import tensorflow as tf +from recml.core.training import mesh_context + with epy.lazy_imports(): # pylint: disable=g-import-not-at-top - from jax_tpu_embedding.sparsecore.lib.flax import embed + from jax_tpu_embedding.sparsecore.lib.flax.linen import embed from jax_tpu_embedding.sparsecore.lib.nn import embedding from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec from jax_tpu_embedding.sparsecore.lib.nn import table_stacking @@ -369,16 +371,28 @@ class SparsecoreEmbed(nn.Module): sparsecore_config: SparsecoreConfig mesh: jax.sharding.Mesh | jax.sharding.AbstractMesh | None = None - def get_mesh(self) -> jax.sharding.Mesh | jax.sharding.AbstractMesh: - if self.mesh is not None: - return self.mesh - abstract_mesh = jax.sharding.get_abstract_mesh() - if not abstract_mesh.shape_tuple: + # def get_mesh(self) -> jax.sharding.Mesh | jax.sharding.AbstractMesh: + # if self.mesh is not None: + # return self.mesh + # abstract_mesh = jax.sharding.get_abstract_mesh() + # if not abstract_mesh.shape_tuple: + # raise ValueError( + # 'No abstract mesh shape was set with `jax.sharding.use_mesh`. Make' + # ' sure to set the mesh when calling the sparsecore module.' + # ) + # return abstract_mesh + + def get_mesh(self) -> jax.sharding.Mesh: + # Try to get the mesh from our custom global context + mesh = mesh_context.get_global_mesh() + + if mesh is None: raise ValueError( - 'No abstract mesh shape was set with `jax.sharding.use_mesh`. Make' - ' sure to set the mesh when calling the sparsecore module.' + "No global mesh found. Make sure to call " + "`partitioning.partition_init` (which sets the mesh) " + "before initializing SparseCore." ) - return abstract_mesh + return mesh def get_sharding_axis( self, mesh: jax.sharding.Mesh | jax.sharding.AbstractMesh diff --git a/requirements.txt b/requirements.txt index 580d6c9..998ee15 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,6 @@ absl-py==2.2.2 +aiofiles==25.1.0 +array-record==0.8.3 astroid==3.3.9 astunparse==1.6.3 attrs==25.3.0 @@ -16,7 +18,7 @@ etils==1.12.2 fiddle==0.3.0 filelock==3.18.0 flatbuffers==25.2.10 -flax==0.10.5 +flax==0.12.2 fsspec==2025.3.2 gast==0.6.0 google-pasta==0.2.0 @@ -31,18 +33,22 @@ immutabledict==4.2.1 importlib-resources==6.5.2 iniconfig==2.1.0 isort==6.0.1 -jax==0.6.0 -jaxlib==0.6.0 +jax==0.8.2 +jax-tpu-embedding==0.1.0.dev20251208 +jaxlib==0.8.2 jaxtyping==0.3.1 -jinja2==3.1.6 +Jinja2==3.1.6 kagglehub==0.3.11 keras==3.9.2 keras-hub==0.20.0 libclang==18.1.1 libcst==1.7.0 -markdown==3.8 +libtpu==0.0.32 +# libtpu-nightly is usually installed directly via URL, but pinning it helps tracking +# libtpu-nightly==0.1.dev20240617+default +Markdown==3.8 markdown-it-py==3.0.0 -markupsafe==3.0.2 +MarkupSafe==3.0.2 mccabe==0.7.0 mdurl==0.1.2 ml-collections==1.1.0 @@ -54,23 +60,37 @@ nest-asyncio==1.6.0 networkx==3.4.2 nodeenv==1.9.1 numpy==2.1.3 +nvidia-cublas-cu12==12.4.5.8 +nvidia-cuda-cupti-cu12==12.4.127 +nvidia-cuda-nvrtc-cu12==12.4.127 +nvidia-cuda-runtime-cu12==12.4.127 +nvidia-cudnn-cu12==9.1.0.70 +nvidia-cufft-cu12==11.2.1.3 +nvidia-curand-cu12==10.3.5.147 +nvidia-cusolver-cu12==11.6.1.9 +nvidia-cusparse-cu12==12.3.1.170 +nvidia-cusparselt-cu12==0.6.2 +nvidia-nccl-cu12==2.21.5 +nvidia-nvjitlink-cu12==12.4.127 +nvidia-nvtx-cu12==12.4.127 opt-einsum==3.4.0 optax==0.2.4 optree==0.15.0 -orbax-checkpoint==0.11.12 +orbax-checkpoint==0.11.31 packaging==24.2 platformdirs==4.3.7 pluggy==1.5.0 +portpicker==1.6.0 pre-commit==4.2.0 promise==2.3 -protobuf==5.29.4 +# protobuf==6.33.4 psutil==7.0.0 pyarrow==19.0.1 -pygments==2.19.1 +Pygments==2.19.1 pylint==3.3.6 pytest==8.3.5 pytest-env==1.1.5 -pyyaml==6.0.2 +PyYAML==6.0.2 regex==2024.11.6 requests==2.32.3 rich==14.0.0 @@ -84,9 +104,10 @@ tensorboard==2.19.0 tensorboard-data-server==0.7.2 tensorflow==2.19.0 tensorflow-datasets==4.9.8 +tensorflow-io-gcs-filesystem==0.37.1 tensorflow-metadata==1.17.1 tensorflow-text==2.19.0 -tensorstore==0.1.73 +tensorstore==0.1.80 termcolor==3.0.1 toml==0.10.2 tomlkit==0.13.2 @@ -94,11 +115,12 @@ toolz==1.0.0 torch==2.6.0 tqdm==4.67.1 treescope==0.1.9 -typing-extensions==4.13.2 +triton==3.2.0 +typing_extensions==4.13.2 urllib3==2.4.0 virtualenv==20.30.0 wadler-lindig==0.1.5 -werkzeug==3.1.3 +Werkzeug==3.1.3 wheel==0.45.1 wrapt==1.17.2 -zipp==3.21.0 +zipp==3.21.0 \ No newline at end of file