Skip to content

[Feat] Add cached_decode dist_2d#883

Merged
helloyongyang merged 1 commit intomainfrom
tmp/dist_vae_decode_withflag
Feb 5, 2026
Merged

[Feat] Add cached_decode dist_2d#883
helloyongyang merged 1 commit intomainfrom
tmp/dist_vae_decode_withflag

Conversation

@wangshankun
Copy link
Collaborator

No description provided.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @wangshankun, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the VAE's decoding capabilities by introducing a distributed 2D decoding mechanism. This allows the model to process large latent representations in parallel across multiple computational ranks, thereby improving performance and scalability, especially in distributed environments. The changes involve adding a new function to manage the distributed logic and updating an existing method to integrate this new functionality based on a parallel execution flag.

Highlights

  • New Distributed 2D Decoding Function: A new method, cached_decode_dist_2d_withflag, has been introduced. This function is responsible for handling the parallel decoding of 2D latent representations by calculating appropriate slices, extracting chunks, decoding them with padding, and then reconstructing the full image from gathered chunks across distributed processes.
  • Integration of Distributed Decoding: The existing cached_decode_withflag method has been modified to conditionally utilize the new distributed decoding logic. When the self.parallel flag is enabled, it orchestrates the distributed decoding process by calling cached_decode_dist_2d_withflag; otherwise, it defaults to the standard, non-distributed decoding.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • lightx2v/models/video_encoders/hf/wan/vae.py
    • Implemented cached_decode_dist_2d_withflag to perform distributed 2D latent decoding, including logic for chunking, padded decoding, and aggregating results from multiple processes.
    • Updated cached_decode_withflag to conditionally invoke the new distributed decoding function when self.parallel is active, enabling parallel processing of image decoding.
Activity
  • No specific activity (comments, reviews, or progress updates) has been recorded for this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new distributed cached decoding function, cached_decode_dist_2d_withflag, and integrates it for parallel execution. The changes look good overall. My feedback focuses on improving code maintainability and robustness by addressing a hardcoded value and duplicated logic within the new function.

Comment on lines +1453 to +1501
# Calculate H dimension slice
if cur_rank_h == 0:
h_start = 0
h_end = chunk_h + 2 * padding_size
elif cur_rank_h == world_size_h - 1:
h_start = total_h - (chunk_h + 2 * padding_size)
h_end = total_h
else:
h_start = cur_rank_h * chunk_h - padding_size
h_end = (cur_rank_h + 1) * chunk_h + padding_size

# Calculate W dimension slice
if cur_rank_w == 0:
w_start = 0
w_end = chunk_w + 2 * padding_size
elif cur_rank_w == world_size_w - 1:
w_start = total_w - (chunk_w + 2 * padding_size)
w_end = total_w
else:
w_start = cur_rank_w * chunk_w - padding_size
w_end = (cur_rank_w + 1) * chunk_w + padding_size

# Extract the latent chunk for this process
zs_chunk = zs[:, :, h_start:h_end, w_start:w_end].contiguous()

# Decode the chunk
images_chunk = self.model.cached_decode_withflag(zs_chunk.unsqueeze(0), self.scale, is_first, is_last)

# Remove padding from decoded chunk
spatial_ratio = 8
if cur_rank_h == 0:
decoded_h_start = 0
decoded_h_end = chunk_h * spatial_ratio
elif cur_rank_h == world_size_h - 1:
decoded_h_start = images_chunk.shape[3] - chunk_h * spatial_ratio
decoded_h_end = images_chunk.shape[3]
else:
decoded_h_start = padding_size * spatial_ratio
decoded_h_end = images_chunk.shape[3] - padding_size * spatial_ratio

if cur_rank_w == 0:
decoded_w_start = 0
decoded_w_end = chunk_w * spatial_ratio
elif cur_rank_w == world_size_w - 1:
decoded_w_start = images_chunk.shape[4] - chunk_w * spatial_ratio
decoded_w_end = images_chunk.shape[4]
else:
decoded_w_start = padding_size * spatial_ratio
decoded_w_end = images_chunk.shape[4] - padding_size * spatial_ratio
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for calculating slices is duplicated for both the height (h) and width (w) dimensions. This occurs for both the initial latent tensor slicing (lines 1453-1473) and for removing padding from the decoded chunk (lines 1483-1501).

This duplication makes the code harder to read and maintain. Consider refactoring this logic into a private helper method to reduce redundancy and improve clarity. For example, a function like _calculate_slice(total_dim, chunk_dim, padding, rank, world_size) could handle the latent slicing, and a similar one for the decoded chunk slicing.

images_chunk = self.model.cached_decode_withflag(zs_chunk.unsqueeze(0), self.scale, is_first, is_last)

# Remove padding from decoded chunk
spatial_ratio = 8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The spatial_ratio is hardcoded to 8. This value is related to the model's architecture (specifically, the downsampling factor). Using a hardcoded value can lead to bugs if the model architecture changes. It would be more robust to use the value from the model configuration, like self.model.spatial_compression_ratio.

Suggested change
spatial_ratio = 8
spatial_ratio = self.model.spatial_compression_ratio

@helloyongyang helloyongyang merged commit d61489d into main Feb 5, 2026
2 checks passed
@helloyongyang helloyongyang deleted the tmp/dist_vae_decode_withflag branch February 5, 2026 12:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants