diff --git a/src/models/model.cpp b/src/models/model.cpp index 86b2e478d..203d01a84 100644 --- a/src/models/model.cpp +++ b/src/models/model.cpp @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. // -// Modifications Copyright (C) 2026 Advanced Micro Devices, Inc. All rights reserved. +// Modifications Copyright (C) 2024-2026 Advanced Micro Devices, Inc. All rights reserved. // Portions of this file consist of AI generated content. #include #include @@ -951,6 +951,7 @@ MultiModalProcessor::MultiModalProcessor(Config& config, const SessionInfo& sess {"qwen2_5_vl", Processor::Create}, {"qwen3_vl", Processor::Create}, {"qwen3_5", Processor::Create}, + {"qwen3_5_moe", Processor::Create}, {"videochat_flash_qwen", Processor::Create}} { auto processor = processor_factory_.find(config.model.type); if (processor != processor_factory_.end()) { diff --git a/src/models/model_type.h b/src/models/model_type.h index dc822fe00..52e1bc361 100644 --- a/src/models/model_type.h +++ b/src/models/model_type.h @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -// Modifications Copyright(C) 2026 Advanced Micro Devices, Inc. All rights reserved +// -------------------------------------------------------------------------- +// Modifications Copyright (C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. +// Portions of this file consist of AI generated content. // -------------------------------------------------------------------------- #pragma once @@ -21,13 +23,14 @@ struct ModelType { inline static bool IsVLM(const std::string& model_type) { // Vision-language model (VLM) - static constexpr std::array VLM = {"fara", "gemma3", "mistral3", "phi3v", "qwen2_5_vl", "qwen3_vl", "qwen3_5", "videochat_flash_qwen"}; + static constexpr std::array VLM = {"fara", "gemma3", "mistral3", "phi3v", "qwen2_5_vl", "qwen3_vl", "qwen3_5", "qwen3_5_moe", "videochat_flash_qwen"}; return std::find(VLM.begin(), VLM.end(), model_type) != VLM.end(); } inline static bool IsQwenVLFamily(const std::string& model_type) { // Qwen-VL family: models requiring 3D mRoPE position IDs - return model_type == "fara" || model_type == "qwen2_5_vl" || model_type == "qwen3_vl" || model_type == "qwen3_5"; + static constexpr std::array QwenVL = {"fara", "qwen2_5_vl", "qwen3_vl", "qwen3_5", "qwen3_5_moe"}; + return std::find(QwenVL.begin(), QwenVL.end(), model_type) != QwenVL.end(); } inline static bool IsPixtralFamily(const std::string& model_type) { diff --git a/src/python/py/models/builder.py b/src/python/py/models/builder.py index 88cd14756..2aab41b83 100644 --- a/src/python/py/models/builder.py +++ b/src/python/py/models/builder.py @@ -46,6 +46,7 @@ Qwen3VLTextModel, Qwen25VLTextModel, Qwen35TextModel, + Qwen35MoeTextModel, QwenModel, SmolLM3Model, VideoChatFlashQwenModel, @@ -310,6 +311,8 @@ def create_model( onnx_model = Qwen3Model(config, io_dtype, onnx_dtype, execution_provider, cache_dir, extra_options) elif config.architectures[0] == "Qwen3_5ForConditionalGeneration": onnx_model = Qwen35TextModel(config, io_dtype, onnx_dtype, execution_provider, cache_dir, extra_options) + elif config.architectures[0] == "Qwen3_5MoeForConditionalGeneration": + onnx_model = Qwen35MoeTextModel(config, io_dtype, onnx_dtype, execution_provider, cache_dir, extra_options) elif config.architectures[0] == "Qwen3VLForConditionalGeneration": text_config = config.text_config for key in text_config: diff --git a/src/python/py/models/builders/__init__.py b/src/python/py/models/builders/__init__.py index 1e6330b31..ae2cf33ed 100644 --- a/src/python/py/models/builders/__init__.py +++ b/src/python/py/models/builders/__init__.py @@ -29,7 +29,7 @@ Phi4MMModel, PhiModel, ) -from .qwen import Qwen3Model, Qwen25VLTextModel, Qwen3VLTextModel, Qwen35TextModel, QwenModel, VideoChatFlashQwenModel +from .qwen import Qwen3Model, Qwen25VLTextModel, Qwen3VLTextModel, Qwen35TextModel, Qwen35MoeTextModel, QwenModel, VideoChatFlashQwenModel from .smollm import SmolLM3Model from .whisper import WhisperModel @@ -62,6 +62,7 @@ "Qwen3VLTextModel", "Qwen25VLTextModel", "Qwen35TextModel", + "Qwen35MoeTextModel", "QwenModel", "SmolLM3Model", "VideoChatFlashQwenModel", diff --git a/src/python/py/models/builders/qwen.py b/src/python/py/models/builders/qwen.py index 37499c01a..459c08c5d 100644 --- a/src/python/py/models/builders/qwen.py +++ b/src/python/py/models/builders/qwen.py @@ -2,12 +2,11 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. -# -------------------------------------------------------------------------- +# ------------------------------------------------------ # Modifications Copyright (C) 2026 Advanced Micro Devices, Inc. All rights reserved. # Portions of this file consist of AI generated content. import os - import numpy as np import onnx_ir as ir import torch @@ -1003,6 +1002,8 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options) + self.model_type = "Qwen3_5_textForCausalLM" if self.is_text_only else "Qwen3_5ForConditionalGeneration" + # OffsetRMSNorm: Qwen3.5 uses (1 + weight) * RMSNorm(x). # Pre-bake the +1 into the weight initializer so the base class's # SkipSimplifiedLayerNormalization can be used directly. @@ -2052,7 +2053,6 @@ def make_genai_config(self, model_name_or_path, extra_kwargs, out_dir): "model_type": self.model_type, } self.num_layers = len(self.layer_types) - self.model_type = "Qwen3_5_textForCausalLM" if self.is_text_only else "Qwen3_5ForConditionalGeneration" self.input_names["past_key_values.key"] = "past_key_values.%d.key" self.input_names["past_key_values.value"] = "past_key_values.%d.value" self.output_names["present.key"] = "present.%d.key" @@ -2067,3 +2067,188 @@ def make_genai_config(self, model_name_or_path, extra_kwargs, out_dir): del self.input_names["past_key_values.value"] del self.output_names["present.key"] del self.output_names["present.value"] + + +class Qwen35MoeTextModel(Qwen35TextModel): + """Qwen3.5 MoE hybrid model builder. + + Extends ``Qwen35TextModel`` with Mixture-of-Experts MLP layers. + Each decoder layer replaces the dense MLP with: + - A router that selects top-k experts from ``num_experts`` candidates + - Packed routed expert weights (gate_up_proj + down_proj) + - A shared expert (always-active) with its own gating signal + + The attention side (GatedDeltaNet linear + gated full) is inherited + unchanged from the parent class. + """ + + def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): + # Map Qwen3.5-MoE config attributes to what the base class expects. + if hasattr(config, "text_config"): + tc = config.text_config + # Base class reads num_local_experts; MoE config uses num_experts + if hasattr(tc, "num_experts") and not hasattr(tc, "num_local_experts"): + tc.num_local_experts = tc.num_experts + # Base class reads intermediate_size; MoE has moe_intermediate_size + if not hasattr(tc, "intermediate_size") and hasattr(tc, "moe_intermediate_size"): + tc.intermediate_size = tc.moe_intermediate_size + + super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options) + + self.model_type = "Qwen3_5_MoeForConditionalGeneration" + + # MoE attributes specific to Qwen3.5-MoE + self.moe_attrs["activation_type"] = "swiglu" + self.moe_attrs["swiglu_fusion"] = 1 + self.moe_attrs["normalize_routing_weights"] = True + + self.moe_intermediate_size = getattr(config, "moe_intermediate_size", 512) + self.shared_expert_intermediate_size = getattr(config, "shared_expert_intermediate_size", self.moe_intermediate_size) + + # MoE layers use MoE/QMoE ops instead of individual MatMul nodes, + # so remove any /mlp/ MatMul overrides that don't apply. + algo_config = self.quant_attrs["int4"].get("algo_config") + if algo_config is not None and hasattr(algo_config, "customized_weight_config"): + keys_to_remove = [k for k in algo_config.customized_weight_config if "/mlp/" in k] + for k in keys_to_remove: + del algo_config.customized_weight_config[k] + + def make_layer(self, layer_id, layer): + """Override to use MoE instead of dense MLP.""" + attn_module = layer.linear_attn if self.layer_types[layer_id] == "linear_attention" else layer.self_attn + self.make_layernorm( + layer_id, + layer.input_layernorm, + skip=not self.layernorm_attrs["first_layernorm"], + simple=self.layernorm_attrs["simple"], + location="input", + ) + self.make_attention(layer_id, attn_module, root_input=self.layernorm_attrs["output_0"]) + self.make_layernorm( + layer_id, + layer.post_attention_layernorm, + skip=True, + simple=self.layernorm_attrs["simple"], + location="post_attention", + ) + self.make_moe(layer_id, layer.mlp, root_input=self.layernorm_attrs["output_0"]) + + self.layernorm_attrs["first_layernorm"] = False + if layer_id == self.num_layers - 1: + self.layernorm_attrs["last_layernorm"] = True + + def make_moe(self, layer_id, mlp, root_input): + """Build MoE + shared expert subgraph for one decoder layer.""" + basename = f"/model/layers.{layer_id}/moe" + op_type = self.moe_attrs["op_type"] + moe_weight_type = f"{'q' if op_type == 'QMoE' else ''}weight" + + # --- Router (bias-free gate) --- + router_basename = f"{basename}/router/MatMul" + router_matmul_name = self.make_matmul(mlp.gate, router_basename, root_input) + router_reshape_name = f"{basename}/router/Reshape" + self.make_reshape( + router_reshape_name, + [f"{router_matmul_name}/output_0", + f"/model/constants/INT64/{[-1, self.moe_attrs['num_experts']]}"], + dtype=self.io_dtype, + shape=["batch_size * sequence_length", self.moe_attrs["num_experts"]], + ) + + # --- Routed expert weights --- + gate_up_proj_weight = f"model.layers.{layer_id}.moe.experts.gate_up_proj.{moe_weight_type}" + gate_up_proj_scales = f"model.layers.{layer_id}.moe.experts.gate_up_proj.scales" + gate_up_proj_bias = f"model.layers.{layer_id}.moe.experts.gate_up_proj.bias" + down_proj_weight = f"model.layers.{layer_id}.moe.experts.down_proj.{moe_weight_type}" + down_proj_scales = f"model.layers.{layer_id}.moe.experts.down_proj.scales" + down_proj_bias = f"model.layers.{layer_id}.moe.experts.down_proj.bias" + + # Repack HF concatenated [gate|up] to ORT interleaved [g0,u0,g1,u1,...] for swiglu_fusion=1 + raw_gate_up = mlp.experts.gate_up_proj + half = raw_gate_up.shape[1] // 2 + interleaved = torch.stack([raw_gate_up[:, :half, :], raw_gate_up[:, half:, :]], dim=2).reshape_as(raw_gate_up) + + if op_type == "MoE": + self.make_initializer(interleaved, gate_up_proj_weight, to=self.io_dtype) + self.make_initializer(mlp.experts.down_proj, down_proj_weight, to=self.io_dtype) + else: + gate_up_qw_list, gate_up_sc_list = [], [] + down_qw_list, down_sc_list = [], [] + for i in range(self.moe_attrs["num_experts"]): + qw1, sc1 = self.make_qmoe_weights(interleaved[i]) + gate_up_qw_list.append(qw1) + gate_up_sc_list.append(sc1) + qw2, sc2 = self.make_qmoe_weights(mlp.experts.down_proj[i]) + down_qw_list.append(qw2) + down_sc_list.append(sc2) + self.make_initializer(torch.stack(gate_up_qw_list, dim=0).to(torch.uint8), gate_up_proj_weight) + self.make_initializer(torch.stack(down_qw_list, dim=0).to(torch.uint8), down_proj_weight) + self.make_initializer(torch.stack(gate_up_sc_list, dim=0), gate_up_proj_scales, to=self.io_dtype) + self.make_initializer(torch.stack(down_sc_list, dim=0), down_proj_scales, to=self.io_dtype) + + num_e = self.moe_attrs["num_experts"] + self.make_initializer(torch.zeros(num_e, 2 * self.moe_intermediate_size), gate_up_proj_bias, to=self.io_dtype) + self.make_initializer(torch.zeros(num_e, self.hidden_size), down_proj_bias, to=self.io_dtype) + + # --- MoE/QMoE op --- + moe_name = f"{basename}/{op_type}" + self.make_moe_op( + moe_name, + root_input=root_input, + router_probs=f"{router_reshape_name}/output_0", + weight1=gate_up_proj_weight, + scales1=gate_up_proj_scales if op_type == "QMoE" else "", + bias1=gate_up_proj_bias, + weight2=down_proj_weight, + scales2=down_proj_scales if op_type == "QMoE" else "", + bias2=down_proj_bias, + ) + + # --- Shared expert --- + shared_output = self.make_shared_expert(layer_id, mlp.shared_expert, mlp.shared_expert_gate, root_input) + combine_name = f"{basename}/Add" + self.make_add( + combine_name, + [f"{moe_name}/output_0", shared_output], + dtype=self.io_dtype, + shape=["batch_size", "sequence_length", self.hidden_size], + ) + self.layernorm_attrs["skip_input"] = f"{combine_name}/output_0" + + def make_shared_expert(self, layer_id, shared_expert, shared_expert_gate, root_input): + """Build shared expert SiLU-MLP with sigmoid gating.""" + basename = f"/model/layers.{layer_id}/shared_expert" + + gate_matmul = self.make_matmul(shared_expert.gate_proj, f"{basename}/gate_proj/MatMul", root_input) + up_matmul = self.make_matmul(shared_expert.up_proj, f"{basename}/up_proj/MatMul", root_input) + + silu_sigmoid_name = f"{basename}/gate_proj/Sigmoid" + self.make_sigmoid(silu_sigmoid_name, f"{gate_matmul}/output_0", self.io_dtype, + shape=["batch_size", "sequence_length", self.shared_expert_intermediate_size]) + + silu_mul_name = f"{basename}/gate_proj/Mul" + self.make_mul(silu_mul_name, + [f"{gate_matmul}/output_0", f"{silu_sigmoid_name}/output_0"], + dtype=self.io_dtype, + shape=["batch_size", "sequence_length", self.shared_expert_intermediate_size]) + + gate_up_mul_name = f"{basename}/Mul" + self.make_mul(gate_up_mul_name, + [f"{silu_mul_name}/output_0", f"{up_matmul}/output_0"], + dtype=self.io_dtype, + shape=["batch_size", "sequence_length", self.shared_expert_intermediate_size]) + + down_matmul = self.make_matmul(shared_expert.down_proj, f"{basename}/down_proj/MatMul", + f"{gate_up_mul_name}/output_0") + + gate_matmul_name = self.make_matmul(shared_expert_gate, f"{basename}_gate/MatMul", root_input) + gate_sigmoid_name = f"{basename}_gate/Sigmoid" + self.make_sigmoid(gate_sigmoid_name, f"{gate_matmul_name}/output_0", self.io_dtype, + shape=["batch_size", "sequence_length", 1]) + + gated_mul_name = f"{basename}/GatedMul" + self.make_mul(gated_mul_name, + [f"{down_matmul}/output_0", f"{gate_sigmoid_name}/output_0"], + dtype=self.io_dtype, + shape=["batch_size", "sequence_length", self.hidden_size]) + return f"{gated_mul_name}/output_0"