multigpu: use unet_manual_cast for SelectModelDevice compute dtype#14108
Conversation
Replace the hardcoded `_force_fp32_cpu_compute` helper with`_force_supported_compute_dtype`, which delegates to`comfy.model_management.unet_manual_cast(weight_dtype, device)`. The interrogator already encodes per-device dtype support (CPU returns False for fp16/bf16, older GPUs may not support bf16, pre-14 MPS doesn't support bf16, etc.) and returns None when no cast is needed.For SelectModelDevice -> CPU on an fp16/bf16 model, behavior is unchanged: `unet_manual_cast` returns `torch.float32` and `set_model_compute_dtype` casts at use without touching peak memory. As a bonus the same code path now handles other `weight_dtype not supported on device` cases (e.g. bf16 weights on pre-Ampere NVIDIA, bf16 on pre-macOS-14 MPS) without growing the code surface, so the call site no longer needs the `if resolved.type == 'cpu':` gate. Amp-Thread-ID: https://ampcode.com/threads/T-019e61db-ffb1-73a6-b2a8-3d23d7b05792 Co-authored-by: Amp <amp@ampcode.com>
Amp-Thread-ID: https://ampcode.com/threads/T-019e61db-ffb1-73a6-b2a8-3d23d7b05792 Co-authored-by: Amp <amp@ampcode.com>
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
📝 WalkthroughWalkthroughThis PR replaces the CPU-specific helper 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Amp-Thread-ID: https://ampcode.com/threads/T-019e61db-ffb1-73a6-b2a8-3d23d7b05792 Co-authored-by: Amp <amp@ampcode.com>
Follow-up to PR #7063.
What
Replace the hardcoded
_force_fp32_cpu_computehelper inSelectModelDeviceNodewith_force_supported_compute_dtype, which delegates tocomfy.model_management.unet_manual_cast(weight_dtype, device)instead of testing forcpuand casting totorch.float32by hand.Why
unet_manual_castalready encodes per-device dtype support:should_use_fp16/should_use_bf16returnFalse-> cast dtype =torch.float32torch.float16torch.float32None(no-op)For the existing
SelectModelDevice -> CPUcase on an fp16/bf16 model, the behavior is unchanged:unet_manual_castreturnstorch.float32andset_model_compute_dtypecasts at use without inflating peak memory. As a bonus, the same code path now handles other 'weight dtype not supported on device' cases without growing the code surface, so theif resolved.type == 'cpu':gate at the call site goes away.CPU verification trace
unet_manual_cast(weight, cpu)torch.float32None(early return)torch.float16torch.float32(fp16/bf16 rejected on CPU)torch.bfloat16torch.float32(fp16/bf16 rejected on CPU)