Skip to content

feat: support rolling weight loading.#1024

Open
Clement-Wang26 wants to merge 1 commit intojd-opensource:mainfrom
Clement-Wang26:rolling_load
Open

feat: support rolling weight loading.#1024
Clement-Wang26 wants to merge 1 commit intojd-opensource:mainfrom
Clement-Wang26:rolling_load

Conversation

@Clement-Wang26
Copy link
Collaborator

@Clement-Wang26 Clement-Wang26 commented Mar 9, 2026

Background

  • The previous path keeps all decoder-layer weights resident in HBM, which is expensive on memory.
  • This change introduces rolling loading so only a limited number of decoder slots stay in HBM while layers are streamed in asynchronously.

Core Changes

  • Added rolling/manual-loader flags and validation:
    • enable_manual_loader
    • enable_rolling_load
    • rolling_load_num_cached_layers
    • rolling_load_num_rolling_slots
  • Added RollingWeightBuffer:
    • Manages N device slots (slot = layer_index % num_slots).
    • Supports both XTensor allocation and non-XTensor aclrtMalloc.
    • Supports address refresh after wakeup (refresh_address).
  • Added RollingLoadManager:
    • Owns rolling lifecycle: init_rolling_load, before_layer, after_layer, finalize.
    • Uses per-layer compute/H2D events to pipeline copy and compute safely.
    • Supports mixed fixed+rolling slot policy and dirty/refill restore behavior.
  • Extended BaseManualLoader:
    • Added rolling slot binding (set_rolling_buffer).
    • Added refresh_rolling_device_at_weights to rebuild device pointers and AT views.
    • Added stream-aware async H2D API.
    • Device storage allocation now supports rolling slot / xtensor / aclrt paths.
  • Extended layer/model abstraction for rolling:
    • Added get_manual_loader and rolling AT refresh hooks in base layer/model.
    • Added rolling-related optional interfaces in CausalLM traits:
      • get_decoder_loaders
      • set_rolling_load_manager
      • init_rolling_model_state
    • Wired these interfaces into NPU model implementations.
  • Integrated rolling hooks into decoder forward:
    • Call before_layer/after_layer per layer.
    • Call finalize(last_executed_layer) at end to support partial/aborted forwards.
  • Integrated runtime wiring in WorkerImpl:
    • Added dedicated load_stream_ and rolling_load_manager_.
    • Rolling path uses lazy_load_model + init_rolling_runtime_state.
    • LIGHT_SLEEP wakeup refreshes rolling buffer address and re-inits rolling state.
    • Remote wakeup explicitly rejects rolling mode.
  • Updated XTensor budgeting/model loader APIs:
    • Added get_non_decoder_weight_size() in ModelLoader.
    • XTensor weight budget now uses non-decoder size + rolling buffer size.
    • Renamed allocator free API to free_weight_allocation for clearer semantics.

Behavior / Compatibility

  • Rolling load is NPU-only.
  • enable_xtensor or enable_rolling_load requires enable_manual_loader.
  • In rolling mode, decoder weights use host-pinned memory as source and are copied to HBM slots on demand.
  • Remote weight wakeup is not supported when rolling is enabled.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant feature: rolling weight loading, designed to reduce device memory (HBM) consumption by loading decoder layer weights on-demand. The changes are extensive, touching configuration flags, memory allocation, model loading, and the forward pass logic in various models. A RollingLoadManager is introduced to orchestrate the just-in-time H2D transfer of weights from pinned host memory to a rolling buffer on the device. My review identified two critical issues. The first is a potential memory under-allocation for XTensor due to integer division that should be a ceiling division. The second is a use-after-free bug where host-pinned memory, essential for rolling load, is prematurely deallocated during a LIGHT_SLEEP cycle. I have provided code suggestions to address both critical issues.

// Free weight allocation (called by sleep), including both contiguous
// GlobalXTensor and fallback XTensor allocations.
// Returns the number of pages freed.
size_t free_weight_allocation(const std::string& model_id);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's not a good func name because it contains free and allocation, which are antonyms.

{"float64", torch::kFloat64},
{"int8", torch::kInt8},
};
auto it = kDtypeMap.find(args_.dtype());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can move this find func to utils/tensor_helper.h as an individual function like:

constexpr torch::ScalarType get_scalar_type() {

<< args_.dtype() << ", falling back to total_weight_size";
return get_total_weight_size();
}
int64_t bytes_per_elem = torch::elementSize(it->second);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

refer to

inline int32_t get_dtype_size(torch::ScalarType dtype) {

virtual void copy_weights_to_device_async();

// Async H2D using the specified ACL stream (used by RollingLoadManager).
virtual void copy_weights_to_device_async(aclrtStream stream);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not use pointer for aclrtStream?

@@ -0,0 +1,173 @@
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2026


namespace xllm {

class Stream;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this Stream?

<< options.src_weight_segments.size();
return wakeup_from_remote_weights(options);
#endif
LOG(ERROR) << "Remote weight wakeup requires USE_NPU build";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remote weight wakeup only supports npu device.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants