add gfx950-specific MLA fp8 decode kernel#517
Conversation
b2a885f to
fb51a87
Compare
fb51a87 to
9130416
Compare
coderfeli
left a comment
There was a problem hiding this comment.
Thanks for adding the gfx950-specific path. I think this needs one more round before merge because the new target path is not actually covered by the current test, and the metadata contract is unclear.
Main concern: tests/kernels/test_mla_decode.py now skips the whole pytest on gfx950, with the reason that AITER emits a folded nh=16 work-info layout while this FlyDSL kernel only supports the native nh=128 layout. But kernels/mla_fwd_decode.py routes fp8/fp8 MLA decode on gfx950 to the new kernel, and the public API still consumes caller-provided work_indptr / work_info_set. As written, CI can pass on gfx950 while never exercising the new gfx950 kernel through pytest, and a caller reusing the same AITER metadata path would hit the newly routed kernel with the incompatible layout.
Please either add a gfx950 test path that generates/uses the native nh=128 metadata expected by this kernel, or make the launcher/API contract explicit and fail loudly when the provided metadata is the folded gfx950 layout. Also please align the CLI behavior with pytest: running python tests/kernels/test_mla_decode.py on gfx950 still calls run_single() despite the same incompatibility that causes pytest to skip.
Smaller follow-up: the routing uses arch.startswith("gfx950"). If this is meant for the CDNA4 gfx95x family, please use the same gfx95* detection style as the rest of the repo or add a comment that this is intentionally gfx950-only.
b43092c to
41cbfc9
Compare
41cbfc9 to
6318ea8
Compare
6318ea8 to
cf6401f
Compare
Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Summary
ds_read_b64_tr_b8.Test plan
python tests/kernels/test_mla_decode.py --bench_aiter
MI3xx


MI35x