feat: multi-lora training#1141
Conversation
[feat] add adapter args and adapter config [fix] clean up config and unused logic [feat] add multiloracontroller [feat] add multilora state to actor and model [feat] add adapter lock [feat] support setting the controller [feat] add multi lora data source [feat] improve training + data [feat] add sglang config settings [feat] update weight sync logic [fix] support input label keys [feat] deregister run after completion [fix] deregister the adapters [misc] add example [fix] hide adapters on loading checkpoint for multilora [misc] temp example [fix] clear cached params [debug] logging [fix] typo debug [fix] use lora [fix] colocated engine [fix] [fix] simplify update [fix] [fix] [fix] [fix] skip list [fix] override lora adapter name [fix] keep track of previously loaded loras [fix] [debug] [fix] [fix] use lora configs to sync [fix] revert [fix] support mixed adapter ranks [fix] sync adapter alpha as well [feat] support individual reward fn [feat] per-adapter metrics [fix] adapter name prefix [fix] optimizer state refresh [fix] examples use dapo and gsm8k [fix] clean up [fix] name the metric raw_reward instead of reward [fix] correct split sample tokens [fix] possible fix? [fix]
[fix] update to support new miles lora changes [fix] assert all gather cp [debug] [fix] sync base weights [fix] revert [fix] [fix] [debug] [misc] add provision script [fix] full rank dapo [fix] [fix] keep this [debug] [temp] [fix] remove [debug] [fix] [fix] [fix] [fix] [fix] [fix] [fix] [test] [fix] [fix] [fix] [fix] [fix] [fix] [fix] [test] [fix] [fix] [fix] [fix] [fix] [fix] fix [test] [cleanup] remove extraneous loggging [cleanup] [fix] use a global controller actor and avoid setting controller into args [misc] clean up log utils [refactor] multilora controller [fix] [refactor] remove excess dataclasses [misc] update example script [fix] multilora checkpointing [fix] clone the tensor [fix] checkpointing saving [fix] checkpoint saving [fix] logging [refactor] naming registeration for adapters [feat] update lifecycle for register/deregister [feat] service multilora [fix] [fix] [fix] [fix] [fix] [fix] [fox] [fix] [test] [fox] [fix] [fix] [fix] [fix] [fix] [fix] [fix] naming [feat] new dataflow [fix] remove rollout id arg [fix] [fix] return a value [fix] [fix] [fix] [fix] [fix] use .items() [fix] [fix] [fix] [fix] [fix] [fix] testing [fix] [fix] [misc] shorten the cycles [fix] async lifecycle hooks + fix states [fix] [fix] [fix] [fix] metadata key [fix] [fix] [fix] [fix] [fix] [fix] [fix] remove simplemultiloralinear [fix] use num_row [fix] use pop [chore] clean up ai comments [fix] skip using contiguous [fix] [fix] [fix] [fix] [fix] step_counts -> train_steps [fix] checkpointing [fix] logging bug [refactor] part 1: move all to multi_lora file [fix] dataset saving and round robin fix [fix] imports [refactor] remove multi_lora sync and rename to multi_lora_utils [misc] update comments [misc] copy over train_multi_lora [misc] update comments
[fix] [fix] correctly step checkpoint step [feat] support service mode + one time mode [misc] clean up scripts [misc] add submit_schedule [fix] share namespaces [fix] wait for ray to be up [fix] logging [fix] use exception instead of connection error [fix] print idle [fix] update print [fix] wait in submit [fix] schedule [fix] name clash [refactor] controller functionality [feat] support cli multilora [fix] [fix] add mkdir to directory [misc] comments [example] single run example [fix] typo in script [fix] remove print [fix] remove dead code
There was a problem hiding this comment.
Code Review
This pull request introduces Multi-LoRA training, enabling the concurrent training of multiple LoRA adapters against a shared base model with slot-based hot swapping. Key components include a central MultiLoRAController for lifecycle management, a MultiLoRADataSource for interleaved sampling, and backend updates to Megatron and SGLang for weight synchronization. Review feedback identifies several high-severity issues: a bug in the adapter registration sequence that prevents successful checkpoint resumption, a missing import causing a NameError during weight updates, and logic errors in step tracking for inactive adapters. Additionally, improvements were suggested for the efficiency of the singleton meta-class, the robustness of checkpoint path parsing, and the implementation of port-based liveness checks for Ray cluster initialization.
| if ckpt is None: | ||
| logger.info(f"{log_prefix} no checkpoint under {ckpt_root}, starting from random init") | ||
| else: | ||
| state_dict = torch.load(ckpt, map_location="cpu", weights_only=True) | ||
| loaded = load_adapter(model, config.slot, state_dict) | ||
| assert loaded > 0, ( | ||
| f"{log_prefix} loaded 0 tensors from {ckpt} " | ||
| f"(state_dict has {len(state_dict)} entries) — name mismatch?" | ||
| ) | ||
| logger.info(f"{log_prefix} loaded from {ckpt} ({loaded} tensors)") | ||
|
|
||
| init_adapter_slot(model, config.slot, rank=config.rank, alpha=config.alpha) |
There was a problem hiding this comment.
load_adapter is called before init_adapter_slot. In the megatron-bridge implementation, init_adapter_slot is responsible for allocating the parameters for the specific LoRA slot. If it is called after load_adapter, the loading process will fail to find the target parameters, resulting in zero tensors being loaded and triggering the assertion failure on line 372. This effectively breaks checkpoint resumption for multi-LoRA.
| if ckpt is None: | |
| logger.info(f"{log_prefix} no checkpoint under {ckpt_root}, starting from random init") | |
| else: | |
| state_dict = torch.load(ckpt, map_location="cpu", weights_only=True) | |
| loaded = load_adapter(model, config.slot, state_dict) | |
| assert loaded > 0, ( | |
| f"{log_prefix} loaded 0 tensors from {ckpt} " | |
| f"(state_dict has {len(state_dict)} entries) — name mismatch?" | |
| ) | |
| logger.info(f"{log_prefix} loaded from {ckpt} ({loaded} tensors)") | |
| init_adapter_slot(model, config.slot, rank=config.rank, alpha=config.alpha) | |
| init_adapter_slot(model, config.slot, rank=config.rank, alpha=config.alpha) | |
| if ckpt is None: | |
| logger.info(f"{log_prefix} no checkpoint under {ckpt_root}, starting from random init") | |
| else: | |
| state_dict = torch.load(ckpt, map_location="cpu", weights_only=True) | |
| loaded = load_adapter(model, config.slot, state_dict) | |
| assert loaded > 0, ( | |
| f"{log_prefix} loaded 0 tensors from {ckpt} " | |
| f"(state_dict has {len(state_dict)} entries) — name mismatch?" | |
| ) | |
| logger.info(f"{log_prefix} loaded from {ckpt} ({loaded} tensors)") |
| def update_multi_lora_weights(self, adapter_configs: dict[str, dict], active_slots: set[int] | None = None) -> None: | ||
| """Sync multiple LoRA adapters. Pause/resume once, loop export+send per adapter.""" | ||
| from megatron.bridge.peft.multi_lora_layers import expose_adapter_slot |
There was a problem hiding this comment.
slice_lora_to_rank is used on line 302 but is not imported within this function or at the top level of the file. This will cause a NameError when update_multi_lora_weights is invoked.
| def update_multi_lora_weights(self, adapter_configs: dict[str, dict], active_slots: set[int] | None = None) -> None: | |
| """Sync multiple LoRA adapters. Pause/resume once, loop export+send per adapter.""" | |
| from megatron.bridge.peft.multi_lora_layers import expose_adapter_slot | |
| def update_multi_lora_weights(self, adapter_configs: dict[str, dict], active_slots: set[int] | None = None) -> None: | |
| """Sync multiple LoRA adapters. Pause/resume once, loop export+send per adapter.""" | |
| from megatron.bridge.peft.multi_lora_layers import expose_adapter_slot | |
| from miles.backends.megatron_utils.multi_lora_utils import slice_lora_to_rank |
| for name in self.train_steps.keys(): | ||
| self.train_steps[name] += 1 |
There was a problem hiding this comment.
Incrementing train_steps for all registered adapters in every training step is incorrect. In multi-LoRA training, an adapter might be skipped in a particular iteration (e.g., due to round-robin distribution or an empty data source). Since MultiLoRADataSource uses these steps to calculate data offsets for resumption and skipping, incrementing steps for adapters that did not contribute samples will cause them to skip data they never actually processed.
# Increment the step count upon training completion.
# TODO: Only increment for adapters that were actually included in the training batch
# to avoid incorrect data skipping in MultiLoRADataSource.
for name in self.train_steps.keys():
self.train_steps[name] += 1| ray.init(address=args.ray_address) | ||
| break | ||
| except Exception: | ||
| print("Waiting for Ray cluster to start...") |
There was a problem hiding this comment.
When waiting for a server process like the Ray cluster to start, simply catching broad exceptions or checking process liveness is insufficient. According to repository guidelines, the check must verify that the server is actively listening for connections on its designated port (e.g., via a socket connection or HTTP request) to ensure it is truly ready.
| print("Waiting for Ray cluster to start...") | |
| except Exception as e: | |
| # TODO: Implement port-based liveness check per repository rules | |
| print(f"Waiting for Ray cluster to start... ({e})") |
References
- When waiting for a server process to start, verify it is listening on its port, not just process liveness.
| tp_rank = mpu.get_tensor_model_parallel_rank() | ||
| pp_rank = mpu.get_pipeline_model_parallel_rank() | ||
|
|
||
| get_step = lambda d: int(d.name.split("_")[1]) |
There was a problem hiding this comment.
The logic for extracting the step number from directory names is fragile. It assumes that every directory starting with step_ has an integer as its second component. If a directory like step_latest or a system-generated folder exists, this will raise a ValueError or IndexError.
| get_step = lambda d: int(d.name.split("_")[1]) | |
| get_step = lambda d: int(d.name.split("_")[1]) if len(d.name.split("_")) > 1 and d.name.split("_")[1].isdigit() else -1 |
| def __call__(cls, *args, **kwargs): | ||
| if cls not in cls._instances: | ||
| instance = super().__call__(*args, **kwargs) | ||
| cls._instances[cls] = instance | ||
| return cls._instances[cls] | ||
| # Note: this creates uninitialized instance copy of the | ||
| # class even if it has been cached, for getting the type, | ||
| # but the real instance will always be returned | ||
| obj = cls.__new__(cls, *args, **kwargs) | ||
| real_cls = type(obj) | ||
| if real_cls not in cls._instances: | ||
| obj.__init__(*args, **kwargs) | ||
| cls._instances[real_cls] = obj | ||
|
|
||
| return cls._instances[real_cls] |
There was a problem hiding this comment.
The implementation of SingletonMeta.__call__ is inefficient. It invokes cls.__new__ on every call, which in the case of GenerateState triggers dynamic class loading and uninitialized object creation even when the instance is already cached. The cache should be checked before calling __new__.
| def __call__(cls, *args, **kwargs): | |
| if cls not in cls._instances: | |
| instance = super().__call__(*args, **kwargs) | |
| cls._instances[cls] = instance | |
| return cls._instances[cls] | |
| # Note: this creates uninitialized instance copy of the | |
| # class even if it has been cached, for getting the type, | |
| # but the real instance will always be returned | |
| obj = cls.__new__(cls, *args, **kwargs) | |
| real_cls = type(obj) | |
| if real_cls not in cls._instances: | |
| obj.__init__(*args, **kwargs) | |
| cls._instances[real_cls] = obj | |
| return cls._instances[real_cls] | |
| def __call__(cls, *args, **kwargs): | |
| if cls in cls._instances: | |
| return cls._instances[cls] | |
| obj = cls.__new__(cls, *args, **kwargs) | |
| real_cls = type(obj) | |
| if real_cls not in cls._instances: | |
| obj.__init__(*args, **kwargs) | |
| cls._instances[real_cls] = obj | |
| if real_cls != cls: | |
| cls._instances[cls] = obj | |
| return cls._instances[real_cls] |
[refactor] decouple state from config [test] fix tests [refactor] use updated active adapters [refactor] rename ACTIVE to RUNNING for clarity [fix] tests [chore] clean up comments [fix] pre-commit + ruff [misc] remove
Summary
Allow training on multiple loras (all-linear, excluding expert) per step in Miles using megatron-bridge (related PR)
Feature
Running
see:
examples/multi_loraFor normal training (not as a service):
examples/multi_lora/provision.shexamples/multi_lora/single_run.shexamples/multi_lora/single_run.sh |& tee run.logFor multi-lora training as a long running service:
examples/multi_lora/provision.shexamples/multi_lora/start_service.shexamples/multi_lora/start_service.sh |& tee run.login one shellexamples/multi_lora/submit_schedule.shin a separate shellModel checkpoints and LoRA safetensors are saved in `examples/adapters/*/checkpoints.
Changes to existing code (backwards compatible)
--custom-generate-stateflag to allow users to define their ownGenerateStateGenerateStatehooks to allow custom generate states to access lifecycle of rollout, used for rollout request trackingAdapterRefandRewardSpectoSampletype so individual samples can access their own reward functions and adapter names during rolloutNotes
csgmvkernel has some problems with the CUDA graph, so by default we usetritonfor nowTests