Skip to content

esimd: fix resadd_norm_gemv_int4 race on large N#399

Draft
hzjane wants to merge 1 commit into
mainfrom
fix_27b_kernel
Draft

esimd: fix resadd_norm_gemv_int4 race on large N#399
hzjane wants to merge 1 commit into
mainfrom
fix_27b_kernel

Conversation

@hzjane
Copy link
Copy Markdown
Contributor

@hzjane hzjane commented May 10, 2026

The fused ResAddNormGEMV INT4 kernel has every work-group read residual in pass 1 (for sum_sq), but only n==0 writes the updated residual back to global memory. When N is large enough that work-groups cannot all run concurrently, the later WGs observe the already-updated residual and compute added = h + (h + r_old) instead of h + r_old, which garbles the GEMV output for all downstream rows.

Observed on Qwen3.6-27B dense gate_up (N=8704, K=5120, TP=4): kernel rel-error vs IPEX is ~0.10 (quant noise) at N=128 but jumps to 0.28 once N crosses ~512, even though the per-row math is otherwise identical. Qwen3.6-35B MoE router (N=256) stays below the threshold, which is why the bug did not surface there.

Fix: for N > 512, split into a prepass kernel that owns the residual / normed_out update (single WG, no cross-WG dep), followed by a pure GEMV-from-normed kernel that just reads normed_out. For N <= 512 keep the original fused path (faster and race-free because all WGs run concurrently on the existing hardware).

Verified:

  • Standalone correctness sweep: rel-error at N=5120/8704, K up to 5120 drops from 0.28-0.33 back to ~0.10.
  • Qwen3.6-27B sym_int4 TP=4 generates coherent text.
  • Qwen3.6-35B-A3B sym_int4 TP=4 unchanged (still uses the fused path for N=256).

The fused ResAddNormGEMV INT4 kernel has every work-group read residual
in pass 1 (for sum_sq), but only n==0 writes the updated residual back
to global memory. When N is large enough that work-groups cannot all run
concurrently, the later WGs observe the already-updated residual and
compute added = h + (h + r_old) instead of h + r_old, which garbles the
GEMV output for all downstream rows.

Observed on Qwen3.6-27B dense gate_up (N=8704, K=5120, TP=4): kernel
rel-error vs IPEX is ~0.10 (quant noise) at N=128 but jumps to 0.28 once
N crosses ~512, even though the per-row math is otherwise identical.
Qwen3.6-35B MoE router (N=256) stays below the threshold, which is why
the bug did not surface there.

Fix: for N > 512, split into a prepass kernel that owns the residual /
normed_out update (single WG, no cross-WG dep), followed by a pure
GEMV-from-normed kernel that just reads normed_out. For N <= 512 keep
the original fused path (faster and race-free because all WGs run
concurrently on the existing hardware).

Verified:
  * Standalone correctness sweep: rel-error at N=5120/8704, K up to
    5120 drops from 0.28-0.33 back to ~0.10.
  * Qwen3.6-27B sym_int4 TP=4 generates coherent text.
  * Qwen3.6-35B-A3B sym_int4 TP=4 unchanged (still uses the fused
    path for N=256).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant