diff --git a/flagscale/train/train_qwen_gr00t.py b/flagscale/train/train_qwen_gr00t.py index baa28897ee..95c6e0b960 100644 --- a/flagscale/train/train_qwen_gr00t.py +++ b/flagscale/train/train_qwen_gr00t.py @@ -85,10 +85,50 @@ def apply_fsdp2(policy, device_mesh): def make_dataset(cfg: DataConfig): - # TODO: (yupu) Remove hard-coded video backend - # After not much testing, It feels like that `torchcodec` is more robust than `pyav` - # `pyav` crashes sometimes - video_backend = "torchcodec" + # TODO: (yupu) Support image transforms + enable_image_transform = False + # Respect config first, keep torchcodec as safe default. + video_backend = ( + getattr(getattr(cfg, "vla_data", None), "video_backend", None) + or getattr(cfg, "video_backend", None) + or "torchcodec" + ) + + # image_transforms = ImageTransforms(cfg.image_transforms) if enable_image_transform else None + + # Match starVLA: resize uint8 via PIL, then normalize to [0,1] + def _resize_like_starvla(frames: torch.Tensor) -> torch.Tensor: + if not isinstance(frames, torch.Tensor): + return frames + is_single = False + if frames.dim() == 3: + frames = frames.unsqueeze(0) + is_single = True + if frames.dim() != 4: + return frames + from PIL import Image + import numpy as np + + resized_frames = [] + for frame in frames: + channel_last = frame.shape[-1] in (1, 3, 4) + if channel_last: + frame_hwc = frame + elif frame.shape[0] in (1, 3, 4): + frame_hwc = frame.permute(1, 2, 0) + else: + frame_hwc = frame + channel_last = True + frame_uint8 = (frame_hwc * 255).round().clamp(0, 255).to(torch.uint8) + pil = Image.fromarray(frame_uint8.cpu().numpy()).resize( + (224, 224), resample=Image.BILINEAR + ) + out = torch.from_numpy(np.array(pil)).to(frames.device).float() / 255.0 + if not channel_last: + out = out.permute(2, 0, 1) + resized_frames.append(out) + output = torch.stack(resized_frames, dim=0) + return output[0] if is_single else output def _resize_to_uint8_hwc(frame: torch.Tensor) -> torch.Tensor: """float32 CHW [0,1] from torchcodec → uint8 HWC 224x224 via PIL resize.""" @@ -201,6 +241,26 @@ def make_policy( if ft.type == FeatureType.ACTION } input_features = {key: ft for key, ft in features.items() if key not in output_features} + # kwargs["config"] = config.model + + # PI0 finetuning, so always load a pretrained policy. + # Load a pretrained policy and override the config if needed (for example, if there are inference-time + # hyperparameters that we want to vary). + # kwargs["pretrained_name_or_path"] = cfg.pretrained_path + # policy = policy_cls.from_pretrained(cfg.pretrained_path, config=cfg) + + # Keep visual key order deterministic across dataset/config variants. + visual_keys = [key for key, ft in input_features.items() if ft.type is FeatureType.VISUAL] + preferred_image_order = list( + getattr( + config.data, + "image_key_order", + ["observation.images.image", "observation.images.wrist_image"], + ) + ) + ordered_visual_keys = [key for key in preferred_image_order if key in visual_keys] + ordered_visual_keys.extend(key for key in visual_keys if key not in ordered_visual_keys) + config.data.vla_data.image_features = ordered_visual_keys policy = QwenGr00t(config=config) policy.input_features = input_features @@ -522,6 +582,7 @@ def main(config: TrainConfig, seed: int): optimizer, lr_scheduler = setup_optimizer_and_scheduler(policy, config) dist.barrier() + policy.train() train_metrics = { "loss": AverageMeter("loss", ":.3f"),