diff --git a/comfy/model_management.py b/comfy/model_management.py index 56dd10f22ca6..fa8fb76ec1ff 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -546,21 +546,22 @@ def should_reload_model(self, force_patch_weights=False): return False def model_unload(self, memory_to_free=None, unpatch_weights=True): + model_loaded_size = self.model.loaded_size() + if memory_to_free is None: + # free the full model + memory_to_free = model_loaded_size + logging.debug(f"model_unload: {self.model.model.__class__.__name__}") logging.debug(f"memory_to_free: {memory_to_free/(1024*1024*1024)} GB") logging.debug(f"unpatch_weights: {unpatch_weights}") - logging.debug(f"loaded_size: {self.model.loaded_size()/(1024*1024*1024)} GB") + logging.debug(f"loaded_size: {model_loaded_size/(1024*1024*1024)} GB") logging.debug(f"offload_device: {self.model.offload_device}") - if memory_to_free is None: - # free the full model - memory_to_free = self.model.loaded_size() - available_memory = get_free_memory(self.model.offload_device) logging.debug(f"before unload, available_memory of offload device {self.model.offload_device}: {available_memory/(1024*1024*1024)} GB") mmap_mem_threshold = get_mmap_mem_threshold_gb() * 1024 * 1024 * 1024 # this is reserved memory for other system usage - if memory_to_free > available_memory - mmap_mem_threshold or memory_to_free < self.model.loaded_size(): + if min(memory_to_free, model_loaded_size) > available_memory - mmap_mem_threshold or memory_to_free < model_loaded_size: partially_unload = True else: partially_unload = False @@ -571,6 +572,8 @@ def model_unload(self, memory_to_free=None, unpatch_weights=True): logging.debug(f"partially_unload freed vram: {freed/(1024*1024*1024)} GB") if freed < memory_to_free: logging.warning(f"Partially unload not enough memory, freed {freed/(1024*1024*1024)} GB, memory_to_free {memory_to_free/(1024*1024*1024)} GB") + if freed == model_loaded_size: + partially_unload = False else: logging.debug("Do full unload") self.model.detach(unpatch_weights)