From 60294bf8794f2e350422943549883de8ae962b73 Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Tue, 5 May 2026 10:05:14 -0700 Subject: [PATCH 01/12] workable code --- .../qwen_vl/data/energon/task_encoder.py | 68 ++++++++++++++++++- .../bridge/recipes/qwen_vl/qwen3_vl.py | 36 ++++++++-- .../bridge/training/utils/visual_inputs.py | 19 +++++- .../recipes/qwen_vl/test_qwen3_vl_recipes.py | 4 ++ 4 files changed, 119 insertions(+), 8 deletions(-) diff --git a/src/megatron/bridge/recipes/qwen_vl/data/energon/task_encoder.py b/src/megatron/bridge/recipes/qwen_vl/data/energon/task_encoder.py index 50178a0749..ef33dc758b 100644 --- a/src/megatron/bridge/recipes/qwen_vl/data/energon/task_encoder.py +++ b/src/megatron/bridge/recipes/qwen_vl/data/energon/task_encoder.py @@ -21,7 +21,7 @@ import numpy as np import torch -from megatron.energon import Batch, DefaultTaskEncoder +from megatron.energon import Batch, DefaultTaskEncoder, SkipSample from transformers import BatchEncoding from megatron.bridge.data.energon.task_encoder_utils import ( @@ -39,6 +39,7 @@ from megatron.bridge.training.utils.visual_inputs import Qwen2_5_VLVisualInputs + def process_vision( processor, images, videos, fps=None, model_version: str = "qwen-vl", min_pixels=None, max_pixels=None ): @@ -56,7 +57,12 @@ def process_vision( image_grid_thw = None if videos is not None: - videos_inputs = processor(images=None, text="", videos=videos, return_tensors="pt") + # DEBUGGING + # videos_inputs = processor(images=None, text="", videos=videos, return_tensors="pt") + # Pre-decoded frames from WDS are already at the desired sampling rate. + # do_sample_frames=False prevents the processor from re-sampling them under + # a spurious 24 fps assumption, which would reduce most clips to T=2. + videos_inputs = processor.video_processor(videos=videos, return_tensors="pt", do_sample_frames=False) video_grid_thw = videos_inputs.get("video_grid_thw", None) else: videos_inputs = {} @@ -168,6 +174,9 @@ def __init__( max_padding_length: int = 4096, min_pixels: int = 200704, max_pixels: int = 1003520, + max_num_images: int | None = 10, + max_num_frames: int | None = 60, + max_visual_tokens: int | None = 16384, ): super().__init__() @@ -176,6 +185,9 @@ def __init__( self.seq_length = max_padding_length self.min_pixels = min_pixels self.max_pixels = max_pixels + self.max_num_images = max_num_images + self.max_num_frames = max_num_frames + self.max_visual_tokens = max_visual_tokens self.temporal_patch_size = temporal_patch_size self.merge_size = spatial_merge_size @@ -202,6 +214,34 @@ def encode_sample(self, sample: ChatMLSample): videos_for_processing = ( _videos_to_pil(sample.videos) if sample.videos is not None and len(sample.videos) > 0 else None ) + + if self.max_num_images is not None and imgs_for_processing is not None: + if len(imgs_for_processing) > self.max_num_images: + logging.warning( + "Skipping sample %s: %d images exceeds max_num_images=%d", + sample.__key__, + len(imgs_for_processing), + self.max_num_images, + ) + print(f"[DEBUG] (task_encoder.py) Skipping sample {sample.__key__} because it has {len(imgs_for_processing)} images, which exceeds max_num_images={self.max_num_images}") + raise SkipSample() + + if self.max_num_frames is not None and videos_for_processing is not None: + clipped = [] + for v in videos_for_processing: + if len(v) > self.max_num_frames: + logging.warning( + "Truncating %d frames to max_num_frames=%d for sample %s", + len(v), + self.max_num_frames, + sample.__key__, + ) + print(f"[DEBUG] (task_encoder.py) Truncating {len(v)} frames to max_num_frames={self.max_num_frames} for sample {sample.__key__}") + clipped.append(v[: self.max_num_frames]) + else: + clipped.append(v) + videos_for_processing = clipped + processed_vision = process_vision( self.image_processor, imgs_for_processing, @@ -214,6 +254,21 @@ def encode_sample(self, sample: ChatMLSample): flattened_imgs = processed_vision["image_inputs"] flattened_videos = processed_vision["video_inputs"] + merge_length = self.merge_size**2 + image_tokens = int(image_thw_grids.prod(dim=-1).sum().item()) // merge_length if image_thw_grids is not None else 0 + video_tokens = int(video_thw_grids.prod(dim=-1).sum().item()) // merge_length if video_thw_grids is not None else 0 + total_visual_tokens = image_tokens + video_tokens + if self.max_visual_tokens is not None: + if total_visual_tokens > self.max_visual_tokens: + logging.warning( + "Skipping sample %s: %d visual tokens exceeds max_visual_tokens=%d", + sample.__key__, + total_visual_tokens, + self.max_visual_tokens, + ) + print(f"[DEBUG] (task_encoder.py) Skipping sample {sample.__key__} because it has {total_visual_tokens} visual tokens, which exceeds max_visual_tokens={self.max_visual_tokens}") + raise SkipSample() + # Normalize conversation to [{"role": ..., "content": ...}, ...] conversation = cook_chatml_sample(sample.conversation) @@ -287,7 +342,12 @@ def encode_sample(self, sample: ChatMLSample): target_length = input_ids.shape[0] if target_length > self.seq_len: - logging.warning(f"Long sequence with length {target_length} found, dropped...") + if total_visual_tokens > self.seq_len: + logging.warning( + f"Long sequence with length {target_length} and visual tokens {total_visual_tokens} exceeds seq_len={self.seq_len}, truncation will affect visual tokens, dropping sample." + ) + print(f"[DEBUG] (task_encoder.py) Long sequence with length {target_length} and visual tokens {total_visual_tokens} exceeds seq_len={self.seq_len}, truncation will affect visual tokens, dropping sample.") + # raise SkipSample() final_input_ids = np.zeros(target_length, dtype=input_ids.dtype) final_input_masks = final_input_ids.copy() @@ -435,7 +495,9 @@ def encode_batch(self, batch: QwenVLTaskBatch) -> dict: raw["visual_inputs"] = Qwen2_5_VLVisualInputs( pixel_values=batch.pixel_values, + pixel_values_videos=batch.pixel_values_videos, image_grid_thw=batch.image_grid_thw, + video_grid_thw=batch.video_grid_thw, ) return raw diff --git a/src/megatron/bridge/recipes/qwen_vl/qwen3_vl.py b/src/megatron/bridge/recipes/qwen_vl/qwen3_vl.py index 5f84009620..72e631e657 100644 --- a/src/megatron/bridge/recipes/qwen_vl/qwen3_vl.py +++ b/src/megatron/bridge/recipes/qwen_vl/qwen3_vl.py @@ -20,6 +20,7 @@ from __future__ import annotations import os +from dataclasses import dataclass from typing import Optional, Union import torch @@ -28,6 +29,7 @@ from megatron.bridge import AutoBridge from megatron.bridge.data.energon.energon_provider import EnergonProvider +from megatron.bridge.data.utils import DatasetBuildContext from megatron.bridge.data.vlm_datasets import MockVLMConversationProvider from megatron.bridge.peft.base import PEFT from megatron.bridge.recipes.common import _peft_common_vlm, _sft_common_vlm @@ -265,10 +267,36 @@ def qwen3_vl_235b_a22b_pretrain_mock_config(**user_kwargs: Unpack[Qwen3VLCommonK return _qwen3_vl_common(**combined_kwargs) +@dataclass(kw_only=True) +class QwenVLEnergonProvider(EnergonProvider): + """EnergonProvider subclass that exposes task-encoder knobs as CLI-overridable fields. + + The task encoder is constructed eagerly (same as before), but build_datasets + syncs these fields onto it after CLI overrides have been applied. + """ + + min_pixels: int = 200704 + max_pixels: int = 1003520 + max_num_images: int | None = 10 + max_num_frames: int | None = 60 + max_visual_tokens: int | None = 16384 + + def build_datasets(self, context: DatasetBuildContext): + if self.task_encoder is not None: + self.task_encoder.seq_len = self.seq_length + self.task_encoder.seq_length = self.seq_length + self.task_encoder.min_pixels = self.min_pixels + self.task_encoder.max_pixels = self.max_pixels + self.task_encoder.max_num_images = self.max_num_images + self.task_encoder.max_num_frames = self.max_num_frames + self.task_encoder.max_visual_tokens = self.max_visual_tokens + return super().build_datasets(context) + + def _make_energon_dataset( hf_path: str, seq_length: int, micro_batch_size: int, global_batch_size: int -) -> EnergonProvider: - """Create an EnergonProvider dataset config for Qwen3-VL recipes.""" +) -> QwenVLEnergonProvider: + """Create a QwenVLEnergonProvider dataset config for Qwen3-VL recipes.""" tokenizer = AutoTokenizer.from_pretrained(hf_path) # Use Qwen3VLProcessor to match the HF flow (which uses AutoProcessor). # This processor accepts both images and videos kwargs. @@ -278,7 +306,7 @@ def _make_energon_dataset( image_processor=image_processor, max_padding_length=seq_length, ) - return EnergonProvider( + return QwenVLEnergonProvider( path="", # Must be set via CLI override: dataset.path= seq_length=seq_length, micro_batch_size=micro_batch_size, @@ -1140,5 +1168,5 @@ def qwen3_vl_8b_peft_energon_config(peft_scheme: str | PEFT = "lora") -> ConfigC """ cfg = qwen3_vl_8b_peft_config(peft_scheme=peft_scheme) hf_path = "Qwen/Qwen3-VL-8B-Instruct" - cfg.dataset = _make_energon_dataset(hf_path, 4096, cfg.train.micro_batch_size, cfg.train.global_batch_size) + cfg.dataset = _make_energon_dataset(hf_path, cfg.model.seq_length, cfg.train.micro_batch_size, cfg.train.global_batch_size) return cfg diff --git a/src/megatron/bridge/training/utils/visual_inputs.py b/src/megatron/bridge/training/utils/visual_inputs.py index 638dc86d66..bc19ea9ba7 100644 --- a/src/megatron/bridge/training/utils/visual_inputs.py +++ b/src/megatron/bridge/training/utils/visual_inputs.py @@ -61,9 +61,15 @@ class Qwen2_5_VLVisualInputs: # Image tensors, e.g., Qwen2.5-VL processor output. pixel_values: Optional[torch.Tensor] = None - # Per-image temporal/spatial grid metadata (T, H, W) for videos, Qwen2.5-VL. + # Video tensors, e.g., Qwen2.5-VL processor output. + pixel_values_videos: Optional[torch.Tensor] = None + + # Per-image (T, H, W) grid metadata. image_grid_thw: Optional[torch.Tensor] = None + # Per-video (T, H, W) grid metadata. + video_grid_thw: Optional[torch.Tensor] = None + def as_model_kwargs(self) -> dict[str, torch.Tensor]: """Return a mapping of non-None fields suitable for model forward kwargs.""" result: dict[str, torch.Tensor] = {} @@ -77,7 +83,9 @@ def normalized_for_model(self) -> dict[str, torch.Tensor]: """Return non-None fields with shapes normalized for model expectations. - pixel_values: [B, N, C, H, W] -> [B*N, C, H, W] + - pixel_values_videos: [B, N, C, H, W] -> [B*N, C, H, W] - image_grid_thw: [B, N, 3] -> [B*N, 3] + - video_grid_thw: [B, N, 3] -> [B*N, 3] """ kwargs = self.as_model_kwargs() @@ -86,10 +94,19 @@ def normalized_for_model(self) -> dict[str, torch.Tensor]: b, n, c, h, w = pixel_values.shape kwargs["pixel_values"] = pixel_values.view(b * n, c, h, w) + pixel_values_videos = kwargs.get("pixel_values_videos") + if isinstance(pixel_values_videos, torch.Tensor) and pixel_values_videos.dim() == 5: + b, n, c, h, w = pixel_values_videos.shape + kwargs["pixel_values_videos"] = pixel_values_videos.view(b * n, c, h, w) + image_grid_thw = kwargs.get("image_grid_thw") if isinstance(image_grid_thw, torch.Tensor) and image_grid_thw.dim() == 3: kwargs["image_grid_thw"] = image_grid_thw.view(-1, image_grid_thw.size(-1)) + video_grid_thw = kwargs.get("video_grid_thw") + if isinstance(video_grid_thw, torch.Tensor) and video_grid_thw.dim() == 3: + kwargs["video_grid_thw"] = video_grid_thw.view(-1, video_grid_thw.size(-1)) + return kwargs diff --git a/tests/unit_tests/recipes/qwen_vl/test_qwen3_vl_recipes.py b/tests/unit_tests/recipes/qwen_vl/test_qwen3_vl_recipes.py index 8b271447a4..4df792df2c 100644 --- a/tests/unit_tests/recipes/qwen_vl/test_qwen3_vl_recipes.py +++ b/tests/unit_tests/recipes/qwen_vl/test_qwen3_vl_recipes.py @@ -568,8 +568,12 @@ def test_qwen3_vl_8b_peft_energon_task_encoder(monkeypatch: pytest.MonkeyPatch): cfg = _qwen3_vl_module.qwen3_vl_8b_peft_energon_config() from megatron.bridge.recipes.qwen_vl.data.energon.task_encoder import QwenVLTaskEncoder + from megatron.bridge.recipes.qwen_vl.qwen3_vl import QwenVLEnergonProvider + assert isinstance(cfg.dataset, QwenVLEnergonProvider) assert isinstance(cfg.dataset.task_encoder, QwenVLTaskEncoder) + assert cfg.dataset.min_pixels == 200704 + assert cfg.dataset.max_pixels == 1003520 # ============================================================================= From 33ddcd98cd6961c9c5ea676b8036d841eb9e0c4a Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Wed, 6 May 2026 07:59:29 -0700 Subject: [PATCH 02/12] adding inference code for Qwen3 for multi-images and video --- .../conversion/hf_to_megatron_generate_vlm.py | 49 ++++++++++++---- examples/conversion/vlm_generate_utils.py | 56 +++++++++++++++++++ 2 files changed, 95 insertions(+), 10 deletions(-) diff --git a/examples/conversion/hf_to_megatron_generate_vlm.py b/examples/conversion/hf_to_megatron_generate_vlm.py index 6392aa34c6..bb6565069a 100644 --- a/examples/conversion/hf_to_megatron_generate_vlm.py +++ b/examples/conversion/hf_to_megatron_generate_vlm.py @@ -38,6 +38,8 @@ pad_input_ids_to_tp_multiple, patch_kimi_vision_processor, process_image_inputs, + process_multi_image_inputs, + process_video_inputs, to_cuda, ) @@ -64,6 +66,8 @@ def __init__( image_grid_thw=None, image_sizes=None, mm_token_type_ids=None, + pixel_values_videos=None, + video_grid_thw=None, ): self.batch = dict( tokens=input_ids, @@ -78,6 +82,10 @@ def __init__( self.batch["image_sizes"] = image_sizes if mm_token_type_ids is not None: self.batch["mm_token_type_ids"] = mm_token_type_ids + if pixel_values_videos is not None: + self.batch["pixel_values_videos"] = pixel_values_videos + if video_grid_thw is not None: + self.batch["video_grid_thw"] = video_grid_thw self._yielded = False def __iter__(self): @@ -98,7 +106,7 @@ def vlm_forward_step(data_iterator, model, **kwargs) -> torch.Tensor: "position_ids": batch["position_ids"], "attention_mask": batch.get("attention_mask"), } - for key in ("pixel_values", "image_grid_thw", "image_sizes", "mm_token_type_ids"): + for key in ("pixel_values", "image_grid_thw", "image_sizes", "mm_token_type_ids", "pixel_values_videos", "video_grid_thw"): if key in batch: forward_args[key] = batch[key] @@ -211,19 +219,33 @@ def _disable_mtp(m): # ------------------------------------------------------------------ # Process inputs # ------------------------------------------------------------------ - input_ids_raw, pixel_values, image_grid_thw, image_sizes, mm_token_type_ids = process_image_inputs( - processor, - args.image_path, - args.prompt, - is_kimi=is_kimi, - image_token_id=image_token_id, - ) + pixel_values = image_grid_thw = image_sizes = mm_token_type_ids = None + pixel_values_videos = video_grid_thw = None + + if args.video_path: + input_ids_raw, pixel_values_videos, video_grid_thw = process_video_inputs( + processor, args.video_path, args.prompt, fps=args.video_fps + ) + elif args.image_paths: + input_ids_raw, pixel_values, image_grid_thw = process_multi_image_inputs( + processor, args.image_paths, args.prompt + ) + else: + input_ids_raw, pixel_values, image_grid_thw, image_sizes, mm_token_type_ids = process_image_inputs( + processor, + args.image_path, + args.prompt, + is_kimi=is_kimi, + image_token_id=image_token_id, + ) input_ids_raw = input_ids_raw.cuda() pixel_values = to_cuda(pixel_values) image_grid_thw = to_cuda(image_grid_thw) image_sizes = to_cuda(image_sizes) mm_token_type_ids = to_cuda(mm_token_type_ids) + pixel_values_videos = to_cuda(pixel_values_videos) + video_grid_thw = to_cuda(video_grid_thw) # ------------------------------------------------------------------ # Greedy generation loop @@ -250,7 +272,8 @@ def _disable_mtp(m): fwd_bwd_function = get_forward_backward_func() iterator = SingleBatchIterator( - input_ids, position_ids, None, pixel_values, image_grid_thw, image_sizes, mm_ids_padded + input_ids, position_ids, None, pixel_values, image_grid_thw, image_sizes, mm_ids_padded, + pixel_values_videos, video_grid_thw, ) output = fwd_bwd_function( @@ -323,7 +346,13 @@ def _disable_mtp(m): "--pp_layout", type=str, default=None, help="Pipeline model parallel layout (e.g. 'Et*15|t*15|t*16|t*15L')" ) parser.add_argument("--megatron_model_path", type=str, default=None, help="Path to Megatron model checkpoint") - parser.add_argument("--image_path", type=str, default=None, help="Path or URL to image (optional).") + parser.add_argument("--image_path", type=str, default=None, help="Path or URL to a single image (optional).") + parser.add_argument("--image_paths", type=str, nargs="+", default=None, + help="Paths to N image files in order (multi-image; Qwen-family only).") + parser.add_argument("--video_path", type=str, default=None, + help="Path to a video file (Qwen-family only).") + parser.add_argument("--video_fps", type=float, default=2.0, + help="Frames per second to sample from the video (default: 2.0).") parser.add_argument("--trust_remote_code", action="store_true", help="Trust remote code for HF model loading") args = parser.parse_args() diff --git a/examples/conversion/vlm_generate_utils.py b/examples/conversion/vlm_generate_utils.py index f0f0d7b280..8650daf0a9 100644 --- a/examples/conversion/vlm_generate_utils.py +++ b/examples/conversion/vlm_generate_utils.py @@ -122,6 +122,62 @@ def to_cuda(x): return x.cuda() +def process_multi_image_inputs(processor, image_paths: list[str], prompt: str): + """Process N ordered images + prompt into model inputs (Qwen-family). + + Returns: + (input_ids, pixel_values, image_grid_thw) + """ + if not _HAS_QWEN_VL_UTILS: + raise ImportError("qwen-vl-utils required: pip install qwen-vl-utils") + pils = [load_image(p).convert("RGB") for p in image_paths] + messages = [ + { + "role": "user", + "content": [{"type": "image", "image": p} for p in pils] + + [{"type": "text", "text": prompt}], + } + ] + image_inputs, video_inputs = process_vision_info(messages) + text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + inputs = processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt") + return inputs.input_ids, inputs.get("pixel_values"), inputs.get("image_grid_thw") + + +def process_video_inputs(processor, video_path: str, prompt: str, *, fps: float = 2.0): + """Process a video + prompt into model inputs (Qwen-family). + + Frame decoding mirrors the Qwen3-VL training pipeline: fetch_video decodes at + ``fps``, then video_processor is called with do_sample_frames=False to use the + pre-decoded frames as-is. + + Returns: + (input_ids, pixel_values_videos, video_grid_thw) + """ + if not _HAS_QWEN_VL_UTILS: + raise ImportError("qwen-vl-utils required: pip install qwen-vl-utils") + from qwen_vl_utils import fetch_video + + frames = fetch_video({"video": video_path, "fps": fps}) + messages = [ + { + "role": "user", + "content": [{"type": "video"}, {"type": "text", "text": prompt}], + } + ] + text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + text_inputs = processor(text=[text], padding=True, return_tensors="pt") + video_proc = processor.video_processor(videos=[frames], return_tensors="pt", do_sample_frames=False) + # processor(text=...) without videos produces a single <|video_pad|> placeholder (id 151656). + # Pre-expand to match actual vision feature count so PP send/recv shapes are correct. + input_ids = pre_expand_image_tokens( + text_inputs["input_ids"], + video_proc["video_grid_thw"], + image_token_id=151656, # <|video_pad|> for Qwen-VL family + ) + return input_ids, video_proc.get("pixel_values_videos"), video_proc.get("video_grid_thw") + + def process_image_inputs( processor, image_path: Optional[str], From 9d1fed0caa83ff78769375e7f9fccbc996d9bc16 Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Wed, 6 May 2026 08:55:56 -0700 Subject: [PATCH 03/12] style: fix ruff-format line-length violations flagged by CI Co-Authored-By: Claude Sonnet 4.6 Signed-off-by: Huy Vu2 --- examples/conversion/hf_to_megatron_generate_vlm.py | 10 +++++++++- .../recipes/qwen_vl/data/energon/task_encoder.py | 8 ++++++-- src/megatron/bridge/recipes/qwen_vl/qwen3_vl.py | 4 +++- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/examples/conversion/hf_to_megatron_generate_vlm.py b/examples/conversion/hf_to_megatron_generate_vlm.py index 7c53b9f084..87873c1b50 100644 --- a/examples/conversion/hf_to_megatron_generate_vlm.py +++ b/examples/conversion/hf_to_megatron_generate_vlm.py @@ -109,7 +109,15 @@ def vlm_forward_step(data_iterator, model, **kwargs) -> torch.Tensor: "position_ids": batch["position_ids"], "attention_mask": batch.get("attention_mask"), } - for key in ("pixel_values", "image_grid_thw", "image_sizes", "mm_token_type_ids", "pixel_values_videos", "video_grid_thw", "image_position_ids"): + for key in ( + "pixel_values", + "image_grid_thw", + "image_sizes", + "mm_token_type_ids", + "pixel_values_videos", + "video_grid_thw", + "image_position_ids", + ): if key in batch: forward_args[key] = batch[key] diff --git a/src/megatron/bridge/recipes/qwen_vl/data/energon/task_encoder.py b/src/megatron/bridge/recipes/qwen_vl/data/energon/task_encoder.py index ef33dc758b..2cbf0177ce 100644 --- a/src/megatron/bridge/recipes/qwen_vl/data/energon/task_encoder.py +++ b/src/megatron/bridge/recipes/qwen_vl/data/energon/task_encoder.py @@ -266,7 +266,9 @@ def encode_sample(self, sample: ChatMLSample): total_visual_tokens, self.max_visual_tokens, ) - print(f"[DEBUG] (task_encoder.py) Skipping sample {sample.__key__} because it has {total_visual_tokens} visual tokens, which exceeds max_visual_tokens={self.max_visual_tokens}") + print( + f"[DEBUG] (task_encoder.py) Skipping sample {sample.__key__} because it has {total_visual_tokens} visual tokens, which exceeds max_visual_tokens={self.max_visual_tokens}" + ) raise SkipSample() # Normalize conversation to [{"role": ..., "content": ...}, ...] @@ -346,7 +348,9 @@ def encode_sample(self, sample: ChatMLSample): logging.warning( f"Long sequence with length {target_length} and visual tokens {total_visual_tokens} exceeds seq_len={self.seq_len}, truncation will affect visual tokens, dropping sample." ) - print(f"[DEBUG] (task_encoder.py) Long sequence with length {target_length} and visual tokens {total_visual_tokens} exceeds seq_len={self.seq_len}, truncation will affect visual tokens, dropping sample.") + print( + f"[DEBUG] (task_encoder.py) Long sequence with length {target_length} and visual tokens {total_visual_tokens} exceeds seq_len={self.seq_len}, truncation will affect visual tokens, dropping sample." + ) # raise SkipSample() final_input_ids = np.zeros(target_length, dtype=input_ids.dtype) final_input_masks = final_input_ids.copy() diff --git a/src/megatron/bridge/recipes/qwen_vl/qwen3_vl.py b/src/megatron/bridge/recipes/qwen_vl/qwen3_vl.py index d1e2a13da7..45c0a78dda 100644 --- a/src/megatron/bridge/recipes/qwen_vl/qwen3_vl.py +++ b/src/megatron/bridge/recipes/qwen_vl/qwen3_vl.py @@ -1168,5 +1168,7 @@ def qwen3_vl_8b_peft_energon_config(peft_scheme: str | PEFT = "lora") -> ConfigC """ cfg = qwen3_vl_8b_peft_config(peft_scheme=peft_scheme) hf_path = "Qwen/Qwen3-VL-8B-Instruct" - cfg.dataset = _make_energon_dataset(hf_path, cfg.model.seq_length, cfg.train.micro_batch_size, cfg.train.global_batch_size) + cfg.dataset = _make_energon_dataset( + hf_path, cfg.model.seq_length, cfg.train.micro_batch_size, cfg.train.global_batch_size + ) return cfg From 67e6f2403b95ff51ecd57ae3fb3b10456880e418 Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Wed, 6 May 2026 10:04:33 -0700 Subject: [PATCH 04/12] style: apply ruff-format reformats and remove debug prints Co-Authored-By: Claude Sonnet 4.6 Signed-off-by: Huy Vu2 --- .../conversion/hf_to_megatron_generate_vlm.py | 17 +++++++++++------ examples/conversion/vlm_generate_utils.py | 3 +-- .../qwen_vl/data/energon/task_encoder.py | 11 +---------- 3 files changed, 13 insertions(+), 18 deletions(-) diff --git a/examples/conversion/hf_to_megatron_generate_vlm.py b/examples/conversion/hf_to_megatron_generate_vlm.py index 87873c1b50..9ea04880e4 100644 --- a/examples/conversion/hf_to_megatron_generate_vlm.py +++ b/examples/conversion/hf_to_megatron_generate_vlm.py @@ -376,12 +376,17 @@ def _disable_mtp(m): ) parser.add_argument("--megatron_model_path", type=str, default=None, help="Path to Megatron model checkpoint") parser.add_argument("--image_path", type=str, default=None, help="Path or URL to a single image (optional).") - parser.add_argument("--image_paths", type=str, nargs="+", default=None, - help="Paths to N image files in order (multi-image; Qwen-family only).") - parser.add_argument("--video_path", type=str, default=None, - help="Path to a video file (Qwen-family only).") - parser.add_argument("--video_fps", type=float, default=2.0, - help="Frames per second to sample from the video (default: 2.0).") + parser.add_argument( + "--image_paths", + type=str, + nargs="+", + default=None, + help="Paths to N image files in order (multi-image; Qwen-family only).", + ) + parser.add_argument("--video_path", type=str, default=None, help="Path to a video file (Qwen-family only).") + parser.add_argument( + "--video_fps", type=float, default=2.0, help="Frames per second to sample from the video (default: 2.0)." + ) parser.add_argument("--trust_remote_code", action="store_true", help="Trust remote code for HF model loading") args = parser.parse_args() diff --git a/examples/conversion/vlm_generate_utils.py b/examples/conversion/vlm_generate_utils.py index 5a5e383311..db47fcf670 100644 --- a/examples/conversion/vlm_generate_utils.py +++ b/examples/conversion/vlm_generate_utils.py @@ -144,8 +144,7 @@ def process_multi_image_inputs(processor, image_paths: list[str], prompt: str): messages = [ { "role": "user", - "content": [{"type": "image", "image": p} for p in pils] - + [{"type": "text", "text": prompt}], + "content": [{"type": "image", "image": p} for p in pils] + [{"type": "text", "text": prompt}], } ] image_inputs, video_inputs = process_vision_info(messages) diff --git a/src/megatron/bridge/recipes/qwen_vl/data/energon/task_encoder.py b/src/megatron/bridge/recipes/qwen_vl/data/energon/task_encoder.py index 2cbf0177ce..12c1e4ed4b 100644 --- a/src/megatron/bridge/recipes/qwen_vl/data/energon/task_encoder.py +++ b/src/megatron/bridge/recipes/qwen_vl/data/energon/task_encoder.py @@ -57,8 +57,6 @@ def process_vision( image_grid_thw = None if videos is not None: - # DEBUGGING - # videos_inputs = processor(images=None, text="", videos=videos, return_tensors="pt") # Pre-decoded frames from WDS are already at the desired sampling rate. # do_sample_frames=False prevents the processor from re-sampling them under # a spurious 24 fps assumption, which would reduce most clips to T=2. @@ -223,7 +221,6 @@ def encode_sample(self, sample: ChatMLSample): len(imgs_for_processing), self.max_num_images, ) - print(f"[DEBUG] (task_encoder.py) Skipping sample {sample.__key__} because it has {len(imgs_for_processing)} images, which exceeds max_num_images={self.max_num_images}") raise SkipSample() if self.max_num_frames is not None and videos_for_processing is not None: @@ -266,9 +263,6 @@ def encode_sample(self, sample: ChatMLSample): total_visual_tokens, self.max_visual_tokens, ) - print( - f"[DEBUG] (task_encoder.py) Skipping sample {sample.__key__} because it has {total_visual_tokens} visual tokens, which exceeds max_visual_tokens={self.max_visual_tokens}" - ) raise SkipSample() # Normalize conversation to [{"role": ..., "content": ...}, ...] @@ -348,10 +342,7 @@ def encode_sample(self, sample: ChatMLSample): logging.warning( f"Long sequence with length {target_length} and visual tokens {total_visual_tokens} exceeds seq_len={self.seq_len}, truncation will affect visual tokens, dropping sample." ) - print( - f"[DEBUG] (task_encoder.py) Long sequence with length {target_length} and visual tokens {total_visual_tokens} exceeds seq_len={self.seq_len}, truncation will affect visual tokens, dropping sample." - ) - # raise SkipSample() + raise SkipSample() final_input_ids = np.zeros(target_length, dtype=input_ids.dtype) final_input_masks = final_input_ids.copy() From c1555c0961c46f3eb47d053cd9f97e8cce9d3cf8 Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Wed, 6 May 2026 10:13:38 -0700 Subject: [PATCH 05/12] style: fix remaining ruff-format violations in task_encoder.py Co-Authored-By: Claude Sonnet 4.6 Signed-off-by: Huy Vu2 --- .../recipes/qwen_vl/data/energon/task_encoder.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/megatron/bridge/recipes/qwen_vl/data/energon/task_encoder.py b/src/megatron/bridge/recipes/qwen_vl/data/energon/task_encoder.py index 12c1e4ed4b..67c86c5f77 100644 --- a/src/megatron/bridge/recipes/qwen_vl/data/energon/task_encoder.py +++ b/src/megatron/bridge/recipes/qwen_vl/data/energon/task_encoder.py @@ -39,7 +39,6 @@ from megatron.bridge.training.utils.visual_inputs import Qwen2_5_VLVisualInputs - def process_vision( processor, images, videos, fps=None, model_version: str = "qwen-vl", min_pixels=None, max_pixels=None ): @@ -233,7 +232,9 @@ def encode_sample(self, sample: ChatMLSample): self.max_num_frames, sample.__key__, ) - print(f"[DEBUG] (task_encoder.py) Truncating {len(v)} frames to max_num_frames={self.max_num_frames} for sample {sample.__key__}") + print( + f"[DEBUG] (task_encoder.py) Truncating {len(v)} frames to max_num_frames={self.max_num_frames} for sample {sample.__key__}" + ) clipped.append(v[: self.max_num_frames]) else: clipped.append(v) @@ -252,8 +253,12 @@ def encode_sample(self, sample: ChatMLSample): flattened_videos = processed_vision["video_inputs"] merge_length = self.merge_size**2 - image_tokens = int(image_thw_grids.prod(dim=-1).sum().item()) // merge_length if image_thw_grids is not None else 0 - video_tokens = int(video_thw_grids.prod(dim=-1).sum().item()) // merge_length if video_thw_grids is not None else 0 + image_tokens = ( + int(image_thw_grids.prod(dim=-1).sum().item()) // merge_length if image_thw_grids is not None else 0 + ) + video_tokens = ( + int(video_thw_grids.prod(dim=-1).sum().item()) // merge_length if video_thw_grids is not None else 0 + ) total_visual_tokens = image_tokens + video_tokens if self.max_visual_tokens is not None: if total_visual_tokens > self.max_visual_tokens: From 859eb867bbfe5f2d8d98413be1334723d3868e18 Mon Sep 17 00:00:00 2001 From: Huy Vu Date: Thu, 7 May 2026 08:22:59 -0700 Subject: [PATCH 06/12] [recipe] test: Update QwenVL task encoder test mocks for torch tensors Adapts the unit tests to the refactored encoder which now computes visual-token counts via .prod(dim=-1) (torch syntax) on the processor's image_grid_thw / video_grid_thw outputs. The mocks previously returned np.array, causing TypeError. Also bumps max_padding_length to 512 so the expanded sequence length stays within seq_len and avoids the new SkipSample() path. Signed-off-by: Huy Vu --- .../recipes/qwen_vl/data/energon/test_task_encoder.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/unit_tests/recipes/qwen_vl/data/energon/test_task_encoder.py b/tests/unit_tests/recipes/qwen_vl/data/energon/test_task_encoder.py index bb4e83abbd..d6e019f643 100644 --- a/tests/unit_tests/recipes/qwen_vl/data/energon/test_task_encoder.py +++ b/tests/unit_tests/recipes/qwen_vl/data/energon/test_task_encoder.py @@ -164,7 +164,7 @@ def setUp(self): self.encoder = QwenVLTaskEncoder( tokenizer=self.tokenizer, image_processor=self.image_processor, - max_padding_length=128, + max_padding_length=512, patch_size=14, spatial_merge_size=2, ) @@ -184,10 +184,10 @@ def test_encode_sample(self): def processor_side_effect(images=None, videos=None, **kwargs): res = {} if images: - res["image_grid_thw"] = np.array([[1, 28, 28]]) # 1 tile, 28x28 + res["image_grid_thw"] = torch.tensor([[1, 28, 28]]) # 1 tile, 28x28 res["pixel_values"] = torch.randn(1, 3, 28, 28) if videos: - res["video_grid_thw"] = np.array([[1, 28, 28]]) + res["video_grid_thw"] = torch.tensor([[1, 28, 28]]) res["pixel_values_videos"] = torch.randn(1, 3, 28, 28) return res @@ -243,10 +243,10 @@ def test_encode_sample_from_value_format(self): def processor_side_effect(images=None, videos=None, **kwargs): res = {} if images: - res["image_grid_thw"] = np.array([[1, 28, 28]]) + res["image_grid_thw"] = torch.tensor([[1, 28, 28]]) res["pixel_values"] = torch.randn(1, 3, 28, 28) if videos: - res["video_grid_thw"] = np.array([[1, 28, 28]]) + res["video_grid_thw"] = torch.tensor([[1, 28, 28]]) res["pixel_values_videos"] = torch.randn(1, 3, 28, 28) return res From bca268d8e5096fc6486487cf64e649e844da8c4c Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Thu, 7 May 2026 08:38:38 -0700 Subject: [PATCH 07/12] [docs, recipe] docs: Document Qwen3-VL visual token budget controls Adds README section describing the three composable controls that bound GPU cost per sample (min/max_pixels, max_num_images/max_num_frames, max_visual_tokens) and asserts the PEFT energon recipe defaults so the documented contract is enforced by tests. Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Huy Vu2 --- examples/models/vlm/qwen3_vl/README.md | 5 +++++ tests/unit_tests/recipes/qwen_vl/test_qwen3_vl_recipes.py | 4 +++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/examples/models/vlm/qwen3_vl/README.md b/examples/models/vlm/qwen3_vl/README.md index 38817cff96..21f6c22c3e 100644 --- a/examples/models/vlm/qwen3_vl/README.md +++ b/examples/models/vlm/qwen3_vl/README.md @@ -127,6 +127,11 @@ field_map: Then, update the dataset path (`dataset.path=/path/to/energon/dataset`) in [peft_energon.sh](peft_energon.sh) and run the script. +#### Controlling visual tokens computation budget +Three independent CLI-overridable controls bound a sample's GPU cost. They compose: +- **`dataset.min_pixels` / `dataset.max_pixels`** — image/frame resolutions lower and upper bound (defaults `200704` / `1003520`). +- **`dataset.max_num_images` / `dataset.max_num_frames`** - limit count of images/frames (defaults `10` / `60`). Too many images → sample is dropped. Too many frames → frame list truncated. +- **`dataset.max_visual_tokens`** — limit total visual tokens across all images and frames in a sample, computed post-rescaling as `prod(T,H,W) // merge_size²` (default `None` = disabled). Catches cases the other two miss (few images at high resolution, or many at low resolution). Exceeding samples are dropped. ### Expected Training Dynamics We provide a [Weights & Biases report](https://api.wandb.ai/links/nvidia-nemo-fw-public/lczz4ixx) for the expected loss curves and grad norms. diff --git a/tests/unit_tests/recipes/qwen_vl/test_qwen3_vl_recipes.py b/tests/unit_tests/recipes/qwen_vl/test_qwen3_vl_recipes.py index 4df792df2c..a101800a1c 100644 --- a/tests/unit_tests/recipes/qwen_vl/test_qwen3_vl_recipes.py +++ b/tests/unit_tests/recipes/qwen_vl/test_qwen3_vl_recipes.py @@ -574,7 +574,9 @@ def test_qwen3_vl_8b_peft_energon_task_encoder(monkeypatch: pytest.MonkeyPatch): assert isinstance(cfg.dataset.task_encoder, QwenVLTaskEncoder) assert cfg.dataset.min_pixels == 200704 assert cfg.dataset.max_pixels == 1003520 - + assert cfg.dataset.max_num_images == 10 + assert cfg.dataset.max_num_frames == 60 + assert cfg.dataset.max_visual_tokens == 16384 # ============================================================================= # Qwen3-VL Pretrain Mock Config Tests From 3b48e4e0d9afd5c426f8a39a166a45c70f797da8 Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Thu, 7 May 2026 08:45:30 -0700 Subject: [PATCH 08/12] [recipe] fix: Add missing blank line before module-level comment block Pre-commit / ruff format requires two blank lines between a function and the following module-level block. Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Huy Vu2 --- tests/unit_tests/recipes/qwen_vl/test_qwen3_vl_recipes.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit_tests/recipes/qwen_vl/test_qwen3_vl_recipes.py b/tests/unit_tests/recipes/qwen_vl/test_qwen3_vl_recipes.py index a101800a1c..d1974c8f4d 100644 --- a/tests/unit_tests/recipes/qwen_vl/test_qwen3_vl_recipes.py +++ b/tests/unit_tests/recipes/qwen_vl/test_qwen3_vl_recipes.py @@ -578,6 +578,7 @@ def test_qwen3_vl_8b_peft_energon_task_encoder(monkeypatch: pytest.MonkeyPatch): assert cfg.dataset.max_num_frames == 60 assert cfg.dataset.max_visual_tokens == 16384 + # ============================================================================= # Qwen3-VL Pretrain Mock Config Tests # ============================================================================= From 84274773c7f461f0690f69eebd021c2adb7b38a9 Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Thu, 7 May 2026 14:19:02 -0700 Subject: [PATCH 09/12] [recipe] test: Add unit tests for QwenVL task encoder limits, provider config sync, and visual inputs video reshape Covers three pieces of recently added behavior: - Per-sample budget limits in QwenVLTaskEncoder (max_num_images skip, max_num_frames truncation, default values). - QwenVLEnergonProvider.build_datasets propagating CLI-overridable knobs onto the task encoder before delegating to the parent. - Qwen2_5_VLVisualInputs.normalized_for_model handling video tensors and mixed image/video shapes, including already-flat passthrough. Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Huy Vu2 --- .../qwen_vl/data/energon/test_task_encoder.py | 208 ++++++++++++++++++ .../qwen_vl/test_qwen_vl_energon_provider.py | 156 +++++++++++++ .../training/utils/test_visual_inputs.py | 72 ++++++ 3 files changed, 436 insertions(+) create mode 100644 tests/unit_tests/recipes/qwen_vl/test_qwen_vl_energon_provider.py diff --git a/tests/unit_tests/recipes/qwen_vl/data/energon/test_task_encoder.py b/tests/unit_tests/recipes/qwen_vl/data/energon/test_task_encoder.py index d6e019f643..f2e4d2371f 100644 --- a/tests/unit_tests/recipes/qwen_vl/data/energon/test_task_encoder.py +++ b/tests/unit_tests/recipes/qwen_vl/data/energon/test_task_encoder.py @@ -21,6 +21,7 @@ import numpy as np import pytest import torch +from megatron.energon import SkipSample from PIL import Image from megatron.bridge.recipes.qwen_vl.data.energon.task_encoder import ( @@ -361,5 +362,212 @@ def test_encode_batch(self): self.assertNotIn("__subflavors__", encoded_dict) +class TestQwenVLTaskEncoderLimits(unittest.TestCase): + """Tests for the per-sample budget limits added to QwenVLTaskEncoder.""" + + def setUp(self): + self.tokenizer = MagicMock() + self.tokenizer.pad_token_id = 0 + self.tokenizer.eos_token_id = 1 + self.tokenizer.image_token_id = 151655 + self.tokenizer.video_token_id = 151656 + self.tokenizer.convert_tokens_to_ids.side_effect = lambda x: { + "": 151655, + "