Setup
- H800 × 4, TP=4, bf16
- T=20468, B=1, chunk_size=64, varlen API
Observation
Comparing chunk_gated_delta_rule output with FLA Triton as reference (zero initial state):
┌───────┬──────────┬───────────┬─────────┐
│ T │ max_diff │ mean_diff │ cos_sim │
├───────┼──────────┼───────────┼─────────┤
│ 512 │ 0.026 │ 3.7e-5 │ 0.99806 │
├───────┼──────────┼───────────┼─────────┤
│ 20468 │ 0.047 │ 1.2e-5 │ 0.99947 │
└───────┴──────────┴───────────┴─────────┘
Mean error is small, but max_diff up to 0.047 is enough to flip argmax at specific positions. In an end-to-end generation test
with prompts padded to T=20468, FlashQLA produces wrong tokens on 2/3 prompts where Triton is correct (e.g. "The capital of the
UK is" → Triton: London, FlashQLA: .).
Question
Is this level of numerical deviation from FLA Triton expected? If so, is there a recommended tolerance or a reference correctness
test for non-padded long-context inputs?
Setup
Observation
Comparing chunk_gated_delta_rule output with FLA Triton as reference (zero initial state):
┌───────┬──────────┬───────────┬─────────┐
│ T │ max_diff │ mean_diff │ cos_sim │
├───────┼──────────┼───────────┼─────────┤
│ 512 │ 0.026 │ 3.7e-5 │ 0.99806 │
├───────┼──────────┼───────────┼─────────┤
│ 20468 │ 0.047 │ 1.2e-5 │ 0.99947 │
└───────┴──────────┴───────────┴─────────┘
Mean error is small, but max_diff up to 0.047 is enough to flip argmax at specific positions. In an end-to-end generation test
with prompts padded to T=20468, FlashQLA produces wrong tokens on 2/3 prompts where Triton is correct (e.g. "The capital of the
UK is" → Triton: London, FlashQLA: .).
Question
Is this level of numerical deviation from FLA Triton expected? If so, is there a recommended tolerance or a reference correctness
test for non-padded long-context inputs?