diff --git a/bionemo-recipes/recipes/esm2_native_te/fp8_stats.yaml b/bionemo-recipes/recipes/esm2_native_te/fp8_stats.yaml new file mode 100644 index 0000000000..54b0bd0ba1 --- /dev/null +++ b/bionemo-recipes/recipes/esm2_native_te/fp8_stats.yaml @@ -0,0 +1,28 @@ +example_fp8_tensor_stat_collection: + enabled: True + layers: + layer_types: [self_attention] + transformer_engine: + LogFp8TensorStats: + enabled: True + tensors_struct: + - tensor: activation + stats: [mxfp8_underflows%] + freq: 1 + - tensor: activation + stats: [mxfp8_overflows%] + freq: 1 + - tensor: activation + stats: [mxfp8_scale_inv_min] + freq: 1 + - tensor: activation + stats: [mxfp8_scale_inv_max] + freq: 1 + - tensor: activation + stats: [mxfp8_mse] + freq: 1 + - tensor: gradient + stats: [underflows%] + freq: 5 + start_step: 0 + end_step: 80 \ No newline at end of file diff --git a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py index 293eafe84d..9a3c72474b 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py @@ -36,6 +36,8 @@ from perf_logger import PerfLogger from scheduler import get_linear_schedule_with_warmup +import nvdlfw_inspect.api as debug_api + logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -84,6 +86,16 @@ def main(args: DictConfig) -> float | None: logger.info("Initialized Model:\n%s", model) + # TE Debug feature logging + debug_api.initialize( + config_file="/workspaces/bionemo-framework/bionemo-recipes/recipes/esm2_native_te/fp8_stats.yaml", + feature_dirs=["/usr/local/lib/python3.12/dist-packages/transformer_engine/debug/features/"], + log_dir="./log", + default_logging_enabled=True + ) + + debug_api.infer_and_assign_layer_names(model) + # We call the transformer stack "layers" in our TE models, but it's called "layer" in the original ESM-2 models. transformer_stack = model.esm.encoder.layers if hasattr(model.esm.encoder, "layers") else model.esm.encoder.layer for layer in transformer_stack: @@ -126,6 +138,9 @@ def main(args: DictConfig) -> float | None: perf_logger = PerfLogger(dist_config, args) + + + # Training loop step = start_step while step < args.num_train_steps: @@ -146,6 +161,9 @@ def main(args: DictConfig) -> float | None: # Step optimizer. optimizer.step() scheduler.step() + + debug_api.step() + optimizer.zero_grad() perf_logger.log_step( @@ -169,6 +187,7 @@ def main(args: DictConfig) -> float | None: max_checkpoints=args.checkpoint.max_checkpoints, ) + step += 1 if step >= args.num_train_steps: break