Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
[submodule "third_party/spdlog"]
path = third_party/spdlog
url = https://github.com/gabime/spdlog.git
[submodule "third_party/json"]
path = third_party/json
url = https://github.com/nlohmann/json.git
Empty file added =0.34.0,
Empty file.
88 changes: 88 additions & 0 deletions csrc/config/model_config.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#include "model_config.hpp"

namespace infinilm::config {
ModelConfig::ModelConfig(const std::string &path) {
std::ifstream file(path);
if (file.is_open()) {
file >> config_json;
file.close();
} else {
throw std::runtime_error("Could not open config file: " + path);
}
this->quant_config = QuantConfig(config_json["quantization_config"]);
}

infinicore::nn::QuantScheme
ModelConfig::get_quant_scheme() const {
if (quant_config.get_quant_scheme() != infinicore::nn::QuantScheme::NONE) {
return quant_config.get_quant_scheme();
} else {
return infinicore::nn::QuantScheme::NONE;
}
}

std::shared_ptr<infinicore::nn::RoPE::ScalingConfig>
ModelConfig::get_rope_scaling() const {
if (!config_json.contains("rope_scaling") || config_json["rope_scaling"].is_null()) {
return nullptr;
}

const auto &rope_scaling = config_json["rope_scaling"];
if (!rope_scaling.is_object()) {
throw std::runtime_error("rope_scaling must be an object");
}

if (!rope_scaling.contains("type")) {
throw std::runtime_error("rope_scaling must contain 'type' field");
}

std::string type_str = rope_scaling["type"].get<std::string>();
if (type_str == "longrope") {
// Required fields for LongRopeConfig
if (!rope_scaling.contains("short_factor") || !rope_scaling.contains("long_factor") || !rope_scaling.contains("original_max_position_embeddings")) {
throw std::runtime_error(
"LongRopeConfig requires 'short_factor', 'long_factor', and 'original_max_position_embeddings'");
}

auto short_factor = rope_scaling["short_factor"].get<std::vector<float>>();
auto long_factor = rope_scaling["long_factor"].get<std::vector<float>>();
size_t original_max_position_embeddings = rope_scaling["original_max_position_embeddings"].get<size_t>();

float factor = 1.0f;
if (rope_scaling.contains("factor")) {
factor = rope_scaling["factor"].get<float>();
}

return std::make_shared<infinicore::nn::RoPE::LongRopeConfig>(
std::move(short_factor),
std::move(long_factor),
original_max_position_embeddings,
factor);
} else if (type_str == "default" || type_str == "none") {
// Default scaling, no scaling applied
return nullptr;
} else {
throw std::runtime_error("Unsupported rope_scaling type: " + type_str);
}
}

infinicore::DataType
ModelConfig::get_dtype() const {
try {
std::string dtype_str = this->get<std::string>("torch_dtype");
if (dtype_str == "float32") {
return infinicore::DataType::F32;
} else if (dtype_str == "float16") {
return infinicore::DataType::F16;
} else if (dtype_str == "bfloat16") {
return infinicore::DataType::BF16;
} else if (dtype_str == "int8") {
return infinicore::DataType::I8;
} else {
throw std::runtime_error("Unsupported dtype string: " + dtype_str);
}
} catch (const std::exception &e) {
throw std::runtime_error("Error getting dtype from config: " + std::string(e.what()));
}
}
} // namespace infinilm::config
63 changes: 63 additions & 0 deletions csrc/config/model_config.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#pragma once

#include "infinicore/nn/rope.hpp"
#include "infinicore/ops.hpp"
#include "quant_config.hpp"
#include <fstream>
#include <string>

namespace infinilm::config {
class ModelConfig {
// Model config is implemented using nlohmann/json and is primarily used for advanced configuration
// beyond the standard model config. It is initialized via ModelConfig(const std::string& path)
// and passed through the InferEngine during inference.
public:
ModelConfig() = default;
// Not Implemented
// ModelConfig(const nlohmann::json &json) : config_json(json) {};
ModelConfig(const std::string &path);

// Template Function to get a value by key with type safety
template <typename T>
T get(const std::string &key) const {
if (!config_json.contains(key)) {
throw std::out_of_range("Key '" + key + "' not found in config.");
}
try {
return config_json.at(key).get<T>();
} catch (const nlohmann::json::type_error &e) {
throw std::runtime_error("Type conversion failed for key '" + key + "': " + std::string(e.what()));
}
}

template <typename T>
T get_or(const std::string &key, const T &default_value) const {
if (!config_json.contains(key) || config_json.at(key).is_null()) {
return default_value;
}
try {
return config_json.at(key).get<T>();
} catch (const nlohmann::json::type_error &) {
// If type conversion fails, return default value
return default_value;
}
}
size_t get_kv_dim() const {
return get<size_t>("hidden_size") * get<size_t>("num_key_value_heads") / get<size_t>("num_attention_heads");
}
size_t get_head_dim() const {
if (config_json.contains("head_dim")) {
return get<size_t>("head_dim");
}
return get<size_t>("hidden_size") / get<size_t>("num_attention_heads");
}

infinicore::DataType get_dtype() const;
infinicore::nn::QuantScheme get_quant_scheme() const;
std::shared_ptr<infinicore::nn::RoPE::ScalingConfig> get_rope_scaling() const;

private:
nlohmann::json config_json;
QuantConfig quant_config;
};
} // namespace infinilm::config
22 changes: 22 additions & 0 deletions csrc/config/quant_config.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#include "quant_config.hpp"

