Skip to content

[V100/SM70] gemma-4 (31B + MTP) support: fully-FA hybrid attention (sliding-window + head_dim-512)#59

Closed
rivetphilbot wants to merge 25 commits into
1CatAI:mainfrom
rivetphilbot:feat/gemma4-mtp
Closed

[V100/SM70] gemma-4 (31B + MTP) support: fully-FA hybrid attention (sliding-window + head_dim-512)#59
rivetphilbot wants to merge 25 commits into
1CatAI:mainfrom
rivetphilbot:feat/gemma4-mtp

Conversation

@rivetphilbot

Copy link
Copy Markdown

Summary

Adds gemma-4 (31B, multimodal) + MTP speculative decoding support on V100/SM70, and makes gemma's hybrid attention run entirely on the Volta FLASH_ATTN_V100 kernels (decode and prefill, no Triton fallback). Stacks on the SM70 W4A16/DeltaNet stack (cf. #45 / rollup #55).

What's here

Model/MTP enablement (gemma4 backbone, multimodal towers, MTP drafter gemma4_mtp):

  • Register gemma4_assistant config so the MTP drafter loads; proportional-RoPE / activation / MM runtime fixes.
  • Fix MTP target KV-cache corruption (set kv_sharing_target_layer_name on the attn impl, not just the module).
  • Load vision/audio tower clip buffers + MTP ordered-embedding buffers.

FLASH_ATTN_V100 hybrid-attention support (the perf core):

  • Sliding-window in the decode + paged-prefill + dense kernels (per-token mask + whole-partition early-out; prefill window-tile-skip). Unblocks the 50 sliding (head_dim-256) layers.
  • head_dim-512 for the 10 global full-attention layers — decode (head-dim-generic GEMV) and prefill (Config<512> with small blocks to fit 96KB smem; the WMMA body already accumulates over D in chunks, so no split-D rewrite).
  • Backend gate accepts causal sliding windows + routes all 60 layers to FA.

Results (gemma-4-31b-qat-w4a16 + bf16-assistant MTP, 2×V100, c1)

metric @ 8k ctx Triton baseline fully-FA gain
decode TPOT 310 ms 25 ms 12×
TTFT (prefill) 63 s 8.3 s 7.6×
throughput 1.25 tok/s 11.2 tok/s
short c1 / c4 17.6 / 54.6 39.7 / 102.9 ~2×

Correctness

  • Standalone kernel parity vs fp32 reference (flash-attention-v100/test_window.py): decode + paged/dense prefill, full + windowed, D∈{128,256,512}, incl. S=8192 edge cases — all pass (~1e-3).
  • In-model: long-context retrieval correct (fact at position 0 of a 3,359-token prompt recalled), coherence clean.

🤖 Generated with Claude Code

rivetphilbot and others added 25 commits May 19, 2026 01:29
Add SM70TurboMindLinearKernel, an MPLinearKernel implementation that
routes compressed-tensors / AWQ WNA16 dense GEMMs through the bundled
TurboMind sm70_884_4 INT4 path. V100 (CC 7.0) has only first-gen FP16
WMMA cores and no Turing INT4 tensor-core GEMM, so the stock CUTLASS /
Machete kernels are unavailable; this kernel gives dense WNA16 layers a
working code path on SM70.

Register it at the head of the CUDA _POSSIBLE_KERNELS priority list so
it is preferred when running on V100; on newer architectures the
existing kernels still win their min-capability checks.
CompressedTensorsWNA16.get_min_capability hard-coded 75, so loading a
compressed-tensors WNA16 model on a V100 failed with 'Failed to find a
kernel that can implement the WNA16 linear layer' before the new
SM70TurboMindLinearKernel ever got a chance to bid.

Lower the reported minimum to 70 specifically when running on an SM70
device (CC 7.0). Older pre-Turing GPUs (sm_60/61/62) still get 75 and
remain correctly rejected, since only V100 has the FP16 WMMA path the
TurboMind kernel relies on.
CompressedTensorsSM70WNA16MoEMethod delegates its decode to
AWQSM70MoEMethod.apply(), but only allocated a subset of the buffers
that path reads, so a CT-quantized MoE model crashed on the first
decode step on V100.

