From 0cf60fae6aa3088f8e0167ec24ec850c567015ab Mon Sep 17 00:00:00 2001 From: gc-fu Date: Mon, 25 May 2026 12:26:43 +0000 Subject: [PATCH] update vllm patch to v0.14.0-b8.3 and pin transformers==5.8.0 - Regenerate vllm_for_multi_arc.patch from intel-sandbox/llm-scaler-vllm-xpu branch v0.14.0-b8.3 (commit d03ae58) - Pin transformers to 5.8.0 instead of installing from git main to avoid version incompatibility issues Co-Authored-By: Claude Opus 4.6 (1M context) --- vllm/docker/Dockerfile | 2 +- vllm/patches/vllm_for_multi_arc.patch | 1533 +++++++++++++++++++------ 2 files changed, 1165 insertions(+), 370 deletions(-) diff --git a/vllm/docker/Dockerfile b/vllm/docker/Dockerfile index 608cffad..7dab597b 100644 --- a/vllm/docker/Dockerfile +++ b/vllm/docker/Dockerfile @@ -107,7 +107,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \ # Pin transformers version to avoid conflict in vLLM RUN --mount=type=cache,target=/root/.cache/pip \ pip install librosa soundfile decord && \ - pip install git+https://github.com/huggingface/transformers.git && \ + pip install transformers==5.8.0 && \ pip install ijson COPY ./patches/vllm_xpu_kernels.patch /tmp/ diff --git a/vllm/patches/vllm_for_multi_arc.patch b/vllm/patches/vllm_for_multi_arc.patch index a6ab5496..41646ed7 100644 --- a/vllm/patches/vllm_for_multi_arc.patch +++ b/vllm/patches/vllm_for_multi_arc.patch @@ -6504,10 +6504,10 @@ index 475bd8536..d9a597115 100644 + return res diff --git a/vllm/model_executor/layers/quantization/sym_int4.py b/vllm/model_executor/layers/quantization/sym_int4.py new file mode 100644 -index 000000000..0980c669e +index 000000000..aefe93764 --- /dev/null +++ b/vllm/model_executor/layers/quantization/sym_int4.py -@@ -0,0 +1,534 @@ +@@ -0,0 +1,939 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, List, Optional, Tuple, Callable, Union @@ -6532,14 +6532,56 @@ index 000000000..0980c669e +from vllm.utils.math_utils import round_up + +from vllm.envs import VLLM_OFFLOAD_WEIGHTS_BEFORE_QUANT, VLLM_QUANTIZE_Q40_LIB ++from vllm.model_executor.layers.quantization.fp8 import ( ++ CopyNumelCounter, _copy_missing_attrs, ++) ++from vllm.logger import init_logger +import ctypes +import os +from packaging import version + ++logger = init_logger(__name__) ++ +MIN_IPEX_VERSION = "2.5.0" +QK4_GROUP_SIZE: int = 128 +QK4_PACK_FACTOR: int = 8 + ++# Optional XPU-side Q4_0 quantization kernel. When available, streaming INT4 ++# quantization can quantize BF16 weights directly on XPU — skipping the ++# transient D→H copy + CPU ggml + H→D that otherwise causes a ~1 side-sized ++# BF16 peak on XPU per layer. Falls back to the CPU path silently if the ++# module is missing or if VLLM_INT4_DISABLE_XPU_QUANT=1 is set. ++_HAS_XPU_Q4_0 = False ++try: ++ import custom_esimd_kernels_vllm.q4_0_quant_ops # noqa: F401 ++ _HAS_XPU_Q4_0 = True ++except ImportError: ++ pass ++ ++ ++def _use_xpu_quant() -> bool: ++ if not _HAS_XPU_Q4_0: ++ return False ++ if os.environ.get("VLLM_INT4_DISABLE_XPU_QUANT", "0") == "1": ++ return False ++ return True ++ ++ ++def _xpu_q4_0_quantize(bf16_xpu: torch.Tensor): ++ """Quantize a 2D BF16/FP16 XPU tensor to GGML Q4_0 layout, returning ++ ``(qweight_xpu [M, K/8] int32, scale_xpu [M, K/128] float16)``. ++ Allocates outputs on the same XPU device as the input.""" ++ assert bf16_xpu.dim() == 2, bf16_xpu.shape ++ M, K = bf16_xpu.shape ++ dev = bf16_xpu.device ++ qweight = torch.empty( ++ M, K // QK4_PACK_FACTOR, dtype=torch.int32, device=dev) ++ scale = torch.empty( ++ M, K // QK4_GROUP_SIZE, dtype=torch.float16, device=dev) ++ torch.ops.custom_esimd_kernels_vllm.q4_0_quantize( ++ bf16_xpu.contiguous(), qweight, scale) ++ return qweight, scale ++ +_QLIB_CACHE = None + +def _get_quant_lib(): @@ -6697,14 +6739,100 @@ index 000000000..0980c669e + layer.orig_dtype = params_dtype + + weight_dtype = params_dtype -+ weight = ModelWeightParameter(data=torch.empty( -+ output_size_per_partition, -+ input_size_per_partition, -+ dtype=weight_dtype, -+ device="cpu"), -+ input_dim=1, -+ output_dim=0, -+ weight_loader=weight_loader) ++ ++ if VLLM_OFFLOAD_WEIGHTS_BEFORE_QUANT: ++ # Legacy path: buffer BF16 weights on CPU; they will later be ++ # moved to XPU by device_loading_context and quantized in ++ # process_weights_after_loading. ++ weight = ModelWeightParameter( ++ data=torch.empty( ++ output_size_per_partition, ++ input_size_per_partition, ++ dtype=weight_dtype, ++ device="cpu", ++ ), ++ input_dim=1, ++ output_dim=0, ++ weight_loader=weight_loader, ++ ) ++ layer.register_parameter("weight", weight) ++ return ++ ++ # Streaming path: allocate on meta; the patched weight_loader will ++ # materialize on XPU on the first shard, quantize once fully loaded, ++ # and release the BF16 intermediate. ++ orig_weight_loader = weight_loader ++ layer._load_device = torch.get_default_device() ++ ++ def patched_weight_loader(param, loaded_weight, *args, **kwargs): ++ # First call: materialize the meta placeholder on the target ++ # device and re-register the parameter so later loads write real ++ # storage. `_streaming_patched` lets qwen3_5.load_weights route ++ # tuple shard_id to us instead of bypassing via weight_loader_v2. ++ if not hasattr(layer, "_loaded_numel"): ++ layer._loaded_numel = 0 ++ materialized = ModelWeightParameter( ++ data=torch.empty_like( ++ layer.weight, device=layer._load_device, ++ ), ++ input_dim=1, ++ output_dim=0, ++ weight_loader=patched_weight_loader, ++ ) ++ _copy_missing_attrs(layer.weight, materialized) ++ materialized._streaming_patched = True ++ layer.register_parameter("weight", materialized) ++ ++ # Always refresh to the live parameter after potential ++ # re-registration; stale references can come from params_dict ++ # snapshots captured before materialization. ++ param = layer.weight ++ ++ # Detect tuple shard_id (GDN fused projections use ++ # MergedColumnParallelLinear.weight_loader_v2 which accepts ++ # tuple). The INT4 v1 weight_loader does not handle tuple, so ++ # we dispatch through layer.weight_loader_v2 directly (Path X). ++ shard_id = kwargs.get("loaded_shard_id") ++ if shard_id is None: ++ shard_id = kwargs.get("shard_id") ++ if shard_id is None and args: ++ shard_id = args[0] ++ ++ copy_numel_counter = CopyNumelCounter() ++ with copy_numel_counter: ++ if (isinstance(shard_id, tuple) ++ and hasattr(layer, "weight_loader_v2")): ++ layer.weight_loader_v2(param, loaded_weight, shard_id) ++ res = None ++ else: ++ res = orig_weight_loader( ++ param, loaded_weight, *args, **kwargs, ++ ) ++ layer._loaded_numel += copy_numel_counter.copied_numel ++ ++ # Fully loaded: quantize in place and release the BF16 copy ++ # before the next layer begins loading. ++ if layer._loaded_numel >= layer.weight.numel(): ++ _quantize_linear_int4_inplace(layer) ++ del layer._loaded_numel ++ if hasattr(layer, "_load_device"): ++ del layer._load_device ++ layer._already_called_process_weights_after_loading = True ++ ++ return res ++ ++ weight = ModelWeightParameter( ++ data=torch.empty( ++ output_size_per_partition, ++ input_size_per_partition, ++ dtype=weight_dtype, ++ device="meta", ++ ), ++ input_dim=1, ++ output_dim=0, ++ weight_loader=patched_weight_loader, ++ ) ++ weight._streaming_patched = True + layer.register_parameter("weight", weight) + + def apply(self, @@ -6719,75 +6847,134 @@ index 000000000..0980c669e + return out.reshape(x.shape[:-1] + (layer.ipex_output_size, )) + + def process_weights_after_loading(self, layer: Module) -> None: -+ weight = layer.weight.float() -+ out_features = layer.weight.shape[0] -+ in_features = layer.weight.shape[1] ++ if getattr(layer, "_already_called_process_weights_after_loading", ++ False): ++ return ++ ++ # Fallback for layers that never saw their streaming weight_loader ++ # (e.g. tied lm_head, skipped vision branches). Materialize zeros so ++ # downstream quantization does not read uninitialized memory. ++ if (hasattr(layer, "weight") and layer.weight is not None ++ and layer.weight.device == torch.device("meta")): ++ device = getattr(layer, "_load_device", torch.device("xpu")) ++ wl = getattr(layer.weight, "weight_loader", ++ lambda p, w, *a, **k: None) ++ materialized = ModelWeightParameter( ++ data=torch.zeros_like(layer.weight, device=device), ++ input_dim=1, output_dim=0, weight_loader=wl) ++ layer.register_parameter("weight", materialized) ++ if hasattr(layer, "_load_device"): ++ del layer._load_device + -+ qweight = torch.zeros((out_features, in_features // QK4_PACK_FACTOR), dtype=torch.int32, device=layer.weight.device) -+ scale = torch.zeros((out_features, in_features // QK4_GROUP_SIZE), dtype=torch.float16, device=layer.weight.device) ++ _quantize_linear_int4_inplace(layer) + -+ # transpose=False: C output is [N, K/8] contiguous (GGML row-major). -+ qweight, scale = ggml_quantize_tensor( -+ weight, qweight, scale, out_features, in_features, -+ block_size=QK4_GROUP_SIZE, transpose=False, -+ ) + -+ # Store [N, K/8] and [N, K/GROUP_SIZE] contiguous on XPU for ESIMD. -+ # IPEX needs [K/8, N] (transposed), so we give it a .t() view. -+ # Note: IPEX transpose_xetla_woq_format mutates the .data of the -+ # tensor it receives, so the ESIMD copy must be separate storage. -+ qweight_xpu = qweight.to("xpu") # [N, K/8] contiguous -+ scale_xpu = scale.to("xpu") # [N, K/GROUP_SIZE] contiguous ++def _quantize_linear_int4_inplace(layer: Module) -> None: ++ """Quantize a Linear layer's BF16/FP16 weight to INT4 and wire up IPEX. + -+ # ESIMD kernel weights: [N, K/2] uint8, [N, K/GROUP_SIZE] fp16 -+ layer.weight_esimd = Parameter( -+ qweight_xpu.view(torch.uint8), requires_grad=False) -+ layer.scale_esimd = Parameter(scale_xpu, requires_grad=False) ++ Works for both paths: ++ - Legacy CPU offload: ``layer.weight`` is BF16 on CPU (or temporarily ++ moved to XPU by ``device_loading_context``). ++ - Streaming: ``layer.weight`` is BF16 on XPU; we copy to CPU for the ++ CPU-only ``ggml_quantize_tensor`` and release XPU BF16 immediately. ++ """ ++ bf16 = layer.weight.data ++ out_features = bf16.shape[0] ++ in_features = bf16.shape[1] ++ ++ if bf16.device.type == "xpu" and _use_xpu_quant(): ++ # XPU streaming path: quantize directly on device — no CPU round trip. ++ # This keeps XPU peak close to OFFLOAD=1 because we never hold both ++ # BF16 + FP32 nor do a D→H bounce of the full side. ++ qweight_xpu, scale_xpu = _xpu_q4_0_quantize(bf16) ++ layer.weight = None ++ del bf16 ++ try: ++ torch.xpu.empty_cache() ++ except Exception: ++ pass ++ else: ++ if bf16.device.type == "xpu": ++ # CPU-fallback streaming path: pull to CPU before BF16→FP32 upcast ++ # so the FP32 temporary never lives on XPU. Doing ++ # ``bf16.float().cpu()`` instead would briefly hold BF16 + FP32 ++ # (= 3× BF16) on XPU for this layer. ++ bf16_cpu = bf16.cpu() ++ layer.weight = None ++ del bf16 ++ try: ++ torch.xpu.empty_cache() ++ except Exception: ++ pass ++ weight_cpu = bf16_cpu.float().contiguous() ++ del bf16_cpu ++ else: ++ weight_cpu = bf16.float().contiguous() + -+ # For IPEX: transposed view -+ layer.weight = Parameter(qweight_xpu.t(), requires_grad=False) -+ layer.weight_scale = Parameter(scale_xpu.t(), requires_grad=False) ++ qweight = torch.zeros( ++ (out_features, in_features // QK4_PACK_FACTOR), ++ dtype=torch.int32, device="cpu", ++ ) ++ scale = torch.zeros( ++ (out_features, in_features // QK4_GROUP_SIZE), ++ dtype=torch.float16, device="cpu", ++ ) ++ qweight, scale = ggml_quantize_tensor( ++ weight_cpu, qweight, scale, out_features, in_features, ++ block_size=QK4_GROUP_SIZE, transpose=False, ++ ) ++ del weight_cpu + ++ # Move quantized weights to XPU. IPEX needs [K/8, N] (transposed); ++ # ESIMD needs [N, K/2] uint8. ++ qweight_xpu = qweight.to("xpu") ++ scale_xpu = scale.to("xpu") ++ del qweight, scale + ++ layer.weight_esimd = Parameter( ++ qweight_xpu.view(torch.uint8), requires_grad=False) ++ layer.scale_esimd = Parameter(scale_xpu, requires_grad=False) ++ layer.weight = Parameter(qweight_xpu.t(), requires_grad=False) ++ layer.weight_scale = Parameter(scale_xpu.t(), requires_grad=False) + -+ try: -+ import intel_extension_for_pytorch as ipex -+ if version.parse(ipex.__version__) < version.parse(MIN_IPEX_VERSION): -+ raise ImportError( -+ f"intel_extension_for_pytorch version is wrong. " -+ f"Current: {ipex.__version__}, Required: >={MIN_IPEX_VERSION}") -+ except ImportError as err: ++ try: ++ import intel_extension_for_pytorch as ipex ++ if version.parse(ipex.__version__) < version.parse(MIN_IPEX_VERSION): + raise ImportError( -+ "Please install " -+ f"intel_extension_for_pytorch>={MIN_IPEX_VERSION} via " -+ f"`pip install intel_extension_for_pytorch>={MIN_IPEX_VERSION}`" -+ " to use IPEX-AWQ linear method.") from err -+ -+ lowp_mode = ipex.quantization.WoqLowpMode.INT8 -+ weight_dtype = ipex.quantization.WoqWeightDtype.INT4 -+ act_quant_mode = ipex.quantization.WoqActQuantMode.PER_BATCH_IC_BLOCK -+ qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping( -+ weight_dtype=weight_dtype, -+ lowp_mode=lowp_mode, -+ act_quant_mode=act_quant_mode, -+ group_size=QK4_GROUP_SIZE, -+ ) -+ layer.ipex_output_size = layer.weight.shape[-1] -+ g_idx = None -+ layer.ipex_qlinear = ipex.llm.quantization.woq_linear. \ -+ IPEXWeightOnlyQuantizedLinear.from_weight( -+ layer.weight, # weight should be on xpu... -+ layer.weight_scale, -+ torch.tensor([8], device=layer.weight.device, dtype=torch.int8), -+ layer.weight.size(0), -+ layer.ipex_output_size, -+ qconfig=qconfig, -+ g_idx=g_idx, -+ bias=None, -+ group_size=QK4_GROUP_SIZE, -+ # For GPTQ layout -+ quant_method=0 -+ ) ++ f"intel_extension_for_pytorch version is wrong. " ++ f"Current: {ipex.__version__}, Required: >={MIN_IPEX_VERSION}") ++ except ImportError as err: ++ raise ImportError( ++ "Please install " ++ f"intel_extension_for_pytorch>={MIN_IPEX_VERSION} via " ++ f"`pip install intel_extension_for_pytorch>={MIN_IPEX_VERSION}`" ++ " to use IPEX-AWQ linear method.") from err ++ ++ lowp_mode = ipex.quantization.WoqLowpMode.INT8 ++ weight_dtype = ipex.quantization.WoqWeightDtype.INT4 ++ act_quant_mode = ipex.quantization.WoqActQuantMode.PER_BATCH_IC_BLOCK ++ qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping( ++ weight_dtype=weight_dtype, ++ lowp_mode=lowp_mode, ++ act_quant_mode=act_quant_mode, ++ group_size=QK4_GROUP_SIZE, ++ ) ++ layer.ipex_output_size = layer.weight.shape[-1] ++ g_idx = None ++ layer.ipex_qlinear = ipex.llm.quantization.woq_linear. \ ++ IPEXWeightOnlyQuantizedLinear.from_weight( ++ layer.weight, ++ layer.weight_scale, ++ torch.tensor([8], device=layer.weight.device, dtype=torch.int8), ++ layer.weight.size(0), ++ layer.ipex_output_size, ++ qconfig=qconfig, ++ g_idx=g_idx, ++ bias=None, ++ group_size=QK4_GROUP_SIZE, ++ # For GPTQ layout ++ quant_method=0, ++ ) + + +def _to_cutlass_nmajor(qweight_nk: torch.Tensor): @@ -6865,6 +7052,174 @@ index 000000000..0980c669e + return out + + ++def _quantize_moe_int4_side_inplace(layer: Module, which: str) -> None: ++ """Quantize a single side (w13 or w2) of an MoE layer to INT4. ++ ++ Loads BF16 weights (wherever they currently live), sends them to CPU ++ one side at a time for ``ggml_quantize_tensor``, moves the quantized ++ int32 qweight + fp16 scales to XPU, and registers them back on the ++ layer under ``_weight`` / ``_scales``. The BF16 source is ++ released before the XPU allocation to minimize peak memory. ++ """ ++ assert which in ("w13", "w2"), f"unknown side {which!r}" ++ ++ d_model = layer.hidden_size ++ d_ff = layer.d_ff ++ num_loop = getattr(layer, "local_num_experts", layer.num_experts) ++ ++ weight_name = f"{which}_weight" ++ scale_name = f"{which}_scales" ++ ++ src = getattr(layer, weight_name).data ++ if src.dim() != 3: ++ raise RuntimeError( ++ f"expected 3D MoE weight for {weight_name}, got {src.shape}") ++ ++ if which == "w13": ++ out_features, in_features = 2 * d_ff, d_model ++ else: ++ out_features, in_features = d_model, d_ff ++ ++ E_all = src.shape[0] ++ ++ # XPU streaming fast path: quantize each expert directly on device, one at ++ # a time, and drop its BF16 slice when done. Peak XPU footprint per layer ++ # is now (accumulated int4) + 1 BF16 expert, not a full BF16 side. ++ if src.device.type == "xpu" and _use_xpu_quant(): ++ qweight_xpu = torch.empty( ++ E_all, out_features, in_features // QK4_PACK_FACTOR, ++ dtype=torch.int32, device=src.device) ++ scales_xpu = torch.empty( ++ E_all, out_features, in_features // QK4_GROUP_SIZE, ++ dtype=torch.float16, device=src.device) ++ for e in range(num_loop): ++ q_e, s_e = _xpu_q4_0_quantize(src[e]) ++ qweight_xpu[e].copy_(q_e) ++ scales_xpu[e].copy_(s_e) ++ del q_e, s_e ++ # Experts outside [0, num_loop) are not local; their slices stay ++ # zero-initialized (matches legacy path behavior). ++ setattr(layer, weight_name, None) ++ del src ++ try: ++ torch.xpu.empty_cache() ++ except Exception: ++ pass ++ setattr(layer, weight_name, torch.nn.Parameter( ++ qweight_xpu, requires_grad=False)) ++ setattr(layer, scale_name, torch.nn.Parameter( ++ scales_xpu, requires_grad=False)) ++ return ++ ++ # CPU fallback path: keep the existing D→H + ggml + H→D behavior. ++ # Streaming path: weights live on XPU. Move BF16 to CPU once (BF16 keeps ++ # the memory footprint half of FP32), release XPU storage immediately, ++ # then quantize per-expert with transient FP32 buffers — matching the ++ # legacy path's per-expert peak shape. ++ # Legacy path: weights already on CPU as BF16; alias directly and drop ++ # the parameter slot so we can overwrite with the quantized result. ++ if src.device.type == "xpu": ++ src_cpu = src.cpu().contiguous() ++ setattr(layer, weight_name, None) ++ del src ++ try: ++ torch.xpu.empty_cache() ++ except Exception: ++ pass ++ else: ++ src_cpu = src ++ ++ E = src_cpu.shape[0] ++ qweight = torch.empty( ++ E, out_features, in_features // QK4_PACK_FACTOR, ++ dtype=torch.int32, device="cpu", ++ ) ++ scales = torch.empty( ++ E, out_features, in_features // QK4_GROUP_SIZE, ++ dtype=torch.float16, device="cpu", ++ ) ++ ++ def _quantize_expert(e): ++ # Per-expert FP32 copy so peak CPU footprint is BF16_full + ++ # FP32_one_expert + qweight_accum + scales_accum. ++ expert_fp32 = src_cpu[e].float().contiguous() ++ q_buf = torch.zeros( ++ (out_features, in_features // QK4_PACK_FACTOR), ++ dtype=torch.int32, device="cpu", ++ ) ++ s_buf = torch.zeros( ++ (out_features, in_features // QK4_GROUP_SIZE), ++ dtype=torch.float16, device="cpu", ++ ) ++ q, s = ggml_quantize_tensor( ++ expert_fp32, q_buf, s_buf, out_features, in_features, ++ block_size=QK4_GROUP_SIZE, transpose=False, ++ ) ++ qweight[e].copy_(q) ++ scales[e].copy_(s) ++ ++ with ThreadPoolExecutor() as executor: ++ list(executor.map(_quantize_expert, range(num_loop))) ++ ++ # Release the BF16 source now that all experts are quantized. For the ++ # legacy path this also drops the original parameter storage. ++ del src_cpu ++ if getattr(layer, weight_name, None) is not None: ++ setattr(layer, weight_name, None) ++ ++ qweight_xpu = qweight.to("xpu") ++ scales_xpu = scales.to("xpu") ++ del qweight, scales ++ ++ setattr(layer, weight_name, torch.nn.Parameter( ++ qweight_xpu, requires_grad=False)) ++ setattr(layer, scale_name, torch.nn.Parameter( ++ scales_xpu, requires_grad=False)) ++ ++ ++def _setup_moe_int4_kernel(layer: Module, method) -> None: ++ """Finalize MoE INT4 layout: CUTLASS N-major repack (optional) and IPEX ++ fusion kernel. Idempotent — safe to call twice. ++ """ ++ if getattr(layer, "_moe_int4_kernel_ready", False): ++ return ++ ++ import intel_extension_for_pytorch as ipex ++ ++ use_esimd = os.environ.get("USE_ESIMD_MOE_PREFILL", "1") == "1" ++ method._use_esimd_prefill = use_esimd ++ ++ if use_esimd: ++ from vllm_xpu_kernels.fused_moe_interface import implement_zp ++ E = layer.num_experts ++ ++ w13_qweight = _to_cutlass_nmajor(layer.w13_weight.data) ++ w2_qweight = _to_cutlass_nmajor(layer.w2_weight.data) ++ w13_tmp = torch.empty_like(w13_qweight) ++ w2_tmp = torch.empty_like(w2_qweight) ++ for i in range(E): ++ w13_tmp[i] = implement_zp(w13_qweight[i]) ++ w2_tmp[i] = implement_zp(w2_qweight[i]) ++ layer.w13_weight = torch.nn.Parameter( ++ w13_tmp.contiguous(), requires_grad=False) ++ layer.w2_weight = torch.nn.Parameter( ++ w2_tmp.contiguous(), requires_grad=False) ++ layer.ipex_fusion = None ++ # Mark as already converted so xpu_fused_moe skips implement_zp. ++ layer.w13_weight.xpu_fused_moe = True ++ layer.w2_weight.xpu_fused_moe = True ++ else: ++ layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE( ++ layer.w13_weight, ++ layer.w2_weight, ++ w1_scale_inv=layer.w13_scales, ++ w2_scale_inv=layer.w2_scales, ++ is_int4=True, ++ ) ++ ++ layer._moe_int4_kernel_ready = True ++ ++ +class XPUGPTQInt4LinearMoEMethod(FusedMoEMethodBase): + def __init__( + self, @@ -6885,7 +7240,6 @@ index 000000000..0980c669e + def create_weights(self, layer: Module, num_experts: int, hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): -+ # Just normally loads the weights, obey VLLM_OFFLOAD_WEIGHTS_BEFORE_QUANT... + layer.intermediate_size_per_partition = intermediate_size_per_partition + layer.hidden_size = hidden_size + layer.num_experts = num_experts @@ -6906,114 +7260,165 @@ index 000000000..0980c669e + elif tp_size == 16: + intermediate_size_per_partition = round_up(intermediate_size_per_partition, 128) + layer.d_ff = intermediate_size_per_partition ++ ++ if VLLM_OFFLOAD_WEIGHTS_BEFORE_QUANT: ++ # Legacy path: full BF16 buffers on CPU; device_loading_context ++ # will move them onto XPU prior to process_weights_after_loading. ++ weight_device = "cpu" ++ else: ++ # Streaming path: meta placeholders, materialized JIT on first ++ # shard and quantized immediately when each side is full. ++ weight_device = "meta" ++ layer._load_device = torch.get_default_device() ++ ++ orig_weight_loader = extra_weight_attrs.get("weight_loader") ++ + # w13 shape: [d_ff * 2, d_model] -+ w13_weight = torch.nn.Parameter(torch.empty( -+ num_experts, -+ 2 * intermediate_size_per_partition, -+ hidden_size, -+ dtype=params_dtype, -+ device="cpu"), -+ requires_grad=False) ++ w13_weight = torch.nn.Parameter( ++ torch.empty( ++ num_experts, ++ 2 * intermediate_size_per_partition, ++ hidden_size, ++ dtype=params_dtype, ++ device=weight_device, ++ ), ++ requires_grad=False, ++ ) + layer.register_parameter("w13_weight", w13_weight) -+ set_weight_attrs(w13_weight, extra_weight_attrs) + + # w2 shape: [d_model, d_ff] -+ w2_weight = torch.nn.Parameter(torch.empty( -+ num_experts, -+ hidden_size, -+ intermediate_size_per_partition, -+ dtype=params_dtype, -+ device="cpu"), -+ requires_grad=False) ++ w2_weight = torch.nn.Parameter( ++ torch.empty( ++ num_experts, ++ hidden_size, ++ intermediate_size_per_partition, ++ dtype=params_dtype, ++ device=weight_device, ++ ), ++ requires_grad=False, ++ ) + layer.register_parameter("w2_weight", w2_weight) -+ set_weight_attrs(w2_weight, extra_weight_attrs) + -+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: -+ import intel_extension_for_pytorch as ipex -+ E = layer.num_experts -+ d_model = layer.hidden_size -+ d_ff = layer.d_ff -+ -+ assert d_model % QK4_PACK_FACTOR == 0 and d_ff % QK4_PACK_FACTOR == 0, "INT4 packing requires feature dims % 8 == 0" -+ assert d_model % QK4_GROUP_SIZE == 0 and d_ff % QK4_GROUP_SIZE == 0, f"group_size={QK4_GROUP_SIZE} requires dims % {QK4_GROUP_SIZE} == 0" -+ -+ # Allocating CPU tensors -+ w13_qweight = torch.empty(E, 2 * d_ff, d_model // QK4_PACK_FACTOR, dtype=torch.int32, device="cpu") -+ w2_qweight = torch.empty(E, d_model, d_ff // QK4_PACK_FACTOR, dtype=torch.int32, device="cpu") -+ w13_scales = torch.empty(E, 2 * d_ff, d_model // QK4_GROUP_SIZE, dtype=torch.float16, device="cpu") -+ w2_scales = torch.empty(E, d_model, d_ff // QK4_GROUP_SIZE, dtype=torch.float16, device="cpu") -+ -+ # Quantize per expert (parallelized across experts) -+ num_loop = getattr(layer, "local_num_experts", E) -+ -+ def _quantize_expert(e): -+ # Ensure fp32 contiguous -+ w13_e = layer.w13_weight[e].float().contiguous().to("cpu") -+ w2_e = layer.w2_weight[e].float().contiguous().to("cpu") -+ -+ # --- w13 --- (transpose=False: C output is already [out, in//8]) -+ q13_buf = torch.zeros((2 * d_ff, d_model // QK4_PACK_FACTOR), dtype=torch.int32, device="cpu") -+ s13_buf = torch.zeros((2 * d_ff, d_model // QK4_GROUP_SIZE), dtype=torch.float16, device="cpu") -+ q13, s13 = ggml_quantize_tensor( -+ w13_e, q13_buf, s13_buf, 2 * d_ff, d_model, -+ block_size=QK4_GROUP_SIZE, transpose=False, -+ ) -+ w13_qweight[e].copy_(q13) -+ w13_scales[e].copy_(s13) -+ -+ # --- w2 --- -+ q2_buf = torch.zeros((d_model, d_ff // QK4_PACK_FACTOR), dtype=torch.int32, device="cpu") -+ s2_buf = torch.zeros((d_model, d_ff // QK4_GROUP_SIZE), dtype=torch.float16, device="cpu") -+ q2, s2 = ggml_quantize_tensor( -+ w2_e, q2_buf, s2_buf, d_model, d_ff, -+ block_size=QK4_GROUP_SIZE, transpose=False, -+ ) -+ w2_qweight[e].copy_(q2) -+ w2_scales[e].copy_(s2) ++ if VLLM_OFFLOAD_WEIGHTS_BEFORE_QUANT: ++ set_weight_attrs(w13_weight, extra_weight_attrs) ++ set_weight_attrs(w2_weight, extra_weight_attrs) ++ return + -+ with ThreadPoolExecutor() as executor: -+ list(executor.map(_quantize_expert, range(num_loop))) ++ # Streaming: install a patched weight_loader that (a) materializes ++ # w13 / w2 on the target device the first time its respective shard ++ # arrives, (b) counts loaded elements independently, and (c) kicks ++ # off per-side INT4 quantization the moment each side is fully ++ # loaded — releasing the BF16 intermediate before the next layer ++ # starts loading. Mirrors the FP8 MoE streaming path. ++ def patched_moe_weight_loader(param, loaded_weight, *args, **kwargs): ++ shard_id = kwargs.get("shard_id") ++ if shard_id is None and len(args) >= 2: ++ shard_id = args[1] ++ is_w13 = shard_id in ("w1", "w3") ++ ++ if is_w13 and not hasattr(layer, "_w13_materialized"): ++ layer._w13_materialized = True ++ layer._w13_loaded_numel = 0 ++ new_w13 = torch.nn.Parameter( ++ torch.empty_like( ++ layer.w13_weight, device=layer._load_device, ++ ), ++ requires_grad=False, ++ ) ++ new_attrs = dict(extra_weight_attrs) ++ new_attrs["weight_loader"] = patched_moe_weight_loader ++ set_weight_attrs(new_w13, new_attrs) ++ layer.register_parameter("w13_weight", new_w13) ++ ++ if (not is_w13) and not hasattr(layer, "_w2_materialized"): ++ layer._w2_materialized = True ++ layer._w2_loaded_numel = 0 ++ new_w2 = torch.nn.Parameter( ++ torch.empty_like( ++ layer.w2_weight, device=layer._load_device, ++ ), ++ requires_grad=False, ++ ) ++ new_attrs = dict(extra_weight_attrs) ++ new_attrs["weight_loader"] = patched_moe_weight_loader ++ set_weight_attrs(new_w2, new_attrs) ++ layer.register_parameter("w2_weight", new_w2) ++ ++ if (hasattr(layer, "_w13_materialized") ++ and hasattr(layer, "_w2_materialized") ++ and hasattr(layer, "_load_device")): ++ del layer._load_device + -+ # Move to XPU -+ w13_qweight = w13_qweight.to("xpu") -+ w2_qweight = w2_qweight.to("xpu") -+ w13_scales = w13_scales.to("xpu") -+ w2_scales = w2_scales.to("xpu") ++ param = layer.w13_weight if is_w13 else layer.w2_weight + -+ # When USE_ESIMD_MOE_PREFILL=1, repack weights into CUTLASS N-major -+ # uint8 layout and apply implement_zp (unsigned→signed int4 encoding). -+ self._use_esimd_prefill = os.environ.get("USE_ESIMD_MOE_PREFILL", "1") == "1" -+ if self._use_esimd_prefill: -+ from vllm_xpu_kernels.fused_moe_interface import implement_zp -+ w13_qweight = _to_cutlass_nmajor(w13_qweight) -+ w2_qweight = _to_cutlass_nmajor(w2_qweight) -+ w13_tmp = torch.empty_like(w13_qweight) -+ w2_tmp = torch.empty_like(w2_qweight) -+ for i in range(E): -+ w13_tmp[i] = implement_zp(w13_qweight[i]) -+ w2_tmp[i] = implement_zp(w2_qweight[i]) -+ w13_qweight = w13_tmp.contiguous() -+ w2_qweight = w2_tmp.contiguous() -+ -+ # Override parameters -+ layer.w13_weight = torch.nn.Parameter(w13_qweight, requires_grad=False) -+ layer.w2_weight = torch.nn.Parameter(w2_qweight, requires_grad=False) -+ layer.w13_scales = torch.nn.Parameter(w13_scales, requires_grad=False) -+ layer.w2_scales = torch.nn.Parameter(w2_scales, requires_grad=False) -+ -+ if not self._use_esimd_prefill: -+ layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE( -+ layer.w13_weight, -+ layer.w2_weight, -+ w1_scale_inv=layer.w13_scales, -+ w2_scale_inv=layer.w2_scales, -+ is_int4=True -+ ) -+ else: -+ layer.ipex_fusion = None -+ # Mark as already converted so xpu_fused_moe skips implement_zp -+ layer.w13_weight.xpu_fused_moe = True -+ layer.w2_weight.xpu_fused_moe = True ++ copy_numel_counter = CopyNumelCounter() ++ with copy_numel_counter: ++ res = orig_weight_loader( ++ param, loaded_weight, *args, **kwargs, ++ ) ++ ++ if is_w13: ++ layer._w13_loaded_numel += copy_numel_counter.copied_numel ++ if layer._w13_loaded_numel >= layer.w13_weight.numel(): ++ _quantize_moe_int4_side_inplace(layer, "w13") ++ del layer._w13_loaded_numel ++ else: ++ layer._w2_loaded_numel += copy_numel_counter.copied_numel ++ if layer._w2_loaded_numel >= layer.w2_weight.numel(): ++ _quantize_moe_int4_side_inplace(layer, "w2") ++ del layer._w2_loaded_numel ++ ++ # When both sides are fully quantized, set up the runtime kernel ++ # and short-circuit the later process_weights_after_loading. ++ if (not hasattr(layer, "_w13_loaded_numel") ++ and not hasattr(layer, "_w2_loaded_numel") ++ and getattr(layer, "_w13_materialized", False) ++ and getattr(layer, "_w2_materialized", False)): ++ _setup_moe_int4_kernel(layer, self) ++ layer._already_called_process_weights_after_loading = True ++ ++ return res ++ ++ patched_attrs = dict(extra_weight_attrs) ++ patched_attrs["weight_loader"] = patched_moe_weight_loader ++ set_weight_attrs(w13_weight, patched_attrs) ++ set_weight_attrs(w2_weight, patched_attrs) ++ ++ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: ++ if getattr(layer, "_already_called_process_weights_after_loading", ++ False): ++ # Streaming already finished quantizing and setting up the ++ # kernel. Still guarantee idempotent kernel setup in case the ++ # caller invoked this after a state reload. ++ _setup_moe_int4_kernel(layer, self) ++ return ++ ++ assert layer.hidden_size % QK4_PACK_FACTOR == 0 \ ++ and layer.d_ff % QK4_PACK_FACTOR == 0, \ ++ "INT4 packing requires feature dims % 8 == 0" ++ assert layer.hidden_size % QK4_GROUP_SIZE == 0 \ ++ and layer.d_ff % QK4_GROUP_SIZE == 0, \ ++ f"group_size={QK4_GROUP_SIZE} requires dims % {QK4_GROUP_SIZE} == 0" ++ ++ # meta fallback for layers whose weight_loader was never invoked ++ # (e.g. tied / skipped branches). ++ for name in ("w13_weight", "w2_weight"): ++ p = getattr(layer, name, None) ++ if p is not None and p.device == torch.device("meta"): ++ dev = getattr(layer, "_load_device", torch.device("xpu")) ++ setattr(layer, name, torch.nn.Parameter( ++ torch.zeros_like(p, device=dev), requires_grad=False, ++ )) ++ if hasattr(layer, "_load_device"): ++ del layer._load_device ++ ++ # Legacy (OFFLOAD=1) path: w13 / w2 still hold BF16. Quantize each ++ # side (releasing its BF16 before moving on), then build the kernel. ++ if layer.w13_weight.dtype != torch.int32: ++ _quantize_moe_int4_side_inplace(layer, "w13") ++ if layer.w2_weight.dtype != torch.int32: ++ _quantize_moe_int4_side_inplace(layer, "w2") ++ _setup_moe_int4_kernel(layer, self) + + + def apply( @@ -10165,10 +10570,10 @@ index 3b0dce7fc..e785eca41 100644 diff --git a/vllm/model_executor/models/qwen3_5.py b/vllm/model_executor/models/qwen3_5.py new file mode 100644 -index 000000000..f20b837b7 +index 000000000..ae061004a --- /dev/null +++ b/vllm/model_executor/models/qwen3_5.py -@@ -0,0 +1,1622 @@ +@@ -0,0 +1,1780 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + @@ -10477,6 +10882,10 @@ index 000000000..f20b837b7 + quant_config is not None + and quant_config.get_name() == "sym_int4" + ) ++ # esimd_gemm_int4_pgrp requires N%16==0; in_proj_ba may not satisfy ++ # this at higher TP (e.g. 27B TP=4 → N=24). ++ _ba_n = (self.num_v_heads * 2) // self.tp_size ++ self._int4_gemm_ok = self._is_sym_int4 and (_ba_n % 16 == 0) + + # Fused norm + out_proj GEMV for decode (FP8 or INT4) + self._use_fused_out_proj = hasattr(self.out_proj, 'weight_scale') @@ -10484,6 +10893,13 @@ index 000000000..f20b837b7 + (1, self.hidden_size), + dtype=torch.float16, device=_dev) + self._norm_weight_fp16 = None # lazily cached after weights loaded ++ # Disable lazy fp16 norm-weight caching only for 27B on TP=4 to work ++ # around an XPU allocator bug that corrupts cached storage at certain ++ # max_model_len values. Other configs keep the cache to avoid the ++ # per-forward .half().clone() overhead. ++ self._disable_norm_cache = ( ++ config.hidden_size == 5120 and self.tp_size == 4 ++ ) + + # Pre-compute gather indices to convert GEMV output from + # sequential [q|k|v|z] to GQA-interleaved [q_g0|k_g0|v_g0|z_g0|...] @@ -10513,6 +10929,24 @@ index 000000000..f20b837b7 + idx_ba.extend(range(nv_tp + g * vpg, nv_tp + (g + 1) * vpg)) + self._gather_ba = torch.tensor(idx_ba, dtype=torch.long, device=_dev) + ++ # Pre-allocate BSZ>1 decode buffers (avoids per-forward torch.zeros/empty) ++ _mb = int(os.environ.get("MAX_DECODE_BSZ", "64")) ++ self._max_bsz = _mb ++ self._m_qkvz = torch.empty( ++ (_mb, (self.key_dim * 2 + self.value_dim * 2) // self.tp_size), ++ dtype=torch.float16, device=_dev) ++ self._m_ba = torch.empty( ++ (_mb, (self.num_v_heads * 2) // self.tp_size), ++ dtype=torch.float16, device=_dev) ++ self._m_attn_out = torch.empty( ++ (_mb, self.num_v_heads // self.tp_size, self.head_v_dim), ++ dtype=torch.float16, device=_dev) ++ self._m_z = torch.empty( ++ (_mb, self.num_v_heads // self.tp_size, self.head_v_dim), ++ dtype=torch.float16, device=_dev) ++ self._m_outproj = torch.empty( ++ (_mb, self.hidden_size), dtype=torch.float16, device=_dev) ++ + def fix_query_key_value_ordering( + self, + mixed_qkvz: torch.Tensor, @@ -10562,13 +10996,26 @@ index 000000000..f20b837b7 + num_tokens = hidden_states.size(0) + is_decode = (num_tokens == 1) + ++ # Resolve this layer's attn_metadata up front so Part 1 can pick the ++ # right projected-states layout (sequential for ESIMD GDN vs ++ # interleaved for the C++ prefill op). ++ # forward_context = get_forward_context() ++ # attn_metadata = forward_context.attn_metadata ++ # if attn_metadata is not None: ++ # attn_metadata = attn_metadata[self.prefix] ++ # _use_esimd_gdn = is_decode or ( ++ # attn_metadata is not None ++ # and attn_metadata.num_prefills == 0 ++ # and attn_metadata.num_decodes > 0 ++ # and attn_metadata.num_decodes <= 128 ++ # ) ++ + # ============================================================ + # Part 1: Input Projection + # ============================================================ + if _PROFILE_ATTN and is_decode: + torch.xpu.synchronize() + _t0 = time.perf_counter() -+ + if is_decode and self._use_esimd_proj: + # M=1, FP8: single fused GEMV for qkvz + ba projections + from custom_esimd_kernels_vllm import esimd_gemv_fp8_pert, esimd_gemv_fp8_pert_fused2 @@ -10608,14 +11055,18 @@ index 000000000..f20b837b7 + projected_states_qkvz = qkvz_merged + projected_states_ba = ba_merged + elif num_tokens <= 64 and self._use_esimd_proj: -+ # M=2-64, FP8: ESIMD GEMM, gather to interleaved for prefill kernel ++ # M=2-64, FP8: ESIMD GEMM. + from custom_esimd_kernels_vllm import esimd_gemm_fp8_pert -+ N_qkvz = self.in_proj_qkvz.weight.shape[0] -+ N_ba = self.in_proj_ba.weight.shape[0] -+ qkvz_merged = torch.empty( -+ (num_tokens, N_qkvz), dtype=torch.float16, device=hidden_states.device) -+ ba_merged = torch.empty( -+ (num_tokens, N_ba), dtype=torch.float16, device=hidden_states.device) ++ if num_tokens <= self._max_bsz: ++ qkvz_merged = self._m_qkvz[:num_tokens] ++ ba_merged = self._m_ba[:num_tokens] ++ else: ++ N_qkvz = self.in_proj_qkvz.weight.shape[0] ++ N_ba = self.in_proj_ba.weight.shape[0] ++ qkvz_merged = torch.empty( ++ (num_tokens, N_qkvz), dtype=torch.float16, device=hidden_states.device) ++ ba_merged = torch.empty( ++ (num_tokens, N_ba), dtype=torch.float16, device=hidden_states.device) + esimd_gemm_fp8_pert( + hidden_states, self.in_proj_qkvz.weight, + self.in_proj_qkvz.weight_scale, qkvz_merged) @@ -10624,8 +11075,27 @@ index 000000000..f20b837b7 + self.in_proj_ba.weight_scale, ba_merged) + projected_states_qkvz = qkvz_merged[:, self._gather_qkvz] + projected_states_ba = ba_merged[:, self._gather_ba] ++ elif num_tokens <= 64 and self._int4_gemm_ok and num_tokens <= self._max_bsz: ++ # M=2-64, INT4: ESIMD DPAS GEMM (esimd_gemm_int4_pgrp). ++ # Kernel requires N % 16 == 0 and K % 128 == 0. ++ from custom_esimd_kernels_vllm import esimd_gemm_int4_pgrp ++ qkvz_merged = self._m_qkvz[:num_tokens] ++ ba_merged = self._m_ba[:num_tokens] ++ esimd_gemm_int4_pgrp( ++ hidden_states, ++ self.in_proj_qkvz.weight_esimd, ++ self.in_proj_qkvz.scale_esimd, ++ qkvz_merged) ++ esimd_gemm_int4_pgrp( ++ hidden_states, ++ self.in_proj_ba.weight_esimd, ++ self.in_proj_ba.scale_esimd, ++ ba_merged) ++ projected_states_qkvz = qkvz_merged[:, self._gather_qkvz] ++ projected_states_ba = ba_merged[:, self._gather_ba] + else: -+ # Fallback: standard Linear, gather to interleaved for prefill kernel ++ # Fallback: standard Linear. ++ # Same layout choice as above. + qkvz_merged, _ = self.in_proj_qkvz(hidden_states) + ba_merged, _ = self.in_proj_ba(hidden_states) + projected_states_qkvz = qkvz_merged[:, self._gather_qkvz] @@ -10643,6 +11113,9 @@ index 000000000..f20b837b7 + if is_decode: + core_attn_out = self._decode_attn_out_buf + z_out = self._decode_z_out_buf ++ elif num_tokens <= self._max_bsz: ++ core_attn_out = self._m_attn_out[:num_tokens] ++ z_out = self._m_z[:num_tokens] + else: + core_attn_out = torch.zeros( + (num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim), @@ -10652,7 +11125,6 @@ index 000000000..f20b837b7 + z_out = torch.empty_like(core_attn_out) + if attn_metadata is not None: + attn_metadata = attn_metadata[self.prefix] -+ + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + conv_state = self_kv_cache[0] + ssm_state = self_kv_cache[1] @@ -10744,23 +11216,28 @@ index 000000000..f20b837b7 + from vllm.distributed import tensor_model_parallel_all_reduce + nv_tp = self._cached_nv_tp + hv = self.head_v_dim -+ if self._norm_weight_fp16 is None: -+ # .clone().contiguous() to own private storage — the XPU -+ # cache allocator has been observed reusing a shared fp16 -+ # half() result's storage for other buffers mid-inference, -+ # corrupting the norm weight and producing NaN cascades. -+ self._norm_weight_fp16 = ( -+ self.norm.weight.data.half().clone().contiguous() -+ ) ++ if self._disable_norm_cache: ++ # 27B+TP4: recompute every forward to dodge XPU allocator bug. ++ norm_w_fp16 = self.norm.weight.data.half().clone().contiguous() ++ else: ++ if self._norm_weight_fp16 is None: ++ # .clone().contiguous() to own private storage — the XPU ++ # cache allocator has been observed reusing a shared fp16 ++ # half() result's storage for other buffers mid-inference, ++ # corrupting the norm weight and producing NaN cascades. ++ self._norm_weight_fp16 = ( ++ self.norm.weight.data.half().clone().contiguous() ++ ) ++ norm_w_fp16 = self._norm_weight_fp16 + if self._is_sym_int4: + from custom_esimd_kernels_vllm import esimd_norm_gemv_int4_pert + # .t() gives contiguous (N, K/8) layout for block_load + esimd_norm_gemv_int4_pert( + core_attn_out.view(nv_tp, hv), + z_out.view(nv_tp, hv), -+ self._norm_weight_fp16, -+ self.out_proj.weight.t(), -+ self.out_proj.weight_scale.t(), ++ norm_w_fp16, ++ self.out_proj.weight_esimd.view(torch.int32), ++ self.out_proj.scale_esimd, + self._decode_outproj_buf, + nv_tp, hv, self.norm.eps, + ) @@ -10769,7 +11246,7 @@ index 000000000..f20b837b7 + esimd_norm_gemv_fp8_pert( + core_attn_out.view(nv_tp, hv), + z_out.view(nv_tp, hv), -+ self._norm_weight_fp16, ++ norm_w_fp16, + self.out_proj.weight, + self.out_proj.weight_scale, + self._decode_outproj_buf, @@ -10778,7 +11255,7 @@ index 000000000..f20b837b7 + output[:1] = tensor_model_parallel_all_reduce( + self._decode_outproj_buf) + elif is_decode: -+ # Decode non-FP8 fallback ++ # Decode non-FP8/sym_int4 fallback + nv_tp = self._cached_nv_tp + hv = self.head_v_dim + core_attn_out_2d = core_attn_out.view(nv_tp, hv) @@ -10786,9 +11263,44 @@ index 000000000..f20b837b7 + core_attn_out_2d = self.norm(core_attn_out_2d, z_out_2d) + core_attn_out_flat = core_attn_out_2d.view(1, nv_tp * hv) + output[:1], _ = self.out_proj(core_attn_out_flat) ++ elif num_tokens <= 64 and self._is_sym_int4: ++ # BSZ=2-64: ESIMD RMSNormGated (single kernel, quantization-agnostic) ++ from custom_esimd_kernels_vllm import esimd_rms_norm_gated ++ from vllm.distributed import tensor_model_parallel_all_reduce ++ if self._disable_norm_cache: ++ # 27B+TP4: recompute every forward to dodge XPU allocator bug. ++ norm_w_fp16 = self.norm.weight.data.half().clone().contiguous() ++ else: ++ if self._norm_weight_fp16 is None: ++ self._norm_weight_fp16 = ( ++ self.norm.weight.data.half().clone().contiguous() ++ ) ++ norm_w_fp16 = self._norm_weight_fp16 ++ x_flat = core_attn_out.reshape(-1, core_attn_out.shape[-1]) ++ z_flat = z_out.reshape(-1, z_out.shape[-1]) ++ normed = torch.empty_like(x_flat) ++ esimd_rms_norm_gated(x_flat, z_flat, norm_w_fp16, normed, self.norm.eps) ++ core_attn_out = normed.reshape(num_tokens, -1) ++ if self._is_sym_int4: ++ from custom_esimd_kernels_vllm import esimd_gemm_int4_pgrp ++ out_buf = self._m_outproj[:num_tokens] ++ esimd_gemm_int4_pgrp( ++ core_attn_out, self.out_proj.weight_esimd, ++ self.out_proj.scale_esimd, out_buf) ++ output[:num_tokens] = tensor_model_parallel_all_reduce(out_buf) ++ elif self._use_fused_out_proj: ++ from custom_esimd_kernels_vllm import esimd_gemm_fp8_pert ++ out_buf = self._m_outproj[:num_tokens] ++ esimd_gemm_fp8_pert( ++ core_attn_out, self.out_proj.weight, ++ self.out_proj.weight_scale, out_buf) ++ output[:num_tokens] = tensor_model_parallel_all_reduce(out_buf) ++ else: ++ core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], -1) ++ output[:num_tokens], _ = self.out_proj(core_attn_out) + else: ++ # BSZ>64: fallback to PyTorch RMSNormGated + z_shape_og = z_out.shape -+ # Reshape input data into 2D tensor + core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) + z_out = z_out.reshape(-1, z_out.shape[-1]) + core_attn_out = self.norm(core_attn_out, z_out) @@ -10832,64 +11344,103 @@ index 000000000..f20b837b7 + projected_states_ba: torch.Tensor, + output: torch.Tensor, + ): -+ """Decode-only fast path: input projection already done by caller. -+ Uses sequential [q|k|v|z] layout with esimd_gdn_conv_fused_seq.""" -+ from vllm.distributed import tensor_model_parallel_all_reduce ++ """Full version matching forward_xpu decode path exactly.""" + -+ # Part 2: Core Attention ++ # ---- Part 2: Core Attention (copied from forward_xpu decode) ---- + forward_context = get_forward_context() + attn_metadata = forward_context.attn_metadata ++ + core_attn_out = self._decode_attn_out_buf + core_attn_out.zero_() + z_out = self._decode_z_out_buf ++ + if attn_metadata is not None: + attn_metadata = attn_metadata[self.prefix] ++ + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + conv_state = self_kv_cache[0] + ssm_state = self_kv_cache[1] ++ + if self._cached_conv_weights is None: + self._cached_conv_weights = self.conv1d.weight.view( -+ self.conv1d.weight.size(0), self.conv1d.weight.size(2)) ++ self.conv1d.weight.size(0), self.conv1d.weight.size(2) ++ ) ++ + N_dec = attn_metadata.num_decodes + state_idx = attn_metadata.non_spec_state_indices_tensor[:N_dec] ++ ++ + esimd_gdn_conv_fused_seq( -+ projected_states_qkvz, conv_state, -+ self._cached_conv_weights, self.conv_bias_zeros, state_idx, -+ self.A_log, self.dt_bias, projected_states_ba, -+ ssm_state, state_idx, core_attn_out, z_out, -+ N_dec, self._cached_nk_tp, self._cached_nv_tp, -+ self.head_k_dim, self.head_v_dim, self._cached_attn_scale, ++ projected_states_qkvz, ++ conv_state, ++ self._cached_conv_weights, ++ self.conv_bias_zeros, ++ state_idx, ++ self.A_log, ++ self.dt_bias, ++ projected_states_ba, ++ ssm_state, ++ state_idx, ++ core_attn_out, ++ z_out, ++ N_dec, ++ self._cached_nk_tp, ++ self._cached_nv_tp, ++ self.head_k_dim, ++ self.head_v_dim, ++ self._cached_attn_scale, + ) + -+ # Part 3: Output Projection (fused norm + GEMV) ++ # ---- Part 3: Output Projection (copied from forward_xpu decode) ---- + nv_tp = self._cached_nv_tp + hv = self.head_v_dim -+ if self._norm_weight_fp16 is None: -+ # .clone().contiguous() to own private storage — the XPU -+ # cache allocator has been observed reusing a shared fp16 -+ # half() result's storage for other buffers mid-inference, -+ # corrupting the norm weight and producing NaN cascades. -+ self._norm_weight_fp16 = ( -+ self.norm.weight.data.half().clone().contiguous() -+ ) -+ if self._is_sym_int4: -+ from custom_esimd_kernels_vllm import esimd_norm_gemv_int4_pert -+ esimd_norm_gemv_int4_pert( -+ core_attn_out.view(nv_tp, hv), z_out.view(nv_tp, hv), -+ self._norm_weight_fp16, -+ self.out_proj.weight.t(), self.out_proj.weight_scale.t(), -+ self._decode_outproj_buf, nv_tp, hv, self.norm.eps, -+ ) ++ if self._use_fused_out_proj: ++ from vllm.distributed import tensor_model_parallel_all_reduce ++ if self._disable_norm_cache: ++ # 27B+TP4: recompute every forward to dodge XPU allocator bug. ++ norm_w_fp16 = self.norm.weight.data.half().clone().contiguous() ++ else: ++ if self._norm_weight_fp16 is None: ++ # .clone().contiguous() to own private storage — the XPU ++ # cache allocator has been observed reusing a shared fp16 ++ # half() result's storage for other buffers mid-inference, ++ # corrupting the norm weight and producing NaN cascades. ++ self._norm_weight_fp16 = ( ++ self.norm.weight.data.half().clone().contiguous() ++ ) ++ norm_w_fp16 = self._norm_weight_fp16 ++ if self._is_sym_int4: ++ from custom_esimd_kernels_vllm import esimd_norm_gemv_int4_pert ++ esimd_norm_gemv_int4_pert( ++ core_attn_out.view(nv_tp, hv), ++ z_out.view(nv_tp, hv), ++ norm_w_fp16, ++ self.out_proj.weight_esimd.view(torch.int32), ++ self.out_proj.scale_esimd, ++ self._decode_outproj_buf, ++ nv_tp, hv, self.norm.eps, ++ ) ++ else: ++ from custom_esimd_kernels_vllm import esimd_norm_gemv_fp8_pert ++ esimd_norm_gemv_fp8_pert( ++ core_attn_out.view(nv_tp, hv), ++ z_out.view(nv_tp, hv), ++ norm_w_fp16, ++ self.out_proj.weight, ++ self.out_proj.weight_scale, ++ self._decode_outproj_buf, ++ nv_tp, hv, self.norm.eps, ++ ) ++ output[:1] = tensor_model_parallel_all_reduce( ++ self._decode_outproj_buf) + else: -+ from custom_esimd_kernels_vllm import esimd_norm_gemv_fp8_pert -+ esimd_norm_gemv_fp8_pert( -+ core_attn_out.view(nv_tp, hv), z_out.view(nv_tp, hv), -+ self._norm_weight_fp16, -+ self.out_proj.weight, self.out_proj.weight_scale, -+ self._decode_outproj_buf, nv_tp, hv, self.norm.eps, -+ ) -+ output[:1] = tensor_model_parallel_all_reduce( -+ self._decode_outproj_buf) ++ from vllm.distributed import tensor_model_parallel_all_reduce ++ core_attn_out_2d = core_attn_out.view(nv_tp, hv) ++ z_out_2d = z_out.view(nv_tp, hv) ++ core_attn_out_2d = self.norm(core_attn_out_2d, z_out_2d) ++ core_attn_out_flat = core_attn_out_2d.view(1, nv_tp * hv) ++ output[:], _ = self.out_proj(core_attn_out_flat) ++ + + def forward_cuda( + self, @@ -11035,7 +11586,20 @@ index 000000000..f20b837b7 + ), + ) + -+ # Detect dense MLP with FP8 quantization for ESIMD fast path ++ # Disable lazy fp16 norm-weight caching only for 27B on TP=4 to work ++ # around an XPU allocator bug that corrupts cached storage at certain ++ # max_model_len values. Other configs keep the cache to avoid the ++ # per-forward .half().clone() overhead. ++ self._disable_norm_cache = ( ++ config.hidden_size == 5120 ++ and get_tensor_model_parallel_world_size() == 4 ++ ) ++ ++ # Detect dense MLP with FP8 quantization for ESIMD fast path. ++ # INT4 dense MLP does not currently have a working fast path (the ++ # fused resadd+norm+GEMV kernel is numerically broken for the 27B ++ # shape), so INT4 falls back to self.mlp(...) → IPEX qlinear. ++ # This matches the pre-ec3f74a31 behavior for Qwen3.5-27B INT4. + _quant_name = quant_config.get_name() if quant_config is not None else "" + self._dense_mlp_fp8 = ( + isinstance(self.mlp, Qwen3NextMLP) @@ -11043,7 +11607,7 @@ index 000000000..f20b837b7 + and _quant_name == "fp8" + and os.environ.get("DISABLE_ESIMD_DENSE", "0") != "1" + ) -+ self._dense_mlp_is_int4 = (_quant_name == "sym_int4") if self._dense_mlp_fp8 else False ++ self._dense_mlp_is_int4 = False + self._max_bsz = 0 # default when no FP8 dense MLP path; overwritten below + if self._dense_mlp_fp8: + _dev = current_platform.current_device() @@ -11067,14 +11631,14 @@ index 000000000..f20b837b7 + # Lazily cached after weights are loaded + self._dense_post_norm_w_fp16 = None + -+ # Detect MoE with FP8 for ESIMD fused norm+router path -+ self._moe_fp8 = ( ++ # Detect MoE with FP8/INT4 for ESIMD fused norm+router path ++ self._moe_esimd_enabled = ( + isinstance(self.mlp, Qwen3NextSparseMoeBlock) -+ and hasattr(self.mlp, 'use_fp8') -+ and self.mlp.use_fp8 ++ and hasattr(self.mlp, 'gate') ++ and _quant_name in ("fp8", "sym_int4") + ) -+ self._moe_is_int4 = (_quant_name == "sym_int4") if self._moe_fp8 else False -+ if self._moe_fp8: ++ self._moe_is_int4 = (_quant_name == "sym_int4") if self._moe_esimd_enabled else False ++ if self._moe_esimd_enabled: + _dev = current_platform.current_device() + n_exp = self.mlp.gate.weight.shape[0] + self._post_norm_w_fp16 = None # lazily cached @@ -11083,13 +11647,12 @@ index 000000000..f20b837b7 + self._normed_buf = torch.empty( + 1, config.hidden_size, dtype=torch.float16, device=_dev) + -+ # Detect fused input_norm + input_proj opportunity (FP8, decode) -+ _is_fp8 = (quant_config is not None -+ and quant_config.get_name() == "fp8") ++ # Detect fused input_norm + input_proj opportunity (FP8/INT4, decode) + self._fused_input_norm = ( -+ _is_fp8 and not self.layer_scale ++ _quant_name in ("fp8", "sym_int4") and not self.layer_scale + and os.environ.get("DISABLE_ESIMD_FUSED_INPUT", "0") != "1" + ) ++ self._is_fused_int4 = (_quant_name == "sym_int4") if self._fused_input_norm else False + if self._fused_input_norm: + _dev = current_platform.current_device() + _hidden = config.hidden_size @@ -12974,7 +13537,7 @@ index f2f354604..ad88271b9 100644 if weight_name not in name: continue diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py -index 0f73a7746..69176f66a 100644 +index 0f73a7746..9b7608cbb 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -2,10 +2,22 @@ @@ -12994,7 +13557,7 @@ index 0f73a7746..69176f66a 100644 +from custom_esimd_kernels_vllm import esimd_resadd_norm_gemv_fp8_pert +from custom_esimd_kernels_vllm import esimd_resadd_norm_gemv2_fp8_pert +from custom_esimd_kernels_vllm import esimd_qkv_split_norm_rope -+from custom_esimd_kernels_vllm import esimd_gdn_conv_fused ++from custom_esimd_kernels_vllm import esimd_gdn_conv_fused, esimd_gdn_conv_fused_seq +from custom_esimd_kernels_vllm import esimd_rms_norm_gated +from custom_esimd_kernels_vllm import esimd_fused_add_rms_norm_batched from einops import rearrange @@ -13160,7 +13723,7 @@ index 0f73a7746..69176f66a 100644 if self.is_sequence_parallel: hidden_states = sequence_parallel_chunk(hidden_states) -@@ -366,6 +490,80 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): +@@ -366,6 +490,88 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): if prefix in compilation_config.static_forward_context: raise ValueError(f"Duplicate layer name: {prefix}") compilation_config.static_forward_context[prefix] = self @@ -13199,8 +13762,16 @@ index 0f73a7746..69176f66a 100644 + ) + # Cache fp16 norm weight (norm.weight is float32, kernel needs fp16) + self._norm_weight_fp16 = None # lazily cached after weights loaded ++ # Disable lazy fp16 norm-weight caching only for 27B on TP=4 to work ++ # around an XPU allocator bug that corrupts cached storage at certain ++ # max_model_len values. Other configs (incl. Qwen3-Coder-Next) keep ++ # the cache to avoid the per-forward .half().clone() overhead. ++ self._disable_norm_cache = ( ++ config.hidden_size == 5120 and self.tp_size == 4 ++ ) + # FP8 flag for ESIMD GEMM path + self._is_fp8 = quant_config is not None and quant_config.get_name() == "fp8" ++ self._is_sym_int4 = quant_config is not None and quant_config.get_name() == "sym_int4" + # BSZ>1 pre-allocated buffers (separate from BSZ=1 to avoid any interference) + _mb = int(os.environ.get("MAX_DECODE_BSZ", "64")) + self._max_bsz = _mb @@ -13241,7 +13812,7 @@ index 0f73a7746..69176f66a 100644 def fix_query_key_value_ordering( self, -@@ -441,6 +639,387 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): +@@ -441,6 +647,419 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): self, hidden_states: torch.Tensor, output: torch.Tensor, @@ -13415,18 +13986,23 @@ index 0f73a7746..69176f66a 100644 + # Decode fast path: fused RMSNormGated + out_proj GEMV + nv_tp = self._cached_nv_tp + hv = self.head_v_dim -+ if self._norm_weight_fp16 is None: -+ # .clone().contiguous() to own private storage — the XPU -+ # cache allocator has been observed reusing a shared fp16 -+ # half() result's storage for other buffers mid-inference, -+ # corrupting the norm weight and producing NaN cascades. -+ self._norm_weight_fp16 = ( -+ self.norm.weight.data.half().clone().contiguous() -+ ) ++ if self._disable_norm_cache: ++ # 27B+TP4: recompute every forward to dodge XPU allocator bug. ++ norm_w_fp16 = self.norm.weight.data.half().clone().contiguous() ++ else: ++ if self._norm_weight_fp16 is None: ++ # .clone().contiguous() to own private storage — the XPU ++ # cache allocator has been observed reusing a shared fp16 ++ # half() result's storage for other buffers mid-inference, ++ # corrupting the norm weight and producing NaN cascades. ++ self._norm_weight_fp16 = ( ++ self.norm.weight.data.half().clone().contiguous() ++ ) ++ norm_w_fp16 = self._norm_weight_fp16 + esimd_norm_gemv_fp8_pert( + self._decode_attn_out_view, + self._decode_z_view, -+ self._norm_weight_fp16, ++ norm_w_fp16, + self.out_proj.weight, + self.out_proj.weight_scale, + self._decode_outproj_buf, @@ -13437,18 +14013,23 @@ index 0f73a7746..69176f66a 100644 + output[:1] = tensor_model_parallel_all_reduce(self._decode_outproj_buf) + elif num_tokens <= 64 and self._is_fp8: + # BSZ=2-64: ESIMD norm then ESIMD GEMM for out_proj -+ if self._norm_weight_fp16 is None: -+ # .clone().contiguous() to own private storage — the XPU -+ # cache allocator has been observed reusing a shared fp16 -+ # half() result's storage for other buffers mid-inference, -+ # corrupting the norm weight and producing NaN cascades. -+ self._norm_weight_fp16 = ( -+ self.norm.weight.data.half().clone().contiguous() -+ ) ++ if self._disable_norm_cache: ++ # 27B+TP4: recompute every forward to dodge XPU allocator bug. ++ norm_w_fp16 = self.norm.weight.data.half().clone().contiguous() ++ else: ++ if self._norm_weight_fp16 is None: ++ # .clone().contiguous() to own private storage — the XPU ++ # cache allocator has been observed reusing a shared fp16 ++ # half() result's storage for other buffers mid-inference, ++ # corrupting the norm weight and producing NaN cascades. ++ self._norm_weight_fp16 = ( ++ self.norm.weight.data.half().clone().contiguous() ++ ) ++ norm_w_fp16 = self._norm_weight_fp16 + x_flat = core_attn_out.reshape(-1, core_attn_out.shape[-1]) + z_flat = z.reshape(-1, z.shape[-1]) + normed = torch.empty_like(x_flat) -+ esimd_rms_norm_gated(x_flat, z_flat, self._norm_weight_fp16, ++ esimd_rms_norm_gated(x_flat, z_flat, norm_w_fp16, + normed, self.norm.eps) + core_attn_out = normed.reshape(num_tokens, -1) + if num_tokens <= self._max_bsz: @@ -13520,25 +14101,42 @@ index 0f73a7746..69176f66a 100644 + # Part 3: Output Projection (fused norm + GEMV) + nv_tp = self._cached_nv_tp + hv = self.head_v_dim -+ if self._norm_weight_fp16 is None: -+ # .clone().contiguous() to own private storage — the XPU -+ # cache allocator has been observed reusing a shared fp16 -+ # half() result's storage for other buffers mid-inference, -+ # corrupting the norm weight and producing NaN cascades. -+ self._norm_weight_fp16 = ( -+ self.norm.weight.data.half().clone().contiguous() ++ if self._disable_norm_cache: ++ # 27B+TP4: recompute every forward to dodge XPU allocator bug. ++ norm_w_fp16 = self.norm.weight.data.half().clone().contiguous() ++ else: ++ if self._norm_weight_fp16 is None: ++ # .clone().contiguous() to own private storage — the XPU ++ # cache allocator has been observed reusing a shared fp16 ++ # half() result's storage for other buffers mid-inference, ++ # corrupting the norm weight and producing NaN cascades. ++ self._norm_weight_fp16 = ( ++ self.norm.weight.data.half().clone().contiguous() ++ ) ++ norm_w_fp16 = self._norm_weight_fp16 ++ if getattr(self, '_is_sym_int4', False): ++ from custom_esimd_kernels_vllm import esimd_norm_gemv_int4_pert ++ esimd_norm_gemv_int4_pert( ++ core_attn_out.view(nv_tp, hv), ++ z.view(nv_tp, hv), ++ norm_w_fp16, ++ self.out_proj.weight_esimd.view(torch.int32), ++ self.out_proj.scale_esimd, ++ self._decode_outproj_buf, ++ nv_tp, hv, self.norm.eps, ++ ) ++ else: ++ esimd_norm_gemv_fp8_pert( ++ core_attn_out.view(nv_tp, hv), ++ z.view(nv_tp, hv), ++ norm_w_fp16, ++ self.out_proj.weight, ++ self.out_proj.weight_scale, ++ self._decode_outproj_buf, ++ nv_tp, ++ hv, ++ self.norm.eps, + ) -+ esimd_norm_gemv_fp8_pert( -+ core_attn_out.view(nv_tp, hv), -+ z.view(nv_tp, hv), -+ self._norm_weight_fp16, -+ self.out_proj.weight, -+ self.out_proj.weight_scale, -+ self._decode_outproj_buf, -+ nv_tp, -+ hv, -+ self.norm.eps, -+ ) + output[:1] = tensor_model_parallel_all_reduce(self._decode_outproj_buf) + + def forward_xpu_batched_precomputed_proj( @@ -13597,18 +14195,23 @@ index 0f73a7746..69176f66a 100644 + ) + + # Part 3: Output Projection — ESIMD norm + ESIMD GEMM -+ if self._norm_weight_fp16 is None: -+ # .clone().contiguous() to own private storage — the XPU -+ # cache allocator has been observed reusing a shared fp16 -+ # half() result's storage for other buffers mid-inference, -+ # corrupting the norm weight and producing NaN cascades. -+ self._norm_weight_fp16 = ( -+ self.norm.weight.data.half().clone().contiguous() -+ ) ++ if self._disable_norm_cache: ++ # 27B+TP4: recompute every forward to dodge XPU allocator bug. ++ norm_w_fp16 = self.norm.weight.data.half().clone().contiguous() ++ else: ++ if self._norm_weight_fp16 is None: ++ # .clone().contiguous() to own private storage — the XPU ++ # cache allocator has been observed reusing a shared fp16 ++ # half() result's storage for other buffers mid-inference, ++ # corrupting the norm weight and producing NaN cascades. ++ self._norm_weight_fp16 = ( ++ self.norm.weight.data.half().clone().contiguous() ++ ) ++ norm_w_fp16 = self._norm_weight_fp16 + x_flat = core_attn_out.reshape(-1, core_attn_out.shape[-1]) + z_flat = z.reshape(-1, z.shape[-1]) + normed = torch.empty_like(x_flat) -+ esimd_rms_norm_gated(x_flat, z_flat, self._norm_weight_fp16, ++ esimd_rms_norm_gated(x_flat, z_flat, norm_w_fp16, + normed, self.norm.eps) + core_attn_out = normed.reshape(num_tokens, -1) + if num_tokens <= self._max_bsz: @@ -13629,7 +14232,7 @@ index 0f73a7746..69176f66a 100644 ): """ Forward pass with three parts: -@@ -778,42 +1357,345 @@ class Qwen3NextAttention(nn.Module): +@@ -778,42 +1397,360 @@ class Qwen3NextAttention(nn.Module): self.q_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm = Qwen3NextRMSNorm(self.head_dim, eps=config.rms_norm_eps) @@ -13903,12 +14506,21 @@ index 0f73a7746..69176f66a 100644 - output[:], _ = self.o_proj(attn_output) + # o_proj GEMV + all_reduce -+ esimd_gemv_fp8_pert( -+ attn_output, -+ self.o_proj.weight, -+ self.o_proj.weight_scale, -+ self._decode_o_buf, -+ ) ++ if self._is_sym_int4: ++ from custom_esimd_kernels_vllm import esimd_gemv_int4 ++ esimd_gemv_int4( ++ attn_output, ++ self.o_proj.weight_esimd, ++ self.o_proj.scale_esimd, ++ self._decode_o_buf, ++ ) ++ else: ++ esimd_gemv_fp8_pert( ++ attn_output, ++ self.o_proj.weight, ++ self.o_proj.weight_scale, ++ self._decode_o_buf, ++ ) + output[:1] = tensor_model_parallel_all_reduce(self._decode_o_buf) + + def forward_batched_precomputed_qkv( @@ -13987,17 +14599,32 @@ index 0f73a7746..69176f66a 100644 + dtype=torch.float16, + device=qkv.device, + ) -+ esimd_gemm_fp8_pert( -+ attn_output, self.o_proj.weight, self.o_proj.weight_scale, o_out -+ ) ++ if self._is_sym_int4: ++ from custom_esimd_kernels_vllm import esimd_gemm_int4_pgrp ++ esimd_gemm_int4_pgrp( ++ attn_output, self.o_proj.weight_esimd, ++ self.o_proj.scale_esimd, o_out) ++ else: ++ esimd_gemm_fp8_pert( ++ attn_output, self.o_proj.weight, self.o_proj.weight_scale, o_out ++ ) + output[:_ntoks] = tensor_model_parallel_all_reduce(o_out) class Qwen3NextDecoderLayer(nn.Module): -@@ -900,6 +1782,107 @@ class Qwen3NextDecoderLayer(nn.Module): +@@ -900,6 +1837,116 @@ class Qwen3NextDecoderLayer(nn.Module): ), ) ++ # Disable lazy fp16 norm-weight caching only for 27B on TP=4 to work ++ # around an XPU allocator bug that corrupts cached storage at certain ++ # max_model_len values. Other configs (incl. Qwen3-Coder-Next) keep ++ # the cache to avoid the per-forward .half().clone() overhead. ++ self._disable_norm_cache = ( ++ config.hidden_size == 5120 ++ and get_tensor_model_parallel_world_size() == 4 ++ ) ++ + # Detect dense MLP with FP8/INT4 quantization for ESIMD fast path + _quant_name = quant_config.get_name() if quant_config is not None else "" + self._dense_mlp_fp8 = ( @@ -14033,13 +14660,13 @@ index 0f73a7746..69176f66a 100644 + self._dense_post_norm_w_fp16 = None + + # Detect MoE with FP8/INT4 for ESIMD fused norm+router path -+ self._moe_fp8 = ( ++ self._moe_esimd_enabled = ( + isinstance(self.mlp, Qwen3NextSparseMoeBlock) + and hasattr(self.mlp, 'gate') + and _quant_name in ("fp8", "sym_int4") + ) -+ self._moe_is_int4 = (_quant_name == "sym_int4") if self._moe_fp8 else False -+ if self._moe_fp8: ++ self._moe_is_int4 = (_quant_name == "sym_int4") if self._moe_esimd_enabled else False ++ if self._moe_esimd_enabled: + _dev = current_platform.current_device() + n_exp = self.mlp.gate.weight.shape[0] + self._post_norm_w_fp16 = None # lazily cached @@ -14102,7 +14729,7 @@ index 0f73a7746..69176f66a 100644 def forward( self, hidden_states: torch.Tensor, -@@ -907,27 +1890,114 @@ class Qwen3NextDecoderLayer(nn.Module): +@@ -907,27 +1954,161 @@ class Qwen3NextDecoderLayer(nn.Module): positions: torch.Tensor = None, **kwargs: object, ): @@ -14112,41 +14739,54 @@ index 0f73a7746..69176f66a 100644 - else: - hidden_states, residual = self.input_layernorm(hidden_states, residual) + n_tok = hidden_states.shape[0] - -- self_attention_output = torch.empty_like(hidden_states) -- if self.layer_type == "linear_attention": -- self.linear_attn( -- hidden_states=hidden_states, -- output=self_attention_output, -- ) -- elif self.layer_type == "full_attention": -- self.self_attn( -- hidden_states=hidden_states, -- output=self_attention_output, -- positions=positions, -- ) ++ + # ---- Fused input_norm + input_proj (BSZ=1: GEMV, BSZ>1: GEMM) ---- ++ _int4_fused = getattr(self, '_is_fused_int4', False) + if ( + self._fused_input_norm + and residual is not None + and n_tok <= self._input_max_bsz + ): -+ if self._input_norm_w_fp16 is None: -+ # See _norm_weight_fp16 comment — private storage is required -+ # or the XPU cache allocator may hand this block to another -+ # tensor mid-inference and corrupt the cached norm weight. -+ self._input_norm_w_fp16 = ( ++ if self._disable_norm_cache: ++ # 27B+TP4: recompute every forward to dodge XPU allocator bug. ++ input_norm_w_fp16 = ( + (self.input_layernorm.weight.data + 1.0) + .half().clone().contiguous() + ) ++ else: ++ if self._input_norm_w_fp16 is None: ++ # See _norm_weight_fp16 comment — private storage is required ++ # or the XPU cache allocator may hand this block to another ++ # tensor mid-inference and corrupt the cached norm weight. ++ self._input_norm_w_fp16 = ( ++ (self.input_layernorm.weight.data + 1.0) ++ .half().clone().contiguous() ++ ) ++ input_norm_w_fp16 = self._input_norm_w_fp16 + esimd_fused_add_rms_norm_batched( -+ hidden_states, residual, self._input_norm_w_fp16, ++ hidden_states, residual, input_norm_w_fp16, + self.input_layernorm.variance_epsilon) + + if self.layer_type == "linear_attention": + gdn = self.linear_attn -+ if n_tok == 1: -+ # BSZ=1: fused 2-GEMV ++ if n_tok == 1 and _int4_fused: ++ # BSZ=1, INT4: fused input_proj + precomputed Part2+3 ++ from custom_esimd_kernels_vllm import esimd_gemv_int4_fused2 ++ esimd_gemv_int4_fused2( ++ hidden_states, ++ gdn.in_proj_qkvz.weight_esimd, ++ gdn.in_proj_qkvz.scale_esimd, ++ self._fused_qkvz_buf, ++ gdn.in_proj_ba.weight_esimd, ++ gdn.in_proj_ba.scale_esimd, ++ self._fused_ba_buf, ++ ) ++ self_attention_output = self._m_attn_output[:1] ++ gdn.forward_xpu_with_precomputed_proj( ++ self._fused_qkvz_buf, self._fused_ba_buf, self_attention_output ++ ) ++ elif n_tok == 1: ++ # BSZ=1, FP8: fused 2-GEMV + esimd_gemv_fp8_pert_fused2( + hidden_states, + gdn.in_proj_qkvz.weight, @@ -14176,11 +14816,36 @@ index 0f73a7746..69176f66a 100644 + hidden_states=hidden_states, + output=self_attention_output, + ) -+ + +- self_attention_output = torch.empty_like(hidden_states) +- if self.layer_type == "linear_attention": +- self.linear_attn( +- hidden_states=hidden_states, +- output=self_attention_output, +- ) +- elif self.layer_type == "full_attention": +- self.self_attn( +- hidden_states=hidden_states, +- output=self_attention_output, +- positions=positions, +- ) + elif self.layer_type == "full_attention": + attn = self.self_attn -+ if n_tok == 1: -+ # BSZ=1: GEMV ++ if n_tok == 1 and _int4_fused: ++ # BSZ=1, INT4: GEMV (INT4 kernel) ++ from custom_esimd_kernels_vllm import esimd_gemv_int4 ++ esimd_gemv_int4( ++ hidden_states, ++ attn.qkv_proj.weight_esimd, ++ attn.qkv_proj.scale_esimd, ++ self._fused_qkv_buf, ++ ) ++ self_attention_output = self._m_attn_output[:1] ++ attn.forward_with_precomputed_qkv( ++ self._fused_qkv_buf, positions, self_attention_output ++ ) ++ elif n_tok == 1: ++ # BSZ=1, FP8: GEMV + esimd_gemv_fp8_pert( + hidden_states, + attn.qkv_proj.weight, @@ -14194,12 +14859,21 @@ index 0f73a7746..69176f66a 100644 + else: + # BSZ>1: GEMM + qkv_buf = self._m_fused_qkv[:n_tok] -+ esimd_gemm_fp8_pert( -+ hidden_states, -+ attn.qkv_proj.weight, -+ attn.qkv_proj.weight_scale, -+ qkv_buf, -+ ) ++ if _int4_fused: ++ from custom_esimd_kernels_vllm import esimd_gemm_int4_pgrp ++ esimd_gemm_int4_pgrp( ++ hidden_states, ++ attn.qkv_proj.weight_esimd, ++ attn.qkv_proj.scale_esimd, ++ qkv_buf, ++ ) ++ else: ++ esimd_gemm_fp8_pert( ++ hidden_states, ++ attn.qkv_proj.weight, ++ attn.qkv_proj.weight_scale, ++ qkv_buf, ++ ) + self_attention_output = self._m_attn_output[:n_tok] + attn.forward_batched_precomputed_qkv( + qkv_buf, positions, self_attention_output @@ -14236,7 +14910,7 @@ index 0f73a7746..69176f66a 100644 if self.layer_scale: if len(hidden_states.shape) == 2: -@@ -939,9 +2009,114 @@ class Qwen3NextDecoderLayer(nn.Module): +@@ -939,9 +2120,139 @@ class Qwen3NextDecoderLayer(nn.Module): self.attn_layer_scale.to(hidden_states.dtype) + 1 ) @@ -14247,16 +14921,24 @@ index 0f73a7746..69176f66a 100644 + + # ---- Dense MLP ESIMD fast path (FP8 decode / small batch) ---- + if self._dense_mlp_fp8 and not self.layer_scale and n_tokens <= 128: -+ if self._dense_post_norm_w_fp16 is None: -+ # See _norm_weight_fp16 comment. Observed on 27B FP8 layer.46: -+ # without clone+contiguous, this cache's storage is reused -+ # for another buffer after some decode steps, giving garbage -+ # RMSNorm weight (e.g. max_abs=58560 mean=-548), which then -+ # blows normed to +/-Inf and cascades NaN to logits ('!!!!'). -+ self._dense_post_norm_w_fp16 = ( ++ if self._disable_norm_cache: ++ # 27B+TP4: recompute every forward to dodge XPU allocator bug. ++ dense_post_norm_w_fp16 = ( + (self.post_attention_layernorm.weight.data + 1.0) + .half().clone().contiguous() + ) ++ else: ++ if self._dense_post_norm_w_fp16 is None: ++ # See _norm_weight_fp16 comment. Observed on 27B FP8 layer.46: ++ # without clone+contiguous, this cache's storage is reused ++ # for another buffer after some decode steps, giving garbage ++ # RMSNorm weight (e.g. max_abs=58560 mean=-548), which then ++ # blows normed to +/-Inf and cascades NaN to logits ('!!!!'). ++ self._dense_post_norm_w_fp16 = ( ++ (self.post_attention_layernorm.weight.data + 1.0) ++ .half().clone().contiguous() ++ ) ++ dense_post_norm_w_fp16 = self._dense_post_norm_w_fp16 + + # ESIMD norm + 2x ESIMD GEMM (M=1..128). + # The FP8 GEMV kernels only reach ~23% HBM bandwidth at M=1 on @@ -14264,7 +14946,7 @@ index 0f73a7746..69176f66a 100644 + # through GEMM as well. Slicing _m_gate_up/_m_down with [:1] + # is valid since _max_bsz >= 1. + esimd_fused_add_rms_norm_batched( -+ hidden_states, residual, self._dense_post_norm_w_fp16, ++ hidden_states, residual, dense_post_norm_w_fp16, + self.post_attention_layernorm.variance_epsilon + ) + if n_tokens <= self._max_bsz: @@ -14301,21 +14983,30 @@ index 0f73a7746..69176f66a 100644 + hidden_states = tensor_model_parallel_all_reduce(down_out) + + # ---- MoE fused post_attn_norm + router for decode (bsz=1) ---- -+ elif n_tokens == 1 and self._moe_fp8 and not self.layer_scale: -+ if self._post_norm_w_fp16 is None: -+ # See _dense_post_norm_w_fp16 comment. -+ self._post_norm_w_fp16 = ( ++ elif n_tokens == 1 and self._moe_esimd_enabled and not self.layer_scale: ++ if self._disable_norm_cache: ++ # 27B+TP4: recompute every forward to dodge XPU allocator bug. ++ post_norm_w_fp16 = ( + (self.post_attention_layernorm.weight.data + 1.0) + .half().clone().contiguous() + ) ++ else: ++ if self._post_norm_w_fp16 is None: ++ # See _dense_post_norm_w_fp16 comment. ++ self._post_norm_w_fp16 = ( ++ (self.post_attention_layernorm.weight.data + 1.0) ++ .half().clone().contiguous() ++ ) ++ post_norm_w_fp16 = self._post_norm_w_fp16 + if self._moe_is_int4: + from custom_esimd_kernels_vllm import esimd_resadd_norm_gemv_int4_pert -+ # .t() gives contiguous (N, K/8) layout for block_load ++ # Use weight_esimd/scale_esimd to avoid sharing storage with ++ # IPEX-transposed weight (mirrors commit f1584b171 in qwen3_5). + esimd_resadd_norm_gemv_int4_pert( + hidden_states, residual, -+ self._post_norm_w_fp16, -+ self.mlp.gate.weight.t(), -+ self.mlp.gate.weight_scale.t(), ++ post_norm_w_fp16, ++ self.mlp.gate.weight_esimd.view(torch.int32), ++ self.mlp.gate.scale_esimd, + self._router_buf, + self._normed_buf, + self.post_attention_layernorm.variance_epsilon, @@ -14323,7 +15014,7 @@ index 0f73a7746..69176f66a 100644 + else: + esimd_resadd_norm_gemv_fp8_pert( + hidden_states, residual, -+ self._post_norm_w_fp16, ++ post_norm_w_fp16, + self.mlp.gate.weight, + self.mlp.gate.weight_scale, + self._router_buf, @@ -14337,15 +15028,23 @@ index 0f73a7746..69176f66a 100644 + + # ---- Standard path (MoE BSZ>1 or fallback) ---- + else: -+ if self._moe_fp8 and n_tokens > 1: -+ if self._post_norm_w_fp16 is None: -+ # See _dense_post_norm_w_fp16 comment. -+ self._post_norm_w_fp16 = ( ++ if self._moe_esimd_enabled and n_tokens > 1: ++ if self._disable_norm_cache: ++ # 27B+TP4: recompute every forward to dodge XPU allocator bug. ++ post_norm_w_fp16 = ( + (self.post_attention_layernorm.weight.data + 1.0) + .half().clone().contiguous() + ) ++ else: ++ if self._post_norm_w_fp16 is None: ++ # See _dense_post_norm_w_fp16 comment. ++ self._post_norm_w_fp16 = ( ++ (self.post_attention_layernorm.weight.data + 1.0) ++ .half().clone().contiguous() ++ ) ++ post_norm_w_fp16 = self._post_norm_w_fp16 + esimd_fused_add_rms_norm_batched( -+ hidden_states, residual, self._post_norm_w_fp16, ++ hidden_states, residual, post_norm_w_fp16, + self.post_attention_layernorm.variance_epsilon) + else: + hidden_states, residual = self.post_attention_layernorm( @@ -14354,7 +15053,7 @@ index 0f73a7746..69176f66a 100644 if self.layer_scale: if len(hidden_states.shape) == 2: -@@ -965,7 +2140,7 @@ class Qwen3NextModel(nn.Module): +@@ -965,7 +2276,7 @@ class Qwen3NextModel(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -14363,7 +15062,7 @@ index 0f73a7746..69176f66a 100644 parallel_config = vllm_config.parallel_config eplb_config = parallel_config.eplb_config -@@ -1042,7 +2217,7 @@ class Qwen3NextModel(nn.Module): +@@ -1042,7 +2353,7 @@ class Qwen3NextModel(nn.Module): ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", @@ -14372,7 +15071,7 @@ index 0f73a7746..69176f66a 100644 num_redundant_experts=self.num_redundant_experts, ) -@@ -1201,15 +2376,17 @@ class Qwen3NextForCausalLM( +@@ -1201,15 +2512,17 @@ class Qwen3NextForCausalLM( } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -14394,7 +15093,7 @@ index 0f73a7746..69176f66a 100644 self.quant_config = vllm_config.quant_config super().__init__() -@@ -1263,7 +2440,7 @@ class Qwen3NextForCausalLM( +@@ -1263,7 +2576,7 @@ class Qwen3NextForCausalLM( cls, vllm_config: "VllmConfig" ) -> tuple[tuple[int, int], tuple[int, int]]: parallel_config = vllm_config.parallel_config @@ -14403,7 +15102,7 @@ index 0f73a7746..69176f66a 100644 tp_size = parallel_config.tensor_parallel_size num_spec = ( vllm_config.speculative_config.num_speculative_tokens -@@ -1280,6 +2457,10 @@ class Qwen3NextForCausalLM( +@@ -1280,6 +2593,10 @@ class Qwen3NextForCausalLM( num_spec, ) @@ -15984,10 +16683,106 @@ index 0fd3d6eb3..dc75478df 100644 def get_kv_cache_stride_order( include_num_layers_dimension: bool = False, diff --git a/vllm/v1/attention/backends/fa_utils.py b/vllm/v1/attention/backends/fa_utils.py -index 130d85efb..2f5eb177c 100644 +index 130d85efb..e5dc2aa9b 100644 --- a/vllm/v1/attention/backends/fa_utils.py +++ b/vllm/v1/attention/backends/fa_utils.py -@@ -92,12 +92,18 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None: +@@ -1,6 +1,8 @@ + # SPDX-License-Identifier: Apache-2.0 + # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + ++import os ++ + from vllm.logger import init_logger + from vllm.platforms import current_platform + +@@ -17,9 +19,85 @@ elif current_platform.is_xpu(): + from vllm._ipex_ops import ipex_ops + + reshape_and_cache_flash = ipex_ops.reshape_and_cache_flash +- flash_attn_varlen_func = ipex_ops.flash_attn_varlen_func + get_scheduler_metadata = ipex_ops.get_scheduler_metadata + ++ # Supported values: "cutlass" (default), "xetla" (ipex fallback) ++ _XPU_FLASH_BACKEND = os.environ.get( ++ "VLLM_XPU_FLASH_ATTN_BACKEND", "cutlass").lower() ++ ++ if _XPU_FLASH_BACKEND == "cutlass": ++ try: ++ from vllm_xpu_kernels import ( ++ flash_attn_varlen_func as _cutlass_flash_attn_varlen_func, ++ ) ++ logger.info( ++ "Using cutlass flash attention backend for XPU (TTFT).") ++ except ImportError as e: ++ logger.warning( ++ "VLLM_XPU_FLASH_ATTN_BACKEND=cutlass but " ++ "vllm_xpu_kernels not available: %s. " ++ "Falling back to xetla (ipex).", e) ++ _XPU_FLASH_BACKEND = "xetla" ++ ++ if _XPU_FLASH_BACKEND == "cutlass": ++ ++ def flash_attn_varlen_func( ++ q, ++ k, ++ v, ++ cu_seqlens_q, ++ max_seqlen_q, ++ max_seqlen_k, ++ softmax_scale=None, ++ causal=False, ++ out=None, ++ block_table=None, ++ alibi_slopes=None, ++ window_size=None, ++ softcap=0.0, ++ seqused_k=None, ++ cu_seqlens_k=None, ++ dropout_p=0.0, ++ scheduler_metadata=None, ++ fa_version=2, ++ q_descale=None, ++ k_descale=None, ++ v_descale=None, ++ num_splits=0, ++ s_aux=None, ++ return_softmax_lse=False, ++ ): ++ result = _cutlass_flash_attn_varlen_func( ++ q=q, ++ k=k, ++ v=v, ++ max_seqlen_q=max_seqlen_q, ++ cu_seqlens_q=cu_seqlens_q, ++ max_seqlen_k=max_seqlen_k, ++ cu_seqlens_k=cu_seqlens_k, ++ seqused_k=seqused_k, ++ softmax_scale=softmax_scale, ++ causal=causal, ++ window_size=window_size, ++ softcap=softcap if softcap else 0.0, ++ alibi_slopes=alibi_slopes, ++ block_table=block_table, ++ return_softmax_lse=return_softmax_lse, ++ out=out, ++ k_descale=k_descale, ++ v_descale=v_descale, ++ s_aux=s_aux, ++ num_splits=num_splits, ++ ) ++ if return_softmax_lse: ++ return result ++ if isinstance(result, tuple): ++ return result[0] ++ return result ++ ++ else: ++ flash_attn_varlen_func = ipex_ops.flash_attn_varlen_func ++ + elif current_platform.is_rocm(): + try: + from flash_attn import flash_attn_varlen_func # noqa: F401 +@@ -92,12 +170,18 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None: def flash_attn_supports_fp8() -> bool: