Skip to content

add gfx950-specific MLA fp8 decode kernel#517

Open
charlieguo1106 wants to merge 3 commits into
mainfrom
cguo/mla_gfx950
Open

add gfx950-specific MLA fp8 decode kernel#517
charlieguo1106 wants to merge 3 commits into
mainfrom
cguo/mla_gfx950

Conversation

@charlieguo1106
Copy link
Copy Markdown

@charlieguo1106 charlieguo1106 commented May 13, 2026

Summary

  • Add a gfx950-specific MLA FP8 decode kernel path using ds_read_b64_tr_b8.
  • Route FP8 MLA decode on gfx950 to the new kernel while keeping the existing path for other architectures.
  • Extend MLA decode benchmark/test tooling with stable repeated timing and optional aiter comparisons.

Test plan

python tests/kernels/test_mla_decode.py --bench_aiter

MI3xx
308
MI35x
355

@charlieguo1106 charlieguo1106 force-pushed the cguo/mla_gfx950 branch 2 times, most recently from b2a885f to fb51a87 Compare May 13, 2026 06:49
@charlieguo1106 charlieguo1106 requested a review from coderfeli May 13, 2026 06:57
Copy link
Copy Markdown
Collaborator

@coderfeli coderfeli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the gfx950-specific path. I think this needs one more round before merge because the new target path is not actually covered by the current test, and the metadata contract is unclear.

Main concern: tests/kernels/test_mla_decode.py now skips the whole pytest on gfx950, with the reason that AITER emits a folded nh=16 work-info layout while this FlyDSL kernel only supports the native nh=128 layout. But kernels/mla_fwd_decode.py routes fp8/fp8 MLA decode on gfx950 to the new kernel, and the public API still consumes caller-provided work_indptr / work_info_set. As written, CI can pass on gfx950 while never exercising the new gfx950 kernel through pytest, and a caller reusing the same AITER metadata path would hit the newly routed kernel with the incompatible layout.

Please either add a gfx950 test path that generates/uses the native nh=128 metadata expected by this kernel, or make the launcher/API contract explicit and fail loudly when the provided metadata is the folded gfx950 layout. Also please align the CLI behavior with pytest: running python tests/kernels/test_mla_decode.py on gfx950 still calls run_single() despite the same incompatibility that causes pytest to skip.

Smaller follow-up: the routing uses arch.startswith("gfx950"). If this is meant for the CDNA4 gfx95x family, please use the same gfx95* detection style as the rest of the repo or add a comment that this is intentionally gfx950-only.

@charlieguo1106 charlieguo1106 force-pushed the cguo/mla_gfx950 branch 3 times, most recently from b43092c to 41cbfc9 Compare May 14, 2026 08:27
@charlieguo1106 charlieguo1106 requested a review from coderfeli May 14, 2026 08:31
coderfeli and others added 2 commits May 16, 2026 02:23
Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants