Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 0 additions & 82 deletions comfy/multigpu.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from __future__ import annotations
import copy
import queue
import threading
import torch
Expand Down Expand Up @@ -176,87 +175,6 @@ def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options:
return model


def create_upscale_model_multigpu_deepclones(upscale_model, max_gpus: int):
"""Return a shallow copy of ``upscale_model`` with a ``multigpu_clones`` dict of CPU-resident
descriptor deepclones, one per extra CUDA device up to ``max_gpus``.
"""
full_extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True)
limit_extra_devices = full_extra_devices[:max_gpus - 1]
cloned = copy.copy(upscale_model)
existing = getattr(upscale_model, 'multigpu_clones', None)
limit_extra_device_set = set(limit_extra_devices)
clones: dict[torch.device, object] = {d: c for d, c in dict(existing).items() if d in limit_extra_device_set} if existing else {}
if len(limit_extra_devices) == 0:
logging.info("No extra torch devices need initialization, skipping initializing MultiGPU upscale clones.")
if hasattr(cloned, 'multigpu_clones'):
del cloned.multigpu_clones
return cloned

for device in limit_extra_devices:
if device in clones:
continue
clone_source = copy.copy(upscale_model)
if hasattr(clone_source, 'multigpu_clones'):
del clone_source.multigpu_clones
clone_desc = copy.deepcopy(clone_source)
clone_desc.model.eval()
for p in clone_desc.model.parameters():
p.requires_grad_(False)
clone_desc.to("cpu")
clones[device] = clone_desc
logging.info(f"Created CPU upscale_model descriptor deepclone for {device}")

cloned.multigpu_clones = clones
return cloned


def create_vae_multigpu_deepclones(vae, max_gpus: int):
"""Return a shallow copy of ``vae`` with a ``multigpu_clones`` dict of CPU-resident VAE
deepclones, one per extra CUDA device up to ``max_gpus``.
"""
vae.throw_exception_if_invalid()
vae_device = torch.device(vae.device)
cloned = copy.copy(vae)
if hasattr(cloned, 'multigpu_clones'):
del cloned.multigpu_clones
if vae_device.type == "cpu":
logging.info("CPU VAE selected, skipping initializing MultiGPU VAE clones.")
return cloned

full_extra_devices = comfy.model_management.get_all_torch_devices()

def is_vae_device(device):
return device.type == vae_device.type and device.index == vae_device.index

limit_extra_devices = [d for d in full_extra_devices if not is_vae_device(d)][:max_gpus - 1]
if len(limit_extra_devices) == 0:
logging.info("No extra torch devices need initialization, skipping initializing MultiGPU VAE clones.")
return cloned

existing = getattr(vae, 'multigpu_clones', None)
limit_extra_device_set = set(limit_extra_devices)
clones: dict[torch.device, object] = {d: c for d, c in dict(existing).items() if d in limit_extra_device_set} if existing else {}

for device in limit_extra_devices:
if device in clones:
continue
cloned_patcher = vae.patcher.deepclone_multigpu(new_load_device=device)
clone_vae = copy.copy(vae)
if hasattr(clone_vae, 'multigpu_clones'):
del clone_vae.multigpu_clones
clone_vae.first_stage_model = cloned_patcher.model
clone_vae.patcher = cloned_patcher
clone_vae.first_stage_model.eval()
for p in clone_vae.first_stage_model.parameters():
p.requires_grad_(False)
clone_vae.first_stage_model.to("cpu")
clones[device] = clone_vae
logging.info(f"Created CPU VAE deepclone for {device}")

cloned.multigpu_clones = clones
return cloned


LoadBalance = namedtuple('LoadBalance', ['work_per_device', 'idle_time'])
def load_balance_devices(model_options: dict[str], total_work: int, return_idle_time=False, work_normalized: int=None):
'Optimize work assigned to different devices, accounting for their relative speeds and splittable work.'
Expand Down
132 changes: 1 addition & 131 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,26 +972,6 @@ def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
pbar = comfy.utils.ProgressBar(steps)

decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())

