From e3abbbf3b88953916408ba91532ccab64c4e527a Mon Sep 17 00:00:00 2001 From: rivetphilbot Date: Tue, 19 May 2026 01:29:19 +0000 Subject: [PATCH 01/25] [SM70] Add V100 dense WNA16 TurboMind linear kernel Add SM70TurboMindLinearKernel, an MPLinearKernel implementation that routes compressed-tensors / AWQ WNA16 dense GEMMs through the bundled TurboMind sm70_884_4 INT4 path. V100 (CC 7.0) has only first-gen FP16 WMMA cores and no Turing INT4 tensor-core GEMM, so the stock CUTLASS / Machete kernels are unavailable; this kernel gives dense WNA16 layers a working code path on SM70. Register it at the head of the CUDA _POSSIBLE_KERNELS priority list so it is preferred when running on V100; on newer architectures the existing kernels still win their min-capability checks. --- .../kernels/mixed_precision/__init__.py | 4 + .../kernels/mixed_precision/sm70_turbomind.py | 193 ++++++++++++++++++ 2 files changed, 197 insertions(+) create mode 100644 vllm/model_executor/layers/quantization/kernels/mixed_precision/sm70_turbomind.py diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py index ceac387c71..20d0755d78 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py @@ -20,6 +20,9 @@ from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501 ExllamaLinearKernel, ) +from vllm.model_executor.layers.quantization.kernels.mixed_precision.sm70_turbomind import ( # noqa: E501 + SM70TurboMindLinearKernel, +) from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( # noqa: E501 MacheteLinearKernel, ) @@ -38,6 +41,7 @@ # in priority/performance order (when available) _POSSIBLE_KERNELS: dict[PlatformEnum, list[type[MPLinearKernel]]] = { PlatformEnum.CUDA: [ + SM70TurboMindLinearKernel, CutlassW4A8LinearKernel, MacheteLinearKernel, AllSparkLinearKernel, diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/sm70_turbomind.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/sm70_turbomind.py new file mode 100644 index 0000000000..85229bf639 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/sm70_turbomind.py @@ -0,0 +1,193 @@ +# SPDX-License-Identifier: Apache-2.0 +# SM70 (V100) WNA16 kernel for compressed-tensors and legacy-AWQ formats. +# +# Reads pack-quantized weights at load time, transcodes to legacy AWQ pack +# layout (interleave order [0,2,4,6,1,3,5,7] along the output dim), then +# dispatches through the existing 1Cat TurboMind s884h kernels via +# awq_sm70_prepare / awq_gemm_sm70. +# +# Why: on SM70 (V100) Cutlass/Machete (CC>=90), AllSpark/Conch (>=80), +# Marlin (>=75) all reject. Exllama runs at CC>=60 but only accepts +# uint4b8 / uint8b128 scalar types. cyankiwi's Qwen3.6-27B quants are +# compressed-tensors uint4 (asymmetric) — nothing currently picks them. +# This kernel fills that gap. + +from __future__ import annotations + +import torch + +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types + +from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig + +logger = init_logger(__name__) + +# AWQ unpack uses GATHER with awq_order [0,4,1,5,2,6,3,7]; the inverse +# permutation [0,2,4,6,1,3,5,7] is what packing must use for round-trip +# correctness. Same constant as the MoE SM70 transcoder uses. +_AWQ_PACK_ORDER = (0, 2, 4, 6, 1, 3, 5, 7) + + +def _awq_pack_last_dim(unpacked: torch.Tensor) -> torch.Tensor: + """[..., X, Y] uint8 -> [..., X, Y/8] int32 with AWQ interleave.""" + *prefix, x, y = unpacked.shape + assert y % 8 == 0 + grouped = unpacked.view(*prefix, x, y // 8, 8) + res = grouped[..., _AWQ_PACK_ORDER[7]].to(torch.int32) + for i in range(6, -1, -1): + res = (res << 4) | grouped[..., _AWQ_PACK_ORDER[i]].to(torch.int32) + return res + + +def _ct_qweight_to_awq(ct_q: torch.Tensor) -> torch.Tensor: + """CT [N, K/8] (sequential pack along K) -> AWQ [K, N/8] (interleave pack along N).""" + n, k_div_8 = ct_q.shape + k = k_div_8 * 8 + unpacked = torch.empty(n, k, dtype=torch.uint8, device=ct_q.device) + tmp = ct_q.clone() + for i in range(8): + unpacked[:, i::8] = (tmp & 0xF).to(torch.uint8) + tmp = tmp >> 4 + return _awq_pack_last_dim(unpacked.t().contiguous()) + + +def _ct_qzeros_to_awq(ct_zp: torch.Tensor) -> torch.Tensor: + """CT [N/8, K/gs] (pack along N) -> AWQ [K/gs, N/8] (interleave pack along N).""" + n_div_8, k_gs = ct_zp.shape + n = n_div_8 * 8 + unpacked = torch.empty(n, k_gs, dtype=torch.uint8, device=ct_zp.device) + tmp = ct_zp.clone() + for i in range(8): + unpacked[i::8, :] = (tmp & 0xF).to(torch.uint8) + tmp = tmp >> 4 + return _awq_pack_last_dim(unpacked.t().contiguous()) + + +class SM70TurboMindLinearKernel(MPLinearKernel): + """V100 dense WNA16 kernel: CT/legacy pack-quant -> TurboMind s884h GEMM.""" + + SUPPORTED_QUANT_TYPES = [scalar_types.uint4, scalar_types.uint4b8] + SUPPORTED_GROUP_SIZES = (32, 64, 128) + + @classmethod + def get_min_capability(cls) -> int: + return 70 + + @classmethod + def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]: + if not current_platform.is_cuda_alike(): + return False, 'SM70TurboMind requires CUDA' + if c.act_type != torch.float16: + return False, 'SM70TurboMind requires float16 activations' + if c.weight_type not in cls.SUPPORTED_QUANT_TYPES: + return (False, + f'SM70TurboMind: weight type {c.weight_type} not supported' + f' (need {cls.SUPPORTED_QUANT_TYPES})') + if c.group_size not in cls.SUPPORTED_GROUP_SIZES: + return (False, + f'SM70TurboMind: group_size={c.group_size} not in ' + f'{cls.SUPPORTED_GROUP_SIZES}') + k_part, n_part = c.partition_weight_shape + if k_part % 8 != 0 or n_part % 8 != 0: + return False, 'SM70TurboMind: K and N must be multiples of 8' + if k_part % c.group_size != 0: + return (False, + f'SM70TurboMind: K={k_part} not divisible by ' + f'group_size={c.group_size}') + if c.has_g_idx: + return False, 'SM70TurboMind: act-reorder (g_idx) not supported' + if not hasattr(torch.ops._C, 'awq_sm70_prepare'): + return False, 'SM70TurboMind: awq_sm70_prepare op missing' + # Only run on actual SM70 hardware - SM75+ should pick a faster kernel + cap = current_platform.get_device_capability() + if cap is None or not (cap[0] == 7 and cap[1] == 0): + return False, 'SM70TurboMind: only used on V100 (CC 7.0)' + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + c = self.config + k_part, n_part = c.partition_weight_shape + gs = c.group_size + + ct_q = getattr(layer, self.w_q_name).data + ct_s = getattr(layer, self.w_s_name).data + ct_zp = (getattr(layer, self.w_zp_name).data + if c.zero_points and self.w_zp_name else None) + device = ct_q.device + + assert ct_q.shape == (n_part, k_part // 8), ( + f'SM70TurboMind: expected CT qweight [{n_part}, {k_part//8}], ' + f'got {tuple(ct_q.shape)}') + assert ct_s.shape[0] == n_part, ( + f'SM70TurboMind: expected scale dim 0 == {n_part}, ' + f'got {tuple(ct_s.shape)}') + k_gs = ct_s.shape[1] + + logger.info( + 'SM70TurboMind: layer %s K=%d N=%d gs=%d K/gs=%d asym=%s', + getattr(layer, '_layer_name', '?'), k_part, n_part, gs, k_gs, + c.zero_points) + + awq_q = _ct_qweight_to_awq(ct_q) + awq_s = ct_s.t().contiguous().to(torch.float16) + + if c.zero_points and ct_zp is not None: + assert ct_zp.shape == (n_part // 8, k_gs), ( + f'SM70TurboMind: expected CT qzeros [{n_part//8}, {k_gs}], ' + f'got {tuple(ct_zp.shape)}') + awq_zp = _ct_qzeros_to_awq(ct_zp) + else: + zp_val = torch.tensor( + [0x88888888], dtype=torch.uint32).view(torch.int32).item() + awq_zp = torch.full((k_gs, n_part // 8), zp_val, + dtype=torch.int32, device=device) + + tm_w, tm_s, meta = ops.awq_sm70_prepare(awq_q, awq_s, awq_zp, gs) + + layer._awq_sm70_weight = torch.nn.Parameter(tm_w, requires_grad=False) + layer._awq_sm70_scales = torch.nn.Parameter(tm_s, requires_grad=False) + meta0 = meta[0].item() if torch.is_tensor(meta[0]) else meta[0] + meta1 = meta[1].item() if torch.is_tensor(meta[1]) else meta[1] + layer._awq_sm70_k_ld = int(meta0) + layer._awq_sm70_q_ld = int(meta1) + layer._awq_sm70_group_size = gs + layer._awq_sm70_prepared = True + + empty_i32 = torch.empty(0, dtype=torch.int32, device=device) + empty_fp16 = torch.empty(0, dtype=torch.float16, device=device) + replace_parameter( + layer, self.w_q_name, + torch.nn.Parameter(empty_i32, requires_grad=False)) + replace_parameter( + layer, self.w_s_name, + torch.nn.Parameter(empty_fp16, requires_grad=False)) + if c.zero_points and self.w_zp_name: + replace_parameter( + layer, self.w_zp_name, + torch.nn.Parameter(empty_i32, requires_grad=False)) + + del awq_q, awq_s, awq_zp + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + out_shape = x.shape[:-1] + (layer._awq_sm70_weight.shape[-1] * 8,) + x_2d = x.reshape(-1, x.shape[-1]) + out = ops.awq_gemm_sm70( + x_2d, + layer._awq_sm70_weight, + layer._awq_sm70_scales, + layer._awq_sm70_group_size, + layer._awq_sm70_k_ld, + layer._awq_sm70_q_ld, + ) + if bias is not None: + out.add_(bias) + return out.reshape(out_shape) From 788812f5667016264af0dc8f898ba05f4502335a Mon Sep 17 00:00:00 2001 From: rivetphilbot Date: Tue, 19 May 2026 01:29:33 +0000 Subject: [PATCH 02/25] [SM70] Admit V100 (CC 7.0) in CompressedTensorsWNA16 CompressedTensorsWNA16.get_min_capability hard-coded 75, so loading a compressed-tensors WNA16 model on a V100 failed with 'Failed to find a kernel that can implement the WNA16 linear layer' before the new SM70TurboMindLinearKernel ever got a chance to bid. Lower the reported minimum to 70 specifically when running on an SM70 device (CC 7.0). Older pre-Turing GPUs (sm_60/61/62) still get 75 and remain correctly rejected, since only V100 has the FP16 WMMA path the TurboMind kernel relies on. --- .../schemes/compressed_tensors_wNa16.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py index 78b1c101c8..e37671c86b 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -79,7 +79,13 @@ def __init__( @classmethod def get_min_capability(cls) -> int: - # Turing and up + # SM70 (V100) is supported via SM70TurboMindLinearKernel + # (CT pack -> AWQ pack -> awq_sm70_prepare/awq_gemm_sm70). + # All other pre-Turing GPUs (sm_60, sm_61, sm_62) are unsupported. + from vllm.platforms import current_platform + cap = current_platform.get_device_capability() + if cap is not None and cap[0] == 7 and cap[1] == 0: + return 70 return 75 def create_weights( From 1b6a4cdc9334ea7dd8fa59e8ac9a43b4f550e045 Mon Sep 17 00:00:00 2001 From: rivetphilbot Date: Tue, 19 May 2026 01:29:47 +0000 Subject: [PATCH 03/25] [SM70] Wire compressed-tensors MoE decode buffers for V100 CompressedTensorsSM70WNA16MoEMethod delegates its decode to AWQSM70MoEMethod.apply(), but only allocated a subset of the buffers that path reads, so a CT-quantized MoE model crashed on the first decode step on V100. - Allocate the full buffer set AWQSM70MoEMethod.process_weights_after_ loading creates: gate/up and permutation scratch, sorted-output and m-index buffers, int64 expert offsets, and the single-token batched pointer buffers. - Publish sm70_hidden/intermediate logical+aligned sizes (CT weights are already in TurboMind layout, so logical == aligned). - Build per-expert StridedPtr row views and record sm70_ptr_row_bytes for the batched GEMM path. - Pass interleave_gated_silu=True to awq_sm70_prepare so the fused gate/up weights match the decode kernel's expectation. Also switch the import to _DEFAULT_PERSISTENT_MAX_TOKENS; awq_sm70_moe renamed _DEFAULT_MAX_TOKENS, leaving the old name a dangling import. --- .../compressed_tensors_moe.py | 77 +++++++++++++++++-- 1 file changed, 69 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 52624288e9..0a245fc51f 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -2176,7 +2176,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: """Convert CT weights to AWQ format, then run TurboMind prepare.""" from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.awq_sm70_moe import ( - _DEFAULT_MAX_TOKENS, + _DEFAULT_PERSISTENT_MAX_TOKENS, ) gs = self.group_size @@ -2213,7 +2213,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: for e in range(num_experts): r13 = ops.awq_sm70_prepare( - w13_qweight[e], w13_scales[e], w13_qzeros[e], gs) + w13_qweight[e], w13_scales[e], w13_qzeros[e], gs, + interleave_gated_silu=True) w13_tm_w.append(r13[0]) w13_tm_s.append(r13[1]) w13_meta.append(r13[2]) @@ -2250,6 +2251,14 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: intermediate_size = layer.sm70_w2_k_dim layer.sm70_intermediate_size = intermediate_size + # CT weights are pre-aligned to TurboMind layout, so logical == aligned. + # AWQSM70MoEMethod.apply() reads these attrs during decode. + hidden_size = layer.sm70_w13_k_dim + layer.sm70_hidden_logical_size = hidden_size + layer.sm70_hidden_aligned_size = hidden_size + layer.sm70_intermediate_logical_size = intermediate_size + layer.sm70_intermediate_aligned_size = intermediate_size + # --- Build StridedPtr arrays for batched GEMM --- w13_k_ld, w13_q_ld = layer.w13_meta_list[0] w2_k_ld, w2_q_ld = layer.w2_meta_list[0] @@ -2268,31 +2277,83 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: w2_ptrs[0], requires_grad=False) layer.w2_strided_ptrs_s = torch.nn.Parameter( w2_ptrs[1], requires_grad=False) + layer.w13_strided_ptrs_w_rows = layer.w13_strided_ptrs_w.view( + num_experts, -1) + layer.w13_strided_ptrs_s_rows = layer.w13_strided_ptrs_s.view( + num_experts, -1) + layer.w2_strided_ptrs_w_rows = layer.w2_strided_ptrs_w.view( + num_experts, -1) + layer.w2_strided_ptrs_s_rows = layer.w2_strided_ptrs_s.view( + num_experts, -1) + layer.sm70_ptr_row_bytes = layer.w13_strided_ptrs_w_rows.shape[1] layer.sm70_batched_ready = True logger.info("SM70 CT MoE: batched GEMM enabled (%d experts)", num_experts) except Exception as e: layer.sm70_batched_ready = False + layer.sm70_ptr_row_bytes = 0 logger.warning("SM70 CT MoE: batched GEMM unavailable (%s)", e) # --- Pre-allocate buffers (CUDA graph safe) --- + # Mirrors AWQSM70MoEMethod.process_weights_after_loading so the + # delegated AWQSM70MoEMethod.apply() path finds every buffer it needs. + hidden_size = layer.sm70_w13_k_dim top_k = self.moe.experts_per_token - max_slots = _DEFAULT_MAX_TOKENS * top_k + persistent_tokens = _DEFAULT_PERSISTENT_MAX_TOKENS + max_slots = persistent_tokens * top_k + layer._buf_max_tokens = persistent_tokens layer._buf_max_slots = max_slots layer._buf_top_k = top_k - layer._buf_expert_counts = torch.zeros( + layer._buf_expert_counts = torch.empty( num_experts, dtype=torch.int32, device=device) - layer._buf_expert_offsets = torch.zeros( + layer._buf_expert_offsets = torch.empty( num_experts + 1, dtype=torch.int32, device=device) + layer._buf_expert_offsets64 = torch.empty( + num_experts + 1, dtype=torch.int64, device=device) + layer._buf_gate_up = torch.empty( + max_slots, layer.sm70_w13_n_dim, + dtype=torch.float16, device=device) layer._buf_intermediate = torch.empty( max_slots, intermediate_size, dtype=torch.float16, device=device) - layer._buf_ones = torch.ones( + layer._buf_permuted_input = torch.empty( + max_slots, hidden_size, dtype=torch.float16, device=device) + layer._buf_sorted_output = torch.empty( + max_slots, hidden_size, dtype=torch.float16, device=device) + layer._buf_inv_permuted_idx = torch.empty( + persistent_tokens, top_k, dtype=torch.int32, device=device) + layer._buf_topk_ids_i32 = torch.empty( + persistent_tokens, top_k, dtype=torch.int32, device=device) + layer._buf_token_expert_indices = torch.arange( + max_slots, dtype=torch.int32, device=device).view( + persistent_tokens, top_k) + layer._buf_permuted_idx = torch.empty( + max_slots, dtype=torch.int32, device=device) + layer._buf_m_indices = torch.empty( max_slots, dtype=torch.int32, device=device) - hidden_size = layer.sm70_w13_k_dim layer._buf_output = torch.empty( - _DEFAULT_MAX_TOKENS, hidden_size, + persistent_tokens, hidden_size, dtype=torch.float16, device=device) + layer._buf_ones = torch.ones( + max_slots, dtype=torch.int32, device=device) + if layer.sm70_batched_ready: + ptr_row = layer.sm70_ptr_row_bytes + layer._buf_single_topk_ids_i64 = torch.empty( + top_k, dtype=torch.int64, device=device) + layer._buf_single_w13_ptrs_w = torch.empty( + top_k, ptr_row, dtype=torch.uint8, device=device) + layer._buf_single_w13_ptrs_s = torch.empty( + top_k, ptr_row, dtype=torch.uint8, device=device) + layer._buf_single_w2_ptrs_w = torch.empty( + top_k, ptr_row, dtype=torch.uint8, device=device) + layer._buf_single_w2_ptrs_s = torch.empty( + top_k, ptr_row, dtype=torch.uint8, device=device) + layer._buf_single_expert_offsets = torch.arange( + top_k + 1, dtype=torch.int32, device=device) + layer._buf_single_expert_offsets64 = torch.arange( + top_k + 1, dtype=torch.int64, device=device) + layer._buf_single_inv_permuted_idx = torch.arange( + top_k, dtype=torch.int32, device=device).view(1, top_k) # Free original CT weights for attr in ("w13_weight_packed", "w13_weight_scale", From bd6d0c8e56dcb78c574f23179072ddb97c981c13 Mon Sep 17 00:00:00 2001 From: rivetphilbot Date: Tue, 19 May 2026 01:30:41 +0000 Subject: [PATCH 04/25] [Qwen3] Keep router gate and split GDN projections unquantized under compressed-tensors Two related fixes for running Qwen3.5/3.6 compressed-tensors checkpoints: - Qwen3NextSparseMoeBlock: the MoE router gate is stored as bf16 in the checkpoint and has no quantized form. Passing the model quant_config to its ReplicatedLinear made the loader expect quantized weights; force quant_config=None so the gate stays bf16. - _uses_split_gdn_input_projections only inspected modules_to_not_convert and ignored_layers. Compressed-tensors records its skip list under the ignore attribute, so the BF16 in_proj_a / in_proj_b GDN projections of a CT checkpoint were not detected and the split-projection layout was not selected. Consult quant_config.ignore as a final fallback. --- vllm/model_executor/models/qwen3_5.py | 4 ++++ vllm/model_executor/models/qwen3_next.py | 4 +++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/qwen3_5.py b/vllm/model_executor/models/qwen3_5.py index bb05455d22..cd75b5bbc7 100644 --- a/vllm/model_executor/models/qwen3_5.py +++ b/vllm/model_executor/models/qwen3_5.py @@ -136,6 +136,10 @@ def _uses_split_gdn_input_projections( modules_to_not_convert = getattr(quant_config, "modules_to_not_convert", None) if modules_to_not_convert is None: modules_to_not_convert = getattr(quant_config, "ignored_layers", None) + # Compressed-tensors exposes its skip list as ``ignore`` rather than + # ``modules_to_not_convert`` / ``ignored_layers``. + if modules_to_not_convert is None: + modules_to_not_convert = getattr(quant_config, "ignore", None) if not modules_to_not_convert: return False return any( diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index 6fc97d738f..31e217d05b 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -203,7 +203,9 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = ""): config.hidden_size, config.num_experts, bias=False, - quant_config=quant_config, + # Router gate is stored as bf16 in the checkpoint; it must stay + # unquantized even when the model uses compressed-tensors / AWQ. + quant_config=None, prefix=f"{prefix}.gate", ) From 4d13a60e96439856b66ec628a1d733f2ad2707df Mon Sep 17 00:00:00 2001 From: rivetphilbot Date: Tue, 19 May 2026 01:31:03 +0000 Subject: [PATCH 05/25] [Qwen3.5] Skip tuple-shard split for non-output-dim CT params CompressedTensorsWNA16 creates auxiliary parameters -- weight_shape (BasevLLMParameter) and weight_g_idx (RowvLLMParameter) -- that hold metadata or input-dim-sharded indices rather than output-dim weight data, so they have no output_dim attribute. When the qkvz stacked-load mapping in Qwen3_5Model.load_weights reached one of these via the tuple-shard path, it hit AttributeError on param.output_dim. Fix: when output_dim is absent, load through the standard non-shard weight_loader (last-write-wins for replicated metadata) and break out of the sub-id loop. The companion debug log now formats output_dim with %s / getattr default so it tolerates the missing attribute too. --- vllm/model_executor/models/qwen3_5.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/qwen3_5.py b/vllm/model_executor/models/qwen3_5.py index cd75b5bbc7..615cce65e0 100644 --- a/vllm/model_executor/models/qwen3_5.py +++ b/vllm/model_executor/models/qwen3_5.py @@ -548,6 +548,14 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: param = params_dict[name] weight_loader = param.weight_loader if isinstance(shard_id, tuple): + # Auxiliary compressed-tensors params (weight_shape: + # BasevLLMParameter, weight_g_idx: RowvLLMParameter, ...) + # have no output_dim and are not output-sharded. Load them + # via the standard non-shard weight_loader and bail out of + # the sub-id loop. + if not hasattr(param, "output_dim"): + weight_loader(param, loaded_weight) + break # Split by the target module's output shard metadata # instead of hardcoding tensor shapes. owner = getattr(weight_loader, "__self__", None) @@ -565,12 +573,12 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: ) logger.debug( "Tuple shard load for %s: shard_ids=%s output_sizes=%s " - "weight_shape=%s output_dim=%d", + "weight_shape=%s output_dim=%s", name, shard_id, output_sizes, tuple(loaded_weight.shape), - param.output_dim, + getattr(param, "output_dim", "?"), ) for sub_id in shard_id: shard_offset = sum(output_sizes[:sub_id]) From d4f98f3b1adc005216a29f34e6c541921c13b6ee Mon Sep 17 00:00:00 2001 From: rivetphilbot Date: Tue, 19 May 2026 01:31:15 +0000 Subject: [PATCH 06/25] [SM70] Sync sm70_884_4.cu kernel registry to lmdeploy main (gs32) Replace the fork's vendored sm70_884_4.cu tile registry (lmdeploy v0.12.1) with the upstream lmdeploy main version (commit e5fbd4da, from PR #4429 'fully implement compressed-tensors gs32 support'). mainloop_sm70.h, iterator_sm70.h and scheduler_sm70.cuh are byte identical between the two snapshots -- only the Registry::sm70_884_4 tile-config list changed. - Add a Config_U4_d block with 21 gs32 tiles (the fork carried none for this layout). - Expand the Config_U4_g gs32 block from 6 to 17 tiles. - Drop the gs64 block; both deployed quants (qwen3.6-27b-int4, granite-4.1-8b-awq-int4) are gs32. Decode on V100 TP=2 is ~83% turbomind::gemm; the autotuner had no gs32 candidates in the most common kColMajor layout, forcing fallback tiles. Net diff +45 -11. Requires a full _C extension rebuild since kernel registration is statically linked. --- .../kernels/gemm/kernel/sm70_884_4.cu | 56 +++++++++++++++---- 1 file changed, 45 insertions(+), 11 deletions(-) diff --git a/lmdeploy/src/turbomind/kernels/gemm/kernel/sm70_884_4.cu b/lmdeploy/src/turbomind/kernels/gemm/kernel/sm70_884_4.cu index 71682ef0c5..690aa331b6 100644 --- a/lmdeploy/src/turbomind/kernels/gemm/kernel/sm70_884_4.cu +++ b/lmdeploy/src/turbomind/kernels/gemm/kernel/sm70_884_4.cu @@ -36,7 +36,6 @@ void Registry::sm70_884_4() if constexpr (1) { // clang-format off - // GroupSizeV=128 using C = Config_U4_g; Add>(); Add>(); @@ -47,23 +46,58 @@ void Registry::sm70_884_4() Add>(); Add>(); Add>(); - Add>(); Add>(); - Add>(); - // GroupSizeV=64 - Add>(); - Add>(); - Add>(); - Add>(); - Add>(); - Add>(); - // GroupSizeV=32 + // clang-format on + } + + if constexpr (1) { + // clang-format off + using C = Config_U4_d; + Add>(); + Add>(); + Add>(); + Add>(); + Add>(); + Add>(); + Add>(); + Add>(); + Add>(); + Add>(); + Add>(); + Add>(); + Add>(); + Add>(); + Add>(); + Add>(); + Add>(); + Add>(); + Add>(); + Add>(); + Add>(); + // clang-format on + } + + if constexpr (1) { + // clang-format off + using C = Config_U4_g; + Add>(); Add>(); Add>(); + Add>(); Add>(); + Add>(); + Add>(); + Add>(); Add>(); Add>(); + Add>(); + Add>(); + Add>(); + Add>(); Add>(); + Add>(); + Add>(); + Add>(); // clang-format on } From 74b0ebf2cea54748d6220526a04aba7cd93b03bd Mon Sep 17 00:00:00 2001 From: Phil Date: Fri, 5 Jun 2026 21:21:34 +0000 Subject: [PATCH 07/25] [gemma4][WIP] Add gemma-4 backbone + register Gemma4ForCausalLM / Gemma4MTPModel MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Vendored upstream vllm-project/vllm model_executor/models/gemma4.py (1714 LOC) and registered Gemma4ForCausalLM (text backbone) + Gemma4MTPModel toward serving gemma-4-31B + its MTP drafter on V100/SM70. Not yet import-clean on this base: gemma4.py imports `GateLinear` from layers.fused_moe (newer-upstream symbol absent here) — first API gap to backport/adapt. transformers 5.7 (Gemma4Config) and the KV-sharing utils are already present, so the foundation is in place. Next: close the fused_moe/GateLinear gap, then add gemma4_mtp + spec-decode core adaptations (PR vllm-project/vllm#41745). Co-Authored-By: RivetOS Claude --- vllm/model_executor/models/gemma4.py | 1714 ++++++++++++++++++++++++ vllm/model_executor/models/registry.py | 2 + 2 files changed, 1716 insertions(+) create mode 100644 vllm/model_executor/models/gemma4.py diff --git a/vllm/model_executor/models/gemma4.py b/vllm/model_executor/models/gemma4.py new file mode 100644 index 0000000000..d14a767df5 --- /dev/null +++ b/vllm/model_executor/models/gemma4.py @@ -0,0 +1,1714 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright 2025 The vLLM team. +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Gemma 4 model implementation for vLLM.""" + +from collections.abc import Iterable +from dataclasses import replace +from itertools import islice + +import regex as re +import torch +from torch import nn + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import ( + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from vllm.forward_context import get_forward_context +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import get_act_and_mul_fn +from vllm.model_executor.layers.attention import Attention +from vllm.model_executor.layers.fused_moe import ( + FusedMoE, + GateLinear, + fused_moe_make_expert_params_mapping, +) +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) +from vllm.platforms import current_platform +from vllm.sequence import IntermediateTensors +from vllm.triton_utils import tl, triton +from vllm.v1.attention.backends.utils import KVSharingFastPrefillMetadata + +from .interfaces import ( + EagleModelMixin, + MixtureOfExperts, + SupportsEagle3, + SupportsLoRA, + SupportsPP, +) +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + extract_layer_index, + is_pp_missing_parameter, + make_layers, + maybe_prefix, +) + +logger = init_logger(__name__) + + +def _remap_gemma4_expert_weight_name(name: str) -> str: + return re.sub(r"(?> 31 + key = tl.where(sign_b == 0, logit_bits ^ -1, logit_bits ^ MIN32) + key = tl.where(valid, key, 0x7FFFFFFF) + sk64 = key.to(tl.int64) & 0x00000000FFFFFFFF + packed = (sk64 << 32) | offs_e.to(tl.int64) + sorted_p = tl.sort(packed, descending=False) + + # Vectorized extraction of ALL sorted elements — no K-loop, no cross-lane reductions + all_keys = ((sorted_p >> 32) & 0x00000000FFFFFFFF).to(tl.int32) + all_ids = (sorted_p & 0x00000000FFFFFFFF).to(tl.int32) + + # Inverse bijection: recover original logit bits + sign_k = all_keys >> 31 + all_bits = tl.where(sign_k < 0, all_keys ^ -1, all_keys ^ MIN32) + all_logits = all_bits.to(tl.float32, bitcast=True) + + # Compute raw_exp for ALL BLOCK_E elements — vectorized, ~2 VALU clocks + all_raw_exp = tl.math.exp2((all_logits - max_l) * 1.4426950408889634) + + # Sum only top-K for renorm — ONE masked reduction + top_mask = offs_e < K + renorm_raw = tl.sum(tl.where(top_mask, all_raw_exp, 0.0), axis=0) + renorm_raw = tl.where(renorm_raw > 0.0, renorm_raw, 1.0) + inv_renorm = 1.0 / renorm_raw + + # Load scales for top-K only (masked gather; scale array is tiny → L1 cached) + all_scales = tl.load( + per_expert_scale_ptr + all_ids.to(tl.int64), + mask=top_mask, + other=1.0, + ).to(tl.float32) + + # Final weights: vectorized multiply (only top-K will be stored) + all_weights = (all_raw_exp * inv_renorm * all_scales).to(tl.float32) + + # Write results with TWO masked stores — replaces K × 2 serial scalar stores + base_off = pid * K + offs_e + tl.store(topk_ids_ptr + base_off, all_ids, mask=top_mask) + tl.store(topk_weights_ptr + base_off, all_weights, mask=top_mask) + + +def gemma4_fused_routing_kernel_triton( + gating_output: torch.Tensor, + topk: int, + per_expert_scale: torch.Tensor, + num_warps: int = 1, +) -> tuple[torch.Tensor, torch.Tensor]: + gating_output = gating_output.contiguous() + per_expert_scale = per_expert_scale.contiguous() + T, E = gating_output.shape + weights = torch.empty(T, topk, dtype=torch.float32, device=gating_output.device) + ids = torch.empty(T, topk, dtype=torch.int32, device=gating_output.device) + BLOCK_E = triton.next_power_of_2(E) + _gemma4_routing_kernel[(T,)]( + gating_output, + per_expert_scale, + weights, + ids, + E, + topk, + BLOCK_E, + num_warps=num_warps, + ) + return weights, ids + + +def gemma4_routing_function_torch( + gating_output: torch.Tensor, + topk: int, + per_expert_scale: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + _, topk_ids = torch.topk(gating_output, k=topk, dim=-1) + router_probabilities = torch.nn.functional.softmax(gating_output, dim=-1) + indicator = torch.nn.functional.one_hot( + topk_ids, num_classes=gating_output.size(-1) + ).sum(dim=-2) + gate_weights = indicator * router_probabilities + renorm_factor = torch.sum(gate_weights, dim=-1, keepdim=True) + renorm_factor = torch.where(renorm_factor > 0.0, renorm_factor, 1.0) + dispatch_weights = gate_weights / renorm_factor + + topk_weights = dispatch_weights.gather(1, topk_ids) + + # Fold per_expert_scale into routing weights + expert_scales = per_expert_scale[topk_ids].to(topk_weights.dtype) + topk_weights = topk_weights * expert_scales + return topk_weights.to(torch.float32), topk_ids.to(torch.int32) + + +def _get_text_config(config): + """Dereference text_config if config is a nested Gemma4Config. + + Gemma4 checkpoints use architectures=["Gemma4ForConditionalGeneration"] + which yields a Gemma4Config with nested text_config. This function + transparently returns the text config regardless of nesting. + """ + if hasattr(config, "text_config"): + return config.text_config + return config + + +class Gemma4MLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_activation: str, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) + self.act_fn = get_act_and_mul_fn(hidden_activation) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class Gemma4Router(nn.Module): + """Router for Gemma4 MoE that preprocesses input before projection. + + Applies RMSNorm (no learned weight), root_size scaling + (hidden_size^{-0.5}), then a learned per-dimension scale before + projecting to expert logits. + + This preprocessing is applied ONLY to the router's input, not to + the expert MLPs' input. + """ + + def __init__( + self, + config, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + + # RMSNorm without learned weight — pure normalization only + self.norm = RMSNorm(self.hidden_size, eps=config.rms_norm_eps, has_weight=False) + # Per-dimension learned scale, applied after norm + root_size + self.scale = nn.Parameter(torch.ones(self.hidden_size)) + # Constant 1/sqrt(hidden_size) scaling factor + self.register_buffer( + "root_size", + torch.tensor(self.hidden_size**-0.5), + persistent=False, + ) + # Project to expert logits; replicated across TP for consistent routing + # GateLinear supports bf16 W/A → fp32 output, which is important + # because the topk kernel often needs fp32 for stable routing. + self.proj = GateLinear( + self.hidden_size, + config.num_experts, + bias=False, + out_dtype=torch.float32, + prefix=f"{prefix}.proj", + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Returns raw router logits [T, E].""" + x = self.norm(x) + x = x * self.root_size.to(x.dtype) + x = x * self.scale.to(x.dtype) + router_logits, _ = self.proj(x) + return router_logits + + +class Gemma4MoE(nn.Module): + """Mixture of Experts for Gemma4 using vLLM's FusedMoE. + + Wraps FusedMoE with custom routing. The router projection is + external (Gemma4Router) — this class only handles expert dispatch. + + Gemma4 routing: softmax over ALL experts → top-k → renormalize. + per_expert_scale is folded into routing weights for mathematical + correctness with FusedMoE's fused kernel. + """ + + def __init__( + self, + config, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.num_experts = config.num_experts + + # Per-expert output scale folded into routing weights so that + # FusedMoE's fused kernel computes: Σ_e (expert_e * w_e * scale_e) + self.per_expert_scale = nn.Parameter(torch.ones(config.num_experts)) + + # Gemma4 routing: softmax over ALL experts → top-k → renormalize. + # FusedMoE's built-in fused_topk scopes softmax differently, so + # a custom routing function is needed for numerical correctness. + # NOTE: self.per_expert_scale is read at call time (not captured into + # a local) so that torch.func.functional_call parameter substitution + # reaches the routing function correctly. + def routing_function( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + ) -> tuple[torch.Tensor, torch.Tensor]: + if current_platform.is_cuda_alike() or current_platform.is_xpu(): + return gemma4_fused_routing_kernel_triton( + gating_output, topk, self.per_expert_scale + ) + + return gemma4_routing_function_torch( + gating_output, topk, self.per_expert_scale + ) + + # FusedMoE experts with custom Gemma4 routing + self.experts = FusedMoE( + num_experts=config.num_experts, + top_k=config.top_k_experts, + hidden_size=config.hidden_size, + intermediate_size=getattr( + config, + "moe_intermediate_size", + getattr(config, "expert_intermediate_size", None), + ), + renormalize=True, + quant_config=quant_config, + prefix=f"{prefix}.experts", + custom_routing_function=routing_function, + activation="gelu_tanh", + ) + + def forward(self, x: torch.Tensor, router_logits: torch.Tensor) -> torch.Tensor: + return self.experts(x, router_logits) + + +class Gemma4Attention(nn.Module): + def __init__( + self, + config, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + max_position_embeddings: int, + use_k_eq_v: bool = False, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + attn_logits_soft_cap: float | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.hidden_size = hidden_size + self.use_k_eq_v = use_k_eq_v + + tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + assert self.total_num_kv_heads % tp_size == 0 + else: + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + # Gemma4 uses scaling=1.0. + # Unlike Gemma2/3, query_pre_attn_scalar is NOT used here; + # Q/K norms with learnable weights handle scaling implicitly. + self.scaling = 1.0 + + # QKVParallelLinear handles GQA correctly for all layer types. + # k_eq_v layers load K weights into both K and V slots via + # _weight_iterator remapping — no structural difference needed. + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=config.attention_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=config.attention_bias, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + # Q/K norms: output = norm(x) * weight (learnable per-head scale) + self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + # V norm: no learnable scale (pure normalization only) + self.v_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, has_weight=False) + + # Determine layer type and sliding window + layer_idx = extract_layer_index(prefix) + layer_type = config.layer_types[layer_idx] + self.is_sliding = layer_type == "sliding_attention" + sliding_window = config.sliding_window if self.is_sliding else None + + # Initialize RoPE based on layer type. + # Gemma4 uses different RoPE parameters for sliding vs full attention. + if layer_type in config.rope_parameters: + # Per-layer-type rope config (dict format). + # rope_parameters already contains the correct + # partial_rotary_factor per layer type (1.0 for full + # attention, 1.0 for sliding). Do NOT override with + # global_partial_rotary_factor — that config key is + # not needed for Gemma4 — config uses per-layer rope_parameters. + rope_parameters = dict(config.rope_parameters[layer_type]) + else: + # Legacy config format fallback. + rope_parameters = dict(config.rope_parameters.copy()) + if self.is_sliding: + rope_parameters["rope_theta"] = getattr( + config, "rope_local_base_freq", 10000.0 + ) + + # KV sharing: layers in the last `num_kv_shared_layers` share KV + # cache with earlier layers of the same type. + kv_sharing_target_layer_name = None + self.is_kv_shared_layer = False + num_kv_shared_layers = getattr(config, "num_kv_shared_layers", 0) + if num_kv_shared_layers > 0: + first_kv_shared_layer_idx = config.num_hidden_layers - num_kv_shared_layers + if layer_idx >= first_kv_shared_layer_idx: + self.is_kv_shared_layer = True + # Find the last non-shared layer of the same attention type + prev_layers = config.layer_types[:first_kv_shared_layer_idx] + current_layer_type = config.layer_types[layer_idx] + kv_shared_layer_index = ( + len(prev_layers) - 1 - prev_layers[::-1].index(current_layer_type) + ) + if kv_shared_layer_index >= 0: + if ".layers." in prefix: + param_name_before_layers = prefix.split(".layers.")[0] + else: + raise ValueError( + "Unexpected prefix format for Gemma4Attention: " + f"'{prefix}'. Expected to contain '.layers.'." + ) + kv_sharing_target_layer_name = ( + f"{param_name_before_layers}.layers." + f"{kv_shared_layer_index}.self_attn.attn" + ) + + self.rotary_emb = get_rope( + self.head_dim, + max_position=max_position_embeddings, + rope_parameters=rope_parameters, + is_neox_style=True, + ) + + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + logits_soft_cap=attn_logits_soft_cap, + per_layer_sliding_window=sliding_window, + kv_sharing_target_layer_name=kv_sharing_target_layer_name, + prefix=f"{prefix}.attn", + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + # Unified QKV path (works for both k_eq_v and standard layers). + # For k_eq_v, K weights are loaded into both K and V slots of + # qkv_proj, so V == K automatically. + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + # Q norm (always applied) + q = q.unflatten(-1, (self.num_heads, self.head_dim)) + q = self.q_norm(q) + q = q.flatten(-2, -1) + + if not self.is_kv_shared_layer: + # Non-shared: apply K norm + RoPE, V norm + k = k.unflatten(-1, (self.num_kv_heads, self.head_dim)) + k = self.k_norm(k) + k = k.flatten(-2, -1) + q, k = self.rotary_emb(positions, q, k) + + v = v.unflatten(-1, (self.num_kv_heads, self.head_dim)) + v = self.v_norm(v) + v = v.flatten(-2, -1) + else: + # Shared: only apply RoPE to Q + q = self.rotary_emb(positions, q, k)[0] + + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + + return output + + +class Gemma4DecoderLayer(nn.Module): + def __init__( + self, + config, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.hidden_size_per_layer_input = getattr( + config, "hidden_size_per_layer_input", 0 + ) + + layer_idx = extract_layer_index(prefix) + self.layer_idx = layer_idx + + # Gemma4 uses different head dimensions for sliding vs full attention + layer_type = config.layer_types[layer_idx] + self.is_full_attention = layer_type == "full_attention" + if self.is_full_attention: + head_dim = getattr(config, "global_head_dim", config.head_dim) + else: + head_dim = config.head_dim + + # Determine if this full-attention layer uses k_eq_v + # (laptop variant: no v_proj, K reused as V on full attention layers) + use_k_eq_v = self.is_full_attention and getattr( + config, "attention_k_eq_v", False + ) + + # For k_eq_v full-attention layers, use num_global_key_value_heads + # as the KV head count when k_eq_v is enabled. + if use_k_eq_v: + num_kv_heads = getattr( + config, "num_global_key_value_heads", config.num_key_value_heads + ) + else: + num_kv_heads = config.num_key_value_heads + + self.self_attn = Gemma4Attention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + max_position_embeddings=config.max_position_embeddings, + use_k_eq_v=use_k_eq_v, + cache_config=cache_config, + quant_config=quant_config, + attn_logits_soft_cap=getattr(config, "attn_logit_softcapping", None), + prefix=f"{prefix}.self_attn", + ) + + # Compute per-layer intermediate_size from config. + # When use_double_wide_mlp is set, intermediate_size doubles for + # KV-shared layers (layers >= first_kv_shared_layer_idx). + first_kv_shared_layer_idx = config.num_hidden_layers - getattr( + config, "num_kv_shared_layers", 0 + ) + is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0 + use_double_wide_mlp = ( + getattr(config, "use_double_wide_mlp", False) and is_kv_shared_layer + ) + layer_intermediate_size = config.intermediate_size * ( + 2 if use_double_wide_mlp else 1 + ) + + self.mlp = Gemma4MLP( + hidden_size=self.hidden_size, + intermediate_size=layer_intermediate_size, + hidden_activation=config.hidden_activation, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + + # Layer norms: output = norm(x) * weight + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.pre_feedforward_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_feedforward_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + # MoE (Mixture of Experts) — router + expert block parallel to MLP + self.enable_moe_block = getattr(config, "enable_moe_block", False) or getattr( + config, "use_second_mlp_block", False + ) + if self.enable_moe_block: + self.router = Gemma4Router( + config, + quant_config=quant_config, + prefix=f"{prefix}.router", + ) + self.moe = Gemma4MoE( + config, + quant_config=quant_config, + prefix=f"{prefix}.moe", + ) + self.post_feedforward_layernorm_1 = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_feedforward_layernorm_2 = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.pre_feedforward_layernorm_2 = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + else: + self.router = None + self.moe = None + self.post_feedforward_layernorm_1 = None + self.post_feedforward_layernorm_2 = None + self.pre_feedforward_layernorm_2 = None + + # Per-Layer Embedding (PLE) components — present in each decoder layer + if ( + self.hidden_size_per_layer_input is not None + and self.hidden_size_per_layer_input > 0 + ): + # Gate: projects hidden_states → per-layer dim for gating + self.per_layer_input_gate = ReplicatedLinear( + self.hidden_size, + self.hidden_size_per_layer_input, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.per_layer_input_gate", + return_bias=False, + ) + # Projection: projects gated per-layer input back → hidden size + self.per_layer_projection = ReplicatedLinear( + self.hidden_size_per_layer_input, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.per_layer_projection", + return_bias=False, + ) + # Post-PLE norm: output = norm(x) * weight + self.post_per_layer_input_norm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + else: + self.per_layer_input_gate = None + self.per_layer_projection = None + self.post_per_layer_input_norm = None + + # Layer scalar (loaded from checkpoint) — applies to ALL text layers + self.register_buffer("layer_scalar", torch.ones(1)) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor | None, + per_layer_input: torch.Tensor | None = None, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Gemma4 residual pattern: + # 1. input_norm(x) → attn → post_attn_norm → ADD residual + # 2. pre_ff_norm → mlp → post_ff_norm → ADD residual + residual = hidden_states + + hidden_states = self.input_layernorm(residual) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + **kwargs, + ) + + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = hidden_states + residual + residual = hidden_states + + # MLP runs unconditionally (same inputs for MoE and non-MoE) + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + + if self.enable_moe_block: + hidden_states_1 = self.post_feedforward_layernorm_1(hidden_states) + + # Router and MoE experts see the residual (pre-MLP state), + # matching the HF transformers forward path + router_logits = self.router(residual) + hidden_states_2 = self.pre_feedforward_layernorm_2(residual) + hidden_states_2 = self.moe(hidden_states_2, router_logits) + hidden_states_2 = self.post_feedforward_layernorm_2(hidden_states_2) + + # Combine MLP and MoE outputs + hidden_states = hidden_states_1 + hidden_states_2 + + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = hidden_states + residual + + # Apply PLE (Per-Layer Embedding) if configured + if per_layer_input is not None and self.per_layer_input_gate is not None: + gate = self.per_layer_input_gate(hidden_states) + gate = torch.nn.functional.gelu(gate, approximate="tanh") + gated_per_layer = gate * per_layer_input + per_layer_contribution = self.per_layer_projection(gated_per_layer) + per_layer_contribution = self.post_per_layer_input_norm( + per_layer_contribution + ) + hidden_states = hidden_states + per_layer_contribution + + # Apply layer scalar for full-attention layers + # Apply per-layer scalar (all text layers) + hidden_states = hidden_states * self.layer_scalar + + return hidden_states, None + + +def _run_decoder_layers( + decoder_layers: list[Gemma4DecoderLayer], + layer_idx_start: int, + positions: torch.Tensor, + hidden_states: torch.Tensor, + per_layer_inputs: torch.Tensor | None = None, + **kwargs, +) -> torch.Tensor: + """Run a slice of decoder layers with PLE extraction.""" + residual = None + for idx, layer in enumerate(decoder_layers): + layer_idx = idx + layer_idx_start + layer_per_input = ( + per_layer_inputs[:, layer_idx, :] if per_layer_inputs is not None else None + ) + hidden_states, residual = layer( + positions, + hidden_states, + residual, + per_layer_input=layer_per_input, + **kwargs, + ) + return hidden_states + + +@support_torch_compile( + enable_if=lambda vllm_config: vllm_config.cache_config.kv_sharing_fast_prefill +) +class Gemma4SelfDecoderLayers(nn.Module): + """Compiled wrapper: embedding + non-KV-shared layers (YOCO first half). + + Owns the embedding and PLE modules so they are inside the compiled + graph. Gemma4Model delegates embedding methods here. + """ + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + decoder_layers: list[Gemma4DecoderLayer], + layer_idx_start: int, + embed_tokens: VocabParallelEmbedding, + normalizer: torch.Tensor, + embed_tokens_per_layer: VocabParallelEmbedding | None, + embed_scale_per_layer: torch.Tensor | None, + per_layer_model_projection: ColumnParallelLinear | None, + per_layer_projection_norm: RMSNorm | None, + per_layer_input_scale: torch.Tensor | None, + per_layer_projection_scale: torch.Tensor | None, + ): + super().__init__() + self.decoder_layers = decoder_layers + self.layer_idx_start = layer_idx_start + + config = _get_text_config(vllm_config.model_config.hf_config) + self.config = config + self.hidden_size_per_layer_input = getattr( + config, "hidden_size_per_layer_input", 0 + ) + self.vocab_size_per_layer_input = getattr( + config, "vocab_size_per_layer_input", config.vocab_size + ) + + # Shared references to modules owned by Gemma4Model — must be + # inside this nn.Module so torch.compile captures them. + self.embed_tokens = embed_tokens + self.normalizer = normalizer + self.embed_tokens_per_layer = embed_tokens_per_layer + self.embed_scale_per_layer = embed_scale_per_layer + self.per_layer_model_projection = per_layer_model_projection + self.per_layer_projection_norm = per_layer_projection_norm + self.per_layer_input_scale = per_layer_input_scale + self.per_layer_projection_scale = per_layer_projection_scale + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) * self.normalizer + + def get_per_layer_inputs(self, input_ids: torch.Tensor) -> torch.Tensor | None: + """Get per-layer embeddings from embed_tokens_per_layer. + + Returns: + Per-layer embeddings (num_tokens, num_layers, + hidden_size_per_layer_input) + """ + if self.embed_tokens_per_layer is None: + return None + per_layer_inputs_mask = torch.logical_and( + input_ids >= 0, + input_ids < self.vocab_size_per_layer_input, + ) + per_layer_inputs_tokens = torch.where( + per_layer_inputs_mask, input_ids, torch.zeros_like(input_ids) + ) + per_layer_embeds = self.embed_tokens_per_layer(per_layer_inputs_tokens) + per_layer_embeds = per_layer_embeds * self.embed_scale_per_layer + return per_layer_embeds.reshape( + *input_ids.shape, + self.config.num_hidden_layers, + self.hidden_size_per_layer_input, + ) + + def project_per_layer_inputs( + self, + inputs_embeds: torch.Tensor, + per_layer_inputs: torch.Tensor | None, + ) -> torch.Tensor | None: + """Project inputs_embeds and combine with per_layer_inputs. + + Steps: + 1. Project inputs_embeds: hidden_size → total_ple_dim + 2. Scale by hidden_size^{-0.5} + 3. Reshape to (num_tokens, num_layers, per_layer_dim) + 4. Normalize with per_layer_projection_norm + 5. Combine: (projection + per_layer_inputs) * 1/sqrt(2) + """ + if self.per_layer_model_projection is None: + return None + per_layer_projection = self.per_layer_model_projection(inputs_embeds) + per_layer_projection = per_layer_projection * self.per_layer_projection_scale + per_layer_projection = per_layer_projection.reshape( + *inputs_embeds.shape[:-1], + self.config.num_hidden_layers, + self.hidden_size_per_layer_input, + ) + per_layer_projection = self.per_layer_projection_norm(per_layer_projection) + if per_layer_inputs is None: + return per_layer_projection + return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale + + def forward( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + inputs_embeds: torch.Tensor | None = None, + per_layer_inputs: torch.Tensor | None = None, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + if inputs_embeds is not None: + hidden_states = inputs_embeds + per_layer_inputs = self.project_per_layer_inputs( + hidden_states, per_layer_inputs + ) + else: + hidden_states = self.embed_input_ids(input_ids) + per_layer_embeds = self.get_per_layer_inputs(input_ids) + per_layer_inputs = self.project_per_layer_inputs( + hidden_states, per_layer_embeds + ) + + hidden_states = _run_decoder_layers( + self.decoder_layers, + self.layer_idx_start, + positions, + hidden_states, + per_layer_inputs, + **kwargs, + ) + return hidden_states, per_layer_inputs + + +@support_torch_compile( + enable_if=lambda vllm_config: vllm_config.cache_config.kv_sharing_fast_prefill +) +class Gemma4CrossDecoderLayers(nn.Module): + """Cross-decoder layers (YOCO second half, KV-shared).""" + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + decoder_layers: list[Gemma4DecoderLayer], + layer_idx_start: int, + ): + super().__init__() + self.decoder_layers = decoder_layers + self.layer_idx_start = layer_idx_start + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + per_layer_inputs: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: + return _run_decoder_layers( + self.decoder_layers, + self.layer_idx_start, + positions, + hidden_states, + per_layer_inputs, + **kwargs, + ) + + +@support_torch_compile( + enable_if=lambda vllm_config: not vllm_config.cache_config.kv_sharing_fast_prefill +) +class Gemma4Model(nn.Module, EagleModelMixin): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = _get_text_config(vllm_config.model_config.hf_config) + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + + # PLE config values (default to 0 if not present — disables PLE) + self.hidden_size_per_layer_input = getattr( + config, "hidden_size_per_layer_input", 0 + ) + self.vocab_size_per_layer_input = getattr( + config, "vocab_size_per_layer_input", config.vocab_size + ) + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens", + ) + + # Per-Layer Embedding (PLE) components + if ( + self.hidden_size_per_layer_input is not None + and self.hidden_size_per_layer_input > 0 + ): + total_ple_dim = self.hidden_size_per_layer_input * config.num_hidden_layers + self.embed_tokens_per_layer = VocabParallelEmbedding( + self.vocab_size_per_layer_input, + total_ple_dim, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens_per_layer", + ) + # Scaled embedding factor (from config, not hardcoded) + # Register as buffer so it moves to GPU with the model + # and interacts correctly with torch.compile AOT caching. + self.register_buffer( + "embed_scale_per_layer", + torch.tensor(self.hidden_size_per_layer_input**0.5), + persistent=False, + ) + # Projection: hidden_size → total_ple_dim + # ColumnParallelLinear with gather_output=True + self.per_layer_model_projection = ColumnParallelLinear( + config.hidden_size, + total_ple_dim, + bias=False, + gather_output=True, + return_bias=False, + quant_config=quant_config, + prefix=f"{prefix}.per_layer_model_projection", + ) + # PLE projection norm: output = norm(x) * weight + self.per_layer_projection_norm = RMSNorm( + self.hidden_size_per_layer_input, + eps=config.rms_norm_eps, + ) + # Scale factor for combining projection + per_layer_inputs + # Register as buffer so it moves to GPU with the model + # and interacts correctly with torch.compile AOT caching. + self.register_buffer( + "per_layer_input_scale", + torch.rsqrt(torch.tensor(2.0)), + persistent=False, + ) + # Scaled projection: multiply output by hidden_size**-0.5. + # Register as buffer for GPU placement and torch.compile. + self.register_buffer( + "per_layer_projection_scale", + torch.tensor(config.hidden_size**-0.5), + persistent=False, + ) + else: + self.embed_tokens_per_layer = None + self.embed_scale_per_layer = None + self.per_layer_model_projection = None + self.per_layer_projection_norm = None + self.per_layer_input_scale = None + self.per_layer_projection_scale = None + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Gemma4DecoderLayer( + config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), + prefix=f"{prefix}.layers", + ) + # Final norm: output = norm(x) * weight + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + # Embedding scale = sqrt(hidden_size), cast to model dtype to avoid + # mixed-precision drift from bf16 * fp32 across deep stacks. + self.register_buffer( + "normalizer", + torch.tensor( + config.hidden_size**0.5, + dtype=vllm_config.model_config.dtype, + ), + persistent=False, + ) + + # --- You Only Cache Once (YOCO) split for fast prefill --- + first_kv_shared_layer_idx = config.num_hidden_layers - getattr( + config, "num_kv_shared_layers", 0 + ) + + from vllm.compilation.backends import set_model_tag + + # Layers 0..(K-1) are self-decoder layers in YOCO + with set_model_tag("self_decoder"): + self.self_decoder = Gemma4SelfDecoderLayers( + vllm_config=vllm_config, + prefix=f"{prefix}.self_decoder", + decoder_layers=self.layers[:first_kv_shared_layer_idx], + layer_idx_start=0, + embed_tokens=self.embed_tokens, + normalizer=self.normalizer, + embed_tokens_per_layer=getattr(self, "embed_tokens_per_layer", None), + embed_scale_per_layer=getattr(self, "embed_scale_per_layer", None), + per_layer_model_projection=getattr( + self, "per_layer_model_projection", None + ), + per_layer_projection_norm=getattr( + self, "per_layer_projection_norm", None + ), + per_layer_input_scale=getattr(self, "per_layer_input_scale", None), + per_layer_projection_scale=getattr( + self, "per_layer_projection_scale", None + ), + ) + # Layers K..(N-1) are cross-decoder layers in YOCO + with set_model_tag("cross_decoder"): + self.cross_decoder = Gemma4CrossDecoderLayers( + vllm_config=vllm_config, + prefix=f"{prefix}.cross_decoder", + decoder_layers=self.layers[first_kv_shared_layer_idx:], + layer_idx_start=first_kv_shared_layer_idx, + ) + + self.fast_prefill_enabled = cache_config.kv_sharing_fast_prefill + + if self.fast_prefill_enabled: + # Allocate static buffers for CUDAGraph + max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens + device = next(self.parameters()).device + self.positions = torch.zeros( + max_num_tokens, dtype=torch.int64, device=device + ) + self.hidden_states = torch.zeros( + (max_num_tokens, config.hidden_size), + dtype=vllm_config.model_config.dtype, + device=device, + ) + if ( + self.hidden_size_per_layer_input + and self.hidden_size_per_layer_input > 0 + ): + self.per_layer_inputs = torch.zeros( + ( + max_num_tokens, + config.num_hidden_layers, + self.hidden_size_per_layer_input, + ), + dtype=vllm_config.model_config.dtype, + device=device, + ) + else: + self.per_layer_inputs = None + + # Custom factory that includes per_layer_inputs for PLE-enabled PP. + # per_layer_inputs has shape (batch, num_layers, per_layer_dim), + # which differs from the standard (batch, hidden_size) shape, + # so we can't use the default factory. + ple_dim = self.hidden_size_per_layer_input + num_layers = config.num_hidden_layers + hidden_size = config.hidden_size + + def _make_empty_intermediate_tensors( + batch_size: int, + dtype: torch.dtype, + device: torch.device, + ) -> IntermediateTensors: + tensors: dict[str, torch.Tensor] = { + "hidden_states": torch.zeros( + (batch_size, hidden_size), + dtype=dtype, + device=device, + ), + } + if ple_dim and ple_dim > 0: + tensors["per_layer_inputs"] = torch.zeros( + (batch_size, num_layers, ple_dim), + dtype=dtype, + device=device, + ) + return IntermediateTensors(tensors) + + self.make_empty_intermediate_tensors = _make_empty_intermediate_tensors + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.self_decoder.embed_input_ids(input_ids) + + def get_per_layer_inputs(self, input_ids: torch.Tensor) -> torch.Tensor | None: + """Get per-layer embeddings from embed_tokens_per_layer. + + Returns: + Per-layer embeddings (num_tokens, num_layers, + hidden_size_per_layer_input) + """ + return self.self_decoder.get_per_layer_inputs(input_ids) + + def project_per_layer_inputs( + self, + inputs_embeds: torch.Tensor, + per_layer_inputs: torch.Tensor | None, + ) -> torch.Tensor | None: + """Project inputs_embeds and combine with per_layer_inputs. + + Steps: + 1. Project inputs_embeds: hidden_size → total_ple_dim + 2. Scale by hidden_size^{-0.5} + 3. Reshape to (num_tokens, num_layers, per_layer_dim) + 4. Normalize with per_layer_projection_norm + 5. Combine: (projection + per_layer_inputs) * 1/sqrt(2) + """ + return self.self_decoder.project_per_layer_inputs( + inputs_embeds, per_layer_inputs + ) + + def fast_prefill_forward( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + inputs_embeds: torch.Tensor | None = None, + per_layer_inputs: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: + logits_indices_padded, num_logits_indices = None, None + attn_metadata = get_forward_context().attn_metadata + + if attn_metadata is not None: + assert isinstance(attn_metadata, dict) + layer_attn_metadata = attn_metadata[ + self.layers[-1].self_attn.attn.layer_name + ] + if isinstance(layer_attn_metadata, KVSharingFastPrefillMetadata): + logits_indices_padded = layer_attn_metadata.logits_indices_padded + num_logits_indices = layer_attn_metadata.num_logits_indices + + batch_size = positions.size(0) + self.positions[:batch_size].copy_(positions) + self_decoder_hidden_states, per_layer_inputs = self.self_decoder( + input_ids=input_ids, + positions=self.positions[:batch_size], + inputs_embeds=inputs_embeds, + per_layer_inputs=per_layer_inputs, + **kwargs, + ) + + if logits_indices_padded is None: + logits_indices_padded = torch.arange( + batch_size, + dtype=positions.dtype, + device=positions.device, + ) + + # NOTE: Keep .clone() until fix in + # https://github.com/vllm-project/vllm/pull/22282 + hidden_states = self_decoder_hidden_states.clone() + + num_padded = logits_indices_padded.size(0) + self.positions[:num_padded].copy_(positions[logits_indices_padded]) + self.hidden_states[:num_padded].copy_( + self_decoder_hidden_states[logits_indices_padded] + ) + if self.per_layer_inputs is not None and per_layer_inputs is not None: + self.per_layer_inputs[:num_padded].copy_( + per_layer_inputs[logits_indices_padded] + ) + + # Update batch_descriptor so the cross-decoder's piecewise + # CUDAGraphWrapper dispatches to the correct (reduced) batch size. + forward_context = get_forward_context() + orig_batch_desc = forward_context.batch_descriptor + if orig_batch_desc is not None: + forward_context.batch_descriptor = replace( + orig_batch_desc, num_tokens=num_padded + ) + + cross_per_layer = ( + self.per_layer_inputs[:num_padded] + if self.per_layer_inputs is not None + else None + ) + cross_hidden_states = self.cross_decoder( + self.positions[:num_padded], + self.hidden_states[:num_padded], + cross_per_layer, + **kwargs, + ) + + # Restore the original batch_descriptor + forward_context.batch_descriptor = orig_batch_desc + + if num_logits_indices is not None: + assert num_logits_indices > 0 + hidden_states[logits_indices_padded[:num_logits_indices]] = ( + cross_hidden_states[:num_logits_indices] + ) + else: + hidden_states = cross_hidden_states + + return hidden_states + + def forward( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + per_layer_inputs: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]: + if self.fast_prefill_enabled: + hidden_states = self.fast_prefill_forward( + input_ids, + positions, + inputs_embeds, + per_layer_inputs, + **kwargs, + ) + hidden_states = self.norm(hidden_states) + return hidden_states + + # Normal (non-fast-prefill) path with PP support + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + # When called from the multimodal wrapper, raw PLE + # embeddings are pre-computed and passed explicitly. + # Project them through per_layer_model_projection. + per_layer_inputs = self.project_per_layer_inputs( + hidden_states, per_layer_inputs + ) + else: + hidden_states = self.embed_input_ids(input_ids) + # Compute per-layer inputs for PLE + per_layer_embeds = self.get_per_layer_inputs(input_ids) + per_layer_inputs = self.project_per_layer_inputs( + hidden_states, per_layer_embeds + ) + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + if per_layer_inputs is not None: + per_layer_inputs = intermediate_tensors["per_layer_inputs"] + residual = None + aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual) + for layer_idx, layer in enumerate( + islice(self.layers, self.start_layer, self.end_layer) + ): + # Extract the per-layer embedding for this specific layer + if per_layer_inputs is not None: + actual_layer_idx = self.start_layer + layer_idx + layer_per_input = per_layer_inputs[ + :, actual_layer_idx, : + ] # (num_tokens, per_layer_dim) + else: + layer_per_input = None + hidden_states, residual = layer( + positions, + hidden_states, + residual, + per_layer_input=layer_per_input, + **kwargs, + ) + self._maybe_add_hidden_state( + aux_hidden_states, layer_idx + 1, hidden_states, residual + ) + if not get_pp_group().is_last_rank: + tensors: dict[str, torch.Tensor] = { + "hidden_states": hidden_states, + } + if per_layer_inputs is not None: + tensors["per_layer_inputs"] = per_layer_inputs + return IntermediateTensors(tensors) + # Gemma4 incorporates residual into hidden_states directly + # Apply norm without residual fusion when possible. + if residual is None: + hidden_states = self.norm(hidden_states) + else: + hidden_states, _ = self.norm(hidden_states, residual) + + if len(aux_hidden_states) > 0: + return hidden_states, aux_hidden_states + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # MoE expert weight mapping: checkpoint can have either: + # 1. 3D packed tensors (exploded in _weight_iterator to per-expert 2D) + # 2. Already per-expert 2D weights (if quantized) + # Map to FusedMoE parameters: + # moe.experts.{id}.gate_proj → FusedMoE w1 (shard of w13) + # moe.experts.{id}.up_proj → FusedMoE w3 (shard of w13) + # moe.experts.{id}.down_proj → FusedMoE w2 + num_experts = getattr(self.config, "num_experts", None) or 0 + # Strategy A: dot-separated suffix + # (standard AWQ/GPTQ e.g. .qweight, .scales, .weight) + dot_suffix_expert_params_mapping = fused_moe_make_expert_params_mapping( + self, + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=num_experts, + ) + # Strategy B: underscore-separated suffix + # (CompressedTensors-format AWQ/W4A16 _packed, _scale) + underscore_suffix_expert_params_mapping = [ + ( + f"{param_name}weight_", + f"{weight_name.rstrip('.')}_", + expert_id, + shard_id, + ) + for ( + param_name, + weight_name, + expert_id, + shard_id, + ) in dot_suffix_expert_params_mapping + ] + expert_params_mapping = ( + dot_suffix_expert_params_mapping + underscore_suffix_expert_params_mapping + ) + params_dict = dict(self.named_parameters()) + # Include buffers (e.g. layer_scalar) so they can be loaded too + params_dict.update(dict(self.named_buffers())) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if name.endswith((".k_scale", ".v_scale", ".q_scale", ".prob_scale")): + remapped_name = maybe_remap_kv_scale_name(name, params_dict) + if remapped_name is not None and remapped_name in params_dict: + param = params_dict[remapped_name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + loaded_params.add(remapped_name) + continue + + for param_name, shard_name, shard_id in stacked_params_mapping: + if shard_name not in name: + continue + stacked_name = name.replace(shard_name, param_name) + # k_eq_v layers use separate q_proj/k_proj instead of + # packed qkv_proj. If the stacked param doesn't exist, + # skip this mapping and fall through to direct load. + if stacked_name not in params_dict: + continue + if is_pp_missing_parameter(stacked_name, self): + continue + param = params_dict[stacked_name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + loaded_params.add(stacked_name) + break + else: + for ( + param_name, + weight_name, + expert_id, + shard_id, + ) in expert_params_mapping: + # Match both: + # - Bare weights: "experts.0.down_proj" (from 3D explosion) + # - With suffix: "experts.0.down_proj.weight_scale" (2D quantized) + # weight_name has trailing dot, so check with and without it + weight_name_base = weight_name.rstrip(".") + if weight_name in name: + # Has suffix (e.g., .weight_scale) + moe_name = name.replace(weight_name, param_name) + elif name.endswith(weight_name_base): + # Bare weight (no suffix) + moe_name = name.replace( + weight_name_base, param_name.rstrip("_") + "_weight" + ) + else: + continue + if moe_name not in params_dict: + continue + if is_pp_missing_parameter(moe_name, self): + continue + param = params_dict[moe_name] + # Expert weights are already in the correct + # orientation for FusedMoE after _weight_iterator: + # gate/up: [I, H] → w1/w3 expects [I, H] + # down: [H, I] → w2 expects [H, I] + # Scales and other quantization params may be 1D or scalar. + weight_loader = param.weight_loader + weight_loader( + param, + loaded_weight, + moe_name, # Pass mapped name (handles both weights and scales) + shard_id=shard_id, + expert_id=expert_id, + ) + loaded_params.add(moe_name) + break + else: + if name.endswith(".bias") and name not in params_dict: + continue + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + return loaded_params + + +class Gemma4ForCausalLM( + nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts, SupportsEagle3 +): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + # Gemma4ForConditionalGeneration already loads the text stack + # from `model.language_model.*`. We reuse that same checkpoint + # and adapter naming for the text-only Gemma4ForCausalLM path, + # so LoRA keys from the conditional wrapper map onto `model.*`. + "model.language_model.": "model.", + }, + orig_to_new_substr={ + # Gemma4ForConditionalGeneration names MoE adapter targets under + # `...moe.experts.*`, while the text-only model exposes them + # under `...moe.*`. + ".moe.experts.gate_up_proj": ".moe.gate_up_proj", + ".moe.experts.down_proj": ".moe.down_proj", + }, + ) + # Note: qkv_proj packing applies to non-k_eq_v layers (sliding + # attention and full attention without k_eq_v). k_eq_v layers use + # separate q_proj + k_proj without packing. + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = _get_text_config(vllm_config.model_config.hf_config) + quant_config = vllm_config.quant_config + + super().__init__() + self.config = config + self.quant_config = quant_config + self.model = Gemma4Model( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model"), + ) + + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + if config.tie_word_embeddings: + self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) + + self.logits_processor = LogitsProcessor( + config.vocab_size, + soft_cap=getattr(config, "final_logit_softcapping", None), + ) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors + ) + + # --- MixtureOfExperts protocol --- + self.expert_weights: list[list[torch.Tensor]] = [] + self.moe_layers: list[nn.Module] = [] + example_moe: Gemma4MoE | None = None + + for layer in self.model.layers: + if hasattr(layer, "moe") and isinstance(layer.moe, Gemma4MoE): + example_moe = layer.moe + self.moe_layers.append(layer.moe.experts) + + self.num_moe_layers = len(self.moe_layers) + + if example_moe is not None: + self.num_logical_experts = example_moe.num_experts + self.num_physical_experts = example_moe.num_experts + self.num_local_physical_experts = example_moe.num_experts + self.num_routed_experts = example_moe.num_experts + else: + self.num_logical_experts = 0 + self.num_physical_experts = 0 + self.num_local_physical_experts = 0 + self.num_routed_experts = 0 + + self.num_expert_groups = 1 + self.num_shared_experts = 0 + self.num_redundant_experts = 0 + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]: + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + return self.logits_processor(self.lm_head, hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + # Checkpoint weight names use "language_model." prefix (from the + # Gemma4ForConditionalGeneration wrapper). Strip it to map to our + # model tree which is just "model.*". + def _weight_iterator(): + use_k_eq_v = getattr(self.config, "attention_k_eq_v", False) + # Build set of k_eq_v layer indices (full_attention layers + # when attention_k_eq_v is enabled). These layers have k_proj + # but no v_proj in checkpoint — we duplicate k_proj as v_proj. + k_eq_v_layer_indices: set[int] = set() + if use_k_eq_v: + for idx, lt in enumerate(self.config.layer_types): + if lt == "full_attention": + k_eq_v_layer_indices.add(idx) + + for name, weight in weights: + # Remap "language_model." → "" to match our model tree. + # Checkpoint: model.language_model.layers.X.* + # Our model: model.layers.X.* + name = name.replace("language_model.", "") + + # Remap new HF checkpoint naming to internal vLLM + # naming: HF moved per_expert_scale to router and + # renamed moe → experts in the MoE block. + name = name.replace( + ".router.per_expert_scale", + ".moe.per_expert_scale", + ) + if ".experts.gate_up_proj" in name: + name = name.replace( + ".experts.gate_up_proj", + ".moe.gate_up_proj", + ) + elif ".experts.down_proj" in name: + name = name.replace( + ".experts.down_proj", + ".moe.down_proj", + ) + + # Remap individual 2D expert weights: + # .experts.{id}.{proj} → .moe.experts.{id}.{proj} + # (This handles per-expert 2D quantized weights) + name = _remap_gemma4_expert_weight_name(name) + + # MoE expert weights: checkpoint stores as 3D packed + # tensors. Explode into per-expert 2D weights for + # FusedMoE weight_loader. + # + # Checkpoint format: + # moe.gate_up_proj: [E, 2*I, H] (fused gate + up) + # moe.down_proj: [E, H, I] + # + # FusedMoE expects per-expert: + # w1 (gate): [I, H] — first half of gate_up + # w3 (up): [I, H] — second half of gate_up + # w2 (down): [H, I] — as-is from checkpoint + # + # No transpose needed: checkpoint orientation already + # matches FusedMoE's expected layout. + if "moe.gate_up_proj" in name and weight.dim() == 3: + num_experts = weight.size(0) + intermediate_size = weight.size(1) // 2 + for expert_id in range(num_experts): + gate_weight = weight[expert_id, :intermediate_size, :] + up_weight = weight[expert_id, intermediate_size:, :] + base = name.replace("moe.", f"moe.experts.{expert_id}.") + yield base.replace("gate_up_proj", "gate_proj"), gate_weight + yield base.replace("gate_up_proj", "up_proj"), up_weight + continue + + if "moe.down_proj" in name and weight.dim() == 3: + num_experts = weight.size(0) + for expert_id in range(num_experts): + expert_name = name.replace("moe.", f"moe.experts.{expert_id}.") + yield expert_name, weight[expert_id] + continue + + # k_eq_v layers: checkpoint has k_proj but no v_proj. + # QKVParallelLinear expects both, so duplicate k_proj + # as v_proj so V gets identical weights to K. + # ONLY for full_attention layers — sliding layers have + # their own real v_proj weights. + if "self_attn.k_proj" in name and k_eq_v_layer_indices: + m = re.search(r"layers\.(\d+)\.", name) + if m and int(m.group(1)) in k_eq_v_layer_indices: + yield name, weight + yield name.replace("k_proj", "v_proj"), weight.clone() + continue + + yield name, weight + + # Skip multimodal weights — handled by the multimodal wrapper. + # Also skip lm_head when weights are tied. + skip = [ + "audio_tower.", + "vision_tower.", + "embed_audio.", + "embed_vision.", + ] + if self.config.tie_word_embeddings: + skip.append("lm_head.") + + loader = AutoWeightsLoader(self, skip_substrs=skip) + return loader.load_weights(_weight_iterator()) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 40612791f7..9ce3af62c8 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -109,6 +109,7 @@ "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"), "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"), "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"), + "Gemma4ForCausalLM": ("gemma4", "Gemma4ForCausalLM"), "Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"), "Qwen3_5ForConditionalGeneration": ("qwen3_5", "Qwen3_5ForConditionalGeneration"), # noqa: E501 "Qwen3_5MoeForConditionalGeneration": ("qwen3_5", "Qwen3_5MoeForConditionalGeneration"), # noqa: E501 @@ -509,6 +510,7 @@ "Step3p5MTP": ("step3p5_mtp", "Step3p5MTP"), "Qwen3_5MTP": ("qwen3_5_mtp", "Qwen3_5MTP"), "Qwen3_5MoeMTP": ("qwen3_5_mtp", "Qwen3_5MoeMTP"), + "Gemma4MTPModel": ("gemma4_mtp", "Gemma4MTP"), # Temporarily disabled. # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1. # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), From 3f5998dad5a81597cc41b817e508d05ad9ae06ea Mon Sep 17 00:00:00 2001 From: Phil Date: Fri, 5 Jun 2026 21:31:49 +0000 Subject: [PATCH 08/25] [gemma4] Backbone imports clean: vendor GateLinear + expert-mapping shim - Vendor upstream fused_moe/router/gate_linear.py. Its specialized GEMM tiers (DSV3/fp32/cuBLAS) are all SM90+-gated, so on V100/SM70 it falls through to the ReplicatedLinear F.linear path; the missing ops.fp32_router_gemm is never reached and import/registration don't touch it. - Export GateLinear and add module-level fused_moe_make_expert_params_mapping (delegates to FusedMoE.make_expert_params_mapping, which upstream refactored into a standalone function). gemma4.py now imports cleanly against base d4f98f3b1 (verified against the live .venv-v110 runtime); Gemma4ForCausalLM present. Co-Authored-By: RivetOS Claude --- .../layers/fused_moe/__init__.py | 10 + .../layers/fused_moe/router/gate_linear.py | 175 ++++++++++++++++++ 2 files changed, 185 insertions(+) create mode 100644 vllm/model_executor/layers/fused_moe/router/gate_linear.py diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 71eb4947ae..6c4d011436 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -23,6 +23,7 @@ from vllm.model_executor.layers.fused_moe.router.fused_moe_router import ( FusedMoERouter, ) +from vllm.model_executor.layers.fused_moe.router.gate_linear import GateLinear from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import ( UnquantizedFusedMoEMethod, @@ -49,8 +50,17 @@ def get_config() -> dict[str, Any] | None: return _config +def fused_moe_make_expert_params_mapping(*args, **kwargs): + # Standalone alias for FusedMoE.make_expert_params_mapping. Upstream vLLM + # refactored the classmethod into a module-level function; some models + # (e.g. gemma4) import it by that name. + return FusedMoE.make_expert_params_mapping(*args, **kwargs) + + __all__ = [ "FusedMoE", + "GateLinear", + "fused_moe_make_expert_params_mapping", "FusedMoERouter", "FusedMoEConfig", "FusedMoEMethodBase", diff --git a/vllm/model_executor/layers/fused_moe/router/gate_linear.py b/vllm/model_executor/layers/fused_moe/router/gate_linear.py new file mode 100644 index 0000000000..0a57a6f4df --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/router/gate_linear.py @@ -0,0 +1,175 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch +from torch.nn.parameter import Parameter + +import vllm._custom_ops as ops +from vllm.model_executor.custom_op import PluggableLayer +from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.platforms import current_platform +from vllm.utils.torch_utils import direct_register_custom_op + + +@PluggableLayer.register("gate_linear") +class GateLinear(ReplicatedLinear): + """MoE gate linear layer with multi-tier GEMM dispatch: + + 1. DSV3 specialized kernel (SM90+, fp32 out, M<=16, H=7168, E=256/384) + 2. fp32 specialized kernel (SM90+, bf16/fp32 in, fp32 out, + M<=32, H=3072, E=256) + 3. cuBLAS bf16×bf16→fp32 (SM90+ + bf16 weight + fp32 out_dtype) + 4. F.linear via ReplicatedLinear (ultimate fallback) + + The ``out_dtype`` attribute is mutable and can be set after init + (e.g. when the required dtype depends on the expert quantization + method which is only known later). + """ + + # Dimensions supported by the DSV3 specialized kernel + DSV3_SUPPORTED_NUM_EXPERTS = [256, 384] + DSV3_SUPPORTED_HIDDEN_SIZES = [7168] + + # Dimensions supported by the fp32 specialized kernel + FP32_SUPPORTED_NUM_EXPERTS = [256] + FP32_SUPPORTED_HIDDEN_SIZES = [3072] + FP32_MAX_TOKENS = 32 + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = False, + out_dtype: torch.dtype | None = None, + params_dtype: torch.dtype | None = None, + force_fp32_compute: bool = False, + prefix: str = "", + ): + is_hopper_or_blackwell = current_platform.is_device_capability( + (9, 0) + ) or current_platform.is_device_capability_family(100) + can_use_specialized_kernels = ( + current_platform.is_cuda() and is_hopper_or_blackwell and not bias + ) + + # If fp32 compute is required and no specialized kernel is available, + # store weights in fp32 so the fallback linear path computes in fp32. + if force_fp32_compute and not can_use_specialized_kernels: + params_dtype = torch.float32 + + super().__init__( + input_size, + output_size, + bias=bias, + params_dtype=params_dtype, + quant_config=None, + prefix=prefix, + ) + self.out_dtype = out_dtype + + # DSV3 specialized kernel eligibility (SM90+, exact dims) + self.allow_specialized_router_gemm = can_use_specialized_kernels + self.allow_dsv3_router_gemm = ( + self.allow_specialized_router_gemm + and output_size in self.DSV3_SUPPORTED_NUM_EXPERTS + and input_size in self.DSV3_SUPPORTED_HIDDEN_SIZES + ) + + # fp32 specialized kernel eligibility (SM90+, exact dims, fp32 weight) + self.allow_fp32_router_gemm = ( + not bias + and self.weight.dtype == torch.float32 + and current_platform.is_cuda() + and is_hopper_or_blackwell + and output_size in self.FP32_SUPPORTED_NUM_EXPERTS + and input_size in self.FP32_SUPPORTED_HIDDEN_SIZES + ) + + # cuBLAS bf16→fp32 eligibility + self.allow_cublas_router_gemm = ( + self.allow_specialized_router_gemm + and self.weight.dtype == torch.bfloat16 + and self.out_dtype == torch.float32 + ) + + def set_out_dtype(self, out_dtype: torch.dtype) -> None: + """Set output dtype for the router logits after init. + + Useful when the required dtype depends on the expert quantization + method which is only known after the gate is constructed. + """ + if self.out_dtype is not None: + raise ValueError("out_dtype has already been set") + self.out_dtype = out_dtype + + if ( + not self.allow_cublas_router_gemm + and self.allow_specialized_router_gemm + and out_dtype == torch.float32 + ): + self.allow_cublas_router_gemm = self.weight.dtype == torch.bfloat16 + + def forward( + self, x: torch.Tensor + ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]: + # Tier 1: DSV3 specialized kernel + if self.allow_dsv3_router_gemm and x.shape[0] <= 16: + output = ops.dsv3_router_gemm( + hidden_states=x, + router_weight=self.weight, + output_dtype=self.out_dtype, + ) + return output, None + + # Tier 2: fp32 specialized kernel (H=3072, E=256, M<=32) + # Dispatch is wrapped in a custom op so that torch.compile/CUDA-graph + # capture does not freeze the runtime num_tokens branch. + if self.allow_fp32_router_gemm and x.dtype in ( + torch.float32, + torch.bfloat16, + ): + output = torch.ops.vllm.fp32_router_gemm_dispatch(x, self.weight) + return output, None + + # Tier 3: cuBLAS bf16→fp32 + if self.allow_cublas_router_gemm and x.dtype == torch.bfloat16: + output = torch.mm(x, self.weight.T, out_dtype=torch.float32) + return output, None + + # Tier 4: F.linear (ReplicatedLinear) + if self.out_dtype is not None and x.dtype != self.weight.dtype: + x = x.to(self.weight.dtype) + output, output_bias = super().forward(x) + if self.out_dtype is not None and output.dtype != self.out_dtype: + output = output.to(self.out_dtype) + return output, output_bias + + +_FP32_ROUTER_GEMM_MAX_TOKENS = GateLinear.FP32_MAX_TOKENS + + +def fp32_router_gemm_dispatch_impl( + x: torch.Tensor, weight: torch.Tensor +) -> torch.Tensor: + """ + Dynamically run fp32 specialized gemm if num_tokens <= FP32_MAX_TOKENS, + otherwise fall back to F.linear. + This must be wrapped in a custom op because our torch.compile integration + does not support runtime dispatching on num_tokens. + """ + if x.shape[0] <= _FP32_ROUTER_GEMM_MAX_TOKENS: + return ops.fp32_router_gemm(x, weight) + else: + return torch.nn.functional.linear(x.float(), weight) + + +def fp32_router_gemm_dispatch_fake( + x: torch.Tensor, weight: torch.Tensor +) -> torch.Tensor: + return x.new_empty((x.shape[0], weight.shape[0]), dtype=torch.float32) + + +direct_register_custom_op( + op_name="fp32_router_gemm_dispatch", + op_func=fp32_router_gemm_dispatch_impl, + fake_impl=fp32_router_gemm_dispatch_fake, +) From cf1781bd9f3fcef71674245daaf0f1e71a6b9752 Mon Sep 17 00:00:00 2001 From: Phil Date: Fri, 5 Jun 2026 21:33:04 +0000 Subject: [PATCH 09/25] [gemma4][WIP] Vendor MTP drafter (gemma4_mtp) + spec-decode proposer - gemma4_mtp.py (627 LOC) imports clean against base. - v1/spec_decode/gemma4.py (340 LOC) vendored but NOT yet import-clean: it subclasses SpecDecodeBaseProposer from v1/spec_decode/llm_base_proposer, a newer-upstream proposer abstraction our base predates. Next: adapt the proposer onto our base's eagle.py architecture (or backport the base class). Co-Authored-By: RivetOS Claude --- vllm/model_executor/models/gemma4_mtp.py | 627 +++++++++++++++++++++++ vllm/v1/spec_decode/gemma4.py | 340 ++++++++++++ 2 files changed, 967 insertions(+) create mode 100644 vllm/model_executor/models/gemma4_mtp.py create mode 100644 vllm/v1/spec_decode/gemma4.py diff --git a/vllm/model_executor/models/gemma4_mtp.py b/vllm/model_executor/models/gemma4_mtp.py new file mode 100644 index 0000000000..03961cac19 --- /dev/null +++ b/vllm/model_executor/models/gemma4_mtp.py @@ -0,0 +1,627 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Inference-only Gemma4 MTP (Multi-Token Prediction) model. + +The Gemma4 assistant model is a lightweight decoder that shares KV cache +with the target (backbone) model. All assistant decoder layers are +KV-shared: they only have Q projections (no K/V projections or norms), +and read K/V from the target model's cache at runtime. + +Checkpoint layout (``gemma4_assistant``):: + + model.embed_tokens.* -- token embeddings + model.layers.{i}.* -- decoder layers (Q-only attention + MLP) + model.norm.* -- final RMSNorm + pre_projection.* -- Linear(2 * backbone_hidden_size, hidden_size) + post_projection.* -- Linear(hidden_size, backbone_hidden_size) + lm_head.* -- language model head (tied to embed_tokens) + masked_embedding.centroids.* -- centroid projection (when use_ordered_embeddings) + masked_embedding.token_ordering -- token-to-centroid mapping buffer +""" + +from collections.abc import Iterable + +import torch +from torch import nn + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import ( + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, +) +from vllm.logger import init_logger +from vllm.model_executor.layers.attention import Attention +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.sequence import IntermediateTensors + +from .gemma4 import Gemma4MLP, _get_text_config +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + extract_layer_index, + maybe_prefix, +) + +logger = init_logger(__name__) + + +class Gemma4MTPMaskedEmbedder(nn.Module): + """Sparse logit computation via centroid-based vocabulary masking. + + Instead of computing logits against the full vocabulary, projects + hidden states to centroid scores, selects top-K centroids, and + computes logits only for the ~top_k * (vocab_size / num_centroids) + tokens belonging to those centroids. + """ + + token_ordering: torch.Tensor + + def __init__( + self, + hidden_size: int, + vocab_size: int, + num_centroids: int, + centroid_intermediate_top_k: int, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.vocab_size = vocab_size + self.num_centroids = num_centroids + self.centroid_intermediate_top_k = centroid_intermediate_top_k + self.vocab_size_per_centroid = vocab_size // num_centroids + self.num_selected = centroid_intermediate_top_k * self.vocab_size_per_centroid + + self.centroids = nn.Linear(hidden_size, num_centroids, bias=False) + self.register_buffer( + "token_ordering", + torch.empty(vocab_size, dtype=torch.long), + ) + + def _select_and_score( + self, + hidden_states: torch.Tensor, + lm_head_weight: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Centroid selection + sparse dot product. + + Returns: + logits: (num_tokens, num_selected) sparse logits. + indices: (num_tokens, num_selected) corresponding vocab indices. + """ + num_tokens = hidden_states.shape[0] + _, top_k_indices = torch.topk( + self.centroids(hidden_states), + k=self.centroid_intermediate_top_k, + dim=-1, + ) + clusters = self.token_ordering.view( + self.num_centroids, + self.vocab_size_per_centroid, + ) + selected = clusters[top_k_indices] + embeddings = lm_head_weight[selected.reshape(-1)].view( + num_tokens, + self.num_selected, + self.hidden_size, + ) + logits = torch.einsum("td,tsd->ts", hidden_states, embeddings) + return logits, selected.view(num_tokens, -1) + + def forward( + self, + hidden_states: torch.Tensor, + lm_head_weight: torch.Tensor, + ) -> torch.Tensor: + """Full-vocab logits with non-selected positions masked to -inf.""" + logits, indices = self._select_and_score(hidden_states, lm_head_weight) + output = torch.full( + (hidden_states.shape[0], self.vocab_size), + fill_value=torch.finfo(hidden_states.dtype).min, + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + return output.scatter_(-1, indices, logits) + + def get_top_tokens( + self, + hidden_states: torch.Tensor, + lm_head_weight: torch.Tensor, + ) -> torch.Tensor: + """Sparse argmax — returns vocab token IDs without full-vocab tensor.""" + logits, indices = self._select_and_score(hidden_states, lm_head_weight) + return indices.gather(-1, logits.argmax(-1, keepdim=True)).squeeze(-1) + + +class Gemma4MTPAttention(nn.Module): + """Q-only attention for Gemma4 MTP layers. + + K/V come from the target model's KV cache via + ``kv_sharing_target_layer_name`` (set by the proposer after + model construction). + """ + + def __init__( + self, + config, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + max_position_embeddings: int, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + attn_logits_soft_cap: float | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.hidden_size = hidden_size + + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim + self.q_size = self.num_heads * self.head_dim + self.scaling = 1.0 + + self.q_proj = ColumnParallelLinear( + hidden_size, + self.total_num_heads * self.head_dim, + bias=config.attention_bias, + quant_config=None, + prefix=f"{prefix}.q_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=config.attention_bias, + quant_config=None, + prefix=f"{prefix}.o_proj", + ) + self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + + layer_idx = extract_layer_index(prefix) + layer_type = config.layer_types[layer_idx] + self.is_sliding = layer_type == "sliding_attention" + sliding_window = config.sliding_window if self.is_sliding else None + + if layer_type in config.rope_parameters: + rope_parameters = dict(config.rope_parameters[layer_type]) + else: + rope_parameters = dict(config.rope_parameters.copy()) + if self.is_sliding: + rope_parameters["rope_theta"] = getattr( + config, "rope_local_base_freq", 10000.0 + ) + + self.rotary_emb = get_rope( + self.head_dim, + max_position=max_position_embeddings, + rope_parameters=rope_parameters, + is_neox_style=True, + ) + + # kv_sharing_target_layer_name is set after model construction + # by Gemma4Proposer._setup_gemma4_kv_sharing(). + self.is_kv_shared_layer = True + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + logits_soft_cap=attn_logits_soft_cap, + per_layer_sliding_window=sliding_window, + prefix=f"{prefix}.attn", + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + q, _ = self.q_proj(hidden_states) + + q = q.unflatten(-1, (self.num_heads, self.head_dim)) + q = self.q_norm(q) + q = q.flatten(-2, -1) + + q, _ = self.rotary_emb(positions, q, None) + + # Attention reads K/V from the target's cache via KV sharing; + # these dummy tensors are never consumed but required by the API. + num_tokens = q.shape[0] + kv_dummy = torch.empty( + num_tokens, + self.num_kv_heads * self.head_dim, + dtype=q.dtype, + device=q.device, + ) + attn_output = self.attn(q, kv_dummy, kv_dummy) + output, _ = self.o_proj(attn_output) + return output + + +class Gemma4MTPDecoderLayer(nn.Module): + def __init__( + self, + config, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + + layer_idx = extract_layer_index(prefix) + layer_type = config.layer_types[layer_idx] + is_full_attention = layer_type == "full_attention" + head_dim = ( + getattr(config, "global_head_dim", config.head_dim) + if is_full_attention + else config.head_dim + ) + + use_k_eq_v = is_full_attention and getattr(config, "attention_k_eq_v", False) + if use_k_eq_v: + num_kv_heads = getattr( + config, "num_global_key_value_heads", config.num_key_value_heads + ) + else: + num_kv_heads = config.num_key_value_heads + + self.self_attn = Gemma4MTPAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + max_position_embeddings=config.max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + attn_logits_soft_cap=getattr(config, "attn_logit_softcapping", None), + prefix=f"{prefix}.self_attn", + ) + + text_config = _get_text_config(config) + self.mlp = Gemma4MLP( + hidden_size=self.hidden_size, + intermediate_size=text_config.intermediate_size, + hidden_activation=text_config.hidden_activation, + quant_config=None, + prefix=f"{prefix}.mlp", + ) + + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.pre_feedforward_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_feedforward_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + self.register_buffer("layer_scalar", torch.ones(1)) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor | None, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: + residual = hidden_states + hidden_states = self.input_layernorm(residual) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + **kwargs, + ) + + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = hidden_states + residual + residual = hidden_states + + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = hidden_states + residual + + hidden_states = hidden_states * self.layer_scalar + return hidden_states, None + + +class Gemma4MultiTokenPredictor(nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.speculative_config.draft_model_config.hf_config + text_config = _get_text_config(config) + self.config = text_config + + self.hidden_size = text_config.hidden_size + self.backbone_hidden_size = getattr( + config, "backbone_hidden_size", self.hidden_size + ) + self.vocab_size = text_config.vocab_size + self.num_mtp_layers = text_config.num_hidden_layers + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + self.hidden_size, + ) + + self.pre_projection = ColumnParallelLinear( + 2 * self.backbone_hidden_size, + self.hidden_size, + bias=False, + gather_output=True, + prefix=f"{prefix}.pre_projection", + ) + + self.post_projection = RowParallelLinear( + self.hidden_size, + self.backbone_hidden_size, + bias=False, + input_is_parallel=False, + prefix=f"{prefix}.post_projection", + ) + + self.layers = nn.ModuleList( + Gemma4MTPDecoderLayer( + text_config, + cache_config=vllm_config.cache_config, + quant_config=vllm_config.quant_config, + prefix=f"{prefix}.layers.{idx}", + ) + for idx in range(self.num_mtp_layers) + ) + + self.norm = RMSNorm(self.hidden_size, eps=text_config.rms_norm_eps) + + # After embedding sharing, embed_tokens is replaced with the + # target model's backbone-dim embedding. Scale by + # sqrt(backbone_hidden_size) to match the target's convention. + self.register_buffer( + "normalizer", + torch.tensor(self.backbone_hidden_size**0.5), + persistent=False, + ) + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) * self.normalizer + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + params_dict.update(dict(self.named_buffers())) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if name.endswith(".bias") and name not in params_dict: + continue + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if name.endswith(".bias") and name not in params_dict: + continue + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + return loaded_params + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + spec_step_idx: int = 0, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Returns (draft_hidden_states, backbone_hidden_states). + + draft_hidden_states: draft-dim, used by compute_logits via lm_head. + backbone_hidden_states: backbone-dim, stored in the proposer's + hidden-state buffer and fed back as input to the next step. + """ + if inputs_embeds is None: + inputs_embeds = self.embed_input_ids(input_ids) + + combined = torch.cat([inputs_embeds, hidden_states], dim=-1) + hidden_states, _ = self.pre_projection(combined) + + residual = None + for layer in self.layers: + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + residual=residual, + ) + + draft_hidden_states = self.norm(hidden_states) + + backbone_hidden_states, _ = self.post_projection(draft_hidden_states) + return draft_hidden_states, backbone_hidden_states + + +@support_torch_compile +class Gemma4MTP(nn.Module): + """Gemma4 Multi-Token Prediction model for speculative decoding. + + forward() returns (draft_hidden_states, backbone_hidden_states). + The proposer uses draft_hidden_states for compute_logits (via + the draft-dim lm_head) and backbone_hidden_states for the + hidden-state feedback buffer. + """ + + has_own_lm_head = True + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "pre_projection.": "model.pre_projection.", + "post_projection.": "model.post_projection.", + }, + ) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.speculative_config.draft_model_config.hf_config + text_config = _get_text_config(config) + self.config = config + self._stable_full_lm_head_weight: torch.Tensor | None = None + + self.model = Gemma4MultiTokenPredictor( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "draft_model"), + ) + + # lm_head operates in draft-dim. Tied to embed_tokens at init + # so load_weights populates both from a single checkpoint entry. + # After embedding sharing, lm_head.weight still references the + # original draft-dim tensor. + self.lm_head = ParallelLMHead( + text_config.vocab_size, + text_config.hidden_size, + prefix=maybe_prefix(prefix, "lm_head"), + ) + if getattr(config, "tie_word_embeddings", True): + self.lm_head.weight = self.model.embed_tokens.weight + + self.logits_processor = LogitsProcessor( + text_config.vocab_size, + soft_cap=getattr(text_config, "final_logit_softcapping", None), + ) + + if getattr(config, "use_ordered_embeddings", False): + num_centroids = getattr(config, "num_centroids", 2048) + top_k = getattr(config, "centroid_intermediate_top_k", 32) + self.masked_embedding = Gemma4MTPMaskedEmbedder( + hidden_size=text_config.hidden_size, + vocab_size=text_config.vocab_size, + num_centroids=num_centroids, + centroid_intermediate_top_k=top_k, + ) + logger.info( + "Gemma4 MTP: centroids masking enabled " + "(num_centroids=%d, top_k=%d, active_tokens=%d/%d).", + num_centroids, + top_k, + top_k * (text_config.vocab_size // num_centroids), + text_config.vocab_size, + ) + else: + self.masked_embedding = None + + draft_cfg = vllm_config.speculative_config.draft_model_config + gen_cfg = draft_cfg.try_get_generation_config() + self._suppress_token_ids = gen_cfg.get("suppress_tokens") if gen_cfg else None + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) + + def forward( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + hidden_states: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + spec_step_idx: int = 0, + **kwargs: object, + ) -> tuple[torch.Tensor, torch.Tensor]: + return self.model( + input_ids, + positions, + hidden_states, + intermediate_tensors, + inputs_embeds, + spec_step_idx, + ) + + def _get_full_lm_head_weight(self) -> torch.Tensor: + if self._stable_full_lm_head_weight is not None: + return self._stable_full_lm_head_weight + lm_head_weight = self.lm_head.weight + tp_size = get_tensor_model_parallel_world_size() + if tp_size > 1: + lm_head_weight = tensor_model_parallel_all_gather( + lm_head_weight, + dim=0, + ) + lm_head_weight = lm_head_weight[: self.masked_embedding.vocab_size] + if tp_size > 1: + lm_head_weight = lm_head_weight.contiguous() + self._stable_full_lm_head_weight = lm_head_weight + return lm_head_weight + + def compute_logits( + self, + hidden_states: torch.Tensor, + spec_step_idx: int = 0, + ) -> torch.Tensor | None: + if self.masked_embedding is not None: + logits = self.masked_embedding( + hidden_states, + self._get_full_lm_head_weight(), + ) + else: + logits = self.logits_processor(self.lm_head, hidden_states) + if logits is not None and self._suppress_token_ids: + logits[:, self._suppress_token_ids] = -float("inf") + return logits + + def get_top_tokens( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + """Sparse argmax via centroids masking. Returns token IDs directly.""" + return self.masked_embedding.get_top_tokens( + hidden_states, + self._get_full_lm_head_weight(), + ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + self._stable_full_lm_head_weight = None + loader = AutoWeightsLoader(self) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/v1/spec_decode/gemma4.py b/vllm/v1/spec_decode/gemma4.py new file mode 100644 index 0000000000..7f67ae9f49 --- /dev/null +++ b/vllm/v1/spec_decode/gemma4.py @@ -0,0 +1,340 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Gemma4 MTP (Multi-Token Prediction) proposer for speculative decoding. + +The Gemma4 assistant model runs all decoder layers per draft step +(producing one token), and all its attention layers share KV cache +with the target model via cross-model KV sharing. +""" + +from collections import defaultdict +from copy import copy + +import torch +import torch.nn as nn + +from vllm.config import VllmConfig, get_layers_from_vllm_config, replace +from vllm.logger import init_logger +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.v1.attention.backend import CommonAttentionMetadata +from vllm.v1.kv_cache_interface import ( + KVCacheConfig, + KVCacheSpec, + UniformTypeKVCacheSpecs, +) +from vllm.v1.spec_decode.llm_base_proposer import SpecDecodeBaseProposer +from vllm.v1.worker.utils import AttentionGroup + +logger = init_logger(__name__) + + +class Gemma4Proposer(SpecDecodeBaseProposer): + def __init__( + self, + vllm_config: VllmConfig, + device: torch.device, + runner=None, + ): + super().__init__( + vllm_config, + device, + pass_hidden_states_to_model=True, + runner=runner, + ) + # All draft steps predict from the same position (the last + # target-model position), so positions and seq_lens must not + # advance between steps. + self.constant_draft_positions = True + + # Per-group block tables for multi-group KV cache models. + # Populated by gpu_model_runner during _prepare_inputs. + self._per_group_block_tables: dict[int, torch.Tensor] = {} + + # Centroids CUDA graphs — populated in load_model if centroids + # masking is active. _centroids_sizes is pre-sorted for fast + # lookup in _greedy_sample. + self._centroids_sizes: list[int] = [] + self._centroids_graphs: dict[int, torch.cuda.CUDAGraph] = {} + self._centroids_inputs: dict[int, torch.Tensor] = {} + self._centroids_outputs: dict[int, torch.Tensor] = {} + + def set_per_group_block_table(self, gid: int, block_table: torch.Tensor) -> None: + self._per_group_block_tables[gid] = block_table + + def model_returns_tuple(self) -> bool: + # forward() returns (draft_hidden_states, backbone_hidden_states). + # The proposer uses draft_hidden_states for compute_logits and + # backbone_hidden_states for the hidden-state feedback buffer. + return True + + def build_per_group_and_layer_attn_metadata( + self, + common_attn_metadata: CommonAttentionMetadata, + draft_index: int = 0, + ) -> tuple[list[object], dict[str, object]]: + """Build attention metadata using the correct block table per group. + + Gemma4 has multiple KV cache groups (sliding vs full attention) + with different block tables. The base class receives a single + common_attn_metadata whose block_table belongs to one group. + We swap in the correct block table for each draft attention group. + """ + per_group_attn_metadata: list[object] = [] + per_layer_attn_metadata: dict[str, object] = {} + batch_size = common_attn_metadata.batch_size() + for attn_group in self.draft_attn_groups: + gid = attn_group.kv_cache_group_id + if gid in self._per_group_block_tables: + cm = copy(common_attn_metadata) + # Slice to actual batch size to match cu_seqlens_q dimension. + # The stored block tables may be padded (num_reqs_padded) from + # the target forward pass, but the drafter operates on the + # unpadded batch. + cm.block_table_tensor = self._per_group_block_tables[gid][:batch_size] + else: + cm = common_attn_metadata + attn_metadata = attn_group.get_metadata_builder().build_for_drafting( + common_attn_metadata=cm, draft_index=draft_index + ) + per_group_attn_metadata.append(attn_metadata) + for layer_name in attn_group.layer_names: + per_layer_attn_metadata[layer_name] = attn_metadata + return per_group_attn_metadata, per_layer_attn_metadata + + def _greedy_sample(self, hidden_states: torch.Tensor) -> torch.Tensor: + if self._centroids_sizes: + T = hidden_states.shape[0] + for size in self._centroids_sizes: + if size >= T: + self._centroids_inputs[size][:T].copy_(hidden_states) + self._centroids_graphs[size].replay() + return self._centroids_outputs[size][:T].clone() + return self.model.get_top_tokens(hidden_states) + return super()._greedy_sample(hidden_states) + + def _setup_centroids_cuda_graphs(self) -> None: + """Capture CUDA graphs for centroids get_top_tokens at key sizes.""" + masked_emb = self.model.masked_embedding + lm_head_weight = self.model._get_full_lm_head_weight() + + for size in [1, 2, 4, 8, 16, 32, 64]: + static_input = torch.zeros( + size, + masked_emb.hidden_size, + dtype=self.dtype, + device=self.device, + ) + for _ in range(3): + masked_emb.get_top_tokens(static_input, lm_head_weight) + torch.accelerator.synchronize() + + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + static_output = masked_emb.get_top_tokens( + static_input, + lm_head_weight, + ) + self._centroids_graphs[size] = g + self._centroids_inputs[size] = static_input + self._centroids_outputs[size] = static_output + + self._centroids_sizes = sorted(self._centroids_graphs) + logger.info( + "Gemma4 MTP: captured centroids CUDA graphs for sizes %s.", + self._centroids_sizes, + ) + + def _create_draft_vllm_config(self) -> VllmConfig: + """Preserve the target's forced TRITON_ATTN backend for draft layers. + + Gemma4 forces TRITON_ATTN due to heterogeneous head dimensions + (head_dim=256 sliding, global_head_dim=512 full). The base class + resets attention_config.backend to None for draft models, causing + sliding layers to fall back to FLASH_ATTN which cannot handle + KV-shared cache. Override to carry the target's backend through. + """ + base = super()._create_draft_vllm_config() + target_backend = self.vllm_config.attention_config.backend + if target_backend is not None: + base = replace( + base, + attention_config=replace( + base.attention_config, + backend=target_backend, + ), + ) + return base + + def _maybe_share_lm_head(self, target_language_model: nn.Module) -> None: + """Gemma4 MTP always keeps its own draft-dim lm_head. + + The draft model's lm_head operates in draft hidden_size (e.g. 256), + which differs from the target's backbone hidden_size (e.g. 1536). + Sharing would break compute_logits (and centroids masking when + use_ordered_embeddings is enabled). + """ + logger.info( + "Gemma4 MTP: keeping draft model's own lm_head (draft_dim != backbone_dim)." + ) + + def load_model(self, target_model: nn.Module) -> None: + target_attn_layer_names = set( + get_layers_from_vllm_config( + self.vllm_config, + AttentionLayerBase, # type: ignore[type-abstract] + ).keys() + ) + + super().load_model(target_model) + + self._setup_gemma4_kv_sharing(target_attn_layer_names) + + if getattr(self.model, "masked_embedding", None) is not None: + self._setup_centroids_cuda_graphs() + + def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None: + """Draft layers span multiple KV cache groups (sliding + full + attention with different head dimensions), so skip the base + class single-group assertion.""" + + def initialize_attn_backend( + self, + kv_cache_config: KVCacheConfig, + kernel_block_sizes: list[int] | None = None, + ) -> None: + """Create separate AttentionGroup objects per KV cache spec + so that each head-dim variant gets its own metadata builder.""" + all_attn_layers = get_layers_from_vllm_config( + self.vllm_config, + AttentionLayerBase, # type: ignore[type-abstract] + ) + + layer_to_gid: dict[str, int] = {} + layer_to_spec: dict[str, KVCacheSpec] = {} + for gid, group in enumerate(kv_cache_config.kv_cache_groups): + group_spec = group.kv_cache_spec + for ln in group.layer_names: + layer_to_gid[ln] = gid + if isinstance(group_spec, UniformTypeKVCacheSpecs): + if ln in group_spec.kv_cache_specs: + layer_to_spec[ln] = group_spec.kv_cache_specs[ln] + else: + tgt = getattr( + all_attn_layers.get(ln), + "kv_sharing_target_layer_name", + None, + ) + if tgt and tgt in group_spec.kv_cache_specs: + layer_to_spec[ln] = group_spec.kv_cache_specs[tgt] + else: + layer_to_spec[ln] = group_spec + else: + layer_to_spec[ln] = group_spec + + attention_groups: dict[tuple[tuple[str, str], KVCacheSpec], AttentionGroup] = {} + for layer_name in self._draft_attn_layer_names: + if layer_name not in layer_to_spec: + continue + attn_layer = all_attn_layers[layer_name] + attn_backend = attn_layer.get_attn_backend() + spec = layer_to_spec[layer_name] + gid = layer_to_gid[layer_name] + group_key = (attn_backend.full_cls_name(), spec) + + if group_key not in attention_groups: + kernel_block_size = ( + kernel_block_sizes[gid] + if kernel_block_sizes is not None and gid < len(kernel_block_sizes) + else None + ) + attn_group = AttentionGroup( + backend=attn_backend, + layer_names=[layer_name], + kv_cache_spec=spec, + kv_cache_group_id=gid, + ) + attn_group.create_metadata_builders( + self.vllm_config, + self.device, + kernel_block_size=kernel_block_size, + ) + attention_groups[group_key] = attn_group + else: + attention_groups[group_key].layer_names.append(layer_name) + + self.draft_attn_groups = list(attention_groups.values()) + if self.draft_attn_groups: + self.kv_cache_gid = self.draft_attn_groups[0].kv_cache_group_id + self.block_size = ( + self.draft_attn_groups[0] + .get_metadata_builder() + .kv_cache_spec.block_size + ) + else: + self.kv_cache_gid = 0 + self.block_size = kv_cache_config.kv_cache_groups[ + 0 + ].kv_cache_spec.block_size + logger.debug("Using block size %d for drafting layers", self.block_size) + + def _setup_gemma4_kv_sharing( + self, + target_attn_layer_names: set[str], + ) -> None: + """Wire draft layers to share KV with the target model. + + Each draft decoder layer is mapped to the last non-KV-shared + target layer of the same attention type (sliding or full). + """ + draft_config = self.speculative_config.draft_model_config.hf_config + draft_text_config = draft_config.get_text_config() + target_config = self.vllm_config.model_config.hf_config + target_text_config = target_config.get_text_config() + target_layer_types = getattr(target_text_config, "layer_types", []) + + if not (hasattr(self.model, "model") and hasattr(self.model.model, "layers")): + return + + target_num_kv_shared = getattr(target_text_config, "num_kv_shared_layers", 0) + num_non_shared = len(target_layer_types) - target_num_kv_shared + type_to_target_indices: dict[str, list[int]] = defaultdict(list) + for idx, lt in enumerate(target_layer_types[:num_non_shared]): + type_to_target_indices[lt].append(idx) + + target_prefix = "model.layers" + for name in target_attn_layer_names: + if ".layers." in name: + target_prefix = name.split(".layers.")[0] + ".layers" + break + + draft_layer_types = getattr(draft_text_config, "layer_types", []) + for draft_idx, layer in enumerate(self.model.model.layers): + if not hasattr(layer, "self_attn"): + continue + attn = getattr(layer.self_attn, "attn", None) + if attn is None: + continue + + draft_layer_type = ( + draft_layer_types[draft_idx] + if draft_idx < len(draft_layer_types) + else "full_attention" + ) + candidates = type_to_target_indices.get(draft_layer_type, []) + if not candidates: + logger.warning( + "No target layer of type '%s' for draft layer %d", + draft_layer_type, + draft_idx, + ) + continue + + target_idx = candidates[-1] + target_layer_name = f"{target_prefix}.{target_idx}.self_attn.attn" + attn.kv_sharing_target_layer_name = target_layer_name + logger.info( + "Gemma4 MTP: draft layer %d (%s) -> %s", + draft_idx, + draft_layer_type, + target_layer_name, + ) From 2b28dbc068eb899709f3cfacb25324ad0fe0bcb6 Mon Sep 17 00:00:00 2001 From: Phil Date: Fri, 5 Jun 2026 21:35:12 +0000 Subject: [PATCH 10/25] [gemma4] spec-decode proposer imports clean (redirect base-class import) Gemma4Proposer's overridden methods all already exist on our base SpecDecodeBaseProposer; only its 3 gemma4-specific methods are new. Our base keeps SpecDecodeBaseProposer in eagle.py rather than upstream's separate llm_base_proposer module, so redirect the import there. All four gemma-4 modules (gemma4, gemma4_mtp, spec_decode/gemma4) now import clean against base d4f98f3b1, verified on the live .venv-v110 runtime. Co-Authored-By: RivetOS Claude --- vllm/v1/spec_decode/gemma4.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/v1/spec_decode/gemma4.py b/vllm/v1/spec_decode/gemma4.py index 7f67ae9f49..242cd2da24 100644 --- a/vllm/v1/spec_decode/gemma4.py +++ b/vllm/v1/spec_decode/gemma4.py @@ -22,7 +22,9 @@ KVCacheSpec, UniformTypeKVCacheSpecs, ) -from vllm.v1.spec_decode.llm_base_proposer import SpecDecodeBaseProposer +# NOTE(rivet): upstream keeps SpecDecodeBaseProposer in v1/spec_decode/ +# llm_base_proposer; this base keeps it in eagle.py. Redirect accordingly. +from vllm.v1.spec_decode.eagle import SpecDecodeBaseProposer from vllm.v1.worker.utils import AttentionGroup logger = init_logger(__name__) From c554f752a12035f80574e3b1e9301c5a914c035d Mon Sep 17 00:00:00 2001 From: Phil Date: Fri, 5 Jun 2026 21:37:31 +0000 Subject: [PATCH 11/25] [gemma4] Config wiring: arch convertors + speculative gemma4_mtp recognition - model_arch_config_convertor: backport Gemma4ModelArchConfigConvertor (dual head_dim/global_head_dim sizing) + add Gemma4MTPModelArchConfigConvertor (speculator buffer sized to backbone_hidden_size); register gemma4/ gemma4_text/gemma4_mtp. - speculative: add gemma4_mtp to MTP types; remap model_type gemma4_assistant -> gemma4_mtp (n_predict=1, Gemma4MTPModel, zero cross-model KV-shared layers); add use_gemma4_mtp(). Co-Authored-By: RivetOS Claude --- vllm/config/speculative.py | 20 ++++++++++++ .../model_arch_config_convertor.py | 31 +++++++++++++++++++ 2 files changed, 51 insertions(+) diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index c7587a0297..cd7db05788 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -42,6 +42,7 @@ "mtp", "pangu_ultra_moe_mtp", "step3p5_mtp", + "gemma4_mtp", ] DFlashModelTypes = Literal["dflash"] EagleModelTypes = Literal["eagle", "eagle3", MTPModelTypes, DFlashModelTypes] @@ -289,6 +290,17 @@ def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig: n_predict = getattr(hf_config, "num_nextn_predict_layers", 1) hf_config.update({"n_predict": n_predict, "architectures": ["Step3p5MTP"]}) + if hf_config.model_type == "gemma4_assistant": + hf_config.model_type = "gemma4_mtp" + text_config = getattr(hf_config, "text_config", hf_config) + # The assistant runs all decoder layers in a single forward + # call to produce one draft token, so n_predict=1. + # num_kv_shared_layers must be 0: cross-model KV sharing is + # set up by the proposer after model construction. + if hasattr(text_config, "num_kv_shared_layers"): + text_config.num_kv_shared_layers = 0 + hf_config.update({"n_predict": 1, "architectures": ["Gemma4MTPModel"]}) + if initial_architecture == "MistralLarge3ForCausalLM": hf_config.update({"architectures": ["EagleMistralLarge3ForCausalLM"]}) @@ -758,6 +770,14 @@ def max_num_new_slots_for_drafting(self) -> int: slots_per_req += 1 return slots_per_req + def use_gemma4_mtp(self) -> bool: + return ( + self.method == "mtp" + and self.draft_model_config is not None + and getattr(self.draft_model_config.hf_config, "model_type", None) + == "gemma4_mtp" + ) + def use_eagle(self) -> bool: return self.method in ("eagle", "eagle3", "mtp", "dflash") diff --git a/vllm/transformers_utils/model_arch_config_convertor.py b/vllm/transformers_utils/model_arch_config_convertor.py index 25f8fb935c..11949ba2e5 100644 --- a/vllm/transformers_utils/model_arch_config_convertor.py +++ b/vllm/transformers_utils/model_arch_config_convertor.py @@ -429,6 +429,34 @@ def get_num_hidden_layers(self) -> int: return getattr(self.hf_text_config, "num_nextn_predict_layers", 1) +class Gemma4ModelArchConfigConvertor(ModelArchConfigConvertorBase): + def is_mm_prefix_lm(self) -> bool: + return ( + getattr(self.hf_text_config, "use_bidirectional_attention", None) + == "vision" + ) + + def get_head_size(self) -> int: + # Gemma4 uses dual head dimensions: head_dim (sliding attention) + # and global_head_dim (full attention). Return the largest so + # that attention backends allocate buffers large enough for both. + head_dim = getattr(self.hf_text_config, "head_dim", 0) + global_head_dim = getattr(self.hf_text_config, "global_head_dim", 0) + return max(head_dim, global_head_dim) or super().get_head_size() + + +class Gemma4MTPModelArchConfigConvertor(ModelArchConfigConvertorBase): + def get_hidden_size(self) -> int: + # The speculator buffer must match the backbone (target) model's + # hidden dimension, not the draft model's smaller dimension. + return getattr( + self.hf_config, "backbone_hidden_size", super().get_hidden_size() + ) + + def get_num_hidden_layers(self) -> int: + return getattr(self.hf_text_config, "num_hidden_layers", 0) + + # hf_config.model_type -> convertor class MODEL_ARCH_CONFIG_CONVERTORS = { "mamba": MambaModelArchConfigConvertor, @@ -439,6 +467,9 @@ def get_num_hidden_layers(self) -> int: "mpt": MPTModelArchConfigConvertor, "dbrx": DbrxModelArchConfigConvertor, "falcon": FalconModelArchConfigConvertor, + "gemma4": Gemma4ModelArchConfigConvertor, + "gemma4_text": Gemma4ModelArchConfigConvertor, + "gemma4_mtp": Gemma4MTPModelArchConfigConvertor, "RefinedWeb": FalconModelArchConfigConvertor, "RefinedWebModel": FalconModelArchConfigConvertor, "nemotron-nas": NemotronNasModelArchConfigConvertor, From 78270816630ef80db045ee0e76d755b5ae3e8565 Mon Sep 17 00:00:00 2001 From: Phil Date: Fri, 5 Jun 2026 22:54:25 +0000 Subject: [PATCH 12/25] [gemma4] Wire proposer dispatch + base-proposer constant-positions path gpu_model_runner: import Gemma4Proposer; construct it for use_gemma4_mtp() (before use_eagle, since gemma4 MTP is method "mtp"); add it to the Eagle/ DFlash isinstance + union sites; capture per-group block tables for it. eagle.py (our base's home for SpecDecodeBaseProposer): add constant_draft_ positions (default False, so existing proposers incl. Deckard qwen3_5_mtp are byte-identical); extract the per-step slot-mapping/metadata update into _update_positions_dependent_metadata; guard it + the attn-metadata rebuild so the Gemma4 constant-positions drafter builds once and reuses. Verified on live .venv-v110: Gemma4ForCausalLM + Gemma4MTPModel register; gpu_model_runner + Gemma4Proposer import clean. Co-Authored-By: RivetOS Claude --- vllm/v1/spec_decode/eagle.py | 133 +++-- vllm/v1/worker/gpu_model_runner.py | 902 +++++++++++++++-------------- 2 files changed, 541 insertions(+), 494 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 5c2aa101b8..5ce48b6ea1 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -137,6 +137,12 @@ def __init__( ) self.needs_extra_input_slots = self.net_num_new_slots_per_request > 0 + # When True, all draft steps reuse the same position as the first step + # instead of advancing by one each iteration. Used by draft models with + # Q-only attention that share KV with the target and always predict from + # the same position (e.g. Gemma4 MTP). + self.constant_draft_positions: bool = False + self.parallel_drafting_token_id: int = 0 self.parallel_drafting_hidden_state_tensor: torch.Tensor | None = None if self.parallel_drafting: @@ -638,6 +644,12 @@ def propose( positions = self.positions[token_indices_to_sample] hidden_states = hidden_states[token_indices_to_sample] + if self.constant_draft_positions: + # Write the sampling positions into the front of the positions + # buffer so subsequent loop iterations (which read via + # _get_positions) use the correct values. + self.positions[:batch_size] = positions + if any(isinstance(md, TreeAttentionMetadata) for md in per_group_attn_metadata): # Draft using tree attention - requires full logits for top-k logits = self.model.compute_logits(sample_hidden_states) @@ -709,57 +721,25 @@ def propose( # cast to int32 is crucial when eagle model is compiled. # tensor.argmax() returns int64 by default. input_ids = draft_token_ids_list[-1].int() - # Use fused kernel for slot mapping and metadata updates. - # Write clamped positions directly into the positions buffer to - # avoid an extra D2D copy for the common (non-mrope) case. - positions_1d = positions[0] if self.uses_mrope else positions - if self.uses_mrope: - out_pos = self.mrope_positions[0, :batch_size] - elif self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim > 0: - out_pos = self.xdrope_positions[0, :batch_size] - else: - out_pos = self.positions[:batch_size] - eagle_step_update_slot_mapping_and_metadata( - positions_1d=positions_1d, - block_table_tensor=common_attn_metadata.block_table_tensor, - seq_lens=common_attn_metadata.seq_lens, - block_size=block_size, - max_model_len=self.max_model_len, - out_clamped_positions=out_pos, - out_slot_mapping=self._slot_mapping_buffer[:input_batch_size], - input_batch_size=input_batch_size, - ) - common_attn_metadata.slot_mapping = self._slot_mapping_buffer[:batch_size] - if self.uses_mrope: - self.mrope_positions[1:, :batch_size] = self.mrope_positions[ - 0, :batch_size - ] - positions = self.mrope_positions[:, :batch_size] - elif self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim > 0: - self.xdrope_positions[1:, :batch_size] = self.xdrope_positions[ - 0, :batch_size - ] - positions = self.xdrope_positions[0, :batch_size] - else: - positions = self.positions[:batch_size] - # Increment the maximum sequence length. We increment max_seq_len - # unconditionally even though some seq_lens may have been capped above, - # as max_seq_len serves as an upper bound for sequence lengths. - common_attn_metadata.max_seq_len = min( - common_attn_metadata.max_seq_len + 1, self.max_model_len - ) - # Also update the CPU-side shadow; NOTE: this is hacky and should be - # removed in when common_attn_metadata.seq_lens_cpu is deprecated. - if common_attn_metadata._seq_lens_cpu is not None: - common_attn_metadata._seq_lens_cpu += 1 - if common_attn_metadata._num_computed_tokens_cpu is not None: - common_attn_metadata._num_computed_tokens_cpu += 1 + if not self.constant_draft_positions: + positions = self._update_positions_dependent_metadata( + positions, + common_attn_metadata, + batch_size, + input_batch_size, + block_size, + ) - # Rebuild attention metadata - _, per_layer_attn_metadata = self.build_per_group_and_layer_attn_metadata( - common_attn_metadata, draft_index=token_index + 1 - ) + # Rebuild attention metadata. When draft positions are constant + # (e.g. Gemma4 MTP), common_attn_metadata is invariant across loop + # iterations so we build once (token_index == 0) and reuse. + if not self.constant_draft_positions or token_index == 0: + _, per_layer_attn_metadata = ( + self.build_per_group_and_layer_attn_metadata( + common_attn_metadata, draft_index=token_index + 1 + ) + ) # copy inputs to buffer for cudagraph self.input_ids[:batch_size] = input_ids @@ -812,6 +792,61 @@ def propose( self._last_draft_probs = torch.stack(draft_probs_list, dim=1).contiguous() return draft_token_ids + def _update_positions_dependent_metadata( + self, + positions: torch.Tensor, + common_attn_metadata, + batch_size: int, + input_batch_size: int, + block_size: int, + ) -> torch.Tensor: + """Update positions, slot mappings, and sequence metadata for the next + draft step. Returns the updated positions tensor. Extracted from + propose() so the Gemma4-MTP constant-positions path can skip it.""" + # Use fused kernel for slot mapping and metadata updates. Write clamped + # positions directly into the positions buffer to avoid an extra D2D + # copy for the common (non-mrope) case. + positions_1d = positions[0] if self.uses_mrope else positions + if self.uses_mrope: + out_pos = self.mrope_positions[0, :batch_size] + elif self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim > 0: + out_pos = self.xdrope_positions[0, :batch_size] + else: + out_pos = self.positions[:batch_size] + eagle_step_update_slot_mapping_and_metadata( + positions_1d=positions_1d, + block_table_tensor=common_attn_metadata.block_table_tensor, + seq_lens=common_attn_metadata.seq_lens, + block_size=block_size, + max_model_len=self.max_model_len, + out_clamped_positions=out_pos, + out_slot_mapping=self._slot_mapping_buffer[:input_batch_size], + input_batch_size=input_batch_size, + ) + common_attn_metadata.slot_mapping = self._slot_mapping_buffer[:batch_size] + if self.uses_mrope: + self.mrope_positions[1:, :batch_size] = self.mrope_positions[ + 0, :batch_size + ] + positions = self.mrope_positions[:, :batch_size] + elif self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim > 0: + self.xdrope_positions[1:, :batch_size] = self.xdrope_positions[ + 0, :batch_size + ] + positions = self.xdrope_positions[0, :batch_size] + else: + positions = self.positions[:batch_size] + # Increment max_seq_len unconditionally (upper bound for seq lengths). + common_attn_metadata.max_seq_len = min( + common_attn_metadata.max_seq_len + 1, self.max_model_len + ) + # Update the CPU-side shadow. + if common_attn_metadata._seq_lens_cpu is not None: + common_attn_metadata._seq_lens_cpu += 1 + if common_attn_metadata._num_computed_tokens_cpu is not None: + common_attn_metadata._num_computed_tokens_cpu += 1 + return positions + def set_inputs_first_pass( self, target_token_ids: torch.Tensor, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 3d26a96998..494f386446 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1,12 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import functools -import gc -import itertools -import os -import threading -import time +import functools +import gc +import itertools +import os +import threading +import time from collections import defaultdict from collections.abc import Iterable, Iterator, Sequence from contextlib import contextmanager @@ -113,13 +113,13 @@ CommonAttentionMetadata, MultipleOf, ) -from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder -from vllm.v1.attention.backends.utils import ( - PAD_SLOT_ID, - create_fast_prefill_custom_backend, - get_dcp_local_seq_lens, - reorder_batch_to_split_decodes_and_prefills, -) +from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder +from vllm.v1.attention.backends.utils import ( + PAD_SLOT_ID, + create_fast_prefill_custom_backend, + get_dcp_local_seq_lens, + reorder_batch_to_split_decodes_and_prefills, +) from vllm.v1.core.sched.output import NewRequestData from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher from vllm.v1.kv_cache_interface import ( @@ -152,11 +152,12 @@ from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs from vllm.v1.sample.logits_processor.interface import LogitsProcessor from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.sample.rejection_sampler import RejectionSampler -from vllm.v1.sample.sampler import Sampler -from vllm.v1.spec_decode.dflash import DFlashProposer -from vllm.v1.spec_decode.draft_model import DraftModelProposer -from vllm.v1.spec_decode.eagle import EagleProposer +from vllm.v1.sample.rejection_sampler import RejectionSampler +from vllm.v1.sample.sampler import Sampler +from vllm.v1.spec_decode.dflash import DFlashProposer +from vllm.v1.spec_decode.draft_model import DraftModelProposer +from vllm.v1.spec_decode.eagle import EagleProposer +from vllm.v1.spec_decode.gemma4 import Gemma4Proposer from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer @@ -312,12 +313,12 @@ def get_output(self) -> ModelRunnerOutput: return self._model_runner_output -class ExecuteModelState(NamedTuple): +class ExecuteModelState(NamedTuple): """Ephemeral cached state transferred between execute_model() and sample_tokens(), after execute_model() returns None.""" scheduler_output: "SchedulerOutput" - logits: torch.Tensor + logits: torch.Tensor spec_decode_metadata: SpecDecodeMetadata | None spec_decode_common_attn_metadata: CommonAttentionMetadata | None hidden_states: torch.Tensor @@ -375,14 +376,14 @@ def __init__( # Always set to false after the first forward pass self.calculate_kv_scales = self.cache_config.calculate_kv_scales self.dcp_world_size = self.parallel_config.decode_context_parallel_size - self.dcp_rank = 0 if self.dcp_world_size <= 1 else get_dcp_group().rank_in_group - self.max_num_tokens = scheduler_config.max_num_batched_tokens - self.max_num_reqs = scheduler_config.max_num_seqs - self.max_spec_state_slots = 1 + ( - self.speculative_config.num_speculative_tokens - if self.speculative_config is not None - else 0 - ) + self.dcp_rank = 0 if self.dcp_world_size <= 1 else get_dcp_group().rank_in_group + self.max_num_tokens = scheduler_config.max_num_batched_tokens + self.max_num_reqs = scheduler_config.max_num_seqs + self.max_spec_state_slots = 1 + ( + self.speculative_config.num_speculative_tokens + if self.speculative_config is not None + else 0 + ) # Broadcast PP output for external_launcher (torchrun) # to make sure we are synced across pp ranks @@ -453,29 +454,31 @@ def __init__( if self.speculative_config and get_pp_group().is_last_rank: self.drafter: ( NgramProposer # noqa: F823 - | SuffixDecodingProposer - | EagleProposer - | DFlashProposer - | DraftModelProposer - | MedusaProposer - ) - if self.speculative_config.method == "ngram": - from vllm.v1.spec_decode.ngram_proposer import NgramProposer - - self.drafter = NgramProposer(self.vllm_config) - elif self.speculative_config.use_dflash(): - self.drafter = DFlashProposer(self.vllm_config, self.device, self) - self.use_aux_hidden_state_outputs = ( - os.getenv("VLLM_DFLASH_DISABLE_AUX_OUTPUTS", "0") != "1" - ) - elif self.speculative_config.uses_draft_model(): - self.drafter = DraftModelProposer( - vllm_config=self.vllm_config, + | SuffixDecodingProposer + | EagleProposer + | DFlashProposer + | DraftModelProposer + | MedusaProposer + ) + if self.speculative_config.method == "ngram": + from vllm.v1.spec_decode.ngram_proposer import NgramProposer + + self.drafter = NgramProposer(self.vllm_config) + elif self.speculative_config.use_dflash(): + self.drafter = DFlashProposer(self.vllm_config, self.device, self) + self.use_aux_hidden_state_outputs = ( + os.getenv("VLLM_DFLASH_DISABLE_AUX_OUTPUTS", "0") != "1" + ) + elif self.speculative_config.uses_draft_model(): + self.drafter = DraftModelProposer( + vllm_config=self.vllm_config, device=self.device, runner=self, ) elif self.speculative_config.method == "suffix": self.drafter = SuffixDecodingProposer(self.vllm_config) + elif self.speculative_config.use_gemma4_mtp(): + self.drafter = Gemma4Proposer(self.vllm_config, self.device, self) elif self.speculative_config.use_eagle(): self.drafter = EagleProposer(self.vllm_config, self.device, self) if self.speculative_config.method == "eagle3": @@ -599,16 +602,16 @@ def __init__( self.num_decode_draft_tokens = self._make_buffer( self.max_num_reqs, dtype=torch.int32 ) - self.num_accepted_tokens = self._make_buffer( - self.max_num_reqs, dtype=torch.int64 - ) - self.current_mamba_state_block_ids = self._make_buffer( - self.max_num_reqs, - self.max_spec_state_slots, - dtype=torch.int32, - ) - - # Only relevant for multimodal models + self.num_accepted_tokens = self._make_buffer( + self.max_num_reqs, dtype=torch.int64 + ) + self.current_mamba_state_block_ids = self._make_buffer( + self.max_num_reqs, + self.max_spec_state_slots, + dtype=torch.int32, + ) + + # Only relevant for multimodal models if self.supports_mm_inputs: # Double buffer to avoid race condition: previous iteration's async # copy may still be reading from CPU while current iteration writes. @@ -682,14 +685,14 @@ def __init__( # KVCacheConfig of the scheduler. self.runner_only_attn_layers: set[str] = set() - # Cached outputs. - self._draft_token_ids: list[list[int]] | torch.Tensor | None = None - self._draft_token_req_ids: list[str] | None = None - self._draft_probs: torch.Tensor | None = None - self._draft_prob_req_ids: list[str] | None = None - self.transfer_event = torch.Event() - self.sampled_token_ids_pinned_cpu = torch.empty( - (self.max_num_reqs, 1), + # Cached outputs. + self._draft_token_ids: list[list[int]] | torch.Tensor | None = None + self._draft_token_req_ids: list[str] | None = None + self._draft_probs: torch.Tensor | None = None + self._draft_prob_req_ids: list[str] | None = None + self.transfer_event = torch.Event() + self.sampled_token_ids_pinned_cpu = torch.empty( + (self.max_num_reqs, 1), dtype=torch.int64, device="cpu", pin_memory=self.pin_memory, @@ -856,7 +859,7 @@ def _init_model_kwargs(self): ) return model_kwargs - def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: + def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: """ Update the order of requests in the batch based on the attention backend's needs. For example, some attention backends (namely MLA) may @@ -875,25 +878,25 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: return if self.reorder_batch_threshold is not None: - reorder_batch_to_split_decodes_and_prefills( - self.input_batch, - scheduler_output, - decode_threshold=self.reorder_batch_threshold, - ) - - # Note: used for model runner override. - @staticmethod - def _split_aux_model_output(model_output): - if not isinstance(model_output, tuple): - return model_output, None - if len(model_output) == 2 and isinstance(model_output[1], list): - return model_output[0], model_output[1] - return model_output[0], list(model_output[1:]) - - # Note: used for model runner override. - def _init_device_properties(self) -> None: - """Initialize attributes from torch.cuda.get_device_properties""" - self.device_properties = torch.cuda.get_device_properties(self.device) + reorder_batch_to_split_decodes_and_prefills( + self.input_batch, + scheduler_output, + decode_threshold=self.reorder_batch_threshold, + ) + + # Note: used for model runner override. + @staticmethod + def _split_aux_model_output(model_output): + if not isinstance(model_output, tuple): + return model_output, None + if len(model_output) == 2 and isinstance(model_output[1], list): + return model_output[0], model_output[1] + return model_output[0], list(model_output[1:]) + + # Note: used for model runner override. + def _init_device_properties(self) -> None: + """Initialize attributes from torch.cuda.get_device_properties""" + self.device_properties = torch.cuda.get_device_properties(self.device) self.num_sms = self.device_properties.multi_processor_count # Note: used for model runner override. @@ -1181,12 +1184,12 @@ def _update_states_after_model_execute( .argmax(-1) .cpu() .numpy() - ) - for i, num_tokens in enumerate(num_accepted_tokens): - self.input_batch.num_accepted_tokens_cpu[i] = num_tokens - self.input_batch.spec_num_accepted_tokens_cpu[i] = num_tokens - if self.cache_config.mamba_cache_mode == "align": - mamba_utils.postprocess_mamba( + ) + for i, num_tokens in enumerate(num_accepted_tokens): + self.input_batch.num_accepted_tokens_cpu[i] = num_tokens + self.input_batch.spec_num_accepted_tokens_cpu[i] = num_tokens + if self.cache_config.mamba_cache_mode == "align": + mamba_utils.postprocess_mamba( scheduler_output, self.kv_cache_config, self.input_batch, @@ -1735,18 +1738,18 @@ def _build_attention_metadata( else: max_seq_len = self.seq_lens.np[:num_reqs].max().item() - spec_sequence_masks_cpu = None - if use_spec_decode: - self.num_accepted_tokens.np[:num_reqs] = ( - self.input_batch.spec_num_accepted_tokens_cpu[:num_reqs] - ) - self.num_accepted_tokens.np[num_reqs:].fill(1) - self.num_accepted_tokens.copy_to_gpu() - spec_sequence_masks_cpu = self.num_decode_draft_tokens.cpu[ - :num_reqs_padded - ] >= 0 - - kv_cache_groups = self.kv_cache_config.kv_cache_groups + spec_sequence_masks_cpu = None + if use_spec_decode: + self.num_accepted_tokens.np[:num_reqs] = ( + self.input_batch.spec_num_accepted_tokens_cpu[:num_reqs] + ) + self.num_accepted_tokens.np[num_reqs:].fill(1) + self.num_accepted_tokens.copy_to_gpu() + spec_sequence_masks_cpu = self.num_decode_draft_tokens.cpu[ + :num_reqs_padded + ] >= 0 + + kv_cache_groups = self.kv_cache_config.kv_cache_groups def _get_block_table(kv_cache_gid: int): assert num_reqs_padded is not None and num_tokens_padded is not None @@ -1804,44 +1807,44 @@ def _get_block_table(kv_cache_gid: int): :num_reqs_padded ] - if logits_indices is not None and self.cache_config.kv_sharing_fast_prefill: - cm_base.num_logits_indices = logits_indices.size(0) - cm_base.logits_indices_padded = self._prepare_kv_sharing_fast_prefill( - logits_indices - ) - - current_mamba_state_block_ids_by_gid: dict[int, torch.Tensor] = {} - - def _get_current_mamba_state_block_ids( - kv_cache_gid: int, - ) -> torch.Tensor | None: - if not use_spec_decode: - return None - if kv_cache_gid in current_mamba_state_block_ids_by_gid: - return current_mamba_state_block_ids_by_gid[kv_cache_gid] - - state_block_ids = self.current_mamba_state_block_ids - state_block_ids.cpu[:num_reqs_padded].fill_(PAD_SLOT_ID) - for req_idx, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): - state_block_idx = self.mamba_state_idx.get(req_id) - if state_block_idx is None: - continue - req_state = self.requests[req_id] - block_ids = req_state.block_ids[kv_cache_gid] - for offset in range(self.max_spec_state_slots): - block_idx = state_block_idx + offset - if 0 <= block_idx < len(block_ids): - state_block_ids.cpu[req_idx, offset] = block_ids[block_idx] - else: - break - state_block_ids.copy_to_gpu(num_reqs_padded) - current_mamba_state_block_ids_by_gid[kv_cache_gid] = ( - state_block_ids.gpu[:num_reqs_padded] - ) - return current_mamba_state_block_ids_by_gid[kv_cache_gid] - - # Cache attention metadata builds across hybrid KV-cache groups - # The only thing that changes between different hybrid KV-cache groups when the + if logits_indices is not None and self.cache_config.kv_sharing_fast_prefill: + cm_base.num_logits_indices = logits_indices.size(0) + cm_base.logits_indices_padded = self._prepare_kv_sharing_fast_prefill( + logits_indices + ) + + current_mamba_state_block_ids_by_gid: dict[int, torch.Tensor] = {} + + def _get_current_mamba_state_block_ids( + kv_cache_gid: int, + ) -> torch.Tensor | None: + if not use_spec_decode: + return None + if kv_cache_gid in current_mamba_state_block_ids_by_gid: + return current_mamba_state_block_ids_by_gid[kv_cache_gid] + + state_block_ids = self.current_mamba_state_block_ids + state_block_ids.cpu[:num_reqs_padded].fill_(PAD_SLOT_ID) + for req_idx, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): + state_block_idx = self.mamba_state_idx.get(req_id) + if state_block_idx is None: + continue + req_state = self.requests[req_id] + block_ids = req_state.block_ids[kv_cache_gid] + for offset in range(self.max_spec_state_slots): + block_idx = state_block_idx + offset + if 0 <= block_idx < len(block_ids): + state_block_ids.cpu[req_idx, offset] = block_ids[block_idx] + else: + break + state_block_ids.copy_to_gpu(num_reqs_padded) + current_mamba_state_block_ids_by_gid[kv_cache_gid] = ( + state_block_ids.gpu[:num_reqs_padded] + ) + return current_mamba_state_block_ids_by_gid[kv_cache_gid] + + # Cache attention metadata builds across hybrid KV-cache groups + # The only thing that changes between different hybrid KV-cache groups when the # same metadata builder and KVCacheSpec is the same is the block table, so we # can cache the attention metadata builds and just update the block table using # `builder.update_block_table` if the builder supports it. @@ -1869,21 +1872,21 @@ def _build_attn_group_metadata( ) extra_attn_metadata_args = {} - if use_spec_decode and isinstance(builder, GDNAttentionMetadataBuilder): - assert ubid is None, "UBatching not supported with GDN yet" - current_state_block_ids = None - if self.cache_config.mamba_cache_mode == "align": - current_state_block_ids = _get_current_mamba_state_block_ids( - kv_cache_gid - ) - extra_attn_metadata_args = dict( - num_accepted_tokens=self.num_accepted_tokens.gpu[:num_reqs_padded], - num_decode_draft_tokens_cpu=self.num_decode_draft_tokens.cpu[ - :num_reqs_padded - ], - spec_sequence_masks_cpu=spec_sequence_masks_cpu, - current_state_block_ids=current_state_block_ids, - ) + if use_spec_decode and isinstance(builder, GDNAttentionMetadataBuilder): + assert ubid is None, "UBatching not supported with GDN yet" + current_state_block_ids = None + if self.cache_config.mamba_cache_mode == "align": + current_state_block_ids = _get_current_mamba_state_block_ids( + kv_cache_gid + ) + extra_attn_metadata_args = dict( + num_accepted_tokens=self.num_accepted_tokens.gpu[:num_reqs_padded], + num_decode_draft_tokens_cpu=self.num_decode_draft_tokens.cpu[ + :num_reqs_padded + ], + spec_sequence_masks_cpu=spec_sequence_masks_cpu, + current_state_block_ids=current_state_block_ids, + ) if for_cudagraph_capture: attn_metadata_i = builder.build_for_cudagraph_capture( @@ -1917,13 +1920,13 @@ def _build_attn_group_metadata( for layer_name in attn_group.layer_names: attn_metadata_dict[layer_name] = attn_metadata_i - # Prepare the attention metadata for each KV cache group and make layers - # in the same group share the same metadata. - spec_decode_common_attn_metadata = None - dflash_common_attn_metadata_by_gid: dict[int, CommonAttentionMetadata] | None - dflash_common_attn_metadata_by_gid = None - for kv_cache_gid, kv_cache_group in enumerate(kv_cache_groups): - cm = copy(cm_base) # shallow copy + # Prepare the attention metadata for each KV cache group and make layers + # in the same group share the same metadata. + spec_decode_common_attn_metadata = None + dflash_common_attn_metadata_by_gid: dict[int, CommonAttentionMetadata] | None + dflash_common_attn_metadata_by_gid = None + for kv_cache_gid, kv_cache_group in enumerate(kv_cache_groups): + cm = copy(cm_base) # shallow copy # Basically only the encoder seq_lens, block_table and slot_mapping change # for each kv_cache_group. @@ -1933,24 +1936,31 @@ def _build_attn_group_metadata( num_reqs_padded, for_cudagraph_capture=for_cudagraph_capture, ) - if kv_cache_gid > 0: - cm.block_table_tensor = _get_block_table(kv_cache_gid) - cm.slot_mapping = slot_mappings[kv_cache_gid] - - if self.speculative_config and spec_decode_common_attn_metadata is None: - if isinstance(self.drafter, (EagleProposer, DFlashProposer)): - if self.drafter.attn_layer_names[0] in kv_cache_group.layer_names: - spec_decode_common_attn_metadata = cm - else: - spec_decode_common_attn_metadata = cm - if self.speculative_config and isinstance(self.drafter, DFlashProposer): - if set(self.drafter.attn_layer_names) & set(kv_cache_group.layer_names): - if dflash_common_attn_metadata_by_gid is None: - dflash_common_attn_metadata_by_gid = {} - dflash_common_attn_metadata_by_gid[kv_cache_gid] = cm - - for attn_gid in range(len(self.attn_groups[kv_cache_gid])): - if ubatch_slices is not None: + if kv_cache_gid > 0: + cm.block_table_tensor = _get_block_table(kv_cache_gid) + cm.slot_mapping = slot_mappings[kv_cache_gid] + + if self.speculative_config and spec_decode_common_attn_metadata is None: + if isinstance( + self.drafter, (EagleProposer, DFlashProposer, Gemma4Proposer) + ): + if self.drafter.attn_layer_names[0] in kv_cache_group.layer_names: + spec_decode_common_attn_metadata = cm + else: + spec_decode_common_attn_metadata = cm + if self.speculative_config and isinstance(self.drafter, DFlashProposer): + if set(self.drafter.attn_layer_names) & set(kv_cache_group.layer_names): + if dflash_common_attn_metadata_by_gid is None: + dflash_common_attn_metadata_by_gid = {} + dflash_common_attn_metadata_by_gid[kv_cache_gid] = cm + # Capture per-group block tables for multi-group proposers (Gemma4 MTP). + if self.speculative_config and isinstance(self.drafter, Gemma4Proposer): + self.drafter.set_per_group_block_table( + kv_cache_gid, cm.block_table_tensor + ) + + for attn_gid in range(len(self.attn_groups[kv_cache_gid])): + if ubatch_slices is not None: for ubid, _cm in enumerate(split_attn_metadata(ubatch_slices, cm)): _build_attn_group_metadata(kv_cache_gid, attn_gid, _cm, ubid) @@ -1983,21 +1993,21 @@ def _build_attn_group_metadata( # Currently the drafter still only uses piecewise cudagraphs (and modifies # the attention metadata in directly), and therefore does not want to use # padded attention metadata. - spec_decode_common_attn_metadata = ( - spec_decode_common_attn_metadata.unpadded(num_tokens, num_reqs) - ) - if dflash_common_attn_metadata_by_gid is not None: - if num_reqs != num_reqs_padded or num_tokens != num_tokens_padded: - dflash_common_attn_metadata_by_gid = { - gid: metadata.unpadded(num_tokens, num_reqs) - for gid, metadata in dflash_common_attn_metadata_by_gid.items() - } - assert isinstance(self.drafter, DFlashProposer) - self.drafter.set_common_attn_metadata_by_kv_cache_group( - dflash_common_attn_metadata_by_gid - ) - - return attn_metadata, spec_decode_common_attn_metadata + spec_decode_common_attn_metadata = ( + spec_decode_common_attn_metadata.unpadded(num_tokens, num_reqs) + ) + if dflash_common_attn_metadata_by_gid is not None: + if num_reqs != num_reqs_padded or num_tokens != num_tokens_padded: + dflash_common_attn_metadata_by_gid = { + gid: metadata.unpadded(num_tokens, num_reqs) + for gid, metadata in dflash_common_attn_metadata_by_gid.items() + } + assert isinstance(self.drafter, DFlashProposer) + self.drafter.set_common_attn_metadata_by_kv_cache_group( + dflash_common_attn_metadata_by_gid + ) + + return attn_metadata, spec_decode_common_attn_metadata def _compute_cascade_attn_prefix_lens( self, @@ -2933,11 +2943,11 @@ def _preprocess( ec_connector_output, ) - def _sample( - self, - logits: torch.Tensor | None, - spec_decode_metadata: SpecDecodeMetadata | None, - ) -> SamplerOutput: + def _sample( + self, + logits: torch.Tensor | None, + spec_decode_metadata: SpecDecodeMetadata | None, + ) -> SamplerOutput: # Sample the next token and get logprobs if needed. sampling_metadata = self.input_batch.sampling_metadata # Update output token ids with tokens sampled in last step @@ -2955,15 +2965,15 @@ def _sample( draft_token_ids_cpu, _ = self._get_draft_token_ids_cpu() self.input_batch.update_async_spec_token_ids(draft_token_ids_cpu) - sampler_output = self.rejection_sampler( - spec_decode_metadata, - None, # draft_probs - logits, - sampling_metadata, - ) - return sampler_output - - def _bookkeeping_sync( + sampler_output = self.rejection_sampler( + spec_decode_metadata, + None, # draft_probs + logits, + sampling_metadata, + ) + return sampler_output + + def _bookkeeping_sync( self, scheduler_output: "SchedulerOutput", sampler_output: SamplerOutput, @@ -2981,8 +2991,8 @@ def _bookkeeping_sync( list[int], ]: num_nans_in_logits = {} - if envs.VLLM_COMPUTE_NANS_IN_LOGITS: - num_nans_in_logits = self._get_nans_in_logits(logits) + if envs.VLLM_COMPUTE_NANS_IN_LOGITS: + num_nans_in_logits = self._get_nans_in_logits(logits) num_reqs = self.input_batch.num_reqs discard_sampled_tokens_req_indices = np.nonzero( @@ -3535,26 +3545,26 @@ def execute_model( ) pad_attn = cudagraph_mode == CUDAGraphMode.FULL - if self.cache_config.mamba_cache_mode == "align": - # preprocess_mamba may reset per-request accepted counts to 1 - # after copying a running state to a new aligned block. Keep - # the GPU buffer in sync before GDN/spec metadata consumes it. - mamba_utils.preprocess_mamba( - scheduler_output, - self.kv_cache_config, - self.cache_config, + if self.cache_config.mamba_cache_mode == "align": + # preprocess_mamba may reset per-request accepted counts to 1 + # after copying a running state to a new aligned block. Keep + # the GPU buffer in sync before GDN/spec metadata consumes it. + mamba_utils.preprocess_mamba( + scheduler_output, + self.kv_cache_config, + self.cache_config, self.mamba_state_idx, self.input_batch, self.requests, - self.compilation_config.static_forward_context, - self.model.get_mamba_state_copy_func(), - ) - self.num_accepted_tokens.np[:num_reqs] = ( - self.input_batch.num_accepted_tokens_cpu[:num_reqs] - ) - self.num_accepted_tokens.copy_to_gpu() - - use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 + self.compilation_config.static_forward_context, + self.model.get_mamba_state_copy_func(), + ) + self.num_accepted_tokens.np[:num_reqs] = ( + self.input_batch.num_accepted_tokens_cpu[:num_reqs] + ) + self.num_accepted_tokens.copy_to_gpu() + + use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices slot_mappings_by_group, slot_mappings = self._get_slot_mappings( @@ -3662,8 +3672,8 @@ def execute_model( kv_connector_output, ) - sample_hidden_states = hidden_states[logits_indices] - logits = self.model.compute_logits(sample_hidden_states) + sample_hidden_states = hidden_states[logits_indices] + logits = self.model.compute_logits(sample_hidden_states) else: # Rare case. assert not self.is_pooling_model @@ -3748,14 +3758,14 @@ def sample_tokens( # Clear ephemeral state. self.execute_model_state = None - # Apply structured output bitmasks if present. - if grammar_output is not None: - apply_grammar_bitmask( - scheduler_output, grammar_output, self.input_batch, logits - ) - - with record_function_or_nullcontext("gpu_model_runner: sample"): - sampler_output = self._sample(logits, spec_decode_metadata) + # Apply structured output bitmasks if present. + if grammar_output is not None: + apply_grammar_bitmask( + scheduler_output, grammar_output, self.input_batch, logits + ) + + with record_function_or_nullcontext("gpu_model_runner: sample"): + sampler_output = self._sample(logits, spec_decode_metadata) self._update_states_after_model_execute( sampler_output.sampled_token_ids, scheduler_output @@ -3770,13 +3780,13 @@ def sample_tokens( sampler_output.sampled_token_ids ) - self._draft_token_ids = None - self._draft_token_req_ids = None - self._draft_probs = None - self._draft_prob_req_ids = None - self.input_batch.prev_sampled_token_ids = None - - def propose_draft_token_ids(sampled_token_ids): + self._draft_token_ids = None + self._draft_token_req_ids = None + self._draft_probs = None + self._draft_prob_req_ids = None + self.input_batch.prev_sampled_token_ids = None + + def propose_draft_token_ids(sampled_token_ids): assert spec_decode_common_attn_metadata is not None with record_function_or_nullcontext("gpu_model_runner: draft"): self._draft_token_ids = self.propose_draft_token_ids( @@ -3787,15 +3797,15 @@ def propose_draft_token_ids(sampled_token_ids): sample_hidden_states, aux_hidden_states, spec_decode_metadata, - spec_decode_common_attn_metadata, - slot_mappings, - ) - if hasattr(self.drafter, "take_last_draft_probs"): - draft_probs = self.drafter.take_last_draft_probs() - if draft_probs is not None: - self._draft_probs = draft_probs - self._draft_prob_req_ids = self.input_batch.req_ids.copy() - self._copy_draft_token_ids_to_cpu(scheduler_output) + spec_decode_common_attn_metadata, + slot_mappings, + ) + if hasattr(self.drafter, "take_last_draft_probs"): + draft_probs = self.drafter.take_last_draft_probs() + if draft_probs is not None: + self._draft_probs = draft_probs + self._draft_prob_req_ids = self.input_batch.req_ids.copy() + self._copy_draft_token_ids_to_cpu(scheduler_output) spec_config = self.speculative_config propose_drafts_after_bookkeeping = False @@ -3804,28 +3814,28 @@ def propose_draft_token_ids(sampled_token_ids): spec_decode_common_attn_metadata.max_seq_len + self.num_spec_tokens <= self.effective_drafter_max_model_len ) - use_gpu_toks = ( - spec_config.use_eagle() - or spec_config.use_dflash() - or spec_config.uses_draft_model() - ) and not spec_config.disable_padded_drafter_batch - if use_gpu_toks: - # EAGLE/DraftModel speculative decoding can use the GPU sampled tokens - # as inputs, and does not need to wait for bookkeeping to finish. - assert isinstance( - self.drafter, EagleProposer | DFlashProposer | DraftModelProposer - ) + use_gpu_toks = ( + spec_config.use_eagle() + or spec_config.use_dflash() + or spec_config.uses_draft_model() + ) and not spec_config.disable_padded_drafter_batch + if use_gpu_toks: + # EAGLE/DraftModel speculative decoding can use the GPU sampled tokens + # as inputs, and does not need to wait for bookkeeping to finish. + assert isinstance( + self.drafter, EagleProposer | DFlashProposer | DraftModelProposer | Gemma4Proposer + ) sampled_token_ids = sampler_output.sampled_token_ids if input_fits_in_drafter: propose_draft_token_ids(sampled_token_ids) elif self.valid_sampled_token_count_event is not None: - assert spec_decode_common_attn_metadata is not None - next_token_ids, valid_sampled_tokens_count = ( - self.drafter.prepare_next_token_ids_padded( - sampled_token_ids, - self.requests, - self.input_batch, - self.discard_request_mask.gpu, + assert spec_decode_common_attn_metadata is not None + next_token_ids, valid_sampled_tokens_count = ( + self.drafter.prepare_next_token_ids_padded( + sampled_token_ids, + self.requests, + self.input_batch, + self.discard_request_mask.gpu, ) ) self._copy_valid_sampled_token_count( @@ -3992,49 +4002,49 @@ def _copy_draft_token_ids_to_cpu( self.draft_token_ids_cpu[:num_reqs] = 0 self.draft_token_ids_event.record() - def _get_draft_token_ids_cpu(self) -> tuple[list[list[int]], list[str]]: - if isinstance(self._draft_token_ids, list): - return self._draft_token_ids, self.input_batch.req_ids - req_ids = self._draft_token_req_ids + def _get_draft_token_ids_cpu(self) -> tuple[list[list[int]], list[str]]: + if isinstance(self._draft_token_ids, list): + return self._draft_token_ids, self.input_batch.req_ids + req_ids = self._draft_token_req_ids if req_ids is None: return [], [] assert self.draft_token_ids_event is not None assert self.draft_token_ids_cpu is not None - self.draft_token_ids_event.synchronize() - return self.draft_token_ids_cpu[: len(req_ids)].tolist(), req_ids - - def _get_spec_decode_draft_probs( - self, spec_decode_metadata: SpecDecodeMetadata - ) -> torch.Tensor | None: - if self._draft_probs is None or self._draft_prob_req_ids is None: - return None - - row_by_req_id = { - req_id: idx for idx, req_id in enumerate(self._draft_prob_req_ids) - } - draft_probs_rows: list[torch.Tensor] = [] - for req_id, num_draft in zip( - self.input_batch.req_ids, spec_decode_metadata.num_draft_tokens - ): - if num_draft == 0: - continue - row_idx = row_by_req_id.get(req_id) - if row_idx is None: - logger.warning_once( - "Missing draft probabilities for request %s; falling back " - "to deterministic draft rejection sampling.", - req_id, - ) - return None - draft_probs_rows.append(self._draft_probs[row_idx, :num_draft]) - - if not draft_probs_rows: - return None - return torch.cat(draft_probs_rows, dim=0).contiguous() - - def _copy_valid_sampled_token_count( - self, next_token_ids: torch.Tensor, valid_sampled_tokens_count: torch.Tensor - ) -> None: + self.draft_token_ids_event.synchronize() + return self.draft_token_ids_cpu[: len(req_ids)].tolist(), req_ids + + def _get_spec_decode_draft_probs( + self, spec_decode_metadata: SpecDecodeMetadata + ) -> torch.Tensor | None: + if self._draft_probs is None or self._draft_prob_req_ids is None: + return None + + row_by_req_id = { + req_id: idx for idx, req_id in enumerate(self._draft_prob_req_ids) + } + draft_probs_rows: list[torch.Tensor] = [] + for req_id, num_draft in zip( + self.input_batch.req_ids, spec_decode_metadata.num_draft_tokens + ): + if num_draft == 0: + continue + row_idx = row_by_req_id.get(req_id) + if row_idx is None: + logger.warning_once( + "Missing draft probabilities for request %s; falling back " + "to deterministic draft rejection sampling.", + req_id, + ) + return None + draft_probs_rows.append(self._draft_probs[row_idx, :num_draft]) + + if not draft_probs_rows: + return None + return torch.cat(draft_probs_rows, dim=0).contiguous() + + def _copy_valid_sampled_token_count( + self, next_token_ids: torch.Tensor, valid_sampled_tokens_count: torch.Tensor + ) -> None: if self.valid_sampled_token_count_event is None: return @@ -4121,14 +4131,14 @@ def propose_draft_token_ids( sampling_metadata=sampling_metadata, slot_mappings=slot_mappings, ) - elif ( - spec_config.use_eagle() - or spec_config.use_dflash() - or spec_config.uses_draft_model() - ): - assert isinstance( - self.drafter, EagleProposer | DFlashProposer | DraftModelProposer - ) + elif ( + spec_config.use_eagle() + or spec_config.use_dflash() + or spec_config.uses_draft_model() + ): + assert isinstance( + self.drafter, EagleProposer | DFlashProposer | DraftModelProposer | Gemma4Proposer + ) if spec_config.disable_padded_drafter_batch: # When padded-batch is disabled, the sampled_token_ids should be @@ -4152,39 +4162,39 @@ def propose_draft_token_ids( assert isinstance(sampled_token_ids, torch.Tensor), ( "sampled_token_ids should be a torch.Tensor when" "padded-batch is enabled." - ) - next_token_ids, valid_sampled_tokens_count = ( - self.drafter.prepare_next_token_ids_padded( - sampled_token_ids, - self.requests, - self.input_batch, - self.discard_request_mask.gpu, - ) - ) - if self.use_async_scheduling or not spec_config.use_dflash(): - self._copy_valid_sampled_token_count( - next_token_ids, valid_sampled_tokens_count - ) + ) + next_token_ids, valid_sampled_tokens_count = ( + self.drafter.prepare_next_token_ids_padded( + sampled_token_ids, + self.requests, + self.input_batch, + self.discard_request_mask.gpu, + ) + ) + if self.use_async_scheduling or not spec_config.use_dflash(): + self._copy_valid_sampled_token_count( + next_token_ids, valid_sampled_tokens_count + ) num_rejected_tokens_gpu = None if spec_decode_metadata is None: token_indices_to_sample = None # input_ids can be None for multimodal models. - target_token_ids = self.input_ids.gpu[:num_scheduled_tokens] - target_positions = self._get_positions(num_scheduled_tokens) - if self.use_aux_hidden_state_outputs: - if aux_hidden_states is None: - aux_layers = self._get_eagle3_aux_layers_from_config() or (0,) - target_hidden_states = hidden_states[ - :num_scheduled_tokens - ].repeat(1, len(aux_layers)) - else: - target_hidden_states = torch.cat( - [h[:num_scheduled_tokens] for h in aux_hidden_states], - dim=-1, - ) - else: - target_hidden_states = hidden_states[:num_scheduled_tokens] + target_token_ids = self.input_ids.gpu[:num_scheduled_tokens] + target_positions = self._get_positions(num_scheduled_tokens) + if self.use_aux_hidden_state_outputs: + if aux_hidden_states is None: + aux_layers = self._get_eagle3_aux_layers_from_config() or (0,) + target_hidden_states = hidden_states[ + :num_scheduled_tokens + ].repeat(1, len(aux_layers)) + else: + target_hidden_states = torch.cat( + [h[:num_scheduled_tokens] for h in aux_hidden_states], + dim=-1, + ) + else: + target_hidden_states = hidden_states[:num_scheduled_tokens] else: if spec_config.disable_padded_drafter_batch: token_indices_to_sample = None @@ -4193,23 +4203,23 @@ def propose_draft_token_ids( sampled_token_ids, spec_decode_metadata.num_draft_tokens, ) - target_token_ids = self.input_ids.gpu[token_indices] - target_positions = self._get_positions(token_indices) - if self.use_aux_hidden_state_outputs: - if aux_hidden_states is None: - aux_layers = self._get_eagle3_aux_layers_from_config() or ( - 0, - ) - target_hidden_states = hidden_states[token_indices].repeat( - 1, len(aux_layers) - ) - else: - target_hidden_states = torch.cat( - [h[token_indices] for h in aux_hidden_states], - dim=-1, - ) - else: - target_hidden_states = hidden_states[token_indices] + target_token_ids = self.input_ids.gpu[token_indices] + target_positions = self._get_positions(token_indices) + if self.use_aux_hidden_state_outputs: + if aux_hidden_states is None: + aux_layers = self._get_eagle3_aux_layers_from_config() or ( + 0, + ) + target_hidden_states = hidden_states[token_indices].repeat( + 1, len(aux_layers) + ) + else: + target_hidden_states = torch.cat( + [h[token_indices] for h in aux_hidden_states], + dim=-1, + ) + else: + target_hidden_states = hidden_states[token_indices] else: ( common_attn_metadata, @@ -4222,23 +4232,23 @@ def propose_draft_token_ids( ) total_num_tokens = common_attn_metadata.num_actual_tokens # When padding the batch, token_indices is just a range - target_token_ids = self.input_ids.gpu[:total_num_tokens] - target_positions = self._get_positions(total_num_tokens) - if self.use_aux_hidden_state_outputs: - if aux_hidden_states is None: - aux_layers = self._get_eagle3_aux_layers_from_config() or ( - 0, - ) - target_hidden_states = hidden_states[ - :total_num_tokens - ].repeat(1, len(aux_layers)) - else: - target_hidden_states = torch.cat( - [h[:total_num_tokens] for h in aux_hidden_states], - dim=-1, - ) - else: - target_hidden_states = hidden_states[:total_num_tokens] + target_token_ids = self.input_ids.gpu[:total_num_tokens] + target_positions = self._get_positions(total_num_tokens) + if self.use_aux_hidden_state_outputs: + if aux_hidden_states is None: + aux_layers = self._get_eagle3_aux_layers_from_config() or ( + 0, + ) + target_hidden_states = hidden_states[ + :total_num_tokens + ].repeat(1, len(aux_layers)) + else: + target_hidden_states = torch.cat( + [h[:total_num_tokens] for h in aux_hidden_states], + dim=-1, + ) + else: + target_hidden_states = hidden_states[:total_num_tokens] if self.supports_mm_inputs and self.drafter.supports_mm_inputs: mm_embed_inputs = self._gather_mm_embeddings( @@ -4250,13 +4260,13 @@ def propose_draft_token_ids( draft_token_ids = self.drafter.propose( target_token_ids=target_token_ids, - target_positions=target_positions, - target_hidden_states=target_hidden_states, - next_token_ids=next_token_ids, - token_indices_to_sample=token_indices_to_sample, - sampling_metadata=sampling_metadata, - common_attn_metadata=common_attn_metadata, - mm_embed_inputs=mm_embed_inputs, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + next_token_ids=next_token_ids, + token_indices_to_sample=token_indices_to_sample, + sampling_metadata=sampling_metadata, + common_attn_metadata=common_attn_metadata, + mm_embed_inputs=mm_embed_inputs, num_rejected_tokens_gpu=num_rejected_tokens_gpu, slot_mappings=slot_mappings, ) @@ -4457,18 +4467,18 @@ def _get_eagle3_aux_layers_from_config(self) -> tuple[int, ...] | None: Tuple of layer indices if found in draft model config, None otherwise. """ - if not (self.speculative_config and self.speculative_config.draft_model_config): - return None - - hf_config = self.speculative_config.draft_model_config.hf_config - layer_ids = getattr(hf_config, "eagle_aux_hidden_state_layer_ids", None) - if not layer_ids: - dflash_config = getattr(hf_config, "dflash_config", None) - if dflash_config and isinstance(dflash_config, dict): - layer_ids = dflash_config.get("target_layer_ids") - if layer_ids and isinstance(layer_ids, (list, tuple)): - return tuple(layer_ids) - + if not (self.speculative_config and self.speculative_config.draft_model_config): + return None + + hf_config = self.speculative_config.draft_model_config.hf_config + layer_ids = getattr(hf_config, "eagle_aux_hidden_state_layer_ids", None) + if not layer_ids: + dflash_config = getattr(hf_config, "dflash_config", None) + if dflash_config and isinstance(dflash_config, dict): + layer_ids = dflash_config.get("target_layer_ids") + if layer_ids and isinstance(layer_ids, (list, tuple)): + return tuple(layer_ids) + return None def reload_weights( @@ -5028,21 +5038,21 @@ def _dummy_run( intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, **model_kwargs, - ) - - if self.use_aux_hidden_state_outputs: - hidden_states, _ = self._split_aux_model_output(outputs) - else: - hidden_states = outputs - - if self.speculative_config and ( - self.speculative_config.use_eagle() - or self.speculative_config.use_dflash() - or self.speculative_config.uses_draft_model() - ): - assert isinstance( - self.drafter, EagleProposer | DFlashProposer | DraftModelProposer - ) + ) + + if self.use_aux_hidden_state_outputs: + hidden_states, _ = self._split_aux_model_output(outputs) + else: + hidden_states = outputs + + if self.speculative_config and ( + self.speculative_config.use_eagle() + or self.speculative_config.use_dflash() + or self.speculative_config.uses_draft_model() + ): + assert isinstance( + self.drafter, EagleProposer | DFlashProposer | DraftModelProposer | Gemma4Proposer + ) assert self.speculative_config is not None # Eagle currently only supports PIECEWISE cudagraphs. # Therefore only use cudagraphs if the main model uses PIECEWISE @@ -5580,21 +5590,21 @@ def initialize_metadata_builders( else self.parallel_config.num_ubatches, ) # Calculate reorder batch threshold (if needed) - # Note (tdoublep): do this *after* constructing builders, - # because some of them change the threshold at init time. - self.calculate_reorder_batch_threshold() - - if self.speculative_config and ( - self.speculative_config.use_eagle() - or self.speculative_config.use_dflash() - or self.speculative_config.uses_draft_model() - ): - assert isinstance( - self.drafter, EagleProposer | DFlashProposer | DraftModelProposer - ) - self.drafter.initialize_attn_backend(kv_cache_config, kernel_block_sizes) - - def _check_and_update_cudagraph_mode( + # Note (tdoublep): do this *after* constructing builders, + # because some of them change the threshold at init time. + self.calculate_reorder_batch_threshold() + + if self.speculative_config and ( + self.speculative_config.use_eagle() + or self.speculative_config.use_dflash() + or self.speculative_config.uses_draft_model() + ): + assert isinstance( + self.drafter, EagleProposer | DFlashProposer | DraftModelProposer | Gemma4Proposer + ) + self.drafter.initialize_attn_backend(kv_cache_config, kernel_block_sizes) + + def _check_and_update_cudagraph_mode( self, attention_backends: list[set[type[AttentionBackend]]], kv_cache_groups: list[KVCacheGroupSpec], @@ -5749,13 +5759,15 @@ def _check_and_update_cudagraph_mode( cudagraph_mode, self.uniform_decode_query_len ) - # Initialize eagle/dflash cudagraph dispatcher if using spec decode. - if self.speculative_config and ( - self.speculative_config.use_eagle() - or self.speculative_config.use_dflash() - ): - assert isinstance(self.drafter, EagleProposer | DFlashProposer) - self.drafter.initialize_cudagraph_keys(cudagraph_mode) + # Initialize eagle/dflash cudagraph dispatcher if using spec decode. + if self.speculative_config and ( + self.speculative_config.use_eagle() + or self.speculative_config.use_dflash() + ): + assert isinstance( + self.drafter, EagleProposer | DFlashProposer | Gemma4Proposer + ) + self.drafter.initialize_cudagraph_keys(cudagraph_mode) def calculate_reorder_batch_threshold(self) -> None: """ @@ -6235,14 +6247,14 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: kv_cache_config, kernel_block_sizes ) - if self.speculative_config and ( - self.speculative_config.use_eagle() - or self.speculative_config.use_dflash() - or self.speculative_config.uses_draft_model() - ): - assert isinstance( - self.drafter, EagleProposer | DFlashProposer | DraftModelProposer - ) + if self.speculative_config and ( + self.speculative_config.use_eagle() + or self.speculative_config.use_dflash() + or self.speculative_config.uses_draft_model() + ): + assert isinstance( + self.drafter, EagleProposer | DFlashProposer | DraftModelProposer | Gemma4Proposer + ) # validate all draft model layers belong to the same kv cache # group self.drafter.validate_same_kv_cache_group(kv_cache_config) From a8b86148665717b5bf4e9cd3e8a508926d274306 Mon Sep 17 00:00:00 2001 From: Phil Date: Fri, 5 Jun 2026 23:40:42 +0000 Subject: [PATCH 13/25] [gemma4] Add image+audio multimodal (tower-based Gemma4ForConditionalGeneration) - Vendor gemma4_mm.py (1706 LOC): SigLIP vision tower + audio tower + multimodal embedders on top of Gemma4ForCausalLM; register Gemma4ForConditionalGeneration -> gemma4_mm. - Backport recursive_replace_linear into models/transformers/utils.py (deps replace_linear_class + maybe_prefix already present). - Redirect MultiModalDataDict import to vllm.multimodal (this base's home; upstream re-exports via vllm.inputs). - Skip gemma4_unified.py: that's the encoder-free 12B Unified variant (needs transformers.models.gemma4_unified, absent in transformers 5.7); not our 31B. Verified on live .venv-v110: gemma4_mm imports clean; Gemma4ForConditional Generation registers. Co-Authored-By: RivetOS Claude --- vllm/model_executor/models/gemma4_mm.py | 1708 +++++++++++++++++ vllm/model_executor/models/registry.py | 1 + .../models/transformers/utils.py | 28 + 3 files changed, 1737 insertions(+) create mode 100644 vllm/model_executor/models/gemma4_mm.py diff --git a/vllm/model_executor/models/gemma4_mm.py b/vllm/model_executor/models/gemma4_mm.py new file mode 100644 index 0000000000..49846b121e --- /dev/null +++ b/vllm/model_executor/models/gemma4_mm.py @@ -0,0 +1,1708 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Gemma 4 multimodal model (image + audio + video support). + +Adds vision tower, audio tower, and multimodal embedders on top of the +text-only Gemma4ForCausalLM. The vision/audio encoders are loaded via +AutoModel.from_config and run in eager mode while the language model uses +the vLLM-optimized path. + +Video support: Gemma4 does **not** have a native video tower. Videos are +decomposed into timestamped image frames (up to 32 frames at 70 soft tokens +each) and fed through the same vision tower as regular images. The +processor inserts ``mm:ss`` timestamps between frames so the model can +reason about temporal order. +""" + +import math +from collections.abc import Iterable, Mapping, Sequence +from typing import TYPE_CHECKING, Annotated, Any, Literal + +import numpy as np +import torch +from PIL import Image as PILImage +from torch import nn +from transformers import AutoModel, BatchFeature +from transformers.models.gemma4 import ( + Gemma4Config, + Gemma4Processor, + Gemma4VisionConfig, +) +from transformers.models.gemma4.configuration_gemma4 import ( + Gemma4AudioConfig, + Gemma4TextConfig, +) + +from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions +# NOTE(rivet): this base exports MultiModalDataDict from vllm.multimodal +# (upstream re-exports it via vllm.inputs). +from vllm.multimodal import MultiModalDataDict +from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.models.gemma4 import Gemma4ForCausalLM +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.model_executor.models.transformers.utils import recursive_replace_linear +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import ( + MultiModalFieldConfig, + MultiModalKwargsItems, + VideoItem, +) +from vllm.multimodal.parse import ( + AudioProcessorItems, + ImageProcessorItems, + MultiModalDataItems, + MultiModalDataParser, +) +from vllm.multimodal.processing import BaseDummyInputsBuilder +from vllm.multimodal.processing.processor import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) +from vllm.platforms import current_platform +from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape + +from .interfaces import ( + MultiModalEmbeddings, + SupportsEagle3, + SupportsLoRA, + SupportsMultiModal, + SupportsPP, + SupportsQuant, +) +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + init_vllm_registered_model, + maybe_prefix, +) + +if TYPE_CHECKING: + from vllm.model_executor.layers.quantization import QuantizationConfig + +logger = init_logger(__name__) + +# Video constants — match transformers Gemma4VideoProcessor defaults. +_SUPPORTED_SOFT_TOKENS = (70, 140, 280, 560, 1120) +_VIDEO_MAX_SOFT_TOKENS = 70 # soft tokens per video frame (vs 280 for images) +_VIDEO_MAX_FRAMES = 32 # max sampled frames per video + + +def _get_max_soft_tokens( + merged_kwargs: Mapping[str, object], +) -> tuple[object | None, bool]: + """Return configured image max_soft_tokens and whether it is top-level.""" + val = merged_kwargs.get("max_soft_tokens") + if val is not None: + return val, True + + images_kwargs = merged_kwargs.get("images_kwargs") + if isinstance(images_kwargs, Mapping): + return images_kwargs.get("max_soft_tokens"), False + + return None, False + + +# --------------------------------------------------------------------------- +# Input schema +# --------------------------------------------------------------------------- + + +class Gemma4ImagePixelInputs(TensorSchema): + """ + Pre-patchified image inputs from the Gemma4 image processor. + + Dimensions: + - bn: Batch size * number of images + - np: Number of patches (max_patches = max_soft_tokens * pooling_kernel_size²) + - pp: Patch pixels (patch_size² * 3) + + The Gemma4 image processor outputs pixel_values as + (batch, max_patches, patch_pixels) — already patchified with + zero-padding for patches beyond the real image content. + pixel_position_ids provides (x, y) coordinates per patch, + with (-1, -1) for padding patches. + """ + + type: Literal["pixel_values"] = "pixel_values" + pixel_values: Annotated[ + torch.Tensor | list[torch.Tensor], + TensorShape("bn", "np", "pp", dynamic_dims={"np"}), + ] + pixel_position_ids: Annotated[ + torch.Tensor | list[torch.Tensor], + TensorShape("bn", "np", 2, dynamic_dims={"np"}), + ] + + +class Gemma4AudioInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of audios + - s: Sequence length (MEL spectrogram frames) + - f: Number of features (MEL bins) + """ + + type: Literal["audio"] = "audio" + input_features_padded: Annotated[ + torch.Tensor, TensorShape("bn", "s", "f", dynamic_dims={"s"}) + ] + input_features_mask: Annotated[ + torch.Tensor, TensorShape("bn", "s", dynamic_dims={"s"}) + ] + + +Gemma4ImageInputs = Gemma4ImagePixelInputs + + +class Gemma4VideoInputs(TensorSchema): + """Video frame inputs — same tensor format as image inputs. + + Gemma4 has no separate video tower; video frames are processed + through the vision tower at lower resolution (max_soft_tokens=70). + """ + + type: Literal["pixel_values_videos"] = "pixel_values_videos" + pixel_values_videos: Annotated[ + torch.Tensor, + TensorShape("bn", "np", "pp"), + ] + pixel_position_ids_videos: Annotated[ + torch.Tensor, + TensorShape("bn", "np", 2), + ] + + +# --------------------------------------------------------------------------- +# Processing info +# --------------------------------------------------------------------------- + + +class Gemma4ProcessingInfo(BaseProcessingInfo): + def get_hf_config(self): + return self.ctx.get_hf_config(Gemma4Config) + + def get_default_tok_params(self): + """Gemma4's chat template already embeds a literal ```` token in + the rendered text. If ``add_special_tokens=True`` (the base-class + default), the tokenizer prepends *another* BOS, producing a + ``[2, 2, ...]`` double-BOS sequence that the model was not trained on. + + Setting ``add_special_tokens=False`` here prevents the duplicate and + ensures both ``llm.generate()`` and the chat/completions API behave + correctly for IT models. For PT models (without chat template), we + keep the default (True) to ensure BOS is added for raw prompts. + """ + tokenizer = self.ctx.get_tokenizer() + has_chat_template = getattr(tokenizer, "chat_template", None) is not None + + params = super().get_default_tok_params() + if has_chat_template: + params = params.with_kwargs(add_special_tokens=False) + return params + + def get_hf_processor(self, **kwargs: object) -> Gemma4Processor: + return self.ctx.get_hf_processor( + Gemma4Processor, + **kwargs, + ) + + def validate_num_items(self, modality: str, num_items: int) -> None: + if ( + modality == "audio" + and num_items > 0 + and self.get_hf_config().audio_config is None + ): + model = self.ctx.model_config.model + raise ValueError( + f"Audio input was provided but the model " + f"'{model}' does not have an audio tower. " + f"Audio inference is only supported for Gemma4 " + f"models that include an audio_config " + f"(i.e., models that include an audio_config)." + ) + super().validate_num_items(modality, num_items) + + def get_supported_mm_limits(self) -> Mapping[str, int | None]: + limits: dict[str, int | None] = {"image": None} + if self.get_hf_config().audio_config is not None: + limits["audio"] = None + limits["video"] = None + return limits + + def get_mm_max_tokens_per_item( + self, seq_len: int, mm_counts: Mapping[str, int] + ) -> Mapping[str, int] | None: + config = self.get_hf_config() + # Upper bound: the pooler outputs max_soft_tokens slots per image. + # After padding is stripped the actual count is ≤ this value, but + # vLLM needs the max for memory planning. + tokens_per_image = config.vision_config.default_output_length + merged_kwargs = self.ctx.get_merged_mm_kwargs({}) + val, _ = _get_max_soft_tokens(merged_kwargs) + if isinstance(val, int) and val in _SUPPORTED_SOFT_TOKENS: + tokens_per_image = val + tokens: dict[str, int] = {"image": tokens_per_image} + if config.audio_config is not None: + # Audio max tokens from the processor's audio_seq_length. + processor = self.get_hf_processor() + tokens["audio"] = processor.audio_seq_length + # Video: each frame ≤ 70 soft tokens + boi + eoi + ~6 ts tokens. + num_frames = _VIDEO_MAX_FRAMES + mm_config = self.ctx.model_config.get_multimodal_config() + video_opts = mm_config.limit_per_prompt.get("video") + if ( + isinstance(video_opts, VideoDummyOptions) + and video_opts.num_frames is not None + ): + num_frames = min(num_frames, video_opts.num_frames) + tokens["video"] = num_frames * (_VIDEO_MAX_SOFT_TOKENS + 2 + 6) + return tokens + + def get_data_parser(self) -> MultiModalDataParser: + config = self.get_hf_config() + kwargs: dict[str, Any] = {"video_needs_metadata": True} + if getattr(config, "audio_config", None) is not None: + processor = self.get_hf_processor() + kwargs["target_sr"] = processor.feature_extractor.sampling_rate + return MultiModalDataParser(**kwargs) + + def _compute_num_soft_tokens( + self, + image_width: int, + image_height: int, + max_soft_tokens: int | None = None, + ) -> int: + """Compute the number of soft tokens the vision tower produces + for an image of the given dimensions, after padding is stripped. + + Args: + max_soft_tokens: Override for the vision config's + ``default_output_length``. When *None*, the value from + the model config is used. + """ + vision_cfg = self.get_hf_config().vision_config + patch_size = vision_cfg.patch_size + pooling_kernel_size = vision_cfg.pooling_kernel_size + + if max_soft_tokens is None: + max_soft_tokens = vision_cfg.default_output_length + + unit = patch_size * pooling_kernel_size + max_patches = max_soft_tokens * pooling_kernel_size**2 + num_patches_orig = (image_height / patch_size) * (image_width / patch_size) + scale = math.sqrt(max_patches / num_patches_orig) + target_h = max(unit, int(math.floor(image_height * scale / unit)) * unit) + target_w = max(unit, int(math.floor(image_width * scale / unit)) * unit) + num_patches = (target_h // patch_size) * (target_w // patch_size) + # Clamp to ``max_soft_tokens``: extreme aspect ratios (e.g. 3x900) + # cause the floor() above to round one dim up to ``unit`` while the + # other scales freely, which over-shoots ``max_patches``. The HF + # Gemma 4 image processor caps its vision-tower output at + # ``max_soft_tokens``, so without this clamp the prompt-side + # placeholder count exceeds the encoder output and + # ``_merge_multimodal_embeddings`` crashes. + return min(num_patches // (pooling_kernel_size**2), max_soft_tokens) + + def get_image_repl( + self, + *, + image_width: int, + image_height: int, + processor: Gemma4Processor | None, + max_soft_tokens: int | None = None, + ) -> PromptUpdateDetails[list[int]]: + """Return the dynamic image token sequence for this image. + + Computes the exact number of soft tokens the vision tower will + produce after stripping padding. + + Args: + max_soft_tokens: Override for the default token budget. + When *None*, falls back to the model config value. + """ + if processor is None: + processor = self.get_hf_processor() + + num_soft = self._compute_num_soft_tokens( + image_width, + image_height, + max_soft_tokens=max_soft_tokens, + ) + config = self.get_hf_config() + token_ids = ( + [config.boi_token_id] + + [processor.image_token_id] * num_soft + + [config.eoi_token_id] + ) + return PromptUpdateDetails.select_token_id(token_ids, processor.image_token_id) + + @staticmethod + def _compute_audio_num_tokens( + num_samples: int, sampling_rate: int, audio_seq_length: int + ) -> int: + """Replicate the audio encoder's sequence-length arithmetic. + + Mirrors: mel framing (_unfold in Gemma4AudioFeatureExtractor) + followed by two Conv2d subsampling layers (kernel=3, stride=2, + semicausal padding top=1, bottom=1), capped at audio_seq_length. + """ + frame_length = int(round(sampling_rate * 20.0 / 1000.0)) + hop_length = int(round(sampling_rate * 10.0 / 1000.0)) + frame_size_for_unfold = frame_length + 1 + pad_left = frame_length // 2 + padded_samples = num_samples + pad_left + num_mel_frames = (padded_samples - frame_size_for_unfold) // hop_length + 1 + if num_mel_frames <= 0: + return 0 + t = num_mel_frames + for _ in range(2): + t = (t + 2 - 3) // 2 + 1 + return min(t, audio_seq_length) + + def get_audio_repl( + self, + *, + audio_len: int, + processor: Gemma4Processor | None, + ) -> PromptUpdateDetails[list[int]]: + """Return the dynamic audio token sequence for this audio. + + Computes the number of soft tokens from the audio waveform + length by replicating the audio encoder's sequence-length + arithmetic (mel framing + two Conv2d subsampling layers). + """ + if processor is None: + processor = self.get_hf_processor() + + sampling_rate = processor.feature_extractor.sampling_rate + num_tokens = self._compute_audio_num_tokens( + audio_len, sampling_rate, processor.audio_seq_length + ) + config = self.get_hf_config() + token_ids = ( + [config.boa_token_id] + + [processor.audio_token_id] * num_tokens + + [getattr(config, "eoa_token_id", config.eoa_token_index)] + ) + return PromptUpdateDetails.select_token_id(token_ids, processor.audio_token_id) + + def get_video_repl( + self, + *, + timestamps: list[float], + num_soft_tokens_per_frame: list[int], + processor: Gemma4Processor, + ) -> PromptUpdateDetails[list[int]]: + """Build the full token replacement for one video. + + Produces the same interleaved sequence as the HF Gemma4Processor: + mm:ss <|video|>*N mm:ss <|video|>*N ... + """ + tokenizer = self.ctx.get_tokenizer() + config = self.get_hf_config() + + boi_token_id = config.boi_token_id + eoi_token_id = config.eoi_token_id + video_token_id = processor.video_token_id + + all_token_ids: list[int] = [] + for i, (ts, n_tokens) in enumerate(zip(timestamps, num_soft_tokens_per_frame)): + # mm:ss timestamp — matches transformers: int-truncated, + # zero-padded. + minutes = int(ts // 60) + seconds = int(ts % 60) + ts_str = f"{minutes:02d}:{seconds:02d}" + + prefix = f" {ts_str} " if i > 0 else f"{ts_str} " + ts_token_ids = tokenizer.encode(prefix, add_special_tokens=False) + all_token_ids.extend(ts_token_ids) + + all_token_ids.append(boi_token_id) + all_token_ids.extend([video_token_id] * n_tokens) + all_token_ids.append(eoi_token_id) + + return PromptUpdateDetails.select_token_id(all_token_ids, video_token_id) + + +# --------------------------------------------------------------------------- +# Dummy inputs builder +# --------------------------------------------------------------------------- + + +class Gemma4DummyInputsBuilder(BaseDummyInputsBuilder[Gemma4ProcessingInfo]): + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + num_audios = mm_counts.get("audio", 0) + num_videos = mm_counts.get("video", 0) + processor = self.info.get_hf_processor() + # Use image_token (<|image|>) with tab prefix — this is what the + # Gemma4 chat template inserts per image (\t<|image|>). + # _get_prompt_updates targets image_token and expands it to the + # full_image_sequence. + text = ("\t" + processor.image_token) * num_images + if num_audios > 0 and processor.audio_token: + text += processor.audio_token * num_audios + if num_videos > 0: + text += processor.video_token * num_videos + return text + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_options: Mapping[str, BaseDummyOptions] | None = None, + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + num_audios = mm_counts.get("audio", 0) + num_videos = mm_counts.get("video", 0) + processor = self.info.get_hf_processor() + image_processor = processor.image_processor + # Use processor's configured image size for dummies. + # Gemma4ImageProcessor sets size=None (it uses patch_size / + # max_soft_tokens instead of the standard size dict), so we + # guard against None with `or {}`. + size = getattr(image_processor, "size", None) or {} + img_width = size.get("width", 224) + img_height = size.get("height", 224) + + image_overrides = mm_options.get("image") if mm_options else None + audio_overrides = mm_options.get("audio") if mm_options else None + video_overrides = mm_options.get("video") if mm_options else None + + data: MultiModalDataDict = { + "image": self._get_dummy_images( + width=img_width, + height=img_height, + num_images=num_images, + overrides=image_overrides, + ), + } + + if num_audios > 0: + audio_len = processor.feature_extractor.fft_length + data["audio"] = self._get_dummy_audios( + length=audio_len, + num_audios=num_audios, + overrides=audio_overrides, + ) + + if num_videos > 0: + data["video"] = self._get_dummy_videos( + width=img_width, + height=img_height, + num_frames=_VIDEO_MAX_FRAMES, + num_videos=num_videos, + overrides=video_overrides, + ) + + return data + + def _get_dummy_videos( + self, + *, + width: int, + height: int, + num_frames: int, + num_videos: int, + overrides: VideoDummyOptions | None = None, + ) -> list[VideoItem]: + num_frames = max(num_frames, 2) + videos = super()._get_dummy_videos( + width=width, + height=height, + num_frames=num_frames, + num_videos=num_videos, + overrides=overrides, + ) + videos = [v.copy() for v in videos] + + video_items: list[VideoItem] = [] + for video in videos: + video_num_frames = video.shape[0] + video_metadata = { + "fps": 2.0, + "duration": video_num_frames / 2.0, + "total_num_frames": video_num_frames, + "frames_indices": list(range(video_num_frames)), + "video_backend": "opencv", + "do_sample_frames": False, + } + video_items.append((video, video_metadata)) + + return video_items + + +# --------------------------------------------------------------------------- +# Multimodal processor +# --------------------------------------------------------------------------- + + +class Gemma4MultiModalProcessor(BaseMultiModalProcessor[Gemma4ProcessingInfo]): + def _apply_hf_processor_text_only( + self, + prompt_text: str, + tokenization_kwargs: Mapping[str, object], + ) -> list[int]: + # Bypass the HF processor and tokenize directly. The HF + # processor expands multimodal placeholders (<|video|>, etc.) + # via get_text_with_replacements, which raises StopIteration + # when the prompt contains placeholders without matching data. + # The text-only path only needs token IDs, so the tokenizer + # alone is sufficient. + processor = self.info.get_hf_processor() + text_inputs = processor.tokenizer([prompt_text], **tokenization_kwargs) + input_ids = text_inputs["input_ids"] + if not isinstance(input_ids, list): + input_ids = input_ids.tolist() + (prompt_ids,) = input_ids + return prompt_ids + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + merged_kwargs = self.info.ctx.get_merged_mm_kwargs(mm_kwargs) + val, is_top_level_max_soft_tokens = _get_max_soft_tokens(merged_kwargs) + + if val is not None and val not in _SUPPORTED_SOFT_TOKENS: + raise ValueError( + f"Unsupported max_soft_tokens value: {val}. " + f"Valid values are {_SUPPORTED_SOFT_TOKENS}." + ) + + mm_data = dict(mm_data) + + # ---- VIDEO HANDLING ---- + # Gemma4 decomposes video into timestamped image frames. + # Each frame is processed with max_soft_tokens=70 through the + # same vision tower, matching transformers processing_gemma4.py. + video_outputs: dict[str, Any] = {} + if videos := mm_data.pop("videos", []): + processor = self.info.get_hf_processor() + + all_video_pixel_values: list[torch.Tensor] = [] + all_video_position_ids: list[torch.Tensor] = [] + video_num_soft_tokens_per_video: list[list[int]] = [] + video_timestamps_per_video: list[list[float]] = [] + video_frame_counts: list[int] = [] + + video_replacements: list[str] = [] + + for item in videos: + video_array, metadata = item + + # Convert frames to PIL images + if isinstance(video_array, np.ndarray): + frames = [ + PILImage.fromarray(video_array[i]) + for i in range(video_array.shape[0]) + ] + else: + frames = list(video_array) + + # Compute timestamps from metadata (same as transformers) + fps = metadata.get("fps") or 24 + frame_indices = metadata.get("frames_indices", list(range(len(frames)))) + timestamps = [idx / fps for idx in frame_indices] + + # Process frames as images with max_soft_tokens=70 + video_mm_kwargs = dict(mm_kwargs) + video_mm_kwargs["max_soft_tokens"] = _VIDEO_MAX_SOFT_TOKENS + + dummy_prompt = ("\t" + processor.image_token) * len(frames) + + frame_outputs = super()._call_hf_processor( + prompt=dummy_prompt, + mm_data={"images": frames}, + mm_kwargs=video_mm_kwargs, + tok_kwargs=tok_kwargs, + ) + + # Remap HF key name + if "image_position_ids" in frame_outputs: + frame_outputs["pixel_position_ids"] = frame_outputs.pop( + "image_position_ids" + ) + + all_video_pixel_values.append(frame_outputs["pixel_values"]) + all_video_position_ids.append(frame_outputs["pixel_position_ids"]) + + # Compute soft tokens per frame + num_soft_per_frame = [] + for img in frames: + w, h = img.size + n = self.info._compute_num_soft_tokens( + w, h, max_soft_tokens=_VIDEO_MAX_SOFT_TOKENS + ) + num_soft_per_frame.append(n) + + video_num_soft_tokens_per_video.append(num_soft_per_frame) + video_timestamps_per_video.append(timestamps) + video_frame_counts.append(len(frames)) + + # Build expanded replacement text for this video. + ts_strs = [f"{int(s // 60):02d}:{int(s % 60):02d}" for s in timestamps] + replacement = " ".join( + f"{t} {processor.boi_token}" + f"{processor.video_token * n}" + f"{processor.eoi_token}" + for t, n in zip(ts_strs, num_soft_per_frame) + ) + video_replacements.append(replacement) + + # Replace all <|video|> placeholders at once. We split on + # video_token to get N+1 parts, then interleave with the + # N replacement strings. This avoids the iterative + # split-replace bug where replacement text (which itself + # contains <|video|> tokens) collides with later splits. + vt = processor.video_token + parts = prompt.split(vt, len(video_replacements)) + + # NOTE: len(parts) <= len(video_replacements) + 1 + parts_with_repl: list[str] = [] + for part, repl in zip(parts, video_replacements): + parts_with_repl.extend([part, repl]) + parts_with_repl.extend(parts[len(video_replacements) :]) + + prompt = "".join(parts_with_repl) + + video_outputs = { + "pixel_values_videos": torch.cat(all_video_pixel_values, dim=0), + "pixel_position_ids_videos": torch.cat(all_video_position_ids, dim=0), + "video_frame_counts": torch.tensor(video_frame_counts), + "video_num_soft_tokens": video_num_soft_tokens_per_video, + "video_timestamps": video_timestamps_per_video, + } + + # The processor accepts 'audio' not 'audios'. + if "audios" in mm_data: + mm_data["audio"] = mm_data.pop("audios") + + # Warn if any audio waveform exceeds the model's max duration. + if "audio" in mm_data: + processor = self.info.get_hf_processor() + sr = processor.feature_extractor.sampling_rate + max_tokens = processor.audio_seq_length + ms_per_tok = processor.audio_ms_per_token + max_duration_s = max_tokens * ms_per_tok / 1000.0 + audios = mm_data["audio"] + if not isinstance(audios, (list, tuple)): + audios = [audios] + for i, waveform in enumerate(audios): + duration_s = len(waveform) / sr + if duration_s > max_duration_s: + logger.warning( + "Audio duration exceeds max: %f > %f seconds", + duration_s, + max_duration_s, + ) + # vLLM's call_hf_processor (context.py) re-merges + # mm_processor_kwargs from the model config on every call via: + # config_kwargs | incoming_kwargs (right side wins) + # + # If we strip max_soft_tokens from incoming, the re-merge puts + # back the config's global default (e.g. 280), ignoring any + # per-prompt override. Instead, we keep it in the kwargs with + # the validated per-prompt value so it wins during the merge. + # + # NOTE: This requires a corresponding type annotation on the + # HF side (Gemma4ProcessorKwargs.images_kwargs) so that + # _merge_kwargs routes max_soft_tokens into images_kwargs. + patched_mm_kwargs = dict(mm_kwargs) + if val is not None and is_top_level_max_soft_tokens: + patched_mm_kwargs["max_soft_tokens"] = val + + processed_outputs = super()._call_hf_processor( + prompt, + mm_data, + patched_mm_kwargs, + tok_kwargs, + ) + + # HF uses 'image_position_ids'; vLLM uses 'pixel_position_ids'. + # Remap here to keep a single translation point. + if "image_position_ids" in processed_outputs: + processed_outputs["pixel_position_ids"] = processed_outputs.pop( + "image_position_ids" + ) + + if "input_features" in processed_outputs: + # Unpad per-item so each item's cache entry is + # self-contained. The batched() field config in + # _get_mm_fields_config will re-pad all fields to the + # batch's max length at batch time, ensuring consistent + # padding regardless of cache history. + masks = processed_outputs["input_features_mask"] + unpadded_features = [ + f[mask] + for f, mask in zip( + processed_outputs["input_features"], + masks, + ) + ] + unpadded_masks = [mask[mask] for mask in masks] + processed_outputs["input_features"] = unpadded_features + processed_outputs["input_features_padded"] = unpadded_features + processed_outputs["input_features_mask"] = unpadded_masks + + # Merge video outputs into the final result + combined_outputs = dict(processed_outputs, **video_outputs) + return BatchFeature(combined_outputs) + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + fields = dict( + pixel_values=MultiModalFieldConfig.batched("image"), + pixel_position_ids=MultiModalFieldConfig.batched("image"), + input_features_padded=MultiModalFieldConfig.batched("audio"), + input_features_mask=MultiModalFieldConfig.batched("audio"), + ) + + # Video fields: frames stored flat, split per video by + # video_frame_counts. + video_frame_counts = hf_inputs.get("video_frame_counts") + if video_frame_counts is not None: + vfc = video_frame_counts + if not isinstance(vfc, torch.Tensor): + vfc = torch.tensor(vfc) + fields.update( + pixel_values_videos=( + MultiModalFieldConfig.flat_from_sizes("video", vfc) + ), + pixel_position_ids_videos=( + MultiModalFieldConfig.flat_from_sizes("video", vfc) + ), + video_frame_counts=MultiModalFieldConfig.batched( + "video", + ), + video_num_soft_tokens=MultiModalFieldConfig.batched( + "video", keep_on_cpu=True + ), + video_timestamps=MultiModalFieldConfig.batched( + "video", keep_on_cpu=True + ), + ) + + return fields + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + + prompt_updates = [] + + if "image" in mm_items: + # Target image_token (<|image|>) — the single placeholder the + # Gemma4 chat template inserts once per image in the prompt. + # vLLM tokenizes the prompt without token expansion, so only + # one image_token exists per image in the token stream. + # The replacement expands it to the full image sequence + # (boi + N×image_token + eoi, where N = max_soft_tokens). + image_token = hf_processor.image_token + + def get_replacement_image(item_idx: int): + images = mm_items.get_items("image", ImageProcessorItems) + image_size = images.get_image_size(item_idx) + # Resolve the effective max_soft_tokens by merging + # per-prompt kwargs with the config-level defaults, + # consistent with how _call_hf_processor resolves it. + # Without this merge, a missing per-prompt override + # would fall back to vision_cfg.default_output_length + # instead of the config's mm_processor_kwargs default. + merged_kwargs = self.info.ctx.get_merged_mm_kwargs( + hf_processor_mm_kwargs, + ) + val, _ = _get_max_soft_tokens(merged_kwargs) + max_soft_tokens = ( + val + if isinstance(val, int) and val in _SUPPORTED_SOFT_TOKENS + else None + ) + return self.info.get_image_repl( + image_width=image_size.width, + image_height=image_size.height, + processor=hf_processor, + max_soft_tokens=max_soft_tokens, + ) + + prompt_updates.append( + PromptReplacement( + modality="image", + target=image_token, + replacement=get_replacement_image, + ) + ) + + if "video" in mm_items: + video_token = hf_processor.video_token + + def get_replacement_video(item_idx: int): + out_item = out_mm_kwargs["video"][item_idx] + timestamps = out_item["video_timestamps"].data + num_soft = out_item["video_num_soft_tokens"].data + return self.info.get_video_repl( + timestamps=timestamps, + num_soft_tokens_per_frame=num_soft, + processor=hf_processor, + ) + + prompt_updates.append( + PromptReplacement( + modality="video", + target=video_token, + replacement=get_replacement_video, + ) + ) + + if "audio" in mm_items: + audio_token = hf_processor.audio_token + + def get_replacement_audio(item_idx: int): + audios = mm_items.get_items("audio", AudioProcessorItems) + audio_len = audios.get_audio_length(item_idx) + return self.info.get_audio_repl( + audio_len=audio_len, + processor=hf_processor, + ) + + prompt_updates.append( + PromptReplacement( + modality="audio", + target=audio_token, + replacement=get_replacement_audio, + ) + ) + + return prompt_updates + + # NOTE: Gemma3/Gemma3n override _apply_token_matches and + # _find_mm_placeholders to merge adjacent newline tokens that arise + # when full_image_sequence contains "\n\n" wrappers. Gemma4's + # full_image_sequence has NO newlines (just BOI + 280×image_token + + # EOI), so the base class implementations work correctly as-is. + + +# --------------------------------------------------------------------------- +# Multimodal embedder +# --------------------------------------------------------------------------- + + +class Gemma4MultimodalEmbedder(nn.Module): + """Projects vision/audio soft tokens into LM embedding space. + + Architecture: + inputs_embeds → embedding_projection → embedding_post_projection_norm + + Unlike Gemma3n which has separate hard/soft embedding paths with + per-path normalization and a learned embedding table, Gemma4 uses a + simplified 2-layer design: a linear projection followed by RMSNorm + (without learnable scale). The checkpoint confirms this — only + ``embedding_projection.weight`` exists; there is no embedding table + or pre-projection norm weights. + """ + + def __init__( + self, + multimodal_config: Gemma4VisionConfig | Gemma4AudioConfig, + text_config: Gemma4TextConfig, + *, + quant_config: "QuantizationConfig | None" = None, + prefix: str = "", + ): + super().__init__() + + self.eps = multimodal_config.rms_norm_eps + self.text_hidden_size = text_config.hidden_size + + # Audio tower uses output_proj_dims (1536) rather than hidden_size + # (1024); vision uses hidden_size (768) directly. + embedding_dim = ( + getattr(multimodal_config, "output_proj_dims", None) + or multimodal_config.hidden_size + ) + + self.embedding_pre_projection_norm = RMSNorm( + embedding_dim, + eps=self.eps, + has_weight=False, + ) + + self.embedding_projection = ReplicatedLinear( + embedding_dim, + self.text_hidden_size, + bias=False, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "embedding_projection"), + ) + + def forward(self, inputs_embeds: torch.Tensor) -> torch.Tensor: + """Project soft tokens from a multimodal tower into LM space.""" + embs_normed = self.embedding_pre_projection_norm(inputs_embeds) + embs_proj, _ = self.embedding_projection(embs_normed) + return embs_proj + + +# --------------------------------------------------------------------------- +# Main model +# --------------------------------------------------------------------------- + + +@MULTIMODAL_REGISTRY.register_processor( + Gemma4MultiModalProcessor, + info=Gemma4ProcessingInfo, + dummy_inputs=Gemma4DummyInputsBuilder, +) +class Gemma4ForConditionalGeneration( + nn.Module, + SupportsMultiModal, + SupportsQuant, + SupportsPP, + SupportsLoRA, + SupportsEagle3, +): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # Maps checkpoint prefixes to vLLM module paths. + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + # vision tower + "model.vision_tower": "vision_tower", + "model.embed_vision": "embed_vision", + # audio tower + "model.audio_tower.": "audio_tower.", + "model.embed_audio.": "embed_audio.", + # backbone + "model.language_model.": "language_model.model.", + "lm_head.": "language_model.lm_head.", + "model": "language_model.model", + } + ) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + self.config = config + self.quant_config = quant_config + self.multimodal_config = multimodal_config + self.model_dtype = vllm_config.model_config.dtype + + # Only quantize towers when the quant method supports their + # dimensions. BNB/torchao handle arbitrary sizes; other methods + # (Marlin, FP8, …) require dimensions divisible by 64, which + # the vision tower (intermediate_size=4304) does not satisfy. + # TODO(mgoin): remove this by fixing kernel padding. + if quant_config and quant_config.get_name() in [ + "bitsandbytes", + "torchao", + "compressed-tensors", + ]: + tower_quant = quant_config + else: + vision_cfg = config.vision_config + quantizable = ( + vision_cfg.hidden_size % 64 == 0 + and vision_cfg.intermediate_size % 64 == 0 + ) + tower_quant = quant_config if quantizable else None + + # ---- Vision tower (shared by image and video) ---- + with self._mark_tower_model(vllm_config, {"image", "video"}): + self.vision_tower = AutoModel.from_config(config=config.vision_config) + self.embed_vision = Gemma4MultimodalEmbedder( + config.vision_config, + config.text_config, + quant_config=tower_quant, + prefix=maybe_prefix(prefix, "embed_vision"), + ) + recursive_replace_linear( + self.vision_tower, + tower_quant, + prefix=maybe_prefix(prefix, "vision_tower"), + ) + + # ---- Audio tower (variants with audio_config) ---- + if config.audio_config is not None: + with self._mark_tower_model(vllm_config, "audio"): + self.audio_tower = AutoModel.from_config(config=config.audio_config) + # AutoModel.from_config does NOT call post_init(), + # which is needed to initialize buffers that are absent + # from the checkpoint (e.g. inv_timescales for relative + # position embeddings, softcap, gradient_clipping). + self.audio_tower.post_init() + self.embed_audio = Gemma4MultimodalEmbedder( + config.audio_config, + config.text_config, + quant_config=tower_quant, + prefix=maybe_prefix(prefix, "embed_audio"), + ) + recursive_replace_linear( + self.audio_tower, + tower_quant, + prefix=maybe_prefix(prefix, "audio_tower"), + ) + else: + self.audio_tower = None + self.embed_audio = None + + # ---- Language model (vLLM optimised) ---- + with self._mark_language_model(vllm_config): + self.language_model: Gemma4ForCausalLM = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + architectures=["Gemma4ForCausalLM"], + ) + + # Pre-allocate PLE buffer for CUDA graph compatibility. + # Some variants have hidden_size_per_layer_input=None (no PLE). + ple_dim = config.text_config.hidden_size_per_layer_input + if ple_dim is not None and ple_dim > 0: + embed = self.language_model.model.embed_tokens + self.per_layer_embeddings = torch.zeros( + vllm_config.scheduler_config.max_num_batched_tokens, + config.text_config.num_hidden_layers, + ple_dim, + device=next(embed.parameters()).device, + dtype=vllm_config.model_config.dtype, + ) + else: + self.per_layer_embeddings = None + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors + ) + + # --- Precompute full-attention layer indices for bidi clearing --- + self._full_attn_layer_idxs: frozenset[int] = frozenset() + text_config = config.text_config + if getattr(text_config, "use_bidirectional_attention", None) == "vision": + layer_types = getattr(text_config, "layer_types", None) + if layer_types: + self._full_attn_layer_idxs = frozenset( + i for i, lt in enumerate(layer_types) if lt != "sliding_attention" + ) + + # --- MixtureOfExperts delegation to language_model --- + self.expert_weights = self.language_model.expert_weights + self.moe_layers = self.language_model.moe_layers + self.num_moe_layers = self.language_model.num_moe_layers + self.num_logical_experts = self.language_model.num_logical_experts + self.num_physical_experts = self.language_model.num_physical_experts + self.num_local_physical_experts = self.language_model.num_local_physical_experts + self.num_routed_experts = self.language_model.num_routed_experts + self.num_expert_groups = self.language_model.num_expert_groups + self.num_shared_experts = self.language_model.num_shared_experts + self.num_redundant_experts = self.language_model.num_redundant_experts + + gen_cfg = vllm_config.model_config.try_get_generation_config() + self._suppress_token_ids = gen_cfg.get("suppress_tokens") if gen_cfg else None + + # ------------------------------------------------------------------ # + # Input parsing + # ------------------------------------------------------------------ # + + def _parse_and_validate_image_input( + self, **kwargs: object + ) -> Gemma4ImageInputs | None: + pixel_values = kwargs.pop("pixel_values", None) + pixel_position_ids = kwargs.pop("pixel_position_ids", None) + image_embeds = kwargs.pop("image_embeds", None) + assert image_embeds is None, "Gemma4 does not support image_embeds." + if pixel_values is None: + return None + return Gemma4ImagePixelInputs( + pixel_values=pixel_values, + pixel_position_ids=pixel_position_ids, + ) + + def _parse_and_validate_audio_input( + self, **kwargs: object + ) -> Gemma4AudioInputs | None: + input_features_padded = kwargs.pop("input_features_padded", None) + if input_features_padded is None: + return None + input_features_mask = kwargs.pop("input_features_mask", None) + if input_features_mask is None: + return None + return Gemma4AudioInputs( + input_features_padded=input_features_padded, + input_features_mask=input_features_mask, + ) + + def _parse_and_validate_video_input( + self, **kwargs: object + ) -> dict[str, torch.Tensor] | None: + pixel_values_videos = kwargs.pop("pixel_values_videos", None) + pixel_position_ids_videos = kwargs.pop("pixel_position_ids_videos", None) + video_frame_counts = kwargs.pop("video_frame_counts", None) + if pixel_values_videos is None: + return None + return { + "pixel_values_videos": pixel_values_videos, + "pixel_position_ids_videos": pixel_position_ids_videos, + "video_frame_counts": video_frame_counts, + } + + def _parse_and_validate_multimodal_inputs( + self, **kwargs: object + ) -> dict[str, Gemma4ImageInputs | Gemma4AudioInputs | Gemma4VideoInputs | None]: + mm_input_by_modality = {} + for input_key in list(kwargs): + if ( + input_key in ("pixel_values", "image_embeds") + and "image" not in mm_input_by_modality + ): + mm_input_by_modality["image"] = self._parse_and_validate_image_input( + **kwargs + ) + if ( + input_key == "pixel_values_videos" + and "video" not in mm_input_by_modality + ): + mm_input_by_modality["video"] = self._parse_and_validate_video_input( + **kwargs + ) + if ( + input_key == "input_features_padded" + and "audio" not in mm_input_by_modality + ): + mm_input_by_modality["audio"] = self._parse_and_validate_audio_input( + **kwargs + ) + return mm_input_by_modality + + @staticmethod + def _encoder_chunk( + patches_per_item: int, + free_bytes: int, + total_bytes: int, + position_embedding_size: int, + ) -> int: + """Max chunk size whose F.one_hot transient fits in the budget. + + The dominant transient inside HF's ``Gemma4VisionPatchEmbedder. + _position_embeddings`` is + ``F.one_hot(clamped_positions, num_classes=position_embedding_size)`` + with shape ``(chunk, patches, 2, position_embedding_size)``, + int64, plus its simultaneous cast to the position embedding + table dtype. That, not the encoder residual stream, sets peak + memory. + """ + if patches_per_item <= 0: + return 1 + # Half of currently-free, capped at 10% of total so we leave room + # for the rest of profile_run / the subsequent encoder + pooler. + budget = min(free_bytes // 2, total_bytes // 10) + if budget <= 0: + return 1 + # F.one_hot allocates (chunk, patches, 2, pos_emb_size) int64 + # (the inner 2 is the (x, y) coordinate axis, 8 is sizeof(int64)). + # Outer 2x covers the int64 buffer and its concurrent bf16 cast + # plus the matmul output that live alongside it at peak. + cost = patches_per_item * 4 * position_embedding_size * 8 + return max(1, budget // cost) if cost > 0 else 1 + + # ------------------------------------------------------------------ # + # Image processing + # ------------------------------------------------------------------ # + + def _process_image_input( + self, + image_input: Gemma4ImageInputs, + ) -> list[torch.Tensor]: + """Batch-encode images through the vision tower. + + Groups images by patch count (resolution bucket) so each + encoder call processes a uniform-shape batch with no + cross-resolution padding. Pooling and projection are then + applied over a single concatenated tensor for all images. + """ + pixel_values = image_input["pixel_values"] + pixel_position_ids = image_input["pixel_position_ids"] + + vt = self.vision_tower + vision_cfg = self.config.vision_config + pooling_k2 = vision_cfg.pooling_kernel_size**2 + + # Concurrent requests with different image resolutions may + # arrive as a list of per-image tensors, while same-resolution + # batches may arrive as a stacked tensor. + buckets: dict[int, list[tuple[int, torch.Tensor, torch.Tensor]]] = {} + total_images = ( + len(pixel_values) + if isinstance(pixel_values, list) + else pixel_values.shape[0] + ) + + for idx in range(total_images): + pv = pixel_values[idx] + pp = pixel_position_ids[idx] + buckets.setdefault(pv.shape[0], []).append((idx, pv, pp)) + + # Encode each resolution bucket in memory-safe chunks. Re-read + # free memory per bucket because the previous bucket's encoder + # pass has already allocated activations we should account for. + last_hidden_states_map: dict[int, torch.Tensor] = {} + for patches, items in buckets.items(): + free, total = current_platform.mem_get_info() + max_batch_size = min( + len(items), + self._encoder_chunk( + patches, free, total, vision_cfg.position_embedding_size + ), + ) + + for chunk_idx in range(0, len(items), max_batch_size): + chunk_items = items[chunk_idx : chunk_idx + max_batch_size] + + pv_tensor = torch.cat( + [item[1].unsqueeze(0) for item in chunk_items], dim=0 + ) + pp_tensor = torch.cat( + [item[2].unsqueeze(0) for item in chunk_items], dim=0 + ) + pad_tensor = (pp_tensor == -1).all(dim=-1) + + inputs_embeds = vt.patch_embedder( + pv_tensor, + pp_tensor, + pad_tensor, + ).to(self.model_dtype) + encoder_outputs = vt.encoder( + inputs_embeds=inputs_embeds, + attention_mask=~pad_tensor, + pixel_position_ids=pp_tensor, + ) + hidden_states = encoder_outputs.last_hidden_state + + for i, (orig_idx, _, _) in enumerate(chunk_items): + last_hidden_states_map[orig_idx] = hidden_states[i] + + # Pool per image to strip padding and reduce spatial resolution. + all_valid_states: list[torch.Tensor] = [None] * total_images # type: ignore[list-item] + valid_lens = [0] * total_images + + for orig_idx in range(total_images): + chunk_hidden = last_hidden_states_map[orig_idx] + output_length = chunk_hidden.shape[0] // pooling_k2 + + single_hidden = chunk_hidden.unsqueeze(0) + single_pos_ids = pixel_position_ids[orig_idx].unsqueeze(0) + padding_positions = (single_pos_ids == -1).all(dim=-1) + + pooled_states, valid_mask = vt.pooler( + hidden_states=single_hidden, + pixel_position_ids=single_pos_ids, + padding_positions=padding_positions, + output_length=output_length, + ) + valid_states = pooled_states[valid_mask] + + if getattr(vt.config, "standardize", False): + valid_states = (valid_states - vt.std_bias) * vt.std_scale + + all_valid_states[orig_idx] = valid_states + valid_lens[orig_idx] = valid_states.shape[0] + + # Project all images in a single batched call. + flat_valid_states = torch.cat(all_valid_states, dim=0).to(self.model_dtype) + flat_proj_embs = self.embed_vision( + inputs_embeds=flat_valid_states.unsqueeze(0) + ).squeeze(0) + + # Split back into per-image tensors (slicing returns views). + per_image_embeddings: list[torch.Tensor] = [] + offset = 0 + for length in valid_lens: + per_image_embeddings.append(flat_proj_embs[offset : offset + length]) + offset += length + + return per_image_embeddings + + # ------------------------------------------------------------------ # + # Video processing (frames through vision tower) + # ------------------------------------------------------------------ # + + def _process_video_input( + self, + video_input: dict[str, torch.Tensor], + ) -> list[torch.Tensor]: + """Batch-encode video frames through the vision tower. + + Gemma4 has no separate video tower; video frames are images at + lower resolution (max_soft_tokens=70). All frames across all + videos in the batch are encoded together in chunks, then pooled + and projected in a single batched call. + + Returns one concatenated embedding tensor per video (not per + frame), matching the flat_from_sizes grouping that vLLM expects + for embed_multimodal. + """ + pixel_values = video_input["pixel_values_videos"] + pixel_position_ids = video_input["pixel_position_ids_videos"] + frame_counts = video_input["video_frame_counts"] + + vt = self.vision_tower + vision_cfg = self.config.vision_config + pooling_k2 = vision_cfg.pooling_kernel_size**2 + + if isinstance(frame_counts, torch.Tensor): + fc_list = frame_counts.tolist() + else: + fc_list = list(frame_counts) + + total_frames = pixel_values.shape[0] + free, total = current_platform.mem_get_info() + max_batch_size = min( + total_frames, + self._encoder_chunk( + pixel_values.shape[1], + free, + total, + vision_cfg.position_embedding_size, + ), + ) + + padding_positions = (pixel_position_ids == -1).all(dim=-1) + + # Encode frames in chunks bounded by _encoder_chunk. + last_hidden_states_list: list[torch.Tensor] = [] + for i in range(0, total_frames, max_batch_size): + pv_chunk = pixel_values[i : i + max_batch_size] + pp_chunk = pixel_position_ids[i : i + max_batch_size] + pad_chunk = padding_positions[i : i + max_batch_size] + + inputs_embeds = vt.patch_embedder( + pv_chunk, + pp_chunk, + pad_chunk, + ).to(self.model_dtype) + encoder_outputs = vt.encoder( + inputs_embeds=inputs_embeds, + attention_mask=~pad_chunk, + pixel_position_ids=pp_chunk, + ) + last_hidden_states_list.append(encoder_outputs.last_hidden_state) + + last_hidden_states = torch.cat(last_hidden_states_list, dim=0) + + # Pool per frame to strip padding and reduce spatial resolution. + output_length = pixel_values.shape[1] // pooling_k2 + all_frame_valid_states: list[torch.Tensor] = [] + frame_valid_lens: list[int] = [] + + for i in range(total_frames): + single_hidden = last_hidden_states[i].unsqueeze(0) + single_pos_ids = pixel_position_ids[i].unsqueeze(0) + single_pad_pos = padding_positions[i].unsqueeze(0) + + pooled_states, valid_mask = vt.pooler( + hidden_states=single_hidden, + pixel_position_ids=single_pos_ids, + padding_positions=single_pad_pos, + output_length=output_length, + ) + valid_states = pooled_states[valid_mask] + + if getattr(vt.config, "standardize", False): + valid_states = (valid_states - vt.std_bias) * vt.std_scale + + all_frame_valid_states.append(valid_states) + frame_valid_lens.append(valid_states.shape[0]) + + # Project all frames in a single batched call. + flat_valid_states = torch.cat(all_frame_valid_states, dim=0).to( + self.model_dtype + ) + flat_proj_embs = self.embed_vision( + inputs_embeds=flat_valid_states.unsqueeze(0) + ).squeeze(0) + + # Regroup into per-video tensors (slicing returns views). + per_video_embeddings: list[torch.Tensor] = [] + frame_idx = 0 + offset = 0 + for count in fc_list: + video_tokens = sum(frame_valid_lens[frame_idx : frame_idx + count]) + per_video_embeddings.append(flat_proj_embs[offset : offset + video_tokens]) + offset += video_tokens + frame_idx += count + + return per_video_embeddings + + # ------------------------------------------------------------------ # + # Audio processing + # ------------------------------------------------------------------ # + + def _process_audio_input( + self, + audio_input: Gemma4AudioInputs, + ) -> list[torch.Tensor]: + input_features = audio_input["input_features_padded"].squeeze(1) + input_features_mask = audio_input["input_features_mask"].squeeze(1) + + # Run audio tower — mask convention: True=valid, False=padding. + audio_outputs = self.audio_tower(input_features, input_features_mask) + if isinstance(audio_outputs, tuple): + audio_encodings, audio_mask = audio_outputs + else: + audio_encodings = audio_outputs.last_hidden_state + audio_mask = audio_outputs.attention_mask + + # Project into LM embedding space. + audio_features = self.embed_audio(inputs_embeds=audio_encodings) + + # Strip padding per-batch element: only keep valid (non-padding) + # tokens. + per_audio = [] + for enc, mask in zip(audio_features, audio_mask, strict=True): + per_audio.append(enc[mask]) # [num_real, hidden_size] + + return per_audio + + # ------------------------------------------------------------------ # + # MultiModalEmbeddings interface + # ------------------------------------------------------------------ # + + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: + mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) + multimodal_embeddings: list[torch.Tensor] = [] + + for modality, multimodal_input in mm_input_by_modality.items(): + if multimodal_input is None: + continue + if modality == "image": + multimodal_embeddings.extend( + self._process_image_input(multimodal_input) + ) + elif modality == "video": + multimodal_embeddings.extend( + self._process_video_input(multimodal_input) + ) + elif modality == "audio": + multimodal_embeddings.extend( + self._process_audio_input(multimodal_input) + ) + + return multimodal_embeddings + + def embed_input_ids( + self, + input_ids: torch.Tensor, + multimodal_embeddings: MultiModalEmbeddings | None = None, + *, + is_multimodal: torch.Tensor | None = None, + ) -> torch.Tensor: + # Cache per-layer embeddings (PLE) for the language model's + # forward pass. During profiling embed_input_ids is not called, + # so the pre-allocated zeros are used instead. + if self.per_layer_embeddings is not None: + # Mask multimodal tokens (image/audio) to 0 for PLE + # computation (using token_type_ids == 0 as text_mask). + # Replicate this: map image token positions to token 0. + if is_multimodal is not None: + ple_input_ids = torch.where( + is_multimodal.to(input_ids.device, non_blocking=True), + torch.zeros_like(input_ids), + input_ids, + ) + else: + ple_input_ids = input_ids + + per_layer_inputs = self.language_model.model.get_per_layer_inputs( + ple_input_ids + ) + if per_layer_inputs is not None: + per_layer_inputs = per_layer_inputs.reshape( + -1, + self.config.text_config.num_hidden_layers, + self.config.text_config.hidden_size_per_layer_input, + ) + self.per_layer_embeddings[: per_layer_inputs.shape[0]].copy_( + per_layer_inputs + ) + + if multimodal_embeddings is None or is_multimodal is None: + return super().embed_input_ids(input_ids) + + return super().embed_input_ids( + input_ids, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + ) + + # ------------------------------------------------------------------ # + # Forward + # ------------------------------------------------------------------ # + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs: object, + ) -> IntermediateTensors: + if intermediate_tensors is not None: + inputs_embeds = None + + # Select the pre-cached PLEs for this batch (None when PLE + # is disabled for variants without PLE). + per_layer_inputs = ( + self.per_layer_embeddings[: inputs_embeds.shape[0]] + if self.per_layer_embeddings is not None and inputs_embeds is not None + else None + ) + + # Gemma4 bidi: clear mm_prefix_range for full_attention layers. + # Must run here (outside @support_torch_compile boundary) because + # _run_decoder_layers is inside a compiled graph where Python + # side effects are eliminated. + self._clear_mm_prefix_for_full_attn_layers() + + hidden_states = self.language_model.model( + input_ids, + positions, + per_layer_inputs=per_layer_inputs, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **kwargs, + ) + + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + logits = self.language_model.compute_logits(hidden_states) + if logits is not None and self._suppress_token_ids: + logits[:, self._suppress_token_ids] = -float("inf") + return logits + + # ------------------------------------------------------------------ # + # Bidirectional attention helpers + # ------------------------------------------------------------------ # + + def _clear_mm_prefix_for_full_attn_layers(self) -> None: + """Clear mm_prefix_range for non-sliding layers. + + Gemma4 with use_bidirectional_attention='vision' applies + bidirectional attention only to sliding_attention layers. + Full attention layers use plain causal masking. + + Uses _full_attn_layer_idxs (precomputed in __init__) for O(1) + lookup instead of per-call regex parsing. + """ + if not self._full_attn_layer_idxs: + return + + from vllm.forward_context import get_forward_context + + attn_metadata = get_forward_context().attn_metadata + if attn_metadata is None: + return + + def _process(metadata_dict: dict) -> None: + for layer_name, metadata in metadata_dict.items(): + if ".layers." not in layer_name: + continue + try: + layer_idx = int(layer_name.split(".layers.")[1].split(".")[0]) + except (ValueError, IndexError): + continue + if layer_idx in self._full_attn_layer_idxs: + if hasattr(metadata, "mm_prefix_range"): + metadata.mm_prefix_range = None + if hasattr(metadata, "mm_prefix_range_tensor"): + metadata.mm_prefix_range_tensor = None + + if isinstance(attn_metadata, list): + for ub_metadata in attn_metadata: + _process(ub_metadata) + elif isinstance(attn_metadata, dict): + _process(attn_metadata) + + # ------------------------------------------------------------------ # + # Weight loading + # ------------------------------------------------------------------ # + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + # Some checkpoints have vestigial embed_vision.embedding and + # embed_audio.embedding weights from the Gemma3n architecture + # that are not used by Gemma4's MultimodalEmbedder (which only + # has embedding_projection + embedding_post_projection_norm). + ignore_prefixes = [ + "embed_vision.embedding.", + "embed_audio.embedding.", + ] + # Models without audio tower should skip audio weights entirely. + if self.audio_tower is None: + ignore_prefixes.extend( + [ + "audio_tower.", + "embed_audio.", + ] + ) + loader = AutoWeightsLoader( + self, + ignore_unexpected_prefixes=ignore_prefixes, + ) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + + # ------------------------------------------------------------------ # + # LoRA / multimodal mapping + # ------------------------------------------------------------------ # + + def get_mm_mapping(self) -> MultiModelKeys: + """Get the module prefix mapping for multimodal models.""" + connectors = ["embed_vision"] + tower_models = ["vision_tower"] + if self.audio_tower is not None: + connectors.append("embed_audio") + tower_models.append("audio_tower") + + return MultiModelKeys.from_string_field( + language_model="language_model", + connector=connectors, + tower_model=tower_models, + ) + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> str | None: + if modality == "image": + return "" + if modality == "audio": + return "" + if modality == "video": + return "<|video|>" + raise ValueError(f"Unsupported modality: {modality}") diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 9ce3af62c8..d935f2c34f 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -330,6 +330,7 @@ ), "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"), "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"), # noqa: E501 + "Gemma4ForConditionalGeneration": ("gemma4_mm", "Gemma4ForConditionalGeneration"), # noqa: E501 "Gemma3nForConditionalGeneration": ( "gemma3n_mm", "Gemma3nForConditionalGeneration", diff --git a/vllm/model_executor/models/transformers/utils.py b/vllm/model_executor/models/transformers/utils.py index 0aca8a64ee..e83cb86cbc 100644 --- a/vllm/model_executor/models/transformers/utils.py +++ b/vllm/model_executor/models/transformers/utils.py @@ -141,6 +141,34 @@ def replace_linear_class( VllmConv = Conv2dLayer | Conv3dLayer +def recursive_replace_linear( + model: nn.Module, + quant_config: "QuantizationConfig | None", + prefix: str = "", +): + """Recursively replace linear modules in the model as needed.""" + from vllm.model_executor.models.utils import maybe_prefix + + def _recursive_replace(module: nn.Module, prefix: str): + for child_name, child_module in module.named_children(): + new_module = child_module + qual_name = maybe_prefix(prefix, child_name) + # Replace modules as needed + if isinstance(child_module, nn.Linear): + new_module = replace_linear_class( + child_module, + "replicate", + quant_config, + prefix=qual_name, + ) + else: + _recursive_replace(child_module, prefix=qual_name) + if new_module is not child_module: + setattr(module, child_name, new_module) + + _recursive_replace(model, prefix=prefix) + + def replace_conv_class(conv: TorchConv) -> VllmConv | TorchConv: """Replace a Transformers Conv2d/Conv3d with vLLM's Conv2d/Conv3d. From 2aeb1cba22dd6a9a4b39c21c3b958f65bf770f59 Mon Sep 17 00:00:00 2001 From: Phil Date: Sat, 6 Jun 2026 00:28:57 +0000 Subject: [PATCH 14/25] [gemma4] Runtime serve fixes: activation, proportional RoPE, MM guard MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Found while bringing up google/gemma-4-31B-it-qat-w4a16-ct on the V100 SM70 W4A16 path (loads via SM70TurboMindLinearKernel; full graph builds): - activation: register gelu_pytorch_tanh -> GeluAndMul(approximate="tanh") (gemma4 looks it up by name via the generic act-and-mul registry; gemma3 special-cased it inline). - rotary: vendor rotary_embedding/gemma4_rope.py (Gemma4RotaryEmbedding) and wire the get_rope `proportional` branch (gemma4 global/full attention). - gemma4_mm: guard get_merged_mm_kwargs (newer InputProcessingContext method absent on this base) -> fall back to the config default soft-token count. Known remaining: vision-tower weight mapping (std_bias / HF-built tower param names) — MM-specific, LM weights load fine. Co-Authored-By: RivetOS Claude --- vllm/model_executor/layers/activation.py | 3 + .../layers/rotary_embedding/__init__.py | 12 +++ .../layers/rotary_embedding/gemma4_rope.py | 84 +++++++++++++++++++ vllm/model_executor/models/gemma4_mm.py | 14 +++- 4 files changed, 109 insertions(+), 4 deletions(-) create mode 100644 vllm/model_executor/layers/rotary_embedding/gemma4_rope.py diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 17b5072985..1ba0abc9d7 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -680,6 +680,9 @@ def get_act_fn(act_fn_name: str) -> nn.Module: "gelu": lambda: GeluAndMul(), "silu": lambda: SiluAndMul(), "geglu": lambda: GeluAndMul(), + # Gemma uses tanh-approx GeGLU; gemma4 looks it up by name (gemma3 + # special-cased it inline). Register it for the generic path. + "gelu_pytorch_tanh": lambda: GeluAndMul(approximate="tanh"), "swigluoai": lambda *args, **kwargs: SwigluOAIAndMul(*args, **kwargs), } ) diff --git a/vllm/model_executor/layers/rotary_embedding/__init__.py b/vllm/model_executor/layers/rotary_embedding/__init__.py index 2d50ff550d..e334cb3681 100644 --- a/vllm/model_executor/layers/rotary_embedding/__init__.py +++ b/vllm/model_executor/layers/rotary_embedding/__init__.py @@ -13,6 +13,7 @@ from .dynamic_ntk_scaling_rope import DynamicNTKScalingRotaryEmbedding from .fope import FourierRotaryEmbedding from .linear_scaling_rope import LinearScalingRotaryEmbedding +from .gemma4_rope import Gemma4RotaryEmbedding from .llama3_rope import Llama3RotaryEmbedding from .llama4_vision_rope import Llama4VisionRotaryEmbedding from .mrope import MRotaryEmbedding @@ -151,6 +152,17 @@ def get_rope( high_freq_factor, original_max_position, ) + elif scaling_type == "proportional": + # Proportional RoPE is used by Gemma4 for global (full) attention + # (sparse/fractional RoPE with cross-mixing between halves). + rotary_emb = Gemma4RotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + dtype, + ) elif scaling_type == "mllama4": rotary_emb = Llama4VisionRotaryEmbedding( head_size, rotary_dim, max_position, base, is_neox_style, dtype diff --git a/vllm/model_executor/layers/rotary_embedding/gemma4_rope.py b/vllm/model_executor/layers/rotary_embedding/gemma4_rope.py new file mode 100644 index 0000000000..48253f469c --- /dev/null +++ b/vllm/model_executor/layers/rotary_embedding/gemma4_rope.py @@ -0,0 +1,84 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Gemma4-specific Rotary Positional Embeddings (proportional scaling). + +Gemma4 uses "proportional" RoPE which computes inv_freq frequencies scaled +by head_dim (not rotary_dim), and zero-pads for non-rotated dimensions when +partial_rotary_factor < 1. The actual rotation uses standard neox-style +rotate_half, matching HF transformers' apply_rotary_pos_emb. +""" + +import torch + +from .base import RotaryEmbedding + + +class Gemma4RotaryEmbedding(RotaryEmbedding): + """Gemma4 proportional RoPE. + + Extends RotaryEmbedding (which provides standard neox-style rotation + via ops.rotary_embedding CUDA kernel) but overrides the inv_freq + computation to match HF's _compute_proportional_rope_parameters: + - Frequency exponents use head_dim (not rotary_dim) as denominator + - Non-rotated dims are zero-padded (cos=1, sin=0 = identity rotation) + + When partial_rotary_factor=1.0 (the default for some variants), ALL dims are + rotated and this is equivalent to standard RotaryEmbedding with + head_dim-scaled frequencies. + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: float, + is_neox_style: bool, + dtype: torch.dtype, + ) -> None: + # Number of rotation angle pairs (from partial_rotary_factor) + self.rope_angles = rotary_dim // 2 + # Non-rotated angle pairs per half + self.nope_angles = (head_size // 2) - self.rope_angles + + # Important: set rotary_dim = head_size so the base class's + # forward_static applies rotation to ALL dims of the cos/sin cache. + # The non-rotated dims will have cos=1, sin=0 (identity) thanks + # to our _compute_inv_freq zero-padding. + super().__init__( + head_size, + head_size, # rotary_dim = head_size (full application) + max_position_embeddings, + base, + is_neox_style, + dtype, + ) + + def _compute_inv_freq(self, base: float) -> torch.Tensor: + """Compute frequencies matching HF proportional RoPE. + + Key difference from base: exponent denominator is head_size (not + rotary_dim), and non-rotated dims are zero-padded. + """ + # HF formula: base ** (arange(0, 2*rope_angles, 2) / head_dim) + freq_exponents = ( + torch.arange(0, 2 * self.rope_angles, 2, dtype=torch.float) / self.head_size + ) + inv_freq = 1.0 / (base**freq_exponents) + + # Zero-pad for non-rotated dims (identity rotation: cos=1, sin=0) + if self.nope_angles > 0: + inv_freq = torch.cat( + [ + inv_freq, + torch.zeros(self.nope_angles, dtype=torch.float), + ] + ) + return inv_freq + + def extra_repr(self) -> str: + s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" + s += f", rope_angles={self.rope_angles}, nope_angles={self.nope_angles}" + s += f", max_position_embeddings={self.max_position_embeddings}" + s += f", base={self.base}, is_neox_style={self.is_neox_style}" + return s diff --git a/vllm/model_executor/models/gemma4_mm.py b/vllm/model_executor/models/gemma4_mm.py index 49846b121e..ddaa843297 100644 --- a/vllm/model_executor/models/gemma4_mm.py +++ b/vllm/model_executor/models/gemma4_mm.py @@ -244,10 +244,16 @@ def get_mm_max_tokens_per_item( # After padding is stripped the actual count is ≤ this value, but # vLLM needs the max for memory planning. tokens_per_image = config.vision_config.default_output_length - merged_kwargs = self.ctx.get_merged_mm_kwargs({}) - val, _ = _get_max_soft_tokens(merged_kwargs) - if isinstance(val, int) and val in _SUPPORTED_SOFT_TOKENS: - tokens_per_image = val + # NOTE(rivet): get_merged_mm_kwargs is a newer InputProcessingContext + # method this base lacks; it only overrides the default soft-token count. + # Degrade gracefully to the config default when unavailable. + try: + merged_kwargs = self.ctx.get_merged_mm_kwargs({}) + val, _ = _get_max_soft_tokens(merged_kwargs) + if isinstance(val, int) and val in _SUPPORTED_SOFT_TOKENS: + tokens_per_image = val + except AttributeError: + pass tokens: dict[str, int] = {"image": tokens_per_image} if config.audio_config is not None: # Audio max tokens from the processor's audio_seq_length. From b99d0db1140b0e03d20bc35afe4afb6c936af870 Mon Sep 17 00:00:00 2001 From: Phil Date: Sat, 6 Jun 2026 00:52:46 +0000 Subject: [PATCH 15/25] [gemma4] Load vision-tower std_bias/std_scale persistent buffers The HF Gemma4VisionModel (standardize=True) registers std_bias/std_scale as persistent BUFFERS at the tower root; they're in the checkpoint and used at runtime ((states - std_bias) * std_scale). vLLM's AutoWeightsLoader only loads nn.Parameters (+ a BatchNorm-only buffer rescue), so it raised "There is no module or parameter named vision_tower.std_bias". Fix: in Gemma4ForConditionalGeneration.load_weights, intercept the two checkpoint keys model.vision_tower.std_{bias,scale}, copy_ them into the registered buffers, and pass the rest to AutoWeightsLoader unchanged (LM load path byte-identical). Also fix the cosmetic 'vision_towerencoder' missing-dot in the AutoWeightsLoader error message (base_prefix + k -> _get_qualname). Co-Authored-By: RivetOS Claude --- vllm/model_executor/models/gemma4_mm.py | 42 ++++++++++++++++++++++++- vllm/model_executor/models/utils.py | 3 +- 2 files changed, 43 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/gemma4_mm.py b/vllm/model_executor/models/gemma4_mm.py index ddaa843297..34bcff2d6a 100644 --- a/vllm/model_executor/models/gemma4_mm.py +++ b/vllm/model_executor/models/gemma4_mm.py @@ -1679,11 +1679,51 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: "embed_audio.", ] ) + + # ---- Persistent-buffer fix for the SigLIP-style vision tower ---- + # The HF Gemma4VisionModel registers `std_bias` / `std_scale` as + # *persistent buffers* at the tower root (config.vision_config. + # standardize=True). They are present in the checkpoint + # (model.vision_tower.std_{bias,scale}) and used at runtime in + # _process_image_input -> (states - std_bias) * std_scale. + # vLLM's AutoWeightsLoader only routes weights to child submodules + # or to nn.Parameters from named_parameters(recurse=False); generic + # persistent buffers are invisible to it (the only buffer rescue, + # _add_loadable_non_param_tensors, handles nn.BatchNorm* only). So + # these two tensors would hit the else-branch and raise + # "There is no module or parameter named vision_tower.std_bias". + # Load them here directly, then hand the rest to AutoWeightsLoader. + # Keyed on the *checkpoint* name because the WeightsMapper is applied + # inside loader.load_weights(), so `weights` still carries the + # original "model.vision_tower.*" names at this point. + buffer_targets = {} + if getattr(self, "vision_tower", None) is not None: + for buf_name in ("std_bias", "std_scale"): + if hasattr(self.vision_tower, buf_name): + buffer_targets[f"model.vision_tower.{buf_name}"] = ( + f"vision_tower.{buf_name}", + getattr(self.vision_tower, buf_name), + ) + + loaded_buffers: set[str] = set() + remaining_weights = [] + for name, weight in weights: + target = buffer_targets.get(name) + if target is not None: + mapped_name, buf = target + buf.data.copy_(weight.to(buf.dtype)) + loaded_buffers.add(mapped_name) + continue + remaining_weights.append((name, weight)) + loader = AutoWeightsLoader( self, ignore_unexpected_prefixes=ignore_prefixes, ) - return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + loaded = loader.load_weights( + remaining_weights, mapper=self.hf_to_vllm_mapper + ) + return loaded | loaded_buffers # ------------------------------------------------------------------ # # LoRA / multimodal mapping diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index c6ea136ddf..6906584aa2 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -315,7 +315,8 @@ def _load_module( continue desc_param_keys = { - base_prefix + k for k, _ in module.named_parameters(recurse=True) + self._get_qualname(base_prefix, k) + for k, _ in module.named_parameters(recurse=True) } msg = ( f"There is no module or parameter named {prefix!r} " From 694a9c2c75e459b85571068ae87098f9090e7046 Mon Sep 17 00:00:00 2001 From: Phil Date: Sat, 6 Jun 2026 01:04:25 +0000 Subject: [PATCH 16/25] [gemma4] Fix inference crash: don't assign read-only mm_prefix_range_tensor In _clear_mm_prefix_range for full-attention layers, the port set both metadata.mm_prefix_range = None and metadata.mm_prefix_range_tensor = None. But in this fork's TritonAttentionMetadata, mm_prefix_range_tensor is a read-only @property derived from mm_prefix_range (no setter) -> AttributeError at first forward. Clearing the source dict already nulls the derived tensor; drop the broken assignment. With this + serving on TRITON_ATTN (gemma-4's 512-dim global-attention layers exceed FLASH_ATTN_V100's D<=256 cap), gemma-4-31B-it-qat-w4a16-ct generates coherent text on the dual V100 via the SM70 TurboMind W4A16 path. Co-Authored-By: RivetOS Claude --- vllm/model_executor/models/gemma4_mm.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/gemma4_mm.py b/vllm/model_executor/models/gemma4_mm.py index 34bcff2d6a..a232e09851 100644 --- a/vllm/model_executor/models/gemma4_mm.py +++ b/vllm/model_executor/models/gemma4_mm.py @@ -1647,10 +1647,11 @@ def _process(metadata_dict: dict) -> None: except (ValueError, IndexError): continue if layer_idx in self._full_attn_layer_idxs: + # mm_prefix_range_tensor is a read-only @property derived + # from mm_prefix_range in this fork's TritonAttentionMetadata + # (no setter); clearing the source dict suffices. if hasattr(metadata, "mm_prefix_range"): metadata.mm_prefix_range = None - if hasattr(metadata, "mm_prefix_range_tensor"): - metadata.mm_prefix_range_tensor = None if isinstance(attn_metadata, list): for ub_metadata in attn_metadata: From a54babb757f765aa0b069f7eed6e85eb93cfec44 Mon Sep 17 00:00:00 2001 From: Phil Date: Sat, 6 Jun 2026 01:18:03 +0000 Subject: [PATCH 17/25] [gemma4][perf] FlashAttnV100Backend.supports_head_size override (mixed backend) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit FlashAttnV100Backend defined get_supported_head_sizes()->[64,128,256] but that method is never called by validate_configuration() — which uses supports_head_size(). FA-V100 inherited TritonAttentionBackend.supports_head_size (head_size >= 32), so it wrongly validated head_size=512 and would hard-crash the Volta CUDA kernel (TORCH_CHECK D<=256). Add the override returning {64,128,256}. This lets vLLM's per-layer backend auto-selection route gemma-4's 50 sliding (head_dim=256) layers to the fast FA-V100 kernel and fall through to TRITON_ATTN only for the 10 global (head_dim=512) layers — instead of forcing Triton on all 60. Deploy by DROPPING --attention-backend so auto-select runs. Co-Authored-By: RivetOS Claude --- vllm/v1/attention/backends/flash_attn_v100.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/vllm/v1/attention/backends/flash_attn_v100.py b/vllm/v1/attention/backends/flash_attn_v100.py index 1e3ffe3f18..0fd10026ac 100644 --- a/vllm/v1/attention/backends/flash_attn_v100.py +++ b/vllm/v1/attention/backends/flash_attn_v100.py @@ -970,3 +970,14 @@ def get_name() -> str: def get_supported_head_sizes() -> list[int]: # Keep this aligned with the dense prefill kernel dispatch table. return [64, 128, 256] + + @classmethod + def supports_head_size(cls, head_size: int) -> bool: + # NOTE(rivet): validate_configuration() calls supports_head_size(), + # NOT get_supported_head_sizes(). Without this override we'd inherit + # TritonAttentionBackend.supports_head_size (head_size >= 32) and + # wrongly accept head_size=512, then hard-crash the Volta CUDA kernel + # (TORCH_CHECK D<=256). The Volta FA kernel only handles 64/128/256. + # With this, auto-selection routes gemma-4's 256-dim sliding layers + # here and falls through to TRITON_ATTN for its 512-dim global layers. + return head_size in (64, 128, 256) From efb414a3a408d627e8f45e541318eaa245d32a14 Mon Sep 17 00:00:00 2001 From: Phil Date: Sun, 7 Jun 2026 03:43:46 +0000 Subject: [PATCH 18/25] [gemma4] Register gemma4_assistant config so MTP drafter loads MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The gemma4 MTP drafter checkpoint advertises model_type "gemma4_assistant", which Transformers' AutoConfig does not recognize. SpeculativeConfig's gemma4_assistant -> gemma4_mtp remap (config/speculative.py) runs only after the draft config is loaded, so get_config fell through to AutoConfig and raised a ValidationError before the remap could fire. Resolve gemma4_assistant to the multimodal Gemma4Config via _CONFIG_REGISTRY (it carries the .text_config the remap expects). Verified on 2×V100: the drafter now loads, the engine boots healthy, and serves. (Output-quality / 0%-draft-acceptance is a separate downstream issue in the KV-sharing path.) Co-Authored-By: RivetOS Claude --- vllm/transformers_utils/config.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 4da3191feb..ddebc1b5e1 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -106,6 +106,20 @@ def __getitem__(self, key): tarsier2="Tarsier2Config", ) +# The gemma4 MTP drafter checkpoint advertises model_type "gemma4_assistant", +# which Transformers' AutoConfig does not recognize. The gemma4_assistant -> +# gemma4_mtp remap in SpeculativeConfig runs only *after* the config is loaded, +# so without this entry get_config falls through to AutoConfig and raises before +# the remap can fire. Resolve it to the multimodal Gemma4Config (it carries the +# .text_config the remap expects). LazyConfigDict.__getitem__ returns a type +# value as-is, so registering the class directly is fine. +try: + from transformers import Gemma4Config as _Gemma4Config + + _CONFIG_REGISTRY["gemma4_assistant"] = _Gemma4Config +except ImportError: + pass + _CONFIG_ATTRS_MAPPING: dict[str, str] = { "llm_config": "text_config", } From f3376f602aaee3d5cd2ce5fd8d6188c6ab92c0ad Mon Sep 17 00:00:00 2001 From: Phil Date: Sun, 7 Jun 2026 12:25:11 +0000 Subject: [PATCH 19/25] [gemma4][mtp] Fix target KV-cache corruption from draft layers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The gemma4 MTP draft layers are Q-only and must read the target's KV read-only via cross-model KV sharing. _setup_gemma4_kv_sharing set kv_sharing_target_layer_name on the attention *module* only, but the backend impl captured that value at construction (None for draft layers, which are built before target layer names are known). The KV-write gate checks the *impl's* copy (e.g. triton_attn.py: `if self.kv_sharing_target_layer_name is None: `), so the draft layers wrote their draft K/V into the target's shared layer-N slots, poisoning the target's verify pass — output was correct for the first decoded token then degenerated into garbage, with ~0% draft acceptance. Propagate the target-layer name to attn.impl as well so the write is correctly skipped. Verified on 2xV100 (gemma-4-31b-qat-w4a16 + assistant drafter): output now matches target-only generation and draft acceptance is 46-73% (mean accept len 1.47-1.73). Co-Authored-By: RivetOS Claude --- vllm/v1/spec_decode/gemma4.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/vllm/v1/spec_decode/gemma4.py b/vllm/v1/spec_decode/gemma4.py index 242cd2da24..05aff16603 100644 --- a/vllm/v1/spec_decode/gemma4.py +++ b/vllm/v1/spec_decode/gemma4.py @@ -334,6 +334,18 @@ def _setup_gemma4_kv_sharing( target_idx = candidates[-1] target_layer_name = f"{target_prefix}.{target_idx}.self_attn.attn" attn.kv_sharing_target_layer_name = target_layer_name + # The backend impl captured kv_sharing_target_layer_name at + # construction (None for these draft layers), and the KV-write gate + # checks the *impl's* copy (e.g. triton_attn.py: `if + # self.kv_sharing_target_layer_name is None: `). Setting + # only the module attr leaves the impl thinking it owns the cache, so + # the Q-only draft layer writes its draft K/V into the *target's* + # shared layer-N slots and corrupts the target's verify pass (output + # garbles after the first decoded token). Propagate to the impl so + # the write is correctly skipped — the draft layer reads the target's + # KV read-only, as intended. + if getattr(attn, "impl", None) is not None: + attn.impl.kv_sharing_target_layer_name = target_layer_name logger.info( "Gemma4 MTP: draft layer %d (%s) -> %s", draft_idx, From 474f985595214de95b805afe99be834ed78c8ebd Mon Sep 17 00:00:00 2001 From: Phil Date: Sun, 7 Jun 2026 14:15:24 +0000 Subject: [PATCH 20/25] [gemma4] Load tower clip buffers + MTP ordered-embedding buffer The E2B/E4B models (and their MTP assistants) exercise gemma4 efficiency features the 31B did not, all hitting the same root cause: vLLM's AutoWeightsLoader cannot place plain register_buffer tensors. - gemma4_mm.py: Gemma4ClippableLinear (use_clipped_linears=True, in BOTH the vision and audio encoders) registers input_min/input_max/output_min/ output_max activation clamps as buffers. Generalize the existing vision std_bias/std_scale buffer-load to walk every persistent buffer in both towers, so the audio tower loads too. - gemma4_mtp.py: assistants with use_ordered_embeddings=True (E2B/E4B drafters) carry masked_embedding.token_ordering (token->centroid map) as a buffer. Load masked_embedding.* buffers before AutoWeightsLoader. Verified on 2xV100: gemma-4-E2B-it and gemma-4-E4B-it both serve as full multimodal targets with their matched MTP assistants, clean output, draft acceptance 62-76%. Co-Authored-By: RivetOS Claude --- vllm/model_executor/models/gemma4_mm.py | 56 ++++++++++++++---------- vllm/model_executor/models/gemma4_mtp.py | 27 +++++++++++- 2 files changed, 58 insertions(+), 25 deletions(-) diff --git a/vllm/model_executor/models/gemma4_mm.py b/vllm/model_executor/models/gemma4_mm.py index a232e09851..62469e13e8 100644 --- a/vllm/model_executor/models/gemma4_mm.py +++ b/vllm/model_executor/models/gemma4_mm.py @@ -1681,30 +1681,34 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: ] ) - # ---- Persistent-buffer fix for the SigLIP-style vision tower ---- - # The HF Gemma4VisionModel registers `std_bias` / `std_scale` as - # *persistent buffers* at the tower root (config.vision_config. - # standardize=True). They are present in the checkpoint - # (model.vision_tower.std_{bias,scale}) and used at runtime in - # _process_image_input -> (states - std_bias) * std_scale. - # vLLM's AutoWeightsLoader only routes weights to child submodules - # or to nn.Parameters from named_parameters(recurse=False); generic - # persistent buffers are invisible to it (the only buffer rescue, - # _add_loadable_non_param_tensors, handles nn.BatchNorm* only). So - # these two tensors would hit the else-branch and raise - # "There is no module or parameter named vision_tower.std_bias". - # Load them here directly, then hand the rest to AutoWeightsLoader. - # Keyed on the *checkpoint* name because the WeightsMapper is applied - # inside loader.load_weights(), so `weights` still carries the - # original "model.vision_tower.*" names at this point. + # ---- Persistent-buffer fix for the vision/audio towers ---- + # The HF Gemma4 towers register several *persistent buffers* that + # AutoWeightsLoader cannot see (it routes only to child submodules and + # nn.Parameters; generic buffers are invisible — its only buffer rescue, + # _add_loadable_non_param_tensors, handles nn.BatchNorm* only): + # - SigLIP vision tower: std_bias / std_scale (standardize=True) + # - Gemma4ClippableLinear (use_clipped_linears=True, in BOTH the + # vision and audio encoders): input_min / input_max / output_min / + # output_max activation clamps, one set per clipped linear. + # Without this they hit the else-branch and raise e.g. "There is no + # module or parameter named vision_tower...down_proj.input_max". + # Load every checkpoint-backed buffer in each tower here, then hand the + # rest to AutoWeightsLoader. Keyed on the *checkpoint* name + # (model..*) because the WeightsMapper is applied inside + # loader.load_weights(), so `weights` still carries the original names. + # named_buffers() also surfaces non-persistent buffers (e.g. + # inv_timescales, softcap); those simply never match a checkpoint key + # and keep their post_init() values. buffer_targets = {} - if getattr(self, "vision_tower", None) is not None: - for buf_name in ("std_bias", "std_scale"): - if hasattr(self.vision_tower, buf_name): - buffer_targets[f"model.vision_tower.{buf_name}"] = ( - f"vision_tower.{buf_name}", - getattr(self.vision_tower, buf_name), - ) + for tower_attr in ("vision_tower", "audio_tower"): + tower = getattr(self, tower_attr, None) + if tower is None: + continue + for buf_name, buf in tower.named_buffers(): + buffer_targets[f"model.{tower_attr}.{buf_name}"] = ( + f"{tower_attr}.{buf_name}", + buf, + ) loaded_buffers: set[str] = set() remaining_weights = [] @@ -1712,7 +1716,11 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: target = buffer_targets.get(name) if target is not None: mapped_name, buf = target - buf.data.copy_(weight.to(buf.dtype)) + w = weight.to(buf.dtype) + if buf.shape == w.shape: + buf.data.copy_(w) + else: + buf.data = w.to(buf.device) loaded_buffers.add(mapped_name) continue remaining_weights.append((name, weight)) diff --git a/vllm/model_executor/models/gemma4_mtp.py b/vllm/model_executor/models/gemma4_mtp.py index 03961cac19..a848e4bd94 100644 --- a/vllm/model_executor/models/gemma4_mtp.py +++ b/vllm/model_executor/models/gemma4_mtp.py @@ -623,5 +623,30 @@ def get_top_tokens( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: self._stable_full_lm_head_weight = None + # The masked embedder's `token_ordering` (token->centroid map) is a + # register_buffer, which AutoWeightsLoader cannot place (it routes only + # to child submodules / nn.Parameters). Without this, assistants with + # use_ordered_embeddings=True (e.g. the E2B/E4B drafters) raise "no + # module or parameter named masked_embedding.token_ordering". Load any + # masked_embedding.* buffer here, then hand the rest to AutoWeightsLoader. + # Scoped to masked_embedding.* so the predictor's own buffer-aware + # load_weights is left untouched. Checkpoint keys here are pre-mapper, + # but masked_embedding.* is identity under hf_to_vllm_mapper. + me_buffers = { + n: b for n, b in self.named_buffers() if n.startswith("masked_embedding.") + } + loaded_buffers: set[str] = set() + remaining = [] + for name, w in weights: + buf = me_buffers.get(name) + if buf is not None: + w = w.to(buf.dtype) + if buf.shape == w.shape: + buf.data.copy_(w) + else: + buf.data = w.to(buf.device) + loaded_buffers.add(name) + continue + remaining.append((name, w)) loader = AutoWeightsLoader(self) - return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + return loader.load_weights(remaining, mapper=self.hf_to_vllm_mapper) | loaded_buffers From 2b93dd9ce501c33fcff4ede681e4628c889a54e0 Mon Sep 17 00:00:00 2001 From: Phil Date: Mon, 8 Jun 2026 01:47:49 +0000 Subject: [PATCH 21/25] [gemma4][FA] Sliding-window support in Volta FLASH_ATTN_V100 kernels MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Plumb a `window` param (attended-token count; -1 = unlimited) through the decode-paged, prefill-paged, and dense forward Volta kernels + relax the backend gate so causal sliding-window (right==0) layers run on flash instead of falling back to Triton. - flash_decode_paged.cu: per-token window mask (warp-uniform skip of the dot) + whole-partition early-out that writes neutral stats so the cross-partition reduce stays correct. - fused_mha_forward_paged.cu + fused_mha_forward.cu: window term in the causal mask (global_q_pos - global_n < window); relax dense kernel's window_size_left==-1 TORCH_CHECK + the python guard in flash_attn_interface. - flash_attn_v100.py backend: gate accepts (-1,-1) or right==0 windows; add _flash_window = sliding_window[0]+1; pass to all paged/dense flash calls. - test_window.py: standalone fp32-reference parity (decode + paged/dense prefill, full & windowed, D=128/256, edge cases) — all pass, err ~1e-3. Validated correct in-model (coherent output, sliding->flash / global->triton auto-route). NOTE: benchmarks show flash-auto currently LOSES to forced-Triton at long context (dead-partition launch overhead masks the window work-reduction). Kept available behind --attention-backend; see memory v100-vllm-own-fork. Co-Authored-By: RivetOS Claude --- .../flash_attn_v100/flash_attn_interface.py | 8 + flash-attention-v100/include/fused_mha.h | 6 +- .../kernel/flash_decode_paged.cu | 31 +++- .../kernel/fused_mha_forward.cu | 31 ++-- .../kernel/fused_mha_forward_paged.cu | 21 ++- flash-attention-v100/test_window.py | 138 ++++++++++++++++++ vllm/v1/attention/backends/flash_attn_v100.py | 26 +++- 7 files changed, 238 insertions(+), 23 deletions(-) create mode 100644 flash-attention-v100/test_window.py diff --git a/flash-attention-v100/flash_attn_v100/flash_attn_interface.py b/flash-attention-v100/flash_attn_v100/flash_attn_interface.py index 8da9a357ea..0134a360f6 100644 --- a/flash-attention-v100/flash_attn_v100/flash_attn_interface.py +++ b/flash-attention-v100/flash_attn_v100/flash_attn_interface.py @@ -184,6 +184,10 @@ def forward( if causal and (window_size_left != -1 or window_size_right != -1): if window_size_left > 0 and window_size_right > 0: window_size_left, window_size_right = -1, -1 + elif window_size_left >= 0 and window_size_right == 0: + # Causal sliding-window: query attends to window_size_left + 1 + # tokens. Supported by the Volta dense kernel. + pass else: raise NotImplementedError(f"Unsupported window_size={window_size} with causal=True") @@ -292,6 +296,7 @@ def flash_attn_decode_paged( kv_cache_dtype: str = "auto", k_scale: float = 1.0, v_scale: float = 1.0, + window: int = -1, ): if softmax_scale is None: softmax_scale = q.shape[-1] ** -0.5 @@ -319,6 +324,7 @@ def flash_attn_decode_paged( kv_cache_dtype, float(k_scale), float(v_scale), + int(window), ) def flash_attn_prefill_paged( @@ -333,6 +339,7 @@ def flash_attn_prefill_paged( k_scale: float = 1.0, v_scale: float = 1.0, causal: bool = True, + window: int = -1, ): if softmax_scale is None: softmax_scale = q.shape[-1] ** -0.5 @@ -357,6 +364,7 @@ def flash_attn_prefill_paged( float(k_scale), float(v_scale), causal, + int(window), ) return out_.permute(0, 2, 1, 3).contiguous() diff --git a/flash-attention-v100/include/fused_mha.h b/flash-attention-v100/include/fused_mha.h index 4d408d1ec8..f07f41df13 100644 --- a/flash-attention-v100/include/fused_mha.h +++ b/flash-attention-v100/include/fused_mha.h @@ -37,7 +37,8 @@ at::Tensor flash_attention_decode_paged( const int partition_size, const std::string& kv_cache_dtype, const float k_scale, - const float v_scale + const float v_scale, + const int window ); at::Tensor flash_attention_prefill_paged( @@ -51,7 +52,8 @@ at::Tensor flash_attention_prefill_paged( const std::string& kv_cache_dtype, const float k_scale, const float v_scale, - const bool is_causal + const bool is_causal, + const int window ); std::vector flash_attention_backward( diff --git a/flash-attention-v100/kernel/flash_decode_paged.cu b/flash-attention-v100/kernel/flash_decode_paged.cu index 5e7ab84b25..90eb3b42cd 100644 --- a/flash-attention-v100/kernel/flash_decode_paged.cu +++ b/flash-attention-v100/kernel/flash_decode_paged.cu @@ -180,7 +180,8 @@ __global__ void flash_attention_decode_partition_kernel( const int64_t v_head_stride, const float softmax_scale, const float k_scale, - const float v_scale) { + const float v_scale, + const int window) { const int batch_idx = blockIdx.x; const int head_idx = blockIdx.y; const int partition_idx = blockIdx.z; @@ -197,6 +198,19 @@ __global__ void flash_attention_decode_partition_kernel( } const int part_tokens = min(PARTITION_SIZE, seq_len - start_token_idx); + // Sliding-window: the decode query sits at seq_len-1 and attends only to keys + // in [seq_len-window, seq_len-1]. If the whole partition predates the window, + // it contributes nothing -- write neutral stats so the reduce step skips it. + if (window >= 0 && start_token_idx + part_tokens <= seq_len - window) { + if (threadIdx.x == 0) { + const int64_t stats_index = + static_cast(batch_idx) * stats_stride0 + + static_cast(head_idx) * stats_stride1 + partition_idx; + max_logits[stats_index] = -1.0e20f; + exp_sums[stats_index] = 0.f; + } + return; + } const int q_per_kv = num_heads_q / num_heads_kv; const int kv_head_idx = head_idx / q_per_kv; const int lane = threadIdx.x % kWarpSize; @@ -227,6 +241,12 @@ __global__ void flash_attention_decode_partition_kernel( float local_max = -1.0e20f; for (int token_local = warp_idx; token_local < part_tokens; token_local += kWarpsPerBlock) { + // Per-token sliding-window mask (token_local is warp-uniform, so the branch + // is uniform across the warp and we can skip the dot product entirely). + if (window >= 0 && start_token_idx + token_local < seq_len - window) { + if (lane == 0) scores_shared[token_local] = -1.0e20f; + continue; + } const int physical_block = block_idx_shared[token_local]; const int block_offset = block_offset_shared[token_local]; const int64_t k_index = @@ -386,6 +406,7 @@ void launch_flash_attention_decode_paged( const float softmax_scale, const float k_scale, const float v_scale, + const int window, cudaStream_t stream) { const int batch_size = q.size(0); const int num_heads_q = q.size(1); @@ -431,7 +452,8 @@ void launch_flash_attention_decode_paged( v_cache.stride(2), softmax_scale, k_scale, - v_scale); + v_scale, + window); flash_attention_decode_reduce_kernel<<>>( reinterpret_cast(tmp_out.data_ptr()), @@ -467,7 +489,8 @@ at::Tensor flash_attention_decode_paged( const int partition_size, const std::string& kv_cache_dtype, const float k_scale, - const float v_scale) { + const float v_scale, + const int window) { TORCH_CHECK(q.is_cuda(), "q must be on CUDA"); TORCH_CHECK(k_cache.is_cuda() && v_cache.is_cuda(), "k/v cache must be on CUDA"); TORCH_CHECK(block_table.is_cuda() && seq_lens.is_cuda(), "block_table and seq_lens must be on CUDA"); @@ -536,7 +559,7 @@ at::Tensor flash_attention_decode_paged( #define LAUNCH_TYPED(HDIM, PARTITION, KV_DTYPE_CODE) \ launch_flash_attention_decode_paged( \ q, k_cache, v_cache, out, block_table, seq_lens, tmp_out, max_logits, \ - exp_sums, softmax_scale, k_scale, v_scale, stream) + exp_sums, softmax_scale, k_scale, v_scale, window, stream) #define LAUNCH_BY_KV_DTYPE(HDIM, PARTITION) \ do { \ diff --git a/flash-attention-v100/kernel/fused_mha_forward.cu b/flash-attention-v100/kernel/fused_mha_forward.cu index 16e4ffb8ed..6a8f12b664 100644 --- a/flash-attention-v100/kernel/fused_mha_forward.cu +++ b/flash-attention-v100/kernel/fused_mha_forward.cu @@ -109,7 +109,8 @@ flash_attention_forward_kernel( const int H_KV, const int M, const int N, - const float softmax_scale + const float softmax_scale, + const int window ) { using Config = KernelConfig; constexpr int BLOCK_M = Config::BLOCK_M; @@ -269,9 +270,14 @@ flash_attention_forward_kernel( const bool is_valid = (global_m < start_row + valid_q_rows) && (global_n < start_col + valid_k_rows); + // Causal + optional sliding-window: mask future keys and + // keys older than `window` tokens from the query. + const bool masked = + (global_n > global_q_pos) || + (window >= 0 && global_q_pos - global_n >= window); acc_frag.x[i] = is_valid - ? ((global_n > global_q_pos) ? NEG_INF : acc_frag.x[i] * softmax_scale) + ? (masked ? NEG_INF : acc_frag.x[i] * softmax_scale) : NEG_INF; } } else { @@ -508,6 +514,7 @@ void launcher_flash_attention_forward( torch::Tensor& softmax_lse, float softmax_scale, bool is_causal, + int window, cudaStream_t stream ) { using Config = KernelConfig; @@ -538,7 +545,7 @@ void launcher_flash_attention_forward( reinterpret_cast(V.data_ptr()), reinterpret_cast<__half*>(Out.data_ptr()), softmax_lse.data_ptr(), - B, H, H_KV, M, N, softmax_scale + B, H, H_KV, M, N, softmax_scale, window ); } else { flash_attention_forward_kernel<<>>( @@ -547,7 +554,7 @@ void launcher_flash_attention_forward( reinterpret_cast(V.data_ptr()), reinterpret_cast<__half*>(Out.data_ptr()), softmax_lse.data_ptr(), - B, H, H_KV, M, N, softmax_scale + B, H, H_KV, M, N, softmax_scale, window ); } } @@ -570,8 +577,12 @@ std::vector flash_attention_forward( TORCH_CHECK(!alibi_slopes_.has_value(), "alibi_slopes not supported"); TORCH_CHECK(p_dropout == 0.f, "dropout not supported"); - TORCH_CHECK(window_size_left == -1, "window_size_left not supported"); + TORCH_CHECK(window_size_left == -1 || (is_causal && window_size_left >= 0), + "window_size_left only supported with causal=True"); TORCH_CHECK(window_size_right == -1 || (is_causal && window_size_right == 0), "window not supported"); + // Attended-token count for the kernel: left==-1 means unlimited, otherwise + // a query attends to window_size_left + 1 tokens (itself + left preceding). + const int window = (window_size_left < 0) ? -1 : window_size_left + 1; TORCH_CHECK(softcap == 0.f, "softcap not supported"); TORCH_CHECK(!return_softmax, "return_softmax not supported"); TORCH_CHECK(!gen_.has_value(), "Generator not supported"); @@ -605,11 +616,11 @@ std::vector flash_attention_forward( TORCH_CHECK(sm70, "Kernel supports only Volta GPUs."); switch (D) { - case 16: launcher_flash_attention_forward<16>(q, k, v, out_fp16, softmax_lse, softmax_scale, is_causal, stream); break; - case 32: launcher_flash_attention_forward<32>(q, k, v, out_fp16, softmax_lse, softmax_scale, is_causal, stream); break; - case 64: launcher_flash_attention_forward<64>(q, k, v, out_fp16, softmax_lse, softmax_scale, is_causal, stream); break; - case 128: launcher_flash_attention_forward<128>(q, k, v, out_fp16, softmax_lse, softmax_scale, is_causal, stream); break; - case 256: launcher_flash_attention_forward<256>(q, k, v, out_fp16, softmax_lse, softmax_scale, is_causal, stream); break; + case 16: launcher_flash_attention_forward<16>(q, k, v, out_fp16, softmax_lse, softmax_scale, is_causal, window, stream); break; + case 32: launcher_flash_attention_forward<32>(q, k, v, out_fp16, softmax_lse, softmax_scale, is_causal, window, stream); break; + case 64: launcher_flash_attention_forward<64>(q, k, v, out_fp16, softmax_lse, softmax_scale, is_causal, window, stream); break; + case 128: launcher_flash_attention_forward<128>(q, k, v, out_fp16, softmax_lse, softmax_scale, is_causal, window, stream); break; + case 256: launcher_flash_attention_forward<256>(q, k, v, out_fp16, softmax_lse, softmax_scale, is_causal, window, stream); break; default: TORCH_CHECK(false, "Unsupported D: ", D); } diff --git a/flash-attention-v100/kernel/fused_mha_forward_paged.cu b/flash-attention-v100/kernel/fused_mha_forward_paged.cu index 73d56ff732..b2052bb7e0 100644 --- a/flash-attention-v100/kernel/fused_mha_forward_paged.cu +++ b/flash-attention-v100/kernel/fused_mha_forward_paged.cu @@ -136,7 +136,8 @@ flash_attention_forward_kernel_paged( const int64_t v_head_stride, const float softmax_scale, const float k_scale, - const float v_scale + const float v_scale, + const int window ) { using Config = KernelConfig; using Traits = FlashV100Traits; @@ -396,7 +397,11 @@ flash_attention_forward_kernel_paged( const int global_q_pos = global_m + causal_q_offset; const bool is_valid = (global_m < start_row + valid_q_rows) && (global_n < start_col + valid_k_rows); - const bool is_causal_valid = (global_q_pos >= global_n); + // Causal + optional sliding-window: key must satisfy + // global_q_pos - window < global_n <= global_q_pos. + const bool is_causal_valid = + (global_q_pos >= global_n) && + (window < 0 || global_q_pos - global_n < window); acc_frag.x[i] = is_valid ? (is_causal_valid ? acc_frag.x[i] * softmax_scale : NEG_INF) @@ -716,6 +721,7 @@ void launcher_flash_attention_forward_paged( bool is_causal, float k_scale, float v_scale, + int window, cudaStream_t stream ) { using Config = KernelConfig; @@ -775,7 +781,8 @@ void launcher_flash_attention_forward_paged( v_head_stride, softmax_scale, k_scale, - v_scale + v_scale, + window ); } else { flash_attention_forward_kernel_paged @@ -801,7 +808,8 @@ void launcher_flash_attention_forward_paged( v_head_stride, softmax_scale, k_scale, - v_scale + v_scale, + window ); } } @@ -817,7 +825,8 @@ at::Tensor flash_attention_prefill_paged( const std::string& kv_cache_dtype, const float k_scale, const float v_scale, - const bool is_causal + const bool is_causal, + const int window ) { TORCH_CHECK(q.dtype() == torch::kFloat16, "q must be fp16"); const int kv_dtype_code = kv_cache_dtype_code_from_string(kv_cache_dtype); @@ -867,7 +876,7 @@ at::Tensor flash_attention_prefill_paged( #define LAUNCH_PAGED_TYPED(HDIM, KV_DTYPE_CODE) \ launcher_flash_attention_forward_paged( \ q, k_cache, v_cache, out_fp16, softmax_lse, block_table, seq_lens, \ - softmax_scale, is_causal, k_scale, v_scale, stream) + softmax_scale, is_causal, k_scale, v_scale, window, stream) #define LAUNCH_PAGED_BY_KV(HDIM) \ do { \ diff --git a/flash-attention-v100/test_window.py b/flash-attention-v100/test_window.py new file mode 100644 index 0000000000..0d22d24187 --- /dev/null +++ b/flash-attention-v100/test_window.py @@ -0,0 +1,138 @@ +"""Standalone correctness test for sliding-window FLASH_ATTN_V100 (Phase 1). + +Compares the windowed paged decode + prefill kernels against an fp32 torch +reference. Run on a V100 (SM70). No vLLM, no model load. +""" +import torch +from flash_attn_v100 import (flash_attn_decode_paged, flash_attn_prefill_paged, + flash_attn_func) + +torch.manual_seed(0) +DEV = "cuda" + + +def build_paged(k_cont, v_cont, block_size): + """k_cont/v_cont: [S, Hkv, D] -> paged [num_blocks, block_size, Hkv, D] + block_table [1, nb].""" + S, Hkv, D = k_cont.shape + nb = (S + block_size - 1) // block_size + k_cache = torch.zeros((nb, block_size, Hkv, D), dtype=k_cont.dtype, device=DEV) + v_cache = torch.zeros((nb, block_size, Hkv, D), dtype=v_cont.dtype, device=DEV) + for b in range(nb): + s = b * block_size + e = min(s + block_size, S) + k_cache[b, : e - s] = k_cont[s:e] + v_cache[b, : e - s] = v_cont[s:e] + block_table = torch.arange(nb, dtype=torch.int32, device=DEV).view(1, nb) + return k_cache, v_cache, block_table + + +def ref_attn(q, k, v, scale, window, causal_qpos=None): + """q:[Hq,D] (single decode query), k/v:[S,Hkv,D]. Returns [Hq,D] fp32. + Decode query sits at position S-1. window<0 = full.""" + Hq, D = q.shape + S, Hkv, _ = k.shape + qpk = Hq // Hkv + out = torch.zeros((Hq, D), dtype=torch.float32, device=DEV) + qpos = S - 1 + for h in range(Hq): + kh = h // qpk + scores = (q[h].float() @ k[:, kh].float().T) * scale # [S] + mask = torch.ones(S, dtype=torch.bool, device=DEV) + if window >= 0: + mask &= torch.arange(S, device=DEV) >= (qpos - window + 1) + scores = scores.masked_fill(~mask, float("-inf")) + p = torch.softmax(scores, dim=-1) + out[h] = p @ v[:, kh].float() + return out + + +def test_decode(D, S, window, block_size=16, Hq=8, Hkv=2): + scale = D ** -0.5 + q = torch.randn(1, Hq, D, dtype=torch.float16, device=DEV) + k = torch.randn(S, Hkv, D, dtype=torch.float16, device=DEV) + v = torch.randn(S, Hkv, D, dtype=torch.float16, device=DEV) + k_cache, v_cache, block_table = build_paged(k, v, block_size) + seq_lens = torch.tensor([S], dtype=torch.int32, device=DEV) + out = flash_attn_decode_paged( + q, k_cache, v_cache, block_table, seq_lens, + softmax_scale=scale, kv_cache_dtype="auto", window=window, + ) + ref = ref_attn(q[0], k, v, scale, window) + got = out[0].float() + err = (got - ref).abs().max().item() + print(f" decode D={D} S={S} win={window:>5}: max_abs_err={err:.5f} {'OK' if err < 2e-2 else 'FAIL'}") + return err < 2e-2 + + +def test_prefill(D, S, window, block_size=16, Hq=8, Hkv=2): + """Prefill: M=S queries, causal + window. Compare last-row + a mid row.""" + scale = D ** -0.5 + q = torch.randn(1, S, Hq, D, dtype=torch.float16, device=DEV) # [B,M,H,D] + k = torch.randn(S, Hkv, D, dtype=torch.float16, device=DEV) + v = torch.randn(S, Hkv, D, dtype=torch.float16, device=DEV) + k_cache, v_cache, block_table = build_paged(k, v, block_size) + seq_lens = torch.tensor([S], dtype=torch.int32, device=DEV) + out = flash_attn_prefill_paged( + q, k_cache, v_cache, block_table, seq_lens, + softmax_scale=scale, kv_cache_dtype="auto", causal=True, window=window, + ) # [B,M,H,D] + qpk = Hq // Hkv + worst = 0.0 + for qi in (S - 1, S // 2, min(window + 3, S - 1)): + for h in range(Hq): + kh = h // qpk + scores = (q[0, qi, h].float() @ k[:, kh].float().T) * scale + idx = torch.arange(S, device=DEV) + mask = idx <= qi + if window >= 0: + mask &= idx >= (qi - window + 1) + scores = scores.masked_fill(~mask, float("-inf")) + p = torch.softmax(scores, dim=-1) + ref = p @ v[:, kh].float() + err = (out[0, qi, h].float() - ref).abs().max().item() + worst = max(worst, err) + print(f" prefill D={D} S={S} win={window:>5}: max_abs_err={worst:.5f} {'OK' if worst < 3e-2 else 'FAIL'}") + return worst < 3e-2 + + +def test_dense(D, S, window, Hq=8, Hkv=2): + """Dense (non-paged) flash_attn_func, causal + window. q/k/v: [B,M,H,D].""" + scale = D ** -0.5 + q = torch.randn(1, S, Hq, D, dtype=torch.float16, device=DEV) + k = torch.randn(1, S, Hkv, D, dtype=torch.float16, device=DEV) + v = torch.randn(1, S, Hkv, D, dtype=torch.float16, device=DEV) + ws = (-1, -1) if window < 0 else (window - 1, 0) + out = flash_attn_func(q, k, v, causal=True, softmax_scale=scale, window_size=ws) # [B,M,H,D] + qpk = Hq // Hkv + worst = 0.0 + for qi in (S - 1, S // 2, min(window + 3, S - 1) if window > 0 else S // 3): + for h in range(Hq): + kh = h // qpk + scores = (q[0, qi, h].float() @ k[0, :, kh].float().T) * scale + idx = torch.arange(S, device=DEV) + mask = idx <= qi + if window >= 0: + mask &= idx >= (qi - window + 1) + scores = scores.masked_fill(~mask, float("-inf")) + p = torch.softmax(scores, dim=-1) + ref = p @ v[0, :, kh].float() + err = (out[0, qi, h].float() - ref).abs().max().item() + worst = max(worst, err) + print(f" dense D={D} S={S} win={window:>5}: max_abs_err={worst:.5f} {'OK' if worst < 3e-2 else 'FAIL'}") + return worst < 3e-2 + + +if __name__ == "__main__": + ok = True + print("== full attention (window=-1) regression ==") + for D in (128, 256): + ok &= test_decode(D, 100, -1) + ok &= test_prefill(D, 100, -1) + ok &= test_dense(D, 100, -1) + print("== sliding window ==") + for D in (128, 256): + for S, W in [(100, 32), (100, 64), (50, 64), (200, 48), (33, 32)]: + ok &= test_decode(D, S, W) + ok &= test_prefill(D, S, W) + ok &= test_dense(D, S, W) + print("ALL PASS" if ok else "SOME FAILED") diff --git a/vllm/v1/attention/backends/flash_attn_v100.py b/vllm/v1/attention/backends/flash_attn_v100.py index 0fd10026ac..6be22a9550 100644 --- a/vllm/v1/attention/backends/flash_attn_v100.py +++ b/vllm/v1/attention/backends/flash_attn_v100.py @@ -366,16 +366,34 @@ def _supports_flash_v100_path(self) -> bool: not self.kv_cache_dtype.startswith("fp8") or self.kv_cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2") ) + # Causal sliding-window (left>=0, right==0) is supported by the kernels + # via the `window` param. Full attention is (-1, -1). Bidirectional + # (right != 0) windows are not supported. + supported_window = ( + self.sliding_window == (-1, -1) + or self.sliding_window[1] == 0 + ) return ( self.use_flash_v100 and self.attn_type == AttentionType.DECODER and self.alibi_slopes is None and self.logits_soft_cap == 0 and self.sinks is None - and self.sliding_window == (-1, -1) + and supported_window and supported_kv_dtype ) + @property + def _flash_window(self) -> int: + """Number of attended tokens for the kernels (-1 = unlimited). + + self.sliding_window is (left, right) with left == sliding_window - 1, + so the attended-token count is left + 1 == the model's sliding_window. + """ + if self.sliding_window == (-1, -1): + return -1 + return self.sliding_window[0] + 1 + def _small_query_decode_enabled( self, attn_metadata: TritonAttentionMetadata, @@ -615,6 +633,7 @@ def _flash_v100_prefill( v_batch, causal=causal, softmax_scale=self.scale, + window_size=self.sliding_window, ) out_view[tok_start:tok_end].copy_( out_batch.view(tok_end - tok_start, out_batch.shape[2], out_batch.shape[3]) @@ -658,6 +677,7 @@ def _flash_v100_decode( kv_cache_dtype=self.kv_cache_dtype, k_scale=float(layer._k_scale_float), v_scale=float(layer._v_scale_float), + window=self._flash_window, ) return output @@ -761,6 +781,7 @@ def _flash_v100_small_query_prefill_as_decode( kv_cache_dtype=self.kv_cache_dtype, k_scale=float(layer._k_scale_float), v_scale=float(layer._v_scale_float), + window=self._flash_window, ) return output @@ -843,6 +864,7 @@ def _flash_v100_prefill_with_prefix( k_scale=float(layer._k_scale_float), v_scale=float(layer._v_scale_float), causal=causal, + window=self._flash_window, ) if debug_compare and not _logged_prefill_compare: seq_len = int(seq_lens[i].item()) @@ -868,6 +890,7 @@ def _flash_v100_prefill_with_prefix( v_cont.unsqueeze(0), causal=causal, softmax_scale=self.scale, + window_size=self.sliding_window, ) diff = (out_seq - ref_out).abs() nan_count = int(torch.isnan(out_seq).sum().item()) @@ -942,6 +965,7 @@ def _flash_v100_prefill_with_prefix( v_cont.unsqueeze(0), causal=causal, softmax_scale=self.scale, + window_size=self.sliding_window, ) out_view[start:end].copy_(out_seq.squeeze(0)) From dbbd2df702c7fd4154dcc5e5bb00a9e8a5a1a8b9 Mon Sep 17 00:00:00 2001 From: Phil Date: Mon, 8 Jun 2026 02:45:49 +0000 Subject: [PATCH 22/25] [gemma4][FA] head_dim-512 decode + prefill window tile-skip MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Pure-FA (all 60 gemma layers on FLASH_ATTN_V100) — decode now beats Triton at every context, verified correct on long-context retrieval. - flash_decode_paged.cu: add head_dim 512 case. Decode kernel is head-dim- generic (dot_qk_cache loops, no WMMA), so 512 is just the switch entry; q_shared[512]=1KB fits smem at M=1. Lets the global layers' DECODE run on FA. - flash_attn_v100.py: supports_head_size/get_supported_head_sizes += 512; forward() gates 512 PREFILL to Triton (flash_prefill_ok=head_size<=256) since the prefill kernels cap at 256 (smem); decode + small-query verifier stay FA. - fused_mha_forward{,_paged}.cu: sliding-window lower-bound tile skip — skip key-tiles entirely older than `window` (per-element mask already guarantees correctness; this just stops computing fully-masked tiles). Sliding prefill O(N^2) -> O(N*W). - test_window.py: + D=512 decode cases (full + windowed, incl S=8192). All pass. Measured (c1, decode TPOT): Triton 48/116/181/310ms @512/2k/4k/8k -> pure-FA 22/24/24/26ms FLAT (~12x at 8k). Short throughput c1 39.7 / c4 102.9 tok/s (was 17.65/54.65). Long-ctx retrieval correct (fact at pos 0 of 3359 toks). Remaining bottleneck: global-512 PREFILL on Triton dominates TTFT (needs split-D prefill-512). See memory v100-vllm-own-fork. Co-Authored-By: RivetOS Claude --- .../kernel/flash_decode_paged.cu | 3 ++ .../kernel/fused_mha_forward.cu | 13 ++++++- .../kernel/fused_mha_forward_paged.cu | 12 ++++++- flash-attention-v100/test_window.py | 3 ++ vllm/v1/attention/backends/flash_attn_v100.py | 34 ++++++++++++++----- 5 files changed, 54 insertions(+), 11 deletions(-) diff --git a/flash-attention-v100/kernel/flash_decode_paged.cu b/flash-attention-v100/kernel/flash_decode_paged.cu index 90eb3b42cd..1a671a41bc 100644 --- a/flash-attention-v100/kernel/flash_decode_paged.cu +++ b/flash-attention-v100/kernel/flash_decode_paged.cu @@ -614,6 +614,9 @@ at::Tensor flash_attention_decode_paged( case 256: LAUNCH_BY_PARTITION(256); break; + case 512: + LAUNCH_BY_PARTITION(512); + break; default: TORCH_CHECK(false, "Unsupported head_dim for paged decode: ", head_dim); } diff --git a/flash-attention-v100/kernel/fused_mha_forward.cu b/flash-attention-v100/kernel/fused_mha_forward.cu index 6a8f12b664..a1c0611515 100644 --- a/flash-attention-v100/kernel/fused_mha_forward.cu +++ b/flash-attention-v100/kernel/fused_mha_forward.cu @@ -198,7 +198,18 @@ flash_attention_forward_kernel( } __syncthreads(); - for (int block_n = 0; block_n < num_n_tiles; ++block_n) { + // Sliding-window lower-bound tile skip: keys older than `window` from the + // earliest query in this row-block are fully masked, so skip those tiles + // entirely (the per-element mask still guarantees correctness). + int first_n_tile = 0; + if constexpr (IS_CAUSAL) { + if (window >= 0) { + const int earliest_key = (start_row + causal_q_offset) - window + 1; + if (earliest_key > 0) first_n_tile = earliest_key / BLOCK_N; + } + } + + for (int block_n = first_n_tile; block_n < num_n_tiles; ++block_n) { const int start_col = block_n * BLOCK_N; if (start_col >= N) break; const int valid_k_rows = min(BLOCK_N, N - start_col); diff --git a/flash-attention-v100/kernel/fused_mha_forward_paged.cu b/flash-attention-v100/kernel/fused_mha_forward_paged.cu index b2052bb7e0..6e32188534 100644 --- a/flash-attention-v100/kernel/fused_mha_forward_paged.cu +++ b/flash-attention-v100/kernel/fused_mha_forward_paged.cu @@ -248,7 +248,17 @@ flash_attention_forward_kernel_paged( } __syncthreads(); - for (int block_n = 0; block_n < num_n_tiles; ++block_n) { + // Sliding-window lower-bound tile skip: keys older than `window` from the + // earliest query in this row-block are fully masked -> skip those tiles. + int first_n_tile = 0; + if constexpr (IS_CAUSAL) { + if (window >= 0) { + const int earliest_key = (start_row + causal_q_offset) - window + 1; + if (earliest_key > 0) first_n_tile = earliest_key / BLOCK_N; + } + } + + for (int block_n = first_n_tile; block_n < num_n_tiles; ++block_n) { const int start_col = block_n * BLOCK_N; if (start_col >= actual_N) break; const int valid_k_rows = min(BLOCK_N, actual_N - start_col); diff --git a/flash-attention-v100/test_window.py b/flash-attention-v100/test_window.py index 0d22d24187..b72240c5e9 100644 --- a/flash-attention-v100/test_window.py +++ b/flash-attention-v100/test_window.py @@ -135,4 +135,7 @@ def test_dense(D, S, window, Hq=8, Hkv=2): ok &= test_decode(D, S, W) ok &= test_prefill(D, S, W) ok &= test_dense(D, S, W) + print("== D=512 decode (global layers: full + windowed) ==") + for S, W in [(100, -1), (8192, -1), (100, 32), (200, 48), (8192, 1024)]: + ok &= test_decode(512, S, W) print("ALL PASS" if ok else "SOME FAILED") diff --git a/vllm/v1/attention/backends/flash_attn_v100.py b/vllm/v1/attention/backends/flash_attn_v100.py index 6be22a9550..8d27d32e29 100644 --- a/vllm/v1/attention/backends/flash_attn_v100.py +++ b/vllm/v1/attention/backends/flash_attn_v100.py @@ -470,6 +470,10 @@ def forward( is_prefill = attn_metadata.max_query_len > 1 is_capturing = query.is_cuda and torch.cuda.is_current_stream_capturing() + # Decode kernel handles head_dim 512; the prefill kernels (dense + paged) + # cap at 256. So for 512-dim layers we route true prefill to Triton but + # keep the small-query verifier (which uses the decode kernel) on FA. + flash_prefill_ok = self.head_size <= 256 if is_prefill: if is_capturing: @@ -508,6 +512,13 @@ def forward( has_prefix_context and self._small_query_decode_enabled(attn_metadata) ) + if has_prefix_context and not (smallq_decode or flash_prefill_ok): + # 512-dim prefix/chunked prefill can't use the 256-cap prefill + # kernel and isn't small-query; fall back to Triton prefill. + return super().forward( + layer, query, key, value, kv_cache, attn_metadata, + output, output_scale, output_block_scale, + ) if has_prefix_context: if not _logged_prefill_prefix_flash: if smallq_decode: @@ -534,6 +545,12 @@ def forward( attn_metadata, output, ) + if not flash_prefill_ok: + # 512-dim no-prefix prefill: dense kernel caps at 256 -> Triton. + return super().forward( + layer, query, key, value, kv_cache, attn_metadata, + output, output_scale, output_block_scale, + ) if not _logged_prefill_flash: logger.info( "FLASH_ATTN_V100 prefill path active (no prefix/chunked context)." @@ -992,16 +1009,15 @@ def get_name() -> str: @staticmethod def get_supported_head_sizes() -> list[int]: - # Keep this aligned with the dense prefill kernel dispatch table. - return [64, 128, 256] + # Decode kernel handles 512 (head-dim-generic GEMV); prefill kernels + # cap at 256 (smem), so 512 prefill falls back to Triton in forward(). + return [64, 128, 256, 512] @classmethod def supports_head_size(cls, head_size: int) -> bool: # NOTE(rivet): validate_configuration() calls supports_head_size(), - # NOT get_supported_head_sizes(). Without this override we'd inherit - # TritonAttentionBackend.supports_head_size (head_size >= 32) and - # wrongly accept head_size=512, then hard-crash the Volta CUDA kernel - # (TORCH_CHECK D<=256). The Volta FA kernel only handles 64/128/256. - # With this, auto-selection routes gemma-4's 256-dim sliding layers - # here and falls through to TRITON_ATTN for its 512-dim global layers. - return head_size in (64, 128, 256) + # NOT get_supported_head_sizes(). The Volta FA DECODE kernel is + # head-dim-generic and now handles 512 (gemma-4 global layers); the + # PREFILL kernels still cap at 256 (512 tile blows 96KB smem), so + # forward() routes 512-dim prefill to Triton while decode stays on FA. + return head_size in (64, 128, 256, 512) From 47b5bbc18f63b2778a46ab3891c92be8e25f6fdf Mon Sep 17 00:00:00 2001 From: Phil Date: Mon, 8 Jun 2026 03:04:14 +0000 Subject: [PATCH 23/25] [gemma4][FA] head_dim-512 prefill (split via small blocks) -> fully-FA Move the gemma global layers' PREFILL onto FA too. The WMMA prefill body already accumulates QK over D in WMMA_K chunks and loops PV over D, so it's head-dim-generic; the only blocker for 512 was fitting the wider Q/K/V/O tiles in 96KB smem. Add a Config<512> with BLOCK_M=16/BLOCK_N=32/WARPS=16 (~84KB), which the existing body uses unchanged -- no split-D kernel rewrite needed. - fused_mha_forward{,_paged}.cu + flash_v100_traits.cuh: BLOCK_M/N_512 consts, D==512 in the config/traits ternaries, dispatch case 512, relax D<=256 check. WARPS_512=16 keeps THREADS_PER_BLOCK=512 consistent with the traits-based paged KV loader (which assumes 16 warps). - flash_attn_v100.py: flash_prefill_ok = head_size<=512 (512 prefill now FA). - test_window.py: + D=512 prefill (paged + dense, full + windowed). All pass ~1e-3. Now ALL 60 layers (decode AND prefill) run on FLASH_ATTN_V100 -- zero Triton in the attention path. Measured (c1) vs original Triton baseline: TTFT 8192: 63s -> 8.3s (7.6x); decode TPOT 8192: 310ms -> 25ms (12x, flat); throughput 8192: 1.25 -> 11.2 tok/s (9x); short c1 39.7 / c4 102.9 tok/s. Correctness: test_window all-pass + long-ctx retrieval (fact@pos0 of 3359 toks). Co-Authored-By: RivetOS Claude --- .../kernel/flash_v100_traits.cuh | 8 ++++++-- .../kernel/fused_mha_forward.cu | 16 ++++++++++++---- .../kernel/fused_mha_forward_paged.cu | 18 +++++++++++++----- flash-attention-v100/test_window.py | 4 ++++ vllm/v1/attention/backends/flash_attn_v100.py | 7 +++---- 5 files changed, 38 insertions(+), 15 deletions(-) diff --git a/flash-attention-v100/kernel/flash_v100_traits.cuh b/flash-attention-v100/kernel/flash_v100_traits.cuh index 909ea97db4..928a42c635 100644 --- a/flash-attention-v100/kernel/flash_v100_traits.cuh +++ b/flash-attention-v100/kernel/flash_v100_traits.cuh @@ -16,18 +16,22 @@ struct FlashV100Traits { static constexpr int BLOCK_N_128 = 176; static constexpr int BLOCK_M_256 = 32; static constexpr int BLOCK_N_256 = 64; + static constexpr int BLOCK_M_512 = 16; + static constexpr int BLOCK_N_512 = 32; static constexpr int WARPS_PER_BLOCK = 16; static constexpr int THREADS_PER_WARP = 32; static constexpr int BLOCK_M = (D == 16) ? BLOCK_M_16 : (D == 32) ? BLOCK_M_32 : (D == 64) ? BLOCK_M_64 : - (D == 128) ? BLOCK_M_128 : BLOCK_M_256; + (D == 128) ? BLOCK_M_128 : + (D == 512) ? BLOCK_M_512 : BLOCK_M_256; static constexpr int BLOCK_N = (D == 16) ? BLOCK_N_16 : (D == 32) ? BLOCK_N_32 : (D == 64) ? BLOCK_N_64 : - (D == 128) ? BLOCK_N_128 : BLOCK_N_256; + (D == 128) ? BLOCK_N_128 : + (D == 512) ? BLOCK_N_512 : BLOCK_N_256; static constexpr int THREADS_PER_BLOCK = WARPS_PER_BLOCK * THREADS_PER_WARP; static constexpr int THREADS_PER_ROW = THREADS_PER_BLOCK / BLOCK_M; diff --git a/flash-attention-v100/kernel/fused_mha_forward.cu b/flash-attention-v100/kernel/fused_mha_forward.cu index a1c0611515..f7883a8967 100644 --- a/flash-attention-v100/kernel/fused_mha_forward.cu +++ b/flash-attention-v100/kernel/fused_mha_forward.cu @@ -49,11 +49,18 @@ using namespace nvcuda::wmma; #define BLOCK_N_256 64 #define WARPS_256 16 +// head_dim 512 (gemma global layers). Small blocks so the 512-wide Q/K/V/O +// tiles fit in 96KB smem (q+kv+o dominate). The QK k-loop already accumulates +// over all of D in WMMA_K chunks, so no body change is needed. +#define BLOCK_M_512 16 +#define BLOCK_N_512 32 +#define WARPS_512 16 + template struct KernelConfig { - static constexpr int BLOCK_M = (D == 16) ? BLOCK_M_16 : (D == 32) ? BLOCK_M_32 : (D == 64) ? BLOCK_M_64 : (D == 128) ? BLOCK_M_128 : BLOCK_M_256; - static constexpr int BLOCK_N = (D == 16) ? BLOCK_N_16 : (D == 32) ? BLOCK_N_32 : (D == 64) ? BLOCK_N_64 : (D == 128) ? BLOCK_N_128 : BLOCK_N_256; - static constexpr int WARPS_PER_BLOCK = (D == 16) ? WARPS_16 : (D == 32) ? WARPS_32 : (D == 64) ? WARPS_64 : (D == 128) ? WARPS_128 : WARPS_256; + static constexpr int BLOCK_M = (D == 16) ? BLOCK_M_16 : (D == 32) ? BLOCK_M_32 : (D == 64) ? BLOCK_M_64 : (D == 128) ? BLOCK_M_128 : (D == 512) ? BLOCK_M_512 : BLOCK_M_256; + static constexpr int BLOCK_N = (D == 16) ? BLOCK_N_16 : (D == 32) ? BLOCK_N_32 : (D == 64) ? BLOCK_N_64 : (D == 128) ? BLOCK_N_128 : (D == 512) ? BLOCK_N_512 : BLOCK_N_256; + static constexpr int WARPS_PER_BLOCK = (D == 16) ? WARPS_16 : (D == 32) ? WARPS_32 : (D == 64) ? WARPS_64 : (D == 128) ? WARPS_128 : (D == 512) ? WARPS_512 : WARPS_256; static constexpr int THREADS_PER_BLOCK = WARPS_PER_BLOCK * MAX_THREADS_PER_WARP; static constexpr int THREADS_PER_ROW = THREADS_PER_BLOCK / BLOCK_M; @@ -608,7 +615,7 @@ std::vector flash_attention_forward( const int B = sizes[0], H = sizes[1], M = sizes[2], D = sizes[3]; const int H_KV = k.size(1); const int N = k.size(2); - TORCH_CHECK(D <= 256 && D % 8 == 0 && D % 2 == 0, "D must be even, <=256, multiple of 8"); + TORCH_CHECK((D <= 256 || D == 512) && D % 8 == 0 && D % 2 == 0, "D must be even, multiple of 8, and <=256 or ==512"); TORCH_CHECK(H_KV > 0, "num_kv_heads must be positive"); TORCH_CHECK(H % H_KV == 0, "num_attention_heads must be divisible by num_kv_heads"); TORCH_CHECK(k.size(0) == B && v.size(0) == B, "K/V batch size must match Q"); @@ -632,6 +639,7 @@ std::vector flash_attention_forward( case 64: launcher_flash_attention_forward<64>(q, k, v, out_fp16, softmax_lse, softmax_scale, is_causal, window, stream); break; case 128: launcher_flash_attention_forward<128>(q, k, v, out_fp16, softmax_lse, softmax_scale, is_causal, window, stream); break; case 256: launcher_flash_attention_forward<256>(q, k, v, out_fp16, softmax_lse, softmax_scale, is_causal, window, stream); break; + case 512: launcher_flash_attention_forward<512>(q, k, v, out_fp16, softmax_lse, softmax_scale, is_causal, window, stream); break; default: TORCH_CHECK(false, "Unsupported D: ", D); } diff --git a/flash-attention-v100/kernel/fused_mha_forward_paged.cu b/flash-attention-v100/kernel/fused_mha_forward_paged.cu index 6e32188534..6e60d86c70 100644 --- a/flash-attention-v100/kernel/fused_mha_forward_paged.cu +++ b/flash-attention-v100/kernel/fused_mha_forward_paged.cu @@ -65,11 +65,16 @@ int kv_cache_dtype_code_from_string(const std::string& kv_cache_dtype) { #define BLOCK_N_256 64 #define WARPS_256 16 +// head_dim 512 (gemma global layers): small blocks so 512-wide tiles fit smem. +#define BLOCK_M_512 16 +#define BLOCK_N_512 32 +#define WARPS_512 16 + template struct KernelConfig { - static constexpr int BLOCK_M = (D == 16) ? BLOCK_M_16 : (D == 32) ? BLOCK_M_32 : (D == 64) ? BLOCK_M_64 : (D == 128) ? BLOCK_M_128 : BLOCK_M_256; - static constexpr int BLOCK_N = (D == 16) ? BLOCK_N_16 : (D == 32) ? BLOCK_N_32 : (D == 64) ? BLOCK_N_64 : (D == 128) ? BLOCK_N_128 : BLOCK_N_256; - static constexpr int WARPS_PER_BLOCK = (D == 16) ? WARPS_16 : (D == 32) ? WARPS_32 : (D == 64) ? WARPS_64 : (D == 128) ? WARPS_128 : WARPS_256; + static constexpr int BLOCK_M = (D == 16) ? BLOCK_M_16 : (D == 32) ? BLOCK_M_32 : (D == 64) ? BLOCK_M_64 : (D == 128) ? BLOCK_M_128 : (D == 512) ? BLOCK_M_512 : BLOCK_M_256; + static constexpr int BLOCK_N = (D == 16) ? BLOCK_N_16 : (D == 32) ? BLOCK_N_32 : (D == 64) ? BLOCK_N_64 : (D == 128) ? BLOCK_N_128 : (D == 512) ? BLOCK_N_512 : BLOCK_N_256; + static constexpr int WARPS_PER_BLOCK = (D == 16) ? WARPS_16 : (D == 32) ? WARPS_32 : (D == 64) ? WARPS_64 : (D == 128) ? WARPS_128 : (D == 512) ? WARPS_512 : WARPS_256; static constexpr int THREADS_PER_BLOCK = WARPS_PER_BLOCK * MAX_THREADS_PER_WARP; static constexpr int THREADS_PER_ROW = THREADS_PER_BLOCK / BLOCK_M; @@ -869,8 +874,8 @@ at::Tensor flash_attention_prefill_paged( const int D = q.size(3); const int num_kv_heads = k_cache.size(2); - TORCH_CHECK(D <= 256 && D % 8 == 0 && D % 2 == 0, - "D must be even, <=256, multiple of 8"); + TORCH_CHECK((D <= 256 || D == 512) && D % 8 == 0 && D % 2 == 0, + "D must be even, multiple of 8, and <=256 or ==512"); TORCH_CHECK(H % num_kv_heads == 0, "num_heads must be divisible by num_kv_heads"); @@ -921,6 +926,9 @@ at::Tensor flash_attention_prefill_paged( case 256: LAUNCH_PAGED_BY_KV(256); break; + case 512: + LAUNCH_PAGED_BY_KV(512); + break; default: TORCH_CHECK(false, "Unsupported D: ", D); } diff --git a/flash-attention-v100/test_window.py b/flash-attention-v100/test_window.py index b72240c5e9..a5780acd33 100644 --- a/flash-attention-v100/test_window.py +++ b/flash-attention-v100/test_window.py @@ -138,4 +138,8 @@ def test_dense(D, S, window, Hq=8, Hkv=2): print("== D=512 decode (global layers: full + windowed) ==") for S, W in [(100, -1), (8192, -1), (100, 32), (200, 48), (8192, 1024)]: ok &= test_decode(512, S, W) + print("== D=512 prefill (global layers: full attention + windowed) ==") + for S, W in [(100, -1), (200, -1), (300, -1), (100, 48), (200, 64)]: + ok &= test_prefill(512, S, W) + ok &= test_dense(512, S, W) print("ALL PASS" if ok else "SOME FAILED") diff --git a/vllm/v1/attention/backends/flash_attn_v100.py b/vllm/v1/attention/backends/flash_attn_v100.py index 8d27d32e29..a43c497dc3 100644 --- a/vllm/v1/attention/backends/flash_attn_v100.py +++ b/vllm/v1/attention/backends/flash_attn_v100.py @@ -470,10 +470,9 @@ def forward( is_prefill = attn_metadata.max_query_len > 1 is_capturing = query.is_cuda and torch.cuda.is_current_stream_capturing() - # Decode kernel handles head_dim 512; the prefill kernels (dense + paged) - # cap at 256. So for 512-dim layers we route true prefill to Triton but - # keep the small-query verifier (which uses the decode kernel) on FA. - flash_prefill_ok = self.head_size <= 256 + # FA kernels now handle head_dim up to 512 for both decode and prefill + # (split via small blocks to fit smem). 512 prefill stays on FA. + flash_prefill_ok = self.head_size <= 512 if is_prefill: if is_capturing: From a2f453cd7c459d41e80e0d02ad6884ce1a98634d Mon Sep 17 00:00:00 2001 From: Phil Date: Mon, 8 Jun 2026 14:31:32 +0000 Subject: [PATCH 24/25] [FA_V100] Guard paged prefill against V100 smem overflow at long context The paged prefill kernel (fused_mha_forward_paged.cu) copies the per-sequence block table into shared memory, so its smem is TOTAL_SMEM[head_dim] + align128(max_num_blocks * 4), where max_num_blocks = block_table width = ceil(max_model_len / page_block_size). The per-D base is already ~84-96KB, so at long max_model_len the block table alone pushes total smem past V100's 96KB ceiling and the kernel's TORCH_CHECK aborts the worker, killing the server. head_dim 256 at 177k ctx needs 138368 bytes (11088 blocks); it only stayed hidden because no-prefix prefill uses the dense kernel and short-context servers (e.g. 11k) fit. Add _paged_prefill_smem_fits() mirroring the kernel's smem formula and gate the paged prefill call in _flash_v100_prefill_with_prefix on it. When it does not fit, fall through to the existing gather + dense path, which is smem-safe at any context length and still fully on FA (no Triton fallback). Verified: 31b @ 177k survives an 8.4k chunked-prefill prompt (previously crashed); e2b @ 11k still uses the paged kernel. Co-Authored-By: RivetOS Claude --- vllm/v1/attention/backends/flash_attn_v100.py | 49 ++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/flash_attn_v100.py b/vllm/v1/attention/backends/flash_attn_v100.py index a43c497dc3..2d58d7aa36 100644 --- a/vllm/v1/attention/backends/flash_attn_v100.py +++ b/vllm/v1/attention/backends/flash_attn_v100.py @@ -35,6 +35,19 @@ _logged_prefill_smallq_decode = False _logged_decode_flash = False _logged_prefill_compare = False +_warned_paged_prefill_smem = False + +# V100 dynamic shared-memory ceiling (bytes). +_FLASH_V100_MAX_SMEM = 98304 + +# Base (block-table-independent) shared memory of the PAGED prefill kernel, +# per head_dim, mirroring KernelConfig::TOTAL_SMEM in +# flash-attention-v100/kernel/fused_mha_forward_paged.cu. The kernel ALSO +# stores the per-sequence block table in smem (extra = align128(max_num_blocks +# * 4)), so total smem grows with max_model_len / page_block_size. These bases +# are already ~84-96KB, leaving little headroom: long-context servers overflow +# the 96KB ceiling and must fall back to the gather+dense prefill path. +_PAGED_PREFILL_BASE_SMEM = {64: 81408, 128: 97792, 256: 93952, 512: 85888} def _get_flash_ops(): @@ -801,6 +814,28 @@ def _flash_v100_small_query_prefill_as_decode( ) return output + def _paged_prefill_smem_fits( + self, + attn_metadata: TritonAttentionMetadata, + ) -> bool: + """Whether the paged prefill kernel's smem fits V100's 96KB ceiling. + + The kernel copies the per-sequence block table into shared memory, so + smem = TOTAL_SMEM[head_dim] + align128(max_num_blocks * 4). For long + max_model_len the block table alone blows the budget (e.g. head_dim 256 + at 177k ctx needs ~135KB). When it does not fit, the caller uses the + gather + dense prefill path, which is smem-safe at any context length. + """ + base = _PAGED_PREFILL_BASE_SMEM.get(self.head_size) + if base is None: + return False + block_table = getattr(attn_metadata, "block_table", None) + if block_table is None or block_table.ndim < 2: + return False + max_num_blocks = int(block_table.shape[1]) + extra = (max_num_blocks * 4 + 127) & ~127 + return base + extra <= _FLASH_V100_MAX_SMEM + def _flash_v100_prefill_with_prefix( self, layer: torch.nn.Module, @@ -835,6 +870,18 @@ def _flash_v100_prefill_with_prefix( head_dim = key_cache.shape[3] debug_compare = (os.getenv("VLLM_FLASH_V100_DEBUG_PREFILL_COMPARE", "0") == "1") + # The paged prefill kernel stores the block table in smem; at long + # max_model_len it overflows V100's 96KB. When it won't fit, use the + # gather + dense path below (smem-safe at any context length). + paged_smem_fits = self._paged_prefill_smem_fits(attn_metadata) + if not paged_smem_fits: + global _warned_paged_prefill_smem + if not _warned_paged_prefill_smem: + logger.info( + "FLASH_ATTN_V100 paged prefill smem exceeds 96KB at this " + "max_model_len; using gather+dense prefill (still FA)." + ) + _warned_paged_prefill_smem = True query_lens = query_start_loc[1:] - query_start_loc[:-1] max_query_len = int(query_lens.max().item()) if num_seqs > 0 else 0 @@ -868,7 +915,7 @@ def _flash_v100_prefill_with_prefix( if end <= start: continue - if self.use_flash_v100_prefill_paged: + if self.use_flash_v100_prefill_paged and paged_smem_fits: out_seq = self.flash_attn_prefill_paged( query[start:end].unsqueeze(0), key_cache, From 68d48fc80ed6a4731bde664d469a4b5a24127ef8 Mon Sep 17 00:00:00 2001 From: Phil Date: Mon, 8 Jun 2026 14:31:51 +0000 Subject: [PATCH 25/25] [FA_V100] Run gemma E2B/E4B KV-shared layers on FA (read target cache) Gemma E2B/E4B use KV-cache sharing (num_kv_shared_layers): the last N decoder layers reuse an earlier layer's KV instead of projecting their own. Per gemma4.py, a shared layer applies RoPE to Q only and passes raw, un-normed/un-RoPE'd K/V to Attention, relying on it reading the TARGET layer's cache via kv_sharing_target_layer_name. The FA_V100 no-prefix prefill path consumed the passed K/V directly, so shared layers attended to junk -> coherent-but-wrong output (e2b: "capital of France" -> "Hanoi"; raw template -> degenerate loop). Decode was already correct (the paged decode kernel reads kv_cache directly), and the kernel math was fine (parity passes at num_kv_heads=1) -- the bug was purely no-prefix prefill. Route a shared layer's no-prefix prefill through the prefix path, which reads the TARGET cache (aliased into kv_cache, already written by the earlier layer this pass) via the paged kernel when smem-safe, else gather+dense. 31b is unaffected (num_kv_shared_layers=0). Verified: e2b now runs fully on FA with correct output. Co-Authored-By: RivetOS Claude --- vllm/v1/attention/backends/flash_attn_v100.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/vllm/v1/attention/backends/flash_attn_v100.py b/vllm/v1/attention/backends/flash_attn_v100.py index 2d58d7aa36..523d534d93 100644 --- a/vllm/v1/attention/backends/flash_attn_v100.py +++ b/vllm/v1/attention/backends/flash_attn_v100.py @@ -557,6 +557,18 @@ def forward( attn_metadata, output, ) + if getattr(self, "kv_sharing_target_layer_name", None) is not None: + # KV-shared layer (gemma E2B/E4B): it applies RoPE to Q only and + # passes raw, un-normed/un-RoPE'd K/V; the real K/V live in the + # TARGET layer's cache, aliased into kv_cache and already written + # by that earlier layer this forward pass. Read them via the + # prefix path (paged kernel when smem-safe, else gather+dense) + # instead of the dense passed-K/V path, which would attend to + # junk. (Decode already reads kv_cache directly.) + self._reset_decode_cache() + return self._flash_v100_prefill_with_prefix( + layer, query, kv_cache, attn_metadata, output + ) if not flash_prefill_ok: # 512-dim no-prefix prefill: dense kernel caps at 256 -> Triton. return super().forward(