Skip to content

feat: multi-lora training#1141

Draft
mathewjhan wants to merge 32 commits into
radixark:mainfrom
mathewjhan:feat/multilora-rebase
Draft

feat: multi-lora training#1141
mathewjhan wants to merge 32 commits into
radixark:mainfrom
mathewjhan:feat/multilora-rebase

Conversation

@mathewjhan
Copy link
Copy Markdown
Contributor

@mathewjhan mathewjhan commented May 16, 2026

Summary

Allow training on multiple loras (all-linear, excluding expert) per step in Miles using megatron-bridge (related PR)

Feature

  • Train multiple LoRAs in a single training step (currently only colocated until [WIP][lora] support disaggregate model lora training #988 is merged)
  • Use as a long running service, supporting online loading and unloading
  • Use as normal training, stopping when no more LoRAs are left to train

Running

see: examples/multi_lora

For normal training (not as a service):

  1. run examples/multi_lora/provision.sh
  2. configure W&B credentials and settings in examples/multi_lora/single_run.sh
  3. run examples/multi_lora/single_run.sh |& tee run.log

For multi-lora training as a long running service:

  1. run examples/multi_lora/provision.sh
  2. configure W&B credentials and settings in examples/multi_lora/start_service.sh
  3. run examples/multi_lora/start_service.sh |& tee run.log in one shell
  4. run examples/multi_lora/submit_schedule.sh in a separate shell

Model checkpoints and LoRA safetensors are saved in `examples/adapters/*/checkpoints.

Changes to existing code (backwards compatible)

  • Support --custom-generate-state flag to allow users to define their own GenerateState
  • Add GenerateState hooks to allow custom generate states to access lifecycle of rollout, used for rollout request tracking
  • Add AdapterRef and RewardSpec to Sample type so individual samples can access their own reward functions and adapter names during rollout

Notes

  • Currently doesn't checkpoint optimizer + scheduler state yet, but can be added later as future PR
  • Dataset checkpoint loading per adapter doesn't fully work yet since data source API doesn't support per adapter loading yet
  • MultiLoRA not applied to experts as of now due to more complex bookkeeping required (need to keep track of the [adapter index, routed experts] together)
  • The sglang lora csgmv kernel has some problems with the CUDA graph, so by default we use triton for now
  • Doesn't support load from *.bin/*.safetensors yet, can be added later as a future PR, only resume from a megatron checkpoint or train from scratch

Tests

  • e2e Qwen3-4B test using 2 LoRA adapters trained on gsm8k and dapo_math
  • MultiLoRAController tests
  • AdapterConfig tests

maocheng23 and others added 15 commits May 12, 2026 16:16
[feat] add adapter args and adapter config

[fix] clean up config and unused logic

[feat] add multiloracontroller

[feat] add multilora state to actor and model

[feat] add adapter lock

[feat] support setting the controller

[feat] add multi lora data source

[feat] improve training + data

[feat] add sglang config settings

[feat] update weight sync logic

[fix] support input label keys

[feat] deregister run after completion

[fix] deregister the adapters

[misc] add example

[fix] hide adapters on loading checkpoint for multilora

[misc] temp example

[fix] clear cached params

[debug]

logging

[fix] typo

debug

[fix] use lora

[fix] colocated engine

[fix]

[fix] simplify update

[fix]

[fix]

[fix]

[fix] skip list

[fix] override lora adapter name

[fix] keep track of previously loaded loras

[fix]

[debug]

[fix]

[fix] use lora configs to sync

[fix] revert

[fix] support mixed adapter ranks

[fix] sync adapter alpha as well

[feat] support individual reward fn

[feat] per-adapter metrics

[fix] adapter name prefix

[fix] optimizer state refresh

[fix] examples use dapo and gsm8k

[fix] clean up

[fix] name the metric raw_reward instead of reward

[fix] correct split sample tokens

[fix] possible fix?

[fix]
[fix] update to support new miles lora changes

[fix] assert all gather cp

[debug]

[fix] sync base weights

[fix] revert

[fix]

[fix]

[debug]

[misc] add provision script

[fix] full rank dapo

[fix]

[fix] keep this

[debug]

[temp]

[fix] remove

[debug]

[fix]

[fix]

[fix]

[fix]

[fix]

[fix]

[fix]

[test]

[fix]

[fix]

[fix]

[fix]

[fix]

[fix]

[fix]

[test]

[fix]

[fix]

[fix]

[fix]

[fix]

[fix]

fix

[test]

[cleanup] remove extraneous loggging

[cleanup]

[fix] use a global controller actor and avoid setting controller into args

[misc] clean up log utils

[refactor] multilora controller

[fix]

[refactor] remove excess dataclasses

[misc] update example script

[fix] multilora checkpointing

[fix] clone the tensor

[fix] checkpointing saving

[fix] checkpoint saving

[fix] logging

[refactor] naming registeration for adapters

[feat] update lifecycle for register/deregister

[feat] service multilora

[fix]

[fix]

[fix]

[fix]

[fix]

[fix]

[fox]

[fix]

[test]

[fox]

[fix]

[fix]

[fix]

[fix]

[fix]

[fix]

[fix] naming

[feat] new dataflow

[fix] remove rollout id arg

[fix]

[fix] return a value

[fix]

[fix]

[fix]

[fix]

[fix] use .items()

[fix]

[fix]

[fix]

[fix]

[fix]

[fix] testing

[fix]

[fix]

[misc] shorten the cycles

[fix] async lifecycle hooks + fix states

[fix]

[fix]

[fix]

[fix] metadata key

[fix]

[fix]

[fix]

[fix]

[fix]

[fix]

[fix] remove simplemultiloralinear

[fix] use num_row

[fix] use pop

[chore] clean up ai comments

[fix] skip using contiguous

[fix]

[fix]

[fix]

[fix]

[fix] step_counts -> train_steps

[fix] checkpointing

[fix] logging bug

[refactor] part 1: move all to multi_lora file

[fix] dataset saving and round robin fix

[fix] imports

[refactor] remove multi_lora sync and rename to multi_lora_utils

[misc] update comments

[misc] copy over train_multi_lora

[misc] update comments
[fix]

[fix] correctly step checkpoint step

[feat] support service mode + one time mode

[misc] clean up scripts

[misc] add submit_schedule

[fix] share namespaces

[fix] wait for ray to be up

[fix] logging

[fix] use exception instead of connection error

[fix] print idle

[fix] update print

[fix] wait in submit

[fix] schedule

[fix] name clash

[refactor] controller functionality

[feat] support cli multilora

[fix]

[fix] add mkdir to directory

[misc] comments

[example] single run example

[fix] typo in script

[fix] remove print

[fix] remove dead code
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces Multi-LoRA training, enabling the concurrent training of multiple LoRA adapters against a shared base model with slot-based hot swapping. Key components include a central MultiLoRAController for lifecycle management, a MultiLoRADataSource for interleaved sampling, and backend updates to Megatron and SGLang for weight synchronization. Review feedback identifies several high-severity issues: a bug in the adapter registration sequence that prevents successful checkpoint resumption, a missing import causing a NameError during weight updates, and logic errors in step tracking for inactive adapters. Additionally, improvements were suggested for the efficiency of the singleton meta-class, the robustness of checkpoint path parsing, and the implementation of port-based liveness checks for Ray cluster initialization.

Comment on lines +367 to +378
if ckpt is None:
logger.info(f"{log_prefix} no checkpoint under {ckpt_root}, starting from random init")
else:
state_dict = torch.load(ckpt, map_location="cpu", weights_only=True)
loaded = load_adapter(model, config.slot, state_dict)
assert loaded > 0, (
f"{log_prefix} loaded 0 tensors from {ckpt} "
f"(state_dict has {len(state_dict)} entries) — name mismatch?"
)
logger.info(f"{log_prefix} loaded from {ckpt} ({loaded} tensors)")

init_adapter_slot(model, config.slot, rank=config.rank, alpha=config.alpha)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

load_adapter is called before init_adapter_slot. In the megatron-bridge implementation, init_adapter_slot is responsible for allocating the parameters for the specific LoRA slot. If it is called after load_adapter, the loading process will fail to find the target parameters, resulting in zero tensors being loaded and triggering the assertion failure on line 372. This effectively breaks checkpoint resumption for multi-LoRA.

Suggested change
if ckpt is None:
logger.info(f"{log_prefix} no checkpoint under {ckpt_root}, starting from random init")
else:
state_dict = torch.load(ckpt, map_location="cpu", weights_only=True)
loaded = load_adapter(model, config.slot, state_dict)
assert loaded > 0, (
f"{log_prefix} loaded 0 tensors from {ckpt} "
f"(state_dict has {len(state_dict)} entries) — name mismatch?"
)
logger.info(f"{log_prefix} loaded from {ckpt} ({loaded} tensors)")
init_adapter_slot(model, config.slot, rank=config.rank, alpha=config.alpha)
init_adapter_slot(model, config.slot, rank=config.rank, alpha=config.alpha)
if ckpt is None:
logger.info(f"{log_prefix} no checkpoint under {ckpt_root}, starting from random init")
else:
state_dict = torch.load(ckpt, map_location="cpu", weights_only=True)
loaded = load_adapter(model, config.slot, state_dict)
assert loaded > 0, (
f"{log_prefix} loaded 0 tensors from {ckpt} "
f"(state_dict has {len(state_dict)} entries) — name mismatch?"
)
logger.info(f"{log_prefix} loaded from {ckpt} ({loaded} tensors)")

Comment on lines +256 to +258
def update_multi_lora_weights(self, adapter_configs: dict[str, dict], active_slots: set[int] | None = None) -> None:
"""Sync multiple LoRA adapters. Pause/resume once, loop export+send per adapter."""
from megatron.bridge.peft.multi_lora_layers import expose_adapter_slot
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

slice_lora_to_rank is used on line 302 but is not imported within this function or at the top level of the file. This will cause a NameError when update_multi_lora_weights is invoked.

Suggested change
def update_multi_lora_weights(self, adapter_configs: dict[str, dict], active_slots: set[int] | None = None) -> None:
"""Sync multiple LoRA adapters. Pause/resume once, loop export+send per adapter."""
from megatron.bridge.peft.multi_lora_layers import expose_adapter_slot
def update_multi_lora_weights(self, adapter_configs: dict[str, dict], active_slots: set[int] | None = None) -> None:
"""Sync multiple LoRA adapters. Pause/resume once, loop export+send per adapter."""
from megatron.bridge.peft.multi_lora_layers import expose_adapter_slot
from miles.backends.megatron_utils.multi_lora_utils import slice_lora_to_rank

Comment on lines +232 to +233
for name in self.train_steps.keys():
self.train_steps[name] += 1
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Incrementing train_steps for all registered adapters in every training step is incorrect. In multi-LoRA training, an adapter might be skipped in a particular iteration (e.g., due to round-robin distribution or an empty data source). Since MultiLoRADataSource uses these steps to calculate data offsets for resumption and skipping, incrementing steps for adapters that did not contribute samples will cause them to skip data they never actually processed.

        # Increment the step count upon training completion.
        # TODO: Only increment for adapters that were actually included in the training batch
        # to avoid incorrect data skipping in MultiLoRADataSource.
        for name in self.train_steps.keys():
            self.train_steps[name] += 1

ray.init(address=args.ray_address)
break
except Exception:
print("Waiting for Ray cluster to start...")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

When waiting for a server process like the Ray cluster to start, simply catching broad exceptions or checking process liveness is insufficient. According to repository guidelines, the check must verify that the server is actively listening for connections on its designated port (e.g., via a socket connection or HTTP request) to ensure it is truly ready.

Suggested change
print("Waiting for Ray cluster to start...")
except Exception as e:
# TODO: Implement port-based liveness check per repository rules
print(f"Waiting for Ray cluster to start... ({e})")
References
  1. When waiting for a server process to start, verify it is listening on its port, not just process liveness.

tp_rank = mpu.get_tensor_model_parallel_rank()
pp_rank = mpu.get_pipeline_model_parallel_rank()

get_step = lambda d: int(d.name.split("_")[1])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for extracting the step number from directory names is fragile. It assumes that every directory starting with step_ has an integer as its second component. If a directory like step_latest or a system-generated folder exists, this will raise a ValueError or IndexError.

Suggested change
get_step = lambda d: int(d.name.split("_")[1])
get_step = lambda d: int(d.name.split("_")[1]) if len(d.name.split("_")) > 1 and d.name.split("_")[1].isdigit() else -1

Comment thread miles/utils/misc.py
Comment on lines 71 to +81
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
instance = super().__call__(*args, **kwargs)
cls._instances[cls] = instance
return cls._instances[cls]
# Note: this creates uninitialized instance copy of the
# class even if it has been cached, for getting the type,
# but the real instance will always be returned
obj = cls.__new__(cls, *args, **kwargs)
real_cls = type(obj)
if real_cls not in cls._instances:
obj.__init__(*args, **kwargs)
cls._instances[real_cls] = obj

return cls._instances[real_cls]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The implementation of SingletonMeta.__call__ is inefficient. It invokes cls.__new__ on every call, which in the case of GenerateState triggers dynamic class loading and uninitialized object creation even when the instance is already cached. The cache should be checked before calling __new__.

Suggested change
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
instance = super().__call__(*args, **kwargs)
cls._instances[cls] = instance
return cls._instances[cls]
# Note: this creates uninitialized instance copy of the
# class even if it has been cached, for getting the type,
# but the real instance will always be returned
obj = cls.__new__(cls, *args, **kwargs)
real_cls = type(obj)
if real_cls not in cls._instances:
obj.__init__(*args, **kwargs)
cls._instances[real_cls] = obj
return cls._instances[real_cls]
def __call__(cls, *args, **kwargs):
if cls in cls._instances:
return cls._instances[cls]
obj = cls.__new__(cls, *args, **kwargs)
real_cls = type(obj)
if real_cls not in cls._instances:
obj.__init__(*args, **kwargs)
cls._instances[real_cls] = obj
if real_cls != cls:
cls._instances[cls] = obj
return cls._instances[real_cls]

@yushengsu-thu yushengsu-thu self-assigned this May 17, 2026
[refactor] decouple state from config

[test] fix tests

[refactor] use updated active adapters

[refactor] rename ACTIVE to RUNNING for clarity

[fix] tests

[chore] clean up comments

[fix] pre-commit + ruff

[misc] remove
@mathewjhan mathewjhan marked this pull request as draft May 22, 2026 20:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants