Skip to content
Open
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ endif()
if(USE_NPU)
# USE_NPU_TORCH: Temporary flag used for debugging qwen3 torch NPU graph
# capture. This variable may be removed in the future.
# add_definitions(-DUSE_NPU_TORCH)
add_definitions(-DUSE_NPU_TORCH)
add_definitions(-DUSE_NPU)
add_definitions(-DBUILD_LIBTORCH)
add_definitions(-DTORCH_SETCUSTOMHANDLER=ON)
Expand All @@ -338,6 +338,7 @@ if(USE_NPU)
$ENV{NPU_HOME_PATH}/include
$ENV{ATB_HOME_PATH}/include
$ENV{NPU_HOME_PATH}/opp/vendors/xllm/op_api/include/
${CMAKE_CURRENT_SOURCE_DIR}/third_party/torch_npu_ops/
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to check whether there is a better way to do this

)
# Keep third-party warnings suppressed while providing headers expected by
# legacy includes like "atb_speed/log.h" from xllm_atb_layers.
Expand Down
2 changes: 1 addition & 1 deletion third_party/torch_npu_ops
Submodule torch_npu_ops updated from 907735 to bf90ef
12 changes: 11 additions & 1 deletion xllm/core/distributed_runtime/comm_channel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,23 @@ bool CommChannel::allocate_kv_cache(
}