multigpu_clones = getattr(self, 'multigpu_clones', None)
if multigpu_clones:
functions = {self.device: decode_fn}
try:
for dev, c in multigpu_clones.items():
model_management.free_memory(c.model_size() + c.memory_used_decode(samples.shape, c.vae_dtype), dev)
c.first_stage_model.to(dev)
for dev, c in multigpu_clones.items():
functions[dev] = lambda a, _c=c, _dev=dev: _c.first_stage_model.decode(a.to(_c.vae_dtype).to(_dev)).to(dtype=_c.vae_output_dtype())
output = self.process_output(
(comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_y * 2, tile_x // 2), overlap=overlap, upscale_amount=self.upscale_ratio, output_device=self.output_device, pbar=pbar) +
comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_y // 2, tile_x * 2), overlap=overlap, upscale_amount=self.upscale_ratio, output_device=self.output_device, pbar=pbar) +
comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_y, tile_x), overlap=overlap, upscale_amount=self.upscale_ratio, output_device=self.output_device, pbar=pbar))
/ 3.0)
return output
finally:
for c in multigpu_clones.values():
c.first_stage_model.to("cpu")

output = self.process_output(
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
Expand All @@ -1001,49 +981,16 @@ def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):

def decode_tiled_1d(self, samples, tile_x=256, overlap=32):
if samples.ndim == 3:
memory_shape = samples.shape
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
clone_decode_fn_factory = lambda c, dev: (lambda a: c.first_stage_model.decode(a.to(c.vae_dtype).to(dev)).to(dtype=c.vae_output_dtype()))
else:
og_shape = samples.shape
memory_shape = og_shape
samples = samples.reshape((og_shape[0], og_shape[1] * og_shape[2], -1))
decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
clone_decode_fn_factory = lambda c, dev: (lambda a: c.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(c.vae_dtype).to(dev)).to(dtype=c.vae_output_dtype()))

multigpu_clones = getattr(self, 'multigpu_clones', None)
if multigpu_clones:
functions = {self.device: decode_fn}
try:
for dev, c in multigpu_clones.items():
model_management.free_memory(c.model_size() + c.memory_used_decode(memory_shape, c.vae_dtype), dev)
c.first_stage_model.to(dev)
for dev, c in multigpu_clones.items():
functions[dev] = clone_decode_fn_factory(c, dev)
return self.process_output(comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
finally:
for c in multigpu_clones.values():
c.first_stage_model.to("cpu")

return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))

def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())

multigpu_clones = getattr(self, 'multigpu_clones', None)
if multigpu_clones:
functions = {self.device: decode_fn}
try:
for dev, c in multigpu_clones.items():
model_management.free_memory(c.model_size() + c.memory_used_decode(samples.shape, c.vae_dtype), dev)
c.first_stage_model.to(dev)
for dev, c in multigpu_clones.items():
functions[dev] = lambda a, _c=c, _dev=dev: _c.first_stage_model.decode(a.to(_c.vae_dtype).to(_dev)).to(dtype=_c.vae_output_dtype())
return self.process_output(comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
finally:
for c in multigpu_clones.values():
c.first_stage_model.to("cpu")

return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))

def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
Expand All @@ -1053,25 +1000,6 @@ def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
pbar = comfy.utils.ProgressBar(steps)

encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())

multigpu_clones = getattr(self, 'multigpu_clones', None)
if multigpu_clones:
functions = {self.device: encode_fn}
try:
for dev, c in multigpu_clones.items():
model_management.free_memory(c.model_size() + c.memory_used_encode(pixel_samples.shape, c.vae_dtype), dev)
c.first_stage_model.to(dev)
for dev, c in multigpu_clones.items():
functions[dev] = lambda a, _c=c, _dev=dev: _c.first_stage_model.encode((_c.process_input(a)).to(_c.vae_dtype).to(_dev)).to(dtype=_c.vae_output_dtype())
samples = comfy.utils.tiled_scale_multidim_multigpu(pixel_samples, functions, tile=(tile_y, tile_x), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples += comfy.utils.tiled_scale_multidim_multigpu(pixel_samples, functions, tile=(tile_y // 2, tile_x * 2), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples += comfy.utils.tiled_scale_multidim_multigpu(pixel_samples, functions, tile=(tile_y * 2, tile_x // 2), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples /= 3.0
return samples
finally:
for c in multigpu_clones.values():
c.first_stage_model.to("cpu")

samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
Expand All @@ -1081,7 +1009,6 @@ def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048):
if self.latent_dim == 1:
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
clone_encode_fn_factory = lambda c, dev: (lambda a: c.first_stage_model.encode((c.process_input(a)).to(c.vae_dtype).to(dev)).to(dtype=c.vae_output_dtype()))
out_channels = self.latent_channels
upscale_amount = 1 / self.downscale_ratio
else:
Expand All @@ -1091,46 +1018,15 @@ def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048):
overlap = overlap // extra_channel_size
upscale_amount = 1 / self.downscale_ratio
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).to(dtype=self.vae_output_dtype())
clone_encode_fn_factory = lambda c, dev: (lambda a: c.first_stage_model.encode((c.process_input(a)).to(c.vae_dtype).to(dev)).reshape(1, out_channels, -1).to(dtype=c.vae_output_dtype()))

multigpu_clones = getattr(self, 'multigpu_clones', None)
if multigpu_clones:
functions = {self.device: encode_fn}
try:
for dev, c in multigpu_clones.items():
model_management.free_memory(c.model_size() + c.memory_used_encode(samples.shape, c.vae_dtype), dev)
c.first_stage_model.to(dev)
for dev, c in multigpu_clones.items():
functions[dev] = clone_encode_fn_factory(c, dev)
out = comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device)
finally:
for c in multigpu_clones.values():
c.first_stage_model.to("cpu")
else:
out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device)

out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device)
if self.latent_dim == 1:
return out
else:
return out.reshape(samples.shape[0], self.latent_channels, extra_channel_size, -1)

def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)):
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())

multigpu_clones = getattr(self, 'multigpu_clones', None)
if multigpu_clones:
functions = {self.device: encode_fn}
try:
for dev, c in multigpu_clones.items():
model_management.free_memory(c.model_size() + c.memory_used_encode(samples.shape, c.vae_dtype), dev)
c.first_stage_model.to(dev)
for dev, c in multigpu_clones.items():
functions[dev] = lambda a, _c=c, _dev=dev: _c.first_stage_model.encode((_c.process_input(a)).to(_c.vae_dtype).to(_dev)).to(dtype=_c.vae_output_dtype())
return comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
finally:
for c in multigpu_clones.values():
c.first_stage_model.to("cpu")

return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)

def decode(self, samples_in, vae_options={}):
Expand Down Expand Up @@ -1831,14 +1727,8 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
if out[0] is not None:
out[0].cached_patcher_init = (load_checkpoint_guess_config, (ckpt_path, False, False, False, embedding_directory, output_model, model_options, te_model_options), 0)
if output_vae and out[2] is not None and hasattr(out[2], "patcher"):
out[2].patcher.cached_patcher_init = (load_checkpoint_vae_patcher, (ckpt_path, embedding_directory, model_options, te_model_options, disable_dynamic))
return out

def load_checkpoint_vae_patcher(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
_, _, vae, _ = load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=False, embedding_directory=embedding_directory, output_model=False, model_options=model_options, te_model_options=te_model_options, disable_dynamic=disable_dynamic)
return vae.patcher

def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
model, *_ = load_checkpoint_guess_config(ckpt_path, False, False, False,
embedding_directory=embedding_directory,
Expand Down Expand Up @@ -2064,26 +1954,6 @@ def load_diffusion_model(unet_path, model_options={}, disable_dynamic=False):
model.cached_patcher_init = (load_diffusion_model, (unet_path, model_options))
return model

def load_vae_patcher(vae_path, metadata=None, device=None):
"""Reload a VAE from disk and return its patcher.

Used as the ``cached_patcher_init`` factory on ``VAE.patcher`` so that
:meth:`comfy.model_patcher.ModelPatcher.deepclone_multigpu` can produce a
fresh VAE patcher with no inherited source-device storage tracking. The
optional device matches the source loader's VAE initialization path; the
cloned patcher's load_device still controls the device targeted by the
multigpu clone. Without this, bare ``copy.deepcopy`` of the VAE wrapper
carries dynamic-VRAM allocator state forward to the clone, which causes
per-device worker threads in tiled encode/decode dispatch to access weights
through the source-device buffer."""
if metadata is None:
sd, metadata = comfy.utils.load_torch_file(vae_path, return_metadata=True)
else:
sd = comfy.utils.load_torch_file(vae_path)
vae = VAE(sd=sd, metadata=metadata, device=device)
vae.throw_exception_if_invalid()
return vae.patcher

def load_unet(unet_path, dtype=None):
logging.warning("The load_unet function has been deprecated and will be removed please switch to: load_diffusion_model")
return load_diffusion_model(unet_path, model_options={"dtype": dtype})
Expand Down
Loading
Loading