rocm: disable gradient_accumulation_fusion on gfx950 in test_sglang_config_mixed_offload#1157
rocm: disable gradient_accumulation_fusion on gfx950 in test_sglang_config_mixed_offload#1157sreerohi wants to merge 2 commits into
Conversation
hipBLASLt on gfx950 (MI350/MI355) has no algorithm for the triple combination of bf16 output with fp32 accumulate + HIPBLASLT_EPILOGUE_BGRADB epilogue + accumulate=True. This fires during TE's LayerNormLinear backward when gradient_accumulation_fusion=True and the layer has bias. Root cause: the bridge provider defaults gradient_accumulation_fusion to True (via can_enable_gradient_accumulation_fusion) and ignores --no-gradient-accumulation-fusion from the CLI. The fix propagates the flag from Megatron args to the bridge provider.
There was a problem hiding this comment.
Code Review
This pull request enables the gradient_accumulation_fusion flag to be configured via command-line arguments, specifically to resolve hipBLASLt compatibility issues on ROCm/gfx950 hardware. The changes update the Megatron model provider and E2E test configurations to support this flag and adjust test behavior in ROCm environments. The reviewer suggested using an explicit null check when assigning the fusion flag to ensure consistency with the existing codebase and prevent potential attribute overwriting.
| provider.gradient_accumulation_fusion = getattr( | ||
| args, "gradient_accumulation_fusion", provider.gradient_accumulation_fusion | ||
| ) |
There was a problem hiding this comment.
To maintain consistency with the surrounding code (e.g., lines 105-112) and to avoid potentially overwriting the provider's default with None if the attribute exists in args but is uninitialized, it is recommended to use an explicit is not None check before assignment.
if getattr(args, "gradient_accumulation_fusion", None) is not None:
provider.gradient_accumulation_fusion = args.gradient_accumulation_fusion
Depends on #1153. Related to #1105 (Miles CI gap between ROCm & CUDA).
hipBLASLt on gfx950 has no algorithm for TE's bias-fused wgrad GEMM (bf16→fp32 + BGRADB + accumulate). This conditionally disables gradient_accumulation_fusion and relaxes CI numerical checkers on ROCm. CUDA is unaffected.