- Allocate the full buffer set AWQSM70MoEMethod.process_weights_after_
  loading creates: gate/up and permutation scratch, sorted-output and
  m-index buffers, int64 expert offsets, and the single-token batched
  pointer buffers.
- Publish sm70_hidden/intermediate logical+aligned sizes (CT weights
  are already in TurboMind layout, so logical == aligned).
- Build per-expert StridedPtr row views and record sm70_ptr_row_bytes
  for the batched GEMM path.
- Pass interleave_gated_silu=True to awq_sm70_prepare so the fused
  gate/up weights match the decode kernel's expectation.

Also switch the import to _DEFAULT_PERSISTENT_MAX_TOKENS; awq_sm70_moe
renamed _DEFAULT_MAX_TOKENS, leaving the old name a dangling import.
…compressed-tensors

Two related fixes for running Qwen3.5/3.6 compressed-tensors checkpoints:

- Qwen3NextSparseMoeBlock: the MoE router gate is stored as bf16 in the
  checkpoint and has no quantized form. Passing the model quant_config
  to its ReplicatedLinear made the loader expect quantized weights;
  force quant_config=None so the gate stays bf16.

- _uses_split_gdn_input_projections only inspected modules_to_not_convert
  and ignored_layers. Compressed-tensors records its skip list under the
  ignore attribute, so the BF16 in_proj_a / in_proj_b GDN projections of
  a CT checkpoint were not detected and the split-projection layout was
  not selected. Consult quant_config.ignore as a final fallback.
CompressedTensorsWNA16 creates auxiliary parameters -- weight_shape
(BasevLLMParameter) and weight_g_idx (RowvLLMParameter) -- that hold
metadata or input-dim-sharded indices rather than output-dim weight
data, so they have no output_dim attribute.

When the qkvz stacked-load mapping in Qwen3_5Model.load_weights reached
one of these via the tuple-shard path, it hit AttributeError on
param.output_dim. Fix: when output_dim is absent, load through the
standard non-shard weight_loader (last-write-wins for replicated
metadata) and break out of the sub-id loop. The companion debug log
now formats output_dim with %s / getattr default so it tolerates the
missing attribute too.
Replace the fork's vendored sm70_884_4.cu tile registry (lmdeploy
v0.12.1) with the upstream lmdeploy main version (commit e5fbd4da,
from PR #4429 'fully implement compressed-tensors gs32 support').
mainloop_sm70.h, iterator_sm70.h and scheduler_sm70.cuh are byte
identical between the two snapshots -- only the Registry::sm70_884_4
tile-config list changed.

- Add a Config_U4_d<kColMajor> block with 21 gs32 tiles (the fork
  carried none for this layout).
- Expand the Config_U4_g<kColMajor> gs32 block from 6 to 17 tiles.
- Drop the gs64 block; both deployed quants (qwen3.6-27b-int4,
  granite-4.1-8b-awq-int4) are gs32.

Decode on V100 TP=2 is ~83% turbomind::gemm; the autotuner had no
gs32 candidates in the most common kColMajor layout, forcing fallback
tiles. Net diff +45 -11. Requires a full _C extension rebuild since
kernel registration is statically linked.
…ma4MTPModel

Vendored upstream vllm-project/vllm model_executor/models/gemma4.py (1714 LOC)
and registered Gemma4ForCausalLM (text backbone) + Gemma4MTPModel toward
serving gemma-4-31B + its MTP drafter on V100/SM70.

Not yet import-clean on this base: gemma4.py imports `GateLinear` from
layers.fused_moe (newer-upstream symbol absent here) — first API gap to
backport/adapt. transformers 5.7 (Gemma4Config) and the KV-sharing utils
are already present, so the foundation is in place.

Next: close the fused_moe/GateLinear gap, then add gemma4_mtp + spec-decode
core adaptations (PR vllm-project/vllm#41745).

Co-Authored-By: RivetOS Claude <noreply@anthropic.com>
- Vendor upstream fused_moe/router/gate_linear.py. Its specialized GEMM tiers
  (DSV3/fp32/cuBLAS) are all SM90+-gated, so on V100/SM70 it falls through to
  the ReplicatedLinear F.linear path; the missing ops.fp32_router_gemm is
  never reached and import/registration don't touch it.
- Export GateLinear and add module-level fused_moe_make_expert_params_mapping
  (delegates to FusedMoE.make_expert_params_mapping, which upstream refactored
  into a standalone function).

gemma4.py now imports cleanly against base d4f98f3 (verified against the live
.venv-v110 runtime); Gemma4ForCausalLM present.

Co-Authored-By: RivetOS Claude <noreply@anthropic.com>
- gemma4_mtp.py (627 LOC) imports clean against base.
- v1/spec_decode/gemma4.py (340 LOC) vendored but NOT yet import-clean: it
  subclasses SpecDecodeBaseProposer from v1/spec_decode/llm_base_proposer,
  a newer-upstream proposer abstraction our base predates. Next: adapt the
  proposer onto our base's eagle.py architecture (or backport the base class).

Co-Authored-By: RivetOS Claude <noreply@anthropic.com>
Gemma4Proposer's overridden methods all already exist on our base
SpecDecodeBaseProposer; only its 3 gemma4-specific methods are new. Our base
keeps SpecDecodeBaseProposer in eagle.py rather than upstream's separate
llm_base_proposer module, so redirect the import there. All four gemma-4
modules (gemma4, gemma4_mtp, spec_decode/gemma4) now import clean against
base d4f98f3, verified on the live .venv-v110 runtime.

Co-Authored-By: RivetOS Claude <noreply@anthropic.com>
…gnition

- model_arch_config_convertor: backport Gemma4ModelArchConfigConvertor
  (dual head_dim/global_head_dim sizing) + add Gemma4MTPModelArchConfigConvertor
  (speculator buffer sized to backbone_hidden_size); register gemma4/
  gemma4_text/gemma4_mtp.
- speculative: add gemma4_mtp to MTP types; remap model_type gemma4_assistant
  -> gemma4_mtp (n_predict=1, Gemma4MTPModel, zero cross-model KV-shared
  layers); add use_gemma4_mtp().

Co-Authored-By: RivetOS Claude <noreply@anthropic.com>
gpu_model_runner: import Gemma4Proposer; construct it for use_gemma4_mtp()
(before use_eagle, since gemma4 MTP is method "mtp"); add it to the Eagle/
DFlash isinstance + union sites; capture per-group block tables for it.

eagle.py (our base's home for SpecDecodeBaseProposer): add constant_draft_
positions (default False, so existing proposers incl. Deckard qwen3_5_mtp are
byte-identical); extract the per-step slot-mapping/metadata update into
_update_positions_dependent_metadata; guard it + the attn-metadata rebuild so
the Gemma4 constant-positions drafter builds once and reuses.

Verified on live .venv-v110: Gemma4ForCausalLM + Gemma4MTPModel register;
gpu_model_runner + Gemma4Proposer import clean.

Co-Authored-By: RivetOS Claude <noreply@anthropic.com>
…Generation)

- Vendor gemma4_mm.py (1706 LOC): SigLIP vision tower + audio tower +
  multimodal embedders on top of Gemma4ForCausalLM; register
  Gemma4ForConditionalGeneration -> gemma4_mm.
- Backport recursive_replace_linear into models/transformers/utils.py
  (deps replace_linear_class + maybe_prefix already present).
- Redirect MultiModalDataDict import to vllm.multimodal (this base's home;
  upstream re-exports via vllm.inputs).
- Skip gemma4_unified.py: that's the encoder-free 12B Unified variant (needs
  transformers.models.gemma4_unified, absent in transformers 5.7); not our 31B.

Verified on live .venv-v110: gemma4_mm imports clean; Gemma4ForConditional
Generation registers.

Co-Authored-By: RivetOS Claude <noreply@anthropic.com>
Found while bringing up google/gemma-4-31B-it-qat-w4a16-ct on the V100 SM70
W4A16 path (loads via SM70TurboMindLinearKernel; full graph builds):
- activation: register gelu_pytorch_tanh -> GeluAndMul(approximate="tanh")
  (gemma4 looks it up by name via the generic act-and-mul registry; gemma3
  special-cased it inline).
- rotary: vendor rotary_embedding/gemma4_rope.py (Gemma4RotaryEmbedding) and
  wire the get_rope `proportional` branch (gemma4 global/full attention).
- gemma4_mm: guard get_merged_mm_kwargs (newer InputProcessingContext method
  absent on this base) -> fall back to the config default soft-token count.

Known remaining: vision-tower weight mapping (std_bias / HF-built tower param
names) — MM-specific, LM weights load fine.

Co-Authored-By: RivetOS Claude <noreply@anthropic.com>
The HF Gemma4VisionModel (standardize=True) registers std_bias/std_scale as
persistent BUFFERS at the tower root; they're in the checkpoint and used at
runtime ((states - std_bias) * std_scale). vLLM's AutoWeightsLoader only loads
nn.Parameters (+ a BatchNorm-only buffer rescue), so it raised
"There is no module or parameter named vision_tower.std_bias".

Fix: in Gemma4ForConditionalGeneration.load_weights, intercept the two
checkpoint keys model.vision_tower.std_{bias,scale}, copy_ them into the
registered buffers, and pass the rest to AutoWeightsLoader unchanged (LM load
path byte-identical). Also fix the cosmetic 'vision_towerencoder' missing-dot
in the AutoWeightsLoader error message (base_prefix + k -> _get_qualname).

Co-Authored-By: RivetOS Claude <noreply@anthropic.com>
…tensor

In _clear_mm_prefix_range for full-attention layers, the port set both
metadata.mm_prefix_range = None and metadata.mm_prefix_range_tensor = None.
But in this fork's TritonAttentionMetadata, mm_prefix_range_tensor is a
read-only @Property derived from mm_prefix_range (no setter) -> AttributeError
at first forward. Clearing the source dict already nulls the derived tensor;
drop the broken assignment.

With this + serving on TRITON_ATTN (gemma-4's 512-dim global-attention layers
exceed FLASH_ATTN_V100's D<=256 cap), gemma-4-31B-it-qat-w4a16-ct generates
coherent text on the dual V100 via the SM70 TurboMind W4A16 path.

Co-Authored-By: RivetOS Claude <noreply@anthropic.com>
…d backend)

FlashAttnV100Backend defined get_supported_head_sizes()->[64,128,256] but that
method is never called by validate_configuration() — which uses
supports_head_size(). FA-V100 inherited TritonAttentionBackend.supports_head_size
(head_size >= 32), so it wrongly validated head_size=512 and would hard-crash
the Volta CUDA kernel (TORCH_CHECK D<=256). Add the override returning
{64,128,256}. This lets vLLM's per-layer backend auto-selection route gemma-4's
50 sliding (head_dim=256) layers to the fast FA-V100 kernel and fall through to
TRITON_ATTN only for the 10 global (head_dim=512) layers — instead of forcing
Triton on all 60. Deploy by DROPPING --attention-backend so auto-select runs.

Co-Authored-By: RivetOS Claude <noreply@anthropic.com>
The gemma4 MTP drafter checkpoint advertises model_type
"gemma4_assistant", which Transformers' AutoConfig does not recognize.
SpeculativeConfig's gemma4_assistant -> gemma4_mtp remap
(config/speculative.py) runs only after the draft config is loaded, so
get_config fell through to AutoConfig and raised a ValidationError before
the remap could fire.

Resolve gemma4_assistant to the multimodal Gemma4Config via _CONFIG_REGISTRY
(it carries the .text_config the remap expects). Verified on 2×V100: the
drafter now loads, the engine boots healthy, and serves. (Output-quality /
0%-draft-acceptance is a separate downstream issue in the KV-sharing path.)

Co-Authored-By: RivetOS Claude <noreply@anthropic.com>
The gemma4 MTP draft layers are Q-only and must read the target's KV
read-only via cross-model KV sharing. _setup_gemma4_kv_sharing set
kv_sharing_target_layer_name on the attention *module* only, but the
backend impl captured that value at construction (None for draft layers,
which are built before target layer names are known). The KV-write gate
checks the *impl's* copy (e.g. triton_attn.py: `if
self.kv_sharing_target_layer_name is None: <store K/V>`), so the draft
layers wrote their draft K/V into the target's shared layer-N slots,
poisoning the target's verify pass — output was correct for the first
decoded token then degenerated into garbage, with ~0% draft acceptance.

Propagate the target-layer name to attn.impl as well so the write is
correctly skipped. Verified on 2xV100 (gemma-4-31b-qat-w4a16 + assistant
drafter): output now matches target-only generation and draft acceptance
is 46-73% (mean accept len 1.47-1.73).

Co-Authored-By: RivetOS Claude <noreply@anthropic.com>
The E2B/E4B models (and their MTP assistants) exercise gemma4 efficiency
features the 31B did not, all hitting the same root cause: vLLM's
AutoWeightsLoader cannot place plain register_buffer tensors.

- gemma4_mm.py: Gemma4ClippableLinear (use_clipped_linears=True, in BOTH
  the vision and audio encoders) registers input_min/input_max/output_min/
  output_max activation clamps as buffers. Generalize the existing vision
  std_bias/std_scale buffer-load to walk every persistent buffer in both
  towers, so the audio tower loads too.
- gemma4_mtp.py: assistants with use_ordered_embeddings=True (E2B/E4B
  drafters) carry masked_embedding.token_ordering (token->centroid map) as
  a buffer. Load masked_embedding.* buffers before AutoWeightsLoader.

Verified on 2xV100: gemma-4-E2B-it and gemma-4-E4B-it both serve as full
multimodal targets with their matched MTP assistants, clean output, draft
acceptance 62-76%.

Co-Authored-By: RivetOS Claude <noreply@anthropic.com>
Plumb a `window` param (attended-token count; -1 = unlimited) through the
decode-paged, prefill-paged, and dense forward Volta kernels + relax the
backend gate so causal sliding-window (right==0) layers run on flash instead
of falling back to Triton.

- flash_decode_paged.cu: per-token window mask (warp-uniform skip of the dot)
  + whole-partition early-out that writes neutral stats so the cross-partition
  reduce stays correct.
- fused_mha_forward_paged.cu + fused_mha_forward.cu: window term in the causal
  mask (global_q_pos - global_n < window); relax dense kernel's
  window_size_left==-1 TORCH_CHECK + the python guard in flash_attn_interface.
- flash_attn_v100.py backend: gate accepts (-1,-1) or right==0 windows; add
  _flash_window = sliding_window[0]+1; pass to all paged/dense flash calls.
- test_window.py: standalone fp32-reference parity (decode + paged/dense
  prefill, full & windowed, D=128/256, edge cases) — all pass, err ~1e-3.

Validated correct in-model (coherent output, sliding->flash / global->triton
auto-route). NOTE: benchmarks show flash-auto currently LOSES to forced-Triton
at long context (dead-partition launch overhead masks the window work-reduction).
Kept available behind --attention-backend; see memory v100-vllm-own-fork.

Co-Authored-By: RivetOS Claude <noreply@anthropic.com>
Pure-FA (all 60 gemma layers on FLASH_ATTN_V100) — decode now beats Triton at
every context, verified correct on long-context retrieval.

- flash_decode_paged.cu: add head_dim 512 case. Decode kernel is head-dim-
  generic (dot_qk_cache<D> loops, no WMMA), so 512 is just the switch entry;
  q_shared[512]=1KB fits smem at M=1. Lets the global layers' DECODE run on FA.
- flash_attn_v100.py: supports_head_size/get_supported_head_sizes += 512;
  forward() gates 512 PREFILL to Triton (flash_prefill_ok=head_size<=256) since
  the prefill kernels cap at 256 (smem); decode + small-query verifier stay FA.
- fused_mha_forward{,_paged}.cu: sliding-window lower-bound tile skip — skip
  key-tiles entirely older than `window` (per-element mask already guarantees
  correctness; this just stops computing fully-masked tiles). Sliding prefill
  O(N^2) -> O(N*W).
- test_window.py: + D=512 decode cases (full + windowed, incl S=8192). All pass.

Measured (c1, decode TPOT): Triton 48/116/181/310ms @512/2k/4k/8k -> pure-FA
22/24/24/26ms FLAT (~12x at 8k). Short throughput c1 39.7 / c4 102.9 tok/s
(was 17.65/54.65). Long-ctx retrieval correct (fact at pos 0 of 3359 toks).
Remaining bottleneck: global-512 PREFILL on Triton dominates TTFT (needs
split-D prefill-512). See memory v100-vllm-own-fork.

Co-Authored-By: RivetOS Claude <noreply@anthropic.com>
Move the gemma global layers' PREFILL onto FA too. The WMMA prefill body
already accumulates QK over D in WMMA_K chunks and loops PV over D, so it's
head-dim-generic; the only blocker for 512 was fitting the wider Q/K/V/O tiles
in 96KB smem. Add a Config<512> with BLOCK_M=16/BLOCK_N=32/WARPS=16 (~84KB),
which the existing body uses unchanged -- no split-D kernel rewrite needed.

- fused_mha_forward{,_paged}.cu + flash_v100_traits.cuh: BLOCK_M/N_512 consts,
  D==512 in the config/traits ternaries, dispatch case 512, relax D<=256 check.
  WARPS_512=16 keeps THREADS_PER_BLOCK=512 consistent with the traits-based
  paged KV loader (which assumes 16 warps).
- flash_attn_v100.py: flash_prefill_ok = head_size<=512 (512 prefill now FA).
- test_window.py: + D=512 prefill (paged + dense, full + windowed). All pass ~1e-3.

Now ALL 60 layers (decode AND prefill) run on FLASH_ATTN_V100 -- zero Triton in
the attention path. Measured (c1) vs original Triton baseline:
  TTFT 8192: 63s -> 8.3s (7.6x);  decode TPOT 8192: 310ms -> 25ms (12x, flat);
  throughput 8192: 1.25 -> 11.2 tok/s (9x); short c1 39.7 / c4 102.9 tok/s.
Correctness: test_window all-pass + long-ctx retrieval (fact@pos0 of 3359 toks).

Co-Authored-By: RivetOS Claude <noreply@anthropic.com>
The paged prefill kernel (fused_mha_forward_paged.cu) copies the
per-sequence block table into shared memory, so its smem is
TOTAL_SMEM[head_dim] + align128(max_num_blocks * 4), where max_num_blocks
= block_table width = ceil(max_model_len / page_block_size). The per-D
base is already ~84-96KB, so at long max_model_len the block table alone
pushes total smem past V100's 96KB ceiling and the kernel's TORCH_CHECK
aborts the worker, killing the server. head_dim 256 at 177k ctx needs
138368 bytes (11088 blocks); it only stayed hidden because no-prefix
prefill uses the dense kernel and short-context servers (e.g. 11k) fit.

Add _paged_prefill_smem_fits() mirroring the kernel's smem formula and
gate the paged prefill call in _flash_v100_prefill_with_prefix on it.
When it does not fit, fall through to the existing gather + dense path,
which is smem-safe at any context length and still fully on FA (no Triton
fallback). Verified: 31b @ 177k survives an 8.4k chunked-prefill prompt
(previously crashed); e2b @ 11k still uses the paged kernel.

Co-Authored-By: RivetOS Claude <noreply@anthropic.com>
Gemma E2B/E4B use KV-cache sharing (num_kv_shared_layers): the last N
decoder layers reuse an earlier layer's KV instead of projecting their
own. Per gemma4.py, a shared layer applies RoPE to Q only and passes
raw, un-normed/un-RoPE'd K/V to Attention, relying on it reading the
TARGET layer's cache via kv_sharing_target_layer_name.

The FA_V100 no-prefix prefill path consumed the passed K/V directly, so
shared layers attended to junk -> coherent-but-wrong output (e2b:
"capital of France" -> "Hanoi"; raw template -> degenerate
<start_of_turn> loop). Decode was already correct (the paged decode
kernel reads kv_cache directly), and the kernel math was fine (parity
passes at num_kv_heads=1) -- the bug was purely no-prefix prefill.

Route a shared layer's no-prefix prefill through the prefix path, which
reads the TARGET cache (aliased into kv_cache, already written by the
earlier layer this pass) via the paged kernel when smem-safe, else
gather+dense. 31b is unaffected (num_kv_shared_layers=0). Verified: e2b
now runs fully on FA with correct output.

Co-Authored-By: RivetOS Claude <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants