Skip to content

Latest commit

 

History

History
155 lines (126 loc) · 12.7 KB

File metadata and controls

155 lines (126 loc) · 12.7 KB

Fast Wan2.2 I2V Inference on 8x NVIDIA H100 GPUs by Morphic 🚀 | 56.20% Boost

Comparison has been done by generating 1280*720 resolution videos with 40 steps and 81 frames image

Introduction

We have seen the rapid development of open-source Video Generation DiT models with MOE architectures, such as Wan2.2 and Wan2.1

It is very exciting to see these open source generation models are going to beat closed source benchmarks. However, the inference speed of these models is still a bottleneck for real-time applications and deployment.

In this article, we will explore how can we speed up inference timings of Wan2.2 for the I2V task using the following clever techniques:

  1. Flash-Attention 3
  2. TensorFloat32 tensor cores for Matrix Multiplication
  3. Quantization: int8_weight_only
  4. Magcache
  5. Torch.compile

We have developed a complete suite to test all these combinations which can be applied one by one or simultaneously together to achieve the fastest inference speed for Wan2.2 I2V tasks.

We set up our experiments on 8 x NVIDIA H100 GPUs

Baseline: Wan2.2 I2V on 8x H100 with Flash Attention 2

Since we are using H100 GPUS with a memory of 80GB each. The low noise and high noise models both do not fit entirely on each GPU. For this purpose, we build our solutions on Wan2.2 original github repository instead of diffusers to use FSDP (Fully Sharded Data Parallel) to make models fit into our memory. We also explore quantization in order to Bypass FSDP. In this article, we will focus on optimizing the inference speed of Wan2.2 I2V totally based on the original repository. In order to get started, simply clone the repo and install the requirements:

https://github.com/morphicfilms/wan2.2_optimizations.git
cd wan2.2_optimizations
pip install -r requirements.txt

Make sure to have the models downloaded:

pip install "huggingface_hub[cli]"
huggingface-cli download Wan-AI/Wan2.2-I2V-A14B --local-dir ./Wan2.2-I2V-A14B

To generate a video, simply use:

torchrun --nproc_per_node=8 generate.py --task i2v-A14B --size 1280*720 --ckpt_dir ./Wan2.2-I2V-A14B --image examples/i2v_input.JPG --dit_fsdp --t5_fsdp --ulysses_size 8 --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."

This is our baseline. Flash attention 2 is enabled by default and it takes 250.70s to generate 1 video with 1280x720 resolution in 40 inference steps with 81 frames.

⚙️Optimization Recipes:

Flash-Attention 3 ⚡

Hopper architectures perform pretty good with Flash Attention 3. To set up FA3 clone the following repo outside the directory of our repo and install via:

git clone https://github.com/Dao-AILab/flash-attention.git
pip install wheel
cd hopper
python setup.py install
export PYTHONPATH=$PWD

After setting up FA3 which might take some time you can rerun the inference as it is:

torchrun --nproc_per_node=8 generate.py --task i2v-A14B --size 1280*720 --ckpt_dir ./Wan2.2-I2V-A14B --image examples/i2v_input.JPG --dit_fsdp --t5_fsdp --ulysses_size 8 --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."

The inference time this time drops to 195.13s which is 22.16% boost alone!

TensorFloat32 tensor core optimization:

Pytorch allows us to set TF32 on matmul and on cuDNN to True. They are False by default. This gives better performance on matmul and convolutions on torch.float32 tensors by rounding input data to have 10 bits of mantissa. To use this optimization, simply pass in additional argument --tf32 True

torchrun --nproc_per_node=8 generate.py --task i2v-A14B --size 1280*720 --ckpt_dir ./Wan2.2-I2V-A14B --image examples/i2v_input.JPG --dit_fsdp --t5_fsdp --ulysses_size 8 --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside. --tf32 True"

which actually does the following:

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

This brings down our inference time to 159.55 seconds, which is a 36.36% boost from the baseline

Quantization

Quantization is yet another way to make inference fast and also allow both the high noise and low noise models fit in each H100 GPUs, thus bypassing FSDP entirely. We use '''int8_weight_only''' for our quantization config. ❗Note: Since the precision of models' weights have been reduced to int8, applying TensorFloat32 optimization shows no speed benefit with quantization as shown in the graph above. That's because now internal matmuls are being done on int8 instead of TF32 and there's nothing left for the tensor cores to optimize.

If you are not familiar with torchao quantization, you can refer to this documentation. Here, we simply install the latest torchao that is capable of quantizing Wan2.2 low noise & high noise models

pip install -U torchao

To use the quantization, simply pass in --quantize True:

torchrun --nproc_per_node=8 generate.py --task i2v-A14B --size 1280*720 --ckpt_dir ./Wan2.2-I2V-A14B --ulysses_size 8 --t5_fsdp --dit_fsdp --use_optimized_loading --prompt "some prompt." --save_file "output.mp4"  --quantize True

This does give us speed benefit when used with FA3, but not as that of TF32+FA3. It brings down inference time to 170.24 seconds, which is a 32.09% boost from the baseline.

Magcache