namespace infinilm::config {
QuantConfig::QuantConfig(const nlohmann::json &json) : quantization_config(json) {
this->quantization_method = get_quantization_method();
}

std::shared_ptr<infinilm::quantization::BaseQuantization>
QuantConfig::get_quantization_method() const {
if (quantization_config.is_null()) {
return nullptr;
}

// Determine the quantization scheme from the JSON config
if (quantization_config["quant_method"] == "compressed-tensors") {
return std::make_shared<infinilm::quantization::CompressedTensors>(quantization_config);
}
// Add other schemes as needed

return nullptr; // Default case if no matching scheme
}
} // namespace infinilm::config
28 changes: 28 additions & 0 deletions csrc/config/quant_config.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#pragma once
#include "../quantization/quantization.hpp"
#include "nlohmann/json.hpp"

namespace infinilm::config {

class QuantConfig {
// QuantConfig is used to store and parse the "quantization" field from config.json.
// This is currently a basic version and will be extended in the future.
public:
QuantConfig() = default;
QuantConfig(const nlohmann::json &json);

infinicore::nn::QuantScheme get_quant_scheme() const {
if (quantization_method != nullptr) {
return quantization_method->get_quant_scheme();
} else {
return infinicore::nn::QuantScheme::NONE;
}
}

private:
nlohmann::json quantization_config;
std::shared_ptr<infinilm::quantization::BaseQuantization> get_quantization_method() const;
std::shared_ptr<infinilm::quantization::BaseQuantization> quantization_method;
};

} // namespace infinilm::config
45 changes: 43 additions & 2 deletions csrc/engine/infer_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,26 @@ namespace infinilm::engine {
//------------------------------------------------------
// Constructor
//------------------------------------------------------
/**
* @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
*
* ⚠️ DEVELOPMENT POLICY:
* - NO new development or feature additions permitted on this interface
* - Only critical bug fixes (security/stability) allowed until removal
* - All new code MUST migrate to the polymorphic overload below
*
* Replacement: Use the polymorphic overload of this same function name with updated signature
* Reason: Legacy signature lacks support for dynamic quantization modes.
* Removal target: v0.2.0 (Q2 2026)
*/
InferEngine::InferEngine(
const InfinilmModel::Config &config,
const distributed::DistConfig &distributed_config,
infinicore::Device::Type device_type,
const cache::CacheConfig *cache_config,
bool enable_graph_compiling) // Changed parameter
: communication_group_(distributed_config, device_type),
model_config_(config) {
legacy_model_config_(config) {

if (cache_config != nullptr) {
cache_config_ = cache_config->unique_copy();
Expand All @@ -24,7 +36,7 @@ InferEngine::InferEngine(
workers_.reserve(world_size);
for (int r = 0; r < world_size; ++r) {
workers_.emplace_back(std::make_unique<RankWorker>(
model_config_,
legacy_model_config_,
communication_group_.get_rank_info(r),
cache_config_ != nullptr ? cache_config_.get() : nullptr,
barrier_.get(),
Expand All @@ -35,6 +47,35 @@ InferEngine::InferEngine(
this->compile();
}

InferEngine::InferEngine(
const std::string &model_path,
const distributed::DistConfig &distributed_config,
infinicore::Device::Type device_type,
const cache::CacheConfig *cache_config,
bool enable_graph_compiling) // Changed parameter
: communication_group_(distributed_config, device_type) {
if (cache_config != nullptr) {
cache_config_ = cache_config->unique_copy();
}

// Load model config if model_path is provided, model_path must be valid, and config.json exists
this->model_config_ = std::make_shared<infinilm::config::ModelConfig>(model_path + "/config.json");
// Create one RankWorker per rank
int world_size = communication_group_.get_world_size();
barrier_ = std::make_unique<RankBarrier>((size_t)world_size);
workers_.reserve(world_size);
for (int r = 0; r < world_size; ++r) {
workers_.emplace_back(std::make_unique<RankWorker>(
model_config_,
communication_group_.get_rank_info(r),
cache_config_ != nullptr ? cache_config_.get() : nullptr,
barrier_.get(),
enable_graph_compiling));
}
// Compile the model on all workers
this->compile();
}

//------------------------------------------------------
// load_param
//------------------------------------------------------
Expand Down
23 changes: 22 additions & 1 deletion csrc/engine/infer_engine.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include "../config/model_config.hpp"
#include "../models/infinilm_model.hpp"
#include "../models/llama/llama_config.hpp"
#include "distributed/distributed.hpp"
Expand All @@ -19,13 +20,32 @@ class InferEngine {
using Output = RankWorker::Output;

// Updated constructor: accept CacheConfig instead of CacheType
/**
* @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
*
* ⚠️ DEVELOPMENT POLICY:
* - NO new development or feature additions permitted on this interface
* - Only critical bug fixes (security/stability) allowed until removal
* - All new code MUST migrate to the polymorphic overload below
*
* Replacement: Use the polymorphic overload of this same function name with updated signature
* Reason: Legacy signature lacks support for dynamic quantization modes.
* Removal target: v0.2.0 (Q2 2026)
*/
InferEngine(
const InfinilmModel::Config &config,
const distributed::DistConfig &distributed_config = distributed::DistConfig(),
infinicore::Device::Type device_type = infinicore::context::getDevice().getType(),
const cache::CacheConfig *cache_config = nullptr,
bool enable_graph_compiling = false);

InferEngine(
const std::string &model_path = "",
const distributed::DistConfig &distributed_config = distributed::DistConfig(),
infinicore::Device::Type device_type = infinicore::context::getDevice().getType(),
const cache::CacheConfig *cache_config = nullptr,
bool enable_graph_compiling = false);

// Load a parameter to all workers (each can extract its shard inside RankWorker)
void load_param(const std::string &name, const infinicore::Tensor &param);

Expand All @@ -50,8 +70,9 @@ class InferEngine {
std::vector<std::unique_ptr<RankWorker>> workers_;
std::unique_ptr<RankBarrier> barrier_;
distributed::CommunicationGroup communication_group_;
const InfinilmModel::Config &model_config_;
std::unique_ptr<cache::CacheConfig> cache_config_;
const InfinilmModel::Config &legacy_model_config_ = InfinilmModel::Config();
std::shared_ptr<infinilm::config::ModelConfig> model_config_;
};

} // namespace infinilm::engine
40 changes: 38 additions & 2 deletions csrc/engine/rank_worker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,54 @@

#include "infinicore/ops.hpp"

#include <iostream>
#include <spdlog/spdlog.h>
#include <stdexcept>

namespace infinilm::engine {

/**
* @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
*
* ⚠️ DEVELOPMENT POLICY:
* - NO new development or feature additions permitted on this interface
* - Only critical bug fixes (security/stability) allowed until removal
* - All new code MUST migrate to the polymorphic overload below
*
* Replacement: Use the polymorphic overload of this same function name with updated signature
* Reason: Legacy signature lacks support for dynamic quantization modes.
* Removal target: v0.2.0 (Q2 2026)
*/
RankWorker::RankWorker(const InfinilmModel::Config &model_config,
const distributed::RankInfo &rank_info,
const cache::CacheConfig *cache_config,
RankBarrier *barrier,
bool enable_graph_compiling)
: legacy_model_config_(model_config),
rank_info_(rank_info),
enable_graph_compiling_(enable_graph_compiling),
job_cmd_(Command::INIT),
has_job_(false),
job_done_(false),
should_exit_(false),
init_done_(false),
barrier_(barrier) {
if (cache_config != nullptr) {
pending_cache_config_ = cache_config->unique_copy();
}
// start the thread
thread_ = std::thread(&RankWorker::thread_loop, this);

// Wait until the worker thread finishes initialization (model created)
std::unique_lock<std::mutex> lk(mutex_);
cv_.wait(lk, [&] { return init_done_; });
}

RankWorker::RankWorker(
std::shared_ptr<infinilm::config::ModelConfig> model_config,
const distributed::RankInfo &rank_info,
const cache::CacheConfig *cache_config,
RankBarrier *barrier,
bool enable_graph_compiling)
: model_config_(model_config),
rank_info_(rank_info),
enable_graph_compiling_(enable_graph_compiling),
Expand All @@ -30,7 +67,6 @@ RankWorker::RankWorker(const InfinilmModel::Config &model_config,
}
// start the thread
thread_ = std::thread(&RankWorker::thread_loop, this);

// Wait until the worker thread finishes initialization (model created)
std::unique_lock<std::mutex> lk(mutex_);
cv_.wait(lk, [&] { return init_done_; });
Expand Down
Loading