From daaeb5c96c859deabd459d31194fe5fa58c328b4 Mon Sep 17 00:00:00 2001 From: Rattus Date: Thu, 1 Jan 2026 00:36:06 +1000 Subject: [PATCH 01/51] Reduce RAM and compute time in model saving with Loras Get the model saving logic away from force_patch_weights and instead do the patching JIT during safetensors saving. Firstly switch off force_patch_weights in the load for save which avoids creating CPU side tensors with loras calculated. Then at save time, wrap the tensor to catch safetensors call to .to() and patch it live. This avoids having to ever have a lora-calculated copy of offloaded weights on the CPU. Also take advantage of the presence of the GPU when doing this Lora calculation. The former force_patch_weights would just do eveyrthing on the CPU. Its generally faster to go the GPU and back even if its just a Lora application. --- comfy/model_base.py | 9 +++---- comfy/model_patcher.py | 53 +++++++++++++++++++++++++++++++++++++----- comfy/sd.py | 4 ++-- 3 files changed, 52 insertions(+), 14 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 66e52864d763..4a248beec2ab 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -322,7 +322,7 @@ def process_latent_in(self, latent): def process_latent_out(self, latent): return self.latent_format.process_out(latent) - def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): + def state_dict_for_saving(self, unet_state_dict, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): extra_sds = [] if clip_state_dict is not None: extra_sds.append(self.model_config.process_clip_state_dict_for_saving(clip_state_dict)) @@ -330,10 +330,7 @@ def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_ extra_sds.append(self.model_config.process_vae_state_dict_for_saving(vae_state_dict)) if clip_vision_state_dict is not None: extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict)) - - unet_state_dict = self.diffusion_model.state_dict() unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) - if self.model_type == ModelType.V_PREDICTION: unet_state_dict["v_pred"] = torch.tensor([]) @@ -776,8 +773,8 @@ def extra_conds(self, **kwargs): out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) return out - def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): - sd = super().state_dict_for_saving(clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict) + def state_dict_for_saving(self, unet_state_dict, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): + sd = super().state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict) d = {"conditioner.conditioners.seconds_start.": self.seconds_start_embedder.state_dict(), "conditioner.conditioners.seconds_total.": self.seconds_total_embedder.state_dict()} for k in d: s = d[k] diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index f6b80a40f1f9..30ca39b2a974 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -24,6 +24,7 @@ import logging import math import uuid +import types from typing import Callable, Optional import torch @@ -212,6 +213,27 @@ def is_useable(self, used: int): def decrement(self, used: int): self.value -= used +CustomTorchDevice = collections.namedtuple("FakeDevice", ["type", "index"])("comfy-lazy-caster", 0) + +class LazyCastingParam(torch.nn.Parameter): + def __new__(cls, model, key, tensor): + return super().__new__(cls, tensor) + + def __init__(self, model, key, tensor): + self.model = model + self.key = key + + @property + def device(self): + return CustomTorchDevice + + #safetensors will .to() us to the cpu which we catch here to cast on demand. The returned tensor is + #then just a short lived thing in the safetensors serialization logic inside its big for loop over + #all weights getting garbage collected per-weight + def to(self, *args, **kwargs): + return self.model.patch_weight_to_device(self.key, device_to=self.model.load_device, return_weight=True).to("cpu") + + class ModelPatcher: def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False): self.size = size @@ -611,14 +633,14 @@ def model_state_dict(self, filter_prefix=None): sd.pop(k) return sd - def patch_weight_to_device(self, key, device_to=None, inplace_update=False): + def patch_weight_to_device(self, key, device_to=None, inplace_update=False, return_weight=False): + weight, set_func, convert_func = get_key_weight(self.model, key) if key not in self.patches: - return + return weight - weight, set_func, convert_func = get_key_weight(self.model, key) inplace_update = self.weight_inplace_update or inplace_update - if key not in self.backup: + if key not in self.backup and not return_weight: self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update) temp_dtype = comfy.model_management.lora_compute_dtype(device_to) @@ -632,12 +654,14 @@ def patch_weight_to_device(self, key, device_to=None, inplace_update=False): out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key) if set_func is None: out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key)) - if inplace_update: + if return_weight: + return out_weight + elif inplace_update: comfy.utils.copy_to_param(self.model, key, out_weight) else: comfy.utils.set_attr_param(self.model, key, out_weight) else: - set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key)) + return set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key), return_weight=return_weight) def pin_weight_to_device(self, key): weight, set_func, convert_func = get_key_weight(self.model, key) @@ -1355,6 +1379,23 @@ def clean_hooks(self): self.unpatch_hooks() self.clear_cached_hook_weights() + def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): + unet_state_dict = self.model.diffusion_model.state_dict() + for k, v in unet_state_dict.items(): + op_keys = k.rsplit('.', 1) + if (len(op_keys) < 2) or not op_keys[1] in ["weight", "bias"]: + continue + try: + op = comfy.utils.get_attr(self.model.diffusion_model, op_keys[0]) + except: + continue + if not op or not hasattr(op, "comfy_cast_weights") or \ + (hasattr(op, "comfy_patched_weights") and op.comfy_patched_weights == True): + continue + key = "diffusion_model." + k + unet_state_dict[k] = LazyCastingParam(self, key, comfy.utils.get_attr(self.model, key)) + return self.model.state_dict_for_saving(unet_state_dict) + def __del__(self): self.unpin_all_weights() self.detach(unpatch_all=False) diff --git a/comfy/sd.py b/comfy/sd.py index f627f7d55780..1953505cc2a8 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1692,9 +1692,9 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m if metadata is None: metadata = {} - model_management.load_models_gpu(load_models, force_patch_weights=True) + model_management.load_models_gpu(load_models) clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None - sd = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd) + sd = model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd) for k in extra_keys: sd[k] = extra_keys[k] From db78623796800254be3c75bda75aa5326ac62d4e Mon Sep 17 00:00:00 2001 From: Rattus Date: Mon, 22 Dec 2025 19:40:22 +1000 Subject: [PATCH 02/51] ops: Do bias dtype conversion on compute stream For consistency with weights. --- comfy/ops.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/comfy/ops.py b/comfy/ops.py index e406ba7edeab..35a1ac953ee5 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -102,7 +102,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of bias = None if s.bias is not None: - bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream) + bias = comfy.model_management.cast_to(s.bias, None, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream) comfy.model_management.sync_stream(device, offload_stream) @@ -110,6 +110,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of weight_a = weight if s.bias is not None: + bias = bias.to(dtype=bias_dtype) for f in s.bias_function: bias = f(bias) From a08aed2d7e6eb7e7057d75d484be1ddcd28a96d6 Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 23 Dec 2025 11:01:32 +1000 Subject: [PATCH 03/51] mm: Implement cast buffer allocations --- comfy/memory_management.py | 51 ++++++++++++++++++++++++++++++++++ comfy/model_management.py | 57 ++++++++++++++++++++++++++++++++++++-- comfy/ops.py | 20 +++++++++++-- cuda_malloc.py | 5 +++- 4 files changed, 126 insertions(+), 7 deletions(-) create mode 100644 comfy/memory_management.py diff --git a/comfy/memory_management.py b/comfy/memory_management.py new file mode 100644 index 000000000000..f8bca526322f --- /dev/null +++ b/comfy/memory_management.py @@ -0,0 +1,51 @@ +import torch +from comfy.quant_ops import QuantizedTensor + +def vram_aligned_size(tensor): + if isinstance(tensor, list): + return sum([vram_aligned_size(t) for t in tensor]) + + if isinstance(tensor, QuantizedTensor): + inner_tensors, _ = tensor.__tensor_flatten__() + return vram_aligned_size([ getattr(tensor, attr) for attr in inner_tensors ]) + + if tensor is None: + return 0 + + size = tensor.numel() * tensor.element_size() + aligment_req = 1024 + return (size + aligment_req - 1) // aligment_req * aligment_req + +def interpret_gathered_like(tensors, gathered): + offset = 0 + dest_views = [] + + if gathered.dim() != 1 or gathered.element_size() != 1: + raise ValueError(f"Buffer must be 1D and single-byte (got {gathered.dim()}D {gathered.dtype})") + + for tensor in tensors: + + if tensor is None: + dest_views.append(None) + continue + + if isinstance(tensor, QuantizedTensor): + inner_tensors, qt_ctx = tensor.__tensor_flatten__() + templates = { attr: getattr(tensor, attr) for attr in inner_tensors } + else: + templates = { "data": tensor } + + actuals = {} + for attr, template in templates.items(): + size = template.numel() * template.element_size() + if offset + size > gathered.numel(): + raise ValueError(f"Buffer too small: needs {offset + size} bytes, but only has {gathered.numel()}. ") + actuals[attr] = gathered[offset:offset+size].view(dtype=template.dtype).view(template.shape) + offset += vram_aligned_size(template) + + if isinstance(tensor, QuantizedTensor): + dest_views.append(QuantizedTensor.__tensor_unflatten__(actuals, qt_ctx, 0, 0)) + else: + dest_views.append(actuals["data"]) + + return dest_views diff --git a/comfy/model_management.py b/comfy/model_management.py index 9d39be7b2a19..790236ede8e8 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -26,6 +26,8 @@ import weakref import gc import os +from contextlib import nullcontext +import comfy.quant_ops class VRAMState(Enum): DISABLED = 0 #No vram present: no need to move models to vram @@ -732,6 +734,9 @@ def loaded_models(only_currently_used=False): def cleanup_models_gc(): do_gc = False + + reset_cast_buffers() + for i in range(len(current_loaded_models)): cur = current_loaded_models[i] if cur.is_dead(): @@ -1051,6 +1056,49 @@ def current_stream(device): return None stream_counters = {} + +STREAM_CAST_BUFFERS = {} +LARGEST_CASTED_WEIGHT = (None, 0) + +def get_cast_buffer(offload_stream, device, size, ref): + global LARGEST_CASTED_WEIGHT + + if offload_stream is not None: + wf_context = offload_stream + if hasattr(wf_context, "as_context"): + wf_context = wf_context.as_context(offload_stream) + else: + wf_context = nullcontext() + + cast_buffer = STREAM_CAST_BUFFERS.get(offload_stream, None) + if cast_buffer is None or cast_buffer.numel() < size: + if ref is LARGEST_CASTED_WEIGHT[0]: + #If there is one giant weight we do not want both streams to + #allocate a buffer for it. It's up to the caster to get the other + #offload stream in this corner case + return None + if cast_buffer is not None and cast_buffer.numel() > 50 * (1024 ** 2): + #I want my wrongly sized 50MB+ of VRAM back from the caching allocator right now + del STREAM_CAST_BUFFERS[offload_stream] + del cast_buffer + torch.cuda.synchronize() + torch.cuda.empty_cache() + with wf_context: + cast_buffer = torch.empty((size), dtype=torch.int8, device=device) + STREAM_CAST_BUFFERS[offload_stream] = cast_buffer + + if size > LARGEST_CASTED_WEIGHT[1]: + LARGEST_CASTED_WEIGHT = (ref, size) + + return cast_buffer + +def reset_cast_buffers(): + global LARGEST_CASTED_WEIGHT + LARGEST_CASTED_WEIGHT = (None, 0) + STREAM_CAST_BUFFERS.clear() + torch.cuda.synchronize() + torch.cuda.empty_cache() + def get_offload_stream(device): stream_counter = stream_counters.get(device, 0) if NUM_STREAMS == 0: @@ -1093,7 +1141,7 @@ def sync_stream(device, stream): return current_stream(device).wait_stream(stream) -def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None): +def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None): if device is None or weight.device == device: if not copy: if dtype is None or weight.dtype == dtype: @@ -1112,10 +1160,12 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str if hasattr(wf_context, "as_context"): wf_context = wf_context.as_context(stream) with wf_context: - r = torch.empty_like(weight, dtype=dtype, device=device) + if r is None: + r = torch.empty_like(weight, dtype=dtype, device=device) r.copy_(weight, non_blocking=non_blocking) else: - r = torch.empty_like(weight, dtype=dtype, device=device) + if r is None: + r = torch.empty_like(weight, dtype=dtype, device=device) r.copy_(weight, non_blocking=non_blocking) return r @@ -1557,6 +1607,7 @@ def soft_empty_cache(force=False): elif is_mlu(): torch.mlu.empty_cache() elif torch.cuda.is_available(): + torch.cuda.synchronize() torch.cuda.empty_cache() torch.cuda.ipc_collect() diff --git a/comfy/ops.py b/comfy/ops.py index 35a1ac953ee5..276081addd4a 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -23,6 +23,7 @@ import comfy.float import comfy.rmsnorm import json +import comfy.memory_management def run_every_op(): if torch.compiler.is_compiling(): @@ -93,16 +94,29 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of else: offload_stream = None + bias = None + weight = None + + if offload_stream is not None and not args.cuda_malloc: + cast_buffer_size = comfy.memory_management.vram_aligned_size([ s.weight, s.bias ]) + cast_buffer = comfy.model_management.get_cast_buffer(offload_stream, device, cast_buffer_size, s) + #The streams can be uneven in buffer capability and reject us. Retry to get the other stream + if cast_buffer is None: + offload_stream = comfy.model_management.get_offload_stream(device) + cast_buffer = comfy.model_management.get_cast_buffer(offload_stream, device, cast_buffer_size, s) + params = interpret_gathered_like([ s.weight, s.bias ], cast_buffer) + weight = params[0] + bias = params[1] + non_blocking = comfy.model_management.device_supports_non_blocking(device) weight_has_function = len(s.weight_function) > 0 bias_has_function = len(s.bias_function) > 0 - weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream) + weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream, r=weight) - bias = None if s.bias is not None: - bias = comfy.model_management.cast_to(s.bias, None, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream) + bias = comfy.model_management.cast_to(s.bias, None, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream, r=bias) comfy.model_management.sync_stream(device, offload_stream) diff --git a/cuda_malloc.py b/cuda_malloc.py index ee2bc4b69d8f..00ee7b633050 100644 --- a/cuda_malloc.py +++ b/cuda_malloc.py @@ -86,7 +86,10 @@ def cuda_malloc_supported(): pass -if args.cuda_malloc and not args.disable_cuda_malloc: +if args.disable_cuda_malloc: + args.cuda_malloc = False + +if args.cuda_malloc: env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None) if env_var is None: env_var = "backend:cudaMallocAsync" From 37567cb0d1ced133368ef1920d007f7be3d0bd49 Mon Sep 17 00:00:00 2001 From: Rattus Date: Thu, 8 Jan 2026 18:21:50 +1000 Subject: [PATCH 04/51] move string_to_seed to utils.py This needs to be visible by ops which may want to do stochastic rounding on the fly. --- comfy/model_patcher.py | 21 ++++----------------- comfy/utils.py | 13 +++++++++++++ 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 30ca39b2a974..46dcf5be82fc 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -40,19 +40,6 @@ from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP -def string_to_seed(data): - crc = 0xFFFFFFFF - for byte in data: - if isinstance(byte, str): - byte = ord(byte) - crc ^= byte - for _ in range(8): - if crc & 1: - crc = (crc >> 1) ^ 0xEDB88320 - else: - crc >>= 1 - return crc ^ 0xFFFFFFFF - def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None): to = model_options["transformer_options"].copy() @@ -653,7 +640,7 @@ def patch_weight_to_device(self, key, device_to=None, inplace_update=False, retu out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key) if set_func is None: - out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key)) + out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=comfy.utils.string_to_seed(key)) if return_weight: return out_weight elif inplace_update: @@ -661,7 +648,7 @@ def patch_weight_to_device(self, key, device_to=None, inplace_update=False, retu else: comfy.utils.set_attr_param(self.model, key, out_weight) else: - return set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key), return_weight=return_weight) + return set_func(out_weight, inplace_update=inplace_update, seed=comfy.utils.string_to_seed(key), return_weight=return_weight) def pin_weight_to_device(self, key): weight, set_func, convert_func = get_key_weight(self.model, key) @@ -1341,10 +1328,10 @@ def patch_hook_weight_to_device(self, hooks: comfy.hooks.HookGroup, combined_pat key, original_weights=original_weights) del original_weights[key] if set_func is None: - out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key)) + out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=comfy.utils.string_to_seed(key)) comfy.utils.copy_to_param(self.model, key, out_weight) else: - set_func(out_weight, inplace_update=True, seed=string_to_seed(key)) + set_func(out_weight, inplace_update=True, seed=comfy.utils.string_to_seed(key)) if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed: # TODO: disable caching if not enough system RAM to do so target_device = self.offload_device diff --git a/comfy/utils.py b/comfy/utils.py index d97d753e6da3..2d11dedbeba4 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1308,3 +1308,16 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}): state_dict["{}.comfy_quant".format(k)] = torch.tensor(list(json.dumps(v).encode('utf-8')), dtype=torch.uint8) return state_dict, metadata + +def string_to_seed(data): + crc = 0xFFFFFFFF + for byte in data: + if isinstance(byte, str): + byte = ord(byte) + crc ^= byte + for _ in range(8): + if crc & 1: + crc = (crc >> 1) ^ 0xEDB88320 + else: + crc >>= 1 + return crc ^ 0xFFFFFFFF From 3c2ce0d58d8103ba31803ad6dc72f85eef418113 Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 13 Jan 2026 15:16:41 +1000 Subject: [PATCH 05/51] pinned_memory: add python Add a python for managing pinned memory of the weight/bias module level. This allocates, pins and attached a tensor to a module for the pin for this module. It does not set the weight, just allocates a singular ram buffer for population and bulk DMA transfer. --- comfy/model_management.py | 2 +- comfy/pinned_memory.py | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) create mode 100644 comfy/pinned_memory.py diff --git a/comfy/model_management.py b/comfy/model_management.py index 790236ede8e8..21761d971fb4 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1185,7 +1185,7 @@ def cast_to_device(tensor, device, dtype, copy=False): MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95 logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024))) -PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"]) +PINNING_ALLOWED_TYPES = set(["Tensor", "Parameter", "QuantizedTensor"]) def discard_cuda_async_error(): try: diff --git a/comfy/pinned_memory.py b/comfy/pinned_memory.py new file mode 100644 index 000000000000..be303b4f18db --- /dev/null +++ b/comfy/pinned_memory.py @@ -0,0 +1,34 @@ +import torch +import logging +import comfy.model_management +import comfy.memory_management + +from comfy.cli_args import args + +def get_pin(module): + return getattr(module, "_pin", None) + +def pin_memory(module): + if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None: + return + #FIXME: This is a RAM cache trigger event + params = [ module.weight, module.bias ] + size = comfy.memory_management.vram_aligned_size(params) + try: + pin = torch.empty((size,), dtype=torch.uint8) + if comfy.model_management.pin_memory(pin): + module._pin = pin + else: + module.pin_failed = True + return False + except: + module.pin_failed = True + return False + return True + +def unpin_memory(module): + if get_pin(module) is None: + return 0 + size = module._pin.numel() * module._pin.element_size() + comfy.model_management.unpin_memory(module._pin) + return size From b6fd3dc2eb02d4846c5d853ef2ce05bf35044212 Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 13 Jan 2026 15:20:47 +1000 Subject: [PATCH 06/51] mp: wrap get_free_memory Dynamic load needs to adjust these numbers based on future movements, so wrap this in a MP API. --- comfy/model_patcher.py | 3 +++ comfy/samplers.py | 2 +- comfy/sd.py | 4 ++-- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 46dcf5be82fc..24e3e5fcd782 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -293,6 +293,9 @@ def loaded_size(self): def lowvram_patch_counter(self): return self.model.lowvram_patch_counter + def get_free_memory(self, device): + return comfy.model_management.get_free_memory(device) + def clone(self): n = self.__class__(self.model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update) n.patches = {} diff --git a/comfy/samplers.py b/comfy/samplers.py index 1989ef107adc..d495ca203895 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -260,7 +260,7 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens to_batch_temp.reverse() to_batch = to_batch_temp[:1] - free_memory = model_management.get_free_memory(x_in.device) + free_memory = model.current_patcher.get_free_memory(x_in.device) for i in range(1, len(to_batch_temp) + 1): batch_amount = to_batch_temp[:len(to_batch_temp)//i] input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] diff --git a/comfy/sd.py b/comfy/sd.py index 1953505cc2a8..ee0ebd50df80 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -897,7 +897,7 @@ def decode(self, samples_in, vae_options={}): try: memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) - free_memory = model_management.get_free_memory(self.device) + free_memory = self.patcher.get_free_memory(self.device) batch_number = int(free_memory / memory_used) batch_number = max(1, batch_number) @@ -971,7 +971,7 @@ def encode(self, pixel_samples): try: memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) - free_memory = model_management.get_free_memory(self.device) + free_memory = self.patcher.get_free_memory(self.device) batch_number = int(free_memory / max(1, memory_used)) batch_number = max(1, batch_number) samples = None From 13a7b68ad76ffc2f2142d89dda247b83bab4643b Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 13 Jan 2026 15:23:49 +1000 Subject: [PATCH 07/51] mp/mm: APi expansions for dynamic loading Add two api expansions, a flag for whether a model patcher is dynamic a a very basic RAM freeing system. Implement the semantics of the dynamic model patcher which never frees VRAM ahead of time for the sake of another dynamic model patcher. At the same time add an API for clearing out pins on a reservation of model size x2 heuristic, as pins consume RAM in their own right in the dynamic patcher. This is actually less about OOMing RAM and more about performance, as with assign=True load semantics there needs to be plenty headroom for the OS to load models to dosk cache on demand so err on the side of kicking old pins out. --- comfy/model_management.py | 34 +++++++++++++++++++++++++--------- comfy/model_patcher.py | 6 ++++++ 2 files changed, 31 insertions(+), 9 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 21761d971fb4..b1bf3bd59f19 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -594,7 +594,7 @@ def extra_reserved_memory(): def minimum_inference_memory(): return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory() -def free_memory(memory_required, device, keep_loaded=[]): +def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_required=0): cleanup_models_gc() unloaded_model = [] can_unload = [] @@ -609,15 +609,22 @@ def free_memory(memory_required, device, keep_loaded=[]): for x in sorted(can_unload): i = x[-1] - memory_to_free = None + memory_to_free = 1e32 + ram_to_free = 1e32 if not DISABLE_SMART_MEMORY: - free_mem = get_free_memory(device) - if free_mem > memory_required: - break - memory_to_free = memory_required - free_mem + memory_to_free = memory_required - get_free_memory(device) + ram_to_free = ram_required - psutil.virtual_memory().available + + if current_loaded_models[i].model.is_dynamic() and for_dynamic: + #don't actually unload dynamic models for the sake of other dynamic models + #as that works on-demand. + memory_required -= current_loaded_models[i].model.loaded_size() + continue logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}") - if current_loaded_models[i].model_unload(memory_to_free): + if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free): unloaded_model.append(i) + if ram_to_free > 0: + current_loaded_models[i].model.partially_unload_ram(ram_to_free) for i in sorted(unloaded_model, reverse=True): unloaded_models.append(current_loaded_models.pop(i)) @@ -652,7 +659,10 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu models_to_load = [] + free_for_dynamic=True for x in models: + if not x.is_dynamic(): + free_for_dynamic = False loaded_model = LoadedModel(x) try: loaded_model_index = current_loaded_models.index(loaded_model) @@ -678,19 +688,25 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu model_to_unload.model.detach(unpatch_all=False) model_to_unload.model_finalizer.detach() + total_memory_required = {} + total_ram_required = {} for loaded_model in models_to_load: total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) + #x2, one to make sure the OS can fit the model for loading in disk cache, and for us to do any pinning we + #want to do. + #FIXME: This should subtract off the to_load current pin consumption. + total_ram_required[loaded_model.device] = total_ram_required.get(loaded_model.device, 0) + loaded_model.model_memory() * 2 for device in total_memory_required: if device != torch.device("cpu"): - free_memory(total_memory_required[device] * 1.1 + extra_mem, device) + free_memory(total_memory_required[device] * 1.1 + extra_mem, device, for_dynamic=free_for_dynamic, ram_required=total_ram_required[device]) for device in total_memory_required: if device != torch.device("cpu"): free_mem = get_free_memory(device) if free_mem < minimum_memory_required: - models_l = free_memory(minimum_memory_required, device) + models_l = free_memory(minimum_memory_required, device, for_dynamic=free_for_dynamic) logging.info("{} models unloaded.".format(len(models_l))) for loaded_model in models_to_load: diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 24e3e5fcd782..57cec274762d 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -278,6 +278,9 @@ def __init__(self, model, load_device, offload_device, size=0, weight_inplace_up if not hasattr(self.model, 'model_offload_buffer_memory'): self.model.model_offload_buffer_memory = 0 + def is_dynamic(self): + return False + def model_size(self): if self.size > 0: return self.size @@ -998,6 +1001,9 @@ def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): return self.model.model_loaded_weight_memory - current_used + def partially_unload_ram(self, ram_to_unload): + pass + def detach(self, unpatch_all=True): self.eject_model() self.model_patches_to(self.offload_device) From 594b472ca9099dbcbb341ca172596373e74dce13 Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 13 Jan 2026 15:33:18 +1000 Subject: [PATCH 08/51] mp: add mode for non comfy weight prioritization non-comfy weights dont get async offload and a few other performance limitations. Load them at top priority accordingly. --- comfy/model_patcher.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 57cec274762d..0263f133d1e6 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -671,7 +671,7 @@ def unpin_all_weights(self): for key in list(self.pinned): self.unpin_weight(key) - def _load_list(self): + def _load_list(self, prio_comfy_cast_weights=False): loading = [] for n, m in self.model.named_modules(): params = [] @@ -698,7 +698,8 @@ def check_module_offload_mem(key): return 0 module_offload_mem += check_module_offload_mem("{}.weight".format(n)) module_offload_mem += check_module_offload_mem("{}.bias".format(n)) - loading.append((module_offload_mem, module_mem, n, m, params)) + prepend = (not hasattr(m, "comfy_cast_weights"),) if prio_comfy_cast_weights else () + loading.append(prepend + (module_offload_mem, module_mem, n, m, params)) return loading def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False): From 6a8255f0c53802d8f72c7595c217cea52df28c3b Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 13 Jan 2026 15:36:09 +1000 Subject: [PATCH 09/51] ops/mp: implement aimdo Implement a model patcher and caster for aimdo. A new ModelPatcher implementation which backs onto comfy-aimdo to implement varying model load levels that can be adjusted during model use. The patcher defers all load processes to lazily load the model during use (e.g. the first step of a ksampler) and automatically negotiates a load level during the inference to maximize VRAM usage without OOMing. If inference requires more VRAM than is available weights are offloaded to make space before the OOM happens. As for loading the weight onto the GPU, that happens via comfy_cast_weights which is now used in all cases. cast_bias_weight checks whether the VBAR assigned to the model has space for the weight (based on the same load priority semantics as the original ModelPatcher). If it does, the VRAM as returned by the Aimdo allocator is used as the parameter GPU side. The caster is responsible for populating the weight data. This is done using the usual offload_stream (which mean we now have asynchronous load overlapping first use compute). Pinning works a little differently. When a weight is detected during load as unable to fit, a pin is allocated at the time of casting and the weight as used by the layer is DMAd back to the the pin using the GPU DMA TX engine, also using the asynchronous offload streams. This means you get to pin the Lora modified and requantized weights which can be a major speedup for offload+quantize+lora use cases, This works around the JIT Lora + FP8 exclusion and brings FP8MM to heavy offloading users (who probably really need it with more modest GPUs). There is a performance risk in that a CPU+RAM patch has been replace with a GPU+RAM patch but my initial performance results look good. Most users as likely to have a GPU that outruns their CPU in these woods. Some common code is written to consolidate a layers tensors for aimdo mapping, pinning, and DMA transfers. interpret_gathered_like() allows unpacking a raw buffer as a set of tensors. This is used consistently to bundle and pack weights, quantization metadata (QuantizedTensor bits) and biases into one payload for DMA in the load process reducing Cuda overhead a little. Some Quantization metadata was missing async offload is some cases which is now added. This also pins quantization metadata and consolidates the number of cuda_host_register calls (which can be expensive). --- comfy/model_management.py | 56 ++++++++++ comfy/model_patcher.py | 214 ++++++++++++++++++++++++++++++++++++++ comfy/ops.py | 142 ++++++++++++++++++++++--- 3 files changed, 400 insertions(+), 12 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index b1bf3bd59f19..c5a22e04c65f 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -27,8 +27,12 @@ import gc import os from contextlib import nullcontext +import comfy.utils import comfy.quant_ops +import comfy_aimdo.torch +import comfy_aimdo.model_vbar + class VRAMState(Enum): DISABLED = 0 #No vram present: no need to move models to vram NO_VRAM = 1 #Very low vram: enable all the options to save vram @@ -1157,7 +1161,59 @@ def sync_stream(device, stream): return current_stream(device).wait_stream(stream) + +def cast_to_gathered(tensors, r, non_blocking=False, stream=None): + wf_context = nullcontext() + if stream is not None: + wf_context = stream + if hasattr(wf_context, "as_context"): + wf_context = wf_context.as_context(stream) + + dest_views = comfy.memory_management.interpret_gathered_like(tensors, r) + with wf_context: + for tensor in tensors: + dest_view = dest_views.pop(0) + if tensor is None: + continue + dest_view.copy_(tensor, non_blocking=non_blocking) + + def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None): + if hasattr(weight, "_v"): + #Unexpected usage patterns. There is no reason these don't work but they + #have no testing and no callers do this. + assert r is None + assert stream is None + + r = torch.empty_like(weight, dtype=dtype, device=device) + + signature = comfy_aimdo.model_vbar.vbar_fault(weight._v) + if signature is not None: + raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device) + v_tensor = comfy.memory_management.interpret_gathered_like([weight], raw_tensor)[0] + + if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature): + #always take a deep copy even if _v is good, as we have no reasonable point to unpin + #a non comfy weight + r.copy_(v_tensor) + comfy_aimdo.model_vbar.vbar_unpin(weight._v) + return r + + r.copy_(weight, non_blocking=non_blocking) + + #FIXME: remove hooks before PR + if hasattr(weight, "comfy_hook"): + dtype = r.dtype + r = weight.comfy_hook(r) + if r.dtype != dtype: + r = comfy.float.stochastic_rounding(r, dtype, seed=comfy.utils.string_to_seed(weight.seed_key)) + + if signature is not None: + v_tensor.copy_(r) + comfy_aimdo.model_vbar.vbar_unpin(weight._v) + + return r + if device is None or weight.device == device: if not copy: if dtype is None or weight.dtype == dtype: diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 0263f133d1e6..77e7eec90150 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -39,6 +39,7 @@ from comfy.quant_ops import QuantizedTensor from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP +import comfy_aimdo.model_vbar def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None): to = model_options["transformer_options"].copy() @@ -1397,3 +1398,216 @@ def __del__(self): self.unpin_all_weights() self.detach(unpatch_all=False) +class ModelPatcherDynamic(ModelPatcher): + + def __new__(cls, model, load_device, offload_device, size=0, weight_inplace_update=False): + if comfy.model_management.is_device_cpu(load_device): + #reroute to default MP for CPUs + return ModelPatcher(model, load_device, offload_device, size, weight_inplace_update) + return super().__new__(cls) + + def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False): + super().__init__(model, load_device, offload_device, size, weight_inplace_update) + #this is now way more dynamic and we dont support the same base model for both Dynamic + #and non-dynamic patchers. + if hasattr(self.model, "model_loaded_weight_memory"): + del self.model.model_loaded_weight_memory + if not hasattr(self.model, "dynamic_vbars"): + self.model.dynamic_vbars = {} + assert load_device is not None + + def is_dynamic(self): + return True + + def _vbar_get(self, create=False): + if self.load_device == torch.device("cpu"): + return None + vbar = self.model.dynamic_vbars.get(self.load_device, None) + if create and vbar is None: + vbar = comfy_aimdo.model_vbar.ModelVBAR(self.model_size() * 1.2, self.load_device.index) + self.model.dynamic_vbars[self.load_device] = vbar + return vbar + + def loaded_size(self): + vbar = self._vbar_get() + if vbar is None: + return 0 + return vbar.loaded_size() + + def get_free_memory(self, device): + #NOTE: on high condition / batch counts, estimate should have already vacated + #all non-dynamic models so this is safe even if its not 100% true that this + #would all be avaiable for inference use. + return comfy.model_management.get_total_memory(device) - self.model_size() + + #Pinning is deferred to ops time. Assert against this API to avoid pin leaks. + + def pin_weight_to_device(self, key): + raise RuntimeError("pin_weight_to_device invalid for dymamic weight loading") + + def unpin_weight(self, key): + raise RuntimeError("unpin_weight invalid for dymamic weight loading") + + def unpin_all_weights(self): + pass + + def memory_required(self, input_shape): + #Pad this significantly. We are trying to get away from precise estimates. This + #estimate is only used when using the ModelPatcherDynamic after ModelPatcher. If you + #use all ModelPatcherDynamic this is ignored and its all done dynamically. + return super().memory_required(input_shape=input_shape) * 1.3 + (1024 ** 3) + + + def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False, dirty=False): + + #Force patching doesn't make sense in Dynamic loading, as you dont know what does and + #doesn't need to be forced at this stage. The only thing you could do would be patch + #it all on CPU which consumes huge RAM. + assert not force_patch_weights + + #Full load doesn't make sense as we dont actually have any loader capability here and + #now. + assert not full_load; + + assert device_to == self.load_device + + num_patches = 0 + allocated_size = 0 + + with self.use_ejected(): + self.unpatch_hooks() + + vbar = self._vbar_get(create=True) + if vbar is not None: + vbar.prioritize() + + #We have way more tools for acceleration on comfy weight offloading, so always + #prioritize the non-comfy weights (note the order reverse). + loading = self._load_list(prio_comfy_cast_weights=True) + loading.sort(reverse=True) + + for x in loading: + _, _, _, n, m, params = x + + def set_dirty(item, dirty): + if dirty or not hasattr(item, "_v_signature"): + item._v_signature = None + if dirty: + comfy.pinned_memory.unpin_memory(item) + + def setup_param(self, m, n, param_key): + nonlocal num_patches + key = "{}.{}".format(n, param_key) + + weight_function = [] + + weight, _, _ = get_key_weight(self.model, key) + if key in self.patches: + setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches)) + num_patches += 1 + else: + setattr(m, param_key + "_lowvram_function", None) + + if key in self.weight_wrapper_patches: + weight_function.extend(self.weight_wrapper_patches[key]) + setattr(m, param_key + "_function", weight_function) + return comfy.memory_management.vram_aligned_size(weight) + + if hasattr(m, "comfy_cast_weights"): + m.comfy_cast_weights = True + m.pin_failed = False + m.seed_key = n + set_dirty(m, dirty) + + v_weight_size = 0 + v_weight_size += setup_param(self, m, n, "weight") + v_weight_size += setup_param(self, m, n, "bias") + + if vbar is not None and not hasattr(m, "_v"): + m._v = vbar.alloc(v_weight_size) + allocated_size += v_weight_size + + else: + for param in params: + key = "{}.{}".format(n, param) + weight, _, _ = get_key_weight(self.model, key) + weight.seed_key = key + set_dirty(weight, dirty) + weight_size = weight.numel() * weight.element_size() + if vbar is not None and not hasattr(weight, "_v"): + weight._v = vbar.alloc(weight_size) + allocated_size += weight_size + + logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.") + + self.model.device = device_to + self.model.current_weight_patches_uuid = self.patches_uuid + + for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD): + #These are all super dangerous. Who knows what the custom nodes actually do here... + callback(self, device_to, lowvram_model_memory, force_patch_weights, full_load) + + self.apply_hooks(self.forced_hooks, force_apply=True) + + def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=False): + assert not force_patch_weights #See above + assert self.load_device != torch.device("cpu") + + vbar = self._vbar_get() + return 0 if vbar is None else vbar.free_memory(memory_to_free) + + def partially_unload_ram(self, ram_to_unload): + loading = self._load_list(prio_comfy_cast_weights=True) + for x in loading: + _, _, _, _, m, _ = x + ram_to_unload -= comfy.pinned_memory.unpin_memory(m) + if ram_to_unload <= 0: + return + + def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False): + #This isn't used by the core at all and can only be to load a model out of + #the control of proper model_managment. If you are a custom node author reading + #this, the correct pattern is to call load_models_gpu() to get a proper + #managed load of your model. + assert not load_weights + return super().patch_model(load_weights=load_weights, force_patch_weights=force_patch_weights) + + def unpatch_model(self, device_to=None, unpatch_weights=True): + super().unpatch_model(device_to=None, unpatch_weights=False) + + if unpatch_weights: + self.partially_unload_ram(1e32) + self.partially_unload(None) + + def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): + assert not force_patch_weights #See above + with self.use_ejected(skip_and_inject_on_exit_only=True): + dirty = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid) + + self.unpatch_model(self.offload_device, unpatch_weights=False) + self.patch_model(load_weights=False) + + try: + self.load(device_to, dirty=dirty) + except Exception as e: + self.detach() + raise e + #ModelPatcher::partially_load returns a number on what got loaded but + #nothing in core uses this and we have no data in the Dynamic world. Hit + #the custom node devs with a None rather than a 0 that would mislead any + #logic they might have. + return None + + def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter): + assert False #Should be unreachable - we dont ever cache in the new implementation + + def patch_hook_weight_to_device(self, hooks: comfy.hooks.HookGroup, combined_patches: dict, key: str, original_weights: dict, memory_counter: MemoryCounter): + if key not in combined_patches: + return + + raise RuntimeError("Hooks not implemented in ModelPatcherDynamic. Please remove --fast arguments form ComfyUI startup") + + def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None: + pass + +CoreModelPatcher = ModelPatcher diff --git a/comfy/ops.py b/comfy/ops.py index 276081addd4a..ce6bd012aaa7 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -24,6 +24,11 @@ import comfy.rmsnorm import json import comfy.memory_management +import comfy.pinned_memory +import comfy.utils + +import comfy_aimdo.model_vbar +import comfy_aimdo.torch def run_every_op(): if torch.compiler.is_compiling(): @@ -73,7 +78,108 @@ def cast_to_input(weight, input, non_blocking=False, copy=True): return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) -def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False): +def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype): + offload_stream = None + xfer_dest = None + + signature = comfy_aimdo.model_vbar.vbar_fault(s._v) + if signature is not None: + xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device) + resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature) + + if not resident: + + xfer_source = [ s.weight, s.bias ] + + pin = comfy.pinned_memory.get_pin(s) + if pin is not None: + xfer_source = [ pin ] + resident = True #If pinned data exists, it always has LowVram already applied + + dest_size = comfy.memory_management.vram_aligned_size(xfer_source) + offload_stream = comfy.model_management.get_offload_stream(device) + if xfer_dest is None and offload_stream is not None: + xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s) + if xfer_dest is None: + offload_stream = comfy.model_management.get_offload_stream(device) + xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s) + if xfer_dest is None: + xfer_dest = torch.empty((dest_size,), dtype=torch.uint8, device=device) + offload_stream = None + + #send it over + comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream) + comfy.model_management.sync_stream(device, offload_stream) + + pin = None + if signature is not None: + #If we are able to increase our load level (e.g. user reduces resolution or batch number) + #reclaim the pin previously used for offload. + comfy.pinned_memory.unpin_memory(s) + elif not resident: + #prepare a new pin + assert comfy.pinned_memory.get_pin(s) is None + comfy.pinned_memory.pin_memory(s) + pin = comfy.pinned_memory.get_pin(s) + + params = comfy.memory_management.interpret_gathered_like([s.weight, s.bias], xfer_dest) + weight = params[0] + bias = params[1] + + def post_cast(s, param_key, x, dtype, resident, update_weight): + lowvram_fn = getattr(s, param_key + "_lowvram_function", None) + hook_fn = getattr(s, param_key + "_hooks", None) + fns = getattr(s, param_key + "_function", []) + + orig = x + q_layout = None + + def to_dequant(tensor, dtype): + tensor = tensor.to(dtype=dtype) + if isinstance(tensor, QuantizedTensor): + tensor = tensor.dequantize() + return tensor + + if orig.dtype != dtype or len(fns) > 0: + x = to_dequant(x, dtype) + if not resident and lowvram_fn is not None: + x = to_dequant(x, dtype if compute_dtype is None else compute_dtype) + #FIXME: this is not accurate, we need to be sensitive to the compute dtype + x = lowvram_fn(x) + if (isinstance(orig, QuantizedTensor) and + (orig.dtype == dtype and len(fns) == 0 or update_weight)): + seed = comfy.utils.string_to_seed(s.seed_key) + y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed) + if orig.dtype == dtype and len(fns) == 0: + #The layer actually wants our freshly saved QT + x = y + else: + y = x + if update_weight: + orig.copy_(y) + for f in fns: + x = f(x) + return x + + update_weight = signature is not None or pin is not None + + weight = post_cast(s, "weight", weight, dtype, resident, update_weight) + if s.bias is not None: + bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight) + s._v_signature=signature + + if pin is not None: + xfer_dest = comfy.memory_management.interpret_gathered_like([ pin ], xfer_dest)[0] + if offload_stream is not None: + #FIXME: if post cast didnt do anything this sync is un-needed + offload_stream.wait_stream(comfy.model_management.current_stream(device)) + comfy.model_management.cast_to(xfer_dest, device=pin.device, non_blocking=non_blocking, stream=offload_stream, r=pin) + + #FIXME: weird offload return protocol + return weight, bias, (offload_stream, device if signature is not None else None, None) + + +def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None): # NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass # offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This # will add async-offload support to your cast and improve performance. @@ -88,6 +194,11 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of if device is None: device = input.device + non_blocking = comfy.model_management.device_supports_non_blocking(device) + + if hasattr(s, "_v"): + return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype) + if offloadable and (device != s.weight.device or (s.bias is not None and device != s.bias.device)): offload_stream = comfy.model_management.get_offload_stream(device) @@ -108,8 +219,6 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of weight = params[0] bias = params[1] - non_blocking = comfy.model_management.device_supports_non_blocking(device) - weight_has_function = len(s.weight_function) > 0 bias_has_function = len(s.bias_function) > 0 @@ -146,14 +255,20 @@ def uncast_bias_weight(s, weight, bias, offload_stream): if offload_stream is None: return os, weight_a, bias_a = offload_stream + device=None + #FIXME: This is not good RTTI + if not isinstance(weight_a, torch.Tensor): + comfy_aimdo.model_vbar.vbar_unpin(s._v) + device = weight_a if os is None: return - if weight_a is not None: - device = weight_a.device - else: - if bias_a is None: - return - device = bias_a.device + if device is None: + if weight_a is not None: + device = weight_a.device + else: + if bias_a is None: + return + device = bias_a.device os.wait_stream(comfy.model_management.current_stream(device)) @@ -670,8 +785,8 @@ def state_dict(self, *args, destination=None, prefix="", **kwargs): def _forward(self, input, weight, bias): return torch.nn.functional.linear(input, weight, bias) - def forward_comfy_cast_weights(self, input): - weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + def forward_comfy_cast_weights(self, input, compute_dtype=None): + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype) x = self._forward(input, weight, bias) uncast_bias_weight(self, weight, bias, offload_stream) return x @@ -681,6 +796,8 @@ def forward(self, input, *args, **kwargs): input_shape = input.shape reshaped_3d = False + #If cast needs to apply lora, it should be done in the compute dtype + compute_dtype = input.dtype if (getattr(self, 'layout_type', None) is not None and not isinstance(input, QuantizedTensor) and not self._full_precision_mm and @@ -699,7 +816,8 @@ def forward(self, input, *args, **kwargs): scale = comfy.model_management.cast_to_device(scale, input.device, None) input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale) - output = self.forward_comfy_cast_weights(input) + + output = self.forward_comfy_cast_weights(input, compute_dtype) # Reshape output back to 3D if input was 3D if reshaped_3d: From c862c42311985041da8320bc51fd77b10f4cc623 Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 13 Jan 2026 15:46:39 +1000 Subject: [PATCH 10/51] models: Use CoreModelPatcher Use CoreModelPatcher for all internal ModelPatcher implementations. This drives conditional use of the aimdo feature, while making sure custom node packs get to keep ModelPatcher unchanged for the moment. --- comfy/audio_encoders/audio_encoders.py | 4 +-- comfy/clip_vision.py | 4 +-- comfy/controlnet.py | 2 +- comfy/ldm/hunyuan_video/upsampler.py | 4 +-- comfy/model_base.py | 4 +-- comfy/sd.py | 39 +++++++++++++++----------- comfy_extras/nodes_model_patch.py | 6 ++-- 7 files changed, 34 insertions(+), 29 deletions(-) diff --git a/comfy/audio_encoders/audio_encoders.py b/comfy/audio_encoders/audio_encoders.py index 46ef21c95cf1..16998af9405d 100644 --- a/comfy/audio_encoders/audio_encoders.py +++ b/comfy/audio_encoders/audio_encoders.py @@ -25,11 +25,11 @@ def __init__(self, config): elif model_type == "whisper3": self.model = WhisperLargeV3(**model_config) self.model.eval() - self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) + self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) self.model_sample_rate = 16000 def load_sd(self, sd): - return self.model.load_state_dict(sd, strict=False) + return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic()) def get_sd(self): return self.model.state_dict() diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index b28bf636c3f7..1691fca817b7 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -47,10 +47,10 @@ def __init__(self, json_config): self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast) self.model.eval() - self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) + self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) def load_sd(self, sd): - return self.model.load_state_dict(sd, strict=False) + return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic()) def get_sd(self): return self.model.state_dict() diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 0b5e30f52a62..9e1e704e077b 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -203,7 +203,7 @@ def __init__(self, control_model=None, global_average_pooling=False, compression self.control_model = control_model self.load_device = load_device if control_model is not None: - self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device()) + self.control_model_wrapped = comfy.model_patcher.CoreModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device()) self.compression_ratio = compression_ratio self.global_average_pooling = global_average_pooling diff --git a/comfy/ldm/hunyuan_video/upsampler.py b/comfy/ldm/hunyuan_video/upsampler.py index 51b6d1da8325..1f68144e2cbe 100644 --- a/comfy/ldm/hunyuan_video/upsampler.py +++ b/comfy/ldm/hunyuan_video/upsampler.py @@ -109,10 +109,10 @@ def __init__(self, model_type, config): self.model_class = UPSAMPLERS.get(model_type) self.model = self.model_class(**config).eval() - self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) + self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) def load_sd(self, sd): - return self.model.load_state_dict(sd, strict=True) + return self.model.load_state_dict(sd, strict=True, assign=self.patcher.is_dynamic()) def get_sd(self): return self.model.state_dict() diff --git a/comfy/model_base.py b/comfy/model_base.py index 4a248beec2ab..8aeb057f1a45 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -299,7 +299,7 @@ def extra_conds(self, **kwargs): return out - def load_model_weights(self, sd, unet_prefix=""): + def load_model_weights(self, sd, unet_prefix="", assign=False): to_load = {} keys = list(sd.keys()) for k in keys: @@ -307,7 +307,7 @@ def load_model_weights(self, sd, unet_prefix=""): to_load[k[len(unet_prefix):]] = sd.pop(k) to_load = self.model_config.process_unet_state_dict(to_load) - m, u = self.diffusion_model.load_state_dict(to_load, strict=False) + m, u = self.diffusion_model.load_state_dict(to_load, strict=False, assign=assign) if len(m) > 0: logging.warning("unet missing: {}".format(m)) diff --git a/comfy/sd.py b/comfy/sd.py index ee0ebd50df80..accacdc6efa8 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -229,7 +229,7 @@ def __init__(self, target=None, embedding_directory=None, no_init=False, tokeniz logging.warning("Had to shift TE back.") self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) - self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) + self.patcher = comfy.model_patcher.CoreModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) #Match torch.float32 hardcode upcast in TE implemention self.patcher.set_model_compute_dtype(torch.float32) self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram @@ -389,7 +389,7 @@ def encode(self, text): def load_sd(self, sd, full_model=False): if full_model: - return self.cond_stage_model.load_state_dict(sd, strict=False) + return self.cond_stage_model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic()) else: return self.cond_stage_model.load_sd(sd) @@ -765,13 +765,6 @@ def estimate_memory(shape, dtype, num_layers = 16, kv_cache_multiplier = 2): self.first_stage_model = AutoencoderKL(**(config['params'])) self.first_stage_model = self.first_stage_model.eval() - m, u = self.first_stage_model.load_state_dict(sd, strict=False) - if len(m) > 0: - logging.warning("Missing VAE keys {}".format(m)) - - if len(u) > 0: - logging.debug("Leftover VAE keys {}".format(u)) - if device is None: device = model_management.vae_device() self.device = device @@ -782,7 +775,18 @@ def estimate_memory(shape, dtype, num_layers = 16, kv_cache_multiplier = 2): self.first_stage_model.to(self.vae_dtype) self.output_device = model_management.intermediate_device() - self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) + mp = comfy.model_patcher.CoreModelPatcher + if self.disable_offload: + mp = comfy.model_patcher.ModelPatcher + self.patcher = mp(self.first_stage_model, load_device=self.device, offload_device=offload_device) + + m, u = self.first_stage_model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic()) + if len(m) > 0: + logging.warning("Missing VAE keys {}".format(m)) + + if len(u) > 0: + logging.debug("Leftover VAE keys {}".format(u)) + logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype)) self.model_size() @@ -1432,7 +1436,7 @@ def load_gligen(ckpt_path): model = gligen.load_gligen(data) if model_management.should_use_fp16(): model = model.half() - return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device()) + return comfy.model_patcher.CoreModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device()) def model_detection_error_hint(path, state_dict): filename = os.path.basename(path) @@ -1520,7 +1524,8 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c if output_model: inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype) model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device) - model.load_model_weights(sd, diffusion_model_prefix) + model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device()) + model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic()) if output_vae: vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True) @@ -1563,7 +1568,6 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c logging.debug("left over keys: {}".format(left_over)) if output_model: - model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device()) if inital_load_device != torch.device("cpu"): logging.info("loaded diffusion model directly to GPU") model_management.load_models_gpu([model_patcher], force_full_load=True) @@ -1655,13 +1659,14 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): model_config.optimizations["fp8"] = True model = model_config.get_model(new_sd, "") - model = model.to(offload_device) - model.load_model_weights(new_sd, "") + model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=load_device, offload_device=offload_device) + if not model_management.is_device_cpu(offload_device): + model.to(offload_device) + model.load_model_weights(new_sd, "", assign=model_patcher.is_dynamic()) left_over = sd.keys() if len(left_over) > 0: logging.info("left over keys in diffusion model: {}".format(left_over)) - return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device) - + return model_patcher def load_diffusion_model(unet_path, model_options={}): sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True) diff --git a/comfy_extras/nodes_model_patch.py b/comfy_extras/nodes_model_patch.py index 82c4754a3f79..53b87e9fc5c5 100644 --- a/comfy_extras/nodes_model_patch.py +++ b/comfy_extras/nodes_model_patch.py @@ -267,9 +267,9 @@ def load_model_patch(self, name): device=comfy.model_management.unet_offload_device(), operations=comfy.ops.manual_cast) - model.load_state_dict(sd) - model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) - return (model,) + model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) + model.load_state_dict(sd, assign=self.model_patcher.is_dynamic()) + return (model_patcher,) class DiffSynthCnetPatch: From 469d7a62de29484c1710c6eef235c1ad94a591c4 Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 13 Jan 2026 15:49:07 +1000 Subject: [PATCH 11/51] execution: add aimdo primary pytorch cache integration We need to general pytorch cache defragmentation on an appropriate level for aimdo. Do in here on the per node basis, which has a reasonable chance of purging stale shapes out of the pytorch caching allocator and saving VRAM without costing too much garbage collector thrash. This looks like a lot of GC but because aimdo never fails from pytorch and saves the pytorch allocator from ever need to defrag out of demand, but it needs a oil change every now and then so we gotta do it. Doing it here also means the pytorch temps are cleared from task manager VRAM usage so user anxiety can go down a little when they see their vram drop back at the end of workflows inline with inference usage (rather than assuming full VRAM leaks). --- comfy/memory_management.py | 6 ++++++ execution.py | 20 +++++++++++++++++++- 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/comfy/memory_management.py b/comfy/memory_management.py index f8bca526322f..88b6da1e327f 100644 --- a/comfy/memory_management.py +++ b/comfy/memory_management.py @@ -1,6 +1,10 @@ import torch from comfy.quant_ops import QuantizedTensor +import comfy_aimdo.torch + +import logging + def vram_aligned_size(tensor): if isinstance(tensor, list): return sum([vram_aligned_size(t) for t in tensor]) @@ -49,3 +53,5 @@ def interpret_gathered_like(tensors, gathered): dest_views.append(actuals["data"]) return dest_views + +aimdo_allocator = comfy_aimdo.torch.CUDAPluggableAllocator() diff --git a/execution.py b/execution.py index 4b4f63c80595..f10a8795fc37 100644 --- a/execution.py +++ b/execution.py @@ -1,3 +1,4 @@ +import gc import copy import heapq import inspect @@ -9,9 +10,12 @@ from enum import Enum from typing import List, Literal, NamedTuple, Optional, Union import asyncio +from contextlib import nullcontext import torch +import comfy.pinned_memory +import comfy.memory_management import comfy.model_management from latent_preview import set_preview_method import nodes @@ -515,7 +519,21 @@ def execution_block_cb(block): def pre_execute_cb(call_index): # TODO - How to handle this with async functions without contextvars (which requires Python 3.12)? GraphBuilder.set_default_prefix(unique_id, call_index, 0) - output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data) + + #Do comfy_aimdo mempool chunking here on the per-node level. Multi-model workflows + #will cause all sorts of incompatible memory shapes to fragment the pytorch alloc + #that we just want to cull out each model run. + allocator = comfy.memory_management.aimdo_allocator + with nullcontext() if allocator is None else torch.cuda.use_mem_pool(torch.cuda.MemPool(allocator.allocator())): + output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data) + torch.cuda.synchronize() + if allocator is not None: + #FIXME: this is probably a little zealous + # Torch code comments says some stuff about not actually freeing tensors on mempool + #context release. Explicitly garbage collect now. + gc.collect() + torch.cuda.empty_cache() + if has_pending_tasks: pending_async_nodes[unique_id] = output_data unblock = execution_list.add_external_block(unique_id) From 04bf6ef0def90ea2e8ce8496f55e2edc115b0a49 Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 13 Jan 2026 15:54:49 +1000 Subject: [PATCH 12/51] main: Go live with --fast dynamic_vram Add the optional command line switch --fast dynamic_vram. This is mutually exclusing --high-vram and --gpu-only which contradict aimdos underlying feature. Add appropriate installation warning and a startup message, match the comfy debug level inconfiguring aimdo. Add comfy-aimdo pip requirement. This will safely stub to a nop for unsupported platforms. --- comfy/cli_args.py | 4 ++++ cuda_malloc.py | 7 ++++++- main.py | 35 ++++++++++++++++++++++++++++++++++- requirements.txt | 1 + 4 files changed, 45 insertions(+), 2 deletions(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 1716c3de7524..63daca8611de 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -159,6 +159,7 @@ class PerformanceFeature(enum.Enum): Fp8MatrixMultiplication = "fp8_matrix_mult" CublasOps = "cublas_ops" AutoTune = "autotune" + DynamicVRAM = "dynamic_vram" parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature)))) @@ -257,3 +258,6 @@ def is_valid_directory(path: str) -> str: # '--fast' is provided with a list of performance features, use that list else: args.fast = set(args.fast) + +def enables_dynamic_vram(): + return PerformanceFeature.DynamicVRAM in args.fast and not args.highvram and not args.gpu_only diff --git a/cuda_malloc.py b/cuda_malloc.py index 00ee7b633050..3c7c8593ebe1 100644 --- a/cuda_malloc.py +++ b/cuda_malloc.py @@ -1,8 +1,10 @@ import os import importlib.util -from comfy.cli_args import args, PerformanceFeature +from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram import subprocess +import comfy_aimdo.control + #Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import. def get_gpu_names(): if os.name == 'nt': @@ -85,6 +87,9 @@ def cuda_malloc_supported(): except: pass +if enables_dynamic_vram() and comfy_aimdo.control.lib is not None: + args.cuda_malloc = False + os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "" if args.disable_cuda_malloc: args.cuda_malloc = False diff --git a/main.py b/main.py index 37b06c1faada..52f11bffff4a 100644 --- a/main.py +++ b/main.py @@ -5,7 +5,7 @@ import importlib.util import folder_paths import time -from comfy.cli_args import args +from comfy.cli_args import args, enables_dynamic_vram from app.logger import setup_logger from app.assets.scanner import seed_assets import itertools @@ -173,6 +173,30 @@ def execute_script(script_path): if 'torch' in sys.modules: logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.") + +has_aimdo = False + +import comfy_aimdo.control + +if comfy_aimdo.control.lib is not None: + if args.verbose == 'DEBUG': + comfy_aimdo.control.set_log_debug() + elif args.verbose == 'CRITICAL': + comfy_aimdo.control.set_log_critical() + elif args.verbose == 'ERROR': + comfy_aimdo.control.set_log_error() + elif args.verbose == 'WARNING': + comfy_aimdo.control.set_log_warning() + else: #INFO + comfy_aimdo.control.set_log_info() + + if enables_dynamic_vram(): + logging.info("DynamicVRAM support detected and enabled") + has_aimdo = True +else: + if enables_dynamic_vram(): + logging.info("No native comfy-aimdo install detected. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows") + import comfy.utils import execution @@ -184,6 +208,15 @@ def execute_script(script_path): import app.logger import hook_breaker_ac10a0 +import comfy.memory_management +import comfy.model_patcher + +if has_aimdo: + comfy.model_patcher.CoreModelPatcher = comfy.model_patcher.ModelPatcherDynamic + comfy_aimdo.control.init_vram_guard(comfy.model_management.get_torch_device().index) +else: + comfy.memory_management.aimdo_allocator = None + def cuda_malloc_warning(): device = comfy.model_management.get_torch_device() device_name = comfy.model_management.get_torch_device_name(device) diff --git a/requirements.txt b/requirements.txt index 666a0e35b455..ff9e330d5842 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,6 +22,7 @@ alembic SQLAlchemy av>=14.2.0 comfy-kitchen>=0.2.7 +comfy-aimdo>=0.1.0 requests #non essential dependencies: From ff434ea98cc071ab7f026ed94a7d520f778aabe2 Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 13 Jan 2026 19:37:46 +1000 Subject: [PATCH 13/51] mm: fix sync Sync before deleting anything. --- comfy/model_management.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index c5a22e04c65f..af59592ea954 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1099,9 +1099,9 @@ def get_cast_buffer(offload_stream, device, size, ref): return None if cast_buffer is not None and cast_buffer.numel() > 50 * (1024 ** 2): #I want my wrongly sized 50MB+ of VRAM back from the caching allocator right now + torch.cuda.synchronize() del STREAM_CAST_BUFFERS[offload_stream] del cast_buffer - torch.cuda.synchronize() torch.cuda.empty_cache() with wf_context: cast_buffer = torch.empty((size), dtype=torch.int8, device=device) @@ -1115,8 +1115,8 @@ def get_cast_buffer(offload_stream, device, size, ref): def reset_cast_buffers(): global LARGEST_CASTED_WEIGHT LARGEST_CASTED_WEIGHT = (None, 0) - STREAM_CAST_BUFFERS.clear() torch.cuda.synchronize() + STREAM_CAST_BUFFERS.clear() torch.cuda.empty_cache() def get_offload_stream(device): From e2d62b8f80e343ab5f303644f76728245287c634 Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 13 Jan 2026 19:38:36 +1000 Subject: [PATCH 14/51] write better tx commentary --- comfy/ops.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/comfy/ops.py b/comfy/ops.py index ce6bd012aaa7..476d521eb227 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -170,8 +170,11 @@ def to_dequant(tensor, dtype): if pin is not None: xfer_dest = comfy.memory_management.interpret_gathered_like([ pin ], xfer_dest)[0] + #FIXME: This might be the wrong thing to do. Some reading suggests the DMA engine + #is posted writes and the compute stream could just fire and forget here. That + #would save this sync and some stalling on the offload stream that is better off + #running ahead to the next layer to read. if offload_stream is not None: - #FIXME: if post cast didnt do anything this sync is un-needed offload_stream.wait_stream(comfy.model_management.current_stream(device)) comfy.model_management.cast_to(xfer_dest, device=pin.device, non_blocking=non_blocking, stream=offload_stream, r=pin) From e8c9977973d371130c056ae37be3797e538db099 Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 13 Jan 2026 19:39:00 +1000 Subject: [PATCH 15/51] add missing del on unpin --- comfy/pinned_memory.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/pinned_memory.py b/comfy/pinned_memory.py index be303b4f18db..923872dace6d 100644 --- a/comfy/pinned_memory.py +++ b/comfy/pinned_memory.py @@ -31,4 +31,5 @@ def unpin_memory(module): return 0 size = module._pin.numel() * module._pin.element_size() comfy.model_management.unpin_memory(module._pin) + del module._pin return size From 7a18963a3384c82bba09839a13481897874dd038 Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 13 Jan 2026 19:40:52 +1000 Subject: [PATCH 16/51] misc cleanup --- comfy/pinned_memory.py | 12 ++++-------- execution.py | 1 - requirements.txt | 2 +- 3 files changed, 5 insertions(+), 10 deletions(-) diff --git a/comfy/pinned_memory.py b/comfy/pinned_memory.py index 923872dace6d..dfb5fcfcd408 100644 --- a/comfy/pinned_memory.py +++ b/comfy/pinned_memory.py @@ -14,14 +14,10 @@ def pin_memory(module): #FIXME: This is a RAM cache trigger event params = [ module.weight, module.bias ] size = comfy.memory_management.vram_aligned_size(params) - try: - pin = torch.empty((size,), dtype=torch.uint8) - if comfy.model_management.pin_memory(pin): - module._pin = pin - else: - module.pin_failed = True - return False - except: + pin = torch.empty((size,), dtype=torch.uint8) + if comfy.model_management.pin_memory(pin): + module._pin = pin + else: module.pin_failed = True return False return True diff --git a/execution.py b/execution.py index f10a8795fc37..b0812be1a14e 100644 --- a/execution.py +++ b/execution.py @@ -14,7 +14,6 @@ import torch -import comfy.pinned_memory import comfy.memory_management import comfy.model_management from latent_preview import set_preview_method diff --git a/requirements.txt b/requirements.txt index ff9e330d5842..e39a6c80c533 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,7 +22,7 @@ alembic SQLAlchemy av>=14.2.0 comfy-kitchen>=0.2.7 -comfy-aimdo>=0.1.0 +comfy-aimdo>=0.1.1 requests #non essential dependencies: From 01ca403bed124e76cda0e2d0fd03e35f5f09ff37 Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 13 Jan 2026 20:29:13 +1000 Subject: [PATCH 17/51] ruff --- comfy/memory_management.py | 3 --- comfy/model_patcher.py | 5 ++--- comfy/ops.py | 4 +--- comfy/pinned_memory.py | 1 - comfy/samplers.py | 1 - 5 files changed, 3 insertions(+), 11 deletions(-) diff --git a/comfy/memory_management.py b/comfy/memory_management.py index 88b6da1e327f..4169e853cc98 100644 --- a/comfy/memory_management.py +++ b/comfy/memory_management.py @@ -1,10 +1,7 @@ -import torch from comfy.quant_ops import QuantizedTensor import comfy_aimdo.torch -import logging - def vram_aligned_size(tensor): if isinstance(tensor, list): return sum([vram_aligned_size(t) for t in tensor]) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 77e7eec90150..b30a9c63daf4 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -24,7 +24,6 @@ import logging import math import uuid -import types from typing import Callable, Optional import torch @@ -1381,7 +1380,7 @@ def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_ unet_state_dict = self.model.diffusion_model.state_dict() for k, v in unet_state_dict.items(): op_keys = k.rsplit('.', 1) - if (len(op_keys) < 2) or not op_keys[1] in ["weight", "bias"]: + if (len(op_keys) < 2) or op_keys[1] not in ["weight", "bias"]: continue try: op = comfy.utils.get_attr(self.model.diffusion_model, op_keys[0]) @@ -1467,7 +1466,7 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False #Full load doesn't make sense as we dont actually have any loader capability here and #now. - assert not full_load; + assert not full_load assert device_to == self.load_device diff --git a/comfy/ops.py b/comfy/ops.py index 476d521eb227..fc401262db72 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -128,11 +128,9 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu def post_cast(s, param_key, x, dtype, resident, update_weight): lowvram_fn = getattr(s, param_key + "_lowvram_function", None) - hook_fn = getattr(s, param_key + "_hooks", None) fns = getattr(s, param_key + "_function", []) orig = x - q_layout = None def to_dequant(tensor, dtype): tensor = tensor.to(dtype=dtype) @@ -218,7 +216,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of if cast_buffer is None: offload_stream = comfy.model_management.get_offload_stream(device) cast_buffer = comfy.model_management.get_cast_buffer(offload_stream, device, cast_buffer_size, s) - params = interpret_gathered_like([ s.weight, s.bias ], cast_buffer) + params = comfy.memory_management.interpret_gathered_like([ s.weight, s.bias ], cast_buffer) weight = params[0] bias = params[1] diff --git a/comfy/pinned_memory.py b/comfy/pinned_memory.py index dfb5fcfcd408..650e27a10420 100644 --- a/comfy/pinned_memory.py +++ b/comfy/pinned_memory.py @@ -1,5 +1,4 @@ import torch -import logging import comfy.model_management import comfy.memory_management diff --git a/comfy/samplers.py b/comfy/samplers.py index d495ca203895..8b9782956046 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -9,7 +9,6 @@ import torch from functools import partial import collections -from comfy import model_management import math import logging import comfy.sampler_helpers From 9f701f69dc8cc97d82b9b304e7d7ab79539f040d Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 13 Jan 2026 21:13:35 +1000 Subject: [PATCH 18/51] sd: empty cache on tiler fallback This is needed for aimdo where the cache cant self recover from fragmentation. It is however a good thing to do anyway after an OOM so make it unconditional. --- comfy/sd.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/comfy/sd.py b/comfy/sd.py index accacdc6efa8..905c81ac96a6 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -920,6 +920,7 @@ def decode(self, samples_in, vae_options={}): do_tile = True if do_tile: + torch.cuda.empty_cache() dims = samples_in.ndim - 2 if dims == 1 or self.extra_1d_channel is not None: pixel_samples = self.decode_tiled_1d(samples_in) @@ -995,6 +996,7 @@ def encode(self, pixel_samples): do_tile = True if do_tile: + torch.cuda.empty_cache() if self.latent_dim == 3: tile = 256 overlap = tile // 4 From 0983fb88ccc73da71d11cf0c45e86ac6e978645a Mon Sep 17 00:00:00 2001 From: Rattus Date: Thu, 15 Jan 2026 12:43:10 +1000 Subject: [PATCH 19/51] clip: support assign load when taking clip from a ckpt --- comfy/sd.py | 10 ++++++++++ comfy/sd1_clip.py | 2 +- comfy/text_encoders/lt.py | 2 +- 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 905c81ac96a6..f3fc56997638 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -391,6 +391,16 @@ def load_sd(self, sd, full_model=False): if full_model: return self.cond_stage_model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic()) else: + can_assign = self.patcher.is_dynamic() + self.cond_stage_model.can_assign_sd = can_assign + + # The CLIP models are a pretty complex web of wrappers and its + # a bit of an API change to plumb this all the way through. + # So spray paint the model with this flag that the loading + # nn.Module can then inspect for itself. + for m in self.cond_stage_model.modules(): + m.can_assign_sd = can_assign + return self.cond_stage_model.load_sd(sd) def get_sd(self): diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index d4f22120bbbb..9ecfc9c5571c 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -297,7 +297,7 @@ def encode(self, tokens): return self(tokens) def load_sd(self, sd): - return self.transformer.load_state_dict(sd, strict=False) + return self.transformer.load_state_dict(sd, strict=False, assign=getattr(self, "can_assign_sd", False)) def parse_parentheses(string): result = [] diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index e491619640e3..26573fb12059 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -125,7 +125,7 @@ def load_sd(self, sd): for prefix, component in [("text_embedding_projection.", self.text_embedding_projection), ("video_embeddings_connector.", self.video_embeddings_connector), ("audio_embeddings_connector.", self.audio_embeddings_connector)]: component_sd = {k.replace(prefix, ""): v for k, v in sdo.items() if k.startswith(prefix)} if component_sd: - missing, unexpected = component.load_state_dict(component_sd, strict=False) + missing, unexpected = component.load_state_dict(component_sd, strict=False, assign=getattr(self, "can_assign_sd", False)) missing_all.extend([f"{prefix}{k}" for k in missing]) unexpected_all.extend([f"{prefix}{k}" for k in unexpected]) From f3021770a497b74b29326ce237ad91470feca1ae Mon Sep 17 00:00:00 2001 From: Rattus Date: Thu, 15 Jan 2026 15:35:20 +1000 Subject: [PATCH 20/51] sampling: improve progress meter accuracy for dynamic loading --- comfy/k_diffusion/sampling.py | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 0949dee44cf0..2a08066a05a7 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1,11 +1,12 @@ import math +import time from functools import partial from scipy import integrate import torch from torch import nn import torchsde -from tqdm.auto import trange, tqdm +from tqdm.auto import trange as trange_, tqdm from . import utils from . import deis @@ -13,6 +14,37 @@ import comfy.model_patcher import comfy.model_sampling +import comfy.memory_management + + +def trange(*args, **kwargs): + if comfy.memory_management.aimdo_allocator == None: + return trange_(*args, **kwargs) + + pbar = trange_(*args, **kwargs, smoothing=1.0) + pbar._i = 0 + pbar.set_postfix_str(" Model Initializing ... ") + + _update = pbar.update + initialized = False + + def warmup_update(n=1): + pbar._i += 1 + if pbar._i == 1: + pbar.i1_time = time.time() + pbar.set_postfix_str(" Model Initialization complete! ") + elif pbar._i == 2: + #bring forward the effective start time based the the diff between first and second iteration + #to attempt to remove load overhead from the final step rate estimate. + pbar.start_t = pbar.i1_time - (time.time() - pbar.i1_time) + pbar.set_postfix_str("") + + _update(n) + + pbar.update = warmup_update + return pbar + + def append_zero(x): return torch.cat([x, x.new_zeros([1])]) From 390805673071e739eb54b2def07fb5b361b7cc41 Mon Sep 17 00:00:00 2001 From: Rattus Date: Thu, 15 Jan 2026 17:40:32 +1000 Subject: [PATCH 21/51] main: Rework aimdo into process Be more tolerant of unsupported platforms and fallback properly. Fixes crash when cuda is not installed at all. --- comfy/memory_management.py | 4 +-- cuda_malloc.py | 2 +- main.py | 51 +++++++++++++++++--------------------- 3 files changed, 25 insertions(+), 32 deletions(-) diff --git a/comfy/memory_management.py b/comfy/memory_management.py index 4169e853cc98..3765de0a199e 100644 --- a/comfy/memory_management.py +++ b/comfy/memory_management.py @@ -1,7 +1,5 @@ from comfy.quant_ops import QuantizedTensor -import comfy_aimdo.torch - def vram_aligned_size(tensor): if isinstance(tensor, list): return sum([vram_aligned_size(t) for t in tensor]) @@ -51,4 +49,4 @@ def interpret_gathered_like(tensors, gathered): return dest_views -aimdo_allocator = comfy_aimdo.torch.CUDAPluggableAllocator() +aimdo_allocator = None diff --git a/cuda_malloc.py b/cuda_malloc.py index 3c7c8593ebe1..d08162cbc674 100644 --- a/cuda_malloc.py +++ b/cuda_malloc.py @@ -87,7 +87,7 @@ def cuda_malloc_supported(): except: pass -if enables_dynamic_vram() and comfy_aimdo.control.lib is not None: +if enables_dynamic_vram() and comfy_aimdo.control.init(0): args.cuda_malloc = False os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "" diff --git a/main.py b/main.py index 52f11bffff4a..b8c951375130 100644 --- a/main.py +++ b/main.py @@ -174,29 +174,6 @@ def execute_script(script_path): logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.") -has_aimdo = False - -import comfy_aimdo.control - -if comfy_aimdo.control.lib is not None: - if args.verbose == 'DEBUG': - comfy_aimdo.control.set_log_debug() - elif args.verbose == 'CRITICAL': - comfy_aimdo.control.set_log_critical() - elif args.verbose == 'ERROR': - comfy_aimdo.control.set_log_error() - elif args.verbose == 'WARNING': - comfy_aimdo.control.set_log_warning() - else: #INFO - comfy_aimdo.control.set_log_info() - - if enables_dynamic_vram(): - logging.info("DynamicVRAM support detected and enabled") - has_aimdo = True -else: - if enables_dynamic_vram(): - logging.info("No native comfy-aimdo install detected. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows") - import comfy.utils import execution @@ -211,11 +188,29 @@ def execute_script(script_path): import comfy.memory_management import comfy.model_patcher -if has_aimdo: - comfy.model_patcher.CoreModelPatcher = comfy.model_patcher.ModelPatcherDynamic - comfy_aimdo.control.init_vram_guard(comfy.model_management.get_torch_device().index) -else: - comfy.memory_management.aimdo_allocator = None +import comfy_aimdo.control +import comfy_aimdo.torch + +if enables_dynamic_vram(): + if comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index): + if args.verbose == 'DEBUG': + comfy_aimdo.control.set_log_debug() + elif args.verbose == 'CRITICAL': + comfy_aimdo.control.set_log_critical() + elif args.verbose == 'ERROR': + comfy_aimdo.control.set_log_error() + elif args.verbose == 'WARNING': + comfy_aimdo.control.set_log_warning() + else: #INFO + comfy_aimdo.control.set_log_info() + + comfy.model_patcher.CoreModelPatcher = comfy.model_patcher.ModelPatcherDynamic + comfy.memory_management.aimdo_allocator = comfy_aimdo.torch.get_torch_allocator() + logging.info("DynamicVRAM support detected and enabled") + else: + logging.info("No working comfy-aimdo install detected. DynamicVRAM support disabled. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows") + comfy.memory_management.aimdo_allocator = None + def cuda_malloc_warning(): device = comfy.model_management.get_torch_device() From 5684c678daca6e47714f6eb2adf59e39fd807ceb Mon Sep 17 00:00:00 2001 From: Rattus Date: Thu, 15 Jan 2026 17:43:37 +1000 Subject: [PATCH 22/51] aimdo version bump --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index e39a6c80c533..c7fd356e7c29 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,7 +22,7 @@ alembic SQLAlchemy av>=14.2.0 comfy-kitchen>=0.2.7 -comfy-aimdo>=0.1.1 +comfy-aimdo>=0.1.2 requests #non essential dependencies: From b0580b83932360860c5f49522126889a80e93c70 Mon Sep 17 00:00:00 2001 From: Rattus Date: Fri, 16 Jan 2026 01:34:12 +1000 Subject: [PATCH 23/51] remove junk arg --- cuda_malloc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cuda_malloc.py b/cuda_malloc.py index d08162cbc674..b2182df374cb 100644 --- a/cuda_malloc.py +++ b/cuda_malloc.py @@ -87,7 +87,7 @@ def cuda_malloc_supported(): except: pass -if enables_dynamic_vram() and comfy_aimdo.control.init(0): +if enables_dynamic_vram() and comfy_aimdo.control.init(): args.cuda_malloc = False os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "" From 2f29e215f35a0ef2148a2101a191d6866ed8a530 Mon Sep 17 00:00:00 2001 From: Rattus Date: Sun, 18 Jan 2026 19:29:58 +1000 Subject: [PATCH 24/51] ops: defer creation of the parameters until state dict load If running on Windows, defer creation of the layer parameters until the state dict is loaded. This avoids a massive charge in windows commit charge spike when a model is created and not loaded. This problem doesnt exist on Linux as linux allows RAM overcommit, however windows does not. Before dynamic memory work this was also a non issue as every non-quant model would just immediate RAM load and need the memory anyway. Make the workaround windows specific, as there may be someone out there with some training from scratch workflow (which this might break), and assume said someone is on Linux. --- comfy/ops.py | 50 +++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 49 insertions(+), 1 deletion(-) diff --git a/comfy/ops.py b/comfy/ops.py index fc401262db72..7bdbb53a1bf0 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -19,7 +19,7 @@ import torch import logging import comfy.model_management -from comfy.cli_args import args, PerformanceFeature +from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram import comfy.float import comfy.rmsnorm import json @@ -280,6 +280,54 @@ class CastWeightBiasOp: class disable_weight_init: class Linear(torch.nn.Linear, CastWeightBiasOp): + + def __init__(self, in_features, out_features, bias=True, device=None, dtype=None): + if not comfy.model_management.WINDOWS or not enables_dynamic_vram(): + return super().__init__(in_features, out_features, bias, device, dtype) + + # Issue is with `torch.empty` still reserving the full memory for the layer. + # Windows doesn't over-commit memory so without this, We are momentarily commit + # charged for the weight even though we might zero-copy it when we load the + # state dict. If the commit charge exceeds the ceiling we can destabilize the + # system. + torch.nn.Module.__init__(self) + self.in_features = in_features + self.out_features = out_features + self.weight = None + self.bias = None + self.comfy_need_lazy_init_bias=bias + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, + strict, missing_keys, unexpected_keys, error_msgs): + + if not comfy.model_management.WINDOWS or not enables_dynamic_vram(): + return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs) + assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False) + prefix_len = len(prefix) + for k,v in state_dict.items(): + if k[prefix_len:] == "weight": + if not assign_to_params_buffers: + v = v.clone() + self.weight = torch.nn.Parameter(v, requires_grad=False) + elif k[prefix_len:] == "bias" and v is not None: + if not assign_to_params_buffers: + v = v.clone() + self.bias = torch.nn.Parameter(v, requires_grad=False) + else: + unexpected_keys.append(k) + + #Reconcile default construction of the weight if its missing. + if self.weight is None: + v = torch.zeros(self.in_features, self.out_features) + self.weight = torch.nn.Parameter(v, requires_grad=False) + missing_keys.append(prefix+"weight") + if self.bias is None and self.comfy_need_lazy_init_bias: + v = torch.zeros(self.out_features,) + self.bias = torch.nn.Parameter(v, requires_grad=False) + missing_keys.append(prefix+"bias") + + def reset_parameters(self): return None From cecf8c55f2ea5572d8d8eb04bdf1683e2ceb93d1 Mon Sep 17 00:00:00 2001 From: Rattus Date: Sun, 18 Jan 2026 22:00:50 +1000 Subject: [PATCH 25/51] implement lightweight safetensors with READ mmap The CoW MMAP as used by safetensors is hardcoded to CoW which forcibly consumes windows commit charge on a zero copy. RIP. Implement safetensors in pytorch itself with a READ mmap to not get commit charged for all our open models. --- comfy/utils.py | 75 +++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 65 insertions(+), 10 deletions(-) diff --git a/comfy/utils.py b/comfy/utils.py index 2d11dedbeba4..9471eed20f55 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -28,9 +28,13 @@ import itertools from torch.nn.functional import interpolate from einops import rearrange -from comfy.cli_args import args +from comfy.cli_args import args, enables_dynamic_vram import json import time +import mmap +import ctypes + +import packaging MMAP_TORCH_FILES = args.mmap_torch_files DISABLE_MMAP = args.disable_mmap @@ -56,21 +60,72 @@ def scalar(*args, **kwargs): else: logging.warning("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended as older versions of pytorch are no longer supported.") +# Current as of safetensors 0.7.0 +_TYPES = { + "F64": torch.float64, + "F32": torch.float32, + "F16": torch.float16, + "BF16": torch.bfloat16, + "I64": torch.int64, + "I32": torch.int32, + "I16": torch.int16, + "I8": torch.int8, + "U8": torch.uint8, + "BOOL": torch.bool, + "F8_E4M3": torch.float8_e4m3fn, + "F8_E5M2": torch.float8_e5m2, + "C64": torch.complex64, +} +if packaging.version.Version(torch.__version__) >= packaging.version.Version("2.3.0"): + _TYPES.update( + { + "U64": torch.uint64, + "U32": torch.uint32, + "U16": torch.uint16, + } + ) + +def load_safetensors(ckpt): + f = open(ckpt, "rb") + mapping = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) + + header_size = struct.unpack(" 0: message = e.args[0] From 607d15cad6ff20ed3639e83a3373c3340d6dad7f Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 20 Jan 2026 21:57:28 +1000 Subject: [PATCH 26/51] execution: remove per node gc.collect() This isn't worth it and the likelyhood of inference leaving a complex data-structure with cyclic reference behind is now. Remove it. We would replace it with a condition on nodes that actually touch the GPU which might be win. --- execution.py | 1 - 1 file changed, 1 deletion(-) diff --git a/execution.py b/execution.py index b0812be1a14e..6229d372de9e 100644 --- a/execution.py +++ b/execution.py @@ -530,7 +530,6 @@ def pre_execute_cb(call_index): #FIXME: this is probably a little zealous # Torch code comments says some stuff about not actually freeing tensors on mempool #context release. Explicitly garbage collect now. - gc.collect() torch.cuda.empty_cache() if has_pending_tasks: From 322d917991180f4a6b481f56a3d83c37a7810331 Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 20 Jan 2026 21:59:45 +1000 Subject: [PATCH 27/51] mm: remove left over hooks draft code This is phase 2 --- comfy/model_management.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index af59592ea954..4a3a0f886eda 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1201,13 +1201,6 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str r.copy_(weight, non_blocking=non_blocking) - #FIXME: remove hooks before PR - if hasattr(weight, "comfy_hook"): - dtype = r.dtype - r = weight.comfy_hook(r) - if r.dtype != dtype: - r = comfy.float.stochastic_rounding(r, dtype, seed=comfy.utils.string_to_seed(weight.seed_key)) - if signature is not None: v_tensor.copy_(r) comfy_aimdo.model_vbar.vbar_unpin(weight._v) From f3854f6d2e167780a9870fe4947e078920aa550c Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 20 Jan 2026 22:02:58 +1000 Subject: [PATCH 28/51] mp: handle blank __new__ call This is needed for deepcopy construction. We shouldnt really have deep copies of MP or MODynamic however this is a stay one in some controlnet flows. --- comfy/model_patcher.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index b30a9c63daf4..2f9be07d86d6 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -1399,8 +1399,8 @@ def __del__(self): class ModelPatcherDynamic(ModelPatcher): - def __new__(cls, model, load_device, offload_device, size=0, weight_inplace_update=False): - if comfy.model_management.is_device_cpu(load_device): + def __new__(cls, model=None, load_device=None, offload_device=None, size=0, weight_inplace_update=False): + if load_device is not None and comfy.model_management.is_device_cpu(load_device): #reroute to default MP for CPUs return ModelPatcher(model, load_device, offload_device, size, weight_inplace_update) return super().__new__(cls) From e54440a0c7a76feab43ec86ca44571fab7a873ab Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 20 Jan 2026 22:04:06 +1000 Subject: [PATCH 29/51] nodes_model_patch: fix copy-paste coding error --- comfy_extras/nodes_model_patch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_model_patch.py b/comfy_extras/nodes_model_patch.py index 53b87e9fc5c5..176e6bc2f766 100644 --- a/comfy_extras/nodes_model_patch.py +++ b/comfy_extras/nodes_model_patch.py @@ -268,7 +268,7 @@ def load_model_patch(self, name): operations=comfy.ops.manual_cast) model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) - model.load_state_dict(sd, assign=self.model_patcher.is_dynamic()) + model.load_state_dict(sd, assign=model_patcher.is_dynamic()) return (model_patcher,) From 12263b7fbf0c4f94abf8f7335d2924ff62b4e374 Mon Sep 17 00:00:00 2001 From: Rattus Date: Wed, 21 Jan 2026 14:43:14 +1000 Subject: [PATCH 30/51] ruff --- comfy/k_diffusion/sampling.py | 3 +-- comfy/utils.py | 4 ++-- execution.py | 1 - 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 2a08066a05a7..c0c51d51a252 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -18,7 +18,7 @@ def trange(*args, **kwargs): - if comfy.memory_management.aimdo_allocator == None: + if comfy.memory_management.aimdo_allocator is None: return trange_(*args, **kwargs) pbar = trange_(*args, **kwargs, smoothing=1.0) @@ -26,7 +26,6 @@ def trange(*args, **kwargs): pbar.set_postfix_str(" Model Initializing ... ") _update = pbar.update - initialized = False def warmup_update(n=1): pbar._i += 1 diff --git a/comfy/utils.py b/comfy/utils.py index 9471eed20f55..c620e75458fb 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -32,7 +32,6 @@ import json import time import mmap -import ctypes import packaging @@ -96,7 +95,8 @@ def load_safetensors(ckpt): sd = {} for name, info in header.items(): - if name == "__metadata__": continue + if name == "__metadata__": + continue start, end = info["data_offsets"] sd[name] = data_area[start:end].view(_TYPES[info["dtype"]]).view(info["shape"]) diff --git a/execution.py b/execution.py index 6229d372de9e..0c64cbe6ac20 100644 --- a/execution.py +++ b/execution.py @@ -1,4 +1,3 @@ -import gc import copy import heapq import inspect From 49809b7b2dbd6eddae84974e7b93662cc82af7ef Mon Sep 17 00:00:00 2001 From: Rattus Date: Wed, 21 Jan 2026 23:57:52 +1000 Subject: [PATCH 31/51] mp: big bump on the VBAR sizes Now that the model defined dtype is decoupled from the state_dict dtypes we need to be able to handle worst case scenario casts between the SD and VBAR. --- comfy/model_patcher.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 2f9be07d86d6..6b25436f24ab 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -1423,7 +1423,10 @@ def _vbar_get(self, create=False): return None vbar = self.model.dynamic_vbars.get(self.load_device, None) if create and vbar is None: - vbar = comfy_aimdo.model_vbar.ModelVBAR(self.model_size() * 1.2, self.load_device.index) + # x10. We dont know what model defined type casts we have in the vbar, but virtual address + # space is pretty free. This will cover someone casting an entire model from FP4 to FP32 + # with some left over. + vbar = comfy_aimdo.model_vbar.ModelVBAR(self.model_size() * 10, self.load_device.index) self.model.dynamic_vbars[self.load_device] = vbar return vbar From d1778d8085c91c9b7f360af1496a9bde401ecf7d Mon Sep 17 00:00:00 2001 From: Rattus Date: Thu, 22 Jan 2026 00:00:33 +1000 Subject: [PATCH 32/51] archive the model defined dtypes Scan created models and save off the dtypes as defined by the model creation process. This is needed for assign=True, which will override the dtypes. --- comfy/model_base.py | 2 ++ comfy/model_management.py | 5 +++++ comfy/ops.py | 2 ++ comfy/sd.py | 4 ++++ 4 files changed, 13 insertions(+) diff --git a/comfy/model_base.py b/comfy/model_base.py index 8aeb057f1a45..85acdb66ae9b 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -149,6 +149,8 @@ def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_mod self.model_type = model_type self.model_sampling = model_sampling(model_config, model_type) + comfy.model_management.archive_model_dtypes(self.diffusion_model) + self.adm_channels = unet_config.get("adm_in_channels", None) if self.adm_channels is None: self.adm_channels = 0 diff --git a/comfy/model_management.py b/comfy/model_management.py index 4a3a0f886eda..cdb9542c0a7f 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -774,6 +774,11 @@ def cleanup_models_gc(): logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__)) +def archive_model_dtypes(model): + for name, module in model.named_modules(): + for param_name, param in module.named_parameters(recurse=False): + setattr(module, f"{param_name}_comfy_model_dtype", param.dtype) + def cleanup_models(): to_delete = [] diff --git a/comfy/ops.py b/comfy/ops.py index 7bdbb53a1bf0..0010a9c8d073 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -296,6 +296,8 @@ def __init__(self, in_features, out_features, bias=True, device=None, dtype=None self.weight = None self.bias = None self.comfy_need_lazy_init_bias=bias + self.weight_comfy_model_dtype = dtype + self.bias_comfy_model_dtype = dtype def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): diff --git a/comfy/sd.py b/comfy/sd.py index f3fc56997638..7e67c6919f2e 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -228,6 +228,8 @@ def __init__(self, target=None, embedding_directory=None, no_init=False, tokeniz self.cond_stage_model.to(offload_device) logging.warning("Had to shift TE back.") + model_management.archive_model_dtypes(self.cond_stage_model) + self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) self.patcher = comfy.model_patcher.CoreModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) #Match torch.float32 hardcode upcast in TE implemention @@ -775,6 +777,8 @@ def estimate_memory(shape, dtype, num_layers = 16, kv_cache_multiplier = 2): self.first_stage_model = AutoencoderKL(**(config['params'])) self.first_stage_model = self.first_stage_model.eval() + model_management.archive_model_dtypes(self.first_stage_model) + if device is None: device = model_management.vae_device() self.device = device From 36c76527de0f5c1ed81afbeb6d71840e63c32a89 Mon Sep 17 00:00:00 2001 From: Rattus Date: Thu, 22 Jan 2026 00:02:11 +1000 Subject: [PATCH 33/51] ops: fix __init__ return --- comfy/ops.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/comfy/ops.py b/comfy/ops.py index 0010a9c8d073..a2ca607d3719 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -283,7 +283,8 @@ class Linear(torch.nn.Linear, CastWeightBiasOp): def __init__(self, in_features, out_features, bias=True, device=None, dtype=None): if not comfy.model_management.WINDOWS or not enables_dynamic_vram(): - return super().__init__(in_features, out_features, bias, device, dtype) + super().__init__(in_features, out_features, bias, device, dtype) + return # Issue is with `torch.empty` still reserving the full memory for the layer. # Windows doesn't over-commit memory so without this, We are momentarily commit From ede3d4b96674fbf7ff02c1473f20c9e3cbeeccea Mon Sep 17 00:00:00 2001 From: Rattus Date: Thu, 22 Jan 2026 00:03:01 +1000 Subject: [PATCH 34/51] MPDynamic: Add support for model defined dtype If the model defines a dtype that is different to what is in the state dict, respect that at load time. This is done as part of the casting process. --- comfy/memory_management.py | 29 +++++++++++++++++++++++++++++ comfy/model_management.py | 4 ++-- comfy/model_patcher.py | 15 +++++++++++++-- comfy/ops.py | 21 ++++++++++++++++++++- comfy/pinned_memory.py | 2 +- 5 files changed, 65 insertions(+), 6 deletions(-) diff --git a/comfy/memory_management.py b/comfy/memory_management.py index 3765de0a199e..858bd4cc782b 100644 --- a/comfy/memory_management.py +++ b/comfy/memory_management.py @@ -1,5 +1,34 @@ +import math +import torch +from typing import NamedTuple + from comfy.quant_ops import QuantizedTensor +class TensorGeometry(NamedTuple): + shape: any + dtype: torch.dtype + + def element_size(self): + info = torch.finfo(self.dtype) if self.dtype.is_floating_point else torch.iinfo(self.dtype) + return info.bits // 8 + + def numel(self): + return math.prod(self.shape) + +def tensors_to_geometries(tensors, dtype=None): + geometries = [] + for t in tensors: + if t is None or isinstance(t, QuantizedTensor): + geometries.append(t) + continue + tdtype = t.dtype + if hasattr(t, "_model_dtype"): + tdtype = t._model_dtype + if dtype is not None: + tdtype = dtype + geometries.append(TensorGeometry(shape=t.shape, dtype=tdtype)) + return geometries + def vram_aligned_size(tensor): if isinstance(tensor, list): return sum([vram_aligned_size(t) for t in tensor]) diff --git a/comfy/model_management.py b/comfy/model_management.py index cdb9542c0a7f..5271974475a9 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1190,12 +1190,12 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str assert r is None assert stream is None - r = torch.empty_like(weight, dtype=dtype, device=device) + r = torch.empty_like(weight, dtype=weight._model_dtype, device=device) signature = comfy_aimdo.model_vbar.vbar_fault(weight._v) if signature is not None: raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device) - v_tensor = comfy.memory_management.interpret_gathered_like([weight], raw_tensor)[0] + v_tensor = comfy.memory_management.interpret_gathered_like([r], raw_tensor)[0] if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature): #always take a deep copy even if _v is good, as we have no reasonable point to unpin diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 6b25436f24ab..1ef5b6661c5b 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -1504,6 +1504,8 @@ def setup_param(self, m, n, param_key): weight_function = [] weight, _, _ = get_key_weight(self.model, key) + if weight is None: + return 0 if key in self.patches: setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches)) num_patches += 1 @@ -1513,7 +1515,12 @@ def setup_param(self, m, n, param_key): if key in self.weight_wrapper_patches: weight_function.extend(self.weight_wrapper_patches[key]) setattr(m, param_key + "_function", weight_function) - return comfy.memory_management.vram_aligned_size(weight) + geometry = weight + if not isinstance(weight, QuantizedTensor): + model_dtype = getattr(m, param_key + "_comfy_model_dtype", weight.dtype) + weight._model_dtype = model_dtype + geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype) + return comfy.memory_management.vram_aligned_size(geometry) if hasattr(m, "comfy_cast_weights"): m.comfy_cast_weights = True @@ -1535,9 +1542,13 @@ def setup_param(self, m, n, param_key): weight, _, _ = get_key_weight(self.model, key) weight.seed_key = key set_dirty(weight, dirty) - weight_size = weight.numel() * weight.element_size() + geometry = weight + model_dtype = getattr(m, param + "_comfy_model_dtype", weight.dtype) + geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype) + weight_size = geometry.numel() * geometry.element_size() if vbar is not None and not hasattr(weight, "_v"): weight._v = vbar.alloc(weight_size) + weight._model_dtype = model_dtype allocated_size += weight_size logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.") diff --git a/comfy/ops.py b/comfy/ops.py index a2ca607d3719..33c43f18327b 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -81,6 +81,7 @@ def cast_to_input(weight, input, non_blocking=False, copy=True): def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype): offload_stream = None xfer_dest = None + cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ]) signature = comfy_aimdo.model_vbar.vbar_fault(s._v) if signature is not None: @@ -88,6 +89,7 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature) if not resident: + cast_dest = None xfer_source = [ s.weight, s.bias ] @@ -95,6 +97,16 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu if pin is not None: xfer_source = [ pin ] resident = True #If pinned data exists, it always has LowVram already applied + else: + for data, geometry in zip([ s.weight, s.bias ], cast_geometry): + if data is None: + continue + if data.dtype != geometry.dtype: + cast_dest = xfer_dest + if cast_dest is None: + cast_dest = torch.empty((comfy.memory_management.vram_aligned_size(cast_geometry),), dtype=torch.uint8, device=device) + xfer_dest = None + break dest_size = comfy.memory_management.vram_aligned_size(xfer_source) offload_stream = comfy.model_management.get_offload_stream(device) @@ -111,6 +123,13 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream) comfy.model_management.sync_stream(device, offload_stream) + if cast_dest is not None: + for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like(xfer_source, xfer_dest), + comfy.memory_management.interpret_gathered_like(cast_geometry, cast_dest)): + if post_cast is not None: + post_cast.copy_(pre_cast) + xfer_dest = cast_dest + pin = None if signature is not None: #If we are able to increase our load level (e.g. user reduces resolution or batch number) @@ -122,7 +141,7 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu comfy.pinned_memory.pin_memory(s) pin = comfy.pinned_memory.get_pin(s) - params = comfy.memory_management.interpret_gathered_like([s.weight, s.bias], xfer_dest) + params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest) weight = params[0] bias = params[1] diff --git a/comfy/pinned_memory.py b/comfy/pinned_memory.py index 650e27a10420..0650e4d1ab5b 100644 --- a/comfy/pinned_memory.py +++ b/comfy/pinned_memory.py @@ -11,7 +11,7 @@ def pin_memory(module): if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None: return #FIXME: This is a RAM cache trigger event - params = [ module.weight, module.bias ] + params = comfy.memory_management.tensors_to_geometries([ module.weight, module.bias ]) size = comfy.memory_management.vram_aligned_size(params) pin = torch.empty((size,), dtype=torch.uint8) if comfy.model_management.pin_memory(pin): From 355172fe7e8502140bc86517f167fd8f9cb3ce51 Mon Sep 17 00:00:00 2001 From: Rattus Date: Sun, 25 Jan 2026 09:14:52 +1000 Subject: [PATCH 35/51] remove bad pyt2.4 versions gate --- comfy/utils.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/comfy/utils.py b/comfy/utils.py index c620e75458fb..5aae2c1bbe4f 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -33,8 +33,6 @@ import time import mmap -import packaging - MMAP_TORCH_FILES = args.mmap_torch_files DISABLE_MMAP = args.disable_mmap @@ -74,15 +72,11 @@ def scalar(*args, **kwargs): "F8_E4M3": torch.float8_e4m3fn, "F8_E5M2": torch.float8_e5m2, "C64": torch.complex64, + + "U64": torch.uint64, + "U32": torch.uint32, + "U16": torch.uint16, } -if packaging.version.Version(torch.__version__) >= packaging.version.Version("2.3.0"): - _TYPES.update( - { - "U64": torch.uint64, - "U32": torch.uint32, - "U16": torch.uint16, - } - ) def load_safetensors(ckpt): f = open(ckpt, "rb") From 8bb291ba173bcefddb9f66f20b743be0c822ff8e Mon Sep 17 00:00:00 2001 From: Rattus Date: Mon, 26 Jan 2026 00:53:07 +1000 Subject: [PATCH 36/51] disable async pin population --- comfy/ops.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index 33c43f18327b..d8ff261f0962 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -187,13 +187,11 @@ def to_dequant(tensor, dtype): if pin is not None: xfer_dest = comfy.memory_management.interpret_gathered_like([ pin ], xfer_dest)[0] - #FIXME: This might be the wrong thing to do. Some reading suggests the DMA engine - #is posted writes and the compute stream could just fire and forget here. That - #would save this sync and some stalling on the offload stream that is better off - #running ahead to the next layer to read. - if offload_stream is not None: - offload_stream.wait_stream(comfy.model_management.current_stream(device)) - comfy.model_management.cast_to(xfer_dest, device=pin.device, non_blocking=non_blocking, stream=offload_stream, r=pin) + #FIXME: put this on nsight and see if its worth offloading to the pin with + #the offload stream. This adds extra sync requirements on xfer_dest in addition to: + #if offload_stream is not None: + # offload_stream.wait_stream(comfy.model_management.current_stream(device)) + comfy.model_management.cast_to(xfer_dest, device=pin.device, non_blocking=non_blocking, stream=None, r=pin) #FIXME: weird offload return protocol return weight, bias, (offload_stream, device if signature is not None else None, None) From 4c875a2a8fb354cb1e471f85c8bcda3f664eef53 Mon Sep 17 00:00:00 2001 From: Rattus Date: Mon, 26 Jan 2026 01:06:17 +1000 Subject: [PATCH 37/51] fix syncs Fix these sync to conditionalize properly for CPU and always run in exception flows. --- comfy/model_management.py | 3 ++- execution.py | 7 +++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 5271974475a9..888cea5c3187 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1120,7 +1120,8 @@ def get_cast_buffer(offload_stream, device, size, ref): def reset_cast_buffers(): global LARGEST_CASTED_WEIGHT LARGEST_CASTED_WEIGHT = (None, 0) - torch.cuda.synchronize() + for offload_stream in STREAM_CAST_BUFFERS: + offload_stream.synchronize() STREAM_CAST_BUFFERS.clear() torch.cuda.empty_cache() diff --git a/execution.py b/execution.py index 0c64cbe6ac20..9607e16364b2 100644 --- a/execution.py +++ b/execution.py @@ -523,8 +523,11 @@ def pre_execute_cb(call_index): #that we just want to cull out each model run. allocator = comfy.memory_management.aimdo_allocator with nullcontext() if allocator is None else torch.cuda.use_mem_pool(torch.cuda.MemPool(allocator.allocator())): - output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data) - torch.cuda.synchronize() + try: + output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data) + finally: + if allocator is not None: + torch.cuda.synchronize() if allocator is not None: #FIXME: this is probably a little zealous # Torch code comments says some stuff about not actually freeing tensors on mempool From f98c86ce9dd382ea4736664529ca8550a58858eb Mon Sep 17 00:00:00 2001 From: Rattus Date: Mon, 26 Jan 2026 22:14:08 +1000 Subject: [PATCH 38/51] add missing signature set for non comfy --- comfy/model_management.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index 888cea5c3187..143991ad6645 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1208,6 +1208,7 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str r.copy_(weight, non_blocking=non_blocking) if signature is not None: + weight._v_signature = signature v_tensor.copy_(r) comfy_aimdo.model_vbar.vbar_unpin(weight._v) From 2a76ec6e033da65575fef5ea660e5e77bf664ae5 Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 27 Jan 2026 14:07:15 +1000 Subject: [PATCH 39/51] fix missing import --- comfy/model_management.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index 143991ad6645..70c2d5e22ef7 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -27,6 +27,7 @@ import gc import os from contextlib import nullcontext +import comfy.memory_management import comfy.utils import comfy.quant_ops From 04141efe5497d41a45b5e1420046c41152523ccb Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 27 Jan 2026 14:10:54 +1000 Subject: [PATCH 40/51] mm: Dont GPU load models Aimdo will do this on demand as 0 copy. Remove the special case for vram > ram. --- comfy/model_management.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 70c2d5e22ef7..412752503de4 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -823,7 +823,7 @@ def unet_inital_load_device(parameters, dtype): mem_dev = get_free_memory(torch_dev) mem_cpu = get_free_memory(cpu_dev) - if mem_dev > mem_cpu and model_size < mem_dev: + if mem_dev > mem_cpu and model_size < mem_dev and comfy.memory_management.aimdo_allocator is None: return torch_dev else: return cpu_dev From cd085314f96b63cb4d3d79c27cf5ba5b0440ccff Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 27 Jan 2026 14:11:42 +1000 Subject: [PATCH 41/51] ops: dont discard pins Its more likely that the user will rerun their workflow and want whatever pins are inplace so remove this. pins still have to respect RAM pressure per model anyway. --- comfy/ops.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index d8ff261f0962..886d2735068d 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -131,11 +131,7 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu xfer_dest = cast_dest pin = None - if signature is not None: - #If we are able to increase our load level (e.g. user reduces resolution or batch number) - #reclaim the pin previously used for offload. - comfy.pinned_memory.unpin_memory(s) - elif not resident: + if signature is None and not resident: #prepare a new pin assert comfy.pinned_memory.get_pin(s) is None comfy.pinned_memory.pin_memory(s) From 101367b0da6fd9ae0b653825763b1fa717c7cef9 Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 27 Jan 2026 14:19:07 +1000 Subject: [PATCH 42/51] mm: redefine free memory for Windows As commented. --- comfy/model_management.py | 8 +++++- comfy/windows.py | 52 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 1 deletion(-) create mode 100644 comfy/windows.py diff --git a/comfy/model_management.py b/comfy/model_management.py index 412752503de4..804be7768d48 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -585,9 +585,15 @@ def offloaded_memory(loaded_models, device): EXTRA_RESERVED_VRAM = 400 * 1024 * 1024 if WINDOWS: + import comfy.windows EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue if total_vram > (15 * 1024): # more extra reserved vram on 16GB+ cards EXTRA_RESERVED_VRAM += 100 * 1024 * 1024 + def get_free_ram(): + return comfy.windows.get_free_ram() +else: + def get_free_ram(): + return psutil.virtual_memory().available if args.reserve_vram is not None: EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024 @@ -618,7 +624,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_ ram_to_free = 1e32 if not DISABLE_SMART_MEMORY: memory_to_free = memory_required - get_free_memory(device) - ram_to_free = ram_required - psutil.virtual_memory().available + ram_to_free = ram_required - get_free_ram() if current_loaded_models[i].model.is_dynamic() and for_dynamic: #don't actually unload dynamic models for the sake of other dynamic models diff --git a/comfy/windows.py b/comfy/windows.py new file mode 100644 index 000000000000..213dc481d937 --- /dev/null +++ b/comfy/windows.py @@ -0,0 +1,52 @@ +import ctypes +import logging +import psutil +from ctypes import wintypes + +import comfy_aimdo.control + +psapi = ctypes.WinDLL("psapi") +kernel32 = ctypes.WinDLL("kernel32") + +class PERFORMANCE_INFORMATION(ctypes.Structure): + _fields_ = [ + ("cb", wintypes.DWORD), + ("CommitTotal", ctypes.c_size_t), + ("CommitLimit", ctypes.c_size_t), + ("CommitPeak", ctypes.c_size_t), + ("PhysicalTotal", ctypes.c_size_t), + ("PhysicalAvailable", ctypes.c_size_t), + ("SystemCache", ctypes.c_size_t), + ("KernelTotal", ctypes.c_size_t), + ("KernelPaged", ctypes.c_size_t), + ("KernelNonpaged", ctypes.c_size_t), + ("PageSize", ctypes.c_size_t), + ("HandleCount", wintypes.DWORD), + ("ProcessCount", wintypes.DWORD), + ("ThreadCount", wintypes.DWORD), + ] + +def get_free_ram(): + #Windows is way too conservative and chalks recently used uncommitted model RAM + #as "in-use". So, calculate free RAM for the sake of general use as the greater of: + # + #1: What psutil says + #2: Total Memory - (Committed Memory - VRAM in use) + # + #We have to subtract VRAM in use from the comitted memory as WDDM creates a naked + #commit charge for all VRAM used just incase it wants to page it all out. This just + #isn't realistic so "overcommit" on our calculations by just subtracting it off. + + pi = PERFORMANCE_INFORMATION() + pi.cb = ctypes.sizeof(pi) + + if not psapi.GetPerformanceInfo(ctypes.byref(pi), pi.cb): + logging.warning("WARNING: Failed to query windows performance info. RAM usage may be sub optimal") + return psutil.virtual_memory().available + + committed = pi.CommitTotal * pi.PageSize + total = pi.PhysicalTotal * pi.PageSize + + return max(psutil.virtual_memory().available, + total - (committed - comfy_aimdo.control.get_total_vram_usage())) + From dff1ee9351e6552e5c408f0f1d3cde6afc6d333e Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 27 Jan 2026 17:23:51 +1000 Subject: [PATCH 43/51] free dynamic pins properly --- comfy/model_patcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 1ef5b6661c5b..6db41dfcfac1 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -1451,7 +1451,7 @@ def unpin_weight(self, key): raise RuntimeError("unpin_weight invalid for dymamic weight loading") def unpin_all_weights(self): - pass + self.partially_unload_ram(1e32) def memory_required(self, input_shape): #Pad this significantly. We are trying to get away from precise estimates. This From f8f9a89f6e8b5d043fc7ccc80e495c83b67d5c36 Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 27 Jan 2026 18:58:52 +1000 Subject: [PATCH 44/51] bump aimdo to 1.4 --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index c7fd356e7c29..fee6c69c5f19 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,7 +22,7 @@ alembic SQLAlchemy av>=14.2.0 comfy-kitchen>=0.2.7 -comfy-aimdo>=0.1.2 +comfy-aimdo>=0.1.4 requests #non essential dependencies: From 8067cb4f93de50b322187aae4921526f7bacf092 Mon Sep 17 00:00:00 2001 From: Rattus Date: Thu, 29 Jan 2026 01:42:38 +1000 Subject: [PATCH 45/51] mm: dont clear_cache with mempools Two things. * pyt2.7 crashes if you try and clear_cache in the presence of mempools. * mempools don't actually ever clear_cache because the mempool itself is considered a ref. Guard the code accordingly and remove useless clear_cache calls. The offload stream resizer will need some fixing. --- comfy/model_management.py | 13 +++++++++---- comfy/sd.py | 2 -- execution.py | 6 +----- 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 804be7768d48..bb9ae5852876 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1114,6 +1114,7 @@ def get_cast_buffer(offload_stream, device, size, ref): torch.cuda.synchronize() del STREAM_CAST_BUFFERS[offload_stream] del cast_buffer + #FIXME: This doesn't work in Aimdo because mempool cant clear cache torch.cuda.empty_cache() with wf_context: cast_buffer = torch.empty((size), dtype=torch.int8, device=device) @@ -1130,7 +1131,9 @@ def reset_cast_buffers(): for offload_stream in STREAM_CAST_BUFFERS: offload_stream.synchronize() STREAM_CAST_BUFFERS.clear() - torch.cuda.empty_cache() + if comfy.memory_management.aimdo_allocator is None: + #Pytorch 2.7 and earlier crashes if you try and empty_cache when mempools exist + torch.cuda.empty_cache() def get_offload_stream(device): stream_counter = stream_counters.get(device, 0) @@ -1686,9 +1689,11 @@ def soft_empty_cache(force=False): elif is_mlu(): torch.mlu.empty_cache() elif torch.cuda.is_available(): - torch.cuda.synchronize() - torch.cuda.empty_cache() - torch.cuda.ipc_collect() + if comfy.memory_management.aimdo_allocator is None: + #Pytorch 2.7 and earlier crashes if you try and empty_cache when mempools exist + torch.cuda.synchronize() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() def unload_all_models(): free_memory(1e30, get_torch_device()) diff --git a/comfy/sd.py b/comfy/sd.py index 7e67c6919f2e..fd0ac85e76b5 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -934,7 +934,6 @@ def decode(self, samples_in, vae_options={}): do_tile = True if do_tile: - torch.cuda.empty_cache() dims = samples_in.ndim - 2 if dims == 1 or self.extra_1d_channel is not None: pixel_samples = self.decode_tiled_1d(samples_in) @@ -1010,7 +1009,6 @@ def encode(self, pixel_samples): do_tile = True if do_tile: - torch.cuda.empty_cache() if self.latent_dim == 3: tile = 256 overlap = tile // 4 diff --git a/execution.py b/execution.py index 9607e16364b2..93fafc4a28d3 100644 --- a/execution.py +++ b/execution.py @@ -527,12 +527,8 @@ def pre_execute_cb(call_index): output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data) finally: if allocator is not None: + comfy.model_management.reset_cast_buffers() torch.cuda.synchronize() - if allocator is not None: - #FIXME: this is probably a little zealous - # Torch code comments says some stuff about not actually freeing tensors on mempool - #context release. Explicitly garbage collect now. - torch.cuda.empty_cache() if has_pending_tasks: pending_async_nodes[unique_id] = output_data From bc80f784d867be6a9d128658a8731144029a2674 Mon Sep 17 00:00:00 2001 From: Rattus Date: Thu, 29 Jan 2026 23:48:00 +1000 Subject: [PATCH 46/51] Fix ram freeing logic --- comfy/model_management.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index bb9ae5852876..758e718e84c4 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -630,11 +630,12 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_ #don't actually unload dynamic models for the sake of other dynamic models #as that works on-demand. memory_required -= current_loaded_models[i].model.loaded_size() - continue - logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}") + memory_to_free = 0 if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free): + logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}") unloaded_model.append(i) if ram_to_free > 0: + logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}") current_loaded_models[i].model.partially_unload_ram(ram_to_free) for i in sorted(unloaded_model, reverse=True): From b1eb25b5c15d172f91d02bf409c9c1c609460ec4 Mon Sep 17 00:00:00 2001 From: Rattus Date: Thu, 29 Jan 2026 23:48:27 +1000 Subject: [PATCH 47/51] Go back to pre-pins Post pins dont really work for low spec users and you are more likely to recycle your model with a different lora than to really care about that tiny little bit of perf of pre-computed Lora. Do it the old way. --- comfy/model_patcher.py | 2 -- comfy/ops.py | 26 ++++++++++---------------- 2 files changed, 10 insertions(+), 18 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 6db41dfcfac1..57b53d8c566e 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -1494,8 +1494,6 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False def set_dirty(item, dirty): if dirty or not hasattr(item, "_v_signature"): item._v_signature = None - if dirty: - comfy.pinned_memory.unpin_memory(item) def setup_param(self, m, n, param_key): nonlocal num_patches diff --git a/comfy/ops.py b/comfy/ops.py index 886d2735068d..3e7c019e1d46 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -119,6 +119,15 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu xfer_dest = torch.empty((dest_size,), dtype=torch.uint8, device=device) offload_stream = None + if signature is None and pin is None: + comfy.pinned_memory.pin_memory(s) + pin = comfy.pinned_memory.get_pin(s) + else: + pin = None + + if pin is not None: + comfy.model_management.cast_to_gathered(xfer_source, pin) + xfer_srouce = [ pin ] #send it over comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream) comfy.model_management.sync_stream(device, offload_stream) @@ -130,13 +139,6 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu post_cast.copy_(pre_cast) xfer_dest = cast_dest - pin = None - if signature is None and not resident: - #prepare a new pin - assert comfy.pinned_memory.get_pin(s) is None - comfy.pinned_memory.pin_memory(s) - pin = comfy.pinned_memory.get_pin(s) - params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest) weight = params[0] bias = params[1] @@ -174,21 +176,13 @@ def to_dequant(tensor, dtype): x = f(x) return x - update_weight = signature is not None or pin is not None + update_weight = signature is not None weight = post_cast(s, "weight", weight, dtype, resident, update_weight) if s.bias is not None: bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight) s._v_signature=signature - if pin is not None: - xfer_dest = comfy.memory_management.interpret_gathered_like([ pin ], xfer_dest)[0] - #FIXME: put this on nsight and see if its worth offloading to the pin with - #the offload stream. This adds extra sync requirements on xfer_dest in addition to: - #if offload_stream is not None: - # offload_stream.wait_stream(comfy.model_management.current_stream(device)) - comfy.model_management.cast_to(xfer_dest, device=pin.device, non_blocking=non_blocking, stream=None, r=pin) - #FIXME: weird offload return protocol return weight, bias, (offload_stream, device if signature is not None else None, None) From 46f9ac1967777bda1af938dd4d009a818c21797b Mon Sep 17 00:00:00 2001 From: Rattus Date: Thu, 29 Jan 2026 23:50:00 +1000 Subject: [PATCH 48/51] bump aimdo --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index fee6c69c5f19..c03d3fce0658 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,7 +22,7 @@ alembic SQLAlchemy av>=14.2.0 comfy-kitchen>=0.2.7 -comfy-aimdo>=0.1.4 +comfy-aimdo>=0.1.5 requests #non essential dependencies: From 74584f69c6342fc69c9544e9a9e487019a1955ff Mon Sep 17 00:00:00 2001 From: Rattus Date: Sat, 31 Jan 2026 01:12:55 +1000 Subject: [PATCH 49/51] fixes to pinning rework --- comfy/ops.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index 3e7c019e1d46..c3a1825cecd2 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -96,7 +96,6 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu pin = comfy.pinned_memory.get_pin(s) if pin is not None: xfer_source = [ pin ] - resident = True #If pinned data exists, it always has LowVram already applied else: for data, geometry in zip([ s.weight, s.bias ], cast_geometry): if data is None: @@ -127,7 +126,7 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu if pin is not None: comfy.model_management.cast_to_gathered(xfer_source, pin) - xfer_srouce = [ pin ] + xfer_source = [ pin ] #send it over comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream) comfy.model_management.sync_stream(device, offload_stream) From 58fd609a2e11c20288d740bc35f24bf8f55a4ab1 Mon Sep 17 00:00:00 2001 From: Rattus Date: Sat, 31 Jan 2026 01:13:04 +1000 Subject: [PATCH 50/51] bump aimdo --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index c03d3fce0658..41823ab6bbe1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,7 +22,7 @@ alembic SQLAlchemy av>=14.2.0 comfy-kitchen>=0.2.7 -comfy-aimdo>=0.1.5 +comfy-aimdo>=0.1.6 requests #non essential dependencies: From 882a3bcba42400455773ff4dfa182fb674d0b8f0 Mon Sep 17 00:00:00 2001 From: Rattus Date: Sun, 1 Feb 2026 09:23:49 +1000 Subject: [PATCH 51/51] remove bad assertion --- comfy/utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/comfy/utils.py b/comfy/utils.py index 5aae2c1bbe4f..9e98eb1769c0 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -101,8 +101,6 @@ def load_safetensors(ckpt): def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): if device is None: device = torch.device("cpu") - else: - assert False metadata = None if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"): try: