Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 9 additions & 13 deletions comfy_extras/nodes_multigpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,14 @@ def execute(cls, model: ModelPatcher, max_gpus: int) -> io.NodeOutput:
return io.NodeOutput(model)


def _force_fp32_cpu_compute(patcher: ModelPatcher):
"""Force fp32 inference dtype for CPU.

PyTorch's CPU conv2d kernels fall back to software emulation for fp16/bf16
and run ~500-600x slower than fp32, which makes a normal-sized workflow
look frozen for hours. Routing through set_model_compute_dtype leaves the
weights as-is and casts at use, so peak memory does not blow up."""
dtype = patcher.model_dtype()
if dtype in (torch.float16, torch.bfloat16):
logging.info(f"Select Model Device: using fp32 compute dtype for CPU inference (model dtype was {dtype}).")
patcher.set_model_compute_dtype(torch.float32)
def _force_supported_compute_dtype(patcher: ModelPatcher, device: torch.device):
"""Cast compute dtype to one the device supports; no-op if already supported."""
weight_dtype = patcher.model_dtype()
cast_dtype = comfy.model_management.unet_manual_cast(weight_dtype, device)
if cast_dtype is None:
return
logging.info(f"Select Model Device: using {cast_dtype} compute dtype on {device} (model weight dtype was {weight_dtype}).")
patcher.set_model_compute_dtype(cast_dtype)


def _remember_base_devices(patcher: ModelPatcher):
Expand Down Expand Up @@ -229,8 +226,7 @@ def execute(cls, model: ModelPatcher, device: str = "default") -> io.NodeOutput:
logging.warning(f"Select Model Device: cannot retarget model, passing through unchanged. ({e})")
return io.NodeOutput(model)
if resolved is not None:
if resolved.type == "cpu":
_force_fp32_cpu_compute(model)
_force_supported_compute_dtype(model, resolved)
_prune_multigpu_collision(model, model.load_device)
return io.NodeOutput(model)

Expand Down
Loading