Conversation
| copy_utils.cpasync_reduce_bulk_add_f32( | ||
| sdW_buf.iterator, mdW_final.iterator, store_bytes, | ||
| ) |
There was a problem hiding this comment.
without a semaphore guarding this, this will result in non-deterministic reductions
There was a problem hiding this comment.
Yes, but if we add a semaphore I don't think this fusion becomes worth it.
should I add a bool flag for determinism which forces the unfused .sum() path instead?
fwiw i think the performance gains are worth it (I still haven't tuned sm_count yet)
There was a problem hiding this comment.
I dont think a backward kernel without deterministic reduction is worth adding tbh. I do not see anyone using it.
|
We could try w semaphore to see if there's still any perf win. |
ffacb0e to
521174c
Compare
|
i don't think you need to load dW from gmem -> smem. Just do cp.async.bulk.reduce from smem -> gmem. With a semaphore this would be deterministic. |
Eliminates the extra .sum(dim=0) kernel for dw_partial reduction in the RMSNorm backward pass. Each CTA writes its partial to dw_partial as before, loads it back into smem contiguously, then thread 0 issues a bulk async reduce-add to a dw tensor. Only for N ≤ 8192 (no intra-cta reduction)
Benchmark sweeped between M=4k to 65k and N=256 to 8k, no performance regressions in any case and on average 1.2x speedup on gb200.
Rough results with triton.testing.do_bench
Also from my testing for bf16, at most 1 ULP off for numerics