Skip to content

FlashQLA GDN backend produces different tokens from FLA Triton on correctness-test prompts #15

@JessePinkman2019

Description

@JessePinkman2019

Setup

  • H800 × 4, TP=4, bf16
  • T=20468, B=1, chunk_size=64, varlen API

Observation

Comparing chunk_gated_delta_rule output with FLA Triton as reference (zero initial state):

┌───────┬──────────┬───────────┬─────────┐
│ T │ max_diff │ mean_diff │ cos_sim │
├───────┼──────────┼───────────┼─────────┤
│ 512 │ 0.026 │ 3.7e-5 │ 0.99806 │
├───────┼──────────┼───────────┼─────────┤
│ 20468 │ 0.047 │ 1.2e-5 │ 0.99947 │
└───────┴──────────┴───────────┴─────────┘

Mean error is small, but max_diff up to 0.047 is enough to flip argmax at specific positions. In an end-to-end generation test
with prompts padded to T=20468, FlashQLA produces wrong tokens on 2/3 prompts where Triton is correct (e.g. "The capital of the
UK is" → Triton: London, FlashQLA: .).

Question

Is this level of numerical deviation from FLA Triton expected? If so, is there a recommended tolerance or a reference correctness
test for non-padded long-context inputs?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions