diff --git a/examples/conversion/hf_to_megatron_generate_vlm.py b/examples/conversion/hf_to_megatron_generate_vlm.py index f25b40bb1b..9ea04880e4 100644 --- a/examples/conversion/hf_to_megatron_generate_vlm.py +++ b/examples/conversion/hf_to_megatron_generate_vlm.py @@ -38,6 +38,8 @@ pad_input_ids_to_tp_multiple, patch_kimi_vision_processor, process_image_inputs, + process_multi_image_inputs, + process_video_inputs, to_cuda, ) @@ -64,6 +66,8 @@ def __init__( image_grid_thw=None, image_sizes=None, mm_token_type_ids=None, + pixel_values_videos=None, + video_grid_thw=None, image_position_ids=None, ): self.batch = dict( @@ -79,6 +83,10 @@ def __init__( self.batch["image_sizes"] = image_sizes if mm_token_type_ids is not None: self.batch["mm_token_type_ids"] = mm_token_type_ids + if pixel_values_videos is not None: + self.batch["pixel_values_videos"] = pixel_values_videos + if video_grid_thw is not None: + self.batch["video_grid_thw"] = video_grid_thw if image_position_ids is not None: self.batch["image_position_ids"] = image_position_ids self._yielded = False @@ -101,7 +109,15 @@ def vlm_forward_step(data_iterator, model, **kwargs) -> torch.Tensor: "position_ids": batch["position_ids"], "attention_mask": batch.get("attention_mask"), } - for key in ("pixel_values", "image_grid_thw", "image_sizes", "mm_token_type_ids", "image_position_ids"): + for key in ( + "pixel_values", + "image_grid_thw", + "image_sizes", + "mm_token_type_ids", + "pixel_values_videos", + "video_grid_thw", + "image_position_ids", + ): if key in batch: forward_args[key] = batch[key] @@ -215,27 +231,41 @@ def _disable_mtp(m): # ------------------------------------------------------------------ # Process inputs # ------------------------------------------------------------------ + pixel_values = image_grid_thw = image_sizes = mm_token_type_ids = image_position_ids = None + pixel_values_videos = video_grid_thw = None + + if args.video_path: + input_ids_raw, pixel_values_videos, video_grid_thw = process_video_inputs( + processor, args.video_path, args.prompt, fps=args.video_fps + ) + elif args.image_paths: + input_ids_raw, pixel_values, image_grid_thw = process_multi_image_inputs( + processor, args.image_paths, args.prompt + ) + else: + ( + input_ids_raw, + pixel_values, + image_grid_thw, + image_sizes, + mm_token_type_ids, + image_position_ids, + ) = process_image_inputs( + processor, + args.image_path, + args.prompt, + is_gemma4=is_gemma4, + is_kimi=is_kimi, + image_token_id=image_token_id, + ) - ( - input_ids_raw, - pixel_values, - image_grid_thw, - image_sizes, - mm_token_type_ids, - image_position_ids, - ) = process_image_inputs( - processor, - args.image_path, - args.prompt, - is_gemma4=is_gemma4, - is_kimi=is_kimi, - image_token_id=image_token_id, - ) input_ids_raw = input_ids_raw.cuda() pixel_values = to_cuda(pixel_values) image_grid_thw = to_cuda(image_grid_thw) image_sizes = to_cuda(image_sizes) mm_token_type_ids = to_cuda(mm_token_type_ids) + pixel_values_videos = to_cuda(pixel_values_videos) + video_grid_thw = to_cuda(video_grid_thw) image_position_ids = to_cuda(image_position_ids) # ------------------------------------------------------------------ @@ -270,7 +300,9 @@ def _disable_mtp(m): image_grid_thw, image_sizes, mm_ids_padded, - image_position_ids, + pixel_values_videos=pixel_values_videos, + video_grid_thw=video_grid_thw, + image_position_ids=image_position_ids, ) output = fwd_bwd_function( @@ -343,7 +375,18 @@ def _disable_mtp(m): "--pp_layout", type=str, default=None, help="Pipeline model parallel layout (e.g. 'Et*15|t*15|t*16|t*15L')" ) parser.add_argument("--megatron_model_path", type=str, default=None, help="Path to Megatron model checkpoint") - parser.add_argument("--image_path", type=str, default=None, help="Path or URL to image (optional).") + parser.add_argument("--image_path", type=str, default=None, help="Path or URL to a single image (optional).") + parser.add_argument( + "--image_paths", + type=str, + nargs="+", + default=None, + help="Paths to N image files in order (multi-image; Qwen-family only).", + ) + parser.add_argument("--video_path", type=str, default=None, help="Path to a video file (Qwen-family only).") + parser.add_argument( + "--video_fps", type=float, default=2.0, help="Frames per second to sample from the video (default: 2.0)." + ) parser.add_argument("--trust_remote_code", action="store_true", help="Trust remote code for HF model loading") args = parser.parse_args() diff --git a/examples/conversion/vlm_generate_utils.py b/examples/conversion/vlm_generate_utils.py index 7b8c193f63..db47fcf670 100644 --- a/examples/conversion/vlm_generate_utils.py +++ b/examples/conversion/vlm_generate_utils.py @@ -132,6 +132,61 @@ def to_cuda(x): return x.cuda() +def process_multi_image_inputs(processor, image_paths: list[str], prompt: str): + """Process N ordered images + prompt into model inputs (Qwen-family). + + Returns: + (input_ids, pixel_values, image_grid_thw) + """ + if not _HAS_QWEN_VL_UTILS: + raise ImportError("qwen-vl-utils required: pip install qwen-vl-utils") + pils = [load_image(p).convert("RGB") for p in image_paths] + messages = [ + { + "role": "user", + "content": [{"type": "image", "image": p} for p in pils] + [{"type": "text", "text": prompt}], + } + ] + image_inputs, video_inputs = process_vision_info(messages) + text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + inputs = processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt") + return inputs.input_ids, inputs.get("pixel_values"), inputs.get("image_grid_thw") + + +def process_video_inputs(processor, video_path: str, prompt: str, *, fps: float = 2.0): + """Process a video + prompt into model inputs (Qwen-family). + + Frame decoding mirrors the Qwen3-VL training pipeline: fetch_video decodes at + ``fps``, then video_processor is called with do_sample_frames=False to use the + pre-decoded frames as-is. + + Returns: + (input_ids, pixel_values_videos, video_grid_thw) + """ + if not _HAS_QWEN_VL_UTILS: + raise ImportError("qwen-vl-utils required: pip install qwen-vl-utils") + from qwen_vl_utils import fetch_video + + frames = fetch_video({"video": video_path, "fps": fps}) + messages = [ + { + "role": "user", + "content": [{"type": "video"}, {"type": "text", "text": prompt}], + } + ] + text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + text_inputs = processor(text=[text], padding=True, return_tensors="pt") + video_proc = processor.video_processor(videos=[frames], return_tensors="pt", do_sample_frames=False) + # processor(text=...) without videos produces a single <|video_pad|> placeholder (id 151656). + # Pre-expand to match actual vision feature count so PP send/recv shapes are correct. + input_ids = pre_expand_image_tokens( + text_inputs["input_ids"], + video_proc["video_grid_thw"], + image_token_id=151656, # <|video_pad|> for Qwen-VL family + ) + return input_ids, video_proc.get("pixel_values_videos"), video_proc.get("video_grid_thw") + + def process_image_inputs( processor, image_path: Optional[str], diff --git a/examples/models/vlm/qwen3_vl/README.md b/examples/models/vlm/qwen3_vl/README.md index 38817cff96..1e40758829 100644 --- a/examples/models/vlm/qwen3_vl/README.md +++ b/examples/models/vlm/qwen3_vl/README.md @@ -127,10 +127,73 @@ field_map: Then, update the dataset path (`dataset.path=/path/to/energon/dataset`) in [peft_energon.sh](peft_energon.sh) and run the script. - ### Expected Training Dynamics We provide a [Weights & Biases report](https://api.wandb.ai/links/nvidia-nemo-fw-public/lczz4ixx) for the expected loss curves and grad norms. +## Dataset with Multiple Images + +Below is the example + +1. Download the LLavA-Pretrain dataset from Hugging Face and unzip the images folder (NOTE: 79GB of disk space required): + + ``` + pip install -U "huggingface_hub[cli]" + huggingface-cli download TIGER-Lab/Mantis-Instruct \ + --include "llava_665k_multi/*" \ + --repo-type dataset \ + --local-dir /path/to/Mantis-Instruct-LLaVA + ``` + +3. Run the following script to convert the data to webdataset format: + + ``` + cd + python examples/models/vlm/qwen3_vl/prepare_mantis_energon.py \ + --source-dir/path/to/Mantis-Instruct-LLaVA \ + --output-dir /path/to/Mantis-Instruct-LLaVA/wds \ + --max-samples-per-tar 10000 + ``` + +4. Run the following command to convert to megatron-energon format: + + ``` + cd /path/to/Mantis-Instruct-LLaVA/wds + energon prepare ./ + ``` + + select the following values for the presented options: + + ``` + > Please enter a desired train/val/test split like "0.5, 0.2, 0.3" or "8,1,1": 9,1,0 + > Do you want to create a dataset.yaml interactively? [Y/n]: Y + > Please enter a number to choose a class: 9 (VQASample) + > Do you want to set a simple field_map[Y] (or write your own sample_loader [n])? [Y/n]: Y + > Please enter a webdataset field name for 'image' (): jpg + > Please enter a webdataset field name for 'context' (): json[0][value] + > Please enter a webdataset field name for 'answers' (typing.Optional[typing.List[str]], default: None): json[1][value] + > Please enter a webdataset field name for 'answer_weights' (typing.Optional[torch.Tensor], default: None): + ``` + +5. Change the file `.nv-meta/dataset.yaml` to the following: + + ```yaml + __module__: megatron.bridge.recipes.qwen_vl.data.energon.task_encoder + __class__: ChatMLWebdataset + field_map: + imgs: jpgs + conversation: json + ``` + +Then, update the dataset path (`dataset.path=/path/to/energon/dataset`) in [peft_energon.sh](peft_energon.sh) and run the script. + + + +#### Controlling visual tokens computation budget +Three independent CLI-overridable controls bound a sample's GPU cost. They compose: +- **`dataset.min_pixels` / `dataset.max_pixels`** — image/frame resolutions lower and upper bound (defaults `200704` / `1003520`). +- **`dataset.max_num_images` / `dataset.max_num_frames`** - limit count of images/frames (defaults `10` / `60`). Too many images → sample is dropped. Too many frames → frame list truncated. +- **`dataset.max_visual_tokens`** — limit total visual tokens across all images and frames in a sample, computed post-rescaling as `prod(T,H,W) // merge_size²` (default `16384`; set to `None` to disable). Catches cases the other two miss (few images at high resolution, or many at low resolution). Exceeding samples are dropped. + ## Evaluation Coming soon. diff --git a/examples/models/vlm/qwen3_vl/prepare_mantis_energon.py b/examples/models/vlm/qwen3_vl/prepare_mantis_energon.py new file mode 100644 index 0000000000..7fdd0f3287 --- /dev/null +++ b/examples/models/vlm/qwen3_vl/prepare_mantis_energon.py @@ -0,0 +1,155 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Convert TIGER-Lab/Mantis-Instruct to WebDataset shards for Energon-based QwenVL training. + +All subsets are merged into a single train split. + +Usage:: + + python examples/models/vlm/qwen3_vl/prepare_mantis_energon.py \\ + --source-dir /path/to/mantis_instruct/Mantis-Instruct \\ + --output-dir /path/to/mantis_energon \\ + --max-samples-per-tar 1000 + +Source layout:: + + Mantis-Instruct/ + {subset}/ + train-*.parquet # images: [{'bytes': None, 'path': 'subdir/foo.png'}, ...] + train_images.zip # image files (may include subdirectories) + ... + +Output layout:: + + / + shard-000000.tar + shard-000001.tar + ... + +Each shard entry: ``{subset}_{id}.jpgs`` (pickled list of raw image bytes) + ``{subset}_{id}.json``. +Images from each subset are extracted from the corresponding zip on first run (skipped on +subsequent runs via a ``.extracted_`` marker file). +""" + +import json +import logging +import os +import pickle +import zipfile +from argparse import ArgumentParser + +import pandas as pd +import webdataset as wds +from tqdm import tqdm + + +logger = logging.getLogger(__name__) + +_MARKER_PREFIX = ".extracted_" + + +def _ensure_extracted(subset_dir: str, zip_name: str) -> None: + zip_path = os.path.join(subset_dir, zip_name) + if not os.path.exists(zip_path): + return + marker = os.path.join(subset_dir, _MARKER_PREFIX + zip_name) + if os.path.exists(marker): + return + logger.info("Extracting %s ...", zip_path) + with zipfile.ZipFile(zip_path) as zf: + for member in tqdm(zf.namelist(), desc=f"extract {zip_name}", unit="file"): + dest = os.path.join(subset_dir, member) + if not os.path.exists(dest): + zf.extract(member, subset_dir) + open(marker, "w").close() + + +def convert(source_dir: str, output_dir: str, max_count: int) -> None: + """Convert Mantis-Instruct subsets under ``source_dir`` into WebDataset shards under ``output_dir``. + + Args: + source_dir: Path to the ``Mantis-Instruct/`` directory containing per-subset folders. + output_dir: Destination directory for the generated ``shard-*.tar`` files. + max_count: Maximum number of samples per output shard. + """ + subsets = sorted(d for d in os.listdir(source_dir) if os.path.isdir(os.path.join(source_dir, d))) + if not subsets: + raise FileNotFoundError(f"No subset directories found in {source_dir}") + + os.makedirs(output_dir, exist_ok=True) + shard_pattern = os.path.join(output_dir, "shard-%06d.tar") + total_written = 0 + total_skipped = 0 + + with wds.ShardWriter(shard_pattern, maxcount=max_count) as sink: + for subset in subsets: + subset_dir = os.path.join(source_dir, subset) + _ensure_extracted(subset_dir, "train_images.zip") + + parquet_files = sorted( + f for f in os.listdir(subset_dir) + if f.startswith("train-") and f.endswith(".parquet") + ) + if not parquet_files: + logger.debug("No train parquets in subset %s, skipping", subset) + continue + + for pf in parquet_files: + df = pd.read_parquet(os.path.join(subset_dir, pf)) + pf_stem = pf.replace(".parquet", "") + for idx, (_, row) in enumerate(tqdm(df.iterrows(), total=len(df), desc=f"{subset}/{pf}", unit="sample")): + if row["images"] is None or len(row["images"]) == 0: + total_skipped += 1 + continue + + try: + imgs = [ + open(os.path.join(subset_dir, ref["path"]), "rb").read() + for ref in row["images"] + ] + except Exception as exc: + logger.warning("Skipping %s/%s idx=%d: %s", subset, pf, idx, exc) + total_skipped += 1 + continue + + conversation = [dict(t) for t in row["conversation"]] + n_placeholders = sum(t["content"].count("") for t in conversation) + if n_placeholders != len(imgs): + logger.warning( + "Skipping %s/%s idx=%d: %d placeholders but %d images", + subset, pf, idx, n_placeholders, len(imgs), + ) + total_skipped += 1 + continue + + sink.write({ + "__key__": f"{subset}__{pf_stem}__{idx:06d}", + "jpgs": pickle.dumps(imgs), + "json": json.dumps(conversation).encode(), + }) + total_written += 1 + + logger.info("Wrote %d samples (%d skipped) → %s", total_written, total_skipped, output_dir) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s") + p = ArgumentParser(description="Convert Mantis-Instruct to WebDataset Energon format.") + p.add_argument("--source-dir", required=True, help="Path to Mantis-Instruct/ directory") + p.add_argument("--output-dir", required=True, help="Output directory for Energon shards") + p.add_argument("--max-samples-per-tar", type=int, default=1000, metavar="N") + args = p.parse_args() + convert(args.source_dir, args.output_dir, args.max_samples_per_tar) + print(f"Done. Set dataset.path={args.output_dir}") diff --git a/src/megatron/bridge/recipes/qwen_vl/data/energon/task_encoder.py b/src/megatron/bridge/recipes/qwen_vl/data/energon/task_encoder.py index 50178a0749..f2cb0d5556 100644 --- a/src/megatron/bridge/recipes/qwen_vl/data/energon/task_encoder.py +++ b/src/megatron/bridge/recipes/qwen_vl/data/energon/task_encoder.py @@ -21,7 +21,7 @@ import numpy as np import torch -from megatron.energon import Batch, DefaultTaskEncoder +from megatron.energon import Batch, DefaultTaskEncoder, SkipSample from transformers import BatchEncoding from megatron.bridge.data.energon.task_encoder_utils import ( @@ -56,7 +56,10 @@ def process_vision( image_grid_thw = None if videos is not None: - videos_inputs = processor(images=None, text="", videos=videos, return_tensors="pt") + # Pre-decoded frames from WDS are already at the desired sampling rate. + # do_sample_frames=False prevents the processor from re-sampling them under + # a spurious 24 fps assumption, which would reduce most clips to T=2. + videos_inputs = processor.video_processor(videos=videos, return_tensors="pt", do_sample_frames=False) video_grid_thw = videos_inputs.get("video_grid_thw", None) else: videos_inputs = {} @@ -168,6 +171,9 @@ def __init__( max_padding_length: int = 4096, min_pixels: int = 200704, max_pixels: int = 1003520, + max_num_images: int | None = 10, + max_num_frames: int | None = 60, + max_visual_tokens: int | None = 16384, ): super().__init__() @@ -176,6 +182,9 @@ def __init__( self.seq_length = max_padding_length self.min_pixels = min_pixels self.max_pixels = max_pixels + self.max_num_images = max_num_images + self.max_num_frames = max_num_frames + self.max_visual_tokens = max_visual_tokens self.temporal_patch_size = temporal_patch_size self.merge_size = spatial_merge_size @@ -202,6 +211,32 @@ def encode_sample(self, sample: ChatMLSample): videos_for_processing = ( _videos_to_pil(sample.videos) if sample.videos is not None and len(sample.videos) > 0 else None ) + + if self.max_num_images is not None and imgs_for_processing is not None: + if len(imgs_for_processing) > self.max_num_images: + logging.warning( + "Skipping sample %s: %d images exceeds max_num_images=%d", + sample.__key__, + len(imgs_for_processing), + self.max_num_images, + ) + raise SkipSample() + + if self.max_num_frames is not None and videos_for_processing is not None: + clipped = [] + for v in videos_for_processing: + if len(v) > self.max_num_frames: + logging.warning( + "Truncating %d frames to max_num_frames=%d for sample %s", + len(v), + self.max_num_frames, + sample.__key__, + ) + clipped.append(v[: self.max_num_frames]) + else: + clipped.append(v) + videos_for_processing = clipped + processed_vision = process_vision( self.image_processor, imgs_for_processing, @@ -214,6 +249,20 @@ def encode_sample(self, sample: ChatMLSample): flattened_imgs = processed_vision["image_inputs"] flattened_videos = processed_vision["video_inputs"] + merge_length = self.merge_size**2 + image_tokens = int(image_thw_grids.prod(-1).sum()) // merge_length if image_thw_grids is not None else 0 + video_tokens = int(video_thw_grids.prod(-1).sum()) // merge_length if video_thw_grids is not None else 0 + total_visual_tokens = image_tokens + video_tokens + if self.max_visual_tokens is not None: + if total_visual_tokens > self.max_visual_tokens: + logging.warning( + "Skipping sample %s: %d visual tokens exceeds max_visual_tokens=%d", + sample.__key__, + total_visual_tokens, + self.max_visual_tokens, + ) + raise SkipSample() + # Normalize conversation to [{"role": ..., "content": ...}, ...] conversation = cook_chatml_sample(sample.conversation) @@ -287,7 +336,22 @@ def encode_sample(self, sample: ChatMLSample): target_length = input_ids.shape[0] if target_length > self.seq_len: - logging.warning(f"Long sequence with length {target_length} found, dropped...") + if total_visual_tokens > self.seq_len: + logging.warning( + "Sample %s: target_length=%d with visual_tokens=%d exceeds seq_len=%d; " + "truncation would corrupt visual tokens, dropping sample.", + sample.__key__, + target_length, + total_visual_tokens, + self.seq_len, + ) + raise SkipSample() + logging.warning( + "Sample %s: target_length=%d exceeds seq_len=%d; text will be truncated.", + sample.__key__, + target_length, + self.seq_len, + ) final_input_ids = np.zeros(target_length, dtype=input_ids.dtype) final_input_masks = final_input_ids.copy() @@ -435,7 +499,9 @@ def encode_batch(self, batch: QwenVLTaskBatch) -> dict: raw["visual_inputs"] = Qwen2_5_VLVisualInputs( pixel_values=batch.pixel_values, + pixel_values_videos=batch.pixel_values_videos, image_grid_thw=batch.image_grid_thw, + video_grid_thw=batch.video_grid_thw, ) return raw diff --git a/src/megatron/bridge/recipes/qwen_vl/qwen3_vl.py b/src/megatron/bridge/recipes/qwen_vl/qwen3_vl.py index 6f12c3080d..45c0a78dda 100644 --- a/src/megatron/bridge/recipes/qwen_vl/qwen3_vl.py +++ b/src/megatron/bridge/recipes/qwen_vl/qwen3_vl.py @@ -20,6 +20,7 @@ from __future__ import annotations import os +from dataclasses import dataclass from typing import Optional, Union import torch @@ -28,6 +29,7 @@ from megatron.bridge import AutoBridge from megatron.bridge.data.energon.energon_provider import EnergonProvider +from megatron.bridge.data.utils import DatasetBuildContext from megatron.bridge.data.vlm_datasets import MockVLMConversationProvider from megatron.bridge.peft.base import PEFT from megatron.bridge.recipes.common import _peft_common_vlm, _sft_common_vlm @@ -265,10 +267,36 @@ def qwen3_vl_235b_a22b_pretrain_mock_config(**user_kwargs: Unpack[Qwen3VLCommonK return _qwen3_vl_common(**combined_kwargs) +@dataclass(kw_only=True) +class QwenVLEnergonProvider(EnergonProvider): + """EnergonProvider subclass that exposes task-encoder knobs as CLI-overridable fields. + + The task encoder is constructed eagerly (same as before), but build_datasets + syncs these fields onto it after CLI overrides have been applied. + """ + + min_pixels: int = 200704 + max_pixels: int = 1003520 + max_num_images: int | None = 10 + max_num_frames: int | None = 60 + max_visual_tokens: int | None = 16384 + + def build_datasets(self, context: DatasetBuildContext): + if self.task_encoder is not None: + self.task_encoder.seq_len = self.seq_length + self.task_encoder.seq_length = self.seq_length + self.task_encoder.min_pixels = self.min_pixels + self.task_encoder.max_pixels = self.max_pixels + self.task_encoder.max_num_images = self.max_num_images + self.task_encoder.max_num_frames = self.max_num_frames + self.task_encoder.max_visual_tokens = self.max_visual_tokens + return super().build_datasets(context) + + def _make_energon_dataset( hf_path: str, seq_length: int, micro_batch_size: int, global_batch_size: int -) -> EnergonProvider: - """Create an EnergonProvider dataset config for Qwen3-VL recipes.""" +) -> QwenVLEnergonProvider: + """Create a QwenVLEnergonProvider dataset config for Qwen3-VL recipes.""" tokenizer = AutoTokenizer.from_pretrained(hf_path) # Use Qwen3VLProcessor to match the HF flow (which uses AutoProcessor). # This processor accepts both images and videos kwargs. @@ -278,7 +306,7 @@ def _make_energon_dataset( image_processor=image_processor, max_padding_length=seq_length, ) - return EnergonProvider( + return QwenVLEnergonProvider( path="", # Must be set via CLI override: dataset.path= seq_length=seq_length, micro_batch_size=micro_batch_size, @@ -1140,5 +1168,7 @@ def qwen3_vl_8b_peft_energon_config(peft_scheme: str | PEFT = "lora") -> ConfigC """ cfg = qwen3_vl_8b_peft_config(peft_scheme=peft_scheme) hf_path = "Qwen/Qwen3-VL-8B-Instruct" - cfg.dataset = _make_energon_dataset(hf_path, 4096, cfg.train.micro_batch_size, cfg.train.global_batch_size) + cfg.dataset = _make_energon_dataset( + hf_path, cfg.model.seq_length, cfg.train.micro_batch_size, cfg.train.global_batch_size + ) return cfg diff --git a/src/megatron/bridge/training/utils/visual_inputs.py b/src/megatron/bridge/training/utils/visual_inputs.py index 76b008c996..aac2bf7a8f 100644 --- a/src/megatron/bridge/training/utils/visual_inputs.py +++ b/src/megatron/bridge/training/utils/visual_inputs.py @@ -62,9 +62,15 @@ class Qwen2_5_VLVisualInputs: # Image tensors, e.g., Qwen2.5-VL processor output. pixel_values: Optional[torch.Tensor] = None - # Per-image temporal/spatial grid metadata (T, H, W) for videos, Qwen2.5-VL. + # Video tensors, e.g., Qwen2.5-VL processor output. + pixel_values_videos: Optional[torch.Tensor] = None + + # Per-image (T, H, W) grid metadata. image_grid_thw: Optional[torch.Tensor] = None + # Per-video (T, H, W) grid metadata. + video_grid_thw: Optional[torch.Tensor] = None + def as_model_kwargs(self) -> dict[str, torch.Tensor]: """Return a mapping of non-None fields suitable for model forward kwargs.""" result: dict[str, torch.Tensor] = {} @@ -78,7 +84,9 @@ def normalized_for_model(self) -> dict[str, torch.Tensor]: """Return non-None fields with shapes normalized for model expectations. - pixel_values: [B, N, C, H, W] -> [B*N, C, H, W] + - pixel_values_videos: [B, N, C, H, W] -> [B*N, C, H, W] - image_grid_thw: [B, N, 3] -> [B*N, 3] + - video_grid_thw: [B, N, 3] -> [B*N, 3] """ kwargs = self.as_model_kwargs() @@ -87,10 +95,19 @@ def normalized_for_model(self) -> dict[str, torch.Tensor]: b, n, c, h, w = pixel_values.shape kwargs["pixel_values"] = pixel_values.view(b * n, c, h, w) + pixel_values_videos = kwargs.get("pixel_values_videos") + if isinstance(pixel_values_videos, torch.Tensor) and pixel_values_videos.dim() == 5: + b, n, c, h, w = pixel_values_videos.shape + kwargs["pixel_values_videos"] = pixel_values_videos.view(b * n, c, h, w) + image_grid_thw = kwargs.get("image_grid_thw") if isinstance(image_grid_thw, torch.Tensor) and image_grid_thw.dim() == 3: kwargs["image_grid_thw"] = image_grid_thw.view(-1, image_grid_thw.size(-1)) + video_grid_thw = kwargs.get("video_grid_thw") + if isinstance(video_grid_thw, torch.Tensor) and video_grid_thw.dim() == 3: + kwargs["video_grid_thw"] = video_grid_thw.view(-1, video_grid_thw.size(-1)) + return kwargs diff --git a/tests/unit_tests/recipes/qwen_vl/data/energon/test_task_encoder.py b/tests/unit_tests/recipes/qwen_vl/data/energon/test_task_encoder.py index bb4e83abbd..fc3e5cd143 100644 --- a/tests/unit_tests/recipes/qwen_vl/data/energon/test_task_encoder.py +++ b/tests/unit_tests/recipes/qwen_vl/data/energon/test_task_encoder.py @@ -21,6 +21,7 @@ import numpy as np import pytest import torch +from megatron.energon import SkipSample from PIL import Image from megatron.bridge.recipes.qwen_vl.data.energon.task_encoder import ( @@ -164,7 +165,7 @@ def setUp(self): self.encoder = QwenVLTaskEncoder( tokenizer=self.tokenizer, image_processor=self.image_processor, - max_padding_length=128, + max_padding_length=512, patch_size=14, spatial_merge_size=2, ) @@ -361,5 +362,212 @@ def test_encode_batch(self): self.assertNotIn("__subflavors__", encoded_dict) +class TestQwenVLTaskEncoderLimits(unittest.TestCase): + """Tests for the per-sample budget limits added to QwenVLTaskEncoder.""" + + def setUp(self): + self.tokenizer = MagicMock() + self.tokenizer.pad_token_id = 0 + self.tokenizer.eos_token_id = 1 + self.tokenizer.image_token_id = 151655 + self.tokenizer.video_token_id = 151656 + self.tokenizer.convert_tokens_to_ids.side_effect = lambda x: { + "": 151655, + "