Skip to content

rocm: disable gradient_accumulation_fusion on gfx950 in test_sglang_config_mixed_offload#1157

Open
sreerohi wants to merge 2 commits into
radixark:mainfrom
sreerohi:rocm/bgradb-sglang-mixed-offload
Open

rocm: disable gradient_accumulation_fusion on gfx950 in test_sglang_config_mixed_offload#1157
sreerohi wants to merge 2 commits into
radixark:mainfrom
sreerohi:rocm/bgradb-sglang-mixed-offload

Conversation

@sreerohi
Copy link
Copy Markdown

Depends on #1153. Related to #1105 (Miles CI gap between ROCm & CUDA).

hipBLASLt on gfx950 has no algorithm for TE's bias-fused wgrad GEMM (bf16→fp32 + BGRADB + accumulate). This conditionally disables gradient_accumulation_fusion and relaxes CI numerical checkers on ROCm. CUDA is unaffected.

sreerohi added 2 commits May 19, 2026 16:29
hipBLASLt on gfx950 (MI350/MI355) has no algorithm for the triple
combination of bf16 output with fp32 accumulate + HIPBLASLT_EPILOGUE_BGRADB
epilogue + accumulate=True. This fires during TE's LayerNormLinear backward
when gradient_accumulation_fusion=True and the layer has bias.

Root cause: the bridge provider defaults gradient_accumulation_fusion to
True (via can_enable_gradient_accumulation_fusion) and ignores
--no-gradient-accumulation-fusion from the CLI. The fix propagates the
flag from Megatron args to the bridge provider.
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enables the gradient_accumulation_fusion flag to be configured via command-line arguments, specifically to resolve hipBLASLt compatibility issues on ROCm/gfx950 hardware. The changes update the Megatron model provider and E2E test configurations to support this flag and adjust test behavior in ROCm environments. The reviewer suggested using an explicit null check when assigning the fusion flag to ensure consistency with the existing codebase and prevent potential attribute overwriting.

Comment on lines +119 to +121
provider.gradient_accumulation_fusion = getattr(
args, "gradient_accumulation_fusion", provider.gradient_accumulation_fusion
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To maintain consistency with the surrounding code (e.g., lines 105-112) and to avoid potentially overwriting the provider's default with None if the attribute exists in args but is uninitialized, it is recommended to use an explicit is not None check before assignment.

        if getattr(args, "gradient_accumulation_fusion", None) is not None:
            provider.gradient_accumulation_fusion = args.gradient_accumulation_fusion

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant