Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 65 additions & 4 deletions flagscale/train/train_qwen_gr00t.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,50 @@ def apply_fsdp2(policy, device_mesh):


def make_dataset(cfg: DataConfig):
# TODO: (yupu) Remove hard-coded video backend
# After not much testing, It feels like that `torchcodec` is more robust than `pyav`
# `pyav` crashes sometimes
video_backend = "torchcodec"
# TODO: (yupu) Support image transforms
enable_image_transform = False
# Respect config first, keep torchcodec as safe default.
video_backend = (
getattr(getattr(cfg, "vla_data", None), "video_backend", None)
or getattr(cfg, "video_backend", None)
or "torchcodec"
)

# image_transforms = ImageTransforms(cfg.image_transforms) if enable_image_transform else None

# Match starVLA: resize uint8 via PIL, then normalize to [0,1]
def _resize_like_starvla(frames: torch.Tensor) -> torch.Tensor:
if not isinstance(frames, torch.Tensor):
return frames
is_single = False
if frames.dim() == 3:
frames = frames.unsqueeze(0)
is_single = True
if frames.dim() != 4:
return frames
from PIL import Image
import numpy as np

resized_frames = []
for frame in frames:
channel_last = frame.shape[-1] in (1, 3, 4)
if channel_last:
frame_hwc = frame
elif frame.shape[0] in (1, 3, 4):
frame_hwc = frame.permute(1, 2, 0)
else:
frame_hwc = frame
channel_last = True
frame_uint8 = (frame_hwc * 255).round().clamp(0, 255).to(torch.uint8)
pil = Image.fromarray(frame_uint8.cpu().numpy()).resize(
(224, 224), resample=Image.BILINEAR
)
out = torch.from_numpy(np.array(pil)).to(frames.device).float() / 255.0
if not channel_last:
out = out.permute(2, 0, 1)
resized_frames.append(out)
output = torch.stack(resized_frames, dim=0)
return output[0] if is_single else output

def _resize_to_uint8_hwc(frame: torch.Tensor) -> torch.Tensor:
"""float32 CHW [0,1] from torchcodec → uint8 HWC 224x224 via PIL resize."""
Expand Down Expand Up @@ -201,6 +241,26 @@ def make_policy(
if ft.type == FeatureType.ACTION
}
input_features = {key: ft for key, ft in features.items() if key not in output_features}
# kwargs["config"] = config.model

# PI0 finetuning, so always load a pretrained policy.
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
# hyperparameters that we want to vary).
# kwargs["pretrained_name_or_path"] = cfg.pretrained_path
# policy = policy_cls.from_pretrained(cfg.pretrained_path, config=cfg)

# Keep visual key order deterministic across dataset/config variants.
visual_keys = [key for key, ft in input_features.items() if ft.type is FeatureType.VISUAL]
preferred_image_order = list(
getattr(
config.data,
"image_key_order",
["observation.images.image", "observation.images.wrist_image"],
)
)
ordered_visual_keys = [key for key in preferred_image_order if key in visual_keys]
ordered_visual_keys.extend(key for key in visual_keys if key not in ordered_visual_keys)
config.data.vla_data.image_features = ordered_visual_keys

policy = QwenGr00t(config=config)
policy.input_features = input_features
Expand Down Expand Up @@ -522,6 +582,7 @@ def main(config: TrainConfig, seed: int):
optimizer, lr_scheduler = setup_optimizer_and_scheduler(policy, config)

dist.barrier()
policy.train()

train_metrics = {
"loss": AverageMeter("loss", ":.3f"),
Expand Down