Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
60294bf
workable code
May 5, 2026
33ddcd9
adding inference code for Qwen3 for multi-images and video
May 6, 2026
ca928cb
resolve conflict
May 6, 2026
9d1fed0
style: fix ruff-format line-length violations flagged by CI
May 6, 2026
67e6f24
style: apply ruff-format reformats and remove debug prints
May 6, 2026
c1555c0
style: fix remaining ruff-format violations in task_encoder.py
May 6, 2026
78aca60
Merge remote-tracking branch 'origin/main' into huvu/vlm_energon
May 6, 2026
859eb86
[recipe] test: Update QwenVL task encoder test mocks for torch tensors
huvunvidia May 7, 2026
4623a8f
Merge remote-tracking branch 'origin/main' into huvu/vlm_energon
huvunvidia May 7, 2026
bca268d
[docs, recipe] docs: Document Qwen3-VL visual token budget controls
May 7, 2026
3b48e4e
[recipe] fix: Add missing blank line before module-level comment block
May 7, 2026
8427477
[recipe] test: Add unit tests for QwenVL task encoder limits, provide…
May 7, 2026
21025ca
Merge remote-tracking branch 'origin/main' into huvu/vlm_energon
May 7, 2026
05bbdd1
[ckpt] test: Add unit tests for VLM generate utils multi-image and vi…
huvunvidia May 8, 2026
496899d
Merge remote-tracking branch 'origin/main' into huvu/vlm_energon
May 8, 2026
a69be19
Merge remote-tracking branch 'origin/main' into huvu/vlm_energon
May 8, 2026
c704146
fix lint
May 8, 2026
1fde946
[recipe, data, docs] fix: address review feedback on Qwen-VL energon …
May 12, 2026
2b7a0d0
Merge remote-tracking branch 'origin/main' into huvu/vlm_energon
May 12, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 61 additions & 18 deletions examples/conversion/hf_to_megatron_generate_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
pad_input_ids_to_tp_multiple,
patch_kimi_vision_processor,
process_image_inputs,
process_multi_image_inputs,
process_video_inputs,
to_cuda,
)

Expand All @@ -64,6 +66,8 @@ def __init__(
image_grid_thw=None,
image_sizes=None,
mm_token_type_ids=None,
pixel_values_videos=None,
video_grid_thw=None,
image_position_ids=None,
):
self.batch = dict(
Expand All @@ -79,6 +83,10 @@ def __init__(
self.batch["image_sizes"] = image_sizes
if mm_token_type_ids is not None:
self.batch["mm_token_type_ids"] = mm_token_type_ids
if pixel_values_videos is not None:
self.batch["pixel_values_videos"] = pixel_values_videos
if video_grid_thw is not None:
self.batch["video_grid_thw"] = video_grid_thw
if image_position_ids is not None:
self.batch["image_position_ids"] = image_position_ids
self._yielded = False
Expand All @@ -101,7 +109,15 @@ def vlm_forward_step(data_iterator, model, **kwargs) -> torch.Tensor:
"position_ids": batch["position_ids"],
"attention_mask": batch.get("attention_mask"),
}
for key in ("pixel_values", "image_grid_thw", "image_sizes", "mm_token_type_ids", "image_position_ids"):
for key in (
"pixel_values",
"image_grid_thw",
"image_sizes",
"mm_token_type_ids",
"pixel_values_videos",
"video_grid_thw",
"image_position_ids",
):
if key in batch:
forward_args[key] = batch[key]

Expand Down Expand Up @@ -215,27 +231,41 @@ def _disable_mtp(m):
# ------------------------------------------------------------------
# Process inputs
# ------------------------------------------------------------------
pixel_values = image_grid_thw = image_sizes = mm_token_type_ids = image_position_ids = None
pixel_values_videos = video_grid_thw = None

if args.video_path:
input_ids_raw, pixel_values_videos, video_grid_thw = process_video_inputs(
processor, args.video_path, args.prompt, fps=args.video_fps
)
elif args.image_paths:
input_ids_raw, pixel_values, image_grid_thw = process_multi_image_inputs(
processor, args.image_paths, args.prompt
)
else:
(
input_ids_raw,
pixel_values,
image_grid_thw,
image_sizes,
mm_token_type_ids,
image_position_ids,
) = process_image_inputs(
processor,
args.image_path,
args.prompt,
is_gemma4=is_gemma4,
is_kimi=is_kimi,
image_token_id=image_token_id,
)

(
input_ids_raw,
pixel_values,
image_grid_thw,
image_sizes,
mm_token_type_ids,
image_position_ids,
) = process_image_inputs(
processor,
args.image_path,
args.prompt,
is_gemma4=is_gemma4,
is_kimi=is_kimi,
image_token_id=image_token_id,
)
input_ids_raw = input_ids_raw.cuda()
pixel_values = to_cuda(pixel_values)
image_grid_thw = to_cuda(image_grid_thw)
image_sizes = to_cuda(image_sizes)
mm_token_type_ids = to_cuda(mm_token_type_ids)
pixel_values_videos = to_cuda(pixel_values_videos)
video_grid_thw = to_cuda(video_grid_thw)
image_position_ids = to_cuda(image_position_ids)

# ------------------------------------------------------------------
Expand Down Expand Up @@ -270,7 +300,9 @@ def _disable_mtp(m):
image_grid_thw,
image_sizes,
mm_ids_padded,
image_position_ids,
pixel_values_videos=pixel_values_videos,
video_grid_thw=video_grid_thw,
image_position_ids=image_position_ids,
)

output = fwd_bwd_function(
Expand Down Expand Up @@ -343,7 +375,18 @@ def _disable_mtp(m):
"--pp_layout", type=str, default=None, help="Pipeline model parallel layout (e.g. 'Et*15|t*15|t*16|t*15L')"
)
parser.add_argument("--megatron_model_path", type=str, default=None, help="Path to Megatron model checkpoint")
parser.add_argument("--image_path", type=str, default=None, help="Path or URL to image (optional).")
parser.add_argument("--image_path", type=str, default=None, help="Path or URL to a single image (optional).")
parser.add_argument(
"--image_paths",
type=str,
nargs="+",
default=None,
help="Paths to N image files in order (multi-image; Qwen-family only).",
)
parser.add_argument("--video_path", type=str, default=None, help="Path to a video file (Qwen-family only).")
parser.add_argument(
"--video_fps", type=float, default=2.0, help="Frames per second to sample from the video (default: 2.0)."
)
parser.add_argument("--trust_remote_code", action="store_true", help="Trust remote code for HF model loading")
args = parser.parse_args()

Expand Down
55 changes: 55 additions & 0 deletions examples/conversion/vlm_generate_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,61 @@ def to_cuda(x):
return x.cuda()


def process_multi_image_inputs(processor, image_paths: list[str], prompt: str):
"""Process N ordered images + prompt into model inputs (Qwen-family).

Returns:
(input_ids, pixel_values, image_grid_thw)
"""
if not _HAS_QWEN_VL_UTILS:
raise ImportError("qwen-vl-utils required: pip install qwen-vl-utils")
pils = [load_image(p).convert("RGB") for p in image_paths]
messages = [
{
"role": "user",
"content": [{"type": "image", "image": p} for p in pils] + [{"type": "text", "text": prompt}],
}
]
image_inputs, video_inputs = process_vision_info(messages)
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt")
return inputs.input_ids, inputs.get("pixel_values"), inputs.get("image_grid_thw")


def process_video_inputs(processor, video_path: str, prompt: str, *, fps: float = 2.0):
"""Process a video + prompt into model inputs (Qwen-family).

Frame decoding mirrors the Qwen3-VL training pipeline: fetch_video decodes at
``fps``, then video_processor is called with do_sample_frames=False to use the
pre-decoded frames as-is.

Returns:
(input_ids, pixel_values_videos, video_grid_thw)
"""
if not _HAS_QWEN_VL_UTILS:
raise ImportError("qwen-vl-utils required: pip install qwen-vl-utils")
from qwen_vl_utils import fetch_video

frames = fetch_video({"video": video_path, "fps": fps})
messages = [
{
"role": "user",
"content": [{"type": "video"}, {"type": "text", "text": prompt}],
}
]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
text_inputs = processor(text=[text], padding=True, return_tensors="pt")
video_proc = processor.video_processor(videos=[frames], return_tensors="pt", do_sample_frames=False)
# processor(text=...) without videos produces a single <|video_pad|> placeholder (id 151656).
# Pre-expand to match actual vision feature count so PP send/recv shapes are correct.
input_ids = pre_expand_image_tokens(
text_inputs["input_ids"],
video_proc["video_grid_thw"],
image_token_id=151656, # <|video_pad|> for Qwen-VL family
)
return input_ids, video_proc.get("pixel_values_videos"), video_proc.get("video_grid_thw")


def process_image_inputs(
processor,
image_path: Optional[str],
Expand Down
5 changes: 5 additions & 0 deletions examples/models/vlm/qwen3_vl/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,11 @@ field_map:

Then, update the dataset path (`dataset.path=/path/to/energon/dataset`) in [peft_energon.sh](peft_energon.sh) and run the script.

#### Controlling visual tokens computation budget
Three independent CLI-overridable controls bound a sample's GPU cost. They compose:
- **`dataset.min_pixels` / `dataset.max_pixels`** — image/frame resolutions lower and upper bound (defaults `200704` / `1003520`).
- **`dataset.max_num_images` / `dataset.max_num_frames`** - limit count of images/frames (defaults `10` / `60`). Too many images → sample is dropped. Too many frames → frame list truncated.
- **`dataset.max_visual_tokens`** — limit total visual tokens across all images and frames in a sample, computed post-rescaling as `prod(T,H,W) // merge_size²` (default `16384`; set to `None` to disable). Catches cases the other two miss (few images at high resolution, or many at low resolution). Exceeding samples are dropped.

### Expected Training Dynamics
We provide a [Weights & Biases report](https://api.wandb.ai/links/nvidia-nemo-fw-public/lczz4ixx) for the expected loss curves and grad norms.
Expand Down
72 changes: 69 additions & 3 deletions src/megatron/bridge/recipes/qwen_vl/data/energon/task_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import numpy as np
import torch
from megatron.energon import Batch, DefaultTaskEncoder
from megatron.energon import Batch, DefaultTaskEncoder, SkipSample
from transformers import BatchEncoding

from megatron.bridge.data.energon.task_encoder_utils import (
Expand Down Expand Up @@ -56,7 +56,10 @@ def process_vision(
image_grid_thw = None

if videos is not None:
videos_inputs = processor(images=None, text="", videos=videos, return_tensors="pt")
# Pre-decoded frames from WDS are already at the desired sampling rate.
# do_sample_frames=False prevents the processor from re-sampling them under
# a spurious 24 fps assumption, which would reduce most clips to T=2.
videos_inputs = processor.video_processor(videos=videos, return_tensors="pt", do_sample_frames=False)
video_grid_thw = videos_inputs.get("video_grid_thw", None)
else:
videos_inputs = {}
Expand Down Expand Up @@ -168,6 +171,9 @@ def __init__(
max_padding_length: int = 4096,
min_pixels: int = 200704,
max_pixels: int = 1003520,
max_num_images: int | None = 10,
max_num_frames: int | None = 60,
max_visual_tokens: int | None = 16384,
):
super().__init__()

Expand All @@ -176,6 +182,9 @@ def __init__(
self.seq_length = max_padding_length
self.min_pixels = min_pixels
self.max_pixels = max_pixels
self.max_num_images = max_num_images
self.max_num_frames = max_num_frames
self.max_visual_tokens = max_visual_tokens

self.temporal_patch_size = temporal_patch_size
self.merge_size = spatial_merge_size
Expand All @@ -202,6 +211,32 @@ def encode_sample(self, sample: ChatMLSample):
videos_for_processing = (
_videos_to_pil(sample.videos) if sample.videos is not None and len(sample.videos) > 0 else None
)

if self.max_num_images is not None and imgs_for_processing is not None:
if len(imgs_for_processing) > self.max_num_images:
logging.warning(
"Skipping sample %s: %d images exceeds max_num_images=%d",
sample.__key__,
len(imgs_for_processing),
self.max_num_images,
)
raise SkipSample()

if self.max_num_frames is not None and videos_for_processing is not None:
clipped = []
for v in videos_for_processing:
if len(v) > self.max_num_frames:
logging.warning(
"Truncating %d frames to max_num_frames=%d for sample %s",
len(v),
self.max_num_frames,
sample.__key__,
)
clipped.append(v[: self.max_num_frames])
else:
clipped.append(v)
videos_for_processing = clipped

processed_vision = process_vision(
self.image_processor,
imgs_for_processing,
Expand All @@ -214,6 +249,20 @@ def encode_sample(self, sample: ChatMLSample):
flattened_imgs = processed_vision["image_inputs"]
flattened_videos = processed_vision["video_inputs"]

merge_length = self.merge_size**2
image_tokens = int(image_thw_grids.prod(-1).sum()) // merge_length if image_thw_grids is not None else 0
video_tokens = int(video_thw_grids.prod(-1).sum()) // merge_length if video_thw_grids is not None else 0
total_visual_tokens = image_tokens + video_tokens
if self.max_visual_tokens is not None:
if total_visual_tokens > self.max_visual_tokens:
logging.warning(
"Skipping sample %s: %d visual tokens exceeds max_visual_tokens=%d",
sample.__key__,
total_visual_tokens,
self.max_visual_tokens,
)
raise SkipSample()

# Normalize conversation to [{"role": ..., "content": ...}, ...]
conversation = cook_chatml_sample(sample.conversation)

Expand Down Expand Up @@ -287,7 +336,22 @@ def encode_sample(self, sample: ChatMLSample):
target_length = input_ids.shape[0]

if target_length > self.seq_len:
logging.warning(f"Long sequence with length {target_length} found, dropped...")
Comment thread
huvunvidia marked this conversation as resolved.
if total_visual_tokens > self.seq_len:
logging.warning(
"Sample %s: target_length=%d with visual_tokens=%d exceeds seq_len=%d; "
"truncation would corrupt visual tokens, dropping sample.",
sample.__key__,
target_length,
total_visual_tokens,
self.seq_len,
)
raise SkipSample()
logging.warning(
"Sample %s: target_length=%d exceeds seq_len=%d; text will be truncated.",
sample.__key__,
target_length,
self.seq_len,
)
final_input_ids = np.zeros(target_length, dtype=input_ids.dtype)
final_input_masks = final_input_ids.copy()

Expand Down Expand Up @@ -435,7 +499,9 @@ def encode_batch(self, batch: QwenVLTaskBatch) -> dict:

raw["visual_inputs"] = Qwen2_5_VLVisualInputs(
pixel_values=batch.pixel_values,
pixel_values_videos=batch.pixel_values_videos,
image_grid_thw=batch.image_grid_thw,
video_grid_thw=batch.video_grid_thw,
)

return raw
Loading
Loading