Skip to content

perf(inference): fused RMSNorm+Residual kernel to reduce dispatch count by 48/token #182

@ohdearquant

Description

@ohdearquant

Problem

Each of the 24 transformer layers has 2 normalization points, each currently requiring 2 Metal dispatches: one for RMSNorm, one for residual addition. That's 24 × 2 × 2 = 96 dispatches just for norm+residual.

Proposal

Fuse RMSNorm + residual add into a single kernel per normalization point:

  • Read input once
  • Compute variance → normalize → scale → add residual
  • Write output once

This cuts 48 dispatches/token (from 304 to ~256), a 16% dispatch reduction.

Evidence

  • metal_qwen35.rs line ~8663 (GDN) and ~9043 (GQA): norm and residual dispatched separately
  • Metal best practice: minimize dispatch count to reduce command-buffer overhead
  • At 304 dispatches/token with ~0.02ms overhead/dispatch, 48 fewer = ~1ms/token potential

Acceptance

  • Single kernel handles: output = RMSNorm(x) * weight + residual
  • No PPL regression (±0.01 vs reference)
  • bench-compare shows measurable intercept improvement
  • clippy clean

Priority

P1

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or requestlattice-inferenceAffects the lattice-inference crate (transformer inference)

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions