Add Apple Silicon (MPS) support to inference scripts#7
Open
sencho17-gif wants to merge 9 commits into
Open
Conversation
Updated link in ReCo_ref entry for clarity.
Added links to inference results and additional context for ReCo_ref in the README.
Updated the release notes formatting and added checkboxes for updates.
Added a new update entry for the release of ReCo_Ref-related code.
Added download statistics and community appreciation for ReCo-Data.
Switches the ReCo inference entry points from CUDA to PyTorch's MPS backend so they can run on Apple Silicon (M-series Macs). - inference_reco_single.py: enable PYTORCH_ENABLE_MPS_FALLBACK, drop CUDA-only seeding (torch.cuda.manual_seed / cudnn.deterministic) in favor of torch.mps.manual_seed, switch dtype from bfloat16 to float16 (bfloat16 has incomplete MPS support), and target device "mps" when constructing WanVideoPipeline. - inference_reco_single_ref.py and tools/step_1_inference_vace_diffusers_unip.py: target "mps" instead of "cuda" when moving the pipeline to device. - assets/replace_test_mps.txt: small single-prompt config used for smoke testing the replace task on MPS. Note: complementary changes are required inside the bundled DiffSynth-Studio submodule (forcing SDPA over flash-attn, etc.); those are not part of this commit since DiffSynth-Studio is a vendored project tracked separately. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
39cae6f to
2da8cc4
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
inference_reco_single.py,inference_reco_single_ref.py,tools/step_1_inference_vace_diffusers_unip.py) runnable on Apple Silicon by targeting PyTorch'smpsbackend instead ofcuda.bfloat16tofloat16(bfloat16 still has incomplete coverage on MPS) and replace the CUDA-only seeding (torch.cuda.manual_seed,torch.backends.cudnn.*) withtorch.mps.manual_seed.PYTORCH_ENABLE_MPS_FALLBACK=1so any operator that isn't yet MPS-native transparently falls back to CPU instead of erroring out.assets/replace_test_mps.txtcontaining a single short-video prompt to make it easy to verify the replace task end-to-end on a Mac.Why
The current entry points hard-code
device=\"cuda\"andtorch_dtype=torch.bfloat16, which prevents any inference on M-series Macs. None of the changes alter the CUDA code path semantically aside from the dtype default, anddeviceis still passed throughargs, so CUDA users can keep usingdevice=\"cuda\"/bfloat16unchanged.Test plan
python inference_reco_single.py --task_name replace --test_txt_file_name assets/replace_test_mps.txt --lora_ckpt all_ckpts/ReCo_ori_rank128-2025_m12_version.ckpton an M-series Mac with PyTorch >= 2.5 produces an output video without throwing a device/dtype error.device=\"cuda\"(no changes to that branch).Notes for reviewers
DiffSynth-Studio/submodule (forcing PyTorch SDPA over flash-attn, etc.). Those live in a separate change becauseDiffSynth-Studiois a vendored upstream project.🤖 Generated with Claude Code