From e13f43e7d6dbdaad9b887038b3a1445961f7adb4 Mon Sep 17 00:00:00 2001 From: "ext.wangxiaochi1" Date: Thu, 26 Feb 2026 15:10:01 +0800 Subject: [PATCH 01/13] add qwen3-next decoder layer --- xllm/core/framework/state_dict/utils.cpp | 55 ++ xllm/core/framework/state_dict/utils.h | 10 + xllm/core/layers/CMakeLists.txt | 2 + xllm/core/layers/common/CMakeLists.txt | 10 + xllm/core/layers/common/linear.cpp | 25 + xllm/core/layers/common/linear.h | 5 + .../common/partial_rotary_embedding.cpp | 73 +++ .../layers/common/partial_rotary_embedding.h | 52 ++ .../layers/common/qwen3_next_attention.cpp | 188 ++++++ .../core/layers/common/qwen3_next_attention.h | 73 +++ .../common/qwen3_next_gated_delta_net.cpp | 558 ++++++++++++++++++ .../common/qwen3_next_gated_delta_net.h | 81 +++ .../layers/common/qwen3_next_rms_norm.cpp | 47 ++ xllm/core/layers/common/qwen3_next_rms_norm.h | 44 ++ xllm/core/layers/common/rms_norm_gated.cpp | 57 ++ xllm/core/layers/common/rms_norm_gated.h | 45 ++ xllm/core/layers/qwen3_next_decoder_layer.cpp | 139 +++++ xllm/core/layers/qwen3_next_decoder_layer.h | 61 ++ 18 files changed, 1525 insertions(+) create mode 100644 xllm/core/layers/common/partial_rotary_embedding.cpp create mode 100644 xllm/core/layers/common/partial_rotary_embedding.h create mode 100644 xllm/core/layers/common/qwen3_next_attention.cpp create mode 100644 xllm/core/layers/common/qwen3_next_attention.h create mode 100644 xllm/core/layers/common/qwen3_next_gated_delta_net.cpp create mode 100644 xllm/core/layers/common/qwen3_next_gated_delta_net.h create mode 100644 xllm/core/layers/common/qwen3_next_rms_norm.cpp create mode 100644 xllm/core/layers/common/qwen3_next_rms_norm.h create mode 100644 xllm/core/layers/common/rms_norm_gated.cpp create mode 100644 xllm/core/layers/common/rms_norm_gated.h create mode 100644 xllm/core/layers/qwen3_next_decoder_layer.cpp create mode 100644 xllm/core/layers/qwen3_next_decoder_layer.h diff --git a/xllm/core/framework/state_dict/utils.cpp b/xllm/core/framework/state_dict/utils.cpp index 969c07175..482c34de4 100644 --- a/xllm/core/framework/state_dict/utils.cpp +++ b/xllm/core/framework/state_dict/utils.cpp @@ -311,6 +311,61 @@ void load_merged_weight(const StateDict& state_dict, weight_is_loaded = true; } +void load_merged_weight_v2(const StateDict& state_dict, + const std::string& name, + int64_t dim, + int32_t rank, + int32_t world_size, + int32_t shard_tensor_count, + const std::vector& shard_sizes, + torch::Tensor& weight, + bool& weight_is_loaded) { + if (weight_is_loaded) { + return; + } + const auto& tensor = state_dict.get_tensor(name); + if (!tensor.defined()) { + return; + } + + // Check that shard_tensor_count matches the size of shard_sizes vector + CHECK_EQ(shard_tensor_count, static_cast(shard_sizes.size())) + << "shard_tensor_count does not match shard_sizes vector size for " << state_dict.prefix() << name; + + // Calculate total expected size + int64_t total_expected_size = 0; + for (int64_t size : shard_sizes) { + total_expected_size += size * world_size; + } + CHECK_EQ(tensor.size(dim), total_expected_size) + << name << "[" << dim << "] size mismatch for " << state_dict.prefix() << name; + + std::vector shard_tensors; + + for (size_t shard_id = 0; shard_id < shard_tensor_count; shard_id++) { + int64_t shard_size = shard_sizes[shard_id]; + + // Calculate the offset for this shard and rank + // First, skip all the data from previous shards for all ranks + int64_t shard_offset = 0; + for (int32_t prev_shard = 0; prev_shard < shard_id; ++prev_shard) { + shard_offset += shard_sizes[prev_shard] * world_size; + } + + // Then add the offset within this shard for the current rank + shard_offset += rank * shard_size; + + shard_tensors.push_back( + tensor.slice(dim, shard_offset, shard_offset + shard_size)); + } + + auto merged_weight = torch::cat(shard_tensors, dim); + CHECK_EQ(weight.sizes(), merged_weight.sizes()) + << "weight size mismatch for " << state_dict.prefix() << name; + weight.copy_(merged_weight); + weight_is_loaded = true; +} + } // namespace weight } // namespace xllm diff --git a/xllm/core/framework/state_dict/utils.h b/xllm/core/framework/state_dict/utils.h index 15f43d473..92e24bd89 100644 --- a/xllm/core/framework/state_dict/utils.h +++ b/xllm/core/framework/state_dict/utils.h @@ -115,6 +115,16 @@ void load_merged_weight(const StateDict& state_dict, torch::Tensor& weight, bool& weight_is_loaded); +void load_merged_weight_v2(const StateDict& state_dict, + const std::string& name, + int64_t dim, + int32_t rank, + int32_t world_size, + int32_t shard_tensor_count, + const std::vector& shard_sizes, + torch::Tensor& weight, + bool& weight_is_loaded); + } // namespace weight // helper macros for defining and loading weights diff --git a/xllm/core/layers/CMakeLists.txt b/xllm/core/layers/CMakeLists.txt index 3e4556a1e..250e7763d 100644 --- a/xllm/core/layers/CMakeLists.txt +++ b/xllm/core/layers/CMakeLists.txt @@ -43,12 +43,14 @@ cc_library( qwen3_vision_layer.h qwen3_decoder_layer.h qwen3_moe_decoder_layer.h + qwen3_next_decoder_layer.h SRCS qwen2_vision_layer.cpp qwen2_decoder_layer.cpp qwen2_5_vision_layer.cpp qwen3_vision_layer.cpp qwen3_moe_decoder_layer.cpp + qwen3_next_decoder_layer.cpp DEPS $<$:ilu_layers> :common_layers diff --git a/xllm/core/layers/common/CMakeLists.txt b/xllm/core/layers/common/CMakeLists.txt index a1c476600..5257b8216 100755 --- a/xllm/core/layers/common/CMakeLists.txt +++ b/xllm/core/layers/common/CMakeLists.txt @@ -6,6 +6,11 @@ cc_library( HDRS qwen2_attention.h qwen2_vision_attention.h + qwen3_next_attention.h + qwen3_next_gated_delta_net.h + qwen3_next_rms_norm.h + partial_rotary_embedding.h + rms_norm_gated.h rms_norm.h rotary_embedding.h rotary_embedding_util.h @@ -23,6 +28,11 @@ cc_library( SRCS qwen2_attention.cpp qwen2_vision_attention.cpp + qwen3_next_attention.cpp + qwen3_next_gated_delta_net.cpp + qwen3_next_rms_norm.cpp + partial_rotary_embedding.cpp + rms_norm_gated.cpp rms_norm.cpp rotary_embedding.cpp rotary_embedding_util.cpp diff --git a/xllm/core/layers/common/linear.cpp b/xllm/core/layers/common/linear.cpp index 2f82fe2ec..18a26d912 100644 --- a/xllm/core/layers/common/linear.cpp +++ b/xllm/core/layers/common/linear.cpp @@ -552,6 +552,31 @@ std::optional ColumnParallelLinearImpl::get_input_scale() const { return std::nullopt; } +// load_state_dict for merged weights with variable shard sizes +void ColumnParallelLinearImpl::load_state_dict( + const StateDict& state_dict, + int32_t shard_tensor_count, + const std::vector& shard_sizes) { + const int64_t rank = rank_; + const int64_t world_size = world_size_; + + // load and merge the weights on dim 0 with variable shard sizes + if (quant_args_.quant_method() == "smoothquant") { + // For smoothquant, load quantized weights with variable shard sizes + LOAD_MERGED_WEIGHT_V2(qweight, 0); + LOAD_MERGED_WEIGHT_V2(per_channel_scale, 0); + } else { + // For regular weights, use the new merged weight loading with variable shard sizes + LOAD_MERGED_WEIGHT_V2(weight, 0); + } + + if (bias_.defined()) { + // For bias, we might need to handle it differently based on the use case + // For now, we'll use the same approach if bias is also sharded + LOAD_MERGED_WEIGHT_V2(bias, 0); + } +} + QKVParallelLinearImpl::QKVParallelLinearImpl( int64_t hidden_size, int64_t num_heads, diff --git a/xllm/core/layers/common/linear.h b/xllm/core/layers/common/linear.h index 04409f5cb..9d1344ebf 100644 --- a/xllm/core/layers/common/linear.h +++ b/xllm/core/layers/common/linear.h @@ -71,6 +71,11 @@ class ColumnParallelLinearImpl : public torch::nn::Module { void load_state_dict(const StateDict& state_dict, const std::vector& prefixes); + // load_state_dict for merged weights with variable shard sizes + void load_state_dict(const StateDict& state_dict, + int32_t shard_tensor_count, + const std::vector& shard_sizes); + void pretty_print(std::ostream& stream) const { stream << name() << " " << weight_.sizes() << " " << weight_.device(); } diff --git a/xllm/core/layers/common/partial_rotary_embedding.cpp b/xllm/core/layers/common/partial_rotary_embedding.cpp new file mode 100644 index 000000000..e44d12c24 --- /dev/null +++ b/xllm/core/layers/common/partial_rotary_embedding.cpp @@ -0,0 +1,73 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "partial_rotary_embedding.h" + +#include "kernels/ops_api.h" +#include "platform/device.h" +namespace xllm { +namespace layer { + +PartialRotaryEmbeddingImpl::PartialRotaryEmbeddingImpl(int64_t rotary_dim, + int64_t max_position_embeddings, + int64_t rope_theta, + int64_t head_size, + bool is_neox_style, + bool interleaved, + const torch::TensorOptions& options) + : head_size_(head_size), + rotary_dim_(rotary_dim), + is_neox_style_(is_neox_style), + interleaved_(interleaved) { + auto dev_options = torch::TensorOptions().device(Device::type_torch()); + + auto inv_freq_t = torch::arange(/*start=*/0, + /*end=*/rotary_dim_, + /*step=*/2, + torch::TensorOptions().dtype(torch::kFloat)); + inv_freq_t = inv_freq_t.to(dev_options); + auto inv_freq = + 1.0 / + torch::pow(rope_theta, inv_freq_t / static_cast(rotary_dim_)); + + auto t = torch::arange(0, max_position_embeddings, 1, torch::kFloat32); + t = t.to(dev_options); + + const auto freqs = torch::einsum("i,j->ij", {t, inv_freq}); + const auto cos_sin = + torch::cat({freqs.cos(), freqs.sin()}, /*dim=*/-1).contiguous(); + cos_sin_cache_ = register_buffer("cos_sin_cache", cos_sin.to(options)); +} + +void PartialRotaryEmbeddingImpl::forward(const torch::Tensor& positions, + torch::Tensor& q, + torch::Tensor& k) { + xllm::kernel::PartialRotaryEmbedding partial_rotary_params; + partial_rotary_params.positions = positions; + partial_rotary_params.query = q; + partial_rotary_params.key = k; + partial_rotary_params.head_size = head_size_; + partial_rotary_params.rotary_dim = rotary_dim_; + partial_rotary_params.cos_sin_cache = cos_sin_cache_; + partial_rotary_params.is_neox_style = is_neox_style_; + auto [q_rot, k_rot] = xllm::kernel::partial_rotary_embedding(partial_rotary_params); + + q = q_rot; + k = k_rot; +} + +} // namespace layer +} // namespace xllm + diff --git a/xllm/core/layers/common/partial_rotary_embedding.h b/xllm/core/layers/common/partial_rotary_embedding.h new file mode 100644 index 000000000..2851888c5 --- /dev/null +++ b/xllm/core/layers/common/partial_rotary_embedding.h @@ -0,0 +1,52 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include + +#include + +namespace xllm { +namespace layer { +class PartialRotaryEmbeddingImpl : public torch::nn::Module { + public: + PartialRotaryEmbeddingImpl(int64_t rotary_dim, + int64_t max_position_embeddings, + int64_t rope_theta, + int64_t head_size, + bool is_neox_style, + bool interleaved, + const torch::TensorOptions& options); + + void forward(const torch::Tensor& positions, + torch::Tensor& q, + torch::Tensor& k); + + torch::Tensor get_cos_sin_cache() { return cos_sin_cache_; } + + private: + int64_t head_size_; + int64_t rotary_dim_; + bool is_neox_style_; + bool interleaved_; + torch::Tensor cos_sin_cache_; +}; +TORCH_MODULE(PartialRotaryEmbedding); + +} // namespace layer +} // namespace xllm + diff --git a/xllm/core/layers/common/qwen3_next_attention.cpp b/xllm/core/layers/common/qwen3_next_attention.cpp new file mode 100644 index 000000000..f78983286 --- /dev/null +++ b/xllm/core/layers/common/qwen3_next_attention.cpp @@ -0,0 +1,188 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "qwen3_next_attention.h" +#include +#include +#include +namespace xllm { +namespace layer { + +Qwen3NextAttentionImpl::Qwen3NextAttentionImpl(const ModelArgs& args, + const QuantArgs& quant_args, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options, + int32_t layer_id) { + const int64_t tp_size = parallel_args.tp_group_->world_size(); + const int64_t total_num_heads = args.n_heads(); + const int64_t total_num_kv_heads = args.n_kv_heads().value_or(args.n_heads()); + layer_id_ = layer_id; + rank_ = parallel_args.tp_group_->rank(); + CHECK(total_num_heads % tp_size == 0); + num_heads_ = total_num_heads / tp_size; + + if (total_num_kv_heads >= tp_size) { + CHECK(total_num_kv_heads % tp_size == 0); + num_kv_heads_ = total_num_kv_heads / tp_size; + num_kv_head_replicas_ = 1; + } else { + CHECK(tp_size % total_num_kv_heads == 0); + num_kv_heads_ = 1; + num_kv_head_replicas_ = tp_size / total_num_kv_heads; + } + + head_dim_ = args.head_dim(); + q_size_ = num_heads_ * head_dim_; + kv_size_ = num_kv_heads_ * head_dim_; + scaling_ = 1.0f / std::sqrt(static_cast(head_dim_)); + attn_output_gate_ = args.attn_output_gate(); + // 1. QKV linear + qkv_proj_ = register_module("qkv_proj", + QKVParallelLinear(args.hidden_size(), + attn_output_gate_ ? num_heads_ * 2 : num_heads_, + num_kv_heads_, + args.head_dim(), + num_kv_head_replicas_, + /*bias=*/args.attention_bias(), + /*gather_output=*/false, + parallel_args, + options)); + + // 2. O proj + o_proj_ = register_module("o_proj", + RowParallelLinear(total_num_heads * head_dim_, + args.hidden_size(), + /*bias=*/false, + /*input_is_parallelized=*/true, + /*if_reduce_results=*/true, + quant_args, + parallel_args, + options)); + + // 3. Q norm + q_norm_ = register_module("q_norm", + Qwen3NextRMSNorm(head_dim_, args.rms_norm_eps(), options)); + + // 4. K norm + k_norm_ = register_module("k_norm", + Qwen3NextRMSNorm(head_dim_, args.rms_norm_eps(), options)); + + // 5. Rotary embedding + const int rotary_dim = static_cast(head_dim_ * args.partial_rotary_factor()); + rotary_emb_ = register_module( + "rotary_emb", + PartialRotaryEmbedding(rotary_dim, + args.max_position_embeddings(), + args.rope_theta(), + head_dim_, + true, + false, + options)); + + // 6. Attention + attn_ = register_module("attn", + Attention(num_heads_, + head_dim_, + scaling_, + num_kv_heads_, + args.sliding_window())); +} + +torch::Tensor Qwen3NextAttentionImpl::forward( + const torch::Tensor& positions, + const torch::Tensor& hidden_states, + const AttentionMetadata& attn_metadata, + KVCache& kv_cache) { + // 1. qkv projection + auto qkv = qkv_proj_->forward(hidden_states); + torch::Tensor q, k, v; + torch::Tensor gate; + + if (attn_output_gate_) { + // Split qkv for attn_output_gate case: [q_size*2, kv_size, kv_size] + auto q_gate = qkv.slice(/*dim=*/-1, 0, q_size_ * 2); + k = qkv.slice(/*dim=*/-1, q_size_ * 2, q_size_ * 2 + kv_size_); + v = qkv.slice(/*dim=*/-1, q_size_ * 2 + kv_size_, q_size_ * 2 + kv_size_ * 2); + v = v.contiguous(); + + std::vector orig_shape; + int64_t q_gate_dim = q_gate.dim(); + orig_shape = std::vector( + q_gate.sizes().slice(0, q_gate_dim - 1).begin(), + q_gate.sizes().slice(0, q_gate_dim - 1).end() + ); + + std::vector new_shape = orig_shape; + new_shape.push_back(num_heads_); + int64_t orig_total = 1; + for (auto d : orig_shape) orig_total *= d; + int64_t last_dim = q_gate.numel() / (orig_total * num_heads_); + new_shape.push_back(last_dim); + + torch::Tensor q_gate_reshaped = q_gate.reshape(new_shape); + + auto chunks = torch::chunk(q_gate_reshaped, 2, /*dim=*/-1); + q = chunks[0]; + gate = chunks[1]; + + std::vector q_new_shape = orig_shape; + q_new_shape.push_back(q.numel() / orig_total); + q = q.reshape(q_new_shape); + + std::vector gate_new_shape = orig_shape; + gate_new_shape.push_back(gate.numel() / orig_total); + gate = gate.reshape(gate_new_shape); + } else { + // Normal case: [q_size, kv_size, kv_size] + q = qkv.slice(/*dim=*/-1, 0, q_size_); + k = qkv.slice(/*dim=*/-1, q_size_, q_size_ + kv_size_); + v = qkv.slice(/*dim=*/-1, q_size_ + kv_size_, q_size_ + 2 * kv_size_); + } + + const int64_t T = q.size(0); + + auto q_reshaped = q.reshape({T, num_heads_, head_dim_}); + auto q_normed = q_norm_->forward(q_reshaped); + auto k_reshaped = k.reshape({T, num_kv_heads_, head_dim_}); + auto k_normed = k_norm_->forward(k_reshaped); + + q = q_normed.view({T, q_size_}); + k = k_normed.view({T, kv_size_}); + + rotary_emb_->forward(positions,q,k); + auto out = std::get<0>(attn_->forward(attn_metadata, q, k, v, kv_cache)); + + if (attn_output_gate_) { + gate = torch::sigmoid(gate); + out = out * gate; + } + + out = o_proj_->forward(out); + return out; +} + +void Qwen3NextAttentionImpl::load_state_dict(const StateDict& state_dict) { + qkv_proj_->load_state_dict(state_dict); + o_proj_->load_state_dict(state_dict.get_dict_with_prefix("o_proj.")); + if (auto w = state_dict.get_tensor("q_norm.weight"); w.defined()) { + q_norm_->load_state_dict(StateDict({{"weight", w}})); + } + if (auto w = state_dict.get_tensor("k_norm.weight"); w.defined()) { + k_norm_->load_state_dict(StateDict({{"weight", w}})); + } +} + +} // namespace layer +} // namespace xllm diff --git a/xllm/core/layers/common/qwen3_next_attention.h b/xllm/core/layers/common/qwen3_next_attention.h new file mode 100644 index 000000000..a7059216b --- /dev/null +++ b/xllm/core/layers/common/qwen3_next_attention.h @@ -0,0 +1,73 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include + +#include "attention.h" +#include "framework/kv_cache/kv_cache.h" +#include "framework/model/model_args.h" +#include "framework/parallel_state/parallel_args.h" +#include "framework/quant_args.h" +#include "framework/state_dict/state_dict.h" +#include "linear.h" +#include "qwen3_next_rms_norm.h" +#include "partial_rotary_embedding.h" + +namespace xllm { +namespace layer { + +class Qwen3NextAttentionImpl : public torch::nn::Module { + public: + Qwen3NextAttentionImpl() = default; + Qwen3NextAttentionImpl(const ModelArgs& args, + const QuantArgs& quant_args, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options, + int32_t layer_id); + + torch::Tensor forward(const torch::Tensor& positions, + const torch::Tensor& hidden_states, + const AttentionMetadata& attn_metadata, + KVCache& kv_cache); + + void load_state_dict(const StateDict& state_dict); + + private: + int64_t num_heads_; + int64_t num_kv_heads_; + int64_t num_kv_head_replicas_; + int64_t head_dim_; + int64_t q_size_; + int64_t kv_size_; + float scaling_; + bool attn_output_gate_; + int32_t layer_id_; + int32_t rank_; + + QKVParallelLinear qkv_proj_{nullptr}; + RowParallelLinear o_proj_{nullptr}; + + Qwen3NextRMSNorm q_norm_{nullptr}; + Qwen3NextRMSNorm k_norm_{nullptr}; + + Attention attn_{nullptr}; + PartialRotaryEmbedding rotary_emb_{nullptr}; +}; +TORCH_MODULE(Qwen3NextAttention); + +} // namespace layer +} // namespace xllm diff --git a/xllm/core/layers/common/qwen3_next_gated_delta_net.cpp b/xllm/core/layers/common/qwen3_next_gated_delta_net.cpp new file mode 100644 index 000000000..b208c44d0 --- /dev/null +++ b/xllm/core/layers/common/qwen3_next_gated_delta_net.cpp @@ -0,0 +1,558 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + https://github.com/jd-opensource/xllm/blob/main/LICENSE +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "qwen3_next_gated_delta_net.h" + +#include +#include "xllm/core/kernels/ops_api.h" + +#include +#include + +#include +#include +#include + +namespace xllm { +namespace layer { + + +namespace { +torch::Tensor l2norm(const torch::Tensor& x, int64_t dim, double eps = 1e-6) { + auto norm = torch::sqrt(torch::sum(torch::square(x), dim, true) + eps); + return x / norm; +} + +std::tuple torch_recurrent_gated_delta_rule( + torch::Tensor query, + torch::Tensor key, + torch::Tensor value, + torch::Tensor g, + torch::Tensor beta, + std::optional initial_state, + bool output_final_state = true, + bool use_qk_l2norm_in_kernel = true +) { + auto initial_dtype = query.dtype(); + + if (use_qk_l2norm_in_kernel) { + query = l2norm(query, -1, 1e-6); + key = l2norm(key, -1, 1e-6); + } + + auto to_float32_and_transpose = [](torch::Tensor x) { + return x.transpose(1, 2).contiguous().to(torch::kFloat32); + }; + query = to_float32_and_transpose(query); + key = to_float32_and_transpose(key); + value = to_float32_and_transpose(value); + beta = to_float32_and_transpose(beta); + g = to_float32_and_transpose(g); + + int64_t batch_size = key.size(0); + int64_t num_heads = key.size(1); + int64_t sequence_length = key.size(2); + int64_t k_head_dim = key.size(3); + int64_t v_head_dim = value.size(3); + + float scale_val = 1.0 / std::sqrt(static_cast(query.size(-1))); + torch::Tensor scale = torch::tensor(scale_val, query.options()); + query = query * scale; + torch::Tensor core_attn_out = torch::zeros({batch_size, num_heads, sequence_length, v_head_dim}, + torch::TensorOptions().dtype(torch::kFloat32).device(value.device())); + torch::Tensor last_recurrent_state; + if (!initial_state.has_value()) { + last_recurrent_state = torch::zeros({batch_size, num_heads, k_head_dim, v_head_dim}, + torch::TensorOptions().dtype(torch::kFloat32).device(value.device())); + } else { + last_recurrent_state = initial_state.value().to(value.device(), torch::kFloat32); + } + + for (int64_t i = 0; i < sequence_length; ++i) { + torch::Tensor q_t = query.select(2, i); + torch::Tensor k_t = key.select(2, i); + torch::Tensor v_t = value.select(2, i); + torch::Tensor g_t = g.select(2, i).exp().unsqueeze(-1).unsqueeze(-1); + torch::Tensor beta_t = beta.select(2, i).unsqueeze(-1); + last_recurrent_state = last_recurrent_state * g_t; + torch::Tensor kv_mem = torch::sum(last_recurrent_state * k_t.unsqueeze(-1), -2); + torch::Tensor delta = (v_t - kv_mem) * beta_t; + last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2); + core_attn_out.select(2, i) = torch::sum(last_recurrent_state * q_t.unsqueeze(-1), -2); + } + + core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype); + return std::make_tuple(core_attn_out, last_recurrent_state); +} + +std::tuple torch_chunk_gated_delta_rule( + torch::Tensor query, + torch::Tensor key, + torch::Tensor value, + torch::Tensor g, + torch::Tensor beta, + int64_t chunk_size = 64, + c10::optional initial_state = c10::nullopt, + bool output_final_state = true, + bool use_qk_l2norm_in_kernel = true) { + auto initial_dtype = query.dtype(); + if (use_qk_l2norm_in_kernel) { + query = l2norm(query, -1, 1e-6); + key = l2norm(key, -1, 1e-6); + } + auto to_float32 = [](torch::Tensor x) { + return x.transpose(1, 2).contiguous().to(torch::kFloat32); + }; + + query = to_float32(query); + key = to_float32(key); + value = to_float32(value); + beta = to_float32(beta); + g = to_float32(g); + + auto batch_size = query.size(0); + auto num_heads = query.size(1); + auto sequence_length = query.size(2); + auto k_head_dim = key.size(-1); + auto v_head_dim = value.size(-1); + + int64_t pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size; + query = torch::nn::functional::pad(query, torch::nn::functional::PadFuncOptions({0, 0, 0, pad_size})); + key = torch::nn::functional::pad(key, torch::nn::functional::PadFuncOptions({0, 0, 0, pad_size})); + value = torch::nn::functional::pad(value, torch::nn::functional::PadFuncOptions({0, 0, 0, pad_size})); + beta = torch::nn::functional::pad(beta, torch::nn::functional::PadFuncOptions({0, pad_size})); + g = torch::nn::functional::pad(g, torch::nn::functional::PadFuncOptions({0, pad_size})); + + int64_t total_sequence_length = sequence_length + pad_size; + float scale = 1.0 / std::sqrt(static_cast(query.size(-1))); + query = query * scale; + auto v_beta = value * beta.unsqueeze(-1); + auto k_beta = key * beta.unsqueeze(-1); + auto reshape_to_chunks = [chunk_size](torch::Tensor x) { + auto shape = x.sizes(); + std::vector new_shape = { + shape[0], shape[1], + shape[2] / chunk_size, chunk_size, + shape[3] + }; + return x.reshape(new_shape); + }; + + query = reshape_to_chunks(query); + key = reshape_to_chunks(key); + value = reshape_to_chunks(value); + k_beta = reshape_to_chunks(k_beta); + v_beta = reshape_to_chunks(v_beta); + + auto g_shape = g.sizes(); + std::vector g_new_shape = { + g_shape[0], g_shape[1], + g_shape[2] / chunk_size, chunk_size + }; + g = g.reshape(g_new_shape); + auto mask = torch::triu( + torch::ones({chunk_size, chunk_size}, torch::TensorOptions().dtype(torch::kBool).device(query.device())), + 0 + ); + + g = g.cumsum(-1); + auto g_diff = g.unsqueeze(-1) - g.unsqueeze(-2); + auto decay_mask = g_diff.tril().exp().to(torch::kFloat32); + decay_mask = decay_mask.tril(); + auto attn = -(torch::matmul(k_beta, key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0.0); + for (int64_t i = 1; i < chunk_size; ++i) { + if (!attn.is_contiguous()) { + attn = attn.contiguous(); + } + auto row = attn.slice(-2, i, i+1).slice(-1, 0, i).squeeze(-2).clone().contiguous(); + auto sub = attn.slice(-2, 0, i).slice(-1, 0, i).clone().contiguous(); + auto row_unsq = row.unsqueeze(-1).contiguous(); + auto row_sub_mul = (row_unsq * sub).contiguous(); + auto row_sub_sum = row_sub_mul.sum(-2).contiguous(); + auto row_final = (row + row_sub_sum).contiguous(); + attn.index_put_( + { + torch::indexing::Ellipsis, + torch::indexing::Slice(i, i+1), + torch::indexing::Slice(0, i) + }, + row_final.unsqueeze(-2) + ); + } + + attn = attn + torch::eye(chunk_size, torch::TensorOptions().dtype(attn.dtype()).device(attn.device())); + value = torch::matmul(attn, v_beta); + auto k_cumdecay = torch::matmul(attn, (k_beta * g.exp().unsqueeze(-1))); + torch::Tensor last_recurrent_state; + if (!initial_state.has_value()) { + last_recurrent_state = torch::zeros( + {batch_size, num_heads, k_head_dim, v_head_dim}, + torch::TensorOptions().dtype(value.dtype()).device(value.device()) + ); + } else { + last_recurrent_state = initial_state.value().to(value); + } + auto core_attn_out = torch::zeros_like(value); + mask = torch::triu( + torch::ones({chunk_size, chunk_size}, torch::TensorOptions().dtype(torch::kBool).device(query.device())), + 1 + ); + int64_t num_chunks = total_sequence_length / chunk_size; + for (int64_t i = 0; i < num_chunks; ++i) { + auto q_i = query.select(2, i); + auto k_i = key.select(2, i); + auto v_i = value.select(2, i); + auto attn_i = (torch::matmul(q_i, k_i.transpose(-1, -2)) * decay_mask.select(2, i)).masked_fill_(mask, 0.0); + auto v_prime = torch::matmul(k_cumdecay.select(2, i), last_recurrent_state); + auto v_new = v_i - v_prime; + auto attn_inter = torch::matmul(q_i * g.select(2, i).unsqueeze(-1).exp(), last_recurrent_state); + core_attn_out.select(2, i) = attn_inter + torch::matmul(attn_i, v_new); + auto g_i_last = g.select(2, i).select(-1, -1).unsqueeze(-1); + auto g_exp_term = (g_i_last - g.select(2, i)).exp().unsqueeze(-1); + auto k_g_exp = (k_i * g_exp_term).transpose(-1, -2).contiguous(); + last_recurrent_state = + last_recurrent_state * g_i_last.unsqueeze(-1).exp() + + torch::matmul(k_g_exp, v_new); + } + auto core_attn_out_shape = core_attn_out.sizes(); + std::vector reshape_shape = { + core_attn_out_shape[0], core_attn_out_shape[1], + core_attn_out_shape[2] * core_attn_out_shape[3], + core_attn_out_shape[4] + }; + core_attn_out = core_attn_out.reshape(reshape_shape); + core_attn_out = core_attn_out.slice(2, 0, sequence_length); + core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype); + return std::make_tuple(core_attn_out, last_recurrent_state); +} +} // namespace + +Qwen3NextGatedDeltaNetImpl::Qwen3NextGatedDeltaNetImpl(const ModelArgs& args, + const QuantArgs& quant_args, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options) { + const int64_t total_num_heads = args.n_heads(); + const int64_t total_num_kv_heads = args.n_kv_heads().value_or(args.n_heads()); + tp_size_ = parallel_args.tp_group_->world_size(); + rank_ = parallel_args.tp_group_->rank(); + num_k_heads_ = args.linear_num_key_heads(); + num_v_heads_ = args.linear_num_value_heads(); + head_k_dim_ = args.linear_key_head_dim(); + head_v_dim_ = args.linear_value_head_dim(); + k_size_ = num_k_heads_ * head_k_dim_; + v_size_ = num_v_heads_ * head_v_dim_; + conv_kernel_size_ = args.linear_conv_kernel_dim(); + // 0. QKVZ parallel linear + conv1d_ = register_module("conv1d", + ColumnParallelLinear(args.linear_conv_kernel_dim(), + k_size_ * 2 + v_size_, + /*bias=*/false, + /*gather_output=*/false, + quant_args, + parallel_args, + options)); + + + // 1. QKVZ parallel linear + qkvz_proj_ = register_module("in_proj_qkvz", + ColumnParallelLinear(args.hidden_size(), + k_size_ * 2 + v_size_ * 2, + /*bias=*/false, + /*gather_output=*/false, + quant_args, + parallel_args, + options)); + // 2. Output projection + ba_proj_ = register_module("in_proj_ba", + ColumnParallelLinear(args.hidden_size(), + num_v_heads_ * 2, + /*bias=*/false, + /*gather_output=*/false, + quant_args, + parallel_args, + options)); + auto opts = options.dtype(torch::kFloat32); + dt_bias_ = register_parameter("dt_bias", torch::ones({num_v_heads_ / tp_size_}, opts), /*requires_grad=*/false); + + A_log_ = register_parameter("A_log", torch::empty({num_v_heads_ / tp_size_}, opts), /*requires_grad=*/false); + // 3. Output projection + o_proj_ = register_module("out_proj", + RowParallelLinear(v_size_, + args.hidden_size(), + /*bias=*/false, + /*input_is_parallelized=*/true, + /*if_reduce_results=*/true, + quant_args, + parallel_args, + options)); + + // 4. RMSNorm + norm_ = register_module("norm", RmsNormGated(head_v_dim_, args.rms_norm_eps(), options)); + +} + +torch::Tensor Qwen3NextGatedDeltaNetImpl::forward( + const torch::Tensor& hidden_states, + const AttentionMetadata& attn_metadata, + KVCache& kv_cache, + const ModelInputParams& input_params) { + + auto qkvz = qkvz_proj_->forward(hidden_states); + auto qkvz_reshaped = reshape_qkvz_with_pad(attn_metadata, qkvz); + auto [q, k, v, z] = process_qkvz_tensor(qkvz_reshaped); + auto ba = ba_proj_->forward(hidden_states); + auto ba_reshaped = reshape_qkvz_with_pad(attn_metadata, ba); + auto [b, a] = process_ba_tensor(ba_reshaped); + auto rearrange_merge = [](const torch::Tensor& t) { + TORCH_CHECK(t.dim() > 2, "Tensor must have at least 2 dims! but got ", t.dim()); + std::vector new_shape; + int64_t slice_end = t.dim() - 2; + auto valid_slice = t.sizes().slice(0, slice_end); + new_shape = std::vector(valid_slice.begin(), valid_slice.end()); + int64_t last_two_dim = t.size(slice_end) * t.size(slice_end + 1); + new_shape.push_back(last_two_dim); + return t.reshape(new_shape); + }; + + q = rearrange_merge(q); + k = rearrange_merge(k); + v = rearrange_merge(v); + + torch::Tensor mixed_qkv = torch::cat({q, k, v}, q.dim() - 1); + mixed_qkv = mixed_qkv.transpose(1,2); + int64_t seq_len = mixed_qkv.size(2); + torch::Tensor conv_cache = kv_cache.get_conv_cache(); + torch::Tensor ssm_cache = kv_cache.get_ssm_cache(); + torch::Tensor g, beta, core_attn_out, last_recurrent_state; + auto device = mixed_qkv.device(); + auto conv_weight = conv1d_->weight(); + + if (attn_metadata.is_prefill) { + torch::Tensor conv_state = (seq_len < conv_kernel_size_-1) ? torch::pad(mixed_qkv, {0, conv_kernel_size_-1-seq_len}) : (seq_len > conv_kernel_size_-1) ? mixed_qkv.narrow(-1, seq_len-conv_kernel_size_+1, conv_kernel_size_-1): mixed_qkv; + conv_cache.index_put_({input_params.block_tables.select(1,0)}, conv_state.to(conv_cache.dtype())); + torch::Tensor bias; + auto conv_output = torch::conv1d( + mixed_qkv, + conv_weight.unsqueeze(1).to(device), + bias, + /*stride=*/std::vector{1}, + /*padding=*/std::vector{3}, + /*dilation=*/std::vector{1}, + /*groups=*/static_cast(mixed_qkv.size(1)) + ); + mixed_qkv = torch::silu(conv_output.slice(2,0,seq_len)); + + } else { + xllm::kernel::CausalConv1dUpdateParams params; + params.x = mixed_qkv; + params.conv_state = conv_cache; + params.weight = conv_weight; + params.conv_state_indices = attn_metadata.block_table.select(1,0); + mixed_qkv = xllm::kernel::causal_conv1d_update(params); + } + + if (attn_metadata.is_prefill) { + beta = torch::sigmoid(b); + torch::Tensor A_log_exp = A_log_.exp(); + torch::Tensor a_float = a.to(torch::kFloat32); + torch::Tensor a_plus_dt = a_float + dt_bias_; + torch::Tensor softplus_out = torch::nn::functional::softplus( + a_plus_dt, + torch::nn::functional::SoftplusFuncOptions().beta(1.0).threshold(20.0) + ); + g = -A_log_exp * softplus_out; + g = g.to(a.dtype()).contiguous(); + } else { + xllm::kernel::FusedGdnGatingParams gdn_params; + gdn_params.A_log = A_log_; + gdn_params.a = a.view({-1, a.size(-1)}); + gdn_params.b = b.view({-1, b.size(-1)}); + gdn_params.dt_bias = dt_bias_; + gdn_params.beta = 1.0f; + gdn_params.threshold = 20.0f; + std::tie(g, beta) = xllm::kernel::fused_gdn_gating(gdn_params); + } + + auto [processed_q, processed_k, processed_v] = process_mixed_qkv(mixed_qkv); + int64_t repeat_times = num_v_heads_ / num_k_heads_; + if (repeat_times > 1) { + processed_q = processed_q.repeat_interleave(repeat_times, 2); + processed_k = processed_k.repeat_interleave(repeat_times, 2); + } + if (attn_metadata.is_prefill) { + std::tie(core_attn_out, last_recurrent_state) = torch_chunk_gated_delta_rule(processed_q, processed_k, processed_v, g, beta); + ssm_cache.index_put_({input_params.block_tables.select(1,0)}, last_recurrent_state.to(ssm_cache.dtype())); + } else { + auto ssm_state = torch::index_select(ssm_cache, 0, attn_metadata.block_table.select(1,0)); + std::tie(core_attn_out, last_recurrent_state) = torch_recurrent_gated_delta_rule(processed_q, processed_k, processed_v, g, beta, ssm_state); + ssm_cache.index_put_({attn_metadata.block_table.select(1,0)}, last_recurrent_state.to(ssm_cache.dtype())); + } + + auto z_reshaped = z.view({-1, z.size(-1)}); + auto core_attn_out_reshaped = core_attn_out.view({-1, core_attn_out.size(-1)}); + auto norm_out = norm_->forward(core_attn_out_reshaped, z_reshaped); + auto z_shape_og = z.sizes().vec(); + norm_out = norm_out.view(z_shape_og); + norm_out = norm_out.view({-1, norm_out.size(2), norm_out.size(3)}); + + auto rearranged_norm = rearrange_merge(norm_out); + rearranged_norm = reshape_qkvz_unpad(attn_metadata, rearranged_norm); + auto attn_output = o_proj_->forward(rearranged_norm); + return attn_output; +} + +torch::Tensor Qwen3NextGatedDeltaNetImpl::reshape_qkvz_unpad(const AttentionMetadata& attn_metadata, const torch::Tensor& padded_qkvz) { + if (!attn_metadata.is_prefill) { + return padded_qkvz; + } + std::vector valid_batches; + int64_t bs = attn_metadata.query_start_loc.size(0); + int64_t max_len = attn_metadata.max_query_len; + const auto& ori_seq_lens = attn_metadata.query_start_loc; + auto reshaped_qkvz = padded_qkvz.view({bs, max_len, -1}); + for (int64_t b = 0; b < bs; ++b) { + int64_t ori_len = ori_seq_lens[b].item(); + torch::Tensor valid_batch = reshaped_qkvz[b].slice(0, 0, ori_len); + valid_batches.push_back(valid_batch); + } + return torch::cat(valid_batches, 0).contiguous(); +} + +torch::Tensor Qwen3NextGatedDeltaNetImpl::reshape_qkvz_with_pad(const AttentionMetadata& attn_metadata, const torch::Tensor& qkvz) { + int64_t bs = attn_metadata.query_start_loc.size(0); + int64_t max_len = attn_metadata.max_query_len; + const auto& start_loc = attn_metadata.query_start_loc; + if (!attn_metadata.is_prefill) { + return qkvz.view({bs, -1, qkvz.size(-1)}); + } + std::vector batches; + int64_t idx = 0; + for (int64_t b = 0; b < bs; ++b) { + int64_t cur_len = start_loc[b].item(); + torch::Tensor batch = qkvz.slice(0, idx, idx + cur_len).contiguous(); + idx = idx + cur_len; + if (batch.size(0) != max_len) { + batch = batch.size(0) > max_len + ? batch.slice(0, 0, max_len).contiguous() + : torch::nn::functional::pad( + batch, + torch::nn::functional::PadFuncOptions({0, 0, 0, max_len - batch.size(0)}) + ).contiguous(); + } + batches.push_back(batch); + } + auto ret = torch::stack(batches, 0).contiguous(); + return ret; +} + + +std::tuple +Qwen3NextGatedDeltaNetImpl::process_mixed_qkv(torch::Tensor& mixed_qkv) { + mixed_qkv = mixed_qkv.transpose(1,2); + int64_t batch_size = mixed_qkv.size(0); + int64_t seq_len = mixed_qkv.size(1); + std::vector split_sizes = {k_size_ / tp_size_, k_size_ / tp_size_, v_size_ / tp_size_}; + auto processed_qkv = torch::split(mixed_qkv, split_sizes, 2); + auto processed_q = processed_qkv[0]; + auto processed_k = processed_qkv[1]; + auto processed_v = processed_qkv[2]; + processed_q = processed_q.view({batch_size, seq_len, num_k_heads_ / tp_size_, head_k_dim_}); + processed_k = processed_k.view({batch_size, seq_len, num_k_heads_ / tp_size_, head_k_dim_}); + processed_v = processed_v.view({batch_size, seq_len, num_v_heads_ / tp_size_, head_v_dim_}); + return std::make_tuple(processed_q, processed_k, processed_v); +} + +std::tuple +Qwen3NextGatedDeltaNetImpl::process_qkvz_tensor(const torch::Tensor& qkvz) { + + std::vector new_tensor_shape_qkvz = [&]() { + std::vector dims; + dims.push_back(qkvz.size(0)); + if (qkvz.dim() >= 3) { + dims.push_back(qkvz.size(1)); + } + + int64_t dim1 = num_k_heads_ / tp_size_; + int64_t dim2 = head_k_dim_ + head_k_dim_ + (head_v_dim_ + head_v_dim_) * num_v_heads_ / num_k_heads_; + dims.push_back(dim1); + dims.push_back(dim2); + + return dims; + }(); + + auto reshaped_qkvz = qkvz.view(new_tensor_shape_qkvz); + auto qkvz_split = torch::split(reshaped_qkvz, + {head_k_dim_, head_k_dim_, + num_v_heads_ * head_v_dim_ / num_k_heads_, + num_v_heads_ * head_v_dim_ / num_k_heads_}, reshaped_qkvz.dim()-1); + + auto q = qkvz_split[0].contiguous(); + auto k = qkvz_split[1].contiguous(); + auto v = qkvz_split[2].contiguous(); + auto z = qkvz_split[3].contiguous(); + + + v = v.view({v.size(0), v.size(1), -1, head_v_dim_}); + z = z.view({z.size(0), z.size(1), -1, head_v_dim_}); + + return std::make_tuple(q, k, v, z); +} + + +std::tuple +Qwen3NextGatedDeltaNetImpl::process_ba_tensor(const torch::Tensor& ba) { + + std::vector new_tensor_shape_ba = [&]() { + std::vector dims; + dims.push_back(ba.size(0)); + dims.push_back(ba.size(1)); + int64_t dim1 = num_k_heads_ / tp_size_; + int64_t dim2 = 2 * num_v_heads_ / num_k_heads_; + dims.push_back(dim1); + dims.push_back(dim2); + return dims; + }(); + + auto reshaped_ba = ba.view(new_tensor_shape_ba); + auto ba_split = torch::split(reshaped_ba, + {num_v_heads_ / num_k_heads_, num_v_heads_ / num_k_heads_}, reshaped_ba.dim()-1); + + auto b = ba_split[0].contiguous(); + auto a = ba_split[1].contiguous(); + + b = b.reshape({b.size(0), b.size(1), num_v_heads_ / tp_size_}); + a = a.reshape({a.size(0), a.size(1), num_v_heads_ / tp_size_}); + + return std::make_tuple(b, a); +} + +void Qwen3NextGatedDeltaNetImpl::load_state_dict(const StateDict& state_dict) { + const int64_t rank = rank_; + const int64_t world_size = tp_size_; + const int32_t shard_tensor_count = 3; + const std::vector shard_sizes = {k_size_ / tp_size_, k_size_ / tp_size_, v_size_ / tp_size_}; + qkvz_proj_->load_state_dict(state_dict.get_dict_with_prefix("in_proj_qkvz.")); + ba_proj_->load_state_dict(state_dict.get_dict_with_prefix("in_proj_ba.")); + + if (auto w = state_dict.get_tensor("conv1d.weight"); w.defined()) { + conv1d_->load_state_dict(StateDict({{"weight", w.squeeze(1)}}), shard_tensor_count, shard_sizes); + } + o_proj_->load_state_dict(state_dict.get_dict_with_prefix("out_proj.")); + if (auto w = state_dict.get_tensor("norm.weight"); w.defined()) { + norm_->load_state_dict(StateDict({{"weight", w}})); + } + LOAD_SHARDED_WEIGHT(dt_bias, 0); + LOAD_SHARDED_WEIGHT(A_log, 0); +} + +} // namespace layer +} // namespace xllm diff --git a/xllm/core/layers/common/qwen3_next_gated_delta_net.h b/xllm/core/layers/common/qwen3_next_gated_delta_net.h new file mode 100644 index 000000000..8f4120af0 --- /dev/null +++ b/xllm/core/layers/common/qwen3_next_gated_delta_net.h @@ -0,0 +1,81 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include + +#include "attention.h" +#include "linear.h" +#include "rms_norm_gated.h" +#include "framework/kv_cache/kv_cache.h" +#include "framework/model/model_args.h" +#include "framework/parallel_state/parallel_args.h" +#include "framework/quant_args.h" +#include "framework/state_dict/state_dict.h" +#include "framework/state_dict/utils.h" + +namespace xllm { +namespace layer { + +class Qwen3NextGatedDeltaNetImpl : public torch::nn::Module { + public: + Qwen3NextGatedDeltaNetImpl() = default; + Qwen3NextGatedDeltaNetImpl(const ModelArgs& args, + const QuantArgs& quant_args, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options); + + torch::Tensor forward(const torch::Tensor& hidden_states, + const AttentionMetadata& attn_metadata, + KVCache& kv_cache, + const ModelInputParams& input_params); + + void load_state_dict(const StateDict& state_dict); + + private: + std::tuple + process_qkvz_tensor(const torch::Tensor& qkvz); + std::tuple process_ba_tensor(const torch::Tensor& ba); + std::tuple process_mixed_qkv(torch::Tensor& mixed_qkv); + + torch::Tensor reshape_qkvz_with_pad(const AttentionMetadata& attn_metadata, const torch::Tensor& qkvz); + torch::Tensor reshape_qkvz_unpad(const AttentionMetadata& attn_metadata, const torch::Tensor& padded_qkvz); + + int64_t num_k_heads_; + int64_t num_v_heads_; + int64_t num_kv_head_replicas_; + int64_t head_k_dim_; + int64_t head_v_dim_; + int64_t k_size_; + int64_t v_size_; + int64_t tp_size_; + int64_t rank_; + int32_t conv_kernel_size_; + + ColumnParallelLinear qkvz_proj_{nullptr}; + ColumnParallelLinear ba_proj_{nullptr}; + ColumnParallelLinear conv1d_{nullptr}; + + RowParallelLinear o_proj_{nullptr}; + + RmsNormGated norm_{nullptr}; + DEFINE_WEIGHT(dt_bias); + DEFINE_WEIGHT(A_log); +}; +TORCH_MODULE(Qwen3NextGatedDeltaNet); + +} // namespace layer +} // namespace xllm diff --git a/xllm/core/layers/common/qwen3_next_rms_norm.cpp b/xllm/core/layers/common/qwen3_next_rms_norm.cpp new file mode 100644 index 000000000..6e3df11e9 --- /dev/null +++ b/xllm/core/layers/common/qwen3_next_rms_norm.cpp @@ -0,0 +1,47 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "qwen3_next_rms_norm.h" + +#include + +namespace xllm { +namespace layer { + +Qwen3NextRMSNormImpl::Qwen3NextRMSNormImpl(int64_t dim, + double eps, + const torch::TensorOptions& options) + : norm_dim_(dim), eps_(eps) { + weight_ = register_parameter("weight", torch::empty({dim}, options),false); +} + +torch::Tensor Qwen3NextRMSNormImpl::forward(torch::Tensor& input) { + auto input_dtype = input.dtype(); + input = input.to(torch::kFloat32); + + // Calculate RMS + auto variance = torch::mean(torch::pow(input, 2), -1, true); + auto normalized = input * torch::rsqrt(variance + eps_); + + // Apply weight and convert back to original dtype + return (normalized * (1.0f + weight_.to(torch::kFloat32))).to(input_dtype); +} + +void Qwen3NextRMSNormImpl::load_state_dict(const StateDict& state_dict) { + LOAD_WEIGHT(weight); +} + +} // namespace layer +} // namespace xllm diff --git a/xllm/core/layers/common/qwen3_next_rms_norm.h b/xllm/core/layers/common/qwen3_next_rms_norm.h new file mode 100644 index 000000000..724847dfc --- /dev/null +++ b/xllm/core/layers/common/qwen3_next_rms_norm.h @@ -0,0 +1,44 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include + +#include "framework/state_dict/state_dict.h" +#include "framework/state_dict/utils.h" + +namespace xllm { +namespace layer { + +class Qwen3NextRMSNormImpl : public torch::nn::Module { + public: + Qwen3NextRMSNormImpl(int64_t dim, + double eps, + const torch::TensorOptions& options); + + torch::Tensor forward(torch::Tensor& input); + + void load_state_dict(const StateDict& state_dict); + + private: + DEFINE_WEIGHT(weight); + int64_t norm_dim_; + double eps_; +}; +TORCH_MODULE(Qwen3NextRMSNorm); + +} // namespace layer +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/layers/common/rms_norm_gated.cpp b/xllm/core/layers/common/rms_norm_gated.cpp new file mode 100644 index 000000000..c8682f92d --- /dev/null +++ b/xllm/core/layers/common/rms_norm_gated.cpp @@ -0,0 +1,57 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "rms_norm_gated.h" + +#include + +#include "framework/state_dict/utils.h" +#include "xllm/core/kernels/ops_api.h" + +namespace xllm { +namespace layer { + +RmsNormGatedImpl::RmsNormGatedImpl(int64_t dim, + double eps, + const torch::TensorOptions& options) + : norm_dim_(dim), eps_(eps) { + weight_ = register_parameter("weight", torch::empty({dim}, options), /*requires_grad=*/false); +} + +torch::Tensor RmsNormGatedImpl::forward(torch::Tensor& input, std::optional gate) { + xllm::kernel::GatedLayerNormParams params; + auto input_type = input.dtype(); + input = input.to(torch::kFloat32); + params.x = input; + params.weight = weight_.to(torch::kFloat32); + torch::Tensor bias; + params.bias = bias; + params.eps = eps_; + if (gate.has_value()) { + gate = gate.value().to(torch::kFloat32); + params.z = gate; + } + params.group_size = input.size(-1); + params.is_rms_norm = true; + auto ret = xllm::kernel::gated_layer_norm(params); + return ret.to(input_type); +} + +void RmsNormGatedImpl::load_state_dict(const StateDict& state_dict) { + LOAD_WEIGHT(weight); +} + +} // namespace layer +} // namespace xllm diff --git a/xllm/core/layers/common/rms_norm_gated.h b/xllm/core/layers/common/rms_norm_gated.h new file mode 100644 index 000000000..fe1aed06b --- /dev/null +++ b/xllm/core/layers/common/rms_norm_gated.h @@ -0,0 +1,45 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include +#include "framework/state_dict/state_dict.h" +#include "framework/state_dict/utils.h" + +namespace xllm { +namespace layer { + +class RmsNormGatedImpl : public torch::nn::Module { + public: + RmsNormGatedImpl(int64_t dim, + double eps, + const torch::TensorOptions& options); + + torch::Tensor forward(torch::Tensor& input, + std::optional gate = std::nullopt); + + void load_state_dict(const StateDict& state_dict); + + private: + DEFINE_WEIGHT(weight); + int64_t norm_dim_; + double eps_; +}; +TORCH_MODULE(RmsNormGated); + +} // namespace layer +} // namespace xllm diff --git a/xllm/core/layers/qwen3_next_decoder_layer.cpp b/xllm/core/layers/qwen3_next_decoder_layer.cpp new file mode 100644 index 000000000..8135217cb --- /dev/null +++ b/xllm/core/layers/qwen3_next_decoder_layer.cpp @@ -0,0 +1,139 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "qwen3_next_decoder_layer.h" + +#include + +namespace xllm { +namespace layer { + +Qwen3NextDecoderLayerImpl::Qwen3NextDecoderLayerImpl(const ModelContext& context, + int32_t layer_id) { + const auto& model_args = context.get_model_args(); + const auto& quant_args = context.get_quant_args(); + const auto& parallel_args = context.get_parallel_args(); + const auto& options = context.get_tensor_options(); + // Initialize attention layers + if ((layer_id + 1) % 4 == 0) { + attention_ = register_module( + "self_attn", + Qwen3NextAttention(model_args, quant_args, parallel_args, options, layer_id)); + } else { + linear_attention_ = register_module( + "linear_attn", + Qwen3NextGatedDeltaNet(model_args, quant_args, parallel_args, options, layer_id)); + } + + // Initialize norm layers + input_norm_ = register_module( + "input_layernorm", + Qwen3NextRMSNorm(model_args.hidden_size(), model_args.rms_norm_eps(), options)); + + post_norm_ = register_module( + "post_attention_layernorm", + Qwen3NextRMSNorm(model_args.hidden_size(), model_args.rms_norm_eps(), options)); + + // Initialize mlp + auto mlp_only_layers = model_args.mlp_only_layers(); + if ((std::count(mlp_only_layers.begin(), mlp_only_layers.end(), layer_id) == + 0) && + model_args.num_experts() > 0 && + (layer_id + 1) % model_args.decoder_sparse_step() == 0) { + moe_mlp_ = register_module("mlp", + FusedMoE(model_args, + FusedMoEArgs{.is_gated = true}, + quant_args, + parallel_args, + options)); + } else { + mlp_ = register_module("mlp", + DenseMLP(model_args.hidden_size(), + model_args.intermediate_size(), + false, + false, + model_args.hidden_act(), + quant_args, + parallel_args, + options)); + } +} + +void Qwen3NextDecoderLayerImpl::load_state_dict(const StateDict& state_dict) { + if (attention_) { + attention_->load_state_dict(state_dict.get_dict_with_prefix("self_attn.")); + } else { + linear_attention_->load_state_dict(state_dict.get_dict_with_prefix("linear_attn.")); + } + input_norm_->load_state_dict( + state_dict.get_dict_with_prefix("input_layernorm.")); + post_norm_->load_state_dict( + state_dict.get_dict_with_prefix("post_attention_layernorm.")); + if (moe_mlp_) { + moe_mlp_->load_state_dict(state_dict.get_dict_with_prefix("mlp.")); + } else { + mlp_->load_state_dict(state_dict.get_dict_with_prefix("mlp.")); + } +} + +torch::Tensor Qwen3NextDecoderLayerImpl::forward( + torch::Tensor& x, + torch::Tensor& positions, + const AttentionMetadata& attn_metadata, + KVCache& kv_cache, + const ModelInputParams& input_params) { + // Pre-attention norm + torch::Tensor residual = x; + x = input_norm_(x); + + // Attention + if (attention_) { + x = attention_->forward(positions, x, attn_metadata, kv_cache); + } else { + //x = x; + x = linear_attention_->forward(x, attn_metadata, kv_cache, input_params); + } + + auto orig_dtype = x.dtype(); + if (orig_dtype == torch::kBFloat16) { + x = x.to(torch::kFloat32); + residual = residual.to(torch::kFloat32); + } + x = x + residual; + + // Post-attention norm + residual = x; + x = x.to(orig_dtype); + x = post_norm_(x); + + // MLP forward + if (moe_mlp_) { + x = moe_mlp_(x, input_params); + } else { + x = mlp_(x); + } + + orig_dtype = x.dtype(); + if (orig_dtype == torch::kBFloat16) { + x = x.to(torch::kFloat32); + residual = residual.to(torch::kFloat32); + } + x = x + residual; + x = x.to(orig_dtype); + return x; +} + +} // namespace layer +} // namespace xllm diff --git a/xllm/core/layers/qwen3_next_decoder_layer.h b/xllm/core/layers/qwen3_next_decoder_layer.h new file mode 100644 index 000000000..8dc855a2d --- /dev/null +++ b/xllm/core/layers/qwen3_next_decoder_layer.h @@ -0,0 +1,61 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include + +#include + +#include "common/dense_mlp.h" +#include "common/qwen3_next_rms_norm.h" +#include "common/qwen3_next_attention.h" +#include "common/qwen3_next_gated_delta_net.h" +#include "layers/npu/fused_moe.h" +#include "framework/kv_cache/kv_cache.h" +#include "framework/model/model_args.h" +#include "framework/model/model_input_params.h" +#include "framework/model_context.h" +#include "framework/state_dict/state_dict.h" + + +namespace xllm { +namespace layer { + +class Qwen3NextDecoderLayerImpl : public torch::nn::Module { + public: + explicit Qwen3NextDecoderLayerImpl(const ModelContext& context, int32_t layer_id); + + void load_state_dict(const StateDict& state_dict); + + torch::Tensor forward(torch::Tensor& x, + torch::Tensor& positions, + const AttentionMetadata& attn_metadata, + KVCache& kv_cache, + const ModelInputParams& input_params); + + private: + Qwen3NextAttention attention_{nullptr}; + Qwen3NextGatedDeltaNet linear_attention_{nullptr}; + + DenseMLP mlp_{nullptr}; + FusedMoE moe_mlp_{nullptr}; + + Qwen3NextRMSNorm input_norm_{nullptr}; + Qwen3NextRMSNorm post_norm_{nullptr}; +}; + +} // namespace layer +} // namespace xllm From 7d7787a1437d274f9e681a93f178478bbc433932 Mon Sep 17 00:00:00 2001 From: "ext.wangxiaochi1" Date: Thu, 26 Feb 2026 19:34:52 +0800 Subject: [PATCH 02/13] support linear attention cache --- third_party/torch_npu_ops | 2 +- .../core/distributed_runtime/comm_channel.cpp | 12 +- xllm/core/distributed_runtime/engine.h | 1 + xllm/core/distributed_runtime/llm_engine.cpp | 57 ++- xllm/core/distributed_runtime/llm_engine.h | 2 + xllm/core/framework/kv_cache/kv_cache.cpp | 15 + xllm/core/framework/kv_cache/kv_cache.h | 13 + xllm/core/framework/model/model_args.h | 11 + xllm/core/runtime/worker_impl.cpp | 13 +- xllm/models/llm/qwen3_next.h | 349 ++++++++++++++++++ xllm/models/models.h | 1 + xllm/proto/worker.proto | 2 + 12 files changed, 465 insertions(+), 13 deletions(-) create mode 100644 xllm/models/llm/qwen3_next.h diff --git a/third_party/torch_npu_ops b/third_party/torch_npu_ops index 90773524d..e7d254285 160000 --- a/third_party/torch_npu_ops +++ b/third_party/torch_npu_ops @@ -1 +1 @@ -Subproject commit 90773524d2d69220fc80f7845b4570eabfccfd0e +Subproject commit e7d254285a3e491abb7ba14e723be0d2909df3a5 diff --git a/xllm/core/distributed_runtime/comm_channel.cpp b/xllm/core/distributed_runtime/comm_channel.cpp index 56f613193..b209d28ee 100644 --- a/xllm/core/distributed_runtime/comm_channel.cpp +++ b/xllm/core/distributed_runtime/comm_channel.cpp @@ -87,13 +87,23 @@ bool CommChannel::allocate_kv_cache( } // add index shape if exists - if (kv_cache_shape.size() > 2) { + if (kv_cache_shape.size() == 3) { shape->mutable_index_shape()->Reserve(kv_cache_shape[2].size()); for (size_t i = 0; i < kv_cache_shape[2].size(); ++i) { shape->add_index_shape(kv_cache_shape[2][i]); } } + if (kv_cache_shape.size() == 4) { + shape.mutable_conv_shape()->Reserve(kv_cache_shape[2].size()); + shape.mutable_ssm_shape()->Reserve(kv_cache_shape[3].size()); + for (size_t i = 0; i < kv_cache_shape[2].size(); ++i) { + shape.add_conv_shape(kv_cache_shape[2][i]); + } + for (size_t i = 0; i < kv_cache_shape[3].size(); ++i) { + shape.add_ssm_shape(kv_cache_shape[3][i]); + } + } proto::Status s; brpc::Controller cntl; stub_->AllocateKVCache(&cntl, &request, &s, nullptr); diff --git a/xllm/core/distributed_runtime/engine.h b/xllm/core/distributed_runtime/engine.h index 3b8c0db63..206212cc9 100644 --- a/xllm/core/distributed_runtime/engine.h +++ b/xllm/core/distributed_runtime/engine.h @@ -180,6 +180,7 @@ class Engine { int64_t cache_size_in_bytes = 0; int64_t slot_size = 0; int64_t index_slot_size = 0; + int64_t linear_slot_size = 0; int64_t n_layers = 0; }; diff --git a/xllm/core/distributed_runtime/llm_engine.cpp b/xllm/core/distributed_runtime/llm_engine.cpp index 59d5d1732..8726dfa79 100644 --- a/xllm/core/distributed_runtime/llm_engine.cpp +++ b/xllm/core/distributed_runtime/llm_engine.cpp @@ -193,7 +193,12 @@ bool LLMEngine::init_model(MasterStatus master_status) { n_local_q_heads_ = std::max(1, n_heads / world_size); head_dim_ = args_.head_dim(); dtype_ = util::parse_dtype(args_.dtype(), options_.devices()[0]); - + if (args_.full_attention_interval() > 1) { + const int64_t linear_n_k_heads = args_.linear_num_key_heads(); + const int64_t linear_n_v_heads = args_.linear_num_value_heads(); + n_local_linear_k_heads_ = std::max(1, linear_n_k_heads / world_size); + n_local_linear_v_heads_ = std::max(1, linear_n_v_heads / world_size); + } // key + value for all layers LOG(INFO) << "Block info, block_size: " << options_.block_size() << ", n_local_kv_heads: " << n_local_kv_heads_ @@ -383,7 +388,7 @@ Engine::KVCacheCapacity LLMEngine::estimate_kv_cache_capacity() { int64_t index_slot_size = 0; int64_t scale_slot_size = 0; // Extra overhead for scale tensors in quant mode - + int64_t linear_slot_size = 0; if (FLAGS_enable_mla) { #if defined(USE_NPU) if (args_.model_type() == "deepseek_v3" && FLAGS_enable_prefix_cache) { @@ -425,8 +430,16 @@ Engine::KVCacheCapacity LLMEngine::estimate_kv_cache_capacity() { scale_slot_size = 2 * sizeof(float) * n_local_kv_heads_; } } + if (args_.linear_num_value_heads() > 0) { + int64_t head_k_dim = args_.linear_value_head_dim(); + int64_t head_v_dim = args_.linear_key_head_dim(); + int64_t linear_ssm_slot_size = dtype_size * n_local_linear_v_heads_ * head_k_dim * head_v_dim; + int64_t linear_conv_slot_size = dtype_size * (head_k_dim * n_local_linear_k_heads_ * 2 + head_v_dim * n_local_linear_v_heads_) * (args_.linear_conv_kernel_dim() -1 ); + linear_slot_size = linear_ssm_slot_size + linear_conv_slot_size; + } kv_cache_cap.slot_size = slot_size; kv_cache_cap.index_slot_size = index_slot_size; + kv_cache_cap.linear_slot_size = linear_slot_size; kv_cache_cap.n_layers = args_.n_layers(); #if !defined(USE_NPU) // this adoption is because the allocation of kv cache is based on @@ -441,13 +454,25 @@ Engine::KVCacheCapacity LLMEngine::estimate_kv_cache_capacity() { } #endif - // compute kv cache n_blocks - const int32_t block_size = options_.block_size(); - const int64_t block_size_in_bytes = - block_size * (slot_size + index_slot_size + scale_slot_size); - kv_cache_cap.n_blocks = kv_cache_cap.cache_size_in_bytes / - (kv_cache_cap.n_layers * block_size_in_bytes); - CHECK_GT(kv_cache_cap.n_blocks, 0) << "no n_blocks for kv cache"; + if (!FLAGS_enable_continuous_kvcache) { + // compute kv cache n_blocks + const int32_t block_size = options_.block_size(); + const int64_t block_size_in_bytes = + block_size * (slot_size + index_slot_size + scale_slot_size) + linear_slot_size; + kv_cache_cap.n_blocks = kv_cache_cap.cache_size_in_bytes / + (kv_cache_cap.n_layers * block_size_in_bytes); + CHECK_GT(kv_cache_cap.n_blocks, 0) << "no n_blocks for kv cache"; + } else { + int32_t n_pages = + kv_cache_cap.cache_size_in_bytes / FLAGS_phy_page_granularity_size; + if (FLAGS_enable_mla) { + n_pages -= n_pages % (kv_cache_cap.n_layers); + } else { + n_pages -= n_pages % (2 * kv_cache_cap.n_layers); + } + kv_cache_cap.n_pages = n_pages; + CHECK_GT(kv_cache_cap.n_pages, 0) << "no n_pages for kv cache"; + } return kv_cache_cap; } @@ -462,6 +487,7 @@ bool LLMEngine::allocate_kv_cache(const Engine::KVCacheCapacity& kv_cache_cap) { CHECK_GT(kv_cache_cap.n_blocks, 0) << "no memory for kv cache"; const int32_t block_size = options_.block_size(); bool enable_lighting_indexer = args_.index_n_heads() > 1; + bool enable_linear_attention = args_.full_attention_interval() > 1; // init kv cache for each worker std::vector> kv_cache_shape; @@ -501,6 +527,15 @@ bool LLMEngine::allocate_kv_cache(const Engine::KVCacheCapacity& kv_cache_cap) { kv_cache_shape.emplace_back(std::vector{ kv_cache_cap.n_blocks, block_size, 1, args_.index_head_dim()}); } + if (enable_linear_attention) { + kv_cache_shape.emplace_back(std::vector{ + kv_cache_cap.n_blocks, + args_.linear_key_head_dim() * n_local_linear_k_heads_ * 2 + + args_.linear_key_head_dim() * n_local_linear_v_heads_, args_.linear_conv_kernel_dim() - 1}); + kv_cache_shape.emplace_back(std::vector{ + kv_cache_cap.n_blocks, n_local_linear_v_heads_, args_.linear_key_head_dim(), + args_.linear_key_head_dim()}); + } #if defined(USE_MLU) // transpose kv_cache layout for mlu // default layout: [n_blocks, block_size, n_head, head_dim] @@ -525,6 +560,10 @@ bool LLMEngine::allocate_kv_cache(const Engine::KVCacheCapacity& kv_cache_cap) { LOG(INFO) << "Initializing indexer cache with shape: [" << kv_cache_shape[2] << "]"; } + if (enable_linear_attention) { + LOG(INFO) << "Initializing conv cache with shape: [" << kv_cache_shape[2] << "]"; + LOG(INFO) << "Initializing ssm cache with shape: [" << kv_cache_shape[3] << "]"; + } // initialize block manager BlockManagerPool::Options options; diff --git a/xllm/core/distributed_runtime/llm_engine.h b/xllm/core/distributed_runtime/llm_engine.h index 2cb0e1e9d..2e155a20b 100644 --- a/xllm/core/distributed_runtime/llm_engine.h +++ b/xllm/core/distributed_runtime/llm_engine.h @@ -161,6 +161,8 @@ class LLMEngine : public Engine { int64_t n_local_kv_heads_ = 0; int64_t n_local_q_heads_ = 0; int64_t head_dim_ = 0; + int64_t n_local_linear_v_heads_ = 0; + int64_t n_local_linear_k_heads_ = 0; // common frequently used args uint32_t dp_size_; diff --git a/xllm/core/framework/kv_cache/kv_cache.cpp b/xllm/core/framework/kv_cache/kv_cache.cpp index b3c427a6a..f2af38ead 100644 --- a/xllm/core/framework/kv_cache/kv_cache.cpp +++ b/xllm/core/framework/kv_cache/kv_cache.cpp @@ -37,9 +37,24 @@ KVCache::KVCache(torch::Tensor key_cache, index_cache_(std::move(index_cache)), key_cache_scale_(std::move(key_cache_scale)), value_cache_scale_(std::move(value_cache_scale)) {} + +KVCache::KVCache(torch::Tensor key_cache, + torch::Tensor value_cache, + torch::Tensor conv_cache, + torch::Tensor ssm_cache) + : key_cache_(std::move(key_cache)), + value_cache_(std::move(value_cache)), + conv_cache_(std::move(conv_cache)), + ssm_cache_(std::move(ssm_cache)) {} + +KVCache::KVCache(std::shared_ptr key_xtensor, + std::shared_ptr value_xtensor) + : key_xtensor_(key_xtensor), value_xtensor_(value_xtensor) {} torch::Tensor KVCache::get_k_cache() const { return key_cache_; } torch::Tensor KVCache::get_v_cache() const { return value_cache_; } torch::Tensor KVCache::get_index_cache() const { return index_cache_; } +torch::Tensor KVCache::get_conv_cache() const { return conv_cache_; } +torch::Tensor KVCache::get_ssm_cache() const { return ssm_cache_; } std::optional KVCache::get_k_cache_scale() const { if (!key_cache_scale_.defined() || key_cache_scale_.numel() == 0) { diff --git a/xllm/core/framework/kv_cache/kv_cache.h b/xllm/core/framework/kv_cache/kv_cache.h index 1e8e53153..5a407c8ee 100644 --- a/xllm/core/framework/kv_cache/kv_cache.h +++ b/xllm/core/framework/kv_cache/kv_cache.h @@ -37,6 +37,12 @@ class KVCache final { torch::Tensor index_cache, torch::Tensor key_cache_scale, torch::Tensor value_cache_scale); + KVCache(torch::Tensor key_cache, + torch::Tensor value_cache, + torch::Tensor conv_cache, + torch::Tensor ssm_cache); + KVCache(std::shared_ptr key_xtensor, + std::shared_ptr value_xtensor); ~KVCache() = default; // TODO: pass in kv_shape and options instead @@ -48,6 +54,8 @@ class KVCache final { std::optional get_k_cache_scale() const; std::optional get_v_cache_scale() const; + torch::Tensor get_conv_cache() const; + torch::Tensor get_ssm_cache() const; std::vector> get_shapes(); bool empty() const { @@ -64,6 +72,11 @@ class KVCache final { // scale tensors for quantized KV cache (int8) torch::Tensor key_cache_scale_; torch::Tensor value_cache_scale_; + torch::Tensor conv_cache_; + torch::Tensor ssm_cache_; + // for continuous kvcache + std::shared_ptr key_xtensor_; + std::shared_ptr value_xtensor_; }; } // namespace xllm diff --git a/xllm/core/framework/model/model_args.h b/xllm/core/framework/model/model_args.h index b8c4178c3..31672cd98 100644 --- a/xllm/core/framework/model/model_args.h +++ b/xllm/core/framework/model/model_args.h @@ -171,6 +171,17 @@ struct ModelArgs { PROPERTY(int32_t, rope_scaling) = -1; PROPERTY(float, router_aux_loss_coef) = 0.001f; + // qwen3 next + PROPERTY(bool, attn_output_gate) = true; + PROPERTY(int32_t, full_attention_interval) = 4; + PROPERTY(int32_t, linear_conv_kernel_dim) = 4; + PROPERTY(int32_t, linear_key_head_dim) = 128; + PROPERTY(int32_t, linear_value_head_dim) = 128; + PROPERTY(int64_t, linear_num_key_heads) = 16; + PROPERTY(int32_t, linear_num_value_heads) = 32; + PROPERTY(int32_t, shared_expert_intermediate_size) = 512; + PROPERTY(float, partial_rotary_factor) = 0.25f; + // Vision model's dropout PROPERTY(float, mm_dropout) = 0.0f; diff --git a/xllm/core/runtime/worker_impl.cpp b/xllm/core/runtime/worker_impl.cpp index 47a1d4dc5..f48160672 100644 --- a/xllm/core/runtime/worker_impl.cpp +++ b/xllm/core/runtime/worker_impl.cpp @@ -155,6 +155,7 @@ bool WorkerImpl::allocate_kv_cache( const std::vector>& kv_cache_shape) { CHECK(model_ != nullptr) << "Model is not initialized."; CHECK(kv_caches_.empty()) << "KV caches are already initialized."; + const bool enable_linear_attention = context_.get_model_args().full_attention_interval() > 1; // Check if KV cache quantization is enabled // "auto" (default): cache dtype aligns with model dtype (no quantization) @@ -208,7 +209,7 @@ bool WorkerImpl::allocate_kv_cache( torch::ScalarType cache_dtype = enable_kv_cache_quant ? torch::kInt8 : dtype_; for (int64_t i = 0; i < num_layers; ++i) { - torch::Tensor key_cache, value_cache, index_cache; + torch::Tensor key_cache, value_cache, index_cache, conv_cache, ssm_cache; torch::Tensor key_cache_scale, value_cache_scale; #if defined(USE_NPU) aclFormat npu_format_type = @@ -230,6 +231,14 @@ bool WorkerImpl::allocate_kv_cache( torch::dtype(dtype_).device(device_)), npu_format_type); } + if (enable_linear_attention) { + conv_cache = at_npu::native::npu_format_cast( + torch::zeros(kv_cache_shape[2], torch::dtype(dtype_).device(device_)), + 2); + ssm_cache = at_npu::native::npu_format_cast( + torch::zeros(kv_cache_shape[3], torch::dtype(dtype_).device(device_)), + 2); + } #elif defined(USE_ILU) || defined(USE_MLU) || defined(USE_MUSA) key_cache = torch::zeros(kv_cache_shape[0], torch::dtype(cache_dtype).device(device_)); @@ -272,7 +281,7 @@ bool WorkerImpl::allocate_kv_cache( key_cache_scale, value_cache_scale); } else { - kv_caches_.emplace_back(key_cache, value_cache, index_cache); + kv_caches_.emplace_back(key_cache, value_cache, index_cache, conv_cache, ssm_cache); } } } diff --git a/xllm/models/llm/qwen3_next.h b/xllm/models/llm/qwen3_next.h new file mode 100644 index 000000000..72951fbe6 --- /dev/null +++ b/xllm/models/llm/qwen3_next.h @@ -0,0 +1,349 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include +#include + +#include "core/framework/model_context.h" +#include "core/layers/common/layer_utils.h" +#include "core/layers/qwen3_next_decoder_layer.h" +#include "llm_model_base.h" + +namespace xllm { + +using torch::indexing::None; +using ISlice = torch::indexing::Slice; + +class Qwen3NextDecoderLayerImpl : public torch::nn::Module { + public: + Qwen3NextDecoderLayerImpl(const ModelContext& context, const int32_t i) { + // register submodules + decoder_layer_ = register_module("decoder_layer", + layer::Qwen3NextDecoderLayer(context, i)); + } + + torch::Tensor forward(torch::Tensor& x, + torch::Tensor& positions, + const layer::AttentionMetadata& attn_metadata, + KVCache& kv_cache, + const ModelInputParams& input_params) { + return decoder_layer_(x, positions, attn_metadata, kv_cache, input_params); + } + + void load_state_dict(const StateDict& state_dict) { + decoder_layer_->load_state_dict(state_dict); + } + + private: + layer::Qwen3NextDecoderLayer decoder_layer_{nullptr}; +}; +TORCH_MODULE(Qwen3NextDecoderLayer); + +class Qwen3NextModelImpl : public torch::nn::Module { + public: + Qwen3NextModelImpl(const ModelContext& context) + : device_(context.get_tensor_options().device()) { + auto options = context.get_tensor_options(); + auto model_args = context.get_model_args(); + auto parallel_args = context.get_parallel_args(); + + blocks_ = register_module("layers", torch::nn::ModuleList()); + layers_.reserve(model_args.n_layers()); + // register submodules + device_ = options.device(); + dtype_ = options.dtype().toScalarType(); + num_speculative_tokens_ = model_args.num_speculative_tokens(); + +#if defined(USE_NPU) + norm_ = register_module( + "norm", + xllm::layer::Qwen3NextRMSNorm( + model_args.hidden_size(), model_args.rms_norm_eps(), options)); +#if defined(USE_NPU_TORCH) + embed_tokens_ = layer::WordEmbedding(model_args.vocab_size(), + model_args.hidden_size(), + context.get_parallel_args(), + options); +#else + for (auto i = 0; i < FLAGS_micro_batch_num; i++) { + npu_embed_tokens_.push_back(layer::NpuWordEmbedding(context)); + } +#endif +#endif + int32_t mask_value = FLAGS_enable_chunked_prefill ? -9984 : 1; + attn_mask_ = layer::AttentionMask(options.device(), + options.dtype().toScalarType(), + /*mask_value=*/mask_value); + for (int32_t i = 0; i < model_args.n_layers(); ++i) { + auto block = Qwen3NextDecoderLayer(context, i); + layers_.push_back(block); + blocks_->push_back(block); + } + + dp_size_ = parallel_args.dp_size(); + std::vector indices; + dp_local_tp_size_ = parallel_args.world_size() / dp_size_; + dp_rank_ = parallel_args.rank() / dp_local_tp_size_; + rank_ = parallel_args.rank(); + num_experts_per_tok_ = model_args.num_experts_per_tok(); + for (int i = 0; i < parallel_args.world_size(); i += dp_local_tp_size_) { + indices.push_back(i); + } + } + + // tokens: [num_tokens] + // positions: [num_tokens] token pos in the sequence + torch::Tensor forward(torch::Tensor tokens, + torch::Tensor positions, + std::vector& kv_caches, + const ModelInputParams& input_params) { + if (dp_size_ > 1) { + if (tokens.sizes() == 0) { + tokens = torch::tensor({1}).to(torch::kInt32).to(device_); + positions = torch::tensor({0}).to(torch::kInt32).to(device_); + } + } + +#if defined(USE_NPU) && defined(USE_NPU_TORCH) + // Create attention mask + torch::Tensor attn_mask; + max_seq_len_ = std::max(input_params.kv_max_seq_len, max_seq_len_); + + if (FLAGS_enable_chunked_prefill) { + int num_sequences = input_params.num_sequences; + if (num_sequences > 0) { + std::vector req_mask_vec; + req_mask_vec.reserve(num_sequences); + + for (int j = 0; j < num_sequences; j++) { + auto mask = attn_mask_.gen_append_mask( + input_params.q_seq_lens_vec[j], + input_params.kv_seq_lens_vec[j], + max_seq_len_, + dtype_, + device_); + req_mask_vec.emplace_back(mask); + } + attn_mask = torch::cat(req_mask_vec, 0); + } + } else { + attn_mask = attn_mask_.get_attn_mask(max_seq_len_, dtype_, device_); + } + + layer::AttentionMetadata attn_metadata = + layer::AttentionMetadata::build(input_params, input_params.q_max_seq_len > 1, attn_mask); + torch::Tensor h = embed_tokens_(tokens); + for (size_t i = 0; i < layers_.size(); i++) { + auto& layer = layers_[i]; + h = layer(h, positions, attn_metadata, kv_caches[i], input_params); + } + h = norm_(h); + return h; +#endif + } + + // load the weight from the checkpoint + void load_state_dict(const StateDict& state_dict) { + embed_tokens_->load_state_dict( + state_dict.get_dict_with_prefix("embed_tokens.")); + for (int i = 0; i < layers_.size(); i++) { + layers_[i]->load_state_dict( + state_dict.get_dict_with_prefix("layers." + std::to_string(i) + ".")); + } + norm_->load_state_dict(state_dict.get_dict_with_prefix("norm.")); + } + + std::vector get_word_embedding() { + return {npu_embed_tokens_}; + } + + void set_word_embedding( + std::vector& word_embedding) { + npu_embed_tokens_ = word_embedding; + } + + + private: + torch::nn::ModuleList blocks_{nullptr}; + std::vector layers_; + int32_t max_seq_len_ = 0; + int32_t dp_rank_; + int32_t rank_; + int32_t dp_size_; + int32_t dp_local_tp_size_; + nlohmann::json mapping_data_; + int32_t num_experts_per_tok_; + int32_t num_speculative_tokens_ = 0; + at::Device device_; + torch::Dtype dtype_; + std::vector npu_embed_tokens_; + layer::Qwen3NextRMSNorm norm_{nullptr}; + layer::AttentionMask attn_mask_; + +#if defined(USE_NPU) && defined(USE_NPU_TORCH) + layer::WordEmbedding embed_tokens_{nullptr}; +#endif +}; + +TORCH_MODULE(Qwen3NextModel); + +class Qwen3NextForCausalLMImpl : public torch::nn::Module { + public: + Qwen3NextForCausalLMImpl(const ModelContext& context) { + model_ = register_module("model", Qwen3NextModel(context)); +#if defined(USE_NPU) && defined(USE_NPU_TORCH) + lm_head_ = + register_module("lm_head", + layer::LmHead(context.get_model_args().hidden_size(), + context.get_model_args().vocab_size(), + /*bias=*/false, + /*gather_output=*/true, + QuantArgs{}, + context.get_parallel_args(), + context.get_tensor_options())); +#else + npu_lm_head_ = register_module("lm_head", layer::NpuLmHead(context)); +#endif + } + + // tokens: [num_tokens] + // positions: [num_tokens] token pos in the sequence + // returns: [num_tokens, hidden_size] + torch::Tensor forward(const std::vector& tokens, + const std::vector& positions, + std::vector& kv_caches, + const std::vector& input_params) { + return model_(tokens[0], positions[0], kv_caches, input_params[0]); + } + + // hidden_states: [num_tokens, hidden_size] + // seleted_idxes: [num_tokens] + // returns: [num_tokens, vocab_size] + torch::Tensor logits(const torch::Tensor& hidden_states, + const torch::Tensor& seleted_idxes) { +#if defined(USE_NPU) && defined(USE_NPU_TORCH) + auto h = hidden_states; + if (seleted_idxes.defined()) { + h = h.index_select(/*dim=*/0, seleted_idxes); + } + return lm_head_(h); +#endif + } + + void load_model(std::unique_ptr loader) { + for (const auto& state_dict : loader->get_state_dicts()) { + model_->load_state_dict(state_dict->get_dict_with_prefix("model.")); +#if defined(USE_NPU_TORCH) + lm_head_->load_state_dict(state_dict->get_dict_with_prefix("lm_head.")); +#else + npu_lm_head_->load_state_dict(state_dict->get_dict_with_prefix("lm_head.")); +#endif + } + +#if defined(USE_NPU) && !defined(USE_NPU_TORCH) + // verify + model_->verify_loaded_weights("model."); + npu_lm_head_->verify_loaded_weights("lm_head."); + + model_->merge_loaded_weights(); + npu_lm_head_->merge_loaded_weights(); +#endif + } + + virtual void prepare_expert_weight(int32_t layer_id, + const std::vector& expert_ids) { + return; + } + virtual void update_expert_weight(int32_t layer_id) { return; } + +#if defined(USE_NPU) + + layer::NpuLmHead get_lm_head() { return npu_lm_head_; } + + void set_lm_head(layer::NpuLmHead& head) { npu_lm_head_ = head; } + + std::vector get_word_embedding() { + return model_->get_word_embedding(); + } + + void set_word_embedding( + std::vector& word_embedding) { + model_->set_word_embedding(word_embedding); + } +#endif + + private: + layer::NpuLmHead npu_lm_head_{nullptr}; + layer::LmHead lm_head_{nullptr}; + Qwen3NextModel model_{nullptr}; + +}; +TORCH_MODULE(Qwen3NextForCausalLM); + +// register the causal model +REGISTER_CAUSAL_MODEL(qwen3_next, Qwen3NextForCausalLM); + +// register the model args +REGISTER_MODEL_ARGS(qwen3_next, [&] { + LOAD_ARG_OR(model_type, "model_type", "qwen3_next"); + LOAD_ARG_OR(dtype, "torch_dtype", ""); + LOAD_ARG_OR(attention_bias, "attention_bias", false); + LOAD_ARG_OR(attention_dropout, "attention_dropout", 0.0f); + LOAD_ARG_OR(bos_token_id, "bos_token_id", 151643); + LOAD_ARG_OR(decoder_sparse_step, "decoder_sparse_step", 1); + LOAD_ARG_OR(eos_token_id, "eos_token_id", 151645); + LOAD_ARG_OR(head_dim, "head_dim", 256); + LOAD_ARG_OR(hidden_act, "hidden_act", "silu"); + LOAD_ARG_OR(hidden_size, "hidden_size", 2048); + LOAD_ARG_OR(initializer_range, "initializer_range", 0.02f); + LOAD_ARG_OR(intermediate_size, "intermediate_size", 5120); + LOAD_ARG_OR(max_position_embeddings, "max_position_embeddings", 262144); + LOAD_ARG_OR(max_window_layers, "max_window_layers", 28); + LOAD_ARG_OR(moe_intermediate_size, "moe_intermediate_size", 512); + LOAD_ARG_OR(norm_topk_prob, "norm_topk_prob", true); + LOAD_ARG_OR(n_heads, "num_attention_heads", 16); + LOAD_ARG_OR(num_experts, "num_experts", 512); + LOAD_ARG_OR(num_experts_per_tok, "num_experts_per_tok", 10); + LOAD_ARG_OR(n_layers, "num_hidden_layers", 48); + LOAD_ARG_OR(n_kv_heads, "num_key_value_heads", 2); + LOAD_ARG_OR(output_router_logits, "output_router_logits", false); + LOAD_ARG_OR(rms_norm_eps, "rms_norm_eps", 1e-6); + LOAD_ARG_OR(rope_theta, "rope_theta", 10000000.0f); + LOAD_ARG_OR(router_aux_loss_coef, "router_aux_loss_coef", 0.001f); + LOAD_ARG_OR(use_sliding_window, "use_sliding_window", false); + LOAD_ARG_OR(sliding_window, "sliding_window", 4096); + LOAD_ARG_OR(tie_word_embeddings, "tie_word_embeddings", false); + LOAD_ARG_OR(vocab_size, "vocab_size", 151936); + LOAD_ARG_OR(mlp_only_layers, "mlp_only_layers", std::vector()); + + // Additional parameters for Qwen3-Next architecture + LOAD_ARG_OR(attn_output_gate, "attn_output_gate", true); + LOAD_ARG_OR(full_attention_interval, "full_attention_interval", 4); + LOAD_ARG_OR(linear_conv_kernel_dim, "linear_conv_kernel_dim", 4); + LOAD_ARG_OR(linear_key_head_dim, "linear_key_head_dim", 128); + LOAD_ARG_OR(linear_num_key_heads, "linear_num_key_heads", 16); + LOAD_ARG_OR(linear_num_value_heads, "linear_num_value_heads", 32); + LOAD_ARG_OR(linear_value_head_dim, "linear_value_head_dim", 128); + LOAD_ARG_OR(partial_rotary_factor, "partial_rotary_factor", 0.25f); + LOAD_ARG_OR(shared_expert_intermediate_size, "shared_expert_intermediate_size", 512); + + SET_ARG(stop_token_ids, std::unordered_set({args->eos_token_id()})); +}); + +} // namespace xllm + diff --git a/xllm/models/models.h b/xllm/models/models.h index b8e02d502..1e4d100ec 100644 --- a/xllm/models/models.h +++ b/xllm/models/models.h @@ -19,6 +19,7 @@ limitations under the License. // capture. This variable may be removed in the future. #if defined(USE_NPU) && defined(USE_NPU_TORCH) #include "llm/qwen3.h" // IWYU pragma: keep +#include "llm/qwen3_next.h" #elif defined(USE_NPU) #include "dit/pipeline_flux.h" // IWYU pragma: keep #include "dit/pipeline_flux_control.h" // IWYU pragma: keep diff --git a/xllm/proto/worker.proto b/xllm/proto/worker.proto index b9ac03981..92443e2f8 100644 --- a/xllm/proto/worker.proto +++ b/xllm/proto/worker.proto @@ -33,6 +33,8 @@ message KVCacheShape { repeated int64 key_shape = 1; repeated int64 value_shape = 2; repeated int64 index_shape = 3; + repeated int64 conv_shape = 4; + repeated int64 ssm_shape = 5; } message AllocateKVCacheRequest { From a0285137d396150dc2e8a0a8f69c1cce4cb4ee53 Mon Sep 17 00:00:00 2001 From: "ext.wangxiaochi1" Date: Thu, 26 Feb 2026 19:47:23 +0800 Subject: [PATCH 03/13] add triton kernel api --- xllm/core/kernels/ops_api.cpp | 86 +++++++++++++++++++ xllm/core/kernels/ops_api.h | 11 +++ xllm/core/kernels/param.h | 64 ++++++++++++++ .../common/partial_rotary_embedding.cpp | 2 +- 4 files changed, 162 insertions(+), 1 deletion(-) diff --git a/xllm/core/kernels/ops_api.cpp b/xllm/core/kernels/ops_api.cpp index 401527913..ddc6edbf4 100644 --- a/xllm/core/kernels/ops_api.cpp +++ b/xllm/core/kernels/ops_api.cpp @@ -767,6 +767,19 @@ moe_init_routing_v2(MoeInitRoutingV2Params& params) { #endif } +std::pair fused_gdn_gating(FusedGdnGatingParams& params) { +#if defined(USE_NPU) + return npu::npu_fused_gdn_gating(params.A_log, + params.a, + params.b, + params.dt_bias, + params.beta, + params.threshold); +#else + NOT_IMPLEMENTED(); +#endif +} + std::tuple fp8_scaled_quantize( Fp8ScaledQuantizeParams& params) { #if defined(USE_CUDA) @@ -776,6 +789,27 @@ std::tuple fp8_scaled_quantize( #endif } +std::pair fused_recurrent_gated_delta_rule( + FusedRecurrentGatedDeltaRuleParams& params) { +#if defined(USE_NPU) + return npu::npu_fused_recurrent_gated_delta_rule( + params.q, + params.k, + params.v, + params.g, + params.beta, + params.scale, + params.initial_state, + params.inplace_final_state, + params.cu_seqlens, + params.ssm_state_indices, + params.num_accepted_tokens, + params.use_qk_l2norm_in_kernel); +#else + NOT_IMPLEMENTED(); +#endif +} + torch::Tensor fp8_scaled_matmul(Fp8ScaledMatmulParams& params) { #if defined(USE_CUDA) auto out_2d = cuda::fp8_scaled_matmul(params.a, @@ -864,4 +898,56 @@ std::tuple fused_add_rms_norm_static_fp8_quant( #endif } +torch::Tensor causal_conv1d_update(CausalConv1dUpdateParams& params) { +#if defined(USE_NPU) + return npu::npu_causal_conv1d_update( + params.x, + params.conv_state, + params.weight, + params.activation, + params.bias, + params.cache_seqlens, + params.conv_state_indices, + params.num_accepted_tokens, + params.query_start_loc, + params.max_query_len, + params.intermediate_conv_window, + params.pad_slot_id, + params.validate_data); +#else + NOT_IMPLEMENTED(); +#endif +} + +torch::Tensor gated_layer_norm(GatedLayerNormParams& params) { +#if defined(USE_NPU) + return npu::layer_norm_fwd( + params.x, + params.weight, + params.bias, + params.eps, + params.z, + params.group_size, + params.norm_before_gate, + params.is_rms_norm); +#else + NOT_IMPLEMENTED(); +#endif +} + +std::pair partial_rotary_embedding(PartialRotaryEmbeddingParams& params) { +#if defined(USE_NPU) + return npu::apply_npu_partial_rotary_embedding( + params.positions, + params.query, + params.key, + params.head_size, + params.rotary_dim, + params.cos_sin_cache, + params.is_neox_style + ); +#else + NOT_IMPLEMENTED(); +#endif +} } // namespace xllm::kernel diff --git a/xllm/core/kernels/ops_api.h b/xllm/core/kernels/ops_api.h index 02bf95a8f..c9285e756 100644 --- a/xllm/core/kernels/ops_api.h +++ b/xllm/core/kernels/ops_api.h @@ -125,4 +125,15 @@ torch::Tensor rms_norm_static_fp8_quant(RmsNormStaticFp8QuantParams& params); std::tuple fused_add_rms_norm_static_fp8_quant( FusedAddRmsNormStaticFp8QuantParams& params); +std::pair fused_gdn_gating(FusedGdnGatingParams& params); + +std::pair fused_recurrent_gated_delta_rule( + FusedRecurrentGatedDeltaRuleParams& params); + +torch::Tensor causal_conv1d_update(CausalConv1dUpdateParams& params); + +torch::Tensor gated_layer_norm(GatedLayerNormParams& params); + +std::pair partial_rotary_embedding(PartialRotaryEmbeddingParams& params); + } // namespace xllm::kernel diff --git a/xllm/core/kernels/param.h b/xllm/core/kernels/param.h index 4995b6c3a..a8978d372 100644 --- a/xllm/core/kernels/param.h +++ b/xllm/core/kernels/param.h @@ -1291,4 +1291,68 @@ struct FusedAddRmsNormStaticFp8QuantParams { double epsilon; }; +// NPU Fused GDN Gating parameters +struct FusedGdnGatingParams { + torch::Tensor A_log; + torch::Tensor a; + torch::Tensor b; + torch::Tensor dt_bias; + float beta = 1.0f; + float threshold = 20.0f; +}; + +// NPU Fused Recurrent Gated Delta Rule parameters +struct FusedRecurrentGatedDeltaRuleParams { + torch::Tensor q; + torch::Tensor k; + torch::Tensor v; + torch::Tensor g; + std::optional beta = std::nullopt; + std::optional scale = std::nullopt; + std::optional initial_state = std::nullopt; + bool inplace_final_state = true; + std::optional cu_seqlens = std::nullopt; + std::optional ssm_state_indices = std::nullopt; + std::optional num_accepted_tokens = std::nullopt; + bool use_qk_l2norm_in_kernel = false; +}; + +// NPU Causal Conv1d Update parameters +struct CausalConv1dUpdateParams { + torch::Tensor x; + torch::Tensor conv_state; + torch::Tensor weight; + bool activation = true; + std::optional bias = std::nullopt; + std::optional cache_seqlens = std::nullopt; + std::optional conv_state_indices = std::nullopt; + std::optional num_accepted_tokens = std::nullopt; + std::optional query_start_loc = std::nullopt; + int32_t max_query_len = -1; + std::optional intermediate_conv_window = std::nullopt; + int32_t pad_slot_id = -1; + bool validate_data = false; +}; + +struct GatedLayerNormParams { + torch::Tensor x; + torch::Tensor weight; + torch::Tensor bias; + double eps; + std::optional z = std::nullopt; + int64_t group_size = -1; + bool norm_before_gate = true; + bool is_rms_norm = true; +}; + + +struct PartialRotaryEmbeddingParams { + torch::Tensor positions; + torch::Tensor query; + torch::Tensor key; + int64_t head_size; + int64_t rotary_dim; + torch::Tensor cos_sin_cache; + bool is_neox_style; +}; } // namespace xllm::kernel diff --git a/xllm/core/layers/common/partial_rotary_embedding.cpp b/xllm/core/layers/common/partial_rotary_embedding.cpp index e44d12c24..cb5682ba3 100644 --- a/xllm/core/layers/common/partial_rotary_embedding.cpp +++ b/xllm/core/layers/common/partial_rotary_embedding.cpp @@ -54,7 +54,7 @@ PartialRotaryEmbeddingImpl::PartialRotaryEmbeddingImpl(int64_t rotary_dim, void PartialRotaryEmbeddingImpl::forward(const torch::Tensor& positions, torch::Tensor& q, torch::Tensor& k) { - xllm::kernel::PartialRotaryEmbedding partial_rotary_params; + xllm::kernel::PartialRotaryEmbeddingParams partial_rotary_params; partial_rotary_params.positions = positions; partial_rotary_params.query = q; partial_rotary_params.key = k; From 1be2d5ca870cde7af8955e898b83647fc2510d5d Mon Sep 17 00:00:00 2001 From: "ext.wangxiaochi1" Date: Thu, 26 Feb 2026 20:01:11 +0800 Subject: [PATCH 04/13] add rope atb ops --- xllm/core/kernels/npu/rope.cpp | 42 ++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/xllm/core/kernels/npu/rope.cpp b/xllm/core/kernels/npu/rope.cpp index 2d3b21896..a25b21ead 100644 --- a/xllm/core/kernels/npu/rope.cpp +++ b/xllm/core/kernels/npu/rope.cpp @@ -43,4 +43,46 @@ void apply_rotary(torch::Tensor& q, at_npu::native::custom_ops::npu_apply_rotary_pos_emb(q, k, cos, sin); } +std::pair apply_npu_partial_rotary_embedding(const torch::Tensor &positions, + torch::Tensor &query, + torch::Tensor &key, + int64_t head_size, + int64_t rotary_dim, + const torch::Tensor &cos_sin_cache, + bool is_neox_style) { + torch::IntArrayRef query_shape = query.sizes(); + torch::IntArrayRef key_shape = key.sizes(); + + int64_t num_tokens = query.size(0); + + torch::Tensor query_reshaped = query.view({num_tokens, -1, head_size}); + torch::Tensor key_reshaped = key.view({num_tokens, -1, head_size}); + + torch::Tensor q_rot = query_reshaped.slice(-1, 0, rotary_dim); + torch::Tensor q_pass = query_reshaped.slice(-1, rotary_dim, head_size); + torch::Tensor k_rot = key_reshaped.slice(-1, 0, rotary_dim); + torch::Tensor k_pass = key_reshaped.slice(-1, rotary_dim, head_size); + + torch::Tensor q_rot_contig = q_rot.contiguous().view({num_tokens, -1}); + torch::Tensor k_rot_contig = k_rot.contiguous().view({num_tokens, -1}); + atb::npu_rotary_embedding( + positions, + q_rot_contig, + k_rot_contig, + head_size, + cos_sin_cache, + is_neox_style + ); + torch::Tensor q_rot_3d = q_rot_contig.view({num_tokens, -1, rotary_dim}); + torch::Tensor k_rot_3d = k_rot_contig.view({num_tokens, -1, rotary_dim}); + + torch::Tensor q_concat = at::cat({q_rot_3d, q_pass}, -1); + torch::Tensor q_final = q_concat.reshape(query_shape); + + torch::Tensor k_concat = at::cat({k_rot_3d, k_pass}, -1); + torch::Tensor k_final = k_concat.reshape(key_shape); + + return {q_final, k_final}; +} + } // namespace xllm::kernel::npu \ No newline at end of file From 75a3417fa239c9857ab6ca37fd4325ac9326cccd Mon Sep 17 00:00:00 2001 From: xuyexiong Date: Fri, 27 Feb 2026 19:31:03 +0800 Subject: [PATCH 05/13] bugfix: fix some compile problem: LOAD_MERGED_WEIGHT_V2 / testing trea / Qwen3NextDecoderLayer. --- CMakeLists.txt | 32 +- setup.py | 16 +- xllm/core/framework/state_dict/utils.h | 11 + xllm/core/kernels/CMakeLists.txt | 3 +- xllm/core/kernels/npu/npu_ops_api.h | 8 + xllm/core/kernels/npu/rope.cpp | 81 ++- xllm/core/kernels/ops_api.cpp | 75 +- xllm/core/kernels/ops_api.h | 4 +- .../common/attention_metadata_builder.cpp | 2 + .../layers/common/qwen3_next_attention.cpp | 85 +-- .../common/qwen3_next_gated_delta_net.cpp | 639 ++++++++++-------- xllm/core/layers/qwen3_next_decoder_layer.cpp | 30 +- xllm/core/layers/qwen3_next_decoder_layer.h | 1 + xllm/models/llm/qwen3_next.h | 67 +- 14 files changed, 567 insertions(+), 487 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 16183d50c..491d961c8 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -28,6 +28,7 @@ if(USE_NPU) message(STATUS "Building for device: A2 (macro USE_A2 defined)") endif() +<<<<<<< HEAD # Override Mooncake option for mooncake transfer engine # CANN 8.5+ migration: ascend_direct_transport replaces ascend_transport set(USE_ASCEND_DIRECT ON CACHE BOOL "Enable ADXL engine for Ascend NPU" FORCE) @@ -37,6 +38,34 @@ if(USE_NPU) CACHE PATH "Path to xllm_atb_layers source tree") if(NOT EXISTS "${XLLM_ATB_LAYERS_SOURCE_DIR}") message(FATAL_ERROR "xllm_atb_layers source not found: ${XLLM_ATB_LAYERS_SOURCE_DIR}") +======= + option(INSTALL_XLLM_KERNELS "Install xllm_kernels RPM" OFF) + message(STATUS "INSTALL_XLLM_KERNELS enabled: ${INSTALL_XLLM_KERNELS}") + if(INSTALL_XLLM_KERNELS) + if(DEVICE_TYPE STREQUAL "USE_A3") + message("downloading a3 arm xllm kernels") + file(DOWNLOAD + "https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.8.0/xllm_kernels-1.3.10-Linux.a3.arm.rpm" + "${CMAKE_BINARY_DIR}/xllm_kernels.rpm" + ) + else() + if(DEVICE_ARCH STREQUAL "ARM") + message("downloading a2 arm xllm_kernels") + file(DOWNLOAD + "https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.8.0/xllm_kernels-1.3.10-Linux.a2.arm.rpm" + "${CMAKE_BINARY_DIR}/xllm_kernels.rpm" + ) + else() + message("downloading a2 x86 xllm_kernels") + file(DOWNLOAD + "https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.8.0/xllm_kernels-1.3.10-Linux.a2.x86.rpm" + "${CMAKE_BINARY_DIR}/xllm_kernels.rpm" + ) + endif() + endif() + execute_process(COMMAND rpm -ivh --replacepkgs --replacefiles "${CMAKE_BINARY_DIR}/xllm_kernels.rpm") + file(WRITE "${CMAKE_BINARY_DIR}/.xllm_installed" "") +>>>>>>> 89ccdda (bugfix: fix some compile problem: LOAD_MERGED_WEIGHT_V2 / testing trea / Qwen3NextDecoderLayer.) endif() message(STATUS "Using xllm_atb_layers source at: ${XLLM_ATB_LAYERS_SOURCE_DIR}") @@ -315,7 +344,7 @@ endif() if(USE_NPU) # USE_NPU_TORCH: Temporary flag used for debugging qwen3 torch NPU graph # capture. This variable may be removed in the future. - # add_definitions(-DUSE_NPU_TORCH) + add_definitions(-DUSE_NPU_TORCH) add_definitions(-DUSE_NPU) add_definitions(-DBUILD_LIBTORCH) add_definitions(-DTORCH_SETCUSTOMHANDLER=ON) @@ -338,6 +367,7 @@ if(USE_NPU) $ENV{NPU_HOME_PATH}/include $ENV{ATB_HOME_PATH}/include $ENV{NPU_HOME_PATH}/opp/vendors/xllm/op_api/include/ + ${CMAKE_CURRENT_SOURCE_DIR}/third_party/torch_npu_ops/ ) # Keep third-party warnings suppressed while providing headers expected by # legacy includes like "atb_speed/log.h" from xllm_atb_layers. diff --git a/setup.py b/setup.py index 6a2343bce..2266eeffb 100644 --- a/setup.py +++ b/setup.py @@ -514,7 +514,21 @@ def parse_arguments() -> dict[str, Any]: default='auto', help='Device type: a2, a3, mlu, ilu, cuda or musa (case-insensitive)' ) - + + parser.add_argument( + '--dry-run', + action='store_true', + help='Dry run mode (do not execute pre_build)' + ) + + parser.add_argument( + '--install-xllm-kernels', + type=str.lower, + choices=['true', 'false', '1', '0', 'yes', 'no', 'y', 'n', 'on', 'off'], + default='false', + help='Whether to install xllm kernels' + ) + parser.add_argument( '--generate-so', type=str.lower, diff --git a/xllm/core/framework/state_dict/utils.h b/xllm/core/framework/state_dict/utils.h index 92e24bd89..43aa847e1 100644 --- a/xllm/core/framework/state_dict/utils.h +++ b/xllm/core/framework/state_dict/utils.h @@ -227,4 +227,15 @@ void load_merged_weight_v2(const StateDict& state_dict, shard_size, \ name##_, \ name##_is_loaded_); + +#define LOAD_MERGED_WEIGHT_V2(name, dim) \ + weight::load_merged_weight_v2(state_dict, \ + #name, \ + dim, \ + rank, \ + world_size, \ + shard_tensor_count, \ + shard_sizes, \ + name##_, \ + name##_is_loaded_); } // namespace xllm diff --git a/xllm/core/kernels/CMakeLists.txt b/xllm/core/kernels/CMakeLists.txt index b7b94c1ec..60ed82d5b 100644 --- a/xllm/core/kernels/CMakeLists.txt +++ b/xllm/core/kernels/CMakeLists.txt @@ -36,9 +36,10 @@ cc_library( ops_api.cpp DEPS torch + triton_adapter $<$:npu_kernels> $<$:mlu_kernels> $<$:musa_kernels> $<$:cuda_kernels> $<$:ilu_kernels> -) \ No newline at end of file +) diff --git a/xllm/core/kernels/npu/npu_ops_api.h b/xllm/core/kernels/npu/npu_ops_api.h index 995fd573b..a02c83ab9 100644 --- a/xllm/core/kernels/npu/npu_ops_api.h +++ b/xllm/core/kernels/npu/npu_ops_api.h @@ -126,4 +126,12 @@ apply_npu_moe_init_routing_v2(const torch::Tensor& x, torch::IntArrayRef active_expert_range, int row_idx_type); +std::pair apply_npu_partial_rotary_embedding(const torch::Tensor &positions, + torch::Tensor &query, + torch::Tensor &key, + int64_t head_size, + int64_t rotary_dim, + const torch::Tensor &cos_sin_cache, + bool is_neox_style); + } // namespace xllm::kernel::npu diff --git a/xllm/core/kernels/npu/rope.cpp b/xllm/core/kernels/npu/rope.cpp index a25b21ead..c29dbb2bc 100644 --- a/xllm/core/kernels/npu/rope.cpp +++ b/xllm/core/kernels/npu/rope.cpp @@ -43,46 +43,45 @@ void apply_rotary(torch::Tensor& q, at_npu::native::custom_ops::npu_apply_rotary_pos_emb(q, k, cos, sin); } -std::pair apply_npu_partial_rotary_embedding(const torch::Tensor &positions, - torch::Tensor &query, - torch::Tensor &key, - int64_t head_size, - int64_t rotary_dim, - const torch::Tensor &cos_sin_cache, - bool is_neox_style) { - torch::IntArrayRef query_shape = query.sizes(); - torch::IntArrayRef key_shape = key.sizes(); - - int64_t num_tokens = query.size(0); - - torch::Tensor query_reshaped = query.view({num_tokens, -1, head_size}); - torch::Tensor key_reshaped = key.view({num_tokens, -1, head_size}); - - torch::Tensor q_rot = query_reshaped.slice(-1, 0, rotary_dim); - torch::Tensor q_pass = query_reshaped.slice(-1, rotary_dim, head_size); - torch::Tensor k_rot = key_reshaped.slice(-1, 0, rotary_dim); - torch::Tensor k_pass = key_reshaped.slice(-1, rotary_dim, head_size); - - torch::Tensor q_rot_contig = q_rot.contiguous().view({num_tokens, -1}); - torch::Tensor k_rot_contig = k_rot.contiguous().view({num_tokens, -1}); - atb::npu_rotary_embedding( - positions, - q_rot_contig, - k_rot_contig, - head_size, - cos_sin_cache, - is_neox_style - ); - torch::Tensor q_rot_3d = q_rot_contig.view({num_tokens, -1, rotary_dim}); - torch::Tensor k_rot_3d = k_rot_contig.view({num_tokens, -1, rotary_dim}); - - torch::Tensor q_concat = at::cat({q_rot_3d, q_pass}, -1); - torch::Tensor q_final = q_concat.reshape(query_shape); - - torch::Tensor k_concat = at::cat({k_rot_3d, k_pass}, -1); - torch::Tensor k_final = k_concat.reshape(key_shape); - - return {q_final, k_final}; +std::pair apply_npu_partial_rotary_embedding( + const torch::Tensor& positions, + torch::Tensor& query, + torch::Tensor& key, + int64_t head_size, + int64_t rotary_dim, + const torch::Tensor& cos_sin_cache, + bool is_neox_style) { + torch::IntArrayRef query_shape = query.sizes(); + torch::IntArrayRef key_shape = key.sizes(); + + int64_t num_tokens = query.size(0); + + torch::Tensor query_reshaped = query.view({num_tokens, -1, head_size}); + torch::Tensor key_reshaped = key.view({num_tokens, -1, head_size}); + + torch::Tensor q_rot = query_reshaped.slice(-1, 0, rotary_dim); + torch::Tensor q_pass = query_reshaped.slice(-1, rotary_dim, head_size); + torch::Tensor k_rot = key_reshaped.slice(-1, 0, rotary_dim); + torch::Tensor k_pass = key_reshaped.slice(-1, rotary_dim, head_size); + + torch::Tensor q_rot_contig = q_rot.contiguous().view({num_tokens, -1}); + torch::Tensor k_rot_contig = k_rot.contiguous().view({num_tokens, -1}); + atb::npu_rotary_embedding(positions, + q_rot_contig, + k_rot_contig, + head_size, + cos_sin_cache, + is_neox_style); + torch::Tensor q_rot_3d = q_rot_contig.view({num_tokens, -1, rotary_dim}); + torch::Tensor k_rot_3d = k_rot_contig.view({num_tokens, -1, rotary_dim}); + + torch::Tensor q_concat = at::cat({q_rot_3d, q_pass}, -1); + torch::Tensor q_final = q_concat.reshape(query_shape); + + torch::Tensor k_concat = at::cat({k_rot_3d, k_pass}, -1); + torch::Tensor k_final = k_concat.reshape(key_shape); + + return {q_final, k_final}; } -} // namespace xllm::kernel::npu \ No newline at end of file +} // namespace xllm::kernel::npu diff --git a/xllm/core/kernels/ops_api.cpp b/xllm/core/kernels/ops_api.cpp index ddc6edbf4..4b6fd8ac3 100644 --- a/xllm/core/kernels/ops_api.cpp +++ b/xllm/core/kernels/ops_api.cpp @@ -767,7 +767,8 @@ moe_init_routing_v2(MoeInitRoutingV2Params& params) { #endif } -std::pair fused_gdn_gating(FusedGdnGatingParams& params) { +std::pair fused_gdn_gating( + FusedGdnGatingParams& params) { #if defined(USE_NPU) return npu::npu_fused_gdn_gating(params.A_log, params.a, @@ -780,15 +781,6 @@ std::pair fused_gdn_gating(FusedGdnGatingParams& p #endif } -std::tuple fp8_scaled_quantize( - Fp8ScaledQuantizeParams& params) { -#if defined(USE_CUDA) - return cuda::fp8_scaled_quantize(params.input, params.output, params.scale); -#else - NOT_IMPLEMENTED(); -#endif -} - std::pair fused_recurrent_gated_delta_rule( FusedRecurrentGatedDeltaRuleParams& params) { #if defined(USE_NPU) @@ -900,20 +892,19 @@ std::tuple fused_add_rms_norm_static_fp8_quant( torch::Tensor causal_conv1d_update(CausalConv1dUpdateParams& params) { #if defined(USE_NPU) - return npu::npu_causal_conv1d_update( - params.x, - params.conv_state, - params.weight, - params.activation, - params.bias, - params.cache_seqlens, - params.conv_state_indices, - params.num_accepted_tokens, - params.query_start_loc, - params.max_query_len, - params.intermediate_conv_window, - params.pad_slot_id, - params.validate_data); + return npu::npu_causal_conv1d_update(params.x, + params.conv_state, + params.weight, + params.activation, + params.bias, + params.cache_seqlens, + params.conv_state_indices, + params.num_accepted_tokens, + params.query_start_loc, + params.max_query_len, + params.intermediate_conv_window, + params.pad_slot_id, + params.validate_data); #else NOT_IMPLEMENTED(); #endif @@ -921,31 +912,29 @@ torch::Tensor causal_conv1d_update(CausalConv1dUpdateParams& params) { torch::Tensor gated_layer_norm(GatedLayerNormParams& params) { #if defined(USE_NPU) - return npu::layer_norm_fwd( - params.x, - params.weight, - params.bias, - params.eps, - params.z, - params.group_size, - params.norm_before_gate, - params.is_rms_norm); + return npu::layer_norm_fwd(params.x, + params.weight, + params.bias, + params.eps, + params.z, + params.group_size, + params.norm_before_gate, + params.is_rms_norm); #else NOT_IMPLEMENTED(); #endif } -std::pair partial_rotary_embedding(PartialRotaryEmbeddingParams& params) { +std::pair partial_rotary_embedding( + PartialRotaryEmbeddingParams& params) { #if defined(USE_NPU) - return npu::apply_npu_partial_rotary_embedding( - params.positions, - params.query, - params.key, - params.head_size, - params.rotary_dim, - params.cos_sin_cache, - params.is_neox_style - ); + return npu::apply_npu_partial_rotary_embedding(params.positions, + params.query, + params.key, + params.head_size, + params.rotary_dim, + params.cos_sin_cache, + params.is_neox_style); #else NOT_IMPLEMENTED(); #endif diff --git a/xllm/core/kernels/ops_api.h b/xllm/core/kernels/ops_api.h index c9285e756..ab44f21c2 100644 --- a/xllm/core/kernels/ops_api.h +++ b/xllm/core/kernels/ops_api.h @@ -16,6 +16,7 @@ limitations under the License. #pragma once #include "param.h" +#include "triton_npu/torch_api/triton_ops_api.h" namespace xllm::kernel { @@ -134,6 +135,7 @@ torch::Tensor causal_conv1d_update(CausalConv1dUpdateParams& params); torch::Tensor gated_layer_norm(GatedLayerNormParams& params); -std::pair partial_rotary_embedding(PartialRotaryEmbeddingParams& params); +std::pair partial_rotary_embedding( + PartialRotaryEmbeddingParams& params); } // namespace xllm::kernel diff --git a/xllm/core/layers/common/attention_metadata_builder.cpp b/xllm/core/layers/common/attention_metadata_builder.cpp index fdb81c7e1..808f61897 100644 --- a/xllm/core/layers/common/attention_metadata_builder.cpp +++ b/xllm/core/layers/common/attention_metadata_builder.cpp @@ -91,8 +91,10 @@ AttentionMetadata AttentionMetadataBuilder::build( #if defined(USE_NPU) // NPU path uses per-sequence lengths (not cumulative), so no diff. attn_metadata.kv_seq_lens = params.kv_seq_lens; + attn_metadata.q_seq_lens = params.q_seq_lens; #else attn_metadata.kv_seq_lens = torch::diff(params.kv_seq_lens); // kv seqlens + attn_metadata.q_seq_lens = torch::diff(params.q_seq_lens); // q seqlens #endif } diff --git a/xllm/core/layers/common/qwen3_next_attention.cpp b/xllm/core/layers/common/qwen3_next_attention.cpp index f78983286..3b3c5e508 100644 --- a/xllm/core/layers/common/qwen3_next_attention.cpp +++ b/xllm/core/layers/common/qwen3_next_attention.cpp @@ -14,17 +14,20 @@ limitations under the License. ==============================================================================*/ #include "qwen3_next_attention.h" -#include + #include +#include + #include namespace xllm { namespace layer { -Qwen3NextAttentionImpl::Qwen3NextAttentionImpl(const ModelArgs& args, - const QuantArgs& quant_args, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options, - int32_t layer_id) { +Qwen3NextAttentionImpl::Qwen3NextAttentionImpl( + const ModelArgs& args, + const QuantArgs& quant_args, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options, + int32_t layer_id) { const int64_t tp_size = parallel_args.tp_group_->world_size(); const int64_t total_num_heads = args.n_heads(); const int64_t total_num_kv_heads = args.n_kv_heads().value_or(args.n_heads()); @@ -49,16 +52,17 @@ Qwen3NextAttentionImpl::Qwen3NextAttentionImpl(const ModelArgs& args, scaling_ = 1.0f / std::sqrt(static_cast(head_dim_)); attn_output_gate_ = args.attn_output_gate(); // 1. QKV linear - qkv_proj_ = register_module("qkv_proj", - QKVParallelLinear(args.hidden_size(), - attn_output_gate_ ? num_heads_ * 2 : num_heads_, - num_kv_heads_, - args.head_dim(), - num_kv_head_replicas_, - /*bias=*/args.attention_bias(), - /*gather_output=*/false, - parallel_args, - options)); + qkv_proj_ = register_module( + "qkv_proj", + QKVParallelLinear(args.hidden_size(), + attn_output_gate_ ? num_heads_ * 2 : num_heads_, + num_kv_heads_, + args.head_dim(), + num_kv_head_replicas_, + /*bias=*/args.attention_bias(), + /*gather_output=*/false, + parallel_args, + options)); // 2. O proj o_proj_ = register_module("o_proj", @@ -68,28 +72,29 @@ Qwen3NextAttentionImpl::Qwen3NextAttentionImpl(const ModelArgs& args, /*input_is_parallelized=*/true, /*if_reduce_results=*/true, quant_args, - parallel_args, + parallel_args.tp_group_, options)); // 3. Q norm - q_norm_ = register_module("q_norm", - Qwen3NextRMSNorm(head_dim_, args.rms_norm_eps(), options)); + q_norm_ = register_module( + "q_norm", Qwen3NextRMSNorm(head_dim_, args.rms_norm_eps(), options)); // 4. K norm - k_norm_ = register_module("k_norm", - Qwen3NextRMSNorm(head_dim_, args.rms_norm_eps(), options)); - + k_norm_ = register_module( + "k_norm", Qwen3NextRMSNorm(head_dim_, args.rms_norm_eps(), options)); + // 5. Rotary embedding - const int rotary_dim = static_cast(head_dim_ * args.partial_rotary_factor()); - rotary_emb_ = register_module( - "rotary_emb", - PartialRotaryEmbedding(rotary_dim, - args.max_position_embeddings(), - args.rope_theta(), - head_dim_, - true, - false, - options)); + const int rotary_dim = + static_cast(head_dim_ * args.partial_rotary_factor()); + rotary_emb_ = + register_module("rotary_emb", + PartialRotaryEmbedding(rotary_dim, + args.max_position_embeddings(), + args.rope_theta(), + head_dim_, + true, + false, + options)); // 6. Attention attn_ = register_module("attn", @@ -109,20 +114,20 @@ torch::Tensor Qwen3NextAttentionImpl::forward( auto qkv = qkv_proj_->forward(hidden_states); torch::Tensor q, k, v; torch::Tensor gate; - + if (attn_output_gate_) { // Split qkv for attn_output_gate case: [q_size*2, kv_size, kv_size] auto q_gate = qkv.slice(/*dim=*/-1, 0, q_size_ * 2); k = qkv.slice(/*dim=*/-1, q_size_ * 2, q_size_ * 2 + kv_size_); - v = qkv.slice(/*dim=*/-1, q_size_ * 2 + kv_size_, q_size_ * 2 + kv_size_ * 2); + v = qkv.slice( + /*dim=*/-1, q_size_ * 2 + kv_size_, q_size_ * 2 + kv_size_ * 2); v = v.contiguous(); std::vector orig_shape; int64_t q_gate_dim = q_gate.dim(); - orig_shape = std::vector( - q_gate.sizes().slice(0, q_gate_dim - 1).begin(), - q_gate.sizes().slice(0, q_gate_dim - 1).end() - ); + orig_shape = + std::vector(q_gate.sizes().slice(0, q_gate_dim - 1).begin(), + q_gate.sizes().slice(0, q_gate_dim - 1).end()); std::vector new_shape = orig_shape; new_shape.push_back(num_heads_); @@ -161,9 +166,9 @@ torch::Tensor Qwen3NextAttentionImpl::forward( q = q_normed.view({T, q_size_}); k = k_normed.view({T, kv_size_}); - rotary_emb_->forward(positions,q,k); + rotary_emb_->forward(positions, q, k); auto out = std::get<0>(attn_->forward(attn_metadata, q, k, v, kv_cache)); - + if (attn_output_gate_) { gate = torch::sigmoid(gate); out = out * gate; diff --git a/xllm/core/layers/common/qwen3_next_gated_delta_net.cpp b/xllm/core/layers/common/qwen3_next_gated_delta_net.cpp index b208c44d0..b6e450e1f 100644 --- a/xllm/core/layers/common/qwen3_next_gated_delta_net.cpp +++ b/xllm/core/layers/common/qwen3_next_gated_delta_net.cpp @@ -12,24 +12,23 @@ limitations under the License. #include "qwen3_next_gated_delta_net.h" +#include #include -#include "xllm/core/kernels/ops_api.h" - +#include #include -#include #include #include -#include + +#include "xllm/core/kernels/ops_api.h" namespace xllm { namespace layer { - namespace { torch::Tensor l2norm(const torch::Tensor& x, int64_t dim, double eps = 1e-6) { - auto norm = torch::sqrt(torch::sum(torch::square(x), dim, true) + eps); - return x / norm; + auto norm = torch::sqrt(torch::sum(torch::square(x), dim, true) + eps); + return x / norm; } std::tuple torch_recurrent_gated_delta_rule( @@ -40,58 +39,63 @@ std::tuple torch_recurrent_gated_delta_rule( torch::Tensor beta, std::optional initial_state, bool output_final_state = true, - bool use_qk_l2norm_in_kernel = true -) { - auto initial_dtype = query.dtype(); + bool use_qk_l2norm_in_kernel = true) { + auto initial_dtype = query.dtype(); - if (use_qk_l2norm_in_kernel) { - query = l2norm(query, -1, 1e-6); - key = l2norm(key, -1, 1e-6); - } + if (use_qk_l2norm_in_kernel) { + query = l2norm(query, -1, 1e-6); + key = l2norm(key, -1, 1e-6); + } - auto to_float32_and_transpose = [](torch::Tensor x) { - return x.transpose(1, 2).contiguous().to(torch::kFloat32); - }; - query = to_float32_and_transpose(query); - key = to_float32_and_transpose(key); - value = to_float32_and_transpose(value); - beta = to_float32_and_transpose(beta); - g = to_float32_and_transpose(g); - - int64_t batch_size = key.size(0); - int64_t num_heads = key.size(1); - int64_t sequence_length = key.size(2); - int64_t k_head_dim = key.size(3); - int64_t v_head_dim = value.size(3); - - float scale_val = 1.0 / std::sqrt(static_cast(query.size(-1))); - torch::Tensor scale = torch::tensor(scale_val, query.options()); - query = query * scale; - torch::Tensor core_attn_out = torch::zeros({batch_size, num_heads, sequence_length, v_head_dim}, + auto to_float32_and_transpose = [](torch::Tensor x) { + return x.transpose(1, 2).contiguous().to(torch::kFloat32); + }; + query = to_float32_and_transpose(query); + key = to_float32_and_transpose(key); + value = to_float32_and_transpose(value); + beta = to_float32_and_transpose(beta); + g = to_float32_and_transpose(g); + + int64_t batch_size = key.size(0); + int64_t num_heads = key.size(1); + int64_t sequence_length = key.size(2); + int64_t k_head_dim = key.size(3); + int64_t v_head_dim = value.size(3); + + float scale_val = 1.0 / std::sqrt(static_cast(query.size(-1))); + torch::Tensor scale = torch::tensor(scale_val, query.options()); + query = query * scale; + torch::Tensor core_attn_out = torch::zeros( + {batch_size, num_heads, sequence_length, v_head_dim}, + torch::TensorOptions().dtype(torch::kFloat32).device(value.device())); + torch::Tensor last_recurrent_state; + if (!initial_state.has_value()) { + last_recurrent_state = torch::zeros( + {batch_size, num_heads, k_head_dim, v_head_dim}, torch::TensorOptions().dtype(torch::kFloat32).device(value.device())); - torch::Tensor last_recurrent_state; - if (!initial_state.has_value()) { - last_recurrent_state = torch::zeros({batch_size, num_heads, k_head_dim, v_head_dim}, - torch::TensorOptions().dtype(torch::kFloat32).device(value.device())); - } else { - last_recurrent_state = initial_state.value().to(value.device(), torch::kFloat32); - } + } else { + last_recurrent_state = + initial_state.value().to(value.device(), torch::kFloat32); + } - for (int64_t i = 0; i < sequence_length; ++i) { - torch::Tensor q_t = query.select(2, i); - torch::Tensor k_t = key.select(2, i); - torch::Tensor v_t = value.select(2, i); - torch::Tensor g_t = g.select(2, i).exp().unsqueeze(-1).unsqueeze(-1); - torch::Tensor beta_t = beta.select(2, i).unsqueeze(-1); - last_recurrent_state = last_recurrent_state * g_t; - torch::Tensor kv_mem = torch::sum(last_recurrent_state * k_t.unsqueeze(-1), -2); - torch::Tensor delta = (v_t - kv_mem) * beta_t; - last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2); - core_attn_out.select(2, i) = torch::sum(last_recurrent_state * q_t.unsqueeze(-1), -2); - } + for (int64_t i = 0; i < sequence_length; ++i) { + torch::Tensor q_t = query.select(2, i); + torch::Tensor k_t = key.select(2, i); + torch::Tensor v_t = value.select(2, i); + torch::Tensor g_t = g.select(2, i).exp().unsqueeze(-1).unsqueeze(-1); + torch::Tensor beta_t = beta.select(2, i).unsqueeze(-1); + last_recurrent_state = last_recurrent_state * g_t; + torch::Tensor kv_mem = + torch::sum(last_recurrent_state * k_t.unsqueeze(-1), -2); + torch::Tensor delta = (v_t - kv_mem) * beta_t; + last_recurrent_state = + last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2); + core_attn_out.select(2, i) = + torch::sum(last_recurrent_state * q_t.unsqueeze(-1), -2); + } - core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype); - return std::make_tuple(core_attn_out, last_recurrent_state); + core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype); + return std::make_tuple(core_attn_out, last_recurrent_state); } std::tuple torch_chunk_gated_delta_rule( @@ -104,186 +108,197 @@ std::tuple torch_chunk_gated_delta_rule( c10::optional initial_state = c10::nullopt, bool output_final_state = true, bool use_qk_l2norm_in_kernel = true) { - auto initial_dtype = query.dtype(); - if (use_qk_l2norm_in_kernel) { - query = l2norm(query, -1, 1e-6); - key = l2norm(key, -1, 1e-6); - } - auto to_float32 = [](torch::Tensor x) { - return x.transpose(1, 2).contiguous().to(torch::kFloat32); - }; - - query = to_float32(query); - key = to_float32(key); - value = to_float32(value); - beta = to_float32(beta); - g = to_float32(g); - - auto batch_size = query.size(0); - auto num_heads = query.size(1); - auto sequence_length = query.size(2); - auto k_head_dim = key.size(-1); - auto v_head_dim = value.size(-1); - - int64_t pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size; - query = torch::nn::functional::pad(query, torch::nn::functional::PadFuncOptions({0, 0, 0, pad_size})); - key = torch::nn::functional::pad(key, torch::nn::functional::PadFuncOptions({0, 0, 0, pad_size})); - value = torch::nn::functional::pad(value, torch::nn::functional::PadFuncOptions({0, 0, 0, pad_size})); - beta = torch::nn::functional::pad(beta, torch::nn::functional::PadFuncOptions({0, pad_size})); - g = torch::nn::functional::pad(g, torch::nn::functional::PadFuncOptions({0, pad_size})); - - int64_t total_sequence_length = sequence_length + pad_size; - float scale = 1.0 / std::sqrt(static_cast(query.size(-1))); - query = query * scale; - auto v_beta = value * beta.unsqueeze(-1); - auto k_beta = key * beta.unsqueeze(-1); - auto reshape_to_chunks = [chunk_size](torch::Tensor x) { - auto shape = x.sizes(); - std::vector new_shape = { - shape[0], shape[1], - shape[2] / chunk_size, chunk_size, - shape[3] - }; - return x.reshape(new_shape); - }; - - query = reshape_to_chunks(query); - key = reshape_to_chunks(key); - value = reshape_to_chunks(value); - k_beta = reshape_to_chunks(k_beta); - v_beta = reshape_to_chunks(v_beta); - - auto g_shape = g.sizes(); - std::vector g_new_shape = { - g_shape[0], g_shape[1], - g_shape[2] / chunk_size, chunk_size - }; - g = g.reshape(g_new_shape); - auto mask = torch::triu( - torch::ones({chunk_size, chunk_size}, torch::TensorOptions().dtype(torch::kBool).device(query.device())), - 0 - ); - - g = g.cumsum(-1); - auto g_diff = g.unsqueeze(-1) - g.unsqueeze(-2); - auto decay_mask = g_diff.tril().exp().to(torch::kFloat32); - decay_mask = decay_mask.tril(); - auto attn = -(torch::matmul(k_beta, key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0.0); - for (int64_t i = 1; i < chunk_size; ++i) { - if (!attn.is_contiguous()) { - attn = attn.contiguous(); - } - auto row = attn.slice(-2, i, i+1).slice(-1, 0, i).squeeze(-2).clone().contiguous(); - auto sub = attn.slice(-2, 0, i).slice(-1, 0, i).clone().contiguous(); - auto row_unsq = row.unsqueeze(-1).contiguous(); - auto row_sub_mul = (row_unsq * sub).contiguous(); - auto row_sub_sum = row_sub_mul.sum(-2).contiguous(); - auto row_final = (row + row_sub_sum).contiguous(); - attn.index_put_( - { - torch::indexing::Ellipsis, - torch::indexing::Slice(i, i+1), - torch::indexing::Slice(0, i) - }, - row_final.unsqueeze(-2) - ); - } - - attn = attn + torch::eye(chunk_size, torch::TensorOptions().dtype(attn.dtype()).device(attn.device())); - value = torch::matmul(attn, v_beta); - auto k_cumdecay = torch::matmul(attn, (k_beta * g.exp().unsqueeze(-1))); - torch::Tensor last_recurrent_state; - if (!initial_state.has_value()) { - last_recurrent_state = torch::zeros( - {batch_size, num_heads, k_head_dim, v_head_dim}, - torch::TensorOptions().dtype(value.dtype()).device(value.device()) - ); - } else { - last_recurrent_state = initial_state.value().to(value); - } - auto core_attn_out = torch::zeros_like(value); - mask = torch::triu( - torch::ones({chunk_size, chunk_size}, torch::TensorOptions().dtype(torch::kBool).device(query.device())), - 1 - ); - int64_t num_chunks = total_sequence_length / chunk_size; - for (int64_t i = 0; i < num_chunks; ++i) { - auto q_i = query.select(2, i); - auto k_i = key.select(2, i); - auto v_i = value.select(2, i); - auto attn_i = (torch::matmul(q_i, k_i.transpose(-1, -2)) * decay_mask.select(2, i)).masked_fill_(mask, 0.0); - auto v_prime = torch::matmul(k_cumdecay.select(2, i), last_recurrent_state); - auto v_new = v_i - v_prime; - auto attn_inter = torch::matmul(q_i * g.select(2, i).unsqueeze(-1).exp(), last_recurrent_state); - core_attn_out.select(2, i) = attn_inter + torch::matmul(attn_i, v_new); - auto g_i_last = g.select(2, i).select(-1, -1).unsqueeze(-1); - auto g_exp_term = (g_i_last - g.select(2, i)).exp().unsqueeze(-1); - auto k_g_exp = (k_i * g_exp_term).transpose(-1, -2).contiguous(); - last_recurrent_state = - last_recurrent_state * g_i_last.unsqueeze(-1).exp() + - torch::matmul(k_g_exp, v_new); + auto initial_dtype = query.dtype(); + if (use_qk_l2norm_in_kernel) { + query = l2norm(query, -1, 1e-6); + key = l2norm(key, -1, 1e-6); + } + auto to_float32 = [](torch::Tensor x) { + return x.transpose(1, 2).contiguous().to(torch::kFloat32); + }; + + query = to_float32(query); + key = to_float32(key); + value = to_float32(value); + beta = to_float32(beta); + g = to_float32(g); + + auto batch_size = query.size(0); + auto num_heads = query.size(1); + auto sequence_length = query.size(2); + auto k_head_dim = key.size(-1); + auto v_head_dim = value.size(-1); + + int64_t pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size; + query = torch::nn::functional::pad( + query, torch::nn::functional::PadFuncOptions({0, 0, 0, pad_size})); + key = torch::nn::functional::pad( + key, torch::nn::functional::PadFuncOptions({0, 0, 0, pad_size})); + value = torch::nn::functional::pad( + value, torch::nn::functional::PadFuncOptions({0, 0, 0, pad_size})); + beta = torch::nn::functional::pad( + beta, torch::nn::functional::PadFuncOptions({0, pad_size})); + g = torch::nn::functional::pad( + g, torch::nn::functional::PadFuncOptions({0, pad_size})); + + int64_t total_sequence_length = sequence_length + pad_size; + float scale = 1.0 / std::sqrt(static_cast(query.size(-1))); + query = query * scale; + auto v_beta = value * beta.unsqueeze(-1); + auto k_beta = key * beta.unsqueeze(-1); + auto reshape_to_chunks = [chunk_size](torch::Tensor x) { + auto shape = x.sizes(); + std::vector new_shape = { + shape[0], shape[1], shape[2] / chunk_size, chunk_size, shape[3]}; + return x.reshape(new_shape); + }; + + query = reshape_to_chunks(query); + key = reshape_to_chunks(key); + value = reshape_to_chunks(value); + k_beta = reshape_to_chunks(k_beta); + v_beta = reshape_to_chunks(v_beta); + + auto g_shape = g.sizes(); + std::vector g_new_shape = { + g_shape[0], g_shape[1], g_shape[2] / chunk_size, chunk_size}; + g = g.reshape(g_new_shape); + auto mask = torch::triu( + torch::ones( + {chunk_size, chunk_size}, + torch::TensorOptions().dtype(torch::kBool).device(query.device())), + 0); + + g = g.cumsum(-1); + auto g_diff = g.unsqueeze(-1) - g.unsqueeze(-2); + auto decay_mask = g_diff.tril().exp().to(torch::kFloat32); + decay_mask = decay_mask.tril(); + auto attn = -(torch::matmul(k_beta, key.transpose(-1, -2)) * decay_mask) + .masked_fill(mask, 0.0); + for (int64_t i = 1; i < chunk_size; ++i) { + if (!attn.is_contiguous()) { + attn = attn.contiguous(); } - auto core_attn_out_shape = core_attn_out.sizes(); - std::vector reshape_shape = { - core_attn_out_shape[0], core_attn_out_shape[1], - core_attn_out_shape[2] * core_attn_out_shape[3], - core_attn_out_shape[4] - }; - core_attn_out = core_attn_out.reshape(reshape_shape); - core_attn_out = core_attn_out.slice(2, 0, sequence_length); - core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype); - return std::make_tuple(core_attn_out, last_recurrent_state); + auto row = attn.slice(-2, i, i + 1) + .slice(-1, 0, i) + .squeeze(-2) + .clone() + .contiguous(); + auto sub = attn.slice(-2, 0, i).slice(-1, 0, i).clone().contiguous(); + auto row_unsq = row.unsqueeze(-1).contiguous(); + auto row_sub_mul = (row_unsq * sub).contiguous(); + auto row_sub_sum = row_sub_mul.sum(-2).contiguous(); + auto row_final = (row + row_sub_sum).contiguous(); + attn.index_put_({torch::indexing::Ellipsis, + torch::indexing::Slice(i, i + 1), + torch::indexing::Slice(0, i)}, + row_final.unsqueeze(-2)); + } + + attn = attn + + torch::eye( + chunk_size, + torch::TensorOptions().dtype(attn.dtype()).device(attn.device())); + value = torch::matmul(attn, v_beta); + auto k_cumdecay = torch::matmul(attn, (k_beta * g.exp().unsqueeze(-1))); + torch::Tensor last_recurrent_state; + if (!initial_state.has_value()) { + last_recurrent_state = torch::zeros( + {batch_size, num_heads, k_head_dim, v_head_dim}, + torch::TensorOptions().dtype(value.dtype()).device(value.device())); + } else { + last_recurrent_state = initial_state.value().to(value); + } + auto core_attn_out = torch::zeros_like(value); + mask = torch::triu( + torch::ones( + {chunk_size, chunk_size}, + torch::TensorOptions().dtype(torch::kBool).device(query.device())), + 1); + int64_t num_chunks = total_sequence_length / chunk_size; + for (int64_t i = 0; i < num_chunks; ++i) { + auto q_i = query.select(2, i); + auto k_i = key.select(2, i); + auto v_i = value.select(2, i); + auto attn_i = + (torch::matmul(q_i, k_i.transpose(-1, -2)) * decay_mask.select(2, i)) + .masked_fill_(mask, 0.0); + auto v_prime = torch::matmul(k_cumdecay.select(2, i), last_recurrent_state); + auto v_new = v_i - v_prime; + auto attn_inter = torch::matmul(q_i * g.select(2, i).unsqueeze(-1).exp(), + last_recurrent_state); + core_attn_out.select(2, i) = attn_inter + torch::matmul(attn_i, v_new); + auto g_i_last = g.select(2, i).select(-1, -1).unsqueeze(-1); + auto g_exp_term = (g_i_last - g.select(2, i)).exp().unsqueeze(-1); + auto k_g_exp = (k_i * g_exp_term).transpose(-1, -2).contiguous(); + last_recurrent_state = last_recurrent_state * g_i_last.unsqueeze(-1).exp() + + torch::matmul(k_g_exp, v_new); + } + auto core_attn_out_shape = core_attn_out.sizes(); + std::vector reshape_shape = { + core_attn_out_shape[0], + core_attn_out_shape[1], + core_attn_out_shape[2] * core_attn_out_shape[3], + core_attn_out_shape[4]}; + core_attn_out = core_attn_out.reshape(reshape_shape); + core_attn_out = core_attn_out.slice(2, 0, sequence_length); + core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype); + return std::make_tuple(core_attn_out, last_recurrent_state); } -} // namespace +} // namespace -Qwen3NextGatedDeltaNetImpl::Qwen3NextGatedDeltaNetImpl(const ModelArgs& args, - const QuantArgs& quant_args, - const ParallelArgs& parallel_args, - const torch::TensorOptions& options) { +Qwen3NextGatedDeltaNetImpl::Qwen3NextGatedDeltaNetImpl( + const ModelArgs& args, + const QuantArgs& quant_args, + const ParallelArgs& parallel_args, + const torch::TensorOptions& options) { const int64_t total_num_heads = args.n_heads(); const int64_t total_num_kv_heads = args.n_kv_heads().value_or(args.n_heads()); - tp_size_ = parallel_args.tp_group_->world_size(); + tp_size_ = parallel_args.tp_group_->world_size(); rank_ = parallel_args.tp_group_->rank(); num_k_heads_ = args.linear_num_key_heads(); num_v_heads_ = args.linear_num_value_heads(); head_k_dim_ = args.linear_key_head_dim(); - head_v_dim_ = args.linear_value_head_dim(); - k_size_ = num_k_heads_ * head_k_dim_; + head_v_dim_ = args.linear_value_head_dim(); + k_size_ = num_k_heads_ * head_k_dim_; v_size_ = num_v_heads_ * head_v_dim_; conv_kernel_size_ = args.linear_conv_kernel_dim(); // 0. QKVZ parallel linear conv1d_ = register_module("conv1d", ColumnParallelLinear(args.linear_conv_kernel_dim(), - k_size_ * 2 + v_size_, - /*bias=*/false, - /*gather_output=*/false, - quant_args, - parallel_args, - options)); - + k_size_ * 2 + v_size_, + /*bias=*/false, + /*gather_output=*/false, + quant_args, + parallel_args.tp_group_, + options)); // 1. QKVZ parallel linear qkvz_proj_ = register_module("in_proj_qkvz", - ColumnParallelLinear(args.hidden_size(), + ColumnParallelLinear(args.hidden_size(), k_size_ * 2 + v_size_ * 2, /*bias=*/false, /*gather_output=*/false, quant_args, - parallel_args, + parallel_args.tp_group_, options)); // 2. Output projection ba_proj_ = register_module("in_proj_ba", - ColumnParallelLinear(args.hidden_size(), + ColumnParallelLinear(args.hidden_size(), num_v_heads_ * 2, /*bias=*/false, /*gather_output=*/false, - quant_args, - parallel_args, - options)); + quant_args, + parallel_args.tp_group_, + options)); auto opts = options.dtype(torch::kFloat32); - dt_bias_ = register_parameter("dt_bias", torch::ones({num_v_heads_ / tp_size_}, opts), /*requires_grad=*/false); + dt_bias_ = register_parameter("dt_bias", + torch::ones({num_v_heads_ / tp_size_}, opts), + /*requires_grad=*/false); - A_log_ = register_parameter("A_log", torch::empty({num_v_heads_ / tp_size_}, opts), /*requires_grad=*/false); + A_log_ = register_parameter("A_log", + torch::empty({num_v_heads_ / tp_size_}, opts), + /*requires_grad=*/false); // 3. Output projection o_proj_ = register_module("out_proj", RowParallelLinear(v_size_, @@ -292,12 +307,12 @@ Qwen3NextGatedDeltaNetImpl::Qwen3NextGatedDeltaNetImpl(const ModelArgs& args, /*input_is_parallelized=*/true, /*if_reduce_results=*/true, quant_args, - parallel_args, + parallel_args.tp_group_, options)); // 4. RMSNorm - norm_ = register_module("norm", RmsNormGated(head_v_dim_, args.rms_norm_eps(), options)); - + norm_ = register_module( + "norm", RmsNormGated(head_v_dim_, args.rms_norm_eps(), options)); } torch::Tensor Qwen3NextGatedDeltaNetImpl::forward( @@ -305,7 +320,6 @@ torch::Tensor Qwen3NextGatedDeltaNetImpl::forward( const AttentionMetadata& attn_metadata, KVCache& kv_cache, const ModelInputParams& input_params) { - auto qkvz = qkvz_proj_->forward(hidden_states); auto qkvz_reshaped = reshape_qkvz_with_pad(attn_metadata, qkvz); auto [q, k, v, z] = process_qkvz_tensor(qkvz_reshaped); @@ -313,7 +327,8 @@ torch::Tensor Qwen3NextGatedDeltaNetImpl::forward( auto ba_reshaped = reshape_qkvz_with_pad(attn_metadata, ba); auto [b, a] = process_ba_tensor(ba_reshaped); auto rearrange_merge = [](const torch::Tensor& t) { - TORCH_CHECK(t.dim() > 2, "Tensor must have at least 2 dims! but got ", t.dim()); + TORCH_CHECK( + t.dim() > 2, "Tensor must have at least 2 dims! but got ", t.dim()); std::vector new_shape; int64_t slice_end = t.dim() - 2; auto valid_slice = t.sizes().slice(0, slice_end); @@ -328,8 +343,8 @@ torch::Tensor Qwen3NextGatedDeltaNetImpl::forward( v = rearrange_merge(v); torch::Tensor mixed_qkv = torch::cat({q, k, v}, q.dim() - 1); - mixed_qkv = mixed_qkv.transpose(1,2); - int64_t seq_len = mixed_qkv.size(2); + mixed_qkv = mixed_qkv.transpose(1, 2); + int64_t seq_len = mixed_qkv.size(2); torch::Tensor conv_cache = kv_cache.get_conv_cache(); torch::Tensor ssm_cache = kv_cache.get_ssm_cache(); torch::Tensor g, beta, core_attn_out, last_recurrent_state; @@ -337,26 +352,32 @@ torch::Tensor Qwen3NextGatedDeltaNetImpl::forward( auto conv_weight = conv1d_->weight(); if (attn_metadata.is_prefill) { - torch::Tensor conv_state = (seq_len < conv_kernel_size_-1) ? torch::pad(mixed_qkv, {0, conv_kernel_size_-1-seq_len}) : (seq_len > conv_kernel_size_-1) ? mixed_qkv.narrow(-1, seq_len-conv_kernel_size_+1, conv_kernel_size_-1): mixed_qkv; - conv_cache.index_put_({input_params.block_tables.select(1,0)}, conv_state.to(conv_cache.dtype())); + torch::Tensor conv_state = + (seq_len < conv_kernel_size_ - 1) + ? torch::pad(mixed_qkv, {0, conv_kernel_size_ - 1 - seq_len}) + : (seq_len > conv_kernel_size_ - 1) + ? mixed_qkv.narrow( + -1, seq_len - conv_kernel_size_ + 1, conv_kernel_size_ - 1) + : mixed_qkv; + conv_cache.index_put_({input_params.block_tables.select(1, 0)}, + conv_state.to(conv_cache.dtype())); torch::Tensor bias; - auto conv_output = torch::conv1d( - mixed_qkv, - conv_weight.unsqueeze(1).to(device), - bias, - /*stride=*/std::vector{1}, - /*padding=*/std::vector{3}, - /*dilation=*/std::vector{1}, - /*groups=*/static_cast(mixed_qkv.size(1)) - ); - mixed_qkv = torch::silu(conv_output.slice(2,0,seq_len)); + auto conv_output = + torch::conv1d(mixed_qkv, + conv_weight.unsqueeze(1).to(device), + bias, + /*stride=*/std::vector{1}, + /*padding=*/std::vector{3}, + /*dilation=*/std::vector{1}, + /*groups=*/static_cast(mixed_qkv.size(1))); + mixed_qkv = torch::silu(conv_output.slice(2, 0, seq_len)); } else { xllm::kernel::CausalConv1dUpdateParams params; params.x = mixed_qkv; params.conv_state = conv_cache; params.weight = conv_weight; - params.conv_state_indices = attn_metadata.block_table.select(1,0); + params.conv_state_indices = attn_metadata.block_table.select(1, 0); mixed_qkv = xllm::kernel::causal_conv1d_update(params); } @@ -367,8 +388,7 @@ torch::Tensor Qwen3NextGatedDeltaNetImpl::forward( torch::Tensor a_plus_dt = a_float + dt_bias_; torch::Tensor softplus_out = torch::nn::functional::softplus( a_plus_dt, - torch::nn::functional::SoftplusFuncOptions().beta(1.0).threshold(20.0) - ); + torch::nn::functional::SoftplusFuncOptions().beta(1.0).threshold(20.0)); g = -A_log_exp * softplus_out; g = g.to(a.dtype()).contiguous(); } else { @@ -389,16 +409,24 @@ torch::Tensor Qwen3NextGatedDeltaNetImpl::forward( processed_k = processed_k.repeat_interleave(repeat_times, 2); } if (attn_metadata.is_prefill) { - std::tie(core_attn_out, last_recurrent_state) = torch_chunk_gated_delta_rule(processed_q, processed_k, processed_v, g, beta); - ssm_cache.index_put_({input_params.block_tables.select(1,0)}, last_recurrent_state.to(ssm_cache.dtype())); + std::tie(core_attn_out, last_recurrent_state) = + torch_chunk_gated_delta_rule( + processed_q, processed_k, processed_v, g, beta); + ssm_cache.index_put_({input_params.block_tables.select(1, 0)}, + last_recurrent_state.to(ssm_cache.dtype())); } else { - auto ssm_state = torch::index_select(ssm_cache, 0, attn_metadata.block_table.select(1,0)); - std::tie(core_attn_out, last_recurrent_state) = torch_recurrent_gated_delta_rule(processed_q, processed_k, processed_v, g, beta, ssm_state); - ssm_cache.index_put_({attn_metadata.block_table.select(1,0)}, last_recurrent_state.to(ssm_cache.dtype())); + auto ssm_state = torch::index_select( + ssm_cache, 0, attn_metadata.block_table.select(1, 0)); + std::tie(core_attn_out, last_recurrent_state) = + torch_recurrent_gated_delta_rule( + processed_q, processed_k, processed_v, g, beta, ssm_state); + ssm_cache.index_put_({attn_metadata.block_table.select(1, 0)}, + last_recurrent_state.to(ssm_cache.dtype())); } auto z_reshaped = z.view({-1, z.size(-1)}); - auto core_attn_out_reshaped = core_attn_out.view({-1, core_attn_out.size(-1)}); + auto core_attn_out_reshaped = + core_attn_out.view({-1, core_attn_out.size(-1)}); auto norm_out = norm_->forward(core_attn_out_reshaped, z_reshaped); auto z_shape_og = z.sizes().vec(); norm_out = norm_out.view(z_shape_og); @@ -407,82 +435,90 @@ torch::Tensor Qwen3NextGatedDeltaNetImpl::forward( auto rearranged_norm = rearrange_merge(norm_out); rearranged_norm = reshape_qkvz_unpad(attn_metadata, rearranged_norm); auto attn_output = o_proj_->forward(rearranged_norm); - return attn_output; + return attn_output; } -torch::Tensor Qwen3NextGatedDeltaNetImpl::reshape_qkvz_unpad(const AttentionMetadata& attn_metadata, const torch::Tensor& padded_qkvz) { - if (!attn_metadata.is_prefill) { - return padded_qkvz; - } - std::vector valid_batches; - int64_t bs = attn_metadata.query_start_loc.size(0); - int64_t max_len = attn_metadata.max_query_len; - const auto& ori_seq_lens = attn_metadata.query_start_loc; - auto reshaped_qkvz = padded_qkvz.view({bs, max_len, -1}); - for (int64_t b = 0; b < bs; ++b) { - int64_t ori_len = ori_seq_lens[b].item(); - torch::Tensor valid_batch = reshaped_qkvz[b].slice(0, 0, ori_len); - valid_batches.push_back(valid_batch); - } - return torch::cat(valid_batches, 0).contiguous(); +torch::Tensor Qwen3NextGatedDeltaNetImpl::reshape_qkvz_unpad( + const AttentionMetadata& attn_metadata, + const torch::Tensor& padded_qkvz) { + if (!attn_metadata.is_prefill) { + return padded_qkvz; + } + std::vector valid_batches; + int64_t bs = attn_metadata.q_seq_lens.size(0); + int64_t max_len = attn_metadata.max_query_len; + const auto& ori_seq_lens = attn_metadata.q_seq_lens; + auto reshaped_qkvz = padded_qkvz.view({bs, max_len, -1}); + for (int64_t b = 0; b < bs; ++b) { + int64_t ori_len = ori_seq_lens[b].template item(); + torch::Tensor valid_batch = reshaped_qkvz[b].slice(0, 0, ori_len); + valid_batches.push_back(valid_batch); + } + return torch::cat(valid_batches, 0).contiguous(); } -torch::Tensor Qwen3NextGatedDeltaNetImpl::reshape_qkvz_with_pad(const AttentionMetadata& attn_metadata, const torch::Tensor& qkvz) { - int64_t bs = attn_metadata.query_start_loc.size(0); - int64_t max_len = attn_metadata.max_query_len; - const auto& start_loc = attn_metadata.query_start_loc; - if (!attn_metadata.is_prefill) { - return qkvz.view({bs, -1, qkvz.size(-1)}); - } - std::vector batches; - int64_t idx = 0; - for (int64_t b = 0; b < bs; ++b) { - int64_t cur_len = start_loc[b].item(); - torch::Tensor batch = qkvz.slice(0, idx, idx + cur_len).contiguous(); - idx = idx + cur_len; - if (batch.size(0) != max_len) { - batch = batch.size(0) > max_len - ? batch.slice(0, 0, max_len).contiguous() - : torch::nn::functional::pad( - batch, - torch::nn::functional::PadFuncOptions({0, 0, 0, max_len - batch.size(0)}) - ).contiguous(); - } - batches.push_back(batch); +torch::Tensor Qwen3NextGatedDeltaNetImpl::reshape_qkvz_with_pad( + const AttentionMetadata& attn_metadata, + const torch::Tensor& qkvz) { + int64_t bs = attn_metadata.q_seq_lens.size(0); + int64_t max_len = attn_metadata.max_query_len; + const auto& start_loc = attn_metadata.q_seq_lens; + if (!attn_metadata.is_prefill) { + return qkvz.view({bs, -1, qkvz.size(-1)}); + } + std::vector batches; + int64_t idx = 0; + for (int64_t b = 0; b < bs; ++b) { + int64_t cur_len = start_loc[b].template item(); + torch::Tensor batch = qkvz.slice(0, idx, idx + cur_len).contiguous(); + idx = idx + cur_len; + if (batch.size(0) != max_len) { + batch = batch.size(0) > max_len + ? batch.slice(0, 0, max_len).contiguous() + : torch::nn::functional::pad( + batch, + torch::nn::functional::PadFuncOptions( + {0, 0, 0, max_len - batch.size(0)})) + .contiguous(); } - auto ret = torch::stack(batches, 0).contiguous(); - return ret; + batches.push_back(batch); + } + auto ret = torch::stack(batches, 0).contiguous(); + return ret; } - std::tuple Qwen3NextGatedDeltaNetImpl::process_mixed_qkv(torch::Tensor& mixed_qkv) { - mixed_qkv = mixed_qkv.transpose(1,2); + mixed_qkv = mixed_qkv.transpose(1, 2); int64_t batch_size = mixed_qkv.size(0); int64_t seq_len = mixed_qkv.size(1); - std::vector split_sizes = {k_size_ / tp_size_, k_size_ / tp_size_, v_size_ / tp_size_}; + std::vector split_sizes = { + k_size_ / tp_size_, k_size_ / tp_size_, v_size_ / tp_size_}; auto processed_qkv = torch::split(mixed_qkv, split_sizes, 2); auto processed_q = processed_qkv[0]; auto processed_k = processed_qkv[1]; - auto processed_v = processed_qkv[2]; - processed_q = processed_q.view({batch_size, seq_len, num_k_heads_ / tp_size_, head_k_dim_}); - processed_k = processed_k.view({batch_size, seq_len, num_k_heads_ / tp_size_, head_k_dim_}); - processed_v = processed_v.view({batch_size, seq_len, num_v_heads_ / tp_size_, head_v_dim_}); + auto processed_v = processed_qkv[2]; + processed_q = processed_q.view( + {batch_size, seq_len, num_k_heads_ / tp_size_, head_k_dim_}); + processed_k = processed_k.view( + {batch_size, seq_len, num_k_heads_ / tp_size_, head_k_dim_}); + processed_v = processed_v.view( + {batch_size, seq_len, num_v_heads_ / tp_size_, head_v_dim_}); return std::make_tuple(processed_q, processed_k, processed_v); } -std::tuple +std::tuple Qwen3NextGatedDeltaNetImpl::process_qkvz_tensor(const torch::Tensor& qkvz) { - std::vector new_tensor_shape_qkvz = [&]() { std::vector dims; dims.push_back(qkvz.size(0)); if (qkvz.dim() >= 3) { - dims.push_back(qkvz.size(1)); + dims.push_back(qkvz.size(1)); } int64_t dim1 = num_k_heads_ / tp_size_; - int64_t dim2 = head_k_dim_ + head_k_dim_ + (head_v_dim_ + head_v_dim_) * num_v_heads_ / num_k_heads_; + int64_t dim2 = head_k_dim_ + head_k_dim_ + + (head_v_dim_ + head_v_dim_) * num_v_heads_ / num_k_heads_; dims.push_back(dim1); dims.push_back(dim2); @@ -490,16 +526,17 @@ Qwen3NextGatedDeltaNetImpl::process_qkvz_tensor(const torch::Tensor& qkvz) { }(); auto reshaped_qkvz = qkvz.view(new_tensor_shape_qkvz); - auto qkvz_split = torch::split(reshaped_qkvz, - {head_k_dim_, head_k_dim_, - num_v_heads_ * head_v_dim_ / num_k_heads_, - num_v_heads_ * head_v_dim_ / num_k_heads_}, reshaped_qkvz.dim()-1); - + auto qkvz_split = torch::split(reshaped_qkvz, + {head_k_dim_, + head_k_dim_, + num_v_heads_ * head_v_dim_ / num_k_heads_, + num_v_heads_ * head_v_dim_ / num_k_heads_}, + reshaped_qkvz.dim() - 1); + auto q = qkvz_split[0].contiguous(); auto k = qkvz_split[1].contiguous(); auto v = qkvz_split[2].contiguous(); auto z = qkvz_split[3].contiguous(); - v = v.view({v.size(0), v.size(1), -1, head_v_dim_}); z = z.view({z.size(0), z.size(1), -1, head_v_dim_}); @@ -507,10 +544,8 @@ Qwen3NextGatedDeltaNetImpl::process_qkvz_tensor(const torch::Tensor& qkvz) { return std::make_tuple(q, k, v, z); } - -std::tuple +std::tuple Qwen3NextGatedDeltaNetImpl::process_ba_tensor(const torch::Tensor& ba) { - std::vector new_tensor_shape_ba = [&]() { std::vector dims; dims.push_back(ba.size(0)); @@ -523,15 +558,17 @@ Qwen3NextGatedDeltaNetImpl::process_ba_tensor(const torch::Tensor& ba) { }(); auto reshaped_ba = ba.view(new_tensor_shape_ba); - auto ba_split = torch::split(reshaped_ba, - {num_v_heads_ / num_k_heads_, num_v_heads_ / num_k_heads_}, reshaped_ba.dim()-1); - + auto ba_split = + torch::split(reshaped_ba, + {num_v_heads_ / num_k_heads_, num_v_heads_ / num_k_heads_}, + reshaped_ba.dim() - 1); + auto b = ba_split[0].contiguous(); auto a = ba_split[1].contiguous(); b = b.reshape({b.size(0), b.size(1), num_v_heads_ / tp_size_}); a = a.reshape({a.size(0), a.size(1), num_v_heads_ / tp_size_}); - + return std::make_tuple(b, a); } @@ -539,12 +576,14 @@ void Qwen3NextGatedDeltaNetImpl::load_state_dict(const StateDict& state_dict) { const int64_t rank = rank_; const int64_t world_size = tp_size_; const int32_t shard_tensor_count = 3; - const std::vector shard_sizes = {k_size_ / tp_size_, k_size_ / tp_size_, v_size_ / tp_size_}; + const std::vector shard_sizes = { + k_size_ / tp_size_, k_size_ / tp_size_, v_size_ / tp_size_}; qkvz_proj_->load_state_dict(state_dict.get_dict_with_prefix("in_proj_qkvz.")); ba_proj_->load_state_dict(state_dict.get_dict_with_prefix("in_proj_ba.")); - + if (auto w = state_dict.get_tensor("conv1d.weight"); w.defined()) { - conv1d_->load_state_dict(StateDict({{"weight", w.squeeze(1)}}), shard_tensor_count, shard_sizes); + conv1d_->load_state_dict( + StateDict({{"weight", w.squeeze(1)}}), shard_tensor_count, shard_sizes); } o_proj_->load_state_dict(state_dict.get_dict_with_prefix("out_proj.")); if (auto w = state_dict.get_tensor("norm.weight"); w.defined()) { diff --git a/xllm/core/layers/qwen3_next_decoder_layer.cpp b/xllm/core/layers/qwen3_next_decoder_layer.cpp index 8135217cb..e156b8bb5 100644 --- a/xllm/core/layers/qwen3_next_decoder_layer.cpp +++ b/xllm/core/layers/qwen3_next_decoder_layer.cpp @@ -20,8 +20,9 @@ limitations under the License. namespace xllm { namespace layer { -Qwen3NextDecoderLayerImpl::Qwen3NextDecoderLayerImpl(const ModelContext& context, - int32_t layer_id) { +Qwen3NextDecoderLayerImpl::Qwen3NextDecoderLayerImpl( + const ModelContext& context, + int32_t layer_id) { const auto& model_args = context.get_model_args(); const auto& quant_args = context.get_quant_args(); const auto& parallel_args = context.get_parallel_args(); @@ -29,22 +30,26 @@ Qwen3NextDecoderLayerImpl::Qwen3NextDecoderLayerImpl(const ModelContext& context // Initialize attention layers if ((layer_id + 1) % 4 == 0) { attention_ = register_module( - "self_attn", - Qwen3NextAttention(model_args, quant_args, parallel_args, options, layer_id)); + "self_attn", + Qwen3NextAttention( + model_args, quant_args, parallel_args, options, layer_id)); } else { linear_attention_ = register_module( - "linear_attn", - Qwen3NextGatedDeltaNet(model_args, quant_args, parallel_args, options, layer_id)); + "linear_attn", + Qwen3NextGatedDeltaNet( + model_args, quant_args, parallel_args, options, layer_id)); } // Initialize norm layers input_norm_ = register_module( "input_layernorm", - Qwen3NextRMSNorm(model_args.hidden_size(), model_args.rms_norm_eps(), options)); + Qwen3NextRMSNorm( + model_args.hidden_size(), model_args.rms_norm_eps(), options)); post_norm_ = register_module( "post_attention_layernorm", - Qwen3NextRMSNorm(model_args.hidden_size(), model_args.rms_norm_eps(), options)); + Qwen3NextRMSNorm( + model_args.hidden_size(), model_args.rms_norm_eps(), options)); // Initialize mlp auto mlp_only_layers = model_args.mlp_only_layers(); @@ -66,7 +71,7 @@ Qwen3NextDecoderLayerImpl::Qwen3NextDecoderLayerImpl(const ModelContext& context false, model_args.hidden_act(), quant_args, - parallel_args, + parallel_args.tp_group_, options)); } } @@ -75,7 +80,8 @@ void Qwen3NextDecoderLayerImpl::load_state_dict(const StateDict& state_dict) { if (attention_) { attention_->load_state_dict(state_dict.get_dict_with_prefix("self_attn.")); } else { - linear_attention_->load_state_dict(state_dict.get_dict_with_prefix("linear_attn.")); + linear_attention_->load_state_dict( + state_dict.get_dict_with_prefix("linear_attn.")); } input_norm_->load_state_dict( state_dict.get_dict_with_prefix("input_layernorm.")); @@ -101,8 +107,8 @@ torch::Tensor Qwen3NextDecoderLayerImpl::forward( // Attention if (attention_) { x = attention_->forward(positions, x, attn_metadata, kv_cache); - } else { - //x = x; + } else { + // x = x; x = linear_attention_->forward(x, attn_metadata, kv_cache, input_params); } diff --git a/xllm/core/layers/qwen3_next_decoder_layer.h b/xllm/core/layers/qwen3_next_decoder_layer.h index 8dc855a2d..18bb8170f 100644 --- a/xllm/core/layers/qwen3_next_decoder_layer.h +++ b/xllm/core/layers/qwen3_next_decoder_layer.h @@ -56,6 +56,7 @@ class Qwen3NextDecoderLayerImpl : public torch::nn::Module { Qwen3NextRMSNorm input_norm_{nullptr}; Qwen3NextRMSNorm post_norm_{nullptr}; }; +TORCH_MODULE(Qwen3NextDecoderLayer); } // namespace layer } // namespace xllm diff --git a/xllm/models/llm/qwen3_next.h b/xllm/models/llm/qwen3_next.h index 72951fbe6..1f8758922 100644 --- a/xllm/models/llm/qwen3_next.h +++ b/xllm/models/llm/qwen3_next.h @@ -20,7 +20,6 @@ limitations under the License. #include #include "core/framework/model_context.h" -#include "core/layers/common/layer_utils.h" #include "core/layers/qwen3_next_decoder_layer.h" #include "llm_model_base.h" @@ -29,31 +28,6 @@ namespace xllm { using torch::indexing::None; using ISlice = torch::indexing::Slice; -class Qwen3NextDecoderLayerImpl : public torch::nn::Module { - public: - Qwen3NextDecoderLayerImpl(const ModelContext& context, const int32_t i) { - // register submodules - decoder_layer_ = register_module("decoder_layer", - layer::Qwen3NextDecoderLayer(context, i)); - } - - torch::Tensor forward(torch::Tensor& x, - torch::Tensor& positions, - const layer::AttentionMetadata& attn_metadata, - KVCache& kv_cache, - const ModelInputParams& input_params) { - return decoder_layer_(x, positions, attn_metadata, kv_cache, input_params); - } - - void load_state_dict(const StateDict& state_dict) { - decoder_layer_->load_state_dict(state_dict); - } - - private: - layer::Qwen3NextDecoderLayer decoder_layer_{nullptr}; -}; -TORCH_MODULE(Qwen3NextDecoderLayer); - class Qwen3NextModelImpl : public torch::nn::Module { public: Qwen3NextModelImpl(const ModelContext& context) @@ -75,10 +49,10 @@ class Qwen3NextModelImpl : public torch::nn::Module { xllm::layer::Qwen3NextRMSNorm( model_args.hidden_size(), model_args.rms_norm_eps(), options)); #if defined(USE_NPU_TORCH) - embed_tokens_ = layer::WordEmbedding(model_args.vocab_size(), - model_args.hidden_size(), - context.get_parallel_args(), - options); + embed_tokens_ = layer::WordEmbedding(model_args.vocab_size(), + model_args.hidden_size(), + context.get_parallel_args(), + options); #else for (auto i = 0; i < FLAGS_micro_batch_num; i++) { npu_embed_tokens_.push_back(layer::NpuWordEmbedding(context)); @@ -90,7 +64,7 @@ class Qwen3NextModelImpl : public torch::nn::Module { options.dtype().toScalarType(), /*mask_value=*/mask_value); for (int32_t i = 0; i < model_args.n_layers(); ++i) { - auto block = Qwen3NextDecoderLayer(context, i); + auto block = layer::Qwen3NextDecoderLayer(context, i); layers_.push_back(block); blocks_->push_back(block); } @@ -123,7 +97,7 @@ class Qwen3NextModelImpl : public torch::nn::Module { // Create attention mask torch::Tensor attn_mask; max_seq_len_ = std::max(input_params.kv_max_seq_len, max_seq_len_); - + if (FLAGS_enable_chunked_prefill) { int num_sequences = input_params.num_sequences; if (num_sequences > 0) { @@ -131,12 +105,12 @@ class Qwen3NextModelImpl : public torch::nn::Module { req_mask_vec.reserve(num_sequences); for (int j = 0; j < num_sequences; j++) { - auto mask = attn_mask_.gen_append_mask( - input_params.q_seq_lens_vec[j], - input_params.kv_seq_lens_vec[j], - max_seq_len_, - dtype_, - device_); + auto mask = + attn_mask_.gen_append_mask(input_params.q_seq_lens_vec[j], + input_params.kv_seq_lens_vec[j], + max_seq_len_, + dtype_, + device_); req_mask_vec.emplace_back(mask); } attn_mask = torch::cat(req_mask_vec, 0); @@ -145,8 +119,8 @@ class Qwen3NextModelImpl : public torch::nn::Module { attn_mask = attn_mask_.get_attn_mask(max_seq_len_, dtype_, device_); } - layer::AttentionMetadata attn_metadata = - layer::AttentionMetadata::build(input_params, input_params.q_max_seq_len > 1, attn_mask); + layer::AttentionMetadata attn_metadata = layer::AttentionMetadata::build( + input_params, input_params.q_max_seq_len > 1, attn_mask); torch::Tensor h = embed_tokens_(tokens); for (size_t i = 0; i < layers_.size(); i++) { auto& layer = layers_[i]; @@ -160,7 +134,7 @@ class Qwen3NextModelImpl : public torch::nn::Module { // load the weight from the checkpoint void load_state_dict(const StateDict& state_dict) { embed_tokens_->load_state_dict( - state_dict.get_dict_with_prefix("embed_tokens.")); + state_dict.get_dict_with_prefix("embed_tokens.")); for (int i = 0; i < layers_.size(); i++) { layers_[i]->load_state_dict( state_dict.get_dict_with_prefix("layers." + std::to_string(i) + ".")); @@ -177,10 +151,9 @@ class Qwen3NextModelImpl : public torch::nn::Module { npu_embed_tokens_ = word_embedding; } - private: torch::nn::ModuleList blocks_{nullptr}; - std::vector layers_; + std::vector layers_; int32_t max_seq_len_ = 0; int32_t dp_rank_; int32_t rank_; @@ -251,7 +224,8 @@ class Qwen3NextForCausalLMImpl : public torch::nn::Module { #if defined(USE_NPU_TORCH) lm_head_->load_state_dict(state_dict->get_dict_with_prefix("lm_head.")); #else - npu_lm_head_->load_state_dict(state_dict->get_dict_with_prefix("lm_head.")); + npu_lm_head_->load_state_dict( + state_dict->get_dict_with_prefix("lm_head.")); #endif } @@ -291,7 +265,6 @@ class Qwen3NextForCausalLMImpl : public torch::nn::Module { layer::NpuLmHead npu_lm_head_{nullptr}; layer::LmHead lm_head_{nullptr}; Qwen3NextModel model_{nullptr}; - }; TORCH_MODULE(Qwen3NextForCausalLM); @@ -340,10 +313,10 @@ REGISTER_MODEL_ARGS(qwen3_next, [&] { LOAD_ARG_OR(linear_num_value_heads, "linear_num_value_heads", 32); LOAD_ARG_OR(linear_value_head_dim, "linear_value_head_dim", 128); LOAD_ARG_OR(partial_rotary_factor, "partial_rotary_factor", 0.25f); - LOAD_ARG_OR(shared_expert_intermediate_size, "shared_expert_intermediate_size", 512); + LOAD_ARG_OR( + shared_expert_intermediate_size, "shared_expert_intermediate_size", 512); SET_ARG(stop_token_ids, std::unordered_set({args->eos_token_id()})); }); } // namespace xllm - From dc8dd64d3d3f8ab7d6d70c89fe9789024a2a9e24 Mon Sep 17 00:00:00 2001 From: shenxiaolong <1193789086@qq.com> Date: Sun, 1 Mar 2026 17:39:17 +0800 Subject: [PATCH 06/13] bugfix: Fix build errors with qwen3-next model interfaces. --- xllm/core/layers/qwen3_next_decoder_layer.cpp | 4 +- xllm/models/llm/qwen3_next.h | 82 +++++++++++-------- 2 files changed, 51 insertions(+), 35 deletions(-) diff --git a/xllm/core/layers/qwen3_next_decoder_layer.cpp b/xllm/core/layers/qwen3_next_decoder_layer.cpp index e156b8bb5..476e2a99d 100644 --- a/xllm/core/layers/qwen3_next_decoder_layer.cpp +++ b/xllm/core/layers/qwen3_next_decoder_layer.cpp @@ -36,8 +36,7 @@ Qwen3NextDecoderLayerImpl::Qwen3NextDecoderLayerImpl( } else { linear_attention_ = register_module( "linear_attn", - Qwen3NextGatedDeltaNet( - model_args, quant_args, parallel_args, options, layer_id)); + Qwen3NextGatedDeltaNet(model_args, quant_args, parallel_args, options)); } // Initialize norm layers @@ -70,6 +69,7 @@ Qwen3NextDecoderLayerImpl::Qwen3NextDecoderLayerImpl( false, false, model_args.hidden_act(), + /*enable_result_reduction=*/true, quant_args, parallel_args.tp_group_, options)); diff --git a/xllm/models/llm/qwen3_next.h b/xllm/models/llm/qwen3_next.h index 1f8758922..939bd0062 100644 --- a/xllm/models/llm/qwen3_next.h +++ b/xllm/models/llm/qwen3_next.h @@ -82,10 +82,10 @@ class Qwen3NextModelImpl : public torch::nn::Module { // tokens: [num_tokens] // positions: [num_tokens] token pos in the sequence - torch::Tensor forward(torch::Tensor tokens, - torch::Tensor positions, - std::vector& kv_caches, - const ModelInputParams& input_params) { + ModelOutput forward(torch::Tensor tokens, + torch::Tensor positions, + std::vector& kv_caches, + const ModelInputParams& input_params) { if (dp_size_ > 1) { if (tokens.sizes() == 0) { tokens = torch::tensor({1}).to(torch::kInt32).to(device_); @@ -119,15 +119,15 @@ class Qwen3NextModelImpl : public torch::nn::Module { attn_mask = attn_mask_.get_attn_mask(max_seq_len_, dtype_, device_); } - layer::AttentionMetadata attn_metadata = layer::AttentionMetadata::build( - input_params, input_params.q_max_seq_len > 1, attn_mask); + layer::AttentionMetadata attn_metadata = + layer::AttentionMetadataBuilder::build(input_params, attn_mask); torch::Tensor h = embed_tokens_(tokens); for (size_t i = 0; i < layers_.size(); i++) { auto& layer = layers_[i]; h = layer(h, positions, attn_metadata, kv_caches[i], input_params); } h = norm_(h); - return h; + return ModelOutput(h); #endif } @@ -142,14 +142,28 @@ class Qwen3NextModelImpl : public torch::nn::Module { norm_->load_state_dict(state_dict.get_dict_with_prefix("norm.")); } - std::vector get_word_embedding() { - return {npu_embed_tokens_}; +#if defined(USE_NPU) && defined(USE_NPU_TORCH) + layer::WordEmbedding get_word_embedding() { return embed_tokens_; } + + void set_word_embedding(layer::WordEmbedding& word_embedding) { + embed_tokens_ = word_embedding; + } +#elif defined(USE_NPU) + layer::NpuWordEmbedding get_npu_word_embedding() { + if (npu_embed_tokens_.empty()) { + return nullptr; + } + return npu_embed_tokens_.front(); } - void set_word_embedding( - std::vector& word_embedding) { - npu_embed_tokens_ = word_embedding; + void set_npu_word_embedding(layer::NpuWordEmbedding& word_embedding) { + if (npu_embed_tokens_.empty()) { + npu_embed_tokens_.push_back(word_embedding); + return; + } + npu_embed_tokens_[0] = word_embedding; } +#endif private: torch::nn::ModuleList blocks_{nullptr}; @@ -180,15 +194,7 @@ class Qwen3NextForCausalLMImpl : public torch::nn::Module { Qwen3NextForCausalLMImpl(const ModelContext& context) { model_ = register_module("model", Qwen3NextModel(context)); #if defined(USE_NPU) && defined(USE_NPU_TORCH) - lm_head_ = - register_module("lm_head", - layer::LmHead(context.get_model_args().hidden_size(), - context.get_model_args().vocab_size(), - /*bias=*/false, - /*gather_output=*/true, - QuantArgs{}, - context.get_parallel_args(), - context.get_tensor_options())); + lm_head_ = register_module("lm_head", layer::LmHead(context)); #else npu_lm_head_ = register_module("lm_head", layer::NpuLmHead(context)); #endif @@ -197,11 +203,11 @@ class Qwen3NextForCausalLMImpl : public torch::nn::Module { // tokens: [num_tokens] // positions: [num_tokens] token pos in the sequence // returns: [num_tokens, hidden_size] - torch::Tensor forward(const std::vector& tokens, - const std::vector& positions, - std::vector& kv_caches, - const std::vector& input_params) { - return model_(tokens[0], positions[0], kv_caches, input_params[0]); + ModelOutput forward(const torch::Tensor& tokens, + const torch::Tensor& positions, + std::vector& kv_caches, + const ModelInputParams& input_params) { + return model_(tokens, positions, kv_caches, input_params); } // hidden_states: [num_tokens, hidden_size] @@ -245,20 +251,30 @@ class Qwen3NextForCausalLMImpl : public torch::nn::Module { } virtual void update_expert_weight(int32_t layer_id) { return; } -#if defined(USE_NPU) - - layer::NpuLmHead get_lm_head() { return npu_lm_head_; } +#if defined(USE_NPU) && defined(USE_NPU_TORCH) + layer::LmHead get_lm_head() { return lm_head_; } - void set_lm_head(layer::NpuLmHead& head) { npu_lm_head_ = head; } + void set_lm_head(layer::LmHead& head) { lm_head_ = head; } - std::vector get_word_embedding() { + layer::WordEmbedding get_word_embedding() { return model_->get_word_embedding(); } - void set_word_embedding( - std::vector& word_embedding) { + void set_word_embedding(layer::WordEmbedding& word_embedding) { model_->set_word_embedding(word_embedding); } +#elif defined(USE_NPU) + layer::NpuLmHead get_npu_lm_head() { return npu_lm_head_; } + + void set_npu_lm_head(layer::NpuLmHead& head) { npu_lm_head_ = head; } + + layer::NpuWordEmbedding get_npu_word_embedding() { + return model_->get_npu_word_embedding(); + } + + void set_npu_word_embedding(layer::NpuWordEmbedding& word_embedding) { + model_->set_npu_word_embedding(word_embedding); + } #endif private: From 0589a7f4d49a9e22f1c40b288a5fd8a5ac30d38b Mon Sep 17 00:00:00 2001 From: shenxiaolong <1193789086@qq.com> Date: Mon, 2 Mar 2026 10:12:22 +0800 Subject: [PATCH 07/13] feat: support qwen3-next inference on npu. feat: adjust cache allocation based on attention settings(support ssm_cache). feat: update torch_npu_ops commit. bugfix:layer CMake fix. bugfix:add model arguments for enhanced configuration. bugfix: QKV linear load fix. bugfix: fallback pg_comm. bugfix: set q_seq_len in prefill. bugfix: handle optional finished tensor in moe gating and FusedMoE implementations. bugfix: ensure activation output is correctly assigned in FusedMoE forward pass. --- .../distributed_runtime/worker_service.cpp | 58 ++++++++++++++++--- .../parallel_state/npu_process_group.cpp | 15 +++++ .../npu/npu_moe_gating_topk_softmax.cpp | 4 +- xllm/core/layers/CMakeLists.txt | 1 + xllm/core/layers/common/CMakeLists.txt | 4 +- .../common/attention_metadata_builder.cpp | 9 +++ .../layers/common/qwen3_next_attention.cpp | 2 +- xllm/core/layers/npu_torch/fused_moe.cpp | 3 +- xllm/core/layers/qwen3_next_decoder_layer.cpp | 2 +- xllm/core/runtime/worker_impl.cpp | 16 +++-- xllm/models/llm/qwen3_next.h | 10 ++++ 11 files changed, 106 insertions(+), 18 deletions(-) diff --git a/xllm/core/distributed_runtime/worker_service.cpp b/xllm/core/distributed_runtime/worker_service.cpp index f647393ce..1156e6a96 100644 --- a/xllm/core/distributed_runtime/worker_service.cpp +++ b/xllm/core/distributed_runtime/worker_service.cpp @@ -286,6 +286,7 @@ void WorkerService::AllocateKVCache( threadpool_->schedule([this, controller, request, response, done]() mutable { brpc::ClosureGuard done_guard(done); std::vector> kv_cache_shape; +<<<<<<< HEAD // Reserve for key, value, and optionally index shape kv_cache_shape.reserve(3); kv_cache_shape.emplace_back( @@ -299,6 +300,31 @@ void WorkerService::AllocateKVCache( kv_cache_shape.emplace_back( std::vector(request->kv_cache_shape().index_shape().begin(), request->kv_cache_shape().index_shape().end())); +======= + const bool has_index_shape = request->index_shape_size() > 0; + const bool has_conv_shape = request->conv_shape_size() > 0; + const bool has_ssm_shape = request->ssm_shape_size() > 0; + CHECK(!(has_index_shape && (has_conv_shape || has_ssm_shape))) + << "KVCacheShape does not support index_shape with conv/ssm shapes " + << "simultaneously."; + // Reserve for key, value, and optional extra shapes + kv_cache_shape.reserve(has_conv_shape || has_ssm_shape ? 4 : 3); + kv_cache_shape.emplace_back(std::vector( + request->key_shape().begin(), request->key_shape().end())); + kv_cache_shape.emplace_back(std::vector( + request->value_shape().begin(), request->value_shape().end())); + // add index shape if exists + if (has_index_shape) { + kv_cache_shape.emplace_back(std::vector( + request->index_shape().begin(), request->index_shape().end())); + } else if (has_conv_shape || has_ssm_shape) { + CHECK(has_conv_shape && has_ssm_shape) + << "conv_shape and ssm_shape must be provided together."; + kv_cache_shape.emplace_back(std::vector( + request->conv_shape().begin(), request->conv_shape().end())); + kv_cache_shape.emplace_back(std::vector( + request->ssm_shape().begin(), request->ssm_shape().end())); +>>>>>>> 922b77e (feat: support qwen3-next inference on npu.) } auto future = worker_->allocate_kv_cache_async(kv_cache_shape); @@ -316,18 +342,34 @@ void WorkerService::AllocateKVCacheWithTransfer( threadpool_->schedule([this, controller, req, resp, done]() mutable { brpc::ClosureGuard done_guard(done); std::vector> kv_cache_shape; - kv_cache_shape.reserve(2); + const auto& shape_req = req->kv_cache_shape(); + const bool has_index_shape = shape_req.index_shape_size() > 0; + const bool has_conv_shape = shape_req.conv_shape_size() > 0; + const bool has_ssm_shape = shape_req.ssm_shape_size() > 0; + CHECK(!(has_index_shape && (has_conv_shape || has_ssm_shape))) + << "KVCacheShape does not support index_shape with conv/ssm shapes " + << "simultaneously."; + kv_cache_shape.reserve(has_conv_shape || has_ssm_shape ? 4 : 3); kv_cache_shape.emplace_back( - std::vector(req->kv_cache_shape().key_shape().begin(), - req->kv_cache_shape().key_shape().end())); + std::vector(shape_req.key_shape().begin(), + shape_req.key_shape().end())); kv_cache_shape.emplace_back( - std::vector(req->kv_cache_shape().value_shape().begin(), - req->kv_cache_shape().value_shape().end())); + std::vector(shape_req.value_shape().begin(), + shape_req.value_shape().end())); // add index shape if exists - if (req->kv_cache_shape().index_shape_size() > 0) { + if (has_index_shape) { + kv_cache_shape.emplace_back( + std::vector(shape_req.index_shape().begin(), + shape_req.index_shape().end())); + } else if (has_conv_shape || has_ssm_shape) { + CHECK(has_conv_shape && has_ssm_shape) + << "conv_shape and ssm_shape must be provided together."; + kv_cache_shape.emplace_back( + std::vector(shape_req.conv_shape().begin(), + shape_req.conv_shape().end())); kv_cache_shape.emplace_back( - std::vector(req->kv_cache_shape().index_shape().begin(), - req->kv_cache_shape().index_shape().end())); + std::vector(shape_req.ssm_shape().begin(), + shape_req.ssm_shape().end())); } auto future = diff --git a/xllm/core/framework/parallel_state/npu_process_group.cpp b/xllm/core/framework/parallel_state/npu_process_group.cpp index d8ff32401..b5d6405fa 100644 --- a/xllm/core/framework/parallel_state/npu_process_group.cpp +++ b/xllm/core/framework/parallel_state/npu_process_group.cpp @@ -131,6 +131,14 @@ void ProcessGroupImpl::allgather(const torch::Tensor& input, check_input(input); torch::DeviceGuard device_guard(device()); + if (pg_) { + std::vector input_tensors = {input}; + std::vector> output_tensors = {outputs}; + pg_->allgather(output_tensors, input_tensors)->wait(); + return; + } + CHECK(comm_ != nullptr) << "HCCL comm is not initialized."; + torch::Tensor flattened_output = flatten_for_scatter_gather(outputs); const auto count = input.numel(); @@ -170,6 +178,13 @@ void ProcessGroupImpl::allreduce(torch::Tensor& input) { check_input(input); torch::DeviceGuard device_guard(device()); + if (pg_) { + std::vector input_tensors = {input}; + pg_->allreduce(input_tensors)->wait(); + return; + } + CHECK(comm_ != nullptr) << "HCCL comm is not initialized."; + const auto count = input.numel(); const auto data_type = to_hccl_data_type(input); diff --git a/xllm/core/kernels/npu/npu_moe_gating_topk_softmax.cpp b/xllm/core/kernels/npu/npu_moe_gating_topk_softmax.cpp index d46909ae9..2b21d753f 100644 --- a/xllm/core/kernels/npu/npu_moe_gating_topk_softmax.cpp +++ b/xllm/core/kernels/npu/npu_moe_gating_topk_softmax.cpp @@ -24,8 +24,10 @@ std::tuple apply_moe_gating_topk_softmax(const torch::Tensor& x, const std::optional& finished, int k) { + const torch::Tensor finished_tensor = + finished.has_value() ? finished.value() : torch::Tensor(); return at_npu::native::custom_ops::npu_moe_gating_top_k_softmax( - x, finished.value(), k); + x, finished_tensor, k); } } // namespace xllm::kernel::npu diff --git a/xllm/core/layers/CMakeLists.txt b/xllm/core/layers/CMakeLists.txt index 250e7763d..af1a22eab 100644 --- a/xllm/core/layers/CMakeLists.txt +++ b/xllm/core/layers/CMakeLists.txt @@ -52,6 +52,7 @@ cc_library( qwen3_moe_decoder_layer.cpp qwen3_next_decoder_layer.cpp DEPS + $<$:npu_layers> $<$:ilu_layers> :common_layers :parallel_state diff --git a/xllm/core/layers/common/CMakeLists.txt b/xllm/core/layers/common/CMakeLists.txt index 5257b8216..2295595cc 100755 --- a/xllm/core/layers/common/CMakeLists.txt +++ b/xllm/core/layers/common/CMakeLists.txt @@ -14,7 +14,7 @@ cc_library( rms_norm.h rotary_embedding.h rotary_embedding_util.h - $<$,$>>:fused_moe.h> + $<$,$,$>>:fused_moe.h> dense_mlp.h linear.h word_embedding_impl.h @@ -36,7 +36,7 @@ cc_library( rms_norm.cpp rotary_embedding.cpp rotary_embedding_util.cpp - $<$,$>>:fused_moe.cpp> + $<$,$,$>>:fused_moe.cpp> dense_mlp.cpp linear.cpp word_embedding_impl.cpp diff --git a/xllm/core/layers/common/attention_metadata_builder.cpp b/xllm/core/layers/common/attention_metadata_builder.cpp index 808f61897..d5ba5fd5f 100644 --- a/xllm/core/layers/common/attention_metadata_builder.cpp +++ b/xllm/core/layers/common/attention_metadata_builder.cpp @@ -97,6 +97,15 @@ AttentionMetadata AttentionMetadataBuilder::build( attn_metadata.q_seq_lens = torch::diff(params.q_seq_lens); // q seqlens #endif } +#if defined(USE_NPU) + // Ensure per-sequence lengths are available for NPU kernels in prefill too. + if (params.kv_seq_lens.defined()) { + attn_metadata.kv_seq_lens = params.kv_seq_lens; + } + if (params.q_seq_lens.defined()) { + attn_metadata.q_seq_lens = params.q_seq_lens; + } +#endif attn_metadata.is_dummy = (params.q_max_seq_len == 0); if (attn_metadata.is_dummy) { diff --git a/xllm/core/layers/common/qwen3_next_attention.cpp b/xllm/core/layers/common/qwen3_next_attention.cpp index 3b3c5e508..2acaf53b3 100644 --- a/xllm/core/layers/common/qwen3_next_attention.cpp +++ b/xllm/core/layers/common/qwen3_next_attention.cpp @@ -179,7 +179,7 @@ torch::Tensor Qwen3NextAttentionImpl::forward( } void Qwen3NextAttentionImpl::load_state_dict(const StateDict& state_dict) { - qkv_proj_->load_state_dict(state_dict); + qkv_proj_->load_state_dict(state_dict, {"q_proj.", "k_proj.", "v_proj."}); o_proj_->load_state_dict(state_dict.get_dict_with_prefix("o_proj.")); if (auto w = state_dict.get_tensor("q_norm.weight"); w.defined()) { q_norm_->load_state_dict(StateDict({{"weight", w}})); diff --git a/xllm/core/layers/npu_torch/fused_moe.cpp b/xllm/core/layers/npu_torch/fused_moe.cpp index 35996c523..3e1ad0f2c 100644 --- a/xllm/core/layers/npu_torch/fused_moe.cpp +++ b/xllm/core/layers/npu_torch/fused_moe.cpp @@ -185,7 +185,7 @@ torch::Tensor FusedMoEImpl::select_experts( // prepare the parameters for select_experts xllm::kernel::MoeFusedTopkParams moe_active_topk_params; moe_active_topk_params.input = router_logits_2d; - moe_active_topk_params.finished = std::nullopt; + moe_active_topk_params.finished = torch::Tensor(); moe_active_topk_params.topk = topk_; auto [topk_weights, topk_ids] = xllm::kernel::moe_active_topk(moe_active_topk_params); @@ -273,6 +273,7 @@ torch::Tensor FusedMoEImpl::forward_expert( activation_params.act_mode = hidden_act_; activation_params.is_gated = is_gated_; xllm::kernel::active(activation_params); + act_out = activation_params.output; // Step 6: group gemm 2 torch::Tensor gemm2_out = create_group_gemm_output(act_out, diff --git a/xllm/core/layers/qwen3_next_decoder_layer.cpp b/xllm/core/layers/qwen3_next_decoder_layer.cpp index 476e2a99d..b0f442f2d 100644 --- a/xllm/core/layers/qwen3_next_decoder_layer.cpp +++ b/xllm/core/layers/qwen3_next_decoder_layer.cpp @@ -54,7 +54,7 @@ Qwen3NextDecoderLayerImpl::Qwen3NextDecoderLayerImpl( auto mlp_only_layers = model_args.mlp_only_layers(); if ((std::count(mlp_only_layers.begin(), mlp_only_layers.end(), layer_id) == 0) && - model_args.num_experts() > 0 && + model_args.n_routed_experts() > 0 && (layer_id + 1) % model_args.decoder_sparse_step() == 0) { moe_mlp_ = register_module("mlp", FusedMoE(model_args, diff --git a/xllm/core/runtime/worker_impl.cpp b/xllm/core/runtime/worker_impl.cpp index f48160672..23d54bd1e 100644 --- a/xllm/core/runtime/worker_impl.cpp +++ b/xllm/core/runtime/worker_impl.cpp @@ -155,7 +155,13 @@ bool WorkerImpl::allocate_kv_cache( const std::vector>& kv_cache_shape) { CHECK(model_ != nullptr) << "Model is not initialized."; CHECK(kv_caches_.empty()) << "KV caches are already initialized."; - const bool enable_linear_attention = context_.get_model_args().full_attention_interval() > 1; + const bool enable_linear_attention = + context_.get_model_args().full_attention_interval() > 1; + const bool enable_lighting_indexer = + context_.get_model_args().index_n_heads() > 0; + CHECK(!(enable_linear_attention && enable_lighting_indexer)) + << "KVCache does not support linear attention and lighting indexer " + << "simultaneously."; // Check if KV cache quantization is enabled // "auto" (default): cache dtype aligns with model dtype (no quantization) @@ -180,8 +186,6 @@ bool WorkerImpl::allocate_kv_cache( // create a KVCache for each layer const int64_t num_layers = get_num_layers(); - const bool enable_lighting_indexer = - context_.get_model_args().index_n_heads() > 0; kv_caches_.reserve(num_layers); if (FLAGS_enable_xtensor) { @@ -280,8 +284,12 @@ bool WorkerImpl::allocate_kv_cache( index_cache, key_cache_scale, value_cache_scale); + } else if (enable_linear_attention) { + kv_caches_.emplace_back(key_cache, value_cache, conv_cache, ssm_cache); + } else if (enable_lighting_indexer) { + kv_caches_.emplace_back(key_cache, value_cache, index_cache); } else { - kv_caches_.emplace_back(key_cache, value_cache, index_cache, conv_cache, ssm_cache); + kv_caches_.emplace_back(key_cache, value_cache); } } } diff --git a/xllm/models/llm/qwen3_next.h b/xllm/models/llm/qwen3_next.h index 939bd0062..094131747 100644 --- a/xllm/models/llm/qwen3_next.h +++ b/xllm/models/llm/qwen3_next.h @@ -332,6 +332,16 @@ REGISTER_MODEL_ARGS(qwen3_next, [&] { LOAD_ARG_OR( shared_expert_intermediate_size, "shared_expert_intermediate_size", 512); + // MoE compatibility with fused_moe implementation. + LOAD_ARG_OR(n_routed_experts, "n_routed_experts", args->num_experts()); + SET_ARG(n_shared_experts, + args->shared_expert_intermediate_size() > 0 ? 1 : 0); + SET_ARG(scoring_func, "softmax"); + SET_ARG(topk_method, ""); + SET_ARG(n_group, -1); + SET_ARG(topk_group, 0); + SET_ARG(routed_scaling_factor, 1.0); + SET_ARG(stop_token_ids, std::unordered_set({args->eos_token_id()})); }); From 8f163759b39d5fd120cd9791a4e013e6fdbe509e Mon Sep 17 00:00:00 2001 From: xuyexiong <809602657@qq.com> Date: Wed, 4 Mar 2026 19:21:11 +0800 Subject: [PATCH 08/13] Apply suggestions from code review Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- xllm/core/distributed_runtime/llm_engine.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xllm/core/distributed_runtime/llm_engine.cpp b/xllm/core/distributed_runtime/llm_engine.cpp index 8726dfa79..7a79ad3e7 100644 --- a/xllm/core/distributed_runtime/llm_engine.cpp +++ b/xllm/core/distributed_runtime/llm_engine.cpp @@ -431,8 +431,8 @@ Engine::KVCacheCapacity LLMEngine::estimate_kv_cache_capacity() { } } if (args_.linear_num_value_heads() > 0) { - int64_t head_k_dim = args_.linear_value_head_dim(); - int64_t head_v_dim = args_.linear_key_head_dim(); + int64_t head_k_dim = args_.linear_key_head_dim(); + int64_t head_v_dim = args_.linear_value_head_dim(); int64_t linear_ssm_slot_size = dtype_size * n_local_linear_v_heads_ * head_k_dim * head_v_dim; int64_t linear_conv_slot_size = dtype_size * (head_k_dim * n_local_linear_k_heads_ * 2 + head_v_dim * n_local_linear_v_heads_) * (args_.linear_conv_kernel_dim() -1 ); linear_slot_size = linear_ssm_slot_size + linear_conv_slot_size; From b06794fe0a2fd04217463d269286dbc74acdfcf9 Mon Sep 17 00:00:00 2001 From: xuyexiong Date: Thu, 5 Mar 2026 16:51:44 +0800 Subject: [PATCH 09/13] BugFix: reduce device memory usage --- xllm/core/layers/npu_torch/fused_moe.cpp | 6 ++---- xllm/models/llm/qwen3_next.h | 2 ++ 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/xllm/core/layers/npu_torch/fused_moe.cpp b/xllm/core/layers/npu_torch/fused_moe.cpp index 3e1ad0f2c..b5c1db0ea 100644 --- a/xllm/core/layers/npu_torch/fused_moe.cpp +++ b/xllm/core/layers/npu_torch/fused_moe.cpp @@ -262,10 +262,8 @@ torch::Tensor FusedMoEImpl::forward_expert( gemm1_out = xllm::kernel::group_gemm(group_gemm_params); } - // Step 5: activation or scaled quantization(fused with activation) - torch::Tensor act_out = - is_gated_ ? gemm1_out.slice(1, 0, gemm1_out.size(1) / 2).contiguous() - : gemm1_out; + // Step 5: activation + torch::Tensor act_out; xllm::kernel::ActivationParams activation_params; activation_params.input = gemm1_out; diff --git a/xllm/models/llm/qwen3_next.h b/xllm/models/llm/qwen3_next.h index 094131747..959d4707e 100644 --- a/xllm/models/llm/qwen3_next.h +++ b/xllm/models/llm/qwen3_next.h @@ -86,6 +86,8 @@ class Qwen3NextModelImpl : public torch::nn::Module { torch::Tensor positions, std::vector& kv_caches, const ModelInputParams& input_params) { + // Disable gradient computation to reduce memory usage during inference + torch::NoGradGuard no_grad; if (dp_size_ > 1) { if (tokens.sizes() == 0) { tokens = torch::tensor({1}).to(torch::kInt32).to(device_); From 86ac26a573caa8fb2035f1e82d91d2e6fef70d9e Mon Sep 17 00:00:00 2001 From: xuyexiong Date: Thu, 12 Mar 2026 17:12:51 +0800 Subject: [PATCH 10/13] Bugfix: qwen3-next only need to update self-attention's params --- xllm/core/runtime/acl_graph_executor_impl.cpp | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/xllm/core/runtime/acl_graph_executor_impl.cpp b/xllm/core/runtime/acl_graph_executor_impl.cpp index e0bb2453b..0d38ddc6d 100644 --- a/xllm/core/runtime/acl_graph_executor_impl.cpp +++ b/xllm/core/runtime/acl_graph_executor_impl.cpp @@ -61,8 +61,14 @@ GraphPersistentParam::GraphPersistentParam(const ModelArgs& args, need_update_attn_mask_(need_update_attn_mask) { // Determine whether attention plan needs to be updated based on model type // Future logic can be extended here for more complex model-specific behavior + // For qwen3_next: disable paged attention plan update because it uses mixed architecture + // (standard attention + linear attention). The standard attention layers (every 4th layer) + // will still work correctly without plan updates, as they use the same k/v cache structure. + // Linear attention layers use different cache types (conv_cache, ssm_cache) and don't need + // paged attention plan at all. need_update_attention_plan_ = (args.model_type() != "deepseek_v32" && - args.model_type() != "glm_moe_dsa"); + args.model_type() != "glm_moe_dsa" && + args.model_type() != "qwen3_next"); // Check if mRoPE is used (for VLM models like qwen2-vl) use_mrope_ = !args.rope_scaling_mrope_section().empty(); @@ -259,7 +265,8 @@ std::optional GraphPersistentParam::update( aclrtStream stream = c10_npu::getCurrentNPUStream().stream(); // Update tiling tensor based on model type - if (need_update_attention_plan_) { + // For models with mixed attention types (e.g., qwen3_next), only update if k/v cache is defined + if (need_update_attention_plan_ && k_cache.defined() && v_cache.defined()) { plan_paged_attention_tiling( tokens, k_cache, v_cache, persistent_block_tables_, params, stream); } From 8515dcf707fa5817f9b1e520359b8dd3e18b3a37 Mon Sep 17 00:00:00 2001 From: xuyexiong Date: Sat, 14 Mar 2026 15:31:47 +0800 Subject: [PATCH 11/13] rebase to main --- CMakeLists.txt | 29 ----------------- setup.py | 16 +--------- third_party/torch_npu_ops | 2 +- xllm/core/distributed_runtime/llm_engine.cpp | 26 ++++----------- .../distributed_runtime/worker_service.cpp | 32 +++++-------------- xllm/core/framework/kv_cache/kv_cache.cpp | 3 -- xllm/core/framework/kv_cache/kv_cache.h | 3 -- xllm/core/framework/model/model_args.h | 20 ++++++------ xllm/core/kernels/ops_api.cpp | 8 +++++ xllm/core/runtime/acl_graph_executor_impl.cpp | 8 +---- 10 files changed, 36 insertions(+), 111 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 491d961c8..6d19789f1 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -28,7 +28,6 @@ if(USE_NPU) message(STATUS "Building for device: A2 (macro USE_A2 defined)") endif() -<<<<<<< HEAD # Override Mooncake option for mooncake transfer engine # CANN 8.5+ migration: ascend_direct_transport replaces ascend_transport set(USE_ASCEND_DIRECT ON CACHE BOOL "Enable ADXL engine for Ascend NPU" FORCE) @@ -38,34 +37,6 @@ if(USE_NPU) CACHE PATH "Path to xllm_atb_layers source tree") if(NOT EXISTS "${XLLM_ATB_LAYERS_SOURCE_DIR}") message(FATAL_ERROR "xllm_atb_layers source not found: ${XLLM_ATB_LAYERS_SOURCE_DIR}") -======= - option(INSTALL_XLLM_KERNELS "Install xllm_kernels RPM" OFF) - message(STATUS "INSTALL_XLLM_KERNELS enabled: ${INSTALL_XLLM_KERNELS}") - if(INSTALL_XLLM_KERNELS) - if(DEVICE_TYPE STREQUAL "USE_A3") - message("downloading a3 arm xllm kernels") - file(DOWNLOAD - "https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.8.0/xllm_kernels-1.3.10-Linux.a3.arm.rpm" - "${CMAKE_BINARY_DIR}/xllm_kernels.rpm" - ) - else() - if(DEVICE_ARCH STREQUAL "ARM") - message("downloading a2 arm xllm_kernels") - file(DOWNLOAD - "https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.8.0/xllm_kernels-1.3.10-Linux.a2.arm.rpm" - "${CMAKE_BINARY_DIR}/xllm_kernels.rpm" - ) - else() - message("downloading a2 x86 xllm_kernels") - file(DOWNLOAD - "https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.8.0/xllm_kernels-1.3.10-Linux.a2.x86.rpm" - "${CMAKE_BINARY_DIR}/xllm_kernels.rpm" - ) - endif() - endif() - execute_process(COMMAND rpm -ivh --replacepkgs --replacefiles "${CMAKE_BINARY_DIR}/xllm_kernels.rpm") - file(WRITE "${CMAKE_BINARY_DIR}/.xllm_installed" "") ->>>>>>> 89ccdda (bugfix: fix some compile problem: LOAD_MERGED_WEIGHT_V2 / testing trea / Qwen3NextDecoderLayer.) endif() message(STATUS "Using xllm_atb_layers source at: ${XLLM_ATB_LAYERS_SOURCE_DIR}") diff --git a/setup.py b/setup.py index 2266eeffb..6a2343bce 100644 --- a/setup.py +++ b/setup.py @@ -514,21 +514,7 @@ def parse_arguments() -> dict[str, Any]: default='auto', help='Device type: a2, a3, mlu, ilu, cuda or musa (case-insensitive)' ) - - parser.add_argument( - '--dry-run', - action='store_true', - help='Dry run mode (do not execute pre_build)' - ) - - parser.add_argument( - '--install-xllm-kernels', - type=str.lower, - choices=['true', 'false', '1', '0', 'yes', 'no', 'y', 'n', 'on', 'off'], - default='false', - help='Whether to install xllm kernels' - ) - + parser.add_argument( '--generate-so', type=str.lower, diff --git a/third_party/torch_npu_ops b/third_party/torch_npu_ops index e7d254285..90773524d 160000 --- a/third_party/torch_npu_ops +++ b/third_party/torch_npu_ops @@ -1 +1 @@ -Subproject commit e7d254285a3e491abb7ba14e723be0d2909df3a5 +Subproject commit 90773524d2d69220fc80f7845b4570eabfccfd0e diff --git a/xllm/core/distributed_runtime/llm_engine.cpp b/xllm/core/distributed_runtime/llm_engine.cpp index 7a79ad3e7..638cf7e69 100644 --- a/xllm/core/distributed_runtime/llm_engine.cpp +++ b/xllm/core/distributed_runtime/llm_engine.cpp @@ -454,25 +454,13 @@ Engine::KVCacheCapacity LLMEngine::estimate_kv_cache_capacity() { } #endif - if (!FLAGS_enable_continuous_kvcache) { - // compute kv cache n_blocks - const int32_t block_size = options_.block_size(); - const int64_t block_size_in_bytes = - block_size * (slot_size + index_slot_size + scale_slot_size) + linear_slot_size; - kv_cache_cap.n_blocks = kv_cache_cap.cache_size_in_bytes / - (kv_cache_cap.n_layers * block_size_in_bytes); - CHECK_GT(kv_cache_cap.n_blocks, 0) << "no n_blocks for kv cache"; - } else { - int32_t n_pages = - kv_cache_cap.cache_size_in_bytes / FLAGS_phy_page_granularity_size; - if (FLAGS_enable_mla) { - n_pages -= n_pages % (kv_cache_cap.n_layers); - } else { - n_pages -= n_pages % (2 * kv_cache_cap.n_layers); - } - kv_cache_cap.n_pages = n_pages; - CHECK_GT(kv_cache_cap.n_pages, 0) << "no n_pages for kv cache"; - } + // compute kv cache n_blocks + const int32_t block_size = options_.block_size(); + const int64_t block_size_in_bytes = + block_size * (slot_size + index_slot_size + scale_slot_size); + kv_cache_cap.n_blocks = kv_cache_cap.cache_size_in_bytes / + (kv_cache_cap.n_layers * block_size_in_bytes); + CHECK_GT(kv_cache_cap.n_blocks, 0) << "no n_blocks for kv cache"; return kv_cache_cap; } diff --git a/xllm/core/distributed_runtime/worker_service.cpp b/xllm/core/distributed_runtime/worker_service.cpp index 1156e6a96..3ac2a7d03 100644 --- a/xllm/core/distributed_runtime/worker_service.cpp +++ b/xllm/core/distributed_runtime/worker_service.cpp @@ -286,45 +286,29 @@ void WorkerService::AllocateKVCache( threadpool_->schedule([this, controller, request, response, done]() mutable { brpc::ClosureGuard done_guard(done); std::vector> kv_cache_shape; -<<<<<<< HEAD - // Reserve for key, value, and optionally index shape - kv_cache_shape.reserve(3); - kv_cache_shape.emplace_back( - std::vector(request->kv_cache_shape().key_shape().begin(), - request->kv_cache_shape().key_shape().end())); - kv_cache_shape.emplace_back( - std::vector(request->kv_cache_shape().value_shape().begin(), - request->kv_cache_shape().value_shape().end())); - // add index shape if exists - if (request->kv_cache_shape().index_shape_size() > 0) { - kv_cache_shape.emplace_back( - std::vector(request->kv_cache_shape().index_shape().begin(), - request->kv_cache_shape().index_shape().end())); -======= - const bool has_index_shape = request->index_shape_size() > 0; - const bool has_conv_shape = request->conv_shape_size() > 0; - const bool has_ssm_shape = request->ssm_shape_size() > 0; + const bool has_index_shape = request->kv_cache_shape().index_shape_size() > 0; + const bool has_conv_shape = request->kv_cache_shape().conv_shape_size() > 0; + const bool has_ssm_shape = request->kv_cache_shape().ssm_shape_size() > 0; CHECK(!(has_index_shape && (has_conv_shape || has_ssm_shape))) << "KVCacheShape does not support index_shape with conv/ssm shapes " << "simultaneously."; // Reserve for key, value, and optional extra shapes kv_cache_shape.reserve(has_conv_shape || has_ssm_shape ? 4 : 3); kv_cache_shape.emplace_back(std::vector( - request->key_shape().begin(), request->key_shape().end())); + request->kv_cache_shape().key_shape().begin(), request->kv_cache_shape().key_shape().end())); kv_cache_shape.emplace_back(std::vector( - request->value_shape().begin(), request->value_shape().end())); + request->kv_cache_shape().value_shape().begin(), request->kv_cache_shape().value_shape().end())); // add index shape if exists if (has_index_shape) { kv_cache_shape.emplace_back(std::vector( - request->index_shape().begin(), request->index_shape().end())); + request->kv_cache_shape().index_shape().begin(), request->kv_cache_shape().index_shape().end())); } else if (has_conv_shape || has_ssm_shape) { CHECK(has_conv_shape && has_ssm_shape) << "conv_shape and ssm_shape must be provided together."; kv_cache_shape.emplace_back(std::vector( - request->conv_shape().begin(), request->conv_shape().end())); + request->kv_cache_shape().conv_shape().begin(), request->kv_cache_shape().conv_shape().end())); kv_cache_shape.emplace_back(std::vector( - request->ssm_shape().begin(), request->ssm_shape().end())); ->>>>>>> 922b77e (feat: support qwen3-next inference on npu.) + request->kv_cache_shape().ssm_shape().begin(), request->kv_cache_shape().ssm_shape().end())); } auto future = worker_->allocate_kv_cache_async(kv_cache_shape); diff --git a/xllm/core/framework/kv_cache/kv_cache.cpp b/xllm/core/framework/kv_cache/kv_cache.cpp index f2af38ead..1ebde7121 100644 --- a/xllm/core/framework/kv_cache/kv_cache.cpp +++ b/xllm/core/framework/kv_cache/kv_cache.cpp @@ -47,9 +47,6 @@ KVCache::KVCache(torch::Tensor key_cache, conv_cache_(std::move(conv_cache)), ssm_cache_(std::move(ssm_cache)) {} -KVCache::KVCache(std::shared_ptr key_xtensor, - std::shared_ptr value_xtensor) - : key_xtensor_(key_xtensor), value_xtensor_(value_xtensor) {} torch::Tensor KVCache::get_k_cache() const { return key_cache_; } torch::Tensor KVCache::get_v_cache() const { return value_cache_; } torch::Tensor KVCache::get_index_cache() const { return index_cache_; } diff --git a/xllm/core/framework/kv_cache/kv_cache.h b/xllm/core/framework/kv_cache/kv_cache.h index 5a407c8ee..4d1fb144b 100644 --- a/xllm/core/framework/kv_cache/kv_cache.h +++ b/xllm/core/framework/kv_cache/kv_cache.h @@ -74,9 +74,6 @@ class KVCache final { torch::Tensor value_cache_scale_; torch::Tensor conv_cache_; torch::Tensor ssm_cache_; - // for continuous kvcache - std::shared_ptr key_xtensor_; - std::shared_ptr value_xtensor_; }; } // namespace xllm diff --git a/xllm/core/framework/model/model_args.h b/xllm/core/framework/model/model_args.h index 31672cd98..c777519b5 100644 --- a/xllm/core/framework/model/model_args.h +++ b/xllm/core/framework/model/model_args.h @@ -171,16 +171,16 @@ struct ModelArgs { PROPERTY(int32_t, rope_scaling) = -1; PROPERTY(float, router_aux_loss_coef) = 0.001f; - // qwen3 next - PROPERTY(bool, attn_output_gate) = true; - PROPERTY(int32_t, full_attention_interval) = 4; - PROPERTY(int32_t, linear_conv_kernel_dim) = 4; - PROPERTY(int32_t, linear_key_head_dim) = 128; - PROPERTY(int32_t, linear_value_head_dim) = 128; - PROPERTY(int64_t, linear_num_key_heads) = 16; - PROPERTY(int32_t, linear_num_value_heads) = 32; - PROPERTY(int32_t, shared_expert_intermediate_size) = 512; - PROPERTY(float, partial_rotary_factor) = 0.25f; + // qwen3 next initialized with 0, and will be loaded in model file + PROPERTY(bool, attn_output_gate) = false; + PROPERTY(int32_t, full_attention_interval) = 0; + PROPERTY(int32_t, linear_conv_kernel_dim) = 0; + PROPERTY(int32_t, linear_key_head_dim) = 0; + PROPERTY(int32_t, linear_value_head_dim) = 0; + PROPERTY(int64_t, linear_num_key_heads) = 0; + PROPERTY(int32_t, linear_num_value_heads) = 0; + PROPERTY(int32_t, shared_expert_intermediate_size) = 0; + PROPERTY(float, partial_rotary_factor) = 0.0f; // Vision model's dropout PROPERTY(float, mm_dropout) = 0.0f; diff --git a/xllm/core/kernels/ops_api.cpp b/xllm/core/kernels/ops_api.cpp index 4b6fd8ac3..2a058451a 100644 --- a/xllm/core/kernels/ops_api.cpp +++ b/xllm/core/kernels/ops_api.cpp @@ -767,6 +767,14 @@ moe_init_routing_v2(MoeInitRoutingV2Params& params) { #endif } +std::tuple fp8_scaled_quantize( + Fp8ScaledQuantizeParams& params) { +#if defined(USE_CUDA) + return cuda::fp8_scaled_quantize(params.input, params.output, params.scale); +#else + NOT_IMPLEMENTED(); +#endif + std::pair fused_gdn_gating( FusedGdnGatingParams& params) { #if defined(USE_NPU) diff --git a/xllm/core/runtime/acl_graph_executor_impl.cpp b/xllm/core/runtime/acl_graph_executor_impl.cpp index 0d38ddc6d..5f80bbbb5 100644 --- a/xllm/core/runtime/acl_graph_executor_impl.cpp +++ b/xllm/core/runtime/acl_graph_executor_impl.cpp @@ -61,14 +61,8 @@ GraphPersistentParam::GraphPersistentParam(const ModelArgs& args, need_update_attn_mask_(need_update_attn_mask) { // Determine whether attention plan needs to be updated based on model type // Future logic can be extended here for more complex model-specific behavior - // For qwen3_next: disable paged attention plan update because it uses mixed architecture - // (standard attention + linear attention). The standard attention layers (every 4th layer) - // will still work correctly without plan updates, as they use the same k/v cache structure. - // Linear attention layers use different cache types (conv_cache, ssm_cache) and don't need - // paged attention plan at all. need_update_attention_plan_ = (args.model_type() != "deepseek_v32" && - args.model_type() != "glm_moe_dsa" && - args.model_type() != "qwen3_next"); + args.model_type() != "glm_moe_dsa"); // Check if mRoPE is used (for VLM models like qwen2-vl) use_mrope_ = !args.rope_scaling_mrope_section().empty(); From 5c890426ae2b527e9e2d9982bd5f35fad98cca4c Mon Sep 17 00:00:00 2001 From: xuyexiong <809602657@qq.com> Date: Sat, 14 Mar 2026 15:47:45 +0800 Subject: [PATCH 12/13] Apply suggestions from code review Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- xllm/core/distributed_runtime/comm_channel.cpp | 8 ++++---- xllm/core/distributed_runtime/llm_engine.cpp | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/xllm/core/distributed_runtime/comm_channel.cpp b/xllm/core/distributed_runtime/comm_channel.cpp index b209d28ee..0f71e7bb3 100644 --- a/xllm/core/distributed_runtime/comm_channel.cpp +++ b/xllm/core/distributed_runtime/comm_channel.cpp @@ -95,13 +95,13 @@ bool CommChannel::allocate_kv_cache( } if (kv_cache_shape.size() == 4) { - shape.mutable_conv_shape()->Reserve(kv_cache_shape[2].size()); - shape.mutable_ssm_shape()->Reserve(kv_cache_shape[3].size()); + shape->mutable_conv_shape()->Reserve(kv_cache_shape[2].size()); + shape->mutable_ssm_shape()->Reserve(kv_cache_shape[3].size()); for (size_t i = 0; i < kv_cache_shape[2].size(); ++i) { - shape.add_conv_shape(kv_cache_shape[2][i]); + shape->add_conv_shape(kv_cache_shape[2][i]); } for (size_t i = 0; i < kv_cache_shape[3].size(); ++i) { - shape.add_ssm_shape(kv_cache_shape[3][i]); + shape->add_ssm_shape(kv_cache_shape[3][i]); } } proto::Status s; diff --git a/xllm/core/distributed_runtime/llm_engine.cpp b/xllm/core/distributed_runtime/llm_engine.cpp index 638cf7e69..7394ea24f 100644 --- a/xllm/core/distributed_runtime/llm_engine.cpp +++ b/xllm/core/distributed_runtime/llm_engine.cpp @@ -522,7 +522,7 @@ bool LLMEngine::allocate_kv_cache(const Engine::KVCacheCapacity& kv_cache_cap) { args_.linear_key_head_dim() * n_local_linear_v_heads_, args_.linear_conv_kernel_dim() - 1}); kv_cache_shape.emplace_back(std::vector{ kv_cache_cap.n_blocks, n_local_linear_v_heads_, args_.linear_key_head_dim(), - args_.linear_key_head_dim()}); + args_.linear_value_head_dim()}); } #if defined(USE_MLU) // transpose kv_cache layout for mlu From e1f3b0cc912b115ebe7a95156fbb18a3caa0939c Mon Sep 17 00:00:00 2001 From: xuyexiong Date: Sat, 14 Mar 2026 16:47:47 +0800 Subject: [PATCH 13/13] fix compile error --- third_party/torch_npu_ops | 2 +- xllm/core/framework/kv_cache/kv_cache.h | 1 + xllm/core/kernels/ops_api.cpp | 1 + xllm/core/kernels/ops_api.h | 3 ++- xllm/core/layers/npu_torch/fused_moe.cpp | 1 + xllm/core/layers/qwen3_next_decoder_layer.h | 16 ++++++++++++---- xllm/models/llm/qwen3_next.h | 10 ++++++++++ 7 files changed, 28 insertions(+), 6 deletions(-) diff --git a/third_party/torch_npu_ops b/third_party/torch_npu_ops index 90773524d..bf90ef22c 160000 --- a/third_party/torch_npu_ops +++ b/third_party/torch_npu_ops @@ -1 +1 @@ -Subproject commit 90773524d2d69220fc80f7845b4570eabfccfd0e +Subproject commit bf90ef22cc789be1a89541da11d2813ef2c8dd4c diff --git a/xllm/core/framework/kv_cache/kv_cache.h b/xllm/core/framework/kv_cache/kv_cache.h index 4d1fb144b..819db4802 100644 --- a/xllm/core/framework/kv_cache/kv_cache.h +++ b/xllm/core/framework/kv_cache/kv_cache.h @@ -22,6 +22,7 @@ limitations under the License. #include "common/global_flags.h" #include "framework/model/model_input_params.h" +#include "framework/xtensor/xtensor.h" namespace xllm { class KVCache final { diff --git a/xllm/core/kernels/ops_api.cpp b/xllm/core/kernels/ops_api.cpp index 2a058451a..742737388 100644 --- a/xllm/core/kernels/ops_api.cpp +++ b/xllm/core/kernels/ops_api.cpp @@ -774,6 +774,7 @@ std::tuple fp8_scaled_quantize( #else NOT_IMPLEMENTED(); #endif +} std::pair fused_gdn_gating( FusedGdnGatingParams& params) { diff --git a/xllm/core/kernels/ops_api.h b/xllm/core/kernels/ops_api.h index ab44f21c2..d9370ff75 100644 --- a/xllm/core/kernels/ops_api.h +++ b/xllm/core/kernels/ops_api.h @@ -126,7 +126,8 @@ torch::Tensor rms_norm_static_fp8_quant(RmsNormStaticFp8QuantParams& params); std::tuple fused_add_rms_norm_static_fp8_quant( FusedAddRmsNormStaticFp8QuantParams& params); -std::pair fused_gdn_gating(FusedGdnGatingParams& params); +std::pair fused_gdn_gating( + FusedGdnGatingParams& params); std::pair fused_recurrent_gated_delta_rule( FusedRecurrentGatedDeltaRuleParams& params); diff --git a/xllm/core/layers/npu_torch/fused_moe.cpp b/xllm/core/layers/npu_torch/fused_moe.cpp index b5c1db0ea..56d2f20d4 100644 --- a/xllm/core/layers/npu_torch/fused_moe.cpp +++ b/xllm/core/layers/npu_torch/fused_moe.cpp @@ -187,6 +187,7 @@ torch::Tensor FusedMoEImpl::select_experts( moe_active_topk_params.input = router_logits_2d; moe_active_topk_params.finished = torch::Tensor(); moe_active_topk_params.topk = topk_; + moe_active_topk_params.scoring_func = "softmax"; auto [topk_weights, topk_ids] = xllm::kernel::moe_active_topk(moe_active_topk_params); topk_ids = topk_ids.to(torch::kInt32); diff --git a/xllm/core/layers/qwen3_next_decoder_layer.h b/xllm/core/layers/qwen3_next_decoder_layer.h index 18bb8170f..8175c3b2b 100644 --- a/xllm/core/layers/qwen3_next_decoder_layer.h +++ b/xllm/core/layers/qwen3_next_decoder_layer.h @@ -20,23 +20,31 @@ limitations under the License. #include #include "common/dense_mlp.h" -#include "common/qwen3_next_rms_norm.h" #include "common/qwen3_next_attention.h" #include "common/qwen3_next_gated_delta_net.h" -#include "layers/npu/fused_moe.h" +#include "common/qwen3_next_rms_norm.h" +#if defined(USE_MLU) +#include "layers/mlu/fused_moe.h" +#elif defined(USE_NPU) +#include "layers/npu_torch/fused_moe.h" +#elif defined(USE_ILU) +#include "layers/ilu/fused_moe.h" +#else +#include "layers/common/fused_moe.h" +#endif #include "framework/kv_cache/kv_cache.h" #include "framework/model/model_args.h" #include "framework/model/model_input_params.h" #include "framework/model_context.h" #include "framework/state_dict/state_dict.h" - namespace xllm { namespace layer { class Qwen3NextDecoderLayerImpl : public torch::nn::Module { public: - explicit Qwen3NextDecoderLayerImpl(const ModelContext& context, int32_t layer_id); + explicit Qwen3NextDecoderLayerImpl(const ModelContext& context, + int32_t layer_id); void load_state_dict(const StateDict& state_dict); diff --git a/xllm/models/llm/qwen3_next.h b/xllm/models/llm/qwen3_next.h index 959d4707e..86d7411e3 100644 --- a/xllm/models/llm/qwen3_next.h +++ b/xllm/models/llm/qwen3_next.h @@ -226,6 +226,16 @@ class Qwen3NextForCausalLMImpl : public torch::nn::Module { #endif } + torch::Tensor pooler(const torch::Tensor& hidden_states, + const torch::Tensor& seleted_idxes) { + auto h = hidden_states; + if (seleted_idxes.defined()) { + h = h.index_select(/*dim=*/0, seleted_idxes); + } + namespace F = torch::nn::functional; + return F::normalize(h, F::NormalizeFuncOptions().p(2).dim(1)); + } + void load_model(std::unique_ptr loader) { for (const auto& state_dict : loader->get_state_dicts()) { model_->load_state_dict(state_dict->get_dict_with_prefix("model."));