feat: support rolling weight loading.#1024
feat: support rolling weight loading.#1024Clement-Wang26 wants to merge 1 commit intojd-opensource:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a significant feature: rolling weight loading, designed to reduce device memory (HBM) consumption by loading decoder layer weights on-demand. The changes are extensive, touching configuration flags, memory allocation, model loading, and the forward pass logic in various models. A RollingLoadManager is introduced to orchestrate the just-in-time H2D transfer of weights from pinned host memory to a rolling buffer on the device. My review identified two critical issues. The first is a potential memory under-allocation for XTensor due to integer division that should be a ceiling division. The second is a use-after-free bug where host-pinned memory, essential for rolling load, is prematurely deallocated during a LIGHT_SLEEP cycle. I have provided code suggestions to address both critical issues.
3b35397 to
d127ddf
Compare
| // Free weight allocation (called by sleep), including both contiguous | ||
| // GlobalXTensor and fallback XTensor allocations. | ||
| // Returns the number of pages freed. | ||
| size_t free_weight_allocation(const std::string& model_id); |
There was a problem hiding this comment.
it's not a good func name because it contains free and allocation, which are antonyms.
| {"float64", torch::kFloat64}, | ||
| {"int8", torch::kInt8}, | ||
| }; | ||
| auto it = kDtypeMap.find(args_.dtype()); |
There was a problem hiding this comment.
we can move this find func to utils/tensor_helper.h as an individual function like:
xllm/xllm/core/util/tensor_helper.h
Line 281 in 7e2c710
| << args_.dtype() << ", falling back to total_weight_size"; | ||
| return get_total_weight_size(); | ||
| } | ||
| int64_t bytes_per_elem = torch::elementSize(it->second); |
There was a problem hiding this comment.
refer to
xllm/xllm/core/util/tensor_helper.h
Line 329 in 7e2c710
| virtual void copy_weights_to_device_async(); | ||
|
|
||
| // Async H2D using the specified ACL stream (used by RollingLoadManager). | ||
| virtual void copy_weights_to_device_async(aclrtStream stream); |
There was a problem hiding this comment.
why not use pointer for aclrtStream?
| @@ -0,0 +1,173 @@ | |||
| /* Copyright 2025 The xLLM Authors. All Rights Reserved. | |||
|
|
||
| namespace xllm { | ||
|
|
||
| class Stream; |
xllm/core/runtime/worker_impl.cpp
Outdated
| << options.src_weight_segments.size(); | ||
| return wakeup_from_remote_weights(options); | ||
| #endif | ||
| LOG(ERROR) << "Remote weight wakeup requires USE_NPU build"; |
There was a problem hiding this comment.
Remote weight wakeup only supports npu device.
d127ddf to
ba900c0
Compare
Background
Core Changes
enable_manual_loaderenable_rolling_loadrolling_load_num_cached_layersrolling_load_num_rolling_slotsRollingWeightBuffer:slot = layer_index % num_slots).aclrtMalloc.refresh_address).RollingLoadManager:init_rolling_load,before_layer,after_layer,finalize.BaseManualLoader:set_rolling_buffer).refresh_rolling_device_at_weightsto rebuild device pointers and AT views.get_manual_loaderand rolling AT refresh hooks in base layer/model.get_decoder_loadersset_rolling_load_managerinit_rolling_model_statebefore_layer/after_layerper layer.finalize(last_executed_layer)at end to support partial/aborted forwards.WorkerImpl:load_stream_androlling_load_manager_.lazy_load_model + init_rolling_runtime_state.get_non_decoder_weight_size()inModelLoader.free_weight_allocationfor clearer semantics.Behavior / Compatibility
enable_xtensororenable_rolling_loadrequiresenable_manual_loader.