Skip to content

Add Apple Silicon (MPS) support to inference scripts#7

Open
sencho17-gif wants to merge 9 commits into
HiDream-ai:mainfrom
sencho17-gif:apple-silicon-mps
Open

Add Apple Silicon (MPS) support to inference scripts#7
sencho17-gif wants to merge 9 commits into
HiDream-ai:mainfrom
sencho17-gif:apple-silicon-mps

Conversation

@sencho17-gif

Copy link
Copy Markdown

Summary

  • Make the ReCo inference entry points (inference_reco_single.py, inference_reco_single_ref.py, tools/step_1_inference_vace_diffusers_unip.py) runnable on Apple Silicon by targeting PyTorch's mps backend instead of cuda.
  • Switch the default dtype on the MPS path from bfloat16 to float16 (bfloat16 still has incomplete coverage on MPS) and replace the CUDA-only seeding (torch.cuda.manual_seed, torch.backends.cudnn.*) with torch.mps.manual_seed.
  • Set PYTORCH_ENABLE_MPS_FALLBACK=1 so any operator that isn't yet MPS-native transparently falls back to CPU instead of erroring out.
  • Add a tiny smoke-test config assets/replace_test_mps.txt containing a single short-video prompt to make it easy to verify the replace task end-to-end on a Mac.

Why

The current entry points hard-code device=\"cuda\" and torch_dtype=torch.bfloat16, which prevents any inference on M-series Macs. None of the changes alter the CUDA code path semantically aside from the dtype default, and device is still passed through args, so CUDA users can keep using device=\"cuda\"/bfloat16 unchanged.

Test plan

  • python inference_reco_single.py --task_name replace --test_txt_file_name assets/replace_test_mps.txt --lora_ckpt all_ckpts/ReCo_ori_rank128-2025_m12_version.ckpt on an M-series Mac with PyTorch >= 2.5 produces an output video without throwing a device/dtype error.
  • Existing CUDA runs are unaffected when the user passes device=\"cuda\" (no changes to that branch).

Notes for reviewers

  • This is the entry-point half of the MPS port. Functional execution also requires complementary patches inside the vendored DiffSynth-Studio/ submodule (forcing PyTorch SDPA over flash-attn, etc.). Those live in a separate change because DiffSynth-Studio is a vendored upstream project.
  • This PR has not yet been verified end-to-end against a finished inference run; the Wan2.1-VACE-1.3B weight download was still in progress when the PR was opened. Marking as draft-ready for early review of the entry-point delta.

🤖 Generated with Claude Code

zhw-zhang and others added 9 commits April 25, 2026 13:53
Updated link in ReCo_ref entry for clarity.
Added links to inference results and additional context for ReCo_ref in the README.
Updated the release notes formatting and added checkboxes for updates.
Added a new update entry for the release of ReCo_Ref-related code.
Added download statistics and community appreciation for ReCo-Data.
Switches the ReCo inference entry points from CUDA to PyTorch's MPS
backend so they can run on Apple Silicon (M-series Macs).

- inference_reco_single.py: enable PYTORCH_ENABLE_MPS_FALLBACK, drop
  CUDA-only seeding (torch.cuda.manual_seed / cudnn.deterministic)
  in favor of torch.mps.manual_seed, switch dtype from bfloat16 to
  float16 (bfloat16 has incomplete MPS support), and target device
  "mps" when constructing WanVideoPipeline.
- inference_reco_single_ref.py and tools/step_1_inference_vace_diffusers_unip.py:
  target "mps" instead of "cuda" when moving the pipeline to device.
- assets/replace_test_mps.txt: small single-prompt config used for
  smoke testing the replace task on MPS.

Note: complementary changes are required inside the bundled
DiffSynth-Studio submodule (forcing SDPA over flash-attn, etc.); those
are not part of this commit since DiffSynth-Studio is a vendored
project tracked separately.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@zhw-zhang zhw-zhang force-pushed the main branch 3 times, most recently from 39cae6f to 2da8cc4 Compare May 26, 2026 15:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants