diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 0000000..c1250ba --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1,2 @@ +.github/CODEOWNERS @Ying1123 @guapisolo +/miles/ @guapisolo @Rockdu diff --git a/miles/backends/fsdp_utils/actor.py b/miles/backends/fsdp_utils/actor.py index ebd929a..8222e3b 100644 --- a/miles/backends/fsdp_utils/actor.py +++ b/miles/backends/fsdp_utils/actor.py @@ -120,10 +120,11 @@ def init(self, args: Namespace, role: str, with_ref: bool = False) -> int: # ty checkpoint_payload = checkpoint.load(self) # sglang-d now supports /update_weights_from_tensor (PR #20464). + update_weight_target_module = self.train_pipeline_config.update_weight_target_module self.weight_updater = ( - DiffusionUpdateWeightFromTensorLoRA(self.args, self.model) + DiffusionUpdateWeightFromTensorLoRA(self.args, self.model, update_weight_target_module) if self.args.use_lora - else DiffusionUpdateWeightFromTensor(self.args, self.model) + else DiffusionUpdateWeightFromTensor(self.args, self.model, update_weight_target_module) ) checkpoint.finalize_load(self, checkpoint_payload) diff --git a/miles/backends/fsdp_utils/checkpoint.py b/miles/backends/fsdp_utils/checkpoint.py index 38c539f..9c4f848 100644 --- a/miles/backends/fsdp_utils/checkpoint.py +++ b/miles/backends/fsdp_utils/checkpoint.py @@ -119,7 +119,7 @@ def load(actor: Any) -> dict[str, Any] | None: Loads model weights and optionally optimizer state from separate directories. This allows loading weights without optimizer or deleting optimizer before loading. """ - load_root = getattr(actor.args, "load", None) + load_root = actor.args.load if load_root is None: return None @@ -128,7 +128,7 @@ def load(actor: Any) -> dict[str, Any] | None: logger.info(f"[FSDP] Checkpoint directory {root_path} not found; skipping load.") return None - target_step = getattr(actor.args, "ckpt_step", None) + target_step = actor.args.ckpt_step if target_step is None: tracker_file = root_path / "latest_checkpointed_iteration.txt" if not tracker_file.exists(): @@ -147,7 +147,7 @@ def load(actor: Any) -> dict[str, Any] | None: return None # Load model weights (always) - lora_only = bool(getattr(actor.args, "use_lora", False)) + lora_only = actor.args.use_lora model_state = ModelState(actor.model, lora_only=lora_only) state_dict = {"model_state": model_state} @@ -159,7 +159,7 @@ def load(actor: Any) -> dict[str, Any] | None: return None # Load optimizer state (optional) - load_optimizer = not getattr(actor.args, "no_load_optim", False) and hasattr(actor, "optimizer") + load_optimizer = not actor.args.no_load_optim and hasattr(actor, "optimizer") if load_optimizer and optimizer_dir.exists(): allowed_missing = getattr(getattr(actor, "train_pipeline_config", None), "optimizer_state_allowed_missing", []) optimizer_state = OptimizerState(actor.model, actor.optimizer, allowed_missing=allowed_missing) @@ -204,7 +204,7 @@ def finalize_load(actor: Any, checkpoint_payload: dict[str, Any] | None) -> None dist.barrier() return - if checkpoint_payload.get("rng") is not None and not getattr(actor.args, "no_load_rng", False): + if checkpoint_payload.get("rng") is not None and not actor.args.no_load_rng: rng_state = checkpoint_payload["rng"] if "torch" in rng_state: torch.set_rng_state(rng_state["torch"]) @@ -220,7 +220,7 @@ def finalize_load(actor: Any, checkpoint_payload: dict[str, Any] | None) -> None if next_rollout is not None: actor.args.start_rollout_id = next_rollout elif iteration is not None: - if getattr(actor.args, "start_rollout_id", None) is None: + if actor.args.start_rollout_id is None: actor.args.start_rollout_id = iteration torch.cuda.synchronize() @@ -250,13 +250,13 @@ def save(actor: Any, iteration: int) -> None: dist.barrier() # Save model weights - lora_only = bool(getattr(actor.args, "use_lora", False)) + lora_only = actor.args.use_lora model_state = ModelState(actor.model, lora_only=lora_only) state_dict = {"model_state": model_state} dcp.save(state_dict, checkpoint_id=str(model_dir)) # Save optimizer state (skip if --no-save-optim is set) - save_optimizer_state = not getattr(actor.args, "no_save_optim", False) + save_optimizer_state = not actor.args.no_save_optim if save_optimizer_state and hasattr(actor, "optimizer") and actor.optimizer is not None: allowed_missing = getattr(getattr(actor, "train_pipeline_config", None), "optimizer_state_allowed_missing", []) optimizer_state = OptimizerState(actor.model, actor.optimizer, allowed_missing=allowed_missing) diff --git a/miles/backends/fsdp_utils/configs/train_pipeline_config.py b/miles/backends/fsdp_utils/configs/train_pipeline_config.py index f380bdf..da29665 100644 --- a/miles/backends/fsdp_utils/configs/train_pipeline_config.py +++ b/miles/backends/fsdp_utils/configs/train_pipeline_config.py @@ -48,6 +48,7 @@ class TrainPipelineConfig(abc.ABC): lora_target_modules: list[str] = ["to_q", "to_k", "to_v", "to_out.0"] optimizer_state_allowed_missing: list[str] = [] + update_weight_target_module: str = "transformer" def prepare_trajectory( self, diff --git a/miles/backends/fsdp_utils/diffusion_update_weight_utils.py b/miles/backends/fsdp_utils/diffusion_update_weight_utils.py index 2dac607..327aab0 100644 --- a/miles/backends/fsdp_utils/diffusion_update_weight_utils.py +++ b/miles/backends/fsdp_utils/diffusion_update_weight_utils.py @@ -30,13 +30,11 @@ class DiffusionUpdateWeight(abc.ABC): """Base updater used by diffusion training actors.""" - def __init__(self, args: Namespace, model: torch.nn.Module) -> None: + def __init__(self, args: Namespace, model: torch.nn.Module, target_module: str) -> None: self.args = args self.model = model self.weight_version = 0 - # Name of the sglang-d pipeline module to target. Defaults to "transformer", - # which is the DiT component for diffusers-based pipelines. - self.target_module = getattr(args, "diffusion_target_module", "transformer") + self.target_module = target_module @abc.abstractmethod def connect_rollout_engines( @@ -176,8 +174,8 @@ class DiffusionUpdateWeightFromTensorLoRA(DiffusionUpdateWeightFromTensor): on the fly during sync (no in-place mutation of the FSDP model). """ - def __init__(self, args, model): - super().__init__(args, model) + def __init__(self, args, model, target_module: str): + super().__init__(args, model, target_module) self._lora_index: dict[str, tuple] = {} for name, module in model.named_modules(): if hasattr(module, "lora_A") and hasattr(module, "lora_B"): diff --git a/miles/backends/sglang_diffusion_utils/sglang_diffusion_engine.py b/miles/backends/sglang_diffusion_utils/sglang_diffusion_engine.py index 87f0477..3b99865 100644 --- a/miles/backends/sglang_diffusion_utils/sglang_diffusion_engine.py +++ b/miles/backends/sglang_diffusion_utils/sglang_diffusion_engine.py @@ -152,7 +152,7 @@ def _format_v6_uri(addr): def _init_normal(self, server_args_dict): logger.info(f"Launch HttpServerEngineAdapter at: {self.server_host}:{self.server_port}") self._pin_to_assigned_gpu() - apply_sgld_monkey_patches = bool(getattr(self.args, "apply_sgld_monkey_patches", False)) + apply_sgld_monkey_patches = self.args.apply_sgld_monkey_patches if apply_sgld_monkey_patches: logger.info( "Launching sglang-d with sgl-d → diffusers monkey patches " diff --git a/miles/backends/training_utils/loss.py b/miles/backends/training_utils/loss.py index 945e8fe..6589692 100644 --- a/miles/backends/training_utils/loss.py +++ b/miles/backends/training_utils/loss.py @@ -588,18 +588,7 @@ def policy_loss_function( max_seq_lens, ) - # Determine pg_loss reducer: use custom if specified, otherwise default - if getattr(args, "custom_pg_loss_reducer_function_path", None) is not None: - custom_pg_loss_reducer_func = load_function(args.custom_pg_loss_reducer_function_path) - # Determine which loss_masks to use for pg_loss reducer - pg_loss_masks = modified_response_masks if (args.get_mismatch_metrics or args.use_tis) else batch["loss_masks"] - pg_loss_reducer = custom_pg_loss_reducer_func( - total_lengths, response_lengths, pg_loss_masks, args.calculate_per_token_loss - ) - else: - pg_loss_reducer = sum_of_sample_mean - - pg_loss = pg_loss_reducer(pg_loss) + pg_loss = sum_of_sample_mean(pg_loss) pg_clipfrac = sum_of_sample_mean(pg_clipfrac) ppo_kl = sum_of_sample_mean(ppo_kl) diff --git a/miles/ray/placement_group.py b/miles/ray/placement_group.py index 1416c72..8a5d5d5 100644 --- a/miles/ray/placement_group.py +++ b/miles/ray/placement_group.py @@ -158,14 +158,14 @@ def create_rollout_manager(args, pg): logger.info( "Creating rollout manager (diffusion=%s, num_gpus=%s)", use_diffusion_rollout, - 0 if (use_diffusion_rollout and getattr(args, "rollout_num_gpus", 1) > 1) else (1 if use_diffusion_rollout else 0), + 0 if (use_diffusion_rollout and args.rollout_num_gpus > 1) else (1 if use_diffusion_rollout else 0), ) scheduling_strategy = None if use_diffusion_rollout: pg_tuple = pg # If rollout uses multiple GPUs, do NOT bind RolloutManager to the rollout PG. # Otherwise it consumes a GPU bundle and starves rollout workers. - if getattr(args, "rollout_num_gpus", 1) <= 1: + if args.rollout_num_gpus <= 1: pg, reordered_bundle_indices, _ = pg_tuple bundle_index = reordered_bundle_indices[0] if reordered_bundle_indices else 0 scheduling_strategy = PlacementGroupSchedulingStrategy( @@ -176,7 +176,7 @@ def create_rollout_manager(args, pg): rollout_manager = RolloutManager.options( num_cpus=1, - num_gpus=0 if (use_diffusion_rollout and getattr(args, "rollout_num_gpus", 1) > 1) else (1 if use_diffusion_rollout else 0), + num_gpus=0 if (use_diffusion_rollout and args.rollout_num_gpus > 1) else (1 if use_diffusion_rollout else 0), scheduling_strategy=scheduling_strategy, ).remote(args, pg_tuple if use_diffusion_rollout else pg) diff --git a/miles/ray/rollout.py b/miles/ray/rollout.py index 603955f..3e21e4d 100644 --- a/miles/ray/rollout.py +++ b/miles/ray/rollout.py @@ -1,4 +1,3 @@ -from ast import Raise import itertools import logging import multiprocessing @@ -11,7 +10,7 @@ import ray import torch from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy -from sglang.srt.constants import GPU_MEMORY_TYPE_CUDA_GRAPH, GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS +from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS from miles.backends.sglang_diffusion_utils.sglang_diffusion_engine import SGLangDiffusionEngine from miles.rollout.base_types import call_rollout_fn @@ -58,10 +57,7 @@ def __init__(self, args, pg): self.data_source = data_source_cls(args) logger.info("RolloutManager: data source loaded, loading rollout functions...") - import sys - print("[DEBUG] RolloutManager: loading generate_rollout...", flush=True) self.generate_rollout = load_function(self.args.rollout_function_path) - print("[DEBUG] RolloutManager: loading eval_generate_rollout...", flush=True) self.eval_generate_rollout = load_function(self.args.eval_function_path) self.custom_reward_post_process_func = ( load_function(self.args.custom_reward_post_process_path) @@ -73,12 +69,10 @@ def __init__(self, args, pg): if self.args.custom_convert_samples_to_train_data_path is not None else None ) - print(f"[DEBUG] RolloutManager: import {self.args.rollout_function_path} done", flush=True) logger.info(f"import {self.args.rollout_function_path} as generate_rollout function.") logger.info(f"import {self.args.eval_function_path} as eval_generate_rollout function.") - print(f"[DEBUG] RolloutManager rollout_num_gpus={getattr(self.args, 'rollout_num_gpus', None)}", flush=True) - logger.info("RolloutManager rollout_num_gpus=%s", getattr(self.args, "rollout_num_gpus", None)) + logger.info("RolloutManager rollout_num_gpus=%s", self.args.rollout_num_gpus) if self.args.debug_train_only: self.all_rollout_engines = [] @@ -88,17 +82,12 @@ def __init__(self, args, pg): num_gpu_per_engine = min(args.rollout_num_gpus_per_engine, args.num_gpus_per_node) num_engines = args.rollout_num_gpus // num_gpu_per_engine self.all_rollout_engines = [None] * num_engines - print(f"[DEBUG] RolloutManager: calling init_rollout_engines with {num_engines} engines...", flush=True) self.num_new_engines = init_rollout_engines(args, pg, self.all_rollout_engines) - print(f"[DEBUG] RolloutManager: init_rollout_engines returned, started {len(self.all_rollout_engines)}", flush=True) logger.info("RolloutManager started %s rollout engines", len(self.all_rollout_engines)) - print("[DEBUG] RolloutManager: creating lock...", flush=True) logger.info("RolloutManager: creating lock...") self.nodes_per_engine = max(1, args.rollout_num_gpus_per_engine // args.num_gpus_per_node) self.rollout_engine_lock = Lock.options(num_cpus=1, num_gpus=0).remote() self.rollout_id = -1 - self._diffusion_offload_fn = None - self._diffusion_onload_fn = None self._metric_checker = MetricChecker.maybe_create(args) self._health_monitor = None if self.args.use_fault_tolerance: @@ -173,7 +162,7 @@ def eval(self, rollout_id): data = result.data self._save_debug_rollout_data(data, rollout_id=rollout_id, evaluation=True) metrics = _log_eval_rollout_data(rollout_id, self.args, data, result.metrics) - max_images = int(getattr(self.args, "diffusion_log_images", 0) or 0) + max_images = self.args.diffusion_log_images if max_images > 0: self._log_images( { @@ -184,7 +173,7 @@ def eval(self, rollout_id): max_images=max_images, step_key="eval/step", step_value=compute_rollout_step(self.args, rollout_id), - reward_key=self.args.eval_reward_key or self.args.reward_key, + reward_key=self.args.eval_reward_key, ) if self._metric_checker is not None: self._metric_checker.on_eval(metrics) @@ -349,23 +338,17 @@ def _convert_samples_to_train_data(self, samples: list[Sample] | list[list[Sampl **_reward_stats_dict(norm_t, "rollout/reward/norm_"), } # Per-prompt (group) stats — meaningful for GRPO-style algorithms. - if getattr(self.args, "advantage_estimator", None) == "grpo" and self.args.n_samples_per_prompt > 1: + if self.args.advantage_estimator == "grpo" and self.args.n_samples_per_prompt > 1: groups_raw = raw_t.view(-1, self.args.n_samples_per_prompt) reward_stats["rollout/reward/group_mean_avg"] = float(groups_raw.mean(dim=-1).mean()) if groups_raw.shape[-1] > 1: reward_stats["rollout/reward/group_std_avg"] = float(groups_raw.std(dim=-1, unbiased=False).mean()) - print( - f"[reward stats] raw mean={raw_t.mean():.4f} std={raw_t.std():.4f} min={raw_t.min():.4f} max={raw_t.max():.4f} | " - f"normalized mean={norm_t.mean():.4f} std={norm_t.std():.4f} min={norm_t.min():.4f} max={norm_t.max():.4f}", - flush=True, - ) - reward_stats["rollout/step"] = compute_rollout_step(self.args, self.rollout_id) tracking_utils.log(self.args, reward_stats, step_key="rollout/step") - max_images = int(getattr(self.args, "diffusion_log_images", 0) or 0) - interval = max(1, int(getattr(self.args, "diffusion_log_image_interval", 1) or 1)) + max_images = self.args.diffusion_log_images + interval = self.args.diffusion_log_image_interval if max_images > 0 and self.rollout_id % interval == 0: self._log_images( {"rollout_media/sample_images": samples}, @@ -469,7 +452,6 @@ def init_rollout_engines(args, pg, all_rollout_engines): assert len(all_rollout_engines) == num_engines pg, reordered_bundle_indices, reordered_gpu_ids = pg - print(f"[DEBUG] init_rollout_engines: reordered_bundle_indices={reordered_bundle_indices}, reordered_gpu_ids={reordered_gpu_ids}", flush=True) # use diffusion SGLang rollout engines for miles diffusion RolloutRayActor = ray.remote(SGLangDiffusionEngine) @@ -484,7 +466,6 @@ def init_rollout_engines(args, pg, all_rollout_engines): # Get the base GPU ID from placement group base_gpu_id = int(reordered_gpu_ids[i * num_gpu_per_engine]) - print(f"[DEBUG] Engine {i}: base_gpu_id={base_gpu_id}, bundle_index={reordered_bundle_indices[i * num_gpu_per_engine]}", flush=True) scheduling_strategy = PlacementGroupSchedulingStrategy( placement_group=pg, diff --git a/miles/rollout/rm_hub/ocr.py b/miles/rollout/rm_hub/ocr.py index fd1adb9..ebd8ea8 100644 --- a/miles/rollout/rm_hub/ocr.py +++ b/miles/rollout/rm_hub/ocr.py @@ -90,7 +90,7 @@ class AsyncOcrPool(metaclass=SingletonMeta): def __init__(self, args) -> None: if not ray.is_initialized(): raise RuntimeError("Ray is not initialized. OCR RM requires Ray for OcrRewardActor.") - num_workers = int(getattr(args, "ocr_num_workers", 4) or 4) + num_workers = args.ocr_num_workers if num_workers <= 0: raise ValueError(f"ocr_num_workers must be > 0, got {num_workers}") self._actors = [OcrRewardActor.options(num_cpus=1).remote(use_gpu=False) for _ in range(num_workers)] diff --git a/miles/rollout/sglang_diffusion_rollout.py b/miles/rollout/sglang_diffusion_rollout.py index 86d2105..4a79412 100644 --- a/miles/rollout/sglang_diffusion_rollout.py +++ b/miles/rollout/sglang_diffusion_rollout.py @@ -37,18 +37,18 @@ def build_rollout_sampling_params( ) -> dict[str, Any]: """Build static fields in JSON body for ``POST /rollout/generate`` (``RolloutImageRequest``). """ - neg = getattr(args, "diffusion_negative_prompt", None) - eval_steps = getattr(args, "diffusion_eval_num_steps", None) - num_steps = int(eval_steps) if evaluation and eval_steps is not None else args.diffusion_num_steps + neg = args.diffusion_negative_prompt + eval_steps = args.diffusion_eval_num_steps + num_steps = eval_steps if evaluation and eval_steps is not None else args.diffusion_num_steps sampling_params: dict[str, Any] = { - "generator_device": getattr(args, "diffusion_generator_device", "cuda"), + "generator_device": args.diffusion_generator_device, "negative_prompt": neg, - "width": getattr(args, "diffusion_width", None), - "height": getattr(args, "diffusion_height", None), + "width": args.diffusion_width, + "height": args.diffusion_height, "num_inference_steps": num_steps, - "guidance_scale": getattr(args, "diffusion_guidance_scale", None), - "true_cfg_scale": getattr(args, "diffusion_true_cfg_scale", None), + "guidance_scale": args.diffusion_guidance_scale, + "true_cfg_scale": args.diffusion_true_cfg_scale, } if evaluation: @@ -57,10 +57,10 @@ def build_rollout_sampling_params( sampling_params.update( { "rollout": True, - "rollout_sde_type": getattr(args, "diffusion_sde_type", "sde"), - "rollout_noise_level": float(getattr(args, "diffusion_noise_level", 0.7)), - "rollout_log_prob_no_const": bool(getattr(args, "diffusion_log_prob_no_const", False)), - "rollout_debug_mode": bool(getattr(args, "diffusion_debug_mode", False)), + "rollout_sde_type": args.diffusion_sde_type, + "rollout_noise_level": args.diffusion_noise_level, + "rollout_log_prob_no_const": args.diffusion_log_prob_no_const, + "rollout_debug_mode": args.diffusion_debug_mode, "rollout_return_denoising_env": True, "rollout_return_dit_trajectory": True, } @@ -97,10 +97,10 @@ def __init__(self, args: Namespace) -> None: self.sampling_params = build_rollout_sampling_params(args) self.step_strategy_fn = ( load_function(args.diffusion_step_strategy_path) - if getattr(args, "diffusion_step_strategy_path", None) + if args.diffusion_step_strategy_path else None ) - self.dp_counts = [0] * (args.sglang_dp_size or 1) + self.dp_counts = [0] * args.sglang_dp_size self.dp_rank = 0 self.node_id = ray.get_runtime_context().get_node_id() self.response_parser_actor = RolloutImageResponseParserActor.options( @@ -232,9 +232,9 @@ async def generate_and_rm_group( # N-spaced base so sgl-d's seed→[seed+0..seed+N-1] expansion stays disjoint # per (rollout, prompt-group); group_index is monotonic across the run. - n_per_prompt = int(args.n_samples_per_prompt) + n_per_prompt = args.n_samples_per_prompt group_index = int(getattr(group[0], "group_index", 0) or 0) - seed_base = (int(args.rollout_seed) + group_index * n_per_prompt) % (2**31) + seed_base = (args.rollout_seed + group_index * n_per_prompt) % (2**31) tasks = [] for idx in range(0, len(group), args.diffusion_microgroup_size): @@ -362,7 +362,7 @@ async def eval_rollout(args: Namespace, rollout_id: int) -> tuple[dict[str, dict assert not args.group_rm, "Group RM is not supported for eval rollout" coros = [] - for dataset_config in getattr(args, "eval_datasets", []) or []: + for dataset_config in args.eval_datasets: coros.append(eval_rollout_single_dataset(args, rollout_id, dataset_config)) results_list = await asyncio.gather(*coros) results = {} @@ -440,7 +440,7 @@ async def eval_rollout_single_dataset( data.sort(key=lambda sample: sample.index) - reward_key = args.eval_reward_key or args.reward_key + reward_key = args.eval_reward_key return { dataset_config.name: { "rewards": [sample.reward if not reward_key else sample.reward[reward_key] for sample in data], diff --git a/miles/rollout/step_strategy_hub.py b/miles/rollout/step_strategy_hub.py index 257c6ad..a2733fc 100644 --- a/miles/rollout/step_strategy_hub.py +++ b/miles/rollout/step_strategy_hub.py @@ -22,8 +22,8 @@ def sde_window( to the window for loss / backprop. Keeping the full trajectory avoids the sglang-d-side trailing ``x_final`` aliasing issue when the window ends before the last denoising step.""" - window_size = int(args.diffusion_sde_window_size) - range_raw = getattr(args, "diffusion_sde_window_range", None) + window_size = args.diffusion_sde_window_size + range_raw = args.diffusion_sde_window_range if range_raw: parts = [int(x) for x in str(range_raw).split(",")] lo, hi = parts[0], parts[1] diff --git a/miles/router/router.py b/miles/router/router.py index a4f52ad..dca0442 100644 --- a/miles/router/router.py +++ b/miles/router/router.py @@ -41,13 +41,13 @@ def __init__(self, args, verbose=False): self.dead_workers: set[str] = set() self.max_weight_version = None - max_connections = getattr(args, "miles_router_max_connections", None) + max_connections = args.miles_router_max_connections if max_connections is None: max_connections = ( args.sglang_server_concurrency * args.rollout_num_gpus // args.rollout_num_gpus_per_engine ) - timeout = getattr(args, "miles_router_timeout", None) + timeout = args.miles_router_timeout self.client = httpx.AsyncClient( limits=httpx.Limits(max_connections=max_connections), diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index a5a08d4..bbe76da 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -720,6 +720,24 @@ def add_algo_arguments(parser): reset_arg(parser, "--save", type=str, default=None) reset_arg(parser, "--save-interval", type=int, default=None) reset_arg(parser, "--async-save", action="store_true") + parser.add_argument( + "--ckpt-step", + type=int, + default=None, + help="Checkpoint iteration to load from --load. Defaults to latest_checkpointed_iteration.txt.", + ) + parser.add_argument( + "--no-load-optim", + action="store_true", + default=False, + help="Do not load optimizer state when resuming from --load.", + ) + parser.add_argument( + "--no-load-rng", + action="store_true", + default=False, + help="Do not restore RNG state when resuming from --load.", + ) reset_arg( parser, "--no-save-optim", @@ -1213,6 +1231,11 @@ def miles_validate_args(args): if args.eval_reward_key is None: args.eval_reward_key = args.reward_key + if args.diffusion_log_image_interval < 1: + raise ValueError( + f"diffusion_log_image_interval must be >= 1, got {args.diffusion_log_image_interval}" + ) + if args.dump_details is not None: args.save_debug_rollout_data = f"{args.dump_details}/rollout_data/{{rollout_id}}.pt" args.save_debug_train_data = f"{args.dump_details}/train_data/{{rollout_id}}_{{rank}}.pt" diff --git a/miles/utils/wandb_utils.py b/miles/utils/wandb_utils.py index ea223f6..1790ae8 100644 --- a/miles/utils/wandb_utils.py +++ b/miles/utils/wandb_utils.py @@ -93,7 +93,7 @@ def _compute_config_for_logging(args): # https://docs.wandb.ai/guides/track/log/distributed-training/#track-all-processes-to-a-single-run def init_wandb_secondary(args, router_addr=None): - wandb_run_id = getattr(args, "wandb_run_id", None) + wandb_run_id = args.wandb_run_id if wandb_run_id is None: return