From b0da78eb67164d9a6e3ecb1b2d9a267edd21ebef Mon Sep 17 00:00:00 2001 From: Ur Rahman Date: Tue, 14 Apr 2026 02:25:05 -0700 Subject: [PATCH 01/11] Add Qwen3.5-MoE (35B-A3B) model support - Added Qwen35MoeTextModel builder class for Qwen3_5MoeForConditionalGeneration architecture with 256 experts, shared expert, and SwiGLU activation - Registered qwen3_5_moe as VLM type in C++ runtime (model_type.h, model.cpp) - Added architecture dispatch in builder.py for Qwen3_5MoeForConditionalGeneration - Key implementation details: - Repacks HF concatenated gate_up_proj to ORT interleaved format (swiglu_fusion=1) - Shared expert implemented as separate SiLU MLP path with sigmoid gating - Router uses bias-free MatMul matching Qwen3_5MoeTopKRouter - QMoE symmetric blockwise quantization without explicit zero_points - Also includes existing gemma.py rope_local_base_freq fix for TranslateGemma --- src/models/model.cpp | 1 + src/models/model_type.h | 4 +- src/python/py/models/builder.py | 3 + src/python/py/models/builders/__init__.py | 3 +- src/python/py/models/builders/qwen.py | 218 ++++++++++++++++++++++ 5 files changed, 226 insertions(+), 3 deletions(-) diff --git a/src/models/model.cpp b/src/models/model.cpp index 86b2e478d..9da524675 100644 --- a/src/models/model.cpp +++ b/src/models/model.cpp @@ -951,6 +951,7 @@ MultiModalProcessor::MultiModalProcessor(Config& config, const SessionInfo& sess {"qwen2_5_vl", Processor::Create}, {"qwen3_vl", Processor::Create}, {"qwen3_5", Processor::Create}, + {"qwen3_5_moe", Processor::Create}, {"videochat_flash_qwen", Processor::Create}} { auto processor = processor_factory_.find(config.model.type); if (processor != processor_factory_.end()) { diff --git a/src/models/model_type.h b/src/models/model_type.h index dc822fe00..d887fe459 100644 --- a/src/models/model_type.h +++ b/src/models/model_type.h @@ -21,13 +21,13 @@ struct ModelType { inline static bool IsVLM(const std::string& model_type) { // Vision-language model (VLM) - static constexpr std::array VLM = {"fara", "gemma3", "mistral3", "phi3v", "qwen2_5_vl", "qwen3_vl", "qwen3_5", "videochat_flash_qwen"}; + static constexpr std::array VLM = {"fara", "gemma3", "mistral3", "phi3v", "qwen2_5_vl", "qwen3_vl", "qwen3_5", "qwen3_5_moe", "videochat_flash_qwen"}; return std::find(VLM.begin(), VLM.end(), model_type) != VLM.end(); } inline static bool IsQwenVLFamily(const std::string& model_type) { // Qwen-VL family: models requiring 3D mRoPE position IDs - return model_type == "fara" || model_type == "qwen2_5_vl" || model_type == "qwen3_vl" || model_type == "qwen3_5"; + return model_type == "fara" || model_type == "qwen2_5_vl" || model_type == "qwen3_vl" || model_type == "qwen3_5" || model_type == "qwen3_5_moe"; } inline static bool IsPixtralFamily(const std::string& model_type) { diff --git a/src/python/py/models/builder.py b/src/python/py/models/builder.py index 88cd14756..2aab41b83 100644 --- a/src/python/py/models/builder.py +++ b/src/python/py/models/builder.py @@ -46,6 +46,7 @@ Qwen3VLTextModel, Qwen25VLTextModel, Qwen35TextModel, + Qwen35MoeTextModel, QwenModel, SmolLM3Model, VideoChatFlashQwenModel, @@ -310,6 +311,8 @@ def create_model( onnx_model = Qwen3Model(config, io_dtype, onnx_dtype, execution_provider, cache_dir, extra_options) elif config.architectures[0] == "Qwen3_5ForConditionalGeneration": onnx_model = Qwen35TextModel(config, io_dtype, onnx_dtype, execution_provider, cache_dir, extra_options) + elif config.architectures[0] == "Qwen3_5MoeForConditionalGeneration": + onnx_model = Qwen35MoeTextModel(config, io_dtype, onnx_dtype, execution_provider, cache_dir, extra_options) elif config.architectures[0] == "Qwen3VLForConditionalGeneration": text_config = config.text_config for key in text_config: diff --git a/src/python/py/models/builders/__init__.py b/src/python/py/models/builders/__init__.py index 1e6330b31..ae2cf33ed 100644 --- a/src/python/py/models/builders/__init__.py +++ b/src/python/py/models/builders/__init__.py @@ -29,7 +29,7 @@ Phi4MMModel, PhiModel, ) -from .qwen import Qwen3Model, Qwen25VLTextModel, Qwen3VLTextModel, Qwen35TextModel, QwenModel, VideoChatFlashQwenModel +from .qwen import Qwen3Model, Qwen25VLTextModel, Qwen3VLTextModel, Qwen35TextModel, Qwen35MoeTextModel, QwenModel, VideoChatFlashQwenModel from .smollm import SmolLM3Model from .whisper import WhisperModel @@ -62,6 +62,7 @@ "Qwen3VLTextModel", "Qwen25VLTextModel", "Qwen35TextModel", + "Qwen35MoeTextModel", "QwenModel", "SmolLM3Model", "VideoChatFlashQwenModel", diff --git a/src/python/py/models/builders/qwen.py b/src/python/py/models/builders/qwen.py index 37499c01a..939626d9f 100644 --- a/src/python/py/models/builders/qwen.py +++ b/src/python/py/models/builders/qwen.py @@ -2067,3 +2067,221 @@ def make_genai_config(self, model_name_or_path, extra_kwargs, out_dir): del self.input_names["past_key_values.value"] del self.output_names["present.key"] del self.output_names["present.value"] + + +class Qwen35MoeTextModel(Qwen35TextModel): + """Qwen3.5 MoE hybrid model builder. + + Extends ``Qwen35TextModel`` with Mixture-of-Experts MLP layers. + Each decoder layer replaces the dense MLP with: + - A router that selects top-k experts from ``num_experts`` candidates + - Packed routed expert weights (gate_up_proj + down_proj) + - A shared expert (always-active) with its own gating signal + + The attention side (GatedDeltaNet linear + gated full) is inherited + unchanged from the parent class. + """ + + def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): + # Map Qwen3.5-MoE config attributes to what the base class expects. + if hasattr(config, "text_config"): + tc = config.text_config + # Base class reads num_local_experts; MoE config uses num_experts + if hasattr(tc, "num_experts") and not hasattr(tc, "num_local_experts"): + tc.num_local_experts = tc.num_experts + # Base class reads intermediate_size; MoE has moe_intermediate_size + if not hasattr(tc, "intermediate_size") and hasattr(tc, "moe_intermediate_size"): + tc.intermediate_size = tc.moe_intermediate_size + + super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options) + + # MoE attributes specific to Qwen3.5-MoE + self.moe_attrs["activation_type"] = "swiglu" + self.moe_attrs["swiglu_fusion"] = 1 + self.moe_attrs["normalize_routing_weights"] = True + + self.moe_intermediate_size = getattr(config, "moe_intermediate_size", 512) + self.shared_expert_intermediate_size = getattr(config, "shared_expert_intermediate_size", self.moe_intermediate_size) + + # Override the INT8 node list for linear-attention layers: MoE layers + # use the MoE/QMoE op for the MLP, not individual MatMul nodes. + # Remove the MLP MatMul INT8 overrides that the parent class added + # since those node names don't exist in MoE layers. + algo_config = self.quant_attrs["int4"].get("algo_config") + if algo_config is not None and hasattr(algo_config, "customized_weight_config"): + keys_to_remove = [k for k in algo_config.customized_weight_config if "/mlp/" in k] + for k in keys_to_remove: + del algo_config.customized_weight_config[k] + + def make_layer(self, layer_id, layer): + """Override to use MoE instead of dense MLP.""" + attn_module = layer.linear_attn if self.layer_types[layer_id] == "linear_attention" else layer.self_attn + self.make_layernorm( + layer_id, + layer.input_layernorm, + skip=not self.layernorm_attrs["first_layernorm"], + simple=self.layernorm_attrs["simple"], + location="input", + ) + self.make_attention(layer_id, attn_module, root_input=self.layernorm_attrs["output_0"]) + self.make_layernorm( + layer_id, + layer.post_attention_layernorm, + skip=True, + simple=self.layernorm_attrs["simple"], + location="post_attention", + ) + self.make_moe(layer_id, layer.mlp, root_input=self.layernorm_attrs["output_0"]) + + self.layernorm_attrs["first_layernorm"] = False + if layer_id == self.num_layers - 1: + self.layernorm_attrs["last_layernorm"] = True + + def make_moe(self, layer_id, mlp, root_input): + """Build the MoE + shared expert subgraph for one decoder layer. + + Structure: + root_input ──┬── router(gate) ──┐ + │ v + ├──────────> MoE/QMoE (routed experts) + │ │ + ├── shared_expert ──┤ + │ │ │ + │ shared_gate ────┤ + │ v + └──────────> Add (moe_out + gated_shared_out) + """ + if self.ep in {"cpu", "cuda", "trt-rtx", "webgpu"}: + self._make_moe_fused(layer_id, mlp, root_input) + else: + raise NotImplementedError( + f"MoE export for EP '{self.ep}' is not yet supported for Qwen3.5-MoE. " + f"Supported EPs: cpu, cuda, trt-rtx, webgpu." + ) + + def _make_moe_fused(self, layer_id, mlp, root_input): + basename = f"/model/layers.{layer_id}/moe" + op_type = self.moe_attrs["op_type"] + moe_weight_type = f"{'q' if op_type == 'QMoE' else ''}weight" + + # --- Router (gate) --- + # Qwen3.5-MoE uses mlp.gate (no bias), unlike GPTOSS mlp.router (with bias) + router_basename = f"{basename}/router/MatMul" + router_matmul_name = self.make_matmul(mlp.gate, router_basename, root_input) + router_reshape_name = f"{basename}/router/Reshape" + router_reshape_inputs = [ + f"{router_matmul_name}/output_0", + f"/model/constants/INT64/{[-1, self.moe_attrs['num_experts']]}", + ] + self.make_reshape( + router_reshape_name, + router_reshape_inputs, + dtype=self.io_dtype, + shape=["batch_size * sequence_length", self.moe_attrs["num_experts"]], + ) + + # --- Routed expert weights --- + gate_up_proj_weight = f"model.layers.{layer_id}.moe.experts.gate_up_proj.{moe_weight_type}" + gate_up_proj_scales = f"model.layers.{layer_id}.moe.experts.gate_up_proj.scales" + gate_up_proj_bias = f"model.layers.{layer_id}.moe.experts.gate_up_proj.bias" + down_proj_weight = f"model.layers.{layer_id}.moe.experts.down_proj.{moe_weight_type}" + down_proj_scales = f"model.layers.{layer_id}.moe.experts.down_proj.scales" + down_proj_bias = f"model.layers.{layer_id}.moe.experts.down_proj.bias" + + # Repack gate_up_proj from HF concatenated [gate|up] to ORT interleaved + # format [g0,u0,g1,u1,...] required by swiglu_fusion=1. + # HF layout: [E, 2*inter, hidden] where rows 0..inter-1 = gate, inter..2*inter-1 = up + # ORT layout: [E, 2*inter, hidden] where row 2i = gate[i], row 2i+1 = up[i] + raw_gate_up = mlp.experts.gate_up_proj # [E, 2*inter, hidden] + half = raw_gate_up.shape[1] // 2 + gate_part = raw_gate_up[:, :half, :] # [E, inter, hidden] + up_part = raw_gate_up[:, half:, :] # [E, inter, hidden] + interleaved_gate_up = torch.stack([gate_part, up_part], dim=2).reshape_as(raw_gate_up) + # down_proj stays as-is: [E, hidden, inter] + + if op_type == "MoE": + self.make_initializer(interleaved_gate_up, gate_up_proj_weight, to=self.io_dtype) + self.make_initializer(mlp.experts.down_proj, down_proj_weight, to=self.io_dtype) + else: + # QMoE: quantize each expert + gate_up_qw_list, gate_up_sc_list = [], [] + down_qw_list, down_sc_list = [], [] + + for i in range(self.moe_attrs["num_experts"]): + qw1, sc1 = self.make_qmoe_weights(interleaved_gate_up[i]) + gate_up_qw_list.append(qw1) + gate_up_sc_list.append(sc1) + qw2, sc2 = self.make_qmoe_weights(mlp.experts.down_proj[i]) + down_qw_list.append(qw2) + down_sc_list.append(sc2) + + gate_up_qw = torch.stack(gate_up_qw_list, dim=0).to(torch.uint8) + gate_up_sc = torch.stack(gate_up_sc_list, dim=0) + down_qw = torch.stack(down_qw_list, dim=0).to(torch.uint8) + down_sc = torch.stack(down_sc_list, dim=0) + + self.make_initializer(gate_up_qw, gate_up_proj_weight) + self.make_initializer(down_qw, down_proj_weight) + self.make_initializer(gate_up_sc, gate_up_proj_scales, to=self.io_dtype) + self.make_initializer(down_sc, down_proj_scales, to=self.io_dtype) + + # Zero biases for routed experts + num_e = self.moe_attrs["num_experts"] + self.make_initializer( + torch.zeros(num_e, 2 * self.moe_intermediate_size), + gate_up_proj_bias, to=self.io_dtype, + ) + self.make_initializer( + torch.zeros(num_e, self.hidden_size), + down_proj_bias, to=self.io_dtype, + ) + + # --- MoE op (routed experts only) --- + moe_name = f"{basename}/{op_type}" + self.make_moe_op( + moe_name, + root_input=root_input, + router_probs=f"{router_reshape_name}/output_0", + weight1=gate_up_proj_weight, + scales1=gate_up_proj_scales if op_type == "QMoE" else "", + bias1=gate_up_proj_bias, + weight2=down_proj_weight, + scales2=down_proj_scales if op_type == "QMoE" else "", + bias2=down_proj_bias, + ) + + # --- Shared expert --- + shared_basename = f"/model/layers.{layer_id}/shared_expert" + + gate_matmul = self.make_matmul(mlp.shared_expert.gate_proj, f"{shared_basename}/gate_proj/MatMul", root_input) + up_matmul = self.make_matmul(mlp.shared_expert.up_proj, f"{shared_basename}/up_proj/MatMul", root_input) + + # SiLU(gate) * up + silu_name = f"{shared_basename}/gate_proj/Sigmoid" + self.make_node("Sigmoid", inputs=[f"{gate_matmul}/output_0"], outputs=[f"{silu_name}/output_0"], name=silu_name) + self.make_value(f"{silu_name}/output_0", self.io_dtype, shape=["batch_size", "sequence_length", self.shared_expert_intermediate_size]) + + mul_gate_name = f"{shared_basename}/gate_proj/Mul" + self.make_node("Mul", inputs=[f"{gate_matmul}/output_0", f"{silu_name}/output_0"], outputs=[f"{mul_gate_name}/output_0"], name=mul_gate_name) + self.make_value(f"{mul_gate_name}/output_0", self.io_dtype, shape=["batch_size", "sequence_length", self.shared_expert_intermediate_size]) + + mul_up_name = f"{shared_basename}/Mul" + self.make_node("Mul", inputs=[f"{mul_gate_name}/output_0", f"{up_matmul}/output_0"], outputs=[f"{mul_up_name}/output_0"], name=mul_up_name) + self.make_value(f"{mul_up_name}/output_0", self.io_dtype, shape=["batch_size", "sequence_length", self.shared_expert_intermediate_size]) + + down_matmul = self.make_matmul(mlp.shared_expert.down_proj, f"{shared_basename}/down_proj/MatMul", f"{mul_up_name}/output_0") + + shared_gate_matmul = self.make_matmul(mlp.shared_expert_gate, f"{shared_basename}_gate/MatMul", root_input) + shared_gate_sigmoid = f"{shared_basename}_gate/Sigmoid" + self.make_node("Sigmoid", inputs=[f"{shared_gate_matmul}/output_0"], outputs=[f"{shared_gate_sigmoid}/output_0"], name=shared_gate_sigmoid) + self.make_value(f"{shared_gate_sigmoid}/output_0", self.io_dtype, shape=["batch_size", "sequence_length", 1]) + + gated_shared_name = f"{shared_basename}/GatedMul" + self.make_node("Mul", inputs=[f"{down_matmul}/output_0", f"{shared_gate_sigmoid}/output_0"], outputs=[f"{gated_shared_name}/output_0"], name=gated_shared_name) + self.make_value(f"{gated_shared_name}/output_0", self.io_dtype, shape=["batch_size", "sequence_length", self.hidden_size]) + + combine_name = f"{basename}/Add" + self.make_node("Add", inputs=[f"{moe_name}/output_0", f"{gated_shared_name}/output_0"], outputs=[f"{combine_name}/output_0"], name=combine_name) + self.make_value(f"{combine_name}/output_0", self.io_dtype, shape=["batch_size", "sequence_length", self.hidden_size]) + + self.layernorm_attrs["skip_input"] = f"{combine_name}/output_0" From 47b9425f26fa1cfaac0ae1c836a24f9b8cd67d08 Mon Sep 17 00:00:00 2001 From: "Ur Rahman, Tanzeel" Date: Wed, 22 Apr 2026 14:56:27 +0530 Subject: [PATCH 02/11] Updated License --- src/models/model.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/models/model.cpp b/src/models/model.cpp index 9da524675..203d01a84 100644 --- a/src/models/model.cpp +++ b/src/models/model.cpp @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. // -// Modifications Copyright (C) 2026 Advanced Micro Devices, Inc. All rights reserved. +// Modifications Copyright (C) 2024-2026 Advanced Micro Devices, Inc. All rights reserved. // Portions of this file consist of AI generated content. #include #include From d4ccaa912d4c796e07d4fa7dcbc8b275ac9722db Mon Sep 17 00:00:00 2001 From: "Ur Rahman, Tanzeel" Date: Wed, 22 Apr 2026 15:01:25 +0530 Subject: [PATCH 03/11] Updated License --- src/models/model_type.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/models/model_type.h b/src/models/model_type.h index d887fe459..d019f54ce 100644 --- a/src/models/model_type.h +++ b/src/models/model_type.h @@ -2,6 +2,10 @@ // Licensed under the MIT License. // Modifications Copyright(C) 2026 Advanced Micro Devices, Inc. All rights reserved // -------------------------------------------------------------------------- +// ------------------------------------------------------------------------- +// Copyright (C) 2026 Advanced Micro Devices, Inc. All rights reserved. +// Portions of this file consist of AI generated content. +// -------------------------------------------------------------------------- #pragma once From 27295ed2158f101b11117a55d60ca3acba7dcf45 Mon Sep 17 00:00:00 2001 From: "Ur Rahman, Tanzeel" Date: Wed, 22 Apr 2026 15:06:05 +0530 Subject: [PATCH 04/11] Update Licence --- src/python/py/models/builders/qwen.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/python/py/models/builders/qwen.py b/src/python/py/models/builders/qwen.py index 939626d9f..6d0c787f4 100644 --- a/src/python/py/models/builders/qwen.py +++ b/src/python/py/models/builders/qwen.py @@ -7,7 +7,6 @@ # Portions of this file consist of AI generated content. import os - import numpy as np import onnx_ir as ir import torch From 7294448232e76d35562d83acc190a469b69c82e1 Mon Sep 17 00:00:00 2001 From: "Ur Rahman, Tanzeel" Date: Wed, 22 Apr 2026 17:28:57 +0530 Subject: [PATCH 05/11] Update Licence --- src/models/model_type.h | 1 + 1 file changed, 1 insertion(+) diff --git a/src/models/model_type.h b/src/models/model_type.h index d019f54ce..c6e4973fe 100644 --- a/src/models/model_type.h +++ b/src/models/model_type.h @@ -2,6 +2,7 @@ // Licensed under the MIT License. // Modifications Copyright(C) 2026 Advanced Micro Devices, Inc. All rights reserved // -------------------------------------------------------------------------- +// Modifications Copyright(C) 2025 Advanced Micro Devices, Inc. All rights reserved. // ------------------------------------------------------------------------- // Copyright (C) 2026 Advanced Micro Devices, Inc. All rights reserved. // Portions of this file consist of AI generated content. From 23a5462cc252654b5f4e2a2c06130c79406f5fac Mon Sep 17 00:00:00 2001 From: "Jain, Vishal" Date: Wed, 22 Apr 2026 18:55:40 +0530 Subject: [PATCH 06/11] Update Copyright notice for modified file --- src/python/py/models/builders/qwen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/python/py/models/builders/qwen.py b/src/python/py/models/builders/qwen.py index 6d0c787f4..032e94e1f 100644 --- a/src/python/py/models/builders/qwen.py +++ b/src/python/py/models/builders/qwen.py @@ -2,7 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. -# -------------------------------------------------------------------------- +# ------------------------------------------------------ # Modifications Copyright (C) 2026 Advanced Micro Devices, Inc. All rights reserved. # Portions of this file consist of AI generated content. From 5dde7dd0bed9cf586836d886a02273234849d93b Mon Sep 17 00:00:00 2001 From: "Jain, Vishal" Date: Wed, 22 Apr 2026 18:58:00 +0530 Subject: [PATCH 07/11] Update Copyright --- src/models/model_type.h | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/models/model_type.h b/src/models/model_type.h index c6e4973fe..a19ec535d 100644 --- a/src/models/model_type.h +++ b/src/models/model_type.h @@ -1,10 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -// Modifications Copyright(C) 2026 Advanced Micro Devices, Inc. All rights reserved // -------------------------------------------------------------------------- -// Modifications Copyright(C) 2025 Advanced Micro Devices, Inc. All rights reserved. -// ------------------------------------------------------------------------- -// Copyright (C) 2026 Advanced Micro Devices, Inc. All rights reserved. +// Modifications Copyright (C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. // Portions of this file consist of AI generated content. // -------------------------------------------------------------------------- From dcd2e33e5cc63a1a8808940694275f5592cc2f2e Mon Sep 17 00:00:00 2001 From: tanzeel-amd Date: Tue, 19 May 2026 01:12:08 -0700 Subject: [PATCH 08/11] Fix Qwen35MoeTextModel to emit correct model type in genai_config.json --- src/python/py/models/builders/qwen.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/python/py/models/builders/qwen.py b/src/python/py/models/builders/qwen.py index 032e94e1f..f8b7bfe66 100644 --- a/src/python/py/models/builders/qwen.py +++ b/src/python/py/models/builders/qwen.py @@ -2112,6 +2112,23 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): for k in keys_to_remove: del algo_config.customized_weight_config[k] + def make_genai_config(self, model_name_or_path, extra_kwargs, out_dir): + """Override to emit ``model.type = "qwen3_5_moe"`` in genai_config.json. + + The parent class hardcodes ``Qwen3_5ForConditionalGeneration`` which + lowercases to ``qwen3_5``, but the C++ runtime VLM/QwenVLFamily + registrations require ``qwen3_5_moe``. + """ + super().make_genai_config(model_name_or_path, extra_kwargs, out_dir) + + import json + from pathlib import Path + + config_path = Path(out_dir) / "genai_config.json" + config = json.loads(config_path.read_text()) + config["model"]["type"] = "qwen3_5_moe" + config_path.write_text(json.dumps(config, indent=4)) + def make_layer(self, layer_id, layer): """Override to use MoE instead of dense MLP.""" attn_module = layer.linear_attn if self.layer_types[layer_id] == "linear_attention" else layer.self_attn From 610bb03cf67a9dd035fda84cc87338c8e230dbad Mon Sep 17 00:00:00 2001 From: tanzeel-amd Date: Tue, 19 May 2026 23:51:22 -0700 Subject: [PATCH 09/11] Address review: consolidate copyright, use std::find for IsQwenVLFamily, set model_type in __init__ - model_type.h: Merge duplicate copyright lines into 2025-2026 range - model_type.h: Rewrite IsQwenVLFamily to use std::array + std::find consistent with other methods - qwen.py: Set model_type in __init__ for both Qwen35TextModel and Qwen35MoeTextModel instead of hardcoding in make_genai_config. Removes the make_genai_config override entirely. Co-authored-by: Cursor --- src/models/model_type.h | 3 ++- src/python/py/models/builders/qwen.py | 22 ++++------------------ 2 files changed, 6 insertions(+), 19 deletions(-) diff --git a/src/models/model_type.h b/src/models/model_type.h index a19ec535d..52e1bc361 100644 --- a/src/models/model_type.h +++ b/src/models/model_type.h @@ -29,7 +29,8 @@ struct ModelType { inline static bool IsQwenVLFamily(const std::string& model_type) { // Qwen-VL family: models requiring 3D mRoPE position IDs - return model_type == "fara" || model_type == "qwen2_5_vl" || model_type == "qwen3_vl" || model_type == "qwen3_5" || model_type == "qwen3_5_moe"; + static constexpr std::array QwenVL = {"fara", "qwen2_5_vl", "qwen3_vl", "qwen3_5", "qwen3_5_moe"}; + return std::find(QwenVL.begin(), QwenVL.end(), model_type) != QwenVL.end(); } inline static bool IsPixtralFamily(const std::string& model_type) { diff --git a/src/python/py/models/builders/qwen.py b/src/python/py/models/builders/qwen.py index f8b7bfe66..98cac1c15 100644 --- a/src/python/py/models/builders/qwen.py +++ b/src/python/py/models/builders/qwen.py @@ -1002,6 +1002,8 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options) + self.model_type = "Qwen3_5_textForCausalLM" if self.is_text_only else "Qwen3_5ForConditionalGeneration" + # OffsetRMSNorm: Qwen3.5 uses (1 + weight) * RMSNorm(x). # Pre-bake the +1 into the weight initializer so the base class's # SkipSimplifiedLayerNormalization can be used directly. @@ -2051,7 +2053,6 @@ def make_genai_config(self, model_name_or_path, extra_kwargs, out_dir): "model_type": self.model_type, } self.num_layers = len(self.layer_types) - self.model_type = "Qwen3_5_textForCausalLM" if self.is_text_only else "Qwen3_5ForConditionalGeneration" self.input_names["past_key_values.key"] = "past_key_values.%d.key" self.input_names["past_key_values.value"] = "past_key_values.%d.value" self.output_names["present.key"] = "present.%d.key" @@ -2094,6 +2095,8 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options) + self.model_type = "Qwen3_5_MoeForConditionalGeneration" + # MoE attributes specific to Qwen3.5-MoE self.moe_attrs["activation_type"] = "swiglu" self.moe_attrs["swiglu_fusion"] = 1 @@ -2112,23 +2115,6 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): for k in keys_to_remove: del algo_config.customized_weight_config[k] - def make_genai_config(self, model_name_or_path, extra_kwargs, out_dir): - """Override to emit ``model.type = "qwen3_5_moe"`` in genai_config.json. - - The parent class hardcodes ``Qwen3_5ForConditionalGeneration`` which - lowercases to ``qwen3_5``, but the C++ runtime VLM/QwenVLFamily - registrations require ``qwen3_5_moe``. - """ - super().make_genai_config(model_name_or_path, extra_kwargs, out_dir) - - import json - from pathlib import Path - - config_path = Path(out_dir) / "genai_config.json" - config = json.loads(config_path.read_text()) - config["model"]["type"] = "qwen3_5_moe" - config_path.write_text(json.dumps(config, indent=4)) - def make_layer(self, layer_id, layer): """Override to use MoE instead of dense MLP.""" attn_module = layer.linear_attn if self.layer_types[layer_id] == "linear_attention" else layer.self_attn From 0979ac2e14f56bab2c0ca72f8bdfa1e1c8f6a375 Mon Sep 17 00:00:00 2001 From: tanzeel-amd Date: Wed, 20 May 2026 03:29:21 -0700 Subject: [PATCH 10/11] Refactor MoE into base class: make_fused_moe, make_shared_expert, generic int4 config - base.py: Add make_fused_moe() supporting router with/without bias, 2-weight SwiGLU layout with interleaving, and optional shared expert. Add make_shared_expert() using wrapper methods (make_sigmoid, make_mul, etc.). Move MoE /mlp/ int4 config cleanup into make_int4_algo_config(). - qwen.py: Remove _make_moe_fused (~150 lines) and make_moe dispatcher. Replace with single make_fused_moe() call from base class. Remove int4 algo cleanup from __init__ (now in base). Co-authored-by: Cursor --- src/python/py/models/builders/base.py | 179 ++++++++++++++++++++++++++ src/python/py/models/builders/qwen.py | 167 +----------------------- 2 files changed, 186 insertions(+), 160 deletions(-) diff --git a/src/python/py/models/builders/base.py b/src/python/py/models/builders/base.py index 061862f15..ed75d2fc6 100644 --- a/src/python/py/models/builders/base.py +++ b/src/python/py/models/builders/base.py @@ -764,6 +764,14 @@ def make_int4_algo_config(self, quant_method: str): customized_weight_config["/lm_head/MatMul"] = {"bits": 8} int4_algo_config = KQuantWeightOnlyQuantConfig(customized_weight_config=customized_weight_config) + # MoE layers use MoE/QMoE ops instead of individual MatMul nodes, + # so remove any /mlp/ MatMul overrides that don't exist in MoE models. + if int4_algo_config is not None and self.moe_attrs.get("num_experts", 0) > 0: + if hasattr(int4_algo_config, "customized_weight_config"): + keys_to_remove = [k for k in int4_algo_config.customized_weight_config if "/mlp/" in k] + for k in keys_to_remove: + del int4_algo_config.customized_weight_config[k] + return int4_algo_config def to_int4(self) -> ir.Model: @@ -3728,6 +3736,177 @@ def _symmetric_blockwise_quantize(self, weights, block_size): return qweight.cpu(), scales.cpu() + def make_fused_moe(self, layer_id, mlp, root_input, shared_expert=None, shared_expert_gate=None): + """Build a fused MoE subgraph (router + routed experts + optional shared expert). + + Supports both 2-weight (gate_up_proj/down_proj with SwiGLU) and 3-weight + (w1/w2/w3) expert layouts, routers with or without bias, and an optional + shared expert path with sigmoid gating. + + Args: + layer_id: Decoder layer index. + mlp: MoE module containing the router and expert weights. + Router: ``mlp.gate`` (no bias) or ``mlp.router`` (with bias). + Experts: ``mlp.experts.gate_up_proj``/``mlp.experts.down_proj`` (2-weight) + or ``mlp.experts[i].w1``/``w2``/``w3`` (3-weight). + root_input: The ONNX input tensor name feeding this MoE block. + shared_expert: Optional shared expert module (e.g. ``mlp.shared_expert``). + shared_expert_gate: Optional shared expert gate module (e.g. ``mlp.shared_expert_gate``). + """ + basename = f"/model/layers.{layer_id}/moe" + op_type = self.moe_attrs["op_type"] + moe_weight_type = f"{'q' if op_type == 'QMoE' else ''}weight" + + # --- Router --- + router_module = mlp.gate if hasattr(mlp, "gate") else mlp.router + has_router_bias = hasattr(router_module, "bias") and router_module.bias is not None + + router_basename = f"{basename}/router/MatMul" + router_matmul_name = self.make_matmul(router_module, router_basename, root_input) + + if has_router_bias: + router_add_name = f"{basename}/router/Add" + self.make_add_bias(router_module.bias, router_add_name, root_input=f"{router_matmul_name}/output_0") + router_out = f"{router_add_name}/output_0" + else: + router_out = f"{router_matmul_name}/output_0" + + router_reshape_name = f"{basename}/router/Reshape" + self.make_reshape( + router_reshape_name, + [router_out, f"/model/constants/INT64/{[-1, self.moe_attrs['num_experts']]}"], + dtype=self.io_dtype, + shape=["batch_size * sequence_length", self.moe_attrs["num_experts"]], + ) + + # --- Expert weights --- + gate_up_proj_weight = f"model.layers.{layer_id}.moe.experts.gate_up_proj.{moe_weight_type}" + gate_up_proj_scales = f"model.layers.{layer_id}.moe.experts.gate_up_proj.scales" + gate_up_proj_bias = f"model.layers.{layer_id}.moe.experts.gate_up_proj.bias" + down_proj_weight = f"model.layers.{layer_id}.moe.experts.down_proj.{moe_weight_type}" + down_proj_scales = f"model.layers.{layer_id}.moe.experts.down_proj.scales" + down_proj_bias = f"model.layers.{layer_id}.moe.experts.down_proj.bias" + + raw_gate_up = mlp.experts.gate_up_proj + down_proj = mlp.experts.down_proj + + # SwiGLU interleaving: repack [gate|up] to [g0,u0,g1,u1,...] when enabled + if self.moe_attrs.get("swiglu_fusion", 0) == 1: + half = raw_gate_up.shape[1] // 2 + gate_part = raw_gate_up[:, :half, :] + up_part = raw_gate_up[:, half:, :] + raw_gate_up = torch.stack([gate_part, up_part], dim=2).reshape_as(mlp.experts.gate_up_proj) + + if op_type == "MoE": + self.make_initializer(raw_gate_up, gate_up_proj_weight, to=self.io_dtype) + self.make_initializer(down_proj, down_proj_weight, to=self.io_dtype) + else: + gate_up_qw_list, gate_up_sc_list = [], [] + down_qw_list, down_sc_list = [], [] + + for i in range(self.moe_attrs["num_experts"]): + qw1, sc1 = self.make_qmoe_weights(raw_gate_up[i]) + gate_up_qw_list.append(qw1) + gate_up_sc_list.append(sc1) + qw2, sc2 = self.make_qmoe_weights(down_proj[i]) + down_qw_list.append(qw2) + down_sc_list.append(sc2) + + self.make_initializer(torch.stack(gate_up_qw_list, dim=0).to(torch.uint8), gate_up_proj_weight) + self.make_initializer(torch.stack(down_qw_list, dim=0).to(torch.uint8), down_proj_weight) + self.make_initializer(torch.stack(gate_up_sc_list, dim=0), gate_up_proj_scales, to=self.io_dtype) + self.make_initializer(torch.stack(down_sc_list, dim=0), down_proj_scales, to=self.io_dtype) + + # Biases + num_e = self.moe_attrs["num_experts"] + if hasattr(mlp.experts, "gate_up_proj_bias"): + self.make_initializer(mlp.experts.gate_up_proj_bias, gate_up_proj_bias, to=self.io_dtype) + self.make_initializer(mlp.experts.down_proj_bias, down_proj_bias, to=self.io_dtype) + else: + moe_inter = raw_gate_up.shape[1] + self.make_initializer(torch.zeros(num_e, moe_inter), gate_up_proj_bias, to=self.io_dtype) + self.make_initializer(torch.zeros(num_e, self.hidden_size), down_proj_bias, to=self.io_dtype) + + # --- MoE/QMoE op --- + moe_name = f"{basename}/{op_type}" + self.make_moe_op( + moe_name, + root_input=root_input, + router_probs=f"{router_reshape_name}/output_0", + weight1=gate_up_proj_weight, + scales1=gate_up_proj_scales if op_type == "QMoE" else "", + bias1=gate_up_proj_bias, + weight2=down_proj_weight, + scales2=down_proj_scales if op_type == "QMoE" else "", + bias2=down_proj_bias, + ) + + moe_output = f"{moe_name}/output_0" + + # --- Optional shared expert --- + if shared_expert is not None: + shared_output = self.make_shared_expert( + layer_id, shared_expert, shared_expert_gate, root_input, + ) + combine_name = f"{basename}/Add" + self.make_add( + combine_name, + [moe_output, shared_output], + dtype=self.io_dtype, + shape=["batch_size", "sequence_length", self.hidden_size], + ) + moe_output = f"{combine_name}/output_0" + + self.layernorm_attrs["skip_input"] = moe_output + + def make_shared_expert(self, layer_id, shared_expert, shared_expert_gate, root_input): + """Build a shared expert SiLU-MLP with optional sigmoid gating. + + Structure: gate_proj -> SiLU -> Mul(gate, up) -> down_proj -> sigmoid_gate -> Mul + """ + basename = f"/model/layers.{layer_id}/shared_expert" + shared_inter = getattr(self, "shared_expert_intermediate_size", self.intermediate_size) + + gate_matmul = self.make_matmul(shared_expert.gate_proj, f"{basename}/gate_proj/MatMul", root_input) + up_matmul = self.make_matmul(shared_expert.up_proj, f"{basename}/up_proj/MatMul", root_input) + + # SiLU(gate) = gate * sigmoid(gate) + silu_sigmoid_name = f"{basename}/gate_proj/Sigmoid" + self.make_sigmoid(silu_sigmoid_name, f"{gate_matmul}/output_0", self.io_dtype, + shape=["batch_size", "sequence_length", shared_inter]) + + silu_mul_name = f"{basename}/gate_proj/Mul" + self.make_mul(silu_mul_name, + [f"{gate_matmul}/output_0", f"{silu_sigmoid_name}/output_0"], + dtype=self.io_dtype, + shape=["batch_size", "sequence_length", shared_inter]) + + # SiLU(gate) * up + gate_up_mul_name = f"{basename}/Mul" + self.make_mul(gate_up_mul_name, + [f"{silu_mul_name}/output_0", f"{up_matmul}/output_0"], + dtype=self.io_dtype, + shape=["batch_size", "sequence_length", shared_inter]) + + down_matmul = self.make_matmul(shared_expert.down_proj, f"{basename}/down_proj/MatMul", + f"{gate_up_mul_name}/output_0") + + # Optional sigmoid gating on the shared expert output + if shared_expert_gate is not None: + gate_matmul_name = self.make_matmul(shared_expert_gate, f"{basename}_gate/MatMul", root_input) + gate_sigmoid_name = f"{basename}_gate/Sigmoid" + self.make_sigmoid(gate_sigmoid_name, f"{gate_matmul_name}/output_0", self.io_dtype, + shape=["batch_size", "sequence_length", 1]) + + gated_mul_name = f"{basename}/GatedMul" + self.make_mul(gated_mul_name, + [f"{down_matmul}/output_0", f"{gate_sigmoid_name}/output_0"], + dtype=self.io_dtype, + shape=["batch_size", "sequence_length", self.hidden_size]) + return f"{gated_mul_name}/output_0" + + return f"{down_matmul}/output_0" + def make_block_sparse_moe(self, layer_id, bsm, root_input): # Make nodes for the QMoE block-sparse subgraph # diff --git a/src/python/py/models/builders/qwen.py b/src/python/py/models/builders/qwen.py index 98cac1c15..1eb26b633 100644 --- a/src/python/py/models/builders/qwen.py +++ b/src/python/py/models/builders/qwen.py @@ -2105,16 +2105,6 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): self.moe_intermediate_size = getattr(config, "moe_intermediate_size", 512) self.shared_expert_intermediate_size = getattr(config, "shared_expert_intermediate_size", self.moe_intermediate_size) - # Override the INT8 node list for linear-attention layers: MoE layers - # use the MoE/QMoE op for the MLP, not individual MatMul nodes. - # Remove the MLP MatMul INT8 overrides that the parent class added - # since those node names don't exist in MoE layers. - algo_config = self.quant_attrs["int4"].get("algo_config") - if algo_config is not None and hasattr(algo_config, "customized_weight_config"): - keys_to_remove = [k for k in algo_config.customized_weight_config if "/mlp/" in k] - for k in keys_to_remove: - del algo_config.customized_weight_config[k] - def make_layer(self, layer_id, layer): """Override to use MoE instead of dense MLP.""" attn_module = layer.linear_attn if self.layer_types[layer_id] == "linear_attention" else layer.self_attn @@ -2133,157 +2123,14 @@ def make_layer(self, layer_id, layer): simple=self.layernorm_attrs["simple"], location="post_attention", ) - self.make_moe(layer_id, layer.mlp, root_input=self.layernorm_attrs["output_0"]) + self.make_fused_moe( + layer_id, + layer.mlp, + root_input=self.layernorm_attrs["output_0"], + shared_expert=layer.mlp.shared_expert, + shared_expert_gate=layer.mlp.shared_expert_gate, + ) self.layernorm_attrs["first_layernorm"] = False if layer_id == self.num_layers - 1: self.layernorm_attrs["last_layernorm"] = True - - def make_moe(self, layer_id, mlp, root_input): - """Build the MoE + shared expert subgraph for one decoder layer. - - Structure: - root_input ──┬── router(gate) ──┐ - │ v - ├──────────> MoE/QMoE (routed experts) - │ │ - ├── shared_expert ──┤ - │ │ │ - │ shared_gate ────┤ - │ v - └──────────> Add (moe_out + gated_shared_out) - """ - if self.ep in {"cpu", "cuda", "trt-rtx", "webgpu"}: - self._make_moe_fused(layer_id, mlp, root_input) - else: - raise NotImplementedError( - f"MoE export for EP '{self.ep}' is not yet supported for Qwen3.5-MoE. " - f"Supported EPs: cpu, cuda, trt-rtx, webgpu." - ) - - def _make_moe_fused(self, layer_id, mlp, root_input): - basename = f"/model/layers.{layer_id}/moe" - op_type = self.moe_attrs["op_type"] - moe_weight_type = f"{'q' if op_type == 'QMoE' else ''}weight" - - # --- Router (gate) --- - # Qwen3.5-MoE uses mlp.gate (no bias), unlike GPTOSS mlp.router (with bias) - router_basename = f"{basename}/router/MatMul" - router_matmul_name = self.make_matmul(mlp.gate, router_basename, root_input) - router_reshape_name = f"{basename}/router/Reshape" - router_reshape_inputs = [ - f"{router_matmul_name}/output_0", - f"/model/constants/INT64/{[-1, self.moe_attrs['num_experts']]}", - ] - self.make_reshape( - router_reshape_name, - router_reshape_inputs, - dtype=self.io_dtype, - shape=["batch_size * sequence_length", self.moe_attrs["num_experts"]], - ) - - # --- Routed expert weights --- - gate_up_proj_weight = f"model.layers.{layer_id}.moe.experts.gate_up_proj.{moe_weight_type}" - gate_up_proj_scales = f"model.layers.{layer_id}.moe.experts.gate_up_proj.scales" - gate_up_proj_bias = f"model.layers.{layer_id}.moe.experts.gate_up_proj.bias" - down_proj_weight = f"model.layers.{layer_id}.moe.experts.down_proj.{moe_weight_type}" - down_proj_scales = f"model.layers.{layer_id}.moe.experts.down_proj.scales" - down_proj_bias = f"model.layers.{layer_id}.moe.experts.down_proj.bias" - - # Repack gate_up_proj from HF concatenated [gate|up] to ORT interleaved - # format [g0,u0,g1,u1,...] required by swiglu_fusion=1. - # HF layout: [E, 2*inter, hidden] where rows 0..inter-1 = gate, inter..2*inter-1 = up - # ORT layout: [E, 2*inter, hidden] where row 2i = gate[i], row 2i+1 = up[i] - raw_gate_up = mlp.experts.gate_up_proj # [E, 2*inter, hidden] - half = raw_gate_up.shape[1] // 2 - gate_part = raw_gate_up[:, :half, :] # [E, inter, hidden] - up_part = raw_gate_up[:, half:, :] # [E, inter, hidden] - interleaved_gate_up = torch.stack([gate_part, up_part], dim=2).reshape_as(raw_gate_up) - # down_proj stays as-is: [E, hidden, inter] - - if op_type == "MoE": - self.make_initializer(interleaved_gate_up, gate_up_proj_weight, to=self.io_dtype) - self.make_initializer(mlp.experts.down_proj, down_proj_weight, to=self.io_dtype) - else: - # QMoE: quantize each expert - gate_up_qw_list, gate_up_sc_list = [], [] - down_qw_list, down_sc_list = [], [] - - for i in range(self.moe_attrs["num_experts"]): - qw1, sc1 = self.make_qmoe_weights(interleaved_gate_up[i]) - gate_up_qw_list.append(qw1) - gate_up_sc_list.append(sc1) - qw2, sc2 = self.make_qmoe_weights(mlp.experts.down_proj[i]) - down_qw_list.append(qw2) - down_sc_list.append(sc2) - - gate_up_qw = torch.stack(gate_up_qw_list, dim=0).to(torch.uint8) - gate_up_sc = torch.stack(gate_up_sc_list, dim=0) - down_qw = torch.stack(down_qw_list, dim=0).to(torch.uint8) - down_sc = torch.stack(down_sc_list, dim=0) - - self.make_initializer(gate_up_qw, gate_up_proj_weight) - self.make_initializer(down_qw, down_proj_weight) - self.make_initializer(gate_up_sc, gate_up_proj_scales, to=self.io_dtype) - self.make_initializer(down_sc, down_proj_scales, to=self.io_dtype) - - # Zero biases for routed experts - num_e = self.moe_attrs["num_experts"] - self.make_initializer( - torch.zeros(num_e, 2 * self.moe_intermediate_size), - gate_up_proj_bias, to=self.io_dtype, - ) - self.make_initializer( - torch.zeros(num_e, self.hidden_size), - down_proj_bias, to=self.io_dtype, - ) - - # --- MoE op (routed experts only) --- - moe_name = f"{basename}/{op_type}" - self.make_moe_op( - moe_name, - root_input=root_input, - router_probs=f"{router_reshape_name}/output_0", - weight1=gate_up_proj_weight, - scales1=gate_up_proj_scales if op_type == "QMoE" else "", - bias1=gate_up_proj_bias, - weight2=down_proj_weight, - scales2=down_proj_scales if op_type == "QMoE" else "", - bias2=down_proj_bias, - ) - - # --- Shared expert --- - shared_basename = f"/model/layers.{layer_id}/shared_expert" - - gate_matmul = self.make_matmul(mlp.shared_expert.gate_proj, f"{shared_basename}/gate_proj/MatMul", root_input) - up_matmul = self.make_matmul(mlp.shared_expert.up_proj, f"{shared_basename}/up_proj/MatMul", root_input) - - # SiLU(gate) * up - silu_name = f"{shared_basename}/gate_proj/Sigmoid" - self.make_node("Sigmoid", inputs=[f"{gate_matmul}/output_0"], outputs=[f"{silu_name}/output_0"], name=silu_name) - self.make_value(f"{silu_name}/output_0", self.io_dtype, shape=["batch_size", "sequence_length", self.shared_expert_intermediate_size]) - - mul_gate_name = f"{shared_basename}/gate_proj/Mul" - self.make_node("Mul", inputs=[f"{gate_matmul}/output_0", f"{silu_name}/output_0"], outputs=[f"{mul_gate_name}/output_0"], name=mul_gate_name) - self.make_value(f"{mul_gate_name}/output_0", self.io_dtype, shape=["batch_size", "sequence_length", self.shared_expert_intermediate_size]) - - mul_up_name = f"{shared_basename}/Mul" - self.make_node("Mul", inputs=[f"{mul_gate_name}/output_0", f"{up_matmul}/output_0"], outputs=[f"{mul_up_name}/output_0"], name=mul_up_name) - self.make_value(f"{mul_up_name}/output_0", self.io_dtype, shape=["batch_size", "sequence_length", self.shared_expert_intermediate_size]) - - down_matmul = self.make_matmul(mlp.shared_expert.down_proj, f"{shared_basename}/down_proj/MatMul", f"{mul_up_name}/output_0") - - shared_gate_matmul = self.make_matmul(mlp.shared_expert_gate, f"{shared_basename}_gate/MatMul", root_input) - shared_gate_sigmoid = f"{shared_basename}_gate/Sigmoid" - self.make_node("Sigmoid", inputs=[f"{shared_gate_matmul}/output_0"], outputs=[f"{shared_gate_sigmoid}/output_0"], name=shared_gate_sigmoid) - self.make_value(f"{shared_gate_sigmoid}/output_0", self.io_dtype, shape=["batch_size", "sequence_length", 1]) - - gated_shared_name = f"{shared_basename}/GatedMul" - self.make_node("Mul", inputs=[f"{down_matmul}/output_0", f"{shared_gate_sigmoid}/output_0"], outputs=[f"{gated_shared_name}/output_0"], name=gated_shared_name) - self.make_value(f"{gated_shared_name}/output_0", self.io_dtype, shape=["batch_size", "sequence_length", self.hidden_size]) - - combine_name = f"{basename}/Add" - self.make_node("Add", inputs=[f"{moe_name}/output_0", f"{gated_shared_name}/output_0"], outputs=[f"{combine_name}/output_0"], name=combine_name) - self.make_value(f"{combine_name}/output_0", self.io_dtype, shape=["batch_size", "sequence_length", self.hidden_size]) - - self.layernorm_attrs["skip_input"] = f"{combine_name}/output_0" From 51155a069a1a7c17dda1df53728995d97231f99d Mon Sep 17 00:00:00 2001 From: tanzeel-amd Date: Thu, 21 May 2026 06:57:20 -0700 Subject: [PATCH 11/11] Revert MoE methods back to Qwen class per review, keep wrapper methods Per reviewer feedback, MoE builders in this codebase follow a model-specific pattern rather than a shared base class method. Moved make_moe, make_shared_expert, and int4 config cleanup back to Qwen35MoeTextModel. Retained use of wrapper methods (make_sigmoid, make_mul, make_add) instead of raw make_node/make_value. Co-authored-by: Cursor --- src/python/py/models/builders/base.py | 179 -------------------------- src/python/py/models/builders/qwen.py | 132 ++++++++++++++++++- 2 files changed, 125 insertions(+), 186 deletions(-) diff --git a/src/python/py/models/builders/base.py b/src/python/py/models/builders/base.py index ed75d2fc6..061862f15 100644 --- a/src/python/py/models/builders/base.py +++ b/src/python/py/models/builders/base.py @@ -764,14 +764,6 @@ def make_int4_algo_config(self, quant_method: str): customized_weight_config["/lm_head/MatMul"] = {"bits": 8} int4_algo_config = KQuantWeightOnlyQuantConfig(customized_weight_config=customized_weight_config) - # MoE layers use MoE/QMoE ops instead of individual MatMul nodes, - # so remove any /mlp/ MatMul overrides that don't exist in MoE models. - if int4_algo_config is not None and self.moe_attrs.get("num_experts", 0) > 0: - if hasattr(int4_algo_config, "customized_weight_config"): - keys_to_remove = [k for k in int4_algo_config.customized_weight_config if "/mlp/" in k] - for k in keys_to_remove: - del int4_algo_config.customized_weight_config[k] - return int4_algo_config def to_int4(self) -> ir.Model: @@ -3736,177 +3728,6 @@ def _symmetric_blockwise_quantize(self, weights, block_size): return qweight.cpu(), scales.cpu() - def make_fused_moe(self, layer_id, mlp, root_input, shared_expert=None, shared_expert_gate=None): - """Build a fused MoE subgraph (router + routed experts + optional shared expert). - - Supports both 2-weight (gate_up_proj/down_proj with SwiGLU) and 3-weight - (w1/w2/w3) expert layouts, routers with or without bias, and an optional - shared expert path with sigmoid gating. - - Args: - layer_id: Decoder layer index. - mlp: MoE module containing the router and expert weights. - Router: ``mlp.gate`` (no bias) or ``mlp.router`` (with bias). - Experts: ``mlp.experts.gate_up_proj``/``mlp.experts.down_proj`` (2-weight) - or ``mlp.experts[i].w1``/``w2``/``w3`` (3-weight). - root_input: The ONNX input tensor name feeding this MoE block. - shared_expert: Optional shared expert module (e.g. ``mlp.shared_expert``). - shared_expert_gate: Optional shared expert gate module (e.g. ``mlp.shared_expert_gate``). - """ - basename = f"/model/layers.{layer_id}/moe" - op_type = self.moe_attrs["op_type"] - moe_weight_type = f"{'q' if op_type == 'QMoE' else ''}weight" - - # --- Router --- - router_module = mlp.gate if hasattr(mlp, "gate") else mlp.router - has_router_bias = hasattr(router_module, "bias") and router_module.bias is not None - - router_basename = f"{basename}/router/MatMul" - router_matmul_name = self.make_matmul(router_module, router_basename, root_input) - - if has_router_bias: - router_add_name = f"{basename}/router/Add" - self.make_add_bias(router_module.bias, router_add_name, root_input=f"{router_matmul_name}/output_0") - router_out = f"{router_add_name}/output_0" - else: - router_out = f"{router_matmul_name}/output_0" - - router_reshape_name = f"{basename}/router/Reshape" - self.make_reshape( - router_reshape_name, - [router_out, f"/model/constants/INT64/{[-1, self.moe_attrs['num_experts']]}"], - dtype=self.io_dtype, - shape=["batch_size * sequence_length", self.moe_attrs["num_experts"]], - ) - - # --- Expert weights --- - gate_up_proj_weight = f"model.layers.{layer_id}.moe.experts.gate_up_proj.{moe_weight_type}" - gate_up_proj_scales = f"model.layers.{layer_id}.moe.experts.gate_up_proj.scales" - gate_up_proj_bias = f"model.layers.{layer_id}.moe.experts.gate_up_proj.bias" - down_proj_weight = f"model.layers.{layer_id}.moe.experts.down_proj.{moe_weight_type}" - down_proj_scales = f"model.layers.{layer_id}.moe.experts.down_proj.scales" - down_proj_bias = f"model.layers.{layer_id}.moe.experts.down_proj.bias" - - raw_gate_up = mlp.experts.gate_up_proj - down_proj = mlp.experts.down_proj - - # SwiGLU interleaving: repack [gate|up] to [g0,u0,g1,u1,...] when enabled - if self.moe_attrs.get("swiglu_fusion", 0) == 1: - half = raw_gate_up.shape[1] // 2 - gate_part = raw_gate_up[:, :half, :] - up_part = raw_gate_up[:, half:, :] - raw_gate_up = torch.stack([gate_part, up_part], dim=2).reshape_as(mlp.experts.gate_up_proj) - - if op_type == "MoE": - self.make_initializer(raw_gate_up, gate_up_proj_weight, to=self.io_dtype) - self.make_initializer(down_proj, down_proj_weight, to=self.io_dtype) - else: - gate_up_qw_list, gate_up_sc_list = [], [] - down_qw_list, down_sc_list = [], [] - - for i in range(self.moe_attrs["num_experts"]): - qw1, sc1 = self.make_qmoe_weights(raw_gate_up[i]) - gate_up_qw_list.append(qw1) - gate_up_sc_list.append(sc1) - qw2, sc2 = self.make_qmoe_weights(down_proj[i]) - down_qw_list.append(qw2) - down_sc_list.append(sc2) - - self.make_initializer(torch.stack(gate_up_qw_list, dim=0).to(torch.uint8), gate_up_proj_weight) - self.make_initializer(torch.stack(down_qw_list, dim=0).to(torch.uint8), down_proj_weight) - self.make_initializer(torch.stack(gate_up_sc_list, dim=0), gate_up_proj_scales, to=self.io_dtype) - self.make_initializer(torch.stack(down_sc_list, dim=0), down_proj_scales, to=self.io_dtype) - - # Biases - num_e = self.moe_attrs["num_experts"] - if hasattr(mlp.experts, "gate_up_proj_bias"): - self.make_initializer(mlp.experts.gate_up_proj_bias, gate_up_proj_bias, to=self.io_dtype) - self.make_initializer(mlp.experts.down_proj_bias, down_proj_bias, to=self.io_dtype) - else: - moe_inter = raw_gate_up.shape[1] - self.make_initializer(torch.zeros(num_e, moe_inter), gate_up_proj_bias, to=self.io_dtype) - self.make_initializer(torch.zeros(num_e, self.hidden_size), down_proj_bias, to=self.io_dtype) - - # --- MoE/QMoE op --- - moe_name = f"{basename}/{op_type}" - self.make_moe_op( - moe_name, - root_input=root_input, - router_probs=f"{router_reshape_name}/output_0", - weight1=gate_up_proj_weight, - scales1=gate_up_proj_scales if op_type == "QMoE" else "", - bias1=gate_up_proj_bias, - weight2=down_proj_weight, - scales2=down_proj_scales if op_type == "QMoE" else "", - bias2=down_proj_bias, - ) - - moe_output = f"{moe_name}/output_0" - - # --- Optional shared expert --- - if shared_expert is not None: - shared_output = self.make_shared_expert( - layer_id, shared_expert, shared_expert_gate, root_input, - ) - combine_name = f"{basename}/Add" - self.make_add( - combine_name, - [moe_output, shared_output], - dtype=self.io_dtype, - shape=["batch_size", "sequence_length", self.hidden_size], - ) - moe_output = f"{combine_name}/output_0" - - self.layernorm_attrs["skip_input"] = moe_output - - def make_shared_expert(self, layer_id, shared_expert, shared_expert_gate, root_input): - """Build a shared expert SiLU-MLP with optional sigmoid gating. - - Structure: gate_proj -> SiLU -> Mul(gate, up) -> down_proj -> sigmoid_gate -> Mul - """ - basename = f"/model/layers.{layer_id}/shared_expert" - shared_inter = getattr(self, "shared_expert_intermediate_size", self.intermediate_size) - - gate_matmul = self.make_matmul(shared_expert.gate_proj, f"{basename}/gate_proj/MatMul", root_input) - up_matmul = self.make_matmul(shared_expert.up_proj, f"{basename}/up_proj/MatMul", root_input) - - # SiLU(gate) = gate * sigmoid(gate) - silu_sigmoid_name = f"{basename}/gate_proj/Sigmoid" - self.make_sigmoid(silu_sigmoid_name, f"{gate_matmul}/output_0", self.io_dtype, - shape=["batch_size", "sequence_length", shared_inter]) - - silu_mul_name = f"{basename}/gate_proj/Mul" - self.make_mul(silu_mul_name, - [f"{gate_matmul}/output_0", f"{silu_sigmoid_name}/output_0"], - dtype=self.io_dtype, - shape=["batch_size", "sequence_length", shared_inter]) - - # SiLU(gate) * up - gate_up_mul_name = f"{basename}/Mul" - self.make_mul(gate_up_mul_name, - [f"{silu_mul_name}/output_0", f"{up_matmul}/output_0"], - dtype=self.io_dtype, - shape=["batch_size", "sequence_length", shared_inter]) - - down_matmul = self.make_matmul(shared_expert.down_proj, f"{basename}/down_proj/MatMul", - f"{gate_up_mul_name}/output_0") - - # Optional sigmoid gating on the shared expert output - if shared_expert_gate is not None: - gate_matmul_name = self.make_matmul(shared_expert_gate, f"{basename}_gate/MatMul", root_input) - gate_sigmoid_name = f"{basename}_gate/Sigmoid" - self.make_sigmoid(gate_sigmoid_name, f"{gate_matmul_name}/output_0", self.io_dtype, - shape=["batch_size", "sequence_length", 1]) - - gated_mul_name = f"{basename}/GatedMul" - self.make_mul(gated_mul_name, - [f"{down_matmul}/output_0", f"{gate_sigmoid_name}/output_0"], - dtype=self.io_dtype, - shape=["batch_size", "sequence_length", self.hidden_size]) - return f"{gated_mul_name}/output_0" - - return f"{down_matmul}/output_0" - def make_block_sparse_moe(self, layer_id, bsm, root_input): # Make nodes for the QMoE block-sparse subgraph # diff --git a/src/python/py/models/builders/qwen.py b/src/python/py/models/builders/qwen.py index 1eb26b633..459c08c5d 100644 --- a/src/python/py/models/builders/qwen.py +++ b/src/python/py/models/builders/qwen.py @@ -2105,6 +2105,14 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): self.moe_intermediate_size = getattr(config, "moe_intermediate_size", 512) self.shared_expert_intermediate_size = getattr(config, "shared_expert_intermediate_size", self.moe_intermediate_size) + # MoE layers use MoE/QMoE ops instead of individual MatMul nodes, + # so remove any /mlp/ MatMul overrides that don't apply. + algo_config = self.quant_attrs["int4"].get("algo_config") + if algo_config is not None and hasattr(algo_config, "customized_weight_config"): + keys_to_remove = [k for k in algo_config.customized_weight_config if "/mlp/" in k] + for k in keys_to_remove: + del algo_config.customized_weight_config[k] + def make_layer(self, layer_id, layer): """Override to use MoE instead of dense MLP.""" attn_module = layer.linear_attn if self.layer_types[layer_id] == "linear_attention" else layer.self_attn @@ -2123,14 +2131,124 @@ def make_layer(self, layer_id, layer): simple=self.layernorm_attrs["simple"], location="post_attention", ) - self.make_fused_moe( - layer_id, - layer.mlp, - root_input=self.layernorm_attrs["output_0"], - shared_expert=layer.mlp.shared_expert, - shared_expert_gate=layer.mlp.shared_expert_gate, - ) + self.make_moe(layer_id, layer.mlp, root_input=self.layernorm_attrs["output_0"]) self.layernorm_attrs["first_layernorm"] = False if layer_id == self.num_layers - 1: self.layernorm_attrs["last_layernorm"] = True + + def make_moe(self, layer_id, mlp, root_input): + """Build MoE + shared expert subgraph for one decoder layer.""" + basename = f"/model/layers.{layer_id}/moe" + op_type = self.moe_attrs["op_type"] + moe_weight_type = f"{'q' if op_type == 'QMoE' else ''}weight" + + # --- Router (bias-free gate) --- + router_basename = f"{basename}/router/MatMul" + router_matmul_name = self.make_matmul(mlp.gate, router_basename, root_input) + router_reshape_name = f"{basename}/router/Reshape" + self.make_reshape( + router_reshape_name, + [f"{router_matmul_name}/output_0", + f"/model/constants/INT64/{[-1, self.moe_attrs['num_experts']]}"], + dtype=self.io_dtype, + shape=["batch_size * sequence_length", self.moe_attrs["num_experts"]], + ) + + # --- Routed expert weights --- + gate_up_proj_weight = f"model.layers.{layer_id}.moe.experts.gate_up_proj.{moe_weight_type}" + gate_up_proj_scales = f"model.layers.{layer_id}.moe.experts.gate_up_proj.scales" + gate_up_proj_bias = f"model.layers.{layer_id}.moe.experts.gate_up_proj.bias" + down_proj_weight = f"model.layers.{layer_id}.moe.experts.down_proj.{moe_weight_type}" + down_proj_scales = f"model.layers.{layer_id}.moe.experts.down_proj.scales" + down_proj_bias = f"model.layers.{layer_id}.moe.experts.down_proj.bias" + + # Repack HF concatenated [gate|up] to ORT interleaved [g0,u0,g1,u1,...] for swiglu_fusion=1 + raw_gate_up = mlp.experts.gate_up_proj + half = raw_gate_up.shape[1] // 2 + interleaved = torch.stack([raw_gate_up[:, :half, :], raw_gate_up[:, half:, :]], dim=2).reshape_as(raw_gate_up) + + if op_type == "MoE": + self.make_initializer(interleaved, gate_up_proj_weight, to=self.io_dtype) + self.make_initializer(mlp.experts.down_proj, down_proj_weight, to=self.io_dtype) + else: + gate_up_qw_list, gate_up_sc_list = [], [] + down_qw_list, down_sc_list = [], [] + for i in range(self.moe_attrs["num_experts"]): + qw1, sc1 = self.make_qmoe_weights(interleaved[i]) + gate_up_qw_list.append(qw1) + gate_up_sc_list.append(sc1) + qw2, sc2 = self.make_qmoe_weights(mlp.experts.down_proj[i]) + down_qw_list.append(qw2) + down_sc_list.append(sc2) + self.make_initializer(torch.stack(gate_up_qw_list, dim=0).to(torch.uint8), gate_up_proj_weight) + self.make_initializer(torch.stack(down_qw_list, dim=0).to(torch.uint8), down_proj_weight) + self.make_initializer(torch.stack(gate_up_sc_list, dim=0), gate_up_proj_scales, to=self.io_dtype) + self.make_initializer(torch.stack(down_sc_list, dim=0), down_proj_scales, to=self.io_dtype) + + num_e = self.moe_attrs["num_experts"] + self.make_initializer(torch.zeros(num_e, 2 * self.moe_intermediate_size), gate_up_proj_bias, to=self.io_dtype) + self.make_initializer(torch.zeros(num_e, self.hidden_size), down_proj_bias, to=self.io_dtype) + + # --- MoE/QMoE op --- + moe_name = f"{basename}/{op_type}" + self.make_moe_op( + moe_name, + root_input=root_input, + router_probs=f"{router_reshape_name}/output_0", + weight1=gate_up_proj_weight, + scales1=gate_up_proj_scales if op_type == "QMoE" else "", + bias1=gate_up_proj_bias, + weight2=down_proj_weight, + scales2=down_proj_scales if op_type == "QMoE" else "", + bias2=down_proj_bias, + ) + + # --- Shared expert --- + shared_output = self.make_shared_expert(layer_id, mlp.shared_expert, mlp.shared_expert_gate, root_input) + combine_name = f"{basename}/Add" + self.make_add( + combine_name, + [f"{moe_name}/output_0", shared_output], + dtype=self.io_dtype, + shape=["batch_size", "sequence_length", self.hidden_size], + ) + self.layernorm_attrs["skip_input"] = f"{combine_name}/output_0" + + def make_shared_expert(self, layer_id, shared_expert, shared_expert_gate, root_input): + """Build shared expert SiLU-MLP with sigmoid gating.""" + basename = f"/model/layers.{layer_id}/shared_expert" + + gate_matmul = self.make_matmul(shared_expert.gate_proj, f"{basename}/gate_proj/MatMul", root_input) + up_matmul = self.make_matmul(shared_expert.up_proj, f"{basename}/up_proj/MatMul", root_input) + + silu_sigmoid_name = f"{basename}/gate_proj/Sigmoid" + self.make_sigmoid(silu_sigmoid_name, f"{gate_matmul}/output_0", self.io_dtype, + shape=["batch_size", "sequence_length", self.shared_expert_intermediate_size]) + + silu_mul_name = f"{basename}/gate_proj/Mul" + self.make_mul(silu_mul_name, + [f"{gate_matmul}/output_0", f"{silu_sigmoid_name}/output_0"], + dtype=self.io_dtype, + shape=["batch_size", "sequence_length", self.shared_expert_intermediate_size]) + + gate_up_mul_name = f"{basename}/Mul" + self.make_mul(gate_up_mul_name, + [f"{silu_mul_name}/output_0", f"{up_matmul}/output_0"], + dtype=self.io_dtype, + shape=["batch_size", "sequence_length", self.shared_expert_intermediate_size]) + + down_matmul = self.make_matmul(shared_expert.down_proj, f"{basename}/down_proj/MatMul", + f"{gate_up_mul_name}/output_0") + + gate_matmul_name = self.make_matmul(shared_expert_gate, f"{basename}_gate/MatMul", root_input) + gate_sigmoid_name = f"{basename}_gate/Sigmoid" + self.make_sigmoid(gate_sigmoid_name, f"{gate_matmul_name}/output_0", self.io_dtype, + shape=["batch_size", "sequence_length", 1]) + + gated_mul_name = f"{basename}/GatedMul" + self.make_mul(gated_mul_name, + [f"{down_matmul}/output_0", f"{gate_sigmoid_name}/output_0"], + dtype=self.io_dtype, + shape=["batch_size", "sequence_length", self.hidden_size]) + return f"{gated_mul_name}/output_0"