From 4ca4d390760a8f68bee73d20fa4904873f67dc14 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 25 May 2026 19:54:22 -0700 Subject: [PATCH 1/3] multigpu: use unet_manual_cast for SelectModelDevice compute dtype Replace the hardcoded `_force_fp32_cpu_compute` helper with`_force_supported_compute_dtype`, which delegates to`comfy.model_management.unet_manual_cast(weight_dtype, device)`. The interrogator already encodes per-device dtype support (CPU returns False for fp16/bf16, older GPUs may not support bf16, pre-14 MPS doesn't support bf16, etc.) and returns None when no cast is needed.For SelectModelDevice -> CPU on an fp16/bf16 model, behavior is unchanged: `unet_manual_cast` returns `torch.float32` and `set_model_compute_dtype` casts at use without touching peak memory. As a bonus the same code path now handles other `weight_dtype not supported on device` cases (e.g. bf16 weights on pre-Ampere NVIDIA, bf16 on pre-macOS-14 MPS) without growing the code surface, so the call site no longer needs the `if resolved.type == 'cpu':` gate. Amp-Thread-ID: https://ampcode.com/threads/T-019e61db-ffb1-73a6-b2a8-3d23d7b05792 Co-authored-by: Amp --- comfy_extras/nodes_multigpu.py | 33 ++++++++++++++++++++------------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py index 2bd752b7da47..0bd5f2995dae 100644 --- a/comfy_extras/nodes_multigpu.py +++ b/comfy_extras/nodes_multigpu.py @@ -48,17 +48,25 @@ def execute(cls, model: ModelPatcher, max_gpus: int) -> io.NodeOutput: return io.NodeOutput(model) -def _force_fp32_cpu_compute(patcher: ModelPatcher): - """Force fp32 inference dtype for CPU. - - PyTorch's CPU conv2d kernels fall back to software emulation for fp16/bf16 - and run ~500-600x slower than fp32, which makes a normal-sized workflow - look frozen for hours. Routing through set_model_compute_dtype leaves the - weights as-is and casts at use, so peak memory does not blow up.""" - dtype = patcher.model_dtype() - if dtype in (torch.float16, torch.bfloat16): - logging.info(f"Select Model Device: using fp32 compute dtype for CPU inference (model dtype was {dtype}).") - patcher.set_model_compute_dtype(torch.float32) +def _force_supported_compute_dtype(patcher: ModelPatcher, device: torch.device): + """Ensure the patcher's compute dtype is one the target device actually supports. + + Defers to comfy.model_management.unet_manual_cast, which already encodes + per-device dtype support (CPU returns False for fp16/bf16, older GPUs may + not support bf16, pre-14 MPS doesn't support bf16, etc.). It returns None + when the weight dtype is already fine and the cast dtype otherwise. + + Concrete motivation: PyTorch's CPU conv2d kernels emulate fp16/bf16 in + software (~500-600x slower than fp32), so SelectModelDevice -> CPU on an + fp16 model would otherwise look frozen for hours. Routing through + set_model_compute_dtype leaves the weights as-is and casts at use, so peak + memory does not blow up.""" + weight_dtype = patcher.model_dtype() + cast_dtype = comfy.model_management.unet_manual_cast(weight_dtype, device) + if cast_dtype is None: + return + logging.info(f"Select Model Device: using {cast_dtype} compute dtype on {device} (model weight dtype was {weight_dtype}).") + patcher.set_model_compute_dtype(cast_dtype) def _remember_base_devices(patcher: ModelPatcher): @@ -229,8 +237,7 @@ def execute(cls, model: ModelPatcher, device: str = "default") -> io.NodeOutput: logging.warning(f"Select Model Device: cannot retarget model, passing through unchanged. ({e})") return io.NodeOutput(model) if resolved is not None: - if resolved.type == "cpu": - _force_fp32_cpu_compute(model) + _force_supported_compute_dtype(model, resolved) _prune_multigpu_collision(model, model.load_device) return io.NodeOutput(model) From 8969bbbf0252080b3e5186e2150d626dd98cd5b4 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 25 May 2026 19:56:07 -0700 Subject: [PATCH 2/3] multigpu: shorten _force_supported_compute_dtype docstring Amp-Thread-ID: https://ampcode.com/threads/T-019e61db-ffb1-73a6-b2a8-3d23d7b05792 Co-authored-by: Amp --- comfy_extras/nodes_multigpu.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py index 0bd5f2995dae..878d85baf11f 100644 --- a/comfy_extras/nodes_multigpu.py +++ b/comfy_extras/nodes_multigpu.py @@ -49,18 +49,11 @@ def execute(cls, model: ModelPatcher, max_gpus: int) -> io.NodeOutput: def _force_supported_compute_dtype(patcher: ModelPatcher, device: torch.device): - """Ensure the patcher's compute dtype is one the target device actually supports. - - Defers to comfy.model_management.unet_manual_cast, which already encodes - per-device dtype support (CPU returns False for fp16/bf16, older GPUs may - not support bf16, pre-14 MPS doesn't support bf16, etc.). It returns None - when the weight dtype is already fine and the cast dtype otherwise. - - Concrete motivation: PyTorch's CPU conv2d kernels emulate fp16/bf16 in - software (~500-600x slower than fp32), so SelectModelDevice -> CPU on an - fp16 model would otherwise look frozen for hours. Routing through - set_model_compute_dtype leaves the weights as-is and casts at use, so peak - memory does not blow up.""" + """Cast compute dtype to one the device supports; no-op if already supported. + + Uses unet_manual_cast which encodes per-device dtype support (e.g. CPU + rejects fp16/bf16, falling back to fp32 to avoid PyTorch's ~500-600x + slower software emulation).""" weight_dtype = patcher.model_dtype() cast_dtype = comfy.model_management.unet_manual_cast(weight_dtype, device) if cast_dtype is None: From f663018950e25bb71ed2023699359fe4cc664a17 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 25 May 2026 19:58:36 -0700 Subject: [PATCH 3/3] multigpu: trim _force_supported_compute_dtype docstring to one line Amp-Thread-ID: https://ampcode.com/threads/T-019e61db-ffb1-73a6-b2a8-3d23d7b05792 Co-authored-by: Amp --- comfy_extras/nodes_multigpu.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py index 878d85baf11f..d2f6fe67a08f 100644 --- a/comfy_extras/nodes_multigpu.py +++ b/comfy_extras/nodes_multigpu.py @@ -49,11 +49,7 @@ def execute(cls, model: ModelPatcher, max_gpus: int) -> io.NodeOutput: def _force_supported_compute_dtype(patcher: ModelPatcher, device: torch.device): - """Cast compute dtype to one the device supports; no-op if already supported. - - Uses unet_manual_cast which encodes per-device dtype support (e.g. CPU - rejects fp16/bf16, falling back to fp32 to avoid PyTorch's ~500-600x - slower software emulation).""" + """Cast compute dtype to one the device supports; no-op if already supported.""" weight_dtype = patcher.model_dtype() cast_dtype = comfy.model_management.unet_manual_cast(weight_dtype, device) if cast_dtype is None: