Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/CODEOWNERS
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
.github/CODEOWNERS @Ying1123 @guapisolo
/miles/ @guapisolo @Rockdu
5 changes: 3 additions & 2 deletions miles/backends/fsdp_utils/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,11 @@ def init(self, args: Namespace, role: str, with_ref: bool = False) -> int: # ty
checkpoint_payload = checkpoint.load(self)

# sglang-d now supports /update_weights_from_tensor (PR #20464).
update_weight_target_module = self.train_pipeline_config.update_weight_target_module
self.weight_updater = (
DiffusionUpdateWeightFromTensorLoRA(self.args, self.model)
DiffusionUpdateWeightFromTensorLoRA(self.args, self.model, update_weight_target_module)
if self.args.use_lora
else DiffusionUpdateWeightFromTensor(self.args, self.model)
else DiffusionUpdateWeightFromTensor(self.args, self.model, update_weight_target_module)
)

checkpoint.finalize_load(self, checkpoint_payload)
Expand Down
16 changes: 8 additions & 8 deletions miles/backends/fsdp_utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def load(actor: Any) -> dict[str, Any] | None:
Loads model weights and optionally optimizer state from separate directories.
This allows loading weights without optimizer or deleting optimizer before loading.
"""
load_root = getattr(actor.args, "load", None)
load_root = actor.args.load
if load_root is None:
return None

Expand All @@ -128,7 +128,7 @@ def load(actor: Any) -> dict[str, Any] | None:
logger.info(f"[FSDP] Checkpoint directory {root_path} not found; skipping load.")
return None

target_step = getattr(actor.args, "ckpt_step", None)
target_step = actor.args.ckpt_step
if target_step is None:
tracker_file = root_path / "latest_checkpointed_iteration.txt"
if not tracker_file.exists():
Expand All @@ -147,7 +147,7 @@ def load(actor: Any) -> dict[str, Any] | None:
return None

# Load model weights (always)
lora_only = bool(getattr(actor.args, "use_lora", False))
lora_only = actor.args.use_lora
model_state = ModelState(actor.model, lora_only=lora_only)
state_dict = {"model_state": model_state}

Expand All @@ -159,7 +159,7 @@ def load(actor: Any) -> dict[str, Any] | None:
return None

# Load optimizer state (optional)
load_optimizer = not getattr(actor.args, "no_load_optim", False) and hasattr(actor, "optimizer")
load_optimizer = not actor.args.no_load_optim and hasattr(actor, "optimizer")
if load_optimizer and optimizer_dir.exists():
allowed_missing = getattr(getattr(actor, "train_pipeline_config", None), "optimizer_state_allowed_missing", [])
optimizer_state = OptimizerState(actor.model, actor.optimizer, allowed_missing=allowed_missing)
Expand Down Expand Up @@ -204,7 +204,7 @@ def finalize_load(actor: Any, checkpoint_payload: dict[str, Any] | None) -> None
dist.barrier()
return

if checkpoint_payload.get("rng") is not None and not getattr(actor.args, "no_load_rng", False):
if checkpoint_payload.get("rng") is not None and not actor.args.no_load_rng:
rng_state = checkpoint_payload["rng"]
if "torch" in rng_state:
torch.set_rng_state(rng_state["torch"])
Expand All @@ -220,7 +220,7 @@ def finalize_load(actor: Any, checkpoint_payload: dict[str, Any] | None) -> None
if next_rollout is not None:
actor.args.start_rollout_id = next_rollout
elif iteration is not None:
if getattr(actor.args, "start_rollout_id", None) is None:
if actor.args.start_rollout_id is None:
actor.args.start_rollout_id = iteration

torch.cuda.synchronize()
Expand Down Expand Up @@ -250,13 +250,13 @@ def save(actor: Any, iteration: int) -> None:
dist.barrier()

# Save model weights
lora_only = bool(getattr(actor.args, "use_lora", False))
lora_only = actor.args.use_lora
model_state = ModelState(actor.model, lora_only=lora_only)
state_dict = {"model_state": model_state}
dcp.save(state_dict, checkpoint_id=str(model_dir))

# Save optimizer state (skip if --no-save-optim is set)
save_optimizer_state = not getattr(actor.args, "no_save_optim", False)
save_optimizer_state = not actor.args.no_save_optim
if save_optimizer_state and hasattr(actor, "optimizer") and actor.optimizer is not None:
allowed_missing = getattr(getattr(actor, "train_pipeline_config", None), "optimizer_state_allowed_missing", [])
optimizer_state = OptimizerState(actor.model, actor.optimizer, allowed_missing=allowed_missing)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class TrainPipelineConfig(abc.ABC):

lora_target_modules: list[str] = ["to_q", "to_k", "to_v", "to_out.0"]
optimizer_state_allowed_missing: list[str] = []
update_weight_target_module: str = "transformer"

def prepare_trajectory(
self,
Expand Down
10 changes: 4 additions & 6 deletions miles/backends/fsdp_utils/diffusion_update_weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,11 @@
class DiffusionUpdateWeight(abc.ABC):
"""Base updater used by diffusion training actors."""

def __init__(self, args: Namespace, model: torch.nn.Module) -> None:
def __init__(self, args: Namespace, model: torch.nn.Module, target_module: str) -> None:
self.args = args
self.model = model
self.weight_version = 0
# Name of the sglang-d pipeline module to target. Defaults to "transformer",
# which is the DiT component for diffusers-based pipelines.
self.target_module = getattr(args, "diffusion_target_module", "transformer")
self.target_module = target_module

@abc.abstractmethod
def connect_rollout_engines(
Expand Down Expand Up @@ -176,8 +174,8 @@ class DiffusionUpdateWeightFromTensorLoRA(DiffusionUpdateWeightFromTensor):
on the fly during sync (no in-place mutation of the FSDP model).
"""

def __init__(self, args, model):
super().__init__(args, model)
def __init__(self, args, model, target_module: str):
super().__init__(args, model, target_module)
self._lora_index: dict[str, tuple] = {}
for name, module in model.named_modules():
if hasattr(module, "lora_A") and hasattr(module, "lora_B"):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def _format_v6_uri(addr):
def _init_normal(self, server_args_dict):
logger.info(f"Launch HttpServerEngineAdapter at: {self.server_host}:{self.server_port}")
self._pin_to_assigned_gpu()
apply_sgld_monkey_patches = bool(getattr(self.args, "apply_sgld_monkey_patches", False))
apply_sgld_monkey_patches = self.args.apply_sgld_monkey_patches
if apply_sgld_monkey_patches:
logger.info(
"Launching sglang-d with sgl-d → diffusers monkey patches "
Expand Down
13 changes: 1 addition & 12 deletions miles/backends/training_utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,18 +588,7 @@ def policy_loss_function(
max_seq_lens,
)

# Determine pg_loss reducer: use custom if specified, otherwise default
if getattr(args, "custom_pg_loss_reducer_function_path", None) is not None:
custom_pg_loss_reducer_func = load_function(args.custom_pg_loss_reducer_function_path)
# Determine which loss_masks to use for pg_loss reducer
pg_loss_masks = modified_response_masks if (args.get_mismatch_metrics or args.use_tis) else batch["loss_masks"]
pg_loss_reducer = custom_pg_loss_reducer_func(
total_lengths, response_lengths, pg_loss_masks, args.calculate_per_token_loss
)
else:
pg_loss_reducer = sum_of_sample_mean

pg_loss = pg_loss_reducer(pg_loss)
pg_loss = sum_of_sample_mean(pg_loss)
pg_clipfrac = sum_of_sample_mean(pg_clipfrac)
ppo_kl = sum_of_sample_mean(ppo_kl)

Expand Down
6 changes: 3 additions & 3 deletions miles/ray/placement_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,14 +158,14 @@ def create_rollout_manager(args, pg):
logger.info(
"Creating rollout manager (diffusion=%s, num_gpus=%s)",
use_diffusion_rollout,
0 if (use_diffusion_rollout and getattr(args, "rollout_num_gpus", 1) > 1) else (1 if use_diffusion_rollout else 0),
0 if (use_diffusion_rollout and args.rollout_num_gpus > 1) else (1 if use_diffusion_rollout else 0),
)
scheduling_strategy = None
if use_diffusion_rollout:
pg_tuple = pg
# If rollout uses multiple GPUs, do NOT bind RolloutManager to the rollout PG.
# Otherwise it consumes a GPU bundle and starves rollout workers.
if getattr(args, "rollout_num_gpus", 1) <= 1:
if args.rollout_num_gpus <= 1:
pg, reordered_bundle_indices, _ = pg_tuple
bundle_index = reordered_bundle_indices[0] if reordered_bundle_indices else 0
scheduling_strategy = PlacementGroupSchedulingStrategy(
Expand All @@ -176,7 +176,7 @@ def create_rollout_manager(args, pg):

rollout_manager = RolloutManager.options(
num_cpus=1,
num_gpus=0 if (use_diffusion_rollout and getattr(args, "rollout_num_gpus", 1) > 1) else (1 if use_diffusion_rollout else 0),
num_gpus=0 if (use_diffusion_rollout and args.rollout_num_gpus > 1) else (1 if use_diffusion_rollout else 0),
scheduling_strategy=scheduling_strategy,
).remote(args, pg_tuple if use_diffusion_rollout else pg)

Expand Down
33 changes: 7 additions & 26 deletions miles/ray/rollout.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from ast import Raise
import itertools
import logging
import multiprocessing
Expand All @@ -11,7 +10,7 @@
import ray
import torch
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from sglang.srt.constants import GPU_MEMORY_TYPE_CUDA_GRAPH, GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS
from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS

from miles.backends.sglang_diffusion_utils.sglang_diffusion_engine import SGLangDiffusionEngine
from miles.rollout.base_types import call_rollout_fn
Expand Down Expand Up @@ -58,10 +57,7 @@ def __init__(self, args, pg):
self.data_source = data_source_cls(args)
logger.info("RolloutManager: data source loaded, loading rollout functions...")

import sys
print("[DEBUG] RolloutManager: loading generate_rollout...", flush=True)
self.generate_rollout = load_function(self.args.rollout_function_path)
print("[DEBUG] RolloutManager: loading eval_generate_rollout...", flush=True)
self.eval_generate_rollout = load_function(self.args.eval_function_path)
self.custom_reward_post_process_func = (
load_function(self.args.custom_reward_post_process_path)
Expand All @@ -73,12 +69,10 @@ def __init__(self, args, pg):
if self.args.custom_convert_samples_to_train_data_path is not None
else None
)
print(f"[DEBUG] RolloutManager: import {self.args.rollout_function_path} done", flush=True)
logger.info(f"import {self.args.rollout_function_path} as generate_rollout function.")
logger.info(f"import {self.args.eval_function_path} as eval_generate_rollout function.")

print(f"[DEBUG] RolloutManager rollout_num_gpus={getattr(self.args, 'rollout_num_gpus', None)}", flush=True)
logger.info("RolloutManager rollout_num_gpus=%s", getattr(self.args, "rollout_num_gpus", None))
logger.info("RolloutManager rollout_num_gpus=%s", self.args.rollout_num_gpus)

if self.args.debug_train_only:
self.all_rollout_engines = []
Expand All @@ -88,17 +82,12 @@ def __init__(self, args, pg):
num_gpu_per_engine = min(args.rollout_num_gpus_per_engine, args.num_gpus_per_node)
num_engines = args.rollout_num_gpus // num_gpu_per_engine
self.all_rollout_engines = [None] * num_engines
print(f"[DEBUG] RolloutManager: calling init_rollout_engines with {num_engines} engines...", flush=True)
self.num_new_engines = init_rollout_engines(args, pg, self.all_rollout_engines)
print(f"[DEBUG] RolloutManager: init_rollout_engines returned, started {len(self.all_rollout_engines)}", flush=True)
logger.info("RolloutManager started %s rollout engines", len(self.all_rollout_engines))
print("[DEBUG] RolloutManager: creating lock...", flush=True)
logger.info("RolloutManager: creating lock...")
self.nodes_per_engine = max(1, args.rollout_num_gpus_per_engine // args.num_gpus_per_node)
self.rollout_engine_lock = Lock.options(num_cpus=1, num_gpus=0).remote()
self.rollout_id = -1
self._diffusion_offload_fn = None
self._diffusion_onload_fn = None
self._metric_checker = MetricChecker.maybe_create(args)
self._health_monitor = None
if self.args.use_fault_tolerance:
Expand Down Expand Up @@ -173,7 +162,7 @@ def eval(self, rollout_id):
data = result.data
self._save_debug_rollout_data(data, rollout_id=rollout_id, evaluation=True)
metrics = _log_eval_rollout_data(rollout_id, self.args, data, result.metrics)
max_images = int(getattr(self.args, "diffusion_log_images", 0) or 0)
max_images = self.args.diffusion_log_images
if max_images > 0:
self._log_images(
{
Expand All @@ -184,7 +173,7 @@ def eval(self, rollout_id):
max_images=max_images,
step_key="eval/step",
step_value=compute_rollout_step(self.args, rollout_id),
reward_key=self.args.eval_reward_key or self.args.reward_key,
reward_key=self.args.eval_reward_key,
)
if self._metric_checker is not None:
self._metric_checker.on_eval(metrics)
Expand Down Expand Up @@ -349,23 +338,17 @@ def _convert_samples_to_train_data(self, samples: list[Sample] | list[list[Sampl
**_reward_stats_dict(norm_t, "rollout/reward/norm_"),
}
# Per-prompt (group) stats — meaningful for GRPO-style algorithms.
if getattr(self.args, "advantage_estimator", None) == "grpo" and self.args.n_samples_per_prompt > 1:
if self.args.advantage_estimator == "grpo" and self.args.n_samples_per_prompt > 1:
groups_raw = raw_t.view(-1, self.args.n_samples_per_prompt)
reward_stats["rollout/reward/group_mean_avg"] = float(groups_raw.mean(dim=-1).mean())
if groups_raw.shape[-1] > 1:
reward_stats["rollout/reward/group_std_avg"] = float(groups_raw.std(dim=-1, unbiased=False).mean())

print(
f"[reward stats] raw mean={raw_t.mean():.4f} std={raw_t.std():.4f} min={raw_t.min():.4f} max={raw_t.max():.4f} | "
f"normalized mean={norm_t.mean():.4f} std={norm_t.std():.4f} min={norm_t.min():.4f} max={norm_t.max():.4f}",
flush=True,
)

reward_stats["rollout/step"] = compute_rollout_step(self.args, self.rollout_id)
tracking_utils.log(self.args, reward_stats, step_key="rollout/step")

max_images = int(getattr(self.args, "diffusion_log_images", 0) or 0)
interval = max(1, int(getattr(self.args, "diffusion_log_image_interval", 1) or 1))
max_images = self.args.diffusion_log_images
interval = self.args.diffusion_log_image_interval
if max_images > 0 and self.rollout_id % interval == 0:
self._log_images(
{"rollout_media/sample_images": samples},
Expand Down Expand Up @@ -469,7 +452,6 @@ def init_rollout_engines(args, pg, all_rollout_engines):
assert len(all_rollout_engines) == num_engines

pg, reordered_bundle_indices, reordered_gpu_ids = pg
print(f"[DEBUG] init_rollout_engines: reordered_bundle_indices={reordered_bundle_indices}, reordered_gpu_ids={reordered_gpu_ids}", flush=True)

# use diffusion SGLang rollout engines for miles diffusion
RolloutRayActor = ray.remote(SGLangDiffusionEngine)
Expand All @@ -484,7 +466,6 @@ def init_rollout_engines(args, pg, all_rollout_engines):

# Get the base GPU ID from placement group
base_gpu_id = int(reordered_gpu_ids[i * num_gpu_per_engine])
print(f"[DEBUG] Engine {i}: base_gpu_id={base_gpu_id}, bundle_index={reordered_bundle_indices[i * num_gpu_per_engine]}", flush=True)

scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=pg,
Expand Down
2 changes: 1 addition & 1 deletion miles/rollout/rm_hub/ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class AsyncOcrPool(metaclass=SingletonMeta):
def __init__(self, args) -> None:
if not ray.is_initialized():
raise RuntimeError("Ray is not initialized. OCR RM requires Ray for OcrRewardActor.")
num_workers = int(getattr(args, "ocr_num_workers", 4) or 4)
num_workers = args.ocr_num_workers
if num_workers <= 0:
raise ValueError(f"ocr_num_workers must be > 0, got {num_workers}")
self._actors = [OcrRewardActor.options(num_cpus=1).remote(use_gpu=False) for _ in range(num_workers)]
Expand Down
Loading