This less lossy cache method exploits an important concept of magnitude ratio of successive residual outputs. The ratio decreases monotonically, steadily in most timesteps while rapidly in the last several steps. It has been shown that it performs better than Teacache in both speed and video quality. The original implemetation only supports inference on 1 x H100. Same is the case with Teacache. We develope on the same concept to scale it to 8x H100s to push the limits and leverage the sequence parallelism along with it. We found that a setting of E012K2R20 (error threshold =012, K =2 and retention ratio = 0.2) works pretty good in maintaining video quality and providing speed benefit.

To use magcache, try the following:

torchrun --nproc_per_node=8 generate.py --task i2v-A14B --size 1280*720 --ckpt_dir ./Wan2.2-I2V-A14B --ulysses_size 8 --t5_fsdp --dit_fsdp --use_optimized_loading --prompt "some prompt." --save_file "output.mp4" --use_magcache --magcache_K 2 --magcache_thresh 0.12 --retention_ratio 0.2

This reduces inference time to 157.10 seconds which is nearly the same as that of TF32+FA3. But it TF32 is applied by passing the additional argument:

torchrun --nproc_per_node=8 generate.py --task i2v-A14B --size 1280*720 --ckpt_dir ./Wan2.2-I2V-A14B --ulysses_size 8 --t5_fsdp --dit_fsdp --use_optimized_loading --prompt "some prompt." --save_file "output.mp4" --use_magcache --magcache_K 2 --magcache_thresh 0.12 --retention_ratio 0.2 --tf32 True

Now, we can generate 1 video with 1280x720 resolution of 81 frames in 40 inference steps in 150.45 seconds which is a 39.99% boost. More inference boosts are shown in the graph above with other magcache settings such as E012K2R10 and E024K2R10. The higher the error threshold and K value (skip steps), the more it hurts the video quality. E012K2R20 is a good pick to preserve the quality.

Torch.compile

Last but not the least, we let low noise and high noise models to compile with max-autotune to push the boundaries. torch.compile with mode="max-autotune-no-cudagraphs" or mode="max-autotune" can help us to achieve the best performance by generating and selecting the best kernel for the model inference. If you are not familiar with torch.compile, you can look up the documentation to the official tutorial.

To compile both models, pass an extra argument --compile True:

torchrun --nproc_per_node=8 generate.py --task i2v-A14B --size 1280*720 --ckpt_dir ./Wan2.2-I2V-A14B --ulysses_size 8 --t5_fsdp --dit_fsdp --use_optimized_loading --prompt "some prompt." --save_file "output.mp4"  --compile True  

This brings the inference time down to We need to warm up the model first to measure the speedup correctly. The actual compilation happens during the first inference pass. We can also pass custom modes using the --compile_mode argument. Note that max-autotune will take a lot of time during the warmup because it looks for the most optimized kernel possible by enabling CUDA graphs by default on GPU. More information can be found here

❗Note that due to non compile-friendly distributed operations and dynamic slicing in the rope_apply function in the code, we have to let fullgraph=False to allow both models to compile without errors.

Benchmarks with Torch Compile 🚀

FA3+Torch Compile:

torchrun --nproc_per_node=8 generate.py --task i2v-A14B --size 1280*720 --ckpt_dir ./Wan2.2-I2V-A14B --ulysses_size 8 --t5_fsdp --dit_fsdp --use_optimized_loading --prompt "some prompt." --save_file "output.mp4"  --compile True  

Inference Time: 142.87 seconds

FA3+Quantization+Torch Compile:

torchrun --nproc_per_node=8 generate.py --task i2v-A14B --size 1280*720 --ckpt_dir ./Wan2.2-I2V-A14B --ulysses_size 8 --t5_fsdp --dit_fsdp --use_optimized_loading --prompt "some prompt." --save_file "output.mp4"  --quantize True --compile True

Inference Time: 142.40 seconds

FA3+TF32+Torch Compile:

torchrun --nproc_per_node=8 generate.py --task i2v-A14B --size 1280*720 --ckpt_dir ./Wan2.2-I2V-A14B --ulysses_size 8 --t5_fsdp --dit_fsdp --use_optimized_loading --prompt "some prompt." --save_file "output.mp4"  --tf32 True --compile True

Inference Time: 142.73 seconds

🔥🔥 Ultimate Super Fast Recipe: FA3+ TF32 + Magcache E012K2R20 + Torch Compile

Without compromising video quality, you can try the following max-autotune-no-cudagraphs compile mode:

torchrun --nproc_per_node=8 generate.py --task i2v-A14B --size 1280*720 --ckpt_dir ./Wan2.2-I2V-A14B --ulysses_size 8 --t5_fsdp --dit_fsdp --use_optimized_loading --prompt "some prompt." --save_file "output.mp4"  --tf32 True --use_magcache --magcache_K 2 --magcache_thresh 0.12 --retention_ratio 0.2 --compile True

Inference Time: 109.81s, 56.20% boost from baseline 🚀 You can always tweek the Magcache parameters to speed it up even further but that will start to show artifacts in the video. For example E024K2R10 gives inference time of 98.87 seconds.

Conclusion

In this article we showcase powerful optimization techniques to speed up latency of Wan2.2 which is a quality open source model. This setting allows production environments to cost efficiently run models like these faster without visual loss of video quality.