Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions bionemo-recipes/recipes/esm2_native_te/fp8_stats.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
example_fp8_tensor_stat_collection:
enabled: True
layers:
layer_types: [self_attention]
transformer_engine:
LogFp8TensorStats:
enabled: True
tensors_struct:
- tensor: activation
stats: [mxfp8_underflows%]
freq: 1
- tensor: activation
stats: [mxfp8_overflows%]
freq: 1
- tensor: activation
stats: [mxfp8_scale_inv_min]
freq: 1
- tensor: activation
stats: [mxfp8_scale_inv_max]
freq: 1
- tensor: activation
stats: [mxfp8_mse]
freq: 1
- tensor: gradient
stats: [underflows%]
freq: 5
start_step: 0
end_step: 80
19 changes: 19 additions & 0 deletions bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
from perf_logger import PerfLogger
from scheduler import get_linear_schedule_with_warmup

import nvdlfw_inspect.api as debug_api


logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -84,6 +86,16 @@ def main(args: DictConfig) -> float | None:

logger.info("Initialized Model:\n%s", model)

# TE Debug feature logging
debug_api.initialize(
config_file="/workspaces/bionemo-framework/bionemo-recipes/recipes/esm2_native_te/fp8_stats.yaml",
feature_dirs=["/usr/local/lib/python3.12/dist-packages/transformer_engine/debug/features/"],
log_dir="./log",
default_logging_enabled=True
)

debug_api.infer_and_assign_layer_names(model)

# We call the transformer stack "layers" in our TE models, but it's called "layer" in the original ESM-2 models.
transformer_stack = model.esm.encoder.layers if hasattr(model.esm.encoder, "layers") else model.esm.encoder.layer
for layer in transformer_stack:
Expand Down Expand Up @@ -126,6 +138,9 @@ def main(args: DictConfig) -> float | None:

perf_logger = PerfLogger(dist_config, args)




# Training loop
step = start_step
while step < args.num_train_steps:
Expand All @@ -146,6 +161,9 @@ def main(args: DictConfig) -> float | None:
# Step optimizer.
optimizer.step()
scheduler.step()

debug_api.step()

optimizer.zero_grad()

perf_logger.log_step(
Expand All @@ -169,6 +187,7 @@ def main(args: DictConfig) -> float | None:
max_checkpoints=args.checkpoint.max_checkpoints,
)


step += 1
if step >= args.num_train_steps:
break
Expand Down