// add index shape if exists
if (kv_cache_shape.size() > 2) {
if (kv_cache_shape.size() == 3) {
shape->mutable_index_shape()->Reserve(kv_cache_shape[2].size());
for (size_t i = 0; i < kv_cache_shape[2].size(); ++i) {
shape->add_index_shape(kv_cache_shape[2][i]);
}
}

if (kv_cache_shape.size() == 4) {
shape->mutable_conv_shape()->Reserve(kv_cache_shape[2].size());
shape->mutable_ssm_shape()->Reserve(kv_cache_shape[3].size());
for (size_t i = 0; i < kv_cache_shape[2].size(); ++i) {
shape->add_conv_shape(kv_cache_shape[2][i]);
}
for (size_t i = 0; i < kv_cache_shape[3].size(); ++i) {
shape->add_ssm_shape(kv_cache_shape[3][i]);
}
}
proto::Status s;
brpc::Controller cntl;
stub_->AllocateKVCache(&cntl, &request, &s, nullptr);
Expand Down
1 change: 1 addition & 0 deletions xllm/core/distributed_runtime/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ class Engine {
int64_t cache_size_in_bytes = 0;
int64_t slot_size = 0;
int64_t index_slot_size = 0;
int64_t linear_slot_size = 0;
int64_t n_layers = 0;
};

Expand Down
31 changes: 29 additions & 2 deletions xllm/core/distributed_runtime/llm_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,12 @@ bool LLMEngine::init_model(MasterStatus master_status) {
n_local_q_heads_ = std::max<int64_t>(1, n_heads / world_size);
head_dim_ = args_.head_dim();
dtype_ = util::parse_dtype(args_.dtype(), options_.devices()[0]);

if (args_.full_attention_interval() > 1) {
const int64_t linear_n_k_heads = args_.linear_num_key_heads();
const int64_t linear_n_v_heads = args_.linear_num_value_heads();
n_local_linear_k_heads_ = std::max<int64_t>(1, linear_n_k_heads / world_size);
n_local_linear_v_heads_ = std::max<int64_t>(1, linear_n_v_heads / world_size);
}
// key + value for all layers
LOG(INFO) << "Block info, block_size: " << options_.block_size()
<< ", n_local_kv_heads: " << n_local_kv_heads_
Expand Down Expand Up @@ -383,7 +388,7 @@ Engine::KVCacheCapacity LLMEngine::estimate_kv_cache_capacity() {
int64_t index_slot_size = 0;
int64_t scale_slot_size =
0; // Extra overhead for scale tensors in quant mode

int64_t linear_slot_size = 0;
if (FLAGS_enable_mla) {
#if defined(USE_NPU)
if (args_.model_type() == "deepseek_v3" && FLAGS_enable_prefix_cache) {
Expand Down Expand Up @@ -425,8 +430,16 @@ Engine::KVCacheCapacity LLMEngine::estimate_kv_cache_capacity() {
scale_slot_size = 2 * sizeof(float) * n_local_kv_heads_;
}
}
if (args_.linear_num_value_heads() > 0) {
int64_t head_k_dim = args_.linear_key_head_dim();
int64_t head_v_dim = args_.linear_value_head_dim();
int64_t linear_ssm_slot_size = dtype_size * n_local_linear_v_heads_ * head_k_dim * head_v_dim;
int64_t linear_conv_slot_size = dtype_size * (head_k_dim * n_local_linear_k_heads_ * 2 + head_v_dim * n_local_linear_v_heads_) * (args_.linear_conv_kernel_dim() -1 );
linear_slot_size = linear_ssm_slot_size + linear_conv_slot_size;
}
kv_cache_cap.slot_size = slot_size;
kv_cache_cap.index_slot_size = index_slot_size;
kv_cache_cap.linear_slot_size = linear_slot_size;
kv_cache_cap.n_layers = args_.n_layers();
#if !defined(USE_NPU)
// this adoption is because the allocation of kv cache is based on
Expand Down Expand Up @@ -462,6 +475,7 @@ bool LLMEngine::allocate_kv_cache(const Engine::KVCacheCapacity& kv_cache_cap) {
CHECK_GT(kv_cache_cap.n_blocks, 0) << "no memory for kv cache";
const int32_t block_size = options_.block_size();
bool enable_lighting_indexer = args_.index_n_heads() > 1;
bool enable_linear_attention = args_.full_attention_interval() > 1;

// init kv cache for each worker
std::vector<std::vector<int64_t>> kv_cache_shape;
Expand Down Expand Up @@ -501,6 +515,15 @@ bool LLMEngine::allocate_kv_cache(const Engine::KVCacheCapacity& kv_cache_cap) {
kv_cache_shape.emplace_back(std::vector<int64_t>{
kv_cache_cap.n_blocks, block_size, 1, args_.index_head_dim()});
}
if (enable_linear_attention) {
kv_cache_shape.emplace_back(std::vector<int64_t>{
kv_cache_cap.n_blocks,
args_.linear_key_head_dim() * n_local_linear_k_heads_ * 2 +
args_.linear_key_head_dim() * n_local_linear_v_heads_, args_.linear_conv_kernel_dim() - 1});
kv_cache_shape.emplace_back(std::vector<int64_t>{
kv_cache_cap.n_blocks, n_local_linear_v_heads_, args_.linear_key_head_dim(),
args_.linear_value_head_dim()});
}
#if defined(USE_MLU)
// transpose kv_cache layout for mlu
// default layout: [n_blocks, block_size, n_head, head_dim]
Expand All @@ -525,6 +548,10 @@ bool LLMEngine::allocate_kv_cache(const Engine::KVCacheCapacity& kv_cache_cap) {
LOG(INFO) << "Initializing indexer cache with shape: [" << kv_cache_shape[2]
<< "]";
}
if (enable_linear_attention) {
LOG(INFO) << "Initializing conv cache with shape: [" << kv_cache_shape[2] << "]";
LOG(INFO) << "Initializing ssm cache with shape: [" << kv_cache_shape[3] << "]";
}

// initialize block manager
BlockManagerPool::Options options;
Expand Down
2 changes: 2 additions & 0 deletions xllm/core/distributed_runtime/llm_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ class LLMEngine : public Engine {
int64_t n_local_kv_heads_ = 0;
int64_t n_local_q_heads_ = 0;
int64_t head_dim_ = 0;
int64_t n_local_linear_v_heads_ = 0;
int64_t n_local_linear_k_heads_ = 0;

// common frequently used args
uint32_t dp_size_;
Expand Down
66 changes: 46 additions & 20 deletions xllm/core/distributed_runtime/worker_service.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,19 +286,29 @@ void WorkerService::AllocateKVCache(
threadpool_->schedule([this, controller, request, response, done]() mutable {
brpc::ClosureGuard done_guard(done);
std::vector<std::vector<int64_t>> kv_cache_shape;
// Reserve for key, value, and optionally index shape
kv_cache_shape.reserve(3);
kv_cache_shape.emplace_back(
std::vector<int64_t>(request->kv_cache_shape().key_shape().begin(),
request->kv_cache_shape().key_shape().end()));
kv_cache_shape.emplace_back(
std::vector<int64_t>(request->kv_cache_shape().value_shape().begin(),
request->kv_cache_shape().value_shape().end()));
const bool has_index_shape = request->kv_cache_shape().index_shape_size() > 0;
const bool has_conv_shape = request->kv_cache_shape().conv_shape_size() > 0;
const bool has_ssm_shape = request->kv_cache_shape().ssm_shape_size() > 0;
CHECK(!(has_index_shape && (has_conv_shape || has_ssm_shape)))
<< "KVCacheShape does not support index_shape with conv/ssm shapes "
<< "simultaneously.";
Comment on lines +292 to +294
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The AllocateKVCache RPC method uses CHECK macros to validate the consistency of the request data. If a request is sent with conflicting shape information (e.g., both index_shape and conv_shape are present), the CHECK macro will fail and cause the worker process to abort. This allows an attacker who can reach the worker's RPC interface to crash the worker process, leading to a denial of service. It is recommended to replace CHECK macros with proper error handling that returns an error status to the caller instead of crashing the process.

// Reserve for key, value, and optional extra shapes
kv_cache_shape.reserve(has_conv_shape || has_ssm_shape ? 4 : 3);
kv_cache_shape.emplace_back(std::vector<int64_t>(
request->kv_cache_shape().key_shape().begin(), request->kv_cache_shape().key_shape().end()));
kv_cache_shape.emplace_back(std::vector<int64_t>(
request->kv_cache_shape().value_shape().begin(), request->kv_cache_shape().value_shape().end()));
// add index shape if exists
if (request->kv_cache_shape().index_shape_size() > 0) {
kv_cache_shape.emplace_back(
std::vector<int64_t>(request->kv_cache_shape().index_shape().begin(),
request->kv_cache_shape().index_shape().end()));
if (has_index_shape) {
kv_cache_shape.emplace_back(std::vector<int64_t>(
request->kv_cache_shape().index_shape().begin(), request->kv_cache_shape().index_shape().end()));
} else if (has_conv_shape || has_ssm_shape) {
CHECK(has_conv_shape && has_ssm_shape)
<< "conv_shape and ssm_shape must be provided together.";
Comment on lines +306 to +307
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The AllocateKVCache RPC method uses a CHECK macro to ensure that conv_shape and ssm_shape are provided together. If a request is sent with only one of these shapes, the CHECK macro will fail and cause the worker process to abort. This is a denial of service vector. It is recommended to validate the request data and return an error status to the caller instead of using CHECK.

Comment on lines +292 to +307
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The AllocateKVCache RPC handler uses the CHECK macro to validate input parameters (has_index_shape, has_conv_shape, has_ssm_shape). The CHECK macro causes the entire process to abort if the condition is not met. Since this logic operates on untrusted data received over the network, it creates a Denial of Service (DoS) vulnerability. An attacker can send a malformed request to crash the worker process. Input validation should be performed using conditional logic that returns an error response (e.g., via controller->SetFailed()) instead of crashing the process.

kv_cache_shape.emplace_back(std::vector<int64_t>(
request->kv_cache_shape().conv_shape().begin(), request->kv_cache_shape().conv_shape().end()));
kv_cache_shape.emplace_back(std::vector<int64_t>(
request->kv_cache_shape().ssm_shape().begin(), request->kv_cache_shape().ssm_shape().end()));
}

auto future = worker_->allocate_kv_cache_async(kv_cache_shape);
Expand All @@ -316,18 +326,34 @@ void WorkerService::AllocateKVCacheWithTransfer(
threadpool_->schedule([this, controller, req, resp, done]() mutable {
brpc::ClosureGuard done_guard(done);
std::vector<std::vector<int64_t>> kv_cache_shape;
kv_cache_shape.reserve(2);
const auto& shape_req = req->kv_cache_shape();
const bool has_index_shape = shape_req.index_shape_size() > 0;
const bool has_conv_shape = shape_req.conv_shape_size() > 0;
const bool has_ssm_shape = shape_req.ssm_shape_size() > 0;
CHECK(!(has_index_shape && (has_conv_shape || has_ssm_shape)))
<< "KVCacheShape does not support index_shape with conv/ssm shapes "
<< "simultaneously.";
Comment on lines +333 to +335
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The AllocateKVCacheWithTransfer RPC method uses CHECK macros to validate the consistency of the request data. If a request is sent with conflicting shape information (e.g., both index_shape and conv_shape are present), the CHECK macro will fail and cause the worker process to abort. This allows an attacker to crash the worker process. It is recommended to replace CHECK macros with proper error handling.

kv_cache_shape.reserve(has_conv_shape || has_ssm_shape ? 4 : 3);
kv_cache_shape.emplace_back(
std::vector<int64_t>(req->kv_cache_shape().key_shape().begin(),
req->kv_cache_shape().key_shape().end()));
std::vector<int64_t>(shape_req.key_shape().begin(),
shape_req.key_shape().end()));
kv_cache_shape.emplace_back(
std::vector<int64_t>(req->kv_cache_shape().value_shape().begin(),
req->kv_cache_shape().value_shape().end()));
std::vector<int64_t>(shape_req.value_shape().begin(),
shape_req.value_shape().end()));
// add index shape if exists
if (req->kv_cache_shape().index_shape_size() > 0) {
if (has_index_shape) {
kv_cache_shape.emplace_back(
std::vector<int64_t>(shape_req.index_shape().begin(),
shape_req.index_shape().end()));
} else if (has_conv_shape || has_ssm_shape) {
CHECK(has_conv_shape && has_ssm_shape)
<< "conv_shape and ssm_shape must be provided together.";
Comment on lines +349 to +350
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The AllocateKVCacheWithTransfer RPC method uses a CHECK macro to ensure that conv_shape and ssm_shape are provided together. If a request is sent with only one of these shapes, the CHECK macro will fail and cause the worker process to abort. This is a denial of service vector. It is recommended to validate the request data and return an error status to the caller instead of using CHECK.

Comment on lines +333 to +350
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The AllocateKVCacheWithTransfer RPC handler uses CHECK macros for shape validation, which is a Denial of Service (DoS) vector allowing a remote caller to crash the worker by providing inconsistent shape flags. Proper error handling should be implemented to return a failure status without aborting the process. Additionally, there is a copy-paste error when constructing the kv_cache_shape for linear attention, where index_shape is incorrectly used instead of ssm_shape. This can lead to incorrect shape information, causing allocation failures or memory corruption.

kv_cache_shape.emplace_back(
std::vector<int64_t>(shape_req.conv_shape().begin(),
shape_req.conv_shape().end()));
kv_cache_shape.emplace_back(
std::vector<int64_t>(req->kv_cache_shape().index_shape().begin(),
req->kv_cache_shape().index_shape().end()));
std::vector<int64_t>(shape_req.ssm_shape().begin(),
shape_req.ssm_shape().end()));
}

auto future =
Expand Down
12 changes: 12 additions & 0 deletions xllm/core/framework/kv_cache/kv_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,21 @@ KVCache::KVCache(torch::Tensor key_cache,
index_cache_(std::move(index_cache)),
key_cache_scale_(std::move(key_cache_scale)),
value_cache_scale_(std::move(value_cache_scale)) {}

KVCache::KVCache(torch::Tensor key_cache,
torch::Tensor value_cache,
torch::Tensor conv_cache,
torch::Tensor ssm_cache)
: key_cache_(std::move(key_cache)),
value_cache_(std::move(value_cache)),
conv_cache_(std::move(conv_cache)),
ssm_cache_(std::move(ssm_cache)) {}

torch::Tensor KVCache::get_k_cache() const { return key_cache_; }
torch::Tensor KVCache::get_v_cache() const { return value_cache_; }
torch::Tensor KVCache::get_index_cache() const { return index_cache_; }
torch::Tensor KVCache::get_conv_cache() const { return conv_cache_; }
torch::Tensor KVCache::get_ssm_cache() const { return ssm_cache_; }

std::optional<torch::Tensor> KVCache::get_k_cache_scale() const {
if (!key_cache_scale_.defined() || key_cache_scale_.numel() == 0) {
Expand Down
11 changes: 11 additions & 0 deletions xllm/core/framework/kv_cache/kv_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License.

#include "common/global_flags.h"
#include "framework/model/model_input_params.h"
#include "framework/xtensor/xtensor.h"

namespace xllm {
class KVCache final {
Expand All @@ -37,6 +38,12 @@ class KVCache final {
torch::Tensor index_cache,
torch::Tensor key_cache_scale,
torch::Tensor value_cache_scale);
KVCache(torch::Tensor key_cache,
torch::Tensor value_cache,
torch::Tensor conv_cache,
torch::Tensor ssm_cache);
KVCache(std::shared_ptr<XTensor> key_xtensor,
std::shared_ptr<XTensor> value_xtensor);
~KVCache() = default;

// TODO: pass in kv_shape and options instead
Expand All @@ -48,6 +55,8 @@ class KVCache final {
std::optional<torch::Tensor> get_k_cache_scale() const;
std::optional<torch::Tensor> get_v_cache_scale() const;

torch::Tensor get_conv_cache() const;
torch::Tensor get_ssm_cache() const;
std::vector<std::vector<int64_t>> get_shapes();

bool empty() const {
Expand All @@ -64,6 +73,8 @@ class KVCache final {
// scale tensors for quantized KV cache (int8)
torch::Tensor key_cache_scale_;
torch::Tensor value_cache_scale_;
torch::Tensor conv_cache_;
torch::Tensor ssm_cache_;
};

} // namespace xllm
11 changes: 11 additions & 0 deletions xllm/core/framework/model/model_args.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,17 @@ struct ModelArgs {
PROPERTY(int32_t, rope_scaling) = -1;
PROPERTY(float, router_aux_loss_coef) = 0.001f;

// qwen3 next initialized with 0, and will be loaded in model file
PROPERTY(bool, attn_output_gate) = false;
PROPERTY(int32_t, full_attention_interval) = 0;
PROPERTY(int32_t, linear_conv_kernel_dim) = 0;
PROPERTY(int32_t, linear_key_head_dim) = 0;
PROPERTY(int32_t, linear_value_head_dim) = 0;
PROPERTY(int64_t, linear_num_key_heads) = 0;
PROPERTY(int32_t, linear_num_value_heads) = 0;
PROPERTY(int32_t, shared_expert_intermediate_size) = 0;
PROPERTY(float, partial_rotary_factor) = 0.0f;

// Vision model's dropout
PROPERTY(float, mm_dropout) = 0.0f;

Expand Down
15 changes: 15 additions & 0 deletions xllm/core/framework/parallel_state/npu_process_group.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,14 @@ void ProcessGroupImpl::allgather(const torch::Tensor& input,
check_input(input);
torch::DeviceGuard device_guard(device());

if (pg_) {
std::vector<torch::Tensor> input_tensors = {input};
std::vector<std::vector<torch::Tensor>> output_tensors = {outputs};
pg_->allgather(output_tensors, input_tensors)->wait();
return;
}
CHECK(comm_ != nullptr) << "HCCL comm is not initialized.";

torch::Tensor flattened_output = flatten_for_scatter_gather(outputs);

const auto count = input.numel();
Expand Down Expand Up @@ -170,6 +178,13 @@ void ProcessGroupImpl::allreduce(torch::Tensor& input) {
check_input(input);
torch::DeviceGuard device_guard(device());

if (pg_) {
std::vector<torch::Tensor> input_tensors = {input};
pg_->allreduce(input_tensors)->wait();
return;
}
CHECK(comm_ != nullptr) << "HCCL comm is not initialized.";

const auto count = input.numel();
const auto data_type = to_hccl_data_type(input);

Expand Down
55 changes: 55 additions & 0 deletions xllm/core/framework/state_dict/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,61 @@ void load_merged_weight(const StateDict& state_dict,
weight_is_loaded = true;
}

void load_merged_weight_v2(const StateDict& state_dict,
const std::string& name,
int64_t dim,
int32_t rank,
int32_t world_size,
int32_t shard_tensor_count,
const std::vector<int64_t>& shard_sizes,
torch::Tensor& weight,
bool& weight_is_loaded) {
if (weight_is_loaded) {
return;
}
const auto& tensor = state_dict.get_tensor(name);
if (!tensor.defined()) {
return;
}

// Check that shard_tensor_count matches the size of shard_sizes vector
CHECK_EQ(shard_tensor_count, static_cast<int32_t>(shard_sizes.size()))
<< "shard_tensor_count does not match shard_sizes vector size for " << state_dict.prefix() << name;

// Calculate total expected size
int64_t total_expected_size = 0;
for (int64_t size : shard_sizes) {
total_expected_size += size * world_size;
}
CHECK_EQ(tensor.size(dim), total_expected_size)
<< name << "[" << dim << "] size mismatch for " << state_dict.prefix() << name;

std::vector<torch::Tensor> shard_tensors;

for (size_t shard_id = 0; shard_id < shard_tensor_count; shard_id++) {
int64_t shard_size = shard_sizes[shard_id];

// Calculate the offset for this shard and rank
// First, skip all the data from previous shards for all ranks
int64_t shard_offset = 0;
for (int32_t prev_shard = 0; prev_shard < shard_id; ++prev_shard) {
shard_offset += shard_sizes[prev_shard] * world_size;
}

// Then add the offset within this shard for the current rank
shard_offset += rank * shard_size;

shard_tensors.push_back(
tensor.slice(dim, shard_offset, shard_offset + shard_size));
}

auto merged_weight = torch::cat(shard_tensors, dim);
CHECK_EQ(weight.sizes(), merged_weight.sizes())
<< "weight size mismatch for " << state_dict.prefix() << name;
weight.copy_(merged_weight);
weight_is_loaded = true;
}

} // namespace weight

} // namespace xllm
Loading
Loading