Skip to content

[Feat] Align quant and fused rmsnorm kernels with aiter/triton#481

Open
cschenjunlin wants to merge 13 commits into
mainfrom
cjl/fused_quant_rmsnorm
Open

[Feat] Align quant and fused rmsnorm kernels with aiter/triton#481
cschenjunlin wants to merge 13 commits into
mainfrom
cjl/fused_quant_rmsnorm

Conversation

@cschenjunlin
Copy link
Copy Markdown
Contributor

@cschenjunlin cschenjunlin commented May 8, 2026

Motivation

Align quant and fused kernels with aiter/triton

Technical Details

Test Plan

Test Result

Tested on MI308+ROCm7.1:

quant rmsnorm performance compare:

====================================================================================================
Perf Compare (gpu us): FlyDSL vs AIter
====================================================================================================
op         shape              dtype  FlyDSL(gpu us)  AIter(gpu us)    speedup
rmsnorm_dq 64x256             f32              28.6           37.0      1.29x
rmsnorm_dq 128x1024           f32              28.1           37.5      1.33x
rmsnorm_dq 32x128             f16              29.5           36.9      1.25x
rmsnorm_dq 64x2000            f32              29.3           38.7      1.32x
rmsnorm_dq 16x512             bf16             29.1           37.1      1.27x
rmsnorm_dq 1024x8192          bf16             28.9           37.4      1.30x
rmsnorm_dq 32768x8192         bf16            400.8        1,089.9      2.72x
rmsnorm_sq 64x256             f32              30.9           38.8      1.25x
rmsnorm_sq 128x1024           f32              30.7           38.0      1.24x
rmsnorm_sq 32x128             f16              30.1           38.7      1.29x
rmsnorm_sq 64x2000            f32              30.3           38.4      1.27x
rmsnorm_sq 16x512             bf16             30.5           39.7      1.30x
rmsnorm_sq 1024x8192          bf16             30.6           46.6      1.52x
rmsnorm_sq 32768x8192         bf16            535.8        1,476.0      2.75x
====================================================================================================

fused_add rmsnorm performance compare:

====================================================================================================
Perf Compare (gpu us): FlyDSL vs AIter
====================================================================================================
op         shape              dtype  FlyDSL(gpu us)  AIter(gpu us)    speedup
rmsnorm_add 64x256             f32              31.5           52.1      1.65x
rmsnorm_add 128x1024           f32              31.5           51.5      1.63x
rmsnorm_add 32x128             f16              31.3           50.6      1.62x
rmsnorm_add 64x2000            f32              30.8           51.4      1.67x
rmsnorm_add 16x512             bf16             31.1           51.3      1.65x
rmsnorm_add 1024x8192          bf16             31.2           55.7      1.79x
rmsnorm_add 32768x8192         bf16            814.7        1,661.4      2.04x
====================================================================================================

fused add quant rmsnorm performance compare:

====================================================================================================
Perf Compare (gpu us): FlyDSL vs AIter
====================================================================================================
op         shape              dtype  FlyDSL(gpu us)  AIter(gpu us)    speedup
rmsnorm_add_dq 64x256             f32              32.8           38.8      1.18x
rmsnorm_add_dq 128x1024           f32              33.3           37.6      1.13x
rmsnorm_add_dq 32x128             f16              33.2           37.6      1.13x
rmsnorm_add_dq 64x2000            f32              33.5           39.3      1.17x
rmsnorm_add_dq 16x512             bf16             32.8           39.2      1.20x
rmsnorm_add_dq 1024x8192          bf16             33.3           54.6      1.64x
rmsnorm_add_dq 32768x8192         bf16            731.4        1,562.0      2.14x
rmsnorm_add_sq 64x256             f32              35.2           42.2      1.20x
rmsnorm_add_sq 128x1024           f32              35.1           41.9      1.19x
rmsnorm_add_sq 32x128             f16              35.5           39.8      1.12x
rmsnorm_add_sq 64x2000            f32              36.4           40.8      1.12x
rmsnorm_add_sq 16x512             bf16             35.1           43.1      1.23x
rmsnorm_add_sq 1024x8192          bf16             35.8           64.1      1.79x
rmsnorm_add_sq 32768x8192         bf16            888.1        1,847.8      2.08x
====================================================================================================

Submission Checklist

  • quant_rms_norm_kernel
  • fused_add_rmsnorm_kernel
  • quant_fused_add_rmsnorm_kernel
  • rmsnorm_kernel_large_m_small_n

Comment thread kernels/rmsnorm_kernel.py
y = (added * rrms) * g
_store_vec(_to_elem_vec(y), out_div, idx)

else:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to copy all the codes in else?

@coderfeli
Copy link
Copy Markdown
Collaborator

@cschenjunlin any update?

@cschenjunlin
Copy link
Copy Markdown
Contributor Author

@cschenjunlin any update?

I have pushed new commits to solve the duplication issue mentioned above.

Some common functions are added, to reduce code duplication within the variant kernels. However, the scope of refactoring here is limited, as the variants have introduced some new data flow logic.
For each variant kernel, a standalone build_xx_module function is retained. This is to align with the implementation of aiter/triton, where each variant kernel is also implemented as an standalone kernel separate from the base version.

If further reduction of code duplication is needed, I can attempt to converge the build_xx_module functions of all variants into a unified implementation, and use flag parameters to differentiate branch logic internally.

@i-chaochen
Copy link
Copy Markdown

@coderfeli Sorry for the delay, please check this PR again.

Thanks!

Comment thread kernels/rmsnorm_kernel.py Outdated
return allocator, red_offset, red2_offset


def _load_scalar(copy_atom, scalar_reg_ty, scalar_reg_lay, divided_tensor, index):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This now conflicts with main, which moved these register temporaries to fx.make_rmem_tensor/internal types. Please rebase and keep the shared helpers on the new API instead of reintroducing MemRefType + memref_alloca.

return ok, flydsl_gpu_us


def test_rmsnorm_fused_add_dynamicquant():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These new variant tests are pytest-only today. run_benchmark.sh executes this file as a script, but main only calls test_all(), so the fused/quant variants are not exercised in that benchmark path.

Resolve RMSNorm conflicts by keeping the branch variants on the current make_rmem_tensor-based register helper API.

Co-authored-by: Cursor <cursoragent@cursor.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants