Skip to content

multigpu: use unet_manual_cast for SelectModelDevice compute dtype#14108

Merged
comfyanonymous merged 3 commits into
masterfrom
smd-use-unet-manual-cast
May 26, 2026
Merged

multigpu: use unet_manual_cast for SelectModelDevice compute dtype#14108
comfyanonymous merged 3 commits into
masterfrom
smd-use-unet-manual-cast

Conversation

@Kosinkadink
Copy link
Copy Markdown
Member

Follow-up to PR #7063.

What

Replace the hardcoded _force_fp32_cpu_compute helper in SelectModelDeviceNode with _force_supported_compute_dtype, which delegates to comfy.model_management.unet_manual_cast(weight_dtype, device) instead of testing for cpu and casting to torch.float32 by hand.

Why

unet_manual_cast already encodes per-device dtype support:

  • CPU: should_use_fp16/should_use_bf16 return False -> cast dtype = torch.float32
  • pre-Ampere NVIDIA with bf16 weights: cast dtype = torch.float16
  • pre-macOS-14 MPS with bf16 weights: cast dtype = torch.float32
  • weight dtype already supported on the device: returns None (no-op)

For the existing SelectModelDevice -> CPU case on an fp16/bf16 model, the behavior is unchanged: unet_manual_cast returns torch.float32 and set_model_compute_dtype casts at use without inflating peak memory. As a bonus, the same code path now handles other 'weight dtype not supported on device' cases without growing the code surface, so the if resolved.type == 'cpu': gate at the call site goes away.

CPU verification trace

weight dtype unet_manual_cast(weight, cpu) action
torch.float32 None (early return) no-op
torch.float16 torch.float32 (fp16/bf16 rejected on CPU) set fp32 compute
torch.bfloat16 torch.float32 (fp16/bf16 rejected on CPU) set fp32 compute

Replace the hardcoded `_force_fp32_cpu_compute` helper with`_force_supported_compute_dtype`, which delegates to`comfy.model_management.unet_manual_cast(weight_dtype, device)`. The interrogator already encodes per-device dtype support (CPU returns False for fp16/bf16, older GPUs may not support bf16, pre-14 MPS doesn't support bf16, etc.) and returns None when no cast is needed.For SelectModelDevice -> CPU on an fp16/bf16 model, behavior is unchanged: `unet_manual_cast` returns `torch.float32` and `set_model_compute_dtype` casts at use without touching peak memory. As a bonus the same code path now handles other `weight_dtype not supported on device` cases (e.g. bf16 weights on pre-Ampere NVIDIA, bf16 on pre-macOS-14 MPS) without growing the code surface, so the call site no longer needs the `if resolved.type == 'cpu':` gate.

Amp-Thread-ID: https://ampcode.com/threads/T-019e61db-ffb1-73a6-b2a8-3d23d7b05792
Co-authored-by: Amp <amp@ampcode.com>
@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented May 26, 2026

Review Change Stack

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 2cc142bc-3030-4925-9d5b-8b9b6ecb1cc3

📥 Commits

Reviewing files that changed from the base of the PR and between 4ca4d39 and f663018.

📒 Files selected for processing (1)
  • comfy_extras/nodes_multigpu.py

📝 Walkthrough

Walkthrough

This PR replaces the CPU-specific helper _force_fp32_cpu_compute with _force_supported_compute_dtype(patcher, device), which queries comfy.model_management.unet_manual_cast for a device-appropriate compute dtype and calls patcher.set_model_compute_dtype only when a cast is needed. SelectModelDeviceNode.execute is updated to call this helper for any resolved non-null device immediately after retargeting the patcher.

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately summarizes the main change: replacing a hardcoded CPU dtype helper with a device-agnostic one that uses unet_manual_cast.
Description check ✅ Passed The description is directly related to the changeset, explaining the motivation, implementation details, and providing verification traces for the dtype handling behavior.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@comfyanonymous comfyanonymous merged commit 88956e7 into master May 26, 2026
16 checks passed
@comfyanonymous comfyanonymous deleted the smd-use-unet-manual-cast branch May 26, 2026 03:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants