diff --git a/miles/backends/megatron_utils/model_provider.py b/miles/backends/megatron_utils/model_provider.py index e7c24a3191..d99673f748 100644 --- a/miles/backends/megatron_utils/model_provider.py +++ b/miles/backends/megatron_utils/model_provider.py @@ -110,6 +110,15 @@ def wrapped_model_provider( provider.moe_router_bias_update_rate = args.moe_router_bias_update_rate if getattr(args, "moe_aux_loss_coeff", None) is not None: provider.moe_aux_loss_coeff = args.moe_aux_loss_coeff + # The bridge provider defaults gradient_accumulation_fusion=True + # (via fusions.can_enable_gradient_accumulation_fusion). On ROCm/gfx950, + # this makes TE's LayerNormLinear backward issue a bias-fused wgrad GEMM + # with bf16 inputs → fp32 output + HIPBLASLT_EPILOGUE_BGRADB + accumulate, + # for which hipBLASLt has no algorithm. Honor the Megatron CLI flag so + # that --no-gradient-accumulation-fusion actually takes effect. + provider.gradient_accumulation_fusion = getattr( + args, "gradient_accumulation_fusion", provider.gradient_accumulation_fusion + ) provider.finalize() def wrapped_bridge_provider( diff --git a/tests/e2e/sglang_config/test_sglang_config_mixed_offload.py b/tests/e2e/sglang_config/test_sglang_config_mixed_offload.py index 79c54a42a6..99d0140403 100644 --- a/tests/e2e/sglang_config/test_sglang_config_mixed_offload.py +++ b/tests/e2e/sglang_config/test_sglang_config_mixed_offload.py @@ -16,10 +16,13 @@ import os import tempfile +import torch from tests.ci.ci_register import register_cuda_ci import miles.utils.external_utils.command_utils as U +IS_ROCM = torch.version.hip is not None + register_cuda_ci(est_time=600, suite="stage-b-short-8-gpu", num_gpus=8) TIGHT_DEVICE_MEMORY = U.get_bool_env_var("MILES_TEST_TIGHT_DEVICE_MEMORY", "1") @@ -119,13 +122,17 @@ def execute(): f"--sglang-config {config_path} " ) - ci_args = "--ci-test " + ci_args = ( + "--ci-test " + + ("--ci-disable-kl-checker --ci-disable-logprobs-checker " if IS_ROCM else "") + ) misc_args = ( "--attention-dropout 0.0 " "--hidden-dropout 0.0 " "--accumulate-allreduce-grads-in-fp32 " - "--attention-softmax-in-fp32 " + + ("--no-gradient-accumulation-fusion " if IS_ROCM else "") + + "--attention-softmax-in-fp32 " "--attention-backend flash " "--actor-num-nodes 1 " "--actor-num-gpus-per-node 8 "