Skip to content

perf(inference): chunked GDN prefill scan (B=64 to 128, simdgroup_matrix) #175

@ohdearquant

Description

@ohdearquant

GDN decode is sequential, but prefill uses a chunkwise DeltaNet/GDN scan: parallel within a chunk, sequential carry across chunks (L/B carries instead of L). The within-chunk pieces (KS0^T, KK^T, QS0^T, U^TK) are GEMM-like (M=B) — the right place for simdgroup_matrix.

Tasks

  • chunkwise WY/DPLR form (Yang et al. Gated DeltaNet; Kimi Linear KDA chunkwise) — within-chunk causal correction + decayed state carry
  • B=64 bring-up, B=128 tiled production (autotune {32,64,128,256})
  • simdgroup_matrix for the GEMM pieces (M=B ≥64 crossover)

Acceptance (ADR-064 gates)

  • GDN-prefill recurrence speedup vs sequential scan: ≥5× @4k, ≥10× @16k (expected 8-18×)
  • chunked vs sequential f32-state ref: PPL Δ≤0.005; no chunk-size-dependent greedy divergence
  • TTFT reported separately at 1K/4K/16K

Ref: d3§2. Complements prefill FA work (#126). Note: d3§2 (WY/DPLR kernel math) is the least implementation-complete section — may warrant a focused research pass before coding.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or requestlattice-inferenceAffects the lattice-inference crate (transformer inference)

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions