Skip to content

Rmsnorm backward fusing sum#101

Open
AaronWang04 wants to merge 3 commits intoDao-AILab:mainfrom
AaronWang04:rmsnorm_bwd_bulk_reduce
Open

Rmsnorm backward fusing sum#101
AaronWang04 wants to merge 3 commits intoDao-AILab:mainfrom
AaronWang04:rmsnorm_bwd_bulk_reduce

Conversation

@AaronWang04
Copy link
Copy Markdown

@AaronWang04 AaronWang04 commented Apr 8, 2026

Eliminates the extra .sum(dim=0) kernel for dw_partial reduction in the RMSNorm backward pass. Each CTA writes its partial to dw_partial as before, loads it back into smem contiguously, then thread 0 issues a bulk async reduce-add to a dw tensor. Only for N ≤ 8192 (no intra-cta reduction)

Benchmark sweeped between M=4k to 65k and N=256 to 8k, no performance regressions in any case and on average 1.2x speedup on gb200.

Rough results with triton.testing.do_bench

==========================================================================================                                                                                                                                                                                 
dtype: torch.bfloat16                                                                                                                                                                                                                                                      
==========================================================================================                                                                                                                                                                                 
M=  4096, N=  256 | fused=0.0335ms  separate=0.0571ms  speedup=1.71x  FUSED                                                                                                                                                                                                
M=  8192, N=  256 | fused=0.0344ms  separate=0.0548ms  speedup=1.59x  FUSED                                                                                                                                                                                                
M= 16384, N=  256 | fused=0.0357ms  separate=0.0516ms  speedup=1.44x  FUSED                                                                                                                                                                                                
M= 32768, N=  256 | fused=0.0399ms  separate=0.0514ms  speedup=1.29x  FUSED                                                                                                                                                                                                
M= 65536, N=  256 | fused=0.0464ms  separate=0.0527ms  speedup=1.14x  FUSED                                                                                                                                                                                                
M=  4096, N=  512 | fused=0.0245ms  separate=0.0583ms  speedup=2.38x  FUSED                                                                                                                                                                                                
M=  8192, N=  512 | fused=0.0267ms  separate=0.0542ms  speedup=2.03x  FUSED                                                                                                                                                                                                
M= 16384, N=  512 | fused=0.0311ms  separate=0.0529ms  speedup=1.70x  FUSED                                                                                                                                                                                                
M= 32768, N=  512 | fused=0.0426ms  separate=0.0532ms  speedup=1.25x  FUSED                                                                                                                                                                                                
M= 65536, N=  512 | fused=0.0646ms  separate=0.0701ms  speedup=1.09x  FUSED                                                                                                                                                                                                
M=  4096, N= 1024 | fused=0.0257ms  separate=0.0557ms  speedup=2.17x  FUSED                                                                                                                                                                                                
M=  8192, N= 1024 | fused=0.0298ms  separate=0.0546ms  speedup=1.83x  FUSED                                                                                                                                                                                                
M= 16384, N= 1024 | fused=0.0405ms  separate=0.0520ms  speedup=1.28x  FUSED                                                                                                                                                                                                
M= 32768, N= 1024 | fused=0.0592ms  separate=0.0685ms  speedup=1.16x  FUSED                                                                                                                                                                                                
M= 65536, N= 1024 | fused=0.0953ms  separate=0.1132ms  speedup=1.19x  FUSED                                                                                                                                                                                                
M=  4096, N= 2048 | fused=0.0262ms  separate=0.0530ms  speedup=2.02x  FUSED                                                                                                                                                                                                
M=  8192, N= 2048 | fused=0.0386ms  separate=0.0513ms  speedup=1.33x  FUSED                                                                                                                                                                                                
M= 16384, N= 2048 | fused=0.0597ms  separate=0.0683ms  speedup=1.14x  FUSED                                                                                                                                                                                                
M= 32768, N= 2048 | fused=0.1012ms  separate=0.1064ms  speedup=1.05x  FUSED                                                                                                                                                                                                
M= 65536, N= 2048 | fused=0.1845ms  separate=0.1827ms  speedup=0.99x  SEPARATE                                                                                                                                                                                             
M=  4096, N= 4096 | fused=0.0347ms  separate=0.0491ms  speedup=1.41x  FUSED                                                                                                                                                                                                
M=  8192, N= 4096 | fused=0.0559ms  separate=0.0644ms  speedup=1.15x  FUSED                                                                                                                                                                                                
M= 16384, N= 4096 | fused=0.0962ms  separate=0.1071ms  speedup=1.11x  FUSED                                                                                                                                                                                                
M= 32768, N= 4096 | fused=0.1801ms  separate=0.1930ms  speedup=1.07x  FUSED                                                                                                                                                                                                
M= 65536, N= 4096 | fused=0.3455ms  separate=0.3642ms  speedup=1.05x  FUSED 

Also from my testing for bf16, at most 1 ULP off for numerics

Comment thread quack/rmsnorm.py Outdated
Comment thread quack/rmsnorm.py
Comment on lines +861 to +863
copy_utils.cpasync_reduce_bulk_add_f32(
sdW_buf.iterator, mdW_final.iterator, store_bytes,
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

without a semaphore guarding this, this will result in non-deterministic reductions

Copy link
Copy Markdown
Author

@AaronWang04 AaronWang04 Apr 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but if we add a semaphore I don't think this fusion becomes worth it.

should I add a bool flag for determinism which forces the unfused .sum() path instead?

fwiw i think the performance gains are worth it (I still haven't tuned sm_count yet)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think a backward kernel without deterministic reduction is worth adding tbh. I do not see anyone using it.

@tridao
Copy link
Copy Markdown
Member

tridao commented Apr 13, 2026

We could try w semaphore to see if there's still any perf win.

@tridao
Copy link
Copy Markdown
Member

tridao commented Apr 13, 2026

i don't think you need to load dW from gmem -> smem. Just do cp.async.bulk.reduce from smem -> gmem. With a semaphore this would be deterministic.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants