diff --git a/examples/gr00t_n1_5/conf/serve.yaml b/examples/gr00t_n1_5/conf/serve.yaml index a30847fdaa..6d53c33735 100644 --- a/examples/gr00t_n1_5/conf/serve.yaml +++ b/examples/gr00t_n1_5/conf/serve.yaml @@ -7,7 +7,7 @@ experiment: exp_dir: outputs/${experiment.exp_name} task: type: serve - entrypoint: flagscale/serve/run_serve_gr00t_n1_5.py + entrypoint: flagscale/serve/run_serve_qwen_gr00t.py runner: hostfile: null deploy: diff --git a/examples/gr00t_n1_5/conf/serve/gr00t_n1_5.yaml b/examples/gr00t_n1_5/conf/serve/gr00t_n1_5.yaml index 45279c2e35..bbb0f42a56 100644 --- a/examples/gr00t_n1_5/conf/serve/gr00t_n1_5.yaml +++ b/examples/gr00t_n1_5/conf/serve/gr00t_n1_5.yaml @@ -3,5 +3,5 @@ host: 0.0.0.0 port: 5000 model_variant: Gr00tN15 - model: /workspace/models/gr00t_n1_5_train/checkpoints/last + model: ./outputs/gr00t_n1_5_train/checkpoints/last/pretrained_model device: "cuda" diff --git a/examples/gr00t_n1_5/conf/train.yaml b/examples/gr00t_n1_5/conf/train.yaml index 89ef63badb..6becf5d0b1 100644 --- a/examples/gr00t_n1_5/conf/train.yaml +++ b/examples/gr00t_n1_5/conf/train.yaml @@ -1,6 +1,6 @@ defaults: - _self_ - - train: gr00t_n1_5 + - train: libero_spatial_demo experiment: exp_name: gr00t_n1_5_train diff --git a/examples/gr00t_n1_5/conf/train/libero_spatial_demo.yaml b/examples/gr00t_n1_5/conf/train/libero_spatial_demo.yaml new file mode 100644 index 0000000000..0743243c15 --- /dev/null +++ b/examples/gr00t_n1_5/conf/train/libero_spatial_demo.yaml @@ -0,0 +1,123 @@ +system: + batch_size: 32 + train_steps: 1000 + log_freq: 10 + grad_clip_norm: 10.0 + use_amp: true + shuffle: true + num_workers: 8 + + checkpoint: + output_directory: ${experiment.exp_dir} + save_checkpoint: true + save_freq: 500 + # Path to a checkpoint directory to resume training from (e.g. outputs/gr00t_n1_5_train/checkpoints/001000) + # resume_from: + +model: + model_name: gr00t_n1_5 + # Path or HuggingFace model ID for the pretrained GR00T N1.5 model + checkpoint_dir: /workspace/models/nvidia/GR00T-N1.5-3B + + # Fine-tuning control + tune_llm: true + tune_visual: true + tune_projector: true + tune_diffusion_model: true + compute_dtype: bfloat16 + + # Embodiment tag to use for training (e.g. 'new_embodiment', 'gr1') + embodiment_tag: new_embodiment + + # Number of future action steps predicted per forward pass. + # Determines action_delta_indices = [0, 1, ..., chunk_size - 1]. + chunk_size: 16 + + # Padding dimensions — shorter state/action sequences are zero-padded to these sizes + max_state_dim: 64 + max_action_dim: 32 + + normalization_mapping: + VISUAL: IDENTITY + STATE: MIN_MAX + ACTION: MIN_MAX + + # LoRA fine-tuning (lora_rank: 0 disables LoRA) + # lora_rank: 0 + # lora_alpha: 16 + # lora_dropout: 0.1 + # lora_full_model: false + + # ============================================================ + # Module Freezing Configuration + # ============================================================ + # Freezing logic: freeze_patterns are applied first, then keep_patterns override. + # Patterns are regex matched against full parameter names. + # + # Common patterns for GR00T N1.5: + # - "_groot_model\\.backbone\\..*" # Entire backbone (VLM + vision) + # - "_groot_model\\.action_head\\..*" # Action diffusion head + # + # freeze: + # freeze_patterns: + # - "_groot_model\\.backbone\\..*" + + optimizer: + name: AdamW + lr: 1.0e-4 + betas: [0.95, 0.999] + eps: 1.0e-08 + weight_decay: 1.0e-05 + scheduler: + name: cosine_decay_with_warmup + warmup_steps: 500 + decay_steps: 10000 + peak_lr: 1.0e-4 + decay_lr: 1.0e-5 + +data: + dataset_type: lerobot + data_path: /workspace/datasets/tailong-wu/libero_spatial_no_noops_1.0.0_lerobot_v3.0 + tolerance_s: 0.0001 + preprocessor: + name: policy_preprocessor + steps: + # 1. Rename keys if needed (e.g., dataset-specific camera names) + - registry_name: rename_observations_processor + config: + rename_map: {} + # 2. Add batch dimension for single samples + - registry_name: to_batch_processor + config: {} + # 3. Pack video/state/action/language/embodiment; apply optional min-max normalization before padding + - registry_name: groot_pack_inputs + config: + state_horizon: 1 + action_horizon: 16 + max_state_dim: 64 + max_action_dim: 32 + language_key: task + embodiment_tag: new_embodiment + normalize_min_max: true + # 4. Eagle encode (creates eagle_content) + - registry_name: groot_eagle_encode + config: {} + # 5. Collate eagle_content -> eagle_* tensors + - registry_name: groot_eagle_collate + config: {} + # 6. Move to device + - registry_name: device_processor + config: + device: cuda + float_dtype: null + postprocessor: + name: policy_postprocessor + steps: + - registry_name: groot_action_unpack_unnormalize + config: + env_action_dim: 7 + normalize_min_max: true + - registry_name: device_processor + config: + device: cpu + float_dtype: null diff --git a/examples/pi0_5/conf/serve/pi0_5.yaml b/examples/pi0_5/conf/serve/pi0_5.yaml index 4eb6bff94f..20011ba1aa 100644 --- a/examples/pi0_5/conf/serve/pi0_5.yaml +++ b/examples/pi0_5/conf/serve/pi0_5.yaml @@ -3,7 +3,7 @@ host: 0.0.0.0 port: 5000 model_variant: "pi0.5" - model: /workspace/models/pi0_5_train/checkpoints/last/pretrained_model + model: ./outputs/pi0_5_train/checkpoints/last/pretrained_model device: "cuda" # Maps client-sent observation keys to the keys the model was trained with. # Format: {key_from_client: key_expected_by_model} diff --git a/examples/pi0_5/conf/train.yaml b/examples/pi0_5/conf/train.yaml index 60e88123b4..04c054b20e 100644 --- a/examples/pi0_5/conf/train.yaml +++ b/examples/pi0_5/conf/train.yaml @@ -1,6 +1,6 @@ defaults: - _self_ - - train: pi0_5 + - train: libero_spatial_demo experiment: exp_name: pi0_5_train diff --git a/examples/pi0_5/conf/train/libero_spatial_demo.yaml b/examples/pi0_5/conf/train/libero_spatial_demo.yaml new file mode 100644 index 0000000000..c9984e63d2 --- /dev/null +++ b/examples/pi0_5/conf/train/libero_spatial_demo.yaml @@ -0,0 +1,48 @@ +system: + batch_size: 16 + train_steps: 2000 + log_freq: 10 + grad_clip_norm: 1.0 + use_amp: true + shuffle: true + num_workers: 4 + + checkpoint: + output_directory: ${experiment.exp_dir} + # Whether to save checkpoint + save_checkpoint: true + # Number of steps between checkpoints + save_freq: 500 + # TODO(yupu): Support resuming from checkpoint + +model: + model_name: pi0.5 + # Path to the pretrained pi05_base model checkpoint + checkpoint_dir: /workspace/models/lerobot/pi05_libero_base + # Path to paligemma tokenizer + tokenizer_path: /workspace/models/google/paligemma-3b-pt-224 + tokenizer_max_length: 200 + gradient_checkpointing: true + freeze_vision_encoder: false + + optimizer: + name: AdamW + lr: 2.5e-5 + betas: [0.9, 0.95] + eps: 1.0e-8 + weight_decay: 0.01 + scheduler: + warmup_steps: 1000 + decay_steps: 30000 + decay_lr: 2.5e-6 + +data: + # Path to the training data + data_path: /workspace/datasets/tailong-wu/libero_spatial_no_noops_1.0.0_lerobot_v3.0 + tolerance_s: 0.0001 + use_imagenet_stats: true + # To match the input features naming from the dataset to the policy config + rename_map: + "observation.images.wrist_image": "observation.images.image2" + # By default, Pi0.5 uses quantiles for state and action normalization, if false, it uses mean and std instead + use_quantiles: false diff --git a/examples/qwen_gr00t/conf/serve/qwen_gr00t.yaml b/examples/qwen_gr00t/conf/serve/qwen_gr00t.yaml index c0dec5eaf7..92dc72b2a3 100644 --- a/examples/qwen_gr00t/conf/serve/qwen_gr00t.yaml +++ b/examples/qwen_gr00t/conf/serve/qwen_gr00t.yaml @@ -3,15 +3,15 @@ host: 0.0.0.0 port: 5000 model_variant: QwenGr00t - model: /share/project/fengyupu/github/FlagScale_2/outputs/260320_qwen_gr00t_train_libero_goal_old_dataset/checkpoints/last + model: ./outputs/qwen_gr00t_train/checkpoints/last/pretrained_model device: "cuda" # Maps client-sent observation keys to the keys the model was trained with. # Format: {key_from_client: key_expected_by_model} - rename_map: - "observation/image": "observation.images.image" - "observation/wrist_image": "observation.images.wrist_image" - "observation/state": "observation.state" - "prompt": "task" + # rename_map: + # "observation/image": "observation.images.image" + # "observation/wrist_image": "observation.images.wrist_image" + # "observation/state": "observation.state" + # "prompt": "task" serve_preprocessor: steps: - registry_name: image_resize_processor diff --git a/examples/qwen_gr00t/conf/train.yaml b/examples/qwen_gr00t/conf/train.yaml index 1b98e3c9b6..455778b002 100644 --- a/examples/qwen_gr00t/conf/train.yaml +++ b/examples/qwen_gr00t/conf/train.yaml @@ -1,6 +1,6 @@ defaults: - _self_ - - train: qwen_gr00t + - train: libero_spatial_demo experiment: exp_name: qwen_gr00t_train diff --git a/examples/qwen_gr00t/conf/train/libero_spatial_demo.yaml b/examples/qwen_gr00t/conf/train/libero_spatial_demo.yaml new file mode 100644 index 0000000000..f6bcfa5e62 --- /dev/null +++ b/examples/qwen_gr00t/conf/train/libero_spatial_demo.yaml @@ -0,0 +1,150 @@ +system: + batch_size: 8 + train_steps: 4000 + log_freq: 1 + grad_clip_norm: 1.0 + use_amp: true + shuffle: true + num_workers: 4 + # Weight applied to VLM language modelling loss when co-training with vlm_data_path. + # Set to 0 or omit vlm_data_path to disable co-training. + vlm_loss_scale: 0.1 + + checkpoint: + output_directory: ${experiment.exp_dir} + # Whether to save checkpoint + save_checkpoint: true + # Number of steps between checkpoints + save_freq: 1000 + # Path to a checkpoint directory to resume training from (e.g. /path/to/checkpoints/005000) + # resume_from: + +model: + model_name: qwen_gr00t + vlm: + type: qwen3-vl + base_vlm: /workspace/models/Qwen/Qwen3-VL-4B-Instruct/ + attn_implementation: flash_attention_2 + action_model: + # Whether to condition the action model on proprioceptive state (observation.state) + use_state: false + type: gr00t_action_head + action_model_type: DiT-B + hidden_size: 1024 + add_pos_embed: True + max_seq_len: 1024 + action_dim: 7 + state_dim: 7 + future_action_window_size: 7 + action_horizon: 8 + repeated_diffusion_steps: 4 + noise_beta_alpha: 1.5 + noise_beta_beta: 1.0 + noise_s: 0.999 + num_timestep_buckets: 1000 + num_inference_timesteps: 4 + num_target_vision_tokens: 32 + diffusion_model_cfg: + cross_attention_dim: 2048 + dropout: 0.2 + final_dropout: True + interleave_self_attention: True + norm_type: ada_norm + num_layers: 16 + output_dim: 1024 + positional_embeddings: None + + prompt_template: "Your task is {instruction}. To identify the key objects for your task. Locate their bounding boxes in [x1,y1,x2,y2] format." + + normalization_mapping: + VISUAL: IDENTITY + STATE: MIN_MAX + ACTION: MIN_MAX + + optimizer: + name: AdamW + lr: 2.5e-5 + betas: [0.9, 0.95] + eps: 1.0e-08 + weight_decay: 1.0e-08 + param_groups: + vlm: + lr: 1.0e-05 + action_model: + lr: 1.0e-04 + scheduler: + name: cosine_with_min_lr + warmup_steps: 200 + scheduler_kwargs: + min_lr: 1.0e-06 + # Legacy fields kept for BC + decay_steps: 30000 + decay_lr: 2.5e-6 + + # ============================================================ + # Module Freezing Configuration + # ============================================================ + # Freezing logic: freeze_patterns are applied first, then keep_patterns override. + # Patterns are regex matched against full parameter names. + # + # Common patterns for QwenGR00T: + # - "qwen_vl_interface\\..*" # Entire VLM + # - "qwen_vl_interface\\.model\\.visual\\..*" # Vision encoder + # - "qwen_vl_interface\\.model\\.model\\..*" # Language model + # - "qwen_vl_interface\\.model\\.model\\.layers\\.[0-9]\\." # LLM layers 0-9 + # - "action_model\\..*" # Action head + # - "action_model\\.model\\.transformer_blocks\\.[0-7]\\." # DiT blocks 0-7 + # + # freeze: + # # SCENARIO A: Freeze VLM, train only action head + # freeze_patterns: + # - "qwen_vl_interface\\..*" + # + # # SCENARIO B: Freeze VLM but keep projector trainable + # # freeze_patterns: + # # - "qwen_vl_interface\\..*" + # # keep_patterns: + # # - "qwen_vl_interface\\.model\\.visual\\.merger\\..*" + # + # # SCENARIO C: Freeze everything except action decoder + # # freeze_patterns: + # # - ".*" + # # keep_patterns: + # # - "action_model\\.action_decoder\\..*" + +data: + dataset_type: lerobot + wds: + vision_root: "" + action_key: eepose + state_key: eepose + # Path to the training data + data_path: /workspace/datasets/tailong-wu/libero_spatial_no_noops_1.0.0_lerobot_v3.0 + # Path to VLM co-training data (WDS/Energon format). Leave unset to disable co-training. + # vlm_data_path: /workspace/datasets/vlm_cotrain/ + tolerance_s: 0.0001 + preprocessor: + name: policy_preprocessor + steps: + - registry_name: rename_observations_processor + config: + rename_map: {} + - registry_name: to_batch_processor + config: {} + - registry_name: device_processor + config: + device: cuda + float_dtype: null + - registry_name: normalizer_processor + config: + eps: 1e-8 + features: {} + # norm_map is injected at runtime from model.normalization_mapping + postprocessor: + name: policy_postprocessor + steps: + - registry_name: unnormalizer_processor + config: + eps: 1e-8 + features: {} + # norm_map is injected at runtime from model.normalization_mapping diff --git a/flagscale/models/vla/base_policy.py b/flagscale/models/vla/base_policy.py index 2d2eb02656..188f883628 100644 --- a/flagscale/models/vla/base_policy.py +++ b/flagscale/models/vla/base_policy.py @@ -161,9 +161,9 @@ def from_pretrained(cls, pretrained_path, device="cpu", *, config=None): strict=False, ) if missing: - logger.warning(f"Missing keys when loading checkpoint: {len(missing)} keys") + logger.warning(f"Missing keys when loading checkpoint: {len(missing)} keys: {missing}") if unexpected: - logger.warning(f"Unexpected keys in checkpoint: {len(unexpected)} keys") + logger.warning(f"Unexpected keys in checkpoint: {len(unexpected)} keys: {unexpected}") model.to(device) model.eval() diff --git a/flagscale/models/vla/gr00t_n1_5/gr00t_n1.py b/flagscale/models/vla/gr00t_n1_5/gr00t_n1.py index 1800fe4d43..6e3ffda76a 100644 --- a/flagscale/models/vla/gr00t_n1_5/gr00t_n1.py +++ b/flagscale/models/vla/gr00t_n1_5/gr00t_n1.py @@ -164,6 +164,7 @@ def __init__( self.eagle_model.language_model.model.layers.pop(-1) self.select_layer = select_layer + self._is_fsdp2: bool | None = None self.set_trainable_parameters(tune_llm, tune_visual) def set_trainable_parameters(self, tune_llm: bool, tune_visual: bool): @@ -223,14 +224,23 @@ def forward(self, vl_input: BatchFeature) -> BatchFeature: # YL (TODO HACK): to resolve DDP issue when tune_visual=True # Ensure all trainable parameters in vision_model are used in the forward pass for DDP compatibility + # Skip under FSDP2 — it doesn't require all params to participate in forward, + # and the DTensor params cause mixed Tensor/DTensor errors in both forward and backward. if self.training and self.tune_visual: - dummy_term = torch.tensor( - 0.0, device=eagle_embeds.device, dtype=eagle_embeds.dtype, requires_grad=True - ) - for param in self.eagle_model.vision_model.parameters(): - if param.requires_grad: - dummy_term = dummy_term + 0.0 * param.sum() - eagle_embeds = eagle_embeds + dummy_term + if self._is_fsdp2 is None: + from torch.distributed.tensor import DTensor + + self._is_fsdp2 = any( + isinstance(p, DTensor) for p in self.eagle_model.vision_model.parameters() + ) + if not self._is_fsdp2: + dummy_term = torch.tensor( + 0.0, device=eagle_embeds.device, dtype=eagle_embeds.dtype, requires_grad=True + ) + for param in self.eagle_model.vision_model.parameters(): + if param.requires_grad: + dummy_term = dummy_term + 0.0 * param.sum() + eagle_embeds = eagle_embeds + dummy_term return BatchFeature( data={"backbone_features": eagle_embeds, "backbone_attention_mask": eagle_mask} diff --git a/flagscale/models/vla/qwen_gr00t/modeling_qwen_gr00t.py b/flagscale/models/vla/qwen_gr00t/modeling_qwen_gr00t.py index 07fdf5e4f6..178b7fa7da 100644 --- a/flagscale/models/vla/qwen_gr00t/modeling_qwen_gr00t.py +++ b/flagscale/models/vla/qwen_gr00t/modeling_qwen_gr00t.py @@ -20,7 +20,6 @@ from safetensors.torch import save_file from .configuration_qwen_gr00t import QwenGr00tConfig -from flagscale.logger import logger from flagscale.models.utils.constants import ( ACTION, OBS_STATE, @@ -182,12 +181,6 @@ def predict_action(self, batch: list[dict] | dict) -> dict: else: state = None else: # lerobot: single dict with batched tensors - logger.info(f"[predict_action] batch keys={list(batch.keys())}") - logger.info(f"[predict_action] image_features keys={list(self.image_features.keys())}") - for k in self.image_features: - if k in batch: - v = batch[k] - logger.info(f"[predict_action] image key={k} shape={v.shape} dtype={v.dtype}") images, instructions = self.vlm.prepare_input( batch, image_feature_keys=list(self.image_features.keys()) ) @@ -200,10 +193,6 @@ def predict_action(self, batch: list[dict] | dict) -> dict: # last_hidden_state: [B, seq_len, H] last_hidden = vlm_output["hidden_states"][-1] # [B, L, H] - logger.info( - f"[predict_action] last_hidden shape={last_hidden.shape} dtype={last_hidden.dtype}" - ) - if state is not None: state = state.to(device=last_hidden.device, dtype=last_hidden.dtype) @@ -213,11 +202,6 @@ def predict_action(self, batch: list[dict] | dict) -> dict: action_input = {"state": state} output = self.action_model.predict_action(vlm_output_for_action, action_input) - logger.info(f"[predict_action] output keys={list(output.keys())}") - for k, v in output.items(): - if isinstance(v, torch.Tensor): - logger.info(f"[predict_action] output {k} shape={v.shape} dtype={v.dtype}") - # Assume the output of the action model is dict mapping `ACTION` to the normalized actions return output diff --git a/flagscale/models/vla/vlm/qwenvl_backbone.py b/flagscale/models/vla/vlm/qwenvl_backbone.py index 709498445b..7c78dca626 100644 --- a/flagscale/models/vla/vlm/qwenvl_backbone.py +++ b/flagscale/models/vla/vlm/qwenvl_backbone.py @@ -18,7 +18,6 @@ Qwen3VLForConditionalGeneration, ) -from flagscale.logger import logger from flagscale.models.vla.registry import register_vlm from flagscale.platforms.platform_manager import get_platform @@ -91,14 +90,9 @@ def prepare_input( if isinstance(instructions, str): instructions = [instructions] - logger.info(f"[prepare_input] image_feature_keys={image_feature_keys}") batch_images: list[list[Image.Image]] | None = None for key in image_feature_keys: imgs = batch[key] - if isinstance(imgs, torch.Tensor): - logger.info( - f"[prepare_input] key={key} tensor shape={imgs.shape} dtype={imgs.dtype}" - ) if isinstance(imgs, torch.Tensor) and imgs.ndim == 3: imgs = [imgs] key_images = [_to_pil(img) for img in imgs] @@ -111,9 +105,6 @@ def prepare_input( for idx, sample_images in enumerate(batch_images): batch_images[idx] = [img for img in sample_images if img is not None] - logger.info( - f"[prepare_input] batch_size={len(batch_images)} images_per_sample={[len(s) for s in batch_images]} pil_size={batch_images[0][0].size if batch_images else None}" - ) return batch_images, instructions def build_qwenvl_inputs( @@ -139,10 +130,6 @@ def _build_messages( return messages def forward(self, batch: dict[str, torch.Tensor], **kwargs) -> dict[str, torch.Tensor]: - logger.info( - f"[VLM.forward] input keys={list(batch.keys())} " - + " ".join(f"{k}={v.shape}" for k, v in batch.items() if isinstance(v, torch.Tensor)) - ) with torch.autocast(get_platform().amp_device_type(), dtype=torch.bfloat16): outputs = self.model( **batch, @@ -150,9 +137,6 @@ def forward(self, batch: dict[str, torch.Tensor], **kwargs) -> dict[str, torch.T return_dict=True, **kwargs, ) - logger.info( - f"[VLM.forward] hidden_states: {len(outputs.hidden_states)} layers, last={outputs.hidden_states[-1].shape}" - ) # TODO: (yupu) We should output the original outputs, not just the hidden states. return {"hidden_states": outputs.hidden_states} @@ -197,13 +181,6 @@ def build_qwenvl_inputs( text=texts, images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt" ) - logger.info( - "[Qwen25.build_qwenvl_inputs] " - + " ".join( - f"{k}={v.shape}" for k, v in batch_input.items() if isinstance(v, torch.Tensor) - ) - ) - # Use current CUDA device instead of self.model.device, which returns # a DTensor device under FSDP2 and causes mixed Tensor/DTensor errors. return batch_input.to(get_platform().device()) @@ -245,13 +222,6 @@ def build_qwenvl_inputs( return_tensors="pt", ) - logger.info( - "[Qwen3.build_qwenvl_inputs] " - + " ".join( - f"{k}={v.shape}" for k, v in batch_inputs.items() if isinstance(v, torch.Tensor) - ) - ) - # Use current CUDA device instead of self.model.device, which returns # a DTensor device under FSDP2 and causes mixed Tensor/DTensor errors. return batch_inputs.to(get_platform().device()) diff --git a/flagscale/serve/run_serve_gr00t_n1_5.py b/flagscale/serve/run_serve_gr00t_n1_5.py deleted file mode 100644 index 271869a38f..0000000000 --- a/flagscale/serve/run_serve_gr00t_n1_5.py +++ /dev/null @@ -1,110 +0,0 @@ -# Mainly adopted from: -# https://github.com/starVLA/starVLA/blob/3f7feefbc5fc25890ad3a7d262b8a0aea1339aa7/deployment/model_server/server_policy.py - -import argparse -import importlib -import time - -import torch -from omegaconf import DictConfig, ListConfig, OmegaConf - -import flagscale.models.vla.gr00t_n1_5.processor_gr00t # noqa: F401 register GR00T processor steps -from flagscale.logger import logger -from flagscale.models.utils.constants import ACTION -from flagscale.serve.websocket_policy_server import WebsocketPolicyServer -from flagscale.train.utils.train_utils import load_checkpoint - - -class Policy: - def __init__(self, config: DictConfig | ListConfig): - self.config_engine = config["engine_args"] - - self.host = self.config_engine.get("host", "0.0.0.0") - self.port = self.config_engine.get("port", 5000) - self.model = None - self.preprocessor = None - self.postprocessor = None - - self.load_model() - - def load_model(self): - t_s = time.perf_counter() - model_variant = self.config_engine.model_variant - policy = getattr(importlib.import_module("flagscale.models.vla"), model_variant) - self.model, self.preprocessor, self.postprocessor = load_checkpoint( - self.config_engine.model, policy, self.config_engine.device - ) - logger.info(f"Policy model loading latency: {time.perf_counter() - t_s:.2f}s") - - def inference(self, batch): - logger.info("Start to inference") - # { - # "observation.images.image": np.ndarray, shape=(224, 224, 3), dtype=uint8, - # "observation.images.wrist_image": np.ndarray, shape=(224, 224, 3), dtype=uint8, - # "observation.state": np.ndarray, shape=(N,), dtype=float32, - # "task": str, - # } - # NOTE: Images must be 224x224 resolution (uint8 HWC format). - # TODO: (yupu) Add explicit numpy-to-tensor conversion here before preprocessing, - # instead of relying on ad-hoc conversions inside each processor step. - - # Debug: log incoming keys and state info - logger.info(f"incoming keys: {list(batch.keys())}") - if "observation.state" in batch: - s = batch["observation.state"] - logger.info( - f"observation.state: type={type(s).__name__}, shape={s.shape if hasattr(s, 'shape') else 'N/A'}, values={s}" - ) - - batch = self.preprocessor(batch) - - with torch.no_grad(): - action = self.model.predict_action(batch) - a_raw = action[ACTION] - logger.info( - f"action before postprocessor: shape={a_raw.shape}, first_step_7={a_raw[0, 0, :7]}" - ) - - logger.info("Applying postprocessor...") - action = self.postprocessor(action) - - # Convert to numpy for msgpack serialization - action[ACTION] = action[ACTION].detach().cpu().numpy() - - # Debug: log action shape and first-timestep values after postprocessing - a = action[ACTION] - logger.info(f"action after postprocessor: shape={a.shape}, first_step={a[0, 0]}") - - return action - - -def parse_config() -> DictConfig | ListConfig: - """Parse the configuration file""" - - parser = argparse.ArgumentParser() - parser.add_argument( - "--config-path", type=str, required=True, help="Path to the configuration YAML file" - ) - parser.add_argument("--log-dir", type=str, required=True, help="Path to the log") - args = parser.parse_args() - config = OmegaConf.load(args.config_path) - return config - - -def main(config): - policy = Policy(config) - logger.info("Done") - # start websocket server - server = WebsocketPolicyServer( - policy=policy, - host=policy.host, - port=policy.port, - metadata={"env": "simpler_env"}, - ) - logger.info("Server running ...") - server.serve_forever() - - -if __name__ == "__main__": - parsed_cfg = parse_config() - main(parsed_cfg["serve"][0]) diff --git a/flagscale/serve/run_serve_qwen_gr00t.py b/flagscale/serve/run_serve_qwen_gr00t.py index b0adfb80f0..886598b985 100644 --- a/flagscale/serve/run_serve_qwen_gr00t.py +++ b/flagscale/serve/run_serve_qwen_gr00t.py @@ -8,6 +8,7 @@ import torch from omegaconf import DictConfig, ListConfig, OmegaConf +import flagscale.models.vla.gr00t_n1_5.processor_gr00t import flagscale.serve.processor # noqa: F401 — registers serve-specific processor steps from flagscale.logger import logger from flagscale.models.utils.constants import ACTION, OBS_IMAGES, OBS_STATE @@ -100,14 +101,6 @@ def validate_batch(batch: dict) -> list[str]: return errors -def debug_print(batch): - for k, v in batch.items(): - if hasattr(v, "shape"): - logger.info(f" {k}: shape={v.shape} dtype={v.dtype}") - else: - logger.info(f" {k}: type={type(v).__name__} value={repr(v)[:120]}") - - class Policy: """VLA policy server wrapping a TrainablePolicy with pre/post-processing. @@ -191,53 +184,28 @@ def inference(self, batch: dict) -> dict: logger.info(f"Raw batch keys: {list(batch.keys())}") if self.rename_map: - batch = {self.rename_map.get(k, k): v for k, v in batch.items()} logger.info(f"After rename keys: {list(batch.keys())}") + batch = {self.rename_map.get(k, k): v for k, v in batch.items()} errors = validate_batch(batch) if errors: for err in errors: - # TODO: (yupu) Response with error status? logger.warning(f"Batch validation: {err}") - debug_print(batch) - # for k, v in batch.items(): - # if hasattr(v, "shape"): - # logger.info(f" {k}: shape={v.shape} dtype={v.dtype}") - # else: - # logger.info(f" {k}: type={type(v).__name__} value={repr(v)[:120]}") - if self.serve_preprocessor: batch = self.serve_preprocessor(batch) - logger.info("After serve_preprocessor:") - debug_print(batch) - # for k, v in batch.items(): - # if hasattr(v, "shape"): - # logger.info(f" {k}: shape={v.shape} dtype={v.dtype}") batch = self.preprocessor(batch) - logger.info("After preprocessor:") - debug_print(batch) - # for k, v in batch.items(): - # if hasattr(v, "shape"): - # logger.info(f" {k}: shape={v.shape} dtype={v.dtype}") with torch.no_grad(): action = self.model.predict_action(batch) - logger.info(f"Raw action keys: {list(action.keys())}") - debug_print(action) - # for k, v in action.items(): - # if hasattr(v, "shape"): - # logger.info(f" {k}: shape={v.shape} dtype={v.dtype} first_step={v[0,0,:7]}") - action = self.postprocessor(action) # Convert to numpy for msgpack serialization; squeeze batch dim [1,T,D] → [T,D] action[ACTION] = action[ACTION].squeeze(0).detach().cpu().numpy() # TODO: (yupu): rename_map for output key action["actions"] = action[ACTION] - logger.info(f"Final action shape: {action[ACTION].shape}, first_step={action[ACTION][0]}") return action @@ -258,7 +226,6 @@ def parse_config() -> DictConfig | ListConfig: def main(config: DictConfig | ListConfig) -> None: """Start the websocket policy server.""" policy = Policy(config) - # start websocket server server = WebsocketPolicyServer( policy=policy, host=policy.host, @@ -271,7 +238,4 @@ def main(config: DictConfig | ListConfig) -> None: if __name__ == "__main__": parsed_cfg = parse_config() - if isinstance(parsed_cfg, ListConfig): - main(parsed_cfg[0]) - else: - main(parsed_cfg["serve"][0]) + main(parsed_cfg["serve"][0]) diff --git a/flagscale/train/train_gr00t_n1_5.py b/flagscale/train/train_gr00t_n1_5.py index 5bce84be13..35412aeea7 100644 --- a/flagscale/train/train_gr00t_n1_5.py +++ b/flagscale/train/train_gr00t_n1_5.py @@ -15,7 +15,7 @@ import torch import torch.distributed as dist from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy -from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.device_mesh import DeviceMesh, init_device_mesh from torch.distributed.checkpoint.state_dict import get_model_state_dict, get_optimizer_state_dict, StateDictOptions from torch.optim import Optimizer @@ -71,17 +71,17 @@ def set_seed(seed: int): if get_platform().name() == "cuda": torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = False - torch.backends.cudnn.deterministic = False + torch.backends.cudnn.deterministic = False torch.backends.cuda.matmul.allow_tf32 = False -def apply_fsdp2(policy, device_mesh): +def apply_fsdp2(policy: TrainablePolicy, device_mesh: DeviceMesh) -> None: """Apply FSDP2 sharding to Gr00tN15. Uses a MixedPrecisionPolicy that matches DeepSpeed bf16 behavior: bf16.enabled=true + ZeRO-2 → param_dtype=bf16, reduce_dtype=bf16, reshard=False """ - # Cast everything to fp32 first so the root param group has uniform dtype. - policy = policy.float() + # Cast everything to bf16 so the root param group has uniform dtype. + policy = policy.bfloat16() # TODO: (yupu) check `reduce_dtype=torch.float32` mp_policy = MixedPrecisionPolicy( @@ -91,7 +91,7 @@ def apply_fsdp2(policy, device_mesh): fsdp_config = {"mesh": device_mesh, "mp_policy": mp_policy} # reshard_after_forward=False keeps params unsharded during forward+backward - reshard = False + reshard = True for unit in policy.fsdp_units(): fully_shard(unit, **fsdp_config, reshard_after_forward=reshard) @@ -419,7 +419,7 @@ def main(config: TrainConfig, seed: int): set_seed(seed) policy_config = PreTrainedConfig.from_train_config(config) - + local_rank = int(os.environ["LOCAL_RANK"]) get_platform().set_device(local_rank) dist.init_process_group(backend=get_platform().dist_backend()) @@ -540,7 +540,7 @@ def main(config: TrainConfig, seed: int): effective_batch_size = config.system.batch_size * world_size train_tracker = MetricsTracker( - effective_batch_size, + config.system.batch_size, num_frames, num_episodes, train_metrics, diff --git a/flagscale/train/train_pi.py b/flagscale/train/train_pi.py index bc73786ce7..15a966070a 100644 --- a/flagscale/train/train_pi.py +++ b/flagscale/train/train_pi.py @@ -74,7 +74,7 @@ def set_seed(seed: int): if get_platform().name() == "cuda": torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = False - torch.backends.cudnn.deterministic = False + torch.backends.cudnn.deterministic = False torch.backends.cuda.matmul.allow_tf32 = False @@ -644,7 +644,7 @@ def main(config: TrainConfig, seed: int): step = 0 train_tracker = MetricsTracker( - effective_batch_size, + config.system.batch_size, dataset.num_frames, dataset.num_episodes, train_metrics, diff --git a/flagscale/train/train_qwen_gr00t.py b/flagscale/train/train_qwen_gr00t.py index aaf5ab3e94..f457b57c8e 100644 --- a/flagscale/train/train_qwen_gr00t.py +++ b/flagscale/train/train_qwen_gr00t.py @@ -71,8 +71,8 @@ def apply_fsdp2(policy, device_mesh): Uses a MixedPrecisionPolicy that matches DeepSpeed bf16 behavior: bf16.enabled=true + ZeRO-2 → param_dtype=bf16, reduce_dtype=bf16, reshard=False """ - # Cast everything to fp32 first so the root param group has uniform dtype. - policy = policy.float() + # Cast everything to bf16 first so the root param group has uniform dtype. + policy = policy.bfloat16() # `reduce_dtype=torch.float32` would make evaluation on libero_goal drop to 94.8% (from 97.0%) mp_policy = MixedPrecisionPolicy( @@ -592,7 +592,7 @@ def main(config: TrainConfig, seed: int): effective_batch_size = config.system.batch_size * world_size train_tracker = MetricsTracker( - effective_batch_size, + config.system.batch_size, num_frames, num_episodes, train_metrics,