From 72bbf493494109d3c177ad5de378c6c4bbae61d1 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sun, 29 Dec 2024 15:49:09 -0600 Subject: [PATCH 01/90] Add 'sigmas' to transformer_options so that downstream code can know about the full scope of current sampling run, fix Hook Keyframes' guarantee_steps=1 inconsistent behavior with sampling split across different Sampling nodes/sampling runs by referencing 'sigmas' --- comfy/hooks.py | 21 +++++++++++++++++---- comfy/model_patcher.py | 5 +++-- comfy/samplers.py | 10 ++++++---- 3 files changed, 26 insertions(+), 10 deletions(-) diff --git a/comfy/hooks.py b/comfy/hooks.py index cf33598ae2eb..79a7090ba206 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -366,9 +366,15 @@ def __init__(self, strength: float, start_percent=0.0, guarantee_steps=1): self.start_t = 999999999.9 self.guarantee_steps = guarantee_steps + def get_effective_guarantee_steps(self, max_sigma: torch.Tensor): + '''If keyframe starts before current sampling range (max_sigma), treat as 0.''' + if self.start_t > max_sigma: + return 0 + return self.guarantee_steps + def clone(self): c = HookKeyframe(strength=self.strength, - start_percent=self.start_percent, guarantee_steps=self.guarantee_steps) + start_percent=self.start_percent, guarantee_steps=self.guarantee_steps) c.start_t = self.start_t return c @@ -408,6 +414,12 @@ def _set_first_as_current(self): else: self._current_keyframe = None + def has_guarantee_steps(self): + for kf in self.keyframes: + if kf.guarantee_steps > 0: + return True + return False + def has_index(self, index: int): return index >= 0 and index < len(self.keyframes) @@ -425,15 +437,16 @@ def initialize_timesteps(self, model: 'BaseModel'): for keyframe in self.keyframes: keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent) - def prepare_current_keyframe(self, curr_t: float) -> bool: + def prepare_current_keyframe(self, curr_t: float, transformer_options: dict[str, torch.Tensor]) -> bool: if self.is_empty(): return False if curr_t == self._curr_t: return False + max_sigma = torch.max(transformer_options["sigmas"]) prev_index = self._current_index prev_strength = self._current_strength # if met guaranteed steps, look for next keyframe in case need to switch - if self._current_used_steps >= self._current_keyframe.guarantee_steps: + if self._current_used_steps >= self._current_keyframe.get_effective_guarantee_steps(max_sigma): # if has next index, loop through and see if need to switch if self.has_index(self._current_index+1): for i in range(self._current_index+1, len(self.keyframes)): @@ -446,7 +459,7 @@ def prepare_current_keyframe(self, curr_t: float) -> bool: self._current_keyframe = eval_c self._current_used_steps = 0 # if guarantee_steps greater than zero, stop searching for other keyframes - if self._current_keyframe.guarantee_steps > 0: + if self._current_keyframe.get_effective_guarantee_steps(max_sigma) > 0: break # if eval_c is outside the percent range, stop looking further else: break diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index d89d9a6a3ec7..4597ce11ccf1 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -919,11 +919,12 @@ def restore_hook_patches(self): def set_hook_mode(self, hook_mode: comfy.hooks.EnumHookMode): self.hook_mode = hook_mode - def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup): + def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup, model_options: dict[str]): curr_t = t[0] reset_current_hooks = False + transformer_options = model_options.get("transformer_options", {}) for hook in hook_group.hooks: - changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t) + changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t, transformer_options=transformer_options) # if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref; # this will cause the weights to be recalculated when sampling if changed: diff --git a/comfy/samplers.py b/comfy/samplers.py index 27686722dcd2..6a386511a2c7 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -144,7 +144,7 @@ def cond_cat(c_list): return out -def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]], default_conds: list[list[dict]], x_in, timestep): +def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]], default_conds: list[list[dict]], x_in, timestep, model_options): # need to figure out remaining unmasked area for conds default_mults = [] for _ in default_conds: @@ -183,7 +183,7 @@ def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.H # replace p's mult with calculated mult p = p._replace(mult=mult) if p.hooks is not None: - model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks) + model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options) hooked_to_run.setdefault(p.hooks, list()) hooked_to_run[p.hooks] += [(p, i)] @@ -218,7 +218,7 @@ def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Te if p is None: continue if p.hooks is not None: - model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks) + model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options) hooked_to_run.setdefault(p.hooks, list()) hooked_to_run[p.hooks] += [(p, i)] default_conds.append(default_c) @@ -840,7 +840,9 @@ def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mas self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed) - extra_args = {"model_options": comfy.model_patcher.create_model_options_clone(self.model_options), "seed": seed} + extra_model_options = comfy.model_patcher.create_model_options_clone(self.model_options) + extra_model_options.setdefault("transformer_options", {})["sigmas"] = sigmas + extra_args = {"model_options": extra_model_options, "seed": seed} executor = comfy.patcher_extension.WrapperExecutor.new_class_executor( sampler.sample, From 5a2ad032cb09afcaf7fadf5cdfa20c2b0498aee5 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 3 Jan 2025 20:02:27 -0600 Subject: [PATCH 02/90] Cleaned up hooks.py, refactored Hook.should_register and add_hook_patches to use target_dict instead of target so that more information can be provided about the current execution environment if needed --- comfy/hooks.py | 148 +++++++++++++++++++++++------------- comfy/model_patcher.py | 8 +- comfy/sampler_helpers.py | 2 +- comfy_extras/nodes_hooks.py | 2 +- 4 files changed, 102 insertions(+), 58 deletions(-) diff --git a/comfy/hooks.py b/comfy/hooks.py index 79a7090ba206..181c4996a418 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -16,46 +16,86 @@ import comfy.patcher_extension from node_helpers import conditioning_set_values +# ####################################################################################################### +# Hooks explanation +# ------------------- +# The purpose of hooks is to allow conds to influence sampling without the need for ComfyUI core code to +# make explicit special cases like it does for ControlNet and GLIGEN. +# +# This is necessary for nodes/features that are intended for use with masked or scheduled conds, or those +# that should run special code when a 'marked' cond is used in sampling. +# ####################################################################################################### + class EnumHookMode(enum.Enum): + ''' + Priority of hook memory optimization vs. speed, mostly related to WeightHooks. + + MinVram: No caching will occur for any operations related to hooks. + MaxSpeed: Excess VRAM (and RAM, once VRAM is sufficiently depleted) will be used to cache hook weights when switching hook groups. + ''' MinVram = "minvram" MaxSpeed = "maxspeed" class EnumHookType(enum.Enum): + ''' + Hook types, each of which has different expected behavior. + ''' Weight = "weight" Patch = "patch" ObjectPatch = "object_patch" AddModels = "add_models" - Callbacks = "callbacks" Wrappers = "wrappers" - SetInjections = "add_injections" + Injections = "add_injections" class EnumWeightTarget(enum.Enum): Model = "model" Clip = "clip" +class EnumHookScope(enum.Enum): + ''' + Determines if hook should be limited in its influence over sampling. + + AllConditioning: hook will affect all conds used in sampling. + HookedOnly: hook will only affect the conds it was attached to. + ''' + AllConditioning = "all_conditioning" + HookedOnly = "hooked_only" + + class _HookRef: pass -# NOTE: this is an example of how the should_register function should look -def default_should_register(hook: 'Hook', model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]): + +def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): + '''Example for how should_register function should look like.''' return True +def create_target_dict(target: EnumWeightTarget=None, **kwargs) -> dict[str]: + '''Creates base dictionary for use with Hooks' target param.''' + d = {} + if target is not None: + d['target'] = target + d.update(kwargs) + return d + + class Hook: def __init__(self, hook_type: EnumHookType=None, hook_ref: _HookRef=None, hook_id: str=None, - hook_keyframe: 'HookKeyframeGroup'=None): + hook_keyframe: HookKeyframeGroup=None, hook_scope=EnumHookScope.AllConditioning): self.hook_type = hook_type self.hook_ref = hook_ref if hook_ref else _HookRef() self.hook_id = hook_id self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup() self.custom_should_register = default_should_register self.auto_apply_to_nonpositive = False + self.hook_scope = hook_scope @property def strength(self): return self.hook_keyframe.strength - def initialize_timesteps(self, model: 'BaseModel'): + def initialize_timesteps(self, model: BaseModel): self.reset() self.hook_keyframe.initialize_timesteps(model) @@ -75,27 +115,32 @@ def clone(self, subtype: Callable=None): c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive return c - def should_register(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]): - return self.custom_should_register(self, model, model_options, target, registered) + def should_register(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): + return self.custom_should_register(self, model, model_options, target_dict, registered) - def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]): + def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): raise NotImplementedError("add_hook_patches should be defined for Hook subclasses") - def on_apply(self, model: 'ModelPatcher', transformer_options: dict[str]): + def on_apply(self, model: ModelPatcher, transformer_options: dict[str]): pass - def on_unapply(self, model: 'ModelPatcher', transformer_options: dict[str]): + def on_unapply(self, model: ModelPatcher, transformer_options: dict[str]): pass - def __eq__(self, other: 'Hook'): + def __eq__(self, other: Hook): return self.__class__ == other.__class__ and self.hook_ref == other.hook_ref def __hash__(self): return hash(self.hook_ref) class WeightHook(Hook): + ''' + Hook responsible for tracking weights to be applied to some model/clip. + + Note, value of hook_scope is ignored and is treated as HookedOnly. + ''' def __init__(self, strength_model=1.0, strength_clip=1.0): - super().__init__(hook_type=EnumHookType.Weight) + super().__init__(hook_type=EnumHookType.Weight, hook_scope=EnumHookScope.HookedOnly) self.weights: dict = None self.weights_clip: dict = None self.need_weight_init = True @@ -110,27 +155,29 @@ def strength_model(self): def strength_clip(self): return self._strength_clip * self.strength - def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]): - if not self.should_register(model, model_options, target, registered): + def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): + if not self.should_register(model, model_options, target_dict, registered): return False weights = None - if target == EnumWeightTarget.Model: - strength = self._strength_model - else: + + target = target_dict.get('target', None) + if target == EnumWeightTarget.Clip: strength = self._strength_clip + else: + strength = self._strength_model if self.need_weight_init: key_map = {} - if target == EnumWeightTarget.Model: - key_map = comfy.lora.model_lora_keys_unet(model.model, key_map) - else: + if target == EnumWeightTarget.Clip: key_map = comfy.lora.model_lora_keys_clip(model.model, key_map) + else: + key_map = comfy.lora.model_lora_keys_unet(model.model, key_map) weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False) else: - if target == EnumWeightTarget.Model: - weights = self.weights - else: + if target == EnumWeightTarget.Clip: weights = self.weights_clip + else: + weights = self.weights model.add_hook_patches(hook=self, patches=weights, strength_patch=strength) registered.append(self) return True @@ -174,7 +221,12 @@ def clone(self, subtype: Callable=None): # TODO: add functionality class AddModelsHook(Hook): - def __init__(self, key: str=None, models: list['ModelPatcher']=None): + ''' + Hook responsible for telling model management any additional models that should be loaded. + + Note, value of hook_scope is ignored and is treated as AllConditioning. + ''' + def __init__(self, key: str=None, models: list[ModelPatcher]=None): super().__init__(hook_type=EnumHookType.AddModels) self.key = key self.models = models @@ -188,24 +240,15 @@ def clone(self, subtype: Callable=None): c.models = self.models.copy() if self.models else self.models c.append_when_same = self.append_when_same return c - # TODO: add functionality - -class CallbackHook(Hook): - def __init__(self, key: str=None, callback: Callable=None): - super().__init__(hook_type=EnumHookType.Callbacks) - self.key = key - self.callback = callback - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: CallbackHook = super().clone(subtype) - c.key = self.key - c.callback = self.callback - return c - # TODO: add functionality + def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): + if not self.should_register(model, model_options, target_dict, registered): + return False class WrapperHook(Hook): + ''' + Hook responsible for adding wrappers, callbacks, or anything else onto transformer_options. + ''' def __init__(self, wrappers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None): super().__init__(hook_type=EnumHookType.Wrappers) self.wrappers_dict = wrappers_dict @@ -217,17 +260,18 @@ def clone(self, subtype: Callable=None): c.wrappers_dict = self.wrappers_dict return c - def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]): - if not self.should_register(model, model_options, target, registered): + def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): + if not self.should_register(model, model_options, target_dict, registered): return False add_model_options = {"transformer_options": self.wrappers_dict} - comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False) + if self.hook_scope == EnumHookScope.AllConditioning: + comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False) registered.append(self) return True class SetInjectionsHook(Hook): - def __init__(self, key: str=None, injections: list['PatcherInjection']=None): - super().__init__(hook_type=EnumHookType.SetInjections) + def __init__(self, key: str=None, injections: list[PatcherInjection]=None): + super().__init__(hook_type=EnumHookType.Injections) self.key = key self.injections = injections @@ -239,7 +283,7 @@ def clone(self, subtype: Callable=None): c.injections = self.injections.copy() if self.injections else self.injections return c - def add_hook_injections(self, model: 'ModelPatcher'): + def add_hook_injections(self, model: ModelPatcher): # TODO: add functionality pass @@ -260,14 +304,14 @@ def clone(self): c.add(hook.clone()) return c - def clone_and_combine(self, other: 'HookGroup'): + def clone_and_combine(self, other: HookGroup): c = self.clone() if other is not None: for hook in other.hooks: c.add(hook.clone()) return c - def set_keyframes_on_hooks(self, hook_kf: 'HookKeyframeGroup'): + def set_keyframes_on_hooks(self, hook_kf: HookKeyframeGroup): if hook_kf is None: hook_kf = HookKeyframeGroup() else: @@ -336,7 +380,7 @@ def reset(self): hook.reset() @staticmethod - def combine_all_hooks(hooks_list: list['HookGroup'], require_count=0) -> 'HookGroup': + def combine_all_hooks(hooks_list: list[HookGroup], require_count=0) -> HookGroup: actual: list[HookGroup] = [] for group in hooks_list: if group is not None: @@ -433,7 +477,7 @@ def clone(self): c._set_first_as_current() return c - def initialize_timesteps(self, model: 'BaseModel'): + def initialize_timesteps(self, model: BaseModel): for keyframe in self.keyframes: keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent) @@ -548,7 +592,7 @@ def create_hook_model_as_lora(weights_model, weights_clip, strength_model: float hook.need_weight_init = False return hook_group -def get_patch_weights_from_model(model: 'ModelPatcher', discard_model_sampling=True): +def get_patch_weights_from_model(model: ModelPatcher, discard_model_sampling=True): if model is None: return None patches_model: dict[str, torch.Tensor] = model.model.state_dict() @@ -560,7 +604,7 @@ def get_patch_weights_from_model(model: 'ModelPatcher', discard_model_sampling=T return patches_model # NOTE: this function shows how to register weight hooks directly on the ModelPatchers -def load_hook_lora_for_models(model: 'ModelPatcher', clip: 'CLIP', lora: dict[str, torch.Tensor], +def load_hook_lora_for_models(model: ModelPatcher, clip: CLIP, lora: dict[str, torch.Tensor], strength_model: float, strength_clip: float): key_map = {} if model is not None: diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 4597ce11ccf1..071535526a42 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -940,13 +940,13 @@ def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: com if reset_current_hooks: self.patch_hooks(None) - def register_all_hook_patches(self, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]], target: comfy.hooks.EnumWeightTarget, model_options: dict=None): + def register_all_hook_patches(self, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]], target_dict: dict[str], model_options: dict=None): self.restore_hook_patches() registered_hooks: list[comfy.hooks.Hook] = [] # handle WrapperHooks, if model_options provided if model_options is not None: for hook in hooks_dict.get(comfy.hooks.EnumHookType.Wrappers, {}): - hook.add_hook_patches(self, model_options, target, registered_hooks) + hook.add_hook_patches(self, model_options, target_dict, registered_hooks) # handle WeightHooks weight_hooks_to_register: list[comfy.hooks.WeightHook] = [] for hook in hooks_dict.get(comfy.hooks.EnumHookType.Weight, {}): @@ -956,9 +956,9 @@ def register_all_hook_patches(self, hooks_dict: dict[comfy.hooks.EnumHookType, d # clone hook_patches to become backup so that any non-dynamic hooks will return to their original state self.hook_patches_backup = create_hook_patches_clone(self.hook_patches) for hook in weight_hooks_to_register: - hook.add_hook_patches(self, model_options, target, registered_hooks) + hook.add_hook_patches(self, model_options, target_dict, registered_hooks) for callback in self.get_all_callbacks(CallbacksMP.ON_REGISTER_ALL_HOOK_PATCHES): - callback(self, hooks_dict, target) + callback(self, hooks_dict, target_dict) def add_hook_patches(self, hook: comfy.hooks.WeightHook, patches, strength_patch=1.0, strength_model=1.0): with self.use_ejected(): diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index ac97353690f5..6f21ca3cff18 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -131,4 +131,4 @@ def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict): model_options["transformer_options"]["wrappers"] = comfy.patcher_extension.copy_nested_dicts(model.wrappers) model_options["transformer_options"]["callbacks"] = comfy.patcher_extension.copy_nested_dicts(model.callbacks) # register hooks on model/model_options - model.register_all_hook_patches(hooks, comfy.hooks.EnumWeightTarget.Model, model_options) + model.register_all_hook_patches(hooks, comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Model), model_options) diff --git a/comfy_extras/nodes_hooks.py b/comfy_extras/nodes_hooks.py index 9d9d48378523..49b90b9d5351 100644 --- a/comfy_extras/nodes_hooks.py +++ b/comfy_extras/nodes_hooks.py @@ -255,7 +255,7 @@ def apply_hooks(self, clip: 'CLIP', schedule_clip: bool, apply_to_conds: bool, h clip.use_clip_schedule = schedule_clip if not clip.use_clip_schedule: clip.patcher.forced_hooks.set_keyframes_on_hooks(None) - clip.patcher.register_all_hook_patches(hooks.get_dict_repr(), comfy.hooks.EnumWeightTarget.Clip) + clip.patcher.register_all_hook_patches(hooks.get_dict_repr(), comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Clip)) return (clip,) class ConditioningTimestepsRange: From 776aa734e1ac0a46fefef6abcc5ad29763003a7e Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sat, 4 Jan 2025 01:02:21 -0600 Subject: [PATCH 03/90] Refactor WrapperHook into TransformerOptionsHook, as there is no need to separate out Wrappers/Callbacks/Patches into different hook types (all affect transformer_options) --- comfy/hooks.py | 24 +++++++++++++++++------- comfy/model_patcher.py | 2 +- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/comfy/hooks.py b/comfy/hooks.py index 181c4996a418..7ca3a8a11a7f 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -44,7 +44,7 @@ class EnumHookType(enum.Enum): Patch = "patch" ObjectPatch = "object_patch" AddModels = "add_models" - Wrappers = "wrappers" + TransformerOptions = "transformer_options" Injections = "add_injections" class EnumWeightTarget(enum.Enum): @@ -245,29 +245,39 @@ def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict if not self.should_register(model, model_options, target_dict, registered): return False -class WrapperHook(Hook): +class TransformerOptionsHook(Hook): ''' - Hook responsible for adding wrappers, callbacks, or anything else onto transformer_options. + Hook responsible for adding wrappers, callbacks, patches, or anything else related to transformer_options. ''' def __init__(self, wrappers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None): - super().__init__(hook_type=EnumHookType.Wrappers) - self.wrappers_dict = wrappers_dict + super().__init__(hook_type=EnumHookType.TransformerOptions) + self.transformers_dict = wrappers_dict def clone(self, subtype: Callable=None): if subtype is None: subtype = type(self) c: WrapperHook = super().clone(subtype) - c.wrappers_dict = self.wrappers_dict + c.transformers_dict = self.transformers_dict return c def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): if not self.should_register(model, model_options, target_dict, registered): return False - add_model_options = {"transformer_options": self.wrappers_dict} + add_model_options = {"transformer_options": self.transformers_dict} + # TODO: call .to on patches/anything else in transformer_options that is expected to do something if self.hook_scope == EnumHookScope.AllConditioning: comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False) registered.append(self) return True + + def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]): + comfy.patcher_extension.merge_nested_dicts(transformer_options, self.transformers_dict, copy_dict1=False) + +class WrapperHook(TransformerOptionsHook): + ''' + For backwards compatibility, this hook is identical to TransformerOptionsHook. + ''' + pass class SetInjectionsHook(Hook): def __init__(self, key: str=None, injections: list[PatcherInjection]=None): diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 071535526a42..2db21bdc4504 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -945,7 +945,7 @@ def register_all_hook_patches(self, hooks_dict: dict[comfy.hooks.EnumHookType, d registered_hooks: list[comfy.hooks.Hook] = [] # handle WrapperHooks, if model_options provided if model_options is not None: - for hook in hooks_dict.get(comfy.hooks.EnumHookType.Wrappers, {}): + for hook in hooks_dict.get(comfy.hooks.EnumHookType.TransformerOptions, {}): hook.add_hook_patches(self, model_options, target_dict, registered_hooks) # handle WeightHooks weight_hooks_to_register: list[comfy.hooks.WeightHook] = [] From 111fd0cadfe83cdda7a1a775f89e0dd675a58d66 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sat, 4 Jan 2025 02:04:07 -0600 Subject: [PATCH 04/90] Refactored HookGroup to also store a dictionary of hooks separated by hook_type, modified necessary code to no longer need to manually separate out hooks by hook_type --- comfy/hooks.py | 78 ++++++++++++++++++------------------- comfy/model_patcher.py | 10 ++--- comfy/sampler_helpers.py | 20 +++++----- comfy_extras/nodes_hooks.py | 2 +- 4 files changed, 53 insertions(+), 57 deletions(-) diff --git a/comfy/hooks.py b/comfy/hooks.py index 7ca3a8a11a7f..9ccfaa6d1a97 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -41,7 +41,6 @@ class EnumHookType(enum.Enum): Hook types, each of which has different expected behavior. ''' Weight = "weight" - Patch = "patch" ObjectPatch = "object_patch" AddModels = "add_models" TransformerOptions = "transformer_options" @@ -194,19 +193,6 @@ def clone(self, subtype: Callable=None): c._strength_clip = self._strength_clip return c -class PatchHook(Hook): - def __init__(self): - super().__init__(hook_type=EnumHookType.Patch) - self.patches: dict = None - - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: PatchHook = super().clone(subtype) - c.patches = self.patches - return c - # TODO: add functionality - class ObjectPatchHook(Hook): def __init__(self): super().__init__(hook_type=EnumHookType.ObjectPatch) @@ -244,6 +230,7 @@ def clone(self, subtype: Callable=None): def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): if not self.should_register(model, model_options, target_dict, registered): return False + return True class TransformerOptionsHook(Hook): ''' @@ -298,12 +285,28 @@ def add_hook_injections(self, model: ModelPatcher): pass class HookGroup: + ''' + Stores groups of hooks, and allows them to be queried by type. + + To prevent breaking their functionality, never modify the underlying self.hooks or self._hook_dict vars directly; + always use the provided functions on HookGroup. + ''' def __init__(self): self.hooks: list[Hook] = [] + self._hook_dict: dict[EnumHookType, list[Hook]] = {} def add(self, hook: Hook): if hook not in self.hooks: self.hooks.append(hook) + self._hook_dict.setdefault(hook.hook_type, []).append(hook) + + def remove(self, hook: Hook): + if hook in self.hooks: + self.hooks.remove(hook) + self._hook_dict[hook.hook_type].remove(hook) + + def get_type(self, hook_type: EnumHookType): + return self._hook_dict.get(hook_type, []) def contains(self, hook: Hook): return hook in self.hooks @@ -329,36 +332,29 @@ def set_keyframes_on_hooks(self, hook_kf: HookKeyframeGroup): for hook in self.hooks: hook.hook_keyframe = hook_kf - def get_dict_repr(self): - d: dict[EnumHookType, dict[Hook, None]] = {} - for hook in self.hooks: - with_type = d.setdefault(hook.hook_type, {}) - with_type[hook] = None - return d - def get_hooks_for_clip_schedule(self): scheduled_hooks: dict[WeightHook, list[tuple[tuple[float,float], HookKeyframe]]] = {} - for hook in self.hooks: - # only care about WeightHooks, for now - if hook.hook_type == EnumHookType.Weight: - hook_schedule = [] - # if no hook keyframes, assign default value - if len(hook.hook_keyframe.keyframes) == 0: - hook_schedule.append(((0.0, 1.0), None)) - scheduled_hooks[hook] = hook_schedule - continue - # find ranges of values - prev_keyframe = hook.hook_keyframe.keyframes[0] - for keyframe in hook.hook_keyframe.keyframes: - if keyframe.start_percent > prev_keyframe.start_percent and not math.isclose(keyframe.strength, prev_keyframe.strength): - hook_schedule.append(((prev_keyframe.start_percent, keyframe.start_percent), prev_keyframe)) - prev_keyframe = keyframe - elif keyframe.start_percent == prev_keyframe.start_percent: - prev_keyframe = keyframe - # create final range, assuming last start_percent was not 1.0 - if not math.isclose(prev_keyframe.start_percent, 1.0): - hook_schedule.append(((prev_keyframe.start_percent, 1.0), prev_keyframe)) + # only care about WeightHooks, for now + for hook in self.get_type(EnumHookType.Weight): + hook: WeightHook + hook_schedule = [] + # if no hook keyframes, assign default value + if len(hook.hook_keyframe.keyframes) == 0: + hook_schedule.append(((0.0, 1.0), None)) scheduled_hooks[hook] = hook_schedule + continue + # find ranges of values + prev_keyframe = hook.hook_keyframe.keyframes[0] + for keyframe in hook.hook_keyframe.keyframes: + if keyframe.start_percent > prev_keyframe.start_percent and not math.isclose(keyframe.strength, prev_keyframe.strength): + hook_schedule.append(((prev_keyframe.start_percent, keyframe.start_percent), prev_keyframe)) + prev_keyframe = keyframe + elif keyframe.start_percent == prev_keyframe.start_percent: + prev_keyframe = keyframe + # create final range, assuming last start_percent was not 1.0 + if not math.isclose(prev_keyframe.start_percent, 1.0): + hook_schedule.append(((prev_keyframe.start_percent, 1.0), prev_keyframe)) + scheduled_hooks[hook] = hook_schedule # hooks should not have their schedules in a list of tuples all_ranges: list[tuple[float, float]] = [] for range_kfs in scheduled_hooks.values(): diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 2db21bdc4504..0430430e52ff 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -940,16 +940,16 @@ def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: com if reset_current_hooks: self.patch_hooks(None) - def register_all_hook_patches(self, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]], target_dict: dict[str], model_options: dict=None): + def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None): self.restore_hook_patches() registered_hooks: list[comfy.hooks.Hook] = [] - # handle WrapperHooks, if model_options provided + # handle TransformerOptionsHooks, if model_options provided if model_options is not None: - for hook in hooks_dict.get(comfy.hooks.EnumHookType.TransformerOptions, {}): + for hook in hooks.get_type(comfy.hooks.EnumHookType.TransformerOptions): hook.add_hook_patches(self, model_options, target_dict, registered_hooks) # handle WeightHooks weight_hooks_to_register: list[comfy.hooks.WeightHook] = [] - for hook in hooks_dict.get(comfy.hooks.EnumHookType.Weight, {}): + for hook in hooks.get_type(comfy.hooks.EnumHookType.Weight): if hook.hook_ref not in self.hook_patches: weight_hooks_to_register.append(hook) if len(weight_hooks_to_register) > 0: @@ -958,7 +958,7 @@ def register_all_hook_patches(self, hooks_dict: dict[comfy.hooks.EnumHookType, d for hook in weight_hooks_to_register: hook.add_hook_patches(self, model_options, target_dict, registered_hooks) for callback in self.get_all_callbacks(CallbacksMP.ON_REGISTER_ALL_HOOK_PATCHES): - callback(self, hooks_dict, target_dict) + callback(self, hooks, target_dict) def add_hook_patches(self, hook: comfy.hooks.WeightHook, patches, strength_patch=1.0, strength_model=1.0): with self.use_ejected(): diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index 6f21ca3cff18..abd44cf6ec5d 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -24,15 +24,13 @@ def get_models_from_cond(cond, model_type): models += [c[model_type]] return models -def get_hooks_from_cond(cond, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]]): +def get_hooks_from_cond(cond, full_hooks: comfy.hooks.HookGroup): # get hooks from conds, and collect cnets so they can be checked for extra_hooks cnets: list[ControlBase] = [] for c in cond: if 'hooks' in c: for hook in c['hooks'].hooks: - hook: comfy.hooks.Hook - with_type = hooks_dict.setdefault(hook.hook_type, {}) - with_type[hook] = None + full_hooks.add(hook) if 'control' in c: cnets.append(c['control']) @@ -50,10 +48,9 @@ def get_extra_hooks_from_cnet(cnet: ControlBase, _list: list): extra_hooks = comfy.hooks.HookGroup.combine_all_hooks(hooks_list) if extra_hooks is not None: for hook in extra_hooks.hooks: - with_type = hooks_dict.setdefault(hook.hook_type, {}) - with_type[hook] = None + full_hooks.add(hook) - return hooks_dict + return full_hooks def convert_cond(cond): out = [] @@ -73,7 +70,7 @@ def get_additional_models(conds, dtype): cnets: list[ControlBase] = [] gligen = [] add_models = [] - hooks: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]] = {} + hooks = comfy.hooks.HookGroup() for k in conds: cnets += get_models_from_cond(conds[k], "control") @@ -90,7 +87,10 @@ def get_additional_models(conds, dtype): inference_memory += m.inference_memory_requirements(dtype) gligen = [x[1] for x in gligen] - hook_models = [x.model for x in hooks.get(comfy.hooks.EnumHookType.AddModels, {}).keys()] + hook_models = [] + for x in hooks.get_type(comfy.hooks.EnumHookType.AddModels): + x: comfy.hooks.AddModelsHook + hook_models.extend(x.models) models = control_models + gligen + add_models + hook_models return models, inference_memory @@ -124,7 +124,7 @@ def cleanup_models(conds, models): def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict): # check for hooks in conds - if not registered, see if can be applied - hooks = {} + hooks = comfy.hooks.HookGroup() for k in conds: get_hooks_from_cond(conds[k], hooks) # add wrappers and callbacks from ModelPatcher to transformer_options diff --git a/comfy_extras/nodes_hooks.py b/comfy_extras/nodes_hooks.py index 49b90b9d5351..642238340d23 100644 --- a/comfy_extras/nodes_hooks.py +++ b/comfy_extras/nodes_hooks.py @@ -255,7 +255,7 @@ def apply_hooks(self, clip: 'CLIP', schedule_clip: bool, apply_to_conds: bool, h clip.use_clip_schedule = schedule_clip if not clip.use_clip_schedule: clip.patcher.forced_hooks.set_keyframes_on_hooks(None) - clip.patcher.register_all_hook_patches(hooks.get_dict_repr(), comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Clip)) + clip.patcher.register_all_hook_patches(hooks, comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Clip)) return (clip,) class ConditioningTimestepsRange: From 6620d86318d19562a4410eabc78c27538d54e445 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sun, 5 Jan 2025 15:26:22 -0600 Subject: [PATCH 05/90] In inner_sample, change "sigmas" to "sampler_sigmas" in transformer_options to not conflict with the "sigmas" that will overwrite "sigmas" in _calc_cond_batch --- comfy/hooks.py | 2 +- comfy/samplers.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/hooks.py b/comfy/hooks.py index 79a7090ba206..3cb0f39636d9 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -442,7 +442,7 @@ def prepare_current_keyframe(self, curr_t: float, transformer_options: dict[str, return False if curr_t == self._curr_t: return False - max_sigma = torch.max(transformer_options["sigmas"]) + max_sigma = torch.max(transformer_options["sample_sigmas"]) prev_index = self._current_index prev_strength = self._current_strength # if met guaranteed steps, look for next keyframe in case need to switch diff --git a/comfy/samplers.py b/comfy/samplers.py index 89464a42ac6a..af2b8e110cf2 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -849,7 +849,7 @@ def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mas self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed) extra_model_options = comfy.model_patcher.create_model_options_clone(self.model_options) - extra_model_options.setdefault("transformer_options", {})["sigmas"] = sigmas + extra_model_options.setdefault("transformer_options", {})["sample_sigmas"] = sigmas extra_args = {"model_options": extra_model_options, "seed": seed} executor = comfy.patcher_extension.WrapperExecutor.new_class_executor( From 8270ff312f7aefc4d29aeeed667296b2a56628ce Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sun, 5 Jan 2025 21:07:02 -0600 Subject: [PATCH 06/90] Refactored 'registered' to be HookGroup instead of a list of Hooks, made AddModelsHook operational and compliant with should_register result, moved TransformerOptionsHook handling out of ModelPatcher.register_all_hook_patches, support patches in TransformerOptionsHook properly by casting any patches/wrappers/hooks to proper device at sample time --- comfy/hooks.py | 34 +++++++++++++++--------- comfy/model_patcher.py | 15 +++++------ comfy/sampler_helpers.py | 48 +++++++++++++++++++++++++-------- comfy/samplers.py | 57 +++++++++++++++++++++++++++++++++++++--- 4 files changed, 119 insertions(+), 35 deletions(-) diff --git a/comfy/hooks.py b/comfy/hooks.py index 3ead8c963d6a..25d67b86c76a 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -65,7 +65,7 @@ class _HookRef: pass -def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): +def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): '''Example for how should_register function should look like.''' return True @@ -114,10 +114,10 @@ def clone(self, subtype: Callable=None): c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive return c - def should_register(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): + def should_register(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): return self.custom_should_register(self, model, model_options, target_dict, registered) - def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): + def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): raise NotImplementedError("add_hook_patches should be defined for Hook subclasses") def on_apply(self, model: ModelPatcher, transformer_options: dict[str]): @@ -154,7 +154,7 @@ def strength_model(self): def strength_clip(self): return self._strength_clip * self.strength - def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): + def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): if not self.should_register(model, model_options, target_dict, registered): return False weights = None @@ -178,7 +178,7 @@ def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict else: weights = self.weights model.add_hook_patches(hook=self, patches=weights, strength_patch=strength) - registered.append(self) + registered.add(self) return True # TODO: add logs about any keys that were not applied @@ -212,11 +212,12 @@ class AddModelsHook(Hook): Note, value of hook_scope is ignored and is treated as AllConditioning. ''' - def __init__(self, key: str=None, models: list[ModelPatcher]=None): + def __init__(self, models: list[ModelPatcher]=None, key: str=None): super().__init__(hook_type=EnumHookType.AddModels) - self.key = key self.models = models + self.key = key self.append_when_same = True + '''Curently does nothing.''' def clone(self, subtype: Callable=None): if subtype is None: @@ -227,9 +228,10 @@ def clone(self, subtype: Callable=None): c.append_when_same = self.append_when_same return c - def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): + def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): if not self.should_register(model, model_options, target_dict, registered): return False + registered.add(self) return True class TransformerOptionsHook(Hook): @@ -247,14 +249,17 @@ def clone(self, subtype: Callable=None): c.transformers_dict = self.transformers_dict return c - def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): + def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): if not self.should_register(model, model_options, target_dict, registered): return False - add_model_options = {"transformer_options": self.transformers_dict} - # TODO: call .to on patches/anything else in transformer_options that is expected to do something + # NOTE: to_load_options will be used to manually load patches/wrappers/callbacks from hooks if self.hook_scope == EnumHookScope.AllConditioning: - comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False) - registered.append(self) + add_model_options = {"transformer_options": self.transformers_dict, + "to_load_options": self.transformers_dict} + else: + add_model_options = {"to_load_options": self.transformers_dict} + comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False) + registered.add(self) return True def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]): @@ -295,6 +300,9 @@ def __init__(self): self.hooks: list[Hook] = [] self._hook_dict: dict[EnumHookType, list[Hook]] = {} + def __len__(self): + return len(self.hooks) + def add(self, hook: Hook): if hook not in self.hooks: self.hooks.append(hook) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 0430430e52ff..2a5510873e45 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -940,13 +940,11 @@ def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: com if reset_current_hooks: self.patch_hooks(None) - def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None): + def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None, + registered: comfy.hooks.HookGroup = None): self.restore_hook_patches() - registered_hooks: list[comfy.hooks.Hook] = [] - # handle TransformerOptionsHooks, if model_options provided - if model_options is not None: - for hook in hooks.get_type(comfy.hooks.EnumHookType.TransformerOptions): - hook.add_hook_patches(self, model_options, target_dict, registered_hooks) + if registered is None: + registered = comfy.hooks.HookGroup() # handle WeightHooks weight_hooks_to_register: list[comfy.hooks.WeightHook] = [] for hook in hooks.get_type(comfy.hooks.EnumHookType.Weight): @@ -956,9 +954,10 @@ def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: d # clone hook_patches to become backup so that any non-dynamic hooks will return to their original state self.hook_patches_backup = create_hook_patches_clone(self.hook_patches) for hook in weight_hooks_to_register: - hook.add_hook_patches(self, model_options, target_dict, registered_hooks) + hook.add_hook_patches(self, model_options, target_dict, registered) for callback in self.get_all_callbacks(CallbacksMP.ON_REGISTER_ALL_HOOK_PATCHES): - callback(self, hooks, target_dict) + callback(self, hooks, target_dict, model_options, registered) + return registered def add_hook_patches(self, hook: comfy.hooks.WeightHook, patches, strength_patch=1.0, strength_model=1.0): with self.use_ejected(): diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index abd44cf6ec5d..cb9388519249 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -70,13 +70,11 @@ def get_additional_models(conds, dtype): cnets: list[ControlBase] = [] gligen = [] add_models = [] - hooks = comfy.hooks.HookGroup() for k in conds: cnets += get_models_from_cond(conds[k], "control") gligen += get_models_from_cond(conds[k], "gligen") add_models += get_models_from_cond(conds[k], "additional_models") - get_hooks_from_cond(conds[k], hooks) control_nets = set(cnets) @@ -87,14 +85,20 @@ def get_additional_models(conds, dtype): inference_memory += m.inference_memory_requirements(dtype) gligen = [x[1] for x in gligen] - hook_models = [] - for x in hooks.get_type(comfy.hooks.EnumHookType.AddModels): - x: comfy.hooks.AddModelsHook - hook_models.extend(x.models) - models = control_models + gligen + add_models + hook_models + models = control_models + gligen + add_models return models, inference_memory +def get_additional_models_from_model_options(model_options: dict[str]=None): + """loads additional models from registered AddModels hooks""" + models = [] + if model_options is not None and "registered_hooks" in model_options: + registered: comfy.hooks.HookGroup = model_options["registered_hooks"] + for hook in registered.get_type(comfy.hooks.EnumHookType.AddModels): + hook: comfy.hooks.AddModelsHook + models.extend(hook.models) + return models + def cleanup_additional_models(models): """cleanup additional models that were loaded""" for m in models: @@ -102,9 +106,10 @@ def cleanup_additional_models(models): m.cleanup() -def prepare_sampling(model: 'ModelPatcher', noise_shape, conds): - real_model: 'BaseModel' = None +def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None): + real_model: BaseModel = None models, inference_memory = get_additional_models(conds, model.model_dtype()) + models += get_additional_models_from_model_options(model_options) models += model.get_nested_additional_models() # TODO: does this require inference_memory update? memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory @@ -130,5 +135,26 @@ def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict): # add wrappers and callbacks from ModelPatcher to transformer_options model_options["transformer_options"]["wrappers"] = comfy.patcher_extension.copy_nested_dicts(model.wrappers) model_options["transformer_options"]["callbacks"] = comfy.patcher_extension.copy_nested_dicts(model.callbacks) - # register hooks on model/model_options - model.register_all_hook_patches(hooks, comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Model), model_options) + # begin registering hooks + registered = comfy.hooks.HookGroup() + target_dict = comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Model) + # handle all TransformerOptionsHooks + for hook in hooks.get_type(comfy.hooks.EnumHookType.TransformerOptions): + hook: comfy.hooks.TransformerOptionsHook + hook.add_hook_patches(model, model_options, target_dict, registered) + # handle all AddModelsHooks + for hook in hooks.get_type(comfy.hooks.EnumHookType.AddModels): + hook: comfy.hooks.AddModelsHook + hook.add_hook_patches(model, model_options, target_dict, registered) + # handle all WeightHooks by registering on ModelPatcher + model.register_all_hook_patches(hooks, target_dict, model_options, registered) + # add registered_hooks onto model_options for further reference + if len(registered) > 0: + model_options["registered_hooks"] = registered + # merge original wrappers and callbacks with hooked wrappers and callbacks + to_load_options: dict[str] = model_options.setdefault("to_load_options", {}) + for wc_name in ["wrappers", "callbacks"]: + comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name], + copy_dict1=False) + return to_load_options + diff --git a/comfy/samplers.py b/comfy/samplers.py index af2b8e110cf2..8f8345abc3f1 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -819,9 +819,58 @@ def get_total_hook_groups_in_conds(conds: dict[str, list[dict[str]]]): return len(hooks_set) +def cast_to_load_options(model_options: dict[str], device=None, dtype=None): + ''' + If any patches from hooks, wrappers, or callbacks have .to to be called, call it. + ''' + if model_options is None: + return + to_load_options = model_options.get("to_load_options", None) + if to_load_options is None: + return + + casts = [] + if device is not None: + casts.append(device) + if dtype is not None: + casts.append(dtype) + # if nothing to apply, do nothing + if len(casts) == 0: + return + + # Try to call .to on patches + if "patches" in to_load_options: + patches = to_load_options["patches"] + for name in patches: + patch_list = patches[name] + for i in range(len(patch_list)): + if hasattr(patch_list[i], "to"): + for cast in casts: + patch_list[i] = patch_list[i].to(cast) + if "patches_replace" in to_load_options: + patches = to_load_options["patches_replace"] + for name in patches: + patch_list = patches[name] + for k in patch_list: + if hasattr(patch_list[k], "to"): + for cast in casts: + patch_list[k] = patch_list[k].to(cast) + # Try to call .to on any wrappers/callbacks + wrappers_and_callbacks = ["wrappers", "callbacks"] + for wc_name in wrappers_and_callbacks: + if wc_name in to_load_options: + wc: dict[str, list] = to_load_options[wc_name] + for wc_dict in wc.values(): + for wc_list in wc_dict.values(): + for i in range(len(wc_list)): + if hasattr(wc_list[i], "to"): + for cast in casts: + wc_list[i] = wc_list[i].to(cast) + + class CFGGuider: - def __init__(self, model_patcher): - self.model_patcher: 'ModelPatcher' = model_patcher + def __init__(self, model_patcher: ModelPatcher): + self.model_patcher = model_patcher self.model_options = model_patcher.model_options self.original_conds = {} self.cfg = 1.0 @@ -861,7 +910,7 @@ def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mas return self.inner_model.process_latent_out(samples.to(torch.float32)) def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None): - self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds) + self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options) device = self.model_patcher.load_device if denoise_mask is not None: @@ -870,6 +919,7 @@ def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, noise = noise.to(device) latent_image = latent_image.to(device) sigmas = sigmas.to(device) + cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype()) try: self.model_patcher.pre_run() @@ -906,6 +956,7 @@ def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callba ) output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) finally: + cast_to_load_options(self.model_options, device=self.model_patcher.offload_device) self.model_options = orig_model_options self.model_patcher.hook_mode = orig_hook_mode self.model_patcher.restore_hook_patches() From 4446c86052bd9a00b72205b761b3744dd51f90eb Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sun, 5 Jan 2025 22:25:51 -0600 Subject: [PATCH 07/90] Made hook clone code sane, made clear ObjectPatchHook and SetInjectionsHook are not yet operational --- comfy/hooks.py | 64 +++++++++++++++++++--------------------- comfy/sampler_helpers.py | 1 - 2 files changed, 31 insertions(+), 34 deletions(-) diff --git a/comfy/hooks.py b/comfy/hooks.py index 25d67b86c76a..b62092cce380 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -101,10 +101,8 @@ def initialize_timesteps(self, model: BaseModel): def reset(self): self.hook_keyframe.reset() - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: Hook = subtype() + def clone(self): + c: Hook = self.__class__() c.hook_type = self.hook_type c.hook_ref = self.hook_ref c.hook_id = self.hook_id @@ -182,10 +180,8 @@ def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict return True # TODO: add logs about any keys that were not applied - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: WeightHook = super().clone(subtype) + def clone(self): + c: WeightHook = super().clone() c.weights = self.weights c.weights_clip = self.weights_clip c.need_weight_init = self.need_weight_init @@ -194,17 +190,21 @@ def clone(self, subtype: Callable=None): return c class ObjectPatchHook(Hook): - def __init__(self): + def __init__(self, object_patches: dict[str]=None): super().__init__(hook_type=EnumHookType.ObjectPatch) - self.object_patches: dict = None + self.object_patches = object_patches - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: ObjectPatchHook = super().clone(subtype) + def clone(self): + c: ObjectPatchHook = super().clone() c.object_patches = self.object_patches return c - # TODO: add functionality + + def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): + raise NotImplementedError("ObjectPatchHook is not supported yet in ComfyUI.") + if not self.should_register(model, model_options, target_dict, registered): + return False + registered.add(self) + return True class AddModelsHook(Hook): ''' @@ -219,12 +219,10 @@ def __init__(self, models: list[ModelPatcher]=None, key: str=None): self.append_when_same = True '''Curently does nothing.''' - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: AddModelsHook = super().clone(subtype) - c.key = self.key + def clone(self): + c: AddModelsHook = super().clone() c.models = self.models.copy() if self.models else self.models + c.key = self.key c.append_when_same = self.append_when_same return c @@ -242,10 +240,8 @@ def __init__(self, wrappers_dict: dict[str, dict[str, dict[str, list[Callable]]] super().__init__(hook_type=EnumHookType.TransformerOptions) self.transformers_dict = wrappers_dict - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: WrapperHook = super().clone(subtype) + def clone(self): + c: TransformerOptionsHook = super().clone() c.transformers_dict = self.transformers_dict return c @@ -265,11 +261,8 @@ def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]): comfy.patcher_extension.merge_nested_dicts(transformer_options, self.transformers_dict, copy_dict1=False) -class WrapperHook(TransformerOptionsHook): - ''' - For backwards compatibility, this hook is identical to TransformerOptionsHook. - ''' - pass +WrapperHook = TransformerOptionsHook +'''Only here for backwards compatibility, WrapperHook is identical to TransformerOptionsHook.''' class SetInjectionsHook(Hook): def __init__(self, key: str=None, injections: list[PatcherInjection]=None): @@ -277,14 +270,19 @@ def __init__(self, key: str=None, injections: list[PatcherInjection]=None): self.key = key self.injections = injections - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: SetInjectionsHook = super().clone(subtype) + def clone(self): + c: SetInjectionsHook = super().clone() c.key = self.key c.injections = self.injections.copy() if self.injections else self.injections return c + def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): + raise NotImplementedError("SetInjectionsHook is not supported yet in ComfyUI.") + if not self.should_register(model, model_options, target_dict, registered): + return False + registered.add(self) + return True + def add_hook_injections(self, model: ModelPatcher): # TODO: add functionality pass diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index cb9388519249..d43280fe459c 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -157,4 +157,3 @@ def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict): comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name], copy_dict1=False) return to_load_options - From 03a97b604a3e8ca9f54c711ed3b007f07c9115ba Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 6 Jan 2025 01:03:59 -0600 Subject: [PATCH 08/90] Fix performance of hooks when hooks are appended via Cond Pair Set Props nodes by properly caching between positive and negative conds, make hook_patches_backup behave as intended (in the case that something pre-registers WeightHooks on the ModelPatcher instead of registering it at sample time) --- comfy/hooks.py | 33 +++++++++++++++++++++++++-------- comfy/model_patcher.py | 10 ++++++---- 2 files changed, 31 insertions(+), 12 deletions(-) diff --git a/comfy/hooks.py b/comfy/hooks.py index b62092cce380..dde3e8bcb50a 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -317,6 +317,18 @@ def get_type(self, hook_type: EnumHookType): def contains(self, hook: Hook): return hook in self.hooks + def is_subset_of(self, other: HookGroup): + self_hooks = set(self.hooks) + other_hooks = set(other.hooks) + return self_hooks.issubset(other_hooks) + + def new_with_common_hooks(self, other: HookGroup): + c = HookGroup() + for hook in self.hooks: + if other.contains(hook): + c.add(hook.clone()) + return c + def clone(self): c = HookGroup() for hook in self.hooks: @@ -668,24 +680,26 @@ def _combine_hooks_from_values(c_dict: dict[str, HookGroup], values: dict[str, H else: c_dict[hooks_key] = cache[hooks_tuple] -def conditioning_set_values_with_hooks(conditioning, values={}, append_hooks=True): +def conditioning_set_values_with_hooks(conditioning, values={}, append_hooks=True, + cache: dict[tuple[HookGroup, HookGroup], HookGroup]=None): c = [] - hooks_combine_cache: dict[tuple[HookGroup, HookGroup], HookGroup] = {} + if cache is None: + cache = {} for t in conditioning: n = [t[0], t[1].copy()] for k in values: if append_hooks and k == 'hooks': - _combine_hooks_from_values(n[1], values, hooks_combine_cache) + _combine_hooks_from_values(n[1], values, cache) else: n[1][k] = values[k] c.append(n) return c -def set_hooks_for_conditioning(cond, hooks: HookGroup, append_hooks=True): +def set_hooks_for_conditioning(cond, hooks: HookGroup, append_hooks=True, cache: dict[tuple[HookGroup, HookGroup], HookGroup]=None): if hooks is None: return cond - return conditioning_set_values_with_hooks(cond, {'hooks': hooks}, append_hooks=append_hooks) + return conditioning_set_values_with_hooks(cond, {'hooks': hooks}, append_hooks=append_hooks, cache=cache) def set_timesteps_for_conditioning(cond, timestep_range: tuple[float,float]): if timestep_range is None: @@ -720,9 +734,10 @@ def combine_with_new_conds(conds: list, new_conds: list): def set_conds_props(conds: list, strength: float, set_cond_area: str, mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True): final_conds = [] + cache = {} for c in conds: # first, apply lora_hook to conditioning, if provided - c = set_hooks_for_conditioning(c, hooks, append_hooks=append_hooks) + c = set_hooks_for_conditioning(c, hooks, append_hooks=append_hooks, cache=cache) # next, apply mask to conditioning c = set_mask_for_conditioning(cond=c, mask=mask, strength=strength, set_cond_area=set_cond_area) # apply timesteps, if present @@ -734,9 +749,10 @@ def set_conds_props(conds: list, strength: float, set_cond_area: str, def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1.0, set_cond_area: str="default", mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True): combined_conds = [] + cache = {} for c, masked_c in zip(conds, new_conds): # first, apply lora_hook to new conditioning, if provided - masked_c = set_hooks_for_conditioning(masked_c, hooks, append_hooks=append_hooks) + masked_c = set_hooks_for_conditioning(masked_c, hooks, append_hooks=append_hooks, cache=cache) # next, apply mask to new conditioning, if provided masked_c = set_mask_for_conditioning(cond=masked_c, mask=mask, set_cond_area=set_cond_area, strength=strength) # apply timesteps, if present @@ -748,9 +764,10 @@ def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1. def set_default_conds_and_combine(conds: list, new_conds: list, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True): combined_conds = [] + cache = {} for c, new_c in zip(conds, new_conds): # first, apply lora_hook to new conditioning, if provided - new_c = set_hooks_for_conditioning(new_c, hooks, append_hooks=append_hooks) + new_c = set_hooks_for_conditioning(new_c, hooks, append_hooks=append_hooks, cache=cache) # next, add default_cond key to cond so that during sampling, it can be identified new_c = conditioning_set_values(new_c, {'default': True}) # apply timesteps, if present diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 2a5510873e45..57a843b8ff00 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -210,7 +210,7 @@ def __init__(self, model, load_device, offload_device, size=0, weight_inplace_up self.injections: dict[str, list[PatcherInjection]] = {} self.hook_patches: dict[comfy.hooks._HookRef] = {} - self.hook_patches_backup: dict[comfy.hooks._HookRef] = {} + self.hook_patches_backup: dict[comfy.hooks._HookRef] = None self.hook_backup: dict[str, tuple[torch.Tensor, torch.device]] = {} self.cached_hook_patches: dict[comfy.hooks.HookGroup, dict[str, torch.Tensor]] = {} self.current_hooks: Optional[comfy.hooks.HookGroup] = None @@ -282,7 +282,7 @@ def clone(self): n.injections[k] = i.copy() # hooks n.hook_patches = create_hook_patches_clone(self.hook_patches) - n.hook_patches_backup = create_hook_patches_clone(self.hook_patches_backup) + n.hook_patches_backup = create_hook_patches_clone(self.hook_patches_backup) if self.hook_patches_backup else self.hook_patches_backup for group in self.cached_hook_patches: n.cached_hook_patches[group] = {} for k in self.cached_hook_patches[group]: @@ -912,9 +912,9 @@ def prepare_state(self, timestep): callback(self, timestep) def restore_hook_patches(self): - if len(self.hook_patches_backup) > 0: + if self.hook_patches_backup is not None: self.hook_patches = self.hook_patches_backup - self.hook_patches_backup = {} + self.hook_patches_backup = None def set_hook_mode(self, hook_mode: comfy.hooks.EnumHookMode): self.hook_mode = hook_mode @@ -950,6 +950,8 @@ def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: d for hook in hooks.get_type(comfy.hooks.EnumHookType.Weight): if hook.hook_ref not in self.hook_patches: weight_hooks_to_register.append(hook) + else: + registered.add(hook) if len(weight_hooks_to_register) > 0: # clone hook_patches to become backup so that any non-dynamic hooks will return to their original state self.hook_patches_backup = create_hook_patches_clone(self.hook_patches) From 0a7e2ae787b81035798ad2ef1ade8cf882d67b69 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 6 Jan 2025 01:04:29 -0600 Subject: [PATCH 09/90] Filter only registered hooks on self.conds in CFGGuider.sample --- comfy/sampler_helpers.py | 3 +++ comfy/samplers.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index d43280fe459c..1433d185908e 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -128,6 +128,9 @@ def cleanup_models(conds, models): cleanup_additional_models(set(control_cleanup)) def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict): + ''' + Registers hooks from conds. + ''' # check for hooks in conds - if not registered, see if can be applied hooks = comfy.hooks.HookGroup() for k in conds: diff --git a/comfy/samplers.py b/comfy/samplers.py index 8f8345abc3f1..43a735c6e196 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -810,6 +810,33 @@ def preprocess_conds_hooks(conds: dict[str, list[dict[str]]]): for cond in conds_to_modify: cond['hooks'] = hooks +def filter_registered_hooks_on_conds(conds: dict[str, list[dict[str]]], model_options: dict[str]): + '''Modify 'hooks' on conds so that only hooks that were registered remain. Properly accounts for + HookGroups that have the same reference.''' + registered: comfy.hooks.HookGroup = model_options.get('registered_hooks', None) + # if None were registered, make sure all hooks are cleaned from conds + if registered is None: + for k in conds: + for kk in conds[k]: + kk.pop('hooks', None) + return + # find conds that contain hooks to be replaced - group by common HookGroup refs + hook_replacement: dict[comfy.hooks.HookGroup, list[dict]] = {} + for k in conds: + for kk in conds[k]: + hooks: comfy.hooks.HookGroup = kk.get('hooks', None) + if hooks is not None: + if not hooks.is_subset_of(registered): + to_replace = hook_replacement.setdefault(hooks, []) + to_replace.append(kk) + # for each hook to replace, create a new proper HookGroup and assign to all common conds + for hooks, conds_to_modify in hook_replacement.items(): + new_hooks = hooks.new_with_common_hooks(registered) + if len(new_hooks) == 0: + new_hooks = None + for kk in conds_to_modify: + kk['hooks'] = new_hooks + def get_total_hook_groups_in_conds(conds: dict[str, list[dict[str]]]): hooks_set = set() @@ -949,6 +976,7 @@ def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callba if get_total_hook_groups_in_conds(self.conds) <= 1: self.model_patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram comfy.sampler_helpers.prepare_model_patcher(self.model_patcher, self.conds, self.model_options) + filter_registered_hooks_on_conds(self.conds, self.model_options) executor = comfy.patcher_extension.WrapperExecutor.new_class_executor( self.outer_sample, self, From f48f90e471fc5440135e7886d712518467c59c00 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 6 Jan 2025 02:23:04 -0600 Subject: [PATCH 10/90] Make hook_scope functional for TransformerOptionsHook --- comfy/hooks.py | 41 ++++++++++++++++++++++++++--------------- comfy/model_patcher.py | 4 ++-- 2 files changed, 28 insertions(+), 17 deletions(-) diff --git a/comfy/hooks.py b/comfy/hooks.py index dde3e8bcb50a..cc9f6cd546cc 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -86,9 +86,9 @@ def __init__(self, hook_type: EnumHookType=None, hook_ref: _HookRef=None, hook_i self.hook_ref = hook_ref if hook_ref else _HookRef() self.hook_id = hook_id self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup() + self.hook_scope = hook_scope self.custom_should_register = default_should_register self.auto_apply_to_nonpositive = False - self.hook_scope = hook_scope @property def strength(self): @@ -107,6 +107,7 @@ def clone(self): c.hook_ref = self.hook_ref c.hook_id = self.hook_id c.hook_keyframe = self.hook_keyframe + c.hook_scope = self.hook_scope c.custom_should_register = self.custom_should_register # TODO: make this do something c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive @@ -118,12 +119,6 @@ def should_register(self, model: ModelPatcher, model_options: dict, target_dict: def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): raise NotImplementedError("add_hook_patches should be defined for Hook subclasses") - def on_apply(self, model: ModelPatcher, transformer_options: dict[str]): - pass - - def on_unapply(self, model: ModelPatcher, transformer_options: dict[str]): - pass - def __eq__(self, other: Hook): return self.__class__ == other.__class__ and self.hook_ref == other.hook_ref @@ -143,6 +138,7 @@ def __init__(self, strength_model=1.0, strength_clip=1.0): self.need_weight_init = True self._strength_model = strength_model self._strength_clip = strength_clip + self.hook_scope = EnumHookScope.HookedOnly # this value does not matter for WeightHooks, just for docs @property def strength_model(self): @@ -190,9 +186,11 @@ def clone(self): return c class ObjectPatchHook(Hook): - def __init__(self, object_patches: dict[str]=None): + def __init__(self, object_patches: dict[str]=None, + hook_scope=EnumHookScope.AllConditioning): super().__init__(hook_type=EnumHookType.ObjectPatch) self.object_patches = object_patches + self.hook_scope = hook_scope def clone(self): c: ObjectPatchHook = super().clone() @@ -216,14 +214,11 @@ def __init__(self, models: list[ModelPatcher]=None, key: str=None): super().__init__(hook_type=EnumHookType.AddModels) self.models = models self.key = key - self.append_when_same = True - '''Curently does nothing.''' def clone(self): c: AddModelsHook = super().clone() c.models = self.models.copy() if self.models else self.models c.key = self.key - c.append_when_same = self.append_when_same return c def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): @@ -236,9 +231,11 @@ class TransformerOptionsHook(Hook): ''' Hook responsible for adding wrappers, callbacks, patches, or anything else related to transformer_options. ''' - def __init__(self, wrappers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None): + def __init__(self, transformers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None, + hook_scope=EnumHookScope.AllConditioning): super().__init__(hook_type=EnumHookType.TransformerOptions) - self.transformers_dict = wrappers_dict + self.transformers_dict = transformers_dict + self.hook_scope = hook_scope def clone(self): c: TransformerOptionsHook = super().clone() @@ -254,8 +251,9 @@ def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict "to_load_options": self.transformers_dict} else: add_model_options = {"to_load_options": self.transformers_dict} + # only register if will not be included in AllConditioning to avoid double loading + registered.add(self) comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False) - registered.add(self) return True def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]): @@ -265,10 +263,12 @@ def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]): '''Only here for backwards compatibility, WrapperHook is identical to TransformerOptionsHook.''' class SetInjectionsHook(Hook): - def __init__(self, key: str=None, injections: list[PatcherInjection]=None): + def __init__(self, key: str=None, injections: list[PatcherInjection]=None, + hook_scope=EnumHookScope.AllConditioning): super().__init__(hook_type=EnumHookType.Injections) self.key = key self.injections = injections + self.hook_scope = hook_scope def clone(self): c: SetInjectionsHook = super().clone() @@ -590,6 +590,17 @@ def get_sorted_list_via_attr(objects: list, attr: str) -> list: sorted_list.extend(object_list) return sorted_list +def create_transformer_options_from_hooks(model: ModelPatcher, hooks: HookGroup, transformer_options: dict[str]=None): + # if no hooks or is not a ModelPatcher for sampling, return empty dict + if hooks is None or model.is_clip: + return {} + if transformer_options is None: + transformer_options = {} + for hook in hooks.get_type(EnumHookType.TransformerOptions): + hook: TransformerOptionsHook + hook.on_apply_hooks(model, transformer_options) + return transformer_options + def create_hook_lora(lora: dict[str, torch.Tensor], strength_model: float, strength_clip: float): hook_group = HookGroup() hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 57a843b8ff00..51a62e048e72 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -1010,11 +1010,11 @@ def get_combined_hook_patches(self, hooks: comfy.hooks.HookGroup): def apply_hooks(self, hooks: comfy.hooks.HookGroup, transformer_options: dict=None, force_apply=False): # TODO: return transformer_options dict with any additions from hooks if self.current_hooks == hooks and (not force_apply or (not self.is_clip and hooks is None)): - return {} + return comfy.hooks.create_transformer_options_from_hooks(self, hooks, transformer_options) self.patch_hooks(hooks=hooks) for callback in self.get_all_callbacks(CallbacksMP.ON_APPLY_HOOKS): callback(self, hooks) - return {} + return comfy.hooks.create_transformer_options_from_hooks(self, hooks, transformer_options) def patch_hooks(self, hooks: comfy.hooks.HookGroup): with self.use_ejected(): From 1b38f5bf57ca07490e616dd58ec3004d05de0655 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 6 Jan 2025 17:11:12 -0600 Subject: [PATCH 11/90] removed 4 whitespace lines to satisfy Ruff, --- comfy/hooks.py | 4 ++-- comfy/samplers.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/comfy/hooks.py b/comfy/hooks.py index cc9f6cd546cc..46fc06bdc898 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -255,7 +255,7 @@ def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict registered.add(self) comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False) return True - + def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]): comfy.patcher_extension.merge_nested_dicts(transformer_options, self.transformers_dict, copy_dict1=False) @@ -290,7 +290,7 @@ def add_hook_injections(self, model: ModelPatcher): class HookGroup: ''' Stores groups of hooks, and allows them to be queried by type. - + To prevent breaking their functionality, never modify the underlying self.hooks or self._hook_dict vars directly; always use the provided functions on HookGroup. ''' diff --git a/comfy/samplers.py b/comfy/samplers.py index 43a735c6e196..a725d5185b12 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -855,7 +855,7 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None): to_load_options = model_options.get("to_load_options", None) if to_load_options is None: return - + casts = [] if device is not None: casts.append(device) @@ -864,7 +864,7 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None): # if nothing to apply, do nothing if len(casts) == 0: return - + # Try to call .to on patches if "patches" in to_load_options: patches = to_load_options["patches"] From 58bf8815c84b67ab26b0f08b8530a822b9899b10 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 6 Jan 2025 20:34:30 -0600 Subject: [PATCH 12/90] Add a get_injections function to ModelPatcher --- comfy/model_patcher.py | 3 +++ comfy/samplers.py | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 51a62e048e72..7d7977c1478e 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -842,6 +842,9 @@ def remove_injections(self, key: str): if key in self.injections: self.injections.pop(key) + def get_injections(self, key: str): + return self.injections.get(key, None) + def set_additional_models(self, key: str, models: list['ModelPatcher']): self.additional_models[key] = models diff --git a/comfy/samplers.py b/comfy/samplers.py index a725d5185b12..5cc33a7d9f63 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -865,7 +865,7 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None): if len(casts) == 0: return - # Try to call .to on patches + # try to call .to on patches if "patches" in to_load_options: patches = to_load_options["patches"] for name in patches: @@ -882,7 +882,7 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None): if hasattr(patch_list[k], "to"): for cast in casts: patch_list[k] = patch_list[k].to(cast) - # Try to call .to on any wrappers/callbacks + # try to call .to on any wrappers/callbacks wrappers_and_callbacks = ["wrappers", "callbacks"] for wc_name in wrappers_and_callbacks: if wc_name in to_load_options: From 216fea15ee033d3301241a5ceb0e193b4924de04 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 7 Jan 2025 00:59:18 -0600 Subject: [PATCH 13/90] Made TransformerOptionsHook contribute to registered hooks properly, added some doc strings and removed a so-far unused variable --- comfy/hooks.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/comfy/hooks.py b/comfy/hooks.py index 46fc06bdc898..7c2f668929b2 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -66,7 +66,7 @@ class _HookRef: def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): - '''Example for how should_register function should look like.''' + '''Example for how custom_should_register function can look like.''' return True @@ -83,12 +83,17 @@ class Hook: def __init__(self, hook_type: EnumHookType=None, hook_ref: _HookRef=None, hook_id: str=None, hook_keyframe: HookKeyframeGroup=None, hook_scope=EnumHookScope.AllConditioning): self.hook_type = hook_type + '''Enum identifying the general class of this hook.''' self.hook_ref = hook_ref if hook_ref else _HookRef() + '''Reference shared between hook clones that have the same value. Should NOT be modified.''' self.hook_id = hook_id + '''Optional string ID to identify hook; useful if need to consolidate duplicates at registration time.''' self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup() + '''Keyframe storage that can be referenced to get strength for current sampling step.''' self.hook_scope = hook_scope + '''Scope of where this hook should apply in terms of the conds used in sampling run.''' self.custom_should_register = default_should_register - self.auto_apply_to_nonpositive = False + '''Can be overriden with a compatible function to decide if this hook should be registered without the need to override .should_register''' @property def strength(self): @@ -109,8 +114,6 @@ def clone(self): c.hook_keyframe = self.hook_keyframe c.hook_scope = self.hook_scope c.custom_should_register = self.custom_should_register - # TODO: make this do something - c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive return c def should_register(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): @@ -236,28 +239,34 @@ def __init__(self, transformers_dict: dict[str, dict[str, dict[str, list[Callabl super().__init__(hook_type=EnumHookType.TransformerOptions) self.transformers_dict = transformers_dict self.hook_scope = hook_scope + self._skip_adding = False + '''Internal value used to avoid double load of transformer_options when hook_scope is AllConditioning.''' def clone(self): c: TransformerOptionsHook = super().clone() c.transformers_dict = self.transformers_dict + c._skip_adding = self._skip_adding return c def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): if not self.should_register(model, model_options, target_dict, registered): return False # NOTE: to_load_options will be used to manually load patches/wrappers/callbacks from hooks + self._skip_adding = False if self.hook_scope == EnumHookScope.AllConditioning: add_model_options = {"transformer_options": self.transformers_dict, "to_load_options": self.transformers_dict} + # skip_adding if included in AllConditioning to avoid double loading + self._skip_adding = True else: add_model_options = {"to_load_options": self.transformers_dict} - # only register if will not be included in AllConditioning to avoid double loading - registered.add(self) + registered.add(self) comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False) return True def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]): - comfy.patcher_extension.merge_nested_dicts(transformer_options, self.transformers_dict, copy_dict1=False) + if not self._skip_adding: + comfy.patcher_extension.merge_nested_dicts(transformer_options, self.transformers_dict, copy_dict1=False) WrapperHook = TransformerOptionsHook '''Only here for backwards compatibility, WrapperHook is identical to TransformerOptionsHook.''' From 3cd4c5cb0a9d4f4f944ee1382e074d3a41e18874 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 7 Jan 2025 02:22:49 -0600 Subject: [PATCH 14/90] Rename AddModelsHooks to AdditionalModelsHook, rename SetInjectionsHook to InjectionsHook (not yet implemented, but at least getting the naming figured out) --- comfy/hooks.py | 26 +++++++------------------- comfy/sampler_helpers.py | 8 ++++---- 2 files changed, 11 insertions(+), 23 deletions(-) diff --git a/comfy/hooks.py b/comfy/hooks.py index 7c2f668929b2..9d0731072902 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -42,7 +42,7 @@ class EnumHookType(enum.Enum): ''' Weight = "weight" ObjectPatch = "object_patch" - AddModels = "add_models" + AdditionalModels = "add_models" TransformerOptions = "transformer_options" Injections = "add_injections" @@ -202,24 +202,20 @@ def clone(self): def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): raise NotImplementedError("ObjectPatchHook is not supported yet in ComfyUI.") - if not self.should_register(model, model_options, target_dict, registered): - return False - registered.add(self) - return True -class AddModelsHook(Hook): +class AdditionalModelsHook(Hook): ''' Hook responsible for telling model management any additional models that should be loaded. Note, value of hook_scope is ignored and is treated as AllConditioning. ''' def __init__(self, models: list[ModelPatcher]=None, key: str=None): - super().__init__(hook_type=EnumHookType.AddModels) + super().__init__(hook_type=EnumHookType.AdditionalModels) self.models = models self.key = key def clone(self): - c: AddModelsHook = super().clone() + c: AdditionalModelsHook = super().clone() c.models = self.models.copy() if self.models else self.models c.key = self.key return c @@ -271,7 +267,7 @@ def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]): WrapperHook = TransformerOptionsHook '''Only here for backwards compatibility, WrapperHook is identical to TransformerOptionsHook.''' -class SetInjectionsHook(Hook): +class InjectionsHook(Hook): def __init__(self, key: str=None, injections: list[PatcherInjection]=None, hook_scope=EnumHookScope.AllConditioning): super().__init__(hook_type=EnumHookType.Injections) @@ -280,21 +276,13 @@ def __init__(self, key: str=None, injections: list[PatcherInjection]=None, self.hook_scope = hook_scope def clone(self): - c: SetInjectionsHook = super().clone() + c: InjectionsHook = super().clone() c.key = self.key c.injections = self.injections.copy() if self.injections else self.injections return c def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): - raise NotImplementedError("SetInjectionsHook is not supported yet in ComfyUI.") - if not self.should_register(model, model_options, target_dict, registered): - return False - registered.add(self) - return True - - def add_hook_injections(self, model: ModelPatcher): - # TODO: add functionality - pass + raise NotImplementedError("InjectionsHook is not supported yet in ComfyUI.") class HookGroup: ''' diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index 1433d185908e..b70e5e636261 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -94,8 +94,8 @@ def get_additional_models_from_model_options(model_options: dict[str]=None): models = [] if model_options is not None and "registered_hooks" in model_options: registered: comfy.hooks.HookGroup = model_options["registered_hooks"] - for hook in registered.get_type(comfy.hooks.EnumHookType.AddModels): - hook: comfy.hooks.AddModelsHook + for hook in registered.get_type(comfy.hooks.EnumHookType.AdditionalModels): + hook: comfy.hooks.AdditionalModelsHook models.extend(hook.models) return models @@ -146,8 +146,8 @@ def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict): hook: comfy.hooks.TransformerOptionsHook hook.add_hook_patches(model, model_options, target_dict, registered) # handle all AddModelsHooks - for hook in hooks.get_type(comfy.hooks.EnumHookType.AddModels): - hook: comfy.hooks.AddModelsHook + for hook in hooks.get_type(comfy.hooks.EnumHookType.AdditionalModels): + hook: comfy.hooks.AdditionalModelsHook hook.add_hook_patches(model, model_options, target_dict, registered) # handle all WeightHooks by registering on ModelPatcher model.register_all_hook_patches(hooks, target_dict, model_options, registered) From 733328169868b9f4120cbfc59af2b00683df8563 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 7 Jan 2025 02:58:59 -0600 Subject: [PATCH 15/90] Clean up a typehint --- comfy_extras/nodes_hooks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_hooks.py b/comfy_extras/nodes_hooks.py index 642238340d23..1edc06f3d7ae 100644 --- a/comfy_extras/nodes_hooks.py +++ b/comfy_extras/nodes_hooks.py @@ -246,7 +246,7 @@ def INPUT_TYPES(s): CATEGORY = "advanced/hooks/clip" FUNCTION = "apply_hooks" - def apply_hooks(self, clip: 'CLIP', schedule_clip: bool, apply_to_conds: bool, hooks: comfy.hooks.HookGroup=None): + def apply_hooks(self, clip: CLIP, schedule_clip: bool, apply_to_conds: bool, hooks: comfy.hooks.HookGroup=None): if hooks is not None: clip = clip.clone() if apply_to_conds: From 871258aa722fb8031e251c7e4d0ecffa9a11c460 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 7 Jan 2025 21:06:03 -0600 Subject: [PATCH 16/90] Add get_all_torch_devices to get detected devices intended for current torch hardware device --- comfy/model_management.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index f6dfc18b02b6..003a89f51267 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -128,6 +128,19 @@ def get_torch_device(): else: return torch.device(torch.cuda.current_device()) +def get_all_torch_devices(exclude_current=False): + global cpu_state + devices = [] + if cpu_state == CPUState.GPU: + if is_nvidia(): + for i in range(torch.cuda.device_count()): + devices.append(torch.device(i)) + else: + devices.append(get_torch_device()) + if exclude_current: + devices.remove(get_torch_device()) + return devices + def get_total_memory(dev=None, torch_total_too=False): global directml_enabled if dev is None: From 7448f02b7cf0e7acf97fcdc41eda4342d062e549 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Wed, 8 Jan 2025 03:33:05 -0600 Subject: [PATCH 17/90] Initial proof of concept of giving splitting cond sampling between multiple GPUs --- comfy/model_management.py | 10 +- comfy/model_patcher.py | 4 + comfy/samplers.py | 188 +++++++++++++++++++++++++++++++++++++- 3 files changed, 198 insertions(+), 4 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 003a89f51267..87ad290d0956 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -15,6 +15,7 @@ You should have received a copy of the GNU General Public License along with this program. If not, see . """ +from __future__ import annotations import psutil import logging @@ -26,6 +27,11 @@ import weakref import gc +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from comfy.model_patcher import ModelPatcher + from comfy.model_base import BaseModel + class VRAMState(Enum): DISABLED = 0 #No vram present: no need to move models to vram NO_VRAM = 1 #Very low vram: enable all the options to save vram @@ -330,7 +336,7 @@ def module_size(module): return module_mem class LoadedModel: - def __init__(self, model): + def __init__(self, model: ModelPatcher): self._set_model(model) self.device = model.load_device self.real_model = None @@ -338,7 +344,7 @@ def __init__(self, model): self.model_finalizer = None self._patcher_finalizer = None - def _set_model(self, model): + def _set_model(self, model: ModelPatcher): self._model = weakref.ref(model) if model.parent is not None: self._parent_model = weakref.ref(model.parent) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 0501f7b38435..5465dde627b0 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -218,6 +218,8 @@ def __init__(self, model, load_device, offload_device, size=0, weight_inplace_up self.is_clip = False self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed + self.is_multigpu_clone = False + if not hasattr(self.model, 'model_loaded_weight_memory'): self.model.model_loaded_weight_memory = 0 @@ -293,6 +295,8 @@ def clone(self): n.is_clip = self.is_clip n.hook_mode = self.hook_mode + n.is_multigpu_clone = self.is_multigpu_clone + for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE): callback(self, n) return n diff --git a/comfy/samplers.py b/comfy/samplers.py index 5cc33a7d9f63..f3064000603c 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -19,6 +19,7 @@ import comfy.hooks import scipy.stats import numpy +import threading def get_area_and_mult(conds, x_in, timestep_in): dims = tuple(x_in.shape[2:]) @@ -130,7 +131,7 @@ def objects_concatable(obj1, obj2): return cond_equal_size(c1.conditioning, c2.conditioning) -def cond_cat(c_list): +def cond_cat(c_list, device=None): temp = {} for x in c_list: for k in x: @@ -142,6 +143,8 @@ def cond_cat(c_list): for k in temp: conds = temp[k] out[k] = conds[0].concat(conds[1:]) + if device is not None: + out[k] = out[k].to(device) return out @@ -195,7 +198,9 @@ def calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Ten ) return executor.execute(model, conds, x_in, timestep, model_options) -def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options): +def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]): + if 'multigpu_clones' in model_options: + return _calc_cond_batch_multigpu(model, conds, x_in, timestep, model_options) out_conds = [] out_counts = [] # separate conds by matching hooks @@ -329,6 +334,173 @@ def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Te return out_conds +def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]): + out_conds = [] + out_counts = [] + # separate conds by matching hooks + hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]] = {} + default_conds = [] + has_default_conds = False + + output_device = x_in.device + + for i in range(len(conds)): + out_conds.append(torch.zeros_like(x_in)) + out_counts.append(torch.ones_like(x_in) * 1e-37) + + cond = conds[i] + default_c = [] + if cond is not None: + for x in cond: + if 'default' in x: + default_c.append(x) + has_default_conds = True + continue + p = comfy.samplers.get_area_and_mult(x, x_in, timestep) + if p is None: + continue + if p.hooks is not None: + model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options) + hooked_to_run.setdefault(p.hooks, list()) + hooked_to_run[p.hooks] += [(p, i)] + default_conds.append(default_c) + + if has_default_conds: + finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options) + + model.current_patcher.prepare_state(timestep) + + devices = [dev_m for dev_m in model_options["multigpu_clones"].keys()] + device_batched_hooked_to_run: dict[torch.device, list[tuple[comfy.hooks.HookGroup, tuple]]] = {} + count = 0 + # run every hooked_to_run separately + for hooks, to_run in hooked_to_run.items(): + while len(to_run) > 0: + first = to_run[0] + first_shape = first[0][0].shape + to_batch_temp = [] + for x in range(len(to_run)): + if can_concat_cond(to_run[x][0], first[0]): + to_batch_temp += [x] + + to_batch_temp.reverse() + to_batch = to_batch_temp[:1] + + current_device = devices[count % len(devices)] + free_memory = model_management.get_free_memory(current_device) + for i in range(1, len(to_batch_temp) + 1): + batch_amount = to_batch_temp[:len(to_batch_temp)//i] + input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] + # if model.memory_required(input_shape) * 1.5 < free_memory: + # to_batch = batch_amount + # break + conds_to_batch = [] + for x in to_batch: + conds_to_batch.append(to_run.pop(x)) + + batched_to_run = device_batched_hooked_to_run.setdefault(current_device, []) + batched_to_run.append((hooks, conds_to_batch)) + count += 1 + + thread_result = collections.namedtuple('thread_result', ['output', 'mult', 'area', 'batch_chunks', 'cond_or_uncond']) + def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]): + model_current: BaseModel = model_options["multigpu_clones"][device].model + # run every hooked_to_run separately + with torch.no_grad(): + for hooks, to_batch in batch_tuple: + input_x = [] + mult = [] + c = [] + cond_or_uncond = [] + uuids = [] + area = [] + control = None + patches = None + for x in to_batch: + o = x + p = o[0] + input_x.append(p.input_x) + mult.append(p.mult) + c.append(p.conditioning) + area.append(p.area) + cond_or_uncond.append(o[1]) + uuids.append(p.uuid) + control = p.control + patches = p.patches + + batch_chunks = len(cond_or_uncond) + input_x = torch.cat(input_x).to(device) + c = cond_cat(c, device=device) + timestep_ = torch.cat([timestep.to(device)] * batch_chunks) + + transformer_options = model_current.current_patcher.apply_hooks(hooks=hooks) + if 'transformer_options' in model_options: + transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options, + model_options['transformer_options'], + copy_dict1=False) + + if patches is not None: + # TODO: replace with merge_nested_dicts function + if "patches" in transformer_options: + cur_patches = transformer_options["patches"].copy() + for p in patches: + if p in cur_patches: + cur_patches[p] = cur_patches[p] + patches[p] + else: + cur_patches[p] = patches[p] + transformer_options["patches"] = cur_patches + else: + transformer_options["patches"] = patches + + transformer_options["cond_or_uncond"] = cond_or_uncond[:] + transformer_options["uuids"] = uuids[:] + transformer_options["sigmas"] = timestep + transformer_options["sample_sigmas"] = transformer_options["sample_sigmas"].to(device) + + c['transformer_options'] = transformer_options + + if control is not None: + c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options) + + if 'model_function_wrapper' in model_options: + output = model_options['model_function_wrapper'](model_current.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).to(output_device).chunk(batch_chunks) + else: + output = model_current.apply_model(input_x, timestep_, **c).to(output_device).chunk(batch_chunks) + results.append(thread_result(output, mult, area, batch_chunks, cond_or_uncond)) + + + results: list[thread_result] = [] + threads: list[threading.Thread] = [] + for device, batch_tuple in device_batched_hooked_to_run.items(): + new_thread = threading.Thread(target=_handle_batch, args=(device, batch_tuple, results)) + threads.append(new_thread) + new_thread.start() + + for thread in threads: + thread.join() + + for output, mult, area, batch_chunks, cond_or_uncond in results: + for o in range(batch_chunks): + cond_index = cond_or_uncond[o] + a = area[o] + if a is None: + out_conds[cond_index] += output[o] * mult[o] + out_counts[cond_index] += mult[o] + else: + out_c = out_conds[cond_index] + out_cts = out_counts[cond_index] + dims = len(a) // 2 + for i in range(dims): + out_c = out_c.narrow(i + 2, a[i + dims], a[i]) + out_cts = out_cts.narrow(i + 2, a[i + dims], a[i]) + out_c += output[o] * mult[o] + out_cts += mult[o] + + for i in range(len(out_conds)): + out_conds[i] /= out_counts[i] + + return out_conds + def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #TODO: remove logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.") return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options)) @@ -940,6 +1112,14 @@ def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options) device = self.model_patcher.load_device + multigpu_patchers: list[ModelPatcher] = [x for x in self.loaded_models if x.is_multigpu_clone] + if len(multigpu_patchers) > 0: + multigpu_dict: dict[torch.device, ModelPatcher] = {} + multigpu_dict[device] = self.model_patcher + for x in multigpu_patchers: + multigpu_dict[x.load_device] = x + self.model_options["multigpu_clones"] = multigpu_dict + if denoise_mask is not None: denoise_mask = comfy.sampler_helpers.prepare_mask(denoise_mask, noise.shape, device) @@ -950,9 +1130,13 @@ def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, try: self.model_patcher.pre_run() + for multigpu_patcher in multigpu_patchers: + multigpu_patcher.pre_run() output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) finally: self.model_patcher.cleanup() + for multigpu_patcher in multigpu_patchers: + multigpu_patcher.cleanup() comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models) del self.inner_model From e88c6c03ff16c197e7b49b7908f91a67f21ef7b1 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 10 Jan 2025 23:05:24 -0600 Subject: [PATCH 18/90] Fix cond_cat to not try to cast anything that doesn't have a 'to' function --- comfy/samplers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index f3064000603c..98b1932f7a3a 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -143,7 +143,7 @@ def cond_cat(c_list, device=None): for k in temp: conds = temp[k] out[k] = conds[0].concat(conds[1:]) - if device is not None: + if device is not None and hasattr(out[k], 'to'): out[k] = out[k].to(device) return out From d5088072fb7561e6c6b44693c65e31c254c81b81 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 13 Jan 2025 20:20:25 -0600 Subject: [PATCH 19/90] Make test node for multigpu instead of storing it in just a local __init__.py --- comfy_extras/nodes_multigpu.py | 39 ++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 comfy_extras/nodes_multigpu.py diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py new file mode 100644 index 000000000000..929151b503e5 --- /dev/null +++ b/comfy_extras/nodes_multigpu.py @@ -0,0 +1,39 @@ +from comfy.model_patcher import ModelPatcher +import comfy.utils +import comfy.patcher_extension +import comfy.model_management +import copy + + +class MultiGPUInitialize: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "model": ("MODEL",), + } + } + + RETURN_TYPES = ("MODEL",) + FUNCTION = "init_multigpu" + CATEGORY = "DevTools" + + def init_multigpu(self, model: ModelPatcher): + model = model.clone() + extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True) + if len(extra_devices) > 0: + comfy.model_management.unload_all_models() + for device in extra_devices: + device_patcher = model.clone() + device_patcher.model = copy.deepcopy(model.model) + device_patcher.load_device = device + device_patcher.is_multigpu_clone = True + multigpu_models = model.get_additional_models_with_key("multigpu") + multigpu_models.append(device_patcher) + model.set_additional_models("multigpu", multigpu_models) + return (model,) + + +NODE_CLASS_MAPPINGS = { + "test_multigpuinit": MultiGPUInitialize, +} \ No newline at end of file From 198953cd088b8a02701315e59047db16a6e6438a Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 14 Jan 2025 12:24:55 -0600 Subject: [PATCH 20/90] Add nodes_multigpu.py to loaded nodes --- nodes.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nodes.py b/nodes.py index cfd7dd8a4532..62b6ad18ad68 100644 --- a/nodes.py +++ b/nodes.py @@ -2224,6 +2224,7 @@ def init_builtin_extra_nodes(): "nodes_mahiro.py", "nodes_lt.py", "nodes_hooks.py", + "nodes_multigpu.py", "nodes_load_3d.py", "nodes_cosmos.py", ] From 25818dc848f8db6f79b4410e46b06133165d35a2 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 14 Jan 2025 13:45:14 -0600 Subject: [PATCH 21/90] Added a 'max_gpus' input --- comfy_extras/nodes_multigpu.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py index 929151b503e5..3ba558621685 100644 --- a/comfy_extras/nodes_multigpu.py +++ b/comfy_extras/nodes_multigpu.py @@ -11,6 +11,9 @@ def INPUT_TYPES(cls): return { "required": { "model": ("MODEL",), + }, + "optional": { + "max_gpus" : ("INT", {"default": 8, "min": 1, "step": 1}), } } @@ -18,9 +21,10 @@ def INPUT_TYPES(cls): FUNCTION = "init_multigpu" CATEGORY = "DevTools" - def init_multigpu(self, model: ModelPatcher): + def init_multigpu(self, model: ModelPatcher, max_gpus: int): model = model.clone() extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True) + extra_devices = extra_devices[:max_gpus-1] if len(extra_devices) > 0: comfy.model_management.unload_all_models() for device in extra_devices: From bfce72331188c4efdfe41edcf4e941ac328632cb Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 17 Jan 2025 03:31:28 -0600 Subject: [PATCH 22/90] Initial work on multigpu_clone function, which will account for additional_models getting cloned --- comfy/model_patcher.py | 25 +++++++++++++++++++++++++ comfy_extras/nodes_multigpu.py | 6 ++---- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 5465dde627b0..63f1f92e4054 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -219,6 +219,7 @@ def __init__(self, model, load_device, offload_device, size=0, weight_inplace_up self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed self.is_multigpu_clone = False + self.clone_uuid = uuid.uuid4() if not hasattr(self.model, 'model_loaded_weight_memory'): self.model.model_loaded_weight_memory = 0 @@ -296,11 +297,35 @@ def clone(self): n.hook_mode = self.hook_mode n.is_multigpu_clone = self.is_multigpu_clone + n.clone_uuid = self.clone_uuid for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE): callback(self, n) return n + def multigpu_clone(self, new_load_device=None, models_cache: dict[ModelPatcher,ModelPatcher]=None): + n = self.clone() + # set load device, if present + if new_load_device is not None: + n.load_device = new_load_device + # unlike for normal clone, backup dicts that shared same ref should not; + # otherwise, patchers that have deep copies of base models will erroneously influence each other. + n.backup = copy.deepcopy(n.backup) + n.object_patches_backup = copy.deepcopy(n.object_patches_backup) + n.model = copy.deepcopy(n.model) + # multigpu clone should not have multigpu additional_models entry + n.remove_additional_models("multigpu") + # multigpu_clone all stored additional_models; make sure circular references are properly handled + if models_cache is None: + models_cache = {} + for key, model_list in n.additional_models.items(): + for i in range(len(model_list)): + add_model = n.additional_models[key][i] + if i not in models_cache: + models_cache[add_model] = add_model.multigpu_clone(new_load_device=new_load_device, models_cache=models_cache) + n.additional_models[key][i] = models_cache[add_model] + return n + def is_clone(self, other): if hasattr(other, 'model') and self.model is other.model: return True diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py index 3ba558621685..dec395fb3546 100644 --- a/comfy_extras/nodes_multigpu.py +++ b/comfy_extras/nodes_multigpu.py @@ -22,15 +22,13 @@ def INPUT_TYPES(cls): CATEGORY = "DevTools" def init_multigpu(self, model: ModelPatcher, max_gpus: int): - model = model.clone() extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True) extra_devices = extra_devices[:max_gpus-1] if len(extra_devices) > 0: + model = model.clone() comfy.model_management.unload_all_models() for device in extra_devices: - device_patcher = model.clone() - device_patcher.model = copy.deepcopy(model.model) - device_patcher.load_device = device + device_patcher = model.multigpu_clone(new_load_device=device) device_patcher.is_multigpu_clone = True multigpu_models = model.get_additional_models_with_key("multigpu") multigpu_models.append(device_patcher) From 328d4f16a90f5d5d8ac1218bcc5b0862fe970afb Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 20 Jan 2025 04:34:26 -0600 Subject: [PATCH 23/90] Make WeightHooks compatible with MultiGPU, clean up some code --- comfy/model_patcher.py | 46 ++++++++++++++++++++++++++++++---- comfy/sampler_helpers.py | 18 ++++++++++++- comfy/samplers.py | 46 ++++++++++++++++++++-------------- comfy_extras/nodes_multigpu.py | 2 +- 4 files changed, 86 insertions(+), 26 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 63f1f92e4054..46779397e933 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -84,12 +84,15 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_ def create_model_options_clone(orig_model_options: dict): return comfy.patcher_extension.copy_nested_dicts(orig_model_options) -def create_hook_patches_clone(orig_hook_patches): +def create_hook_patches_clone(orig_hook_patches, copy_tuples=False): new_hook_patches = {} for hook_ref in orig_hook_patches: new_hook_patches[hook_ref] = {} for k in orig_hook_patches[hook_ref]: new_hook_patches[hook_ref][k] = orig_hook_patches[hook_ref][k][:] + if copy_tuples: + for i in range(len(new_hook_patches[hook_ref][k])): + new_hook_patches[hook_ref][k][i] = tuple(new_hook_patches[hook_ref][k][i]) return new_hook_patches def wipe_lowvram_weight(m): @@ -303,7 +306,7 @@ def clone(self): callback(self, n) return n - def multigpu_clone(self, new_load_device=None, models_cache: dict[ModelPatcher,ModelPatcher]=None): + def multigpu_deepclone(self, new_load_device=None, models_cache: dict[ModelPatcher,ModelPatcher]=None): n = self.clone() # set load device, if present if new_load_device is not None: @@ -312,6 +315,7 @@ def multigpu_clone(self, new_load_device=None, models_cache: dict[ModelPatcher,M # otherwise, patchers that have deep copies of base models will erroneously influence each other. n.backup = copy.deepcopy(n.backup) n.object_patches_backup = copy.deepcopy(n.object_patches_backup) + n.hook_backup = copy.deepcopy(n.hook_backup) n.model = copy.deepcopy(n.model) # multigpu clone should not have multigpu additional_models entry n.remove_additional_models("multigpu") @@ -322,7 +326,7 @@ def multigpu_clone(self, new_load_device=None, models_cache: dict[ModelPatcher,M for i in range(len(model_list)): add_model = n.additional_models[key][i] if i not in models_cache: - models_cache[add_model] = add_model.multigpu_clone(new_load_device=new_load_device, models_cache=models_cache) + models_cache[add_model] = add_model.multigpu_deepclone(new_load_device=new_load_device, models_cache=models_cache) n.additional_models[key][i] = models_cache[add_model] return n @@ -952,9 +956,13 @@ def pre_run(self): for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN): callback(self) - def prepare_state(self, timestep): + def prepare_state(self, timestep, model_options, ignore_multigpu=False): for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE): - callback(self, timestep) + callback(self, timestep, model_options, ignore_multigpu) + if not ignore_multigpu and "multigpu_clones" in model_options: + for p in model_options["multigpu_clones"].values(): + p: ModelPatcher + p.prepare_state(timestep, model_options, ignore_multigpu=True) def restore_hook_patches(self): if self.hook_patches_backup is not None: @@ -967,12 +975,18 @@ def set_hook_mode(self, hook_mode: comfy.hooks.EnumHookMode): def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup, model_options: dict[str]): curr_t = t[0] reset_current_hooks = False + multigpu_kf_changed_cache = None transformer_options = model_options.get("transformer_options", {}) for hook in hook_group.hooks: changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t, transformer_options=transformer_options) # if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref; # this will cause the weights to be recalculated when sampling if changed: + # cache changed for multigpu usage + if "multigpu_clones" in model_options: + if multigpu_kf_changed_cache is None: + multigpu_kf_changed_cache = [] + multigpu_kf_changed_cache.append(hook) # reset current_hooks if contains hook that changed if self.current_hooks is not None: for current_hook in self.current_hooks.hooks: @@ -984,6 +998,28 @@ def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: com self.cached_hook_patches.pop(cached_group) if reset_current_hooks: self.patch_hooks(None) + if "multigpu_clones" in model_options: + for p in model_options["multigpu_clones"].values(): + p: ModelPatcher + p._handle_changed_hook_keyframes(multigpu_kf_changed_cache) + + def _handle_changed_hook_keyframes(self, kf_changed_cache: list[comfy.hooks.Hook]): + 'Used to handle multigpu behavior inside prepare_hook_patches_current_keyframe.' + if kf_changed_cache is None: + return + reset_current_hooks = False + # reset current_hooks if contains hook that changed + for hook in kf_changed_cache: + if self.current_hooks is not None: + for current_hook in self.current_hooks.hooks: + if current_hook == hook: + reset_current_hooks = True + break + for cached_group in list(self.cached_hook_patches.keys()): + if cached_group.contains(hook): + self.cached_hook_patches.pop(cached_group) + if reset_current_hooks: + self.patch_hooks(None) def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None, registered: comfy.hooks.HookGroup = None): diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index b70e5e636261..a95231ff5d3b 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -1,7 +1,9 @@ from __future__ import annotations +import torch import uuid import comfy.model_management import comfy.conds +import comfy.model_patcher import comfy.utils import comfy.hooks import comfy.patcher_extension @@ -127,7 +129,7 @@ def cleanup_models(conds, models): cleanup_additional_models(set(control_cleanup)) -def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict): +def prepare_model_patcher(model: ModelPatcher, conds, model_options: dict): ''' Registers hooks from conds. ''' @@ -160,3 +162,17 @@ def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict): comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name], copy_dict1=False) return to_load_options + +def prepare_model_patcher_multigpu_clones(model_patcher: ModelPatcher, loaded_models: list[ModelPatcher], model_options: dict): + ''' + In case multigpu acceleration is enabled, prep ModelPatchers for each device. + ''' + multigpu_patchers: list[ModelPatcher] = [x for x in loaded_models if x.is_multigpu_clone] + if len(multigpu_patchers) > 0: + multigpu_dict: dict[torch.device, ModelPatcher] = {} + multigpu_dict[model_patcher.load_device] = model_patcher + for x in multigpu_patchers: + x.hook_patches = comfy.model_patcher.create_hook_patches_clone(model_patcher.hook_patches, copy_tuples=True) + multigpu_dict[x.load_device] = x + model_options["multigpu_clones"] = multigpu_dict + return multigpu_patchers diff --git a/comfy/samplers.py b/comfy/samplers.py index e9cd076e9344..dde0b6521405 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -232,7 +232,7 @@ def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Te if has_default_conds: finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options) - model.current_patcher.prepare_state(timestep) + model.current_patcher.prepare_state(timestep, model_options) # run every hooked_to_run separately for hooks, to_run in hooked_to_run.items(): @@ -368,39 +368,53 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t if has_default_conds: finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options) - model.current_patcher.prepare_state(timestep) + model.current_patcher.prepare_state(timestep, model_options) - devices = [dev_m for dev_m in model_options["multigpu_clones"].keys()] + devices = [dev_m for dev_m in model_options['multigpu_clones'].keys()] device_batched_hooked_to_run: dict[torch.device, list[tuple[comfy.hooks.HookGroup, tuple]]] = {} - count = 0 + + total_conds = 0 + for to_run in hooked_to_run.values(): + total_conds += len(to_run) + conds_per_device = max(1, math.ceil(total_conds//len(devices))) + index_device = 0 + current_device = devices[index_device] # run every hooked_to_run separately for hooks, to_run in hooked_to_run.items(): while len(to_run) > 0: + current_device = devices[index_device % len(devices)] + batched_to_run = device_batched_hooked_to_run.setdefault(current_device, []) + # keep track of conds currently scheduled onto this device + batched_to_run_length = 0 + for btr in batched_to_run: + batched_to_run_length += len(btr[1]) + first = to_run[0] first_shape = first[0][0].shape to_batch_temp = [] + # make sure not over conds_per_device limit when creating temp batch for x in range(len(to_run)): - if can_concat_cond(to_run[x][0], first[0]): + if can_concat_cond(to_run[x][0], first[0]) and len(to_batch_temp) < (conds_per_device - batched_to_run_length): to_batch_temp += [x] to_batch_temp.reverse() to_batch = to_batch_temp[:1] - current_device = devices[count % len(devices)] free_memory = model_management.get_free_memory(current_device) for i in range(1, len(to_batch_temp) + 1): batch_amount = to_batch_temp[:len(to_batch_temp)//i] input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] - # if model.memory_required(input_shape) * 1.5 < free_memory: - # to_batch = batch_amount - # break + if model.memory_required(input_shape) * 1.5 < free_memory: + to_batch = batch_amount + break conds_to_batch = [] for x in to_batch: conds_to_batch.append(to_run.pop(x)) - - batched_to_run = device_batched_hooked_to_run.setdefault(current_device, []) + batched_to_run_length += len(conds_to_batch) + batched_to_run.append((hooks, conds_to_batch)) - count += 1 + if batched_to_run_length >= conds_per_device: + index_device += 1 thread_result = collections.namedtuple('thread_result', ['output', 'mult', 'area', 'batch_chunks', 'cond_or_uncond']) def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]): @@ -1112,13 +1126,7 @@ def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options) device = self.model_patcher.load_device - multigpu_patchers: list[ModelPatcher] = [x for x in self.loaded_models if x.is_multigpu_clone] - if len(multigpu_patchers) > 0: - multigpu_dict: dict[torch.device, ModelPatcher] = {} - multigpu_dict[device] = self.model_patcher - for x in multigpu_patchers: - multigpu_dict[x.load_device] = x - self.model_options["multigpu_clones"] = multigpu_dict + multigpu_patchers = comfy.sampler_helpers.prepare_model_patcher_multigpu_clones(self.model_patcher, self.loaded_models, self.model_options) if denoise_mask is not None: denoise_mask = comfy.sampler_helpers.prepare_mask(denoise_mask, noise.shape, device) diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py index dec395fb3546..b3c8635b8769 100644 --- a/comfy_extras/nodes_multigpu.py +++ b/comfy_extras/nodes_multigpu.py @@ -28,7 +28,7 @@ def init_multigpu(self, model: ModelPatcher, max_gpus: int): model = model.clone() comfy.model_management.unload_all_models() for device in extra_devices: - device_patcher = model.multigpu_clone(new_load_device=device) + device_patcher = model.multigpu_deepclone(new_load_device=device) device_patcher.is_multigpu_clone = True multigpu_models = model.get_additional_models_with_key("multigpu") multigpu_models.append(device_patcher) From 02a4d0ad7de479c8e1145b2305edab4bac2a2e45 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 23 Jan 2025 01:20:00 -0600 Subject: [PATCH 24/90] Added unload_model_and_clones to model_management.py to allow unloading only relevant models --- comfy/model_management.py | 10 ++++++++++ comfy/model_patcher.py | 4 ++-- comfy/sampler_helpers.py | 1 + comfy_extras/nodes_multigpu.py | 2 +- 4 files changed, 14 insertions(+), 3 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 87ad290d0956..2cf792b56b96 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1146,6 +1146,16 @@ def soft_empty_cache(force=False): def unload_all_models(): free_memory(1e30, get_torch_device()) +def unload_model_and_clones(model: ModelPatcher): + 'Unload only model and its clones - primarily for multigpu cloning purposes.' + initial_keep_loaded: list[LoadedModel] = current_loaded_models.copy() + keep_loaded = [] + for loaded_model in initial_keep_loaded: + if loaded_model.model is not None: + if model.clone_base_uuid == loaded_model.model.clone_base_uuid: + continue + keep_loaded.append(loaded_model) + free_memory(1e30, get_torch_device(), keep_loaded) #TODO: might be cleaner to put this somewhere else import threading diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 46779397e933..b4efa8d02766 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -222,7 +222,7 @@ def __init__(self, model, load_device, offload_device, size=0, weight_inplace_up self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed self.is_multigpu_clone = False - self.clone_uuid = uuid.uuid4() + self.clone_base_uuid = uuid.uuid4() if not hasattr(self.model, 'model_loaded_weight_memory'): self.model.model_loaded_weight_memory = 0 @@ -300,7 +300,7 @@ def clone(self): n.hook_mode = self.hook_mode n.is_multigpu_clone = self.is_multigpu_clone - n.clone_uuid = self.clone_uuid + n.clone_base_uuid = self.clone_base_uuid for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE): callback(self, n) diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index a95231ff5d3b..5564b62c2ad4 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -173,6 +173,7 @@ def prepare_model_patcher_multigpu_clones(model_patcher: ModelPatcher, loaded_mo multigpu_dict[model_patcher.load_device] = model_patcher for x in multigpu_patchers: x.hook_patches = comfy.model_patcher.create_hook_patches_clone(model_patcher.hook_patches, copy_tuples=True) + x.hook_mode = model_patcher.hook_mode # match main model's hook_mode multigpu_dict[x.load_device] = x model_options["multigpu_clones"] = multigpu_dict return multigpu_patchers diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py index b3c8635b8769..b5c36c64d4e2 100644 --- a/comfy_extras/nodes_multigpu.py +++ b/comfy_extras/nodes_multigpu.py @@ -26,7 +26,7 @@ def init_multigpu(self, model: ModelPatcher, max_gpus: int): extra_devices = extra_devices[:max_gpus-1] if len(extra_devices) > 0: model = model.clone() - comfy.model_management.unload_all_models() + comfy.model_management.unload_model_and_clones(model) for device in extra_devices: device_patcher = model.multigpu_deepclone(new_load_device=device) device_patcher.is_multigpu_clone = True From 5db42774496189142ef1521d7280e47b14044de5 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 23 Jan 2025 19:06:05 -0600 Subject: [PATCH 25/90] Make sure additional_models are unloaded as well when perform --- comfy/model_management.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 2cf792b56b96..c72ed247d18b 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1146,14 +1146,25 @@ def soft_empty_cache(force=False): def unload_all_models(): free_memory(1e30, get_torch_device()) -def unload_model_and_clones(model: ModelPatcher): +def unload_model_and_clones(model: ModelPatcher, unload_additional_models=True): 'Unload only model and its clones - primarily for multigpu cloning purposes.' initial_keep_loaded: list[LoadedModel] = current_loaded_models.copy() + additional_models = [] + if unload_additional_models: + additional_models = model.get_nested_additional_models() keep_loaded = [] for loaded_model in initial_keep_loaded: if loaded_model.model is not None: if model.clone_base_uuid == loaded_model.model.clone_base_uuid: continue + # check additional models if they are a match + skip = False + for add_model in additional_models: + if add_model.clone_base_uuid == loaded_model.model.clone_base_uuid: + skip = True + break + if skip: + continue keep_loaded.append(loaded_model) free_memory(1e30, get_torch_device(), keep_loaded) From 46969c380aa15dd7f26dfcee67dc59b9d830a3bd Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 24 Jan 2025 05:39:38 -0600 Subject: [PATCH 26/90] Initial MultiGPU support for controlnets --- comfy/controlnet.py | 49 ++++++++++++++++++++++++++++++++++++++---- comfy/samplers.py | 52 +++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 95 insertions(+), 6 deletions(-) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index ee29251b9727..0029a4987099 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -15,13 +15,14 @@ You should have received a copy of the GNU General Public License along with this program. If not, see . """ - +from __future__ import annotations import torch from enum import Enum import math import os import logging +import copy import comfy.utils import comfy.model_management import comfy.model_detection @@ -36,7 +37,7 @@ import comfy.ldm.hydit.controlnet import comfy.ldm.flux.controlnet import comfy.cldm.dit_embedder -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Union if TYPE_CHECKING: from comfy.hooks import HookGroup @@ -76,7 +77,7 @@ def __init__(self): self.compression_ratio = 8 self.upscale_algorithm = 'nearest-exact' self.extra_args = {} - self.previous_controlnet = None + self.previous_controlnet: Union[ControlBase, None] = None self.extra_conds = [] self.strength_type = StrengthType.CONSTANT self.concat_mask = False @@ -84,6 +85,7 @@ def __init__(self): self.extra_concat = None self.extra_hooks: HookGroup = None self.preprocess_image = lambda a: a + self.multigpu_clones: dict[torch.device, ControlBase] = {} def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]): self.cond_hint_original = cond_hint @@ -117,10 +119,33 @@ def cleanup(self): def get_models(self): out = [] + for device_cnet in self.multigpu_clones.values(): + out += device_cnet.get_models() if self.previous_controlnet is not None: out += self.previous_controlnet.get_models() return out + def get_models_only_self(self): + 'Calls get_models, but temporarily sets previous_controlnet to None.' + try: + orig_previous_controlnet = self.previous_controlnet + self.previous_controlnet = None + return self.get_models() + finally: + self.previous_controlnet = orig_previous_controlnet + + def get_instance_for_device(self, device): + 'Returns instance of this Control object intended for selected device.' + return self.multigpu_clones.get(device, self) + + def deepclone_multigpu(self, load_device, autoregister=False): + ''' + Create deep clone of Control object where model(s) is set to other devices. + + When autoregister is set to True, the deep clone is also added to multigpu_clones dict. + ''' + raise NotImplementedError("Classes inheriting from ControlBase should define their own deepclone_multigpu funtion.") + def get_extra_hooks(self): out = [] if self.extra_hooks is not None: @@ -129,7 +154,7 @@ def get_extra_hooks(self): out += self.previous_controlnet.get_extra_hooks() return out - def copy_to(self, c): + def copy_to(self, c: ControlBase): c.cond_hint_original = self.cond_hint_original c.strength = self.strength c.timestep_percent_range = self.timestep_percent_range @@ -280,6 +305,14 @@ def copy(self): self.copy_to(c) return c + def deepclone_multigpu(self, load_device, autoregister=False): + c = self.copy() + c.control_model = copy.deepcopy(c.control_model) + c.control_model_wrapped = comfy.model_patcher.ModelPatcher(c.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device()) + if autoregister: + self.multigpu_clones[load_device] = c + return c + def get_models(self): out = super().get_models() out.append(self.control_model_wrapped) @@ -809,6 +842,14 @@ def copy(self): c = T2IAdapter(self.t2i_model, self.channels_in, self.compression_ratio, self.upscale_algorithm) self.copy_to(c) return c + + def deepclone_multigpu(self, load_device, autoregister=False): + c = self.copy() + c.t2i_model = copy.deepcopy(c.t2i_model) + c.device = load_device + if autoregister: + self.multigpu_clones[load_device] = c + return c def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options compression_ratio = 8 diff --git a/comfy/samplers.py b/comfy/samplers.py index cf97b9820238..27d875709c72 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -1,4 +1,6 @@ from __future__ import annotations + +import comfy.model_management from .k_diffusion import sampling as k_diffusion_sampling from .extra_samplers import uni_pc from typing import TYPE_CHECKING, Callable, NamedTuple @@ -427,7 +429,7 @@ def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup cond_or_uncond = [] uuids = [] area = [] - control = None + control: ControlBase = None patches = None for x in to_batch: o = x @@ -473,7 +475,8 @@ def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup c['transformer_options'] = transformer_options if control is not None: - c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options) + device_control = control.get_instance_for_device(device) + c['control'] = device_control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options) if 'model_function_wrapper' in model_options: output = model_options['model_function_wrapper'](model_current.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).to(output_device).chunk(batch_chunks) @@ -799,6 +802,8 @@ def pre_run_control(model, conds): percent_to_timestep_function = lambda a: s.percent_to_sigma(a) if 'control' in x: x['control'].pre_run(model, percent_to_timestep_function) + for device_cnet in x['control'].multigpu_clones.values(): + device_cnet.pre_run(model, percent_to_timestep_function) def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func): cond_cnets = [] @@ -1080,6 +1085,48 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None): wc_list[i] = wc_list[i].to(cast) +def preprocess_multigpu_conds(conds: dict[str, list[dict[str]]], model_options: dict[str], model: ModelPatcher): + '''If multigpu acceleration required, creates deepclones of ControlNets and GLIGEN per device.''' + multigpu_models: list[ModelPatcher] = model.get_additional_models_with_key("multigpu") + if len(multigpu_models) == 0: + return + extra_devices = [x.load_device for x in multigpu_models] + # handle controlnets + controlnets: set[ControlBase] = set() + for k in conds: + for kk in conds[k]: + if 'control' in kk: + controlnets.add(kk['control']) + if len(controlnets) > 0: + # first, unload all controlnet clones + for cnet in list(controlnets): + cnet_models = cnet.get_models() + for cm in cnet_models: + comfy.model_management.unload_model_and_clones(cm, unload_additional_models=True) + + # next, make sure each controlnet has a deepclone for all relevant devices + for cnet in controlnets: + curr_cnet = cnet + while curr_cnet is not None: + for device in extra_devices: + if device not in curr_cnet.multigpu_clones: + curr_cnet.deepclone_multigpu(device, autoregister=True) + curr_cnet = curr_cnet.previous_controlnet + # since all device clones are now present, recreate the linked list for cloned cnets per device + for cnet in controlnets: + curr_cnet = cnet + while curr_cnet is not None: + prev_cnet = curr_cnet.previous_controlnet + for device in extra_devices: + device_cnet = curr_cnet.get_instance_for_device(device) + prev_device_cnet = None + if prev_cnet is not None: + prev_device_cnet = prev_cnet.get_instance_for_device(device) + device_cnet.set_previous_controlnet(prev_device_cnet) + curr_cnet = prev_cnet + # TODO: handle gligen + + class CFGGuider: def __init__(self, model_patcher: ModelPatcher): self.model_patcher = model_patcher @@ -1122,6 +1169,7 @@ def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mas return self.inner_model.process_latent_out(samples.to(torch.float32)) def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None): + preprocess_multigpu_conds(self.conds, self.model_options, self.model_patcher) self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options) device = self.model_patcher.load_device From 51af7fa1b4f42c674d755e60bfb9a67410f956b4 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sat, 25 Jan 2025 06:05:01 -0600 Subject: [PATCH 27/90] Fix multigpu ControlBase get_models and cleanup calls to avoid multiple calls of functions on multigpu_clones versions of controlnets --- comfy/controlnet.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 0029a4987099..31227ae310a6 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -64,6 +64,18 @@ class StrengthType(Enum): CONSTANT = 1 LINEAR_UP = 2 +class ControlIsolation: + '''Temporarily set a ControlBase object's previous_controlnet to None to prevent cascading calls.''' + def __init__(self, control: ControlBase): + self.control = control + self.orig_previous_controlnet = control.previous_controlnet + + def __enter__(self): + self.control.previous_controlnet = None + + def __exit__(self, *args): + self.control.previous_controlnet = self.orig_previous_controlnet + class ControlBase: def __init__(self): self.cond_hint_original = None @@ -112,7 +124,9 @@ def set_previous_controlnet(self, controlnet): def cleanup(self): if self.previous_controlnet is not None: self.previous_controlnet.cleanup() - + for device_cnet in self.multigpu_clones.values(): + with ControlIsolation(device_cnet): + device_cnet.cleanup() self.cond_hint = None self.extra_concat = None self.timestep_range = None @@ -120,19 +134,15 @@ def cleanup(self): def get_models(self): out = [] for device_cnet in self.multigpu_clones.values(): - out += device_cnet.get_models() + out += device_cnet.get_models_only_self() if self.previous_controlnet is not None: out += self.previous_controlnet.get_models() return out def get_models_only_self(self): 'Calls get_models, but temporarily sets previous_controlnet to None.' - try: - orig_previous_controlnet = self.previous_controlnet - self.previous_controlnet = None + with ControlIsolation(self): return self.get_models() - finally: - self.previous_controlnet = orig_previous_controlnet def get_instance_for_device(self, device): 'Returns instance of this Control object intended for selected device.' From c7feef90605801fbda28ae473c46008f7b5b404b Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sun, 26 Jan 2025 05:29:27 -0600 Subject: [PATCH 28/90] Cast transformer_options for multigpu --- comfy/samplers.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index 27d875709c72..b8b30f2c6132 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -471,7 +471,9 @@ def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup transformer_options["uuids"] = uuids[:] transformer_options["sigmas"] = timestep transformer_options["sample_sigmas"] = transformer_options["sample_sigmas"].to(device) + transformer_options["multigpu_thread_device"] = device + cast_transformer_options(transformer_options, device=device) c['transformer_options'] = transformer_options if control is not None: @@ -1045,7 +1047,9 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None): to_load_options = model_options.get("to_load_options", None) if to_load_options is None: return + cast_transformer_options(to_load_options, device, dtype) +def cast_transformer_options(transformer_options: dict[str], device=None, dtype=None): casts = [] if device is not None: casts.append(device) @@ -1054,18 +1058,17 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None): # if nothing to apply, do nothing if len(casts) == 0: return - # try to call .to on patches - if "patches" in to_load_options: - patches = to_load_options["patches"] + if "patches" in transformer_options: + patches = transformer_options["patches"] for name in patches: patch_list = patches[name] for i in range(len(patch_list)): if hasattr(patch_list[i], "to"): for cast in casts: patch_list[i] = patch_list[i].to(cast) - if "patches_replace" in to_load_options: - patches = to_load_options["patches_replace"] + if "patches_replace" in transformer_options: + patches = transformer_options["patches_replace"] for name in patches: patch_list = patches[name] for k in patch_list: @@ -1075,8 +1078,8 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None): # try to call .to on any wrappers/callbacks wrappers_and_callbacks = ["wrappers", "callbacks"] for wc_name in wrappers_and_callbacks: - if wc_name in to_load_options: - wc: dict[str, list] = to_load_options[wc_name] + if wc_name in transformer_options: + wc: dict[str, list] = transformer_options[wc_name] for wc_dict in wc.values(): for wc_list in wc_dict.values(): for i in range(len(wc_list)): From e3298b84de502a9df8a20ed2ab2877a30d631ff7 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sun, 26 Jan 2025 09:34:20 -0600 Subject: [PATCH 29/90] Create proper MultiGPU Initialize node, create gpu_options to create scaffolding for asymmetrical GPU support --- comfy_extras/nodes_multigpu.py | 107 ++++++++++++++++++++++++++++++--- 1 file changed, 100 insertions(+), 7 deletions(-) diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py index b5c36c64d4e2..2ec1e3cfadcd 100644 --- a/comfy_extras/nodes_multigpu.py +++ b/comfy_extras/nodes_multigpu.py @@ -1,27 +1,32 @@ +from __future__ import annotations +import torch + from comfy.model_patcher import ModelPatcher import comfy.utils import comfy.patcher_extension import comfy.model_management -import copy class MultiGPUInitialize: + NodeId = "MultiGPU_Initialize" + NodeName = "MultiGPU Initialize" @classmethod def INPUT_TYPES(cls): return { "required": { "model": ("MODEL",), + "max_gpus" : ("INT", {"default": 8, "min": 1, "step": 1}), }, "optional": { - "max_gpus" : ("INT", {"default": 8, "min": 1, "step": 1}), + "gpu_options": ("GPU_OPTIONS",) } } RETURN_TYPES = ("MODEL",) FUNCTION = "init_multigpu" - CATEGORY = "DevTools" + CATEGORY = "advanced/multigpu" - def init_multigpu(self, model: ModelPatcher, max_gpus: int): + def init_multigpu(self, model: ModelPatcher, max_gpus: int, gpu_options: GPUOptionsGroup=None): extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True) extra_devices = extra_devices[:max_gpus-1] if len(extra_devices) > 0: @@ -33,9 +38,97 @@ def init_multigpu(self, model: ModelPatcher, max_gpus: int): multigpu_models = model.get_additional_models_with_key("multigpu") multigpu_models.append(device_patcher) model.set_additional_models("multigpu", multigpu_models) + if gpu_options is None: + gpu_options = GPUOptionsGroup() + gpu_options.register(model) return (model,) + +class MultiGPUOptionsNode: + NodeId = "MultiGPU_Options" + NodeName = "MultiGPU Options" + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "device_index": ("INT", {"default": 0, "min": 0, "max": 64}), + "relative_speed": ("FLOAT", {"default": 1.0, "min": 0.0, "step": 0.01}) + }, + "optional": { + "gpu_options": ("GPU_OPTIONS",) + } + } + + RETURN_TYPES = ("GPU_OPTIONS",) + FUNCTION = "create_gpu_options" + CATEGORY = "advanced/multigpu" + + def create_gpu_options(self, device_index: int, relative_speed: float, gpu_options: GPUOptionsGroup=None): + if not gpu_options: + gpu_options = GPUOptionsGroup() + gpu_options.clone() + + opt = GPUOptions(device_index=device_index, relative_speed=relative_speed) + gpu_options.add(opt) + + return (gpu_options,) + + +class GPUOptions: + def __init__(self, device_index: int, relative_speed: float): + self.device_index = device_index + self.relative_speed = relative_speed + + def clone(self): + return GPUOptions(self.device_index, self.relative_speed) + def create_dict(self): + return { + "relative_speed": self.relative_speed + } + +class GPUOptionsGroup: + def __init__(self): + self.options: dict[int, GPUOptions] = {} + + def add(self, info: GPUOptions): + self.options[info.device_index] = info + + def clone(self): + c = GPUOptionsGroup() + for opt in self.options.values(): + c.add(opt) + return c + + def register(self, model: ModelPatcher): + opts_dict = {} + # get devices that are valid for this model + devices: list[torch.device] = [model.load_device] + for extra_model in model.get_additional_models_with_key("multigpu"): + extra_model: ModelPatcher + devices.append(extra_model.load_device) + # create dictionary with actual device mapped to its GPUOptions + device_opts_list: list[GPUOptions] = [] + for device in devices: + device_opts = self.options.get(device.index, GPUOptions(device_index=device.index, relative_speed=1.0)) + opts_dict[device] = device_opts.create_dict() + device_opts_list.append(device_opts) + # make relative_speed relative to 1.0 + max_speed = max([x.relative_speed for x in device_opts_list]) + for value in opts_dict.values(): + value["relative_speed"] /= max_speed + model.model_options["multigpu_options"] = opts_dict + + +node_list = [ + MultiGPUInitialize, + MultiGPUOptionsNode +] +NODE_CLASS_MAPPINGS = {} +NODE_DISPLAY_NAME_MAPPINGS = {} + +for node in node_list: + NODE_CLASS_MAPPINGS[node.NodeId] = node + NODE_DISPLAY_NAME_MAPPINGS[node.NodeId] = node.NodeName -NODE_CLASS_MAPPINGS = { - "test_multigpuinit": MultiGPUInitialize, -} \ No newline at end of file +# TODO: remove +NODE_CLASS_MAPPINGS["test_multigpuinit"] = MultiGPUInitialize From eda866bf5113fcbbc03877bcfaa10bb4c24518f9 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 27 Jan 2025 06:25:48 -0600 Subject: [PATCH 30/90] Extracted multigpu core code into multigpu.py, added load_balance_devices to get subdivision of work based on available devices and splittable work item count, added MultiGPU Options nodes to set relative_speed of specific devices; does not change behavior yet --- comfy/multigpu.py | 107 +++++++++++++++++++++++++++++++++ comfy_extras/nodes_multigpu.py | 58 ++---------------- 2 files changed, 113 insertions(+), 52 deletions(-) create mode 100644 comfy/multigpu.py diff --git a/comfy/multigpu.py b/comfy/multigpu.py new file mode 100644 index 000000000000..2a1fc29d2255 --- /dev/null +++ b/comfy/multigpu.py @@ -0,0 +1,107 @@ +from __future__ import annotations +import torch + +from collections import namedtuple +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from comfy.model_patcher import ModelPatcher + + +class GPUOptions: + def __init__(self, device_index: int, relative_speed: float): + self.device_index = device_index + self.relative_speed = relative_speed + + def clone(self): + return GPUOptions(self.device_index, self.relative_speed) + + def create_dict(self): + return { + "relative_speed": self.relative_speed + } + +class GPUOptionsGroup: + def __init__(self): + self.options: dict[int, GPUOptions] = {} + + def add(self, info: GPUOptions): + self.options[info.device_index] = info + + def clone(self): + c = GPUOptionsGroup() + for opt in self.options.values(): + c.add(opt) + return c + + def register(self, model: ModelPatcher): + opts_dict = {} + # get devices that are valid for this model + devices: list[torch.device] = [model.load_device] + for extra_model in model.get_additional_models_with_key("multigpu"): + extra_model: ModelPatcher + devices.append(extra_model.load_device) + # create dictionary with actual device mapped to its GPUOptions + device_opts_list: list[GPUOptions] = [] + for device in devices: + device_opts = self.options.get(device.index, GPUOptions(device_index=device.index, relative_speed=1.0)) + opts_dict[device] = device_opts.create_dict() + device_opts_list.append(device_opts) + # make relative_speed relative to 1.0 + min_speed = min([x.relative_speed for x in device_opts_list]) + for value in opts_dict.values(): + value['relative_speed'] /= min_speed + model.model_options['multigpu_options'] = opts_dict + + +LoadBalance = namedtuple('LoadBalance', ['work_per_device', 'idle_time']) +def load_balance_devices(model_options: dict[str], total_work: int, return_idle_time=False, work_normalized: int=None): + 'Optimize work assigned to different devices, accounting for their relative speeds and splittable work.' + opts_dict = model_options['multigpu_options'] + devices = list(model_options['multigpu_clones'].keys()) + speed_per_device = [] + work_per_device = [] + # get sum of each device's relative_speed + total_speed = 0.0 + for opts in opts_dict.values(): + total_speed += opts['relative_speed'] + # get relative work for each device; + # obtained by w = (W*r)/R + for device in devices: + relative_speed = opts_dict[device]['relative_speed'] + relative_work = (total_work*relative_speed) / total_speed + speed_per_device.append(relative_speed) + work_per_device.append(relative_work) + # relative work must be expressed in whole numbers, but likely is a decimal; + # perform rounding while maintaining total sum equal to total work (sum of relative works) + work_per_device = round_preserved(work_per_device) + dict_work_per_device = {} + for device, relative_work in zip(devices, work_per_device): + dict_work_per_device[device] = relative_work + if not return_idle_time: + return LoadBalance(dict_work_per_device, None) + # divide relative work by relative speed to get estimated completion time of said work by each device; + # time here is relative and does not correspond to real-world units + completion_time = [w/r for w,r in zip(work_per_device, speed_per_device)] + # calculate relative time spent by the devices waiting on each other after their work is completed + idle_time = abs(min(completion_time) - max(completion_time)) + if work_normalized: + idle_time *= (work_normalized/total_work) + + return LoadBalance(dict_work_per_device, idle_time) + +def round_preserved(values: list[float]): + 'Round all values in a list, preserving the combined sum of values.' + # get floor of values; casting to int does it too + floored = [int(x) for x in values] + total_floored = sum(floored) + # get remainder to distribute + remainder = round(sum(values)) - total_floored + # pair values with fractional portions + fractional = [(i, x-floored[i]) for i, x in enumerate(values)] + # sort by fractional part in descending order + fractional.sort(key=lambda x: x[1], reverse=True) + # distribute the remainder + for i in range(remainder): + index = fractional[i][0] + floored[index] += 1 + return floored diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py index 2ec1e3cfadcd..54f68182e696 100644 --- a/comfy_extras/nodes_multigpu.py +++ b/comfy_extras/nodes_multigpu.py @@ -1,10 +1,10 @@ from __future__ import annotations -import torch from comfy.model_patcher import ModelPatcher import comfy.utils import comfy.patcher_extension import comfy.model_management +import comfy.multigpu class MultiGPUInitialize: @@ -26,7 +26,7 @@ def INPUT_TYPES(cls): FUNCTION = "init_multigpu" CATEGORY = "advanced/multigpu" - def init_multigpu(self, model: ModelPatcher, max_gpus: int, gpu_options: GPUOptionsGroup=None): + def init_multigpu(self, model: ModelPatcher, max_gpus: int, gpu_options: comfy.multigpu.GPUOptionsGroup=None): extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True) extra_devices = extra_devices[:max_gpus-1] if len(extra_devices) > 0: @@ -39,7 +39,7 @@ def init_multigpu(self, model: ModelPatcher, max_gpus: int, gpu_options: GPUOpti multigpu_models.append(device_patcher) model.set_additional_models("multigpu", multigpu_models) if gpu_options is None: - gpu_options = GPUOptionsGroup() + gpu_options = comfy.multigpu.GPUOptionsGroup() gpu_options.register(model) return (model,) @@ -62,63 +62,17 @@ def INPUT_TYPES(cls): FUNCTION = "create_gpu_options" CATEGORY = "advanced/multigpu" - def create_gpu_options(self, device_index: int, relative_speed: float, gpu_options: GPUOptionsGroup=None): + def create_gpu_options(self, device_index: int, relative_speed: float, gpu_options: comfy.multigpu.GPUOptionsGroup=None): if not gpu_options: - gpu_options = GPUOptionsGroup() + gpu_options = comfy.multigpu.GPUOptionsGroup() gpu_options.clone() - opt = GPUOptions(device_index=device_index, relative_speed=relative_speed) + opt = comfy.multigpu.GPUOptions(device_index=device_index, relative_speed=relative_speed) gpu_options.add(opt) return (gpu_options,) -class GPUOptions: - def __init__(self, device_index: int, relative_speed: float): - self.device_index = device_index - self.relative_speed = relative_speed - - def clone(self): - return GPUOptions(self.device_index, self.relative_speed) - - def create_dict(self): - return { - "relative_speed": self.relative_speed - } - -class GPUOptionsGroup: - def __init__(self): - self.options: dict[int, GPUOptions] = {} - - def add(self, info: GPUOptions): - self.options[info.device_index] = info - - def clone(self): - c = GPUOptionsGroup() - for opt in self.options.values(): - c.add(opt) - return c - - def register(self, model: ModelPatcher): - opts_dict = {} - # get devices that are valid for this model - devices: list[torch.device] = [model.load_device] - for extra_model in model.get_additional_models_with_key("multigpu"): - extra_model: ModelPatcher - devices.append(extra_model.load_device) - # create dictionary with actual device mapped to its GPUOptions - device_opts_list: list[GPUOptions] = [] - for device in devices: - device_opts = self.options.get(device.index, GPUOptions(device_index=device.index, relative_speed=1.0)) - opts_dict[device] = device_opts.create_dict() - device_opts_list.append(device_opts) - # make relative_speed relative to 1.0 - max_speed = max([x.relative_speed for x in device_opts_list]) - for value in opts_dict.values(): - value["relative_speed"] /= max_speed - model.model_options["multigpu_options"] = opts_dict - - node_list = [ MultiGPUInitialize, MultiGPUOptionsNode From 02747cde7ddacc3fd8a8165cf00aa13cbb770b12 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Wed, 29 Jan 2025 11:10:23 -0600 Subject: [PATCH 31/90] Carry over change from _calc_cond_batch into _calc_cond_batch_multigpu --- comfy/samplers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index b5252d1442d7..f4873e3a510d 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -357,7 +357,7 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t default_c.append(x) has_default_conds = True continue - p = comfy.samplers.get_area_and_mult(x, x_in, timestep) + p = get_area_and_mult(x, x_in, timestep) if p is None: continue if p.hooks is not None: From 476aa79b642f7b09c2a7bbe30a0763761eb11fe5 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 6 Feb 2025 08:44:07 -0600 Subject: [PATCH 32/90] Let --cuda-device take in a string to allow multiple devices (or device order) to be chosen, print available devices on startup, potentially support MultiGPU Intel and Ascend setups --- comfy/cli_args.py | 2 +- comfy/model_management.py | 14 ++++++++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index a92fc0dbac5e..f54be19e4cc2 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -50,7 +50,7 @@ def __call__(self, parser, namespace, values, option_string=None): parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.") parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.") parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.") -parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.") +parser.add_argument("--cuda-device", type=str, default=None, metavar="DEVICE_ID", help="Set the ids of cuda devices this instance will use.") cm_group = parser.add_mutually_exclusive_group() cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).") cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.") diff --git a/comfy/model_management.py b/comfy/model_management.py index 420eb9e89c1c..477bb0f5f727 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -141,6 +141,12 @@ def get_all_torch_devices(exclude_current=False): if is_nvidia(): for i in range(torch.cuda.device_count()): devices.append(torch.device(i)) + elif is_intel_xpu(): + for i in range(torch.xpu.device_count()): + devices.append(torch.device(i)) + elif is_ascend_npu(): + for i in range(torch.npu.device_count()): + devices.append(torch.device(i)) else: devices.append(get_torch_device()) if exclude_current: @@ -320,10 +326,14 @@ def get_torch_device_name(device): return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) try: - logging.info("Device: {}".format(get_torch_device_name(get_torch_device()))) + logging.info("Device [X]: {}".format(get_torch_device_name(get_torch_device()))) except: logging.warning("Could not pick default device.") - +try: + for device in get_all_torch_devices(exclude_current=True): + logging.info("Device [ ]: {}".format(get_torch_device_name(device))) +except: + pass current_loaded_models = [] From 093914a24714ef7264e34062fdeae46bd81964d9 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 3 Mar 2025 22:56:13 -0600 Subject: [PATCH 33/90] Made MultiGPU Work Units node more robust by forcing ModelPatcher clones to match at sample time, reuse loaded MultiGPU clones, finalize MultiGPU Work Units node ID and name, small refactors/cleanup of logging and multigpu-related code --- comfy/model_management.py | 14 ++++--- comfy/model_patcher.py | 67 +++++++++++++++++++++++++++++----- comfy/multigpu.py | 52 ++++++++++++++++++++++++++ comfy/patcher_extension.py | 2 + comfy/sampler_helpers.py | 47 ++++++++++++++++++++++-- comfy/samplers.py | 44 ---------------------- comfy_extras/nodes_multigpu.py | 49 ++++++++++++------------- 7 files changed, 188 insertions(+), 87 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index dd762bdc50ce..3ee8857c252a 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -345,16 +345,16 @@ def get_torch_device_name(device): return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) try: - logging.info("Device [X]: {}".format(get_torch_device_name(get_torch_device()))) + logging.info("Device: {}".format(get_torch_device_name(get_torch_device()))) except: logging.warning("Could not pick default device.") try: for device in get_all_torch_devices(exclude_current=True): - logging.info("Device [ ]: {}".format(get_torch_device_name(device))) + logging.info("Device: {}".format(get_torch_device_name(device))) except: pass -current_loaded_models = [] +current_loaded_models: list[LoadedModel] = [] def module_size(module): module_mem = 0 @@ -1198,7 +1198,7 @@ def soft_empty_cache(force=False): def unload_all_models(): free_memory(1e30, get_torch_device()) -def unload_model_and_clones(model: ModelPatcher, unload_additional_models=True): +def unload_model_and_clones(model: ModelPatcher, unload_additional_models=True, all_devices=False): 'Unload only model and its clones - primarily for multigpu cloning purposes.' initial_keep_loaded: list[LoadedModel] = current_loaded_models.copy() additional_models = [] @@ -1218,7 +1218,11 @@ def unload_model_and_clones(model: ModelPatcher, unload_additional_models=True): if skip: continue keep_loaded.append(loaded_model) - free_memory(1e30, get_torch_device(), keep_loaded) + if not all_devices: + free_memory(1e30, get_torch_device(), keep_loaded) + else: + for device in get_all_torch_devices(): + free_memory(1e30, device, keep_loaded) #TODO: might be cleaner to put this somewhere else import threading diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index eb21396be230..5ede41dd6faf 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -243,7 +243,7 @@ def __init__(self, model, load_device, offload_device, size=0, weight_inplace_up self.is_clip = False self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed - self.is_multigpu_clone = False + self.is_multigpu_base_clone = False self.clone_base_uuid = uuid.uuid4() if not hasattr(self.model, 'model_loaded_weight_memory'): @@ -324,14 +324,16 @@ def clone(self): n.is_clip = self.is_clip n.hook_mode = self.hook_mode - n.is_multigpu_clone = self.is_multigpu_clone + n.is_multigpu_base_clone = self.is_multigpu_base_clone n.clone_base_uuid = self.clone_base_uuid for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE): callback(self, n) return n - def multigpu_deepclone(self, new_load_device=None, models_cache: dict[ModelPatcher,ModelPatcher]=None): + def deepclone_multigpu(self, new_load_device=None, models_cache: dict[uuid.UUID,ModelPatcher]=None): + logging.info(f"Creating deepclone of {self.model.__class__.__name__} for {new_load_device if new_load_device else self.load_device}.") + comfy.model_management.unload_model_and_clones(self) n = self.clone() # set load device, if present if new_load_device is not None: @@ -350,19 +352,64 @@ def multigpu_deepclone(self, new_load_device=None, models_cache: dict[ModelPatch for key, model_list in n.additional_models.items(): for i in range(len(model_list)): add_model = n.additional_models[key][i] - if i not in models_cache: - models_cache[add_model] = add_model.multigpu_deepclone(new_load_device=new_load_device, models_cache=models_cache) - n.additional_models[key][i] = models_cache[add_model] + if add_model.clone_base_uuid not in models_cache: + models_cache[add_model.clone_base_uuid] = add_model.deepclone_multigpu(new_load_device=new_load_device, models_cache=models_cache) + n.additional_models[key][i] = models_cache[add_model.clone_base_uuid] + for callback in self.get_all_callbacks(CallbacksMP.ON_DEEPCLONE_MULTIGPU): + callback(self, n) return n + def match_multigpu_clones(self): + multigpu_models = self.get_additional_models_with_key("multigpu") + if len(multigpu_models) > 0: + new_multigpu_models = [] + for mm in multigpu_models: + # clone main model, but bring over relevant props from existing multigpu clone + n = self.clone() + n.load_device = mm.load_device + n.backup = mm.backup + n.object_patches_backup = mm.object_patches_backup + n.hook_backup = mm.hook_backup + n.model = mm.model + n.is_multigpu_base_clone = mm.is_multigpu_base_clone + n.remove_additional_models("multigpu") + orig_additional_models: dict[str, list[ModelPatcher]] = comfy.patcher_extension.copy_nested_dicts(n.additional_models) + n.additional_models = comfy.patcher_extension.copy_nested_dicts(mm.additional_models) + # figure out which additional models are not present in multigpu clone + models_cache = {} + for mm_add_model in mm.get_additional_models(): + models_cache[mm_add_model.clone_base_uuid] = mm_add_model + remove_models_uuids = set(list(models_cache.keys())) + for key, model_list in orig_additional_models.items(): + for orig_add_model in model_list: + if orig_add_model.clone_base_uuid not in models_cache: + models_cache[orig_add_model.clone_base_uuid] = orig_add_model.deepclone_multigpu(new_load_device=n.load_device, models_cache=models_cache) + existing_list = n.get_additional_models_with_key(key) + existing_list.append(models_cache[orig_add_model.clone_base_uuid]) + n.set_additional_models(key, existing_list) + if orig_add_model.clone_base_uuid in remove_models_uuids: + remove_models_uuids.remove(orig_add_model.clone_base_uuid) + # remove duplicate additional models + for key, model_list in n.additional_models.items(): + new_model_list = [x for x in model_list if x.clone_base_uuid not in remove_models_uuids] + n.set_additional_models(key, new_model_list) + for callback in self.get_all_callbacks(CallbacksMP.ON_MATCH_MULTIGPU_CLONES): + callback(self, n) + new_multigpu_models.append(n) + self.set_additional_models("multigpu", new_multigpu_models) + def is_clone(self, other): if hasattr(other, 'model') and self.model is other.model: return True return False - def clone_has_same_weights(self, clone: 'ModelPatcher'): - if not self.is_clone(clone): - return False + def clone_has_same_weights(self, clone: ModelPatcher, allow_multigpu=False): + if allow_multigpu: + if self.clone_base_uuid != clone.clone_base_uuid: + return False + else: + if not self.is_clone(clone): + return False if self.current_hooks != clone.current_hooks: return False @@ -957,7 +1004,7 @@ def get_additional_models_with_key(self, key: str): return self.additional_models.get(key, []) def get_additional_models(self): - all_models = [] + all_models: list[ModelPatcher] = [] for models in self.additional_models.values(): all_models.extend(models) return all_models diff --git a/comfy/multigpu.py b/comfy/multigpu.py index 2a1fc29d2255..9cc8a37fa76d 100644 --- a/comfy/multigpu.py +++ b/comfy/multigpu.py @@ -1,10 +1,14 @@ from __future__ import annotations import torch +import logging from collections import namedtuple from typing import TYPE_CHECKING if TYPE_CHECKING: from comfy.model_patcher import ModelPatcher +import comfy.utils +import comfy.patcher_extension +import comfy.model_management class GPUOptions: @@ -53,6 +57,53 @@ def register(self, model: ModelPatcher): model.model_options['multigpu_options'] = opts_dict +def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options: GPUOptionsGroup=None, reuse_loaded=False): + 'Prepare ModelPatcher to contain deepclones of its BaseModel and related properties.' + model = model.clone() + # check if multigpu is already prepared - get the load devices from them if possible to exclude + skip_devices = set() + multigpu_models = model.get_additional_models_with_key("multigpu") + if len(multigpu_models) > 0: + for mm in multigpu_models: + skip_devices.add(mm.load_device) + skip_devices = list(skip_devices) + + extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True) + extra_devices = extra_devices[:max_gpus-1] + # exclude skipped devices + for skip in skip_devices: + if skip in extra_devices: + extra_devices.remove(skip) + # create new deepclones + if len(extra_devices) > 0: + for device in extra_devices: + device_patcher = None + if reuse_loaded: + # check if there are any ModelPatchers currently loaded that could be referenced here after a clone + loaded_models: list[ModelPatcher] = comfy.model_management.loaded_models() + for lm in loaded_models: + if lm.model is not None and lm.clone_base_uuid == model.clone_base_uuid and lm.load_device == device: + device_patcher = lm.clone() + logging.info(f"Reusing loaded deepclone of {device_patcher.model.__class__.__name__} for {device}") + break + if device_patcher is None: + device_patcher = model.deepclone_multigpu(new_load_device=device) + device_patcher.is_multigpu_base_clone = True + multigpu_models = model.get_additional_models_with_key("multigpu") + multigpu_models.append(device_patcher) + model.set_additional_models("multigpu", multigpu_models) + model.match_multigpu_clones() + if gpu_options is None: + gpu_options = GPUOptionsGroup() + gpu_options.register(model) + else: + logging.info("No extra torch devices need initialization, skipping initializing MultiGPU Work Units.") + # persist skip_devices for use in sampling code + # if len(skip_devices) > 0 or "multigpu_skip_devices" in model.model_options: + # model.model_options["multigpu_skip_devices"] = skip_devices + return model + + LoadBalance = namedtuple('LoadBalance', ['work_per_device', 'idle_time']) def load_balance_devices(model_options: dict[str], total_work: int, return_idle_time=False, work_normalized: int=None): 'Optimize work assigned to different devices, accounting for their relative speeds and splittable work.' @@ -84,6 +135,7 @@ def load_balance_devices(model_options: dict[str], total_work: int, return_idle_ completion_time = [w/r for w,r in zip(work_per_device, speed_per_device)] # calculate relative time spent by the devices waiting on each other after their work is completed idle_time = abs(min(completion_time) - max(completion_time)) + # if need to compare work idle time, need to normalize to a common total work if work_normalized: idle_time *= (work_normalized/total_work) diff --git a/comfy/patcher_extension.py b/comfy/patcher_extension.py index 8597582447fb..5145855f5574 100644 --- a/comfy/patcher_extension.py +++ b/comfy/patcher_extension.py @@ -3,6 +3,8 @@ class CallbacksMP: ON_CLONE = "on_clone" + ON_DEEPCLONE_MULTIGPU = "on_deepclone_multigpu" + ON_MATCH_MULTIGPU_CLONES = "on_match_multigpu_clones" ON_LOAD = "on_load_after" ON_DETACH = "on_detach_after" ON_CLEANUP = "on_cleanup" diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index 40b2021f79e6..9a97c8559a87 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -106,16 +106,57 @@ def cleanup_additional_models(models): if hasattr(m, 'cleanup'): m.cleanup() +def preprocess_multigpu_conds(conds: dict[str, list[dict[str]]], model: ModelPatcher, model_options: dict[str]): + '''If multigpu acceleration required, creates deepclones of ControlNets and GLIGEN per device.''' + multigpu_models: list[ModelPatcher] = model.get_additional_models_with_key("multigpu") + if len(multigpu_models) == 0: + return + extra_devices = [x.load_device for x in multigpu_models] + # handle controlnets + controlnets: set[ControlBase] = set() + for k in conds: + for kk in conds[k]: + if 'control' in kk: + controlnets.add(kk['control']) + if len(controlnets) > 0: + # first, unload all controlnet clones + for cnet in list(controlnets): + cnet_models = cnet.get_models() + for cm in cnet_models: + comfy.model_management.unload_model_and_clones(cm, unload_additional_models=True) + + # next, make sure each controlnet has a deepclone for all relevant devices + for cnet in controlnets: + curr_cnet = cnet + while curr_cnet is not None: + for device in extra_devices: + if device not in curr_cnet.multigpu_clones: + curr_cnet.deepclone_multigpu(device, autoregister=True) + curr_cnet = curr_cnet.previous_controlnet + # since all device clones are now present, recreate the linked list for cloned cnets per device + for cnet in controlnets: + curr_cnet = cnet + while curr_cnet is not None: + prev_cnet = curr_cnet.previous_controlnet + for device in extra_devices: + device_cnet = curr_cnet.get_instance_for_device(device) + prev_device_cnet = None + if prev_cnet is not None: + prev_device_cnet = prev_cnet.get_instance_for_device(device) + device_cnet.set_previous_controlnet(prev_device_cnet) + curr_cnet = prev_cnet + # potentially handle gligen - since not widely used, ignored for now def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None): - real_model: BaseModel = None + model.match_multigpu_clones() + preprocess_multigpu_conds(conds, model, model_options) models, inference_memory = get_additional_models(conds, model.model_dtype()) models += get_additional_models_from_model_options(model_options) models += model.get_nested_additional_models() # TODO: does this require inference_memory update? memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required) - real_model = model.model + real_model: BaseModel = model.model return real_model, conds, models @@ -166,7 +207,7 @@ def prepare_model_patcher_multigpu_clones(model_patcher: ModelPatcher, loaded_mo ''' In case multigpu acceleration is enabled, prep ModelPatchers for each device. ''' - multigpu_patchers: list[ModelPatcher] = [x for x in loaded_models if x.is_multigpu_clone] + multigpu_patchers: list[ModelPatcher] = [x for x in loaded_models if x.is_multigpu_base_clone] if len(multigpu_patchers) > 0: multigpu_dict: dict[torch.device, ModelPatcher] = {} multigpu_dict[model_patcher.load_device] = model_patcher diff --git a/comfy/samplers.py b/comfy/samplers.py index beef0b7e4747..d02627d8adef 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -1088,49 +1088,6 @@ def cast_transformer_options(transformer_options: dict[str], device=None, dtype= for cast in casts: wc_list[i] = wc_list[i].to(cast) - -def preprocess_multigpu_conds(conds: dict[str, list[dict[str]]], model_options: dict[str], model: ModelPatcher): - '''If multigpu acceleration required, creates deepclones of ControlNets and GLIGEN per device.''' - multigpu_models: list[ModelPatcher] = model.get_additional_models_with_key("multigpu") - if len(multigpu_models) == 0: - return - extra_devices = [x.load_device for x in multigpu_models] - # handle controlnets - controlnets: set[ControlBase] = set() - for k in conds: - for kk in conds[k]: - if 'control' in kk: - controlnets.add(kk['control']) - if len(controlnets) > 0: - # first, unload all controlnet clones - for cnet in list(controlnets): - cnet_models = cnet.get_models() - for cm in cnet_models: - comfy.model_management.unload_model_and_clones(cm, unload_additional_models=True) - - # next, make sure each controlnet has a deepclone for all relevant devices - for cnet in controlnets: - curr_cnet = cnet - while curr_cnet is not None: - for device in extra_devices: - if device not in curr_cnet.multigpu_clones: - curr_cnet.deepclone_multigpu(device, autoregister=True) - curr_cnet = curr_cnet.previous_controlnet - # since all device clones are now present, recreate the linked list for cloned cnets per device - for cnet in controlnets: - curr_cnet = cnet - while curr_cnet is not None: - prev_cnet = curr_cnet.previous_controlnet - for device in extra_devices: - device_cnet = curr_cnet.get_instance_for_device(device) - prev_device_cnet = None - if prev_cnet is not None: - prev_device_cnet = prev_cnet.get_instance_for_device(device) - device_cnet.set_previous_controlnet(prev_device_cnet) - curr_cnet = prev_cnet - # TODO: handle gligen - - class CFGGuider: def __init__(self, model_patcher: ModelPatcher): self.model_patcher = model_patcher @@ -1173,7 +1130,6 @@ def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mas return self.inner_model.process_latent_out(samples.to(torch.float32)) def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None): - preprocess_multigpu_conds(self.conds, self.model_options, self.model_patcher) self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options) device = self.model_patcher.load_device diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py index 54f68182e696..d1e458b7e312 100644 --- a/comfy_extras/nodes_multigpu.py +++ b/comfy_extras/nodes_multigpu.py @@ -1,15 +1,24 @@ from __future__ import annotations +import logging +from inspect import cleandoc -from comfy.model_patcher import ModelPatcher -import comfy.utils -import comfy.patcher_extension -import comfy.model_management +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from comfy.model_patcher import ModelPatcher import comfy.multigpu -class MultiGPUInitialize: - NodeId = "MultiGPU_Initialize" - NodeName = "MultiGPU Initialize" +class MultiGPUWorkUnitsNode: + """ + Prepares model to have sampling accelerated via splitting work units. + + Should be placed after nodes that modify the model object itself, such as compile or attention-switch nodes. + + Other than those exceptions, this node can be placed in any order. + """ + + NodeId = "MultiGPU_WorkUnits" + NodeName = "MultiGPU Work Units" @classmethod def INPUT_TYPES(cls): return { @@ -25,25 +34,17 @@ def INPUT_TYPES(cls): RETURN_TYPES = ("MODEL",) FUNCTION = "init_multigpu" CATEGORY = "advanced/multigpu" + DESCRIPTION = cleandoc(__doc__) def init_multigpu(self, model: ModelPatcher, max_gpus: int, gpu_options: comfy.multigpu.GPUOptionsGroup=None): - extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True) - extra_devices = extra_devices[:max_gpus-1] - if len(extra_devices) > 0: - model = model.clone() - comfy.model_management.unload_model_and_clones(model) - for device in extra_devices: - device_patcher = model.multigpu_deepclone(new_load_device=device) - device_patcher.is_multigpu_clone = True - multigpu_models = model.get_additional_models_with_key("multigpu") - multigpu_models.append(device_patcher) - model.set_additional_models("multigpu", multigpu_models) - if gpu_options is None: - gpu_options = comfy.multigpu.GPUOptionsGroup() - gpu_options.register(model) + model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, gpu_options, reuse_loaded=True) return (model,) class MultiGPUOptionsNode: + """ + Select the relative speed of GPUs in the special case they have significantly different performance from one another. + """ + NodeId = "MultiGPU_Options" NodeName = "MultiGPU Options" @classmethod @@ -61,6 +62,7 @@ def INPUT_TYPES(cls): RETURN_TYPES = ("GPU_OPTIONS",) FUNCTION = "create_gpu_options" CATEGORY = "advanced/multigpu" + DESCRIPTION = cleandoc(__doc__) def create_gpu_options(self, device_index: int, relative_speed: float, gpu_options: comfy.multigpu.GPUOptionsGroup=None): if not gpu_options: @@ -74,7 +76,7 @@ def create_gpu_options(self, device_index: int, relative_speed: float, gpu_optio node_list = [ - MultiGPUInitialize, + MultiGPUWorkUnitsNode, MultiGPUOptionsNode ] NODE_CLASS_MAPPINGS = {} @@ -83,6 +85,3 @@ def create_gpu_options(self, device_index: int, relative_speed: float, gpu_optio for node in node_list: NODE_CLASS_MAPPINGS[node.NodeId] = node NODE_DISPLAY_NAME_MAPPINGS[node.NodeId] = node.NodeName - -# TODO: remove -NODE_CLASS_MAPPINGS["test_multigpuinit"] = MultiGPUInitialize From 6dca17bd2dd7455701d5eb466d39d72aa4520b1c Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 3 Mar 2025 23:08:29 -0600 Subject: [PATCH 34/90] Satisfy ruff linting --- comfy/controlnet.py | 6 +++--- comfy/model_management.py | 1 - comfy/multigpu.py | 6 +++--- comfy/samplers.py | 4 ++-- comfy_extras/nodes_multigpu.py | 5 ++--- 5 files changed, 10 insertions(+), 12 deletions(-) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 9bcd1d2e3484..14f13bd9d943 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -72,7 +72,7 @@ def __init__(self, control: ControlBase): def __enter__(self): self.control.previous_controlnet = None - + def __exit__(self, *args): self.control.previous_controlnet = self.orig_previous_controlnet @@ -151,7 +151,7 @@ def get_instance_for_device(self, device): def deepclone_multigpu(self, load_device, autoregister=False): ''' Create deep clone of Control object where model(s) is set to other devices. - + When autoregister is set to True, the deep clone is also added to multigpu_clones dict. ''' raise NotImplementedError("Classes inheriting from ControlBase should define their own deepclone_multigpu funtion.") @@ -846,7 +846,7 @@ def copy(self): c = T2IAdapter(self.t2i_model, self.channels_in, self.compression_ratio, self.upscale_algorithm) self.copy_to(c) return c - + def deepclone_multigpu(self, load_device, autoregister=False): c = self.copy() c.t2i_model = copy.deepcopy(c.t2i_model) diff --git a/comfy/model_management.py b/comfy/model_management.py index 10d0dece2ca6..6e243a4372fc 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -30,7 +30,6 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from comfy.model_patcher import ModelPatcher - from comfy.model_base import BaseModel class VRAMState(Enum): DISABLED = 0 #No vram present: no need to move models to vram diff --git a/comfy/multigpu.py b/comfy/multigpu.py index 9cc8a37fa76d..aef0b68e831e 100644 --- a/comfy/multigpu.py +++ b/comfy/multigpu.py @@ -18,7 +18,7 @@ def __init__(self, device_index: int, relative_speed: float): def clone(self): return GPUOptions(self.device_index, self.relative_speed) - + def create_dict(self): return { "relative_speed": self.relative_speed @@ -86,7 +86,7 @@ def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options: device_patcher = lm.clone() logging.info(f"Reusing loaded deepclone of {device_patcher.model.__class__.__name__} for {device}") break - if device_patcher is None: + if device_patcher is None: device_patcher = model.deepclone_multigpu(new_load_device=device) device_patcher.is_multigpu_base_clone = True multigpu_models = model.get_additional_models_with_key("multigpu") @@ -138,7 +138,7 @@ def load_balance_devices(model_options: dict[str], total_work: int, return_idle_ # if need to compare work idle time, need to normalize to a common total work if work_normalized: idle_time *= (work_normalized/total_work) - + return LoadBalance(dict_work_per_device, idle_time) def round_preserved(values: list[float]): diff --git a/comfy/samplers.py b/comfy/samplers.py index babfe7a45f7f..bc97f9f71b16 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -384,7 +384,7 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t devices = [dev_m for dev_m in model_options['multigpu_clones'].keys()] device_batched_hooked_to_run: dict[torch.device, list[tuple[comfy.hooks.HookGroup, tuple]]] = {} - + total_conds = 0 for to_run in hooked_to_run.values(): total_conds += len(to_run) @@ -504,7 +504,7 @@ def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup new_thread = threading.Thread(target=_handle_batch, args=(device, batch_tuple, results)) threads.append(new_thread) new_thread.start() - + for thread in threads: thread.join() diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py index d1e458b7e312..3b68c10ff371 100644 --- a/comfy_extras/nodes_multigpu.py +++ b/comfy_extras/nodes_multigpu.py @@ -1,5 +1,4 @@ from __future__ import annotations -import logging from inspect import cleandoc from typing import TYPE_CHECKING @@ -11,7 +10,7 @@ class MultiGPUWorkUnitsNode: """ Prepares model to have sampling accelerated via splitting work units. - + Should be placed after nodes that modify the model object itself, such as compile or attention-switch nodes. Other than those exceptions, this node can be placed in any order. @@ -30,7 +29,7 @@ def INPUT_TYPES(cls): "gpu_options": ("GPU_OPTIONS",) } } - + RETURN_TYPES = ("MODEL",) FUNCTION = "init_multigpu" CATEGORY = "advanced/multigpu" From 9ce9ff8ef862f23a2486e97f4721fa56c3cea29a Mon Sep 17 00:00:00 2001 From: "kosinkadink1@gmail.com" Date: Fri, 28 Mar 2025 15:29:44 +0800 Subject: [PATCH 35/90] Allow chained MultiGPU Work Unit nodes to affect max_gpus present on ModelPatcher clone --- comfy/multigpu.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/comfy/multigpu.py b/comfy/multigpu.py index aef0b68e831e..26edcee9029f 100644 --- a/comfy/multigpu.py +++ b/comfy/multigpu.py @@ -68,8 +68,9 @@ def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options: skip_devices.add(mm.load_device) skip_devices = list(skip_devices) - extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True) - extra_devices = extra_devices[:max_gpus-1] + full_extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True) + limit_extra_devices = full_extra_devices[:max_gpus-1] + extra_devices = limit_extra_devices.copy() # exclude skipped devices for skip in skip_devices: if skip in extra_devices: @@ -98,6 +99,13 @@ def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options: gpu_options.register(model) else: logging.info("No extra torch devices need initialization, skipping initializing MultiGPU Work Units.") + # only keep model clones that don't go 'past' the intended max_gpu count + multigpu_models = model.get_additional_models_with_key("multigpu") + new_multigpu_models = [] + for m in multigpu_models: + if m.load_device in limit_extra_devices: + new_multigpu_models.append(m) + model.set_additional_models("multigpu", new_multigpu_models) # persist skip_devices for use in sampling code # if len(skip_devices) > 0 or "multigpu_skip_devices" in model.model_options: # model.model_options["multigpu_skip_devices"] = skip_devices From 407a5a656f103c42497f5938a80d0771712b8613 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 28 Mar 2025 02:48:11 -0500 Subject: [PATCH 36/90] Rollback core of last commit due to weird behavior --- comfy/multigpu.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/comfy/multigpu.py b/comfy/multigpu.py index 26edcee9029f..90995a5abdca 100644 --- a/comfy/multigpu.py +++ b/comfy/multigpu.py @@ -99,13 +99,13 @@ def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options: gpu_options.register(model) else: logging.info("No extra torch devices need initialization, skipping initializing MultiGPU Work Units.") - # only keep model clones that don't go 'past' the intended max_gpu count - multigpu_models = model.get_additional_models_with_key("multigpu") - new_multigpu_models = [] - for m in multigpu_models: - if m.load_device in limit_extra_devices: - new_multigpu_models.append(m) - model.set_additional_models("multigpu", new_multigpu_models) + # TODO: only keep model clones that don't go 'past' the intended max_gpu count + # multigpu_models = model.get_additional_models_with_key("multigpu") + # new_multigpu_models = [] + # for m in multigpu_models: + # if m.load_device in limit_extra_devices: + # new_multigpu_models.append(m) + # model.set_additional_models("multigpu", new_multigpu_models) # persist skip_devices for use in sampling code # if len(skip_devices) > 0 or "multigpu_skip_devices" in model.model_options: # model.model_options["multigpu_skip_devices"] = skip_devices From 8be711715c471db68f9cea15989b5ec0f2ac2e7d Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sat, 19 Apr 2025 17:35:54 -0500 Subject: [PATCH 37/90] Make unload_all_models account for all devices --- comfy/model_management.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 90785c2c5dd1..88c1c0a12b03 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1272,7 +1272,8 @@ def soft_empty_cache(force=False): torch.cuda.ipc_collect() def unload_all_models(): - free_memory(1e30, get_torch_device()) + for device in get_all_torch_devices(): + free_memory(1e30, device) def unload_model_and_clones(model: ModelPatcher, unload_additional_models=True, all_devices=False): 'Unload only model and its clones - primarily for multigpu cloning purposes.' From 44e053c26dc8982e88973a253eef51b9a9a91302 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 24 Jun 2025 00:48:51 -0500 Subject: [PATCH 38/90] Improve error handling for multigpu threads --- comfy/samplers.py | 143 +++++++++++++++++++++++++--------------------- 1 file changed, 78 insertions(+), 65 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index 26052766158b..90cce078d148 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -3,7 +3,7 @@ import comfy.model_management from .k_diffusion import sampling as k_diffusion_sampling from .extra_samplers import uni_pc -from typing import TYPE_CHECKING, Callable, NamedTuple +from typing import TYPE_CHECKING, Callable, NamedTuple, Any if TYPE_CHECKING: from comfy.model_patcher import ModelPatcher from comfy.model_base import BaseModel @@ -428,74 +428,85 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t if batched_to_run_length >= conds_per_device: index_device += 1 - thread_result = collections.namedtuple('thread_result', ['output', 'mult', 'area', 'batch_chunks', 'cond_or_uncond']) + class thread_result(NamedTuple): + output: Any + mult: Any + area: Any + batch_chunks: int + cond_or_uncond: Any + error: Exception = None + def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]): - model_current: BaseModel = model_options["multigpu_clones"][device].model - # run every hooked_to_run separately - with torch.no_grad(): - for hooks, to_batch in batch_tuple: - input_x = [] - mult = [] - c = [] - cond_or_uncond = [] - uuids = [] - area = [] - control: ControlBase = None - patches = None - for x in to_batch: - o = x - p = o[0] - input_x.append(p.input_x) - mult.append(p.mult) - c.append(p.conditioning) - area.append(p.area) - cond_or_uncond.append(o[1]) - uuids.append(p.uuid) - control = p.control - patches = p.patches - - batch_chunks = len(cond_or_uncond) - input_x = torch.cat(input_x).to(device) - c = cond_cat(c, device=device) - timestep_ = torch.cat([timestep.to(device)] * batch_chunks) - - transformer_options = model_current.current_patcher.apply_hooks(hooks=hooks) - if 'transformer_options' in model_options: - transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options, - model_options['transformer_options'], - copy_dict1=False) - - if patches is not None: - # TODO: replace with merge_nested_dicts function - if "patches" in transformer_options: - cur_patches = transformer_options["patches"].copy() - for p in patches: - if p in cur_patches: - cur_patches[p] = cur_patches[p] + patches[p] - else: - cur_patches[p] = patches[p] - transformer_options["patches"] = cur_patches - else: - transformer_options["patches"] = patches + try: + model_current: BaseModel = model_options["multigpu_clones"][device].model + # run every hooked_to_run separately + with torch.no_grad(): + for hooks, to_batch in batch_tuple: + input_x = [] + mult = [] + c = [] + cond_or_uncond = [] + uuids = [] + area = [] + control: ControlBase = None + patches = None + for x in to_batch: + o = x + p = o[0] + input_x.append(p.input_x) + mult.append(p.mult) + c.append(p.conditioning) + area.append(p.area) + cond_or_uncond.append(o[1]) + uuids.append(p.uuid) + control = p.control + patches = p.patches + + batch_chunks = len(cond_or_uncond) + input_x = torch.cat(input_x).to(device) + c = cond_cat(c, device=device) + timestep_ = torch.cat([timestep.to(device)] * batch_chunks) + + transformer_options = model_current.current_patcher.apply_hooks(hooks=hooks) + if 'transformer_options' in model_options: + transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options, + model_options['transformer_options'], + copy_dict1=False) + + if patches is not None: + # TODO: replace with merge_nested_dicts function + if "patches" in transformer_options: + cur_patches = transformer_options["patches"].copy() + for p in patches: + if p in cur_patches: + cur_patches[p] = cur_patches[p] + patches[p] + else: + cur_patches[p] = patches[p] + transformer_options["patches"] = cur_patches + else: + transformer_options["patches"] = patches - transformer_options["cond_or_uncond"] = cond_or_uncond[:] - transformer_options["uuids"] = uuids[:] - transformer_options["sigmas"] = timestep - transformer_options["sample_sigmas"] = transformer_options["sample_sigmas"].to(device) - transformer_options["multigpu_thread_device"] = device + transformer_options["cond_or_uncond"] = cond_or_uncond[:] + transformer_options["uuids"] = uuids[:] + transformer_options["sigmas"] = timestep + transformer_options["sample_sigmas"] = transformer_options["sample_sigmas"].to(device) + transformer_options["multigpu_thread_device"] = device - cast_transformer_options(transformer_options, device=device) - c['transformer_options'] = transformer_options + cast_transformer_options(transformer_options, device=device) + c['transformer_options'] = transformer_options - if control is not None: - device_control = control.get_instance_for_device(device) - c['control'] = device_control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options) + if control is not None: + device_control = control.get_instance_for_device(device) + c['control'] = device_control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options) - if 'model_function_wrapper' in model_options: - output = model_options['model_function_wrapper'](model_current.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).to(output_device).chunk(batch_chunks) - else: - output = model_current.apply_model(input_x, timestep_, **c).to(output_device).chunk(batch_chunks) - results.append(thread_result(output, mult, area, batch_chunks, cond_or_uncond)) + if 'model_function_wrapper' in model_options: + output = model_options['model_function_wrapper'](model_current.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).to(output_device).chunk(batch_chunks) + else: + output = model_current.apply_model(input_x, timestep_, **c).to(output_device).chunk(batch_chunks) + results.append(thread_result(output, mult, area, batch_chunks, cond_or_uncond)) + except Exception as e: + results.append(thread_result(None, None, None, None, None, error=e)) + raise results: list[thread_result] = [] @@ -508,7 +519,9 @@ def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup for thread in threads: thread.join() - for output, mult, area, batch_chunks, cond_or_uncond in results: + for output, mult, area, batch_chunks, cond_or_uncond, error in results: + if error is not None: + raise error for o in range(batch_chunks): cond_index = cond_or_uncond[o] a = area[o] From d89dd5f0b04c09b01926002244280d98590f02fe Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 13 Oct 2025 22:00:34 -0700 Subject: [PATCH 39/90] Satisfy ruff --- comfy/sampler_helpers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index c43bd3bac3be..9aa9fa28aaaa 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -12,7 +12,6 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from comfy.model_patcher import ModelPatcher - from comfy.model_base import BaseModel from comfy.controlnet import ControlBase def prepare_mask(noise_mask, shape, device): From 4661d1db5aa774f972bc270f2a1e5f8cf20ea978 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Wed, 15 Oct 2025 17:34:36 -0700 Subject: [PATCH 40/90] Bring patches changes from _calc_cond_batch into _calc_cond_batch_multigpu --- comfy/samplers.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index e0e0296f8589..ed702304cdf8 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -481,17 +481,10 @@ def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup copy_dict1=False) if patches is not None: - # TODO: replace with merge_nested_dicts function - if "patches" in transformer_options: - cur_patches = transformer_options["patches"].copy() - for p in patches: - if p in cur_patches: - cur_patches[p] = cur_patches[p] + patches[p] - else: - cur_patches[p] = patches[p] - transformer_options["patches"] = cur_patches - else: - transformer_options["patches"] = patches + transformer_options["patches"] = comfy.patcher_extension.merge_nested_dicts( + transformer_options.get("patches", {}), + patches + ) transformer_options["cond_or_uncond"] = cond_or_uncond[:] transformer_options["uuids"] = uuids[:] From f4b99bc62389af315013dda85f24f2bbd262b686 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 17 Feb 2026 04:55:00 -0800 Subject: [PATCH 41/90] Made multigpu deepclone load model from disk to avoid needing to deepclone actual model object, fixed issues with merge, turn off cuda backend as it causes device mismatch issue with rope (and potentially other ops), will investigate --- comfy/model_patcher.py | 11 ++++++++++- comfy/quant_ops.py | 2 +- comfy/samplers.py | 4 ++-- comfy/sd.py | 2 ++ 4 files changed, 15 insertions(+), 4 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index d0110c7c6bac..aa7b862e77e4 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -23,6 +23,7 @@ import logging import math import uuid +import copy from typing import Callable, Optional import torch @@ -274,6 +275,7 @@ def __init__(self, model, load_device, offload_device, size=0, weight_inplace_up self.is_clip = False self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed + self.cached_patcher_init: tuple[Callable, tuple] | tuple[Callable, tuple, int] | None = None self.is_multigpu_base_clone = False self.clone_base_uuid = uuid.uuid4() @@ -368,6 +370,7 @@ def clone(self): n.is_clip = self.is_clip n.hook_mode = self.hook_mode + n.cached_patcher_init = self.cached_patcher_init n.is_multigpu_base_clone = self.is_multigpu_base_clone n.clone_base_uuid = self.clone_base_uuid @@ -382,12 +385,18 @@ def deepclone_multigpu(self, new_load_device=None, models_cache: dict[uuid.UUID, # set load device, if present if new_load_device is not None: n.load_device = new_load_device + if self.cached_patcher_init is not None: + temp_model_patcher: ModelPatcher | list[ModelPatcher] = self.cached_patcher_init[0](*self.cached_patcher_init[1]) + if len(self.cached_patcher_init) > 2: + temp_model_patcher = temp_model_patcher[self.cached_patcher_init[2]] + n.model = temp_model_patcher.model + else: + n.model = copy.deepcopy(n.model) # unlike for normal clone, backup dicts that shared same ref should not; # otherwise, patchers that have deep copies of base models will erroneously influence each other. n.backup = copy.deepcopy(n.backup) n.object_patches_backup = copy.deepcopy(n.object_patches_backup) n.hook_backup = copy.deepcopy(n.hook_backup) - n.model = copy.deepcopy(n.model) # multigpu clone should not have multigpu additional_models entry n.remove_additional_models("multigpu") # multigpu_clone all stored additional_models; make sure circular references are properly handled diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 15a4f457bed6..d8addefd85f6 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -20,7 +20,7 @@ if cuda_version < (13,): ck.registry.disable("cuda") logging.warning("WARNING: You need pytorch with cu130 or higher to use optimized CUDA operations.") - + ck.registry.disable("cuda") # multigpu will not work rn with comfy-kitchen on cuda backend ck.registry.disable("triton") for k, v in ck.list_backends().items(): logging.info(f"Found comfy_kitchen backend {k}: {v}") diff --git a/comfy/samplers.py b/comfy/samplers.py index 3f5a699d9ef9..5dee49e7e3a0 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -418,7 +418,7 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t to_batch_temp.reverse() to_batch = to_batch_temp[:1] - free_memory = model_management.get_free_memory(current_device) + free_memory = comfy.model_management.get_free_memory(current_device) for i in range(1, len(to_batch_temp) + 1): batch_amount = to_batch_temp[:len(to_batch_temp)//i] input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] @@ -487,7 +487,7 @@ def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup transformer_options["cond_or_uncond"] = cond_or_uncond[:] transformer_options["uuids"] = uuids[:] - transformer_options["sigmas"] = timestep + transformer_options["sigmas"] = timestep.to(device) transformer_options["sample_sigmas"] = transformer_options["sample_sigmas"].to(device) transformer_options["multigpu_thread_device"] = device diff --git a/comfy/sd.py b/comfy/sd.py index f65e7caddd55..2643de26dc79 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1510,6 +1510,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata) if out is None: raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd))) + out[0].cached_patcher_init = (load_checkpoint_guess_config, (ckpt_path, False, False, False, embedding_directory, output_model, model_options, te_model_options), 0) return out def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None): @@ -1711,6 +1712,7 @@ def load_diffusion_model(unet_path, model_options={}): if model is None: logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path)) raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd))) + model.cached_patcher_init = (load_diffusion_model, (unet_path, model_options)) return model def load_unet(unet_path, dtype=None): From 84f465e791f4957921b1452fc239fa6794c96f22 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 30 Mar 2026 07:07:54 -0700 Subject: [PATCH 42/90] Set CUDA device at start of multigpu threads to avoid multithreading bugs Amp-Thread-ID: https://ampcode.com/threads/T-019d3ee9-19d5-767a-9d7a-e50cbbef815b Co-authored-by: Amp --- comfy/samplers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/samplers.py b/comfy/samplers.py index ab691ed5bb72..1ff50f51d6ad 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -444,6 +444,7 @@ class thread_result(NamedTuple): def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]): try: + torch.cuda.set_device(device) model_current: BaseModel = model_options["multigpu_clones"][device].model # run every hooked_to_run separately with torch.no_grad(): From d52dcbc88fa225707bc18269da69b7c18cbbf5b3 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 30 Mar 2026 07:23:13 -0700 Subject: [PATCH 43/90] Rewrite multigpu nodes to V3 format Amp-Thread-ID: https://ampcode.com/threads/T-019d3ee9-19d5-767a-9d7a-e50cbbef815b Co-authored-by: Amp --- comfy_extras/nodes_multigpu.py | 108 +++++++++++++++++---------------- 1 file changed, 56 insertions(+), 52 deletions(-) diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py index 3b68c10ff371..789038b1df59 100644 --- a/comfy_extras/nodes_multigpu.py +++ b/comfy_extras/nodes_multigpu.py @@ -1,13 +1,17 @@ from __future__ import annotations -from inspect import cleandoc +from inspect import cleandoc from typing import TYPE_CHECKING +from typing_extensions import override + +from comfy_api.latest import ComfyExtension, io + if TYPE_CHECKING: from comfy.model_patcher import ModelPatcher import comfy.multigpu -class MultiGPUWorkUnitsNode: +class MultiGPUWorkUnitsNode(io.ComfyNode): """ Prepares model to have sampling accelerated via splitting work units. @@ -16,54 +20,53 @@ class MultiGPUWorkUnitsNode: Other than those exceptions, this node can be placed in any order. """ - NodeId = "MultiGPU_WorkUnits" - NodeName = "MultiGPU Work Units" @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "model": ("MODEL",), - "max_gpus" : ("INT", {"default": 8, "min": 1, "step": 1}), - }, - "optional": { - "gpu_options": ("GPU_OPTIONS",) - } - } - - RETURN_TYPES = ("MODEL",) - FUNCTION = "init_multigpu" - CATEGORY = "advanced/multigpu" - DESCRIPTION = cleandoc(__doc__) - - def init_multigpu(self, model: ModelPatcher, max_gpus: int, gpu_options: comfy.multigpu.GPUOptionsGroup=None): + def define_schema(cls): + return io.Schema( + node_id="MultiGPU_WorkUnits", + display_name="MultiGPU Work Units", + category="advanced/multigpu", + description=cleandoc(cls.__doc__), + inputs=[ + io.Model.Input("model"), + io.Int.Input("max_gpus", default=8, min=1, step=1), + io.Custom("GPU_OPTIONS").Input("gpu_options", optional=True), + ], + outputs=[ + io.Model.Output(), + ], + ) + + @classmethod + def execute(cls, model: ModelPatcher, max_gpus: int, gpu_options: comfy.multigpu.GPUOptionsGroup = None) -> io.NodeOutput: model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, gpu_options, reuse_loaded=True) - return (model,) + return io.NodeOutput(model) + -class MultiGPUOptionsNode: +class MultiGPUOptionsNode(io.ComfyNode): """ Select the relative speed of GPUs in the special case they have significantly different performance from one another. """ - NodeId = "MultiGPU_Options" - NodeName = "MultiGPU Options" @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "device_index": ("INT", {"default": 0, "min": 0, "max": 64}), - "relative_speed": ("FLOAT", {"default": 1.0, "min": 0.0, "step": 0.01}) - }, - "optional": { - "gpu_options": ("GPU_OPTIONS",) - } - } - - RETURN_TYPES = ("GPU_OPTIONS",) - FUNCTION = "create_gpu_options" - CATEGORY = "advanced/multigpu" - DESCRIPTION = cleandoc(__doc__) - - def create_gpu_options(self, device_index: int, relative_speed: float, gpu_options: comfy.multigpu.GPUOptionsGroup=None): + def define_schema(cls): + return io.Schema( + node_id="MultiGPU_Options", + display_name="MultiGPU Options", + category="advanced/multigpu", + description=cleandoc(cls.__doc__), + inputs=[ + io.Int.Input("device_index", default=0, min=0, max=64), + io.Float.Input("relative_speed", default=1.0, min=0.0, step=0.01), + io.Custom("GPU_OPTIONS").Input("gpu_options", optional=True), + ], + outputs=[ + io.Custom("GPU_OPTIONS").Output(), + ], + ) + + @classmethod + def execute(cls, device_index: int, relative_speed: float, gpu_options: comfy.multigpu.GPUOptionsGroup = None) -> io.NodeOutput: if not gpu_options: gpu_options = comfy.multigpu.GPUOptionsGroup() gpu_options.clone() @@ -71,16 +74,17 @@ def create_gpu_options(self, device_index: int, relative_speed: float, gpu_optio opt = comfy.multigpu.GPUOptions(device_index=device_index, relative_speed=relative_speed) gpu_options.add(opt) - return (gpu_options,) + return io.NodeOutput(gpu_options) + +class MultiGPUExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + MultiGPUWorkUnitsNode, + MultiGPUOptionsNode, + ] -node_list = [ - MultiGPUWorkUnitsNode, - MultiGPUOptionsNode -] -NODE_CLASS_MAPPINGS = {} -NODE_DISPLAY_NAME_MAPPINGS = {} -for node in node_list: - NODE_CLASS_MAPPINGS[node.NodeId] = node - NODE_DISPLAY_NAME_MAPPINGS[node.NodeId] = node.NodeName +async def comfy_entrypoint() -> MultiGPUExtension: + return MultiGPUExtension() From 5f4fcd19e7a5ce82b998495d18c10f4a111e41b7 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 30 Mar 2026 07:30:32 -0700 Subject: [PATCH 44/90] Simplify multigpu nodes: default max_gpus=2, remove gpu_options input, disable Options node Amp-Thread-ID: https://ampcode.com/threads/T-019d3ee9-19d5-767a-9d7a-e50cbbef815b Co-authored-by: Amp --- comfy_extras/nodes_multigpu.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py index 789038b1df59..c77dd5c1fe9d 100644 --- a/comfy_extras/nodes_multigpu.py +++ b/comfy_extras/nodes_multigpu.py @@ -29,8 +29,7 @@ def define_schema(cls): description=cleandoc(cls.__doc__), inputs=[ io.Model.Input("model"), - io.Int.Input("max_gpus", default=8, min=1, step=1), - io.Custom("GPU_OPTIONS").Input("gpu_options", optional=True), + io.Int.Input("max_gpus", default=2, min=1, step=1), ], outputs=[ io.Model.Output(), @@ -38,8 +37,8 @@ def define_schema(cls): ) @classmethod - def execute(cls, model: ModelPatcher, max_gpus: int, gpu_options: comfy.multigpu.GPUOptionsGroup = None) -> io.NodeOutput: - model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, gpu_options, reuse_loaded=True) + def execute(cls, model: ModelPatcher, max_gpus: int) -> io.NodeOutput: + model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, reuse_loaded=True) return io.NodeOutput(model) @@ -82,7 +81,7 @@ class MultiGPUExtension(ComfyExtension): async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ MultiGPUWorkUnitsNode, - MultiGPUOptionsNode, + # MultiGPUOptionsNode, ] From 1d8e379f41154354edf7879d21606cd8dabd575a Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 30 Mar 2026 08:00:20 -0700 Subject: [PATCH 45/90] Rename MultiGPU Work Units to MultiGPU CFG Split Amp-Thread-ID: https://ampcode.com/threads/T-019d3ee9-19d5-767a-9d7a-e50cbbef815b Co-authored-by: Amp --- comfy_extras/nodes_multigpu.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py index c77dd5c1fe9d..5d24952bf61c 100644 --- a/comfy_extras/nodes_multigpu.py +++ b/comfy_extras/nodes_multigpu.py @@ -11,7 +11,7 @@ import comfy.multigpu -class MultiGPUWorkUnitsNode(io.ComfyNode): +class MultiGPUCFGSplitNode(io.ComfyNode): """ Prepares model to have sampling accelerated via splitting work units. @@ -24,7 +24,7 @@ class MultiGPUWorkUnitsNode(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="MultiGPU_WorkUnits", - display_name="MultiGPU Work Units", + display_name="MultiGPU CFG Split", category="advanced/multigpu", description=cleandoc(cls.__doc__), inputs=[ @@ -80,7 +80,7 @@ class MultiGPUExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ - MultiGPUWorkUnitsNode, + MultiGPUCFGSplitNode, # MultiGPUOptionsNode, ] From afdddcee66cb80b81bdc071da3773a54652d1284 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 30 Mar 2026 08:32:52 -0700 Subject: [PATCH 46/90] Re-enable comfy-kitchen cuda backend for multigpu testing Amp-Thread-ID: https://ampcode.com/threads/T-019d3f5c-28c5-72c9-abed-34681f1b54ba Co-authored-by: Amp --- comfy/quant_ops.py | 1 - 1 file changed, 1 deletion(-) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 9375255d142f..37e54672285f 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -20,7 +20,6 @@ if cuda_version < (13,): ck.registry.disable("cuda") logging.warning("WARNING: You need pytorch with cu130 or higher to use optimized CUDA operations.") - ck.registry.disable("cuda") # multigpu will not work rn with comfy-kitchen on cuda backend ck.registry.disable("triton") for k, v in ck.list_backends().items(): logging.info(f"Found comfy_kitchen backend {k}: {v}") From 3fab720be9123a710578b94a89f94d80f9601761 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 30 Mar 2026 08:45:55 -0700 Subject: [PATCH 47/90] Add debug logging for device mismatch in ModelPatcherDynamic.load Amp-Thread-ID: https://ampcode.com/threads/T-019d3f5c-28c5-72c9-abed-34681f1b54ba Co-authored-by: Amp --- comfy/model_management.py | 2 ++ comfy/model_patcher.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index c89f7a246c98..3e58e7dd9e24 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -639,6 +639,8 @@ def model_unload(self, memory_to_free=None, unpatch_weights=True): return True def model_use_more_vram(self, extra_memory, force_patch_weights=False): + if self.device != self.model.load_device: + logging.error(f"LoadedModel device mismatch: self.device={self.device}, model.load_device={self.model.load_device}, model_class={self.model.model.__class__.__name__}, is_multigpu={getattr(self.model, 'is_multigpu_base_clone', False)}, id(model)={id(self.model)}") return self.model.partially_load(self.device, extra_memory, force_patch_weights=force_patch_weights) def __eq__(self, other): diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index c3ecc276f5ca..a3872926d0a7 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -1646,6 +1646,8 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False #now. assert not full_load + if device_to != self.load_device: + logging.error(f"ModelPatcherDynamic.load device mismatch: device_to={device_to}, self.load_device={self.load_device}, model_class={self.model.__class__.__name__}, is_multigpu_base_clone={getattr(self, 'is_multigpu_base_clone', False)}, id(self)={id(self)}") assert device_to == self.load_device num_patches = 0 From 20803749c3be2666d2ee34f0371c6f483a792b5d Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 30 Mar 2026 08:53:36 -0700 Subject: [PATCH 48/90] Add detailed multigpu debug logging to load_models_gpu Amp-Thread-ID: https://ampcode.com/threads/T-019d3f5c-28c5-72c9-abed-34681f1b54ba Co-authored-by: Amp --- comfy/model_management.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index 3e58e7dd9e24..76d475c0da94 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -780,16 +780,19 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu if loaded_model_index is not None: loaded = current_loaded_models[loaded_model_index] loaded.currently_used = True + logging.info(f"[MULTIGPU_DBG] Reusing LoadedModel for {x.model.__class__.__name__}: LoadedModel.device={loaded.device}, model.load_device={loaded.model.load_device}, is_multigpu={getattr(loaded.model, 'is_multigpu_base_clone', False)}, id(patcher)={id(loaded.model)}, id(inner)={id(loaded.model.model)}") models_to_load.append(loaded) else: if hasattr(x, "model"): logging.info(f"Requested to load {x.model.__class__.__name__}") + logging.info(f"[MULTIGPU_DBG] New LoadedModel for {x.model.__class__.__name__}: LoadedModel.device={loaded_model.device}, model.load_device={x.load_device}, is_multigpu={getattr(x, 'is_multigpu_base_clone', False)}, id(patcher)={id(x)}, id(inner)={id(x.model)}") models_to_load.append(loaded_model) for loaded_model in models_to_load: to_unload = [] for i in range(len(current_loaded_models)): if loaded_model.model.is_clone(current_loaded_models[i].model): + logging.info(f"[MULTIGPU_DBG] is_clone match: unloading idx={i}, LoadedModel.device={current_loaded_models[i].device}, model.load_device={current_loaded_models[i].model.load_device}, id(inner)={id(current_loaded_models[i].model.model)}") to_unload = [i] + to_unload for i in to_unload: model_to_unload = current_loaded_models.pop(i) From b418fb1582946578ca04daf0aeeee76955e79c7a Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 30 Mar 2026 08:56:33 -0700 Subject: [PATCH 49/90] Fix device mismatch: update LoadedModel.device when _switch_parent swaps to parent patcher When a multigpu clone ModelPatcher is garbage collected, LoadedModel._switch_parent switches the weakref to point at the parent (main) ModelPatcher. However, it was not updating LoadedModel.device, leaving it with the old clone's device (e.g., cuda:1). On subsequent runs, this stale device was passed to ModelPatcherDynamic.load(), causing an assertion failure (device_to != self.load_device). Amp-Thread-ID: https://ampcode.com/threads/T-019d3f5c-28c5-72c9-abed-34681f1b54ba Co-authored-by: Amp --- comfy/model_management.py | 6 +----- comfy/model_patcher.py | 2 -- 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 76d475c0da94..14d9f80fb4d0 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -577,6 +577,7 @@ def _switch_parent(self): model = self._parent_model() if model is not None: self._set_model(model) + self.device = model.load_device @property def model(self): @@ -639,8 +640,6 @@ def model_unload(self, memory_to_free=None, unpatch_weights=True): return True def model_use_more_vram(self, extra_memory, force_patch_weights=False): - if self.device != self.model.load_device: - logging.error(f"LoadedModel device mismatch: self.device={self.device}, model.load_device={self.model.load_device}, model_class={self.model.model.__class__.__name__}, is_multigpu={getattr(self.model, 'is_multigpu_base_clone', False)}, id(model)={id(self.model)}") return self.model.partially_load(self.device, extra_memory, force_patch_weights=force_patch_weights) def __eq__(self, other): @@ -780,19 +779,16 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu if loaded_model_index is not None: loaded = current_loaded_models[loaded_model_index] loaded.currently_used = True - logging.info(f"[MULTIGPU_DBG] Reusing LoadedModel for {x.model.__class__.__name__}: LoadedModel.device={loaded.device}, model.load_device={loaded.model.load_device}, is_multigpu={getattr(loaded.model, 'is_multigpu_base_clone', False)}, id(patcher)={id(loaded.model)}, id(inner)={id(loaded.model.model)}") models_to_load.append(loaded) else: if hasattr(x, "model"): logging.info(f"Requested to load {x.model.__class__.__name__}") - logging.info(f"[MULTIGPU_DBG] New LoadedModel for {x.model.__class__.__name__}: LoadedModel.device={loaded_model.device}, model.load_device={x.load_device}, is_multigpu={getattr(x, 'is_multigpu_base_clone', False)}, id(patcher)={id(x)}, id(inner)={id(x.model)}") models_to_load.append(loaded_model) for loaded_model in models_to_load: to_unload = [] for i in range(len(current_loaded_models)): if loaded_model.model.is_clone(current_loaded_models[i].model): - logging.info(f"[MULTIGPU_DBG] is_clone match: unloading idx={i}, LoadedModel.device={current_loaded_models[i].device}, model.load_device={current_loaded_models[i].model.load_device}, id(inner)={id(current_loaded_models[i].model.model)}") to_unload = [i] + to_unload for i in to_unload: model_to_unload = current_loaded_models.pop(i) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index a3872926d0a7..c3ecc276f5ca 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -1646,8 +1646,6 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False #now. assert not full_load - if device_to != self.load_device: - logging.error(f"ModelPatcherDynamic.load device mismatch: device_to={device_to}, self.load_device={self.load_device}, model_class={self.model.__class__.__name__}, is_multigpu_base_clone={getattr(self, 'is_multigpu_base_clone', False)}, id(self)={id(self)}") assert device_to == self.load_device num_patches = 0 From 4b93c4360f4d09fa6f3a360fbf74c858c86f091a Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Wed, 8 Apr 2026 02:39:07 -1000 Subject: [PATCH 50/90] Implement persistent thread pool for multi-GPU CFG splitting (#13329) Replace per-step thread create/destroy in _calc_cond_batch_multigpu with a persistent MultiGPUThreadPool. Each worker thread calls torch.cuda.set_device() once at startup, preserving compiled kernel caches across diffusion steps. - Add MultiGPUThreadPool class in comfy/multigpu.py - Create pool in CFGGuider.outer_sample(), shut down in finally block - Main thread handles its own device batch directly for zero overhead - Falls back to sequential execution if no pool is available --- comfy/multigpu.py | 63 ++++++++++++++++++++++++++++++++++++++++ comfy/sampler_helpers.py | 1 + comfy/samplers.py | 53 ++++++++++++++++++++++++++------- 3 files changed, 106 insertions(+), 11 deletions(-) diff --git a/comfy/multigpu.py b/comfy/multigpu.py index 90995a5abdca..096270c12573 100644 --- a/comfy/multigpu.py +++ b/comfy/multigpu.py @@ -1,4 +1,6 @@ from __future__ import annotations +import queue +import threading import torch import logging @@ -11,6 +13,67 @@ import comfy.model_management +class MultiGPUThreadPool: + """Persistent thread pool for multi-GPU work distribution. + + Maintains one worker thread per extra GPU device. Each thread calls + torch.cuda.set_device() once at startup so that compiled kernel caches + (inductor/triton) stay warm across diffusion steps. + """ + + def __init__(self, devices: list[torch.device]): + self._workers: list[threading.Thread] = [] + self._work_queues: dict[torch.device, queue.Queue] = {} + self._result_queues: dict[torch.device, queue.Queue] = {} + + for device in devices: + wq = queue.Queue() + rq = queue.Queue() + self._work_queues[device] = wq + self._result_queues[device] = rq + t = threading.Thread(target=self._worker_loop, args=(device, wq, rq), daemon=True) + t.start() + self._workers.append(t) + + def _worker_loop(self, device: torch.device, work_q: queue.Queue, result_q: queue.Queue): + try: + torch.cuda.set_device(device) + except Exception as e: + logging.error(f"MultiGPUThreadPool: failed to set device {device}: {e}") + while True: + item = work_q.get() + if item is None: + return + result_q.put((None, e)) + return + while True: + item = work_q.get() + if item is None: + break + fn, args, kwargs = item + try: + result = fn(*args, **kwargs) + result_q.put((result, None)) + except Exception as e: + result_q.put((None, e)) + + def submit(self, device: torch.device, fn, *args, **kwargs): + self._work_queues[device].put((fn, args, kwargs)) + + def get_result(self, device: torch.device): + return self._result_queues[device].get() + + @property + def devices(self) -> list[torch.device]: + return list(self._work_queues.keys()) + + def shutdown(self): + for wq in self._work_queues.values(): + wq.put(None) # sentinel + for t in self._workers: + t.join(timeout=5.0) + + class GPUOptions: def __init__(self, device_index: int, relative_speed: float): self.device_index = device_index diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index 844fadacd844..6f5447d959f3 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -11,6 +11,7 @@ import comfy.patcher_extension from typing import TYPE_CHECKING if TYPE_CHECKING: + from comfy.model_base import BaseModel from comfy.model_patcher import ModelPatcher from comfy.controlnet import ControlBase diff --git a/comfy/samplers.py b/comfy/samplers.py index 1ff50f51d6ad..68f093749835 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -18,10 +18,10 @@ import comfy.patcher_extension import comfy.hooks import comfy.context_windows +import comfy.multigpu import comfy.utils import scipy.stats import numpy -import threading def add_area_dims(area, num_dims): @@ -509,15 +509,38 @@ def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup raise + def _handle_batch_pooled(device, batch_tuple): + worker_results = [] + _handle_batch(device, batch_tuple, worker_results) + return worker_results + results: list[thread_result] = [] - threads: list[threading.Thread] = [] + thread_pool: comfy.multigpu.MultiGPUThreadPool = model_options.get("multigpu_thread_pool") + main_device = output_device + main_batch_tuple = None + + # Submit extra GPU work to pool first, then run main device on this thread + pool_devices = [] for device, batch_tuple in device_batched_hooked_to_run.items(): - new_thread = threading.Thread(target=_handle_batch, args=(device, batch_tuple, results)) - threads.append(new_thread) - new_thread.start() + if device == main_device and thread_pool is not None: + main_batch_tuple = batch_tuple + elif thread_pool is not None: + thread_pool.submit(device, _handle_batch_pooled, device, batch_tuple) + pool_devices.append(device) + else: + # Fallback: no pool, run everything on main thread + _handle_batch(device, batch_tuple, results) - for thread in threads: - thread.join() + # Run main device batch on this thread (parallel with pool workers) + if main_batch_tuple is not None: + _handle_batch(main_device, main_batch_tuple, results) + + # Collect results from pool workers + for device in pool_devices: + worker_results, error = thread_pool.get_result(device) + if error is not None: + raise error + results.extend(worker_results) for output, mult, area, batch_chunks, cond_or_uncond, error in results: if error is not None: @@ -1187,17 +1210,25 @@ def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, multigpu_patchers = comfy.sampler_helpers.prepare_model_patcher_multigpu_clones(self.model_patcher, self.loaded_models, self.model_options) - noise = noise.to(device=device, dtype=torch.float32) - latent_image = latent_image.to(device=device, dtype=torch.float32) - sigmas = sigmas.to(device) - cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype()) + # Create persistent thread pool for extra GPU devices + if multigpu_patchers: + extra_devices = [p.load_device for p in multigpu_patchers] + self.model_options["multigpu_thread_pool"] = comfy.multigpu.MultiGPUThreadPool(extra_devices) try: + noise = noise.to(device=device, dtype=torch.float32) + latent_image = latent_image.to(device=device, dtype=torch.float32) + sigmas = sigmas.to(device) + cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype()) + self.model_patcher.pre_run() for multigpu_patcher in multigpu_patchers: multigpu_patcher.pre_run() output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes) finally: + thread_pool = self.model_options.pop("multigpu_thread_pool", None) + if thread_pool is not None: + thread_pool.shutdown() self.model_patcher.cleanup() for multigpu_patcher in multigpu_patchers: multigpu_patcher.cleanup() From 48deb15c0e2b3336de4ca27b3e920954dfde453b Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Wed, 8 Apr 2026 22:15:57 -1000 Subject: [PATCH 51/90] Simplify multigpu dispatch: run all devices on pool threads (#13340) Benchmarked hybrid (main thread + pool) vs all-pool on 2x RTX 4090 with SD1.5 and NetaYume models. No meaningful performance difference (within noise). All-pool is simpler: eliminates the main_device special case, main_batch_tuple deferred execution, and the 3-way branch in the dispatch loop. --- comfy/samplers.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index 68f093749835..8ebf1c496a0e 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -516,25 +516,17 @@ def _handle_batch_pooled(device, batch_tuple): results: list[thread_result] = [] thread_pool: comfy.multigpu.MultiGPUThreadPool = model_options.get("multigpu_thread_pool") - main_device = output_device - main_batch_tuple = None - # Submit extra GPU work to pool first, then run main device on this thread + # Submit all GPU work to pool threads pool_devices = [] for device, batch_tuple in device_batched_hooked_to_run.items(): - if device == main_device and thread_pool is not None: - main_batch_tuple = batch_tuple - elif thread_pool is not None: + if thread_pool is not None: thread_pool.submit(device, _handle_batch_pooled, device, batch_tuple) pool_devices.append(device) else: # Fallback: no pool, run everything on main thread _handle_batch(device, batch_tuple, results) - # Run main device batch on this thread (parallel with pool workers) - if main_batch_tuple is not None: - _handle_batch(main_device, main_batch_tuple, results) - # Collect results from pool workers for device in pool_devices: worker_results, error = thread_pool.get_result(device) @@ -1210,10 +1202,11 @@ def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, multigpu_patchers = comfy.sampler_helpers.prepare_model_patcher_multigpu_clones(self.model_patcher, self.loaded_models, self.model_options) - # Create persistent thread pool for extra GPU devices + # Create persistent thread pool for all GPU devices (main + extras) if multigpu_patchers: extra_devices = [p.load_device for p in multigpu_patchers] - self.model_options["multigpu_thread_pool"] = comfy.multigpu.MultiGPUThreadPool(extra_devices) + all_devices = [device] + extra_devices + self.model_options["multigpu_thread_pool"] = comfy.multigpu.MultiGPUThreadPool(all_devices) try: noise = noise.to(device=device, dtype=torch.float32) From f0d550bd02bc0f7550cad113eca852cdf5c805c6 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Thu, 16 Apr 2026 15:49:01 +1000 Subject: [PATCH 52/90] Minor updates for worksplit_gpu with comfy-aimdo (#13419) * main: init all visible cuda devices in aimdo * mp: call vbars_analyze for the GPU in question * requirements: bump aimdo to pre-release version --- comfy/model_patcher.py | 3 ++- main.py | 4 ++-- requirements.txt | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index c3ecc276f5ca..a74a5190292f 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -319,7 +319,8 @@ def get_free_memory(self, device): #than pays for CFG. So return everything both torch and Aimdo could give us aimdo_mem = 0 if comfy.memory_management.aimdo_enabled: - aimdo_mem = comfy_aimdo.model_vbar.vbars_analyze() + aimdo_device = device.index if getattr(device, "type", None) == "cuda" else None + aimdo_mem = comfy_aimdo.model_vbar.vbars_analyze(aimdo_device) return comfy.model_management.get_free_memory(device) + aimdo_mem def get_clone_model_override(self): diff --git a/main.py b/main.py index 12b04719d572..de145a1e9378 100644 --- a/main.py +++ b/main.py @@ -192,7 +192,7 @@ def execute_script(script_path): if 'torch' in sys.modules: logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.") - +import torch import comfy.utils import execution @@ -210,7 +210,7 @@ def execute_script(script_path): if args.enable_dynamic_vram or (enables_dynamic_vram() and comfy.model_management.is_nvidia() and not comfy.model_management.is_wsl()): if (not args.enable_dynamic_vram) and (comfy.model_management.torch_version_numeric < (2, 8)): logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows") - elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index): + elif comfy_aimdo.control.init_devices(range(torch.cuda.device_count())): if args.verbose == 'DEBUG': comfy_aimdo.control.set_log_debug() elif args.verbose == 'CRITICAL': diff --git a/requirements.txt b/requirements.txt index 1a8e1ea1ce98..c60219a88b65 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,7 +23,7 @@ SQLAlchemy filelock av>=14.2.0 comfy-kitchen>=0.2.8 -comfy-aimdo>=0.2.12 +comfy-aimdo==0.0.213 requests simpleeval>=1.0.0 blake3 From 37deccb0d4200efb986b11bf7240e65961fbaa00 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 20 Apr 2026 04:37:18 -0500 Subject: [PATCH 53/90] Fix Hunyuan 3D 2.1 multi-GPU worksplit: use cond_or_uncond instead of hardcoded chunk(2) (#13478) --- comfy/ldm/hunyuan3dv2_1/hunyuandit.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/comfy/ldm/hunyuan3dv2_1/hunyuandit.py b/comfy/ldm/hunyuan3dv2_1/hunyuandit.py index f67ba84e912c..61d1b3dc674a 100644 --- a/comfy/ldm/hunyuan3dv2_1/hunyuandit.py +++ b/comfy/ldm/hunyuan3dv2_1/hunyuandit.py @@ -607,9 +607,14 @@ def __init__( def forward(self, x, t, context, transformer_options = {}, **kwargs): x = x.movedim(-1, -2) - uncond_emb, cond_emb = context.chunk(2, dim = 0) - context = torch.cat([cond_emb, uncond_emb], dim = 0) + cond_or_uncond = transformer_options.get("cond_or_uncond", []) + swap_cfg_halves = len(cond_or_uncond) == 2 and set(cond_or_uncond) == {0, 1} + + if swap_cfg_halves: + first_half, second_half = context.chunk(2, dim = 0) + context = torch.cat([second_half, first_half], dim = 0) + main_condition = context t = 1.0 - t @@ -657,5 +662,8 @@ def block_wrap(args): output = self.final_layer(combined) output = output.movedim(-2, -1) * (-1.0) - cond_emb, uncond_emb = output.chunk(2, dim = 0) - return torch.cat([uncond_emb, cond_emb]) + if swap_cfg_halves: + first_half, second_half = output.chunk(2, dim = 0) + output = torch.cat([second_half, first_half], dim = 0) + + return output From 7b8b3673ff3551d7336be32045782afe6546bcf6 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Fri, 24 Apr 2026 12:09:56 +1000 Subject: [PATCH 54/90] comfy-aimdo: 0.0.214 (#13532) Cut pre-release 0.0.214 off aimdo master to pickup async mem accounting fix. --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index a8e4f9bf6014..d08980f81223 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,7 +23,7 @@ SQLAlchemy filelock av>=14.2.0 comfy-kitchen>=0.2.8 -comfy-aimdo==0.0.213 +comfy-aimdo==0.0.214 requests simpleeval>=1.0.0 blake3 From aa464b36b344ff78a94c176f08ed915eeb8410e9 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 23 Apr 2026 19:10:33 -0700 Subject: [PATCH 55/90] Multi-GPU device selection for loader nodes + CUDA context fixes (#13483) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix Hunyuan 3D 2.1 multi-GPU worksplit: use cond_or_uncond instead of hardcoded chunk(2) Amp-Thread-ID: https://ampcode.com/threads/T-019da964-2cc8-77f9-9aae-23f65da233db Co-authored-by: Amp * Add GPU device selection to all loader nodes - Add get_gpu_device_options() and resolve_gpu_device_option() helpers in model_management.py for vendor-agnostic GPU device selection - Add device widget to CheckpointLoaderSimple, UNETLoader, VAELoader - Expand device options in CLIPLoader, DualCLIPLoader, LTXAVTextEncoderLoader from [default, cpu] to include gpu:0, gpu:1, etc. on multi-GPU systems - Wire load_diffusion_model_state_dict and load_state_dict_guess_config to respect model_options['load_device'] - Graceful fallback: unrecognized devices (e.g. gpu:1 on single-GPU) silently fall back to default Amp-Thread-ID: https://ampcode.com/threads/T-019daa41-f394-731a-8955-4cff4f16283a Co-authored-by: Amp * Add VALIDATE_INPUTS to skip device combo validation for workflow portability When a workflow saved on a 2-GPU machine (with device=gpu:1) is loaded on a 1-GPU machine, the combo validation would reject the unknown value. VALIDATE_INPUTS with the device parameter bypasses combo validation for that input only, allowing resolve_gpu_device_option to handle the graceful fallback at runtime. Amp-Thread-ID: https://ampcode.com/threads/T-019daa41-f394-731a-8955-4cff4f16283a Co-authored-by: Amp * Set CUDA device context in outer_sample to match model load_device Custom CUDA kernels (comfy_kitchen fp8 quantization) use torch.cuda.current_device() for DLPack tensor export. When a model is loaded on a non-default GPU (e.g. cuda:1), the CUDA context must match or the kernel fails with 'Can't export tensors on a different CUDA device index'. Save and restore the previous device around sampling. Amp-Thread-ID: https://ampcode.com/threads/T-019daa41-f394-731a-8955-4cff4f16283a Co-authored-by: Amp * Fix code review bugs: negative index guard, CPU offload_device, checkpoint te_model_options - resolve_gpu_device_option: reject negative indices (gpu:-1) - UNETLoader: set offload_device when cpu is selected - CheckpointLoaderSimple: pass te_model_options for CLIP device, set offload_device for cpu, pass load_device to VAE - load_diffusion_model_state_dict: respect offload_device from model_options - load_state_dict_guess_config: respect offload_device, pass load_device to VAE Amp-Thread-ID: https://ampcode.com/threads/T-019daa41-f394-731a-8955-4cff4f16283a Co-authored-by: Amp * Fix CUDA device context for CLIP encoding and VAE encode/decode Add torch.cuda.set_device() calls to match model's load device in: - CLIP.encode_from_tokens: fixes 'Can't export tensors on a different CUDA device index' when CLIP is loaded on a non-default GPU - CLIP.encode_from_tokens_scheduled: same fix for the hooks code path - CLIP.generate: same fix for text generation - VAE.decode: fixes VAE decoding on non-default GPU - VAE.encode: fixes VAE encoding on non-default GPU Same pattern as the existing outer_sample fix in samplers.py - saves and restores previous CUDA device in a try/finally block. Amp-Thread-ID: https://ampcode.com/threads/T-019dabdc-8feb-766f-b4dc-f46ef4d8ff57 Co-authored-by: Amp * Extract cuda_device_context manager, fix tiled VAE methods Add model_management.cuda_device_context() — a context manager that saves/restores torch.cuda.current_device when operating on a non-default GPU. Replaces 6 copies of the manual save/set/restore boilerplate. Refactored call sites: - CLIP.encode_from_tokens - CLIP.encode_from_tokens_scheduled (hooks path) - CLIP.generate - VAE.decode - VAE.encode - samplers.outer_sample Bug fixes (newly wrapped): - VAE.decode_tiled: was missing device context entirely, would fail on non-default GPU when called from 'VAE Decode (Tiled)' node - VAE.encode_tiled: same issue for 'VAE Encode (Tiled)' node Amp-Thread-ID: https://ampcode.com/threads/T-019dabdc-8feb-766f-b4dc-f46ef4d8ff57 Co-authored-by: Amp * Restore CheckpointLoaderSimple, add CheckpointLoaderDevice Revert CheckpointLoaderSimple to its original form (no device input) so it remains the simple default loader. Add new CheckpointLoaderDevice node (advanced/loaders) with separate model_device, clip_device, and vae_device inputs for per-component GPU placement in multi-GPU setups. Amp-Thread-ID: https://ampcode.com/threads/T-019dabdc-8feb-766f-b4dc-f46ef4d8ff57 Co-authored-by: Amp --------- Co-authored-by: Amp --- comfy/model_management.py | 66 ++++++- comfy/samplers.py | 35 ++-- comfy/sd.py | 319 +++++++++++++++++---------------- comfy_extras/nodes_lt_audio.py | 10 +- nodes.py | 127 +++++++++++-- 5 files changed, 375 insertions(+), 182 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 46261a0eddb9..c7f6c4e6adc0 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -28,7 +28,7 @@ import weakref import gc import os -from contextlib import nullcontext +from contextlib import contextmanager, nullcontext import comfy.memory_management import comfy.utils import comfy.quant_ops @@ -231,6 +231,70 @@ def get_all_torch_devices(exclude_current=False): devices.remove(get_torch_device()) return devices +def get_gpu_device_options(): + """Return list of device option strings for node widgets. + + Always includes "default" and "cpu". When multiple GPUs are present, + adds "gpu:0", "gpu:1", etc. (vendor-agnostic labels). + """ + options = ["default", "cpu"] + devices = get_all_torch_devices() + if len(devices) > 1: + for i in range(len(devices)): + options.append(f"gpu:{i}") + return options + +def resolve_gpu_device_option(option: str): + """Resolve a device option string to a torch.device. + + Returns None for "default" (let the caller use its normal default). + Returns torch.device("cpu") for "cpu". + For "gpu:N", returns the Nth torch device. Falls back to None if + the index is out of range (caller should use default). + """ + if option is None or option == "default": + return None + if option == "cpu": + return torch.device("cpu") + if option.startswith("gpu:"): + try: + idx = int(option[4:]) + devices = get_all_torch_devices() + if 0 <= idx < len(devices): + return devices[idx] + else: + logging.warning(f"Device '{option}' not available (only {len(devices)} GPU(s)), using default.") + return None + except (ValueError, IndexError): + logging.warning(f"Invalid device option '{option}', using default.") + return None + logging.warning(f"Unrecognized device option '{option}', using default.") + return None + +@contextmanager +def cuda_device_context(device): + """Context manager that sets torch.cuda.current_device to match *device*. + + Used when running operations on a non-default CUDA device so that custom + CUDA kernels (e.g. comfy_kitchen fp8 quantization) pick up the correct + device index. The previous device is restored on exit. + + No-op when *device* is not CUDA, has no explicit index, or already matches + the current device. + """ + prev = None + if device.type == "cuda" and device.index is not None: + prev = torch.cuda.current_device() + if prev != device.index: + torch.cuda.set_device(device) + else: + prev = None + try: + yield + finally: + if prev is not None: + torch.cuda.set_device(prev) + def get_total_memory(dev=None, torch_total_too=False): global directml_enabled if dev is None: diff --git a/comfy/samplers.py b/comfy/samplers.py index 8ebf1c496a0e..88393e3673ca 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -1208,23 +1208,24 @@ def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, all_devices = [device] + extra_devices self.model_options["multigpu_thread_pool"] = comfy.multigpu.MultiGPUThreadPool(all_devices) - try: - noise = noise.to(device=device, dtype=torch.float32) - latent_image = latent_image.to(device=device, dtype=torch.float32) - sigmas = sigmas.to(device) - cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype()) - - self.model_patcher.pre_run() - for multigpu_patcher in multigpu_patchers: - multigpu_patcher.pre_run() - output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes) - finally: - thread_pool = self.model_options.pop("multigpu_thread_pool", None) - if thread_pool is not None: - thread_pool.shutdown() - self.model_patcher.cleanup() - for multigpu_patcher in multigpu_patchers: - multigpu_patcher.cleanup() + with comfy.model_management.cuda_device_context(device): + try: + noise = noise.to(device=device, dtype=torch.float32) + latent_image = latent_image.to(device=device, dtype=torch.float32) + sigmas = sigmas.to(device) + cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype()) + + self.model_patcher.pre_run() + for multigpu_patcher in multigpu_patchers: + multigpu_patcher.pre_run() + output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes) + finally: + thread_pool = self.model_options.pop("multigpu_thread_pool", None) + if thread_pool is not None: + thread_pool.shutdown() + self.model_patcher.cleanup() + for multigpu_patcher in multigpu_patchers: + multigpu_patcher.cleanup() comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models) del self.inner_model diff --git a/comfy/sd.py b/comfy/sd.py index 0ce450ace9ec..2417ac12137a 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -324,41 +324,43 @@ def encode_from_tokens_scheduled(self, tokens, unprojected=False, add_dict: dict self.cond_stage_model.set_clip_options({"projected_pooled": False}) self.load_model(tokens) - self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) + device = self.patcher.load_device + self.cond_stage_model.set_clip_options({"execution_device": device}) all_hooks.reset() self.patcher.patch_hooks(None) if show_pbar: pbar = ProgressBar(len(scheduled_keyframes)) - for scheduled_opts in scheduled_keyframes: - t_range = scheduled_opts[0] - # don't bother encoding any conds outside of start_percent and end_percent bounds - if "start_percent" in add_dict: - if t_range[1] < add_dict["start_percent"]: - continue - if "end_percent" in add_dict: - if t_range[0] > add_dict["end_percent"]: - continue - hooks_keyframes = scheduled_opts[1] - for hook, keyframe in hooks_keyframes: - hook.hook_keyframe._current_keyframe = keyframe - # apply appropriate hooks with values that match new hook_keyframe - self.patcher.patch_hooks(all_hooks) - # perform encoding as normal - o = self.cond_stage_model.encode_token_weights(tokens) - cond, pooled = o[:2] - pooled_dict = {"pooled_output": pooled} - # add clip_start_percent and clip_end_percent in pooled - pooled_dict["clip_start_percent"] = t_range[0] - pooled_dict["clip_end_percent"] = t_range[1] - # add/update any keys with the provided add_dict - pooled_dict.update(add_dict) - # add hooks stored on clip - self.add_hooks_to_dict(pooled_dict) - all_cond_pooled.append([cond, pooled_dict]) - if show_pbar: - pbar.update(1) - model_management.throw_exception_if_processing_interrupted() + with model_management.cuda_device_context(device): + for scheduled_opts in scheduled_keyframes: + t_range = scheduled_opts[0] + # don't bother encoding any conds outside of start_percent and end_percent bounds + if "start_percent" in add_dict: + if t_range[1] < add_dict["start_percent"]: + continue + if "end_percent" in add_dict: + if t_range[0] > add_dict["end_percent"]: + continue + hooks_keyframes = scheduled_opts[1] + for hook, keyframe in hooks_keyframes: + hook.hook_keyframe._current_keyframe = keyframe + # apply appropriate hooks with values that match new hook_keyframe + self.patcher.patch_hooks(all_hooks) + # perform encoding as normal + o = self.cond_stage_model.encode_token_weights(tokens) + cond, pooled = o[:2] + pooled_dict = {"pooled_output": pooled} + # add clip_start_percent and clip_end_percent in pooled + pooled_dict["clip_start_percent"] = t_range[0] + pooled_dict["clip_end_percent"] = t_range[1] + # add/update any keys with the provided add_dict + pooled_dict.update(add_dict) + # add hooks stored on clip + self.add_hooks_to_dict(pooled_dict) + all_cond_pooled.append([cond, pooled_dict]) + if show_pbar: + pbar.update(1) + model_management.throw_exception_if_processing_interrupted() all_hooks.reset() return all_cond_pooled @@ -372,8 +374,12 @@ def encode_from_tokens(self, tokens, return_pooled=False, return_dict=False): self.cond_stage_model.set_clip_options({"projected_pooled": False}) self.load_model(tokens) - self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) - o = self.cond_stage_model.encode_token_weights(tokens) + device = self.patcher.load_device + self.cond_stage_model.set_clip_options({"execution_device": device}) + + with model_management.cuda_device_context(device): + o = self.cond_stage_model.encode_token_weights(tokens) + cond, pooled = o[:2] if return_dict: out = {"cond": cond, "pooled_output": pooled} @@ -428,9 +434,12 @@ def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_ self.cond_stage_model.reset_clip_options() self.load_model(tokens) + device = self.patcher.load_device self.cond_stage_model.set_clip_options({"layer": None}) - self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) - return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty) + self.cond_stage_model.set_clip_options({"execution_device": device}) + + with model_management.cuda_device_context(device): + return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty) def decode(self, token_ids, skip_special_tokens=True): return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens) @@ -947,50 +956,52 @@ def decode(self, samples_in, vae_options={}): do_tile = False if self.latent_dim == 2 and samples_in.ndim == 5: samples_in = samples_in[:, :, 0] - try: - memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) - model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) - free_memory = self.patcher.get_free_memory(self.device) - batch_number = int(free_memory / memory_used) - batch_number = max(1, batch_number) - - # Pre-allocate output for VAEs that support direct buffer writes - preallocated = False - if getattr(self.first_stage_model, 'comfy_has_chunked_io', False): - pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype()) - preallocated = True - - for x in range(0, samples_in.shape[0], batch_number): - samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype) - if preallocated: - self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options) - else: - out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True) - if pixel_samples is None: - pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype()) - pixel_samples[x:x+batch_number].copy_(out) - del out - self.process_output(pixel_samples[x:x+batch_number]) - except Exception as e: - model_management.raise_non_oom(e) - logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") - #NOTE: We don't know what tensors were allocated to stack variables at the time of the - #exception and the exception itself refs them all until we get out of this except block. - #So we just set a flag for tiler fallback so that tensor gc can happen once the - #exception is fully off the books. - do_tile = True - - if do_tile: - comfy.model_management.soft_empty_cache() - dims = samples_in.ndim - 2 - if dims == 1 or self.extra_1d_channel is not None: - pixel_samples = self.decode_tiled_1d(samples_in) - elif dims == 2: - pixel_samples = self.decode_tiled_(samples_in) - elif dims == 3: - tile = 256 // self.spacial_compression_decode() - overlap = tile // 4 - pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) + + with model_management.cuda_device_context(self.device): + try: + memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) + model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) + free_memory = self.patcher.get_free_memory(self.device) + batch_number = int(free_memory / memory_used) + batch_number = max(1, batch_number) + + # Pre-allocate output for VAEs that support direct buffer writes + preallocated = False + if getattr(self.first_stage_model, 'comfy_has_chunked_io', False): + pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype()) + preallocated = True + + for x in range(0, samples_in.shape[0], batch_number): + samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype) + if preallocated: + self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options) + else: + out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True) + if pixel_samples is None: + pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype()) + pixel_samples[x:x+batch_number].copy_(out) + del out + self.process_output(pixel_samples[x:x+batch_number]) + except Exception as e: + model_management.raise_non_oom(e) + logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") + #NOTE: We don't know what tensors were allocated to stack variables at the time of the + #exception and the exception itself refs them all until we get out of this except block. + #So we just set a flag for tiler fallback so that tensor gc can happen once the + #exception is fully off the books. + do_tile = True + + if do_tile: + comfy.model_management.soft_empty_cache() + dims = samples_in.ndim - 2 + if dims == 1 or self.extra_1d_channel is not None: + pixel_samples = self.decode_tiled_1d(samples_in) + elif dims == 2: + pixel_samples = self.decode_tiled_(samples_in) + elif dims == 3: + tile = 256 // self.spacial_compression_decode() + overlap = tile // 4 + pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1) return pixel_samples @@ -1008,20 +1019,21 @@ def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=N if overlap is not None: args["overlap"] = overlap - if dims == 1 or self.extra_1d_channel is not None: - args.pop("tile_y") - output = self.decode_tiled_1d(samples, **args) - elif dims == 2: - output = self.decode_tiled_(samples, **args) - elif dims == 3: - if overlap_t is None: - args["overlap"] = (1, overlap, overlap) - else: - args["overlap"] = (max(1, overlap_t), overlap, overlap) - if tile_t is not None: - args["tile_t"] = max(2, tile_t) + with model_management.cuda_device_context(self.device): + if dims == 1 or self.extra_1d_channel is not None: + args.pop("tile_y") + output = self.decode_tiled_1d(samples, **args) + elif dims == 2: + output = self.decode_tiled_(samples, **args) + elif dims == 3: + if overlap_t is None: + args["overlap"] = (1, overlap, overlap) + else: + args["overlap"] = (max(1, overlap_t), overlap, overlap) + if tile_t is not None: + args["tile_t"] = max(2, tile_t) - output = self.decode_tiled_3d(samples, **args) + output = self.decode_tiled_3d(samples, **args) return output.movedim(1, -1) def encode(self, pixel_samples): @@ -1034,44 +1046,46 @@ def encode(self, pixel_samples): pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0) else: pixel_samples = pixel_samples.unsqueeze(2) - try: - memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) - model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) - free_memory = self.patcher.get_free_memory(self.device) - batch_number = int(free_memory / max(1, memory_used)) - batch_number = max(1, batch_number) - samples = None - for x in range(0, pixel_samples.shape[0], batch_number): - pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype) - if getattr(self.first_stage_model, 'comfy_has_chunked_io', False): - out = self.first_stage_model.encode(pixels_in, device=self.device) + + with model_management.cuda_device_context(self.device): + try: + memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) + model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) + free_memory = self.patcher.get_free_memory(self.device) + batch_number = int(free_memory / max(1, memory_used)) + batch_number = max(1, batch_number) + samples = None + for x in range(0, pixel_samples.shape[0], batch_number): + pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype) + if getattr(self.first_stage_model, 'comfy_has_chunked_io', False): + out = self.first_stage_model.encode(pixels_in, device=self.device) + else: + pixels_in = pixels_in.to(self.device) + out = self.first_stage_model.encode(pixels_in) + out = out.to(self.output_device).to(dtype=self.vae_output_dtype()) + if samples is None: + samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype()) + samples[x:x + batch_number] = out + + except Exception as e: + model_management.raise_non_oom(e) + logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") + #NOTE: We don't know what tensors were allocated to stack variables at the time of the + #exception and the exception itself refs them all until we get out of this except block. + #So we just set a flag for tiler fallback so that tensor gc can happen once the + #exception is fully off the books. + do_tile = True + + if do_tile: + comfy.model_management.soft_empty_cache() + if self.latent_dim == 3: + tile = 256 + overlap = tile // 4 + samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) + elif self.latent_dim == 1 or self.extra_1d_channel is not None: + samples = self.encode_tiled_1d(pixel_samples) else: - pixels_in = pixels_in.to(self.device) - out = self.first_stage_model.encode(pixels_in) - out = out.to(self.output_device).to(dtype=self.vae_output_dtype()) - if samples is None: - samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype()) - samples[x:x + batch_number] = out - - except Exception as e: - model_management.raise_non_oom(e) - logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") - #NOTE: We don't know what tensors were allocated to stack variables at the time of the - #exception and the exception itself refs them all until we get out of this except block. - #So we just set a flag for tiler fallback so that tensor gc can happen once the - #exception is fully off the books. - do_tile = True - - if do_tile: - comfy.model_management.soft_empty_cache() - if self.latent_dim == 3: - tile = 256 - overlap = tile // 4 - samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) - elif self.latent_dim == 1 or self.extra_1d_channel is not None: - samples = self.encode_tiled_1d(pixel_samples) - else: - samples = self.encode_tiled_(pixel_samples) + samples = self.encode_tiled_(pixel_samples) return samples @@ -1097,26 +1111,27 @@ def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, ti if overlap is not None: args["overlap"] = overlap - if dims == 1: - args.pop("tile_y") - samples = self.encode_tiled_1d(pixel_samples, **args) - elif dims == 2: - samples = self.encode_tiled_(pixel_samples, **args) - elif dims == 3: - if tile_t is not None: - tile_t_latent = max(2, self.downscale_ratio[0](tile_t)) - else: - tile_t_latent = 9999 - args["tile_t"] = self.upscale_ratio[0](tile_t_latent) + with model_management.cuda_device_context(self.device): + if dims == 1: + args.pop("tile_y") + samples = self.encode_tiled_1d(pixel_samples, **args) + elif dims == 2: + samples = self.encode_tiled_(pixel_samples, **args) + elif dims == 3: + if tile_t is not None: + tile_t_latent = max(2, self.downscale_ratio[0](tile_t)) + else: + tile_t_latent = 9999 + args["tile_t"] = self.upscale_ratio[0](tile_t_latent) - if overlap_t is None: - args["overlap"] = (1, overlap, overlap) - else: - args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap) - maximum = pixel_samples.shape[2] - maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum)) + if overlap_t is None: + args["overlap"] = (1, overlap, overlap) + else: + args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap) + maximum = pixel_samples.shape[2] + maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum)) - samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args) + samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args) return samples @@ -1633,7 +1648,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd) parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix) weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix) - load_device = model_management.get_torch_device() + load_device = model_options.get("load_device", model_management.get_torch_device()) custom_operations = model_options.get("custom_operations", None) if custom_operations is None: @@ -1673,13 +1688,15 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype) model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device) ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher - model_patcher = ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device()) + offload_device = model_options.get("offload_device", model_management.unet_offload_device()) + model_patcher = ModelPatcher(model, load_device=load_device, offload_device=offload_device) model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic()) if output_vae: vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True) vae_sd = model_config.process_vae_state_dict(vae_sd) - vae = VAE(sd=vae_sd, metadata=metadata) + vae_device = model_options.get("load_device", None) + vae = VAE(sd=vae_sd, metadata=metadata, device=vae_device) if output_clip: if te_model_options.get("custom_operations", None) is None: @@ -1763,7 +1780,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable parameters = comfy.utils.calculate_parameters(sd) weight_dtype = comfy.utils.weight_dtype(sd) - load_device = model_management.get_torch_device() + load_device = model_options.get("load_device", model_management.get_torch_device()) model_config = model_detection.model_config_from_unet(sd, "", metadata=metadata) if model_config is not None: @@ -1788,7 +1805,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable else: logging.warning("{} {}".format(diffusers_keys[k], k)) - offload_device = model_management.unet_offload_device() + offload_device = model_options.get("offload_device", model_management.unet_offload_device()) unet_weight_dtype = list(model_config.supported_inference_dtypes) if model_config.quant_config is not None: weight_dtype = None diff --git a/comfy_extras/nodes_lt_audio.py b/comfy_extras/nodes_lt_audio.py index 3e4222264f22..be0b1c887dbb 100644 --- a/comfy_extras/nodes_lt_audio.py +++ b/comfy_extras/nodes_lt_audio.py @@ -188,7 +188,7 @@ def define_schema(cls) -> io.Schema: ), io.Combo.Input( "device", - options=["default", "cpu"], + options=comfy.model_management.get_gpu_device_options(), advanced=True, ) ], @@ -203,8 +203,12 @@ def execute(cls, text_encoder, ckpt_name, device="default"): clip_path2 = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) model_options = {} - if device == "cpu": - model_options["load_device"] = model_options["offload_device"] = torch.device("cpu") + resolved = comfy.model_management.resolve_gpu_device_option(device) + if resolved is not None: + if resolved.type == "cpu": + model_options["load_device"] = model_options["offload_device"] = resolved + else: + model_options["load_device"] = resolved clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options) return io.NodeOutput(clip) diff --git a/nodes.py b/nodes.py index 9eced6838352..d81ac2935ebe 100644 --- a/nodes.py +++ b/nodes.py @@ -608,6 +608,73 @@ def load_checkpoint(self, ckpt_name): out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) return out[:3] + +class CheckpointLoaderDevice: + @classmethod + def INPUT_TYPES(s): + device_options = comfy.model_management.get_gpu_device_options() + return { + "required": { + "ckpt_name": (folder_paths.get_filename_list("checkpoints"), {"tooltip": "The name of the checkpoint (model) to load."}), + }, + "optional": { + "model_device": (device_options, {"advanced": True, "tooltip": "Device for the diffusion model (UNET)."}), + "clip_device": (device_options, {"advanced": True, "tooltip": "Device for the CLIP text encoder."}), + "vae_device": (device_options, {"advanced": True, "tooltip": "Device for the VAE."}), + } + } + RETURN_TYPES = ("MODEL", "CLIP", "VAE") + OUTPUT_TOOLTIPS = ("The model used for denoising latents.", + "The CLIP model used for encoding text prompts.", + "The VAE model used for encoding and decoding images to and from latent space.") + FUNCTION = "load_checkpoint" + + CATEGORY = "advanced/loaders" + DESCRIPTION = "Loads a diffusion model checkpoint with per-component device selection for multi-GPU setups." + + @classmethod + def VALIDATE_INPUTS(cls, model_device="default", clip_device="default", vae_device="default"): + return True + + def load_checkpoint(self, ckpt_name, model_device="default", clip_device="default", vae_device="default"): + ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) + + model_options = {} + resolved_model = comfy.model_management.resolve_gpu_device_option(model_device) + if resolved_model is not None: + if resolved_model.type == "cpu": + model_options["load_device"] = model_options["offload_device"] = resolved_model + else: + model_options["load_device"] = resolved_model + + te_model_options = {} + resolved_clip = comfy.model_management.resolve_gpu_device_option(clip_device) + if resolved_clip is not None: + if resolved_clip.type == "cpu": + te_model_options["load_device"] = te_model_options["offload_device"] = resolved_clip + else: + te_model_options["load_device"] = resolved_clip + + # VAE device is passed via model_options["load_device"] which + # load_state_dict_guess_config forwards to the VAE constructor. + # If vae_device differs from model_device, we override after loading. + resolved_vae = comfy.model_management.resolve_gpu_device_option(vae_device) + + out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"), model_options=model_options, te_model_options=te_model_options) + model_patcher, clip, vae = out[:3] + + # Apply VAE device override if it differs from the model device + if resolved_vae is not None and vae is not None: + vae.device = resolved_vae + if resolved_vae.type == "cpu": + offload = resolved_vae + else: + offload = comfy.model_management.vae_offload_device() + vae.patcher.load_device = resolved_vae + vae.patcher.offload_device = offload + + return (model_patcher, clip, vae) + class DiffusersLoader: SEARCH_ALIASES = ["load diffusers model"] @@ -807,14 +874,21 @@ def load_taesd(name): @classmethod def INPUT_TYPES(s): - return {"required": { "vae_name": (s.vae_list(s), )}} + return {"required": { "vae_name": (s.vae_list(s), )}, + "optional": { + "device": (comfy.model_management.get_gpu_device_options(), {"advanced": True}), + }} RETURN_TYPES = ("VAE",) FUNCTION = "load_vae" CATEGORY = "loaders" + @classmethod + def VALIDATE_INPUTS(cls, device="default"): + return True + #TODO: scale factor? - def load_vae(self, vae_name): + def load_vae(self, vae_name, device="default"): metadata = None if vae_name == "pixel_space": sd = {} @@ -827,7 +901,8 @@ def load_vae(self, vae_name): else: vae_path = folder_paths.get_full_path_or_raise("vae", vae_name) sd, metadata = comfy.utils.load_torch_file(vae_path, return_metadata=True) - vae = comfy.sd.VAE(sd=sd, metadata=metadata) + resolved = comfy.model_management.resolve_gpu_device_option(device) + vae = comfy.sd.VAE(sd=sd, metadata=metadata, device=resolved) vae.throw_exception_if_invalid() return (vae,) @@ -953,13 +1028,20 @@ class UNETLoader: def INPUT_TYPES(s): return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ), "weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"], {"advanced": True}) + }, + "optional": { + "device": (comfy.model_management.get_gpu_device_options(), {"advanced": True}), }} RETURN_TYPES = ("MODEL",) FUNCTION = "load_unet" CATEGORY = "advanced/loaders" - def load_unet(self, unet_name, weight_dtype): + @classmethod + def VALIDATE_INPUTS(cls, device="default"): + return True + + def load_unet(self, unet_name, weight_dtype, device="default"): model_options = {} if weight_dtype == "fp8_e4m3fn": model_options["dtype"] = torch.float8_e4m3fn @@ -969,6 +1051,13 @@ def load_unet(self, unet_name, weight_dtype): elif weight_dtype == "fp8_e5m2": model_options["dtype"] = torch.float8_e5m2 + resolved = comfy.model_management.resolve_gpu_device_option(device) + if resolved is not None: + if resolved.type == "cpu": + model_options["load_device"] = model_options["offload_device"] = resolved + else: + model_options["load_device"] = resolved + unet_path = folder_paths.get_full_path_or_raise("diffusion_models", unet_name) model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options) return (model,) @@ -980,7 +1069,7 @@ def INPUT_TYPES(s): "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image"], ), }, "optional": { - "device": (["default", "cpu"], {"advanced": True}), + "device": (comfy.model_management.get_gpu_device_options(), {"advanced": True}), }} RETURN_TYPES = ("CLIP",) FUNCTION = "load_clip" @@ -989,12 +1078,20 @@ def INPUT_TYPES(s): DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5\nomnigen2: qwen vl 2.5 3B" + @classmethod + def VALIDATE_INPUTS(cls, device="default"): + return True + def load_clip(self, clip_name, type="stable_diffusion", device="default"): clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION) model_options = {} - if device == "cpu": - model_options["load_device"] = model_options["offload_device"] = torch.device("cpu") + resolved = comfy.model_management.resolve_gpu_device_option(device) + if resolved is not None: + if resolved.type == "cpu": + model_options["load_device"] = model_options["offload_device"] = resolved + else: + model_options["load_device"] = resolved clip_path = folder_paths.get_full_path_or_raise("text_encoders", clip_name) clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options) @@ -1008,7 +1105,7 @@ def INPUT_TYPES(s): "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image", "ltxv", "newbie", "ace"], ), }, "optional": { - "device": (["default", "cpu"], {"advanced": True}), + "device": (comfy.model_management.get_gpu_device_options(), {"advanced": True}), }} RETURN_TYPES = ("CLIP",) FUNCTION = "load_clip" @@ -1017,6 +1114,10 @@ def INPUT_TYPES(s): DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama\nhunyuan_image: qwen2.5vl 7b and byt5 small\nnewbie: gemma-3-4b-it, jina clip v2" + @classmethod + def VALIDATE_INPUTS(cls, device="default"): + return True + def load_clip(self, clip_name1, clip_name2, type, device="default"): clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION) @@ -1024,8 +1125,12 @@ def load_clip(self, clip_name1, clip_name2, type, device="default"): clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2) model_options = {} - if device == "cpu": - model_options["load_device"] = model_options["offload_device"] = torch.device("cpu") + resolved = comfy.model_management.resolve_gpu_device_option(device) + if resolved is not None: + if resolved.type == "cpu": + model_options["load_device"] = model_options["offload_device"] = resolved + else: + model_options["load_device"] = resolved clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options) return (clip,) @@ -2098,6 +2203,7 @@ def expand_image(self, image, left, top, right, bottom, feathering): "InpaintModelConditioning": InpaintModelConditioning, "CheckpointLoader": CheckpointLoader, + "CheckpointLoaderDevice": CheckpointLoaderDevice, "DiffusersLoader": DiffusersLoader, "LoadLatent": LoadLatent, @@ -2115,6 +2221,7 @@ def expand_image(self, image, left, top, right, bottom, feathering): # Loaders "CheckpointLoader": "Load Checkpoint With Config (DEPRECATED)", "CheckpointLoaderSimple": "Load Checkpoint", + "CheckpointLoaderDevice": "Load Checkpoint (Device)", "VAELoader": "Load VAE", "LoraLoader": "Load LoRA (Model and CLIP)", "LoraLoaderModelOnly": "Load LoRA", From 1b96430c601dec6b4ffbc676adbc92894d2fa251 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 23 Apr 2026 19:20:14 -0700 Subject: [PATCH 56/90] Merge master into worksplit-multigpu (#13546) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: pin SQLAlchemy>=2.0 in requirements.txt (fixes #13036) (#13316) * Refactor io to IO in nodes_ace.py (#13485) * Bump comfyui-frontend-package to 1.42.12 (#13489) * Make the ltx audio vae more native. (#13486) * feat(api-nodes): add automatic downscaling of videos for ByteDance 2 nodes (#13465) * Support standalone LTXV audio VAEs (#13499) * [Partner Nodes] added 4K resolution for Veo models; added Veo 3 Lite model (#13330) * feat(api nodes): added 4K resolution for Veo models; added Veo 3 Lite model Signed-off-by: bigcat88 * increase poll_interval from 5 to 9 --------- Signed-off-by: bigcat88 Co-authored-by: Jedrzej Kosinski * Bump comfyui-frontend-package to 1.42.14 (#13493) * Add gpt-image-2 as version option (#13501) * Allow logging in comfy app files. (#13505) * chore: update workflow templates to v0.9.59 (#13507) * fix(veo): reject 4K resolution for veo-3.0 models in Veo3VideoGenerationNode (#13504) The tooltip on the resolution input states that 4K is not available for veo-3.1-lite or veo-3.0 models, but the execute guard only rejected the lite combination. Selecting 4K with veo-3.0-generate-001 or veo-3.0-fast-generate-001 would fall through and hit the upstream API with an invalid request. Broaden the guard to match the documented behavior and update the error message accordingly. Co-authored-by: Jedrzej Kosinski * feat: RIFE and FILM frame interpolation model support (CORE-29) (#13258) * initial RIFE support * Also support FILM * Better RAM usage, reduce FILM VRAM peak * Add model folder placeholder * Fix oom fallback frame loss * Remove torch.compile for now * Rename model input * Shorter input type name --------- * fix: use Parameter assignment for Stable_Zero123 cc_projection weights (fixes #13492) (#13518) On Windows with aimdo enabled, disable_weight_init.Linear uses lazy initialization that sets weight and bias to None to avoid unnecessary memory allocation. This caused a crash when copy_() was called on the None weight attribute in Stable_Zero123.__init__. Replace copy_() with direct torch.nn.Parameter assignment, which works correctly on both Windows (aimdo enabled) and other platforms. * Derive InterruptProcessingException from BaseException (#13523) * bump manager version to 4.2.1 (#13516) * ModelPatcherDynamic: force cast stray weights on comfy layers (#13487) the mixed_precision ops can have input_scale parameters that are used in tensor math but arent a weight or bias so dont get proper VRAM management. Treat these as force-castable parameters like the non comfy weight, random params are buffers already are. * Update logging level for invalid version format (#13526) * [Partner Nodes] add SD2 real human support (#13509) * feat(api-nodes): add SD2 real human support Signed-off-by: bigcat88 * fix: add validation before uploading Assets Signed-off-by: bigcat88 * Add asset_id and group_id displaying on the node Signed-off-by: bigcat88 * extend poll_op to use instead of custom async cycle Signed-off-by: bigcat88 * added the polling for the "Active" status after asset creation Signed-off-by: bigcat88 * updated tooltip for group_id * allow usage of real human in the ByteDance2FirstLastFrame node * add reference count limits * corrected price in status when input assets contain video Signed-off-by: bigcat88 --------- Signed-off-by: bigcat88 * feat: SAM (segment anything) 3.1 support (CORE-34) (#13408) * [Partner Nodes] GPTImage: fix price badges, add new resolutions (#13519) * fix(api-nodes): fixed price badges, add new resolutions Signed-off-by: bigcat88 * proper calculate the total run cost when "n > 1" Signed-off-by: bigcat88 --------- Signed-off-by: bigcat88 * chore: update workflow templates to v0.9.61 (#13533) * chore: update embedded docs to v0.4.4 (#13535) * add 4K resolution to Kling nodes (#13536) Signed-off-by: bigcat88 * Fix LTXV Reference Audio node (#13531) * comfy-aimdo 0.2.14: Hotfix async allocator estimations (#13534) This was doing an over-estimate of VRAM used by the async allocator when lots of little small tensors were in play. Also change the versioning scheme to == so we can roll forward aimdo without worrying about stable regressions downstream in comfyUI core. * Disable sageattention for SAM3 (#13529) Causes Nans * execution: Add anti-cycle validation (#13169) Currently if the graph contains a cycle, the just inifitiate recursions, hits a catch all then throws a generic error against the output node that seeded the validation. Instead, fail the offending cycling mode chain and handlng it as an error in its own right. Co-authored-by: guill * chore: update workflow templates to v0.9.62 (#13539) --------- Signed-off-by: bigcat88 Co-authored-by: Octopus Co-authored-by: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Co-authored-by: Comfy Org PR Bot Co-authored-by: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Co-authored-by: Jukka Seppänen <40791699+kijai@users.noreply.github.com> Co-authored-by: AustinMroz Co-authored-by: Daxiong (Lin) Co-authored-by: Matt Miller Co-authored-by: blepping <157360029+blepping@users.noreply.github.com> Co-authored-by: Dr.Lt.Data <128333288+ltdrdata@users.noreply.github.com> Co-authored-by: rattus <46076784+rattus128@users.noreply.github.com> Co-authored-by: guill --- comfy/ldm/lightricks/vae/audio_vae.py | 55 +- comfy/ldm/sam3/detector.py | 596 ++++++ comfy/ldm/sam3/sam.py | 425 ++++ comfy/ldm/sam3/tracker.py | 1785 +++++++++++++++++ comfy/model_base.py | 9 +- comfy/model_detection.py | 12 + comfy/model_management.py | 2 +- comfy/model_patcher.py | 17 +- comfy/sd.py | 19 + comfy/supported_models.py | 53 +- comfy/text_encoders/sam3_clip.py | 97 + comfy_api_nodes/apis/bytedance.py | 48 +- comfy_api_nodes/nodes_bytedance.py | 485 ++++- comfy_api_nodes/nodes_kling.py | 91 +- comfy_api_nodes/nodes_openai.py | 63 +- comfy_api_nodes/nodes_veo2.py | 171 +- comfy_api_nodes/util/__init__.py | 2 + comfy_api_nodes/util/client.py | 9 +- comfy_api_nodes/util/conversions.py | 104 +- .../frame_interpolation_models/film_net.py | 258 +++ .../frame_interpolation_models/ifnet.py | 128 ++ comfy_extras/nodes_ace.py | 104 +- comfy_extras/nodes_audio.py | 2 +- comfy_extras/nodes_frame_interpolation.py | 211 ++ comfy_extras/nodes_lt.py | 10 +- comfy_extras/nodes_lt_audio.py | 36 +- comfy_extras/nodes_sam3.py | 529 +++++ execution.py | 38 +- folder_paths.py | 2 + main.py | 4 +- manager_requirements.txt | 2 +- .../put_frame_interpolation_models_here | 0 nodes.py | 4 +- utils/install_util.py | 2 +- 34 files changed, 5123 insertions(+), 250 deletions(-) create mode 100644 comfy/ldm/sam3/detector.py create mode 100644 comfy/ldm/sam3/sam.py create mode 100644 comfy/ldm/sam3/tracker.py create mode 100644 comfy/text_encoders/sam3_clip.py create mode 100644 comfy_extras/frame_interpolation_models/film_net.py create mode 100644 comfy_extras/frame_interpolation_models/ifnet.py create mode 100644 comfy_extras/nodes_frame_interpolation.py create mode 100644 comfy_extras/nodes_sam3.py create mode 100644 models/frame_interpolation/put_frame_interpolation_models_here diff --git a/comfy/ldm/lightricks/vae/audio_vae.py b/comfy/ldm/lightricks/vae/audio_vae.py index fa0a00748983..dd5320c8f8fc 100644 --- a/comfy/ldm/lightricks/vae/audio_vae.py +++ b/comfy/ldm/lightricks/vae/audio_vae.py @@ -4,9 +4,6 @@ import torch import torchaudio -import comfy.model_management -import comfy.model_patcher -import comfy.utils as utils from comfy.ldm.mmaudio.vae.distributions import DiagonalGaussianDistribution from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier from comfy.ldm.lightricks.vae.causal_audio_autoencoder import ( @@ -43,30 +40,6 @@ def from_metadata(cls, metadata: dict) -> "AudioVAEComponentConfig": return cls(autoencoder=audio_config, vocoder=vocoder_config) - -class ModelDeviceManager: - """Manages device placement and GPU residency for the composed model.""" - - def __init__(self, module: torch.nn.Module): - load_device = comfy.model_management.get_torch_device() - offload_device = comfy.model_management.vae_offload_device() - self.patcher = comfy.model_patcher.ModelPatcher(module, load_device, offload_device) - - def ensure_model_loaded(self) -> None: - comfy.model_management.free_memory( - self.patcher.model_size(), - self.patcher.load_device, - ) - comfy.model_management.load_model_gpu(self.patcher) - - def move_to_load_device(self, tensor: torch.Tensor) -> torch.Tensor: - return tensor.to(self.patcher.load_device) - - @property - def load_device(self): - return self.patcher.load_device - - class AudioLatentNormalizer: """Applies per-channel statistics in patch space and restores original layout.""" @@ -132,23 +105,17 @@ def waveform_to_mel( class AudioVAE(torch.nn.Module): """High-level Audio VAE wrapper exposing encode and decode entry points.""" - def __init__(self, state_dict: dict, metadata: dict): + def __init__(self, metadata: dict): super().__init__() component_config = AudioVAEComponentConfig.from_metadata(metadata) - vae_sd = utils.state_dict_prefix_replace(state_dict, {"audio_vae.": ""}, filter_keys=True) - vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True) - self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder) if "bwe" in component_config.vocoder: self.vocoder = VocoderWithBWE(config=component_config.vocoder) else: self.vocoder = Vocoder(config=component_config.vocoder) - self.autoencoder.load_state_dict(vae_sd, strict=False) - self.vocoder.load_state_dict(vocoder_sd, strict=False) - autoencoder_config = self.autoencoder.get_config() self.normalizer = AudioLatentNormalizer( AudioPatchifier( @@ -168,18 +135,12 @@ def __init__(self, state_dict: dict, metadata: dict): n_fft=autoencoder_config["n_fft"], ) - self.device_manager = ModelDeviceManager(self) - - def encode(self, audio: dict) -> torch.Tensor: + def encode(self, audio, sample_rate=44100) -> torch.Tensor: """Encode a waveform dictionary into normalized latent tensors.""" - waveform = audio["waveform"] - waveform_sample_rate = audio["sample_rate"] + waveform = audio + waveform_sample_rate = sample_rate input_device = waveform.device - # Ensure that Audio VAE is loaded on the correct device. - self.device_manager.ensure_model_loaded() - - waveform = self.device_manager.move_to_load_device(waveform) expected_channels = self.autoencoder.encoder.in_channels if waveform.shape[1] != expected_channels: if waveform.shape[1] == 1: @@ -190,7 +151,7 @@ def encode(self, audio: dict) -> torch.Tensor: ) mel_spec = self.preprocessor.waveform_to_mel( - waveform, waveform_sample_rate, device=self.device_manager.load_device + waveform, waveform_sample_rate, device=waveform.device ) latents = self.autoencoder.encode(mel_spec) @@ -204,17 +165,13 @@ def decode(self, latents: torch.Tensor) -> torch.Tensor: """Decode normalized latent tensors into an audio waveform.""" original_shape = latents.shape - # Ensure that Audio VAE is loaded on the correct device. - self.device_manager.ensure_model_loaded() - - latents = self.device_manager.move_to_load_device(latents) latents = self.normalizer.denormalize(latents) target_shape = self.target_shape_from_latents(original_shape) mel_spec = self.autoencoder.decode(latents, target_shape=target_shape) waveform = self.run_vocoder(mel_spec) - return self.device_manager.move_to_load_device(waveform) + return waveform def target_shape_from_latents(self, latents_shape): batch, _, time, _ = latents_shape diff --git a/comfy/ldm/sam3/detector.py b/comfy/ldm/sam3/detector.py new file mode 100644 index 000000000000..12d3a01abf92 --- /dev/null +++ b/comfy/ldm/sam3/detector.py @@ -0,0 +1,596 @@ +# SAM3 detector: transformer encoder-decoder, segmentation head, geometry encoder, scoring. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.ops import roi_align + +from comfy.ldm.modules.attention import optimized_attention +from comfy.ldm.sam3.tracker import SAM3Tracker, SAM31Tracker +from comfy.ldm.sam3.sam import SAM3VisionBackbone # noqa: used in __init__ +from comfy.ldm.sam3.sam import MLP, PositionEmbeddingSine + +TRACKER_CLASSES = {"SAM3": SAM3Tracker, "SAM31": SAM31Tracker} +from comfy.ops import cast_to_input + + +def box_cxcywh_to_xyxy(x): + cx, cy, w, h = x.unbind(-1) + return torch.stack([cx - 0.5 * w, cy - 0.5 * h, cx + 0.5 * w, cy + 0.5 * h], dim=-1) + + +def gen_sineembed_for_position(pos_tensor, num_feats=256): + """Per-coordinate sinusoidal embedding: (..., N) -> (..., N * num_feats).""" + assert num_feats % 2 == 0 + hdim = num_feats // 2 + freqs = 10000.0 ** (2 * (torch.arange(hdim, dtype=torch.float32, device=pos_tensor.device) // 2) / hdim) + embeds = [] + for c in range(pos_tensor.shape[-1]): + raw = (pos_tensor[..., c].float() * 2 * math.pi).unsqueeze(-1) / freqs + embeds.append(torch.stack([raw[..., 0::2].sin(), raw[..., 1::2].cos()], dim=-1).flatten(-2)) + return torch.cat(embeds, dim=-1).to(pos_tensor.dtype) + + +class SplitMHA(nn.Module): + """Multi-head attention with separate Q/K/V projections (split from fused in_proj_weight).""" + def __init__(self, d_model, num_heads=8, device=None, dtype=None, operations=None): + super().__init__() + self.num_heads = num_heads + self.q_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.k_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.v_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.out_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + + def forward(self, q_input, k_input=None, v_input=None, mask=None): + q = self.q_proj(q_input) + if k_input is None: + k = self.k_proj(q_input) + v = self.v_proj(q_input) + else: + k = self.k_proj(k_input) + v = self.v_proj(v_input if v_input is not None else k_input) + if mask is not None and mask.ndim == 2: + mask = mask[:, None, None, :] # [B, T] -> [B, 1, 1, T] for SDPA broadcast + dtype = q.dtype # manual_cast may produce mixed dtypes + out = optimized_attention(q, k.to(dtype), v.to(dtype), self.num_heads, mask=mask, low_precision_attention=False) + return self.out_proj(out) + + +class MLPWithNorm(nn.Module): + """MLP with residual connection and output LayerNorm.""" + def __init__(self, input_dim, hidden_dim, output_dim, num_layers, residual=True, device=None, dtype=None, operations=None): + super().__init__() + dims = [input_dim] + [hidden_dim] * (num_layers - 1) + [output_dim] + self.layers = nn.ModuleList([ + operations.Linear(dims[i], dims[i + 1], device=device, dtype=dtype) + for i in range(num_layers) + ]) + self.out_norm = operations.LayerNorm(output_dim, device=device, dtype=dtype) + self.residual = residual and (input_dim == output_dim) + + def forward(self, x): + orig = x + for i, layer in enumerate(self.layers): + x = layer(x) + if i < len(self.layers) - 1: + x = F.relu(x) + if self.residual: + x = x + orig + return self.out_norm(x) + + +class EncoderLayer(nn.Module): + def __init__(self, d_model=256, num_heads=8, dim_ff=2048, device=None, dtype=None, operations=None): + super().__init__() + self.self_attn = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations) + self.cross_attn_image = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations) + self.linear1 = operations.Linear(d_model, dim_ff, device=device, dtype=dtype) + self.linear2 = operations.Linear(dim_ff, d_model, device=device, dtype=dtype) + self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.norm2 = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.norm3 = operations.LayerNorm(d_model, device=device, dtype=dtype) + + def forward(self, x, pos, text_memory=None, text_mask=None): + normed = self.norm1(x) + q_k = normed + pos + x = x + self.self_attn(q_k, q_k, normed) + if text_memory is not None: + normed = self.norm2(x) + x = x + self.cross_attn_image(normed, text_memory, text_memory, mask=text_mask) + normed = self.norm3(x) + x = x + self.linear2(F.relu(self.linear1(normed))) + return x + + +class TransformerEncoder(nn.Module): + """Checkpoint: transformer.encoder.layers.N.*""" + def __init__(self, d_model=256, num_heads=8, dim_ff=2048, num_layers=6, device=None, dtype=None, operations=None): + super().__init__() + self.layers = nn.ModuleList([ + EncoderLayer(d_model, num_heads, dim_ff, device=device, dtype=dtype, operations=operations) + for _ in range(num_layers) + ]) + + def forward(self, x, pos, text_memory=None, text_mask=None): + for layer in self.layers: + x = layer(x, pos, text_memory, text_mask) + return x + + +class DecoderLayer(nn.Module): + def __init__(self, d_model=256, num_heads=8, dim_ff=2048, device=None, dtype=None, operations=None): + super().__init__() + self.self_attn = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations) + self.cross_attn = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations) + self.ca_text = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations) + self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.norm2 = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.norm3 = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.catext_norm = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.linear1 = operations.Linear(d_model, dim_ff, device=device, dtype=dtype) + self.linear2 = operations.Linear(dim_ff, d_model, device=device, dtype=dtype) + + def forward(self, x, memory, x_pos, memory_pos, text_memory=None, text_mask=None, cross_attn_bias=None): + q_k = x + x_pos + x = self.norm2(x + self.self_attn(q_k, q_k, x)) + if text_memory is not None: + x = self.catext_norm(x + self.ca_text(x + x_pos, text_memory, text_memory, mask=text_mask)) + x = self.norm1(x + self.cross_attn(x + x_pos, memory + memory_pos, memory, mask=cross_attn_bias)) + x = self.norm3(x + self.linear2(F.relu(self.linear1(x)))) + return x + + +class TransformerDecoder(nn.Module): + def __init__(self, d_model=256, num_heads=8, dim_ff=2048, num_layers=6, + num_queries=200, device=None, dtype=None, operations=None): + super().__init__() + self.d_model = d_model + self.num_queries = num_queries + + self.layers = nn.ModuleList([ + DecoderLayer(d_model, num_heads, dim_ff, device=device, dtype=dtype, operations=operations) + for _ in range(num_layers) + ]) + self.norm = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.query_embed = operations.Embedding(num_queries, d_model, device=device, dtype=dtype) + self.reference_points = operations.Embedding(num_queries, 4, device=device, dtype=dtype) # Reference points: Embedding(num_queries, 4) — learned anchor boxes + self.ref_point_head = MLP(d_model * 2, d_model, d_model, 2, device=device, dtype=dtype, operations=operations) # ref_point_head input: 512 (4 coords * 128 sine features each) + self.bbox_embed = MLP(d_model, d_model, 4, 3, device=device, dtype=dtype, operations=operations) + + self.boxRPB_embed_x = MLP(2, d_model, num_heads, 2, device=device, dtype=dtype, operations=operations) + self.boxRPB_embed_y = MLP(2, d_model, num_heads, 2, device=device, dtype=dtype, operations=operations) + + self.presence_token = operations.Embedding(1, d_model, device=device, dtype=dtype) + self.presence_token_head = MLP(d_model, d_model, 1, 3, device=device, dtype=dtype, operations=operations) + self.presence_token_out_norm = operations.LayerNorm(d_model, device=device, dtype=dtype) + + @staticmethod + def _inverse_sigmoid(x): + return torch.log(x / (1 - x + 1e-6) + 1e-6) + + def _compute_box_rpb(self, ref_points, H, W): + """Box rotary position bias: (B, Q, 4) cxcywh -> (B, n_heads, Q+1, H*W) bias.""" + boxes_xyxy = box_cxcywh_to_xyxy(ref_points) + B, Q, _ = boxes_xyxy.shape + coords_h = torch.arange(H, device=ref_points.device, dtype=torch.float32) / H + coords_w = torch.arange(W, device=ref_points.device, dtype=torch.float32) / W + deltas_x = coords_w.view(1, 1, -1, 1) - boxes_xyxy[:, :, None, 0:3:2] + deltas_y = coords_h.view(1, 1, -1, 1) - boxes_xyxy[:, :, None, 1:4:2] + + log2_8 = float(math.log2(8)) + def log_scale(d): + return torch.sign(d * 8) * torch.log2(torch.abs(d * 8) + 1.0) / log2_8 + + rpb_x = self.boxRPB_embed_x(log_scale(deltas_x).to(ref_points.dtype)) + rpb_y = self.boxRPB_embed_y(log_scale(deltas_y).to(ref_points.dtype)) + + bias = (rpb_y.unsqueeze(3) + rpb_x.unsqueeze(2)).flatten(2, 3).permute(0, 3, 1, 2) + pres_bias = torch.zeros(B, bias.shape[1], 1, bias.shape[3], device=bias.device, dtype=bias.dtype) + return torch.cat([pres_bias, bias], dim=2) + + def forward(self, memory, memory_pos, text_memory=None, text_mask=None, H=72, W=72): + B = memory.shape[0] + tgt = cast_to_input(self.query_embed.weight, memory).unsqueeze(0).expand(B, -1, -1) + presence_out = cast_to_input(self.presence_token.weight, memory)[None].expand(B, -1, -1) + ref_points = cast_to_input(self.reference_points.weight, memory).unsqueeze(0).expand(B, -1, -1).sigmoid() + + for layer_idx, layer in enumerate(self.layers): + query_pos = self.ref_point_head(gen_sineembed_for_position(ref_points, self.d_model)) + tgt_with_pres = torch.cat([presence_out, tgt], dim=1) + pos_with_pres = torch.cat([torch.zeros_like(presence_out), query_pos], dim=1) + tgt_with_pres = layer(tgt_with_pres, memory, pos_with_pres, memory_pos, + text_memory, text_mask, self._compute_box_rpb(ref_points, H, W)) + presence_out, tgt = tgt_with_pres[:, :1], tgt_with_pres[:, 1:] + if layer_idx < len(self.layers) - 1: + ref_inv = self._inverse_sigmoid(ref_points) + ref_points = (ref_inv + self.bbox_embed(self.norm(tgt))).sigmoid().detach() + + query_out = self.norm(tgt) + ref_inv = self._inverse_sigmoid(ref_points) + boxes = (ref_inv + self.bbox_embed(query_out)).sigmoid() + presence = self.presence_token_head(self.presence_token_out_norm(presence_out)).squeeze(-1) + return {"decoder_output": query_out, "pred_boxes": boxes, "presence": presence} + + +class Transformer(nn.Module): + def __init__(self, d_model=256, num_heads=8, dim_ff=2048, enc_layers=6, dec_layers=6, + num_queries=200, device=None, dtype=None, operations=None): + super().__init__() + self.encoder = TransformerEncoder(d_model, num_heads, dim_ff, enc_layers, device=device, dtype=dtype, operations=operations) + self.decoder = TransformerDecoder(d_model, num_heads, dim_ff, dec_layers, num_queries, device=device, dtype=dtype, operations=operations) + + +class GeometryEncoder(nn.Module): + def __init__(self, d_model=256, num_heads=8, num_layers=3, roi_size=7, device=None, dtype=None, operations=None): + super().__init__() + self.d_model = d_model + self.roi_size = roi_size + self.pos_enc = PositionEmbeddingSine(num_pos_feats=d_model, normalize=True) + self.points_direct_project = operations.Linear(2, d_model, device=device, dtype=dtype) + self.points_pool_project = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.points_pos_enc_project = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.boxes_direct_project = operations.Linear(4, d_model, device=device, dtype=dtype) + self.boxes_pool_project = operations.Conv2d(d_model, d_model, kernel_size=roi_size, device=device, dtype=dtype) + self.boxes_pos_enc_project = operations.Linear(d_model + 2, d_model, device=device, dtype=dtype) + self.label_embed = operations.Embedding(2, d_model, device=device, dtype=dtype) + self.cls_embed = operations.Embedding(1, d_model, device=device, dtype=dtype) + self.norm = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.img_pre_norm = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.encode = nn.ModuleList([ + EncoderLayer(d_model, num_heads, 2048, device=device, dtype=dtype, operations=operations) + for _ in range(num_layers) + ]) + self.encode_norm = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.final_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + + def _encode_points(self, coords, labels, img_feat_2d): + """Encode point prompts: direct + pool + pos_enc + label. coords: [B, N, 2] normalized.""" + B, N, _ = coords.shape + embed = self.points_direct_project(coords) + # Pool features from backbone at point locations via grid_sample + grid = (coords * 2 - 1).unsqueeze(2) # [B, N, 1, 2] in [-1, 1] + sampled = F.grid_sample(img_feat_2d, grid, align_corners=False) # [B, C, N, 1] + embed = embed + self.points_pool_project(sampled.squeeze(-1).permute(0, 2, 1)) # [B, N, C] + # Positional encoding of coordinates + x, y = coords[:, :, 0], coords[:, :, 1] # [B, N] + pos_x, pos_y = self.pos_enc._encode_xy(x.flatten(), y.flatten()) + enc = torch.cat([pos_x, pos_y], dim=-1).view(B, N, -1) + embed = embed + self.points_pos_enc_project(cast_to_input(enc, embed)) + embed = embed + cast_to_input(self.label_embed(labels.long()), embed) + return embed + + def _encode_boxes(self, boxes, labels, img_feat_2d): + """Encode box prompts: direct + pool + pos_enc + label. boxes: [B, N, 4] normalized cxcywh.""" + B, N, _ = boxes.shape + embed = self.boxes_direct_project(boxes) + # ROI align from backbone at box regions + H, W = img_feat_2d.shape[-2:] + boxes_xyxy = box_cxcywh_to_xyxy(boxes) + scale = torch.tensor([W, H, W, H], dtype=boxes_xyxy.dtype, device=boxes_xyxy.device) + boxes_scaled = boxes_xyxy * scale + sampled = roi_align(img_feat_2d, boxes_scaled.view(-1, 4).split(N), self.roi_size) + proj = self.boxes_pool_project(sampled).view(B, N, -1) # Conv2d(roi_size) -> [B*N, C, 1, 1] -> [B, N, C] + embed = embed + proj + # Positional encoding of box center + size + cx, cy, w, h = boxes[:, :, 0], boxes[:, :, 1], boxes[:, :, 2], boxes[:, :, 3] + enc = self.pos_enc.encode_boxes(cx.flatten(), cy.flatten(), w.flatten(), h.flatten()) + enc = enc.view(B, N, -1) + embed = embed + self.boxes_pos_enc_project(cast_to_input(enc, embed)) + embed = embed + cast_to_input(self.label_embed(labels.long()), embed) + return embed + + def forward(self, points=None, boxes=None, image_features=None): + """Encode geometry prompts. image_features: [B, HW, C] flattened backbone features.""" + # Prepare 2D image features for pooling + img_feat_2d = None + if image_features is not None: + B = image_features.shape[0] + HW, C = image_features.shape[1], image_features.shape[2] + hw = int(math.sqrt(HW)) + img_normed = self.img_pre_norm(image_features) + img_feat_2d = img_normed.permute(0, 2, 1).view(B, C, hw, hw) + + embeddings = [] + if points is not None: + coords, labels = points + embeddings.append(self._encode_points(coords, labels, img_feat_2d)) + if boxes is not None: + B = boxes.shape[0] + box_labels = torch.ones(B, boxes.shape[1], dtype=torch.long, device=boxes.device) + embeddings.append(self._encode_boxes(boxes, box_labels, img_feat_2d)) + if not embeddings: + return None + geo = torch.cat(embeddings, dim=1) + geo = self.norm(geo) + if image_features is not None: + for layer in self.encode: + geo = layer(geo, torch.zeros_like(geo), image_features) + geo = self.encode_norm(geo) + return self.final_proj(geo) + + +class PixelDecoder(nn.Module): + """Top-down FPN pixel decoder with GroupNorm + ReLU + nearest interpolation.""" + def __init__(self, d_model=256, num_stages=3, device=None, dtype=None, operations=None): + super().__init__() + self.conv_layers = nn.ModuleList([operations.Conv2d(d_model, d_model, kernel_size=3, padding=1, device=device, dtype=dtype) for _ in range(num_stages)]) + self.norms = nn.ModuleList([operations.GroupNorm(8, d_model, device=device, dtype=dtype) for _ in range(num_stages)]) + + def forward(self, backbone_features): + prev = backbone_features[-1] + for i, feat in enumerate(backbone_features[:-1][::-1]): + prev = F.relu(self.norms[i](self.conv_layers[i](feat + F.interpolate(prev, size=feat.shape[-2:], mode="nearest")))) + return prev + + +class MaskPredictor(nn.Module): + def __init__(self, d_model=256, device=None, dtype=None, operations=None): + super().__init__() + self.mask_embed = MLP(d_model, d_model, d_model, 3, device=device, dtype=dtype, operations=operations) + + def forward(self, query_embeddings, pixel_features): + mask_embed = self.mask_embed(query_embeddings) + return torch.einsum("bqc,bchw->bqhw", mask_embed, pixel_features) + + +class SegmentationHead(nn.Module): + def __init__(self, d_model=256, num_heads=8, device=None, dtype=None, operations=None): + super().__init__() + self.d_model = d_model + self.pixel_decoder = PixelDecoder(d_model, 3, device=device, dtype=dtype, operations=operations) + self.mask_predictor = MaskPredictor(d_model, device=device, dtype=dtype, operations=operations) + self.cross_attend_prompt = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations) + self.cross_attn_norm = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.instance_seg_head = operations.Conv2d(d_model, d_model, kernel_size=1, device=device, dtype=dtype) + self.semantic_seg_head = operations.Conv2d(d_model, 1, kernel_size=1, device=device, dtype=dtype) + + def forward(self, query_embeddings, backbone_features, encoder_hidden_states=None, prompt=None, prompt_mask=None): + if encoder_hidden_states is not None and prompt is not None: + enc_normed = self.cross_attn_norm(encoder_hidden_states) + enc_cross = self.cross_attend_prompt(enc_normed, prompt, prompt, mask=prompt_mask) + encoder_hidden_states = enc_cross + encoder_hidden_states + + if encoder_hidden_states is not None: + B, H, W = encoder_hidden_states.shape[0], backbone_features[-1].shape[-2], backbone_features[-1].shape[-1] + encoder_visual = encoder_hidden_states[:, :H * W].permute(0, 2, 1).view(B, self.d_model, H, W) + backbone_features = list(backbone_features) + backbone_features[-1] = encoder_visual + + pixel_features = self.pixel_decoder(backbone_features) + instance_features = self.instance_seg_head(pixel_features) + masks = self.mask_predictor(query_embeddings, instance_features) + return masks + + +class DotProductScoring(nn.Module): + def __init__(self, d_model=256, device=None, dtype=None, operations=None): + super().__init__() + self.hs_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.prompt_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.prompt_mlp = MLPWithNorm(d_model, 2048, d_model, 2, device=device, dtype=dtype, operations=operations) + self.scale = 1.0 / (d_model ** 0.5) + + def forward(self, query_embeddings, prompt_embeddings, prompt_mask=None): + prompt = self.prompt_mlp(prompt_embeddings) + if prompt_mask is not None: + weight = prompt_mask.unsqueeze(-1).to(dtype=prompt.dtype) + pooled = (prompt * weight).sum(dim=1) / weight.sum(dim=1).clamp(min=1) + else: + pooled = prompt.mean(dim=1) + hs = self.hs_proj(query_embeddings) + pp = self.prompt_proj(pooled).unsqueeze(-1).to(hs.dtype) + scores = torch.matmul(hs, pp) + return (scores * self.scale).clamp(-12.0, 12.0).squeeze(-1) + + +class SAM3Detector(nn.Module): + def __init__(self, d_model=256, embed_dim=1024, num_queries=200, device=None, dtype=None, operations=None, **kwargs): + super().__init__() + image_model = kwargs.pop("image_model", "SAM3") + for k in ("num_heads", "num_head_channels"): + kwargs.pop(k, None) + multiplex = image_model == "SAM31" + # SAM3: 4 FPN levels, drop last (scalp=1); SAM3.1: 3 levels, use all (scalp=0) + self.scalp = 0 if multiplex else 1 + self.backbone = nn.ModuleDict({ + "vision_backbone": SAM3VisionBackbone(embed_dim=embed_dim, d_model=d_model, multiplex=multiplex, device=device, dtype=dtype, operations=operations, **kwargs), + "language_backbone": nn.ModuleDict({"resizer": operations.Linear(embed_dim, d_model, device=device, dtype=dtype)}), + }) + self.transformer = Transformer(d_model=d_model, num_queries=num_queries, device=device, dtype=dtype, operations=operations) + self.segmentation_head = SegmentationHead(d_model=d_model, device=device, dtype=dtype, operations=operations) + self.geometry_encoder = GeometryEncoder(d_model=d_model, device=device, dtype=dtype, operations=operations) + self.dot_prod_scoring = DotProductScoring(d_model=d_model, device=device, dtype=dtype, operations=operations) + + def _get_backbone_features(self, images): + """Run backbone and return (detector_features, detector_positions, tracker_features, tracker_positions).""" + bb = self.backbone["vision_backbone"] + if bb.multiplex: + all_f, all_p, tf, tp = bb(images, tracker_mode="propagation") + else: + all_f, all_p, tf, tp = bb(images, need_tracker=True) + return all_f, all_p, tf, tp + + @staticmethod + def _run_geo_layer(layer, x, memory, memory_pos): + x = x + layer.self_attn(layer.norm1(x)) + x = x + layer.cross_attn_image(layer.norm2(x), memory + memory_pos, memory) + x = x + layer.linear2(F.relu(layer.linear1(layer.norm3(x)))) + return x + + def _detect(self, features, positions, text_embeddings=None, text_mask=None, + points=None, boxes=None): + """Shared detection: geometry encoding, transformer, scoring, segmentation.""" + B = features[0].shape[0] + # Scalp for encoder (use top-level feature), but keep all levels for segmentation head + seg_features = features + if self.scalp > 0: + features = features[:-self.scalp] + positions = positions[:-self.scalp] + enc_feat, enc_pos = features[-1], positions[-1] + _, _, H, W = enc_feat.shape + img_flat = enc_feat.flatten(2).permute(0, 2, 1) + pos_flat = enc_pos.flatten(2).permute(0, 2, 1) + + has_prompts = text_embeddings is not None or points is not None or boxes is not None + if has_prompts: + geo_enc = self.geometry_encoder + geo_prompts = geo_enc(points=points, boxes=boxes, image_features=img_flat) + geo_cls = geo_enc.norm(geo_enc.final_proj(cast_to_input(geo_enc.cls_embed.weight, img_flat).view(1, 1, -1).expand(B, -1, -1))) + for layer in geo_enc.encode: + geo_cls = self._run_geo_layer(layer, geo_cls, img_flat, pos_flat) + geo_cls = geo_enc.encode_norm(geo_cls) + if text_embeddings is not None and text_embeddings.shape[0] != B: + text_embeddings = text_embeddings.expand(B, -1, -1) + if text_mask is not None and text_mask.shape[0] != B: + text_mask = text_mask.expand(B, -1) + parts = [t for t in [text_embeddings, geo_prompts, geo_cls] if t is not None] + text_embeddings = torch.cat(parts, dim=1) + n_new = text_embeddings.shape[1] - (text_mask.shape[1] if text_mask is not None else 0) + if text_mask is not None: + text_mask = torch.cat([text_mask, torch.ones(B, n_new, dtype=torch.bool, device=text_mask.device)], dim=1) + else: + text_mask = torch.ones(B, text_embeddings.shape[1], dtype=torch.bool, device=text_embeddings.device) + + memory = self.transformer.encoder(img_flat, pos_flat, text_embeddings, text_mask) + dec_out = self.transformer.decoder(memory, pos_flat, text_embeddings, text_mask, H, W) + query_out, pred_boxes = dec_out["decoder_output"], dec_out["pred_boxes"] + + if text_embeddings is not None: + scores = self.dot_prod_scoring(query_out, text_embeddings, text_mask) + else: + scores = torch.zeros(B, query_out.shape[1], device=query_out.device) + + masks = self.segmentation_head(query_out, seg_features, encoder_hidden_states=memory, prompt=text_embeddings, prompt_mask=text_mask) + return box_cxcywh_to_xyxy(pred_boxes), scores, masks, dec_out + + def forward(self, images, text_embeddings=None, text_mask=None, points=None, boxes=None, threshold=0.3, orig_size=None): + features, positions, _, _ = self._get_backbone_features(images) + + if text_embeddings is not None: + text_embeddings = self.backbone["language_backbone"]["resizer"](text_embeddings) + if text_mask is not None: + text_mask = text_mask.bool() + + boxes_xyxy, scores, masks, dec_out = self._detect( + features, positions, text_embeddings, text_mask, points, boxes) + + if orig_size is not None: + oh, ow = orig_size + boxes_xyxy = boxes_xyxy * torch.tensor([ow, oh, ow, oh], device=boxes_xyxy.device, dtype=boxes_xyxy.dtype) + masks = F.interpolate(masks, size=orig_size, mode="bilinear", align_corners=False) + + return { + "boxes": boxes_xyxy, + "scores": scores, + "masks": masks, + "presence": dec_out.get("presence"), + } + + def forward_from_trunk(self, trunk_out, text_embeddings, text_mask): + """Run detection using a pre-computed ViTDet trunk output. + + text_embeddings must already be resized through language_backbone.resizer. + Returns dict with boxes (normalized xyxy), scores, masks at detector resolution. + """ + bb = self.backbone["vision_backbone"] + features = [conv(trunk_out) for conv in bb.convs] + positions = [cast_to_input(bb.position_encoding(f), f) for f in features] + + if text_mask is not None: + text_mask = text_mask.bool() + + boxes_xyxy, scores, masks, _ = self._detect(features, positions, text_embeddings, text_mask) + return {"boxes": boxes_xyxy, "scores": scores, "masks": masks} + + +class SAM3Model(nn.Module): + def __init__(self, device=None, dtype=None, operations=None, **kwargs): + super().__init__() + self.dtype = dtype + image_model = kwargs.get("image_model", "SAM3") + tracker_cls = TRACKER_CLASSES[image_model] + self.detector = SAM3Detector(device=device, dtype=dtype, operations=operations, **kwargs) + self.tracker = tracker_cls(device=device, dtype=dtype, operations=operations, **kwargs) + + def forward(self, images, **kwargs): + return self.detector(images, **kwargs) + + def forward_segment(self, images, point_inputs=None, box_inputs=None, mask_inputs=None): + """Interactive segmentation using SAM decoder with point/box/mask prompts. + + Args: + images: [B, 3, 1008, 1008] preprocessed images + point_inputs: {"point_coords": [B, N, 2], "point_labels": [B, N]} in 1008x1008 pixel space + box_inputs: [B, 2, 2] box corners (top-left, bottom-right) in 1008x1008 pixel space + mask_inputs: [B, 1, H, W] coarse mask logits to refine + Returns: + [B, 1, image_size, image_size] high-res mask logits + """ + bb = self.detector.backbone["vision_backbone"] + if bb.multiplex: + _, _, tracker_features, tracker_positions = bb(images, tracker_mode="interactive") + else: + _, _, tracker_features, tracker_positions = bb(images, need_tracker=True) + if self.detector.scalp > 0: + tracker_features = tracker_features[:-self.detector.scalp] + tracker_positions = tracker_positions[:-self.detector.scalp] + + high_res = list(tracker_features[:-1]) + backbone_feat = tracker_features[-1] + B, C, H, W = backbone_feat.shape + # Add no-memory embedding (init frame path) + no_mem = getattr(self.tracker, 'interactivity_no_mem_embed', None) + if no_mem is None: + no_mem = getattr(self.tracker, 'no_mem_embed', None) + if no_mem is not None: + feat_flat = backbone_feat.flatten(2).permute(0, 2, 1) + feat_flat = feat_flat + cast_to_input(no_mem, feat_flat) + backbone_feat = feat_flat.view(B, H, W, C).permute(0, 3, 1, 2) + + num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1) + _, high_res_masks, _, _ = self.tracker._forward_sam_heads( + backbone_features=backbone_feat, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + box_inputs=box_inputs, + high_res_features=high_res, + multimask_output=(0 < num_pts <= 1), + ) + return high_res_masks + + def forward_video(self, images, initial_masks, pbar=None, text_prompts=None, + new_det_thresh=0.5, max_objects=0, detect_interval=1): + """Track video with optional per-frame text-prompted detection.""" + bb = self.detector.backbone["vision_backbone"] + + def backbone_fn(frame, frame_idx=None): + trunk_out = bb.trunk(frame) + if bb.multiplex: + _, _, tf, tp = bb(frame, tracker_mode="propagation", cached_trunk=trunk_out, tracker_only=True) + else: + _, _, tf, tp = bb(frame, need_tracker=True, cached_trunk=trunk_out, tracker_only=True) + return tf, tp, trunk_out + + detect_fn = None + if text_prompts: + resizer = self.detector.backbone["language_backbone"]["resizer"] + resized = [(resizer(emb), m.bool() if m is not None else None) for emb, m in text_prompts] + def detect_fn(trunk_out): + all_scores, all_masks = [], [] + for emb, mask in resized: + det = self.detector.forward_from_trunk(trunk_out, emb, mask) + all_scores.append(det["scores"]) + all_masks.append(det["masks"]) + return {"scores": torch.cat(all_scores, dim=1), "masks": torch.cat(all_masks, dim=1)} + + if hasattr(self.tracker, 'track_video_with_detection'): + return self.tracker.track_video_with_detection( + backbone_fn, images, initial_masks, detect_fn, + new_det_thresh=new_det_thresh, max_objects=max_objects, + detect_interval=detect_interval, backbone_obj=bb, pbar=pbar) + # SAM3 (non-multiplex) — no detection support, requires initial masks + if initial_masks is None: + raise ValueError("SAM3 (non-multiplex) requires initial_mask for video tracking") + return self.tracker.track_video(backbone_fn, images, initial_masks, pbar=pbar, backbone_obj=bb) diff --git a/comfy/ldm/sam3/sam.py b/comfy/ldm/sam3/sam.py new file mode 100644 index 000000000000..75cb457cff53 --- /dev/null +++ b/comfy/ldm/sam3/sam.py @@ -0,0 +1,425 @@ +# SAM3 shared components: primitives, ViTDet backbone, FPN neck, position encodings. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from comfy.ldm.modules.attention import optimized_attention +from comfy.ldm.flux.math import apply_rope +from comfy.ldm.flux.layers import EmbedND +from comfy.ops import cast_to_input + + +class MLP(nn.Module): + def __init__(self, input_dim, hidden_dim, output_dim, num_layers, sigmoid_output=False, device=None, dtype=None, operations=None): + super().__init__() + dims = [input_dim] + [hidden_dim] * (num_layers - 1) + [output_dim] + self.layers = nn.ModuleList([operations.Linear(dims[i], dims[i + 1], device=device, dtype=dtype) for i in range(num_layers)]) + self.sigmoid_output = sigmoid_output + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < len(self.layers) - 1 else layer(x) + return torch.sigmoid(x) if self.sigmoid_output else x + + +class SAMAttention(nn.Module): + def __init__(self, embedding_dim, num_heads, downsample_rate=1, kv_in_dim=None, device=None, dtype=None, operations=None): + super().__init__() + self.num_heads = num_heads + internal_dim = embedding_dim // downsample_rate + kv_dim = kv_in_dim if kv_in_dim is not None else embedding_dim + self.q_proj = operations.Linear(embedding_dim, internal_dim, device=device, dtype=dtype) + self.k_proj = operations.Linear(kv_dim, internal_dim, device=device, dtype=dtype) + self.v_proj = operations.Linear(kv_dim, internal_dim, device=device, dtype=dtype) + self.out_proj = operations.Linear(internal_dim, embedding_dim, device=device, dtype=dtype) + + def forward(self, q, k, v): + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + return self.out_proj(optimized_attention(q, k, v, self.num_heads, low_precision_attention=False)) + + +class TwoWayAttentionBlock(nn.Module): + def __init__(self, embedding_dim, num_heads, mlp_dim=2048, attention_downsample_rate=2, skip_first_layer_pe=False, device=None, dtype=None, operations=None): + super().__init__() + self.skip_first_layer_pe = skip_first_layer_pe + self.self_attn = SAMAttention(embedding_dim, num_heads, device=device, dtype=dtype, operations=operations) + self.cross_attn_token_to_image = SAMAttention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate, device=device, dtype=dtype, operations=operations) + self.cross_attn_image_to_token = SAMAttention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate, device=device, dtype=dtype, operations=operations) + self.mlp = nn.Sequential(operations.Linear(embedding_dim, mlp_dim, device=device, dtype=dtype), nn.ReLU(), operations.Linear(mlp_dim, embedding_dim, device=device, dtype=dtype)) + self.norm1 = operations.LayerNorm(embedding_dim, device=device, dtype=dtype) + self.norm2 = operations.LayerNorm(embedding_dim, device=device, dtype=dtype) + self.norm3 = operations.LayerNorm(embedding_dim, device=device, dtype=dtype) + self.norm4 = operations.LayerNorm(embedding_dim, device=device, dtype=dtype) + + def forward(self, queries, keys, query_pe, key_pe): + if self.skip_first_layer_pe: + queries = self.norm1(self.self_attn(queries, queries, queries)) + else: + q = queries + query_pe + queries = self.norm1(queries + self.self_attn(q, q, queries)) + q, k = queries + query_pe, keys + key_pe + queries = self.norm2(queries + self.cross_attn_token_to_image(q, k, keys)) + queries = self.norm3(queries + self.mlp(queries)) + q, k = queries + query_pe, keys + key_pe + keys = self.norm4(keys + self.cross_attn_image_to_token(k, q, queries)) + return queries, keys + + +class TwoWayTransformer(nn.Module): + def __init__(self, depth=2, embedding_dim=256, num_heads=8, mlp_dim=2048, attention_downsample_rate=2, device=None, dtype=None, operations=None): + super().__init__() + self.layers = nn.ModuleList([ + TwoWayAttentionBlock(embedding_dim, num_heads, mlp_dim, attention_downsample_rate, + skip_first_layer_pe=(i == 0), device=device, dtype=dtype, operations=operations) + for i in range(depth) + ]) + self.final_attn_token_to_image = SAMAttention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate, device=device, dtype=dtype, operations=operations) + self.norm_final = operations.LayerNorm(embedding_dim, device=device, dtype=dtype) + + def forward(self, image_embedding, image_pe, point_embedding): + queries, keys = point_embedding, image_embedding + for layer in self.layers: + queries, keys = layer(queries, keys, point_embedding, image_pe) + q, k = queries + point_embedding, keys + image_pe + queries = self.norm_final(queries + self.final_attn_token_to_image(q, k, keys)) + return queries, keys + + +class PositionEmbeddingRandom(nn.Module): + """Fourier feature positional encoding with random gaussian projection.""" + def __init__(self, num_pos_feats=64, scale=None): + super().__init__() + self.register_buffer("positional_encoding_gaussian_matrix", (scale or 1.0) * torch.randn(2, num_pos_feats)) + + def _encode(self, normalized_coords): + """Map normalized [0,1] coordinates to fourier features via random projection. Computes in fp32.""" + orig_dtype = normalized_coords.dtype + proj_matrix = self.positional_encoding_gaussian_matrix.to(device=normalized_coords.device, dtype=torch.float32) + projected = 2 * math.pi * (2 * normalized_coords.float() - 1) @ proj_matrix + return torch.cat([projected.sin(), projected.cos()], dim=-1).to(orig_dtype) + + def forward(self, size, device=None): + h, w = size + dev = device if device is not None else self.positional_encoding_gaussian_matrix.device + ones = torch.ones((h, w), device=dev, dtype=torch.float32) + norm_xy = torch.stack([(ones.cumsum(1) - 0.5) / w, (ones.cumsum(0) - 0.5) / h], dim=-1) + return self._encode(norm_xy).permute(2, 0, 1).unsqueeze(0) + + def forward_with_coords(self, pixel_coords, image_size): + norm = pixel_coords.clone() + norm[:, :, 0] /= image_size[1] + norm[:, :, 1] /= image_size[0] + return self._encode(norm) + + +# ViTDet backbone + FPN neck + +def window_partition(x: torch.Tensor, window_size: int): + B, H, W, C = x.shape + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows, (Hp, Wp) + + +def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw, hw): + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +def rope_2d(end_x: int, end_y: int, dim: int, theta: float = 10000.0, scale_pos: float = 1.0): + """Generate 2D axial RoPE using flux EmbedND. Returns [1, 1, HW, dim//2, 2, 2].""" + t = torch.arange(end_x * end_y, dtype=torch.float32) + ids = torch.stack([(t % end_x) * scale_pos, + torch.div(t, end_x, rounding_mode="floor") * scale_pos], dim=-1) + return EmbedND(dim=dim, theta=theta, axes_dim=[dim // 2, dim // 2])(ids.unsqueeze(0)) + + +class _ViTMLP(nn.Module): + def __init__(self, dim, mlp_ratio=4.0, device=None, dtype=None, operations=None): + super().__init__() + hidden = int(dim * mlp_ratio) + self.fc1 = operations.Linear(dim, hidden, device=device, dtype=dtype) + self.act = nn.GELU() + self.fc2 = operations.Linear(hidden, dim, device=device, dtype=dtype) + + def forward(self, x): + return self.fc2(self.act(self.fc1(x))) + + +class Attention(nn.Module): + """ViTDet multi-head attention with fused QKV projection.""" + + def __init__(self, dim, num_heads=8, qkv_bias=True, use_rope=False, device=None, dtype=None, operations=None): + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.use_rope = use_rope + self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, device=device, dtype=dtype) + self.proj = operations.Linear(dim, dim, device=device, dtype=dtype) + + def forward(self, x, freqs_cis=None): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim) + q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(dim=0) + if self.use_rope and freqs_cis is not None: + q, k = apply_rope(q, k, freqs_cis) + return self.proj(optimized_attention(q, k, v, self.num_heads, skip_reshape=True, low_precision_attention=False)) + + +class Block(nn.Module): + def __init__(self, dim, num_heads, mlp_ratio=4.0, qkv_bias=True, window_size=0, use_rope=False, device=None, dtype=None, operations=None): + super().__init__() + self.window_size = window_size + self.norm1 = operations.LayerNorm(dim, device=device, dtype=dtype) + self.attn = Attention(dim, num_heads, qkv_bias, use_rope, device=device, dtype=dtype, operations=operations) + self.norm2 = operations.LayerNorm(dim, device=device, dtype=dtype) + self.mlp = _ViTMLP(dim, mlp_ratio, device=device, dtype=dtype, operations=operations) + + def forward(self, x, freqs_cis=None): + shortcut = x + x = self.norm1(x) + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + x = x.view(x.shape[0], self.window_size * self.window_size, -1) + x = self.attn(x, freqs_cis=freqs_cis) + x = x.view(-1, self.window_size, self.window_size, x.shape[-1]) + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + else: + B, H, W, C = x.shape + x = x.view(B, H * W, C) + x = self.attn(x, freqs_cis=freqs_cis) + x = x.view(B, H, W, C) + x = shortcut + x + x = x + self.mlp(self.norm2(x)) + return x + + +class PatchEmbed(nn.Module): + def __init__(self, patch_size=14, in_chans=3, embed_dim=1024, device=None, dtype=None, operations=None): + super().__init__() + self.proj = operations.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=False, device=device, dtype=dtype) + + def forward(self, x): + return self.proj(x) + + +class ViTDet(nn.Module): + def __init__(self, img_size=1008, patch_size=14, embed_dim=1024, depth=32, num_heads=16, mlp_ratio=4.625, qkv_bias=True, window_size=24, + global_att_blocks=(7, 15, 23, 31), use_rope=True, pretrain_img_size=336, device=None, dtype=None, operations=None, **kwargs): + super().__init__() + self.img_size = img_size + self.patch_size = patch_size + self.embed_dim = embed_dim + self.num_heads = num_heads + self.global_att_blocks = set(global_att_blocks) + + self.patch_embed = PatchEmbed(patch_size, 3, embed_dim, device=device, dtype=dtype, operations=operations) + + num_patches = (pretrain_img_size // patch_size) ** 2 + 1 # +1 for cls token + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim, device=device, dtype=dtype)) + + self.ln_pre = operations.LayerNorm(embed_dim, device=device, dtype=dtype) + + grid_size = img_size // patch_size + pretrain_grid = pretrain_img_size // patch_size + + self.blocks = nn.ModuleList() + for i in range(depth): + is_global = i in self.global_att_blocks + self.blocks.append(Block( + embed_dim, num_heads, mlp_ratio, qkv_bias, + window_size=0 if is_global else window_size, + use_rope=use_rope, + device=device, dtype=dtype, operations=operations, + )) + + if use_rope: + rope_scale = pretrain_grid / grid_size + self.register_buffer("freqs_cis", rope_2d(grid_size, grid_size, embed_dim // num_heads, scale_pos=rope_scale), persistent=False) + self.register_buffer("freqs_cis_window", rope_2d(window_size, window_size, embed_dim // num_heads), persistent=False) + else: + self.freqs_cis = None + self.freqs_cis_window = None + + def _get_pos_embed(self, num_tokens): + pos = self.pos_embed + if pos.shape[1] == num_tokens: + return pos + cls_pos = pos[:, :1] + spatial_pos = pos[:, 1:] + old_size = int(math.sqrt(spatial_pos.shape[1])) + new_size = int(math.sqrt(num_tokens - 1)) if num_tokens > 1 else old_size + spatial_2d = spatial_pos.reshape(1, old_size, old_size, -1).permute(0, 3, 1, 2) + tiles_h = new_size // old_size + 1 + tiles_w = new_size // old_size + 1 + tiled = spatial_2d.tile([1, 1, tiles_h, tiles_w])[:, :, :new_size, :new_size] + tiled = tiled.permute(0, 2, 3, 1).reshape(1, new_size * new_size, -1) + return torch.cat([cls_pos, tiled], dim=1) + + def forward(self, x): + x = self.patch_embed(x) + B, C, Hp, Wp = x.shape + x = x.permute(0, 2, 3, 1).reshape(B, Hp * Wp, C) + + pos = cast_to_input(self._get_pos_embed(Hp * Wp + 1), x) + x = x + pos[:, 1:Hp * Wp + 1] + + x = x.view(B, Hp, Wp, C) + x = self.ln_pre(x) + + freqs_cis_global = self.freqs_cis + freqs_cis_win = self.freqs_cis_window + if freqs_cis_global is not None: + freqs_cis_global = cast_to_input(freqs_cis_global, x) + if freqs_cis_win is not None: + freqs_cis_win = cast_to_input(freqs_cis_win, x) + + for block in self.blocks: + fc = freqs_cis_win if block.window_size > 0 else freqs_cis_global + x = block(x, freqs_cis=fc) + + return x.permute(0, 3, 1, 2) + + +class FPNScaleConv(nn.Module): + def __init__(self, in_dim, out_dim, scale, device=None, dtype=None, operations=None): + super().__init__() + if scale == 4.0: + self.dconv_2x2_0 = operations.ConvTranspose2d(in_dim, in_dim // 2, kernel_size=2, stride=2, device=device, dtype=dtype) + self.dconv_2x2_1 = operations.ConvTranspose2d(in_dim // 2, in_dim // 4, kernel_size=2, stride=2, device=device, dtype=dtype) + proj_in = in_dim // 4 + elif scale == 2.0: + self.dconv_2x2 = operations.ConvTranspose2d(in_dim, in_dim // 2, kernel_size=2, stride=2, device=device, dtype=dtype) + proj_in = in_dim // 2 + elif scale == 1.0: + proj_in = in_dim + elif scale == 0.5: + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + proj_in = in_dim + self.scale = scale + self.conv_1x1 = operations.Conv2d(proj_in, out_dim, kernel_size=1, device=device, dtype=dtype) + self.conv_3x3 = operations.Conv2d(out_dim, out_dim, kernel_size=3, padding=1, device=device, dtype=dtype) + + def forward(self, x): + if self.scale == 4.0: + x = F.gelu(self.dconv_2x2_0(x)) + x = self.dconv_2x2_1(x) + elif self.scale == 2.0: + x = self.dconv_2x2(x) + elif self.scale == 0.5: + x = self.pool(x) + x = self.conv_1x1(x) + x = self.conv_3x3(x) + return x + + +class PositionEmbeddingSine(nn.Module): + """2D sinusoidal position encoding (DETR-style) with result caching.""" + def __init__(self, num_pos_feats=256, temperature=10000.0, normalize=True, scale=None): + super().__init__() + assert num_pos_feats % 2 == 0 + self.half_dim = num_pos_feats // 2 + self.temperature = temperature + self.normalize = normalize + self.scale = scale if scale is not None else 2 * math.pi + self._cache = {} + + def _sincos(self, vals): + """Encode 1D values to interleaved sin/cos features.""" + freqs = self.temperature ** (2 * (torch.arange(self.half_dim, dtype=torch.float32, device=vals.device) // 2) / self.half_dim) + raw = vals[..., None] * self.scale / freqs + return torch.stack((raw[..., 0::2].sin(), raw[..., 1::2].cos()), dim=-1).flatten(-2) + + def _encode_xy(self, x, y): + """Encode normalized x, y coordinates to sinusoidal features. Returns (pos_x, pos_y) each [N, half_dim].""" + dim_t = self.temperature ** (2 * (torch.arange(self.half_dim, dtype=torch.float32, device=x.device) // 2) / self.half_dim) + pos_x = x[:, None] * self.scale / dim_t + pos_y = y[:, None] * self.scale / dim_t + pos_x = torch.stack((pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2).flatten(1) + pos_y = torch.stack((pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2).flatten(1) + return pos_x, pos_y + + def encode_boxes(self, cx, cy, w, h): + """Encode box center + size to [N, d_model+2] features.""" + pos_x, pos_y = self._encode_xy(cx, cy) + return torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1) + + def forward(self, x): + B, C, H, W = x.shape + key = (H, W, x.device) + if key not in self._cache: + gy = torch.arange(H, dtype=torch.float32, device=x.device) + gx = torch.arange(W, dtype=torch.float32, device=x.device) + if self.normalize: + gy, gx = gy / (H - 1 + 1e-6), gx / (W - 1 + 1e-6) + yy, xx = torch.meshgrid(gy, gx, indexing="ij") + self._cache[key] = torch.cat((self._sincos(yy), self._sincos(xx)), dim=-1).permute(2, 0, 1).unsqueeze(0) + return self._cache[key].expand(B, -1, -1, -1) + + +class SAM3VisionBackbone(nn.Module): + def __init__(self, embed_dim=1024, d_model=256, multiplex=False, device=None, dtype=None, operations=None, **kwargs): + super().__init__() + self.trunk = ViTDet(embed_dim=embed_dim, device=device, dtype=dtype, operations=operations, **kwargs) + self.position_encoding = PositionEmbeddingSine(num_pos_feats=d_model, normalize=True) + self.multiplex = multiplex + + fpn_args = dict(device=device, dtype=dtype, operations=operations) + if multiplex: + scales = [4.0, 2.0, 1.0] + self.convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales]) + self.propagation_convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales]) + self.interactive_convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales]) + else: + scales = [4.0, 2.0, 1.0, 0.5] + self.convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales]) + self.sam2_convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales]) + + def forward(self, images, need_tracker=False, tracker_mode=None, cached_trunk=None, tracker_only=False): + backbone_out = cached_trunk if cached_trunk is not None else self.trunk(images) + + if tracker_only: + # Skip detector FPN when only tracker features are needed (video tracking) + if self.multiplex: + tracker_convs = self.propagation_convs if tracker_mode == "propagation" else self.interactive_convs + else: + tracker_convs = self.sam2_convs + tracker_features = [conv(backbone_out) for conv in tracker_convs] + tracker_positions = [cast_to_input(self.position_encoding(f), f) for f in tracker_features] + return None, None, tracker_features, tracker_positions + + features = [conv(backbone_out) for conv in self.convs] + positions = [cast_to_input(self.position_encoding(f), f) for f in features] + + if self.multiplex: + if tracker_mode == "propagation": + tracker_convs = self.propagation_convs + elif tracker_mode == "interactive": + tracker_convs = self.interactive_convs + else: + return features, positions, None, None + elif need_tracker: + tracker_convs = self.sam2_convs + else: + return features, positions, None, None + + tracker_features = [conv(backbone_out) for conv in tracker_convs] + tracker_positions = [cast_to_input(self.position_encoding(f), f) for f in tracker_features] + return features, positions, tracker_features, tracker_positions diff --git a/comfy/ldm/sam3/tracker.py b/comfy/ldm/sam3/tracker.py new file mode 100644 index 000000000000..8f7481003cf0 --- /dev/null +++ b/comfy/ldm/sam3/tracker.py @@ -0,0 +1,1785 @@ +# SAM3 video tracker: memory encoder, memory attention, SAM mask decoder/prompt encoder. + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from tqdm import tqdm + +try: + import cv2 + _HAS_CV2 = True +except ImportError: + from scipy import ndimage + _HAS_CV2 = False + +import comfy.model_management +from comfy.ldm.modules.attention import optimized_attention +from comfy.ldm.sam3.sam import rope_2d, PositionEmbeddingSine +from comfy.ops import cast_to_input +from comfy.ldm.flux.math import apply_rope1 +from comfy.ldm.cascade.common import LayerNorm2d_op +from comfy.ldm.sam3.sam import MLP, PositionEmbeddingRandom +from comfy.ldm.sam3.sam import TwoWayTransformer as SAMTwoWayTransformer + +NO_OBJ_SCORE = -1024.0 + + +def to_spatial(x, H, W): + """Reshape (B, H*W, C) → (B, C, H, W).""" + return x.view(x.shape[0], H, W, -1).permute(0, 3, 1, 2) + +class MultiplexState: + """Tracks object-to-slot assignments for multiplex tracking. Provides mux/demux operations.""" + + def __init__(self, num_objects, multiplex_count, device, dtype): + self.multiplex_count = multiplex_count + self.device = device + self.dtype = dtype + self._build(num_objects) + + def mux(self, x): + """[N_obj, ...] -> [num_buckets, multiplex_count, ...]""" + out_shape = (self.num_buckets, self.multiplex_count) + x.shape[1:] + return (self.mux_matrix.to(device=x.device, dtype=x.dtype) @ x.reshape(self.total_valid_entries, -1)).view(out_shape) + + def demux(self, x): + """[num_buckets, multiplex_count, ...] -> [N_obj, ...]""" + out_shape = (self.total_valid_entries,) + x.shape[2:] + flat = x.reshape(self.num_buckets * self.multiplex_count, -1) + return (self.demux_matrix.to(device=x.device, dtype=x.dtype) @ flat).view(out_shape) + + def get_valid_object_mask(self): + """[num_buckets, multiplex_count] bool tensor, True for valid slots.""" + return (self.mux_matrix.sum(dim=1) > 0).reshape(self.num_buckets, self.multiplex_count) + + def _build(self, num_objects): + M = self.multiplex_count + self.num_buckets = (num_objects + M - 1) // M + self.total_valid_entries = num_objects + total_slots = self.num_buckets * M + self.mux_matrix = torch.zeros(total_slots, num_objects, device=self.device, dtype=self.dtype) + self.demux_matrix = torch.zeros(num_objects, total_slots, device=self.device, dtype=self.dtype) + oids = torch.arange(num_objects, device=self.device) + slots = (oids // M) * M + (oids % M) + self.mux_matrix[slots, oids] = 1.0 + self.demux_matrix[oids, slots] = 1.0 + + def add_objects(self, n_new): + """Grow multiplex state for n_new additional objects.""" + self._build(self.total_valid_entries + n_new) + +def _compute_mask_overlap(masks_a, masks_b): + """Max of IoU and IoM (intersection over minimum area). More robust to size differences.""" + a_flat = (masks_a > 0).float().flatten(1) + b_flat = (masks_b > 0).float().flatten(1) + intersection = a_flat @ b_flat.T + area_a = a_flat.sum(1, keepdim=True) + area_b = b_flat.sum(1, keepdim=True).T + iou = intersection / (area_a + area_b - intersection).clamp(min=1) + iom = intersection / torch.min(area_a.expand_as(iou), area_b.expand_as(iou)).clamp(min=1) + return torch.max(iou, iom) + + +def _nms_masks(masks, scores, thresh=0.5): + """Mask-based NMS using IoU+IoM overlap. Returns (filtered_masks, filtered_scores).""" + order = scores.argsort(descending=True) + masks, scores = masks[order], scores[order] + keep = [] + for i in range(masks.shape[0]): + if keep: + if _compute_mask_overlap(masks[i:i+1], masks[torch.tensor(keep, device=masks.device)]).max() >= thresh: + continue + keep.append(i) + return masks[keep], scores[keep] + + +def _get_connected_components(mask_bin): + """Get connected component labels and areas. mask_bin: [B, 1, H, W] uint8.""" + labels_list, areas_list = [], [] + for i in range(mask_bin.shape[0]): + m = mask_bin[i, 0].cpu().numpy() + if _HAS_CV2: + _, labeled, stats, _ = cv2.connectedComponentsWithStats(m, connectivity=8) + areas = stats[labeled, cv2.CC_STAT_AREA].astype('int32') + else: + labeled, num_features = ndimage.label(m) + areas = np.zeros_like(m, dtype=np.int32) + for c in range(1, num_features + 1): + component = labeled == c + areas[component] = component.sum() + labels_list.append(torch.from_numpy(labeled).to(mask_bin.device)) + areas_list.append(torch.from_numpy(areas).to(device=mask_bin.device, dtype=torch.int32)) + return torch.stack(labels_list).unsqueeze(1), torch.stack(areas_list).unsqueeze(1) + + +def fill_holes_in_mask_scores(mask, max_area=0): + """Remove small foreground sprinkles and fill small background holes using connected components.""" + if max_area <= 0: + return mask + + # Fill holes: small connected components in background → foreground + mask_bg = (mask <= 0).to(torch.uint8) + _, areas_bg = _get_connected_components(mask_bg) + small_bg = mask_bg.bool() & (areas_bg <= max_area) + mask = torch.where(small_bg, 0.1, mask) + + # Remove sprinkles: small connected components in foreground → background + # Only remove if area < min(max_area, half of total foreground area) + mask_fg = (mask > 0).to(torch.uint8) + fg_area_thresh = mask_fg.sum(dim=(2, 3), keepdim=True, dtype=torch.int32) + fg_area_thresh.floor_divide_(2).clamp_(max=max_area) + _, areas_fg = _get_connected_components(mask_fg) + small_fg = mask_fg.bool() & (areas_fg <= fg_area_thresh) + mask = torch.where(small_fg, -0.1, mask) + + return mask + + +def apply_rope_memory(q, k, freqs, num_heads, num_k_exclude_rope=0): + """Apply 2D axial RoPE to memory attention using flux rope format. + + Args: + q: [B, Nq, C] projected queries (current frame features) + k: [B, Nk, C] projected keys (memory tokens) + freqs: [1, Nq, dim//2, 2, 2] flux-format rotation matrices for one frame + num_heads: number of attention heads + num_k_exclude_rope: number of trailing k tokens to skip RoPE (object pointers) + """ + B, Nq, C = q.shape + head_dim = C // num_heads + + # freqs shape: [1, 1, Nq, dim//2, 2, 2] (heads broadcast dim already included) + q_h = q.view(B, Nq, num_heads, head_dim).transpose(1, 2) + q_h = apply_rope1(q_h, freqs) + q = q_h.transpose(1, 2).reshape(B, Nq, C) + + # Apply RoPE to k (excluding last num_k_exclude_rope tokens) + Nk = k.shape[1] + num_k_rope = Nk - num_k_exclude_rope + if num_k_rope > 0: + # Repeat freqs for multiple frames of spatial memory + Nf = freqs.shape[2] # spatial positions in one frame + if num_k_rope > Nf: + r = (num_k_rope + Nf - 1) // Nf + pe_k = freqs.repeat(1, 1, r, 1, 1, 1)[:, :, :num_k_rope] + else: + pe_k = freqs[:, :, :num_k_rope] + + k_h = k[:, :num_k_rope].view(B, num_k_rope, num_heads, head_dim).transpose(1, 2) + k_h = apply_rope1(k_h, pe_k) + k = k.clone() + k[:, :num_k_rope] = k_h.transpose(1, 2).reshape(B, num_k_rope, C) + + return q, k + + +def get_1d_sine_pe(pos_inds, dim, temperature=10000): + """1D sinusoidal positional encoding for temporal positions.""" + pe_dim = dim // 2 + dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device) + dim_t = temperature ** (2 * (dim_t // 2) / pe_dim) + pos_embed = pos_inds.unsqueeze(-1) / dim_t + return torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1) + + +def _pad_to_buckets(tensor, target_buckets): + """Pad a [num_buckets, ...] tensor to target_buckets along dim 0 if needed.""" + if tensor.shape[0] >= target_buckets: + return tensor + pad_shape = (target_buckets - tensor.shape[0],) + tensor.shape[1:] + return torch.cat([tensor, torch.zeros(pad_shape, device=tensor.device, dtype=tensor.dtype)], dim=0) + + +def pack_masks(masks): + """Pack binary masks [*, H, W] to bit-packed [*, H, W//8] uint8. W must be divisible by 8.""" + binary = masks > 0 + shifts = torch.arange(8, device=masks.device) + return (binary.view(*masks.shape[:-1], -1, 8) * (1 << shifts)).sum(-1).byte() + + +def unpack_masks(packed): + """Unpack bit-packed [*, H, W//8] uint8 to bool [*, H, W*8].""" + shifts = torch.arange(8, device=packed.device) + return ((packed.unsqueeze(-1) >> shifts) & 1).view(*packed.shape[:-1], -1).bool() + + +def _compute_backbone(backbone_fn, frame, frame_idx=None): + """Compute backbone features for a single frame. Returns (vision_feats, vision_pos, feat_sizes, features, trunk_out).""" + features, positions, trunk_out = backbone_fn(frame, frame_idx=frame_idx) + feat_sizes = [(x.shape[-2], x.shape[-1]) for x in features] + vision_feats = [x.flatten(2).permute(0, 2, 1) for x in features] + vision_pos = [x.flatten(2).permute(0, 2, 1) for x in positions] + return vision_feats, vision_pos, feat_sizes, features, trunk_out + + +def collect_memory_tokens(output_dict, frame_idx, num_maskmem, maskmem_tpos_enc, device, + collect_image_feats=False, tpos_v2=False, num_buckets=None): + """Collect spatial memory, position encodings, and optionally image features from past frames.""" + to_cat_memory, to_cat_memory_pos = [], [] + to_cat_image_feat, to_cat_image_pos = [], [] + + def _append(out, tpos_idx): + feats = out["maskmem_features"].to(device) + if num_buckets is not None: + feats = _pad_to_buckets(feats, num_buckets) + to_cat_memory.append(feats.flatten(2).permute(0, 2, 1)) + enc = out["maskmem_pos_enc"][-1].to(device).flatten(2).permute(0, 2, 1) + if num_buckets is not None: + enc = _pad_to_buckets(enc, num_buckets) + tpos = cast_to_input(maskmem_tpos_enc[tpos_idx], enc) + to_cat_memory_pos.append(enc + tpos) + if collect_image_feats and "image_features" in out: + to_cat_image_feat.append(out["image_features"].to(device)) + to_cat_image_pos.append(out["image_pos_enc"].to(device) + tpos) + + cond_outputs = output_dict["cond_frame_outputs"] + for t, out in cond_outputs.items(): + if tpos_v2: + t_pos = frame_idx - t + tpos_idx = num_maskmem - t_pos - 1 if 0 < t_pos < num_maskmem else num_maskmem - 1 + else: + tpos_idx = num_maskmem - 1 + _append(out, tpos_idx) + + for t_pos in range(1, num_maskmem): + out = output_dict["non_cond_frame_outputs"].get(frame_idx - (num_maskmem - t_pos), None) + if out is None or out.get("maskmem_features") is None: + continue + _append(out, num_maskmem - t_pos - 1) + + return to_cat_memory, to_cat_memory_pos, to_cat_image_feat, to_cat_image_pos, cond_outputs + + +def compute_tpos_enc(rel_pos_list, device, d_model, proj_layer, dtype=None, max_abs_pos=None): + """Temporal position encoding for object pointers.""" + pos_enc = torch.tensor(rel_pos_list, dtype=torch.float32, device=device) / max((max_abs_pos or 2) - 1, 1) + pos_enc = get_1d_sine_pe(pos_enc, dim=d_model) + if dtype is not None: + pos_enc = pos_enc.to(dtype) + return proj_layer(pos_enc) + + +def forward_sam_heads(backbone_features, prompt_encoder, mask_decoder, obj_ptr_proj, no_obj_fn, + image_size, point_inputs=None, mask_inputs=None, box_inputs=None, + high_res_features=None, multimask_output=False): + """Shared SAM prompt encoder + mask decoder forward for both SAM3 and SAM3.1 trackers.""" + device = backbone_features.device + # Batch size from inputs (mask_inputs may have N_obj > 1 while backbone is batch 1) + if mask_inputs is not None: + B = mask_inputs.shape[0] + elif box_inputs is not None: + B = box_inputs.shape[0] + elif point_inputs is not None: + B = point_inputs["point_coords"].shape[0] + else: + B = backbone_features.shape[0] + + if point_inputs is not None: + sam_point_coords = point_inputs["point_coords"] + sam_point_labels = point_inputs["point_labels"] + else: + sam_point_coords = torch.zeros(B, 1, 2, device=device) + sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device) + + if mask_inputs is not None: + prompt_size = (prompt_encoder.image_embedding_size[0] * 4, prompt_encoder.image_embedding_size[1] * 4) + if mask_inputs.shape[-2:] != prompt_size: + sam_mask_prompt = F.interpolate(mask_inputs, size=prompt_size, mode="bilinear", align_corners=False, antialias=True) + else: + sam_mask_prompt = mask_inputs + else: + sam_mask_prompt = None + + sparse, dense = prompt_encoder(points=(sam_point_coords, sam_point_labels), boxes=box_inputs, masks=sam_mask_prompt) + sparse = cast_to_input(sparse, backbone_features) + dense = cast_to_input(dense, backbone_features) + image_pe = cast_to_input(prompt_encoder.get_dense_pe(), backbone_features) + + low_res_multimasks, ious, sam_output_tokens, object_score_logits = mask_decoder( + image_embeddings=backbone_features, image_pe=image_pe, + sparse_prompt_embeddings=sparse, dense_prompt_embeddings=dense, + high_res_features=high_res_features, multimask_output=multimask_output, return_all=True, + ) + + is_obj_appearing = object_score_logits > 0 + low_res_multimasks = torch.where(is_obj_appearing[:, None, None], low_res_multimasks, + torch.tensor(NO_OBJ_SCORE, device=device, dtype=low_res_multimasks.dtype)) + high_res_multimasks = F.interpolate(low_res_multimasks, size=(image_size, image_size), mode="bilinear", align_corners=False) + + sam_output_token = sam_output_tokens[:, 0] + if multimask_output: + best_iou_inds = torch.argmax(ious, dim=-1) + batch_inds = torch.arange(B, device=device) + low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) + high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) + if sam_output_tokens.size(1) > 1: + sam_output_token = sam_output_tokens[batch_inds, best_iou_inds] + else: + low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks + + obj_ptr = obj_ptr_proj(sam_output_token) + obj_ptr = no_obj_fn(obj_ptr, is_obj_appearing) + + return low_res_masks, high_res_masks, obj_ptr, object_score_logits + + +def use_mask_as_output(backbone_features, high_res_features, mask_inputs, mask_downsample, + prompt_encoder, mask_decoder, obj_ptr_proj, no_obj_fn, image_size, backbone_stride): + """Shared mask-as-output for both SAM3 and SAM3.1 trackers.""" + out_scale, out_bias = 20.0, -10.0 + mask_inputs_float = cast_to_input(mask_inputs, backbone_features) + high_res_masks = mask_inputs_float * out_scale + out_bias + low_res_masks = F.interpolate(high_res_masks, size=(image_size // backbone_stride * 4,) * 2, + mode="bilinear", align_corners=False, antialias=True) + _, _, obj_ptr, _ = forward_sam_heads( + backbone_features, prompt_encoder, mask_decoder, obj_ptr_proj, no_obj_fn, + image_size, mask_inputs=mask_downsample(mask_inputs_float), high_res_features=high_res_features, + ) + is_obj_appearing = torch.any(mask_inputs.flatten(1) > 0.0, dim=1)[..., None] + alpha = is_obj_appearing.to(obj_ptr.dtype) + object_score_logits = out_scale * alpha + out_bias + return low_res_masks, high_res_masks, obj_ptr, object_score_logits + + +# Split attention with configurable input dims (for asymmetric cross-attention) +class SplitAttn(nn.Module): + def __init__(self, embed_dim, num_heads=1, kv_dim=None, internal_dim=None, device=None, dtype=None, operations=None): + super().__init__() + self.num_heads = num_heads + kv_dim = kv_dim or embed_dim + internal_dim = internal_dim or embed_dim + self.q_proj = operations.Linear(embed_dim, internal_dim, device=device, dtype=dtype) + self.k_proj = operations.Linear(kv_dim, internal_dim, device=device, dtype=dtype) + self.v_proj = operations.Linear(kv_dim, internal_dim, device=device, dtype=dtype) + self.out_proj = operations.Linear(internal_dim, embed_dim, device=device, dtype=dtype) + + def forward(self, q, k=None, v=None, rope=None, num_k_exclude_rope=0): + if k is None: + k = q + if v is None: + v = k + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + if rope is not None: + q, k = apply_rope_memory(q, k, rope, self.num_heads, num_k_exclude_rope) + out = optimized_attention(q, k, v, self.num_heads, low_precision_attention=False) + return self.out_proj(out) + + +class MemoryAttnLayer(nn.Module): + def __init__(self, d_model=256, num_heads=1, kv_dim=64, dim_ff=2048, device=None, dtype=None, operations=None): + super().__init__() + self.num_heads = num_heads + self.self_attn = SplitAttn(d_model, num_heads, device=device, dtype=dtype, operations=operations) + self.cross_attn_image = SplitAttn(d_model, num_heads, kv_dim=kv_dim, device=device, dtype=dtype, operations=operations) + self.linear1 = operations.Linear(d_model, dim_ff, device=device, dtype=dtype) + self.linear2 = operations.Linear(dim_ff, d_model, device=device, dtype=dtype) + self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.norm2 = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.norm3 = operations.LayerNorm(d_model, device=device, dtype=dtype) + + def forward(self, x, memory, memory_pos=None, rope=None, num_k_exclude_rope=0): + x = x + self.self_attn(self.norm1(x), rope=rope) + mem_k = memory + memory_pos if memory_pos is not None else memory + x = x + self.cross_attn_image(self.norm2(x), mem_k, memory, rope=rope, num_k_exclude_rope=num_k_exclude_rope) + normed = self.norm3(x) + x = x + self.linear2(F.relu(self.linear1(normed))) + return x + + +class MemoryAttnEncoder(nn.Module): + def __init__(self, d_model=256, num_heads=1, kv_dim=64, dim_ff=2048, num_layers=4, image_size=1008, patch_size=14, + device=None, dtype=None, operations=None): + super().__init__() + self.layers = nn.ModuleList([ + MemoryAttnLayer(d_model, num_heads, kv_dim, dim_ff, device=device, dtype=dtype, operations=operations) + for _ in range(num_layers) + ]) + self.norm = operations.LayerNorm(d_model, device=device, dtype=dtype) + hw = image_size // patch_size + self.register_buffer("_rope", rope_2d(hw, hw, d_model // num_heads), persistent=False) + + def forward(self, x, memory, src_pos=None, memory_pos=None, num_k_exclude_rope=0): + if src_pos is not None: + x = x + 0.1 * src_pos + + rope = self._rope.to(device=x.device) + for layer in self.layers: + x = layer(x, memory, memory_pos=memory_pos, rope=rope, num_k_exclude_rope=num_k_exclude_rope) + return self.norm(x) + + +class MemoryTransformer(nn.Module): + def __init__(self, d_model=256, num_heads=1, kv_dim=64, dim_ff=2048, num_layers=4, device=None, dtype=None, operations=None): + super().__init__() + self.encoder = MemoryAttnEncoder(d_model, num_heads, kv_dim, dim_ff, num_layers, device=device, dtype=dtype, operations=operations) + + +def _upscale_masks(output_upscaling, conv_s0, conv_s1, src_out, high_res_features): + """Shared upscaling for SAM mask decoders: deconv + high-res feature integration.""" + dc1, ln1, act1, dc2, act2 = output_upscaling + if high_res_features is not None: + upscaled = act1(ln1(dc1(src_out) + conv_s1(high_res_features[1]))) + upscaled = act2(dc2(upscaled) + conv_s0(high_res_features[0])) + else: + upscaled = act2(dc2(act1(ln1(dc1(src_out))))) + return upscaled + + +class SAMMaskDecoder(nn.Module): + def __init__(self, d_model=256, num_multimask_outputs=3, device=None, dtype=None, operations=None): + super().__init__() + self.num_mask_tokens = num_multimask_outputs + 1 + + self.transformer = SAMTwoWayTransformer(depth=2, embedding_dim=d_model, num_heads=8, mlp_dim=2048, device=device, dtype=dtype, operations=operations) + + self.iou_token = operations.Embedding(1, d_model, device=device, dtype=dtype) + self.mask_tokens = operations.Embedding(self.num_mask_tokens, d_model, device=device, dtype=dtype) + self.obj_score_token = operations.Embedding(1, d_model, device=device, dtype=dtype) + + # Output upscaling: d_model -> d_model//4 -> d_model//8 at 4x resolution + LN2d = LayerNorm2d_op(operations) + self.output_upscaling = nn.Sequential( + operations.ConvTranspose2d(d_model, d_model // 4, kernel_size=2, stride=2, device=device, dtype=dtype), LN2d(d_model // 4, device=device, dtype=dtype), nn.GELU(), + operations.ConvTranspose2d(d_model // 4, d_model // 8, kernel_size=2, stride=2, device=device, dtype=dtype), nn.GELU(), + ) + + # High-res feature integration + self.conv_s0 = operations.Conv2d(d_model, d_model // 8, kernel_size=1, device=device, dtype=dtype) + self.conv_s1 = operations.Conv2d(d_model, d_model // 4, kernel_size=1, device=device, dtype=dtype) + + # Per-mask hypernetwork MLPs + self.output_hypernetworks_mlps = nn.ModuleList([ + MLP(d_model, d_model, d_model // 8, 3, device=device, dtype=dtype, operations=operations) + for _ in range(self.num_mask_tokens) + ]) + + self.iou_prediction_head = MLP(d_model, d_model, self.num_mask_tokens, 3, device=device, dtype=dtype, operations=operations) + self.pred_obj_score_head = MLP(d_model, d_model, 1, 3, device=device, dtype=dtype, operations=operations) + + def forward(self, image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, + high_res_features=None, multimask_output=False, return_all=False): + B = sparse_prompt_embeddings.shape[0] + ref = sparse_prompt_embeddings + # Token order: [obj_score(1), iou(1), mask(num_mask_tokens)] + tokens = torch.cat([cast_to_input(self.obj_score_token.weight, ref), + cast_to_input(self.iou_token.weight, ref), + cast_to_input(self.mask_tokens.weight, ref)], dim=0) + tokens = torch.cat([tokens.unsqueeze(0).expand(B, -1, -1), sparse_prompt_embeddings], dim=1) + + src = image_embeddings + if src.shape[0] != B: + src = src.expand(B, -1, -1, -1) + src = src + dense_prompt_embeddings + pos_src = image_pe.expand(B, -1, -1, -1) + + b, c, h, w = src.shape + src_flat = src.flatten(2).permute(0, 2, 1) + pos_flat = pos_src.flatten(2).permute(0, 2, 1) + + hs, src_out = self.transformer(src_flat, pos_flat, tokens) + + obj_score_token_out = hs[:, 0, :] + iou_token_out = hs[:, 1, :] + mask_tokens_out = hs[:, 2:2 + self.num_mask_tokens, :] + + src_out = src_out.permute(0, 2, 1).view(b, c, h, w) + upscaled = _upscale_masks(self.output_upscaling, self.conv_s0, self.conv_s1, src_out, high_res_features) + + hyper_in = torch.stack([ + mlp(mask_tokens_out[:, i, :]) for i, mlp in enumerate(self.output_hypernetworks_mlps) + ], dim=1) + + masks = (hyper_in @ upscaled.flatten(2)).view(B, self.num_mask_tokens, upscaled.shape[2], upscaled.shape[3]) + iou_pred = self.iou_prediction_head(iou_token_out) + object_score_logits = self.pred_obj_score_head(obj_score_token_out) + + if multimask_output: + out_masks = masks[:, 1:] + out_iou = iou_pred[:, 1:] + out_tokens = mask_tokens_out[:, 1:] + else: + out_masks = masks[:, 0:1] + out_iou = iou_pred[:, 0:1] + out_tokens = mask_tokens_out[:, 0:1] + + if return_all: + return out_masks, out_iou, out_tokens, object_score_logits + return out_masks, out_iou + + +class SAMPromptEncoder(nn.Module): + def __init__(self, d_model=256, image_embedding_size=(72, 72), input_image_size=(1008, 1008), device=None, dtype=None, operations=None): + super().__init__() + self.embed_dim = d_model + self.image_embedding_size = image_embedding_size + self.input_image_size = input_image_size + + self.pe_layer = PositionEmbeddingRandom(d_model // 2) + self.point_embeddings = nn.ModuleList([ + operations.Embedding(1, d_model, device=device, dtype=dtype) for _ in range(4) + ]) + self.not_a_point_embed = operations.Embedding(1, d_model, device=device, dtype=dtype) + + LN2d = LayerNorm2d_op(operations) + self.mask_downscaling = nn.Sequential( + operations.Conv2d(1, 4, kernel_size=2, stride=2, device=device, dtype=dtype), + LN2d(4, device=device, dtype=dtype), nn.GELU(), + operations.Conv2d(4, 16, kernel_size=2, stride=2, device=device, dtype=dtype), + LN2d(16, device=device, dtype=dtype), nn.GELU(), + operations.Conv2d(16, d_model, kernel_size=1, device=device, dtype=dtype), + ) + self.no_mask_embed = operations.Embedding(1, d_model, device=device, dtype=dtype) + + def get_dense_pe(self): + return self.pe_layer(self.image_embedding_size) + + def forward(self, points=None, boxes=None, masks=None): + ref = points[0] if points is not None else boxes if boxes is not None else masks + B = 1 + sparse = torch.empty((B, 0, self.embed_dim), device=ref.device, dtype=ref.dtype) + + if points is not None: + coords, labels = points + B = coords.shape[0] + # Pad with an extra point (label=-1) when no boxes are provided (matching reference) + if boxes is None: + coords = torch.cat([coords, torch.zeros(B, 1, 2, device=coords.device, dtype=coords.dtype)], dim=1) + labels = torch.cat([labels, -torch.ones(B, 1, device=labels.device, dtype=labels.dtype)], dim=1) + pe = self.pe_layer.forward_with_coords(coords + 0.5, self.input_image_size) + for i in range(4): + pe[labels == i] += cast_to_input(self.point_embeddings[i].weight, ref) + invalid = (labels == -1) + pe[invalid] = 0.0 + pe[invalid] += cast_to_input(self.not_a_point_embed.weight, ref) + sparse = torch.cat([sparse.expand(B, -1, -1), pe], dim=1) + + if boxes is not None: + B = boxes.shape[0] + corners = self.pe_layer.forward_with_coords((boxes.reshape(-1, 2, 2) + 0.5), self.input_image_size) + corners[:, 0] += cast_to_input(self.point_embeddings[2].weight, ref) + corners[:, 1] += cast_to_input(self.point_embeddings[3].weight, ref) + sparse = torch.cat([sparse.expand(B, -1, -1), corners], dim=1) + + if masks is not None: + dense = self.mask_downscaling(masks) + else: + dense = cast_to_input(self.no_mask_embed.weight, ref).reshape(1, -1, 1, 1).expand( + B, -1, self.image_embedding_size[0], self.image_embedding_size[1]) + + return sparse, dense + + +class CXBlock(nn.Module): + def __init__(self, dim=256, kernel_size=7, device=None, dtype=None, operations=None): + super().__init__() + self.dwconv = operations.Conv2d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=dim, device=device, dtype=dtype) + self.norm = operations.LayerNorm(dim, device=device, dtype=dtype) + self.pwconv1 = operations.Linear(dim, 4 * dim, device=device, dtype=dtype) + self.pwconv2 = operations.Linear(4 * dim, dim, device=device, dtype=dtype) + self.gamma = nn.Parameter(torch.ones(dim, device=device, dtype=dtype)) + + def forward(self, x): + residual = x + x = self.dwconv(x).permute(0, 2, 3, 1) + x = self.pwconv2(F.gelu(self.pwconv1(self.norm(x)))) + x.mul_(cast_to_input(self.gamma, x)) + return residual + x.permute(0, 3, 1, 2) + + +class MaskDownSampler(nn.Module): + def __init__(self, out_dim=256, in_chans=1, channels=None, interpol_size=(1152, 1152), device=None, dtype=None, operations=None): + super().__init__() + self.interpol_size = list(interpol_size) if interpol_size else None + if channels is None: + channels = [4, 16, 64, out_dim] # SAM3 default + LN2d = LayerNorm2d_op(operations) + layers = [] + prev = in_chans + for ch in channels: + layers += [operations.Conv2d(prev, ch, kernel_size=3, stride=2, padding=1, device=device, dtype=dtype), + LN2d(ch, device=device, dtype=dtype), nn.GELU()] + prev = ch + layers.append(operations.Conv2d(prev, out_dim, kernel_size=1, device=device, dtype=dtype)) + self.encoder = nn.Sequential(*layers) + + def forward(self, x): + if self.interpol_size is not None and list(x.shape[-2:]) != self.interpol_size: + x = F.interpolate(x, size=self.interpol_size, mode="bilinear", align_corners=False, antialias=True) + return self.encoder(x) + + +class Fuser(nn.Module): + def __init__(self, dim=256, num_layers=2, device=None, dtype=None, operations=None): + super().__init__() + self.layers = nn.Sequential(*[CXBlock(dim, device=device, dtype=dtype, operations=operations) for _ in range(num_layers)]) + + def forward(self, x): + return self.layers(x) + + +# --- SAM3.1 Multiplex components --- + +class DecoupledMemoryAttnLayer(nn.Module): + """Decoupled cross-attention layer for SAM3.1: fuses image and memory projections.""" + + def __init__(self, d_model=256, num_heads=1, dim_ff=2048, device=None, dtype=None, operations=None): + super().__init__() + self.num_heads = num_heads + # Self-attention projections (flat, not nested) + self.self_attn_q_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.self_attn_k_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.self_attn_v_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.self_attn_out_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + # Cross-attention projections + self.cross_attn_q_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.cross_attn_k_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.cross_attn_v_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.cross_attn_out_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + # Image cross-attention (q/k only, fused with cross_attn) + self.image_cross_attn_q_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.image_cross_attn_k_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + # FFN + self.linear1 = operations.Linear(d_model, dim_ff, device=device, dtype=dtype) + self.linear2 = operations.Linear(dim_ff, d_model, device=device, dtype=dtype) + self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.norm2 = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.norm3 = operations.LayerNorm(d_model, device=device, dtype=dtype) + + def forward(self, image, x, memory_image, memory, memory_image_pos=None, + rope=None, num_k_exclude_rope=0): + # Self-attention with RoPE + normed = self.norm1(x) + q = self.self_attn_q_proj(normed) + k = self.self_attn_k_proj(normed) + v = self.self_attn_v_proj(normed) + if rope is not None: + q, k = apply_rope_memory(q, k, rope, self.num_heads, 0) + x = x + self.self_attn_out_proj(optimized_attention(q, k, v, self.num_heads, low_precision_attention=False)) + + # Decoupled cross-attention: fuse image and memory projections + normed = self.norm2(x) + q = self.image_cross_attn_q_proj(image) + self.cross_attn_q_proj(normed) + k = self.image_cross_attn_k_proj(memory_image) + self.cross_attn_k_proj(memory) + if memory_image_pos is not None: + k = k + memory_image_pos + v = self.cross_attn_v_proj(memory) + if rope is not None: + q, k = apply_rope_memory(q, k, rope, self.num_heads, num_k_exclude_rope) + x = x + self.cross_attn_out_proj(optimized_attention(q, k, v, self.num_heads, low_precision_attention=False)) + + # FFN + x = x + self.linear2(F.gelu(self.linear1(self.norm3(x)))) + return image, x + + +class DecoupledMemoryEncoder(nn.Module): + """Memory attention encoder for SAM3.1 with decoupled cross-attention.""" + + def __init__(self, d_model=256, num_heads=1, dim_ff=2048, num_layers=4, image_size=1008, patch_size=14, + device=None, dtype=None, operations=None): + super().__init__() + self.layers = nn.ModuleList([ + DecoupledMemoryAttnLayer(d_model, num_heads, dim_ff, device=device, dtype=dtype, operations=operations) + for _ in range(num_layers) + ]) + self.norm = operations.LayerNorm(d_model, device=device, dtype=dtype) + hw = image_size // patch_size + self.register_buffer("_rope", rope_2d(hw, hw, d_model // num_heads), persistent=False) + + def forward(self, x, memory, memory_pos=None, src_pos=None, num_k_exclude_rope=0, + memory_image=None, memory_image_pos=None): + image = x # constant residual for decoupled cross-attention + output = x + if src_pos is not None: + output = output + 0.1 * src_pos + + B, _, C = x.shape + rope = self._rope.to(device=x.device) + + # memory_image: raw backbone features from past frames for decoupled cross-attention + if memory_image is None: + # Fallback: use spatial portion of memory (without obj pointers) + num_spatial = memory.shape[1] - num_k_exclude_rope + memory_image = memory[:, :num_spatial] + memory_image_pos = memory_pos[:, :num_spatial] if memory_pos is not None else None + # Pad memory_image to match memory length (zeros for obj pointer tokens) + if memory_image.shape[1] < memory.shape[1]: + pad_len = memory.shape[1] - memory_image.shape[1] + pad = torch.zeros(B, pad_len, C, device=memory.device, dtype=memory.dtype) + memory_image = torch.cat([memory_image, pad], dim=1) + if memory_image_pos is not None: + ptr_pos = memory_pos[:, -pad_len:] if memory_pos is not None else torch.zeros_like(pad) + memory_image_pos = torch.cat([memory_image_pos, ptr_pos], dim=1) + + for layer in self.layers: + image, output = layer(image, output, memory_image, memory, + memory_image_pos=memory_image_pos, rope=rope, + num_k_exclude_rope=num_k_exclude_rope) + + return self.norm(output) + + +class DecoupledMemoryTransformer(nn.Module): + def __init__(self, d_model=256, num_heads=1, dim_ff=2048, num_layers=4, device=None, dtype=None, operations=None): + super().__init__() + self.encoder = DecoupledMemoryEncoder(d_model, num_heads, dim_ff, num_layers, + device=device, dtype=dtype, operations=operations) + + +class MemoryBackbone(nn.Module): + """Memory encoder: downsamples mask, fuses with pixel features, optionally compresses.""" + + def __init__(self, d_model=256, out_dim=None, in_chans=1, channels=None, device=None, dtype=None, operations=None): + super().__init__() + self.mask_downsampler = MaskDownSampler(d_model, in_chans=in_chans, channels=channels, device=device, dtype=dtype, operations=operations) + self.pix_feat_proj = operations.Conv2d(d_model, d_model, kernel_size=1, device=device, dtype=dtype) + self.fuser = Fuser(d_model, num_layers=2, device=device, dtype=dtype, operations=operations) + self.has_out_proj = out_dim is not None and out_dim != d_model + if self.has_out_proj: + self.out_proj = operations.Conv2d(d_model, out_dim, kernel_size=1, device=device, dtype=dtype) + feat_dim = out_dim + else: + feat_dim = d_model + self.position_encoding = PositionEmbeddingSine(num_pos_feats=feat_dim, normalize=True) + + def forward(self, image_features, mask_for_mem, skip_mask_sigmoid=False): + if not skip_mask_sigmoid: + mask_for_mem = mask_for_mem.sigmoid() + mask_features = self.mask_downsampler(cast_to_input(mask_for_mem, image_features)) + if mask_features.shape[-2:] != image_features.shape[-2:]: + mask_features = F.interpolate(mask_features, size=image_features.shape[-2:], mode="bilinear", align_corners=False) + features = self.pix_feat_proj(image_features) + mask_features + features = self.fuser(features) + if self.has_out_proj: + features = self.out_proj(features) + pos = cast_to_input(self.position_encoding(features), features) + return {"vision_features": features, "vision_pos_enc": [pos]} + + +class MultiplexMaskDecoder(nn.Module): + """SAM mask decoder for SAM3.1 multiplex: predicts masks for num_multiplex objects simultaneously. + + Uses multimask_outputs_only=True: num_mask_output_per_object = num_multimask_outputs (no +1). + Hypernetwork MLPs are shared across multiplex objects. + Token order: [obj_score_token(M), iou_token(M), mask_tokens(M*T)]. + """ + + def __init__(self, d_model=256, num_multiplex=16, num_multimask_outputs=3, device=None, dtype=None, operations=None): + super().__init__() + self.num_multiplex = num_multiplex + self.num_mask_output_per_object = num_multimask_outputs # 3 (multimask_outputs_only) + total_mask_tokens = num_multiplex * self.num_mask_output_per_object # 48 + + self.transformer = SAMTwoWayTransformer(depth=2, embedding_dim=d_model, num_heads=8, mlp_dim=2048, device=device, dtype=dtype, operations=operations) + + self.obj_score_token = operations.Embedding(num_multiplex, d_model, device=device, dtype=dtype) + self.iou_token = operations.Embedding(num_multiplex, d_model, device=device, dtype=dtype) + self.mask_tokens = operations.Embedding(total_mask_tokens, d_model, device=device, dtype=dtype) + + LN2d = LayerNorm2d_op(operations) + self.output_upscaling = nn.Sequential( + operations.ConvTranspose2d(d_model, d_model // 4, kernel_size=2, stride=2, device=device, dtype=dtype), + LN2d(d_model // 4, device=device, dtype=dtype), nn.GELU(), + operations.ConvTranspose2d(d_model // 4, d_model // 8, kernel_size=2, stride=2, device=device, dtype=dtype), nn.GELU(), + ) + self.conv_s0 = operations.Conv2d(d_model, d_model // 8, kernel_size=1, device=device, dtype=dtype) + self.conv_s1 = operations.Conv2d(d_model, d_model // 4, kernel_size=1, device=device, dtype=dtype) + + # Shared across all multiplex objects (one per mask output) + self.output_hypernetworks_mlps = nn.ModuleList([ + MLP(d_model, d_model, d_model // 8, 3, device=device, dtype=dtype, operations=operations) + for _ in range(self.num_mask_output_per_object) + ]) + self.iou_prediction_head = MLP(d_model, d_model, self.num_mask_output_per_object, 3, device=device, dtype=dtype, operations=operations) + self.pred_obj_score_head = MLP(d_model, d_model, 1, 3, device=device, dtype=dtype, operations=operations) + + def forward(self, image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, + high_res_features=None, multimask_output=False, return_all=False, extra_per_object_embeddings=None): + B = sparse_prompt_embeddings.shape[0] + M = self.num_multiplex + T = self.num_mask_output_per_object + + # Token order: [obj_score(M), iou(M), mask(M*T)] + ref = sparse_prompt_embeddings + mask_tokens = cast_to_input(self.mask_tokens.weight, ref) + if extra_per_object_embeddings is not None: + mask_tokens = mask_tokens.view(1, M, T, -1).expand(B, -1, -1, -1) + extra_per_object_embeddings.unsqueeze(2) + mask_tokens = mask_tokens.flatten(1, 2) # [B, M*T, C] + other_tokens = torch.cat([cast_to_input(self.obj_score_token.weight, ref), + cast_to_input(self.iou_token.weight, ref)], dim=0).unsqueeze(0).expand(B, -1, -1) + tokens = torch.cat([other_tokens, mask_tokens, sparse_prompt_embeddings], dim=1) + else: + tokens = torch.cat([cast_to_input(self.obj_score_token.weight, ref), + cast_to_input(self.iou_token.weight, ref), mask_tokens], dim=0) + tokens = torch.cat([tokens.unsqueeze(0).expand(B, -1, -1), sparse_prompt_embeddings], dim=1) + + src = image_embeddings + if src.shape[0] != B: + src = src.expand(B, -1, -1, -1) + src = src + dense_prompt_embeddings + pos_src = image_pe.expand(B, -1, -1, -1) + + b, c, h, w = src.shape + hs, src_out = self.transformer(src.flatten(2).permute(0, 2, 1), pos_src.flatten(2).permute(0, 2, 1), tokens) + + # Parse output tokens + obj_score_token_out = hs[:, :M] + iou_token_out = hs[:, M:2 * M] + mask_tokens_out = hs[:, 2 * M:2 * M + M * T] + + src_out = src_out.permute(0, 2, 1).view(b, c, h, w) + upscaled = _upscale_masks(self.output_upscaling, self.conv_s0, self.conv_s1, src_out, high_res_features) + + # Reshape mask tokens to [B, M, T, C] and apply shared hypernetwork MLPs per mask output index + mask_tokens_2d = mask_tokens_out.view(B, M, T, -1) + hyper_in = torch.stack([ + self.output_hypernetworks_mlps[i](mask_tokens_2d[:, :, i, :]) # [B, M, C//8] + for i in range(T) + ], dim=2) # [B, M, T, C//8] + + # Generate masks: [B, M*T, H*W] -> [B, M, T, H, W] + masks = torch.bmm(hyper_in.flatten(1, 2), upscaled.flatten(2)).view(b, M, T, upscaled.shape[2], upscaled.shape[3]) + + # IoU and object scores + iou_pred = self.iou_prediction_head(iou_token_out).view(b, M, T) + object_score_logits = self.pred_obj_score_head(obj_score_token_out) # [B, M, 1] + + # multimask_outputs_only: always output all T masks (no singlemask token) + sam_tokens_out = mask_tokens_2d[:, :, 0:1] # [B, M, 1, C] + + if return_all: + return masks, iou_pred, sam_tokens_out, object_score_logits + return masks, iou_pred + + +class SAM3Tracker(nn.Module): + def __init__(self, d_model=256, mem_dim=64, num_maskmem=7, device=None, dtype=None, operations=None, **kwargs): + super().__init__() + + # Memory attention transformer + self.transformer = MemoryTransformer(d_model, num_heads=1, kv_dim=mem_dim, dim_ff=2048, num_layers=4, + device=device, dtype=dtype, operations=operations) + # SAM components + self.sam_mask_decoder = SAMMaskDecoder(d_model, device=device, dtype=dtype, operations=operations) + self.sam_prompt_encoder = SAMPromptEncoder(d_model, device=device, dtype=dtype, operations=operations) + + # Memory backbone + self.maskmem_backbone = MemoryBackbone(d_model, out_dim=mem_dim, device=device, dtype=dtype, operations=operations) + + # Standalone parameters + self.maskmem_tpos_enc = nn.Parameter(torch.zeros(num_maskmem, 1, 1, mem_dim, device=device, dtype=dtype)) + self.no_mem_embed = nn.Parameter(torch.zeros(1, 1, d_model, device=device, dtype=dtype)) + self.register_buffer("no_mem_pos_enc", torch.zeros(1, 1, d_model, device=device, dtype=dtype)) # checkpoint key, unused in forward + self.no_obj_embed_spatial = nn.Parameter(torch.zeros(1, mem_dim, device=device, dtype=dtype)) + self.no_obj_ptr = nn.Parameter(torch.zeros(1, d_model, device=device, dtype=dtype)) + + # Object pointer projection + self.obj_ptr_proj = MLP(d_model, d_model, d_model, 3, device=device, dtype=dtype, operations=operations) + self.obj_ptr_tpos_proj = operations.Linear(d_model, mem_dim, device=device, dtype=dtype) + + # Mask downsample: Conv2d stride 4 to reduce GT mask to SAM logit scale + self.mask_downsample = operations.Conv2d(1, 1, kernel_size=4, stride=4, device=device, dtype=dtype) + + # Config + self.d_model = d_model + self.mem_dim = mem_dim + self.num_maskmem = num_maskmem + self.image_size = 1008 + self.backbone_stride = 14 + self.max_obj_ptrs_in_encoder = 16 + self.sigmoid_scale_for_mem_enc = 20.0 + self.sigmoid_bias_for_mem_enc = -10.0 + + def _no_obj_blend(self, obj_ptr, is_obj): + alpha = is_obj.to(obj_ptr.dtype) + return torch.lerp(cast_to_input(self.no_obj_ptr, obj_ptr), obj_ptr, alpha) + + def _forward_sam_heads(self, backbone_features, point_inputs=None, mask_inputs=None, box_inputs=None, + high_res_features=None, multimask_output=False): + return forward_sam_heads(backbone_features, self.sam_prompt_encoder, self.sam_mask_decoder, + self.obj_ptr_proj, self._no_obj_blend, self.image_size, + point_inputs, mask_inputs, box_inputs, high_res_features, multimask_output) + + def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs): + return use_mask_as_output(backbone_features, high_res_features, mask_inputs, + self.mask_downsample, self.sam_prompt_encoder, self.sam_mask_decoder, + self.obj_ptr_proj, self._no_obj_blend, self.image_size, self.backbone_stride) + + def _prepare_memory_conditioned_features(self, frame_idx, is_init_cond_frame, current_vision_feats, current_vision_pos_embeds, feat_sizes, output_dict, num_frames): + """Fuse current frame features with memory from previous frames.""" + B = current_vision_feats[-1].shape[0] + C = self.d_model + H, W = feat_sizes[-1] + device = current_vision_feats[-1].device + + if self.num_maskmem == 0: + return current_vision_feats[-1].permute(0, 2, 1).view(B, C, H, W) + + if is_init_cond_frame: + # First conditioning frame: no memory yet, add no_mem_embed + pix_feat = current_vision_feats[-1] + cast_to_input(self.no_mem_embed, current_vision_feats[-1]) + return to_spatial(pix_feat, H, W) + + to_cat_memory, to_cat_memory_pos, _, _, cond_outputs = collect_memory_tokens( + output_dict, frame_idx, self.num_maskmem, self.maskmem_tpos_enc, device) + + max_obj_ptrs = min(num_frames, self.max_obj_ptrs_in_encoder) + pos_and_ptrs = [] + for t, out in cond_outputs.items(): + if t <= frame_idx: + pos_and_ptrs.append(((frame_idx - t), out["obj_ptr"].to(device))) + for t_diff in range(1, max_obj_ptrs): + t = frame_idx - t_diff + if t < 0: + break + out = output_dict["non_cond_frame_outputs"].get(t, None) + if out is not None: + pos_and_ptrs.append((t_diff, out["obj_ptr"].to(device))) + + num_obj_ptr_tokens = 0 + if len(pos_and_ptrs) > 0: + pos_list, ptrs_list = zip(*pos_and_ptrs) + obj_ptrs = torch.stack(ptrs_list, dim=1) # [B, N, C=256] + + # Temporal position encoding for pointers + obj_pos = compute_tpos_enc( + list(pos_list), device, self.d_model, self.obj_ptr_tpos_proj, + max_abs_pos=max_obj_ptrs, dtype=current_vision_feats[-1].dtype + ) # [N, mem_dim=64] + obj_pos = obj_pos.unsqueeze(0).expand(B, -1, -1) # [B, N, 64] + + # Split each 256-dim pointer into 4 x 64-dim tokens + if self.mem_dim < C: + N = obj_ptrs.shape[1] + obj_ptrs = obj_ptrs.view(B, N, C // self.mem_dim, self.mem_dim) # [B, N, 4, 64] + obj_ptrs = obj_ptrs.reshape(B, N * (C // self.mem_dim), self.mem_dim) # [B, N*4, 64] + obj_pos = obj_pos.unsqueeze(2).expand(-1, -1, C // self.mem_dim, -1) + obj_pos = obj_pos.reshape(B, N * (C // self.mem_dim), self.mem_dim) # [B, N*4, 64] + + to_cat_memory.append(obj_ptrs) + to_cat_memory_pos.append(obj_pos) + num_obj_ptr_tokens = obj_ptrs.shape[1] + + if len(to_cat_memory) == 0: + # No memory available yet, add no_mem_embed + pix_feat = current_vision_feats[-1] + cast_to_input(self.no_mem_embed, current_vision_feats[-1]) + return to_spatial(pix_feat, H, W) + + # Concatenate all memory and position encodings [B, total_mem, mem_dim=64] + memory = torch.cat(to_cat_memory, dim=1) + memory_pos = torch.cat(to_cat_memory_pos, dim=1) + + # Run memory attention encoder + pix_feat = current_vision_feats[-1] # [B, HW, C] + src_pos = current_vision_pos_embeds[-1] # [B, HW, C] + + pix_feat_with_mem = self.transformer.encoder( + x=pix_feat, + memory=memory, + src_pos=src_pos, + memory_pos=memory_pos, + num_k_exclude_rope=num_obj_ptr_tokens, + ) + return to_spatial(pix_feat_with_mem, H, W) + + def _encode_new_memory(self, pix_feat, pred_masks_high_res, object_score_logits, is_mask_from_pts=False): + """Encode predicted mask into memory features.""" + if is_mask_from_pts: + mask_for_mem = (pred_masks_high_res > 0).to(pix_feat.dtype) + else: + mask_for_mem = torch.sigmoid(pred_masks_high_res) + + mask_for_mem.mul_(self.sigmoid_scale_for_mem_enc).add_(self.sigmoid_bias_for_mem_enc) + + maskmem_out = self.maskmem_backbone(pix_feat, mask_for_mem, skip_mask_sigmoid=True) + maskmem_features = maskmem_out["vision_features"] + maskmem_pos_enc = maskmem_out["vision_pos_enc"] + + # Add no_obj_embed for occluded objects + alpha = (object_score_logits > 0).to(maskmem_features.dtype)[..., None, None] + no_obj = cast_to_input(self.no_obj_embed_spatial, maskmem_features)[..., None, None].expand_as(maskmem_features) + return maskmem_features + (1 - alpha) * no_obj, maskmem_pos_enc + + def track_step(self, frame_idx, is_init_cond_frame, current_vision_feats, current_vision_pos_embeds, feat_sizes, mask_inputs, output_dict, + num_frames, point_inputs=None): + """Track one frame: fuse with memory, predict mask, encode memory.""" + current_out = {} + + # High-res features for SAM head [stride-8, stride-4] + if len(current_vision_feats) > 1: + high_res_features = [ + x.view(x.shape[0], feat_sizes[i][0], feat_sizes[i][1], -1).permute(0, 3, 1, 2) + for i, x in enumerate(current_vision_feats[:-1]) + ] + else: + high_res_features = None + + # Top-level feature for memory + H, W = feat_sizes[-1] + + if mask_inputs is not None: + # Conditioning frame: use mask directly + pix_feat = to_spatial(current_vision_feats[-1], H, W) + sam_outputs = self._use_mask_as_output(pix_feat, high_res_features, mask_inputs) + else: + # Track frame: fuse with memory, then SAM decoder + pix_feat_with_mem = self._prepare_memory_conditioned_features( + frame_idx=frame_idx, + is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + output_dict=output_dict, + num_frames=num_frames, + ) + # Use multimask for point prompts on init frames (picks best of 3 candidates) + num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1) + multimask_output = is_init_cond_frame and 0 < num_pts <= 1 + sam_outputs = self._forward_sam_heads( + backbone_features=pix_feat_with_mem, + point_inputs=point_inputs, + high_res_features=high_res_features, + multimask_output=multimask_output, + ) + + (low_res_masks, high_res_masks, obj_ptr, object_score_logits) = sam_outputs + + # Clean low-res masks: remove sprinkles and fill holes + low_res_masks = fill_holes_in_mask_scores(low_res_masks, max_area=200) + high_res_masks = F.interpolate(low_res_masks, size=(self.image_size, self.image_size), mode="bilinear", align_corners=False) + + current_out["pred_masks"] = low_res_masks + current_out["pred_masks_high_res"] = high_res_masks + current_out["obj_ptr"] = obj_ptr + current_out["object_score_logits"] = object_score_logits + + # Encode memory + if self.num_maskmem > 0: + pix_feat = to_spatial(current_vision_feats[-1], H, W) + maskmem_features, maskmem_pos_enc = self._encode_new_memory( + pix_feat=pix_feat, + pred_masks_high_res=high_res_masks, + object_score_logits=object_score_logits, + is_mask_from_pts=(point_inputs is not None), + ) + current_out["maskmem_features"] = maskmem_features + current_out["maskmem_pos_enc"] = maskmem_pos_enc + else: + current_out["maskmem_features"] = None + current_out["maskmem_pos_enc"] = None + + return current_out + + def _compute_backbone_frame(self, backbone_fn, frame, frame_idx=None): + vision_feats, vision_pos, feat_sizes, _, _ = _compute_backbone(backbone_fn, frame, frame_idx) + # SAM3: drop last FPN level + return vision_feats[:-1], vision_pos[:-1], feat_sizes[:-1] + + def _track_single_object(self, backbone_fn, images, initial_mask, pbar=None): + """Track one object, computing backbone per frame to save VRAM.""" + N = images.shape[0] + device, dt = images.device, images.dtype + output_dict = {"cond_frame_outputs": {}, "non_cond_frame_outputs": {}} + all_masks = [] + + for frame_idx in tqdm(range(N), desc="tracking"): + vision_feats, vision_pos, feat_sizes = self._compute_backbone_frame( + backbone_fn, images[frame_idx:frame_idx + 1], frame_idx=frame_idx) + mask_input = None + if frame_idx == 0: + mask_input = F.interpolate(initial_mask.to(device=device, dtype=dt), + size=(self.image_size, self.image_size), mode="bilinear", align_corners=False) + mask_input = (mask_input > 0.5).to(dt) + + current_out = self.track_step( + frame_idx=frame_idx, is_init_cond_frame=(frame_idx == 0), + current_vision_feats=vision_feats, current_vision_pos_embeds=vision_pos, + feat_sizes=feat_sizes, mask_inputs=mask_input, output_dict=output_dict, num_frames=N) + + if frame_idx == 0: + output_dict["cond_frame_outputs"][frame_idx] = current_out + else: + output_dict["non_cond_frame_outputs"][frame_idx] = current_out + lookback = max(self.num_maskmem, self.max_obj_ptrs_in_encoder) + for old_idx in list(output_dict["non_cond_frame_outputs"]): + if old_idx < frame_idx - lookback: + del output_dict["non_cond_frame_outputs"][old_idx] + # Move masks to CPU immediately to free VRAM + all_masks.append(current_out["pred_masks_high_res"].to(comfy.model_management.intermediate_device())) + if pbar is not None: + pbar.update(1) + + return torch.cat(all_masks, dim=0) # [N, 1, H, W] + + def track_video(self, backbone_fn, images, initial_masks, pbar=None, **kwargs): + """Track one or more objects across video frames. + + Args: + backbone_fn: callable that returns (sam2_features, sam2_positions, trunk_out) for a frame + images: [N, 3, 1008, 1008] video frames + initial_masks: [N_obj, 1, H, W] binary masks for first frame (one per object) + pbar: optional progress bar + + Returns: + [N, N_obj, image_size, image_size] predicted mask logits per frame per object + """ + N_obj = initial_masks.shape[0] + per_object = [] + for obj_idx in range(N_obj): + obj_masks = self._track_single_object( + backbone_fn, images, initial_masks[obj_idx:obj_idx + 1], pbar=pbar) + per_object.append(obj_masks) + + return torch.cat(per_object, dim=1) # [N, N_obj, H, W] + + +class SAM31Tracker(nn.Module): + """SAM3.1 multiplex tracker: decoupled memory attention, dual decoder, 16-object multiplex.""" + + def __init__(self, d_model=256, mem_dim=256, num_maskmem=7, num_multiplex=16, device=None, dtype=None, operations=None, **kwargs): + super().__init__() + self.d_model = d_model + self.mem_dim = mem_dim + self.num_maskmem = num_maskmem + self.num_multiplex = num_multiplex + self.image_size = 1008 + self.backbone_stride = 14 + self.max_obj_ptrs_in_encoder = 16 + self.sigmoid_scale_for_mem_enc = 2.0 + self.sigmoid_bias_for_mem_enc = -1.0 + + # Memory attention (decoupled cross-attention, 8 heads matching reference) + self.transformer = DecoupledMemoryTransformer(d_model, num_heads=8, dim_ff=2048, num_layers=4, + device=device, dtype=dtype, operations=operations) + + # Propagation decoder (multiplex: 16 objects, multimask_outputs_only) + self.sam_mask_decoder = MultiplexMaskDecoder(d_model, num_multiplex, num_multimask_outputs=3, + device=device, dtype=dtype, operations=operations) + # Interactive decoder (single object, same as SAM3) + self.interactive_sam_mask_decoder = SAMMaskDecoder(d_model, num_multimask_outputs=3, + device=device, dtype=dtype, operations=operations) + self.interactive_sam_prompt_encoder = SAMPromptEncoder(d_model, device=device, dtype=dtype, operations=operations) + + # Memory backbone (mem_dim=256, no out_proj compression) + self.maskmem_backbone = MemoryBackbone(d_model, in_chans=num_multiplex * 2, channels=[16, 64, 256, 1024], + device=device, dtype=dtype, operations=operations) + + # Standalone parameters + self.maskmem_tpos_enc = nn.Parameter(torch.zeros(num_maskmem, 1, 1, mem_dim, device=device, dtype=dtype)) + self.no_obj_embed_spatial = nn.Parameter(torch.zeros(num_multiplex, mem_dim, device=device, dtype=dtype)) + self.interactivity_no_mem_embed = nn.Parameter(torch.zeros(1, 1, d_model, device=device, dtype=dtype)) + + # Object pointer projection + self.obj_ptr_proj = MLP(d_model, d_model, d_model, 3, device=device, dtype=dtype, operations=operations) + self.obj_ptr_tpos_proj = operations.Linear(d_model, mem_dim, device=device, dtype=dtype) + self.no_obj_ptr_linear = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.interactive_obj_ptr_proj = MLP(d_model, d_model, d_model, 3, device=device, dtype=dtype, operations=operations) + + # Interactive mask downsample + self.interactive_mask_downsample = operations.Conv2d(1, 1, kernel_size=4, stride=4, device=device, dtype=dtype) + + # Multiplex validity embeddings + self.output_valid_embed = nn.Parameter(torch.zeros(num_multiplex, d_model, device=device, dtype=dtype)) + self.output_invalid_embed = nn.Parameter(torch.zeros(num_multiplex, d_model, device=device, dtype=dtype)) + + # Position encoding for image (used by multiplex decoder) + self.image_pe_layer = PositionEmbeddingRandom(d_model // 2) + + def _no_obj_blend(self, obj_ptr, is_obj): + alpha = is_obj.to(obj_ptr.dtype) + return torch.lerp(self.no_obj_ptr_linear(obj_ptr), obj_ptr, alpha) + + def _forward_sam_heads(self, backbone_features, point_inputs=None, mask_inputs=None, box_inputs=None, + high_res_features=None, multimask_output=False): + return forward_sam_heads(backbone_features, self.interactive_sam_prompt_encoder, self.interactive_sam_mask_decoder, + self.interactive_obj_ptr_proj, self._no_obj_blend, self.image_size, + point_inputs, mask_inputs, box_inputs, high_res_features, multimask_output) + + def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs): + return use_mask_as_output(backbone_features, high_res_features, mask_inputs, + self.interactive_mask_downsample, self.interactive_sam_prompt_encoder, + self.interactive_sam_mask_decoder, self.interactive_obj_ptr_proj, + self._no_obj_blend, self.image_size, self.backbone_stride) + + def _prepare_memory_conditioned_features(self, frame_idx, is_init_cond_frame, current_vision_feats, + current_vision_pos_embeds, feat_sizes, output_dict, num_frames, + multiplex_state=None): + B = current_vision_feats[-1].shape[0] + C = self.d_model + H, W = feat_sizes[-1] + device = current_vision_feats[-1].device + num_buc = multiplex_state.num_buckets if multiplex_state is not None else None + + if self.num_maskmem == 0: + return current_vision_feats[-1].permute(0, 2, 1).view(B, C, H, W) + + if is_init_cond_frame: + pix_feat = current_vision_feats[-1] + cast_to_input(self.interactivity_no_mem_embed, current_vision_feats[-1]) + return to_spatial(pix_feat, H, W) + + to_cat_memory, to_cat_memory_pos, to_cat_image_feat, to_cat_image_pos, cond_outputs = collect_memory_tokens( + output_dict, frame_idx, self.num_maskmem, self.maskmem_tpos_enc, device, + collect_image_feats=True, tpos_v2=True, num_buckets=num_buc) + + max_obj_ptrs = min(num_frames, self.max_obj_ptrs_in_encoder) + pos_and_ptrs = [] + for t, out in cond_outputs.items(): + if t <= frame_idx and "obj_ptr" in out: + ptr = out["obj_ptr"].to(device) + if num_buc is not None: + ptr = _pad_to_buckets(ptr, num_buc) + pos_and_ptrs.append(((frame_idx - t), ptr)) + for t_diff in range(1, max_obj_ptrs): + t = frame_idx - t_diff + if t < 0: + break + out = output_dict["non_cond_frame_outputs"].get(t, None) + if out is not None and "obj_ptr" in out: + ptr = out["obj_ptr"].to(device) + if num_buc is not None: + ptr = _pad_to_buckets(ptr, num_buc) + pos_and_ptrs.append((t_diff, ptr)) + + num_obj_ptr_tokens = 0 + if len(pos_and_ptrs) > 0: + pos_list, ptrs_list = zip(*pos_and_ptrs) + obj_ptrs = torch.stack(ptrs_list, dim=1) # [num_buckets, N, M, C] + B_ptr = obj_ptrs.shape[0] + N_ptrs = obj_ptrs.shape[1] + M = obj_ptrs.shape[2] + obj_ptrs = obj_ptrs.reshape(B_ptr, N_ptrs * M, -1) + obj_pos = compute_tpos_enc(list(pos_list), device, self.d_model, self.obj_ptr_tpos_proj, + max_abs_pos=max_obj_ptrs, dtype=current_vision_feats[-1].dtype) + obj_pos = obj_pos.unsqueeze(0).expand(B_ptr, -1, -1) + obj_pos = obj_pos.unsqueeze(2).expand(-1, -1, M, -1).reshape(B_ptr, N_ptrs * M, -1) + to_cat_memory.append(obj_ptrs) + to_cat_memory_pos.append(obj_pos) + num_obj_ptr_tokens = obj_ptrs.shape[1] + + if len(to_cat_memory) == 0: + pix_feat = current_vision_feats[-1] + cast_to_input(self.interactivity_no_mem_embed, current_vision_feats[-1]) + return to_spatial(pix_feat, H, W) + + memory = torch.cat(to_cat_memory, dim=1) + memory_pos = torch.cat(to_cat_memory_pos, dim=1) + + # Expand vision features to num_buckets if memory has more buckets than B + mem_B = memory.shape[0] + x = current_vision_feats[-1] + x_pos = current_vision_pos_embeds[-1] + if x.shape[0] < mem_B: + x = x.expand(mem_B, -1, -1) + x_pos = x_pos.expand(mem_B, -1, -1) + + if len(to_cat_image_feat) > 0: + # Decoupled cross-attention: separate image features from memory + memory_image = cast_to_input(torch.cat(to_cat_image_feat, dim=1), x) + memory_image_pos = cast_to_input(torch.cat(to_cat_image_pos, dim=1), x) + if memory_image.shape[0] < mem_B: + memory_image = memory_image.expand(mem_B, -1, -1) + memory_image_pos = memory_image_pos.expand(mem_B, -1, -1) + pix_feat_with_mem = self.transformer.encoder( + x=x, + memory=cast_to_input(memory, x), + memory_pos=cast_to_input(memory_pos, x), + src_pos=cast_to_input(x_pos, x), + num_k_exclude_rope=num_obj_ptr_tokens, + memory_image=memory_image, + memory_image_pos=memory_image_pos, + ) + else: + pix_feat_with_mem = self.transformer.encoder( + x=x, + memory=memory, + memory_pos=memory_pos, + src_pos=x_pos, + num_k_exclude_rope=num_obj_ptr_tokens, + ) + return to_spatial(pix_feat_with_mem, H, W) + + def _encode_new_memory(self, pix_feat, pred_masks_high_res, object_score_logits, is_mask_from_pts=False, + multiplex_state=None, is_conditioning=False, cond_obj_mask=None): + if is_mask_from_pts: + mask_for_mem = (pred_masks_high_res > 0).to(pix_feat.dtype) + else: + mask_for_mem = torch.sigmoid(pred_masks_high_res) + mask_for_mem.mul_(self.sigmoid_scale_for_mem_enc).add_(self.sigmoid_bias_for_mem_enc) + + # Mux masks: [N_obj, 1, H, W] -> [num_buckets, M, H, W] + mux_masks = multiplex_state.mux(mask_for_mem[:, 0]) + + # Conditioning channel: 1.0 = clean mask (trust it), 0.0 = propagation (noisy) + N_obj = mask_for_mem.shape[0] + cond_values = torch.full((N_obj,), 0.0, device=mask_for_mem.device, dtype=mask_for_mem.dtype) + if is_conditioning: + cond_values[:] = 1.0 + elif cond_obj_mask is not None: + cond_values[cond_obj_mask] = 1.0 + cond_spatial = cond_values.view(-1, 1, 1, 1).expand_as(mask_for_mem[:, 0:1, :, :]).squeeze(1) + mux_cond = multiplex_state.mux(cond_spatial) # [num_buckets, M, H, W] + mux_input = torch.cat([mux_masks, mux_cond], dim=1) # [num_buckets, 2*M, H, W] + + maskmem_out = self.maskmem_backbone(pix_feat, mux_input, skip_mask_sigmoid=True) + maskmem_features = maskmem_out["vision_features"] + maskmem_pos_enc = maskmem_out["vision_pos_enc"] + + # Add no_obj_embed_spatial for occluded objects + is_obj = (object_score_logits > 0).float() # [N_obj, 1] + mux_is_obj = multiplex_state.mux(is_obj) # [num_buckets, M, 1] + no_obj_embed = cast_to_input(self.no_obj_embed_spatial, maskmem_features) # [M, C] + no_obj_spatial = no_obj_embed.unsqueeze(0)[..., None, None] # [1, M, C, 1, 1] + # Expand and sum across multiplex slots weighted by (1 - is_obj) + alpha = mux_is_obj[..., None, None] # [num_buckets, M, 1, 1, 1] + per_slot_no_obj = ((1 - alpha) * no_obj_spatial).sum(dim=1) # [num_buckets, C, 1, 1] + maskmem_features = maskmem_features + per_slot_no_obj.expand_as(maskmem_features) + + return maskmem_features, maskmem_pos_enc + + def _forward_propagation(self, backbone_features, high_res_features=None, multiplex_state=None): + """Propagation path using the multiplex SAM decoder (no prompts).""" + B = backbone_features.shape[0] + device = backbone_features.device + + # Suppression embeddings from valid object mask + valid_mask = cast_to_input(multiplex_state.get_valid_object_mask().unsqueeze(-1).float(), backbone_features) + output_valid = cast_to_input(self.output_valid_embed, backbone_features).unsqueeze(0) + output_invalid = cast_to_input(self.output_invalid_embed, backbone_features).unsqueeze(0) + extra_embed = valid_mask * output_valid + (1 - valid_mask) * output_invalid + + image_pe = self.image_pe_layer((backbone_features.shape[-2], backbone_features.shape[-1]), device=backbone_features.device) + image_pe = cast_to_input(image_pe, backbone_features) + + masks, iou_pred, sam_tokens_out, object_score_logits = self.sam_mask_decoder( + image_embeddings=backbone_features, image_pe=image_pe, + sparse_prompt_embeddings=torch.empty(B, 0, self.d_model, device=device, dtype=backbone_features.dtype), + dense_prompt_embeddings=torch.zeros(B, self.d_model, *backbone_features.shape[-2:], device=device, dtype=backbone_features.dtype), + high_res_features=high_res_features, multimask_output=True, return_all=True, + extra_per_object_embeddings=extra_embed.expand(B, -1, -1), + ) + # masks: [B=num_buckets, M, T, H, W] + # Demux to per-object: [N_obj, T, H, W] + masks_obj = multiplex_state.demux(masks) + iou_obj = multiplex_state.demux(iou_pred) + score_obj = multiplex_state.demux(object_score_logits) + tokens_obj = multiplex_state.demux(sam_tokens_out) + + # Select best mask by IoU for each object + best_idx = torch.argmax(iou_obj, dim=-1) # [N_obj] + N_obj = masks_obj.shape[0] + obj_range = torch.arange(N_obj, device=device) + low_res_masks = masks_obj[obj_range, best_idx].unsqueeze(1) # [N_obj, 1, H, W] + # Suppress masks for objects with low confidence + is_obj = score_obj > 0 + low_res_masks = torch.where(is_obj[:, :, None, None], low_res_masks, + torch.tensor(NO_OBJ_SCORE, device=device, dtype=low_res_masks.dtype)) + high_res_masks = F.interpolate(low_res_masks.float(), size=(self.image_size, self.image_size), mode="bilinear", align_corners=False) + + # Object pointer: compute per-object, mux for storage as [num_buckets, M, C] + sam_token = tokens_obj[:, 0] # [N_obj, C] + obj_ptr = self.obj_ptr_proj(sam_token) + is_obj = (score_obj > 0).float() + no_obj = self.no_obj_ptr_linear(obj_ptr) + obj_ptr = is_obj * obj_ptr + (1 - is_obj) * no_obj + obj_ptr_muxed = multiplex_state.mux(obj_ptr) # [num_buckets, M, C] + + return low_res_masks, high_res_masks, obj_ptr_muxed, score_obj + + def track_step(self, frame_idx, is_init_cond_frame, current_vision_feats, current_vision_pos_embeds, + feat_sizes, mask_inputs, output_dict, num_frames, point_inputs=None, + interactive_high_res=None, interactive_backbone=None, propagation_high_res=None, + multiplex_state=None, run_mem_encoder=True): + current_out = {} + H, W = feat_sizes[-1] + + if mask_inputs is not None: + # Conditioning frame: use interactive features if available, else propagation + if interactive_backbone is not None: + pix_feat = interactive_backbone + # Add no_mem_embed for interactive path + pix_flat = pix_feat.flatten(2) + bf = pix_flat.permute(0, 2, 1) + cast_to_input(self.interactivity_no_mem_embed, pix_flat) + pix_feat = to_spatial(bf, H, W) + hi_res = interactive_high_res + else: + # Fallback: interactive backbone not available (e.g. called outside track_video). + # Propagation features work but may produce lower-quality conditioning. + pix_feat = to_spatial(current_vision_feats[-1], H, W) + hi_res = propagation_high_res + sam_outputs = self._use_mask_as_output(pix_feat, hi_res, mask_inputs) + elif point_inputs is not None: + # Interactive path: use interactive SAM decoder + pix_feat_with_mem = self._prepare_memory_conditioned_features( + frame_idx=frame_idx, is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats, current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, output_dict=output_dict, num_frames=num_frames, + multiplex_state=multiplex_state, + ) + hi_res = interactive_high_res if interactive_high_res is not None else propagation_high_res + num_pts = point_inputs["point_labels"].size(1) + multimask_output = is_init_cond_frame and 0 < num_pts <= 1 + sam_outputs = self._forward_sam_heads( + backbone_features=pix_feat_with_mem, point_inputs=point_inputs, + high_res_features=hi_res, multimask_output=multimask_output, + ) + else: + # Propagation path: use multiplex SAM decoder with propagation features + pix_feat_with_mem = self._prepare_memory_conditioned_features( + frame_idx=frame_idx, is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats, current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, output_dict=output_dict, num_frames=num_frames, + multiplex_state=multiplex_state, + ) + sam_outputs = self._forward_propagation(pix_feat_with_mem, propagation_high_res, + multiplex_state=multiplex_state) + + (low_res_masks, high_res_masks, obj_ptr, object_score_logits) = sam_outputs + + # Mux obj_ptr if it came from interactive path (shape [B, C]) vs propagation ([num_buckets, M, C]) + if multiplex_state is not None and obj_ptr.dim() == 2: + obj_ptr = multiplex_state.mux(obj_ptr) # [N_obj, C] -> [num_buckets, M, C] + + # Encode memory (can be deferred with run_mem_encoder=False) + if run_mem_encoder and self.num_maskmem > 0: + pix_feat = to_spatial(current_vision_feats[-1], H, W) + maskmem_features, maskmem_pos_enc = self._encode_new_memory( + pix_feat=pix_feat, pred_masks_high_res=high_res_masks, + object_score_logits=object_score_logits, + is_mask_from_pts=(point_inputs is not None), + multiplex_state=multiplex_state, + is_conditioning=(mask_inputs is not None), + ) + current_out["maskmem_features"] = maskmem_features + current_out["maskmem_pos_enc"] = maskmem_pos_enc + else: + current_out["maskmem_features"] = None + current_out["maskmem_pos_enc"] = None + + # Store propagation image features for decoupled memory attention + current_out["image_features"] = current_vision_feats[-1] # [B, HW, C] + current_out["image_pos_enc"] = current_vision_pos_embeds[-1] # [B, HW, C] + + current_out["pred_masks"] = low_res_masks + current_out["pred_masks_high_res"] = high_res_masks + current_out["obj_ptr"] = obj_ptr + current_out["object_score_logits"] = object_score_logits + + return current_out + + def _compute_backbone_frame(self, backbone_fn, frame, frame_idx=None): + vision_feats, vision_pos, feat_sizes, features, trunk_out = _compute_backbone(backbone_fn, frame, frame_idx) + return vision_feats, vision_pos, feat_sizes, list(features[:-1]), trunk_out + + @staticmethod + def _suppress_recently_occluded(low_res_masks, last_occluded, frame_idx, threshold=0.3): + """Suppress overlapping masks for objects that were most recently occluded. + Prevents corrupted masks from occluded objects from contaminating other objects.""" + N_obj = low_res_masks.shape[0] + if N_obj <= 1: + return low_res_masks + binary = low_res_masks[:, 0] > 0 # [N_obj, H, W] + iou = _compute_mask_overlap(low_res_masks[:, 0], low_res_masks[:, 0]) + overlapping = torch.triu(iou >= threshold, diagonal=1) # [N, N] upper triangle + last_occ_i = last_occluded.unsqueeze(1) # [N, 1] + last_occ_j = last_occluded.unsqueeze(0) # [1, N] + # Suppress the more recently occluded object in each overlapping pair + suppress_i = overlapping & (last_occ_i > last_occ_j) & (last_occ_j > -1) + suppress_j = overlapping & (last_occ_j > last_occ_i) & (last_occ_i > -1) + to_suppress = suppress_i.any(dim=1) | suppress_j.any(dim=0) + # Update last_occluded for occluded/suppressed objects + is_empty = ~binary.any(dim=(-1, -2)) + newly_occluded = is_empty | to_suppress + last_occluded[newly_occluded] = frame_idx + # Suppress masks + low_res_masks[to_suppress] = -10.0 + return low_res_masks + + def _deferred_memory_encode(self, current_out, N_obj, vision_feats, feat_sizes, mux_state, device, + cond_obj_mask=None): + """Deferred memory encoding for propagation frames. cond_obj_mask: per-object bool for conditioning.""" + low_res_masks = current_out["pred_masks"] # [N_obj, 1, H_low, W_low] + + if N_obj > 1: + lr = low_res_masks.squeeze(1) # [N_obj, H, W] + max_obj = torch.argmax(lr, dim=0, keepdim=True) + batch_inds = torch.arange(N_obj, device=device)[:, None, None] + pixel_nol = torch.where(max_obj == batch_inds, lr, torch.clamp(lr, max=-10.0)) + area_before = (lr > 0).sum(dim=(-1, -2)).float().clamp(min=1) + area_after = (pixel_nol > 0).sum(dim=(-1, -2)).float() + shrink_ok = (area_after / area_before) >= 0.3 + low_res_masks = torch.where( + shrink_ok[:, None, None, None].expand_as(low_res_masks), + low_res_masks, torch.clamp(low_res_masks, max=-10.0)) + + interpol_size = self.maskmem_backbone.mask_downsampler.interpol_size + mem_masks = F.interpolate(low_res_masks, size=interpol_size, + mode="bilinear", align_corners=False) + + obj_scores = torch.where( + (mem_masks > 0).any(dim=(-1, -2)), 10.0, -10.0) + + pix_feat = to_spatial(vision_feats[-1], feat_sizes[-1][0], feat_sizes[-1][1]) + maskmem_features, maskmem_pos_enc = self._encode_new_memory( + pix_feat=pix_feat, pred_masks_high_res=mem_masks, + object_score_logits=obj_scores, + multiplex_state=mux_state, cond_obj_mask=cond_obj_mask) + current_out["maskmem_features"] = maskmem_features + current_out["maskmem_pos_enc"] = maskmem_pos_enc + + def _add_detected_objects(self, new_masks, mux_state, vision_feats, feat_sizes, current_out): + """Grow MultiplexState with new detections, merge masks, re-encode memory. Modifies current_out.""" + n_old = mux_state.total_valid_entries + mux_state.add_objects(new_masks.shape[0]) + N_obj = mux_state.total_valid_entries + # Stored memory with old bucket counts is padded at read time by _pad_to_buckets + for k in ("pred_masks", "pred_masks_high_res"): + det = F.interpolate(new_masks.unsqueeze(1), size=current_out[k].shape[-2:], + mode="bilinear", align_corners=False) + current_out[k] = torch.cat([current_out[k], det], dim=0) + if self.num_maskmem > 0: + # Mark new objects as conditioning (clean detection masks) so model trusts them + cond_mask = torch.zeros(N_obj, dtype=torch.bool, device=new_masks.device) + cond_mask[n_old:] = True + self._deferred_memory_encode(current_out, N_obj, vision_feats, feat_sizes, + mux_state, new_masks.device, cond_obj_mask=cond_mask) + + def _condition_with_masks(self, masks, frame_idx, vision_feats, vision_pos, feat_sizes, + high_res_prop, output_dict, N, mux_state, backbone_obj, frame, + trunk_out, threshold=0.5): + """Condition tracker with masks on a frame.""" + mask_input = F.interpolate(masks if masks.dim() == 4 else masks.unsqueeze(1), + size=(self.image_size, self.image_size), mode="bilinear", align_corners=False) + mask_input = (mask_input > threshold).to(masks.dtype) + hi_res = lo_feat = None + if backbone_obj is not None and backbone_obj.multiplex: + _, _, itf, _ = backbone_obj(frame, tracker_mode="interactive", cached_trunk=trunk_out, tracker_only=True) + hi_res, lo_feat = itf[:-1], itf[-1] + current_out = self.track_step( + frame_idx=frame_idx, is_init_cond_frame=True, current_vision_feats=vision_feats, + current_vision_pos_embeds=vision_pos, feat_sizes=feat_sizes, mask_inputs=mask_input, + output_dict=output_dict, num_frames=N, interactive_high_res=hi_res, + interactive_backbone=lo_feat, propagation_high_res=high_res_prop, + multiplex_state=mux_state, run_mem_encoder=True) + output_dict["cond_frame_outputs"][frame_idx] = current_out + return current_out + + def _match_and_add_detections(self, det_masks, det_scores, current_out, mux_state, + vision_feats, feat_sizes, device, max_objects=0, + keep_alive=None): + """Match detections against tracked masks, add new objects, recondition degraded tracks. + Updates keep_alive counters: +1 for matched tracks, -1 for unmatched.""" + N_obj = mux_state.total_valid_entries + if det_masks.shape[0] == 0: + if keep_alive is not None: + for i in range(N_obj): + keep_alive[i] = max(-4, keep_alive.get(i, 0) - 1) + return [] + + # Match at low-res (like reference) + trk_masks = current_out["pred_masks"][:, 0] # [N_obj, H_low, W_low] + det_resized = F.interpolate(det_masks.unsqueeze(1), size=trk_masks.shape[-2:], + mode="bilinear", align_corners=False)[:, 0] + overlap = _compute_mask_overlap(det_resized, trk_masks) + + # Update keep_alive and find matched tracks + matched = set() + if overlap.shape[1] > 0: + matched = set((overlap >= 0.5).any(dim=0).nonzero(as_tuple=True)[0].tolist()) + if keep_alive is not None: + for i in range(N_obj): + if i in matched: + keep_alive[i] = min(8, keep_alive.get(i, 0) + 1) + else: + keep_alive[i] = max(-4, keep_alive.get(i, 0) - 1) + + # Recondition: high-confidence detections (>=0.8) with high overlap refresh tracked masks + reconditioned = False + if det_scores is not None and overlap.shape[1] > 0: + HIGH_CONF = 0.8 + for det_idx in range(overlap.shape[0]): + if det_scores[det_idx] < HIGH_CONF: + continue + best_trk = overlap[det_idx].argmax().item() + if overlap[det_idx, best_trk] >= 0.5: + # Replace tracked mask with fresh detection mask + current_out["pred_masks"][best_trk] = det_resized[det_idx].unsqueeze(0) + det_hr = F.interpolate(det_masks[det_idx:det_idx+1].unsqueeze(1), + size=current_out["pred_masks_high_res"].shape[-2:], + mode="bilinear", align_corners=False) + current_out["pred_masks_high_res"][best_trk] = det_hr[0] + reconditioned = True + + # Re-encode memory if any tracks were reconditioned + if reconditioned and self.num_maskmem > 0: + self._deferred_memory_encode(current_out, N_obj, vision_feats, feat_sizes, mux_state, device) + + # Add new detections (not matching any track) + if max_objects > 0 and N_obj >= max_objects: + return [] + max_overlap = overlap.max(dim=1)[0] if overlap.shape[1] > 0 else torch.zeros(overlap.shape[0], device=device) + new_dets = max_overlap < 0.5 + if new_dets.any(): + if max_objects > 0: + slots = max_objects - N_obj + new_dets = new_dets & (torch.cumsum(new_dets.int(), 0) <= slots) + self._add_detected_objects(det_masks[new_dets], mux_state, + vision_feats, feat_sizes, current_out) + if keep_alive is not None: + for i in range(N_obj, mux_state.total_valid_entries): + keep_alive[i] = 1 + return det_scores[new_dets].tolist() if det_scores is not None else [0.0] * new_dets.sum().item() + return [] + + def track_video_with_detection(self, backbone_fn, images, initial_masks, detect_fn=None, + new_det_thresh=0.5, max_objects=0, detect_interval=1, + backbone_obj=None, pbar=None): + """Track with optional per-frame detection. Returns [N, max_N_obj, H, W] mask logits.""" + N, device, dt = images.shape[0], images.device, images.dtype + output_dict = {"cond_frame_outputs": {}, "non_cond_frame_outputs": {}} + all_masks = [] + idev = comfy.model_management.intermediate_device() + mux_state = None + if initial_masks is not None: + mux_state = MultiplexState(initial_masks.shape[0], self.num_multiplex, device, dt) + obj_scores = [] # per-object detection score (1.0 for initial masks) + keep_alive = {} if detect_fn is not None else None + last_occluded = torch.empty(0, device=device, dtype=torch.long) # per-object last occluded frame + + # Prefetch next frame's backbone on a separate CUDA stream + prefetch = False + backbone_stream = None + if comfy.model_management.is_device_cuda(device): + try: + backbone_stream = torch.cuda.Stream(device=device) + prefetch = True + except RuntimeError: + pass + cur_bb = self._compute_backbone_frame(backbone_fn, images[0:1], frame_idx=0) + + for frame_idx in tqdm(range(N), desc="tracking"): + vision_feats, vision_pos, feat_sizes, high_res_prop, trunk_out = cur_bb + + # Start next frame's backbone on separate stream (overlaps with current frame's work) + if prefetch and frame_idx + 1 < N: + backbone_stream.wait_stream(torch.cuda.current_stream(device)) + with torch.cuda.stream(backbone_stream): + next_bb = self._compute_backbone_frame( + backbone_fn, images[frame_idx + 1:frame_idx + 2], frame_idx=frame_idx + 1) + + # Per-frame detection with NMS (skip if no detect_fn, or interval/max not met) + det_masks = torch.empty(0, device=device) + det_scores = None + run_det = (detect_fn is not None + and frame_idx % max(detect_interval, 1) == 0 + and not (max_objects > 0 and mux_state is not None + and mux_state.total_valid_entries >= max_objects)) + if run_det: + det_out = detect_fn(trunk_out) + scores = det_out["scores"][0].sigmoid() + keep = scores > new_det_thresh + det_masks, det_scores = det_out["masks"][0][keep], scores[keep] + if det_masks.shape[0] > 1: + det_masks, det_scores = _nms_masks(det_masks, det_scores) + + if frame_idx == 0 and initial_masks is not None: + current_out = self._condition_with_masks( + initial_masks.to(device=device, dtype=dt), frame_idx, vision_feats, vision_pos, + feat_sizes, high_res_prop, output_dict, N, mux_state, backbone_obj, + images[frame_idx:frame_idx + 1], trunk_out) + last_occluded = torch.full((mux_state.total_valid_entries,), -1, device=device, dtype=torch.long) + obj_scores = [1.0] * mux_state.total_valid_entries + if keep_alive is not None: + for i in range(mux_state.total_valid_entries): + keep_alive[i] = 8 + elif mux_state is None or mux_state.total_valid_entries == 0: + if det_masks.shape[0] > 0: + if max_objects > 0: + det_scores = det_scores[:max_objects] + det_masks = det_masks[:max_objects] + mux_state = MultiplexState(det_masks.shape[0], self.num_multiplex, device, dt) + current_out = self._condition_with_masks( + det_masks, frame_idx, vision_feats, vision_pos, feat_sizes, high_res_prop, + output_dict, N, mux_state, backbone_obj, + images[frame_idx:frame_idx + 1], trunk_out, threshold=0.0) + last_occluded = torch.full((mux_state.total_valid_entries,), -1, device=device, dtype=torch.long) + obj_scores = det_scores[:mux_state.total_valid_entries].tolist() + if keep_alive is not None: + for i in range(mux_state.total_valid_entries): + keep_alive[i] = 1 + else: + all_masks.append(None) + if pbar is not None: + pbar.update(1) + # Skip to backbone advance at end of loop + if frame_idx + 1 < N: + if prefetch: + torch.cuda.current_stream(device).wait_stream(backbone_stream) + cur_bb = next_bb + else: + cur_bb = self._compute_backbone_frame(backbone_fn, images[frame_idx + 1:frame_idx + 2], frame_idx=frame_idx + 1) + continue + else: + N_obj = mux_state.total_valid_entries + current_out = self.track_step( + frame_idx=frame_idx, is_init_cond_frame=False, current_vision_feats=vision_feats, + current_vision_pos_embeds=vision_pos, feat_sizes=feat_sizes, mask_inputs=None, + output_dict=output_dict, num_frames=N, propagation_high_res=high_res_prop, + multiplex_state=mux_state, run_mem_encoder=False) + current_out["pred_masks"] = fill_holes_in_mask_scores( + current_out["pred_masks"], max_area=16) + if last_occluded.shape[0] == N_obj and N_obj > 1: + self._suppress_recently_occluded( + current_out["pred_masks"], last_occluded, frame_idx) + if self.num_maskmem > 0: + self._deferred_memory_encode(current_out, N_obj, vision_feats, feat_sizes, mux_state, device) + output_dict["non_cond_frame_outputs"][frame_idx] = current_out + lookback = max(self.num_maskmem, self.max_obj_ptrs_in_encoder) + for old_idx in list(output_dict["non_cond_frame_outputs"]): + if old_idx < frame_idx - lookback: + del output_dict["non_cond_frame_outputs"][old_idx] + n_before = mux_state.total_valid_entries + new_obj_scores = self._match_and_add_detections(det_masks, det_scores, current_out, mux_state, + vision_feats, feat_sizes, device, max_objects, + keep_alive if run_det else None) + n_added = mux_state.total_valid_entries - n_before + if n_added > 0: + last_occluded = torch.cat([last_occluded, + torch.full((n_added,), -1, device=device, dtype=torch.long)]) + obj_scores.extend(new_obj_scores) + + masks_out = current_out["pred_masks_high_res"][:, 0] + if keep_alive is not None: + for i in range(masks_out.shape[0]): + if keep_alive.get(i, 0) <= 0: + masks_out[i] = NO_OBJ_SCORE + N_obj_now = mux_state.total_valid_entries if mux_state is not None else 0 + if N_obj_now > 0: + all_masks.append(pack_masks(masks_out).to(idev)) + else: + all_masks.append(None) + if pbar is not None: + pbar.update(1) + + # Next frame's backbone + if frame_idx + 1 < N: + if prefetch: + torch.cuda.current_stream(device).wait_stream(backbone_stream) + cur_bb = next_bb + else: + cur_bb = self._compute_backbone_frame(backbone_fn, images[frame_idx + 1:frame_idx + 2], frame_idx=frame_idx + 1) + + if not all_masks or all(m is None for m in all_masks): + return {"packed_masks": None, "n_frames": N, "scores": []} + + max_obj = max(m.shape[0] for m in all_masks if m is not None) + sample = next(m for m in all_masks if m is not None) + empty_packed = torch.zeros(max_obj, *sample.shape[1:], dtype=torch.uint8, device=sample.device) + for i, m in enumerate(all_masks): + if m is None: + all_masks[i] = empty_packed + elif m.shape[0] < max_obj: + pad = torch.zeros(max_obj - m.shape[0], *m.shape[1:], dtype=torch.uint8, device=m.device) + all_masks[i] = torch.cat([m, pad], dim=0) + return {"packed_masks": torch.stack(all_masks, dim=0), "n_frames": N, "scores": obj_scores} diff --git a/comfy/model_base.py b/comfy/model_base.py index 5c2668ba9312..787ea11452e8 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -54,6 +54,7 @@ import comfy.ldm.ace.ace_step15 import comfy.ldm.rt_detr.rtdetr_v4 import comfy.ldm.ernie.model +import comfy.ldm.sam3.detector import comfy.model_management import comfy.patcher_extension @@ -578,8 +579,8 @@ class Stable_Zero123(BaseModel): def __init__(self, model_config, model_type=ModelType.EPS, device=None, cc_projection_weight=None, cc_projection_bias=None): super().__init__(model_config, model_type, device=device) self.cc_projection = comfy.ops.manual_cast.Linear(cc_projection_weight.shape[1], cc_projection_weight.shape[0], dtype=self.get_dtype(), device=device) - self.cc_projection.weight.copy_(cc_projection_weight) - self.cc_projection.bias.copy_(cc_projection_bias) + self.cc_projection.weight = torch.nn.Parameter(cc_projection_weight.clone()) + self.cc_projection.bias = torch.nn.Parameter(cc_projection_bias.clone()) def extra_conds(self, **kwargs): out = {} @@ -1974,3 +1975,7 @@ def extra_conds(self, **kwargs): if cross_attn is not None: out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) return out + +class SAM3(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.sam3.detector.SAM3Model) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index ca06cdd1e5ce..724a241bfde0 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -718,6 +718,14 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["image_model"] = "ernie" return dit_config + if 'detector.backbone.vision_backbone.trunk.blocks.0.attn.qkv.weight' in state_dict_keys: # SAM3 / SAM3.1 + if 'detector.transformer.decoder.query_embed.weight' in state_dict_keys: + dit_config = {} + dit_config["image_model"] = "SAM3" + if 'detector.backbone.vision_backbone.propagation_convs.0.conv_1x1.weight' in state_dict_keys: + dit_config["image_model"] = "SAM31" + return dit_config + if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: return None @@ -873,6 +881,10 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal return model_config def unet_prefix_from_state_dict(state_dict): + # SAM3: detector.* and tracker.* at top level, no common prefix + if any(k.startswith("detector.") for k in state_dict) and any(k.startswith("tracker.") for k in state_dict): + return "" + candidates = ["model.diffusion_model.", #ldm/sgm models "model.model.", #audio models "net.", #cosmos diff --git a/comfy/model_management.py b/comfy/model_management.py index c7f6c4e6adc0..1a4e2f2abaa2 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1922,7 +1922,7 @@ def debug_memory_summary(): return torch.cuda.memory.memory_summary() return "" -class InterruptProcessingException(Exception): +class InterruptProcessingException(BaseException): pass interrupt_processing_mutex = threading.RLock() diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 092bc6a79b9c..8119b4ab3aca 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -774,9 +774,9 @@ def model_state_dict(self, filter_prefix=None): sd.pop(k) return sd - def patch_weight_to_device(self, key, device_to=None, inplace_update=False, return_weight=False): + def patch_weight_to_device(self, key, device_to=None, inplace_update=False, return_weight=False, force_cast=False): weight, set_func, convert_func = get_key_weight(self.model, key) - if key not in self.patches: + if key not in self.patches and not force_cast: return weight inplace_update = self.weight_inplace_update or inplace_update @@ -784,7 +784,7 @@ def patch_weight_to_device(self, key, device_to=None, inplace_update=False, retu if key not in self.backup and not return_weight: self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update) - temp_dtype = comfy.model_management.lora_compute_dtype(device_to) + temp_dtype = comfy.model_management.lora_compute_dtype(device_to) if key in self.patches else None if device_to is not None: temp_weight = comfy.model_management.cast_to_device(weight, device_to, temp_dtype, copy=True) else: @@ -792,9 +792,10 @@ def patch_weight_to_device(self, key, device_to=None, inplace_update=False, retu if convert_func is not None: temp_weight = convert_func(temp_weight, inplace=True) - out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key) + out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key) if key in self.patches else temp_weight if set_func is None: - out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=comfy.utils.string_to_seed(key)) + if key in self.patches: + out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=comfy.utils.string_to_seed(key)) if return_weight: return out_weight elif inplace_update: @@ -1705,7 +1706,7 @@ def force_load_param(self, param_key, device_to): key = key_param_name_to_key(n, param_key) if key in self.backup: comfy.utils.set_attr_param(self.model, key, self.backup[key].weight) - self.patch_weight_to_device(key, device_to=device_to) + self.patch_weight_to_device(key, device_to=device_to, force_cast=True) weight, _, _ = get_key_weight(self.model, key) if weight is not None: self.model.model_loaded_weight_memory += weight.numel() * weight.element_size() @@ -1730,6 +1731,10 @@ def force_load_param(self, param_key, device_to): m._v = vbar.alloc(v_weight_size) allocated_size += v_weight_size + for param in params: + if param not in ("weight", "bias"): + force_load_param(self, param, device_to) + else: for param in params: key = key_param_name_to_key(n, param) diff --git a/comfy/sd.py b/comfy/sd.py index 2417ac12137a..ac70abcf5a90 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -12,6 +12,7 @@ from .ldm.audio.autoencoder import AudioOobleckVAE import comfy.ldm.genmo.vae.model import comfy.ldm.lightricks.vae.causal_video_autoencoder +import comfy.ldm.lightricks.vae.audio_vae import comfy.ldm.cosmos.vae import comfy.ldm.wan.vae import comfy.ldm.wan.vae2_2 @@ -814,6 +815,24 @@ def estimate_memory(shape, dtype, num_layers = 16, kv_cache_multiplier = 2): self.downscale_index_formula = (4, 8, 8) self.memory_used_encode = lambda shape, dtype: (700 * (max(1, (shape[-3] ** 0.66 * 0.11)) * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)) self.memory_used_decode = lambda shape, dtype: (50 * (max(1, (shape[-3] ** 0.65 * 0.26)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype)) + elif "vocoder.resblocks.0.convs1.0.weight" in sd or "vocoder.vocoder.resblocks.0.convs1.0.weight" in sd: # LTX Audio + sd = comfy.utils.state_dict_prefix_replace(sd, {"audio_vae.": "autoencoder."}) + self.first_stage_model = comfy.ldm.lightricks.vae.audio_vae.AudioVAE(metadata=metadata) + self.memory_used_encode = lambda shape, dtype: (shape[2] * 330) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype) + self.latent_channels = self.first_stage_model.latent_channels + self.audio_sample_rate_output = self.first_stage_model.output_sample_rate + self.autoencoder = self.first_stage_model.autoencoder # TODO: remove hack for ltxv custom nodes + self.output_channels = 2 + self.pad_channel_value = "replicate" + self.upscale_ratio = 4096 + self.downscale_ratio = 4096 + self.latent_dim = 2 + self.process_output = lambda audio: audio + self.process_input = lambda audio: audio + self.working_dtypes = [torch.float32] + self.disable_offload = True + self.extra_1d_channel = 16 else: logging.warning("WARNING: No VAE weights detected, VAE not initalized.") self.first_stage_model = None diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 58d4ce731062..8886f32d5b1c 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1781,6 +1781,57 @@ def clip_target(self, state_dict={}): return supported_models_base.ClipTarget(comfy.text_encoders.ernie.ErnieTokenizer, comfy.text_encoders.ernie.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, ErnieImage] +class SAM3(supported_models_base.BASE): + unet_config = {"image_model": "SAM3"} + supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32] + text_encoder_key_prefix = ["detector.backbone.language_backbone."] + unet_extra_prefix = "" + + def process_clip_state_dict(self, state_dict): + clip_keys = getattr(self, "_clip_stash", {}) + clip_keys = utils.state_dict_prefix_replace(clip_keys, {"detector.backbone.language_backbone.": "", "backbone.language_backbone.": ""}, filter_keys=True) + clip_keys = utils.clip_text_transformers_convert(clip_keys, "encoder.", "sam3_clip.transformer.") + return {k: v for k, v in clip_keys.items() if not k.startswith("resizer.")} + + def process_unet_state_dict(self, state_dict): + self._clip_stash = {k: state_dict.pop(k) for k in list(state_dict.keys()) if "language_backbone" in k and "resizer" not in k} + # SAM3.1: remap tracker.model.* -> tracker.* + for k in list(state_dict.keys()): + if k.startswith("tracker.model."): + state_dict["tracker." + k[len("tracker.model."):]] = state_dict.pop(k) + # SAM3.1: remove per-block freqs_cis buffers (computed dynamically) + for k in [k for k in list(state_dict.keys()) if ".attn.freqs_cis" in k]: + state_dict.pop(k) + # Split fused QKV projections + for k in [k for k in list(state_dict.keys()) if k.endswith((".in_proj_weight", ".in_proj_bias"))]: + t = state_dict.pop(k) + base, suffix = k.rsplit(".in_proj_", 1) + s = ".weight" if suffix == "weight" else ".bias" + d = t.shape[0] // 3 + state_dict[base + ".q_proj" + s] = t[:d] + state_dict[base + ".k_proj" + s] = t[d:2*d] + state_dict[base + ".v_proj" + s] = t[2*d:] + # Remap tracker SAM decoder transformer key names to match sam.py TwoWayTransformer + for k in list(state_dict.keys()): + if "sam_mask_decoder.transformer." not in k: + continue + new_k = k.replace(".mlp.lin1.", ".mlp.0.").replace(".mlp.lin2.", ".mlp.2.").replace(".norm_final_attn.", ".norm_final.") + if new_k != k: + state_dict[new_k] = state_dict.pop(k) + return state_dict + + def get_model(self, state_dict, prefix="", device=None): + return model_base.SAM3(self, device=device) + + def clip_target(self, state_dict={}): + import comfy.text_encoders.sam3_clip + return supported_models_base.ClipTarget(comfy.text_encoders.sam3_clip.SAM3TokenizerWrapper, comfy.text_encoders.sam3_clip.SAM3ClipModelWrapper) + + +class SAM31(SAM3): + unet_config = {"image_model": "SAM31"} + + +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, ErnieImage, SAM3, SAM31] models += [SVD_img2vid] diff --git a/comfy/text_encoders/sam3_clip.py b/comfy/text_encoders/sam3_clip.py new file mode 100644 index 000000000000..11cb7d9dbff5 --- /dev/null +++ b/comfy/text_encoders/sam3_clip.py @@ -0,0 +1,97 @@ +import re +from comfy import sd1_clip + +SAM3_CLIP_CONFIG = { + "architectures": ["CLIPTextModel"], + "hidden_act": "quick_gelu", + "hidden_size": 1024, + "intermediate_size": 4096, + "num_attention_heads": 16, + "num_hidden_layers": 24, + "max_position_embeddings": 32, + "projection_dim": 512, + "vocab_size": 49408, + "layer_norm_eps": 1e-5, + "eos_token_id": 49407, +} + + +class SAM3ClipModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, max_length=32, layer="last", textmodel_json_config=SAM3_CLIP_CONFIG, special_tokens={"start": 49406, "end": 49407, "pad": 0}, return_projected_pooled=False, return_attention_masks=True, enable_attention_masks=True, model_options=model_options) + + +class SAM3Tokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(max_length=32, pad_with_end=False, pad_token=0, embedding_directory=embedding_directory, embedding_size=1024, embedding_key="sam3_clip", tokenizer_data=tokenizer_data) + self.disable_weights = True + + +def _parse_prompts(text): + """Split comma-separated prompts with optional :N max detections per category""" + text = text.replace("(", "").replace(")", "") + parts = [p.strip() for p in text.split(",") if p.strip()] + result = [] + for part in parts: + m = re.match(r'^(.+?)\s*:\s*([\d.]+)\s*$', part) + if m: + text_part = m.group(1).strip() + val = m.group(2) + max_det = max(1, round(float(val))) + result.append((text_part, max_det)) + else: + result.append((part, 1)) + return result + + +class SAM3TokenizerWrapper(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="l", tokenizer=SAM3Tokenizer, name="sam3_clip") + + def tokenize_with_weights(self, text: str, return_word_ids=False, **kwargs): + parsed = _parse_prompts(text) + if len(parsed) <= 1 and (not parsed or parsed[0][1] == 1): + return super().tokenize_with_weights(text, return_word_ids, **kwargs) + # Tokenize each prompt part separately, store per-part batches and metadata + inner = getattr(self, self.clip) + per_prompt = [] + for prompt_text, max_det in parsed: + batches = inner.tokenize_with_weights(prompt_text, return_word_ids, **kwargs) + per_prompt.append((batches, max_det)) + # Main output uses first prompt's tokens (for compatibility) + out = {self.clip_name: per_prompt[0][0], "sam3_per_prompt": per_prompt} + return out + + +class SAM3ClipModelWrapper(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs): + super().__init__(device=device, dtype=dtype, model_options=model_options, clip_name="l", clip_model=SAM3ClipModel, name="sam3_clip") + + def encode_token_weights(self, token_weight_pairs): + per_prompt = token_weight_pairs.pop("sam3_per_prompt", None) + if per_prompt is None: + return super().encode_token_weights(token_weight_pairs) + + # Encode each prompt separately, pack into extra dict + inner = getattr(self, self.clip) + multi_cond = [] + first_pooled = None + for batches, max_det in per_prompt: + out = inner.encode_token_weights(batches) + cond, pooled = out[0], out[1] + extra = out[2] if len(out) > 2 else {} + if first_pooled is None: + first_pooled = pooled + multi_cond.append({ + "cond": cond, + "attention_mask": extra.get("attention_mask"), + "max_detections": max_det, + }) + + # Return first prompt as main (for non-SAM3 consumers), all prompts in metadata + main = multi_cond[0] + main_extra = {} + if main["attention_mask"] is not None: + main_extra["attention_mask"] = main["attention_mask"] + main_extra["sam3_multi_cond"] = multi_cond + return (main["cond"], first_pooled, main_extra) diff --git a/comfy_api_nodes/apis/bytedance.py b/comfy_api_nodes/apis/bytedance.py index 3755323acbc8..eafabbefef9f 100644 --- a/comfy_api_nodes/apis/bytedance.py +++ b/comfy_api_nodes/apis/bytedance.py @@ -122,6 +122,41 @@ class TaskStatusResponse(BaseModel): usage: TaskStatusUsage | None = Field(None) +class GetAssetResponse(BaseModel): + id: str = Field(...) + name: str | None = Field(None) + url: str | None = Field(None) + asset_type: str = Field(...) + group_id: str = Field(...) + status: str = Field(...) + error: TaskStatusError | None = Field(None) + + +class SeedanceCreateVisualValidateSessionResponse(BaseModel): + session_id: str = Field(...) + h5_link: str = Field(...) + + +class SeedanceGetVisualValidateSessionResponse(BaseModel): + session_id: str = Field(...) + status: str = Field(...) + group_id: str | None = Field(None) + error_code: str | None = Field(None) + error_message: str | None = Field(None) + + +class SeedanceCreateAssetRequest(BaseModel): + group_id: str = Field(...) + url: str = Field(...) + asset_type: str = Field(...) + name: str | None = Field(None, max_length=64) + project_name: str | None = Field(None) + + +class SeedanceCreateAssetResponse(BaseModel): + asset_id: str = Field(...) + + # Dollars per 1K tokens, keyed by (model_id, has_video_input). SEEDANCE2_PRICE_PER_1K_TOKENS = { ("dreamina-seedance-2-0-260128", False): 0.007, @@ -158,10 +193,17 @@ class TaskStatusResponse(BaseModel): ("Custom", None, None), ] -# Seedance 2.0 reference video pixel count limits per model. +# Seedance 2.0 reference video pixel count limits per model and output resolution. SEEDANCE2_REF_VIDEO_PIXEL_LIMITS = { - "dreamina-seedance-2-0-260128": {"min": 409_600, "max": 927_408}, - "dreamina-seedance-2-0-fast-260128": {"min": 409_600, "max": 927_408}, + "dreamina-seedance-2-0-260128": { + "480p": {"min": 409_600, "max": 927_408}, + "720p": {"min": 409_600, "max": 927_408}, + "1080p": {"min": 409_600, "max": 2_073_600}, + }, + "dreamina-seedance-2-0-fast-260128": { + "480p": {"min": 409_600, "max": 927_408}, + "720p": {"min": 409_600, "max": 927_408}, + }, } # The time in this dictionary are given for 10 seconds duration. diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index 429c3244460b..de192c5ac2a2 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -1,5 +1,6 @@ import logging import math +import re import torch from typing_extensions import override @@ -11,9 +12,14 @@ SEEDANCE2_PRICE_PER_1K_TOKENS, SEEDANCE2_REF_VIDEO_PIXEL_LIMITS, VIDEO_TASKS_EXECUTION_TIME, + GetAssetResponse, Image2VideoTaskCreationRequest, ImageTaskCreationResponse, Seedance2TaskCreationRequest, + SeedanceCreateAssetRequest, + SeedanceCreateAssetResponse, + SeedanceCreateVisualValidateSessionResponse, + SeedanceGetVisualValidateSessionResponse, Seedream4Options, Seedream4TaskCreationRequest, TaskAudioContent, @@ -35,6 +41,7 @@ get_number_of_images, image_tensor_pair_to_batch, poll_op, + resize_video_to_pixel_budget, sync_op, upload_audio_to_comfyapi, upload_image_to_comfyapi, @@ -43,10 +50,16 @@ validate_image_aspect_ratio, validate_image_dimensions, validate_string, + validate_video_dimensions, + validate_video_duration, ) +from server import PromptServer BYTEPLUS_IMAGE_ENDPOINT = "/proxy/byteplus/api/v3/images/generations" +_VERIFICATION_POLL_TIMEOUT_SEC = 120 +_VERIFICATION_POLL_INTERVAL_SEC = 3 + SEEDREAM_MODELS = { "seedream 5.0 lite": "seedream-5-0-260128", "seedream-4-5-251128": "seedream-4-5-251128", @@ -69,9 +82,12 @@ logger = logging.getLogger(__name__) -def _validate_ref_video_pixels(video: Input.Video, model_id: str, index: int) -> None: - """Validate reference video pixel count against Seedance 2.0 model limits.""" - limits = SEEDANCE2_REF_VIDEO_PIXEL_LIMITS.get(model_id) +def _validate_ref_video_pixels(video: Input.Video, model_id: str, resolution: str, index: int) -> None: + """Validate reference video pixel count against Seedance 2.0 model limits for the selected resolution.""" + model_limits = SEEDANCE2_REF_VIDEO_PIXEL_LIMITS.get(model_id) + if not model_limits: + return + limits = model_limits.get(resolution) if not limits: return try: @@ -92,6 +108,169 @@ def _validate_ref_video_pixels(video: Input.Video, model_id: str, index: int) -> ) +async def _resolve_reference_assets( + cls: type[IO.ComfyNode], + asset_ids: list[str], +) -> tuple[dict[str, str], dict[str, str], dict[str, str]]: + """Look up each asset, validate Active status, group by asset_type. + + Returns (image_assets, video_assets, audio_assets), each mapping asset_id -> "asset://". + """ + image_assets: dict[str, str] = {} + video_assets: dict[str, str] = {} + audio_assets: dict[str, str] = {} + for i, raw_id in enumerate(asset_ids, 1): + asset_id = (raw_id or "").strip() + if not asset_id: + continue + result = await sync_op( + cls, + ApiEndpoint(path=f"/proxy/seedance/assets/{asset_id}"), + response_model=GetAssetResponse, + ) + if result.status != "Active": + extra = f" {result.error.code}: {result.error.message}" if result.error else "" + raise ValueError(f"Reference asset {i} (Id={asset_id}) is not Active (Status={result.status}).{extra}") + asset_uri = f"asset://{asset_id}" + if result.asset_type == "Image": + image_assets[asset_id] = asset_uri + elif result.asset_type == "Video": + video_assets[asset_id] = asset_uri + elif result.asset_type == "Audio": + audio_assets[asset_id] = asset_uri + return image_assets, video_assets, audio_assets + + +_ASSET_REF_RE = re.compile(r"\basset ?(\d{1,2})\b", re.IGNORECASE) + + +def _build_asset_labels( + reference_assets: dict[str, str], + image_asset_uris: dict[str, str], + video_asset_uris: dict[str, str], + audio_asset_uris: dict[str, str], + n_reference_images: int, + n_reference_videos: int, + n_reference_audios: int, +) -> dict[int, str]: + """Map asset slot number (from 'asset_N' keys) to its positional label. + + Asset entries are appended to `content` after the reference_images/videos/audios, + so their 1-indexed labels continue from the count of existing same-type refs: + one reference_images entry + one Image-type asset -> asset labelled "Image 2". + """ + image_n = n_reference_images + video_n = n_reference_videos + audio_n = n_reference_audios + labels: dict[int, str] = {} + for slot_key, raw_id in reference_assets.items(): + asset_id = (raw_id or "").strip() + if not asset_id: + continue + try: + slot_num = int(slot_key.rsplit("_", 1)[-1]) + except ValueError: + continue + if asset_id in image_asset_uris: + image_n += 1 + labels[slot_num] = f"Image {image_n}" + elif asset_id in video_asset_uris: + video_n += 1 + labels[slot_num] = f"Video {video_n}" + elif asset_id in audio_asset_uris: + audio_n += 1 + labels[slot_num] = f"Audio {audio_n}" + return labels + + +def _rewrite_asset_refs(prompt: str, labels: dict[int, str]) -> str: + """Case-insensitively replace 'assetNN' (1-2 digit) tokens with their labels.""" + if not labels: + return prompt + + def _sub(m: "re.Match[str]") -> str: + return labels.get(int(m.group(1)), m.group(0)) + + return _ASSET_REF_RE.sub(_sub, prompt) + + +async def _obtain_group_id_via_h5_auth(cls: type[IO.ComfyNode]) -> str: + session = await sync_op( + cls, + ApiEndpoint(path="/proxy/seedance/visual-validate/sessions", method="POST"), + response_model=SeedanceCreateVisualValidateSessionResponse, + ) + logger.warning("Seedance authentication required. Open link: %s", session.h5_link) + + h5_text = f"Open this link in your browser and complete face verification:\n\n{session.h5_link}" + + result = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/seedance/visual-validate/sessions/{session.session_id}"), + response_model=SeedanceGetVisualValidateSessionResponse, + status_extractor=lambda r: r.status, + completed_statuses=["completed"], + failed_statuses=["failed"], + poll_interval=_VERIFICATION_POLL_INTERVAL_SEC, + max_poll_attempts=(_VERIFICATION_POLL_TIMEOUT_SEC // _VERIFICATION_POLL_INTERVAL_SEC) - 1, + estimated_duration=_VERIFICATION_POLL_TIMEOUT_SEC - 1, + extra_text=h5_text, + ) + + if not result.group_id: + raise RuntimeError(f"Seedance session {session.session_id} completed without a group_id") + + logger.warning("Seedance authentication complete. New GroupId: %s", result.group_id) + PromptServer.instance.send_progress_text( + f"Authentication complete. New GroupId: {result.group_id}", cls.hidden.unique_id + ) + return result.group_id + + +async def _resolve_group_id(cls: type[IO.ComfyNode], group_id: str) -> str: + if group_id and group_id.strip(): + return group_id.strip() + return await _obtain_group_id_via_h5_auth(cls) + + +async def _create_seedance_asset( + cls: type[IO.ComfyNode], + *, + group_id: str, + url: str, + name: str, + asset_type: str, +) -> str: + req = SeedanceCreateAssetRequest( + group_id=group_id, + url=url, + asset_type=asset_type, + name=name or None, + ) + result = await sync_op( + cls, + ApiEndpoint(path="/proxy/seedance/assets", method="POST"), + response_model=SeedanceCreateAssetResponse, + data=req, + ) + return result.asset_id + + +async def _wait_for_asset_active(cls: type[IO.ComfyNode], asset_id: str, group_id: str) -> GetAssetResponse: + """Poll the newly created asset until its status becomes Active.""" + return await poll_op( + cls, + ApiEndpoint(path=f"/proxy/seedance/assets/{asset_id}"), + response_model=GetAssetResponse, + status_extractor=lambda r: r.status, + completed_statuses=["Active"], + failed_statuses=["Failed"], + poll_interval=5, + max_poll_attempts=1200, + extra_text=f"Waiting for asset pre-processing...\n\nasset_id: {asset_id}\n\ngroup_id: {group_id}", + ) + + def _seedance2_price_extractor(model_id: str, has_video_input: bool): """Returns a price_extractor closure for Seedance 2.0 poll_op.""" rate = SEEDANCE2_PRICE_PER_1K_TOKENS.get((model_id, has_video_input)) @@ -1224,12 +1403,27 @@ def define_schema(cls): IO.Image.Input( "first_frame", tooltip="First frame image for the video.", + optional=True, ), IO.Image.Input( "last_frame", tooltip="Last frame image for the video.", optional=True, ), + IO.String.Input( + "first_frame_asset_id", + default="", + tooltip="Seedance asset_id to use as the first frame. " + "Mutually exclusive with the first_frame image input.", + optional=True, + ), + IO.String.Input( + "last_frame_asset_id", + default="", + tooltip="Seedance asset_id to use as the last frame. " + "Mutually exclusive with the last_frame image input.", + optional=True, + ), IO.Int.Input( "seed", default=0, @@ -1282,24 +1476,54 @@ def define_schema(cls): async def execute( cls, model: dict, - first_frame: Input.Image, seed: int, watermark: bool, + first_frame: Input.Image | None = None, last_frame: Input.Image | None = None, + first_frame_asset_id: str = "", + last_frame_asset_id: str = "", ) -> IO.NodeOutput: validate_string(model["prompt"], strip_whitespace=True, min_length=1) model_id = SEEDANCE_MODELS[model["model"]] + first_frame_asset_id = first_frame_asset_id.strip() + last_frame_asset_id = last_frame_asset_id.strip() + + if first_frame is not None and first_frame_asset_id: + raise ValueError("Provide only one of first_frame or first_frame_asset_id, not both.") + if first_frame is None and not first_frame_asset_id: + raise ValueError("Either first_frame or first_frame_asset_id is required.") + if last_frame is not None and last_frame_asset_id: + raise ValueError("Provide only one of last_frame or last_frame_asset_id, not both.") + + asset_ids_to_resolve = [a for a in (first_frame_asset_id, last_frame_asset_id) if a] + image_assets: dict[str, str] = {} + if asset_ids_to_resolve: + image_assets, _, _ = await _resolve_reference_assets(cls, asset_ids_to_resolve) + for aid in asset_ids_to_resolve: + if aid not in image_assets: + raise ValueError(f"Asset {aid} is not an Image asset.") + + if first_frame_asset_id: + first_frame_url = image_assets[first_frame_asset_id] + else: + first_frame_url = await upload_image_to_comfyapi(cls, first_frame, wait_label="Uploading first frame.") + content: list[TaskTextContent | TaskImageContent] = [ TaskTextContent(text=model["prompt"]), TaskImageContent( - image_url=TaskImageContentUrl( - url=await upload_image_to_comfyapi(cls, first_frame, wait_label="Uploading first frame.") - ), + image_url=TaskImageContentUrl(url=first_frame_url), role="first_frame", ), ] - if last_frame is not None: + if last_frame_asset_id: + content.append( + TaskImageContent( + image_url=TaskImageContentUrl(url=image_assets[last_frame_asset_id]), + role="last_frame", + ), + ) + elif last_frame is not None: content.append( TaskImageContent( image_url=TaskImageContentUrl( @@ -1373,6 +1597,32 @@ def _seedance2_reference_inputs(resolutions: list[str]): min=0, ), ), + IO.Boolean.Input( + "auto_downscale", + default=False, + advanced=True, + optional=True, + tooltip="Automatically downscale reference videos that exceed the model's pixel budget " + "for the selected resolution. Aspect ratio is preserved; videos already within limits are untouched.", + ), + IO.Autogrow.Input( + "reference_assets", + template=IO.Autogrow.TemplateNames( + IO.String.Input("reference_asset"), + names=[ + "asset_1", + "asset_2", + "asset_3", + "asset_4", + "asset_5", + "asset_6", + "asset_7", + "asset_8", + "asset_9", + ], + min=0, + ), + ), ] @@ -1474,16 +1724,47 @@ async def execute( reference_images = model.get("reference_images", {}) reference_videos = model.get("reference_videos", {}) reference_audios = model.get("reference_audios", {}) + reference_assets = model.get("reference_assets", {}) + + reference_image_assets, reference_video_assets, reference_audio_assets = await _resolve_reference_assets( + cls, list(reference_assets.values()) + ) - if not reference_images and not reference_videos: - raise ValueError("At least one reference image or video is required.") + if not reference_images and not reference_videos and not reference_image_assets and not reference_video_assets: + raise ValueError("At least one reference image or video or asset is required.") + + total_images = len(reference_images) + len(reference_image_assets) + if total_images > 9: + raise ValueError( + f"Too many reference images: {total_images} " + f"(images={len(reference_images)}, image assets={len(reference_image_assets)}). Maximum is 9." + ) + total_videos = len(reference_videos) + len(reference_video_assets) + if total_videos > 3: + raise ValueError( + f"Too many reference videos: {total_videos} " + f"(videos={len(reference_videos)}, video assets={len(reference_video_assets)}). Maximum is 3." + ) + total_audios = len(reference_audios) + len(reference_audio_assets) + if total_audios > 3: + raise ValueError( + f"Too many reference audios: {total_audios} " + f"(audios={len(reference_audios)}, audio assets={len(reference_audio_assets)}). Maximum is 3." + ) model_id = SEEDANCE_MODELS[model["model"]] - has_video_input = len(reference_videos) > 0 + has_video_input = total_videos > 0 + + if model.get("auto_downscale") and reference_videos: + max_px = SEEDANCE2_REF_VIDEO_PIXEL_LIMITS.get(model_id, {}).get(model["resolution"], {}).get("max") + if max_px: + for key in reference_videos: + reference_videos[key] = resize_video_to_pixel_budget(reference_videos[key], max_px) + total_video_duration = 0.0 for i, key in enumerate(reference_videos, 1): video = reference_videos[key] - _validate_ref_video_pixels(video, model_id, i) + _validate_ref_video_pixels(video, model_id, model["resolution"], i) try: dur = video.get_duration() if dur < 1.8: @@ -1506,8 +1787,19 @@ async def execute( if total_audio_duration > 15.1: raise ValueError(f"Total reference audio duration is {total_audio_duration:.1f}s. Maximum is 15.1 seconds.") + asset_labels = _build_asset_labels( + reference_assets, + reference_image_assets, + reference_video_assets, + reference_audio_assets, + len(reference_images), + len(reference_videos), + len(reference_audios), + ) + prompt_text = _rewrite_asset_refs(model["prompt"], asset_labels) + content: list[TaskTextContent | TaskImageContent | TaskVideoContent | TaskAudioContent] = [ - TaskTextContent(text=model["prompt"]), + TaskTextContent(text=prompt_text), ] for i, key in enumerate(reference_images, 1): content.append( @@ -1548,6 +1840,21 @@ async def execute( ), ), ) + for url in reference_image_assets.values(): + content.append( + TaskImageContent( + image_url=TaskImageContentUrl(url=url), + role="reference_image", + ), + ) + for url in reference_video_assets.values(): + content.append( + TaskVideoContent(video_url=TaskVideoContentUrl(url=url)), + ) + for url in reference_audio_assets.values(): + content.append( + TaskAudioContent(audio_url=TaskAudioContentUrl(url=url)), + ) initial_response = await sync_op( cls, ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"), @@ -1602,6 +1909,156 @@ async def process_video_task( return IO.NodeOutput(await download_url_to_video_output(response.content.video_url)) +class ByteDanceCreateImageAsset(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="ByteDanceCreateImageAsset", + display_name="ByteDance Create Image Asset", + category="api node/image/ByteDance", + description=( + "Create a Seedance 2.0 personal image asset. Uploads the input image and " + "registers it in the given asset group. If group_id is empty, runs a real-person " + "H5 authentication flow to create a new group before adding the asset." + ), + inputs=[ + IO.Image.Input("image", tooltip="Image to register as a personal asset."), + IO.String.Input( + "group_id", + default="", + tooltip="Reuse an existing Seedance asset group ID to skip repeated human verification for the " + "same person. Leave empty to run real-person authentication in the browser and create a new group.", + ), + # IO.String.Input( + # "name", + # default="", + # tooltip="Asset name (up to 64 characters).", + # ), + ], + outputs=[ + IO.String.Output(display_name="asset_id"), + IO.String.Output(display_name="group_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + # is_api_node=True, + ) + + @classmethod + async def execute( + cls, + image: Input.Image, + group_id: str = "", + # name: str = "", + ) -> IO.NodeOutput: + # if len(name) > 64: + # raise ValueError("Name of asset can not be greater then 64 symbols") + validate_image_dimensions(image, min_width=300, max_width=6000, min_height=300, max_height=6000) + validate_image_aspect_ratio(image, min_ratio=(0.4, 1), max_ratio=(2.5, 1)) + resolved_group = await _resolve_group_id(cls, group_id) + asset_id = await _create_seedance_asset( + cls, + group_id=resolved_group, + url=await upload_image_to_comfyapi(cls, image), + name="", + asset_type="Image", + ) + await _wait_for_asset_active(cls, asset_id, resolved_group) + PromptServer.instance.send_progress_text( + f"Please save the asset_id and group_id for reuse.\n\nasset_id: {asset_id}\n\n" + f"group_id: {resolved_group}", + cls.hidden.unique_id, + ) + return IO.NodeOutput(asset_id, resolved_group) + + +class ByteDanceCreateVideoAsset(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="ByteDanceCreateVideoAsset", + display_name="ByteDance Create Video Asset", + category="api node/video/ByteDance", + description=( + "Create a Seedance 2.0 personal video asset. Uploads the input video and " + "registers it in the given asset group. If group_id is empty, runs a real-person " + "H5 authentication flow to create a new group before adding the asset." + ), + inputs=[ + IO.Video.Input("video", tooltip="Video to register as a personal asset."), + IO.String.Input( + "group_id", + default="", + tooltip="Reuse an existing Seedance asset group ID to skip repeated human verification for the " + "same person. Leave empty to run real-person authentication in the browser and create a new group.", + ), + # IO.String.Input( + # "name", + # default="", + # tooltip="Asset name (up to 64 characters).", + # ), + ], + outputs=[ + IO.String.Output(display_name="asset_id"), + IO.String.Output(display_name="group_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + # is_api_node=True, + ) + + @classmethod + async def execute( + cls, + video: Input.Video, + group_id: str = "", + # name: str = "", + ) -> IO.NodeOutput: + # if len(name) > 64: + # raise ValueError("Name of asset can not be greater then 64 symbols") + validate_video_duration(video, min_duration=2, max_duration=15) + validate_video_dimensions(video, min_width=300, max_width=6000, min_height=300, max_height=6000) + + w, h = video.get_dimensions() + if h > 0: + ratio = w / h + if not (0.4 <= ratio <= 2.5): + raise ValueError(f"Asset video aspect ratio (W/H) must be in [0.4, 2.5], got {ratio:.3f} ({w}x{h}).") + pixels = w * h + if not (409_600 <= pixels <= 927_408): + raise ValueError( + f"Asset video total pixels (W×H) must be in [409600, 927408], " f"got {pixels:,} ({w}x{h})." + ) + + fps = float(video.get_frame_rate()) + if not (24 <= fps <= 60): + raise ValueError(f"Asset video FPS must be in [24, 60], got {fps:.2f}.") + + resolved_group = await _resolve_group_id(cls, group_id) + asset_id = await _create_seedance_asset( + cls, + group_id=resolved_group, + url=await upload_video_to_comfyapi(cls, video), + name="", + asset_type="Video", + ) + await _wait_for_asset_active(cls, asset_id, resolved_group) + PromptServer.instance.send_progress_text( + f"Please save the asset_id and group_id for reuse.\n\nasset_id: {asset_id}\n\n" + f"group_id: {resolved_group}", + cls.hidden.unique_id, + ) + return IO.NodeOutput(asset_id, resolved_group) + + class ByteDanceExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: @@ -1615,6 +2072,8 @@ async def get_node_list(self) -> list[type[IO.ComfyNode]]: ByteDance2TextToVideoNode, ByteDance2FirstLastFrameNode, ByteDance2ReferenceNode, + ByteDanceCreateImageAsset, + ByteDanceCreateVideoAsset, ] diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 9a37ccc53407..709b3726ca8a 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -276,6 +276,7 @@ async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusRe cls, ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"), response_model=TaskStatusResponse, + max_poll_attempts=280, status_extractor=lambda r: (r.data.task_status if r.data else None), ) return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) @@ -862,7 +863,7 @@ def define_schema(cls) -> IO.Schema: ), IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]), IO.Int.Input("duration", default=5, min=3, max=15, display_mode=IO.NumberDisplay.slider), - IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True), + IO.Combo.Input("resolution", options=["4k", "1080p", "720p"], default="1080p", optional=True), IO.DynamicCombo.Input( "storyboards", options=[ @@ -904,12 +905,13 @@ def define_schema(cls) -> IO.Schema: depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution", "model_name", "generate_audio"]), expr=""" ( - $mode := (widgets.resolution = "720p") ? "std" : "pro"; + $res := widgets.resolution; + $mode := $res = "4k" ? "4k" : ($res = "720p" ? "std" : "pro"); $isV3 := $contains(widgets.model_name, "v3"); $audio := $isV3 and widgets.generate_audio; $rates := $audio - ? {"std": 0.112, "pro": 0.14} - : {"std": 0.084, "pro": 0.112}; + ? {"std": 0.112, "pro": 0.14, "4k": 0.42} + : {"std": 0.084, "pro": 0.112, "4k": 0.42}; {"type":"usd","usd": $lookup($rates, $mode) * widgets.duration} ) """, @@ -934,6 +936,8 @@ async def execute( raise ValueError("kling-video-o1 only supports durations of 5 or 10 seconds.") if generate_audio: raise ValueError("kling-video-o1 does not support audio generation.") + if resolution == "4k": + raise ValueError("kling-video-o1 does not support 4k resolution.") stories_enabled = storyboards is not None and storyboards["storyboards"] != "disabled" if stories_enabled and model_name == "kling-video-o1": raise ValueError("kling-video-o1 does not support storyboards.") @@ -963,6 +967,12 @@ async def execute( f"must equal the global duration ({duration}s)." ) + if resolution == "4k": + mode = "4k" + elif resolution == "1080p": + mode = "pro" + else: + mode = "std" response = await sync_op( cls, ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), @@ -972,7 +982,7 @@ async def execute( prompt=prompt, aspect_ratio=aspect_ratio, duration=str(duration), - mode="pro" if resolution == "1080p" else "std", + mode=mode, multi_shot=multi_shot, multi_prompt=multi_prompt_list, shot_type="customize" if multi_shot else None, @@ -1014,7 +1024,7 @@ def define_schema(cls) -> IO.Schema: optional=True, tooltip="Up to 6 additional reference images.", ), - IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True), + IO.Combo.Input("resolution", options=["4k", "1080p", "720p"], default="1080p", optional=True), IO.DynamicCombo.Input( "storyboards", options=[ @@ -1061,12 +1071,13 @@ def define_schema(cls) -> IO.Schema: depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution", "model_name", "generate_audio"]), expr=""" ( - $mode := (widgets.resolution = "720p") ? "std" : "pro"; + $res := widgets.resolution; + $mode := $res = "4k" ? "4k" : ($res = "720p" ? "std" : "pro"); $isV3 := $contains(widgets.model_name, "v3"); $audio := $isV3 and widgets.generate_audio; $rates := $audio - ? {"std": 0.112, "pro": 0.14} - : {"std": 0.084, "pro": 0.112}; + ? {"std": 0.112, "pro": 0.14, "4k": 0.42} + : {"std": 0.084, "pro": 0.112, "4k": 0.42}; {"type":"usd","usd": $lookup($rates, $mode) * widgets.duration} ) """, @@ -1093,6 +1104,8 @@ async def execute( raise ValueError("kling-video-o1 does not support durations greater than 10 seconds.") if generate_audio: raise ValueError("kling-video-o1 does not support audio generation.") + if resolution == "4k": + raise ValueError("kling-video-o1 does not support 4k resolution.") stories_enabled = storyboards is not None and storyboards["storyboards"] != "disabled" if stories_enabled and model_name == "kling-video-o1": raise ValueError("kling-video-o1 does not support storyboards.") @@ -1161,6 +1174,12 @@ async def execute( validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1)) for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference frame(s)"): image_list.append(OmniParamImage(image_url=i)) + if resolution == "4k": + mode = "4k" + elif resolution == "1080p": + mode = "pro" + else: + mode = "std" response = await sync_op( cls, ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), @@ -1170,7 +1189,7 @@ async def execute( prompt=prompt, duration=str(duration), image_list=image_list, - mode="pro" if resolution == "1080p" else "std", + mode=mode, sound="on" if generate_audio else "off", multi_shot=multi_shot, multi_prompt=multi_prompt_list, @@ -1204,7 +1223,7 @@ def define_schema(cls) -> IO.Schema: "reference_images", tooltip="Up to 7 reference images.", ), - IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True), + IO.Combo.Input("resolution", options=["4k", "1080p", "720p"], default="1080p", optional=True), IO.DynamicCombo.Input( "storyboards", options=[ @@ -1251,12 +1270,13 @@ def define_schema(cls) -> IO.Schema: depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution", "model_name", "generate_audio"]), expr=""" ( - $mode := (widgets.resolution = "720p") ? "std" : "pro"; + $res := widgets.resolution; + $mode := $res = "4k" ? "4k" : ($res = "720p" ? "std" : "pro"); $isV3 := $contains(widgets.model_name, "v3"); $audio := $isV3 and widgets.generate_audio; $rates := $audio - ? {"std": 0.112, "pro": 0.14} - : {"std": 0.084, "pro": 0.112}; + ? {"std": 0.112, "pro": 0.14, "4k": 0.42} + : {"std": 0.084, "pro": 0.112, "4k": 0.42}; {"type":"usd","usd": $lookup($rates, $mode) * widgets.duration} ) """, @@ -1282,6 +1302,8 @@ async def execute( raise ValueError("kling-video-o1 does not support durations greater than 10 seconds.") if generate_audio: raise ValueError("kling-video-o1 does not support audio generation.") + if resolution == "4k": + raise ValueError("kling-video-o1 does not support 4k resolution.") stories_enabled = storyboards is not None and storyboards["storyboards"] != "disabled" if stories_enabled and model_name == "kling-video-o1": raise ValueError("kling-video-o1 does not support storyboards.") @@ -1320,6 +1342,12 @@ async def execute( image_list: list[OmniParamImage] = [] for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"): image_list.append(OmniParamImage(image_url=i)) + if resolution == "4k": + mode = "4k" + elif resolution == "1080p": + mode = "pro" + else: + mode = "std" response = await sync_op( cls, ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), @@ -1330,7 +1358,7 @@ async def execute( aspect_ratio=aspect_ratio, duration=str(duration), image_list=image_list, - mode="pro" if resolution == "1080p" else "std", + mode=mode, sound="on" if generate_audio else "off", multi_shot=multi_shot, multi_prompt=multi_prompt_list, @@ -2860,7 +2888,7 @@ def define_schema(cls) -> IO.Schema: IO.DynamicCombo.Option( "kling-v3", [ - IO.Combo.Input("resolution", options=["1080p", "720p"]), + IO.Combo.Input("resolution", options=["4k", "1080p", "720p"], default="1080p"), IO.Combo.Input( "aspect_ratio", options=["16:9", "9:16", "1:1"], @@ -2913,7 +2941,11 @@ def define_schema(cls) -> IO.Schema: ), expr=""" ( - $rates := {"1080p": {"off": 0.112, "on": 0.168}, "720p": {"off": 0.084, "on": 0.126}}; + $rates := { + "4k": {"off": 0.42, "on": 0.42}, + "1080p": {"off": 0.112, "on": 0.168}, + "720p": {"off": 0.084, "on": 0.126} + }; $res := $lookup(widgets, "model.resolution"); $audio := widgets.generate_audio ? "on" : "off"; $rate := $lookup($lookup($rates, $res), $audio); @@ -2943,7 +2975,12 @@ async def execute( start_frame: Input.Image | None = None, ) -> IO.NodeOutput: _ = seed - mode = "pro" if model["resolution"] == "1080p" else "std" + if model["resolution"] == "4k": + mode = "4k" + elif model["resolution"] == "1080p": + mode = "pro" + else: + mode = "std" custom_multi_shot = False if multi_shot["multi_shot"] == "disabled": shot_type = None @@ -3025,6 +3062,7 @@ async def execute( cls, ApiEndpoint(path=poll_path), response_model=TaskStatusResponse, + max_poll_attempts=280, status_extractor=lambda r: (r.data.task_status if r.data else None), ) return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) @@ -3057,7 +3095,7 @@ def define_schema(cls) -> IO.Schema: IO.DynamicCombo.Option( "kling-v3", [ - IO.Combo.Input("resolution", options=["1080p", "720p"]), + IO.Combo.Input("resolution", options=["4k", "1080p", "720p"], default="1080p"), ], ), ], @@ -3089,7 +3127,11 @@ def define_schema(cls) -> IO.Schema: ), expr=""" ( - $rates := {"1080p": {"off": 0.112, "on": 0.168}, "720p": {"off": 0.084, "on": 0.126}}; + $rates := { + "4k": {"off": 0.42, "on": 0.42}, + "1080p": {"off": 0.112, "on": 0.168}, + "720p": {"off": 0.084, "on": 0.126} + }; $res := $lookup(widgets, "model.resolution"); $audio := widgets.generate_audio ? "on" : "off"; $rate := $lookup($lookup($rates, $res), $audio); @@ -3118,6 +3160,12 @@ async def execute( validate_image_aspect_ratio(end_frame, (1, 2.5), (2.5, 1)) image_url = await upload_image_to_comfyapi(cls, first_frame, wait_label="Uploading first frame") image_tail_url = await upload_image_to_comfyapi(cls, end_frame, wait_label="Uploading end frame") + if model["resolution"] == "4k": + mode = "4k" + elif model["resolution"] == "1080p": + mode = "pro" + else: + mode = "std" response = await sync_op( cls, ApiEndpoint(path="/proxy/kling/v1/videos/image2video", method="POST"), @@ -3127,7 +3175,7 @@ async def execute( image=image_url, image_tail=image_tail_url, prompt=prompt, - mode="pro" if model["resolution"] == "1080p" else "std", + mode=mode, duration=str(duration), sound="on" if generate_audio else "off", ), @@ -3140,6 +3188,7 @@ async def execute( cls, ApiEndpoint(path=f"/proxy/kling/v1/videos/image2video/{response.data.task_id}"), response_model=TaskStatusResponse, + max_poll_attempts=280, status_extractor=lambda r: (r.data.task_status if r.data else None), ) return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index 4ee896fa8afe..bbb758068cfe 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -357,13 +357,17 @@ def calculate_tokens_price_image_1_5(response: OpenAIImageGenerationResponse) -> return ((response.usage.input_tokens * 8.0) + (response.usage.output_tokens * 32.0)) / 1_000_000.0 +def calculate_tokens_price_image_2_0(response: OpenAIImageGenerationResponse) -> float | None: + return ((response.usage.input_tokens * 8.0) + (response.usage.output_tokens * 30.0)) / 1_000_000.0 + + class OpenAIGPTImage1(IO.ComfyNode): @classmethod def define_schema(cls): return IO.Schema( node_id="OpenAIGPTImage1", - display_name="OpenAI GPT Image 1.5", + display_name="OpenAI GPT Image 2", category="api node/image/OpenAI", description="Generates images synchronously via OpenAI's GPT Image endpoint.", inputs=[ @@ -401,7 +405,17 @@ def define_schema(cls): IO.Combo.Input( "size", default="auto", - options=["auto", "1024x1024", "1024x1536", "1536x1024"], + options=[ + "auto", + "1024x1024", + "1024x1536", + "1536x1024", + "2048x2048", + "2048x1152", + "1152x2048", + "3840x2160", + "2160x3840", + ], tooltip="Image size", optional=True, ), @@ -427,8 +441,8 @@ def define_schema(cls): ), IO.Combo.Input( "model", - options=["gpt-image-1", "gpt-image-1.5"], - default="gpt-image-1.5", + options=["gpt-image-1", "gpt-image-1.5", "gpt-image-2"], + default="gpt-image-2", optional=True, ), ], @@ -442,23 +456,36 @@ def define_schema(cls): ], is_api_node=True, price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["quality", "n"]), + depends_on=IO.PriceBadgeDepends(widgets=["quality", "n", "model"]), expr=""" ( $ranges := { - "low": [0.011, 0.02], - "medium": [0.046, 0.07], - "high": [0.167, 0.3] + "gpt-image-1": { + "low": [0.011, 0.02], + "medium": [0.042, 0.07], + "high": [0.167, 0.25] + }, + "gpt-image-1.5": { + "low": [0.009, 0.02], + "medium": [0.034, 0.062], + "high": [0.133, 0.22] + }, + "gpt-image-2": { + "low": [0.0048, 0.012], + "medium": [0.041, 0.112], + "high": [0.165, 0.43] + } }; - $range := $lookup($ranges, widgets.quality); - $n := widgets.n; + $range := $lookup($lookup($ranges, widgets.model), widgets.quality); + $nRaw := widgets.n; + $n := ($nRaw != null and $nRaw != 0) ? $nRaw : 1; ($n = 1) - ? {"type":"range_usd","min_usd": $range[0], "max_usd": $range[1]} + ? {"type":"range_usd","min_usd": $range[0], "max_usd": $range[1], "format": {"approximate": true}} : { "type":"range_usd", - "min_usd": $range[0], - "max_usd": $range[1], - "format": { "suffix": " x " & $string($n) & "/Run" } + "min_usd": $range[0] * $n, + "max_usd": $range[1] * $n, + "format": { "suffix": "/Run", "approximate": true } } ) """, @@ -483,10 +510,18 @@ async def execute( if mask is not None and image is None: raise ValueError("Cannot use a mask without an input image") + if model in ("gpt-image-1", "gpt-image-1.5"): + if size not in ("auto", "1024x1024", "1024x1536", "1536x1024"): + raise ValueError(f"Resolution {size} is only supported by GPT Image 2 model") + if model == "gpt-image-1": price_extractor = calculate_tokens_price_image_1 elif model == "gpt-image-1.5": price_extractor = calculate_tokens_price_image_1_5 + elif model == "gpt-image-2": + price_extractor = calculate_tokens_price_image_2_0 + if background == "transparent": + raise ValueError("Transparent background is not supported for GPT Image 2 model") else: raise ValueError(f"Unknown model: {model}") diff --git a/comfy_api_nodes/nodes_veo2.py b/comfy_api_nodes/nodes_veo2.py index 13fc1cc3682a..2ff75d9b2d2d 100644 --- a/comfy_api_nodes/nodes_veo2.py +++ b/comfy_api_nodes/nodes_veo2.py @@ -24,8 +24,9 @@ AVERAGE_DURATION_VIDEO_GEN = 32 MODELS_MAP = { "veo-2.0-generate-001": "veo-2.0-generate-001", - "veo-3.1-generate": "veo-3.1-generate-preview", - "veo-3.1-fast-generate": "veo-3.1-fast-generate-preview", + "veo-3.1-generate": "veo-3.1-generate-001", + "veo-3.1-fast-generate": "veo-3.1-fast-generate-001", + "veo-3.1-lite": "veo-3.1-lite-generate-001", "veo-3.0-generate-001": "veo-3.0-generate-001", "veo-3.0-fast-generate-001": "veo-3.0-fast-generate-001", } @@ -247,17 +248,8 @@ def status_extractor(response): raise Exception("Video generation completed but no video was returned") -class Veo3VideoGenerationNode(VeoVideoGenerationNode): - """ - Generates videos from text prompts using Google's Veo 3 API. - - Supported models: - - veo-3.0-generate-001 - - veo-3.0-fast-generate-001 - - This node extends the base Veo node with Veo 3 specific features including - audio generation and fixed 8-second duration. - """ +class Veo3VideoGenerationNode(IO.ComfyNode): + """Generates videos from text prompts using Google's Veo 3 API.""" @classmethod def define_schema(cls): @@ -279,6 +271,13 @@ def define_schema(cls): default="16:9", tooltip="Aspect ratio of the output video", ), + IO.Combo.Input( + "resolution", + options=["720p", "1080p", "4k"], + default="720p", + tooltip="Output video resolution. 4K is not available for veo-3.1-lite and veo-3.0 models.", + optional=True, + ), IO.String.Input( "negative_prompt", multiline=True, @@ -289,11 +288,11 @@ def define_schema(cls): IO.Int.Input( "duration_seconds", default=8, - min=8, + min=4, max=8, - step=1, + step=2, display_mode=IO.NumberDisplay.number, - tooltip="Duration of the output video in seconds (Veo 3 only supports 8 seconds)", + tooltip="Duration of the output video in seconds", optional=True, ), IO.Boolean.Input( @@ -332,10 +331,10 @@ def define_schema(cls): options=[ "veo-3.1-generate", "veo-3.1-fast-generate", + "veo-3.1-lite", "veo-3.0-generate-001", "veo-3.0-fast-generate-001", ], - default="veo-3.0-generate-001", tooltip="Veo 3 model to use for video generation", optional=True, ), @@ -356,21 +355,111 @@ def define_schema(cls): ], is_api_node=True, price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["model", "generate_audio"]), + depends_on=IO.PriceBadgeDepends(widgets=["model", "generate_audio", "resolution", "duration_seconds"]), expr=""" ( $m := widgets.model; + $r := widgets.resolution; $a := widgets.generate_audio; - ($contains($m,"veo-3.0-fast-generate-001") or $contains($m,"veo-3.1-fast-generate")) - ? {"type":"usd","usd": ($a ? 1.2 : 0.8)} - : ($contains($m,"veo-3.0-generate-001") or $contains($m,"veo-3.1-generate")) - ? {"type":"usd","usd": ($a ? 3.2 : 1.6)} - : {"type":"range_usd","min_usd":0.8,"max_usd":3.2} + $seconds := widgets.duration_seconds; + $pps := + $contains($m, "lite") + ? ($r = "1080p" ? ($a ? 0.08 : 0.05) : ($a ? 0.05 : 0.03)) + : $contains($m, "3.1-fast") + ? ($r = "4k" ? ($a ? 0.30 : 0.25) : $r = "1080p" ? ($a ? 0.12 : 0.10) : ($a ? 0.10 : 0.08)) + : $contains($m, "3.1-generate") + ? ($r = "4k" ? ($a ? 0.60 : 0.40) : ($a ? 0.40 : 0.20)) + : $contains($m, "3.0-fast") + ? ($a ? 0.15 : 0.10) + : ($a ? 0.40 : 0.20); + {"type":"usd","usd": $pps * $seconds} ) """, ), ) + @classmethod + async def execute( + cls, + prompt, + aspect_ratio="16:9", + resolution="720p", + negative_prompt="", + duration_seconds=8, + enhance_prompt=True, + person_generation="ALLOW", + seed=0, + image=None, + model="veo-3.0-generate-001", + generate_audio=False, + ): + if resolution == "4k" and ("lite" in model or "3.0" in model): + raise Exception("4K resolution is not supported by the veo-3.1-lite or veo-3.0 models.") + + model = MODELS_MAP[model] + + instances = [{"prompt": prompt}] + if image is not None: + image_base64 = tensor_to_base64_string(image) + if image_base64: + instances[0]["image"] = {"bytesBase64Encoded": image_base64, "mimeType": "image/png"} + + parameters = { + "aspectRatio": aspect_ratio, + "personGeneration": person_generation, + "durationSeconds": duration_seconds, + "enhancePrompt": True, + "generateAudio": generate_audio, + } + if negative_prompt: + parameters["negativePrompt"] = negative_prompt + if seed > 0: + parameters["seed"] = seed + if "veo-3.1" in model: + parameters["resolution"] = resolution + + initial_response = await sync_op( + cls, + ApiEndpoint(path=f"/proxy/veo/{model}/generate", method="POST"), + response_model=VeoGenVidResponse, + data=VeoGenVidRequest( + instances=instances, + parameters=parameters, + ), + ) + + poll_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/veo/{model}/poll", method="POST"), + response_model=VeoGenVidPollResponse, + status_extractor=lambda r: "completed" if r.done else "pending", + data=VeoGenVidPollRequest(operationName=initial_response.name), + poll_interval=9.0, + estimated_duration=AVERAGE_DURATION_VIDEO_GEN, + ) + + if poll_response.error: + raise Exception(f"Veo API error: {poll_response.error.message} (code: {poll_response.error.code})") + + response = poll_response.response + filtered_count = response.raiMediaFilteredCount + if filtered_count: + reasons = response.raiMediaFilteredReasons or [] + reason_part = f": {reasons[0]}" if reasons else "" + raise Exception( + f"Content blocked by Google's Responsible AI filters{reason_part} " + f"({filtered_count} video{'s' if filtered_count != 1 else ''} filtered)." + ) + + if response.videos: + video = response.videos[0] + if video.bytesBase64Encoded: + return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded)))) + if video.gcsUri: + return IO.NodeOutput(await download_url_to_video_output(video.gcsUri)) + raise Exception("Video returned but no data or URL was provided") + raise Exception("Video generation completed but no video was returned") + class Veo3FirstLastFrameNode(IO.ComfyNode): @@ -394,7 +483,7 @@ def define_schema(cls): default="", tooltip="Negative text prompt to guide what to avoid in the video", ), - IO.Combo.Input("resolution", options=["720p", "1080p"]), + IO.Combo.Input("resolution", options=["720p", "1080p", "4k"]), IO.Combo.Input( "aspect_ratio", options=["16:9", "9:16"], @@ -424,8 +513,7 @@ def define_schema(cls): IO.Image.Input("last_frame", tooltip="End frame"), IO.Combo.Input( "model", - options=["veo-3.1-generate", "veo-3.1-fast-generate"], - default="veo-3.1-fast-generate", + options=["veo-3.1-generate", "veo-3.1-fast-generate", "veo-3.1-lite"], ), IO.Boolean.Input( "generate_audio", @@ -443,26 +531,20 @@ def define_schema(cls): ], is_api_node=True, price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["model", "generate_audio", "duration"]), + depends_on=IO.PriceBadgeDepends(widgets=["model", "generate_audio", "duration", "resolution"]), expr=""" ( - $prices := { - "veo-3.1-fast-generate": { "audio": 0.15, "no_audio": 0.10 }, - "veo-3.1-generate": { "audio": 0.40, "no_audio": 0.20 } - }; $m := widgets.model; - $ga := (widgets.generate_audio = "true"); + $r := widgets.resolution; + $ga := widgets.generate_audio; $seconds := widgets.duration; - $modelKey := - $contains($m, "veo-3.1-fast-generate") ? "veo-3.1-fast-generate" : - $contains($m, "veo-3.1-generate") ? "veo-3.1-generate" : - ""; - $audioKey := $ga ? "audio" : "no_audio"; - $modelPrices := $lookup($prices, $modelKey); - $pps := $lookup($modelPrices, $audioKey); - ($pps != null) - ? {"type":"usd","usd": $pps * $seconds} - : {"type":"range_usd","min_usd": 0.4, "max_usd": 3.2} + $pps := + $contains($m, "lite") + ? ($r = "1080p" ? ($ga ? 0.08 : 0.05) : ($ga ? 0.05 : 0.03)) + : $contains($m, "fast") + ? ($r = "4k" ? ($ga ? 0.30 : 0.25) : $r = "1080p" ? ($ga ? 0.12 : 0.10) : ($ga ? 0.10 : 0.08)) + : ($r = "4k" ? ($ga ? 0.60 : 0.40) : ($ga ? 0.40 : 0.20)); + {"type":"usd","usd": $pps * $seconds} ) """, ), @@ -482,6 +564,9 @@ async def execute( model: str, generate_audio: bool, ): + if "lite" in model and resolution == "4k": + raise Exception("4K resolution is not supported by the veo-3.1-lite model.") + model = MODELS_MAP[model] initial_response = await sync_op( cls, @@ -519,7 +604,7 @@ async def execute( data=VeoGenVidPollRequest( operationName=initial_response.name, ), - poll_interval=5.0, + poll_interval=9.0, estimated_duration=AVERAGE_DURATION_VIDEO_GEN, ) diff --git a/comfy_api_nodes/util/__init__.py b/comfy_api_nodes/util/__init__.py index 0cb9a47c780f..f3584aba9a9d 100644 --- a/comfy_api_nodes/util/__init__.py +++ b/comfy_api_nodes/util/__init__.py @@ -19,6 +19,7 @@ image_tensor_pair_to_batch, pil_to_bytesio, resize_mask_to_image, + resize_video_to_pixel_budget, tensor_to_base64_string, tensor_to_bytesio, tensor_to_pil, @@ -90,6 +91,7 @@ "image_tensor_pair_to_batch", "pil_to_bytesio", "resize_mask_to_image", + "resize_video_to_pixel_budget", "tensor_to_base64_string", "tensor_to_bytesio", "tensor_to_pil", diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index 9d730b81a4da..b0cf97ae48fb 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -156,6 +156,7 @@ async def poll_op( estimated_duration: int | None = None, cancel_endpoint: ApiEndpoint | None = None, cancel_timeout: float = 10.0, + extra_text: str | None = None, ) -> M: raw = await poll_op_raw( cls, @@ -176,6 +177,7 @@ async def poll_op( estimated_duration=estimated_duration, cancel_endpoint=cancel_endpoint, cancel_timeout=cancel_timeout, + extra_text=extra_text, ) if not isinstance(raw, dict): raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).") @@ -260,6 +262,7 @@ async def poll_op_raw( estimated_duration: int | None = None, cancel_endpoint: ApiEndpoint | None = None, cancel_timeout: float = 10.0, + extra_text: str | None = None, ) -> dict[str, Any]: """ Polls an endpoint until the task reaches a terminal state. Displays time while queued/processing, @@ -299,6 +302,7 @@ async def _ticker(): price=state.price, is_queued=state.is_queued, processing_elapsed_seconds=int(proc_elapsed), + extra_text=extra_text, ) await asyncio.sleep(1.0) except Exception as exc: @@ -389,6 +393,7 @@ async def _ticker(): price=state.price, is_queued=False, processing_elapsed_seconds=int(state.base_processing_elapsed), + extra_text=extra_text, ) return resp_json @@ -462,6 +467,7 @@ def _display_time_progress( price: float | None = None, is_queued: bool | None = None, processing_elapsed_seconds: int | None = None, + extra_text: str | None = None, ) -> None: if estimated_total is not None and estimated_total > 0 and is_queued is False: pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds @@ -469,7 +475,8 @@ def _display_time_progress( time_line = f"Time elapsed: {int(elapsed_seconds)}s (~{remaining}s remaining)" else: time_line = f"Time elapsed: {int(elapsed_seconds)}s" - _display_text(node_cls, time_line, status=status, price=price) + text = f"{time_line}\n\n{extra_text}" if extra_text else time_line + _display_text(node_cls, text, status=status, price=price) async def _diagnose_connectivity() -> dict[str, bool]: diff --git a/comfy_api_nodes/util/conversions.py b/comfy_api_nodes/util/conversions.py index 82b6d22a513d..be5d5719bdfa 100644 --- a/comfy_api_nodes/util/conversions.py +++ b/comfy_api_nodes/util/conversions.py @@ -129,22 +129,38 @@ def pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO: return img_byte_arr +def _compute_downscale_dims(src_w: int, src_h: int, total_pixels: int) -> tuple[int, int] | None: + """Return downscaled (w, h) with even dims fitting ``total_pixels``, or None if already fits. + + Source aspect ratio is preserved; output may drift by a fraction of a percent because both dimensions + are rounded down to even values (many codecs require divisible-by-2). + """ + pixels = src_w * src_h + if pixels <= total_pixels: + return None + scale = math.sqrt(total_pixels / pixels) + new_w = max(2, int(src_w * scale)) + new_h = max(2, int(src_h * scale)) + new_w -= new_w % 2 + new_h -= new_h % 2 + return new_w, new_h + + def downscale_image_tensor(image: torch.Tensor, total_pixels: int = 1536 * 1024) -> torch.Tensor: - """Downscale input image tensor to roughly the specified total pixels.""" + """Downscale input image tensor to roughly the specified total pixels. + + Output dimensions are rounded down to even values so that the result is guaranteed to fit within ``total_pixels`` + and is compatible with codecs that require even dimensions (e.g. yuv420p). + """ samples = image.movedim(-1, 1) - total = int(total_pixels) - scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2])) - if scale_by >= 1: + dims = _compute_downscale_dims(samples.shape[3], samples.shape[2], int(total_pixels)) + if dims is None: return image - width = round(samples.shape[3] * scale_by) - height = round(samples.shape[2] * scale_by) - - s = common_upscale(samples, width, height, "lanczos", "disabled") - s = s.movedim(1, -1) - return s + new_w, new_h = dims + return common_upscale(samples, new_w, new_h, "lanczos", "disabled").movedim(1, -1) -def downscale_image_tensor_by_max_side(image: torch.Tensor, *, max_side: int) -> torch.Tensor: +def downscale_image_tensor_by_max_side(image: torch.Tensor, *, max_side: int) -> torch.Tensor: """Downscale input image tensor so the largest dimension is at most max_side pixels.""" samples = image.movedim(-1, 1) height, width = samples.shape[2], samples.shape[3] @@ -399,6 +415,72 @@ def trim_video(video: Input.Video, duration_sec: float) -> Input.Video: raise RuntimeError(f"Failed to trim video: {str(e)}") from e +def resize_video_to_pixel_budget(video: Input.Video, total_pixels: int) -> Input.Video: + """Downscale a video to fit within ``total_pixels`` (w * h), preserving aspect ratio. + + Returns the original video object untouched when it already fits. Preserves frame rate, duration, and audio. + Aspect ratio is preserved up to a fraction of a percent (even-dim rounding). + """ + src_w, src_h = video.get_dimensions() + scale_dims = _compute_downscale_dims(src_w, src_h, total_pixels) + if scale_dims is None: + return video + return _apply_video_scale(video, scale_dims) + + +def _apply_video_scale(video: Input.Video, scale_dims: tuple[int, int]) -> Input.Video: + """Re-encode ``video`` scaled to ``scale_dims`` with a single decode/encode pass.""" + out_w, out_h = scale_dims + output_buffer = BytesIO() + input_container = None + output_container = None + + try: + input_source = video.get_stream_source() + input_container = av.open(input_source, mode="r") + output_container = av.open(output_buffer, mode="w", format="mp4") + + video_stream = output_container.add_stream("h264", rate=video.get_frame_rate()) + video_stream.width = out_w + video_stream.height = out_h + video_stream.pix_fmt = "yuv420p" + + audio_stream = None + for stream in input_container.streams: + if isinstance(stream, av.AudioStream): + audio_stream = output_container.add_stream("aac", rate=stream.sample_rate) + audio_stream.sample_rate = stream.sample_rate + audio_stream.layout = stream.layout + break + + for frame in input_container.decode(video=0): + frame = frame.reformat(width=out_w, height=out_h, format="yuv420p") + for packet in video_stream.encode(frame): + output_container.mux(packet) + for packet in video_stream.encode(): + output_container.mux(packet) + + if audio_stream is not None: + input_container.seek(0) + for audio_frame in input_container.decode(audio=0): + for packet in audio_stream.encode(audio_frame): + output_container.mux(packet) + for packet in audio_stream.encode(): + output_container.mux(packet) + + output_container.close() + input_container.close() + output_buffer.seek(0) + return InputImpl.VideoFromFile(output_buffer) + + except Exception as e: + if input_container is not None: + input_container.close() + if output_container is not None: + output_container.close() + raise RuntimeError(f"Failed to resize video: {str(e)}") from e + + def _f32_pcm(wav: torch.Tensor) -> torch.Tensor: """Convert audio to float 32 bits PCM format. Copy-paste from nodes_audio.py file.""" if wav.dtype.is_floating_point: diff --git a/comfy_extras/frame_interpolation_models/film_net.py b/comfy_extras/frame_interpolation_models/film_net.py new file mode 100644 index 000000000000..cf4f6e1e1fb6 --- /dev/null +++ b/comfy_extras/frame_interpolation_models/film_net.py @@ -0,0 +1,258 @@ +"""FILM: Frame Interpolation for Large Motion (ECCV 2022).""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import comfy.ops + +ops = comfy.ops.disable_weight_init + + +class FilmConv2d(nn.Module): + """Conv2d with optional LeakyReLU and FILM-style padding.""" + + def __init__(self, in_channels, out_channels, size, activation=True, device=None, dtype=None, operations=ops): + super().__init__() + self.even_pad = not size % 2 + self.conv = operations.Conv2d(in_channels, out_channels, kernel_size=size, padding=size // 2 if size % 2 else 0, device=device, dtype=dtype) + self.activation = nn.LeakyReLU(0.2) if activation else None + + def forward(self, x): + if self.even_pad: + x = F.pad(x, (0, 1, 0, 1)) + x = self.conv(x) + if self.activation is not None: + x = self.activation(x) + return x + + +def _warp_core(image, flow, grid_x, grid_y): + dtype = image.dtype + H, W = flow.shape[2], flow.shape[3] + dx = flow[:, 0].float() / (W * 0.5) + dy = flow[:, 1].float() / (H * 0.5) + grid = torch.stack([grid_x[None, None, :] + dx, grid_y[None, :, None] + dy], dim=3) + return F.grid_sample(image.float(), grid, mode="bilinear", padding_mode="border", align_corners=False).to(dtype) + + +def build_image_pyramid(image, pyramid_levels): + pyramid = [image] + for _ in range(1, pyramid_levels): + image = F.avg_pool2d(image, 2, 2) + pyramid.append(image) + return pyramid + + +def flow_pyramid_synthesis(residual_pyramid): + flow = residual_pyramid[-1] + flow_pyramid = [flow] + for residual_flow in residual_pyramid[:-1][::-1]: + flow = F.interpolate(flow, size=residual_flow.shape[2:4], mode="bilinear", scale_factor=None).mul_(2).add_(residual_flow) + flow_pyramid.append(flow) + flow_pyramid.reverse() + return flow_pyramid + + +def multiply_pyramid(pyramid, scalar): + return [image * scalar[:, None, None, None] for image in pyramid] + + +def pyramid_warp(feature_pyramid, flow_pyramid, warp_fn): + return [warp_fn(features, flow) for features, flow in zip(feature_pyramid, flow_pyramid)] + + +def concatenate_pyramids(pyramid1, pyramid2): + return [torch.cat([f1, f2], dim=1) for f1, f2 in zip(pyramid1, pyramid2)] + + +class SubTreeExtractor(nn.Module): + def __init__(self, in_channels=3, channels=64, n_layers=4, device=None, dtype=None, operations=ops): + super().__init__() + convs = [] + for i in range(n_layers): + out_ch = channels << i + convs.append(nn.Sequential( + FilmConv2d(in_channels, out_ch, 3, device=device, dtype=dtype, operations=operations), + FilmConv2d(out_ch, out_ch, 3, device=device, dtype=dtype, operations=operations))) + in_channels = out_ch + self.convs = nn.ModuleList(convs) + + def forward(self, image, n): + head = image + pyramid = [] + for i, layer in enumerate(self.convs): + head = layer(head) + pyramid.append(head) + if i < n - 1: + head = F.avg_pool2d(head, 2, 2) + return pyramid + + +class FeatureExtractor(nn.Module): + def __init__(self, in_channels=3, channels=64, sub_levels=4, device=None, dtype=None, operations=ops): + super().__init__() + self.extract_sublevels = SubTreeExtractor(in_channels, channels, sub_levels, device=device, dtype=dtype, operations=operations) + self.sub_levels = sub_levels + + def forward(self, image_pyramid): + sub_pyramids = [self.extract_sublevels(image_pyramid[i], min(len(image_pyramid) - i, self.sub_levels)) + for i in range(len(image_pyramid))] + feature_pyramid = [] + for i in range(len(image_pyramid)): + features = sub_pyramids[i][0] + for j in range(1, self.sub_levels): + if j <= i: + features = torch.cat([features, sub_pyramids[i - j][j]], dim=1) + feature_pyramid.append(features) + # Free sub-pyramids no longer needed by future levels + if i >= self.sub_levels - 1: + sub_pyramids[i - self.sub_levels + 1] = None + return feature_pyramid + + +class FlowEstimator(nn.Module): + def __init__(self, in_channels, num_convs, num_filters, device=None, dtype=None, operations=ops): + super().__init__() + self._convs = nn.ModuleList() + for _ in range(num_convs): + self._convs.append(FilmConv2d(in_channels, num_filters, 3, device=device, dtype=dtype, operations=operations)) + in_channels = num_filters + self._convs.append(FilmConv2d(in_channels, num_filters // 2, 1, device=device, dtype=dtype, operations=operations)) + self._convs.append(FilmConv2d(num_filters // 2, 2, 1, activation=False, device=device, dtype=dtype, operations=operations)) + + def forward(self, features_a, features_b): + net = torch.cat([features_a, features_b], dim=1) + for conv in self._convs: + net = conv(net) + return net + + +class PyramidFlowEstimator(nn.Module): + def __init__(self, filters=64, flow_convs=(3, 3, 3, 3), flow_filters=(32, 64, 128, 256), device=None, dtype=None, operations=ops): + super().__init__() + in_channels = filters << 1 + predictors = [] + for i in range(len(flow_convs)): + predictors.append(FlowEstimator(in_channels, flow_convs[i], flow_filters[i], device=device, dtype=dtype, operations=operations)) + in_channels += filters << (i + 2) + self._predictor = predictors[-1] + self._predictors = nn.ModuleList(predictors[:-1][::-1]) + + def forward(self, feature_pyramid_a, feature_pyramid_b, warp_fn): + levels = len(feature_pyramid_a) + v = self._predictor(feature_pyramid_a[-1], feature_pyramid_b[-1]) + residuals = [v] + # Coarse-to-fine: shared predictor for deep levels, then specialized predictors for fine levels + steps = [(i, self._predictor) for i in range(levels - 2, len(self._predictors) - 1, -1)] + steps += [(len(self._predictors) - 1 - k, p) for k, p in enumerate(self._predictors)] + for i, predictor in steps: + v = F.interpolate(v, size=feature_pyramid_a[i].shape[2:4], mode="bilinear").mul_(2) + v_residual = predictor(feature_pyramid_a[i], warp_fn(feature_pyramid_b[i], v)) + residuals.append(v_residual) + v = v.add_(v_residual) + residuals.reverse() + return residuals + + +def _get_fusion_channels(level, filters): + # Per direction: multi-scale features + RGB image (3ch) + flow (2ch), doubled for both directions + return (sum(filters << i for i in range(level)) + 3 + 2) * 2 + + +class Fusion(nn.Module): + def __init__(self, n_layers=4, specialized_layers=3, filters=64, device=None, dtype=None, operations=ops): + super().__init__() + self.output_conv = operations.Conv2d(filters, 3, kernel_size=1, device=device, dtype=dtype) + self.convs = nn.ModuleList() + in_channels = _get_fusion_channels(n_layers, filters) + increase = 0 + for i in range(n_layers)[::-1]: + num_filters = (filters << i) if i < specialized_layers else (filters << specialized_layers) + self.convs.append(nn.ModuleList([ + FilmConv2d(in_channels, num_filters, 2, activation=False, device=device, dtype=dtype, operations=operations), + FilmConv2d(in_channels + (increase or num_filters), num_filters, 3, device=device, dtype=dtype, operations=operations), + FilmConv2d(num_filters, num_filters, 3, device=device, dtype=dtype, operations=operations)])) + in_channels = num_filters + increase = _get_fusion_channels(i, filters) - num_filters // 2 + + def forward(self, pyramid): + net = pyramid[-1] + for k, layers in enumerate(self.convs): + i = len(self.convs) - 1 - k + net = layers[0](F.interpolate(net, size=pyramid[i].shape[2:4], mode="nearest")) + net = layers[2](layers[1](torch.cat([pyramid[i], net], dim=1))) + return self.output_conv(net) + + +class FILMNet(nn.Module): + def __init__(self, pyramid_levels=7, fusion_pyramid_levels=5, specialized_levels=3, sub_levels=4, + filters=64, flow_convs=(3, 3, 3, 3), flow_filters=(32, 64, 128, 256), device=None, dtype=None, operations=ops): + super().__init__() + self.pyramid_levels = pyramid_levels + self.fusion_pyramid_levels = fusion_pyramid_levels + self.extract = FeatureExtractor(3, filters, sub_levels, device=device, dtype=dtype, operations=operations) + self.predict_flow = PyramidFlowEstimator(filters, flow_convs, flow_filters, device=device, dtype=dtype, operations=operations) + self.fuse = Fusion(sub_levels, specialized_levels, filters, device=device, dtype=dtype, operations=operations) + self._warp_grids = {} + + def get_dtype(self): + return self.extract.extract_sublevels.convs[0][0].conv.weight.dtype + + def _build_warp_grids(self, H, W, device): + """Pre-compute warp grids for all pyramid levels.""" + if (H, W) in self._warp_grids: + return + self._warp_grids = {} # clear old resolution grids to prevent memory leaks + for _ in range(self.pyramid_levels): + self._warp_grids[(H, W)] = ( + torch.linspace(-(1 - 1 / W), 1 - 1 / W, W, dtype=torch.float32, device=device), + torch.linspace(-(1 - 1 / H), 1 - 1 / H, H, dtype=torch.float32, device=device), + ) + H, W = H // 2, W // 2 + + def warp(self, image, flow): + grid_x, grid_y = self._warp_grids[(flow.shape[2], flow.shape[3])] + return _warp_core(image, flow, grid_x, grid_y) + + def extract_features(self, img): + """Extract image and feature pyramids for a single frame. Can be cached across pairs.""" + image_pyramid = build_image_pyramid(img, self.pyramid_levels) + feature_pyramid = self.extract(image_pyramid) + return image_pyramid, feature_pyramid + + def forward(self, img0, img1, timestep=0.5, cache=None): + # FILM uses a scalar timestep per batch element (spatially-varying timesteps not supported) + t = timestep.mean(dim=(1, 2, 3)).item() if isinstance(timestep, torch.Tensor) else timestep + return self.forward_multi_timestep(img0, img1, [t], cache=cache) + + def forward_multi_timestep(self, img0, img1, timesteps, cache=None): + """Compute flow once, synthesize at multiple timesteps. Expects batch=1 inputs.""" + self._build_warp_grids(img0.shape[2], img0.shape[3], img0.device) + + image_pyr0, feat_pyr0 = cache["img0"] if cache and "img0" in cache else self.extract_features(img0) + image_pyr1, feat_pyr1 = cache["img1"] if cache and "img1" in cache else self.extract_features(img1) + + fwd_flow = flow_pyramid_synthesis(self.predict_flow(feat_pyr0, feat_pyr1, self.warp))[:self.fusion_pyramid_levels] + bwd_flow = flow_pyramid_synthesis(self.predict_flow(feat_pyr1, feat_pyr0, self.warp))[:self.fusion_pyramid_levels] + + # Build warp targets and free full pyramids (only first fpl levels needed from here) + fpl = self.fusion_pyramid_levels + p2w = [concatenate_pyramids(image_pyr0[:fpl], feat_pyr0[:fpl]), + concatenate_pyramids(image_pyr1[:fpl], feat_pyr1[:fpl])] + del image_pyr0, image_pyr1, feat_pyr0, feat_pyr1 + + results = [] + dt_tensors = torch.tensor(timesteps, device=img0.device, dtype=img0.dtype) + for idx in range(len(timesteps)): + batch_dt = dt_tensors[idx:idx + 1] + bwd_scaled = multiply_pyramid(bwd_flow, batch_dt) + fwd_scaled = multiply_pyramid(fwd_flow, 1 - batch_dt) + fwd_warped = pyramid_warp(p2w[0], bwd_scaled, self.warp) + bwd_warped = pyramid_warp(p2w[1], fwd_scaled, self.warp) + aligned = [torch.cat([fw, bw, bf, ff], dim=1) + for fw, bw, bf, ff in zip(fwd_warped, bwd_warped, bwd_scaled, fwd_scaled)] + del fwd_warped, bwd_warped, bwd_scaled, fwd_scaled + results.append(self.fuse(aligned)) + del aligned + return torch.cat(results, dim=0) diff --git a/comfy_extras/frame_interpolation_models/ifnet.py b/comfy_extras/frame_interpolation_models/ifnet.py new file mode 100644 index 000000000000..03cb34c50a9c --- /dev/null +++ b/comfy_extras/frame_interpolation_models/ifnet.py @@ -0,0 +1,128 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import comfy.ops + +ops = comfy.ops.disable_weight_init + + +def _warp(img, flow, warp_grids): + B, _, H, W = img.shape + base_grid, flow_div = warp_grids[(H, W)] + flow_norm = torch.cat([flow[:, 0:1] / flow_div[0], flow[:, 1:2] / flow_div[1]], 1).float() + grid = (base_grid.expand(B, -1, -1, -1) + flow_norm).permute(0, 2, 3, 1) + return F.grid_sample(img.float(), grid, mode="bilinear", padding_mode="border", align_corners=True).to(img.dtype) + + +class Head(nn.Module): + def __init__(self, out_ch=4, device=None, dtype=None, operations=ops): + super().__init__() + self.cnn0 = operations.Conv2d(3, 16, 3, 2, 1, device=device, dtype=dtype) + self.cnn1 = operations.Conv2d(16, 16, 3, 1, 1, device=device, dtype=dtype) + self.cnn2 = operations.Conv2d(16, 16, 3, 1, 1, device=device, dtype=dtype) + self.cnn3 = operations.ConvTranspose2d(16, out_ch, 4, 2, 1, device=device, dtype=dtype) + self.relu = nn.LeakyReLU(0.2, True) + + def forward(self, x): + x = self.relu(self.cnn0(x)) + x = self.relu(self.cnn1(x)) + x = self.relu(self.cnn2(x)) + return self.cnn3(x) + + +class ResConv(nn.Module): + def __init__(self, c, device=None, dtype=None, operations=ops): + super().__init__() + self.conv = operations.Conv2d(c, c, 3, 1, 1, device=device, dtype=dtype) + self.beta = nn.Parameter(torch.ones((1, c, 1, 1), device=device, dtype=dtype)) + self.relu = nn.LeakyReLU(0.2, True) + + def forward(self, x): + return self.relu(torch.addcmul(x, self.conv(x), self.beta)) + + +class IFBlock(nn.Module): + def __init__(self, in_planes, c=64, device=None, dtype=None, operations=ops): + super().__init__() + self.conv0 = nn.Sequential( + nn.Sequential(operations.Conv2d(in_planes, c // 2, 3, 2, 1, device=device, dtype=dtype), nn.LeakyReLU(0.2, True)), + nn.Sequential(operations.Conv2d(c // 2, c, 3, 2, 1, device=device, dtype=dtype), nn.LeakyReLU(0.2, True))) + self.convblock = nn.Sequential(*(ResConv(c, device=device, dtype=dtype, operations=operations) for _ in range(8))) + self.lastconv = nn.Sequential(operations.ConvTranspose2d(c, 4 * 13, 4, 2, 1, device=device, dtype=dtype), nn.PixelShuffle(2)) + + def forward(self, x, flow=None, scale=1): + x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear") + if flow is not None: + flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear").div_(scale) + x = torch.cat((x, flow), 1) + feat = self.convblock(self.conv0(x)) + tmp = F.interpolate(self.lastconv(feat), scale_factor=scale, mode="bilinear") + return tmp[:, :4] * scale, tmp[:, 4:5], tmp[:, 5:] + + +class IFNet(nn.Module): + def __init__(self, head_ch=4, channels=(192, 128, 96, 64, 32), device=None, dtype=None, operations=ops): + super().__init__() + self.encode = Head(out_ch=head_ch, device=device, dtype=dtype, operations=operations) + block_in = [7 + 2 * head_ch] + [8 + 4 + 8 + 2 * head_ch] * 4 + self.blocks = nn.ModuleList([IFBlock(block_in[i], channels[i], device=device, dtype=dtype, operations=operations) for i in range(5)]) + self.scale_list = [16, 8, 4, 2, 1] + self.pad_align = 64 + self._warp_grids = {} + + def get_dtype(self): + return self.encode.cnn0.weight.dtype + + def _build_warp_grids(self, H, W, device): + if (H, W) in self._warp_grids: + return + self._warp_grids = {} # clear old resolution grids to prevent memory leaks + grid_y, grid_x = torch.meshgrid( + torch.linspace(-1.0, 1.0, H, device=device, dtype=torch.float32), + torch.linspace(-1.0, 1.0, W, device=device, dtype=torch.float32), indexing="ij") + self._warp_grids[(H, W)] = ( + torch.stack((grid_x, grid_y), dim=0).unsqueeze(0), + torch.tensor([(W - 1.0) / 2.0, (H - 1.0) / 2.0], dtype=torch.float32, device=device)) + + def warp(self, img, flow): + return _warp(img, flow, self._warp_grids) + + def extract_features(self, img): + """Extract head features for a single frame. Can be cached across pairs.""" + return self.encode(img) + + def forward(self, img0, img1, timestep=0.5, cache=None): + if not isinstance(timestep, torch.Tensor): + timestep = torch.full((img0.shape[0], 1, img0.shape[2], img0.shape[3]), timestep, device=img0.device, dtype=img0.dtype) + + self._build_warp_grids(img0.shape[2], img0.shape[3], img0.device) + + B = img0.shape[0] + f0 = cache["img0"].expand(B, -1, -1, -1) if cache and "img0" in cache else self.encode(img0) + f1 = cache["img1"].expand(B, -1, -1, -1) if cache and "img1" in cache else self.encode(img1) + flow = mask = feat = None + warped_img0, warped_img1 = img0, img1 + for i, block in enumerate(self.blocks): + if flow is None: + flow, mask, feat = block(torch.cat((img0, img1, f0, f1, timestep), 1), None, scale=self.scale_list[i]) + else: + fd, mask, feat = block( + torch.cat((warped_img0, warped_img1, self.warp(f0, flow[:, :2]), self.warp(f1, flow[:, 2:4]), timestep, mask, feat), 1), + flow, scale=self.scale_list[i]) + flow = flow.add_(fd) + warped_img0 = self.warp(img0, flow[:, :2]) + warped_img1 = self.warp(img1, flow[:, 2:4]) + return torch.lerp(warped_img1, warped_img0, torch.sigmoid(mask)) + + +def detect_rife_config(state_dict): + head_ch = state_dict["encode.cnn3.weight"].shape[1] # ConvTranspose2d: (in_ch, out_ch, kH, kW) + channels = [] + for i in range(5): + key = f"blocks.{i}.conv0.1.0.weight" + if key in state_dict: + channels.append(state_dict[key].shape[0]) + if len(channels) != 5: + raise ValueError(f"Unsupported RIFE model: expected 5 blocks, found {len(channels)}") + return head_ch, channels diff --git a/comfy_extras/nodes_ace.py b/comfy_extras/nodes_ace.py index cbfaf913d8be..1602add84c86 100644 --- a/comfy_extras/nodes_ace.py +++ b/comfy_extras/nodes_ace.py @@ -3,136 +3,136 @@ import comfy.model_management import node_helpers -from comfy_api.latest import ComfyExtension, io +from comfy_api.latest import ComfyExtension, IO -class TextEncodeAceStepAudio(io.ComfyNode): +class TextEncodeAceStepAudio(IO.ComfyNode): @classmethod def define_schema(cls): - return io.Schema( + return IO.Schema( node_id="TextEncodeAceStepAudio", category="conditioning", inputs=[ - io.Clip.Input("clip"), - io.String.Input("tags", multiline=True, dynamic_prompts=True), - io.String.Input("lyrics", multiline=True, dynamic_prompts=True), - io.Float.Input("lyrics_strength", default=1.0, min=0.0, max=10.0, step=0.01), + IO.Clip.Input("clip"), + IO.String.Input("tags", multiline=True, dynamic_prompts=True), + IO.String.Input("lyrics", multiline=True, dynamic_prompts=True), + IO.Float.Input("lyrics_strength", default=1.0, min=0.0, max=10.0, step=0.01), ], - outputs=[io.Conditioning.Output()], + outputs=[IO.Conditioning.Output()], ) @classmethod - def execute(cls, clip, tags, lyrics, lyrics_strength) -> io.NodeOutput: + def execute(cls, clip, tags, lyrics, lyrics_strength) -> IO.NodeOutput: tokens = clip.tokenize(tags, lyrics=lyrics) conditioning = clip.encode_from_tokens_scheduled(tokens) conditioning = node_helpers.conditioning_set_values(conditioning, {"lyrics_strength": lyrics_strength}) - return io.NodeOutput(conditioning) + return IO.NodeOutput(conditioning) -class TextEncodeAceStepAudio15(io.ComfyNode): +class TextEncodeAceStepAudio15(IO.ComfyNode): @classmethod def define_schema(cls): - return io.Schema( + return IO.Schema( node_id="TextEncodeAceStepAudio1.5", category="conditioning", inputs=[ - io.Clip.Input("clip"), - io.String.Input("tags", multiline=True, dynamic_prompts=True), - io.String.Input("lyrics", multiline=True, dynamic_prompts=True), - io.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True), - io.Int.Input("bpm", default=120, min=10, max=300), - io.Float.Input("duration", default=120.0, min=0.0, max=2000.0, step=0.1), - io.Combo.Input("timesignature", options=['2', '3', '4', '6']), - io.Combo.Input("language", options=["en", "ja", "zh", "es", "de", "fr", "pt", "ru", "it", "nl", "pl", "tr", "vi", "cs", "fa", "id", "ko", "uk", "hu", "ar", "sv", "ro", "el"]), - io.Combo.Input("keyscale", options=[f"{root} {quality}" for quality in ["major", "minor"] for root in ["C", "C#", "Db", "D", "D#", "Eb", "E", "F", "F#", "Gb", "G", "G#", "Ab", "A", "A#", "Bb", "B"]]), - io.Boolean.Input("generate_audio_codes", default=True, tooltip="Enable the LLM that generates audio codes. This can be slow but will increase the quality of the generated audio. Turn this off if you are giving the model an audio reference.", advanced=True), - io.Float.Input("cfg_scale", default=2.0, min=0.0, max=100.0, step=0.1, advanced=True), - io.Float.Input("temperature", default=0.85, min=0.0, max=2.0, step=0.01, advanced=True), - io.Float.Input("top_p", default=0.9, min=0.0, max=2000.0, step=0.01, advanced=True), - io.Int.Input("top_k", default=0, min=0, max=100, advanced=True), - io.Float.Input("min_p", default=0.000, min=0.0, max=1.0, step=0.001, advanced=True), + IO.Clip.Input("clip"), + IO.String.Input("tags", multiline=True, dynamic_prompts=True), + IO.String.Input("lyrics", multiline=True, dynamic_prompts=True), + IO.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True), + IO.Int.Input("bpm", default=120, min=10, max=300), + IO.Float.Input("duration", default=120.0, min=0.0, max=2000.0, step=0.1), + IO.Combo.Input("timesignature", options=['2', '3', '4', '6']), + IO.Combo.Input("language", options=["en", "ja", "zh", "es", "de", "fr", "pt", "ru", "it", "nl", "pl", "tr", "vi", "cs", "fa", "id", "ko", "uk", "hu", "ar", "sv", "ro", "el"]), + IO.Combo.Input("keyscale", options=[f"{root} {quality}" for quality in ["major", "minor"] for root in ["C", "C#", "Db", "D", "D#", "Eb", "E", "F", "F#", "Gb", "G", "G#", "Ab", "A", "A#", "Bb", "B"]]), + IO.Boolean.Input("generate_audio_codes", default=True, tooltip="Enable the LLM that generates audio codes. This can be slow but will increase the quality of the generated audio. Turn this off if you are giving the model an audio reference.", advanced=True), + IO.Float.Input("cfg_scale", default=2.0, min=0.0, max=100.0, step=0.1, advanced=True), + IO.Float.Input("temperature", default=0.85, min=0.0, max=2.0, step=0.01, advanced=True), + IO.Float.Input("top_p", default=0.9, min=0.0, max=2000.0, step=0.01, advanced=True), + IO.Int.Input("top_k", default=0, min=0, max=100, advanced=True), + IO.Float.Input("min_p", default=0.000, min=0.0, max=1.0, step=0.001, advanced=True), ], - outputs=[io.Conditioning.Output()], + outputs=[IO.Conditioning.Output()], ) @classmethod - def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k, min_p) -> io.NodeOutput: + def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k, min_p) -> IO.NodeOutput: tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p) conditioning = clip.encode_from_tokens_scheduled(tokens) - return io.NodeOutput(conditioning) + return IO.NodeOutput(conditioning) -class EmptyAceStepLatentAudio(io.ComfyNode): +class EmptyAceStepLatentAudio(IO.ComfyNode): @classmethod def define_schema(cls): - return io.Schema( + return IO.Schema( node_id="EmptyAceStepLatentAudio", display_name="Empty Ace Step 1.0 Latent Audio", category="latent/audio", inputs=[ - io.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.1), - io.Int.Input( + IO.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.1), + IO.Int.Input( "batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch." ), ], - outputs=[io.Latent.Output()], + outputs=[IO.Latent.Output()], ) @classmethod - def execute(cls, seconds, batch_size) -> io.NodeOutput: + def execute(cls, seconds, batch_size) -> IO.NodeOutput: length = int(seconds * 44100 / 512 / 8) latent = torch.zeros([batch_size, 8, 16, length], device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype()) - return io.NodeOutput({"samples": latent, "type": "audio"}) + return IO.NodeOutput({"samples": latent, "type": "audio"}) -class EmptyAceStep15LatentAudio(io.ComfyNode): +class EmptyAceStep15LatentAudio(IO.ComfyNode): @classmethod def define_schema(cls): - return io.Schema( + return IO.Schema( node_id="EmptyAceStep1.5LatentAudio", display_name="Empty Ace Step 1.5 Latent Audio", category="latent/audio", inputs=[ - io.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.01), - io.Int.Input( + IO.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.01), + IO.Int.Input( "batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch." ), ], - outputs=[io.Latent.Output()], + outputs=[IO.Latent.Output()], ) @classmethod - def execute(cls, seconds, batch_size) -> io.NodeOutput: + def execute(cls, seconds, batch_size) -> IO.NodeOutput: length = round((seconds * 48000 / 1920)) latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype()) - return io.NodeOutput({"samples": latent, "type": "audio"}) + return IO.NodeOutput({"samples": latent, "type": "audio"}) -class ReferenceAudio(io.ComfyNode): +class ReferenceAudio(IO.ComfyNode): @classmethod def define_schema(cls): - return io.Schema( + return IO.Schema( node_id="ReferenceTimbreAudio", display_name="Reference Audio", category="advanced/conditioning/audio", is_experimental=True, description="This node sets the reference audio for ace step 1.5", inputs=[ - io.Conditioning.Input("conditioning"), - io.Latent.Input("latent", optional=True), + IO.Conditioning.Input("conditioning"), + IO.Latent.Input("latent", optional=True), ], outputs=[ - io.Conditioning.Output(), + IO.Conditioning.Output(), ] ) @classmethod - def execute(cls, conditioning, latent=None) -> io.NodeOutput: + def execute(cls, conditioning, latent=None) -> IO.NodeOutput: if latent is not None: conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_audio_timbre_latents": [latent["samples"]]}, append=True) - return io.NodeOutput(conditioning) + return IO.NodeOutput(conditioning) class AceExtension(ComfyExtension): @override - async def get_node_list(self) -> list[type[io.ComfyNode]]: + async def get_node_list(self) -> list[type[IO.ComfyNode]]: return [ TextEncodeAceStepAudio, EmptyAceStepLatentAudio, diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index a395392d86ee..5f514716f1d9 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -104,7 +104,7 @@ def vae_decode_audio(vae, samples, tile=None, overlap=None): std = torch.std(audio, dim=[1, 2], keepdim=True) * 5.0 std[std < 1.0] = 1.0 audio /= std - vae_sample_rate = getattr(vae, "audio_sample_rate", 44100) + vae_sample_rate = getattr(vae, "audio_sample_rate_output", getattr(vae, "audio_sample_rate", 44100)) return {"waveform": audio, "sample_rate": vae_sample_rate if "sample_rate" not in samples else samples["sample_rate"]} diff --git a/comfy_extras/nodes_frame_interpolation.py b/comfy_extras/nodes_frame_interpolation.py new file mode 100644 index 000000000000..a3b00d36ec58 --- /dev/null +++ b/comfy_extras/nodes_frame_interpolation.py @@ -0,0 +1,211 @@ +import torch +from tqdm import tqdm +from typing_extensions import override + +import comfy.model_patcher +import comfy.utils +import folder_paths +from comfy import model_management +from comfy_extras.frame_interpolation_models.ifnet import IFNet, detect_rife_config +from comfy_extras.frame_interpolation_models.film_net import FILMNet +from comfy_api.latest import ComfyExtension, io + +FrameInterpolationModel = io.Custom("INTERP_MODEL") + + +class FrameInterpolationModelLoader(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="FrameInterpolationModelLoader", + display_name="Load Frame Interpolation Model", + category="loaders", + inputs=[ + io.Combo.Input("model_name", options=folder_paths.get_filename_list("frame_interpolation"), + tooltip="Select a frame interpolation model to load. Models must be placed in the 'frame_interpolation' folder."), + ], + outputs=[ + FrameInterpolationModel.Output(), + ], + ) + + @classmethod + def execute(cls, model_name) -> io.NodeOutput: + model_path = folder_paths.get_full_path_or_raise("frame_interpolation", model_name) + sd = comfy.utils.load_torch_file(model_path, safe_load=True) + + model = cls._detect_and_load(sd) + dtype = torch.float16 if model_management.should_use_fp16(model_management.get_torch_device()) else torch.float32 + model.eval().to(dtype) + patcher = comfy.model_patcher.ModelPatcher( + model, + load_device=model_management.get_torch_device(), + offload_device=model_management.unet_offload_device(), + ) + return io.NodeOutput(patcher) + + @classmethod + def _detect_and_load(cls, sd): + # Try FILM + if "extract.extract_sublevels.convs.0.0.conv.weight" in sd: + model = FILMNet() + model.load_state_dict(sd) + return model + + # Try RIFE (needs key remapping for raw checkpoints) + sd = comfy.utils.state_dict_prefix_replace(sd, {"module.": "", "flownet.": ""}) + key_map = {} + for k in sd: + for i in range(5): + if k.startswith(f"block{i}."): + key_map[k] = f"blocks.{i}.{k[len(f'block{i}.'):]}" + if key_map: + sd = {key_map.get(k, k): v for k, v in sd.items()} + sd = {k: v for k, v in sd.items() if not k.startswith(("teacher.", "caltime."))} + + try: + head_ch, channels = detect_rife_config(sd) + except (KeyError, ValueError): + raise ValueError("Unrecognized frame interpolation model format") + model = IFNet(head_ch=head_ch, channels=channels) + model.load_state_dict(sd) + return model + + +class FrameInterpolate(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="FrameInterpolate", + display_name="Frame Interpolate", + category="image/video", + search_aliases=["rife", "film", "frame interpolation", "slow motion", "interpolate frames", "vfi"], + inputs=[ + FrameInterpolationModel.Input("interp_model"), + io.Image.Input("images"), + io.Int.Input("multiplier", default=2, min=2, max=16), + ], + outputs=[ + io.Image.Output(), + ], + ) + + @classmethod + def execute(cls, interp_model, images, multiplier) -> io.NodeOutput: + offload_device = model_management.intermediate_device() + + num_frames = images.shape[0] + if num_frames < 2 or multiplier < 2: + return io.NodeOutput(images) + + model_management.load_model_gpu(interp_model) + device = interp_model.load_device + dtype = interp_model.model_dtype() + inference_model = interp_model.model + + # Free VRAM for inference activations (model weights + ~20x a single frame's worth) + H, W = images.shape[1], images.shape[2] + activation_mem = H * W * 3 * images.element_size() * 20 + model_management.free_memory(activation_mem, device) + align = getattr(inference_model, "pad_align", 1) + + # Prepare a single padded frame on device for determining output dimensions + def prepare_frame(idx): + frame = images[idx:idx + 1].movedim(-1, 1).to(dtype=dtype, device=device) + if align > 1: + from comfy.ldm.common_dit import pad_to_patch_size + frame = pad_to_patch_size(frame, (align, align), padding_mode="reflect") + return frame + + # Count total interpolation passes for progress bar + total_pairs = num_frames - 1 + num_interp = multiplier - 1 + total_steps = total_pairs * num_interp + pbar = comfy.utils.ProgressBar(total_steps) + tqdm_bar = tqdm(total=total_steps, desc="Frame interpolation") + + batch = num_interp # reduced on OOM and persists across pairs (same resolution = same limit) + t_values = [t / multiplier for t in range(1, multiplier)] + + out_dtype = model_management.intermediate_dtype() + total_out_frames = total_pairs * multiplier + 1 + result = torch.empty((total_out_frames, 3, H, W), dtype=out_dtype, device=offload_device) + result[0] = images[0].movedim(-1, 0).to(out_dtype) + out_idx = 1 + + # Pre-compute timestep tensor on device (padded dimensions needed) + sample = prepare_frame(0) + pH, pW = sample.shape[2], sample.shape[3] + ts_full = torch.tensor(t_values, device=device, dtype=dtype).reshape(num_interp, 1, 1, 1) + ts_full = ts_full.expand(-1, 1, pH, pW) + del sample + + multi_fn = getattr(inference_model, "forward_multi_timestep", None) + feat_cache = {} + prev_frame = None + + try: + for i in range(total_pairs): + img0_single = prev_frame if prev_frame is not None else prepare_frame(i) + img1_single = prepare_frame(i + 1) + prev_frame = img1_single + + # Cache features: img1 of pair N becomes img0 of pair N+1 + feat_cache["img0"] = feat_cache.pop("next") if "next" in feat_cache else inference_model.extract_features(img0_single) + feat_cache["img1"] = inference_model.extract_features(img1_single) + feat_cache["next"] = feat_cache["img1"] + + used_multi = False + if multi_fn is not None: + # Models with timestep-independent flow can compute it once for all timesteps + try: + mids = multi_fn(img0_single, img1_single, t_values, cache=feat_cache) + result[out_idx:out_idx + num_interp] = mids[:, :, :H, :W].to(out_dtype) + out_idx += num_interp + pbar.update(num_interp) + tqdm_bar.update(num_interp) + used_multi = True + except model_management.OOM_EXCEPTION: + model_management.soft_empty_cache() + multi_fn = None # fall through to single-timestep path + + if not used_multi: + j = 0 + while j < num_interp: + b = min(batch, num_interp - j) + try: + img0 = img0_single.expand(b, -1, -1, -1) + img1 = img1_single.expand(b, -1, -1, -1) + mids = inference_model(img0, img1, timestep=ts_full[j:j + b], cache=feat_cache) + result[out_idx:out_idx + b] = mids[:, :, :H, :W].to(out_dtype) + out_idx += b + pbar.update(b) + tqdm_bar.update(b) + j += b + except model_management.OOM_EXCEPTION: + if batch <= 1: + raise + batch = max(1, batch // 2) + model_management.soft_empty_cache() + + result[out_idx] = images[i + 1].movedim(-1, 0).to(out_dtype) + out_idx += 1 + finally: + tqdm_bar.close() + + # BCHW -> BHWC + result = result.movedim(1, -1).clamp_(0.0, 1.0) + return io.NodeOutput(result) + + +class FrameInterpolationExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + FrameInterpolationModelLoader, + FrameInterpolate, + ] + + +async def comfy_entrypoint() -> FrameInterpolationExtension: + return FrameInterpolationExtension() diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index d7c2e874470f..19d8a387f07b 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -1,6 +1,7 @@ import nodes import node_helpers import torch +import torchaudio import comfy.model_management import comfy.model_sampling import comfy.samplers @@ -711,7 +712,14 @@ def define_schema(cls) -> io.Schema: @classmethod def execute(cls, model, positive, negative, reference_audio, audio_vae, identity_guidance_scale, start_percent, end_percent) -> io.NodeOutput: # Encode reference audio to latents and patchify - audio_latents = audio_vae.encode(reference_audio) + sample_rate = reference_audio["sample_rate"] + vae_sample_rate = getattr(audio_vae, "audio_sample_rate", 44100) + if vae_sample_rate != sample_rate: + waveform = torchaudio.functional.resample(reference_audio["waveform"], sample_rate, vae_sample_rate) + else: + waveform = reference_audio["waveform"] + + audio_latents = audio_vae.encode(waveform.movedim(1, -1)) b, c, t, f = audio_latents.shape ref_tokens = audio_latents.permute(0, 2, 1, 3).reshape(b, t, c * f) ref_audio = {"tokens": ref_tokens} diff --git a/comfy_extras/nodes_lt_audio.py b/comfy_extras/nodes_lt_audio.py index be0b1c887dbb..15d8a497e609 100644 --- a/comfy_extras/nodes_lt_audio.py +++ b/comfy_extras/nodes_lt_audio.py @@ -3,9 +3,8 @@ import comfy.model_management import torch -from comfy.ldm.lightricks.vae.audio_vae import AudioVAE from comfy_api.latest import ComfyExtension, io - +from comfy_extras.nodes_audio import VAEEncodeAudio class LTXVAudioVAELoader(io.ComfyNode): @classmethod @@ -28,10 +27,14 @@ def define_schema(cls) -> io.Schema: def execute(cls, ckpt_name: str) -> io.NodeOutput: ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True) - return io.NodeOutput(AudioVAE(sd, metadata)) + sd = comfy.utils.state_dict_prefix_replace(sd, {"audio_vae.": "autoencoder.", "vocoder.": "vocoder."}, filter_keys=True) + vae = comfy.sd.VAE(sd=sd, metadata=metadata) + vae.throw_exception_if_invalid() + + return io.NodeOutput(vae) -class LTXVAudioVAEEncode(io.ComfyNode): +class LTXVAudioVAEEncode(VAEEncodeAudio): @classmethod def define_schema(cls) -> io.Schema: return io.Schema( @@ -50,15 +53,8 @@ def define_schema(cls) -> io.Schema: ) @classmethod - def execute(cls, audio, audio_vae: AudioVAE) -> io.NodeOutput: - audio_latents = audio_vae.encode(audio) - return io.NodeOutput( - { - "samples": audio_latents, - "sample_rate": int(audio_vae.sample_rate), - "type": "audio", - } - ) + def execute(cls, audio, audio_vae) -> io.NodeOutput: + return super().execute(audio_vae, audio) class LTXVAudioVAEDecode(io.ComfyNode): @@ -80,12 +76,12 @@ def define_schema(cls) -> io.Schema: ) @classmethod - def execute(cls, samples, audio_vae: AudioVAE) -> io.NodeOutput: + def execute(cls, samples, audio_vae) -> io.NodeOutput: audio_latent = samples["samples"] if audio_latent.is_nested: audio_latent = audio_latent.unbind()[-1] - audio = audio_vae.decode(audio_latent).to(audio_latent.device) - output_audio_sample_rate = audio_vae.output_sample_rate + audio = audio_vae.decode(audio_latent).movedim(-1, 1).to(audio_latent.device) + output_audio_sample_rate = audio_vae.first_stage_model.output_sample_rate return io.NodeOutput( { "waveform": audio, @@ -143,17 +139,17 @@ def execute( frames_number: int, frame_rate: int, batch_size: int, - audio_vae: AudioVAE, + audio_vae, ) -> io.NodeOutput: """Generate empty audio latents matching the reference pipeline structure.""" assert audio_vae is not None, "Audio VAE model is required" z_channels = audio_vae.latent_channels - audio_freq = audio_vae.latent_frequency_bins - sampling_rate = int(audio_vae.sample_rate) + audio_freq = audio_vae.first_stage_model.latent_frequency_bins + sampling_rate = int(audio_vae.first_stage_model.sample_rate) - num_audio_latents = audio_vae.num_of_latents_from_frames(frames_number, frame_rate) + num_audio_latents = audio_vae.first_stage_model.num_of_latents_from_frames(frames_number, frame_rate) audio_latents = torch.zeros( (batch_size, z_channels, num_audio_latents, audio_freq), diff --git a/comfy_extras/nodes_sam3.py b/comfy_extras/nodes_sam3.py new file mode 100644 index 000000000000..5cf92ccb39a8 --- /dev/null +++ b/comfy_extras/nodes_sam3.py @@ -0,0 +1,529 @@ +""" +SAM3 (Segment Anything 3) nodes for detection, segmentation, and video tracking. +""" + +from typing_extensions import override + +import json +import os +import torch +import torch.nn.functional as F +import comfy.model_management +import comfy.utils +import folder_paths +from comfy_api.latest import ComfyExtension, io, ui +import av +from fractions import Fraction + + +def _extract_text_prompts(conditioning, device, dtype): + """Extract list of (text_embeddings, text_mask) from conditioning.""" + cond_meta = conditioning[0][1] + multi = cond_meta.get("sam3_multi_cond") + prompts = [] + if multi is not None: + for entry in multi: + emb = entry["cond"].to(device=device, dtype=dtype) + mask = entry["attention_mask"].to(device) if entry["attention_mask"] is not None else None + if mask is None: + mask = torch.ones(emb.shape[0], emb.shape[1], dtype=torch.int64, device=device) + prompts.append((emb, mask, entry.get("max_detections", 1))) + else: + emb = conditioning[0][0].to(device=device, dtype=dtype) + mask = cond_meta.get("attention_mask") + if mask is not None: + mask = mask.to(device) + else: + mask = torch.ones(emb.shape[0], emb.shape[1], dtype=torch.int64, device=device) + prompts.append((emb, mask, 1)) + return prompts + + +def _refine_mask(sam3_model, orig_image_hwc, coarse_mask, box_xyxy, H, W, device, dtype, iterations): + """Refine a coarse detector mask via SAM decoder, cropping to the detection box. + + Returns: [1, H, W] binary mask + """ + def _coarse_fallback(): + return (F.interpolate(coarse_mask.unsqueeze(0).unsqueeze(0), size=(H, W), + mode="bilinear", align_corners=False)[0] > 0).float() + + if iterations <= 0: + return _coarse_fallback() + + pad_frac = 0.1 + x1, y1, x2, y2 = box_xyxy.tolist() + bw, bh = x2 - x1, y2 - y1 + cx1 = max(0, int(x1 - bw * pad_frac)) + cy1 = max(0, int(y1 - bh * pad_frac)) + cx2 = min(W, int(x2 + bw * pad_frac)) + cy2 = min(H, int(y2 + bh * pad_frac)) + if cx2 <= cx1 or cy2 <= cy1: + return _coarse_fallback() + + crop = orig_image_hwc[cy1:cy2, cx1:cx2, :3] + crop_1008 = comfy.utils.common_upscale(crop.unsqueeze(0).movedim(-1, 1), 1008, 1008, "bilinear", crop="disabled") + crop_frame = crop_1008.to(device=device, dtype=dtype) + crop_h, crop_w = cy2 - cy1, cx2 - cx1 + + # Crop coarse mask and refine via SAM on the cropped image + mask_h, mask_w = coarse_mask.shape[-2:] + mx1, my1 = int(cx1 / W * mask_w), int(cy1 / H * mask_h) + mx2, my2 = int(cx2 / W * mask_w), int(cy2 / H * mask_h) + if mx2 <= mx1 or my2 <= my1: + return _coarse_fallback() + mask_logit = coarse_mask[..., my1:my2, mx1:mx2].unsqueeze(0).unsqueeze(0) + for _ in range(iterations): + coarse_input = F.interpolate(mask_logit, size=(1008, 1008), mode="bilinear", align_corners=False) + mask_logit = sam3_model.forward_segment(crop_frame, mask_inputs=coarse_input) + + refined_crop = F.interpolate(mask_logit, size=(crop_h, crop_w), mode="bilinear", align_corners=False) + full_mask = torch.zeros(1, 1, H, W, device=device, dtype=dtype) + full_mask[:, :, cy1:cy2, cx1:cx2] = refined_crop + coarse_full = F.interpolate(coarse_mask.unsqueeze(0).unsqueeze(0), size=(H, W), mode="bilinear", align_corners=False) + return ((full_mask[0] > 0) | (coarse_full[0] > 0)).float() + + + +class SAM3_Detect(io.ComfyNode): + """Open-vocabulary detection and segmentation using text, box, or point prompts.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SAM3_Detect", + display_name="SAM3 Detect", + category="detection/", + search_aliases=["sam3", "segment anything", "open vocabulary", "text detection", "segment"], + inputs=[ + io.Model.Input("model", display_name="model"), + io.Image.Input("image", display_name="image"), + io.Conditioning.Input("conditioning", display_name="conditioning", optional=True, tooltip="Text conditioning from CLIPTextEncode"), + io.BoundingBox.Input("bboxes", display_name="bboxes", force_input=True, optional=True, tooltip="Bounding boxes to segment within"), + io.String.Input("positive_coords", display_name="positive_coords", force_input=True, optional=True, tooltip="Positive point prompts as JSON [{\"x\": int, \"y\": int}, ...] (pixel coords)"), + io.String.Input("negative_coords", display_name="negative_coords", force_input=True, optional=True, tooltip="Negative point prompts as JSON [{\"x\": int, \"y\": int}, ...] (pixel coords)"), + io.Float.Input("threshold", display_name="threshold", default=0.5, min=0.0, max=1.0, step=0.01), + io.Int.Input("refine_iterations", display_name="refine_iterations", default=2, min=0, max=5, tooltip="SAM decoder refinement passes (0=use raw detector masks)"), + io.Boolean.Input("individual_masks", display_name="individual_masks", default=False, tooltip="Output per-object masks instead of union"), + ], + outputs=[ + io.Mask.Output("masks"), + io.BoundingBox.Output("bboxes"), + ], + ) + + @classmethod + def execute(cls, model, image, conditioning=None, bboxes=None, positive_coords=None, negative_coords=None, threshold=0.5, refine_iterations=2, individual_masks=False) -> io.NodeOutput: + B, H, W, C = image.shape + image_in = comfy.utils.common_upscale(image[..., :3].movedim(-1, 1), 1008, 1008, "bilinear", crop="disabled") + + # Convert bboxes to normalized cxcywh format, per-frame list of [1, N, 4] tensors. + # Supports: single dict (all frames), list[dict] (all frames), list[list[dict]] (per-frame). + def _boxes_to_tensor(box_list): + coords = [] + for d in box_list: + cx = (d["x"] + d["width"] / 2) / W + cy = (d["y"] + d["height"] / 2) / H + coords.append([cx, cy, d["width"] / W, d["height"] / H]) + return torch.tensor([coords], dtype=torch.float32) # [1, N, 4] + + per_frame_boxes = None + if bboxes is not None: + if isinstance(bboxes, dict): + # Single box → same for all frames + shared = _boxes_to_tensor([bboxes]) + per_frame_boxes = [shared] * B + elif isinstance(bboxes, list) and len(bboxes) > 0 and isinstance(bboxes[0], list): + # list[list[dict]] → per-frame boxes + per_frame_boxes = [_boxes_to_tensor(frame_boxes) if frame_boxes else None for frame_boxes in bboxes] + # Pad to B if fewer frames provided + while len(per_frame_boxes) < B: + per_frame_boxes.append(per_frame_boxes[-1] if per_frame_boxes else None) + elif isinstance(bboxes, list) and len(bboxes) > 0: + # list[dict] → same boxes for all frames + shared = _boxes_to_tensor(bboxes) + per_frame_boxes = [shared] * B + + # Parse point prompts from JSON (KJNodes PointsEditor format: [{"x": int, "y": int}, ...]) + pos_pts = json.loads(positive_coords) if positive_coords else [] + neg_pts = json.loads(negative_coords) if negative_coords else [] + has_points = len(pos_pts) > 0 or len(neg_pts) > 0 + + comfy.model_management.load_model_gpu(model) + device = comfy.model_management.get_torch_device() + dtype = model.model.get_dtype() + sam3_model = model.model.diffusion_model + + # Build point inputs for tracker SAM decoder path + point_inputs = None + if has_points: + all_coords = [[p["x"] / W * 1008, p["y"] / H * 1008] for p in pos_pts] + \ + [[p["x"] / W * 1008, p["y"] / H * 1008] for p in neg_pts] + all_labels = [1] * len(pos_pts) + [0] * len(neg_pts) + point_inputs = { + "point_coords": torch.tensor([all_coords], dtype=dtype, device=device), + "point_labels": torch.tensor([all_labels], dtype=torch.int32, device=device), + } + + cond_list = _extract_text_prompts(conditioning, device, dtype) if conditioning is not None and len(conditioning) > 0 else [] + has_text = len(cond_list) > 0 + + # Run per-image through detector (text/boxes) and/or tracker (points) + all_bbox_dicts = [] + all_masks = [] + pbar = comfy.utils.ProgressBar(B) + + for b in range(B): + frame = image_in[b:b+1].to(device=device, dtype=dtype) + b_boxes = None + if per_frame_boxes is not None and per_frame_boxes[b] is not None: + b_boxes = per_frame_boxes[b].to(device=device, dtype=dtype) + + frame_bbox_dicts = [] + frame_masks = [] + + # Point prompts: tracker SAM decoder path with iterative refinement + if point_inputs is not None: + mask_logit = sam3_model.forward_segment(frame, point_inputs=point_inputs) + for _ in range(max(0, refine_iterations - 1)): + mask_logit = sam3_model.forward_segment(frame, mask_inputs=mask_logit) + mask = F.interpolate(mask_logit, size=(H, W), mode="bilinear", align_corners=False) + frame_masks.append((mask[0] > 0).float()) + + # Box prompts: SAM decoder path (segment inside each box) + if b_boxes is not None and not has_text: + for box_cxcywh in b_boxes[0]: + cx, cy, bw, bh = box_cxcywh.tolist() + # Convert cxcywh normalized → xyxy in 1008 space → [1, 2, 2] corners + sam_box = torch.tensor([[[(cx - bw/2) * 1008, (cy - bh/2) * 1008], + [(cx + bw/2) * 1008, (cy + bh/2) * 1008]]], + device=device, dtype=dtype) + mask_logit = sam3_model.forward_segment(frame, box_inputs=sam_box) + for _ in range(max(0, refine_iterations - 1)): + mask_logit = sam3_model.forward_segment(frame, mask_inputs=mask_logit) + mask = F.interpolate(mask_logit, size=(H, W), mode="bilinear", align_corners=False) + frame_masks.append((mask[0] > 0).float()) + + # Text prompts: run detector per text prompt (each detects one category) + for text_embeddings, text_mask, max_det in cond_list: + results = sam3_model( + frame, text_embeddings=text_embeddings, text_mask=text_mask, + boxes=b_boxes, threshold=threshold, orig_size=(H, W)) + + pred_boxes = results["boxes"][0] + scores = results["scores"][0] + masks = results["masks"][0] + + probs = scores.sigmoid() + keep = probs > threshold + kept_boxes = pred_boxes[keep].cpu() + kept_scores = probs[keep].cpu() + kept_masks = masks[keep] + + order = kept_scores.argsort(descending=True)[:max_det] + kept_boxes = kept_boxes[order] + kept_scores = kept_scores[order] + kept_masks = kept_masks[order] + + for box, score in zip(kept_boxes, kept_scores): + frame_bbox_dicts.append({ + "x": float(box[0]), "y": float(box[1]), + "width": float(box[2] - box[0]), "height": float(box[3] - box[1]), + "score": float(score), + }) + for m, box in zip(kept_masks, kept_boxes): + frame_masks.append(_refine_mask( + sam3_model, image[b], m, box, H, W, device, dtype, refine_iterations)) + + all_bbox_dicts.append(frame_bbox_dicts) + if len(frame_masks) > 0: + combined = torch.cat(frame_masks, dim=0) # [N_obj, H, W] + if individual_masks: + all_masks.append(combined) + else: + all_masks.append((combined > 0).any(dim=0).float()) + else: + if individual_masks: + all_masks.append(torch.zeros(0, H, W, device=comfy.model_management.intermediate_device())) + else: + all_masks.append(torch.zeros(H, W, device=comfy.model_management.intermediate_device())) + pbar.update(1) + + idev = comfy.model_management.intermediate_device() + all_masks = [m.to(idev) for m in all_masks] + mask_out = torch.cat(all_masks, dim=0) if individual_masks else torch.stack(all_masks) + return io.NodeOutput(mask_out, all_bbox_dicts) + + +SAM3TrackData = io.Custom("SAM3_TRACK_DATA") + +class SAM3_VideoTrack(io.ComfyNode): + """Track objects across video frames using SAM3's memory-based tracker.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SAM3_VideoTrack", + display_name="SAM3 Video Track", + category="detection/", + search_aliases=["sam3", "video", "track", "propagate"], + inputs=[ + io.Image.Input("images", display_name="images", tooltip="Video frames as batched images"), + io.Model.Input("model", display_name="model"), + io.Mask.Input("initial_mask", display_name="initial_mask", optional=True, tooltip="Mask(s) for the first frame to track (one per object)"), + io.Conditioning.Input("conditioning", display_name="conditioning", optional=True, tooltip="Text conditioning for detecting new objects during tracking"), + io.Float.Input("detection_threshold", display_name="detection_threshold", default=0.5, min=0.0, max=1.0, step=0.01, tooltip="Score threshold for text-prompted detection"), + io.Int.Input("max_objects", display_name="max_objects", default=0, min=0, tooltip="Max tracked objects (0=unlimited). Initial masks count toward this limit."), + io.Int.Input("detect_interval", display_name="detect_interval", default=1, min=1, tooltip="Run detection every N frames (1=every frame). Higher values save compute."), + ], + outputs=[ + SAM3TrackData.Output("track_data", display_name="track_data"), + ], + ) + + @classmethod + def execute(cls, images, model, initial_mask=None, conditioning=None, detection_threshold=0.5, max_objects=0, detect_interval=1) -> io.NodeOutput: + N, H, W, C = images.shape + + comfy.model_management.load_model_gpu(model) + device = comfy.model_management.get_torch_device() + dtype = model.model.get_dtype() + sam3_model = model.model.diffusion_model + + frames = images[..., :3].movedim(-1, 1) + frames_in = comfy.utils.common_upscale(frames, 1008, 1008, "bilinear", crop="disabled").to(device=device, dtype=dtype) + + init_masks = None + if initial_mask is not None: + init_masks = initial_mask.unsqueeze(1).to(device=device, dtype=dtype) + + pbar = comfy.utils.ProgressBar(N) + + text_prompts = None + if conditioning is not None and len(conditioning) > 0: + text_prompts = [(emb, mask) for emb, mask, _ in _extract_text_prompts(conditioning, device, dtype)] + elif initial_mask is None: + raise ValueError("Either initial_mask or conditioning must be provided") + + result = sam3_model.forward_video( + images=frames_in, initial_masks=init_masks, pbar=pbar, text_prompts=text_prompts, + new_det_thresh=detection_threshold, max_objects=max_objects, + detect_interval=detect_interval) + result["orig_size"] = (H, W) + return io.NodeOutput(result) + + +class SAM3_TrackPreview(io.ComfyNode): + """Visualize tracked objects with distinct colors as a video preview. No tensor output — saves to temp video.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SAM3_TrackPreview", + display_name="SAM3 Track Preview", + category="detection/", + inputs=[ + SAM3TrackData.Input("track_data", display_name="track_data"), + io.Image.Input("images", display_name="images", optional=True), + io.Float.Input("opacity", display_name="opacity", default=0.5, min=0.0, max=1.0, step=0.05), + io.Float.Input("fps", display_name="fps", default=24.0, min=1.0, max=120.0, step=1.0), + ], + is_output_node=True, + ) + + COLORS = [ + (0.12, 0.47, 0.71), (1.0, 0.5, 0.05), (0.17, 0.63, 0.17), (0.84, 0.15, 0.16), + (0.58, 0.4, 0.74), (0.55, 0.34, 0.29), (0.89, 0.47, 0.76), (0.5, 0.5, 0.5), + (0.74, 0.74, 0.13), (0.09, 0.75, 0.81), (0.94, 0.76, 0.06), (0.42, 0.68, 0.84), + ] + + # 5x3 bitmap font atlas for digits 0-9 [10, 5, 3] + _glyph_cache = {} # (device, scale) -> (glyphs, outlines, gh, gw, oh, ow) + + @staticmethod + def _get_glyphs(device, scale=3): + key = (device, scale) + if key in SAM3_TrackPreview._glyph_cache: + return SAM3_TrackPreview._glyph_cache[key] + atlas = torch.tensor([ + [[1,1,1],[1,0,1],[1,0,1],[1,0,1],[1,1,1]], + [[0,1,0],[1,1,0],[0,1,0],[0,1,0],[1,1,1]], + [[1,1,1],[0,0,1],[1,1,1],[1,0,0],[1,1,1]], + [[1,1,1],[0,0,1],[1,1,1],[0,0,1],[1,1,1]], + [[1,0,1],[1,0,1],[1,1,1],[0,0,1],[0,0,1]], + [[1,1,1],[1,0,0],[1,1,1],[0,0,1],[1,1,1]], + [[1,1,1],[1,0,0],[1,1,1],[1,0,1],[1,1,1]], + [[1,1,1],[0,0,1],[0,0,1],[0,0,1],[0,0,1]], + [[1,1,1],[1,0,1],[1,1,1],[1,0,1],[1,1,1]], + [[1,1,1],[1,0,1],[1,1,1],[0,0,1],[1,1,1]], + ], dtype=torch.bool) + glyphs, outlines = [], [] + for d in range(10): + g = atlas[d].repeat_interleave(scale, 0).repeat_interleave(scale, 1) + padded = F.pad(g.float().unsqueeze(0).unsqueeze(0), (1,1,1,1)) + o = (F.max_pool2d(padded, 3, stride=1, padding=1)[0, 0] > 0) + glyphs.append(g.to(device)) + outlines.append(o.to(device)) + gh, gw = glyphs[0].shape + oh, ow = outlines[0].shape + SAM3_TrackPreview._glyph_cache[key] = (glyphs, outlines, gh, gw, oh, ow) + return SAM3_TrackPreview._glyph_cache[key] + + @staticmethod + def _draw_number_gpu(frame, number, cx, cy, color, scale=3): + """Draw a number on a GPU tensor [H, W, 3] float 0-1 at (cx, cy) with outline.""" + H, W = frame.shape[:2] + device = frame.device + glyphs, outlines, gh, gw, oh, ow = SAM3_TrackPreview._get_glyphs(device, scale) + color_t = torch.tensor(color, device=device, dtype=frame.dtype) + digs = [int(d) for d in str(number)] + total_w = len(digs) * (gw + scale) - scale + x0 = cx - total_w // 2 + y0 = cy - gh // 2 + for i, d in enumerate(digs): + dx = x0 + i * (gw + scale) + # Black outline + oy0, ox0 = y0 - 1, dx - 1 + osy1, osx1 = max(0, -oy0), max(0, -ox0) + osy2, osx2 = min(oh, H - oy0), min(ow, W - ox0) + if osy2 > osy1 and osx2 > osx1: + fy1, fx1 = oy0 + osy1, ox0 + osx1 + frame[fy1:fy1+(osy2-osy1), fx1:fx1+(osx2-osx1)][outlines[d][osy1:osy2, osx1:osx2]] = 0 + # Colored fill + sy1, sx1 = max(0, -y0), max(0, -dx) + sy2, sx2 = min(gh, H - y0), min(gw, W - dx) + if sy2 > sy1 and sx2 > sx1: + fy1, fx1 = y0 + sy1, dx + sx1 + frame[fy1:fy1+(sy2-sy1), fx1:fx1+(sx2-sx1)][glyphs[d][sy1:sy2, sx1:sx2]] = color_t + + @classmethod + def execute(cls, track_data, images=None, opacity=0.5, fps=24.0) -> io.NodeOutput: + + from comfy.ldm.sam3.tracker import unpack_masks + packed = track_data["packed_masks"] + H, W = track_data["orig_size"] + if images is not None: + H, W = images.shape[1], images.shape[2] + if packed is None: + N, N_obj = track_data["n_frames"], 0 + else: + N, N_obj = packed.shape[0], packed.shape[1] + + import uuid + gpu = comfy.model_management.get_torch_device() + temp_dir = folder_paths.get_temp_directory() + filename = f"sam3_track_preview_{uuid.uuid4().hex[:8]}.mp4" + filepath = os.path.join(temp_dir, filename) + with av.open(filepath, mode='w') as output: + stream = output.add_stream('h264', rate=Fraction(round(fps * 1000), 1000)) + stream.width = W + stream.height = H + stream.pix_fmt = 'yuv420p' + + frame_cpu = torch.empty(H, W, 3, dtype=torch.uint8) + frame_np = frame_cpu.numpy() + if N_obj > 0: + colors_t = torch.tensor([cls.COLORS[i % len(cls.COLORS)] for i in range(N_obj)], + device=gpu, dtype=torch.float32) + grid_y = torch.arange(H, device=gpu).view(1, H, 1) + grid_x = torch.arange(W, device=gpu).view(1, 1, W) + for t in range(N): + if images is not None and t < images.shape[0]: + frame = images[t].clone() + else: + frame = torch.zeros(H, W, 3) + + if N_obj > 0: + frame_binary = unpack_masks(packed[t:t+1].to(gpu)) # [1, N_obj, H, W] bool + frame_masks = F.interpolate(frame_binary.float(), size=(H, W), mode="nearest")[0] + frame_gpu = frame.to(gpu) + bool_masks = frame_masks > 0.5 + any_mask = bool_masks.any(dim=0) + if any_mask.any(): + obj_idx_map = bool_masks.to(torch.uint8).argmax(dim=0) + color_overlay = colors_t[obj_idx_map] + mask_3d = any_mask.unsqueeze(-1) + frame_gpu = torch.where(mask_3d, frame_gpu * (1 - opacity) + color_overlay * opacity, frame_gpu) + area = bool_masks.sum(dim=(-1, -2)).clamp_(min=1) + cy = (bool_masks * grid_y).sum(dim=(-1, -2)) // area + cx = (bool_masks * grid_x).sum(dim=(-1, -2)) // area + has = area > 1 + scores = track_data.get("scores", []) + for obj_idx in range(N_obj): + if has[obj_idx]: + _cx, _cy = int(cx[obj_idx]), int(cy[obj_idx]) + color = cls.COLORS[obj_idx % len(cls.COLORS)] + SAM3_TrackPreview._draw_number_gpu(frame_gpu, obj_idx, _cx, _cy, color) + if obj_idx < len(scores) and scores[obj_idx] < 1.0: + SAM3_TrackPreview._draw_number_gpu(frame_gpu, int(scores[obj_idx] * 100), + _cx, _cy + 5 * 3 + 3, color, scale=2) + frame_cpu.copy_(frame_gpu.clamp_(0, 1).mul_(255).byte()) + else: + frame_cpu.copy_(frame.clamp_(0, 1).mul_(255).byte()) + + vframe = av.VideoFrame.from_ndarray(frame_np, format='rgb24') + output.mux(stream.encode(vframe.reformat(format='yuv420p'))) + output.mux(stream.encode(None)) + return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(filename, "", io.FolderType.temp)])) + + +class SAM3_TrackToMask(io.ComfyNode): + """Select tracked objects by index and output as mask.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SAM3_TrackToMask", + display_name="SAM3 Track to Mask", + category="detection/", + inputs=[ + SAM3TrackData.Input("track_data", display_name="track_data"), + io.String.Input("object_indices", display_name="object_indices", default="", + tooltip="Comma-separated object indices to include (e.g. '0,2,3'). Empty = all objects."), + ], + outputs=[ + io.Mask.Output("masks", display_name="masks"), + ], + ) + + @classmethod + def execute(cls, track_data, object_indices="") -> io.NodeOutput: + from comfy.ldm.sam3.tracker import unpack_masks + packed = track_data["packed_masks"] + H, W = track_data["orig_size"] + + if packed is None: + N = track_data["n_frames"] + return io.NodeOutput(torch.zeros(N, H, W, device=comfy.model_management.intermediate_device())) + + N, N_obj = packed.shape[0], packed.shape[1] + + if object_indices.strip(): + indices = [int(i.strip()) for i in object_indices.split(",") if i.strip().isdigit()] + indices = [i for i in indices if 0 <= i < N_obj] + else: + indices = list(range(N_obj)) + + if not indices: + return io.NodeOutput(torch.zeros(N, H, W, device=comfy.model_management.intermediate_device())) + + selected = packed[:, indices] + binary = unpack_masks(selected) # [N, len(indices), Hm, Wm] bool + union = binary.any(dim=1, keepdim=True).float() + mask_out = F.interpolate(union, size=(H, W), mode="bilinear", align_corners=False)[:, 0] + return io.NodeOutput(mask_out) + + +class SAM3Extension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + SAM3_Detect, + SAM3_VideoTrack, + SAM3_TrackPreview, + SAM3_TrackToMask, + ] + + +async def comfy_entrypoint() -> SAM3Extension: + return SAM3Extension() diff --git a/execution.py b/execution.py index 5e02dffb204f..e15eb4bda008 100644 --- a/execution.py +++ b/execution.py @@ -811,11 +811,30 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs= self._notify_prompt_lifecycle("end", prompt_id) -async def validate_inputs(prompt_id, prompt, item, validated): +async def validate_inputs(prompt_id, prompt, item, validated, visiting=None): + if visiting is None: + visiting = [] + unique_id = item if unique_id in validated: return validated[unique_id] + if unique_id in visiting: + cycle_path_nodes = visiting[visiting.index(unique_id):] + [unique_id] + cycle_nodes = list(dict.fromkeys(cycle_path_nodes)) + cycle_path = " -> ".join(f"{node_id} ({prompt[node_id]['class_type']})" for node_id in cycle_path_nodes) + for node_id in cycle_nodes: + validated[node_id] = (False, [{ + "type": "dependency_cycle", + "message": "Dependency cycle detected", + "details": cycle_path, + "extra_info": { + "node_id": node_id, + "cycle_nodes": cycle_nodes, + } + }], node_id) + return validated[unique_id] + inputs = prompt[unique_id]['inputs'] class_type = prompt[unique_id]['class_type'] obj_class = nodes.NODE_CLASS_MAPPINGS[class_type] @@ -899,7 +918,11 @@ async def validate_inputs(prompt_id, prompt, item, validated): errors.append(error) continue try: - r = await validate_inputs(prompt_id, prompt, o_id, validated) + visiting.append(unique_id) + try: + r = await validate_inputs(prompt_id, prompt, o_id, validated, visiting) + finally: + visiting.pop() if r[0] is False: # `r` will be set in `validated[o_id]` already valid = False @@ -1048,10 +1071,13 @@ async def validate_inputs(prompt_id, prompt, item, validated): errors.append(error) continue - if len(errors) > 0 or valid is not True: - ret = (False, errors, unique_id) - else: - ret = (True, [], unique_id) + ret = validated.get(unique_id, (True, [], unique_id)) + # Recursive cycle detection may have already populated an error on us. Join it. + ret = ( + ret[0] and valid is True and not errors, + ret[1] + [error for error in errors if error not in ret[1]], + unique_id, + ) validated[unique_id] = ret return ret diff --git a/folder_paths.py b/folder_paths.py index 9c96540e3dd9..80f4b291a895 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -52,6 +52,8 @@ folder_names_and_paths["audio_encoders"] = ([os.path.join(models_dir, "audio_encoders")], supported_pt_extensions) +folder_names_and_paths["frame_interpolation"] = ([os.path.join(models_dir, "frame_interpolation")], supported_pt_extensions) + output_directory = os.path.join(base_path, "output") temp_directory = os.path.join(base_path, "temp") input_directory = os.path.join(base_path, "input") diff --git a/main.py b/main.py index de145a1e9378..fd228b2562d9 100644 --- a/main.py +++ b/main.py @@ -9,6 +9,8 @@ import time from comfy.cli_args import args, enables_dynamic_vram from app.logger import setup_logger +setup_logger(log_level=args.verbose, use_stdout=args.log_stdout) + from app.assets.seeder import asset_seeder from app.assets.services import register_output_files import itertools @@ -27,8 +29,6 @@ os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1' os.environ['DO_NOT_TRACK'] = '1' -setup_logger(log_level=args.verbose, use_stdout=args.log_stdout) - faulthandler.enable(file=sys.stderr, all_threads=False) import comfy_aimdo.control diff --git a/manager_requirements.txt b/manager_requirements.txt index f770ec933d73..a079d3492931 100644 --- a/manager_requirements.txt +++ b/manager_requirements.txt @@ -1 +1 @@ -comfyui_manager==4.1 +comfyui_manager==4.2.1 diff --git a/models/frame_interpolation/put_frame_interpolation_models_here b/models/frame_interpolation/put_frame_interpolation_models_here new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/nodes.py b/nodes.py index d81ac2935ebe..82d7ef332a97 100644 --- a/nodes.py +++ b/nodes.py @@ -2565,7 +2565,9 @@ async def init_builtin_extra_nodes(): "nodes_number_convert.py", "nodes_painter.py", "nodes_curve.py", - "nodes_rtdetr.py" + "nodes_rtdetr.py", + "nodes_frame_interpolation.py", + "nodes_sam3.py" ] import_failed = [] diff --git a/utils/install_util.py b/utils/install_util.py index 34489aec563c..fdba23a8f062 100644 --- a/utils/install_util.py +++ b/utils/install_util.py @@ -39,7 +39,7 @@ def get_required_packages_versions(): if len(s) == 2: version_str = s[-1] if not is_valid_version(version_str): - logging.error(f"Invalid version format in requirements.txt: {version_str}") + logging.debug(f"Invalid version format for {s[0]} in requirements.txt: {version_str}") continue out[s[0]] = version_str return out.copy() From a61e2bbb85a3052dc75205fa30768ecdbdc89c91 Mon Sep 17 00:00:00 2001 From: Alexis Rolland Date: Thu, 7 May 2026 12:49:23 +0800 Subject: [PATCH 57/90] Add device selection on Image Only Load Checkpoint (CORE-158) (#13748) * Add device selection on Image Only Load Checkpoint * Rename variables * Update variable name * Fix linting --- comfy_extras/nodes_video_model.py | 65 +++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/comfy_extras/nodes_video_model.py b/comfy_extras/nodes_video_model.py index bf98e6b8283d..a3b148d7d55c 100644 --- a/comfy_extras/nodes_video_model.py +++ b/comfy_extras/nodes_video_model.py @@ -23,6 +23,69 @@ def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True): return (out[0], out[3], out[2]) +class ImageOnlyCheckpointLoaderDevice: + @classmethod + def INPUT_TYPES(s): + device_options = comfy.model_management.get_gpu_device_options() + return { + "required": { + "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), + }, + "optional": { + "model_device": (device_options, {"advanced": True, "tooltip": "Device for the diffusion model (UNET)."}), + "clip_vision_device": (device_options, {"advanced": True, "tooltip": "Device for the CLIP vision encoder."}), + "vae_device": (device_options, {"advanced": True, "tooltip": "Device for the VAE."}), + } + } + RETURN_TYPES = ("MODEL", "CLIP_VISION", "VAE") + FUNCTION = "load_checkpoint" + + CATEGORY = "loaders/video_models" + + @classmethod + def VALIDATE_INPUTS(cls, model_device="default", clip_vision_device="default", vae_device="default"): + return True + + def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True, model_device="default", clip_vision_device="default", vae_device="default"): + ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) + + model_options = {} + resolved_model = comfy.model_management.resolve_gpu_device_option(model_device) + if resolved_model is not None: + if resolved_model.type == "cpu": + model_options["load_device"] = model_options["offload_device"] = resolved_model + else: + model_options["load_device"] = resolved_model + + cv_model_options = {} + resolved_clip = comfy.model_management.resolve_gpu_device_option(clip_vision_device) + if resolved_clip is not None: + if resolved_clip.type == "cpu": + cv_model_options["load_device"] = cv_model_options["offload_device"] = resolved_clip + else: + cv_model_options["load_device"] = resolved_clip + + # VAE device is passed via model_options["load_device"] which + # load_state_dict_guess_config forwards to the VAE constructor. + # If vae_device differs from model_device, we override after loading. + resolved_vae = comfy.model_management.resolve_gpu_device_option(vae_device) + + out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) + model_patcher, clip, vae, clip_vision = out[:4] + + # Apply VAE device override if it differs from the model device + if resolved_vae is not None and vae is not None: + vae.device = resolved_vae + if resolved_vae.type == "cpu": + offload = resolved_vae + else: + offload = comfy.model_management.vae_offload_device() + vae.patcher.load_device = resolved_vae + vae.patcher.offload_device = offload + + return (model_patcher, clip_vision, vae) + + class SVD_img2vid_Conditioning: @classmethod def INPUT_TYPES(s): @@ -149,6 +212,7 @@ def append(self, conditioning, width, height, temporal, x, y, z, strength): NODE_CLASS_MAPPINGS = { "ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader, + "ImageOnlyCheckpointLoaderDevice": ImageOnlyCheckpointLoaderDevice, "SVD_img2vid_Conditioning": SVD_img2vid_Conditioning, "VideoLinearCFGGuidance": VideoLinearCFGGuidance, "VideoTriangleCFGGuidance": VideoTriangleCFGGuidance, @@ -158,4 +222,5 @@ def append(self, conditioning, width, height, temporal, x, y, z, strength): NODE_DISPLAY_NAME_MAPPINGS = { "ImageOnlyCheckpointLoader": "Image Only Checkpoint Loader (img2vid model)", + "ImageOnlyCheckpointLoaderDevice": "Image Only Checkpoint Loader (Device)", } From 9e3ede14062ac42d85885eae20e75b754c85e030 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 19 May 2026 20:11:53 -0700 Subject: [PATCH 58/90] Fix MultiGPU scheduler capacity accounting (#14000) Fixes _calc_cond_batch_multigpu so that: 1. conds_per_device uses real division before math.ceil. The previous expression math.ceil(total_conds // len(devices)) applied integer floor division first, making ceil a no-op. For 3 conds across 2 devices this produced conds_per_device=1 instead of 2. 2. The scheduling loop skips devices that have already reached capacity instead of appending empty batch groups. Without this guard, the loop could repeatedly emit zero-length groups for a full device, leaving sampling stuck at 0/N until timeout. Reproduces with an Omnigen2 image workflow that produces three condition entries scheduled across two CUDA devices. With the fix the scheduler assigns conds_per_device=2 and splits the batches as 2 + 1 across the two devices, allowing sampling to complete. Original fix authored and validated by @pollockjj in pollockjj/ComfyUI#64. Co-authored-by: John Pollock --- comfy/samplers.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index 88393e3673ca..83fa2e6098ca 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -394,7 +394,7 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t total_conds = 0 for to_run in hooked_to_run.values(): total_conds += len(to_run) - conds_per_device = max(1, math.ceil(total_conds//len(devices))) + conds_per_device = max(1, math.ceil(total_conds / len(devices))) index_device = 0 current_device = devices[index_device] # run every hooked_to_run separately @@ -406,13 +406,17 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t batched_to_run_length = 0 for btr in batched_to_run: batched_to_run_length += len(btr[1]) + remaining_capacity = conds_per_device - batched_to_run_length + if remaining_capacity <= 0: + index_device += 1 + continue first = to_run[0] first_shape = first[0][0].shape to_batch_temp = [] # make sure not over conds_per_device limit when creating temp batch for x in range(len(to_run)): - if can_concat_cond(to_run[x][0], first[0]) and len(to_batch_temp) < (conds_per_device - batched_to_run_length): + if can_concat_cond(to_run[x][0], first[0]) and len(to_batch_temp) < remaining_capacity: to_batch_temp += [x] to_batch_temp.reverse() From 819c7c0702107511b4d08a7de5e7f03007b53799 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 19 May 2026 21:23:56 -0700 Subject: [PATCH 59/90] Refactor MultiGPU scheduler for readability and termination safety (#14001) Behaviour-equivalent cleanup of _calc_cond_batch_multigpu device scheduling. No change to batching decisions or memory checks for any valid input. Changes: * Replace re-summed batched_to_run_length with a per-device load dict (device_load), so capacity checks are O(1) and use a single source of truth. * Extract device selection into next_available_device(), which scans at most len(devices) positions and raises if no device has remaining capacity. This makes the 'skip a full device' rule live in one place instead of two and guarantees the outer while loop cannot spin forever on a scheduling bug. * Drop the unused current_device assignment before the outer loop and the index_device % len(devices) modulo dance (now handled inside next_available_device). * Minor cleanups: list comprehensions for total_conds, conds_to_batch, and the devices list. --- comfy/samplers.py | 52 ++++++++++++++++++++++++++--------------------- 1 file changed, 29 insertions(+), 23 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index 83fa2e6098ca..f0d67cb7e0f5 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -388,33 +388,40 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t model.current_patcher.prepare_state(timestep, model_options) - devices = [dev_m for dev_m in model_options['multigpu_clones'].keys()] + devices = list(model_options['multigpu_clones'].keys()) device_batched_hooked_to_run: dict[torch.device, list[tuple[comfy.hooks.HookGroup, tuple]]] = {} + # Track conds currently scheduled per device; single source of truth for capacity checks. + device_load: dict[torch.device, int] = {d: 0 for d in devices} - total_conds = 0 - for to_run in hooked_to_run.values(): - total_conds += len(to_run) + total_conds = sum(len(to_run) for to_run in hooked_to_run.values()) conds_per_device = max(1, math.ceil(total_conds / len(devices))) - index_device = 0 - current_device = devices[index_device] + + def next_available_device(start: int) -> tuple[int, torch.device]: + """Return (index, device) for the next device with remaining capacity, starting at `start`. + + Scans at most len(devices) positions, so this always terminates. Raises if no device + has remaining capacity, which would indicate a bug in conds_per_device accounting. + """ + for offset in range(len(devices)): + i = (start + offset) % len(devices) + if device_load[devices[i]] < conds_per_device: + return i, devices[i] + raise RuntimeError( + f"MultiGPU scheduler: all {len(devices)} devices at capacity " + f"({conds_per_device}) but conds remain to schedule" + ) + # run every hooked_to_run separately + index_device = 0 for hooks, to_run in hooked_to_run.items(): while len(to_run) > 0: - current_device = devices[index_device % len(devices)] - batched_to_run = device_batched_hooked_to_run.setdefault(current_device, []) - # keep track of conds currently scheduled onto this device - batched_to_run_length = 0 - for btr in batched_to_run: - batched_to_run_length += len(btr[1]) - remaining_capacity = conds_per_device - batched_to_run_length - if remaining_capacity <= 0: - index_device += 1 - continue + index_device, current_device = next_available_device(index_device) + remaining_capacity = conds_per_device - device_load[current_device] first = to_run[0] first_shape = first[0][0].shape + # collect candidate indices that can be concatenated with `first`, up to remaining capacity to_batch_temp = [] - # make sure not over conds_per_device limit when creating temp batch for x in range(len(to_run)): if can_concat_cond(to_run[x][0], first[0]) and len(to_batch_temp) < remaining_capacity: to_batch_temp += [x] @@ -429,13 +436,12 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t if model.memory_required(input_shape) * 1.5 < free_memory: to_batch = batch_amount break - conds_to_batch = [] - for x in to_batch: - conds_to_batch.append(to_run.pop(x)) - batched_to_run_length += len(conds_to_batch) - batched_to_run.append((hooks, conds_to_batch)) - if batched_to_run_length >= conds_per_device: + conds_to_batch = [to_run.pop(x) for x in to_batch] + device_load[current_device] += len(conds_to_batch) + device_batched_hooked_to_run.setdefault(current_device, []).append((hooks, conds_to_batch)) + + if device_load[current_device] >= conds_per_device: index_device += 1 class thread_result(NamedTuple): From 50d1dd6273be924d5945f52e9f218ed22c4154a1 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Wed, 20 May 2026 16:46:23 -0700 Subject: [PATCH 60/90] Fix MultiGPU Options node discarding cloned GPUOptionsGroup GPUOptionsGroup.clone() returns a new instance, but the return value was discarded, causing the node to mutate the upstream caller's group in-place. When multiple MultiGPU Options nodes share an input group, each node's additions would leak into earlier siblings. Assign the clone result back to gpu_options so each node owns its own copy. Amp-Thread-ID: https://ampcode.com/threads/T-019e43b8-8258-70fd-ab3a-53e4c97f85d5 Co-authored-by: Amp --- comfy_extras/nodes_multigpu.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py index 5d24952bf61c..53b50029e409 100644 --- a/comfy_extras/nodes_multigpu.py +++ b/comfy_extras/nodes_multigpu.py @@ -68,7 +68,8 @@ def define_schema(cls): def execute(cls, device_index: int, relative_speed: float, gpu_options: comfy.multigpu.GPUOptionsGroup = None) -> io.NodeOutput: if not gpu_options: gpu_options = comfy.multigpu.GPUOptionsGroup() - gpu_options.clone() + else: + gpu_options = gpu_options.clone() opt = comfy.multigpu.GPUOptions(device_index=device_index, relative_speed=relative_speed) gpu_options.add(opt) From 9a681ccfc9d70f1797d1df0dd6e87eee4caf4b21 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Wed, 20 May 2026 16:46:31 -0700 Subject: [PATCH 61/90] Guard cached_patcher_init when output_model is False load_checkpoint_guess_config_clip_only() calls load_checkpoint_guess_config() with output_model=False, leaving out[0] as None. The subsequent unconditional assignment of cached_patcher_init crashed with AttributeError, breaking CLIP-only checkpoint loading entirely. Guard the assignment with a None check. Amp-Thread-ID: https://ampcode.com/threads/T-019e43b8-8258-70fd-ab3a-53e4c97f85d5 Co-authored-by: Amp --- comfy/sd.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/comfy/sd.py b/comfy/sd.py index e7857bf0a954..481c87cb1e21 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1688,7 +1688,8 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic) if out is None: raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd))) - out[0].cached_patcher_init = (load_checkpoint_guess_config, (ckpt_path, False, False, False, embedding_directory, output_model, model_options, te_model_options), 0) + if out[0] is not None: + out[0].cached_patcher_init = (load_checkpoint_guess_config, (ckpt_path, False, False, False, embedding_directory, output_model, model_options, te_model_options), 0) return out def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False): From ba417750a73d93c035485f71f56e5a3b146c111c Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Wed, 20 May 2026 16:46:38 -0700 Subject: [PATCH 62/90] Fix get_all_torch_devices for XPU/NPU and guard remove() torch.device(i) defaults to CUDA, so XPU/NPU branches were producing 'cuda:N' devices that don't match get_torch_device() output ('xpu:N'/'npu:N'). This caused devices.remove(get_torch_device()) to raise ValueError when exclude_current=True on non-NVIDIA hardware. Use explicit device strings, and guard the remove() with a membership check for safety. Amp-Thread-ID: https://ampcode.com/threads/T-019e43b8-8258-70fd-ab3a-53e4c97f85d5 Co-authored-by: Amp --- comfy/model_management.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 2e168f363804..10b982868376 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -215,17 +215,19 @@ def get_all_torch_devices(exclude_current=False): if cpu_state == CPUState.GPU: if is_nvidia(): for i in range(torch.cuda.device_count()): - devices.append(torch.device(i)) + devices.append(torch.device("cuda", i)) elif is_intel_xpu(): for i in range(torch.xpu.device_count()): - devices.append(torch.device(i)) + devices.append(torch.device("xpu", i)) elif is_ascend_npu(): for i in range(torch.npu.device_count()): - devices.append(torch.device(i)) + devices.append(torch.device("npu", i)) else: devices.append(get_torch_device()) if exclude_current: - devices.remove(get_torch_device()) + current = get_torch_device() + if current in devices: + devices.remove(current) return devices def get_gpu_device_options(): From dd85851efec772298772f159e2134cea45bd1b3e Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Wed, 20 May 2026 16:46:45 -0700 Subject: [PATCH 63/90] Prune inherited multigpu clones when max_gpus is lowered create_multigpu_deepclones cloned the existing 'multigpu' additional_models list verbatim and never pruned entries beyond limit_extra_devices. If a workflow was previously prepared for more GPUs, reducing max_gpus would leave stale clones attached and eligible for later scheduling. Replace the TODO block with a real prune that keeps only clones whose load_device is either the model's load_device or in limit_extra_devices, and re-match clones if anything was removed. Amp-Thread-ID: https://ampcode.com/threads/T-019e43b8-8258-70fd-ab3a-53e4c97f85d5 Co-authored-by: Amp --- comfy/multigpu.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/comfy/multigpu.py b/comfy/multigpu.py index 096270c12573..eff7d06499a9 100644 --- a/comfy/multigpu.py +++ b/comfy/multigpu.py @@ -162,16 +162,16 @@ def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options: gpu_options.register(model) else: logging.info("No extra torch devices need initialization, skipping initializing MultiGPU Work Units.") - # TODO: only keep model clones that don't go 'past' the intended max_gpu count - # multigpu_models = model.get_additional_models_with_key("multigpu") - # new_multigpu_models = [] - # for m in multigpu_models: - # if m.load_device in limit_extra_devices: - # new_multigpu_models.append(m) - # model.set_additional_models("multigpu", new_multigpu_models) - # persist skip_devices for use in sampling code - # if len(skip_devices) > 0 or "multigpu_skip_devices" in model.model_options: - # model.model_options["multigpu_skip_devices"] = skip_devices + # only keep model clones that don't go 'past' the intended max_gpu count; + # this prunes any inherited multigpu clones whose load_device is no longer allowed + # when max_gpus is lowered between runs. + allowed_devices = set(limit_extra_devices) + allowed_devices.add(model.load_device) + multigpu_models = model.get_additional_models_with_key("multigpu") + new_multigpu_models = [m for m in multigpu_models if m.load_device in allowed_devices] + if len(new_multigpu_models) != len(multigpu_models): + model.set_additional_models("multigpu", new_multigpu_models) + model.match_multigpu_clones() return model From ac0a90c323735333346397ff2d9b7bf493b531d3 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Wed, 20 May 2026 19:52:03 -0700 Subject: [PATCH 64/90] Use cond_shapes in multigpu memory-fit check (parity with single-GPU path) The multigpu cond-batching loop called model.memory_required(input_shape) without conditioning shapes, while the single-GPU path at line 279 passes cond_shapes. Large conditioning tensors (e.g. video prompts, control inputs) were therefore under-counted, risking OOM at runtime when the chosen batch size was too large. Match the single-GPU pattern by building cond_shapes from each batched cond's conditioning dict and passing it to memory_required. Amp-Thread-ID: https://ampcode.com/threads/T-019e43b8-8258-70fd-ab3a-53e4c97f85d5 Co-authored-by: Amp --- comfy/samplers.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index f0d67cb7e0f5..a99af52174f4 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -433,7 +433,11 @@ def next_available_device(start: int) -> tuple[int, torch.device]: for i in range(1, len(to_batch_temp) + 1): batch_amount = to_batch_temp[:len(to_batch_temp)//i] input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] - if model.memory_required(input_shape) * 1.5 < free_memory: + cond_shapes = collections.defaultdict(list) + for tt in batch_amount: + for k, v in to_run[tt][0].conditioning.items(): + cond_shapes[k].append(v.size()) + if model.memory_required(input_shape, cond_shapes=cond_shapes) * 1.5 < free_memory: to_batch = batch_amount break From 4d9106dcedbecb3df8c98a9cd05cfa8fdb3fd862 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Wed, 20 May 2026 20:48:59 -0700 Subject: [PATCH 65/90] Document --cuda-device comma format and MultiGPU Options relative_speed gap Two doc-only changes addressing minor CodeRabbit findings on PR #7063: * cli_args.py: clarify --cuda-device help text to document the required comma-separated format ('0' or '0,1'), matching how the value is consumed by CUDA_VISIBLE_DEVICES in main.py. * nodes_multigpu.py: add a docstring NOTE on the (currently unregistered) MultiGPUOptionsNode explaining that its relative_speed input is plumbed through to model_options['multigpu_options'] but is not yet consulted by the cond scheduler, which still uses uniform round-robin via next_available_device(). Wire relative_speed into the scheduler before re-enabling the node. Amp-Thread-ID: https://ampcode.com/threads/T-019e43b8-8258-70fd-ab3a-53e4c97f85d5 Co-authored-by: Amp --- comfy/cli_args.py | 2 +- comfy_extras/nodes_multigpu.py | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index df38418714b9..3a14a470d121 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -49,7 +49,7 @@ def __call__(self, parser, namespace, values, option_string=None): parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.") parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.") parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.") -parser.add_argument("--cuda-device", type=str, default=None, metavar="DEVICE_ID", help="Set the ids of cuda devices this instance will use. All other devices will not be visible.") +parser.add_argument("--cuda-device", type=str, default=None, metavar="DEVICE_ID", help="Set the ids of cuda devices this instance will use, as a comma-separated list (e.g. '0' or '0,1'). All other devices will not be visible.") parser.add_argument("--default-device", type=int, default=None, metavar="DEFAULT_DEVICE_ID", help="Set the id of the default device, all other devices will stay visible.") cm_group = parser.add_mutually_exclusive_group() cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).") diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py index 53b50029e409..fedafef7114e 100644 --- a/comfy_extras/nodes_multigpu.py +++ b/comfy_extras/nodes_multigpu.py @@ -45,6 +45,16 @@ def execute(cls, model: ModelPatcher, max_gpus: int) -> io.NodeOutput: class MultiGPUOptionsNode(io.ComfyNode): """ Select the relative speed of GPUs in the special case they have significantly different performance from one another. + + NOTE (not registered yet, see MultiGPUExtension.get_node_list below): + The output GPUOptionsGroup is plumbed through create_multigpu_deepclones() and stored on + model.model_options['multigpu_options'] via GPUOptionsGroup.register(), but the cond + scheduler in comfy/samplers.py (calc_cond_batch_outer_multigpu) does NOT yet consult + relative_speed when distributing conds across devices; it uses a uniform conds_per_device + round-robin via next_available_device(). Before re-enabling this node, wire its + relative_speed into the scheduler (e.g. via comfy.multigpu.load_balance_devices(), + which already implements the proportional split) so the input actually affects work + distribution. """ @classmethod From adde1239b1037f7bf1b2dfce9052e6fd1fde4edf Mon Sep 17 00:00:00 2001 From: Kosinkadink Date: Thu, 21 May 2026 11:35:39 -0700 Subject: [PATCH 66/90] Restore prepare_state backward-compatible signature Drop the new ignore_multigpu positional argument from prepare_state and from the ON_PREPARE_STATE callbacks; pass the flag via model_options instead. This restores the original 3-arg callback signature so existing custom-node ON_PREPARE_STATE handlers keep working unchanged, while still letting prepare_state's recursive call into multigpu_clones short-circuit. Amp-Thread-ID: https://ampcode.com/threads/T-019e4a00-fe3d-76bd-a2f2-a8c8c4040082 Co-authored-by: Amp --- comfy/model_patcher.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 00d60ff72490..b680de058bf6 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -1361,13 +1361,18 @@ def pre_run(self): for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN): callback(self) - def prepare_state(self, timestep, model_options, ignore_multigpu=False): + def prepare_state(self, timestep, model_options): + ignore_multigpu = model_options.get("ignore_multigpu", False) for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE): - callback(self, timestep, model_options, ignore_multigpu) + callback(self, timestep, model_options) if not ignore_multigpu and "multigpu_clones" in model_options: - for p in model_options["multigpu_clones"].values(): - p: ModelPatcher - p.prepare_state(timestep, model_options, ignore_multigpu=True) + model_options["ignore_multigpu"] = True + try: + for p in model_options["multigpu_clones"].values(): + p: ModelPatcher + p.prepare_state(timestep, model_options) + finally: + model_options.pop("ignore_multigpu", None) def restore_hook_patches(self): if self.hook_patches_backup is not None: From 963621603ce2b43a567ec7cf88709555dfa9d6b5 Mon Sep 17 00:00:00 2001 From: Kosinkadink Date: Thu, 21 May 2026 11:35:54 -0700 Subject: [PATCH 67/90] Free QwenFunControlNet base_model reference in cleanup QwenFunControlNet.pre_run stashes the model's diffusion_model into self.extra_args['base_model'], but ControlBase.cleanup never clears extra_args. The diffusion_model reference therefore lingered between sampling runs, blocking ComfyUI's model offload/eviction logic from freeing the UNet and -- for multigpu -- holding one such reference per per-device control clone (defeating the max_gpus pruning added in this PR). Override cleanup to drop the entry; super().cleanup() already recurses into multigpu_clones so each per-device clone pops its own. Amp-Thread-ID: https://ampcode.com/threads/T-019e4a00-fe3d-76bd-a2f2-a8c8c4040082 Co-authored-by: Amp --- comfy/controlnet.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 837aa907ab00..6dbbaa959fdc 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -357,6 +357,10 @@ def pre_run(self, model, percent_to_timestep_function): super().pre_run(model, percent_to_timestep_function) self.set_extra_arg("base_model", model.diffusion_model) + def cleanup(self): + self.extra_args.pop("base_model", None) + super().cleanup() + def copy(self): c = QwenFunControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype) c.control_model = self.control_model From a18dd219d57079c8d20ee1506feaf17a0e995ffb Mon Sep 17 00:00:00 2001 From: Kosinkadink Date: Thu, 21 May 2026 11:40:49 -0700 Subject: [PATCH 68/90] Pass per-device model to multigpu control clones in pre_run_control QwenFunControlNet.pre_run stashes model.diffusion_model into extra_args, which the control_model then uses for forward passes (img_in, txt_in, pe_embedder, time_text_embed). With multigpu, every per-device control clone was being pre_run with the base model on GPU0, so secondary devices would invoke those modules with parameters on GPU0 and inputs on their own device, raising 'Expected all tensors to be on the same device'. Build a device -> per-device BaseModel lookup from the patcher's additional multigpu models and pass each clone the model on its own device. Falls back to the base model when no per-device match is found (single-GPU path and the case where cnet.multigpu_clones lags the patcher's clone set). Amp-Thread-ID: https://ampcode.com/threads/T-019e4a00-fe3d-76bd-a2f2-a8c8c4040082 Co-authored-by: Amp --- comfy/samplers.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index a99af52174f4..8bfc42bdbf36 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -870,14 +870,21 @@ def calculate_start_end_timesteps(model, conds): def pre_run_control(model, conds): s = model.model_sampling + # Per-device model lookup so multigpu control clones get the matching + # diffusion_model (e.g. QwenFunControlNet stashes it into extra_args). + device_models: dict = {} + patcher = getattr(model, "current_patcher", None) + if patcher is not None: + for p in patcher.get_additional_models_with_key("multigpu"): + device_models[p.load_device] = p.model for t in range(len(conds)): x = conds[t] percent_to_timestep_function = lambda a: s.percent_to_sigma(a) if 'control' in x: x['control'].pre_run(model, percent_to_timestep_function) - for device_cnet in x['control'].multigpu_clones.values(): - device_cnet.pre_run(model, percent_to_timestep_function) + for device, device_cnet in x['control'].multigpu_clones.items(): + device_cnet.pre_run(device_models.get(device, model), percent_to_timestep_function) def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func): cond_cnets = [] From 822a3ecf7372760eafb3b367fff805d8b0f2fbc5 Mon Sep 17 00:00:00 2001 From: Kosinkadink Date: Thu, 21 May 2026 11:47:53 -0700 Subject: [PATCH 69/90] Note _calc_cond_batch and _calc_cond_batch_multigpu must stay in sync Per review feedback on #7063. The two functions share the conds-by-hooks accumulation, memory-fit batching, and per-chunk output aggregation; the multigpu variant adds per-device scheduling, .to(device) placement, per-device patcher/control lookup, and thread-pool dispatch around the inner loop. Documenting the relationship without extracting helpers -- extraction can land after the initial worksplit-multigpu release once both paths have settled. Amp-Thread-ID: https://ampcode.com/threads/T-019e4a00-fe3d-76bd-a2f2-a8c8c4040082 Co-authored-by: Amp --- comfy/samplers.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/comfy/samplers.py b/comfy/samplers.py index 8bfc42bdbf36..6fd0387d5a2e 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -218,6 +218,9 @@ def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torc return executor.execute(model, conds, x_in, timestep, model_options) def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]): + # NOTE: keep in sync with _calc_cond_batch_multigpu below. Shared logic + # (hooked_to_run accumulation, memory-fit batching, per-chunk output + # aggregation) is duplicated there with per-device scheduling layered on top. if 'multigpu_clones' in model_options: return _calc_cond_batch_multigpu(model, conds, x_in, timestep, model_options) out_conds = [] @@ -353,6 +356,10 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens return out_conds def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]): + # NOTE: keep in sync with _calc_cond_batch above. Same conds-by-hooks + # accumulation, memory-fit batching, and output aggregation, but adds a + # per-device scheduler, per-device patcher/control lookup, tensor .to(device) + # placement, and MultiGPUThreadPool dispatch around the inner loop. out_conds = [] out_counts = [] # separate conds by matching hooks From 019261ed968640ffd53d9fd53bbdfd550ad7d9c3 Mon Sep 17 00:00:00 2001 From: Kosinkadink Date: Thu, 21 May 2026 12:14:02 -0700 Subject: [PATCH 70/90] Simplify Hunyuan 3D 2.1 swap_cfg_halves gate to a shape check The previous gate (len(cond_or_uncond) == 2 and set == {0, 1}) was intended to skip the cond/uncond swap when only one half was present under MultiGPU CFG Split, but it was too restrictive: it also skipped batch_size > 1 + CFG (cond_or_uncond like [0, 0, 1, 1] or [0,0,0,0, 1,1,1,1]), where chunk(2) still splits the batch cleanly into a cond half and an uncond half and the swap is still required. Switch to context.shape[0] >= 2, matching the parallel fix landed on master in #13699. The swap is a permutation-invariant no-op when the two halves don't form a CFG pair (since the output swap_cfg_halves block immediately undoes the permutation), so the only thing the gate actually needs to do is guard against chunk(2) on a batch of one. Amp-Thread-ID: https://ampcode.com/threads/T-019e4a00-fe3d-76bd-a2f2-a8c8c4040082 Co-authored-by: Amp --- comfy/ldm/hunyuan3dv2_1/hunyuandit.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/comfy/ldm/hunyuan3dv2_1/hunyuandit.py b/comfy/ldm/hunyuan3dv2_1/hunyuandit.py index 61d1b3dc674a..cb260e0a8c2c 100644 --- a/comfy/ldm/hunyuan3dv2_1/hunyuandit.py +++ b/comfy/ldm/hunyuan3dv2_1/hunyuandit.py @@ -608,8 +608,7 @@ def forward(self, x, t, context, transformer_options = {}, **kwargs): x = x.movedim(-1, -2) - cond_or_uncond = transformer_options.get("cond_or_uncond", []) - swap_cfg_halves = len(cond_or_uncond) == 2 and set(cond_or_uncond) == {0, 1} + swap_cfg_halves = context.shape[0] >= 2 if swap_cfg_halves: first_half, second_half = context.chunk(2, dim = 0) From fd79f22bdfceaf58fe6bdd6fb65a267b14710fec Mon Sep 17 00:00:00 2001 From: Kosinkadink Date: Thu, 21 May 2026 12:17:24 -0700 Subject: [PATCH 71/90] Backport Hunyuan 3D 2.1 attention batch-size fixes from #13699 CrossAttention.kv.view and Attention.qkv_combined.view both hardcoded batch=1 in the reshape, crashing or silently mis-shaping whenever the actual batch dimension was greater than 1. These were fixed on master in #13699 as part of the same patch that gated the chunk(2) swap, but worksplit-multigpu only picked up the chunk(2) gate. Bring the two view() fixes over so we have parity with master. Amp-Thread-ID: https://ampcode.com/threads/T-019e4a00-fe3d-76bd-a2f2-a8c8c4040082 Co-authored-by: Amp --- comfy/ldm/hunyuan3dv2_1/hunyuandit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/hunyuan3dv2_1/hunyuandit.py b/comfy/ldm/hunyuan3dv2_1/hunyuandit.py index cb260e0a8c2c..4e4819fe3dc4 100644 --- a/comfy/ldm/hunyuan3dv2_1/hunyuandit.py +++ b/comfy/ldm/hunyuan3dv2_1/hunyuandit.py @@ -328,7 +328,7 @@ def forward(self, x, y): kv = torch.cat((k, v), dim=-1) split_size = kv.shape[-1] // self.num_heads // 2 - kv = kv.view(1, -1, self.num_heads, split_size * 2) + kv = kv.view(b, -1, self.num_heads, split_size * 2) k, v = torch.split(kv, split_size, dim=-1) q = q.view(b, s1, self.num_heads, self.head_dim) @@ -398,7 +398,7 @@ def forward(self, x): qkv_combined = torch.cat((query, key, value), dim=-1) split_size = qkv_combined.shape[-1] // self.num_heads // 3 - qkv = qkv_combined.view(1, -1, self.num_heads, split_size * 3) + qkv = qkv_combined.view(B, -1, self.num_heads, split_size * 3) query, key, value = torch.split(qkv, split_size, dim=-1) query = query.reshape(B, N, self.num_heads, self.head_dim) From 2ed396c769bfce4c668a840672cc44c701051dbd Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 21 May 2026 12:47:43 -0700 Subject: [PATCH 72/90] Mark non-NVIDIA multigpu gaps with TODOs in _handle_batch Two CodeRabbit findings from #7063 (#13 and #14) are deferred because worksplit-multigpu's initial release scope is NVIDIA-only QA. Leave a TODO at the unconditional torch.cuda.set_device call and at the post-aggregation point so the required guards/synchronize are easy to find when multigpu support is extended to XPU/NPU/MPS/CPU/DirectML. Amp-Thread-ID: https://ampcode.com/threads/T-019e4a00-fe3d-76bd-a2f2-a8c8c4040082 Co-authored-by: Amp --- comfy/samplers.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/comfy/samplers.py b/comfy/samplers.py index 6fd0387d5a2e..42b05f3ba120 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -465,6 +465,9 @@ class thread_result(NamedTuple): def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]): try: + # TODO: non-NVIDIA support -- guard with `if device.type == "cuda":` once + # we extend multigpu QA beyond CUDA. Unconditional call crashes on + # XPU/NPU/MPS/CPU/DirectML backends. torch.cuda.set_device(device) model_current: BaseModel = model_options["multigpu_clones"][device].model # run every hooked_to_run separately @@ -524,6 +527,12 @@ def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup output = model_options['model_function_wrapper'](model_current.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).to(output_device).chunk(batch_chunks) else: output = model_current.apply_model(input_x, timestep_, **c).to(output_device).chunk(batch_chunks) + # TODO: non-NVIDIA support -- the `.to(output_device)` copies + # above are async on CUDA, so the main thread's aggregation + # could race with in-flight transfers. CUDA-only QA has not + # surfaced this in practice, but before extending multigpu + # beyond NVIDIA add a `torch.cuda.synchronize(output_device)` + # here (guarded by `output_device.type == "cuda"`). results.append(thread_result(output, mult, area, batch_chunks, cond_or_uncond)) except Exception as e: results.append(thread_result(None, None, None, None, None, error=e)) From b649502c9ce6146b946437fddf1848ba0292aa17 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 21 May 2026 13:04:54 -0700 Subject: [PATCH 73/90] Report all torch devices from /system_stats The /system_stats endpoint was returning a hardcoded single-element devices list built from get_torch_device(), which only reflects the primary CUDA device. On multi-GPU systems this hides the additional devices from frontends / tooling (the API surface that enables multigpu support discovery). Switch to iterating get_all_torch_devices(), with the primary device kept first so existing clients reading devices[0] keep working. (Worksplit-multigpu-only: get_all_torch_devices is the multigpu helper introduced on this branch; master's /system_stats remains unchanged.) Amp-Thread-ID: https://ampcode.com/threads/T-019e4a00-fe3d-76bd-a2f2-a8c8c4040082 Co-authored-by: Amp --- server.py | 39 ++++++++++++++++++++++++--------------- 1 file changed, 24 insertions(+), 15 deletions(-) diff --git a/server.py b/server.py index 44470b9042ed..268441bd1a87 100644 --- a/server.py +++ b/server.py @@ -646,18 +646,37 @@ async def view_metadata(request): @routes.get("/system_stats") async def system_stats(request): - device = comfy.model_management.get_torch_device() - device_name = comfy.model_management.get_torch_device_name(device) + primary_device = comfy.model_management.get_torch_device() cpu_device = comfy.model_management.torch.device("cpu") ram_total = comfy.model_management.get_total_memory(cpu_device) ram_free = comfy.model_management.get_free_memory(cpu_device) - vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True) - vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True) required_frontend_version = FrontendManager.get_required_frontend_version() installed_templates_version = FrontendManager.get_installed_templates_version() required_templates_version = FrontendManager.get_required_templates_version() comfy_package_versions = FrontendManager.get_comfy_package_versions() + # Report every torch device visible to multigpu, with the primary + # device first so existing clients that read devices[0] keep working. + torch_devices = comfy.model_management.get_all_torch_devices() + if primary_device in torch_devices: + torch_devices = [primary_device] + [d for d in torch_devices if d != primary_device] + else: + torch_devices = [primary_device] + list(torch_devices) + + device_entries = [] + for d in torch_devices: + vram_total, torch_vram_total = comfy.model_management.get_total_memory(d, torch_total_too=True) + vram_free, torch_vram_free = comfy.model_management.get_free_memory(d, torch_free_too=True) + device_entries.append({ + "name": comfy.model_management.get_torch_device_name(d), + "type": d.type, + "index": d.index, + "vram_total": vram_total, + "vram_free": vram_free, + "torch_vram_total": torch_vram_total, + "torch_vram_free": torch_vram_free, + }) + system_stats = { "system": { "os": sys.platform, @@ -673,17 +692,7 @@ async def system_stats(request): "embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded", "argv": sys.argv }, - "devices": [ - { - "name": device_name, - "type": device.type, - "index": device.index, - "vram_total": vram_total, - "vram_free": vram_free, - "torch_vram_total": torch_vram_total, - "torch_vram_free": torch_vram_free, - } - ] + "devices": device_entries } return web.json_response(system_stats) From df17b560c5b1666ce0693b3985a51467a6d805b9 Mon Sep 17 00:00:00 2001 From: Rattus Date: Fri, 22 May 2026 23:30:35 +1000 Subject: [PATCH 74/90] memory_management: replace thread refusal with mutex This was an attempt to be a fast path by ensuring the file slice was created by the owning thread and refusing without needing ot mutex but worksplit-multigpu doesnt work that way. Go mutex. Shoot me for overthinking next time. --- comfy/memory_management.py | 36 ++++++++++++++++++------------------ comfy/utils.py | 3 ++- 2 files changed, 20 insertions(+), 19 deletions(-) diff --git a/comfy/memory_management.py b/comfy/memory_management.py index c43f0c4a2109..962addb27bf9 100644 --- a/comfy/memory_management.py +++ b/comfy/memory_management.py @@ -1,6 +1,5 @@ import math import ctypes -import threading import dataclasses import torch from typing import NamedTuple @@ -10,7 +9,7 @@ class TensorFileSlice(NamedTuple): file_ref: object - thread_id: int + lock: object offset: int size: int @@ -43,7 +42,6 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N file_obj = info.file_ref if (destination.device.type != "cpu" or file_obj is None - or threading.get_ident() != info.thread_id or destination.numel() * destination.element_size() < info.size or tensor.numel() * tensor.element_size() != info.size or tensor.storage_offset() != 0 @@ -57,27 +55,29 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N if hostbuf is not None: stream_ptr = getattr(stream, "cuda_stream", 0) if stream is not None else 0 device_ptr = destination2.data_ptr() if destination2 is not None else 0 - hostbuf.read_file_slice(file_obj, info.offset, info.size, - offset=destination.data_ptr() - hostbuf.get_raw_address(), - stream=stream_ptr, - device_ptr=device_ptr, - device=None if destination2 is None else destination2.device.index) + with info.lock: + hostbuf.read_file_slice(file_obj, info.offset, info.size, + offset=destination.data_ptr() - hostbuf.get_raw_address(), + stream=stream_ptr, + device_ptr=device_ptr, + device=None if destination2 is None else destination2.device.index) return True buf_type = ctypes.c_ubyte * info.size view = memoryview(buf_type.from_address(destination.data_ptr())) try: - file_obj.seek(info.offset) - done = 0 - while done < info.size: - try: - n = file_obj.readinto(view[done:]) - except OSError: - return False - if n <= 0: - return False - done += n + with info.lock: + file_obj.seek(info.offset) + done = 0 + while done < info.size: + try: + n = file_obj.readinto(view[done:]) + except OSError: + return False + if n <= 0: + return False + done += n return True finally: view.release() diff --git a/comfy/utils.py b/comfy/utils.py index 31052714a585..49ae12b0660c 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -86,6 +86,7 @@ def load_safetensors(ckpt): import comfy_aimdo.model_mmap f = open(ckpt, "rb", buffering=0) + file_lock = threading.Lock() model_mmap = comfy_aimdo.model_mmap.ModelMMAP(ckpt) file_size = os.path.getsize(ckpt) mv = memoryview((ctypes.c_uint8 * file_size).from_address(model_mmap.get())) @@ -111,7 +112,7 @@ def load_safetensors(ckpt): storage = tensor.untyped_storage() setattr(storage, "_comfy_tensor_file_slice", - comfy.memory_management.TensorFileSlice(f, threading.get_ident(), data_base_offset + start, end - start)) + comfy.memory_management.TensorFileSlice(f, file_lock, data_base_offset + start, end - start)) setattr(storage, "_comfy_tensor_mmap_refs", (model_mmap, mv)) sd[name] = tensor From 7a18f9affbe52f6b6560fd40eb5de0dcb520f7c8 Mon Sep 17 00:00:00 2001 From: Rattus Date: Sat, 23 May 2026 00:47:19 +1000 Subject: [PATCH 75/90] comfy-aimdo 0.4.4 Comfy-aimdo 0.4.4 contains a small bugfix to allow recovery of a hostbuf after full truncation. This pattern doesnt happen as a general rule, but does happen in the upcoming worksplit-multigpu branch. --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index e20b6e044b07..381e7d05fdfe 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,7 +23,7 @@ SQLAlchemy>=2.0.0 filelock av>=14.2.0 comfy-kitchen>=0.2.8 -comfy-aimdo==0.4.3 +comfy-aimdo==0.4.4 requests simpleeval>=1.0.0 blake3 From 74b0a826eaa7962e5093d83a27e13c20d4acfadf Mon Sep 17 00:00:00 2001 From: John Pollock Date: Wed, 20 May 2026 15:37:09 -0500 Subject: [PATCH 76/90] Add UPSCALE_MODEL lane to MultiGPU CFG Split Introduce tiled_scale_multidim_multigpu in comfy/utils.py: a tile scheduler that dispatches per-device tile functions through the existing MultiGPUThreadPool and merges per-device CPU output buffers in deterministic key order. The worker only catches BaseException at the thread boundary to funnel errors to the main thread; bare torch.cuda.set_device and torch.cuda.synchronize calls inside the worker fail loud if the device is not CUDA, which is part of the primitive's contract. Add UPSCALE_MODEL input on the MultiGPU CFG Split node and an upscale-model descriptor deepclone helper in comfy/multigpu.py. Clones stay CPU-resident until execute time and are returned to CPU afterward. ImageUpscaleWithModel dispatches through tiled_scale_multidim_multigpu when a multigpu descriptor is attached; the single-device path runs unchanged when no clones are present. --- comfy/multigpu.py | 30 ++++++ comfy/utils.py | 151 +++++++++++++++++++++++++++- comfy_extras/nodes_multigpu.py | 23 +++-- comfy_extras/nodes_upscale_model.py | 25 ++++- 4 files changed, 218 insertions(+), 11 deletions(-) diff --git a/comfy/multigpu.py b/comfy/multigpu.py index eff7d06499a9..7f90b7db7484 100644 --- a/comfy/multigpu.py +++ b/comfy/multigpu.py @@ -1,4 +1,5 @@ from __future__ import annotations +import copy import queue import threading import torch @@ -175,6 +176,35 @@ def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options: return model +def create_upscale_model_multigpu_deepclones(upscale_model, max_gpus: int): + """Return a shallow copy of ``upscale_model`` with a ``multigpu_clones`` dict of CPU-resident + descriptor deepclones, one per extra CUDA device up to ``max_gpus``. + """ + full_extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True) + limit_extra_devices = full_extra_devices[:max_gpus - 1] + if len(limit_extra_devices) == 0: + logging.info("No extra torch devices need initialization, skipping initializing MultiGPU upscale clones.") + return upscale_model + + cloned = copy.copy(upscale_model) + existing = getattr(upscale_model, 'multigpu_clones', None) + clones: dict[torch.device, object] = dict(existing) if existing else {} + + for device in limit_extra_devices: + if device in clones: + continue + clone_desc = copy.deepcopy(upscale_model) + clone_desc.model.eval() + for p in clone_desc.model.parameters(): + p.requires_grad_(False) + clone_desc.to("cpu") + clones[device] = clone_desc + logging.info(f"Created CPU upscale_model descriptor deepclone for {device}") + + cloned.multigpu_clones = clones + return cloned + + LoadBalance = namedtuple('LoadBalance', ['work_per_device', 'idle_time']) def load_balance_devices(model_options: dict[str], total_work: int, return_idle_time=False, work_normalized: int=None): 'Optimize work assigned to different devices, accounting for their relative speeds and splittable work.' diff --git a/comfy/utils.py b/comfy/utils.py index 31052714a585..c53e0cb914db 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -28,13 +28,13 @@ from PIL import Image import logging import itertools +import threading from torch.nn.functional import interpolate from tqdm.auto import trange from einops import rearrange from comfy.cli_args import args import json import time -import threading import warnings MMAP_TORCH_FILES = args.mmap_torch_files @@ -1186,6 +1186,155 @@ def mult_list_upscale(a): def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None): return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar) + +def tiled_scale_multidim_multigpu(samples, functions, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", downscale=False, index_formulas=None, pbar=None): + """Multigpu variant of tiled_scale_multidim. ``functions`` is a dict[torch.device, callable]. + + Round-robin dispatches tile positions across devices via threading. Each thread maintains + its own per-device CPU output and divisor buffer, applying the same feathered overlap mask + formula as the single-device path. Buffers are summed at the end, producing output that is + bit-equivalent to ``tiled_scale_multidim`` within fp32 add-order noise. + + Falls back to ``tiled_scale_multidim`` with the only function when ``len(functions) < 2``. + Falls back to single-device on the "whole input fits in one tile" branch (no parallelism + available at that granularity). + """ + devices = list(functions.keys()) + if len(devices) < 2: + only_fn = next(iter(functions.values())) if functions else None + return tiled_scale_multidim(samples, only_fn, tile=tile, overlap=overlap, + upscale_amount=upscale_amount, out_channels=out_channels, + output_device=output_device, downscale=downscale, + index_formulas=index_formulas, pbar=pbar) + + dims = len(tile) + + if not (isinstance(upscale_amount, (tuple, list))): + upscale_amount = [upscale_amount] * dims + if not (isinstance(overlap, (tuple, list))): + overlap = [overlap] * dims + if index_formulas is None: + index_formulas = upscale_amount + if not (isinstance(index_formulas, (tuple, list))): + index_formulas = [index_formulas] * dims + + def get_upscale(dim, val): + up = upscale_amount[dim] + return up(val) if callable(up) else up * val + + def get_downscale(dim, val): + up = upscale_amount[dim] + return up(val) if callable(up) else val / up + + def get_upscale_pos(dim, val): + up = index_formulas[dim] + return up(val) if callable(up) else up * val + + def get_downscale_pos(dim, val): + up = index_formulas[dim] + return up(val) if callable(up) else val / up + + if downscale: + get_scale = get_downscale + get_pos = get_downscale_pos + else: + get_scale = get_upscale + get_pos = get_upscale_pos + + def mult_list_upscale(a): + return [round(get_scale(i, a[i])) for i in range(len(a))] + + output = torch.empty([samples.shape[0], out_channels] + mult_list_upscale(samples.shape[2:]), device=output_device) + merge_device = torch.device("cpu") + + pbar_lock = threading.Lock() if pbar is not None else None + primary_device = devices[0] + + samples_staged = samples if samples.device.type == "cpu" else samples.to("cpu", non_blocking=False) + + for b in range(samples_staged.shape[0]): + s = samples_staged[b:b+1] + + if all(s.shape[d+2] <= tile[d] for d in range(dims)): + with torch.inference_mode(): + output[b:b+1] = functions[primary_device](s.to(primary_device, non_blocking=True)).to(output_device) + if pbar is not None: + pbar.update(1) + continue + + positions = [range(0, s.shape[d+2] - overlap[d], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)] + all_positions = list(itertools.product(*positions)) + + split = {devices[i]: all_positions[i::len(devices)] for i in range(len(devices))} + + out_shape = [s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]) + div_shape = [s.shape[0], 1] + mult_list_upscale(s.shape[2:]) + bufs = {d: torch.zeros(out_shape, device=merge_device) for d in devices} + divs = {d: torch.zeros(div_shape, device=merge_device) for d in devices} + + worker_errors: list[BaseException] = [] + worker_lock = threading.Lock() + + def worker(device, my_positions): + try: + torch.cuda.set_device(device) + fn = functions[device] + local_buf = bufs[device] + local_div = divs[device] + with torch.inference_mode(): + for it in my_positions: + s_in = s + upscaled = [] + for d in range(dims): + pos = max(0, min(s.shape[d + 2] - overlap[d], it[d])) + l = min(tile[d], s.shape[d + 2] - pos) + s_in = s_in.narrow(d + 2, pos, l) + upscaled.append(round(get_pos(d, pos))) + + s_in_dev = s_in.to(device, non_blocking=True) + ps = fn(s_in_dev).to(merge_device) + mask = torch.ones([1, 1] + list(ps.shape[2:]), device=merge_device) + + for d in range(2, dims + 2): + feather = round(get_scale(d - 2, overlap[d - 2])) + if feather >= mask.shape[d]: + continue + for t in range(feather): + a = (t + 1) / feather + mask.narrow(d, t, 1).mul_(a) + mask.narrow(d, mask.shape[d] - 1 - t, 1).mul_(a) + + o = local_buf + o_d = local_div + for d in range(dims): + o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2]) + o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2]) + + o.add_(ps * mask) + o_d.add_(mask) + + if pbar is not None: + with pbar_lock: + pbar.update(1) + torch.cuda.synchronize(device) + except BaseException as e: + with worker_lock: + worker_errors.append(e) + + threads = [threading.Thread(target=worker, args=(d, split[d])) for d in devices] + for t in threads: + t.start() + for t in threads: + t.join() + if worker_errors: + raise worker_errors[0] + + combined_buf = sum(bufs.values()) + combined_div = sum(divs.values()).clamp_(min=1e-12) + output[b:b+1] = combined_buf / combined_div + + return output + def model_trange(*args, **kwargs): if not comfy.memory_management.aimdo_enabled: return trange(*args, **kwargs) diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py index fedafef7114e..021dfca3f591 100644 --- a/comfy_extras/nodes_multigpu.py +++ b/comfy_extras/nodes_multigpu.py @@ -13,33 +13,38 @@ class MultiGPUCFGSplitNode(io.ComfyNode): """ - Prepares model to have sampling accelerated via splitting work units. + Attaches per-device deepclones to any connected MODEL and/or UPSCALE_MODEL so downstream + nodes that recognize the attached state dispatch their work across multiple GPUs. - Should be placed after nodes that modify the model object itself, such as compile or attention-switch nodes. - - Other than those exceptions, this node can be placed in any order. + Place after nodes that modify the model object itself (compile, attention-switch, etc.). + Otherwise position is not order-sensitive. """ @classmethod def define_schema(cls): return io.Schema( node_id="MultiGPU_WorkUnits", - display_name="MultiGPU CFG Split", + display_name="MultiGPU Work Units", category="advanced/multigpu", description=cleandoc(cls.__doc__), inputs=[ - io.Model.Input("model"), + io.Model.Input("model", optional=True), + io.UpscaleModel.Input("upscale_model", optional=True), io.Int.Input("max_gpus", default=2, min=1, step=1), ], outputs=[ io.Model.Output(), + io.UpscaleModel.Output(), ], ) @classmethod - def execute(cls, model: ModelPatcher, max_gpus: int) -> io.NodeOutput: - model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, reuse_loaded=True) - return io.NodeOutput(model) + def execute(cls, max_gpus: int, model: ModelPatcher = None, upscale_model=None) -> io.NodeOutput: + if model is not None: + model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, reuse_loaded=True) + if upscale_model is not None: + upscale_model = comfy.multigpu.create_upscale_model_multigpu_deepclones(upscale_model, max_gpus) + return io.NodeOutput(model, upscale_model) class MultiGPUOptionsNode(io.ComfyNode): diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index d3ee3f1c1964..3a4e3926cab3 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -81,13 +81,33 @@ def execute(cls, upscale_model, image) -> io.NodeOutput: output_device = comfy.model_management.intermediate_device() + multigpu_clones = getattr(upscale_model, 'multigpu_clones', None) + if multigpu_clones: + for dev, desc in multigpu_clones.items(): + model_management.free_memory(memory_required, dev) + desc.to(dev) + oom = True try: while oom: try: steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap) pbar = comfy.utils.ProgressBar(steps) - s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a.float()), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar, output_device=output_device) + if multigpu_clones: + functions = {device: lambda a: upscale_model(a.float())} + for dev, desc in multigpu_clones.items(): + functions[dev] = lambda a, d=desc: d(a.float()) + s = comfy.utils.tiled_scale_multidim_multigpu( + in_img, + functions, + tile=(tile, tile), + overlap=overlap, + upscale_amount=upscale_model.scale, + pbar=pbar, + output_device=output_device, + ) + else: + s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a.float()), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar, output_device=output_device) oom = False except Exception as e: model_management.raise_non_oom(e) @@ -96,6 +116,9 @@ def execute(cls, upscale_model, image) -> io.NodeOutput: raise e finally: upscale_model.to("cpu") + if multigpu_clones: + for desc in multigpu_clones.values(): + desc.to("cpu") s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0).to(comfy.model_management.intermediate_dtype()) return io.NodeOutput(s) From 4d3d68e4731cf366289f9f4ca11242f4a78956df Mon Sep 17 00:00:00 2001 From: John Pollock Date: Fri, 22 May 2026 12:32:30 -0500 Subject: [PATCH 77/90] Add tiled VAE lane to MultiGPU Work Units --- comfy/multigpu.py | 64 ++++++++++- comfy/sd.py | 132 +++++++++++++++++++++- comfy/utils.py | 26 +++-- comfy_extras/nodes_multigpu.py | 12 +- nodes.py | 6 + tests-unit/comfy_test/multigpu_test.py | 147 +++++++++++++++++++++++++ 6 files changed, 366 insertions(+), 21 deletions(-) create mode 100644 tests-unit/comfy_test/multigpu_test.py diff --git a/comfy/multigpu.py b/comfy/multigpu.py index 7f90b7db7484..2573185de92f 100644 --- a/comfy/multigpu.py +++ b/comfy/multigpu.py @@ -182,18 +182,23 @@ def create_upscale_model_multigpu_deepclones(upscale_model, max_gpus: int): """ full_extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True) limit_extra_devices = full_extra_devices[:max_gpus - 1] - if len(limit_extra_devices) == 0: - logging.info("No extra torch devices need initialization, skipping initializing MultiGPU upscale clones.") - return upscale_model - cloned = copy.copy(upscale_model) existing = getattr(upscale_model, 'multigpu_clones', None) - clones: dict[torch.device, object] = dict(existing) if existing else {} + limit_extra_device_set = set(limit_extra_devices) + clones: dict[torch.device, object] = {d: c for d, c in dict(existing).items() if d in limit_extra_device_set} if existing else {} + if len(limit_extra_devices) == 0: + logging.info("No extra torch devices need initialization, skipping initializing MultiGPU upscale clones.") + if hasattr(cloned, 'multigpu_clones'): + del cloned.multigpu_clones + return cloned for device in limit_extra_devices: if device in clones: continue - clone_desc = copy.deepcopy(upscale_model) + clone_source = copy.copy(upscale_model) + if hasattr(clone_source, 'multigpu_clones'): + del clone_source.multigpu_clones + clone_desc = copy.deepcopy(clone_source) clone_desc.model.eval() for p in clone_desc.model.parameters(): p.requires_grad_(False) @@ -205,6 +210,53 @@ def create_upscale_model_multigpu_deepclones(upscale_model, max_gpus: int): return cloned +def create_vae_multigpu_deepclones(vae, max_gpus: int): + """Return a shallow copy of ``vae`` with a ``multigpu_clones`` dict of CPU-resident VAE + deepclones, one per extra CUDA device up to ``max_gpus``. + """ + vae.throw_exception_if_invalid() + vae_device = torch.device(vae.device) + cloned = copy.copy(vae) + if hasattr(cloned, 'multigpu_clones'): + del cloned.multigpu_clones + if vae_device.type == "cpu": + logging.info("CPU VAE selected, skipping initializing MultiGPU VAE clones.") + return cloned + + full_extra_devices = comfy.model_management.get_all_torch_devices() + + def is_vae_device(device): + return device.type == vae_device.type and device.index == vae_device.index + + limit_extra_devices = [d for d in full_extra_devices if not is_vae_device(d)][:max_gpus - 1] + if len(limit_extra_devices) == 0: + logging.info("No extra torch devices need initialization, skipping initializing MultiGPU VAE clones.") + return cloned + + existing = getattr(vae, 'multigpu_clones', None) + limit_extra_device_set = set(limit_extra_devices) + clones: dict[torch.device, object] = {d: c for d, c in dict(existing).items() if d in limit_extra_device_set} if existing else {} + + for device in limit_extra_devices: + if device in clones: + continue + cloned_patcher = vae.patcher.deepclone_multigpu(new_load_device=device) + clone_vae = copy.copy(vae) + if hasattr(clone_vae, 'multigpu_clones'): + del clone_vae.multigpu_clones + clone_vae.first_stage_model = cloned_patcher.model + clone_vae.patcher = cloned_patcher + clone_vae.first_stage_model.eval() + for p in clone_vae.first_stage_model.parameters(): + p.requires_grad_(False) + clone_vae.first_stage_model.to("cpu") + clones[device] = clone_vae + logging.info(f"Created CPU VAE deepclone for {device}") + + cloned.multigpu_clones = clones + return cloned + + LoadBalance = namedtuple('LoadBalance', ['work_per_device', 'idle_time']) def load_balance_devices(model_options: dict[str], total_work: int, return_idle_time=False, work_normalized: int=None): 'Optimize work assigned to different devices, accounting for their relative speeds and splittable work.' diff --git a/comfy/sd.py b/comfy/sd.py index 1670a0486570..6401fdb144da 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -972,6 +972,26 @@ def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): pbar = comfy.utils.ProgressBar(steps) decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) + + multigpu_clones = getattr(self, 'multigpu_clones', None) + if multigpu_clones: + functions = {self.device: decode_fn} + try: + for dev, c in multigpu_clones.items(): + model_management.free_memory(c.model_size() + c.memory_used_decode(samples.shape, c.vae_dtype), dev) + c.first_stage_model.to(dev) + for dev, c in multigpu_clones.items(): + functions[dev] = lambda a, _c=c, _dev=dev: _c.first_stage_model.decode(a.to(_c.vae_dtype).to(_dev)).to(dtype=_c.vae_output_dtype()) + output = self.process_output( + (comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_y * 2, tile_x // 2), overlap=overlap, upscale_amount=self.upscale_ratio, output_device=self.output_device, pbar=pbar) + + comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_y // 2, tile_x * 2), overlap=overlap, upscale_amount=self.upscale_ratio, output_device=self.output_device, pbar=pbar) + + comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_y, tile_x), overlap=overlap, upscale_amount=self.upscale_ratio, output_device=self.output_device, pbar=pbar)) + / 3.0) + return output + finally: + for c in multigpu_clones.values(): + c.first_stage_model.to("cpu") + output = self.process_output( (comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) + comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) + @@ -981,16 +1001,49 @@ def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): def decode_tiled_1d(self, samples, tile_x=256, overlap=32): if samples.ndim == 3: + memory_shape = samples.shape decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) + clone_decode_fn_factory = lambda c, dev: (lambda a: c.first_stage_model.decode(a.to(c.vae_dtype).to(dev)).to(dtype=c.vae_output_dtype())) else: og_shape = samples.shape + memory_shape = og_shape samples = samples.reshape((og_shape[0], og_shape[1] * og_shape[2], -1)) decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) + clone_decode_fn_factory = lambda c, dev: (lambda a: c.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(c.vae_dtype).to(dev)).to(dtype=c.vae_output_dtype())) + + multigpu_clones = getattr(self, 'multigpu_clones', None) + if multigpu_clones: + functions = {self.device: decode_fn} + try: + for dev, c in multigpu_clones.items(): + model_management.free_memory(c.model_size() + c.memory_used_decode(memory_shape, c.vae_dtype), dev) + c.first_stage_model.to(dev) + for dev, c in multigpu_clones.items(): + functions[dev] = clone_decode_fn_factory(c, dev) + return self.process_output(comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device)) + finally: + for c in multigpu_clones.values(): + c.first_stage_model.to("cpu") return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device)) def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)): decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) + + multigpu_clones = getattr(self, 'multigpu_clones', None) + if multigpu_clones: + functions = {self.device: decode_fn} + try: + for dev, c in multigpu_clones.items(): + model_management.free_memory(c.model_size() + c.memory_used_decode(samples.shape, c.vae_dtype), dev) + c.first_stage_model.to(dev) + for dev, c in multigpu_clones.items(): + functions[dev] = lambda a, _c=c, _dev=dev: _c.first_stage_model.decode(a.to(_c.vae_dtype).to(_dev)).to(dtype=_c.vae_output_dtype()) + return self.process_output(comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device)) + finally: + for c in multigpu_clones.values(): + c.first_stage_model.to("cpu") + return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device)) def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): @@ -1000,6 +1053,25 @@ def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): pbar = comfy.utils.ProgressBar(steps) encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) + + multigpu_clones = getattr(self, 'multigpu_clones', None) + if multigpu_clones: + functions = {self.device: encode_fn} + try: + for dev, c in multigpu_clones.items(): + model_management.free_memory(c.model_size() + c.memory_used_encode(pixel_samples.shape, c.vae_dtype), dev) + c.first_stage_model.to(dev) + for dev, c in multigpu_clones.items(): + functions[dev] = lambda a, _c=c, _dev=dev: _c.first_stage_model.encode((_c.process_input(a)).to(_c.vae_dtype).to(_dev)).to(dtype=_c.vae_output_dtype()) + samples = comfy.utils.tiled_scale_multidim_multigpu(pixel_samples, functions, tile=(tile_y, tile_x), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) + samples += comfy.utils.tiled_scale_multidim_multigpu(pixel_samples, functions, tile=(tile_y // 2, tile_x * 2), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) + samples += comfy.utils.tiled_scale_multidim_multigpu(pixel_samples, functions, tile=(tile_y * 2, tile_x // 2), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) + samples /= 3.0 + return samples + finally: + for c in multigpu_clones.values(): + c.first_stage_model.to("cpu") + samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) @@ -1009,6 +1081,7 @@ def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048): if self.latent_dim == 1: encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) + clone_encode_fn_factory = lambda c, dev: (lambda a: c.first_stage_model.encode((c.process_input(a)).to(c.vae_dtype).to(dev)).to(dtype=c.vae_output_dtype())) out_channels = self.latent_channels upscale_amount = 1 / self.downscale_ratio else: @@ -1018,8 +1091,24 @@ def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048): overlap = overlap // extra_channel_size upscale_amount = 1 / self.downscale_ratio encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).to(dtype=self.vae_output_dtype()) + clone_encode_fn_factory = lambda c, dev: (lambda a: c.first_stage_model.encode((c.process_input(a)).to(c.vae_dtype).to(dev)).reshape(1, out_channels, -1).to(dtype=c.vae_output_dtype())) + + multigpu_clones = getattr(self, 'multigpu_clones', None) + if multigpu_clones: + functions = {self.device: encode_fn} + try: + for dev, c in multigpu_clones.items(): + model_management.free_memory(c.model_size() + c.memory_used_encode(samples.shape, c.vae_dtype), dev) + c.first_stage_model.to(dev) + for dev, c in multigpu_clones.items(): + functions[dev] = clone_encode_fn_factory(c, dev) + out = comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device) + finally: + for c in multigpu_clones.values(): + c.first_stage_model.to("cpu") + else: + out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device) - out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device) if self.latent_dim == 1: return out else: @@ -1027,6 +1116,21 @@ def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048): def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)): encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) + + multigpu_clones = getattr(self, 'multigpu_clones', None) + if multigpu_clones: + functions = {self.device: encode_fn} + try: + for dev, c in multigpu_clones.items(): + model_management.free_memory(c.model_size() + c.memory_used_encode(samples.shape, c.vae_dtype), dev) + c.first_stage_model.to(dev) + for dev, c in multigpu_clones.items(): + functions[dev] = lambda a, _c=c, _dev=dev: _c.first_stage_model.encode((_c.process_input(a)).to(_c.vae_dtype).to(_dev)).to(dtype=_c.vae_output_dtype()) + return comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device) + finally: + for c in multigpu_clones.values(): + c.first_stage_model.to("cpu") + return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device) def decode(self, samples_in, vae_options={}): @@ -1727,8 +1831,14 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd))) if out[0] is not None: out[0].cached_patcher_init = (load_checkpoint_guess_config, (ckpt_path, False, False, False, embedding_directory, output_model, model_options, te_model_options), 0) + if output_vae and out[2] is not None and hasattr(out[2], "patcher"): + out[2].patcher.cached_patcher_init = (load_checkpoint_vae_patcher, (ckpt_path, embedding_directory, model_options, te_model_options, disable_dynamic)) return out +def load_checkpoint_vae_patcher(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False): + _, _, vae, _ = load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=False, embedding_directory=embedding_directory, output_model=False, model_options=model_options, te_model_options=te_model_options, disable_dynamic=disable_dynamic) + return vae.patcher + def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False): model, *_ = load_checkpoint_guess_config(ckpt_path, False, False, False, embedding_directory=embedding_directory, @@ -1954,6 +2064,26 @@ def load_diffusion_model(unet_path, model_options={}, disable_dynamic=False): model.cached_patcher_init = (load_diffusion_model, (unet_path, model_options)) return model +def load_vae_patcher(vae_path, metadata=None, device=None): + """Reload a VAE from disk and return its patcher. + + Used as the ``cached_patcher_init`` factory on ``VAE.patcher`` so that + :meth:`comfy.model_patcher.ModelPatcher.deepclone_multigpu` can produce a + fresh VAE patcher with no inherited source-device storage tracking. The + optional device matches the source loader's VAE initialization path; the + cloned patcher's load_device still controls the device targeted by the + multigpu clone. Without this, bare ``copy.deepcopy`` of the VAE wrapper + carries dynamic-VRAM allocator state forward to the clone, which causes + per-device worker threads in tiled encode/decode dispatch to access weights + through the source-device buffer.""" + if metadata is None: + sd, metadata = comfy.utils.load_torch_file(vae_path, return_metadata=True) + else: + sd = comfy.utils.load_torch_file(vae_path) + vae = VAE(sd=sd, metadata=metadata, device=device) + vae.throw_exception_if_invalid() + return vae.patcher + def load_unet(unet_path, dtype=None): logging.warning("The load_unet function has been deprecated and will be removed please switch to: load_diffusion_model") return load_diffusion_model(unet_path, model_options={"dtype": dtype}) diff --git a/comfy/utils.py b/comfy/utils.py index c53e0cb914db..6b12676d2dfe 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1263,9 +1263,7 @@ def mult_list_upscale(a): continue positions = [range(0, s.shape[d+2] - overlap[d], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)] - all_positions = list(itertools.product(*positions)) - - split = {devices[i]: all_positions[i::len(devices)] for i in range(len(devices))} + split = {devices[i]: itertools.islice(itertools.product(*positions), i, None, len(devices)) for i in range(len(devices))} out_shape = [s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]) div_shape = [s.shape[0], 1] + mult_list_upscale(s.shape[2:]) @@ -1277,7 +1275,8 @@ def mult_list_upscale(a): def worker(device, my_positions): try: - torch.cuda.set_device(device) + if device.type == "cuda": + torch.cuda.set_device(device) fn = functions[device] local_buf = bufs[device] local_div = divs[device] @@ -1306,17 +1305,24 @@ def worker(device, my_positions): o = local_buf o_d = local_div + ps_view = ps + mask_view = mask for d in range(dims): - o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2]) - o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2]) + l = min(ps_view.shape[d + 2], o.shape[d + 2] - upscaled[d]) + o = o.narrow(d + 2, upscaled[d], l) + o_d = o_d.narrow(d + 2, upscaled[d], l) + if l < ps_view.shape[d + 2]: + ps_view = ps_view.narrow(d + 2, 0, l) + mask_view = mask_view.narrow(d + 2, 0, l) - o.add_(ps * mask) - o_d.add_(mask) + o.add_(ps_view * mask_view) + o_d.add_(mask_view) if pbar is not None: with pbar_lock: pbar.update(1) - torch.cuda.synchronize(device) + if device.type == "cuda": + torch.cuda.synchronize(device) except BaseException as e: with worker_lock: worker_errors.append(e) @@ -1330,7 +1336,7 @@ def worker(device, my_positions): raise worker_errors[0] combined_buf = sum(bufs.values()) - combined_div = sum(divs.values()).clamp_(min=1e-12) + combined_div = sum(divs.values()) output[b:b+1] = combined_buf / combined_div return output diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py index 021dfca3f591..dd0f7679869e 100644 --- a/comfy_extras/nodes_multigpu.py +++ b/comfy_extras/nodes_multigpu.py @@ -13,8 +13,8 @@ class MultiGPUCFGSplitNode(io.ComfyNode): """ - Attaches per-device deepclones to any connected MODEL and/or UPSCALE_MODEL so downstream - nodes that recognize the attached state dispatch their work across multiple GPUs. + Attaches per-device deepclones to any connected MODEL, UPSCALE_MODEL, and/or VAE so + downstream nodes that recognize the attached state dispatch their work across multiple GPUs. Place after nodes that modify the model object itself (compile, attention-switch, etc.). Otherwise position is not order-sensitive. @@ -30,21 +30,25 @@ def define_schema(cls): inputs=[ io.Model.Input("model", optional=True), io.UpscaleModel.Input("upscale_model", optional=True), + io.Vae.Input("vae", optional=True), io.Int.Input("max_gpus", default=2, min=1, step=1), ], outputs=[ io.Model.Output(), io.UpscaleModel.Output(), + io.Vae.Output(), ], ) @classmethod - def execute(cls, max_gpus: int, model: ModelPatcher = None, upscale_model=None) -> io.NodeOutput: + def execute(cls, max_gpus: int, model: ModelPatcher = None, upscale_model=None, vae=None) -> io.NodeOutput: if model is not None: model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, reuse_loaded=True) if upscale_model is not None: upscale_model = comfy.multigpu.create_upscale_model_multigpu_deepclones(upscale_model, max_gpus) - return io.NodeOutput(model, upscale_model) + if vae is not None: + vae = comfy.multigpu.create_vae_multigpu_deepclones(vae, max_gpus) + return io.NodeOutput(model, upscale_model, vae) class MultiGPUOptionsNode(io.ComfyNode): diff --git a/nodes.py b/nodes.py index 2f3856330bde..9193e9ddb21a 100644 --- a/nodes.py +++ b/nodes.py @@ -869,6 +869,7 @@ def VALIDATE_INPUTS(cls, device="default"): #TODO: scale factor? def load_vae(self, vae_name, device="default"): metadata = None + vae_path = None if vae_name == "pixel_space": sd = {} sd["pixel_space_vae"] = torch.tensor(1.0) @@ -888,6 +889,11 @@ def load_vae(self, vae_name, device="default"): resolved = comfy.model_management.resolve_gpu_device_option(device) vae = comfy.sd.VAE(sd=sd, metadata=metadata, device=resolved) vae.throw_exception_if_invalid() + # Register a reload factory on the patcher so MultiGPU work-units can use + # ModelPatcher.deepclone_multigpu to produce per-device clones from the + # same loader context (mirrors UNETLoader / CLIPLoader / checkpoint loader). + if vae_path is not None: + vae.patcher.cached_patcher_init = (comfy.sd.load_vae_patcher, (vae_path, metadata, resolved)) return (vae,) class ControlNetLoader: diff --git a/tests-unit/comfy_test/multigpu_test.py b/tests-unit/comfy_test/multigpu_test.py new file mode 100644 index 000000000000..e7ba15df7884 --- /dev/null +++ b/tests-unit/comfy_test/multigpu_test.py @@ -0,0 +1,147 @@ +import importlib +import sys +import types + +import torch + +import comfy.utils + + +def install_fake_comfy_aimdo(monkeypatch): + package = types.ModuleType("comfy_aimdo") + package.__path__ = [] + monkeypatch.setitem(sys.modules, "comfy_aimdo", package) + for name in ("vram_buffer", "host_buffer", "torch", "model_vbar", "model_mmap", "control"): + module = types.ModuleType(f"comfy_aimdo.{name}") + monkeypatch.setitem(sys.modules, f"comfy_aimdo.{name}", module) + setattr(package, name, module) + + +def test_tiled_scale_multidim_multigpu_clips_edge_tiles(monkeypatch): + monkeypatch.setattr(torch.cuda, "set_device", lambda device: None) + monkeypatch.setattr(torch.cuda, "synchronize", lambda device: None) + + scale = 1.1 + + def upscale(a): + return torch.ones((a.shape[0], 1, round(a.shape[-1] * scale)), dtype=a.dtype, device=a.device) + + samples = torch.ones((1, 1, 11)) + devices = [torch.device("cpu:0"), torch.device("cpu:1")] + + actual = comfy.utils.tiled_scale_multidim_multigpu( + samples, + {device: upscale for device in devices}, + tile=(5,), + overlap=2, + upscale_amount=scale, + out_channels=1, + output_device="cpu", + ) + expected = comfy.utils.tiled_scale_multidim( + samples, + upscale, + tile=(5,), + overlap=2, + upscale_amount=scale, + out_channels=1, + output_device="cpu", + ) + + assert actual.shape == expected.shape == (1, 1, 12) + torch.testing.assert_close(actual, expected) + + +def test_upscale_model_deepclone_does_not_copy_existing_clone_graph(monkeypatch): + class FakeModel: + def __init__(self): + self.param = torch.nn.Parameter(torch.ones(1)) + + def eval(self): + return self + + def parameters(self): + return [self.param] + + class FakeDescriptor: + def __init__(self): + self.model = FakeModel() + self.device = None + + def to(self, device): + self.device = device + return self + + first_device = torch.device("cpu:0") + second_device = torch.device("cpu:1") + stale_device = torch.device("cpu:2") + existing_clone = FakeDescriptor() + stale_clone = FakeDescriptor() + source = FakeDescriptor() + source.multigpu_clones = {first_device: existing_clone, stale_device: stale_clone} + fake_model_management = types.ModuleType("comfy.model_management") + fake_model_management.get_all_torch_devices = lambda exclude_current=True: [first_device, second_device] + monkeypatch.setitem(sys.modules, "comfy.model_management", fake_model_management) + import comfy + monkeypatch.setattr(comfy, "model_management", fake_model_management, raising=False) + import comfy.multigpu + importlib.reload(comfy.multigpu) + + cloned = comfy.multigpu.create_upscale_model_multigpu_deepclones(source, max_gpus=3) + + assert cloned is not source + assert cloned.multigpu_clones[first_device] is existing_clone + assert stale_device not in cloned.multigpu_clones + assert second_device in cloned.multigpu_clones + assert not hasattr(cloned.multigpu_clones[second_device], "multigpu_clones") + assert cloned.multigpu_clones[second_device].device == "cpu" + assert not cloned.multigpu_clones[second_device].model.param.requires_grad + + single_gpu_clone = comfy.multigpu.create_upscale_model_multigpu_deepclones(source, max_gpus=1) + assert single_gpu_clone is not source + assert not hasattr(single_gpu_clone, "multigpu_clones") + + +def test_checkpoint_loader_registers_vae_cached_patcher(monkeypatch): + install_fake_comfy_aimdo(monkeypatch) + import comfy.sd + importlib.reload(comfy.sd) + + class FakeVAE: + def __init__(self): + self.patcher = types.SimpleNamespace(cached_patcher_init=None) + + model_patcher = types.SimpleNamespace(cached_patcher_init=None) + vae = FakeVAE() + metadata = {"format": "checkpoint"} + monkeypatch.setattr(comfy.utils, "load_torch_file", lambda path, return_metadata=False: ({}, metadata)) + monkeypatch.setattr( + comfy.sd, + "load_state_dict_guess_config", + lambda *args, **kwargs: (model_patcher, None, vae, None), + ) + + comfy.sd.load_checkpoint_guess_config("checkpoint.safetensors", output_vae=True) + + assert model_patcher.cached_patcher_init[0] is comfy.sd.load_checkpoint_guess_config + assert vae.patcher.cached_patcher_init[0] is comfy.sd.load_checkpoint_vae_patcher + assert vae.patcher.cached_patcher_init[1][0] == "checkpoint.safetensors" + + +def test_checkpoint_loader_skips_cached_patcher_for_placeholder_vae(monkeypatch): + install_fake_comfy_aimdo(monkeypatch) + import comfy.sd + importlib.reload(comfy.sd) + + model_patcher = types.SimpleNamespace(cached_patcher_init=None) + placeholder_vae = types.SimpleNamespace() + metadata = {"format": "checkpoint"} + monkeypatch.setattr(comfy.utils, "load_torch_file", lambda path, return_metadata=False: ({}, metadata)) + monkeypatch.setattr( + comfy.sd, + "load_state_dict_guess_config", + lambda *args, **kwargs: (model_patcher, None, placeholder_vae, None), + ) + + assert comfy.sd.load_checkpoint_guess_config("diffusion_only.safetensors", output_vae=True)[2] is placeholder_vae + assert model_patcher.cached_patcher_init[0] is comfy.sd.load_checkpoint_guess_config From 5dc4e38b89503ba77d58ae450d3f3fff30f57fa8 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 22 May 2026 16:44:29 -0700 Subject: [PATCH 78/90] Defer @pollockjj's tiled-VAE and UPSCALE_MODEL MultiGPU lanes (#14066) * Revert "Add tiled VAE lane to MultiGPU Work Units" This reverts commit 4d3d68e4731cf366289f9f4ca11242f4a78956df. The tiled VAE lane will land as part of a follow-up PR alongside the UPSCALE_MODEL lane, separated from the threaded-loader fix PR (#14052) to keep the upstream merge focused. * Revert "Add UPSCALE_MODEL lane to MultiGPU CFG Split" This reverts commit 74b0a826eaa7962e5093d83a27e13c20d4acfadf. The UPSCALE_MODEL lane will land as part of a follow-up PR alongside the tiled VAE lane, separated from the threaded-loader fix PR (#14052) to keep the upstream merge focused. --------- Co-authored-by: John Pollock --- comfy/multigpu.py | 82 ------------- comfy/sd.py | 132 +-------------------- comfy/utils.py | 157 +------------------------ comfy_extras/nodes_multigpu.py | 27 ++--- comfy_extras/nodes_upscale_model.py | 25 +--- nodes.py | 6 - tests-unit/comfy_test/multigpu_test.py | 147 ----------------------- 7 files changed, 12 insertions(+), 564 deletions(-) delete mode 100644 tests-unit/comfy_test/multigpu_test.py diff --git a/comfy/multigpu.py b/comfy/multigpu.py index 2573185de92f..eff7d06499a9 100644 --- a/comfy/multigpu.py +++ b/comfy/multigpu.py @@ -1,5 +1,4 @@ from __future__ import annotations -import copy import queue import threading import torch @@ -176,87 +175,6 @@ def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options: return model -def create_upscale_model_multigpu_deepclones(upscale_model, max_gpus: int): - """Return a shallow copy of ``upscale_model`` with a ``multigpu_clones`` dict of CPU-resident - descriptor deepclones, one per extra CUDA device up to ``max_gpus``. - """ - full_extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True) - limit_extra_devices = full_extra_devices[:max_gpus - 1] - cloned = copy.copy(upscale_model) - existing = getattr(upscale_model, 'multigpu_clones', None) - limit_extra_device_set = set(limit_extra_devices) - clones: dict[torch.device, object] = {d: c for d, c in dict(existing).items() if d in limit_extra_device_set} if existing else {} - if len(limit_extra_devices) == 0: - logging.info("No extra torch devices need initialization, skipping initializing MultiGPU upscale clones.") - if hasattr(cloned, 'multigpu_clones'): - del cloned.multigpu_clones - return cloned - - for device in limit_extra_devices: - if device in clones: - continue - clone_source = copy.copy(upscale_model) - if hasattr(clone_source, 'multigpu_clones'): - del clone_source.multigpu_clones - clone_desc = copy.deepcopy(clone_source) - clone_desc.model.eval() - for p in clone_desc.model.parameters(): - p.requires_grad_(False) - clone_desc.to("cpu") - clones[device] = clone_desc - logging.info(f"Created CPU upscale_model descriptor deepclone for {device}") - - cloned.multigpu_clones = clones - return cloned - - -def create_vae_multigpu_deepclones(vae, max_gpus: int): - """Return a shallow copy of ``vae`` with a ``multigpu_clones`` dict of CPU-resident VAE - deepclones, one per extra CUDA device up to ``max_gpus``. - """ - vae.throw_exception_if_invalid() - vae_device = torch.device(vae.device) - cloned = copy.copy(vae) - if hasattr(cloned, 'multigpu_clones'): - del cloned.multigpu_clones - if vae_device.type == "cpu": - logging.info("CPU VAE selected, skipping initializing MultiGPU VAE clones.") - return cloned - - full_extra_devices = comfy.model_management.get_all_torch_devices() - - def is_vae_device(device): - return device.type == vae_device.type and device.index == vae_device.index - - limit_extra_devices = [d for d in full_extra_devices if not is_vae_device(d)][:max_gpus - 1] - if len(limit_extra_devices) == 0: - logging.info("No extra torch devices need initialization, skipping initializing MultiGPU VAE clones.") - return cloned - - existing = getattr(vae, 'multigpu_clones', None) - limit_extra_device_set = set(limit_extra_devices) - clones: dict[torch.device, object] = {d: c for d, c in dict(existing).items() if d in limit_extra_device_set} if existing else {} - - for device in limit_extra_devices: - if device in clones: - continue - cloned_patcher = vae.patcher.deepclone_multigpu(new_load_device=device) - clone_vae = copy.copy(vae) - if hasattr(clone_vae, 'multigpu_clones'): - del clone_vae.multigpu_clones - clone_vae.first_stage_model = cloned_patcher.model - clone_vae.patcher = cloned_patcher - clone_vae.first_stage_model.eval() - for p in clone_vae.first_stage_model.parameters(): - p.requires_grad_(False) - clone_vae.first_stage_model.to("cpu") - clones[device] = clone_vae - logging.info(f"Created CPU VAE deepclone for {device}") - - cloned.multigpu_clones = clones - return cloned - - LoadBalance = namedtuple('LoadBalance', ['work_per_device', 'idle_time']) def load_balance_devices(model_options: dict[str], total_work: int, return_idle_time=False, work_normalized: int=None): 'Optimize work assigned to different devices, accounting for their relative speeds and splittable work.' diff --git a/comfy/sd.py b/comfy/sd.py index 6401fdb144da..1670a0486570 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -972,26 +972,6 @@ def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): pbar = comfy.utils.ProgressBar(steps) decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) - - multigpu_clones = getattr(self, 'multigpu_clones', None) - if multigpu_clones: - functions = {self.device: decode_fn} - try: - for dev, c in multigpu_clones.items(): - model_management.free_memory(c.model_size() + c.memory_used_decode(samples.shape, c.vae_dtype), dev) - c.first_stage_model.to(dev) - for dev, c in multigpu_clones.items(): - functions[dev] = lambda a, _c=c, _dev=dev: _c.first_stage_model.decode(a.to(_c.vae_dtype).to(_dev)).to(dtype=_c.vae_output_dtype()) - output = self.process_output( - (comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_y * 2, tile_x // 2), overlap=overlap, upscale_amount=self.upscale_ratio, output_device=self.output_device, pbar=pbar) + - comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_y // 2, tile_x * 2), overlap=overlap, upscale_amount=self.upscale_ratio, output_device=self.output_device, pbar=pbar) + - comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_y, tile_x), overlap=overlap, upscale_amount=self.upscale_ratio, output_device=self.output_device, pbar=pbar)) - / 3.0) - return output - finally: - for c in multigpu_clones.values(): - c.first_stage_model.to("cpu") - output = self.process_output( (comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) + comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) + @@ -1001,49 +981,16 @@ def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): def decode_tiled_1d(self, samples, tile_x=256, overlap=32): if samples.ndim == 3: - memory_shape = samples.shape decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) - clone_decode_fn_factory = lambda c, dev: (lambda a: c.first_stage_model.decode(a.to(c.vae_dtype).to(dev)).to(dtype=c.vae_output_dtype())) else: og_shape = samples.shape - memory_shape = og_shape samples = samples.reshape((og_shape[0], og_shape[1] * og_shape[2], -1)) decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) - clone_decode_fn_factory = lambda c, dev: (lambda a: c.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(c.vae_dtype).to(dev)).to(dtype=c.vae_output_dtype())) - - multigpu_clones = getattr(self, 'multigpu_clones', None) - if multigpu_clones: - functions = {self.device: decode_fn} - try: - for dev, c in multigpu_clones.items(): - model_management.free_memory(c.model_size() + c.memory_used_decode(memory_shape, c.vae_dtype), dev) - c.first_stage_model.to(dev) - for dev, c in multigpu_clones.items(): - functions[dev] = clone_decode_fn_factory(c, dev) - return self.process_output(comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device)) - finally: - for c in multigpu_clones.values(): - c.first_stage_model.to("cpu") return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device)) def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)): decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) - - multigpu_clones = getattr(self, 'multigpu_clones', None) - if multigpu_clones: - functions = {self.device: decode_fn} - try: - for dev, c in multigpu_clones.items(): - model_management.free_memory(c.model_size() + c.memory_used_decode(samples.shape, c.vae_dtype), dev) - c.first_stage_model.to(dev) - for dev, c in multigpu_clones.items(): - functions[dev] = lambda a, _c=c, _dev=dev: _c.first_stage_model.decode(a.to(_c.vae_dtype).to(_dev)).to(dtype=_c.vae_output_dtype()) - return self.process_output(comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device)) - finally: - for c in multigpu_clones.values(): - c.first_stage_model.to("cpu") - return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device)) def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): @@ -1053,25 +1000,6 @@ def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): pbar = comfy.utils.ProgressBar(steps) encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) - - multigpu_clones = getattr(self, 'multigpu_clones', None) - if multigpu_clones: - functions = {self.device: encode_fn} - try: - for dev, c in multigpu_clones.items(): - model_management.free_memory(c.model_size() + c.memory_used_encode(pixel_samples.shape, c.vae_dtype), dev) - c.first_stage_model.to(dev) - for dev, c in multigpu_clones.items(): - functions[dev] = lambda a, _c=c, _dev=dev: _c.first_stage_model.encode((_c.process_input(a)).to(_c.vae_dtype).to(_dev)).to(dtype=_c.vae_output_dtype()) - samples = comfy.utils.tiled_scale_multidim_multigpu(pixel_samples, functions, tile=(tile_y, tile_x), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) - samples += comfy.utils.tiled_scale_multidim_multigpu(pixel_samples, functions, tile=(tile_y // 2, tile_x * 2), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) - samples += comfy.utils.tiled_scale_multidim_multigpu(pixel_samples, functions, tile=(tile_y * 2, tile_x // 2), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) - samples /= 3.0 - return samples - finally: - for c in multigpu_clones.values(): - c.first_stage_model.to("cpu") - samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) @@ -1081,7 +1009,6 @@ def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048): if self.latent_dim == 1: encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) - clone_encode_fn_factory = lambda c, dev: (lambda a: c.first_stage_model.encode((c.process_input(a)).to(c.vae_dtype).to(dev)).to(dtype=c.vae_output_dtype())) out_channels = self.latent_channels upscale_amount = 1 / self.downscale_ratio else: @@ -1091,24 +1018,8 @@ def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048): overlap = overlap // extra_channel_size upscale_amount = 1 / self.downscale_ratio encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).to(dtype=self.vae_output_dtype()) - clone_encode_fn_factory = lambda c, dev: (lambda a: c.first_stage_model.encode((c.process_input(a)).to(c.vae_dtype).to(dev)).reshape(1, out_channels, -1).to(dtype=c.vae_output_dtype())) - - multigpu_clones = getattr(self, 'multigpu_clones', None) - if multigpu_clones: - functions = {self.device: encode_fn} - try: - for dev, c in multigpu_clones.items(): - model_management.free_memory(c.model_size() + c.memory_used_encode(samples.shape, c.vae_dtype), dev) - c.first_stage_model.to(dev) - for dev, c in multigpu_clones.items(): - functions[dev] = clone_encode_fn_factory(c, dev) - out = comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device) - finally: - for c in multigpu_clones.values(): - c.first_stage_model.to("cpu") - else: - out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device) + out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device) if self.latent_dim == 1: return out else: @@ -1116,21 +1027,6 @@ def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048): def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)): encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) - - multigpu_clones = getattr(self, 'multigpu_clones', None) - if multigpu_clones: - functions = {self.device: encode_fn} - try: - for dev, c in multigpu_clones.items(): - model_management.free_memory(c.model_size() + c.memory_used_encode(samples.shape, c.vae_dtype), dev) - c.first_stage_model.to(dev) - for dev, c in multigpu_clones.items(): - functions[dev] = lambda a, _c=c, _dev=dev: _c.first_stage_model.encode((_c.process_input(a)).to(_c.vae_dtype).to(_dev)).to(dtype=_c.vae_output_dtype()) - return comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device) - finally: - for c in multigpu_clones.values(): - c.first_stage_model.to("cpu") - return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device) def decode(self, samples_in, vae_options={}): @@ -1831,14 +1727,8 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd))) if out[0] is not None: out[0].cached_patcher_init = (load_checkpoint_guess_config, (ckpt_path, False, False, False, embedding_directory, output_model, model_options, te_model_options), 0) - if output_vae and out[2] is not None and hasattr(out[2], "patcher"): - out[2].patcher.cached_patcher_init = (load_checkpoint_vae_patcher, (ckpt_path, embedding_directory, model_options, te_model_options, disable_dynamic)) return out -def load_checkpoint_vae_patcher(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False): - _, _, vae, _ = load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=False, embedding_directory=embedding_directory, output_model=False, model_options=model_options, te_model_options=te_model_options, disable_dynamic=disable_dynamic) - return vae.patcher - def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False): model, *_ = load_checkpoint_guess_config(ckpt_path, False, False, False, embedding_directory=embedding_directory, @@ -2064,26 +1954,6 @@ def load_diffusion_model(unet_path, model_options={}, disable_dynamic=False): model.cached_patcher_init = (load_diffusion_model, (unet_path, model_options)) return model -def load_vae_patcher(vae_path, metadata=None, device=None): - """Reload a VAE from disk and return its patcher. - - Used as the ``cached_patcher_init`` factory on ``VAE.patcher`` so that - :meth:`comfy.model_patcher.ModelPatcher.deepclone_multigpu` can produce a - fresh VAE patcher with no inherited source-device storage tracking. The - optional device matches the source loader's VAE initialization path; the - cloned patcher's load_device still controls the device targeted by the - multigpu clone. Without this, bare ``copy.deepcopy`` of the VAE wrapper - carries dynamic-VRAM allocator state forward to the clone, which causes - per-device worker threads in tiled encode/decode dispatch to access weights - through the source-device buffer.""" - if metadata is None: - sd, metadata = comfy.utils.load_torch_file(vae_path, return_metadata=True) - else: - sd = comfy.utils.load_torch_file(vae_path) - vae = VAE(sd=sd, metadata=metadata, device=device) - vae.throw_exception_if_invalid() - return vae.patcher - def load_unet(unet_path, dtype=None): logging.warning("The load_unet function has been deprecated and will be removed please switch to: load_diffusion_model") return load_diffusion_model(unet_path, model_options={"dtype": dtype}) diff --git a/comfy/utils.py b/comfy/utils.py index abfd4079dc1e..49ae12b0660c 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -28,13 +28,13 @@ from PIL import Image import logging import itertools -import threading from torch.nn.functional import interpolate from tqdm.auto import trange from einops import rearrange from comfy.cli_args import args import json import time +import threading import warnings MMAP_TORCH_FILES = args.mmap_torch_files @@ -1187,161 +1187,6 @@ def mult_list_upscale(a): def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None): return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar) - -def tiled_scale_multidim_multigpu(samples, functions, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", downscale=False, index_formulas=None, pbar=None): - """Multigpu variant of tiled_scale_multidim. ``functions`` is a dict[torch.device, callable]. - - Round-robin dispatches tile positions across devices via threading. Each thread maintains - its own per-device CPU output and divisor buffer, applying the same feathered overlap mask - formula as the single-device path. Buffers are summed at the end, producing output that is - bit-equivalent to ``tiled_scale_multidim`` within fp32 add-order noise. - - Falls back to ``tiled_scale_multidim`` with the only function when ``len(functions) < 2``. - Falls back to single-device on the "whole input fits in one tile" branch (no parallelism - available at that granularity). - """ - devices = list(functions.keys()) - if len(devices) < 2: - only_fn = next(iter(functions.values())) if functions else None - return tiled_scale_multidim(samples, only_fn, tile=tile, overlap=overlap, - upscale_amount=upscale_amount, out_channels=out_channels, - output_device=output_device, downscale=downscale, - index_formulas=index_formulas, pbar=pbar) - - dims = len(tile) - - if not (isinstance(upscale_amount, (tuple, list))): - upscale_amount = [upscale_amount] * dims - if not (isinstance(overlap, (tuple, list))): - overlap = [overlap] * dims - if index_formulas is None: - index_formulas = upscale_amount - if not (isinstance(index_formulas, (tuple, list))): - index_formulas = [index_formulas] * dims - - def get_upscale(dim, val): - up = upscale_amount[dim] - return up(val) if callable(up) else up * val - - def get_downscale(dim, val): - up = upscale_amount[dim] - return up(val) if callable(up) else val / up - - def get_upscale_pos(dim, val): - up = index_formulas[dim] - return up(val) if callable(up) else up * val - - def get_downscale_pos(dim, val): - up = index_formulas[dim] - return up(val) if callable(up) else val / up - - if downscale: - get_scale = get_downscale - get_pos = get_downscale_pos - else: - get_scale = get_upscale - get_pos = get_upscale_pos - - def mult_list_upscale(a): - return [round(get_scale(i, a[i])) for i in range(len(a))] - - output = torch.empty([samples.shape[0], out_channels] + mult_list_upscale(samples.shape[2:]), device=output_device) - merge_device = torch.device("cpu") - - pbar_lock = threading.Lock() if pbar is not None else None - primary_device = devices[0] - - samples_staged = samples if samples.device.type == "cpu" else samples.to("cpu", non_blocking=False) - - for b in range(samples_staged.shape[0]): - s = samples_staged[b:b+1] - - if all(s.shape[d+2] <= tile[d] for d in range(dims)): - with torch.inference_mode(): - output[b:b+1] = functions[primary_device](s.to(primary_device, non_blocking=True)).to(output_device) - if pbar is not None: - pbar.update(1) - continue - - positions = [range(0, s.shape[d+2] - overlap[d], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)] - split = {devices[i]: itertools.islice(itertools.product(*positions), i, None, len(devices)) for i in range(len(devices))} - - out_shape = [s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]) - div_shape = [s.shape[0], 1] + mult_list_upscale(s.shape[2:]) - bufs = {d: torch.zeros(out_shape, device=merge_device) for d in devices} - divs = {d: torch.zeros(div_shape, device=merge_device) for d in devices} - - worker_errors: list[BaseException] = [] - worker_lock = threading.Lock() - - def worker(device, my_positions): - try: - if device.type == "cuda": - torch.cuda.set_device(device) - fn = functions[device] - local_buf = bufs[device] - local_div = divs[device] - with torch.inference_mode(): - for it in my_positions: - s_in = s - upscaled = [] - for d in range(dims): - pos = max(0, min(s.shape[d + 2] - overlap[d], it[d])) - l = min(tile[d], s.shape[d + 2] - pos) - s_in = s_in.narrow(d + 2, pos, l) - upscaled.append(round(get_pos(d, pos))) - - s_in_dev = s_in.to(device, non_blocking=True) - ps = fn(s_in_dev).to(merge_device) - mask = torch.ones([1, 1] + list(ps.shape[2:]), device=merge_device) - - for d in range(2, dims + 2): - feather = round(get_scale(d - 2, overlap[d - 2])) - if feather >= mask.shape[d]: - continue - for t in range(feather): - a = (t + 1) / feather - mask.narrow(d, t, 1).mul_(a) - mask.narrow(d, mask.shape[d] - 1 - t, 1).mul_(a) - - o = local_buf - o_d = local_div - ps_view = ps - mask_view = mask - for d in range(dims): - l = min(ps_view.shape[d + 2], o.shape[d + 2] - upscaled[d]) - o = o.narrow(d + 2, upscaled[d], l) - o_d = o_d.narrow(d + 2, upscaled[d], l) - if l < ps_view.shape[d + 2]: - ps_view = ps_view.narrow(d + 2, 0, l) - mask_view = mask_view.narrow(d + 2, 0, l) - - o.add_(ps_view * mask_view) - o_d.add_(mask_view) - - if pbar is not None: - with pbar_lock: - pbar.update(1) - if device.type == "cuda": - torch.cuda.synchronize(device) - except BaseException as e: - with worker_lock: - worker_errors.append(e) - - threads = [threading.Thread(target=worker, args=(d, split[d])) for d in devices] - for t in threads: - t.start() - for t in threads: - t.join() - if worker_errors: - raise worker_errors[0] - - combined_buf = sum(bufs.values()) - combined_div = sum(divs.values()) - output[b:b+1] = combined_buf / combined_div - - return output - def model_trange(*args, **kwargs): if not comfy.memory_management.aimdo_enabled: return trange(*args, **kwargs) diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py index dd0f7679869e..fedafef7114e 100644 --- a/comfy_extras/nodes_multigpu.py +++ b/comfy_extras/nodes_multigpu.py @@ -13,42 +13,33 @@ class MultiGPUCFGSplitNode(io.ComfyNode): """ - Attaches per-device deepclones to any connected MODEL, UPSCALE_MODEL, and/or VAE so - downstream nodes that recognize the attached state dispatch their work across multiple GPUs. + Prepares model to have sampling accelerated via splitting work units. - Place after nodes that modify the model object itself (compile, attention-switch, etc.). - Otherwise position is not order-sensitive. + Should be placed after nodes that modify the model object itself, such as compile or attention-switch nodes. + + Other than those exceptions, this node can be placed in any order. """ @classmethod def define_schema(cls): return io.Schema( node_id="MultiGPU_WorkUnits", - display_name="MultiGPU Work Units", + display_name="MultiGPU CFG Split", category="advanced/multigpu", description=cleandoc(cls.__doc__), inputs=[ - io.Model.Input("model", optional=True), - io.UpscaleModel.Input("upscale_model", optional=True), - io.Vae.Input("vae", optional=True), + io.Model.Input("model"), io.Int.Input("max_gpus", default=2, min=1, step=1), ], outputs=[ io.Model.Output(), - io.UpscaleModel.Output(), - io.Vae.Output(), ], ) @classmethod - def execute(cls, max_gpus: int, model: ModelPatcher = None, upscale_model=None, vae=None) -> io.NodeOutput: - if model is not None: - model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, reuse_loaded=True) - if upscale_model is not None: - upscale_model = comfy.multigpu.create_upscale_model_multigpu_deepclones(upscale_model, max_gpus) - if vae is not None: - vae = comfy.multigpu.create_vae_multigpu_deepclones(vae, max_gpus) - return io.NodeOutput(model, upscale_model, vae) + def execute(cls, model: ModelPatcher, max_gpus: int) -> io.NodeOutput: + model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, reuse_loaded=True) + return io.NodeOutput(model) class MultiGPUOptionsNode(io.ComfyNode): diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index 3a4e3926cab3..d3ee3f1c1964 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -81,33 +81,13 @@ def execute(cls, upscale_model, image) -> io.NodeOutput: output_device = comfy.model_management.intermediate_device() - multigpu_clones = getattr(upscale_model, 'multigpu_clones', None) - if multigpu_clones: - for dev, desc in multigpu_clones.items(): - model_management.free_memory(memory_required, dev) - desc.to(dev) - oom = True try: while oom: try: steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap) pbar = comfy.utils.ProgressBar(steps) - if multigpu_clones: - functions = {device: lambda a: upscale_model(a.float())} - for dev, desc in multigpu_clones.items(): - functions[dev] = lambda a, d=desc: d(a.float()) - s = comfy.utils.tiled_scale_multidim_multigpu( - in_img, - functions, - tile=(tile, tile), - overlap=overlap, - upscale_amount=upscale_model.scale, - pbar=pbar, - output_device=output_device, - ) - else: - s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a.float()), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar, output_device=output_device) + s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a.float()), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar, output_device=output_device) oom = False except Exception as e: model_management.raise_non_oom(e) @@ -116,9 +96,6 @@ def execute(cls, upscale_model, image) -> io.NodeOutput: raise e finally: upscale_model.to("cpu") - if multigpu_clones: - for desc in multigpu_clones.values(): - desc.to("cpu") s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0).to(comfy.model_management.intermediate_dtype()) return io.NodeOutput(s) diff --git a/nodes.py b/nodes.py index 9193e9ddb21a..2f3856330bde 100644 --- a/nodes.py +++ b/nodes.py @@ -869,7 +869,6 @@ def VALIDATE_INPUTS(cls, device="default"): #TODO: scale factor? def load_vae(self, vae_name, device="default"): metadata = None - vae_path = None if vae_name == "pixel_space": sd = {} sd["pixel_space_vae"] = torch.tensor(1.0) @@ -889,11 +888,6 @@ def load_vae(self, vae_name, device="default"): resolved = comfy.model_management.resolve_gpu_device_option(device) vae = comfy.sd.VAE(sd=sd, metadata=metadata, device=resolved) vae.throw_exception_if_invalid() - # Register a reload factory on the patcher so MultiGPU work-units can use - # ModelPatcher.deepclone_multigpu to produce per-device clones from the - # same loader context (mirrors UNETLoader / CLIPLoader / checkpoint loader). - if vae_path is not None: - vae.patcher.cached_patcher_init = (comfy.sd.load_vae_patcher, (vae_path, metadata, resolved)) return (vae,) class ControlNetLoader: diff --git a/tests-unit/comfy_test/multigpu_test.py b/tests-unit/comfy_test/multigpu_test.py deleted file mode 100644 index e7ba15df7884..000000000000 --- a/tests-unit/comfy_test/multigpu_test.py +++ /dev/null @@ -1,147 +0,0 @@ -import importlib -import sys -import types - -import torch - -import comfy.utils - - -def install_fake_comfy_aimdo(monkeypatch): - package = types.ModuleType("comfy_aimdo") - package.__path__ = [] - monkeypatch.setitem(sys.modules, "comfy_aimdo", package) - for name in ("vram_buffer", "host_buffer", "torch", "model_vbar", "model_mmap", "control"): - module = types.ModuleType(f"comfy_aimdo.{name}") - monkeypatch.setitem(sys.modules, f"comfy_aimdo.{name}", module) - setattr(package, name, module) - - -def test_tiled_scale_multidim_multigpu_clips_edge_tiles(monkeypatch): - monkeypatch.setattr(torch.cuda, "set_device", lambda device: None) - monkeypatch.setattr(torch.cuda, "synchronize", lambda device: None) - - scale = 1.1 - - def upscale(a): - return torch.ones((a.shape[0], 1, round(a.shape[-1] * scale)), dtype=a.dtype, device=a.device) - - samples = torch.ones((1, 1, 11)) - devices = [torch.device("cpu:0"), torch.device("cpu:1")] - - actual = comfy.utils.tiled_scale_multidim_multigpu( - samples, - {device: upscale for device in devices}, - tile=(5,), - overlap=2, - upscale_amount=scale, - out_channels=1, - output_device="cpu", - ) - expected = comfy.utils.tiled_scale_multidim( - samples, - upscale, - tile=(5,), - overlap=2, - upscale_amount=scale, - out_channels=1, - output_device="cpu", - ) - - assert actual.shape == expected.shape == (1, 1, 12) - torch.testing.assert_close(actual, expected) - - -def test_upscale_model_deepclone_does_not_copy_existing_clone_graph(monkeypatch): - class FakeModel: - def __init__(self): - self.param = torch.nn.Parameter(torch.ones(1)) - - def eval(self): - return self - - def parameters(self): - return [self.param] - - class FakeDescriptor: - def __init__(self): - self.model = FakeModel() - self.device = None - - def to(self, device): - self.device = device - return self - - first_device = torch.device("cpu:0") - second_device = torch.device("cpu:1") - stale_device = torch.device("cpu:2") - existing_clone = FakeDescriptor() - stale_clone = FakeDescriptor() - source = FakeDescriptor() - source.multigpu_clones = {first_device: existing_clone, stale_device: stale_clone} - fake_model_management = types.ModuleType("comfy.model_management") - fake_model_management.get_all_torch_devices = lambda exclude_current=True: [first_device, second_device] - monkeypatch.setitem(sys.modules, "comfy.model_management", fake_model_management) - import comfy - monkeypatch.setattr(comfy, "model_management", fake_model_management, raising=False) - import comfy.multigpu - importlib.reload(comfy.multigpu) - - cloned = comfy.multigpu.create_upscale_model_multigpu_deepclones(source, max_gpus=3) - - assert cloned is not source - assert cloned.multigpu_clones[first_device] is existing_clone - assert stale_device not in cloned.multigpu_clones - assert second_device in cloned.multigpu_clones - assert not hasattr(cloned.multigpu_clones[second_device], "multigpu_clones") - assert cloned.multigpu_clones[second_device].device == "cpu" - assert not cloned.multigpu_clones[second_device].model.param.requires_grad - - single_gpu_clone = comfy.multigpu.create_upscale_model_multigpu_deepclones(source, max_gpus=1) - assert single_gpu_clone is not source - assert not hasattr(single_gpu_clone, "multigpu_clones") - - -def test_checkpoint_loader_registers_vae_cached_patcher(monkeypatch): - install_fake_comfy_aimdo(monkeypatch) - import comfy.sd - importlib.reload(comfy.sd) - - class FakeVAE: - def __init__(self): - self.patcher = types.SimpleNamespace(cached_patcher_init=None) - - model_patcher = types.SimpleNamespace(cached_patcher_init=None) - vae = FakeVAE() - metadata = {"format": "checkpoint"} - monkeypatch.setattr(comfy.utils, "load_torch_file", lambda path, return_metadata=False: ({}, metadata)) - monkeypatch.setattr( - comfy.sd, - "load_state_dict_guess_config", - lambda *args, **kwargs: (model_patcher, None, vae, None), - ) - - comfy.sd.load_checkpoint_guess_config("checkpoint.safetensors", output_vae=True) - - assert model_patcher.cached_patcher_init[0] is comfy.sd.load_checkpoint_guess_config - assert vae.patcher.cached_patcher_init[0] is comfy.sd.load_checkpoint_vae_patcher - assert vae.patcher.cached_patcher_init[1][0] == "checkpoint.safetensors" - - -def test_checkpoint_loader_skips_cached_patcher_for_placeholder_vae(monkeypatch): - install_fake_comfy_aimdo(monkeypatch) - import comfy.sd - importlib.reload(comfy.sd) - - model_patcher = types.SimpleNamespace(cached_patcher_init=None) - placeholder_vae = types.SimpleNamespace() - metadata = {"format": "checkpoint"} - monkeypatch.setattr(comfy.utils, "load_torch_file", lambda path, return_metadata=False: ({}, metadata)) - monkeypatch.setattr( - comfy.sd, - "load_state_dict_guess_config", - lambda *args, **kwargs: (model_patcher, None, placeholder_vae, None), - ) - - assert comfy.sd.load_checkpoint_guess_config("diffusion_only.safetensors", output_vae=True)[2] is placeholder_vae - assert model_patcher.cached_patcher_init[0] is comfy.sd.load_checkpoint_guess_config From 5ffea26de7a4fe046b0c95dcb85195e56f9677d6 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 22 May 2026 17:13:55 -0700 Subject: [PATCH 79/90] Fix single-GPU non-CUDA regressions on worksplit-multigpu Two fixes for single-GPU users on non-NVIDIA backends; multi-GPU non-CUDA support is intentionally out of scope here (tracked separately). 1. get_all_torch_devices: add AMD/ROCm, MLU, and a generic fallback arm. Previously the function only enumerated NVIDIA, Intel XPU, and Ascend NPU when cpu_state==GPU; on AMD/ROCm (which exposes its GPU through torch.cuda.*) and DirectML it fell through to an empty list. The biggest user-visible regression: unload_all_models() iterates this list, so it became a silent no-op on AMD/ROCm. /free, manager unloads, and shutdown stopped releasing VRAM. - is_amd() now shares the torch.cuda.* arm with is_nvidia(), since ROCm reuses the CUDA API surface. - is_mlu() gets its own arm using torch.mlu.device_count(). - A final fallback appends get_torch_device() for any GPU backend the explicit arms miss (notably DirectML), so callers see at least the current device and unload_all_models works. MPS users are unaffected: cpu_state==MPS already routes to the else branch which appends get_torch_device() returning mps. 2. main.py DynamicVRAM init: guard the comfy_aimdo branch with an explicit is_nvidia() check. The outer condition allows entering the DynamicVRAM init block when the user passes --enable-dynamic-vram explicitly, bypassing the implicit is_nvidia() gate. On non-NVIDIA backends this then runs comfy_aimdo.control.init_devices(range(torch.cuda.device_count())), which is comfy-aimdo-only territory and may crash at startup. Add a leading is_nvidia() check that logs a clean warning and falls back to the legacy ModelPatcher path. --- comfy/model_management.py | 13 ++++++++++++- main.py | 9 ++++++++- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 051062a908e2..c146eee119d4 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -214,7 +214,10 @@ def get_all_torch_devices(exclude_current=False): global cpu_state devices = [] if cpu_state == CPUState.GPU: - if is_nvidia(): + # NVIDIA + AMD/ROCm both expose their GPUs through torch.cuda.*; + # without the AMD arm, single-GPU ROCm users get an empty list + # which silently turns unload_all_models() into a no-op. + if is_nvidia() or is_amd(): for i in range(torch.cuda.device_count()): devices.append(torch.device("cuda", i)) elif is_intel_xpu(): @@ -223,6 +226,14 @@ def get_all_torch_devices(exclude_current=False): elif is_ascend_npu(): for i in range(torch.npu.device_count()): devices.append(torch.device("npu", i)) + elif is_mlu(): + for i in range(torch.mlu.device_count()): + devices.append(torch.device("mlu", i)) + else: + # Fallback for unhandled GPU backends (e.g. DirectML): at least + # report the current device so callers like unload_all_models() + # do not silently no-op. + devices.append(get_torch_device()) else: devices.append(get_torch_device()) if exclude_current: diff --git a/main.py b/main.py index 9933d11eeded..9b22d1304fdf 100644 --- a/main.py +++ b/main.py @@ -216,7 +216,14 @@ def execute_script(script_path): import comfy.model_patcher if args.enable_dynamic_vram or (enables_dynamic_vram() and comfy.model_management.is_nvidia() and not comfy.model_management.is_wsl()): - if (not args.enable_dynamic_vram) and (comfy.model_management.torch_version_numeric < (2, 8)): + if not comfy.model_management.is_nvidia(): + # The implicit auto-enable path is already gated by is_nvidia(); + # this guard handles users who pass --enable-dynamic-vram explicitly + # on a non-NVIDIA system, where torch.cuda.device_count() below would + # either return 0 (silently disabling) or crash on backends that + # raise without CUDA. Be explicit and disable cleanly. + logging.warning("DynamicVRAM was requested but no NVIDIA GPU was detected. Falling back to legacy ModelPatcher.") + elif (not args.enable_dynamic_vram) and (comfy.model_management.torch_version_numeric < (2, 8)): logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows") elif comfy_aimdo.control.init_devices(range(torch.cuda.device_count())): if args.verbose == 'DEBUG': From 403ff496472e050aaefa2fbac2757ed9ba6ec20e Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 22 May 2026 19:44:13 -0700 Subject: [PATCH 80/90] Restore nodes_kling.py removal of max_poll_attempts=280 lost in merge Master commit cf758bd2 (PR #13663, "chore(api-nodes): increase default timeout for partner API node tasks") removed three explicit max_poll_attempts=280 overrides from nodes_kling.py so the new 480 default in util/client.py would take effect. The May 19 merge of master into worksplit-multigpu (ff766e5c) silently discarded those three deletions in the 3-way resolve - nodes_kling.py had no textual conflict but the resolution kept the pre-cf758bd2 lines. The other seven files cf758bd2 touched were merged correctly; this restores nodes_kling.py to match master. Amp-Thread-ID: https://ampcode.com/threads/T-019e52b4-31ee-72cd-996b-64ecd9420e13 Co-authored-by: Amp --- comfy_api_nodes/nodes_kling.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index ef647b20b136..7586f1816214 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -276,7 +276,6 @@ async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusRe cls, ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"), response_model=TaskStatusResponse, - max_poll_attempts=280, status_extractor=lambda r: (r.data.task_status if r.data else None), ) return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) @@ -3066,7 +3065,6 @@ async def execute( cls, ApiEndpoint(path=poll_path), response_model=TaskStatusResponse, - max_poll_attempts=280, status_extractor=lambda r: (r.data.task_status if r.data else None), ) return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) @@ -3192,7 +3190,6 @@ async def execute( cls, ApiEndpoint(path=f"/proxy/kling/v1/videos/image2video/{response.data.task_id}"), response_model=TaskStatusResponse, - max_poll_attempts=280, status_extractor=lambda r: (r.data.task_status if r.data else None), ) return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) From 2369eb00e7dea37654ee30e0ff273a5a573def96 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 22 May 2026 19:53:04 -0700 Subject: [PATCH 81/90] Route aimdo init through get_all_torch_devices() instead of raw torch.cuda The aimdo init call on worksplit-multigpu was using comfy_aimdo.control.init_devices(range(torch.cuda.device_count())) which required adding `import torch` at the top of main.py (violating the "torch should never be imported before this point" expectation) and an inner is_nvidia() guard added in PR #14068 to defend the raw cuda call on non-NVIDIA systems where --enable-dynamic-vram is explicitly passed. Replace the call with comfy_aimdo.control.init_devices( d.index for d in comfy.model_management.get_all_torch_devices() if d.type == "cuda" and d.index is not None ) comfy_aimdo.control.init_devices accepts any iterable of int-coercible device indices and returns False on an empty iterable, so on non-cuda systems the elif naturally falls through to the existing "No working comfy-aimdo install detected" fallback - no extra vendor gate needed. HIP devices appear as type "cuda" in torch, so ROCm setups (which comfy-aimdo supports via aimdo_rocm.so) are handled correctly too. This lets us drop both the `import torch` at the top of main.py and the inner is_nvidia() guard, leaving a single logical-line divergence from master (init_device(single index) -> init_devices(generator of cuda indices)) for multi-GPU aimdo support. Amp-Thread-ID: https://ampcode.com/threads/T-019e52b4-31ee-72cd-996b-64ecd9420e13 Co-authored-by: Amp --- main.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/main.py b/main.py index 9b22d1304fdf..b1a19c7bff2d 100644 --- a/main.py +++ b/main.py @@ -200,7 +200,7 @@ def execute_script(script_path): if 'torch' in sys.modules: logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.") -import torch + import comfy.utils import execution @@ -216,16 +216,12 @@ def execute_script(script_path): import comfy.model_patcher if args.enable_dynamic_vram or (enables_dynamic_vram() and comfy.model_management.is_nvidia() and not comfy.model_management.is_wsl()): - if not comfy.model_management.is_nvidia(): - # The implicit auto-enable path is already gated by is_nvidia(); - # this guard handles users who pass --enable-dynamic-vram explicitly - # on a non-NVIDIA system, where torch.cuda.device_count() below would - # either return 0 (silently disabling) or crash on backends that - # raise without CUDA. Be explicit and disable cleanly. - logging.warning("DynamicVRAM was requested but no NVIDIA GPU was detected. Falling back to legacy ModelPatcher.") - elif (not args.enable_dynamic_vram) and (comfy.model_management.torch_version_numeric < (2, 8)): + if (not args.enable_dynamic_vram) and (comfy.model_management.torch_version_numeric < (2, 8)): logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows") - elif comfy_aimdo.control.init_devices(range(torch.cuda.device_count())): + elif comfy_aimdo.control.init_devices( + d.index for d in comfy.model_management.get_all_torch_devices() + if d.type == "cuda" and d.index is not None + ): if args.verbose == 'DEBUG': comfy_aimdo.control.set_log_debug() elif args.verbose == 'CRITICAL': From 711bb1bae0de0448a46d9bfcba365f377bf23a0e Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 22 May 2026 19:55:50 -0700 Subject: [PATCH 82/90] Simplify aimdo init call - drop redundant type/index filter get_all_torch_devices() only enumerates one vendor at a time (the is_nvidia/is_intel_xpu/is_ascend_npu branches are exclusive and each constructs devices via torch.device("type", i) with a real integer index), and aimdo_control.init_devices short-circuits on lib is None before iterating, so the d.type == "cuda" and d.index is not None filter cannot ever change the result. Match master's trust level and just pass the indices directly. Reduces the divergence from master to a single line: init_device(get_torch_device().index) -> init_devices(d.index for d in get_all_torch_devices()) Amp-Thread-ID: https://ampcode.com/threads/T-019e52b4-31ee-72cd-996b-64ecd9420e13 Co-authored-by: Amp --- main.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/main.py b/main.py index b1a19c7bff2d..fe824439c0c7 100644 --- a/main.py +++ b/main.py @@ -218,10 +218,7 @@ def execute_script(script_path): if args.enable_dynamic_vram or (enables_dynamic_vram() and comfy.model_management.is_nvidia() and not comfy.model_management.is_wsl()): if (not args.enable_dynamic_vram) and (comfy.model_management.torch_version_numeric < (2, 8)): logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows") - elif comfy_aimdo.control.init_devices( - d.index for d in comfy.model_management.get_all_torch_devices() - if d.type == "cuda" and d.index is not None - ): + elif comfy_aimdo.control.init_devices(d.index for d in comfy.model_management.get_all_torch_devices()): if args.verbose == 'DEBUG': comfy_aimdo.control.set_log_debug() elif args.verbose == 'CRITICAL': From 9a12a9328b5b007b762e04599b1eba4cd84e45e4 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 22 May 2026 21:38:58 -0700 Subject: [PATCH 83/90] Revert per-loader device inputs from #13483 / #13748 Remove the device-selection widgets that were added directly to existing loader nodes (and the new CheckpointLoaderDevice / ImageOnlyCheckpointLoaderDevice variants): - nodes.py: - delete CheckpointLoaderDevice class and its NODE_CLASS_MAPPINGS / NODE_DISPLAY_NAME_MAPPINGS entries - remove the optional `device` input + VALIDATE_INPUTS + resolve logic from UNETLoader, VAELoader, CLIPLoader, DualCLIPLoader - restore CLIPLoader/DualCLIPLoader `device` options to ["default", "cpu"] - comfy_extras/nodes_video_model.py: - delete ImageOnlyCheckpointLoaderDevice class + its mapping entries - comfy_extras/nodes_lt_audio.py: - restore LTXAVTextEncoderLoader `device` options to ["default", "cpu"] and revert the resolve logic back to the simple `if device == "cpu"` branch The replacement approach is a small set of passthrough Select*Device nodes (added in the next commit) that retarget MODEL/CLIP/VAE devices without bloating every loader's UI or duplicating loaders. The cuda_device_context helper and the model_management helpers (get_gpu_device_options / resolve_gpu_device_option) from #13483 are kept; they are still used by the new selector nodes. Amp-Thread-ID: https://ampcode.com/threads/T-019e52b4-31ee-72cd-996b-64ecd9420e13 Co-authored-by: Amp --- comfy_extras/nodes_lt_audio.py | 10 +-- comfy_extras/nodes_video_model.py | 65 --------------- nodes.py | 127 +++--------------------------- 3 files changed, 13 insertions(+), 189 deletions(-) diff --git a/comfy_extras/nodes_lt_audio.py b/comfy_extras/nodes_lt_audio.py index afe5b0d13a25..51ddf584a20c 100644 --- a/comfy_extras/nodes_lt_audio.py +++ b/comfy_extras/nodes_lt_audio.py @@ -182,7 +182,7 @@ def define_schema(cls) -> io.Schema: ), io.Combo.Input( "device", - options=comfy.model_management.get_gpu_device_options(), + options=["default", "cpu"], advanced=True, ) ], @@ -197,12 +197,8 @@ def execute(cls, text_encoder, ckpt_name, device="default"): clip_path2 = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) model_options = {} - resolved = comfy.model_management.resolve_gpu_device_option(device) - if resolved is not None: - if resolved.type == "cpu": - model_options["load_device"] = model_options["offload_device"] = resolved - else: - model_options["load_device"] = resolved + if device == "cpu": + model_options["load_device"] = model_options["offload_device"] = torch.device("cpu") clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options) return io.NodeOutput(clip) diff --git a/comfy_extras/nodes_video_model.py b/comfy_extras/nodes_video_model.py index b0d0390ca2bb..8f19895a106f 100644 --- a/comfy_extras/nodes_video_model.py +++ b/comfy_extras/nodes_video_model.py @@ -23,69 +23,6 @@ def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True): return (out[0], out[3], out[2]) -class ImageOnlyCheckpointLoaderDevice: - @classmethod - def INPUT_TYPES(s): - device_options = comfy.model_management.get_gpu_device_options() - return { - "required": { - "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), - }, - "optional": { - "model_device": (device_options, {"advanced": True, "tooltip": "Device for the diffusion model (UNET)."}), - "clip_vision_device": (device_options, {"advanced": True, "tooltip": "Device for the CLIP vision encoder."}), - "vae_device": (device_options, {"advanced": True, "tooltip": "Device for the VAE."}), - } - } - RETURN_TYPES = ("MODEL", "CLIP_VISION", "VAE") - FUNCTION = "load_checkpoint" - - CATEGORY = "loaders/video_models" - - @classmethod - def VALIDATE_INPUTS(cls, model_device="default", clip_vision_device="default", vae_device="default"): - return True - - def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True, model_device="default", clip_vision_device="default", vae_device="default"): - ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) - - model_options = {} - resolved_model = comfy.model_management.resolve_gpu_device_option(model_device) - if resolved_model is not None: - if resolved_model.type == "cpu": - model_options["load_device"] = model_options["offload_device"] = resolved_model - else: - model_options["load_device"] = resolved_model - - cv_model_options = {} - resolved_clip = comfy.model_management.resolve_gpu_device_option(clip_vision_device) - if resolved_clip is not None: - if resolved_clip.type == "cpu": - cv_model_options["load_device"] = cv_model_options["offload_device"] = resolved_clip - else: - cv_model_options["load_device"] = resolved_clip - - # VAE device is passed via model_options["load_device"] which - # load_state_dict_guess_config forwards to the VAE constructor. - # If vae_device differs from model_device, we override after loading. - resolved_vae = comfy.model_management.resolve_gpu_device_option(vae_device) - - out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) - model_patcher, clip, vae, clip_vision = out[:4] - - # Apply VAE device override if it differs from the model device - if resolved_vae is not None and vae is not None: - vae.device = resolved_vae - if resolved_vae.type == "cpu": - offload = resolved_vae - else: - offload = comfy.model_management.vae_offload_device() - vae.patcher.load_device = resolved_vae - vae.patcher.offload_device = offload - - return (model_patcher, clip_vision, vae) - - class SVD_img2vid_Conditioning: @classmethod def INPUT_TYPES(s): @@ -212,7 +149,6 @@ def append(self, conditioning, width, height, temporal, x, y, z, strength): NODE_CLASS_MAPPINGS = { "ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader, - "ImageOnlyCheckpointLoaderDevice": ImageOnlyCheckpointLoaderDevice, "SVD_img2vid_Conditioning": SVD_img2vid_Conditioning, "VideoLinearCFGGuidance": VideoLinearCFGGuidance, "VideoTriangleCFGGuidance": VideoTriangleCFGGuidance, @@ -222,7 +158,6 @@ def append(self, conditioning, width, height, temporal, x, y, z, strength): NODE_DISPLAY_NAME_MAPPINGS = { "ImageOnlyCheckpointLoader": "Load Checkpoint Image Only (img2vid model)", - "ImageOnlyCheckpointLoaderDevice": "Image Only Checkpoint Loader (Device)", "VideoLinearCFGGuidance": "Video Linear CFG Guidance", "VideoTriangleCFGGuidance": "Video Triangle CFG Guidance", } diff --git a/nodes.py b/nodes.py index 2f3856330bde..d1e9a2511a63 100644 --- a/nodes.py +++ b/nodes.py @@ -608,73 +608,6 @@ def load_checkpoint(self, ckpt_name): out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) return out[:3] - -class CheckpointLoaderDevice: - @classmethod - def INPUT_TYPES(s): - device_options = comfy.model_management.get_gpu_device_options() - return { - "required": { - "ckpt_name": (folder_paths.get_filename_list("checkpoints"), {"tooltip": "The name of the checkpoint (model) to load."}), - }, - "optional": { - "model_device": (device_options, {"advanced": True, "tooltip": "Device for the diffusion model (UNET)."}), - "clip_device": (device_options, {"advanced": True, "tooltip": "Device for the CLIP text encoder."}), - "vae_device": (device_options, {"advanced": True, "tooltip": "Device for the VAE."}), - } - } - RETURN_TYPES = ("MODEL", "CLIP", "VAE") - OUTPUT_TOOLTIPS = ("The model used for denoising latents.", - "The CLIP model used for encoding text prompts.", - "The VAE model used for encoding and decoding images to and from latent space.") - FUNCTION = "load_checkpoint" - - CATEGORY = "advanced/loaders" - DESCRIPTION = "Loads a diffusion model checkpoint with per-component device selection for multi-GPU setups." - - @classmethod - def VALIDATE_INPUTS(cls, model_device="default", clip_device="default", vae_device="default"): - return True - - def load_checkpoint(self, ckpt_name, model_device="default", clip_device="default", vae_device="default"): - ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) - - model_options = {} - resolved_model = comfy.model_management.resolve_gpu_device_option(model_device) - if resolved_model is not None: - if resolved_model.type == "cpu": - model_options["load_device"] = model_options["offload_device"] = resolved_model - else: - model_options["load_device"] = resolved_model - - te_model_options = {} - resolved_clip = comfy.model_management.resolve_gpu_device_option(clip_device) - if resolved_clip is not None: - if resolved_clip.type == "cpu": - te_model_options["load_device"] = te_model_options["offload_device"] = resolved_clip - else: - te_model_options["load_device"] = resolved_clip - - # VAE device is passed via model_options["load_device"] which - # load_state_dict_guess_config forwards to the VAE constructor. - # If vae_device differs from model_device, we override after loading. - resolved_vae = comfy.model_management.resolve_gpu_device_option(vae_device) - - out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"), model_options=model_options, te_model_options=te_model_options) - model_patcher, clip, vae = out[:3] - - # Apply VAE device override if it differs from the model device - if resolved_vae is not None and vae is not None: - vae.device = resolved_vae - if resolved_vae.type == "cpu": - offload = resolved_vae - else: - offload = comfy.model_management.vae_offload_device() - vae.patcher.load_device = resolved_vae - vae.patcher.offload_device = offload - - return (model_patcher, clip, vae) - class DiffusersLoader: SEARCH_ALIASES = ["load diffusers model"] @@ -853,21 +786,14 @@ def load_taesd(name): @classmethod def INPUT_TYPES(s): - return {"required": { "vae_name": (s.vae_list(s), )}, - "optional": { - "device": (comfy.model_management.get_gpu_device_options(), {"advanced": True}), - }} + return {"required": { "vae_name": (s.vae_list(s), )}} RETURN_TYPES = ("VAE",) FUNCTION = "load_vae" CATEGORY = "loaders" - @classmethod - def VALIDATE_INPUTS(cls, device="default"): - return True - #TODO: scale factor? - def load_vae(self, vae_name, device="default"): + def load_vae(self, vae_name): metadata = None if vae_name == "pixel_space": sd = {} @@ -885,8 +811,7 @@ def load_vae(self, vae_name, device="default"): metadata = {"tae_latent_channels": 128} else: metadata["tae_latent_channels"] = 128 - resolved = comfy.model_management.resolve_gpu_device_option(device) - vae = comfy.sd.VAE(sd=sd, metadata=metadata, device=resolved) + vae = comfy.sd.VAE(sd=sd, metadata=metadata) vae.throw_exception_if_invalid() return (vae,) @@ -1012,20 +937,13 @@ class UNETLoader: def INPUT_TYPES(s): return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ), "weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"], {"advanced": True}) - }, - "optional": { - "device": (comfy.model_management.get_gpu_device_options(), {"advanced": True}), }} RETURN_TYPES = ("MODEL",) FUNCTION = "load_unet" CATEGORY = "advanced/loaders" - @classmethod - def VALIDATE_INPUTS(cls, device="default"): - return True - - def load_unet(self, unet_name, weight_dtype, device="default"): + def load_unet(self, unet_name, weight_dtype): model_options = {} if weight_dtype == "fp8_e4m3fn": model_options["dtype"] = torch.float8_e4m3fn @@ -1035,13 +953,6 @@ def load_unet(self, unet_name, weight_dtype, device="default"): elif weight_dtype == "fp8_e5m2": model_options["dtype"] = torch.float8_e5m2 - resolved = comfy.model_management.resolve_gpu_device_option(device) - if resolved is not None: - if resolved.type == "cpu": - model_options["load_device"] = model_options["offload_device"] = resolved - else: - model_options["load_device"] = resolved - unet_path = folder_paths.get_full_path_or_raise("diffusion_models", unet_name) model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options) return (model,) @@ -1053,7 +964,7 @@ def INPUT_TYPES(s): "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox"], ), }, "optional": { - "device": (comfy.model_management.get_gpu_device_options(), {"advanced": True}), + "device": (["default", "cpu"], {"advanced": True}), }} RETURN_TYPES = ("CLIP",) FUNCTION = "load_clip" @@ -1062,20 +973,12 @@ def INPUT_TYPES(s): DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncogvideox: t5 xxl (226-token padding)\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5\nomnigen2: qwen vl 2.5 3B" - @classmethod - def VALIDATE_INPUTS(cls, device="default"): - return True - def load_clip(self, clip_name, type="stable_diffusion", device="default"): clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION) model_options = {} - resolved = comfy.model_management.resolve_gpu_device_option(device) - if resolved is not None: - if resolved.type == "cpu": - model_options["load_device"] = model_options["offload_device"] = resolved - else: - model_options["load_device"] = resolved + if device == "cpu": + model_options["load_device"] = model_options["offload_device"] = torch.device("cpu") clip_path = folder_paths.get_full_path_or_raise("text_encoders", clip_name) clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options) @@ -1089,7 +992,7 @@ def INPUT_TYPES(s): "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image", "ltxv", "newbie", "ace"], ), }, "optional": { - "device": (comfy.model_management.get_gpu_device_options(), {"advanced": True}), + "device": (["default", "cpu"], {"advanced": True}), }} RETURN_TYPES = ("CLIP",) FUNCTION = "load_clip" @@ -1098,10 +1001,6 @@ def INPUT_TYPES(s): DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama\nhunyuan_image: qwen2.5vl 7b and byt5 small\nnewbie: gemma-3-4b-it, jina clip v2" - @classmethod - def VALIDATE_INPUTS(cls, device="default"): - return True - def load_clip(self, clip_name1, clip_name2, type, device="default"): clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION) @@ -1109,12 +1008,8 @@ def load_clip(self, clip_name1, clip_name2, type, device="default"): clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2) model_options = {} - resolved = comfy.model_management.resolve_gpu_device_option(device) - if resolved is not None: - if resolved.type == "cpu": - model_options["load_device"] = model_options["offload_device"] = resolved - else: - model_options["load_device"] = resolved + if device == "cpu": + model_options["load_device"] = model_options["offload_device"] = torch.device("cpu") clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options) return (clip,) @@ -2177,7 +2072,6 @@ def expand_image(self, image, left, top, right, bottom, feathering): "InpaintModelConditioning": InpaintModelConditioning, "CheckpointLoader": CheckpointLoader, - "CheckpointLoaderDevice": CheckpointLoaderDevice, "DiffusersLoader": DiffusersLoader, "LoadLatent": LoadLatent, @@ -2195,7 +2089,6 @@ def expand_image(self, image, left, top, right, bottom, feathering): # Loaders "CheckpointLoader": "Load Checkpoint With Config (DEPRECATED)", "CheckpointLoaderSimple": "Load Checkpoint", - "CheckpointLoaderDevice": "Load Checkpoint (Device)", "VAELoader": "Load VAE", "LoraLoader": "Load LoRA (Model and CLIP)", "LoraLoaderModelOnly": "Load LoRA", From d7706091aea7e0f7afd641fb807a473cf9acd06c Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 22 May 2026 21:39:18 -0700 Subject: [PATCH 84/90] Add Select Model/CLIP/VAE Device passthrough nodes Replace the per-loader device widgets removed in the previous commit with three small passthrough selector nodes registered under advanced/multigpu: - Select Model Device (MODEL in/out) - options: default / cpu / gpu:N - Select CLIP Device (CLIP in/out) - options: default / cpu / gpu:N - Select VAE Device (VAE in/out) - options: default / gpu:N (no cpu) Each node clones the inbound patcher (model.clone() / clip.clone() / copy.copy(vae)+vae.patcher.clone()) and retargets load_device (and offload_device for cpu / vae_offload_device for VAE). Portability across machines with different GPU counts: - VALIDATE_INPUTS returns True so an unknown gpu:N value (e.g. a workflow saved on a 2-GPU machine opened on a 1-GPU machine) does not error at validation time. - At runtime, resolve_gpu_device_option(...) returns None for unknown options (with a warning), and each selector then logs a per-node info message and passes through unchanged, matching the no-op style used by MultiGPU CFG Split's "No extra torch devices need initialization..." log. Also adds comfy.model_management.get_gpu_device_options_no_cpu() which the VAE selector uses; on a single-GPU box this collapses to just ["default"], which is fine. Amp-Thread-ID: https://ampcode.com/threads/T-019e52b4-31ee-72cd-996b-64ecd9420e13 Co-authored-by: Amp --- comfy/model_management.py | 8 ++ comfy_extras/nodes_multigpu.py | 149 +++++++++++++++++++++++++++++++++ 2 files changed, 157 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index c146eee119d4..d744f174535d 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -255,6 +255,14 @@ def get_gpu_device_options(): options.append(f"gpu:{i}") return options +def get_gpu_device_options_no_cpu(): + """Variant of get_gpu_device_options that omits "cpu". + + Intended for components like the VAE selector where running on CPU + is impractical and should not be offered as a choice. + """ + return [o for o in get_gpu_device_options() if o != "cpu"] + def resolve_gpu_device_option(option: str): """Resolve a device option string to a torch.device. diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py index fedafef7114e..9e03c56f0ea1 100644 --- a/comfy_extras/nodes_multigpu.py +++ b/comfy_extras/nodes_multigpu.py @@ -1,5 +1,7 @@ from __future__ import annotations +import copy +import logging from inspect import cleandoc from typing import TYPE_CHECKING from typing_extensions import override @@ -8,6 +10,8 @@ if TYPE_CHECKING: from comfy.model_patcher import ModelPatcher + from comfy.sd import CLIP, VAE +import comfy.model_management import comfy.multigpu @@ -42,6 +46,148 @@ def execute(cls, model: ModelPatcher, max_gpus: int) -> io.NodeOutput: return io.NodeOutput(model) +class SelectModelDeviceNode(io.ComfyNode): + """ + Place the diffusion model on a specific device (default / cpu / gpu:N). + + When the selected device does not exist on the current machine + (e.g. a workflow built on a 2-GPU box opened on a 1-GPU box), + the node passes the model through unchanged and logs a message + instead of failing. This keeps workflows portable across machines + with different GPU counts. + """ + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SelectModelDevice", + display_name="Select Model Device", + category="advanced/multigpu", + description=cleandoc(cls.__doc__), + inputs=[ + io.Model.Input("model"), + io.Combo.Input("device", options=comfy.model_management.get_gpu_device_options()), + ], + outputs=[ + io.Model.Output(), + ], + ) + + @classmethod + def VALIDATE_INPUTS(cls, device="default"): + # Allow unknown gpu:N values so portable workflows do not error + # at validation time; runtime fallback will handle them. + return True + + @classmethod + def execute(cls, model: ModelPatcher, device: str = "default") -> io.NodeOutput: + model = model.clone() + resolved = comfy.model_management.resolve_gpu_device_option(device) + if resolved is None: + if device not in (None, "default"): + logging.info(f"Select Model Device: requested device '{device}' not available, passing through unchanged.") + return io.NodeOutput(model) + model.load_device = resolved + if resolved.type == "cpu": + model.offload_device = resolved + return io.NodeOutput(model) + + +class SelectCLIPDeviceNode(io.ComfyNode): + """ + Place the CLIP text encoder on a specific device (default / cpu / gpu:N). + + When the selected device does not exist on the current machine + (e.g. a workflow built on a 2-GPU box opened on a 1-GPU box), + the node passes the CLIP through unchanged and logs a message + instead of failing. This keeps workflows portable across machines + with different GPU counts. + """ + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SelectCLIPDevice", + display_name="Select CLIP Device", + category="advanced/multigpu", + description=cleandoc(cls.__doc__), + inputs=[ + io.Clip.Input("clip"), + io.Combo.Input("device", options=comfy.model_management.get_gpu_device_options()), + ], + outputs=[ + io.Clip.Output(), + ], + ) + + @classmethod + def VALIDATE_INPUTS(cls, device="default"): + return True + + @classmethod + def execute(cls, clip: CLIP, device: str = "default") -> io.NodeOutput: + clip = clip.clone() + resolved = comfy.model_management.resolve_gpu_device_option(device) + if resolved is None: + if device not in (None, "default"): + logging.info(f"Select CLIP Device: requested device '{device}' not available, passing through unchanged.") + return io.NodeOutput(clip) + clip.patcher.load_device = resolved + if resolved.type == "cpu": + clip.patcher.offload_device = resolved + return io.NodeOutput(clip) + + +class SelectVAEDeviceNode(io.ComfyNode): + """ + Place the VAE on a specific device (default / gpu:N). + + CPU is intentionally not offered as a choice; VAE on CPU is impractical. + + When the selected device does not exist on the current machine + (e.g. a workflow built on a 2-GPU box opened on a 1-GPU box), + the node passes the VAE through unchanged and logs a message + instead of failing. This keeps workflows portable across machines + with different GPU counts. + """ + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SelectVAEDevice", + display_name="Select VAE Device", + category="advanced/multigpu", + description=cleandoc(cls.__doc__), + inputs=[ + io.Vae.Input("vae"), + io.Combo.Input("device", options=comfy.model_management.get_gpu_device_options_no_cpu()), + ], + outputs=[ + io.Vae.Output(), + ], + ) + + @classmethod + def VALIDATE_INPUTS(cls, device="default"): + return True + + @classmethod + def execute(cls, vae: VAE, device: str = "default") -> io.NodeOutput: + # VAE has no .clone(); shallow-copy the wrapper and clone the patcher + # so we can retarget load/offload device without affecting the input VAE. + vae = copy.copy(vae) + vae.patcher = vae.patcher.clone() + resolved = comfy.model_management.resolve_gpu_device_option(device) + if resolved is None: + if device not in (None, "default"): + logging.info(f"Select VAE Device: requested device '{device}' not available, passing through unchanged.") + return io.NodeOutput(vae) + vae.device = resolved + vae.patcher.load_device = resolved + vae.patcher.offload_device = comfy.model_management.vae_offload_device() + return io.NodeOutput(vae) + + class MultiGPUOptionsNode(io.ComfyNode): """ Select the relative speed of GPUs in the special case they have significantly different performance from one another. @@ -92,6 +238,9 @@ class MultiGPUExtension(ComfyExtension): async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ MultiGPUCFGSplitNode, + SelectModelDeviceNode, + SelectCLIPDeviceNode, + SelectVAEDeviceNode, # MultiGPUOptionsNode, ] From 4e650055d0b7ea2eef4a0ac5091fadd5bcf553e1 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 22 May 2026 21:46:07 -0700 Subject: [PATCH 85/90] SelectXDevice nodes: register new load_device with ModelPatcherDynamic When --enable-dynamic-vram is on, every ModelPatcher is a ModelPatcherDynamic whose underlying model has a per-device dynamic_pins dict, initialized in __init__ for self.load_device only. If a cloned patcher's load_device is later reassigned (as the Select{Model,CLIP,VAE} Device nodes do), the new device key is missing and partially_unload_ram raises KeyError: device(type='cuda', index=N). Fix: - Extract the per-device dynamic_pins init in ModelPatcherDynamic.__init__ into a new helper method register_load_device(device) which is now also called from __init__. - Each Select*Device node calls clone.patcher.register_load_device(resolved) after retargeting load_device, guarded by hasattr so non-dynamic patchers (plain ModelPatcher in non-dynamic-vram installs) skip it. Caught by happy-path test where SelectCLIPDevice retargeted CLIP from cuda:0 to cuda:1 and CLIPTextEncode then crashed in partially_unload_ram -> dynamic_pins[cuda:1]. Amp-Thread-ID: https://ampcode.com/threads/T-019e52b4-31ee-72cd-996b-64ecd9420e13 Co-authored-by: Amp --- comfy/model_patcher.py | 19 +++++++++++++++---- comfy_extras/nodes_multigpu.py | 6 ++++++ 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 087b0fbfaf93..2bb363fab226 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -1692,16 +1692,27 @@ def __init__(self, model, load_device, offload_device, size=0, weight_inplace_up self.model.dynamic_vbars = {} if not hasattr(self.model, "dynamic_pins"): self.model.dynamic_pins = {} - if self.load_device not in self.model.dynamic_pins: - self.model.dynamic_pins[self.load_device] = { + self.register_load_device(self.load_device) + self.non_dynamic_delegate_model = None + assert load_device is not None + + def register_load_device(self, device): + """Ensure dynamic_pins has an entry for *device*. + + Called from __init__ and also from any code that retargets an + already-constructed patcher to a new load_device (e.g. the + Select{Model,CLIP,VAE}Device selector nodes); without this entry + partially_unload_ram() raises KeyError when it tries to read the + per-device pin state. + """ + if device not in self.model.dynamic_pins: + self.model.dynamic_pins[device] = { "weights": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]), "patches": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]), "hostbufs_initialized": False, "failed": False, "active": False, } - self.non_dynamic_delegate_model = None - assert load_device is not None def is_dynamic(self): return True diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py index 9e03c56f0ea1..df701af56c3e 100644 --- a/comfy_extras/nodes_multigpu.py +++ b/comfy_extras/nodes_multigpu.py @@ -90,6 +90,8 @@ def execute(cls, model: ModelPatcher, device: str = "default") -> io.NodeOutput: model.load_device = resolved if resolved.type == "cpu": model.offload_device = resolved + if hasattr(model, "register_load_device"): + model.register_load_device(resolved) return io.NodeOutput(model) @@ -135,6 +137,8 @@ def execute(cls, clip: CLIP, device: str = "default") -> io.NodeOutput: clip.patcher.load_device = resolved if resolved.type == "cpu": clip.patcher.offload_device = resolved + if hasattr(clip.patcher, "register_load_device"): + clip.patcher.register_load_device(resolved) return io.NodeOutput(clip) @@ -185,6 +189,8 @@ def execute(cls, vae: VAE, device: str = "default") -> io.NodeOutput: vae.device = resolved vae.patcher.load_device = resolved vae.patcher.offload_device = comfy.model_management.vae_offload_device() + if hasattr(vae.patcher, "register_load_device"): + vae.patcher.register_load_device(resolved) return io.NodeOutput(vae) From 9ee15408820168311fb74d31239f0926550141be Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 22 May 2026 21:50:29 -0700 Subject: [PATCH 86/90] SelectXDevice: use lowercase validate_inputs for V3 combo bypass V3 io.ComfyNode subclasses use the lowercase `validate_inputs` hook for opting out of strict combo validation (execution.py line 862); the uppercase `VALIDATE_INPUTS` is the V1 spelling and is ignored on V3 nodes. The strict combo check at execution.py line 1025 is gated on `if x not in validate_function_inputs`, so renaming to `validate_inputs(cls, device='default')` lets unknown `gpu:N` values pass validation and fall through to the runtime fallback. Amp-Thread-ID: https://ampcode.com/threads/T-019e52b4-31ee-72cd-996b-64ecd9420e13 Co-authored-by: Amp --- comfy_extras/nodes_multigpu.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py index df701af56c3e..1fb134eca621 100644 --- a/comfy_extras/nodes_multigpu.py +++ b/comfy_extras/nodes_multigpu.py @@ -74,7 +74,7 @@ def define_schema(cls): ) @classmethod - def VALIDATE_INPUTS(cls, device="default"): + def validate_inputs(cls, device="default"): # Allow unknown gpu:N values so portable workflows do not error # at validation time; runtime fallback will handle them. return True @@ -123,7 +123,7 @@ def define_schema(cls): ) @classmethod - def VALIDATE_INPUTS(cls, device="default"): + def validate_inputs(cls, device="default"): return True @classmethod @@ -172,7 +172,7 @@ def define_schema(cls): ) @classmethod - def VALIDATE_INPUTS(cls, device="default"): + def validate_inputs(cls, device="default"): return True @classmethod From b319c8088b8f125332071d3fbfadcf5c26e830c3 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 22 May 2026 22:29:45 -0700 Subject: [PATCH 87/90] SelectXDevice: address code-review follow-ups True reset semantics for "default": - On first selector application, cache the loader's original load_device / offload_device on the underlying model object (which is shared across patcher clones) and restore those base values when the user picks "default". Previously "default" meant "passthrough" so SelectXDevice(gpu:1) -> SelectXDevice(default) silently kept the gpu:1 routing. CPU + dynamic VRAM: - When SelectModelDevice / SelectCLIPDevice resolves to CPU on a ModelPatcherDynamic, also call clone(disable_dynamic=True) so the result is a plain ModelPatcher, matching ModelPatcherDynamic.__new__'s intent that CPU loads never run through the dynamic path. Fallback to the regular dynamic clone if disable_dynamic is unsupported on that patcher. MultiGPU collision pruning: - After SelectModelDevice retargets the primary patcher, drop any multigpu clone (from a prior MultiGPU CFG Split) whose load_device now matches the primary; otherwise two patchers would be bound to the same device. Logs the prune at info level. SelectVAEDevice: reject CPU at runtime: - The UI uses get_gpu_device_options_no_cpu(), but a workflow opened from another machine could still pass "cpu" through validate_inputs. Detect that case explicitly, log a "CPU is not a supported choice" passthrough message, and leave the VAE unchanged. Cosmetic: - Update VAE node docstring to accurately reflect the runtime CPU rejection rather than the older "intentionally not offered" claim. - Demote the fallback warnings inside resolve_gpu_device_option to no log at all; the Select*Device nodes now own a single context-rich info-level message per failed lookup, so there is no double logging. Amp-Thread-ID: https://ampcode.com/threads/T-019e52b4-31ee-72cd-996b-64ecd9420e13 Co-authored-by: Amp --- comfy/model_management.py | 19 ++--- comfy_extras/nodes_multigpu.py | 142 ++++++++++++++++++++++++++------- 2 files changed, 119 insertions(+), 42 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index d744f174535d..3bce128b268a 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -268,8 +268,10 @@ def resolve_gpu_device_option(option: str): Returns None for "default" (let the caller use its normal default). Returns torch.device("cpu") for "cpu". - For "gpu:N", returns the Nth torch device. Falls back to None if - the index is out of range (caller should use default). + For "gpu:N", returns the Nth torch device. Returns None if the + index is out of range, the option string is malformed, or + unrecognized (callers are expected to log their own context-rich + message before falling back to the default device). """ if option is None or option == "default": return None @@ -278,16 +280,11 @@ def resolve_gpu_device_option(option: str): if option.startswith("gpu:"): try: idx = int(option[4:]) - devices = get_all_torch_devices() - if 0 <= idx < len(devices): - return devices[idx] - else: - logging.warning(f"Device '{option}' not available (only {len(devices)} GPU(s)), using default.") - return None - except (ValueError, IndexError): - logging.warning(f"Invalid device option '{option}', using default.") + except ValueError: return None - logging.warning(f"Unrecognized device option '{option}', using default.") + devices = get_all_torch_devices() + if 0 <= idx < len(devices): + return devices[idx] return None @contextmanager diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py index 1fb134eca621..0e109f426edc 100644 --- a/comfy_extras/nodes_multigpu.py +++ b/comfy_extras/nodes_multigpu.py @@ -46,15 +46,90 @@ def execute(cls, model: ModelPatcher, max_gpus: int) -> io.NodeOutput: return io.NodeOutput(model) +def _remember_base_devices(patcher: ModelPatcher): + """Stash the original load/offload device on the underlying model. + + Stored on patcher.model (which is shared across patcher clones), so + repeated selector applications can recover the loader's original + routing when the user picks "default". + """ + if not hasattr(patcher.model, "_select_base_load_device"): + patcher.model._select_base_load_device = patcher.load_device + patcher.model._select_base_offload_device = patcher.offload_device + + +def _apply_patcher_device(patcher: ModelPatcher, resolved, base_offload_override=None): + """Apply *resolved* to a freshly-cloned patcher; respect base devices on default. + + Returns the (possibly newly-replaced) patcher. For CPU on a dynamic + patcher, also tries to downgrade to a plain ModelPatcher so the + dynamic-only code paths are bypassed (best-effort: silently keeps + the dynamic patcher if downgrade is not supported). + """ + _remember_base_devices(patcher) + base_load = patcher.model._select_base_load_device + base_offload = base_offload_override if base_offload_override is not None else patcher.model._select_base_offload_device + + if resolved is None: + # "default" -> reset routing to whatever the loader produced + patcher.load_device = base_load + patcher.offload_device = base_offload + elif resolved.type == "cpu": + if patcher.is_dynamic(): + try: + patcher = patcher.clone(disable_dynamic=True) + except Exception: + # Downgrade unavailable (no cached_patcher_init); fall + # back to the existing dynamic patcher. + pass + patcher.load_device = resolved + patcher.offload_device = resolved + else: + patcher.load_device = resolved + patcher.offload_device = base_offload + + if hasattr(patcher, "register_load_device"): + patcher.register_load_device(patcher.load_device) + return patcher + + +def _prune_multigpu_collision(model: ModelPatcher, primary_device): + """Drop any multigpu clone whose load_device matches *primary_device*. + + Without pruning, MultiGPU CFG Split would have stacked a clone on + the same device the primary now occupies (i.e. the workflow places + MultiGPU CFG Split before Select Model Device). Keeps the clone set + consistent with the new primary placement. + """ + multigpu_models = model.get_additional_models_with_key("multigpu") + if not multigpu_models: + return + filtered = [m for m in multigpu_models if m.load_device != primary_device] + if len(filtered) != len(multigpu_models): + logging.info(f"Select Model Device: pruning MultiGPU clone on {primary_device} that now collides with the primary model.") + model.set_additional_models("multigpu", filtered) + if hasattr(model, "match_multigpu_clones"): + model.match_multigpu_clones() + + class SelectModelDeviceNode(io.ComfyNode): """ Place the diffusion model on a specific device (default / cpu / gpu:N). + - "default" restores the device assigned by the loader (even after a + prior Select Model Device call). + - "cpu" pins both the load and offload device to CPU. + - "gpu:N" pins the load device to the Nth available GPU; the offload + device is restored to the loader's original choice. + + If the workflow already has MultiGPU CFG Split applied and the chosen + GPU collides with one of the existing multigpu clones, that clone is + dropped so two patchers don't end up bound to the same device. + When the selected device does not exist on the current machine (e.g. a workflow built on a 2-GPU box opened on a 1-GPU box), the node passes the model through unchanged and logs a message - instead of failing. This keeps workflows portable across machines - with different GPU counts. + instead of failing. """ @classmethod @@ -83,15 +158,12 @@ def validate_inputs(cls, device="default"): def execute(cls, model: ModelPatcher, device: str = "default") -> io.NodeOutput: model = model.clone() resolved = comfy.model_management.resolve_gpu_device_option(device) - if resolved is None: - if device not in (None, "default"): - logging.info(f"Select Model Device: requested device '{device}' not available, passing through unchanged.") + if resolved is None and device not in (None, "default"): + logging.info(f"Select Model Device: requested device '{device}' not available, passing through unchanged.") return io.NodeOutput(model) - model.load_device = resolved - if resolved.type == "cpu": - model.offload_device = resolved - if hasattr(model, "register_load_device"): - model.register_load_device(resolved) + model = _apply_patcher_device(model, resolved) + if resolved is not None: + _prune_multigpu_collision(model, model.load_device) return io.NodeOutput(model) @@ -99,11 +171,14 @@ class SelectCLIPDeviceNode(io.ComfyNode): """ Place the CLIP text encoder on a specific device (default / cpu / gpu:N). + - "default" restores the device assigned by the loader. + - "cpu" pins both the load and offload device to CPU. + - "gpu:N" pins the load device to the Nth available GPU. + When the selected device does not exist on the current machine (e.g. a workflow built on a 2-GPU box opened on a 1-GPU box), the node passes the CLIP through unchanged and logs a message - instead of failing. This keeps workflows portable across machines - with different GPU counts. + instead of failing. """ @classmethod @@ -130,15 +205,10 @@ def validate_inputs(cls, device="default"): def execute(cls, clip: CLIP, device: str = "default") -> io.NodeOutput: clip = clip.clone() resolved = comfy.model_management.resolve_gpu_device_option(device) - if resolved is None: - if device not in (None, "default"): - logging.info(f"Select CLIP Device: requested device '{device}' not available, passing through unchanged.") + if resolved is None and device not in (None, "default"): + logging.info(f"Select CLIP Device: requested device '{device}' not available, passing through unchanged.") return io.NodeOutput(clip) - clip.patcher.load_device = resolved - if resolved.type == "cpu": - clip.patcher.offload_device = resolved - if hasattr(clip.patcher, "register_load_device"): - clip.patcher.register_load_device(resolved) + clip.patcher = _apply_patcher_device(clip.patcher, resolved) return io.NodeOutput(clip) @@ -146,13 +216,18 @@ class SelectVAEDeviceNode(io.ComfyNode): """ Place the VAE on a specific device (default / gpu:N). - CPU is intentionally not offered as a choice; VAE on CPU is impractical. + - "default" restores the device assigned by the loader. + - "gpu:N" pins the load device to the Nth available GPU; the offload + device is set to the standard VAE offload device. + + CPU is intentionally not exposed in the UI for the VAE; if a workflow + supplies "cpu" anyway (e.g. opened from another machine), the request + is dropped with a log message and the VAE is passed through unchanged. When the selected device does not exist on the current machine (e.g. a workflow built on a 2-GPU box opened on a 1-GPU box), the node passes the VAE through unchanged and logs a message - instead of failing. This keeps workflows portable across machines - with different GPU counts. + instead of failing. """ @classmethod @@ -182,15 +257,20 @@ def execute(cls, vae: VAE, device: str = "default") -> io.NodeOutput: vae = copy.copy(vae) vae.patcher = vae.patcher.clone() resolved = comfy.model_management.resolve_gpu_device_option(device) - if resolved is None: - if device not in (None, "default"): - logging.info(f"Select VAE Device: requested device '{device}' not available, passing through unchanged.") + if resolved is None and device not in (None, "default"): + logging.info(f"Select VAE Device: requested device '{device}' not available, passing through unchanged.") + return io.NodeOutput(vae) + if resolved is not None and resolved.type == "cpu": + logging.info("Select VAE Device: CPU is not a supported choice, passing through unchanged.") return io.NodeOutput(vae) - vae.device = resolved - vae.patcher.load_device = resolved - vae.patcher.offload_device = comfy.model_management.vae_offload_device() - if hasattr(vae.patcher, "register_load_device"): - vae.patcher.register_load_device(resolved) + vae.patcher = _apply_patcher_device( + vae.patcher, resolved, + base_offload_override=comfy.model_management.vae_offload_device(), + ) + # VAE caches the working device separately from its patcher. + if not hasattr(vae, "_select_base_device"): + vae._select_base_device = vae.device + vae.device = vae._select_base_device if resolved is None else resolved return io.NodeOutput(vae) From bece6b2aec98df32c08c7ce4fbd7fe26bc539835 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sat, 23 May 2026 19:11:48 -0700 Subject: [PATCH 88/90] multigpu: refactor deepclone_multigpu + register cached_patcher_init for CLIP/VAE; Select*Device retargets via deepclone - ModelPatcher.deepclone_multigpu: remove copy.deepcopy fallback. Require cached_patcher_init (raise a descriptive RuntimeError if missing) and always go through clone(model_override=...) with empty backup containers so the per-device clone owns a pristine, unpatched module instead of a deepcopy of an already-loaded/already-patched one. Also call register_load_device on the new patcher so ModelPatcherDynamic per-device bookkeeping (e.g. dynamic_pins) is populated for the requested load device. - comfy/sd.py: register cached_patcher_init on the CLIP and VAE patchers returned by load_checkpoint_guess_config, and on the patcher returned by load_diffusion_model's companion paths. Add load_checkpoint_clip_patcher, load_checkpoint_vae_patcher, and load_vae_patcher reload helpers so the same loader context can be reused to produce per-device clones. - nodes.py: VAELoader registers cached_patcher_init on the produced VAE's patcher when there is a single backing file (skip for pixel_space and composite image-TAESDs which aren't addressable by a single path). - comfy_extras/nodes_multigpu.py: SelectModelDevice / SelectCLIPDevice / SelectVAEDevice now retarget via deepclone_multigpu when the requested device differs from the current load_device, so the consumed model is not just relabeled but actually rehomed onto the chosen device. Verified on runner-2 (2x RTX 4090, comfy-aimdo 0.4.4): - 10/10 focused unit tests (deepclone behavior, missing-factory error path, Select*Device behavior). - Device-switch-after-consumption end-to-end (SD1.5) produces bit-identical PNGs on cuda:0 and cuda:1. - Z Image multigpu CFG split: ~1.90x speedup (10.5s vs 19.9s steady). - Qwen Image multigpu CFG split (real text negative, cfg=4): ~1.69x speedup (32.5s vs 54.8s steady) -- matches pre-refactor numbers. - Baseline (patch stashed) and patched produce identical timings on both models, so the refactor is performance-neutral. Amp-Thread-ID: https://ampcode.com/threads/T-019e5783-b810-74b1-8ca9-09d675de1479 Co-authored-by: Amp --- comfy/model_patcher.py | 41 +++++++---- comfy/sd.py | 62 ++++++++++++++++ comfy_extras/nodes_multigpu.py | 126 ++++++++++++++++++++++++--------- nodes.py | 9 +++ 4 files changed, 192 insertions(+), 46 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 2bb363fab226..c68a52cc2b69 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -457,23 +457,38 @@ def clone(self, disable_dynamic=False, model_override=None): def deepclone_multigpu(self, new_load_device=None, models_cache: dict[uuid.UUID,ModelPatcher]=None): logging.info(f"Creating deepclone of {self.model.__class__.__name__} for {new_load_device if new_load_device else self.load_device}.") + if self.cached_patcher_init is None: + raise RuntimeError( + f"Cannot create multigpu deepclone of {self.model.__class__.__name__}: " + "the loader that produced this model does not support multigpu " + "(cached_patcher_init is not initialized). Use a core loader " + "(CheckpointLoaderSimple, UNETLoader, CLIPLoader/DualCLIPLoader, VAELoader), " + "or have the custom loader register a cached_patcher_init factory." + ) comfy.model_management.unload_model_and_clones(self) - n = self.clone() + # Produce a freshly-loaded patcher from the loader factory so the multigpu + # clone owns its own untainted model weights (rather than relying on + # copy.deepcopy of an already-patched/already-loaded module). + temp_model_patcher: ModelPatcher | list[ModelPatcher] = self.cached_patcher_init[0](*self.cached_patcher_init[1]) + if len(self.cached_patcher_init) > 2: + temp_model_patcher = temp_model_patcher[self.cached_patcher_init[2]] + # Override clone()'s normal "share self.model + share backup containers" with + # the pristine model from temp_model_patcher plus empty backup containers -- + # the fresh model has no patches applied, so any deepcopy of self's stale + # backup/object_patches_backup/pinned would just propagate dead state that + # no longer corresponds to anything in n.model. + model_override = (temp_model_patcher.model, ({}, {}, {}, set())) + n = self.clone(model_override=model_override) + # clone() copies hook_backup by reference from self; reset since model is pristine. + n.hook_backup = {} # set load device, if present if new_load_device is not None: n.load_device = new_load_device - if self.cached_patcher_init is not None: - temp_model_patcher: ModelPatcher | list[ModelPatcher] = self.cached_patcher_init[0](*self.cached_patcher_init[1]) - if len(self.cached_patcher_init) > 2: - temp_model_patcher = temp_model_patcher[self.cached_patcher_init[2]] - n.model = temp_model_patcher.model - else: - n.model = copy.deepcopy(n.model) - # unlike for normal clone, backup dicts that shared same ref should not; - # otherwise, patchers that have deep copies of base models will erroneously influence each other. - n.backup = copy.deepcopy(n.backup) - n.object_patches_backup = copy.deepcopy(n.object_patches_backup) - n.hook_backup = copy.deepcopy(n.hook_backup) + # Ensure any per-device bookkeeping (e.g. ModelPatcherDynamic.dynamic_pins) + # has an entry for n.load_device on the freshly-loaded n.model. temp_model_patcher's + # __init__ only registered its own (default) load_device. + if hasattr(n, "register_load_device"): + n.register_load_device(n.load_device) # multigpu clone should not have multigpu additional_models entry n.remove_additional_models("multigpu") # multigpu_clone all stored additional_models; make sure circular references are properly handled diff --git a/comfy/sd.py b/comfy/sd.py index 1670a0486570..084170c627a5 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1727,8 +1727,50 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd))) if out[0] is not None: out[0].cached_patcher_init = (load_checkpoint_guess_config, (ckpt_path, False, False, False, embedding_directory, output_model, model_options, te_model_options), 0) + # Register reload factories for the CLIP and VAE produced by the same checkpoint so + # ModelPatcher.deepclone_multigpu can spawn per-device copies (Select{CLIP,VAE}Device, + # MultiGPU work-units, etc.) without falling back to copy.deepcopy of an + # already-loaded module. + if out[1] is not None and getattr(out[1], "patcher", None) is not None: + out[1].patcher.cached_patcher_init = (load_checkpoint_clip_patcher, (ckpt_path, embedding_directory, model_options, te_model_options)) + if out[2] is not None and getattr(out[2], "patcher", None) is not None: + out[2].patcher.cached_patcher_init = (load_checkpoint_vae_patcher, (ckpt_path, embedding_directory, model_options, te_model_options)) return out + +def load_checkpoint_clip_patcher(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False): + """Reload only the CLIP patcher from a checkpoint. Used as the cached_patcher_init + factory for the CLIP returned by load_checkpoint_guess_config.""" + _, clip, _, _ = load_checkpoint_guess_config( + ckpt_path, + output_vae=False, + output_clip=True, + output_clipvision=False, + embedding_directory=embedding_directory, + output_model=False, + model_options=model_options, + te_model_options=te_model_options, + disable_dynamic=disable_dynamic, + ) + return clip.patcher + + +def load_checkpoint_vae_patcher(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False): + """Reload only the VAE patcher from a checkpoint. Used as the cached_patcher_init + factory for the VAE returned by load_checkpoint_guess_config.""" + _, _, vae, _ = load_checkpoint_guess_config( + ckpt_path, + output_vae=True, + output_clip=False, + output_clipvision=False, + embedding_directory=embedding_directory, + output_model=False, + model_options=model_options, + te_model_options=te_model_options, + disable_dynamic=disable_dynamic, + ) + return vae.patcher + def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False): model, *_ = load_checkpoint_guess_config(ckpt_path, False, False, False, embedding_directory=embedding_directory, @@ -1954,6 +1996,26 @@ def load_diffusion_model(unet_path, model_options={}, disable_dynamic=False): model.cached_patcher_init = (load_diffusion_model, (unet_path, model_options)) return model + +def load_vae_patcher(vae_path, metadata=None, device=None, disable_dynamic=False): + """Reload a disk-backed VAE from ``vae_path`` and return its patcher. + + Used as the ``cached_patcher_init`` factory on ``VAE.patcher`` so + :meth:`comfy.model_patcher.ModelPatcher.deepclone_multigpu` can produce a + fresh, untainted VAE patcher (no inherited per-device load state, no + in-place quantization fallout) for multigpu work-units and the + SelectVAEDevice node. The optional ``device`` matches the source loader's + VAE initialization path; the deepclone's ``load_device`` still controls + where the cloned patcher is targeted. + """ + if metadata is None: + sd, metadata = comfy.utils.load_torch_file(vae_path, return_metadata=True) + else: + sd = comfy.utils.load_torch_file(vae_path) + vae = VAE(sd=sd, metadata=metadata, device=device) + vae.throw_exception_if_invalid() + return vae.patcher + def load_unet(unet_path, dtype=None): logging.warning("The load_unet function has been deprecated and will be removed please switch to: load_diffusion_model") return load_diffusion_model(unet_path, model_options={"dtype": dtype}) diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py index 0e109f426edc..d39cca3f8f68 100644 --- a/comfy_extras/nodes_multigpu.py +++ b/comfy_extras/nodes_multigpu.py @@ -49,48 +49,82 @@ def execute(cls, model: ModelPatcher, max_gpus: int) -> io.NodeOutput: def _remember_base_devices(patcher: ModelPatcher): """Stash the original load/offload device on the underlying model. - Stored on patcher.model (which is shared across patcher clones), so - repeated selector applications can recover the loader's original - routing when the user picks "default". + Stored on patcher.model (which is shared with the input patcher), so + later "default" selections can recover the loader's original routing. + Only the first Select on a given chain writes these attrs; subsequent + deepclones inherit them onto their freshly-loaded model below. """ if not hasattr(patcher.model, "_select_base_load_device"): patcher.model._select_base_load_device = patcher.load_device patcher.model._select_base_offload_device = patcher.offload_device -def _apply_patcher_device(patcher: ModelPatcher, resolved, base_offload_override=None): - """Apply *resolved* to a freshly-cloned patcher; respect base devices on default. +def _propagate_base_devices(src_model, dst_model): + """Carry the loader-original device attrs onto the freshly-deepcloned model.""" + if hasattr(src_model, "_select_base_load_device") and not hasattr(dst_model, "_select_base_load_device"): + dst_model._select_base_load_device = src_model._select_base_load_device + dst_model._select_base_offload_device = src_model._select_base_offload_device + + +def _retarget_patcher(patcher: ModelPatcher, target_load_device, target_offload_device): + """Return a patcher whose actual model weights live on *target_load_device*. - Returns the (possibly newly-replaced) patcher. For CPU on a dynamic - patcher, also tries to downgrade to a plain ModelPatcher so the - dynamic-only code paths are bypassed (best-effort: silently keeps - the dynamic patcher if downgrade is not supported). + If *patcher* is already on *target_load_device* we just retarget the + (already-cloned) patcher's metadata in place. Otherwise we call + :meth:`ModelPatcher.deepclone_multigpu` to spawn a fresh model from + the loader's ``cached_patcher_init`` factory -- the only safe way to + move weights that may already be partially loaded onto another device. + + NOTE: reusing the input patcher's model when the requested device + matches its current load_device is a deliberate fast path. Anything + that has already mutated the original model (e.g. a prior KSampler + invocation on the same model) will be observed here. This is by + design and documented on the SelectXDeviceNode docstrings -- placing + Select X Device after a node that consumes the same model is not + recommended. + """ + if patcher.load_device == target_load_device: + # Fast path: weights already on the desired device, just update offload. + patcher.offload_device = target_offload_device + return patcher + src_model = patcher.model + patcher = patcher.deepclone_multigpu(new_load_device=target_load_device) + patcher.offload_device = target_offload_device + _propagate_base_devices(src_model, patcher.model) + if hasattr(patcher, "register_load_device"): + patcher.register_load_device(patcher.load_device) + return patcher + + +def _apply_patcher_device(patcher: ModelPatcher, resolved, base_offload_override=None): + """Resolve the requested device and produce a patcher routed there. + + For "default" we restore the loader's original load/offload pair. + For CPU we pin both load and offload to CPU (and, on a dynamic + patcher, downgrade to a plain ModelPatcher so the dynamic-only + code paths are bypassed). + For an explicit GPU we keep the loader's original offload but + target the requested load device; if that differs from the current + load device the patcher is deepcloned onto the new device. """ _remember_base_devices(patcher) base_load = patcher.model._select_base_load_device base_offload = base_offload_override if base_offload_override is not None else patcher.model._select_base_offload_device if resolved is None: - # "default" -> reset routing to whatever the loader produced - patcher.load_device = base_load - patcher.offload_device = base_offload - elif resolved.type == "cpu": + # "default" -> route back to the loader's original devices. + return _retarget_patcher(patcher, base_load, base_offload) + if resolved.type == "cpu": if patcher.is_dynamic(): - try: - patcher = patcher.clone(disable_dynamic=True) - except Exception: - # Downgrade unavailable (no cached_patcher_init); fall - # back to the existing dynamic patcher. - pass + # clone(disable_dynamic=True) requires cached_patcher_init; let the + # exception surface to the caller (Select*DeviceNode.execute), which + # will translate it into a passthrough+log so unsupported loaders + # don't hard-fail the workflow. + patcher = patcher.clone(disable_dynamic=True) patcher.load_device = resolved patcher.offload_device = resolved - else: - patcher.load_device = resolved - patcher.offload_device = base_offload - - if hasattr(patcher, "register_load_device"): - patcher.register_load_device(patcher.load_device) - return patcher + return patcher + return _retarget_patcher(patcher, resolved, base_offload) def _prune_multigpu_collision(model: ModelPatcher, primary_device): @@ -122,6 +156,12 @@ class SelectModelDeviceNode(io.ComfyNode): - "gpu:N" pins the load device to the Nth available GPU; the offload device is restored to the loader's original choice. + When the requested device differs from the device the input model is + already on, a fresh model is spawned via the loader's reload factory + (cached_patcher_init) so the new patcher owns independent weights on + the new device. Loaders that don't support multigpu (no factory) will + cause the node to pass through unchanged with a warning. + If the workflow already has MultiGPU CFG Split applied and the chosen GPU collides with one of the existing multigpu clones, that clone is dropped so two patchers don't end up bound to the same device. @@ -130,6 +170,13 @@ class SelectModelDeviceNode(io.ComfyNode): (e.g. a workflow built on a 2-GPU box opened on a 1-GPU box), the node passes the model through unchanged and logs a message instead of failing. + + NOTE: Placing Select Model Device *after* a node that has already + consumed the same model (e.g. a KSampler that ran on this model on + the original device) is not recommended -- any state the prior + consumer mutated on the original model will be observed when the + selected device matches the original (fast path). Place Select Model + Device before any consumer of the model. """ @classmethod @@ -161,7 +208,11 @@ def execute(cls, model: ModelPatcher, device: str = "default") -> io.NodeOutput: if resolved is None and device not in (None, "default"): logging.info(f"Select Model Device: requested device '{device}' not available, passing through unchanged.") return io.NodeOutput(model) - model = _apply_patcher_device(model, resolved) + try: + model = _apply_patcher_device(model, resolved) + except RuntimeError as e: + logging.warning(f"Select Model Device: cannot retarget model, passing through unchanged. ({e})") + return io.NodeOutput(model) if resolved is not None: _prune_multigpu_collision(model, model.load_device) return io.NodeOutput(model) @@ -208,7 +259,10 @@ def execute(cls, clip: CLIP, device: str = "default") -> io.NodeOutput: if resolved is None and device not in (None, "default"): logging.info(f"Select CLIP Device: requested device '{device}' not available, passing through unchanged.") return io.NodeOutput(clip) - clip.patcher = _apply_patcher_device(clip.patcher, resolved) + try: + clip.patcher = _apply_patcher_device(clip.patcher, resolved) + except RuntimeError as e: + logging.warning(f"Select CLIP Device: cannot retarget CLIP, passing through unchanged. ({e})") return io.NodeOutput(clip) @@ -263,13 +317,19 @@ def execute(cls, vae: VAE, device: str = "default") -> io.NodeOutput: if resolved is not None and resolved.type == "cpu": logging.info("Select VAE Device: CPU is not a supported choice, passing through unchanged.") return io.NodeOutput(vae) - vae.patcher = _apply_patcher_device( - vae.patcher, resolved, - base_offload_override=comfy.model_management.vae_offload_device(), - ) - # VAE caches the working device separately from its patcher. if not hasattr(vae, "_select_base_device"): vae._select_base_device = vae.device + try: + vae.patcher = _apply_patcher_device( + vae.patcher, resolved, + base_offload_override=comfy.model_management.vae_offload_device(), + ) + except RuntimeError as e: + logging.warning(f"Select VAE Device: cannot retarget VAE, passing through unchanged. ({e})") + return io.NodeOutput(vae) + # Keep VAE wrapper in sync with whatever model the patcher now owns; + # deepclone_multigpu may have produced a fresh first_stage_model. + vae.first_stage_model = vae.patcher.model vae.device = vae._select_base_device if resolved is None else resolved return io.NodeOutput(vae) diff --git a/nodes.py b/nodes.py index d1e9a2511a63..fd4365c90bc3 100644 --- a/nodes.py +++ b/nodes.py @@ -795,6 +795,7 @@ def INPUT_TYPES(s): #TODO: scale factor? def load_vae(self, vae_name): metadata = None + vae_path = None if vae_name == "pixel_space": sd = {} sd["pixel_space_vae"] = torch.tensor(1.0) @@ -813,6 +814,14 @@ def load_vae(self, vae_name): metadata["tae_latent_channels"] = 128 vae = comfy.sd.VAE(sd=sd, metadata=metadata) vae.throw_exception_if_invalid() + # Register a reload factory on the patcher so multigpu deepclones + # (Select VAE Device, future MultiGPU VAE work-units) can produce + # per-device clones from the same loader context. Only set when we + # actually have a single backing file -- pixel_space and the + # image TAESDs (composed from separate encoder/decoder files via + # load_taesd) are not addressable by a single vae_path. + if vae_path is not None: + vae.patcher.cached_patcher_init = (comfy.sd.load_vae_patcher, (vae_path, metadata, None)) return (vae,) class ControlNetLoader: From ac5b7e8bd69f8d2065aa2006251d9244e60a2b18 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sun, 24 May 2026 17:17:08 -0700 Subject: [PATCH 89/90] multigpu: drop unused copy import; sync requirements.txt with master Amp-Thread-ID: https://ampcode.com/threads/T-019e5783-b810-74b1-8ca9-09d675de1479 Co-authored-by: Amp --- comfy/model_patcher.py | 1 - requirements.txt | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index c68a52cc2b69..00a15fa63b0c 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -23,7 +23,6 @@ import logging import math import uuid -import copy from typing import Callable, Optional import torch diff --git a/requirements.txt b/requirements.txt index 381e7d05fdfe..a22fa50ad9f4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.43.18 +comfyui-frontend-package==1.44.19 comfyui-workflow-templates==0.9.82 comfyui-embedded-docs==0.5.0 torch @@ -23,7 +23,7 @@ SQLAlchemy>=2.0.0 filelock av>=14.2.0 comfy-kitchen>=0.2.8 -comfy-aimdo==0.4.4 +comfy-aimdo==0.4.5 requests simpleeval>=1.0.0 blake3 From 7d958e18addd46e03aa5a8f5acdb92c58e0b3308 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 25 May 2026 18:13:20 -0700 Subject: [PATCH 90/90] multigpu: fix CPU SelectModelDevice slowness + MGCS reuse stripping is_multigpu_base_clone Two issues surfaced while testing the worksplit-multigpu PR: 1. Select Model Device -> CPU sampled at roughly 0.01 it/s, looking like an indefinite hang. PyTorch's CPU conv2d kernels do not have native fp16/bf16 paths and software-emulate at ~500-600x slower than fp32. Force fp32 compute via set_model_compute_dtype when the target is CPU; this keeps weights fp16 in memory and casts at use so peak memory does not double. 2. After running SelectModelDevice(gpu:N) and then activating MultiGPU CFG Split, only one GPU did real work even though both were loaded. create_multigpu_deepclones' reuse_loaded path matched the prior SelectModelDevice patcher (same clone_base_uuid, same device) but never set is_multigpu_base_clone, so the cond scheduler later filtered it out. Restrict reuse to clones that already carry the flag and always set it on the chosen patcher. Also fix a related sharp edge: extra-device selection used get_all_torch_devices(exclude_current=True), which assumes the primary lives on the process's current CUDA device. After SelectModelDevice(gpu:N) that is not true. Exclude the primary model's actual load_device instead. Amp-Thread-ID: https://ampcode.com/threads/T-019e6131-7175-719e-ad94-df5d65507375 Co-authored-by: Amp --- comfy/multigpu.py | 32 +++++++++++++++++++++++++------- comfy_extras/nodes_multigpu.py | 17 +++++++++++++++++ 2 files changed, 42 insertions(+), 7 deletions(-) diff --git a/comfy/multigpu.py b/comfy/multigpu.py index eff7d06499a9..e7f5b3d6fb75 100644 --- a/comfy/multigpu.py +++ b/comfy/multigpu.py @@ -131,7 +131,11 @@ def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options: skip_devices.add(mm.load_device) skip_devices = list(skip_devices) - full_extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True) + # Exclude the primary model's actual device, not the global current device: + # after SelectModelDevice(gpu:N) the primary may not live on the process's + # current CUDA device, and excluding the wrong device picks bad extras. + all_devices = comfy.model_management.get_all_torch_devices(exclude_current=False) + full_extra_devices = [d for d in all_devices if d != model.load_device] limit_extra_devices = full_extra_devices[:max_gpus-1] extra_devices = limit_extra_devices.copy() # exclude skipped devices @@ -143,16 +147,30 @@ def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options: for device in extra_devices: device_patcher = None if reuse_loaded: - # check if there are any ModelPatchers currently loaded that could be referenced here after a clone + # Only reuse a previously-loaded MultiGPU clone. A SelectModelDevice + # patcher on the same device shares clone_base_uuid but has + # is_multigpu_base_clone=False, which would later be filtered out by + # prepare_model_patcher_multigpu_clones() and silently shrink the + # work split back to one GPU. loaded_models: list[ModelPatcher] = comfy.model_management.loaded_models() for lm in loaded_models: - if lm.model is not None and lm.clone_base_uuid == model.clone_base_uuid and lm.load_device == device: - device_patcher = lm.clone() - logging.info(f"Reusing loaded deepclone of {device_patcher.model.__class__.__name__} for {device}") - break + if lm.model is None: + continue + if lm.load_device != device: + continue + if lm.clone_base_uuid != model.clone_base_uuid: + continue + if not getattr(lm, "is_multigpu_base_clone", False): + continue + device_patcher = lm.clone() + logging.info(f"Reusing loaded multigpu deepclone of {device_patcher.model.__class__.__name__} for {device}") + break if device_patcher is None: device_patcher = model.deepclone_multigpu(new_load_device=device) - device_patcher.is_multigpu_base_clone = True + # Always flag the clone; whether reused or freshly deepcloned, it must + # advertise itself as a MultiGPU base clone so the cond scheduler picks + # it up in prepare_model_patcher_multigpu_clones(). + device_patcher.is_multigpu_base_clone = True multigpu_models = model.get_additional_models_with_key("multigpu") multigpu_models.append(device_patcher) model.set_additional_models("multigpu", multigpu_models) diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py index d39cca3f8f68..2bd752b7da47 100644 --- a/comfy_extras/nodes_multigpu.py +++ b/comfy_extras/nodes_multigpu.py @@ -11,6 +11,8 @@ if TYPE_CHECKING: from comfy.model_patcher import ModelPatcher from comfy.sd import CLIP, VAE +import torch + import comfy.model_management import comfy.multigpu @@ -46,6 +48,19 @@ def execute(cls, model: ModelPatcher, max_gpus: int) -> io.NodeOutput: return io.NodeOutput(model) +def _force_fp32_cpu_compute(patcher: ModelPatcher): + """Force fp32 inference dtype for CPU. + + PyTorch's CPU conv2d kernels fall back to software emulation for fp16/bf16 + and run ~500-600x slower than fp32, which makes a normal-sized workflow + look frozen for hours. Routing through set_model_compute_dtype leaves the + weights as-is and casts at use, so peak memory does not blow up.""" + dtype = patcher.model_dtype() + if dtype in (torch.float16, torch.bfloat16): + logging.info(f"Select Model Device: using fp32 compute dtype for CPU inference (model dtype was {dtype}).") + patcher.set_model_compute_dtype(torch.float32) + + def _remember_base_devices(patcher: ModelPatcher): """Stash the original load/offload device on the underlying model. @@ -214,6 +229,8 @@ def execute(cls, model: ModelPatcher, device: str = "default") -> io.NodeOutput: logging.warning(f"Select Model Device: cannot retarget model, passing through unchanged. ({e})") return io.NodeOutput(model) if resolved is not None: + if resolved.type == "cpu": + _force_fp32_cpu_compute(model) _prune_multigpu_collision(model, model.load_device) return io.NodeOutput(model)