diff --git a/README.md b/README.md index b0f62695b053..6d09758c0ead 100644 --- a/README.md +++ b/README.md @@ -212,6 +212,8 @@ Python 3.14 works but you may encounter issues with the torch compile node. The Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12 +torch 2.4 and above is supported but some features might only work on newer versions. We generally recommend using the latest major version of pytorch unless it is less than 2 weeks old. + ### Instructions: Git clone this repo. diff --git a/app/model_manager.py b/app/model_manager.py index ab36bca74414..f124d1117f64 100644 --- a/app/model_manager.py +++ b/app/model_manager.py @@ -44,7 +44,7 @@ async def get_model_folders(request): @routes.get("/experiment/models/{folder}") async def get_all_models(request): folder = request.match_info.get("folder", None) - if not folder in folder_paths.folder_names_and_paths: + if folder not in folder_paths.folder_names_and_paths: return web.Response(status=404) files = self.get_model_file_list(folder) return web.json_response(files) @@ -55,7 +55,7 @@ async def get_model_preview(request): path_index = int(request.match_info.get("path_index", None)) filename = request.match_info.get("filename", None) - if not folder_name in folder_paths.folder_names_and_paths: + if folder_name not in folder_paths.folder_names_and_paths: return web.Response(status=404) folders = folder_paths.folder_names_and_paths[folder_name] diff --git a/comfy/clip_model.py b/comfy/clip_model.py index 7c0cadab5bc9..e88872728f4a 100644 --- a/comfy/clip_model.py +++ b/comfy/clip_model.py @@ -2,6 +2,25 @@ from comfy.ldm.modules.attention import optimized_attention_for_device import comfy.ops +def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True): + image = image[:, :, :, :3] if image.shape[3] > 3 else image + mean = torch.tensor(mean, device=image.device, dtype=image.dtype) + std = torch.tensor(std, device=image.device, dtype=image.dtype) + image = image.movedim(-1, 1) + if not (image.shape[2] == size and image.shape[3] == size): + if crop: + scale = (size / min(image.shape[2], image.shape[3])) + scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3])) + else: + scale_size = (size, size) + + image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True) + h = (image.shape[2] - size)//2 + w = (image.shape[3] - size)//2 + image = image[:,:,h:h+size,w:w+size] + image = torch.clip((255. * image), 0, 255).round() / 255.0 + return (image - mean.view([3,1,1])) / std.view([3,1,1]) + class CLIPAttention(torch.nn.Module): def __init__(self, embed_dim, heads, dtype, device, operations): super().__init__() diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 447b1ce4abef..d5fc5349748c 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -1,6 +1,5 @@ from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace import os -import torch import json import logging @@ -17,24 +16,7 @@ def __getitem__(self, key): def __setitem__(self, key, item): setattr(self, key, item) -def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True): - image = image[:, :, :, :3] if image.shape[3] > 3 else image - mean = torch.tensor(mean, device=image.device, dtype=image.dtype) - std = torch.tensor(std, device=image.device, dtype=image.dtype) - image = image.movedim(-1, 1) - if not (image.shape[2] == size and image.shape[3] == size): - if crop: - scale = (size / min(image.shape[2], image.shape[3])) - scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3])) - else: - scale_size = (size, size) - - image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True) - h = (image.shape[2] - size)//2 - w = (image.shape[3] - size)//2 - image = image[:,:,h:h+size,w:w+size] - image = torch.clip((255. * image), 0, 255).round() / 255.0 - return (image - mean.view([3,1,1])) / std.view([3,1,1]) +clip_preprocess = comfy.clip_model.clip_preprocess # Prevent some stuff from breaking, TODO: remove eventually IMAGE_ENCODERS = { "clip_vision_model": comfy.clip_model.CLIPVisionModelProjection, @@ -73,7 +55,7 @@ def get_sd(self): def encode_image(self, image, crop=True): comfy.model_management.load_model_gpu(self.patcher) - pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float() + pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float() out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2) outputs = Output() diff --git a/comfy/context_windows.py b/comfy/context_windows.py index 1e0f86026344..2f82d51dadee 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -188,6 +188,12 @@ def get_resized_cond(self, cond_in: list[dict], x_in: torch.Tensor, window: Inde audio_cond = cond_value.cond if audio_cond.ndim > 1 and audio_cond.size(1) == x_in.size(self.dim): new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(audio_cond, device, dim=1)) + # Handle vace_context (temporal dim is 3) + elif cond_key == "vace_context" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): + vace_cond = cond_value.cond + if vace_cond.ndim >= 4 and vace_cond.size(3) == x_in.size(self.dim): + sliced_vace = window.get_tensor(vace_cond, device, dim=3, retain_index_list=self.cond_retain_index_list) + new_cond_item[cond_key] = cond_value._copy_with(sliced_vace) # if has cond that is a Tensor, check if needs to be subset elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): if (self.dim < cond_value.cond.ndim and cond_value.cond.size(self.dim) == x_in.size(self.dim)) or \ diff --git a/comfy/hooks.py b/comfy/hooks.py index 9d0731072902..1a76c7ba41f6 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -527,7 +527,8 @@ def prepare_current_keyframe(self, curr_t: float, transformer_options: dict[str, if self._current_keyframe.get_effective_guarantee_steps(max_sigma) > 0: break # if eval_c is outside the percent range, stop looking further - else: break + else: + break # update steps current context is used self._current_used_steps += 1 # update current timestep this was performed on diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 1ba9edad73a1..0949dee44cf0 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -74,6 +74,9 @@ def get_ancestral_step(sigma_from, sigma_to, eta=1.): def default_noise_sampler(x, seed=None): if seed is not None: + if x.device == torch.device("cpu"): + seed += 1 + generator = torch.Generator(device=x.device) generator.manual_seed(seed) else: diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index f1ca0151e0b2..9bbe30b53f75 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -407,6 +407,9 @@ def __init__(self): self.latent_rgb_factors_bias = [-0.0571, -0.1657, -0.2512] +class LTXAV(LTXV): + pass + class HunyuanVideo(LatentFormat): latent_channels = 16 latent_dimensions = 3 diff --git a/comfy/ldm/chroma_radiance/model.py b/comfy/ldm/chroma_radiance/model.py index 70d1738890cb..4fb56165e338 100644 --- a/comfy/ldm/chroma_radiance/model.py +++ b/comfy/ldm/chroma_radiance/model.py @@ -270,7 +270,7 @@ def radiance_get_override_params(self, overrides: dict) -> ChromaRadianceParams: bad_keys = tuple( k for k, v in overrides.items() - if type(v) != type(getattr(params, k)) and (v is not None or k not in nullable_keys) + if not isinstance(v, type(getattr(params, k))) and (v is not None or k not in nullable_keys) ) if bad_keys: e = f"Invalid value(s) in transformer_options chroma_radiance_options: {', '.join(bad_keys)}" diff --git a/comfy/ldm/hunyuan_video/upsampler.py b/comfy/ldm/hunyuan_video/upsampler.py index 85f515f67abc..d9e76922f527 100644 --- a/comfy/ldm/hunyuan_video/upsampler.py +++ b/comfy/ldm/hunyuan_video/upsampler.py @@ -3,7 +3,8 @@ import torch.nn.functional as F from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm -import model_management, model_patcher +import model_management +import model_patcher class SRResidualCausalBlock3D(nn.Module): def __init__(self, channels: int): diff --git a/comfy/ldm/lightricks/av_model.py b/comfy/ldm/lightricks/av_model.py new file mode 100644 index 000000000000..759535501f23 --- /dev/null +++ b/comfy/ldm/lightricks/av_model.py @@ -0,0 +1,837 @@ +from typing import Tuple +import torch +import torch.nn as nn +from comfy.ldm.lightricks.model import ( + CrossAttention, + FeedForward, + AdaLayerNormSingle, + PixArtAlphaTextProjection, + LTXVModel, +) +from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier +import comfy.ldm.common_dit + +class BasicAVTransformerBlock(nn.Module): + def __init__( + self, + v_dim, + a_dim, + v_heads, + a_heads, + vd_head, + ad_head, + v_context_dim=None, + a_context_dim=None, + attn_precision=None, + dtype=None, + device=None, + operations=None, + ): + super().__init__() + + self.attn_precision = attn_precision + + self.attn1 = CrossAttention( + query_dim=v_dim, + heads=v_heads, + dim_head=vd_head, + context_dim=None, + attn_precision=self.attn_precision, + dtype=dtype, + device=device, + operations=operations, + ) + self.audio_attn1 = CrossAttention( + query_dim=a_dim, + heads=a_heads, + dim_head=ad_head, + context_dim=None, + attn_precision=self.attn_precision, + dtype=dtype, + device=device, + operations=operations, + ) + + self.attn2 = CrossAttention( + query_dim=v_dim, + context_dim=v_context_dim, + heads=v_heads, + dim_head=vd_head, + attn_precision=self.attn_precision, + dtype=dtype, + device=device, + operations=operations, + ) + self.audio_attn2 = CrossAttention( + query_dim=a_dim, + context_dim=a_context_dim, + heads=a_heads, + dim_head=ad_head, + attn_precision=self.attn_precision, + dtype=dtype, + device=device, + operations=operations, + ) + + # Q: Video, K,V: Audio + self.audio_to_video_attn = CrossAttention( + query_dim=v_dim, + context_dim=a_dim, + heads=a_heads, + dim_head=ad_head, + attn_precision=self.attn_precision, + dtype=dtype, + device=device, + operations=operations, + ) + + # Q: Audio, K,V: Video + self.video_to_audio_attn = CrossAttention( + query_dim=a_dim, + context_dim=v_dim, + heads=a_heads, + dim_head=ad_head, + attn_precision=self.attn_precision, + dtype=dtype, + device=device, + operations=operations, + ) + + self.ff = FeedForward( + v_dim, dim_out=v_dim, glu=True, dtype=dtype, device=device, operations=operations + ) + self.audio_ff = FeedForward( + a_dim, dim_out=a_dim, glu=True, dtype=dtype, device=device, operations=operations + ) + + self.scale_shift_table = nn.Parameter(torch.empty(6, v_dim, device=device, dtype=dtype)) + self.audio_scale_shift_table = nn.Parameter( + torch.empty(6, a_dim, device=device, dtype=dtype) + ) + + self.scale_shift_table_a2v_ca_audio = nn.Parameter( + torch.empty(5, a_dim, device=device, dtype=dtype) + ) + self.scale_shift_table_a2v_ca_video = nn.Parameter( + torch.empty(5, v_dim, device=device, dtype=dtype) + ) + + def get_ada_values( + self, scale_shift_table: torch.Tensor, batch_size: int, timestep: torch.Tensor, indices: slice = slice(None, None) + ): + num_ada_params = scale_shift_table.shape[0] + + ada_values = ( + scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=timestep.device, dtype=timestep.dtype) + + timestep.reshape(batch_size, timestep.shape[1], num_ada_params, -1)[:, :, indices, :] + ).unbind(dim=2) + return ada_values + + def get_av_ca_ada_values( + self, + scale_shift_table: torch.Tensor, + batch_size: int, + scale_shift_timestep: torch.Tensor, + gate_timestep: torch.Tensor, + num_scale_shift_values: int = 4, + ): + scale_shift_ada_values = self.get_ada_values( + scale_shift_table[:num_scale_shift_values, :], + batch_size, + scale_shift_timestep, + ) + gate_ada_values = self.get_ada_values( + scale_shift_table[num_scale_shift_values:, :], + batch_size, + gate_timestep, + ) + + scale_shift_chunks = [t.squeeze(2) for t in scale_shift_ada_values] + gate_ada_values = [t.squeeze(2) for t in gate_ada_values] + + return (*scale_shift_chunks, *gate_ada_values) + + def forward( + self, + x: Tuple[torch.Tensor, torch.Tensor], + v_context=None, + a_context=None, + attention_mask=None, + v_timestep=None, + a_timestep=None, + v_pe=None, + a_pe=None, + v_cross_pe=None, + a_cross_pe=None, + v_cross_scale_shift_timestep=None, + a_cross_scale_shift_timestep=None, + v_cross_gate_timestep=None, + a_cross_gate_timestep=None, + transformer_options=None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + run_vx = transformer_options.get("run_vx", True) + run_ax = transformer_options.get("run_ax", True) + + vx, ax = x + run_ax = run_ax and ax.numel() > 0 + run_a2v = run_vx and transformer_options.get("a2v_cross_attn", True) and ax.numel() > 0 + run_v2a = run_ax and transformer_options.get("v2a_cross_attn", True) + + if run_vx: + vshift_msa, vscale_msa, vgate_msa = ( + self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 3)) + ) + + norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa + vx += self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options) * vgate_msa + vx += self.attn2( + comfy.ldm.common_dit.rms_norm(vx), + context=v_context, + mask=attention_mask, + transformer_options=transformer_options, + ) + + del vshift_msa, vscale_msa, vgate_msa + + if run_ax: + ashift_msa, ascale_msa, agate_msa = ( + self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(0, 3)) + ) + + norm_ax = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_msa) + ashift_msa + ax += ( + self.audio_attn1(norm_ax, pe=a_pe, transformer_options=transformer_options) + * agate_msa + ) + ax += self.audio_attn2( + comfy.ldm.common_dit.rms_norm(ax), + context=a_context, + mask=attention_mask, + transformer_options=transformer_options, + ) + + del ashift_msa, ascale_msa, agate_msa + + # Audio - Video cross attention. + if run_a2v or run_v2a: + # norm3 + vx_norm3 = comfy.ldm.common_dit.rms_norm(vx) + ax_norm3 = comfy.ldm.common_dit.rms_norm(ax) + + ( + scale_ca_audio_hidden_states_a2v, + shift_ca_audio_hidden_states_a2v, + scale_ca_audio_hidden_states_v2a, + shift_ca_audio_hidden_states_v2a, + gate_out_v2a, + ) = self.get_av_ca_ada_values( + self.scale_shift_table_a2v_ca_audio, + ax.shape[0], + a_cross_scale_shift_timestep, + a_cross_gate_timestep, + ) + + ( + scale_ca_video_hidden_states_a2v, + shift_ca_video_hidden_states_a2v, + scale_ca_video_hidden_states_v2a, + shift_ca_video_hidden_states_v2a, + gate_out_a2v, + ) = self.get_av_ca_ada_values( + self.scale_shift_table_a2v_ca_video, + vx.shape[0], + v_cross_scale_shift_timestep, + v_cross_gate_timestep, + ) + + if run_a2v: + vx_scaled = ( + vx_norm3 * (1 + scale_ca_video_hidden_states_a2v) + + shift_ca_video_hidden_states_a2v + ) + ax_scaled = ( + ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v) + + shift_ca_audio_hidden_states_a2v + ) + vx += ( + self.audio_to_video_attn( + vx_scaled, + context=ax_scaled, + pe=v_cross_pe, + k_pe=a_cross_pe, + transformer_options=transformer_options, + ) + * gate_out_a2v + ) + + del gate_out_a2v + del scale_ca_video_hidden_states_a2v,\ + shift_ca_video_hidden_states_a2v,\ + scale_ca_audio_hidden_states_a2v,\ + shift_ca_audio_hidden_states_a2v,\ + + if run_v2a: + ax_scaled = ( + ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a) + + shift_ca_audio_hidden_states_v2a + ) + vx_scaled = ( + vx_norm3 * (1 + scale_ca_video_hidden_states_v2a) + + shift_ca_video_hidden_states_v2a + ) + ax += ( + self.video_to_audio_attn( + ax_scaled, + context=vx_scaled, + pe=a_cross_pe, + k_pe=v_cross_pe, + transformer_options=transformer_options, + ) + * gate_out_v2a + ) + + del gate_out_v2a + del scale_ca_video_hidden_states_v2a,\ + shift_ca_video_hidden_states_v2a,\ + scale_ca_audio_hidden_states_v2a,\ + shift_ca_audio_hidden_states_v2a + + if run_vx: + vshift_mlp, vscale_mlp, vgate_mlp = ( + self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(3, None)) + ) + + vx_scaled = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_mlp) + vshift_mlp + vx += self.ff(vx_scaled) * vgate_mlp + del vshift_mlp, vscale_mlp, vgate_mlp + + if run_ax: + ashift_mlp, ascale_mlp, agate_mlp = ( + self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(3, None)) + ) + + ax_scaled = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_mlp) + ashift_mlp + ax += self.audio_ff(ax_scaled) * agate_mlp + + del ashift_mlp, ascale_mlp, agate_mlp + + + return vx, ax + + +class LTXAVModel(LTXVModel): + """LTXAV model for audio-video generation.""" + + def __init__( + self, + in_channels=128, + audio_in_channels=128, + cross_attention_dim=4096, + audio_cross_attention_dim=2048, + attention_head_dim=128, + audio_attention_head_dim=64, + num_attention_heads=32, + audio_num_attention_heads=32, + caption_channels=3840, + num_layers=48, + positional_embedding_theta=10000.0, + positional_embedding_max_pos=[20, 2048, 2048], + audio_positional_embedding_max_pos=[20], + causal_temporal_positioning=False, + vae_scale_factors=(8, 32, 32), + use_middle_indices_grid=False, + timestep_scale_multiplier=1000.0, + av_ca_timestep_scale_multiplier=1.0, + dtype=None, + device=None, + operations=None, + **kwargs, + ): + # Store audio-specific parameters + self.audio_in_channels = audio_in_channels + self.audio_cross_attention_dim = audio_cross_attention_dim + self.audio_attention_head_dim = audio_attention_head_dim + self.audio_num_attention_heads = audio_num_attention_heads + self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos + + # Calculate audio dimensions + self.audio_inner_dim = audio_num_attention_heads * audio_attention_head_dim + self.audio_out_channels = audio_in_channels + + # Audio-specific constants + self.num_audio_channels = 8 + self.audio_frequency_bins = 16 + + self.av_ca_timestep_scale_multiplier = av_ca_timestep_scale_multiplier + + super().__init__( + in_channels=in_channels, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + num_attention_heads=num_attention_heads, + caption_channels=caption_channels, + num_layers=num_layers, + positional_embedding_theta=positional_embedding_theta, + positional_embedding_max_pos=positional_embedding_max_pos, + causal_temporal_positioning=causal_temporal_positioning, + vae_scale_factors=vae_scale_factors, + use_middle_indices_grid=use_middle_indices_grid, + timestep_scale_multiplier=timestep_scale_multiplier, + dtype=dtype, + device=device, + operations=operations, + **kwargs, + ) + + def _init_model_components(self, device, dtype, **kwargs): + """Initialize LTXAV-specific components.""" + # Audio-specific projections + self.audio_patchify_proj = self.operations.Linear( + self.audio_in_channels, self.audio_inner_dim, bias=True, dtype=dtype, device=device + ) + + # Audio-specific AdaLN + self.audio_adaln_single = AdaLayerNormSingle( + self.audio_inner_dim, + use_additional_conditions=False, + dtype=dtype, + device=device, + operations=self.operations, + ) + + num_scale_shift_values = 4 + self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle( + self.inner_dim, + use_additional_conditions=False, + embedding_coefficient=num_scale_shift_values, + dtype=dtype, + device=device, + operations=self.operations, + ) + self.av_ca_a2v_gate_adaln_single = AdaLayerNormSingle( + self.inner_dim, + use_additional_conditions=False, + embedding_coefficient=1, + dtype=dtype, + device=device, + operations=self.operations, + ) + self.av_ca_audio_scale_shift_adaln_single = AdaLayerNormSingle( + self.audio_inner_dim, + use_additional_conditions=False, + embedding_coefficient=num_scale_shift_values, + dtype=dtype, + device=device, + operations=self.operations, + ) + self.av_ca_v2a_gate_adaln_single = AdaLayerNormSingle( + self.audio_inner_dim, + use_additional_conditions=False, + embedding_coefficient=1, + dtype=dtype, + device=device, + operations=self.operations, + ) + + # Audio caption projection + self.audio_caption_projection = PixArtAlphaTextProjection( + in_features=self.caption_channels, + hidden_size=self.audio_inner_dim, + dtype=dtype, + device=device, + operations=self.operations, + ) + + def _init_transformer_blocks(self, device, dtype, **kwargs): + """Initialize transformer blocks for LTXAV.""" + self.transformer_blocks = nn.ModuleList( + [ + BasicAVTransformerBlock( + v_dim=self.inner_dim, + a_dim=self.audio_inner_dim, + v_heads=self.num_attention_heads, + a_heads=self.audio_num_attention_heads, + vd_head=self.attention_head_dim, + ad_head=self.audio_attention_head_dim, + v_context_dim=self.cross_attention_dim, + a_context_dim=self.audio_cross_attention_dim, + dtype=dtype, + device=device, + operations=self.operations, + ) + for _ in range(self.num_layers) + ] + ) + + def _init_output_components(self, device, dtype): + """Initialize output components for LTXAV.""" + # Video output components + super()._init_output_components(device, dtype) + # Audio output components + self.audio_scale_shift_table = nn.Parameter( + torch.empty(2, self.audio_inner_dim, dtype=dtype, device=device) + ) + self.audio_norm_out = self.operations.LayerNorm( + self.audio_inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device + ) + self.audio_proj_out = self.operations.Linear( + self.audio_inner_dim, self.audio_out_channels, dtype=dtype, device=device + ) + self.a_patchifier = AudioPatchifier(1, start_end=True) + + def separate_audio_and_video_latents(self, x, audio_length): + """Separate audio and video latents from combined input.""" + # vx = x[:, : self.in_channels] + # ax = x[:, self.in_channels :] + # + # ax = ax.reshape(ax.shape[0], -1) + # ax = ax[:, : audio_length * self.num_audio_channels * self.audio_frequency_bins] + # + # ax = ax.reshape( + # ax.shape[0], self.num_audio_channels, audio_length, self.audio_frequency_bins + # ) + + vx = x[0] + ax = x[1] if len(x) > 1 else torch.zeros( + (vx.shape[0], self.num_audio_channels, 0, self.audio_frequency_bins), + device=vx.device, dtype=vx.dtype + ) + return vx, ax + + def recombine_audio_and_video_latents(self, vx, ax, target_shape=None): + if ax.numel() == 0: + return vx + else: + return [vx, ax] + """Recombine audio and video latents for output.""" + # if ax.device != vx.device or ax.dtype != vx.dtype: + # logging.warning("Audio and video latents are on different devices or dtypes.") + # ax = ax.to(device=vx.device, dtype=vx.dtype) + # logging.warning(f"Audio audio latent moved to device: {ax.device}, dtype: {ax.dtype}") + # + # ax = ax.reshape(ax.shape[0], -1) + # # pad to f x h x w of the video latents + # divisor = vx.shape[-1] * vx.shape[-2] * vx.shape[-3] + # if target_shape is None: + # repetitions = math.ceil(ax.shape[-1] / divisor) + # else: + # repetitions = target_shape[1] - vx.shape[1] + # padded_len = repetitions * divisor + # ax = F.pad(ax, (0, padded_len - ax.shape[-1])) + # ax = ax.reshape(ax.shape[0], -1, vx.shape[-3], vx.shape[-2], vx.shape[-1]) + # return torch.cat([vx, ax], dim=1) + + def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs): + """Process input for LTXAV - separate audio and video, then patchify.""" + audio_length = kwargs.get("audio_length", 0) + # Separate audio and video latents + vx, ax = self.separate_audio_and_video_latents(x, audio_length) + [vx, v_pixel_coords, additional_args] = super()._process_input( + vx, keyframe_idxs, denoise_mask, **kwargs + ) + + ax, a_latent_coords = self.a_patchifier.patchify(ax) + ax = self.audio_patchify_proj(ax) + + # additional_args.update({"av_orig_shape": list(x.shape)}) + return [vx, ax], [v_pixel_coords, a_latent_coords], additional_args + + def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs): + """Prepare timestep embeddings.""" + # TODO: some code reuse is needed here. + grid_mask = kwargs.get("grid_mask", None) + if grid_mask is not None: + timestep = timestep[:, grid_mask] + + timestep = timestep * self.timestep_scale_multiplier + v_timestep, v_embedded_timestep = self.adaln_single( + timestep.flatten(), + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_dtype, + ) + + # Second dimension is 1 or number of tokens (if timestep_per_token) + v_timestep = v_timestep.view(batch_size, -1, v_timestep.shape[-1]) + v_embedded_timestep = v_embedded_timestep.view( + batch_size, -1, v_embedded_timestep.shape[-1] + ) + + # Prepare audio timestep + a_timestep = kwargs.get("a_timestep") + if a_timestep is not None: + a_timestep = a_timestep * self.timestep_scale_multiplier + av_ca_factor = self.av_ca_timestep_scale_multiplier / self.timestep_scale_multiplier + + av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single( + a_timestep.flatten(), + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_dtype, + ) + av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single( + timestep.flatten(), + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_dtype, + ) + av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single( + timestep.flatten() * av_ca_factor, + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_dtype, + ) + av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single( + a_timestep.flatten() * av_ca_factor, + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_dtype, + ) + + a_timestep, a_embedded_timestep = self.audio_adaln_single( + a_timestep.flatten(), + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_dtype, + ) + a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1]) + a_embedded_timestep = a_embedded_timestep.view( + batch_size, -1, a_embedded_timestep.shape[-1] + ) + cross_av_timestep_ss = [ + av_ca_audio_scale_shift_timestep, + av_ca_video_scale_shift_timestep, + av_ca_a2v_gate_noise_timestep, + av_ca_v2a_gate_noise_timestep, + ] + cross_av_timestep_ss = list( + [t.view(batch_size, -1, t.shape[-1]) for t in cross_av_timestep_ss] + ) + else: + a_timestep = timestep + a_embedded_timestep = kwargs.get("embedded_timestep") + cross_av_timestep_ss = [] + + return [v_timestep, a_timestep, cross_av_timestep_ss], [ + v_embedded_timestep, + a_embedded_timestep, + ] + + def _prepare_context(self, context, batch_size, x, attention_mask=None): + vx = x[0] + ax = x[1] + v_context, a_context = torch.split( + context, int(context.shape[-1] / 2), len(context.shape) - 1 + ) + + v_context, attention_mask = super()._prepare_context( + v_context, batch_size, vx, attention_mask + ) + if self.audio_caption_projection is not None: + a_context = self.audio_caption_projection(a_context) + a_context = a_context.view(batch_size, -1, ax.shape[-1]) + + return [v_context, a_context], attention_mask + + def _prepare_positional_embeddings(self, pixel_coords, frame_rate, x_dtype): + v_pixel_coords = pixel_coords[0] + v_pe = super()._prepare_positional_embeddings(v_pixel_coords, frame_rate, x_dtype) + + a_latent_coords = pixel_coords[1] + a_pe = self._precompute_freqs_cis( + a_latent_coords, + dim=self.audio_inner_dim, + out_dtype=x_dtype, + max_pos=self.audio_positional_embedding_max_pos, + use_middle_indices_grid=self.use_middle_indices_grid, + num_attention_heads=self.audio_num_attention_heads, + ) + + # calculate positional embeddings for the middle of the token duration, to use in av cross attention layers. + max_pos = max( + self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0] + ) + v_pixel_coords = v_pixel_coords.to(torch.float32) + v_pixel_coords[:, 0] = v_pixel_coords[:, 0] * (1.0 / frame_rate) + av_cross_video_freq_cis = self._precompute_freqs_cis( + v_pixel_coords[:, 0:1, :], + dim=self.audio_cross_attention_dim, + out_dtype=x_dtype, + max_pos=[max_pos], + use_middle_indices_grid=True, + num_attention_heads=self.audio_num_attention_heads, + ) + av_cross_audio_freq_cis = self._precompute_freqs_cis( + a_latent_coords[:, 0:1, :], + dim=self.audio_cross_attention_dim, + out_dtype=x_dtype, + max_pos=[max_pos], + use_middle_indices_grid=True, + num_attention_heads=self.audio_num_attention_heads, + ) + + return [(v_pe, av_cross_video_freq_cis), (a_pe, av_cross_audio_freq_cis)] + + def _process_transformer_blocks( + self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs + ): + vx = x[0] + ax = x[1] + v_context = context[0] + a_context = context[1] + v_timestep = timestep[0] + a_timestep = timestep[1] + v_pe, av_cross_video_freq_cis = pe[0] + a_pe, av_cross_audio_freq_cis = pe[1] + + ( + av_ca_audio_scale_shift_timestep, + av_ca_video_scale_shift_timestep, + av_ca_a2v_gate_noise_timestep, + av_ca_v2a_gate_noise_timestep, + ) = timestep[2] + + """Process transformer blocks for LTXAV.""" + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + + # Process transformer blocks + for i, block in enumerate(self.transformer_blocks): + if ("double_block", i) in blocks_replace: + + def block_wrap(args): + out = {} + out["img"] = block( + args["img"], + v_context=args["v_context"], + a_context=args["a_context"], + attention_mask=args["attention_mask"], + v_timestep=args["v_timestep"], + a_timestep=args["a_timestep"], + v_pe=args["v_pe"], + a_pe=args["a_pe"], + v_cross_pe=args["v_cross_pe"], + a_cross_pe=args["a_cross_pe"], + v_cross_scale_shift_timestep=args["v_cross_scale_shift_timestep"], + a_cross_scale_shift_timestep=args["a_cross_scale_shift_timestep"], + v_cross_gate_timestep=args["v_cross_gate_timestep"], + a_cross_gate_timestep=args["a_cross_gate_timestep"], + transformer_options=args["transformer_options"], + ) + return out + + out = blocks_replace[("double_block", i)]( + { + "img": (vx, ax), + "v_context": v_context, + "a_context": a_context, + "attention_mask": attention_mask, + "v_timestep": v_timestep, + "a_timestep": a_timestep, + "v_pe": v_pe, + "a_pe": a_pe, + "v_cross_pe": av_cross_video_freq_cis, + "a_cross_pe": av_cross_audio_freq_cis, + "v_cross_scale_shift_timestep": av_ca_video_scale_shift_timestep, + "a_cross_scale_shift_timestep": av_ca_audio_scale_shift_timestep, + "v_cross_gate_timestep": av_ca_a2v_gate_noise_timestep, + "a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep, + "transformer_options": transformer_options, + }, + {"original_block": block_wrap}, + ) + vx, ax = out["img"] + else: + vx, ax = block( + (vx, ax), + v_context=v_context, + a_context=a_context, + attention_mask=attention_mask, + v_timestep=v_timestep, + a_timestep=a_timestep, + v_pe=v_pe, + a_pe=a_pe, + v_cross_pe=av_cross_video_freq_cis, + a_cross_pe=av_cross_audio_freq_cis, + v_cross_scale_shift_timestep=av_ca_video_scale_shift_timestep, + a_cross_scale_shift_timestep=av_ca_audio_scale_shift_timestep, + v_cross_gate_timestep=av_ca_a2v_gate_noise_timestep, + a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep, + transformer_options=transformer_options, + ) + + return [vx, ax] + + def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs): + vx = x[0] + ax = x[1] + v_embedded_timestep = embedded_timestep[0] + a_embedded_timestep = embedded_timestep[1] + vx = super()._process_output(vx, v_embedded_timestep, keyframe_idxs, **kwargs) + + # Process audio output + a_scale_shift_values = ( + self.audio_scale_shift_table[None, None].to(device=a_embedded_timestep.device, dtype=a_embedded_timestep.dtype) + + a_embedded_timestep[:, :, None] + ) + a_shift, a_scale = a_scale_shift_values[:, :, 0], a_scale_shift_values[:, :, 1] + + ax = self.audio_norm_out(ax) + ax = ax * (1 + a_scale) + a_shift + ax = self.audio_proj_out(ax) + + # Unpatchify audio + ax = self.a_patchifier.unpatchify( + ax, channels=self.num_audio_channels, freq=self.audio_frequency_bins + ) + + # Recombine audio and video + original_shape = kwargs.get("av_orig_shape") + return self.recombine_audio_and_video_latents(vx, ax, original_shape) + + def forward( + self, + x, + timestep, + context, + attention_mask=None, + frame_rate=25, + transformer_options={}, + keyframe_idxs=None, + **kwargs, + ): + """ + Forward pass for LTXAV model. + + Args: + x: Combined audio-video input tensor + timestep: Tuple of (video_timestep, audio_timestep) or single timestep + context: Context tensor (e.g., text embeddings) + attention_mask: Attention mask tensor + frame_rate: Frame rate for temporal processing + transformer_options: Additional options for transformer blocks + keyframe_idxs: Keyframe indices for temporal processing + **kwargs: Additional keyword arguments including audio_length + + Returns: + Combined audio-video output tensor + """ + # Handle timestep format + if isinstance(timestep, (tuple, list)) and len(timestep) == 2: + v_timestep, a_timestep = timestep + kwargs["a_timestep"] = a_timestep + timestep = v_timestep + else: + kwargs["a_timestep"] = timestep + + # Call parent forward method + return super().forward( + x, + timestep, + context, + attention_mask, + frame_rate, + transformer_options, + keyframe_idxs, + **kwargs, + ) diff --git a/comfy/ldm/lightricks/embeddings_connector.py b/comfy/ldm/lightricks/embeddings_connector.py new file mode 100644 index 000000000000..f7a43f3c32ad --- /dev/null +++ b/comfy/ldm/lightricks/embeddings_connector.py @@ -0,0 +1,305 @@ +import math +from typing import Optional + +import comfy.ldm.common_dit +import torch +from comfy.ldm.lightricks.model import ( + CrossAttention, + FeedForward, + generate_freq_grid_np, + interleaved_freqs_cis, + split_freqs_cis, +) +from torch import nn + + +class BasicTransformerBlock1D(nn.Module): + r""" + A basic Transformer block. + + Parameters: + + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + standardization_norm (`str`, *optional*, defaults to `"layer_norm"`): The type of pre-normalization to use. Can be `"layer_norm"` or `"rms_norm"`. + norm_eps (`float`, *optional*, defaults to 1e-5): Epsilon value for normalization layers. + qk_norm (`str`, *optional*, defaults to None): + Set to 'layer_norm' or `rms_norm` to perform query and key normalization. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + ff_inner_dim (`int`, *optional*): Dimension of the inner feed-forward layer. If not provided, defaults to `dim * 4`. + ff_bias (`bool`, *optional*, defaults to `True`): Whether to use bias in the feed-forward layer. + attention_out_bias (`bool`, *optional*, defaults to `True`): Whether to use bias in the attention output layer. + use_rope (`bool`, *optional*, defaults to `False`): Whether to use Rotary Position Embeddings (RoPE). + ffn_dim_mult (`int`, *optional*, defaults to 4): Multiplier for the inner dimension of the feed-forward layer. + """ + + def __init__( + self, + dim, + n_heads, + d_head, + context_dim=None, + attn_precision=None, + dtype=None, + device=None, + operations=None, + ): + super().__init__() + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + self.attn1 = CrossAttention( + query_dim=dim, + heads=n_heads, + dim_head=d_head, + context_dim=None, + dtype=dtype, + device=device, + operations=operations, + ) + + # 3. Feed-forward + self.ff = FeedForward( + dim, + dim_out=dim, + glu=True, + dtype=dtype, + device=device, + operations=operations, + ) + + def forward(self, hidden_states, attention_mask=None, pe=None) -> torch.FloatTensor: + + # Notice that normalization is always applied before the real computation in the following blocks. + + # 1. Normalization Before Self-Attention + norm_hidden_states = comfy.ldm.common_dit.rms_norm(hidden_states) + + norm_hidden_states = norm_hidden_states.squeeze(1) + + # 2. Self-Attention + attn_output = self.attn1(norm_hidden_states, mask=attention_mask, pe=pe) + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 3. Normalization before Feed-Forward + norm_hidden_states = comfy.ldm.common_dit.rms_norm(hidden_states) + + # 4. Feed-forward + ff_output = self.ff(norm_hidden_states) + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states + + +class Embeddings1DConnector(nn.Module): + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels=128, + cross_attention_dim=2048, + attention_head_dim=128, + num_attention_heads=30, + num_layers=2, + positional_embedding_theta=10000.0, + positional_embedding_max_pos=[4096], + causal_temporal_positioning=False, + num_learnable_registers: Optional[int] = 128, + dtype=None, + device=None, + operations=None, + split_rope=False, + double_precision_rope=False, + **kwargs, + ): + super().__init__() + self.dtype = dtype + self.out_channels = in_channels + self.num_attention_heads = num_attention_heads + self.inner_dim = num_attention_heads * attention_head_dim + self.causal_temporal_positioning = causal_temporal_positioning + self.positional_embedding_theta = positional_embedding_theta + self.positional_embedding_max_pos = positional_embedding_max_pos + self.split_rope = split_rope + self.double_precision_rope = double_precision_rope + self.transformer_1d_blocks = nn.ModuleList( + [ + BasicTransformerBlock1D( + self.inner_dim, + num_attention_heads, + attention_head_dim, + context_dim=cross_attention_dim, + dtype=dtype, + device=device, + operations=operations, + ) + for _ in range(num_layers) + ] + ) + + inner_dim = num_attention_heads * attention_head_dim + self.num_learnable_registers = num_learnable_registers + if self.num_learnable_registers: + self.learnable_registers = nn.Parameter( + torch.rand( + self.num_learnable_registers, inner_dim, dtype=dtype, device=device + ) + * 2.0 + - 1.0 + ) + + def get_fractional_positions(self, indices_grid): + fractional_positions = torch.stack( + [ + indices_grid[:, i] / self.positional_embedding_max_pos[i] + for i in range(1) + ], + dim=-1, + ) + return fractional_positions + + def precompute_freqs(self, indices_grid, spacing): + source_dtype = indices_grid.dtype + dtype = ( + torch.float32 + if source_dtype in (torch.bfloat16, torch.float16) + else source_dtype + ) + + fractional_positions = self.get_fractional_positions(indices_grid) + indices = ( + generate_freq_grid_np( + self.positional_embedding_theta, + indices_grid.shape[1], + self.inner_dim, + ) + if self.double_precision_rope + else self.generate_freq_grid(spacing, dtype, fractional_positions.device) + ).to(device=fractional_positions.device) + + if spacing == "exp_2": + freqs = ( + (indices * fractional_positions.unsqueeze(-1)) + .transpose(-1, -2) + .flatten(2) + ) + else: + freqs = ( + (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)) + .transpose(-1, -2) + .flatten(2) + ) + return freqs + + def generate_freq_grid(self, spacing, dtype, device): + dim = self.inner_dim + theta = self.positional_embedding_theta + n_pos_dims = 1 + n_elem = 2 * n_pos_dims # 2 for cos and sin e.g. x 3 = 6 + start = 1 + end = theta + + if spacing == "exp": + indices = theta ** (torch.arange(0, dim, n_elem, device="cpu", dtype=torch.float32) / (dim - n_elem)) + indices = indices.to(dtype=dtype, device=device) + elif spacing == "exp_2": + indices = 1.0 / theta ** (torch.arange(0, dim, n_elem, device=device) / dim) + indices = indices.to(dtype=dtype) + elif spacing == "linear": + indices = torch.linspace( + start, end, dim // n_elem, device=device, dtype=dtype + ) + elif spacing == "sqrt": + indices = torch.linspace( + start**2, end**2, dim // n_elem, device=device, dtype=dtype + ).sqrt() + + indices = indices * math.pi / 2 + + return indices + + def precompute_freqs_cis(self, indices_grid, spacing="exp"): + dim = self.inner_dim + n_elem = 2 # 2 because of cos and sin + freqs = self.precompute_freqs(indices_grid, spacing) + if self.split_rope: + expected_freqs = dim // 2 + current_freqs = freqs.shape[-1] + pad_size = expected_freqs - current_freqs + cos_freq, sin_freq = split_freqs_cis( + freqs, pad_size, self.num_attention_heads + ) + else: + cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem) + return cos_freq.to(self.dtype), sin_freq.to(self.dtype), self.split_rope + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ): + """ + The [`Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): + Input `hidden_states`. + indices_grid (`torch.LongTensor` of shape `(batch size, 3, num latent pixels)`): + attention_mask ( `torch.Tensor`, *optional*): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # 1. Input + + if self.num_learnable_registers: + num_registers_duplications = math.ceil( + max(1024, hidden_states.shape[1]) / self.num_learnable_registers + ) + learnable_registers = torch.tile( + self.learnable_registers, (num_registers_duplications, 1) + ) + + hidden_states = torch.cat((hidden_states, learnable_registers[hidden_states.shape[1]:].unsqueeze(0).repeat(hidden_states.shape[0], 1, 1)), dim=1) + + if attention_mask is not None: + attention_mask = torch.zeros([1, 1, 1, hidden_states.shape[1]], dtype=attention_mask.dtype, device=attention_mask.device) + + indices_grid = torch.arange( + hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device + ) + indices_grid = indices_grid[None, None, :] + freqs_cis = self.precompute_freqs_cis(indices_grid) + + # 2. Blocks + for block_idx, block in enumerate(self.transformer_1d_blocks): + hidden_states = block( + hidden_states, attention_mask=attention_mask, pe=freqs_cis + ) + + # 3. Output + # if self.output_scale is not None: + # hidden_states = hidden_states / self.output_scale + + hidden_states = comfy.ldm.common_dit.rms_norm(hidden_states) + + return hidden_states, attention_mask diff --git a/comfy/ldm/lightricks/latent_upsampler.py b/comfy/ldm/lightricks/latent_upsampler.py new file mode 100644 index 000000000000..78ed7653f74f --- /dev/null +++ b/comfy/ldm/lightricks/latent_upsampler.py @@ -0,0 +1,292 @@ +from typing import Optional, Tuple +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + + +def _rational_for_scale(scale: float) -> Tuple[int, int]: + mapping = {0.75: (3, 4), 1.5: (3, 2), 2.0: (2, 1), 4.0: (4, 1)} + if float(scale) not in mapping: + raise ValueError( + f"Unsupported spatial_scale {scale}. Choose from {list(mapping.keys())}" + ) + return mapping[float(scale)] + + +class PixelShuffleND(nn.Module): + def __init__(self, dims, upscale_factors=(2, 2, 2)): + super().__init__() + assert dims in [1, 2, 3], "dims must be 1, 2, or 3" + self.dims = dims + self.upscale_factors = upscale_factors + + def forward(self, x): + if self.dims == 3: + return rearrange( + x, + "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)", + p1=self.upscale_factors[0], + p2=self.upscale_factors[1], + p3=self.upscale_factors[2], + ) + elif self.dims == 2: + return rearrange( + x, + "b (c p1 p2) h w -> b c (h p1) (w p2)", + p1=self.upscale_factors[0], + p2=self.upscale_factors[1], + ) + elif self.dims == 1: + return rearrange( + x, + "b (c p1) f h w -> b c (f p1) h w", + p1=self.upscale_factors[0], + ) + + +class BlurDownsample(nn.Module): + """ + Anti-aliased spatial downsampling by integer stride using a fixed separable binomial kernel. + Applies only on H,W. Works for dims=2 or dims=3 (per-frame). + """ + + def __init__(self, dims: int, stride: int): + super().__init__() + assert dims in (2, 3) + assert stride >= 1 and isinstance(stride, int) + self.dims = dims + self.stride = stride + + # 5x5 separable binomial kernel [1,4,6,4,1] (outer product), normalized + k = torch.tensor([1.0, 4.0, 6.0, 4.0, 1.0]) + k2d = k[:, None] @ k[None, :] + k2d = (k2d / k2d.sum()).float() # shape (5,5) + self.register_buffer("kernel", k2d[None, None, :, :]) # (1,1,5,5) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.stride == 1: + return x + + def _apply_2d(x2d: torch.Tensor) -> torch.Tensor: + # x2d: (B, C, H, W) + B, C, H, W = x2d.shape + weight = self.kernel.expand(C, 1, 5, 5) # depthwise + x2d = F.conv2d( + x2d, weight=weight, bias=None, stride=self.stride, padding=2, groups=C + ) + return x2d + + if self.dims == 2: + return _apply_2d(x) + else: + # dims == 3: apply per-frame on H,W + b, c, f, h, w = x.shape + x = rearrange(x, "b c f h w -> (b f) c h w") + x = _apply_2d(x) + h2, w2 = x.shape[-2:] + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f, h=h2, w=w2) + return x + + +class SpatialRationalResampler(nn.Module): + """ + Fully-learned rational spatial scaling: up by 'num' via PixelShuffle, then anti-aliased + downsample by 'den' using fixed blur + stride. Operates on H,W only. + + For dims==3, work per-frame for spatial scaling (temporal axis untouched). + """ + + def __init__(self, mid_channels: int, scale: float): + super().__init__() + self.scale = float(scale) + self.num, self.den = _rational_for_scale(self.scale) + self.conv = nn.Conv2d( + mid_channels, (self.num**2) * mid_channels, kernel_size=3, padding=1 + ) + self.pixel_shuffle = PixelShuffleND(2, upscale_factors=(self.num, self.num)) + self.blur_down = BlurDownsample(dims=2, stride=self.den) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + b, c, f, h, w = x.shape + x = rearrange(x, "b c f h w -> (b f) c h w") + x = self.conv(x) + x = self.pixel_shuffle(x) + x = self.blur_down(x) + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) + return x + + +class ResBlock(nn.Module): + def __init__( + self, channels: int, mid_channels: Optional[int] = None, dims: int = 3 + ): + super().__init__() + if mid_channels is None: + mid_channels = channels + + Conv = nn.Conv2d if dims == 2 else nn.Conv3d + + self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1) + self.norm1 = nn.GroupNorm(32, mid_channels) + self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1) + self.norm2 = nn.GroupNorm(32, channels) + self.activation = nn.SiLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + x = self.conv1(x) + x = self.norm1(x) + x = self.activation(x) + x = self.conv2(x) + x = self.norm2(x) + x = self.activation(x + residual) + return x + + +class LatentUpsampler(nn.Module): + """ + Model to spatially upsample VAE latents. + + Args: + in_channels (`int`): Number of channels in the input latent + mid_channels (`int`): Number of channels in the middle layers + num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling) + dims (`int`): Number of dimensions for convolutions (2 or 3) + spatial_upsample (`bool`): Whether to spatially upsample the latent + temporal_upsample (`bool`): Whether to temporally upsample the latent + """ + + def __init__( + self, + in_channels: int = 128, + mid_channels: int = 512, + num_blocks_per_stage: int = 4, + dims: int = 3, + spatial_upsample: bool = True, + temporal_upsample: bool = False, + spatial_scale: float = 2.0, + rational_resampler: bool = False, + ): + super().__init__() + + self.in_channels = in_channels + self.mid_channels = mid_channels + self.num_blocks_per_stage = num_blocks_per_stage + self.dims = dims + self.spatial_upsample = spatial_upsample + self.temporal_upsample = temporal_upsample + self.spatial_scale = float(spatial_scale) + self.rational_resampler = rational_resampler + + Conv = nn.Conv2d if dims == 2 else nn.Conv3d + + self.initial_conv = Conv(in_channels, mid_channels, kernel_size=3, padding=1) + self.initial_norm = nn.GroupNorm(32, mid_channels) + self.initial_activation = nn.SiLU() + + self.res_blocks = nn.ModuleList( + [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)] + ) + + if spatial_upsample and temporal_upsample: + self.upsampler = nn.Sequential( + nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(3), + ) + elif spatial_upsample: + if rational_resampler: + self.upsampler = SpatialRationalResampler( + mid_channels=mid_channels, scale=self.spatial_scale + ) + else: + self.upsampler = nn.Sequential( + nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(2), + ) + elif temporal_upsample: + self.upsampler = nn.Sequential( + nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(1), + ) + else: + raise ValueError( + "Either spatial_upsample or temporal_upsample must be True" + ) + + self.post_upsample_res_blocks = nn.ModuleList( + [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)] + ) + + self.final_conv = Conv(mid_channels, in_channels, kernel_size=3, padding=1) + + def forward(self, latent: torch.Tensor) -> torch.Tensor: + b, c, f, h, w = latent.shape + + if self.dims == 2: + x = rearrange(latent, "b c f h w -> (b f) c h w") + x = self.initial_conv(x) + x = self.initial_norm(x) + x = self.initial_activation(x) + + for block in self.res_blocks: + x = block(x) + + x = self.upsampler(x) + + for block in self.post_upsample_res_blocks: + x = block(x) + + x = self.final_conv(x) + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) + else: + x = self.initial_conv(latent) + x = self.initial_norm(x) + x = self.initial_activation(x) + + for block in self.res_blocks: + x = block(x) + + if self.temporal_upsample: + x = self.upsampler(x) + x = x[:, :, 1:, :, :] + else: + if isinstance(self.upsampler, SpatialRationalResampler): + x = self.upsampler(x) + else: + x = rearrange(x, "b c f h w -> (b f) c h w") + x = self.upsampler(x) + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) + + for block in self.post_upsample_res_blocks: + x = block(x) + + x = self.final_conv(x) + + return x + + @classmethod + def from_config(cls, config): + return cls( + in_channels=config.get("in_channels", 4), + mid_channels=config.get("mid_channels", 128), + num_blocks_per_stage=config.get("num_blocks_per_stage", 4), + dims=config.get("dims", 2), + spatial_upsample=config.get("spatial_upsample", True), + temporal_upsample=config.get("temporal_upsample", False), + spatial_scale=config.get("spatial_scale", 2.0), + rational_resampler=config.get("rational_resampler", False), + ) + + def config(self): + return { + "_class_name": "LatentUpsampler", + "in_channels": self.in_channels, + "mid_channels": self.mid_channels, + "num_blocks_per_stage": self.num_blocks_per_stage, + "dims": self.dims, + "spatial_upsample": self.spatial_upsample, + "temporal_upsample": self.temporal_upsample, + "spatial_scale": self.spatial_scale, + "rational_resampler": self.rational_resampler, + } diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index 593f7940f8cc..d61e19d6e6ac 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -1,13 +1,47 @@ +from abc import ABC, abstractmethod +from enum import Enum +import functools +import math +from typing import Dict, Optional, Tuple + +from einops import rearrange +import numpy as np import torch from torch import nn import comfy.patcher_extension import comfy.ldm.modules.attention import comfy.ldm.common_dit -import math -from typing import Dict, Optional, Tuple from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords -from comfy.ldm.flux.math import apply_rope1 + +def _log_base(x, base): + return np.log(x) / np.log(base) + +class LTXRopeType(str, Enum): + INTERLEAVED = "interleaved" + SPLIT = "split" + + KEY = "rope_type" + + @classmethod + def from_dict(cls, kwargs, default=None): + if default is None: + default = cls.INTERLEAVED + return cls(kwargs.get(cls.KEY, default)) + + +class LTXFrequenciesPrecision(str, Enum): + FLOAT32 = "float32" + FLOAT64 = "float64" + + KEY = "frequencies_precision" + + @classmethod + def from_dict(cls, kwargs, default=None): + if default is None: + default = cls.FLOAT32 + return cls(kwargs.get(cls.KEY, default)) + def get_timestep_embedding( timesteps: torch.Tensor, @@ -39,9 +73,7 @@ def get_timestep_embedding( assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" half_dim = embedding_dim // 2 - exponent = -math.log(max_period) * torch.arange( - start=0, end=half_dim, dtype=torch.float32, device=timesteps.device - ) + exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) exponent = exponent / (half_dim - downscale_freq_shift) emb = torch.exp(exponent) @@ -73,7 +105,9 @@ def __init__( post_act_fn: Optional[str] = None, cond_proj_dim=None, sample_proj_bias=True, - dtype=None, device=None, operations=None, + dtype=None, + device=None, + operations=None, ): super().__init__() @@ -90,7 +124,9 @@ def __init__( time_embed_dim_out = out_dim else: time_embed_dim_out = time_embed_dim - self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype=dtype, device=device) + self.linear_2 = operations.Linear( + time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype=dtype, device=device + ) if post_act_fn is None: self.post_act = None @@ -139,12 +175,22 @@ class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module): https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29 """ - def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False, dtype=None, device=None, operations=None): + def __init__( + self, + embedding_dim, + size_emb_dim, + use_additional_conditions: bool = False, + dtype=None, + device=None, + operations=None, + ): super().__init__() self.outdim = size_emb_dim self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) - self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, device=device, operations=operations) + self.timestep_embedder = TimestepEmbedding( + in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, device=device, operations=operations + ) def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): timesteps_proj = self.time_proj(timestep) @@ -163,15 +209,22 @@ class AdaLayerNormSingle(nn.Module): use_additional_conditions (`bool`): To use additional conditions for normalization or not. """ - def __init__(self, embedding_dim: int, use_additional_conditions: bool = False, dtype=None, device=None, operations=None): + def __init__( + self, embedding_dim: int, embedding_coefficient: int = 6, use_additional_conditions: bool = False, dtype=None, device=None, operations=None + ): super().__init__() self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings( - embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions, dtype=dtype, device=device, operations=operations + embedding_dim, + size_emb_dim=embedding_dim // 3, + use_additional_conditions=use_additional_conditions, + dtype=dtype, + device=device, + operations=operations, ) self.silu = nn.SiLU() - self.linear = operations.Linear(embedding_dim, 6 * embedding_dim, bias=True, dtype=dtype, device=device) + self.linear = operations.Linear(embedding_dim, embedding_coefficient * embedding_dim, bias=True, dtype=dtype, device=device) def forward( self, @@ -185,6 +238,7 @@ def forward( embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype) return self.linear(self.silu(embedded_timestep)), embedded_timestep + class PixArtAlphaTextProjection(nn.Module): """ Projects caption embeddings. Also handles dropout for classifier-free guidance. @@ -192,18 +246,24 @@ class PixArtAlphaTextProjection(nn.Module): Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py """ - def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", dtype=None, device=None, operations=None): + def __init__( + self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", dtype=None, device=None, operations=None + ): super().__init__() if out_features is None: out_features = hidden_size - self.linear_1 = operations.Linear(in_features=in_features, out_features=hidden_size, bias=True, dtype=dtype, device=device) + self.linear_1 = operations.Linear( + in_features=in_features, out_features=hidden_size, bias=True, dtype=dtype, device=device + ) if act_fn == "gelu_tanh": self.act_1 = nn.GELU(approximate="tanh") elif act_fn == "silu": self.act_1 = nn.SiLU() else: raise ValueError(f"Unknown activation function: {act_fn}") - self.linear_2 = operations.Linear(in_features=hidden_size, out_features=out_features, bias=True, dtype=dtype, device=device) + self.linear_2 = operations.Linear( + in_features=hidden_size, out_features=out_features, bias=True, dtype=dtype, device=device + ) def forward(self, caption): hidden_states = self.linear_1(caption) @@ -222,23 +282,68 @@ def forward(self, x): class FeedForward(nn.Module): - def __init__(self, dim, dim_out, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=None): + def __init__(self, dim, dim_out, mult=4, glu=False, dropout=0.0, dtype=None, device=None, operations=None): super().__init__() inner_dim = int(dim * mult) project_in = GELU_approx(dim, inner_dim, dtype=dtype, device=device, operations=operations) self.net = nn.Sequential( - project_in, - nn.Dropout(dropout), - operations.Linear(inner_dim, dim_out, dtype=dtype, device=device) + project_in, nn.Dropout(dropout), operations.Linear(inner_dim, dim_out, dtype=dtype, device=device) ) def forward(self, x): return self.net(x) +def apply_rotary_emb(input_tensor, freqs_cis): + cos_freqs, sin_freqs = freqs_cis[0], freqs_cis[1] + split_pe = freqs_cis[2] if len(freqs_cis) > 2 else False + return ( + apply_split_rotary_emb(input_tensor, cos_freqs, sin_freqs) + if split_pe else + apply_interleaved_rotary_emb(input_tensor, cos_freqs, sin_freqs) + ) + +def apply_interleaved_rotary_emb(input_tensor, cos_freqs, sin_freqs): # TODO: remove duplicate funcs and pick the best/fastest one + t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2) + t1, t2 = t_dup.unbind(dim=-1) + t_dup = torch.stack((-t2, t1), dim=-1) + input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)") + + out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs + + return out + +def apply_split_rotary_emb(input_tensor, cos, sin): + needs_reshape = False + if input_tensor.ndim != 4 and cos.ndim == 4: + B, H, T, _ = cos.shape + input_tensor = input_tensor.reshape(B, T, H, -1).swapaxes(1, 2) + needs_reshape = True + split_input = rearrange(input_tensor, "... (d r) -> ... d r", d=2) + first_half_input = split_input[..., :1, :] + second_half_input = split_input[..., 1:, :] + output = split_input * cos.unsqueeze(-2) + first_half_output = output[..., :1, :] + second_half_output = output[..., 1:, :] + first_half_output.addcmul_(-sin.unsqueeze(-2), second_half_input) + second_half_output.addcmul_(sin.unsqueeze(-2), first_half_input) + output = rearrange(output, "... d r -> ... (d r)") + return output.swapaxes(1, 2).reshape(B, T, -1) if needs_reshape else output + class CrossAttention(nn.Module): - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None): + def __init__( + self, + query_dim, + context_dim=None, + heads=8, + dim_head=64, + dropout=0.0, + attn_precision=None, + dtype=None, + device=None, + operations=None, + ): super().__init__() inner_dim = dim_head * heads context_dim = query_dim if context_dim is None else context_dim @@ -254,9 +359,11 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0. self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device) self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device) - self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)) + self.to_out = nn.Sequential( + operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout) + ) - def forward(self, x, context=None, mask=None, pe=None, transformer_options={}): + def forward(self, x, context=None, mask=None, pe=None, k_pe=None, transformer_options={}): q = self.to_q(x) context = x if context is None else context k = self.to_k(context) @@ -266,8 +373,8 @@ def forward(self, x, context=None, mask=None, pe=None, transformer_options={}): k = self.k_norm(k) if pe is not None: - q = apply_rope1(q.unsqueeze(1), pe).squeeze(1) - k = apply_rope1(k.unsqueeze(1), pe).squeeze(1) + q = apply_rotary_emb(q, pe) + k = apply_rotary_emb(k, pe if k_pe is None else k_pe) if mask is None: out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options) @@ -277,14 +384,34 @@ def forward(self, x, context=None, mask=None, pe=None, transformer_options={}): class BasicTransformerBlock(nn.Module): - def __init__(self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None): + def __init__( + self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None + ): super().__init__() self.attn_precision = attn_precision - self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, context_dim=None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) + self.attn1 = CrossAttention( + query_dim=dim, + heads=n_heads, + dim_head=d_head, + context_dim=None, + attn_precision=self.attn_precision, + dtype=dtype, + device=device, + operations=operations, + ) self.ff = FeedForward(dim, dim_out=dim, glu=True, dtype=dtype, device=device, operations=operations) - self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) + self.attn2 = CrossAttention( + query_dim=dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + attn_precision=self.attn_precision, + dtype=dtype, + device=device, + operations=operations, + ) self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype)) @@ -306,116 +433,446 @@ def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, return x def get_fractional_positions(indices_grid, max_pos): + n_pos_dims = indices_grid.shape[1] + assert n_pos_dims == len(max_pos), f'Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})' fractional_positions = torch.stack( - [ - indices_grid[:, i] / max_pos[i] - for i in range(3) - ], - dim=-1, + [indices_grid[:, i] / max_pos[i] for i in range(n_pos_dims)], + axis=-1, ) return fractional_positions -def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]): - dtype = torch.float32 - device = indices_grid.device +@functools.lru_cache(maxsize=5) +def generate_freq_grid_np(positional_embedding_theta, positional_embedding_max_pos_count, inner_dim, _ = None): + theta = positional_embedding_theta + start = 1 + end = theta + + n_elem = 2 * positional_embedding_max_pos_count + pow_indices = np.power( + theta, + np.linspace( + _log_base(start, theta), + _log_base(end, theta), + inner_dim // n_elem, + dtype=np.float64, + ), + ) + return torch.tensor(pow_indices * math.pi / 2, dtype=torch.float32) + +def generate_freq_grid_pytorch(positional_embedding_theta, positional_embedding_max_pos_count, inner_dim, device): + theta = positional_embedding_theta + start = 1 + end = theta + n_elem = 2 * positional_embedding_max_pos_count + + indices = theta ** ( + torch.linspace( + math.log(start, theta), + math.log(end, theta), + inner_dim // n_elem, + device=device, + dtype=torch.float32, + ) + ) + indices = indices.to(dtype=torch.float32) + + indices = indices * math.pi / 2 + + return indices + +def generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid): + if use_middle_indices_grid: + assert(len(indices_grid.shape) == 4 and indices_grid.shape[-1] ==2) + indices_grid_start, indices_grid_end = indices_grid[..., 0], indices_grid[..., 1] + indices_grid = (indices_grid_start + indices_grid_end) / 2.0 + elif len(indices_grid.shape) == 4: + indices_grid = indices_grid[..., 0] # Get fractional positions and compute frequency indices fractional_positions = get_fractional_positions(indices_grid, max_pos) - indices = theta ** torch.linspace(0, 1, dim // 6, device=device, dtype=dtype) * math.pi / 2 + indices = indices.to(device=fractional_positions.device) + + freqs = ( + (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)) + .transpose(-1, -2) + .flatten(2) + ) + return freqs - # Compute frequencies and apply cos/sin - freqs = (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)).transpose(-1, -2).flatten(2) - cos_vals = freqs.cos().repeat_interleave(2, dim=-1) - sin_vals = freqs.sin().repeat_interleave(2, dim=-1) +def interleaved_freqs_cis(freqs, pad_size): + cos_freq = freqs.cos().repeat_interleave(2, dim=-1) + sin_freq = freqs.sin().repeat_interleave(2, dim=-1) + if pad_size != 0: + cos_padding = torch.ones_like(cos_freq[:, :, : pad_size]) + sin_padding = torch.zeros_like(cos_freq[:, :, : pad_size]) + cos_freq = torch.cat([cos_padding, cos_freq], dim=-1) + sin_freq = torch.cat([sin_padding, sin_freq], dim=-1) + return cos_freq, sin_freq - # Pad if dim is not divisible by 6 - if dim % 6 != 0: - padding_size = dim % 6 - cos_vals = torch.cat([torch.ones_like(cos_vals[:, :, :padding_size]), cos_vals], dim=-1) - sin_vals = torch.cat([torch.zeros_like(sin_vals[:, :, :padding_size]), sin_vals], dim=-1) +def split_freqs_cis(freqs, pad_size, num_attention_heads): + cos_freq = freqs.cos() + sin_freq = freqs.sin() - # Reshape and extract one value per pair (since repeat_interleave duplicates each value) - cos_vals = cos_vals.reshape(*cos_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2] - sin_vals = sin_vals.reshape(*sin_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2] + if pad_size != 0: + cos_padding = torch.ones_like(cos_freq[:, :, :pad_size]) + sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size]) - # Build rotation matrix [[cos, -sin], [sin, cos]] and add heads dimension - freqs_cis = torch.stack([ - torch.stack([cos_vals, -sin_vals], dim=-1), - torch.stack([sin_vals, cos_vals], dim=-1) - ], dim=-2).unsqueeze(1) # [B, 1, N, dim//2, 2, 2] + cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1) + sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1) - return freqs_cis + # Reshape freqs to be compatible with multi-head attention + B , T, half_HD = cos_freq.shape + cos_freq = cos_freq.reshape(B, T, num_attention_heads, half_HD // num_attention_heads) + sin_freq = sin_freq.reshape(B, T, num_attention_heads, half_HD // num_attention_heads) -class LTXVModel(torch.nn.Module): - def __init__(self, - in_channels=128, - cross_attention_dim=2048, - attention_head_dim=64, - num_attention_heads=32, + cos_freq = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2) + sin_freq = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2) + return cos_freq, sin_freq - caption_channels=4096, - num_layers=28, +class LTXBaseModel(torch.nn.Module, ABC): + """ + Abstract base class for LTX models (Lightricks Transformer models). + This class defines the common interface and shared functionality for all LTX models, + including LTXV (video) and LTXAV (audio-video) variants. + """ - positional_embedding_theta=10000.0, - positional_embedding_max_pos=[20, 2048, 2048], - causal_temporal_positioning=False, - vae_scale_factors=(8, 32, 32), - dtype=None, device=None, operations=None, **kwargs): + def __init__( + self, + in_channels: int, + cross_attention_dim: int, + attention_head_dim: int, + num_attention_heads: int, + caption_channels: int, + num_layers: int, + positional_embedding_theta: float = 10000.0, + positional_embedding_max_pos: list = [20, 2048, 2048], + causal_temporal_positioning: bool = False, + vae_scale_factors: tuple = (8, 32, 32), + use_middle_indices_grid=False, + timestep_scale_multiplier = 1000.0, + dtype=None, + device=None, + operations=None, + **kwargs, + ): super().__init__() self.generator = None self.vae_scale_factors = vae_scale_factors + self.use_middle_indices_grid = use_middle_indices_grid self.dtype = dtype - self.out_channels = in_channels - self.inner_dim = num_attention_heads * attention_head_dim + self.in_channels = in_channels + self.cross_attention_dim = cross_attention_dim + self.attention_head_dim = attention_head_dim + self.num_attention_heads = num_attention_heads + self.caption_channels = caption_channels + self.num_layers = num_layers + self.positional_embedding_theta = positional_embedding_theta + self.positional_embedding_max_pos = positional_embedding_max_pos + self.split_positional_embedding = LTXRopeType.from_dict(kwargs) + self.freq_grid_generator = ( + generate_freq_grid_np if LTXFrequenciesPrecision.from_dict(kwargs) == LTXFrequenciesPrecision.FLOAT64 + else generate_freq_grid_pytorch + ) self.causal_temporal_positioning = causal_temporal_positioning + self.operations = operations + self.timestep_scale_multiplier = timestep_scale_multiplier - self.patchify_proj = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device) + # Common dimensions + self.inner_dim = num_attention_heads * attention_head_dim + self.out_channels = in_channels - self.adaln_single = AdaLayerNormSingle( - self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=operations + # Initialize common components + self._init_common_components(device, dtype) + + # Initialize model-specific components + self._init_model_components(device, dtype, **kwargs) + + # Initialize transformer blocks + self._init_transformer_blocks(device, dtype, **kwargs) + + # Initialize output components + self._init_output_components(device, dtype) + + def _init_common_components(self, device, dtype): + """Initialize components common to all LTX models + - patchify_proj: Linear projection for patchifying input + - adaln_single: AdaLN layer for timestep embedding + - caption_projection: Linear projection for caption embedding + """ + self.patchify_proj = self.operations.Linear( + self.in_channels, self.inner_dim, bias=True, dtype=dtype, device=device ) - # self.adaln_single.linear = operations.Linear(self.inner_dim, 4 * self.inner_dim, bias=True, dtype=dtype, device=device) + self.adaln_single = AdaLayerNormSingle( + self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations + ) self.caption_projection = PixArtAlphaTextProjection( - in_features=caption_channels, hidden_size=self.inner_dim, dtype=dtype, device=device, operations=operations + in_features=self.caption_channels, + hidden_size=self.inner_dim, + dtype=dtype, + device=device, + operations=self.operations, ) - self.transformer_blocks = nn.ModuleList( - [ - BasicTransformerBlock( - self.inner_dim, - num_attention_heads, - attention_head_dim, - context_dim=cross_attention_dim, - # attn_precision=attn_precision, - dtype=dtype, device=device, operations=operations - ) - for d in range(num_layers) - ] + @abstractmethod + def _init_model_components(self, device, dtype, **kwargs): + """Initialize model-specific components. Must be implemented by subclasses.""" + pass + + @abstractmethod + def _init_transformer_blocks(self, device, dtype, **kwargs): + """Initialize transformer blocks. Must be implemented by subclasses.""" + pass + + @abstractmethod + def _init_output_components(self, device, dtype): + """Initialize output components. Must be implemented by subclasses.""" + pass + + @abstractmethod + def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs): + """Process input data. Must be implemented by subclasses.""" + pass + + @abstractmethod + def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, **kwargs): + """Process transformer blocks. Must be implemented by subclasses.""" + pass + + @abstractmethod + def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs): + """Process output data. Must be implemented by subclasses.""" + pass + + def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs): + """Prepare timestep embeddings.""" + grid_mask = kwargs.get("grid_mask", None) + if grid_mask is not None: + timestep = timestep[:, grid_mask] + + timestep = timestep * self.timestep_scale_multiplier + timestep, embedded_timestep = self.adaln_single( + timestep.flatten(), + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_dtype, ) - self.scale_shift_table = nn.Parameter(torch.empty(2, self.inner_dim, dtype=dtype, device=device)) - self.norm_out = operations.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.proj_out = operations.Linear(self.inner_dim, self.out_channels, dtype=dtype, device=device) + # Second dimension is 1 or number of tokens (if timestep_per_token) + timestep = timestep.view(batch_size, -1, timestep.shape[-1]) + embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1]) + + return timestep, embedded_timestep + + def _prepare_context(self, context, batch_size, x, attention_mask=None): + """Prepare context for transformer blocks.""" + if self.caption_projection is not None: + context = self.caption_projection(context) + context = context.view(batch_size, -1, x.shape[-1]) + + return context, attention_mask + + def _precompute_freqs_cis( + self, + indices_grid, + dim, + out_dtype, + theta=10000.0, + max_pos=[20, 2048, 2048], + use_middle_indices_grid=False, + num_attention_heads=32, + ): + split_mode = self.split_positional_embedding == LTXRopeType.SPLIT + indices = self.freq_grid_generator(theta, indices_grid.shape[1], dim, indices_grid.device) + freqs = generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid) + + if split_mode: + expected_freqs = dim // 2 + current_freqs = freqs.shape[-1] + pad_size = expected_freqs - current_freqs + cos_freq, sin_freq = split_freqs_cis(freqs, pad_size, num_attention_heads) + else: + # 2 because of cos and sin by 3 for (t, x, y), 1 for temporal only + n_elem = 2 * indices_grid.shape[1] + cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem) + return cos_freq.to(out_dtype), sin_freq.to(out_dtype), split_mode + + def _prepare_positional_embeddings(self, pixel_coords, frame_rate, x_dtype): + """Prepare positional embeddings.""" + fractional_coords = pixel_coords.to(torch.float32) + fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate) + pe = self._precompute_freqs_cis( + fractional_coords, + dim=self.inner_dim, + out_dtype=x_dtype, + max_pos=self.positional_embedding_max_pos, + use_middle_indices_grid=self.use_middle_indices_grid, + num_attention_heads=self.num_attention_heads, + ) + return pe - self.patchifier = SymmetricPatchifier(1) + def _prepare_attention_mask(self, attention_mask, x_dtype): + """Prepare attention mask.""" + if attention_mask is not None and not torch.is_floating_point(attention_mask): + attention_mask = (attention_mask - 1).to(x_dtype).reshape( + (attention_mask.shape[0], 1, -1, attention_mask.shape[-1]) + ) * torch.finfo(x_dtype).max + return attention_mask - def forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs): + def forward( + self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, denoise_mask=None, **kwargs + ): + """ + Forward pass for LTX models. + + Args: + x: Input tensor + timestep: Timestep tensor + context: Context tensor (e.g., text embeddings) + attention_mask: Attention mask tensor + frame_rate: Frame rate for temporal processing + transformer_options: Additional options for transformer blocks + keyframe_idxs: Keyframe indices for temporal processing + **kwargs: Additional keyword arguments + + Returns: + Processed output tensor + """ return comfy.patcher_extension.WrapperExecutor.new_class_executor( self._forward, self, - comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) - ).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, **kwargs) + comfy.patcher_extension.get_all_wrappers( + comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options + ), + ).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, denoise_mask=denoise_mask, **kwargs) + + def _forward( + self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, denoise_mask=None, **kwargs + ): + """ + Internal forward pass for LTX models. + + Args: + x: Input tensor + timestep: Timestep tensor + context: Context tensor (e.g., text embeddings) + attention_mask: Attention mask tensor + frame_rate: Frame rate for temporal processing + transformer_options: Additional options for transformer blocks + keyframe_idxs: Keyframe indices for temporal processing + **kwargs: Additional keyword arguments + + Returns: + Processed output tensor + """ + if isinstance(x, list): + input_dtype = x[0].dtype + batch_size = x[0].shape[0] + else: + input_dtype = x.dtype + batch_size = x.shape[0] + # Process input + merged_args = {**transformer_options, **kwargs} + x, pixel_coords, additional_args = self._process_input(x, keyframe_idxs, denoise_mask, **merged_args) + merged_args.update(additional_args) + + # Prepare timestep and context + timestep, embedded_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args) + context, attention_mask = self._prepare_context(context, batch_size, x, attention_mask) + + # Prepare attention mask and positional embeddings + attention_mask = self._prepare_attention_mask(attention_mask, input_dtype) + pe = self._prepare_positional_embeddings(pixel_coords, frame_rate, input_dtype) + + # Process transformer blocks + x = self._process_transformer_blocks( + x, context, attention_mask, timestep, pe, transformer_options=transformer_options, **merged_args + ) + + # Process output + x = self._process_output(x, embedded_timestep, keyframe_idxs, **merged_args) + return x - def _forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs): - patches_replace = transformer_options.get("patches_replace", {}) - orig_shape = list(x.shape) +class LTXVModel(LTXBaseModel): + """LTXV model for video generation.""" + def __init__( + self, + in_channels=128, + cross_attention_dim=2048, + attention_head_dim=64, + num_attention_heads=32, + caption_channels=4096, + num_layers=28, + positional_embedding_theta=10000.0, + positional_embedding_max_pos=[20, 2048, 2048], + causal_temporal_positioning=False, + vae_scale_factors=(8, 32, 32), + use_middle_indices_grid=False, + timestep_scale_multiplier = 1000.0, + dtype=None, + device=None, + operations=None, + **kwargs, + ): + super().__init__( + in_channels=in_channels, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + num_attention_heads=num_attention_heads, + caption_channels=caption_channels, + num_layers=num_layers, + positional_embedding_theta=positional_embedding_theta, + positional_embedding_max_pos=positional_embedding_max_pos, + causal_temporal_positioning=causal_temporal_positioning, + vae_scale_factors=vae_scale_factors, + use_middle_indices_grid=use_middle_indices_grid, + timestep_scale_multiplier=timestep_scale_multiplier, + dtype=dtype, + device=device, + operations=operations, + **kwargs, + ) + + def _init_model_components(self, device, dtype, **kwargs): + """Initialize LTXV-specific components.""" + # No additional components needed for LTXV beyond base class + pass + + def _init_transformer_blocks(self, device, dtype, **kwargs): + """Initialize transformer blocks for LTXV.""" + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + self.inner_dim, + self.num_attention_heads, + self.attention_head_dim, + context_dim=self.cross_attention_dim, + dtype=dtype, + device=device, + operations=self.operations, + ) + for _ in range(self.num_layers) + ] + ) + + def _init_output_components(self, device, dtype): + """Initialize output components for LTXV.""" + self.scale_shift_table = nn.Parameter(torch.empty(2, self.inner_dim, dtype=dtype, device=device)) + self.norm_out = self.operations.LayerNorm( + self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device + ) + self.proj_out = self.operations.Linear(self.inner_dim, self.out_channels, dtype=dtype, device=device) + self.patchifier = SymmetricPatchifier(1, start_end=True) + + def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs): + """Process input for LTXV.""" + additional_args = {"orig_shape": list(x.shape)} x, latent_coords = self.patchifier.patchify(x) pixel_coords = latent_to_pixel_coords( latent_coords=latent_coords, @@ -423,44 +880,30 @@ def _forward(self, x, timestep, context, attention_mask, frame_rate=25, transfor causal_fix=self.causal_temporal_positioning, ) + grid_mask = None if keyframe_idxs is not None: - pixel_coords[:, :, -keyframe_idxs.shape[2]:] = keyframe_idxs + additional_args.update({ "orig_patchified_shape": list(x.shape)}) + denoise_mask = self.patchifier.patchify(denoise_mask)[0] + grid_mask = ~torch.any(denoise_mask < 0, dim=-1)[0] + additional_args.update({"grid_mask": grid_mask}) + x = x[:, grid_mask, :] + pixel_coords = pixel_coords[:, :, grid_mask, ...] - fractional_coords = pixel_coords.to(torch.float32) - fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate) + kf_grid_mask = grid_mask[-keyframe_idxs.shape[2]:] + keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :] + pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs x = self.patchify_proj(x) - timestep = timestep * 1000.0 - - if attention_mask is not None and not torch.is_floating_point(attention_mask): - attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max - - pe = precompute_freqs_cis(fractional_coords, dim=self.inner_dim, out_dtype=x.dtype) - - batch_size = x.shape[0] - timestep, embedded_timestep = self.adaln_single( - timestep.flatten(), - {"resolution": None, "aspect_ratio": None}, - batch_size=batch_size, - hidden_dtype=x.dtype, - ) - # Second dimension is 1 or number of tokens (if timestep_per_token) - timestep = timestep.view(batch_size, -1, timestep.shape[-1]) - embedded_timestep = embedded_timestep.view( - batch_size, -1, embedded_timestep.shape[-1] - ) - - # 2. Blocks - if self.caption_projection is not None: - batch_size = x.shape[0] - context = self.caption_projection(context) - context = context.view( - batch_size, -1, x.shape[-1] - ) + return x, pixel_coords, additional_args + def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs): + """Process transformer blocks for LTXV.""" + patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) + for i, block in enumerate(self.transformer_blocks): if ("double_block", i) in blocks_replace: + def block_wrap(args): out = {} out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"]) @@ -478,16 +921,28 @@ def block_wrap(args): transformer_options=transformer_options, ) - # 3. Output + return x + + def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs): + """Process output for LTXV.""" + # Apply scale-shift modulation scale_shift_values = ( self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None] ) shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + x = self.norm_out(x) - # Modulation - x = torch.addcmul(x, x, scale).add_(shift) + x = x * (1 + scale) + shift x = self.proj_out(x) + if keyframe_idxs is not None: + grid_mask = kwargs["grid_mask"] + orig_patchified_shape = kwargs["orig_patchified_shape"] + full_x = torch.zeros(orig_patchified_shape, dtype=x.dtype, device=x.device) + full_x[:, grid_mask, :] = x + x = full_x + # Unpatchify to restore original dimensions + orig_shape = kwargs["orig_shape"] x = self.patchifier.unpatchify( latents=x, output_height=orig_shape[3], diff --git a/comfy/ldm/lightricks/symmetric_patchifier.py b/comfy/ldm/lightricks/symmetric_patchifier.py index 4b9972b9fb58..8f9a41186bf7 100644 --- a/comfy/ldm/lightricks/symmetric_patchifier.py +++ b/comfy/ldm/lightricks/symmetric_patchifier.py @@ -21,20 +21,23 @@ def latent_to_pixel_coords( Returns: Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates. """ + shape = [1] * latent_coords.ndim + shape[1] = -1 pixel_coords = ( latent_coords - * torch.tensor(scale_factors, device=latent_coords.device)[None, :, None] + * torch.tensor(scale_factors, device=latent_coords.device).view(*shape) ) if causal_fix: # Fix temporal scale for first frame to 1 due to causality - pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - scale_factors[0]).clamp(min=0) + pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + 1 - scale_factors[0]).clamp(min=0) return pixel_coords class Patchifier(ABC): - def __init__(self, patch_size: int): + def __init__(self, patch_size: int, start_end: bool=False): super().__init__() self._patch_size = (1, patch_size, patch_size) + self.start_end = start_end @abstractmethod def patchify( @@ -71,11 +74,23 @@ def get_latent_coords( torch.arange(0, latent_width, self._patch_size[2], device=device), indexing="ij", ) - latent_sample_coords = torch.stack(latent_sample_coords, dim=0) - latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) - latent_coords = rearrange( - latent_coords, "b c f h w -> b c (f h w)", b=batch_size + latent_sample_coords_start = torch.stack(latent_sample_coords, dim=0) + delta = torch.tensor(self._patch_size, device=latent_sample_coords_start.device, dtype=latent_sample_coords_start.dtype)[:, None, None, None] + latent_sample_coords_end = latent_sample_coords_start + delta + + latent_sample_coords_start = latent_sample_coords_start.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) + latent_sample_coords_start = rearrange( + latent_sample_coords_start, "b c f h w -> b c (f h w)", b=batch_size ) + if self.start_end: + latent_sample_coords_end = latent_sample_coords_end.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) + latent_sample_coords_end = rearrange( + latent_sample_coords_end, "b c f h w -> b c (f h w)", b=batch_size + ) + + latent_coords = torch.stack((latent_sample_coords_start, latent_sample_coords_end), dim=-1) + else: + latent_coords = latent_sample_coords_start return latent_coords @@ -115,3 +130,61 @@ def unpatchify( q=self._patch_size[2], ) return latents + + +class AudioPatchifier(Patchifier): + def __init__(self, patch_size: int, + sample_rate=16000, + hop_length=160, + audio_latent_downsample_factor=4, + is_causal=True, + start_end=False, + shift = 0 + ): + super().__init__(patch_size, start_end=start_end) + self.hop_length = hop_length + self.sample_rate = sample_rate + self.audio_latent_downsample_factor = audio_latent_downsample_factor + self.is_causal = is_causal + self.shift = shift + + def copy_with_shift(self, shift): + return AudioPatchifier( + self.patch_size, self.sample_rate, self.hop_length, self.audio_latent_downsample_factor, + self.is_causal, self.start_end, shift + ) + + def _get_audio_latent_time_in_sec(self, start_latent, end_latent: int, dtype: torch.dtype, device=torch.device): + audio_latent_frame = torch.arange(start_latent, end_latent, dtype=dtype, device=device) + audio_mel_frame = audio_latent_frame * self.audio_latent_downsample_factor + if self.is_causal: + audio_mel_frame = (audio_mel_frame + 1 - self.audio_latent_downsample_factor).clip(min=0) + return audio_mel_frame * self.hop_length / self.sample_rate + + + def patchify(self, audio_latents: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + # audio_latents: (batch, channels, time, freq) + b, _, t, _ = audio_latents.shape + audio_latents = rearrange( + audio_latents, + "b c t f -> b t (c f)", + ) + + audio_latents_start_timings = self._get_audio_latent_time_in_sec(self.shift, t + self.shift, torch.float32, audio_latents.device) + audio_latents_start_timings = audio_latents_start_timings.unsqueeze(0).expand(b, -1).unsqueeze(1) + + if self.start_end: + audio_latents_end_timings = self._get_audio_latent_time_in_sec(self.shift + 1, t + self.shift + 1, torch.float32, audio_latents.device) + audio_latents_end_timings = audio_latents_end_timings.unsqueeze(0).expand(b, -1).unsqueeze(1) + + audio_latents_timings = torch.stack([audio_latents_start_timings, audio_latents_end_timings], dim=-1) + else: + audio_latents_timings = audio_latents_start_timings + return audio_latents, audio_latents_timings + + def unpatchify(self, audio_latents: torch.Tensor, channels: int, freq: int) -> torch.Tensor: + # audio_latents: (batch, time, freq * channels) + audio_latents = rearrange( + audio_latents, "b t (c f) -> b c t f", c=channels, f=freq + ) + return audio_latents diff --git a/comfy/ldm/lightricks/vae/audio_vae.py b/comfy/ldm/lightricks/vae/audio_vae.py new file mode 100644 index 000000000000..a9111d3bda85 --- /dev/null +++ b/comfy/ldm/lightricks/vae/audio_vae.py @@ -0,0 +1,286 @@ +import json +from dataclasses import dataclass +import math +import torch +import torchaudio + +import comfy.model_management +import comfy.model_patcher +import comfy.utils as utils +from comfy.ldm.mmaudio.vae.distributions import DiagonalGaussianDistribution +from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier +from comfy.ldm.lightricks.vae.causal_audio_autoencoder import ( + CausalityAxis, + CausalAudioAutoencoder, +) +from comfy.ldm.lightricks.vocoders.vocoder import Vocoder + +LATENT_DOWNSAMPLE_FACTOR = 4 + + +@dataclass(frozen=True) +class AudioVAEComponentConfig: + """Container for model component configuration extracted from metadata.""" + + autoencoder: dict + vocoder: dict + + @classmethod + def from_metadata(cls, metadata: dict) -> "AudioVAEComponentConfig": + assert metadata is not None and "config" in metadata, "Metadata is required for audio VAE" + + raw_config = metadata["config"] + if isinstance(raw_config, str): + parsed_config = json.loads(raw_config) + else: + parsed_config = raw_config + + audio_config = parsed_config.get("audio_vae") + vocoder_config = parsed_config.get("vocoder") + + assert audio_config is not None, "Audio VAE config is required for audio VAE" + assert vocoder_config is not None, "Vocoder config is required for audio VAE" + + return cls(autoencoder=audio_config, vocoder=vocoder_config) + + +class ModelDeviceManager: + """Manages device placement and GPU residency for the composed model.""" + + def __init__(self, module: torch.nn.Module): + load_device = comfy.model_management.get_torch_device() + offload_device = comfy.model_management.vae_offload_device() + self.patcher = comfy.model_patcher.ModelPatcher(module, load_device, offload_device) + + def ensure_model_loaded(self) -> None: + comfy.model_management.free_memory( + self.patcher.model_size(), + self.patcher.load_device, + ) + comfy.model_management.load_model_gpu(self.patcher) + + def move_to_load_device(self, tensor: torch.Tensor) -> torch.Tensor: + return tensor.to(self.patcher.load_device) + + @property + def load_device(self): + return self.patcher.load_device + + +class AudioLatentNormalizer: + """Applies per-channel statistics in patch space and restores original layout.""" + + def __init__(self, patchfier: AudioPatchifier, statistics_processor: torch.nn.Module): + self.patchifier = patchfier + self.statistics = statistics_processor + + def normalize(self, latents: torch.Tensor) -> torch.Tensor: + channels = latents.shape[1] + freq = latents.shape[3] + patched, _ = self.patchifier.patchify(latents) + normalized = self.statistics.normalize(patched) + return self.patchifier.unpatchify(normalized, channels=channels, freq=freq) + + def denormalize(self, latents: torch.Tensor) -> torch.Tensor: + channels = latents.shape[1] + freq = latents.shape[3] + patched, _ = self.patchifier.patchify(latents) + denormalized = self.statistics.un_normalize(patched) + return self.patchifier.unpatchify(denormalized, channels=channels, freq=freq) + + +class AudioPreprocessor: + """Prepares raw waveforms for the autoencoder by matching training conditions.""" + + def __init__(self, target_sample_rate: int, mel_bins: int, mel_hop_length: int, n_fft: int): + self.target_sample_rate = target_sample_rate + self.mel_bins = mel_bins + self.mel_hop_length = mel_hop_length + self.n_fft = n_fft + + def resample(self, waveform: torch.Tensor, source_rate: int) -> torch.Tensor: + if source_rate == self.target_sample_rate: + return waveform + return torchaudio.functional.resample(waveform, source_rate, self.target_sample_rate) + + @staticmethod + def normalize_amplitude( + waveform: torch.Tensor, max_amplitude: float = 0.5, eps: float = 1e-5 + ) -> torch.Tensor: + waveform = waveform - waveform.mean(dim=2, keepdim=True) + peak = torch.max(torch.abs(waveform)) + eps + scale = peak.clamp(max=max_amplitude) / peak + return waveform * scale + + def waveform_to_mel( + self, waveform: torch.Tensor, waveform_sample_rate: int, device + ) -> torch.Tensor: + waveform = self.resample(waveform, waveform_sample_rate) + waveform = self.normalize_amplitude(waveform) + + mel_transform = torchaudio.transforms.MelSpectrogram( + sample_rate=self.target_sample_rate, + n_fft=self.n_fft, + win_length=self.n_fft, + hop_length=self.mel_hop_length, + f_min=0.0, + f_max=self.target_sample_rate / 2.0, + n_mels=self.mel_bins, + window_fn=torch.hann_window, + center=True, + pad_mode="reflect", + power=1.0, + mel_scale="slaney", + norm="slaney", + ).to(device) + + mel = mel_transform(waveform) + mel = torch.log(torch.clamp(mel, min=1e-5)) + return mel.permute(0, 1, 3, 2).contiguous() + + +class AudioVAE(torch.nn.Module): + """High-level Audio VAE wrapper exposing encode and decode entry points.""" + + def __init__(self, state_dict: dict, metadata: dict): + super().__init__() + + component_config = AudioVAEComponentConfig.from_metadata(metadata) + + vae_sd = utils.state_dict_prefix_replace(state_dict, {"audio_vae.": ""}, filter_keys=True) + vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True) + + self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder) + self.vocoder = Vocoder(config=component_config.vocoder) + + self.autoencoder.load_state_dict(vae_sd, strict=False) + self.vocoder.load_state_dict(vocoder_sd, strict=False) + + autoencoder_config = self.autoencoder.get_config() + self.normalizer = AudioLatentNormalizer( + AudioPatchifier( + patch_size=1, + audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR, + sample_rate=autoencoder_config["sampling_rate"], + hop_length=autoencoder_config["mel_hop_length"], + is_causal=autoencoder_config["is_causal"], + ), + self.autoencoder.per_channel_statistics, + ) + + self.preprocessor = AudioPreprocessor( + target_sample_rate=autoencoder_config["sampling_rate"], + mel_bins=autoencoder_config["mel_bins"], + mel_hop_length=autoencoder_config["mel_hop_length"], + n_fft=autoencoder_config["n_fft"], + ) + + self.device_manager = ModelDeviceManager(self) + + def encode(self, audio: dict) -> torch.Tensor: + """Encode a waveform dictionary into normalized latent tensors.""" + + waveform = audio["waveform"] + waveform_sample_rate = audio["sample_rate"] + input_device = waveform.device + # Ensure that Audio VAE is loaded on the correct device. + self.device_manager.ensure_model_loaded() + + waveform = self.device_manager.move_to_load_device(waveform) + expected_channels = self.autoencoder.encoder.in_channels + if waveform.shape[1] != expected_channels: + raise ValueError( + f"Input audio must have {expected_channels} channels, got {waveform.shape[1]}" + ) + + mel_spec = self.preprocessor.waveform_to_mel( + waveform, waveform_sample_rate, device=self.device_manager.load_device + ) + + latents = self.autoencoder.encode(mel_spec) + posterior = DiagonalGaussianDistribution(latents) + latent_mode = posterior.mode() + + normalized = self.normalizer.normalize(latent_mode) + return normalized.to(input_device) + + def decode(self, latents: torch.Tensor) -> torch.Tensor: + """Decode normalized latent tensors into an audio waveform.""" + original_shape = latents.shape + + # Ensure that Audio VAE is loaded on the correct device. + self.device_manager.ensure_model_loaded() + + latents = self.device_manager.move_to_load_device(latents) + latents = self.normalizer.denormalize(latents) + + target_shape = self.target_shape_from_latents(original_shape) + mel_spec = self.autoencoder.decode(latents, target_shape=target_shape) + + waveform = self.run_vocoder(mel_spec) + return self.device_manager.move_to_load_device(waveform) + + def target_shape_from_latents(self, latents_shape): + batch, _, time, _ = latents_shape + target_length = time * LATENT_DOWNSAMPLE_FACTOR + if self.autoencoder.causality_axis != CausalityAxis.NONE: + target_length -= LATENT_DOWNSAMPLE_FACTOR - 1 + return ( + batch, + self.autoencoder.decoder.out_ch, + target_length, + self.autoencoder.mel_bins, + ) + + def num_of_latents_from_frames(self, frames_number: int, frame_rate: int) -> int: + return math.ceil((float(frames_number) / frame_rate) * self.latents_per_second) + + def run_vocoder(self, mel_spec: torch.Tensor) -> torch.Tensor: + audio_channels = self.autoencoder.decoder.out_ch + vocoder_input = mel_spec.transpose(2, 3) + + if audio_channels == 1: + vocoder_input = vocoder_input.squeeze(1) + elif audio_channels != 2: + raise ValueError(f"Unsupported audio_channels: {audio_channels}") + + return self.vocoder(vocoder_input) + + @property + def sample_rate(self) -> int: + return int(self.autoencoder.sampling_rate) + + @property + def mel_hop_length(self) -> int: + return int(self.autoencoder.mel_hop_length) + + @property + def mel_bins(self) -> int: + return int(self.autoencoder.mel_bins) + + @property + def latent_channels(self) -> int: + return int(self.autoencoder.decoder.z_channels) + + @property + def latent_frequency_bins(self) -> int: + return int(self.mel_bins // LATENT_DOWNSAMPLE_FACTOR) + + @property + def latents_per_second(self) -> float: + return self.sample_rate / self.mel_hop_length / LATENT_DOWNSAMPLE_FACTOR + + @property + def output_sample_rate(self) -> int: + output_rate = getattr(self.vocoder, "output_sample_rate", None) + if output_rate is not None: + return int(output_rate) + upsample_factor = getattr(self.vocoder, "upsample_factor", None) + if upsample_factor is None: + raise AttributeError( + "Vocoder is missing upsample_factor; cannot infer output sample rate" + ) + return int(self.sample_rate * upsample_factor / self.mel_hop_length) + + def memory_required(self, input_shape): + return self.device_manager.patcher.model_size() diff --git a/comfy/ldm/lightricks/vae/causal_audio_autoencoder.py b/comfy/ldm/lightricks/vae/causal_audio_autoencoder.py new file mode 100644 index 000000000000..f12b9bb53d86 --- /dev/null +++ b/comfy/ldm/lightricks/vae/causal_audio_autoencoder.py @@ -0,0 +1,909 @@ +from __future__ import annotations +import torch +from torch import nn +from torch.nn import functional as F +from typing import Optional +from enum import Enum +from .pixel_norm import PixelNorm +import comfy.ops +import logging + +ops = comfy.ops.disable_weight_init + + +class StringConvertibleEnum(Enum): + """ + Base enum class that provides string-to-enum conversion functionality. + + This mixin adds a str_to_enum() class method that handles conversion from + strings, None, or existing enum instances with case-insensitive matching. + """ + + @classmethod + def str_to_enum(cls, value): + """ + Convert a string, enum instance, or None to the appropriate enum member. + + Args: + value: Can be an enum instance of this class, a string, or None + + Returns: + Enum member of this class + + Raises: + ValueError: If the value cannot be converted to a valid enum member + """ + # Already an enum instance of this class + if isinstance(value, cls): + return value + + # None maps to NONE member if it exists + if value is None: + if hasattr(cls, "NONE"): + return cls.NONE + raise ValueError(f"{cls.__name__} does not have a NONE member to map None to") + + # String conversion (case-insensitive) + if isinstance(value, str): + value_lower = value.lower() + + # Try to match against enum values + for member in cls: + # Handle members with None values + if member.value is None: + if value_lower == "none": + return member + # Handle members with string values + elif isinstance(member.value, str) and member.value.lower() == value_lower: + return member + + # Build helpful error message with valid values + valid_values = [] + for member in cls: + if member.value is None: + valid_values.append("none") + elif isinstance(member.value, str): + valid_values.append(member.value) + + raise ValueError(f"Invalid {cls.__name__} string: '{value}'. " f"Valid values are: {valid_values}") + + raise ValueError( + f"Cannot convert type {type(value).__name__} to {cls.__name__} enum. " + f"Expected string, None, or {cls.__name__} instance." + ) + + +class AttentionType(StringConvertibleEnum): + """Enum for specifying the attention mechanism type.""" + + VANILLA = "vanilla" + LINEAR = "linear" + NONE = "none" + + +class CausalityAxis(StringConvertibleEnum): + """Enum for specifying the causality axis in causal convolutions.""" + + NONE = None + WIDTH = "width" + HEIGHT = "height" + WIDTH_COMPATIBILITY = "width-compatibility" + + +def Normalize(in_channels, *, num_groups=32, normtype="group"): + if normtype == "group": + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + elif normtype == "pixel": + return PixelNorm(dim=1, eps=1e-6) + else: + raise ValueError(f"Invalid normalization type: {normtype}") + + +class CausalConv2d(nn.Module): + """ + A causal 2D convolution. + + This layer ensures that the output at time `t` only depends on inputs + at time `t` and earlier. It achieves this by applying asymmetric padding + to the time dimension (width) before the convolution. + """ + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + dilation=1, + groups=1, + bias=True, + causality_axis: CausalityAxis = CausalityAxis.HEIGHT, + ): + super().__init__() + + self.causality_axis = causality_axis + + # Ensure kernel_size and dilation are tuples + kernel_size = nn.modules.utils._pair(kernel_size) + dilation = nn.modules.utils._pair(dilation) + + # Calculate padding dimensions + pad_h = (kernel_size[0] - 1) * dilation[0] + pad_w = (kernel_size[1] - 1) * dilation[1] + + # The padding tuple for F.pad is (pad_left, pad_right, pad_top, pad_bottom) + match self.causality_axis: + case CausalityAxis.NONE: + self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2) + case CausalityAxis.WIDTH | CausalityAxis.WIDTH_COMPATIBILITY: + self.padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2) + case CausalityAxis.HEIGHT: + self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0) + case _: + raise ValueError(f"Invalid causality_axis: {causality_axis}") + + # The internal convolution layer uses no padding, as we handle it manually + self.conv = ops.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=0, + dilation=dilation, + groups=groups, + bias=bias, + ) + + def forward(self, x): + # Apply causal padding before convolution + x = F.pad(x, self.padding) + return self.conv(x) + + +def make_conv2d( + in_channels, + out_channels, + kernel_size, + stride=1, + padding=None, + dilation=1, + groups=1, + bias=True, + causality_axis: Optional[CausalityAxis] = None, +): + """ + Create a 2D convolution layer that can be either causal or non-causal. + + Args: + in_channels: Number of input channels + out_channels: Number of output channels + kernel_size: Size of the convolution kernel + stride: Convolution stride + padding: Padding (if None, will be calculated based on causal flag) + dilation: Dilation rate + groups: Number of groups for grouped convolution + bias: Whether to use bias + causality_axis: Dimension along which to apply causality. + + Returns: + Either a regular Conv2d or CausalConv2d layer + """ + if causality_axis is not None: + # For causal convolution, padding is handled internally by CausalConv2d + return CausalConv2d(in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis) + else: + # For non-causal convolution, use symmetric padding if not specified + if padding is None: + if isinstance(kernel_size, int): + padding = kernel_size // 2 + else: + padding = tuple(k // 2 for k in kernel_size) + return ops.Conv2d( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + ) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv, causality_axis: CausalityAxis = CausalityAxis.HEIGHT): + super().__init__() + self.with_conv = with_conv + self.causality_axis = causality_axis + if self.with_conv: + self.conv = make_conv2d(in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + # Drop FIRST element in the causal axis to undo encoder's padding, while keeping the length 1 + 2 * n. + # For example, if the input is [0, 1, 2], after interpolation, the output is [0, 0, 1, 1, 2, 2]. + # The causal convolution will pad the first element as [-, -, 0, 0, 1, 1, 2, 2], + # So the output elements rely on the following windows: + # 0: [-,-,0] + # 1: [-,0,0] + # 2: [0,0,1] + # 3: [0,1,1] + # 4: [1,1,2] + # 5: [1,2,2] + # Notice that the first and second elements in the output rely only on the first element in the input, + # while all other elements rely on two elements in the input. + # So we can drop the first element to undo the padding (rather than the last element). + # This is a no-op for non-causal convolutions. + match self.causality_axis: + case CausalityAxis.NONE: + pass # x remains unchanged + case CausalityAxis.HEIGHT: + x = x[:, :, 1:, :] + case CausalityAxis.WIDTH: + x = x[:, :, :, 1:] + case CausalityAxis.WIDTH_COMPATIBILITY: + pass # x remains unchanged + case _: + raise ValueError(f"Invalid causality_axis: {self.causality_axis}") + + return x + + +class Downsample(nn.Module): + """ + A downsampling layer that can use either a strided convolution + or average pooling. Supports standard and causal padding for the + convolutional mode. + """ + + def __init__(self, in_channels, with_conv, causality_axis: CausalityAxis = CausalityAxis.WIDTH): + super().__init__() + self.with_conv = with_conv + self.causality_axis = causality_axis + + if self.causality_axis != CausalityAxis.NONE and not self.with_conv: + raise ValueError("causality is only supported when `with_conv=True`.") + + if self.with_conv: + # Do time downsampling here + # no asymmetric padding in torch conv, must do it ourselves + self.conv = ops.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x): + if self.with_conv: + # (pad_left, pad_right, pad_top, pad_bottom) + match self.causality_axis: + case CausalityAxis.NONE: + pad = (0, 1, 0, 1) + case CausalityAxis.WIDTH: + pad = (2, 0, 0, 1) + case CausalityAxis.HEIGHT: + pad = (0, 1, 2, 0) + case CausalityAxis.WIDTH_COMPATIBILITY: + pad = (1, 0, 0, 1) + case _: + raise ValueError(f"Invalid causality_axis: {self.causality_axis}") + + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + # This branch is only taken if with_conv=False, which implies causality_axis is NONE. + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + + return x + + +class ResnetBlock(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512, + norm_type="group", + causality_axis: CausalityAxis = CausalityAxis.HEIGHT, + ): + super().__init__() + self.causality_axis = causality_axis + + if self.causality_axis != CausalityAxis.NONE and norm_type == "group": + raise ValueError("Causal ResnetBlock with GroupNorm is not supported.") + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels, normtype=norm_type) + self.non_linearity = nn.SiLU() + self.conv1 = make_conv2d(in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis) + if temb_channels > 0: + self.temb_proj = ops.Linear(temb_channels, out_channels) + self.norm2 = Normalize(out_channels, normtype=norm_type) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = make_conv2d(out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = make_conv2d( + in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis + ) + else: + self.nin_shortcut = make_conv2d( + in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis + ) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = self.non_linearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = self.non_linearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels, norm_type="group"): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels, normtype=norm_type) + self.q = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w).contiguous() + q = q.permute(0, 2, 1).contiguous() # b,hw,c + k = k.reshape(b, c, h * w).contiguous() # b,c,hw + w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w).contiguous() + w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w).contiguous() + + h_ = self.proj_out(h_) + + return x + h_ + + +def make_attn(in_channels, attn_type="vanilla", norm_type="group"): + # Convert string to enum if needed + attn_type = AttentionType.str_to_enum(attn_type) + + if attn_type != AttentionType.NONE: + logging.info(f"making attention of type '{attn_type.value}' with {in_channels} in_channels") + else: + logging.info(f"making identity attention with {in_channels} in_channels") + + match attn_type: + case AttentionType.VANILLA: + return AttnBlock(in_channels, norm_type=norm_type) + case AttentionType.NONE: + return nn.Identity(in_channels) + case AttentionType.LINEAR: + raise NotImplementedError(f"Attention type {attn_type.value} is not supported yet.") + case _: + raise ValueError(f"Unknown attention type: {attn_type}") + + +class Encoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + attn_type="vanilla", + mid_block_add_attention=True, + norm_type="group", + causality_axis=CausalityAxis.WIDTH.value, + **ignore_kwargs, + ): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.z_channels = z_channels + self.double_z = double_z + self.norm_type = norm_type + # Convert string to enum if needed (for config loading) + causality_axis = CausalityAxis.str_to_enum(causality_axis) + self.attn_type = AttentionType.str_to_enum(attn_type) + + # downsampling + self.conv_in = make_conv2d( + in_channels, + self.ch, + kernel_size=3, + stride=1, + causality_axis=causality_axis, + ) + + self.non_linearity = nn.SiLU() + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + + for _ in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=causality_axis, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type)) + + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv, causality_axis=causality_axis) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=causality_axis, + ) + if mid_block_add_attention: + self.mid.attn_1 = make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type) + else: + self.mid.attn_1 = nn.Identity() + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=causality_axis, + ) + + # end + self.norm_out = Normalize(block_in, normtype=self.norm_type) + self.conv_out = make_conv2d( + block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + causality_axis=causality_axis, + ) + + def forward(self, x): + """ + Forward pass through the encoder. + + Args: + x: Input tensor of shape [batch, channels, time, n_mels] + + Returns: + Encoded latent representation + """ + feature_maps = [self.conv_in(x)] + + # Process each resolution level (from high to low resolution) + for resolution_level in range(self.num_resolutions): + # Apply residual blocks at current resolution level + for block_idx in range(self.num_res_blocks): + # Apply ResNet block with optional timestep embedding + current_features = self.down[resolution_level].block[block_idx](feature_maps[-1], temb=None) + + # Apply attention if configured for this resolution level + if len(self.down[resolution_level].attn) > 0: + current_features = self.down[resolution_level].attn[block_idx](current_features) + + # Store processed features + feature_maps.append(current_features) + + # Downsample spatial dimensions (except at the final resolution level) + if resolution_level != self.num_resolutions - 1: + downsampled_features = self.down[resolution_level].downsample(feature_maps[-1]) + feature_maps.append(downsampled_features) + + # === MIDDLE PROCESSING PHASE === + # Take the lowest resolution features for middle processing + bottleneck_features = feature_maps[-1] + + # Apply first middle ResNet block + bottleneck_features = self.mid.block_1(bottleneck_features, temb=None) + + # Apply middle attention block + bottleneck_features = self.mid.attn_1(bottleneck_features) + + # Apply second middle ResNet block + bottleneck_features = self.mid.block_2(bottleneck_features, temb=None) + + # === OUTPUT PHASE === + # Normalize the bottleneck features + output_features = self.norm_out(bottleneck_features) + + # Apply non-linearity (SiLU activation) + output_features = self.non_linearity(output_features) + + # Final convolution to produce latent representation + # [batch, channels, time, n_mels] -> [batch, 2 * z_channels if double_z else z_channels, time, n_mels] + return self.conv_out(output_features) + + +class Decoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + tanh_out=False, + attn_type="vanilla", + mid_block_add_attention=True, + norm_type="group", + causality_axis=CausalityAxis.WIDTH.value, + **ignorekwargs, + ): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.out_ch = out_ch + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + self.norm_type = norm_type + self.z_channels = z_channels + # Convert string to enum if needed (for config loading) + causality_axis = CausalityAxis.str_to_enum(causality_axis) + self.attn_type = AttentionType.str_to_enum(attn_type) + + # compute block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + + # z to block_in + self.conv_in = make_conv2d(z_channels, block_in, kernel_size=3, stride=1, causality_axis=causality_axis) + + self.non_linearity = nn.SiLU() + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=causality_axis, + ) + if mid_block_add_attention: + self.mid.attn_1 = make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type) + else: + self.mid.attn_1 = nn.Identity() + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=causality_axis, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=causality_axis, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv, causality_axis=causality_axis) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in, normtype=self.norm_type) + self.conv_out = make_conv2d(block_in, out_ch, kernel_size=3, stride=1, causality_axis=causality_axis) + + def _adjust_output_shape(self, decoded_output, target_shape): + """ + Adjust output shape to match target dimensions for variable-length audio. + + This function handles the common case where decoded audio spectrograms need to be + resized to match a specific target shape. + + Args: + decoded_output: Tensor of shape (batch, channels, time, frequency) + target_shape: Target shape tuple (batch, channels, time, frequency) + + Returns: + Tensor adjusted to match target_shape exactly + """ + # Current output shape: (batch, channels, time, frequency) + _, _, current_time, current_freq = decoded_output.shape + _, target_channels, target_time, target_freq = target_shape + + # Step 1: Crop first to avoid exceeding target dimensions + decoded_output = decoded_output[ + :, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq) + ] + + # Step 2: Calculate padding needed for time and frequency dimensions + time_padding_needed = target_time - decoded_output.shape[2] + freq_padding_needed = target_freq - decoded_output.shape[3] + + # Step 3: Apply padding if needed + if time_padding_needed > 0 or freq_padding_needed > 0: + # PyTorch padding format: (pad_left, pad_right, pad_top, pad_bottom) + # For audio: pad_left/right = frequency, pad_top/bottom = time + padding = ( + 0, + max(freq_padding_needed, 0), # frequency padding (left, right) + 0, + max(time_padding_needed, 0), # time padding (top, bottom) + ) + decoded_output = F.pad(decoded_output, padding) + + # Step 4: Final safety crop to ensure exact target shape + decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq] + + return decoded_output + + def get_config(self): + return { + "ch": self.ch, + "out_ch": self.out_ch, + "ch_mult": self.ch_mult, + "num_res_blocks": self.num_res_blocks, + "in_channels": self.in_channels, + "resolution": self.resolution, + "z_channels": self.z_channels, + } + + def forward(self, latent_features, target_shape=None): + """ + Decode latent features back to audio spectrograms. + + Args: + latent_features: Encoded latent representation of shape (batch, channels, height, width) + target_shape: Optional target output shape (batch, channels, time, frequency) + If provided, output will be cropped/padded to match this shape + + Returns: + Reconstructed audio spectrogram of shape (batch, channels, time, frequency) + """ + assert target_shape is not None, "Target shape is required for CausalAudioAutoencoder Decoder" + + # Transform latent features to decoder's internal feature dimension + hidden_features = self.conv_in(latent_features) + + # Middle processing + hidden_features = self.mid.block_1(hidden_features, temb=None) + hidden_features = self.mid.attn_1(hidden_features) + hidden_features = self.mid.block_2(hidden_features, temb=None) + + # Upsampling + # Progressively increase spatial resolution from lowest to highest + for resolution_level in reversed(range(self.num_resolutions)): + # Apply residual blocks at current resolution level + for block_index in range(self.num_res_blocks + 1): + hidden_features = self.up[resolution_level].block[block_index](hidden_features, temb=None) + + if len(self.up[resolution_level].attn) > 0: + hidden_features = self.up[resolution_level].attn[block_index](hidden_features) + + if resolution_level != 0: + hidden_features = self.up[resolution_level].upsample(hidden_features) + + # Output + if self.give_pre_end: + # Return intermediate features before final processing (for debugging/analysis) + decoded_output = hidden_features + else: + # Standard output path: normalize, activate, and convert to output channels + # Final normalization layer + hidden_features = self.norm_out(hidden_features) + + # Apply SiLU (Swish) activation function + hidden_features = self.non_linearity(hidden_features) + + # Final convolution to map to output channels (typically 2 for stereo audio) + decoded_output = self.conv_out(hidden_features) + + # Optional tanh activation to bound output values to [-1, 1] range + if self.tanh_out: + decoded_output = torch.tanh(decoded_output) + + # Adjust shape for audio data + if target_shape is not None: + decoded_output = self._adjust_output_shape(decoded_output, target_shape) + + return decoded_output + + +class processor(nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("std-of-means", torch.empty(128)) + self.register_buffer("mean-of-means", torch.empty(128)) + + def un_normalize(self, x): + return (x * self.get_buffer("std-of-means").to(x)) + self.get_buffer("mean-of-means").to(x) + + def normalize(self, x): + return (x - self.get_buffer("mean-of-means").to(x)) / self.get_buffer("std-of-means").to(x) + + +class CausalAudioAutoencoder(nn.Module): + def __init__(self, config=None): + super().__init__() + + if config is None: + config = self._guess_config() + + # Extract encoder and decoder configs from the new format + model_config = config.get("model", {}).get("params", {}) + variables_config = config.get("variables", {}) + + self.sampling_rate = variables_config.get( + "sampling_rate", + model_config.get("sampling_rate", config.get("sampling_rate", 16000)), + ) + encoder_config = model_config.get("encoder", model_config.get("ddconfig", {})) + decoder_config = model_config.get("decoder", encoder_config) + + # Load mel spectrogram parameters + self.mel_bins = encoder_config.get("mel_bins", 64) + self.mel_hop_length = model_config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160) + self.n_fft = model_config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024) + + # Store causality configuration at VAE level (not just in encoder internals) + causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.WIDTH.value) + self.causality_axis = CausalityAxis.str_to_enum(causality_axis_value) + self.is_causal = self.causality_axis == CausalityAxis.HEIGHT + + self.encoder = Encoder(**encoder_config) + self.decoder = Decoder(**decoder_config) + + self.per_channel_statistics = processor() + + def _guess_config(self): + encoder_config = { + # Required parameters - based on ltx-video-av-1679000 model metadata + "ch": 128, + "out_ch": 8, + "ch_mult": [1, 2, 4], # Based on metadata: [1, 2, 4] not [1, 2, 4, 8] + "num_res_blocks": 2, + "attn_resolutions": [], # Based on metadata: empty list, no attention + "dropout": 0.0, + "resamp_with_conv": True, + "in_channels": 2, # stereo + "resolution": 256, + "z_channels": 8, + "double_z": True, + "attn_type": "vanilla", + "mid_block_add_attention": False, # Based on metadata: false + "norm_type": "pixel", + "causality_axis": "height", # Based on metadata + "mel_bins": 64, # Based on metadata: mel_bins = 64 + } + + decoder_config = { + # Inherits encoder config, can override specific params + **encoder_config, + "out_ch": 2, # Stereo audio output (2 channels) + "give_pre_end": False, + "tanh_out": False, + } + + config = { + "_class_name": "CausalAudioAutoencoder", + "sampling_rate": 16000, + "model": { + "params": { + "encoder": encoder_config, + "decoder": decoder_config, + } + }, + } + + return config + + def get_config(self): + return { + "sampling_rate": self.sampling_rate, + "mel_bins": self.mel_bins, + "mel_hop_length": self.mel_hop_length, + "n_fft": self.n_fft, + "causality_axis": self.causality_axis.value, + "is_causal": self.is_causal, + } + + def encode(self, x): + return self.encoder(x) + + def decode(self, x, target_shape=None): + return self.decoder(x, target_shape=target_shape) diff --git a/comfy/ldm/lightricks/vocoders/vocoder.py b/comfy/ldm/lightricks/vocoders/vocoder.py new file mode 100644 index 000000000000..b1f15f2c5f77 --- /dev/null +++ b/comfy/ldm/lightricks/vocoders/vocoder.py @@ -0,0 +1,213 @@ +import torch +import torch.nn.functional as F +import torch.nn as nn +import comfy.ops +import numpy as np + +ops = comfy.ops.disable_weight_init + +LRELU_SLOPE = 0.1 + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +class ResBlock1(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock1, self).__init__() + self.convs1 = nn.ModuleList( + [ + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ), + ] + ) + + self.convs2 = nn.ModuleList( + [ + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ), + ] + ) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + +class ResBlock2(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3)): + super(ResBlock2, self).__init__() + self.convs = nn.ModuleList( + [ + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ), + ] + ) + + def forward(self, x): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c(xt) + x = xt + x + return x + + +class Vocoder(torch.nn.Module): + """ + Vocoder model for synthesizing audio from spectrograms, based on: https://github.com/jik876/hifi-gan. + + """ + + def __init__(self, config=None): + super(Vocoder, self).__init__() + + if config is None: + config = self.get_default_config() + + resblock_kernel_sizes = config.get("resblock_kernel_sizes", [3, 7, 11]) + upsample_rates = config.get("upsample_rates", [6, 5, 2, 2, 2]) + upsample_kernel_sizes = config.get("upsample_kernel_sizes", [16, 15, 8, 4, 4]) + resblock_dilation_sizes = config.get("resblock_dilation_sizes", [[1, 3, 5], [1, 3, 5], [1, 3, 5]]) + upsample_initial_channel = config.get("upsample_initial_channel", 1024) + stereo = config.get("stereo", True) + resblock = config.get("resblock", "1") + + self.output_sample_rate = config.get("output_sample_rate") + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + in_channels = 128 if stereo else 64 + self.conv_pre = ops.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3) + resblock_class = ResBlock1 if resblock == "1" else ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append( + ops.ConvTranspose1d( + upsample_initial_channel // (2**i), + upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): + self.resblocks.append(resblock_class(ch, k, d)) + + out_channels = 2 if stereo else 1 + self.conv_post = ops.Conv1d(ch, out_channels, 7, 1, padding=3) + + self.upsample_factor = np.prod([self.ups[i].stride[0] for i in range(len(self.ups))]) + + def get_default_config(self): + """Generate default configuration for the vocoder.""" + + config = { + "resblock_kernel_sizes": [3, 7, 11], + "upsample_rates": [6, 5, 2, 2, 2], + "upsample_kernel_sizes": [16, 15, 8, 4, 4], + "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "upsample_initial_channel": 1024, + "stereo": True, + "resblock": "1", + } + + return config + + def forward(self, x): + """ + Forward pass of the vocoder. + + Args: + x: Input spectrogram tensor. Can be: + - 3D: (batch_size, channels, time_steps) for mono + - 4D: (batch_size, 2, channels, time_steps) for stereo + + Returns: + Audio tensor of shape (batch_size, out_channels, audio_length) + """ + if x.dim() == 4: # stereo + assert x.shape[1] == 2, "Input must have 2 channels for stereo" + x = torch.cat((x[:, 0, :, :], x[:, 1, :, :]), dim=1) + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index e80b1c1389f5..afbab2ac7256 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -491,7 +491,8 @@ def __init__( for layer_id in range(n_layers) ] ) - self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + # This norm final is in the lumina 2.0 code but isn't actually used for anything. + # self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.final_layer = FinalLayer(dim, patch_size, self.out_channels, z_image_modulation=z_image_modulation, operation_settings=operation_settings) if self.pad_tokens_multiple is not None: diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index a8800ded0ca6..ccf690945aaa 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -30,6 +30,13 @@ raise e exit(-1) +SAGE_ATTENTION3_IS_AVAILABLE = False +try: + from sageattn3 import sageattn3_blackwell + SAGE_ATTENTION3_IS_AVAILABLE = True +except ImportError: + pass + FLASH_ATTENTION_IS_AVAILABLE = False try: from flash_attn import flash_attn_func @@ -563,6 +570,93 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape= out = out.reshape(b, -1, heads * dim_head) return out +@wrap_attn +def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): + exception_fallback = False + if (q.device.type != "cuda" or + q.dtype not in (torch.float16, torch.bfloat16) or + mask is not None): + return attention_pytorch( + q, k, v, heads, + mask=mask, + attn_precision=attn_precision, + skip_reshape=skip_reshape, + skip_output_reshape=skip_output_reshape, + **kwargs + ) + + if skip_reshape: + B, H, L, D = q.shape + if H != heads: + return attention_pytorch( + q, k, v, heads, + mask=mask, + attn_precision=attn_precision, + skip_reshape=True, + skip_output_reshape=skip_output_reshape, + **kwargs + ) + q_s, k_s, v_s = q, k, v + N = q.shape[2] + dim_head = D + else: + B, N, inner_dim = q.shape + if inner_dim % heads != 0: + return attention_pytorch( + q, k, v, heads, + mask=mask, + attn_precision=attn_precision, + skip_reshape=False, + skip_output_reshape=skip_output_reshape, + **kwargs + ) + dim_head = inner_dim // heads + + if dim_head >= 256 or N <= 1024: + return attention_pytorch( + q, k, v, heads, + mask=mask, + attn_precision=attn_precision, + skip_reshape=skip_reshape, + skip_output_reshape=skip_output_reshape, + **kwargs + ) + + if not skip_reshape: + q_s, k_s, v_s = map( + lambda t: t.view(B, -1, heads, dim_head).permute(0, 2, 1, 3).contiguous(), + (q, k, v), + ) + B, H, L, D = q_s.shape + + try: + out = sageattn3_blackwell(q_s, k_s, v_s, is_causal=False) + except Exception as e: + exception_fallback = True + logging.error("Error running SageAttention3: %s, falling back to pytorch attention.", e) + + if exception_fallback: + if not skip_reshape: + del q_s, k_s, v_s + return attention_pytorch( + q, k, v, heads, + mask=mask, + attn_precision=attn_precision, + skip_reshape=False, + skip_output_reshape=skip_output_reshape, + **kwargs + ) + + if skip_reshape: + if not skip_output_reshape: + out = out.permute(0, 2, 1, 3).reshape(B, L, H * D) + else: + if skip_output_reshape: + pass + else: + out = out.permute(0, 2, 1, 3).reshape(B, L, H * D) + + return out try: @torch.library.custom_op("flash_attention::flash_attn", mutates_args=()) @@ -650,6 +744,8 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape # register core-supported attention functions if SAGE_ATTENTION_IS_AVAILABLE: register_attention_function("sage", attention_sage) +if SAGE_ATTENTION3_IS_AVAILABLE: + register_attention_function("sage3", attention3_sage) if FLASH_ATTENTION_IS_AVAILABLE: register_attention_function("flash", attention_flash) if model_management.xformers_enabled(): diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 681a55db5a1b..1ae3ef0343cd 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -394,7 +394,8 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"): super().__init__() - if use_linear_attn: attn_type = "linear" + if use_linear_attn: + attn_type = "linear" self.ch = ch self.temb_ch = self.ch*4 self.num_resolutions = len(ch_mult) @@ -548,7 +549,8 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, conv3d=False, time_compress=None, **ignore_kwargs): super().__init__() - if use_linear_attn: attn_type = "linear" + if use_linear_attn: + attn_type = "linear" self.ch = ch self.temb_ch = 0 self.num_resolutions = len(ch_mult) diff --git a/comfy/ldm/modules/ema.py b/comfy/ldm/modules/ema.py index bded25019b9b..96ee6e89549f 100644 --- a/comfy/ldm/modules/ema.py +++ b/comfy/ldm/modules/ema.py @@ -45,7 +45,7 @@ def forward(self, model): shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) else: - assert not key in self.m_name2s_name + assert key not in self.m_name2s_name def copy_to(self, model): m_param = dict(model.named_parameters()) @@ -54,7 +54,7 @@ def copy_to(self, model): if m_param[key].requires_grad: m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) else: - assert not key in self.m_name2s_name + assert key not in self.m_name2s_name def store(self, parameters): """ diff --git a/comfy/ldm/util.py b/comfy/ldm/util.py index 30b4b4721056..304936ff4414 100644 --- a/comfy/ldm/util.py +++ b/comfy/ldm/util.py @@ -71,7 +71,7 @@ def count_params(model, verbose=False): def instantiate_from_config(config): - if not "target" in config: + if "target" not in config: if config == '__is_first_stage__': return None elif config == "__is_unconditional__": diff --git a/comfy/model_base.py b/comfy/model_base.py index c4f3c0639bab..7939bbc7b2e5 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -20,6 +20,7 @@ import comfy.ldm.hunyuan3dv2_1.hunyuandit import torch import logging +import comfy.ldm.lightricks.av_model from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep from comfy.ldm.cascade.stage_c import StageC from comfy.ldm.cascade.stage_b import StageB @@ -61,6 +62,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from comfy.model_patcher import ModelPatcher +from comfy.model_management import get_free_memory class ModelType(Enum): EPS = 1 @@ -304,8 +306,15 @@ def load_model_weights(self, sd, unet_prefix=""): if k.startswith(unet_prefix): to_load[k[len(unet_prefix):]] = sd.pop(k) + free_cpu_memory = get_free_memory(torch.device("cpu")) + logging.debug(f"load model weights start, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB") + logging.debug(f"model destination device {next(self.diffusion_model.parameters()).device}") to_load = self.model_config.process_unet_state_dict(to_load) - m, u = self.diffusion_model.load_state_dict(to_load, strict=False) + logging.debug(f"load model {self.model_config} weights process end") + # replace tensor with mmap tensor by assign + m, u = self.diffusion_model.load_state_dict(to_load, strict=False, assign=True) + free_cpu_memory = get_free_memory(torch.device("cpu")) + logging.debug(f"load model {self.model_config} weights end, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB") if len(m) > 0: logging.warning("unet missing: {}".format(m)) @@ -946,7 +955,7 @@ def extra_conds(self, **kwargs): class LTXV(BaseModel): def __init__(self, model_config, model_type=ModelType.FLUX, device=None): - super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.model.LTXVModel) #TODO + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.model.LTXVModel) def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) @@ -977,6 +986,60 @@ def process_timestep(self, timestep, x, denoise_mask=None, **kwargs): def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): return latent_image +class LTXAV(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLUX, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.av_model.LTXAVModel) #TODO + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + attention_mask = kwargs.get("attention_mask", None) + if attention_mask is not None: + out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + + out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25)) + + denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) + + audio_denoise_mask = None + if denoise_mask is not None and "latent_shapes" in kwargs: + denoise_mask = utils.unpack_latents(denoise_mask, kwargs["latent_shapes"]) + if len(denoise_mask) > 1: + audio_denoise_mask = denoise_mask[1] + denoise_mask = denoise_mask[0] + + if denoise_mask is not None: + out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask) + + if audio_denoise_mask is not None: + out["audio_denoise_mask"] = comfy.conds.CONDRegular(audio_denoise_mask) + + keyframe_idxs = kwargs.get("keyframe_idxs", None) + if keyframe_idxs is not None: + out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs) + + latent_shapes = kwargs.get("latent_shapes", None) + if latent_shapes is not None: + out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes) + + return out + + def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs): + v_timestep = timestep + a_timestep = timestep + + if denoise_mask is not None: + v_timestep = self.diffusion_model.patchifier.patchify(((denoise_mask) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1)))[:, :1])[0] + if audio_denoise_mask is not None: + a_timestep = self.diffusion_model.a_patchifier.patchify(((audio_denoise_mask) * timestep.view([timestep.shape[0]] + [1] * (audio_denoise_mask.ndim - 1)))[:, :1, :, :1])[0] + + return v_timestep, a_timestep + + def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): + return latent_image + class HunyuanVideo(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 539e296ed218..0853b3aec5cd 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -305,7 +305,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv dit_config = {} - dit_config["image_model"] = "ltxv" + dit_config["image_model"] = "ltxav" if f'{key_prefix}audio_adaln_single.linear.weight' in state_dict_keys else "ltxv" dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.') shape = state_dict['{}transformer_blocks.0.attn2.to_k.weight'.format(key_prefix)].shape dit_config["attention_head_dim"] = shape[0] // 32 diff --git a/comfy/model_management.py b/comfy/model_management.py index 1889ab0acda7..79f50a5fb16e 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -28,6 +28,17 @@ import gc import os +from functools import lru_cache + +@lru_cache(maxsize=1) +def get_mmap_mem_threshold_gb(): + mmap_mem_threshold_gb = int(os.environ.get("MMAP_MEM_THRESHOLD_GB", "0")) + logging.debug(f"MMAP_MEM_THRESHOLD_GB: {mmap_mem_threshold_gb}") + return mmap_mem_threshold_gb + +def get_free_disk(): + return psutil.disk_usage("/").free + class VRAMState(Enum): DISABLED = 0 #No vram present: no need to move models to vram NO_VRAM = 1 #Very low vram: enable all the options to save vram @@ -456,7 +467,7 @@ def module_size(module): sd = module.state_dict() for k in sd: t = sd[k] - module_mem += t.nelement() * t.element_size() + module_mem += t.nbytes return module_mem class LoadedModel: @@ -524,16 +535,47 @@ def should_reload_model(self, force_patch_weights=False): return False def model_unload(self, memory_to_free=None, unpatch_weights=True): - if memory_to_free is not None: - if memory_to_free < self.model.loaded_size(): - freed = self.model.partially_unload(self.model.offload_device, memory_to_free) - if freed >= memory_to_free: - return False - self.model.detach(unpatch_weights) - self.model_finalizer.detach() - self.model_finalizer = None - self.real_model = None - return True + logging.debug(f"model_unload: {self.model.model.__class__.__name__}") + logging.debug(f"memory_to_free: {memory_to_free/(1024*1024*1024)} GB") + logging.debug(f"unpatch_weights: {unpatch_weights}") + logging.debug(f"loaded_size: {self.model.loaded_size()/(1024*1024*1024)} GB") + logging.debug(f"offload_device: {self.model.offload_device}") + + if memory_to_free is None: + # free the full model + memory_to_free = self.model.loaded_size() + + available_memory = get_free_memory(self.model.offload_device) + logging.debug(f"before unload, available_memory of offload device {self.model.offload_device}: {available_memory/(1024*1024*1024)} GB") + + mmap_mem_threshold = get_mmap_mem_threshold_gb() * 1024 * 1024 * 1024 # this is reserved memory for other system usage + if memory_to_free > available_memory - mmap_mem_threshold or memory_to_free < self.model.loaded_size(): + partially_unload = True + else: + partially_unload = False + + if partially_unload: + logging.debug("Do partially unload") + freed = self.model.partially_unload(self.model.offload_device, memory_to_free) + logging.debug(f"partially_unload freed vram: {freed/(1024*1024*1024)} GB") + if freed < memory_to_free: + logging.warning(f"Partially unload not enough memory, freed {freed/(1024*1024*1024)} GB, memory_to_free {memory_to_free/(1024*1024*1024)} GB") + else: + logging.debug("Do full unload") + self.model.detach(unpatch_weights) + logging.debug("Do full unload done") + self.model_finalizer.detach() + self.model_finalizer = None + self.real_model = None + + available_memory = get_free_memory(self.model.offload_device) + logging.debug(f"after unload, available_memory of offload device {self.model.offload_device}: {available_memory/(1024*1024*1024)} GB") + + if partially_unload: + return False + else: + return True + def model_use_more_vram(self, extra_memory, force_patch_weights=False): return self.model.partially_load(self.device, extra_memory, force_patch_weights=force_patch_weights) @@ -582,6 +624,7 @@ def minimum_inference_memory(): return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory() def free_memory(memory_required, device, keep_loaded=[]): + logging.debug("start to free mem") cleanup_models_gc() unloaded_model = [] can_unload = [] @@ -619,6 +662,7 @@ def free_memory(memory_required, device, keep_loaded=[]): return unloaded_models def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False): + logging.debug(f"start to load models") cleanup_models_gc() global vram_state @@ -640,6 +684,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu models_to_load = [] for x in models: + logging.debug(f"start loading model to vram: {x.model.__class__.__name__}") loaded_model = LoadedModel(x) try: loaded_model_index = current_loaded_models.index(loaded_model) @@ -1019,8 +1064,8 @@ def force_channels_last(): if args.async_offload is not None: NUM_STREAMS = args.async_offload else: - # Enable by default on Nvidia - if is_nvidia(): + # Enable by default on Nvidia and AMD + if is_nvidia() or is_amd(): NUM_STREAMS = 2 if args.disable_async_offload: @@ -1126,6 +1171,16 @@ def cast_to_device(tensor, device, dtype, copy=False): PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"]) +def discard_cuda_async_error(): + try: + a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device()) + b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device()) + _ = a + b + torch.cuda.synchronize() + except torch.AcceleratorError: + #Dump it! We already know about it from the synchronous return + pass + def pin_memory(tensor): global TOTAL_PINNED_MEMORY if MAX_PINNED_MEMORY <= 0: @@ -1158,6 +1213,9 @@ def pin_memory(tensor): PINNED_MEMORY[ptr] = size TOTAL_PINNED_MEMORY += size return True + else: + logging.warning("Pin error.") + discard_cuda_async_error() return False @@ -1186,6 +1244,9 @@ def unpin_memory(tensor): if len(PINNED_MEMORY) == 0: TOTAL_PINNED_MEMORY = 0 return True + else: + logging.warning("Unpin error.") + discard_cuda_async_error() return False @@ -1526,6 +1587,10 @@ def soft_empty_cache(force=False): def unload_all_models(): free_memory(1e30, get_torch_device()) +def debug_memory_summary(): + if is_amd() or is_nvidia(): + return torch.cuda.memory.memory_summary() + return "" #TODO: might be cleaner to put this somewhere else import threading diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 93d26c690665..f5c9633fd8ca 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -27,6 +27,10 @@ from typing import Callable, Optional import torch +import os +import tempfile +import weakref +import gc import comfy.float import comfy.hooks @@ -37,6 +41,87 @@ from comfy.comfy_types import UnetWrapperFunction from comfy.quant_ops import QuantizedTensor from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP +from comfy.model_management import get_free_memory, get_mmap_mem_threshold_gb, get_free_disk +from comfy.quant_ops import QuantizedTensor + +def need_mmap() -> bool: + free_cpu_mem = get_free_memory(torch.device("cpu")) + mmap_mem_threshold_gb = get_mmap_mem_threshold_gb() + if free_cpu_mem < mmap_mem_threshold_gb * 1024 * 1024 * 1024: + logging.debug(f"Enabling mmap, current free cpu memory {free_cpu_mem/(1024*1024*1024)} GB < {mmap_mem_threshold_gb} GB") + return True + return False + +def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor: + """ + Convert a tensor to a memory-mapped CPU tensor using PyTorch's native mmap support. + """ + # Create temporary file + if filename is None: + temp_file = tempfile.mkstemp(suffix='.pt', prefix='comfy_mmap_')[1] + else: + temp_file = filename + + # Save tensor to file + cpu_tensor = t.cpu() + torch.save(cpu_tensor, temp_file) + + # If we created a CPU copy from other device, delete it to free memory + if not t.device.type == 'cpu': + del cpu_tensor + gc.collect() + + # Load with mmap - this doesn't load all data into RAM + mmap_tensor = torch.load(temp_file, map_location='cpu', mmap=True, weights_only=False) + + # Register cleanup callback - will be called when tensor is garbage collected + def _cleanup(): + try: + if os.path.exists(temp_file): + os.remove(temp_file) + logging.debug(f"Cleaned up mmap file: {temp_file}") + except Exception: + pass + + weakref.finalize(mmap_tensor, _cleanup) + + return mmap_tensor + +def model_to_mmap(model: torch.nn.Module): + """Convert all parameters and buffers to memory-mapped tensors + + This function mimics PyTorch's Module.to() behavior but converts + tensors to memory-mapped format instead, using _apply() method. + + Reference: https://github.com/pytorch/pytorch/blob/0fabc3ba44823f257e70ce397d989c8de5e362c1/torch/nn/modules/module.py#L1244 + + Note: For Parameters, we modify .data in-place because + MemoryMappedTensor cannot be wrapped in torch.nn.Parameter. + For buffers, _apply() will automatically update the reference. + + Args: + model: PyTorch module to convert + + Returns: + The same model with all tensors converted to memory-mapped format + """ + free_cpu_mem = get_free_memory(torch.device("cpu")) + logging.debug(f"Converting model {model.__class__.__name__} to mmap, current free cpu memory: {free_cpu_mem/(1024*1024*1024)} GB") + + def convert_fn(t): + if isinstance(t, QuantizedTensor): + logging.debug(f"QuantizedTensor detected, tensor meta info: size {t.size()}, dtype {t.dtype}, device {t.device}, is_contiguous {t.is_contiguous()}") + if isinstance(t, torch.nn.Parameter): + new_tensor = to_mmap(t.detach()) + return torch.nn.Parameter(new_tensor, requires_grad=t.requires_grad) + elif isinstance(t, torch.Tensor): + return to_mmap(t) + return t + + new_model = model._apply(convert_fn) + free_cpu_mem = get_free_memory(torch.device("cpu")) + logging.debug(f"Model {model.__class__.__name__} converted to mmap, current free cpu memory: {free_cpu_mem/(1024*1024*1024)} GB") + return new_model def string_to_seed(data): @@ -506,6 +591,7 @@ def get_model_object(self, name: str) -> torch.nn.Module: return comfy.utils.get_attr(self.model, name) def model_patches_to(self, device): + # TODO(sf): to mmap to = self.model_options["transformer_options"] if "patches" in to: patches = to["patches"] @@ -853,9 +939,15 @@ def unpatch_model(self, device_to=None, unpatch_weights=True): self.model.current_weight_patches_uuid = None self.backup.clear() + if device_to is not None: - self.model.to(device_to) + if need_mmap(): + # offload to mmap + model_to_mmap(self.model) + else: + self.model.to(device_to) self.model.device = device_to + self.model.model_loaded_weight_memory = 0 self.model.model_offload_buffer_memory = 0 @@ -914,7 +1006,14 @@ def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=Fals bias_key = "{}.bias".format(n) if move_weight: cast_weight = self.force_cast_weights - m.to(device_to) + if need_mmap(): + if get_free_disk() < module_mem: + logging.warning(f"Not enough disk space to offload {n} to mmap, current free disk space {get_free_disk()/(1024*1024*1024)} GB < {module_mem/(1024*1024*1024)} GB") + break + # offload to mmap + model_to_mmap(m) + else: + m.to(device_to) module_mem += move_weight_functions(m, device_to) if lowvram_possible: if weight_key in self.patches: diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index cd96541d78eb..b88a2661f335 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -130,7 +130,19 @@ def __new__(cls, qdata, layout_type, layout_params): layout_type: Layout class (subclass of QuantizedLayout) layout_params: Dict with layout-specific parameters """ - return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False) + # Use as_subclass so the QuantizedTensor instance shares the same + # storage and metadata as the underlying qdata tensor. This ensures + # torch.save/torch.load and the torch serialization storage scanning + # see a valid underlying storage (fixes data_ptr errors). + if not isinstance(qdata, torch.Tensor): + raise TypeError("qdata must be a torch.Tensor") + obj = qdata.as_subclass(cls) + # Ensure grad flag is consistent for quantized tensors + try: + obj.requires_grad_(False) + except Exception: + pass + return obj def __init__(self, qdata, layout_type, layout_params): self._qdata = qdata @@ -578,3 +590,34 @@ def fp8_func(func, args, kwargs): ar[0] = plain_input return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params) return func(*args, **kwargs) + +def _rebuild_quantized_tensor(qdata, layout_type, layout_params): + """Rebuild QuantizedTensor during unpickling when qdata is already a tensor.""" + return QuantizedTensor(qdata, layout_type, layout_params) + + +def _rebuild_quantized_tensor_from_base(qdata_reduce, layout_type, layout_params): + """Rebuild QuantizedTensor during unpickling given the base tensor's reduce tuple. + + qdata_reduce is the tuple returned by qdata.__reduce_ex__(protocol) on the original + inner tensor. We call the provided rebuild function with its args to recreate the + inner tensor, then wrap it in QuantizedTensor. + """ + rebuild_fn, rebuild_args = qdata_reduce + qdata = rebuild_fn(*rebuild_args) + return QuantizedTensor(qdata, layout_type, layout_params) + + +# Register custom globals with torch.serialization so torch.load(..., weights_only=True) +# accepts these during unpickling. Wrapped in try/except for older PyTorch versions. +try: + import torch as _torch_serial + if hasattr(_torch_serial, "serialization") and hasattr(_torch_serial.serialization, "add_safe_globals"): + _torch_serial.serialization.add_safe_globals([ + QuantizedTensor, + _rebuild_quantized_tensor, + _rebuild_quantized_tensor_from_base, + ]) +except Exception: + # If add_safe_globals doesn't exist or registration fails, we silently continue. + pass diff --git a/comfy/sd.py b/comfy/sd.py index 7de7dd9c65d5..78e1e7af444a 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1041,7 +1041,8 @@ class TEModel(Enum): MISTRAL3_24B_PRUNED_FLUX2 = 15 QWEN3_4B = 16 QWEN3_2B = 17 - JINA_CLIP_2 = 18 + GEMMA_3_12B = 18 + JINA_CLIP_2 = 19 def detect_te_model(sd): @@ -1067,6 +1068,8 @@ def detect_te_model(sd): return TEModel.BYT5_SMALL_GLYPH return TEModel.T5_BASE if 'model.layers.0.post_feedforward_layernorm.weight' in sd: + if 'model.layers.47.self_attn.q_norm.weight' in sd: + return TEModel.GEMMA_3_12B if 'model.layers.0.self_attn.q_norm.weight' in sd: return TEModel.GEMMA_3_4B return TEModel.GEMMA_2_2B @@ -1271,6 +1274,10 @@ class EmptyClass: elif clip_type == CLIPType.KANDINSKY5_IMAGE: clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage + elif clip_type == CLIPType.LTXV: + clip_target.clip = comfy.text_encoders.lt.ltxav_te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.lt.LTXAVGemmaTokenizer + tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) elif clip_type == CLIPType.NEWBIE: clip_target.clip = comfy.text_encoders.newbie.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.newbie.NewBieTokenizer @@ -1506,6 +1513,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): logging.warning("{} {}".format(diffusers_keys[k], k)) offload_device = model_management.unet_offload_device() + logging.debug(f"loader load model to offload device: {offload_device}") unet_weight_dtype = list(model_config.supported_inference_dtypes) if model_config.quant_config is not None: weight_dtype = None diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 1888f35baf31..ee9a7900190d 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -836,6 +836,21 @@ def clip_target(self, state_dict={}): t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.lt.LTXVT5Tokenizer, comfy.text_encoders.lt.ltxv_te(**t5_detect)) +class LTXAV(LTXV): + unet_config = { + "image_model": "ltxav", + } + + latent_format = latent_formats.LTXAV + + def __init__(self, unet_config): + super().__init__(unet_config) + self.memory_usage_factor = 0.055 # TODO + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.LTXAV(self, device=device) + return out + class HunyuanVideo(supported_models_base.BASE): unet_config = { "image_model": "hunyuan_video", @@ -1536,6 +1551,6 @@ def clip_target(self, state_dict={}): return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5] models += [SVD_img2vid] diff --git a/comfy/taesd/taehv.py b/comfy/taesd/taehv.py index 3dfe1e4d482d..0e5f9a378b58 100644 --- a/comfy/taesd/taehv.py +++ b/comfy/taesd/taehv.py @@ -154,7 +154,8 @@ def show_progress_bar(self, value): self._show_progress_bar = value def encode(self, x, **kwargs): - if self.patch_size > 1: x = F.pixel_unshuffle(x, self.patch_size) + if self.patch_size > 1: + x = F.pixel_unshuffle(x, self.patch_size) x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W] if x.shape[1] % 4 != 0: # pad at end to multiple of 4 @@ -167,5 +168,6 @@ def encode(self, x, **kwargs): def decode(self, x, **kwargs): x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W] x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar) - if self.patch_size > 1: x = F.pixel_shuffle(x, self.patch_size) + if self.patch_size > 1: + x = F.pixel_shuffle(x, self.patch_size) return x[:, self.frames_to_trim:].movedim(2, 1) diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index ed29e014d636..76731576b461 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -7,8 +7,8 @@ from comfy.ldm.modules.attention import optimized_attention_for_device import comfy.model_management import comfy.ldm.common_dit +import comfy.clip_model -import comfy.model_management from . import qwen_vl @dataclass @@ -189,6 +189,31 @@ class Gemma3_4B_Config: rope_scale = [8.0, 1.0] final_norm: bool = True +@dataclass +class Gemma3_12B_Config: + vocab_size: int = 262208 + hidden_size: int = 3840 + intermediate_size: int = 15360 + num_hidden_layers: int = 48 + num_attention_heads: int = 16 + num_key_value_heads: int = 8 + max_position_embeddings: int = 131072 + rms_norm_eps: float = 1e-6 + rope_theta = [1000000.0, 10000.0] + transformer_type: str = "gemma3" + head_dim = 256 + rms_norm_add = True + mlp_activation = "gelu_pytorch_tanh" + qkv_bias = False + rope_dims = None + q_norm = "gemma3" + k_norm = "gemma3" + sliding_attention = [1024, 1024, 1024, 1024, 1024, False] + rope_scale = [8.0, 1.0] + final_norm: bool = True + vision_config = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14} + mm_tokens_per_image = 256 + class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None): super().__init__() @@ -521,6 +546,41 @@ def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermed return x, intermediate + +class Gemma3MultiModalProjector(torch.nn.Module): + def __init__(self, config, dtype, device, operations): + super().__init__() + + self.mm_input_projection_weight = nn.Parameter( + torch.empty(config.vision_config["hidden_size"], config.hidden_size, device=device, dtype=dtype) + ) + + self.mm_soft_emb_norm = RMSNorm(config.vision_config["hidden_size"], eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) + + self.patches_per_image = int(config.vision_config["image_size"] // config.vision_config["patch_size"]) + self.tokens_per_side = int(config.mm_tokens_per_image**0.5) + self.kernel_size = self.patches_per_image // self.tokens_per_side + self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size) + + def forward(self, vision_outputs: torch.Tensor): + batch_size, _, seq_length = vision_outputs.shape + + reshaped_vision_outputs = vision_outputs.transpose(1, 2) + reshaped_vision_outputs = reshaped_vision_outputs.reshape( + batch_size, seq_length, self.patches_per_image, self.patches_per_image + ) + reshaped_vision_outputs = reshaped_vision_outputs.contiguous() + + pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs) + pooled_vision_outputs = pooled_vision_outputs.flatten(2) + pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2) + + normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs) + + projected_vision_outputs = torch.matmul(normed_vision_outputs, comfy.model_management.cast_to_device(self.mm_input_projection_weight, device=normed_vision_outputs.device, dtype=normed_vision_outputs.dtype)) + return projected_vision_outputs.type_as(vision_outputs) + + class BaseLlama: def get_input_embeddings(self): return self.model.embed_tokens @@ -637,3 +697,21 @@ def __init__(self, config_dict, dtype, device, operations): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype + +class Gemma3_12B(BaseLlama, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Gemma3_12B_Config(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.multi_modal_projector = Gemma3MultiModalProjector(config, dtype, device, operations) + self.vision_model = comfy.clip_model.CLIPVision(config.vision_config, dtype, device, operations) + self.dtype = dtype + self.image_size = config.vision_config["image_size"] + + def preprocess_embed(self, embed, device): + if embed["type"] == "image": + image = comfy.clip_model.clip_preprocess(embed["data"], size=self.image_size, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], crop=True) + return self.multi_modal_projector(self.vision_model(image.to(device, dtype=torch.float32))[0]), None + return None, None diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index 48ea67e67823..2c2d453e89c9 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -1,7 +1,11 @@ from comfy import sd1_clip import os from transformers import T5TokenizerFast +from .spiece_tokenizer import SPieceTokenizer import comfy.text_encoders.genmo +from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector +import torch +import comfy.utils class T5XXLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): @@ -16,3 +20,110 @@ def __init__(self, embedding_directory=None, tokenizer_data={}): def ltxv_te(*args, **kwargs): return comfy.text_encoders.genmo.mochi_te(*args, **kwargs) + + +class Gemma3_12BTokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer = tokenizer_data.get("spiece_model", None) + super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data) + + def state_dict(self): + return {"spiece_model": self.tokenizer.serialize_model()} + +class LTXAVGemmaTokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma3_12b", tokenizer=Gemma3_12BTokenizer) + +class Gemma3_12BModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}): + llama_scaled_fp8 = model_options.get("gemma_scaled_fp8", None) + if llama_scaled_fp8 is not None: + model_options = model_options.copy() + model_options["scaled_fp8"] = llama_scaled_fp8 + + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_12B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + + def tokenize_with_weights(self, text, return_word_ids=False, llama_template="{}", image_embeds=None, **kwargs): + text = llama_template.format(text) + text_tokens = super().tokenize_with_weights(text, return_word_ids) + embed_count = 0 + for k in text_tokens: + tt = text_tokens[k] + for r in tt: + for i in range(len(r)): + if r[i][0] == 262144: + if image_embeds is not None and embed_count < image_embeds.shape[0]: + r[i] = ({"type": "embedding", "data": image_embeds[embed_count], "original_type": "image"},) + r[i][1:] + embed_count += 1 + return text_tokens + +class LTXAVTEModel(torch.nn.Module): + def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}): + super().__init__() + self.dtypes = set() + self.dtypes.add(dtype) + + self.gemma3_12b = Gemma3_12BModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None) + self.dtypes.add(dtype_llama) + + operations = self.gemma3_12b.operations # TODO + self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device) + + self.audio_embeddings_connector = Embeddings1DConnector( + split_rope=True, + double_precision_rope=True, + dtype=dtype, + device=device, + operations=operations, + ) + + self.video_embeddings_connector = Embeddings1DConnector( + split_rope=True, + double_precision_rope=True, + dtype=dtype, + device=device, + operations=operations, + ) + + def set_clip_options(self, options): + self.gemma3_12b.set_clip_options(options) + + def reset_clip_options(self): + self.gemma3_12b.reset_clip_options() + + def encode_token_weights(self, token_weight_pairs): + token_weight_pairs = token_weight_pairs["gemma3_12b"] + + out, pooled, extra = self.gemma3_12b.encode_token_weights(token_weight_pairs) + out_device = out.device + out = out.movedim(1, -1).to(self.text_embedding_projection.weight.device) + out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6) + out = out.reshape((out.shape[0], out.shape[1], -1)) + out = self.text_embedding_projection(out) + out_vid = self.video_embeddings_connector(out)[0] + out_audio = self.audio_embeddings_connector(out)[0] + out = torch.concat((out_vid, out_audio), dim=-1) + + return out.to(out_device), pooled + + def load_sd(self, sd): + if "model.layers.47.self_attn.q_norm.weight" in sd: + return self.gemma3_12b.load_sd(sd) + else: + sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "model.diffusion_model.video_embeddings_connector.": "video_embeddings_connector.", "model.diffusion_model.audio_embeddings_connector.": "audio_embeddings_connector."}, filter_keys=True) + if len(sdo) == 0: + sdo = sd + + return self.load_state_dict(sdo, strict=False) + + +def ltxav_te(dtype_llama=None, llama_scaled_fp8=None): + class LTXAVTEModel_(LTXAVTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options: + model_options = model_options.copy() + model_options["llama_scaled_fp8"] = llama_scaled_fp8 + if dtype_llama is not None: + dtype = dtype_llama + super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options) + return LTXAVTEModel_ diff --git a/comfy/utils.py b/comfy/utils.py index 8d4e2b445a39..84467b3bed50 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -61,6 +61,8 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): metadata = None if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"): try: + if not DISABLE_MMAP: + logging.debug(f"load_torch_file of safetensors into mmap True") with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f: sd = {} for k in f.keys(): @@ -81,6 +83,7 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): else: torch_args = {} if MMAP_TORCH_FILES: + logging.debug(f"load_torch_file of torch state dict into mmap True") torch_args["mmap"] = True if safe_load or ALWAYS_SAFE_LOAD: @@ -1198,7 +1201,7 @@ def unpack_latents(combined_latent, latent_shapes): combined_latent = combined_latent[:, :, cut:] output_tensors.append(tens.reshape([tens.shape[0]] + list(shape)[1:])) else: - output_tensors = combined_latent + output_tensors = [combined_latent] return output_tensors def detect_layer_quantization(state_dict, prefix): @@ -1230,6 +1233,8 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}): out_sd = {} layers = {} for k in list(state_dict.keys()): + if k == scaled_fp8_key: + continue if not k.startswith(model_prefix): out_sd[k] = state_dict[k] continue diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py index fab63c7dfe3b..b0fa14ff6749 100644 --- a/comfy_api/latest/__init__.py +++ b/comfy_api/latest/__init__.py @@ -10,7 +10,6 @@ from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL from . import _io_public as io from . import _ui_public as ui -# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401 from comfy_execution.utils import get_executing_context from comfy_execution.progress import get_progress_state, PreviewImageTuple from PIL import Image diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 4b14e5ded38b..764fa8b2b4da 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -26,11 +26,9 @@ from comfy_api.input import VideoInput from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class, prune_dict, shallow_clone_class) -from ._resources import Resources, ResourcesLocal from comfy_execution.graph_utils import ExecutionBlocker -from ._util import MESH, VOXEL +from ._util import MESH, VOXEL, SVG as _SVG -# from comfy_extras.nodes_images import SVG as SVG_ # NOTE: needs to be moved before can be imported due to circular reference class FolderType(str, Enum): input = "input" @@ -77,16 +75,6 @@ class NumberDisplay(str, Enum): slider = "slider" -class _StringIOType(str): - def __ne__(self, value: object) -> bool: - if self == "*" or value == "*": - return False - if not isinstance(value, str): - return True - a = frozenset(self.split(",")) - b = frozenset(value.split(",")) - return not (b.issubset(a) or a.issubset(b)) - class _ComfyType(ABC): Type = Any io_type: str = None @@ -126,8 +114,7 @@ def decorator(cls: T) -> T: new_cls.__module__ = cls.__module__ new_cls.__doc__ = cls.__doc__ # assign ComfyType attributes, if needed - # NOTE: use __ne__ trick for io_type (see node_typing.IO.__ne__ for details) - new_cls.io_type = _StringIOType(io_type) + new_cls.io_type = io_type if hasattr(new_cls, "Input") and new_cls.Input is not None: new_cls.Input.Parent = new_cls if hasattr(new_cls, "Output") and new_cls.Output is not None: @@ -166,7 +153,7 @@ class Input(_IO_V3): ''' Base class for a V3 Input. ''' - def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): + def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None): super().__init__() self.id = id self.display_name = display_name @@ -174,6 +161,7 @@ def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str self.tooltip = tooltip self.lazy = lazy self.extra_dict = extra_dict if extra_dict is not None else {} + self.rawLink = raw_link def as_dict(self): return prune_dict({ @@ -181,10 +169,11 @@ def as_dict(self): "optional": self.optional, "tooltip": self.tooltip, "lazy": self.lazy, + "rawLink": self.rawLink, }) | prune_dict(self.extra_dict) def get_io_type(self): - return _StringIOType(self.io_type) + return self.io_type def get_all(self) -> list[Input]: return [self] @@ -195,8 +184,8 @@ class WidgetInput(Input): ''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: Any=None, - socketless: bool=None, widget_type: str=None, force_input: bool=None, extra_dict=None): - super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) + socketless: bool=None, widget_type: str=None, force_input: bool=None, extra_dict=None, raw_link: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link) self.default = default self.socketless = socketless self.widget_type = widget_type @@ -218,13 +207,14 @@ class Output(_IO_V3): def __init__(self, id: str=None, display_name: str=None, tooltip: str=None, is_output_list=False): self.id = id - self.display_name = display_name + self.display_name = display_name if display_name else id self.tooltip = tooltip self.is_output_list = is_output_list def as_dict(self): + display_name = self.display_name if self.display_name else self.id return prune_dict({ - "display_name": self.display_name, + "display_name": display_name, "tooltip": self.tooltip, "is_output_list": self.is_output_list, }) @@ -252,8 +242,8 @@ class Input(WidgetInput): '''Boolean input.''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: bool=None, label_on: str=None, label_off: str=None, - socketless: bool=None, force_input: bool=None): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input) + socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link) self.label_on = label_on self.label_off = label_off self.default: bool @@ -272,8 +262,8 @@ class Input(WidgetInput): '''Integer input.''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None, - display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input) + display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link) self.min = min self.max = max self.step = step @@ -298,8 +288,8 @@ class Input(WidgetInput): '''Float input.''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: float=None, min: float=None, max: float=None, step: float=None, round: float=None, - display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input) + display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link) self.min = min self.max = max self.step = step @@ -324,8 +314,8 @@ class Input(WidgetInput): '''String input.''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, multiline=False, placeholder: str=None, default: str=None, dynamic_prompts: bool=None, - socketless: bool=None, force_input: bool=None): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input) + socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link) self.multiline = multiline self.placeholder = placeholder self.dynamic_prompts = dynamic_prompts @@ -358,12 +348,14 @@ def __init__( image_folder: FolderType=None, remote: RemoteOptions=None, socketless: bool=None, + extra_dict=None, + raw_link: bool=None, ): if isinstance(options, type) and issubclass(options, Enum): options = [v.value for v in options] if isinstance(default, Enum): default = default.value - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless) + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, None, extra_dict, raw_link) self.multiselect = False self.options = options self.control_after_generate = control_after_generate @@ -387,10 +379,6 @@ def __init__(self, id: str=None, display_name: str=None, options: list[str]=None super().__init__(id, display_name, tooltip, is_output_list) self.options = options if options is not None else [] - @property - def io_type(self): - return self.options - @comfytype(io_type="COMBO") class MultiCombo(ComfyTypeI): '''Multiselect Combo input (dropdown for selecting potentially more than one value).''' @@ -399,8 +387,8 @@ class MultiCombo(ComfyTypeI): class Input(Combo.Input): def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None, - socketless: bool=None): - super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless) + socketless: bool=None, extra_dict=None, raw_link: bool=None): + super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless, extra_dict=extra_dict, raw_link=raw_link) self.multiselect = True self.placeholder = placeholder self.chip = chip @@ -433,9 +421,9 @@ class Input(WidgetInput): Type = str def __init__( self, id: str, display_name: str=None, optional=False, - tooltip: str=None, lazy: bool=None, default: str=None, socketless: bool=None + tooltip: str=None, lazy: bool=None, default: str=None, socketless: bool=None, extra_dict=None, raw_link: bool=None ): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless) + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, None, extra_dict, raw_link) @comfytype(io_type="MASK") @@ -656,7 +644,7 @@ class Video(ComfyTypeIO): @comfytype(io_type="SVG") class SVG(ComfyTypeIO): - Type = Any # TODO: SVG class is defined in comfy_extras/nodes_images.py, causing circular reference; should be moved to somewhere else before referenced directly in v3 + Type = _SVG @comfytype(io_type="LORA_MODEL") class LoraModel(ComfyTypeIO): @@ -788,7 +776,7 @@ class Input(Input): ''' Input that permits more than one input type; if `id` is an instance of `ComfyType.Input`, then that input will be used to create a widget (if applicable) with overridden values. ''' - def __init__(self, id: str | Input, types: list[type[_ComfyType] | _ComfyType], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): + def __init__(self, id: str | Input, types: list[type[_ComfyType] | _ComfyType], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None): # if id is an Input, then use that Input with overridden values self.input_override = None if isinstance(id, Input): @@ -801,7 +789,7 @@ def __init__(self, id: str | Input, types: list[type[_ComfyType] | _ComfyType], # if is a widget input, make sure widget_type is set appropriately if isinstance(self.input_override, WidgetInput): self.input_override.widget_type = self.input_override.get_io_type() - super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) + super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link) self._io_types = types @property @@ -855,8 +843,8 @@ def as_dict(self): class Input(Input): def __init__(self, id: str, template: MatchType.Template, - display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): - super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) + display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link) self.template = template def as_dict(self): @@ -867,6 +855,8 @@ def as_dict(self): class Output(Output): def __init__(self, template: MatchType.Template, id: str=None, display_name: str=None, tooltip: str=None, is_output_list=False): + if not id and not display_name: + display_name = "MATCHTYPE" super().__init__(id, display_name, tooltip, is_output_list) self.template = template @@ -879,24 +869,30 @@ class DynamicInput(Input, ABC): ''' Abstract class for dynamic input registration. ''' - def get_dynamic(self) -> list[Input]: - return [] - - def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''): - pass + pass class DynamicOutput(Output, ABC): ''' Abstract class for dynamic output registration. ''' - def __init__(self, id: str=None, display_name: str=None, tooltip: str=None, - is_output_list=False): - super().__init__(id, display_name, tooltip, is_output_list) + pass - def get_dynamic(self) -> list[Output]: - return [] +def handle_prefix(prefix_list: list[str] | None, id: str | None = None) -> list[str]: + if prefix_list is None: + prefix_list = [] + if id is not None: + prefix_list = prefix_list + [id] + return prefix_list + +def finalize_prefix(prefix_list: list[str] | None, id: str | None = None) -> str: + assert not (prefix_list is None and id is None) + if prefix_list is None: + return id + elif id is not None: + prefix_list = prefix_list + [id] + return ".".join(prefix_list) @comfytype(io_type="COMFY_AUTOGROW_V3") class Autogrow(ComfyTypeI): @@ -933,14 +929,6 @@ def as_dict(self): def validate(self): self.input.validate() - def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''): - real_inputs = [] - for name, input in self.cached_inputs.items(): - if name in live_inputs: - real_inputs.append(input) - add_to_input_dict_v1(d, real_inputs, live_inputs, curr_prefix) - add_dynamic_id_mapping(d, real_inputs, curr_prefix) - class TemplatePrefix(_AutogrowTemplate): def __init__(self, input: Input, prefix: str, min: int=1, max: int=10): super().__init__(input) @@ -985,22 +973,45 @@ def as_dict(self): "template": self.template.as_dict(), }) - def get_dynamic(self) -> list[Input]: - return self.template.get_all() - def get_all(self) -> list[Input]: return [self] + self.template.get_all() def validate(self): self.template.validate() - def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''): - curr_prefix = f"{curr_prefix}{self.id}." - # need to remove self from expected inputs dictionary; replaced by template inputs in frontend - for inner_dict in d.values(): - if self.id in inner_dict: - del inner_dict[self.id] - self.template.expand_schema_for_dynamic(d, live_inputs, curr_prefix) + @staticmethod + def _expand_schema_for_dynamic(out_dict: dict[str, Any], live_inputs: dict[str, Any], value: tuple[str, dict[str, Any]], input_type: str, curr_prefix: list[str] | None): + # NOTE: purposely do not include self in out_dict; instead use only the template inputs + # need to figure out names based on template type + is_names = ("names" in value[1]["template"]) + is_prefix = ("prefix" in value[1]["template"]) + input = value[1]["template"]["input"] + if is_names: + min = value[1]["template"]["min"] + names = value[1]["template"]["names"] + max = len(names) + elif is_prefix: + prefix = value[1]["template"]["prefix"] + min = value[1]["template"]["min"] + max = value[1]["template"]["max"] + names = [f"{prefix}{i}" for i in range(max)] + # need to create a new input based on the contents of input + template_input = None + for _, dict_input in input.items(): + # for now, get just the first value from dict_input + template_input = list(dict_input.values())[0] + new_dict = {} + for i, name in enumerate(names): + expected_id = finalize_prefix(curr_prefix, name) + if expected_id in live_inputs: + # required + if i < min: + type_dict = new_dict.setdefault("required", {}) + # optional + else: + type_dict = new_dict.setdefault("optional", {}) + type_dict[name] = template_input + parse_class_inputs(out_dict, live_inputs, new_dict, curr_prefix) @comfytype(io_type="COMFY_DYNAMICCOMBO_V3") class DynamicCombo(ComfyTypeI): @@ -1023,23 +1034,6 @@ def __init__(self, id: str, options: list[DynamicCombo.Option], super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) self.options = options - def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''): - # check if dynamic input's id is in live_inputs - if self.id in live_inputs: - curr_prefix = f"{curr_prefix}{self.id}." - key = live_inputs[self.id] - selected_option = None - for option in self.options: - if option.key == key: - selected_option = option - break - if selected_option is not None: - add_to_input_dict_v1(d, selected_option.inputs, live_inputs, curr_prefix) - add_dynamic_id_mapping(d, selected_option.inputs, curr_prefix, self) - - def get_dynamic(self) -> list[Input]: - return [input for option in self.options for input in option.inputs] - def get_all(self) -> list[Input]: return [self] + [input for option in self.options for input in option.inputs] @@ -1054,6 +1048,24 @@ def validate(self): for input in option.inputs: input.validate() + @staticmethod + def _expand_schema_for_dynamic(out_dict: dict[str, Any], live_inputs: dict[str, Any], value: tuple[str, dict[str, Any]], input_type: str, curr_prefix: list[str] | None): + finalized_id = finalize_prefix(curr_prefix) + if finalized_id in live_inputs: + key = live_inputs[finalized_id] + selected_option = None + # get options from dict + options: list[dict[str, str | dict[str, Any]]] = value[1]["options"] + for option in options: + if option["key"] == key: + selected_option = option + break + if selected_option is not None: + parse_class_inputs(out_dict, live_inputs, selected_option["inputs"], curr_prefix) + # add self to inputs + out_dict[input_type][finalized_id] = value + out_dict["dynamic_paths"][finalized_id] = finalize_prefix(curr_prefix, curr_prefix[-1]) + @comfytype(io_type="COMFY_DYNAMICSLOT_V3") class DynamicSlot(ComfyTypeI): Type = dict[str, Any] @@ -1076,17 +1088,8 @@ def __init__(self, slot: Input, inputs: list[Input], self.force_input = True self.slot.force_input = True - def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''): - if self.id in live_inputs: - curr_prefix = f"{curr_prefix}{self.id}." - add_to_input_dict_v1(d, self.inputs, live_inputs, curr_prefix) - add_dynamic_id_mapping(d, [self.slot] + self.inputs, curr_prefix) - - def get_dynamic(self) -> list[Input]: - return [self.slot] + self.inputs - def get_all(self) -> list[Input]: - return [self] + [self.slot] + self.inputs + return [self.slot] + self.inputs def as_dict(self): return super().as_dict() | prune_dict({ @@ -1100,17 +1103,41 @@ def validate(self): for input in self.inputs: input.validate() -def add_dynamic_id_mapping(d: dict[str, Any], inputs: list[Input], curr_prefix: str, self: DynamicInput=None): - dynamic = d.setdefault("dynamic_paths", {}) - if self is not None: - dynamic[self.id] = f"{curr_prefix}{self.id}" - for i in inputs: - if not isinstance(i, DynamicInput): - dynamic[f"{i.id}"] = f"{curr_prefix}{i.id}" + @staticmethod + def _expand_schema_for_dynamic(out_dict: dict[str, Any], live_inputs: dict[str, Any], value: tuple[str, dict[str, Any]], input_type: str, curr_prefix: list[str] | None): + finalized_id = finalize_prefix(curr_prefix) + if finalized_id in live_inputs: + inputs = value[1]["inputs"] + parse_class_inputs(out_dict, live_inputs, inputs, curr_prefix) + # add self to inputs + out_dict[input_type][finalized_id] = value + out_dict["dynamic_paths"][finalized_id] = finalize_prefix(curr_prefix, curr_prefix[-1]) + +DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {} +def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]): + DYNAMIC_INPUT_LOOKUP[io_type] = func + +def get_dynamic_input_func(io_type: str) -> Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]: + return DYNAMIC_INPUT_LOOKUP[io_type] + +def setup_dynamic_input_funcs(): + # Autogrow.Input + register_dynamic_input_func(Autogrow.io_type, Autogrow._expand_schema_for_dynamic) + # DynamicCombo.Input + register_dynamic_input_func(DynamicCombo.io_type, DynamicCombo._expand_schema_for_dynamic) + # DynamicSlot.Input + register_dynamic_input_func(DynamicSlot.io_type, DynamicSlot._expand_schema_for_dynamic) + +if len(DYNAMIC_INPUT_LOOKUP) == 0: + setup_dynamic_input_funcs() class V3Data(TypedDict): hidden_inputs: dict[str, Any] + 'Dictionary where the keys are the hidden input ids and the values are the values of the hidden inputs.' dynamic_paths: dict[str, Any] + 'Dictionary where the keys are the input ids and the values dictate how to turn the inputs into a nested dictionary.' + create_dynamic_tuple: bool + 'When True, the value of the dynamic input will be in the format (value, path_key).' class HiddenHolder: def __init__(self, unique_id: str, prompt: Any, @@ -1146,6 +1173,10 @@ def from_dict(cls, d: dict | None): api_key_comfy_org=d.get(Hidden.api_key_comfy_org, None), ) + @classmethod + def from_v3_data(cls, v3_data: V3Data | None) -> HiddenHolder: + return cls.from_dict(v3_data["hidden_inputs"] if v3_data else None) + class Hidden(str, Enum): ''' Enumerator for requesting hidden variables in nodes. @@ -1251,61 +1282,56 @@ def validate(self): - verify ids on inputs and outputs are unique - both internally and in relation to each other ''' nested_inputs: list[Input] = [] - if self.inputs is not None: - for input in self.inputs: + for input in self.inputs: + if not isinstance(input, DynamicInput): nested_inputs.extend(input.get_all()) - input_ids = [i.id for i in nested_inputs] if nested_inputs is not None else [] - output_ids = [o.id for o in self.outputs] if self.outputs is not None else [] + input_ids = [i.id for i in nested_inputs] + output_ids = [o.id for o in self.outputs] input_set = set(input_ids) output_set = set(output_ids) - issues = [] + issues: list[str] = [] # verify ids are unique per list if len(input_set) != len(input_ids): issues.append(f"Input ids must be unique, but {[item for item, count in Counter(input_ids).items() if count > 1]} are not.") if len(output_set) != len(output_ids): issues.append(f"Output ids must be unique, but {[item for item, count in Counter(output_ids).items() if count > 1]} are not.") - # verify ids are unique between lists - intersection = input_set & output_set - if len(intersection) > 0: - issues.append(f"Ids must be unique between inputs and outputs, but {intersection} are not.") if len(issues) > 0: raise ValueError("\n".join(issues)) # validate inputs and outputs - if self.inputs is not None: - for input in self.inputs: - input.validate() - if self.outputs is not None: - for output in self.outputs: - output.validate() + for input in self.inputs: + input.validate() + for output in self.outputs: + output.validate() def finalize(self): """Add hidden based on selected schema options, and give outputs without ids default ids.""" + # ensure inputs, outputs, and hidden are lists + if self.inputs is None: + self.inputs = [] + if self.outputs is None: + self.outputs = [] + if self.hidden is None: + self.hidden = [] # if is an api_node, will need key-related hidden if self.is_api_node: - if self.hidden is None: - self.hidden = [] if Hidden.auth_token_comfy_org not in self.hidden: self.hidden.append(Hidden.auth_token_comfy_org) if Hidden.api_key_comfy_org not in self.hidden: self.hidden.append(Hidden.api_key_comfy_org) # if is an output_node, will need prompt and extra_pnginfo if self.is_output_node: - if self.hidden is None: - self.hidden = [] if Hidden.prompt not in self.hidden: self.hidden.append(Hidden.prompt) if Hidden.extra_pnginfo not in self.hidden: self.hidden.append(Hidden.extra_pnginfo) # give outputs without ids default ids - if self.outputs is not None: - for i, output in enumerate(self.outputs): - if output.id is None: - output.id = f"_{i}_{output.io_type}_" + for i, output in enumerate(self.outputs): + if output.id is None: + output.id = f"_{i}_{output.io_type}_" - def get_v1_info(self, cls, live_inputs: dict[str, Any]=None) -> NodeInfoV1: - # NOTE: live_inputs will not be used anymore very soon and this will be done another way + def get_v1_info(self, cls) -> NodeInfoV1: # get V1 inputs - input = create_input_dict_v1(self.inputs, live_inputs) + input = create_input_dict_v1(self.inputs) if self.hidden: for hidden in self.hidden: input.setdefault("hidden", {})[hidden.name] = (hidden.value,) @@ -1385,33 +1411,54 @@ def get_v3_info(self, cls) -> NodeInfoV3: ) return info +def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], include_hidden=False) -> tuple[dict[str, Any], V3Data]: + out_dict = { + "required": {}, + "optional": {}, + "dynamic_paths": {}, + } + d = d.copy() + # ignore hidden for parsing + hidden = d.pop("hidden", None) + parse_class_inputs(out_dict, live_inputs, d) + if hidden is not None and include_hidden: + out_dict["hidden"] = hidden + v3_data = {} + dynamic_paths = out_dict.pop("dynamic_paths", None) + if dynamic_paths is not None: + v3_data["dynamic_paths"] = dynamic_paths + return out_dict, hidden, v3_data + +def parse_class_inputs(out_dict: dict[str, Any], live_inputs: dict[str, Any], curr_dict: dict[str, Any], curr_prefix: list[str] | None=None) -> None: + for input_type, inner_d in curr_dict.items(): + for id, value in inner_d.items(): + io_type = value[0] + if io_type in DYNAMIC_INPUT_LOOKUP: + # dynamic inputs need to be handled with lookup functions + dynamic_input_func = get_dynamic_input_func(io_type) + new_prefix = handle_prefix(curr_prefix, id) + dynamic_input_func(out_dict, live_inputs, value, input_type, new_prefix) + else: + # non-dynamic inputs get directly transferred + finalized_id = finalize_prefix(curr_prefix, id) + out_dict[input_type][finalized_id] = value + if curr_prefix: + out_dict["dynamic_paths"][finalized_id] = finalized_id -def create_input_dict_v1(inputs: list[Input], live_inputs: dict[str, Any]=None) -> dict: +def create_input_dict_v1(inputs: list[Input]) -> dict: input = { "required": {} } - add_to_input_dict_v1(input, inputs, live_inputs) - return input - -def add_to_input_dict_v1(d: dict[str, Any], inputs: list[Input], live_inputs: dict[str, Any]=None, curr_prefix=''): for i in inputs: - if isinstance(i, DynamicInput): - add_to_dict_v1(i, d) - if live_inputs is not None: - i.expand_schema_for_dynamic(d, live_inputs, curr_prefix) - else: - add_to_dict_v1(i, d) + add_to_dict_v1(i, input) + return input -def add_to_dict_v1(i: Input, d: dict, dynamic_dict: dict=None): +def add_to_dict_v1(i: Input, d: dict): key = "optional" if i.optional else "required" as_dict = i.as_dict() # for v1, we don't want to include the optional key as_dict.pop("optional", None) - if dynamic_dict is None: - value = (i.get_io_type(), as_dict) - else: - value = (i.get_io_type(), as_dict, dynamic_dict) - d.setdefault(key, {})[i.id] = value + d.setdefault(key, {})[i.id] = (i.get_io_type(), as_dict) def add_to_dict_v3(io: Input | Output, d: dict): d[io.id] = (io.get_io_type(), io.as_dict()) @@ -1423,6 +1470,8 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data): values = values.copy() result = {} + create_tuple = v3_data.get("create_dynamic_tuple", False) + for key, path in paths.items(): parts = path.split(".") current = result @@ -1431,7 +1480,10 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data): is_last = (i == len(parts) - 1) if is_last: - current[p] = values.pop(key, None) + value = values.pop(key, None) + if create_tuple: + value = (value, key) + current[p] = value else: current = current.setdefault(p, {}) @@ -1446,7 +1498,6 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal): SCHEMA = None # filled in during execution - resources: Resources = None hidden: HiddenHolder = None @classmethod @@ -1493,7 +1544,6 @@ def check_lazy_status(cls, **kwargs) -> list[str]: return [name for name in kwargs if kwargs[name] is None] def __init__(self): - self.local_resources: ResourcesLocal = None self.__class__.VALIDATE_CLASS() @classmethod @@ -1561,7 +1611,7 @@ def PREPARE_CLASS_CLONE(cls, v3_data: V3Data | None) -> type[ComfyNode]: c_type: type[ComfyNode] = cls if is_class(cls) else type(cls) type_clone: type[ComfyNode] = shallow_clone_class(c_type) # set hidden - type_clone.hidden = HiddenHolder.from_dict(v3_data["hidden_inputs"] if v3_data else None) + type_clone.hidden = HiddenHolder.from_v3_data(v3_data) return type_clone @final @@ -1678,19 +1728,10 @@ def NOT_IDEMPOTENT(cls): # noqa @final @classmethod - def INPUT_TYPES(cls, include_hidden=True, return_schema=False, live_inputs=None) -> dict[str, dict] | tuple[dict[str, dict], Schema, V3Data]: + def INPUT_TYPES(cls) -> dict[str, dict]: schema = cls.FINALIZE_SCHEMA() - info = schema.get_v1_info(cls, live_inputs) - input = info.input - if not include_hidden: - input.pop("hidden", None) - if return_schema: - v3_data: V3Data = {} - dynamic = input.pop("dynamic_paths", None) - if dynamic is not None: - v3_data["dynamic_paths"] = dynamic - return input, schema, v3_data - return input + info = schema.get_v1_info(cls) + return info.input @final @classmethod @@ -1809,7 +1850,7 @@ def result(self): return self.args if len(self.args) > 0 else None @classmethod - def from_dict(cls, data: dict[str, Any]) -> "NodeOutput": + def from_dict(cls, data: dict[str, Any]) -> NodeOutput: args = () ui = None expand = None @@ -1904,8 +1945,8 @@ def as_dict(self) -> dict: "Tracks", # Dynamic Types "MatchType", - # "DynamicCombo", - # "Autogrow", + "DynamicCombo", + "Autogrow", # Other classes "HiddenHolder", "Hidden", diff --git a/comfy_api/latest/_resources.py b/comfy_api/latest/_resources.py deleted file mode 100644 index a6bdda97204e..000000000000 --- a/comfy_api/latest/_resources.py +++ /dev/null @@ -1,72 +0,0 @@ -from __future__ import annotations -import comfy.utils -import folder_paths -import logging -from abc import ABC, abstractmethod -from typing import Any -import torch - -class ResourceKey(ABC): - Type = Any - def __init__(self): - ... - -class TorchDictFolderFilename(ResourceKey): - '''Key for requesting a torch file via file_name from a folder category.''' - Type = dict[str, torch.Tensor] - def __init__(self, folder_name: str, file_name: str): - self.folder_name = folder_name - self.file_name = file_name - - def __hash__(self): - return hash((self.folder_name, self.file_name)) - - def __eq__(self, other: object) -> bool: - if not isinstance(other, TorchDictFolderFilename): - return False - return self.folder_name == other.folder_name and self.file_name == other.file_name - - def __str__(self): - return f"{self.folder_name} -> {self.file_name}" - -class Resources(ABC): - def __init__(self): - ... - - @abstractmethod - def get(self, key: ResourceKey, default: Any=...) -> Any: - pass - -class ResourcesLocal(Resources): - def __init__(self): - super().__init__() - self.local_resources: dict[ResourceKey, Any] = {} - - def get(self, key: ResourceKey, default: Any=...) -> Any: - cached = self.local_resources.get(key, None) - if cached is not None: - logging.info(f"Using cached resource '{key}'") - return cached - logging.info(f"Loading resource '{key}'") - to_return = None - if isinstance(key, TorchDictFolderFilename): - if default is ...: - to_return = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise(key.folder_name, key.file_name), safe_load=True) - else: - full_path = folder_paths.get_full_path(key.folder_name, key.file_name) - if full_path is not None: - to_return = comfy.utils.load_torch_file(full_path, safe_load=True) - - if to_return is not None: - self.local_resources[key] = to_return - return to_return - if default is not ...: - return default - raise Exception(f"Unsupported resource key type: {type(key)}") - - -class _RESOURCES: - ResourceKey = ResourceKey - TorchDictFolderFilename = TorchDictFolderFilename - Resources = Resources - ResourcesLocal = ResourcesLocal diff --git a/comfy_api/latest/_util/__init__.py b/comfy_api/latest/_util/__init__.py index fc5431dda5bf..6313eb01bcb5 100644 --- a/comfy_api/latest/_util/__init__.py +++ b/comfy_api/latest/_util/__init__.py @@ -1,5 +1,6 @@ from .video_types import VideoContainer, VideoCodec, VideoComponents from .geometry_types import VOXEL, MESH +from .image_types import SVG __all__ = [ # Utility Types @@ -8,4 +9,5 @@ "VideoComponents", "VOXEL", "MESH", + "SVG", ] diff --git a/comfy_api/latest/_util/image_types.py b/comfy_api/latest/_util/image_types.py new file mode 100644 index 000000000000..f031ed42615c --- /dev/null +++ b/comfy_api/latest/_util/image_types.py @@ -0,0 +1,18 @@ +from io import BytesIO + + +class SVG: + """Stores SVG representations via a list of BytesIO objects.""" + + def __init__(self, data: list[BytesIO]): + self.data = data + + def combine(self, other: 'SVG') -> 'SVG': + return SVG(self.data + other.data) + + @staticmethod + def combine_all(svgs: list['SVG']) -> 'SVG': + all_svgs_list: list[BytesIO] = [] + for svg_item in svgs: + all_svgs_list.extend(svg_item.data) + return SVG(all_svgs_list) diff --git a/comfy_api_nodes/apis/gemini_api.py b/comfy_api_nodes/apis/gemini_api.py index f8edc38c936e..d81337dae73d 100644 --- a/comfy_api_nodes/apis/gemini_api.py +++ b/comfy_api_nodes/apis/gemini_api.py @@ -133,6 +133,7 @@ class GeminiImageGenerateContentRequest(BaseModel): systemInstruction: GeminiSystemInstructionContent | None = Field(None) tools: list[GeminiTool] | None = Field(None) videoMetadata: GeminiVideoMetadata | None = Field(None) + uploadImagesToStorage: bool = Field(True) class GeminiGenerateContentRequest(BaseModel): diff --git a/comfy_api_nodes/apis/kling_api.py b/comfy_api_nodes/apis/kling_api.py index 80a7584660d0..bf54ede3ebc5 100644 --- a/comfy_api_nodes/apis/kling_api.py +++ b/comfy_api_nodes/apis/kling_api.py @@ -102,3 +102,12 @@ class ImageToVideoWithAudioRequest(BaseModel): prompt: str = Field(...) mode: str = Field("pro") sound: str = Field(..., description="'on' or 'off'") + + +class MotionControlRequest(BaseModel): + prompt: str = Field(...) + image_url: str = Field(...) + video_url: str = Field(...) + keep_original_sound: str = Field(...) + character_orientation: str = Field(...) + mode: str = Field(..., description="'pro' or 'std'") diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index 636cc1265780..d4a2cfae681a 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -229,6 +229,7 @@ def define_schema(cls): IO.Hidden.unique_id, ], is_api_node=True, + is_deprecated=True, ) @classmethod @@ -269,7 +270,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="ByteDanceSeedreamNode", - display_name="ByteDance Seedream 4", + display_name="ByteDance Seedream 4.5", category="api node/image/ByteDance", description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.", inputs=[ diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index ad0f4b4d1a70..e8ed7e797218 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -34,6 +34,7 @@ ApiEndpoint, audio_to_base64_string, bytesio_to_image_tensor, + download_url_to_image_tensor, get_number_of_images, sync_op, tensor_to_base64_string, @@ -141,9 +142,11 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera ) parts = [] for part in response.candidates[0].content.parts: - if part_type == "text" and hasattr(part, "text") and part.text: + if part_type == "text" and part.text: parts.append(part) - elif hasattr(part, "inlineData") and part.inlineData and part.inlineData.mimeType == part_type: + elif part.inlineData and part.inlineData.mimeType == part_type: + parts.append(part) + elif part.fileData and part.fileData.mimeType == part_type: parts.append(part) # Skip parts that don't match the requested type return parts @@ -163,12 +166,15 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str: return "\n".join([part.text for part in parts]) -def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image: +async def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image: image_tensors: list[Input.Image] = [] parts = get_parts_by_type(response, "image/png") for part in parts: - image_data = base64.b64decode(part.inlineData.data) - returned_image = bytesio_to_image_tensor(BytesIO(image_data)) + if part.inlineData: + image_data = base64.b64decode(part.inlineData.data) + returned_image = bytesio_to_image_tensor(BytesIO(image_data)) + else: + returned_image = await download_url_to_image_tensor(part.fileData.fileUri) image_tensors.append(returned_image) if len(image_tensors) == 0: return torch.zeros((1, 1024, 1024, 4)) @@ -596,7 +602,7 @@ async def execute( response = await sync_op( cls, - endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"), + ApiEndpoint(path=f"/proxy/vertexai/gemini/{model}", method="POST"), data=GeminiImageGenerateContentRequest( contents=[ GeminiContent(role=GeminiRole.user, parts=parts), @@ -610,7 +616,7 @@ async def execute( response_model=GeminiGenerateContentResponse, price_extractor=calculate_tokens_price, ) - return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response)) + return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response)) class GeminiImage2(IO.ComfyNode): @@ -729,7 +735,7 @@ async def execute( response = await sync_op( cls, - ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"), + ApiEndpoint(path=f"/proxy/vertexai/gemini/{model}", method="POST"), data=GeminiImageGenerateContentRequest( contents=[ GeminiContent(role=GeminiRole.user, parts=parts), @@ -743,7 +749,7 @@ async def execute( response_model=GeminiGenerateContentResponse, price_extractor=calculate_tokens_price, ) - return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response)) + return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response)) class GeminiExtension(ComfyExtension): diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 5294b10d4d7e..9c707a339f60 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -51,6 +51,7 @@ ) from comfy_api_nodes.apis.kling_api import ( ImageToVideoWithAudioRequest, + MotionControlRequest, OmniImageParamImage, OmniParamImage, OmniParamVideo, @@ -806,6 +807,7 @@ def define_schema(cls) -> IO.Schema: ), IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]), IO.Combo.Input("duration", options=[5, 10]), + IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True), ], outputs=[ IO.Video.Output(), @@ -825,6 +827,7 @@ async def execute( prompt: str, aspect_ratio: str, duration: int, + resolution: str = "1080p", ) -> IO.NodeOutput: validate_string(prompt, min_length=1, max_length=2500) response = await sync_op( @@ -836,6 +839,7 @@ async def execute( prompt=prompt, aspect_ratio=aspect_ratio, duration=str(duration), + mode="pro" if resolution == "1080p" else "std", ), ) return await finish_omni_video_task(cls, response) @@ -871,6 +875,7 @@ def define_schema(cls) -> IO.Schema: optional=True, tooltip="Up to 6 additional reference images.", ), + IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True), ], outputs=[ IO.Video.Output(), @@ -892,6 +897,7 @@ async def execute( first_frame: Input.Image, end_frame: Input.Image | None = None, reference_images: Input.Image | None = None, + resolution: str = "1080p", ) -> IO.NodeOutput: prompt = normalize_omni_prompt_references(prompt) validate_string(prompt, min_length=1, max_length=2500) @@ -935,6 +941,7 @@ async def execute( prompt=prompt, duration=str(duration), image_list=image_list, + mode="pro" if resolution == "1080p" else "std", ), ) return await finish_omni_video_task(cls, response) @@ -963,6 +970,7 @@ def define_schema(cls) -> IO.Schema: "reference_images", tooltip="Up to 7 reference images.", ), + IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True), ], outputs=[ IO.Video.Output(), @@ -983,6 +991,7 @@ async def execute( aspect_ratio: str, duration: int, reference_images: Input.Image, + resolution: str = "1080p", ) -> IO.NodeOutput: prompt = normalize_omni_prompt_references(prompt) validate_string(prompt, min_length=1, max_length=2500) @@ -1004,6 +1013,7 @@ async def execute( aspect_ratio=aspect_ratio, duration=str(duration), image_list=image_list, + mode="pro" if resolution == "1080p" else "std", ), ) return await finish_omni_video_task(cls, response) @@ -1035,6 +1045,7 @@ def define_schema(cls) -> IO.Schema: tooltip="Up to 4 additional reference images.", optional=True, ), + IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True), ], outputs=[ IO.Video.Output(), @@ -1057,6 +1068,7 @@ async def execute( reference_video: Input.Video, keep_original_sound: bool, reference_images: Input.Image | None = None, + resolution: str = "1080p", ) -> IO.NodeOutput: prompt = normalize_omni_prompt_references(prompt) validate_string(prompt, min_length=1, max_length=2500) @@ -1089,6 +1101,7 @@ async def execute( duration=str(duration), image_list=image_list if image_list else None, video_list=video_list, + mode="pro" if resolution == "1080p" else "std", ), ) return await finish_omni_video_task(cls, response) @@ -1118,6 +1131,7 @@ def define_schema(cls) -> IO.Schema: tooltip="Up to 4 additional reference images.", optional=True, ), + IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True), ], outputs=[ IO.Video.Output(), @@ -1138,6 +1152,7 @@ async def execute( video: Input.Video, keep_original_sound: bool, reference_images: Input.Image | None = None, + resolution: str = "1080p", ) -> IO.NodeOutput: prompt = normalize_omni_prompt_references(prompt) validate_string(prompt, min_length=1, max_length=2500) @@ -1170,6 +1185,7 @@ async def execute( duration=None, image_list=image_list if image_list else None, video_list=video_list, + mode="pro" if resolution == "1080p" else "std", ), ) return await finish_omni_video_task(cls, response) @@ -2163,6 +2179,91 @@ async def execute( return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) +class MotionControl(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingMotionControl", + display_name="Kling Motion Control", + category="api node/video/Kling", + inputs=[ + IO.String.Input("prompt", multiline=True), + IO.Image.Input("reference_image"), + IO.Video.Input( + "reference_video", + tooltip="Motion reference video used to drive movement/expression.\n" + "Duration limits depend on character_orientation:\n" + " - image: 3–10s (max 10s)\n" + " - video: 3–30s (max 30s)", + ), + IO.Boolean.Input("keep_original_sound", default=True), + IO.Combo.Input( + "character_orientation", + options=["video", "image"], + tooltip="Controls where the character's facing/orientation comes from.\n" + "video: movements, expressions, camera moves, and orientation " + "follow the motion reference video (other details via prompt).\n" + "image: movements and expressions still follow the motion reference video, " + "but the character orientation matches the reference image (camera/other details via prompt).", + ), + IO.Combo.Input("mode", options=["pro", "std"]), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + prompt: str, + reference_image: Input.Image, + reference_video: Input.Video, + keep_original_sound: bool, + character_orientation: str, + mode: str, + ) -> IO.NodeOutput: + validate_string(prompt, max_length=2500) + validate_image_dimensions(reference_image, min_width=340, min_height=340) + validate_image_aspect_ratio(reference_image, (1, 2.5), (2.5, 1)) + if character_orientation == "image": + validate_video_duration(reference_video, min_duration=3, max_duration=10) + else: + validate_video_duration(reference_video, min_duration=3, max_duration=30) + validate_video_dimensions(reference_video, min_width=340, min_height=340, max_width=3850, max_height=3850) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/kling/v1/videos/motion-control", method="POST"), + response_model=TaskStatusResponse, + data=MotionControlRequest( + prompt=prompt, + image_url=(await upload_images_to_comfyapi(cls, reference_image))[0], + video_url=await upload_video_to_comfyapi(cls, reference_video), + keep_original_sound="yes" if keep_original_sound else "no", + character_orientation=character_orientation, + mode=mode, + ), + ) + if response.code: + raise RuntimeError( + f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}" + ) + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/kling/v1/videos/motion-control/{response.data.task_id}"), + response_model=TaskStatusResponse, + status_extractor=lambda r: (r.data.task_status if r.data else None), + ) + return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) + + class KlingExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: @@ -2188,6 +2289,7 @@ async def get_node_list(self) -> list[type[IO.ComfyNode]]: OmniProImageNode, TextToVideoWithAudio, ImageToVideoWithAudio, + MotionControl, ] diff --git a/comfy_api_nodes/nodes_tripo.py b/comfy_api_nodes/nodes_tripo.py index bd3c24fb393d..e72f8e96a781 100644 --- a/comfy_api_nodes/nodes_tripo.py +++ b/comfy_api_nodes/nodes_tripo.py @@ -155,7 +155,7 @@ async def execute( model_seed=model_seed, texture_seed=texture_seed, texture_quality=texture_quality, - face_limit=face_limit, + face_limit=face_limit if face_limit != -1 else None, geometry_quality=geometry_quality, auto_size=True, quad=quad, @@ -255,7 +255,7 @@ async def execute( texture_alignment=texture_alignment, texture_seed=texture_seed, texture_quality=texture_quality, - face_limit=face_limit, + face_limit=face_limit if face_limit != -1 else None, auto_size=True, quad=quad, ), @@ -369,7 +369,7 @@ async def execute( texture_quality=texture_quality, geometry_quality=geometry_quality, texture_alignment=texture_alignment, - face_limit=face_limit, + face_limit=face_limit if face_limit != -1 else None, quad=quad, ), ) diff --git a/comfy_api_nodes/nodes_veo2.py b/comfy_api_nodes/nodes_veo2.py index e165b8380f94..13a6bfd91f24 100644 --- a/comfy_api_nodes/nodes_veo2.py +++ b/comfy_api_nodes/nodes_veo2.py @@ -168,6 +168,8 @@ async def execute( # Only add generateAudio for Veo 3 models if model.find("veo-2.0") == -1: parameters["generateAudio"] = generate_audio + # force "enhance_prompt" to True for Veo3 models + parameters["enhancePrompt"] = True initial_response = await sync_op( cls, @@ -291,7 +293,7 @@ def define_schema(cls): IO.Boolean.Input( "enhance_prompt", default=True, - tooltip="Whether to enhance the prompt with AI assistance", + tooltip="This parameter is deprecated and ignored.", optional=True, ), IO.Combo.Input( diff --git a/comfy_api_nodes/util/_helpers.py b/comfy_api_nodes/util/_helpers.py index 491e6b6a8d8e..648defe3deba 100644 --- a/comfy_api_nodes/util/_helpers.py +++ b/comfy_api_nodes/util/_helpers.py @@ -1,16 +1,22 @@ import asyncio import contextlib import os +import re import time from collections.abc import Callable from io import BytesIO +from yarl import URL + from comfy.cli_args import args from comfy.model_management import processing_interrupted from comfy_api.latest import IO from .common_exceptions import ProcessingInterrupted +_HAS_PCT_ESC = re.compile(r"%[0-9A-Fa-f]{2}") # any % followed by 2 hex digits +_HAS_BAD_PCT = re.compile(r"%(?![0-9A-Fa-f]{2})") # any % not followed by 2 hex digits + def is_processing_interrupted() -> bool: """Return True if user/runtime requested interruption.""" @@ -69,3 +75,17 @@ def get_fs_object_size(path_or_object: str | BytesIO) -> int: if isinstance(path_or_object, str): return os.path.getsize(path_or_object) return len(path_or_object.getvalue()) + + +def to_aiohttp_url(url: str) -> URL: + """If `url` appears to be already percent-encoded (contains at least one valid %HH + escape and no malformed '%' sequences) and contains no raw whitespace/control + characters preserve the original encoding byte-for-byte (important for signed/presigned URLs). + Otherwise, return `URL(url)` and allow yarl to normalize/quote as needed.""" + if any(c.isspace() for c in url) or any(ord(c) < 0x20 for c in url): + # Avoid encoded=True if URL contains raw whitespace/control chars + return URL(url) + if _HAS_PCT_ESC.search(url) and not _HAS_BAD_PCT.search(url): + # Preserve encoding only if it appears pre-encoded AND has no invalid % sequences + return URL(url, encoded=True) + return URL(url) diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index bf37cba5ff98..f372ec7b5ed6 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -430,9 +430,9 @@ def _display_text( if status: display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}") if price is not None: - p = f"{float(price):,.4f}".rstrip("0").rstrip(".") + p = f"{float(price) * 211:,.1f}".rstrip("0").rstrip(".") if p != "0": - display_lines.append(f"Price: ${p}") + display_lines.append(f"Price: {p} credits") if text is not None: display_lines.append(text) if display_lines: diff --git a/comfy_api_nodes/util/download_helpers.py b/comfy_api_nodes/util/download_helpers.py index 3e0d0352da2b..4668d14a9416 100644 --- a/comfy_api_nodes/util/download_helpers.py +++ b/comfy_api_nodes/util/download_helpers.py @@ -19,6 +19,7 @@ get_auth_header, is_processing_interrupted, sleep_with_interrupt, + to_aiohttp_url, ) from .client import _diagnose_connectivity from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted @@ -94,7 +95,7 @@ async def _monitor(): monitor_task = asyncio.create_task(_monitor()) - req_task = asyncio.create_task(session.get(url, headers=headers)) + req_task = asyncio.create_task(session.get(to_aiohttp_url(url), headers=headers)) done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED) if monitor_task in done and req_task in pending: diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index 0d811e3546ef..9d170b16e140 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -97,6 +97,11 @@ def get_input_info( extra_info = input_info[1] else: extra_info = {} + # if input_type is a list, it is a Combo defined in outdated format; convert it. + # NOTE: uncomment this when we are confident old format going away won't cause too much trouble. + # if isinstance(input_type, list): + # extra_info["options"] = input_type + # input_type = IO.Combo.io_type return input_type, input_category, extra_info class TopologicalSort: @@ -202,15 +207,15 @@ def is_cached(self, node_id): return self.output_cache.get(node_id) is not None def cache_link(self, from_node_id, to_node_id): - if not to_node_id in self.execution_cache: + if to_node_id not in self.execution_cache: self.execution_cache[to_node_id] = {} self.execution_cache[to_node_id][from_node_id] = self.output_cache.get(from_node_id) - if not from_node_id in self.execution_cache_listeners: + if from_node_id not in self.execution_cache_listeners: self.execution_cache_listeners[from_node_id] = set() self.execution_cache_listeners[from_node_id].add(to_node_id) def get_cache(self, from_node_id, to_node_id): - if not to_node_id in self.execution_cache: + if to_node_id not in self.execution_cache: return None value = self.execution_cache[to_node_id].get(from_node_id) if value is None: diff --git a/comfy_execution/validation.py b/comfy_execution/validation.py index 24c0b4ed76cd..e73624bd1ef5 100644 --- a/comfy_execution/validation.py +++ b/comfy_execution/validation.py @@ -21,14 +21,24 @@ def validate_node_input( """ # If the types are exactly the same, we can return immediately # Use pre-union behaviour: inverse of `__ne__` + # NOTE: this lets legacy '*' Any types work that override the __ne__ method of the str class. if not received_type != input_type: return True + # If one of the types is '*', we can return True immediately; this is the 'Any' type. + if received_type == IO.AnyType.io_type or input_type == IO.AnyType.io_type: + return True + # If the received type or input_type is a MatchType, we can return True immediately; # validation for this is handled by the frontend if received_type == IO.MatchType.io_type or input_type == IO.MatchType.io_type: return True + # This accounts for some custom nodes that output lists of options as the type; + # if we ever want to break them on purpose, this can be removed + if isinstance(received_type, list) and input_type == IO.Combo.io_type: + return True + # Not equal, and not strings if not isinstance(received_type, str) or not isinstance(input_type, str): return False @@ -37,6 +47,10 @@ def validate_node_input( received_types = set(t.strip() for t in received_type.split(",")) input_types = set(t.strip() for t in input_type.split(",")) + # If any of the types is '*', we can return True immediately; this is the 'Any' type. + if IO.AnyType.io_type in received_types or IO.AnyType.io_type in input_types: + return True + if strict: # In strict mode, all received types must be in the input types return received_types.issubset(input_types) diff --git a/comfy_extras/nodes_apg.py b/comfy_extras/nodes_apg.py index f27ae7da8ce7..b9df2dcc946a 100644 --- a/comfy_extras/nodes_apg.py +++ b/comfy_extras/nodes_apg.py @@ -55,7 +55,8 @@ def execute(cls, model, eta, norm_threshold, momentum) -> io.NodeOutput: def pre_cfg_function(args): nonlocal running_avg, prev_sigma - if len(args["conds_out"]) == 1: return args["conds_out"] + if len(args["conds_out"]) == 1: + return args["conds_out"] cond = args["conds_out"][0] uncond = args["conds_out"][1] diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index c7916443caed..94ad5e8a840d 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -112,7 +112,7 @@ def execute(cls, vae, samples) -> IO.NodeOutput: std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0 std[std < 1.0] = 1.0 audio /= std - return IO.NodeOutput({"waveform": audio, "sample_rate": 44100}) + return IO.NodeOutput({"waveform": audio, "sample_rate": 44100 if "sample_rate" not in samples else samples["sample_rate"]}) decode = execute # TODO: remove diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index 7ee4caac111f..f19adf4b9eb5 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -9,6 +9,7 @@ import node_helpers from typing_extensions import override from comfy_api.latest import ComfyExtension, io +import re class BasicScheduler(io.ComfyNode): @@ -760,8 +761,12 @@ def execute(cls, model, add_noise, noise_seed, cfg, positive, negative, sampler, out = latent.copy() out["samples"] = samples if "x0" in x0_output: + x0_out = model.model.process_latent_out(x0_output["x0"].cpu()) + if samples.is_nested: + latent_shapes = [x.shape for x in samples.unbind()] + x0_out = comfy.nested_tensor.NestedTensor(comfy.utils.unpack_latents(x0_out, latent_shapes)) out_denoised = latent.copy() - out_denoised["samples"] = model.model.process_latent_out(x0_output["x0"].cpu()) + out_denoised["samples"] = x0_out else: out_denoised = out return io.NodeOutput(out, out_denoised) @@ -948,8 +953,12 @@ def execute(cls, noise, guider, sampler, sigmas, latent_image) -> io.NodeOutput: out = latent.copy() out["samples"] = samples if "x0" in x0_output: + x0_out = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu()) + if samples.is_nested: + latent_shapes = [x.shape for x in samples.unbind()] + x0_out = comfy.nested_tensor.NestedTensor(comfy.utils.unpack_latents(x0_out, latent_shapes)) out_denoised = latent.copy() - out_denoised["samples"] = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu()) + out_denoised["samples"] = x0_out else: out_denoised = out return io.NodeOutput(out, out_denoised) @@ -1005,6 +1014,25 @@ def execute(cls, model, noise, sigmas, latent_image) -> io.NodeOutput: add_noise = execute +class ManualSigmas(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ManualSigmas", + category="_for_testing/custom_sampling", + is_experimental=True, + inputs=[ + io.String.Input("sigmas", default="1, 0.5", multiline=False) + ], + outputs=[io.Sigmas.Output()] + ) + + @classmethod + def execute(cls, sigmas) -> io.NodeOutput: + sigmas = re.findall(r"[-+]?(?:\d*\.*\d+)", sigmas) + sigmas = [float(i) for i in sigmas] + sigmas = torch.FloatTensor(sigmas) + return io.NodeOutput(sigmas) class CustomSamplersExtension(ComfyExtension): @override @@ -1044,6 +1072,7 @@ async def get_node_list(self) -> list[type[io.ComfyNode]]: DisableNoise, AddNoise, SamplerCustomAdvanced, + ManualSigmas, ] diff --git a/comfy_extras/nodes_dataset.py b/comfy_extras/nodes_dataset.py index 513aecf3a941..5ef851bd0387 100644 --- a/comfy_extras/nodes_dataset.py +++ b/comfy_extras/nodes_dataset.py @@ -667,16 +667,19 @@ class ResizeImagesByLongerEdgeNode(ImageProcessingNode): @classmethod def _process(cls, image, longer_edge): - img = tensor_to_pil(image) - w, h = img.size - if w > h: - new_w = longer_edge - new_h = int(h * (longer_edge / w)) - else: - new_h = longer_edge - new_w = int(w * (longer_edge / h)) - img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) - return pil_to_tensor(img) + resized_images = [] + for image_i in image: + img = tensor_to_pil(image_i) + w, h = img.size + if w > h: + new_w = longer_edge + new_h = int(h * (longer_edge / w)) + else: + new_h = longer_edge + new_w = int(w * (longer_edge / h)) + img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) + resized_images.append(pil_to_tensor(img)) + return torch.cat(resized_images, dim=0) class CenterCropImagesNode(ImageProcessingNode): diff --git a/comfy_extras/nodes_hunyuan.py b/comfy_extras/nodes_hunyuan.py index 32be182f1575..ceff657d35f9 100644 --- a/comfy_extras/nodes_hunyuan.py +++ b/comfy_extras/nodes_hunyuan.py @@ -5,7 +5,9 @@ from typing_extensions import override from comfy_api.latest import ComfyExtension, io from comfy.ldm.hunyuan_video.upsampler import HunyuanVideo15SRModel +from comfy.ldm.lightricks.latent_upsampler import LatentUpsampler import folder_paths +import json class CLIPTextEncodeHunyuanDiT(io.ComfyNode): @classmethod @@ -186,7 +188,7 @@ def define_schema(cls): @classmethod def execute(cls, model_name) -> io.NodeOutput: model_path = folder_paths.get_full_path_or_raise("latent_upscale_models", model_name) - sd = comfy.utils.load_torch_file(model_path, safe_load=True) + sd, metadata = comfy.utils.load_torch_file(model_path, safe_load=True, return_metadata=True) if "blocks.0.block.0.conv.weight" in sd: config = { @@ -197,6 +199,8 @@ def execute(cls, model_name) -> io.NodeOutput: "global_residual": False, } model_type = "720p" + model = HunyuanVideo15SRModel(model_type, config) + model.load_sd(sd) elif "up.0.block.0.conv1.conv.weight" in sd: sd = {key.replace("nin_shortcut", "nin_shortcut.conv", 1): value for key, value in sd.items()} config = { @@ -205,9 +209,12 @@ def execute(cls, model_name) -> io.NodeOutput: "block_out_channels": tuple(sd[f"up.{i}.block.0.conv1.conv.weight"].shape[0] for i in range(len([k for k in sd.keys() if k.startswith("up.") and k.endswith(".block.0.conv1.conv.weight")]))), } model_type = "1080p" - - model = HunyuanVideo15SRModel(model_type, config) - model.load_sd(sd) + model = HunyuanVideo15SRModel(model_type, config) + model.load_sd(sd) + elif "post_upsample_res_blocks.0.conv2.bias" in sd: + config = json.loads(metadata["config"]) + model = LatentUpsampler.from_config(config).to(dtype=comfy.model_management.vae_dtype(allowed_dtypes=[torch.bfloat16, torch.float32])) + model.load_state_dict(sd) return io.NodeOutput(model) diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index 392aea32c268..ce21caade8cf 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -2,280 +2,231 @@ import nodes import folder_paths -from comfy.cli_args import args -from PIL import Image -from PIL.PngImagePlugin import PngInfo - -import numpy as np import json import os import re -from io import BytesIO -from inspect import cleandoc import torch import comfy.utils -from comfy.comfy_types import FileLocator, IO from server import PromptServer +from comfy_api.latest import ComfyExtension, IO, UI +from typing_extensions import override + +SVG = IO.SVG.Type # TODO: temporary solution for backward compatibility, will be removed later. MAX_RESOLUTION = nodes.MAX_RESOLUTION -class ImageCrop: +class ImageCrop(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ImageCrop", + display_name="Image Crop", + category="image/transform", + inputs=[ + IO.Image.Input("image"), + IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + ], + outputs=[IO.Image.Output()], + ) + @classmethod - def INPUT_TYPES(s): - return {"required": { "image": ("IMAGE",), - "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), - "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), - "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - }} - RETURN_TYPES = ("IMAGE",) - FUNCTION = "crop" - - CATEGORY = "image/transform" - - def crop(self, image, width, height, x, y): + def execute(cls, image, width, height, x, y) -> IO.NodeOutput: x = min(x, image.shape[2] - 1) y = min(y, image.shape[1] - 1) to_x = width + x to_y = height + y img = image[:,y:to_y, x:to_x, :] - return (img,) + return IO.NodeOutput(img) -class RepeatImageBatch: - @classmethod - def INPUT_TYPES(s): - return {"required": { "image": ("IMAGE",), - "amount": ("INT", {"default": 1, "min": 1, "max": 4096}), - }} - RETURN_TYPES = ("IMAGE",) - FUNCTION = "repeat" + crop = execute # TODO: remove - CATEGORY = "image/batch" - def repeat(self, image, amount): - s = image.repeat((amount, 1,1,1)) - return (s,) +class RepeatImageBatch(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RepeatImageBatch", + category="image/batch", + inputs=[ + IO.Image.Input("image"), + IO.Int.Input("amount", default=1, min=1, max=4096), + ], + outputs=[IO.Image.Output()], + ) -class ImageFromBatch: @classmethod - def INPUT_TYPES(s): - return {"required": { "image": ("IMAGE",), - "batch_index": ("INT", {"default": 0, "min": 0, "max": 4095}), - "length": ("INT", {"default": 1, "min": 1, "max": 4096}), - }} - RETURN_TYPES = ("IMAGE",) - FUNCTION = "frombatch" + def execute(cls, image, amount) -> IO.NodeOutput: + s = image.repeat((amount, 1,1,1)) + return IO.NodeOutput(s) - CATEGORY = "image/batch" + repeat = execute # TODO: remove - def frombatch(self, image, batch_index, length): + +class ImageFromBatch(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ImageFromBatch", + category="image/batch", + inputs=[ + IO.Image.Input("image"), + IO.Int.Input("batch_index", default=0, min=0, max=4095), + IO.Int.Input("length", default=1, min=1, max=4096), + ], + outputs=[IO.Image.Output()], + ) + + @classmethod + def execute(cls, image, batch_index, length) -> IO.NodeOutput: s_in = image batch_index = min(s_in.shape[0] - 1, batch_index) length = min(s_in.shape[0] - batch_index, length) s = s_in[batch_index:batch_index + length].clone() - return (s,) + return IO.NodeOutput(s) + frombatch = execute # TODO: remove -class ImageAddNoise: - @classmethod - def INPUT_TYPES(s): - return {"required": { "image": ("IMAGE",), - "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True, "tooltip": "The random seed used for creating the noise."}), - "strength": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), - }} - RETURN_TYPES = ("IMAGE",) - FUNCTION = "repeat" - CATEGORY = "image" +class ImageAddNoise(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ImageAddNoise", + category="image", + inputs=[ + IO.Image.Input("image"), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", + ), + IO.Float.Input("strength", default=0.5, min=0.0, max=1.0, step=0.01), + ], + outputs=[IO.Image.Output()], + ) - def repeat(self, image, seed, strength): + @classmethod + def execute(cls, image, seed, strength) -> IO.NodeOutput: generator = torch.manual_seed(seed) s = torch.clip((image + strength * torch.randn(image.size(), generator=generator, device="cpu").to(image)), min=0.0, max=1.0) - return (s,) + return IO.NodeOutput(s) + + repeat = execute # TODO: remove + -class SaveAnimatedWEBP: - def __init__(self): - self.output_dir = folder_paths.get_output_directory() - self.type = "output" - self.prefix_append = "" +class SaveAnimatedWEBP(IO.ComfyNode): + COMPRESS_METHODS = {"default": 4, "fastest": 0, "slowest": 6} - methods = {"default": 4, "fastest": 0, "slowest": 6} @classmethod - def INPUT_TYPES(s): - return {"required": - {"images": ("IMAGE", ), - "filename_prefix": ("STRING", {"default": "ComfyUI"}), - "fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}), - "lossless": ("BOOLEAN", {"default": True}), - "quality": ("INT", {"default": 80, "min": 0, "max": 100}), - "method": (list(s.methods.keys()),), - # "num_frames": ("INT", {"default": 0, "min": 0, "max": 8192}), - }, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, - } - - RETURN_TYPES = () - FUNCTION = "save_images" - - OUTPUT_NODE = True - - CATEGORY = "image/animation" - - def save_images(self, images, fps, filename_prefix, lossless, quality, method, num_frames=0, prompt=None, extra_pnginfo=None): - method = self.methods.get(method) - filename_prefix += self.prefix_append - full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) - results: list[FileLocator] = [] - pil_images = [] - for image in images: - i = 255. * image.cpu().numpy() - img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) - pil_images.append(img) - - metadata = pil_images[0].getexif() - if not args.disable_metadata: - if prompt is not None: - metadata[0x0110] = "prompt:{}".format(json.dumps(prompt)) - if extra_pnginfo is not None: - inital_exif = 0x010f - for x in extra_pnginfo: - metadata[inital_exif] = "{}:{}".format(x, json.dumps(extra_pnginfo[x])) - inital_exif -= 1 - - if num_frames == 0: - num_frames = len(pil_images) - - c = len(pil_images) - for i in range(0, c, num_frames): - file = f"{filename}_{counter:05}_.webp" - pil_images[i].save(os.path.join(full_output_folder, file), save_all=True, duration=int(1000.0/fps), append_images=pil_images[i + 1:i + num_frames], exif=metadata, lossless=lossless, quality=quality, method=method) - results.append({ - "filename": file, - "subfolder": subfolder, - "type": self.type - }) - counter += 1 + def define_schema(cls): + return IO.Schema( + node_id="SaveAnimatedWEBP", + category="image/animation", + inputs=[ + IO.Image.Input("images"), + IO.String.Input("filename_prefix", default="ComfyUI"), + IO.Float.Input("fps", default=6.0, min=0.01, max=1000.0, step=0.01), + IO.Boolean.Input("lossless", default=True), + IO.Int.Input("quality", default=80, min=0, max=100), + IO.Combo.Input("method", options=list(cls.COMPRESS_METHODS.keys())), + # "num_frames": ("INT", {"default": 0, "min": 0, "max": 8192}), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) - animated = num_frames != 1 - return { "ui": { "images": results, "animated": (animated,) } } + @classmethod + def execute(cls, images, fps, filename_prefix, lossless, quality, method, num_frames=0) -> IO.NodeOutput: + return IO.NodeOutput( + ui=UI.ImageSaveHelper.get_save_animated_webp_ui( + images=images, + filename_prefix=filename_prefix, + cls=cls, + fps=fps, + lossless=lossless, + quality=quality, + method=cls.COMPRESS_METHODS.get(method) + ) + ) + + save_images = execute # TODO: remove -class SaveAnimatedPNG: - def __init__(self): - self.output_dir = folder_paths.get_output_directory() - self.type = "output" - self.prefix_append = "" + +class SaveAnimatedPNG(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"images": ("IMAGE", ), - "filename_prefix": ("STRING", {"default": "ComfyUI"}), - "fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}), - "compress_level": ("INT", {"default": 4, "min": 0, "max": 9}) - }, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, - } - - RETURN_TYPES = () - FUNCTION = "save_images" - - OUTPUT_NODE = True - - CATEGORY = "image/animation" - - def save_images(self, images, fps, compress_level, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): - filename_prefix += self.prefix_append - full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) - results = list() - pil_images = [] - for image in images: - i = 255. * image.cpu().numpy() - img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) - pil_images.append(img) - - metadata = None - if not args.disable_metadata: - metadata = PngInfo() - if prompt is not None: - metadata.add(b"comf", "prompt".encode("latin-1", "strict") + b"\0" + json.dumps(prompt).encode("latin-1", "strict"), after_idat=True) - if extra_pnginfo is not None: - for x in extra_pnginfo: - metadata.add(b"comf", x.encode("latin-1", "strict") + b"\0" + json.dumps(extra_pnginfo[x]).encode("latin-1", "strict"), after_idat=True) - - file = f"{filename}_{counter:05}_.png" - pil_images[0].save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level, save_all=True, duration=int(1000.0/fps), append_images=pil_images[1:]) - results.append({ - "filename": file, - "subfolder": subfolder, - "type": self.type - }) - - return { "ui": { "images": results, "animated": (True,)} } - -class SVG: - """ - Stores SVG representations via a list of BytesIO objects. - """ - def __init__(self, data: list[BytesIO]): - self.data = data - - def combine(self, other: 'SVG') -> 'SVG': - return SVG(self.data + other.data) - - @staticmethod - def combine_all(svgs: list['SVG']) -> 'SVG': - all_svgs_list: list[BytesIO] = [] - for svg_item in svgs: - all_svgs_list.extend(svg_item.data) - return SVG(all_svgs_list) - - -class ImageStitch: + def define_schema(cls): + return IO.Schema( + node_id="SaveAnimatedPNG", + category="image/animation", + inputs=[ + IO.Image.Input("images"), + IO.String.Input("filename_prefix", default="ComfyUI"), + IO.Float.Input("fps", default=6.0, min=0.01, max=1000.0, step=0.01), + IO.Int.Input("compress_level", default=4, min=0, max=9), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) + + @classmethod + def execute(cls, images, fps, compress_level, filename_prefix="ComfyUI") -> IO.NodeOutput: + return IO.NodeOutput( + ui=UI.ImageSaveHelper.get_save_animated_png_ui( + images=images, + filename_prefix=filename_prefix, + cls=cls, + fps=fps, + compress_level=compress_level, + ) + ) + + save_images = execute # TODO: remove + + +class ImageStitch(IO.ComfyNode): """Upstreamed from https://github.com/kijai/ComfyUI-KJNodes""" @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image1": ("IMAGE",), - "direction": (["right", "down", "left", "up"], {"default": "right"}), - "match_image_size": ("BOOLEAN", {"default": True}), - "spacing_width": ( - "INT", - {"default": 0, "min": 0, "max": 1024, "step": 2}, - ), - "spacing_color": ( - ["white", "black", "red", "green", "blue"], - {"default": "white"}, - ), - }, - "optional": { - "image2": ("IMAGE",), - }, - } + def define_schema(cls): + return IO.Schema( + node_id="ImageStitch", + display_name="Image Stitch", + description="Stitches image2 to image1 in the specified direction.\n" + "If image2 is not provided, returns image1 unchanged.\n" + "Optional spacing can be added between images.", + category="image/transform", + inputs=[ + IO.Image.Input("image1"), + IO.Combo.Input("direction", options=["right", "down", "left", "up"], default="right"), + IO.Boolean.Input("match_image_size", default=True), + IO.Int.Input("spacing_width", default=0, min=0, max=1024, step=2), + IO.Combo.Input("spacing_color", options=["white", "black", "red", "green", "blue"], default="white"), + IO.Image.Input("image2", optional=True), + ], + outputs=[IO.Image.Output()], + ) - RETURN_TYPES = ("IMAGE",) - FUNCTION = "stitch" - CATEGORY = "image/transform" - DESCRIPTION = """ -Stitches image2 to image1 in the specified direction. -If image2 is not provided, returns image1 unchanged. -Optional spacing can be added between images. -""" - - def stitch( - self, + @classmethod + def execute( + cls, image1, direction, match_image_size, spacing_width, spacing_color, image2=None, - ): + ) -> IO.NodeOutput: if image2 is None: - return (image1,) + return IO.NodeOutput(image1) # Handle batch size differences if image1.shape[0] != image2.shape[0]: @@ -412,36 +363,30 @@ def stitch( images.insert(1, spacing) concat_dim = 2 if direction in ["left", "right"] else 1 - return (torch.cat(images, dim=concat_dim),) + return IO.NodeOutput(torch.cat(images, dim=concat_dim)) + + stitch = execute # TODO: remove -class ResizeAndPadImage: - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "image": ("IMAGE",), - "target_width": ("INT", { - "default": 512, - "min": 1, - "max": MAX_RESOLUTION, - "step": 1 - }), - "target_height": ("INT", { - "default": 512, - "min": 1, - "max": MAX_RESOLUTION, - "step": 1 - }), - "padding_color": (["white", "black"],), - "interpolation": (["area", "bicubic", "nearest-exact", "bilinear", "lanczos"],), - } - } - RETURN_TYPES = ("IMAGE",) - FUNCTION = "resize_and_pad" - CATEGORY = "image/transform" +class ResizeAndPadImage(IO.ComfyNode): - def resize_and_pad(self, image, target_width, target_height, padding_color, interpolation): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ResizeAndPadImage", + category="image/transform", + inputs=[ + IO.Image.Input("image"), + IO.Int.Input("target_width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("target_height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), + IO.Combo.Input("padding_color", options=["white", "black"]), + IO.Combo.Input("interpolation", options=["area", "bicubic", "nearest-exact", "bilinear", "lanczos"]), + ], + outputs=[IO.Image.Output()], + ) + + @classmethod + def execute(cls, image, target_width, target_height, padding_color, interpolation) -> IO.NodeOutput: batch_size, orig_height, orig_width, channels = image.shape scale_w = target_width / orig_width @@ -469,52 +414,47 @@ def resize_and_pad(self, image, target_width, target_height, padding_color, inte padded[:, :, y_offset:y_offset + new_height, x_offset:x_offset + new_width] = resized output = padded.permute(0, 2, 3, 1) - return (output,) + return IO.NodeOutput(output) -class SaveSVGNode: - """ - Save SVG files on disk. - """ + resize_and_pad = execute # TODO: remove - def __init__(self): - self.output_dir = folder_paths.get_output_directory() - self.type = "output" - self.prefix_append = "" - RETURN_TYPES = () - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "save_svg" - CATEGORY = "image/save" # Changed - OUTPUT_NODE = True +class SaveSVGNode(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "svg": ("SVG",), # Changed - "filename_prefix": ("STRING", {"default": "svg/ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."}) - }, - "hidden": { - "prompt": "PROMPT", - "extra_pnginfo": "EXTRA_PNGINFO" - } - } + def define_schema(cls): + return IO.Schema( + node_id="SaveSVGNode", + description="Save SVG files on disk.", + category="image/save", + inputs=[ + IO.SVG.Input("svg"), + IO.String.Input( + "filename_prefix", + default="svg/ComfyUI", + tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes.", + ), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) - def save_svg(self, svg: SVG, filename_prefix="svg/ComfyUI", prompt=None, extra_pnginfo=None): - filename_prefix += self.prefix_append - full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) - results = list() + @classmethod + def execute(cls, svg: IO.SVG.Type, filename_prefix="svg/ComfyUI") -> IO.NodeOutput: + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory()) + results: list[UI.SavedResult] = [] # Prepare metadata JSON metadata_dict = {} - if prompt is not None: - metadata_dict["prompt"] = prompt - if extra_pnginfo is not None: - metadata_dict.update(extra_pnginfo) + if cls.hidden.prompt is not None: + metadata_dict["prompt"] = cls.hidden.prompt + if cls.hidden.extra_pnginfo is not None: + metadata_dict.update(cls.hidden.extra_pnginfo) # Convert metadata to JSON string metadata_json = json.dumps(metadata_dict, indent=2) if metadata_dict else None + for batch_number, svg_bytes in enumerate(svg.data): filename_with_batch_num = filename.replace("%batch_num%", str(batch_number)) file = f"{filename_with_batch_num}_{counter:05}_.svg" @@ -544,57 +484,64 @@ def replacement(match): with open(os.path.join(full_output_folder, file), 'wb') as svg_file: svg_file.write(svg_content.encode('utf-8')) - results.append({ - "filename": file, - "subfolder": subfolder, - "type": self.type - }) + results.append(UI.SavedResult(filename=file, subfolder=subfolder, type=IO.FolderType.output)) counter += 1 - return { "ui": { "images": results } } + return IO.NodeOutput(ui={"images": results}) -class GetImageSize: + save_svg = execute # TODO: remove - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE,), - }, - "hidden": { - "unique_id": "UNIQUE_ID", - } - } - RETURN_TYPES = (IO.INT, IO.INT, IO.INT) - RETURN_NAMES = ("width", "height", "batch_size") - FUNCTION = "get_size" +class GetImageSize(IO.ComfyNode): - CATEGORY = "image" - DESCRIPTION = """Returns width and height of the image, and passes it through unchanged.""" + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="GetImageSize", + display_name="Get Image Size", + description="Returns width and height of the image, and passes it through unchanged.", + category="image", + inputs=[ + IO.Image.Input("image"), + ], + outputs=[ + IO.Int.Output(display_name="width"), + IO.Int.Output(display_name="height"), + IO.Int.Output(display_name="batch_size"), + ], + hidden=[IO.Hidden.unique_id], + ) - def get_size(self, image, unique_id=None) -> tuple[int, int]: + @classmethod + def execute(cls, image) -> IO.NodeOutput: height = image.shape[1] width = image.shape[2] batch_size = image.shape[0] # Send progress text to display size on the node - if unique_id: - PromptServer.instance.send_progress_text(f"width: {width}, height: {height}\n batch size: {batch_size}", unique_id) + if cls.hidden.unique_id: + PromptServer.instance.send_progress_text(f"width: {width}, height: {height}\n batch size: {batch_size}", cls.hidden.unique_id) - return width, height, batch_size + return IO.NodeOutput(width, height, batch_size) -class ImageRotate: - @classmethod - def INPUT_TYPES(s): - return {"required": { "image": (IO.IMAGE,), - "rotation": (["none", "90 degrees", "180 degrees", "270 degrees"],), - }} - RETURN_TYPES = (IO.IMAGE,) - FUNCTION = "rotate" + get_size = execute # TODO: remove + + +class ImageRotate(IO.ComfyNode): - CATEGORY = "image/transform" + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ImageRotate", + category="image/transform", + inputs=[ + IO.Image.Input("image"), + IO.Combo.Input("rotation", options=["none", "90 degrees", "180 degrees", "270 degrees"]), + ], + outputs=[IO.Image.Output()], + ) - def rotate(self, image, rotation): + @classmethod + def execute(cls, image, rotation) -> IO.NodeOutput: rotate_by = 0 if rotation.startswith("90"): rotate_by = 1 @@ -604,41 +551,57 @@ def rotate(self, image, rotation): rotate_by = 3 image = torch.rot90(image, k=rotate_by, dims=[2, 1]) - return (image,) + return IO.NodeOutput(image) -class ImageFlip: - @classmethod - def INPUT_TYPES(s): - return {"required": { "image": (IO.IMAGE,), - "flip_method": (["x-axis: vertically", "y-axis: horizontally"],), - }} - RETURN_TYPES = (IO.IMAGE,) - FUNCTION = "flip" + rotate = execute # TODO: remove + + +class ImageFlip(IO.ComfyNode): - CATEGORY = "image/transform" + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ImageFlip", + category="image/transform", + inputs=[ + IO.Image.Input("image"), + IO.Combo.Input("flip_method", options=["x-axis: vertically", "y-axis: horizontally"]), + ], + outputs=[IO.Image.Output()], + ) - def flip(self, image, flip_method): + @classmethod + def execute(cls, image, flip_method) -> IO.NodeOutput: if flip_method.startswith("x"): image = torch.flip(image, dims=[1]) elif flip_method.startswith("y"): image = torch.flip(image, dims=[2]) - return (image,) + return IO.NodeOutput(image) -class ImageScaleToMaxDimension: - upscale_methods = ["area", "lanczos", "bilinear", "nearest-exact", "bilinear", "bicubic"] + flip = execute # TODO: remove - @classmethod - def INPUT_TYPES(s): - return {"required": {"image": ("IMAGE",), - "upscale_method": (s.upscale_methods,), - "largest_size": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1})}} - RETURN_TYPES = ("IMAGE",) - FUNCTION = "upscale" - CATEGORY = "image/upscaling" +class ImageScaleToMaxDimension(IO.ComfyNode): - def upscale(self, image, upscale_method, largest_size): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ImageScaleToMaxDimension", + category="image/upscaling", + inputs=[ + IO.Image.Input("image"), + IO.Combo.Input( + "upscale_method", + options=["area", "lanczos", "bilinear", "nearest-exact", "bilinear", "bicubic"], + ), + IO.Int.Input("largest_size", default=512, min=0, max=MAX_RESOLUTION, step=1), + ], + outputs=[IO.Image.Output()], + ) + + @classmethod + def execute(cls, image, upscale_method, largest_size) -> IO.NodeOutput: height = image.shape[1] width = image.shape[2] @@ -655,20 +618,30 @@ def upscale(self, image, upscale_method, largest_size): samples = image.movedim(-1, 1) s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled") s = s.movedim(1, -1) - return (s,) - -NODE_CLASS_MAPPINGS = { - "ImageCrop": ImageCrop, - "RepeatImageBatch": RepeatImageBatch, - "ImageFromBatch": ImageFromBatch, - "ImageAddNoise": ImageAddNoise, - "SaveAnimatedWEBP": SaveAnimatedWEBP, - "SaveAnimatedPNG": SaveAnimatedPNG, - "SaveSVGNode": SaveSVGNode, - "ImageStitch": ImageStitch, - "ResizeAndPadImage": ResizeAndPadImage, - "GetImageSize": GetImageSize, - "ImageRotate": ImageRotate, - "ImageFlip": ImageFlip, - "ImageScaleToMaxDimension": ImageScaleToMaxDimension, -} + return IO.NodeOutput(s) + + upscale = execute # TODO: remove + + +class ImagesExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + ImageCrop, + RepeatImageBatch, + ImageFromBatch, + ImageAddNoise, + SaveAnimatedWEBP, + SaveAnimatedPNG, + SaveSVGNode, + ImageStitch, + ResizeAndPadImage, + GetImageSize, + ImageRotate, + ImageFlip, + ImageScaleToMaxDimension, + ] + + +async def comfy_entrypoint() -> ImagesExtension: + return ImagesExtension() diff --git a/comfy_extras/nodes_latent.py b/comfy_extras/nodes_latent.py index 2815c5ffc258..9ba1c4ba8a66 100644 --- a/comfy_extras/nodes_latent.py +++ b/comfy_extras/nodes_latent.py @@ -255,6 +255,7 @@ def define_schema(cls): return io.Schema( node_id="LatentBatch", category="latent/batch", + is_deprecated=True, inputs=[ io.Latent.Input("samples1"), io.Latent.Input("samples2"), diff --git a/comfy_extras/nodes_logic.py b/comfy_extras/nodes_logic.py index 95a6ba788faf..eb888316acdc 100644 --- a/comfy_extras/nodes_logic.py +++ b/comfy_extras/nodes_logic.py @@ -1,8 +1,11 @@ +from __future__ import annotations from typing import TypedDict from typing_extensions import override from comfy_api.latest import ComfyExtension, io from comfy_api.latest import _io +# sentinel for missing inputs +MISSING = object() class SwitchNode(io.ComfyNode): @@ -14,6 +17,37 @@ def define_schema(cls): display_name="Switch", category="logic", is_experimental=True, + inputs=[ + io.Boolean.Input("switch"), + io.MatchType.Input("on_false", template=template, lazy=True), + io.MatchType.Input("on_true", template=template, lazy=True), + ], + outputs=[ + io.MatchType.Output(template=template, display_name="output"), + ], + ) + + @classmethod + def check_lazy_status(cls, switch, on_false=None, on_true=None): + if switch and on_true is None: + return ["on_true"] + if not switch and on_false is None: + return ["on_false"] + + @classmethod + def execute(cls, switch, on_true, on_false) -> io.NodeOutput: + return io.NodeOutput(on_true if switch else on_false) + + +class SoftSwitchNode(io.ComfyNode): + @classmethod + def define_schema(cls): + template = io.MatchType.Template("switch") + return io.Schema( + node_id="ComfySoftSwitchNode", + display_name="Soft Switch", + category="logic", + is_experimental=True, inputs=[ io.Boolean.Input("switch"), io.MatchType.Input("on_false", template=template, lazy=True, optional=True), @@ -25,14 +59,14 @@ def define_schema(cls): ) @classmethod - def check_lazy_status(cls, switch, on_false=..., on_true=...): - # We use ... instead of None, as None is passed for connected-but-unevaluated inputs. + def check_lazy_status(cls, switch, on_false=MISSING, on_true=MISSING): + # We use MISSING instead of None, as None is passed for connected-but-unevaluated inputs. # This trick allows us to ignore the value of the switch and still be able to run execute(). # One of the inputs may be missing, in which case we need to evaluate the other input - if on_false is ...: + if on_false is MISSING: return ["on_true"] - if on_true is ...: + if on_true is MISSING: return ["on_false"] # Normal lazy switch operation if switch and on_true is None: @@ -41,22 +75,50 @@ def check_lazy_status(cls, switch, on_false=..., on_true=...): return ["on_false"] @classmethod - def validate_inputs(cls, switch, on_false=..., on_true=...): + def validate_inputs(cls, switch, on_false=MISSING, on_true=MISSING): # This check happens before check_lazy_status(), so we can eliminate the case where # both inputs are missing. - if on_false is ... and on_true is ...: + if on_false is MISSING and on_true is MISSING: return "At least one of on_false or on_true must be connected to Switch node" return True @classmethod - def execute(cls, switch, on_true=..., on_false=...) -> io.NodeOutput: - if on_true is ...: + def execute(cls, switch, on_true=MISSING, on_false=MISSING) -> io.NodeOutput: + if on_true is MISSING: return io.NodeOutput(on_false) - if on_false is ...: + if on_false is MISSING: return io.NodeOutput(on_true) return io.NodeOutput(on_true if switch else on_false) +class CustomComboNode(io.ComfyNode): + """ + Frontend node that allows user to write their own options for a combo. + This is here to make sure the node has a backend-representation to avoid some annoyances. + """ + @classmethod + def define_schema(cls): + return io.Schema( + node_id="CustomCombo", + display_name="Custom Combo", + category="utils", + is_experimental=True, + inputs=[io.Combo.Input("choice", options=[])], + outputs=[io.String.Output()] + ) + + @classmethod + def validate_inputs(cls, choice: io.Combo.Type) -> bool: + # NOTE: DO NOT DO THIS unless you want to skip validation entirely on the node's inputs. + # I am doing that here because the widgets (besides the combo dropdown) on this node are fully frontend defined. + # I need to skip checking that the chosen combo option is in the options list, since those are defined by the user. + return True + + @classmethod + def execute(cls, choice: io.Combo.Type) -> io.NodeOutput: + return io.NodeOutput(choice) + + class DCTestNode(io.ComfyNode): class DCValues(TypedDict): combo: str @@ -72,14 +134,14 @@ def define_schema(cls): display_name="DCTest", category="logic", is_output_node=True, - inputs=[_io.DynamicCombo.Input("combo", options=[ - _io.DynamicCombo.Option("option1", [io.String.Input("string")]), - _io.DynamicCombo.Option("option2", [io.Int.Input("integer")]), - _io.DynamicCombo.Option("option3", [io.Image.Input("image")]), - _io.DynamicCombo.Option("option4", [ - _io.DynamicCombo.Input("subcombo", options=[ - _io.DynamicCombo.Option("opt1", [io.Float.Input("float_x"), io.Float.Input("float_y")]), - _io.DynamicCombo.Option("opt2", [io.Mask.Input("mask1", optional=True)]), + inputs=[io.DynamicCombo.Input("combo", options=[ + io.DynamicCombo.Option("option1", [io.String.Input("string")]), + io.DynamicCombo.Option("option2", [io.Int.Input("integer")]), + io.DynamicCombo.Option("option3", [io.Image.Input("image")]), + io.DynamicCombo.Option("option4", [ + io.DynamicCombo.Input("subcombo", options=[ + io.DynamicCombo.Option("opt1", [io.Float.Input("float_x"), io.Float.Input("float_y")]), + io.DynamicCombo.Option("opt2", [io.Mask.Input("mask1", optional=True)]), ]) ])] )], @@ -141,14 +203,65 @@ def execute(cls, autogrow: _io.Autogrow.Type) -> io.NodeOutput: combined = ",".join([str(x) for x in vals]) return io.NodeOutput(combined) +class ComboOutputTestNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ComboOptionTestNode", + display_name="ComboOptionTest", + category="logic", + inputs=[io.Combo.Input("combo", options=["option1", "option2", "option3"]), + io.Combo.Input("combo2", options=["option4", "option5", "option6"])], + outputs=[io.Combo.Output(), io.Combo.Output()], + ) + + @classmethod + def execute(cls, combo: io.Combo.Type, combo2: io.Combo.Type) -> io.NodeOutput: + return io.NodeOutput(combo, combo2) + +class ConvertStringToComboNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ConvertStringToComboNode", + display_name="Convert String to Combo", + category="logic", + inputs=[io.String.Input("string")], + outputs=[io.Combo.Output()], + ) + + @classmethod + def execute(cls, string: str) -> io.NodeOutput: + return io.NodeOutput(string) + +class InvertBooleanNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="InvertBooleanNode", + display_name="Invert Boolean", + category="logic", + inputs=[io.Boolean.Input("boolean")], + outputs=[io.Boolean.Output()], + ) + + @classmethod + def execute(cls, boolean: bool) -> io.NodeOutput: + return io.NodeOutput(not boolean) + class LogicExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ - # SwitchNode, + SwitchNode, + CustomComboNode, + # SoftSwitchNode, + # ConvertStringToComboNode, # DCTestNode, # AutogrowNamesTestNode, # AutogrowPrefixTestNode, + # ComboOutputTestNode, + # InvertBooleanNode, ] async def comfy_entrypoint() -> LogicExtension: diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index 50da5f4eb606..b91a22309de1 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -81,6 +81,59 @@ def execute(cls, positive, negative, image, vae, width, height, length, batch_si generate = execute # TODO: remove +class LTXVImgToVideoInplace(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LTXVImgToVideoInplace", + category="conditioning/video_models", + inputs=[ + io.Vae.Input("vae"), + io.Image.Input("image"), + io.Latent.Input("latent"), + io.Float.Input("strength", default=1.0, min=0.0, max=1.0), + io.Boolean.Input("bypass", default=False, tooltip="Bypass the conditioning.") + ], + outputs=[ + io.Latent.Output(display_name="latent"), + ], + ) + + @classmethod + def execute(cls, vae, image, latent, strength, bypass=False) -> io.NodeOutput: + if bypass: + return (latent,) + + samples = latent["samples"] + _, height_scale_factor, width_scale_factor = ( + vae.downscale_index_formula + ) + + batch, _, latent_frames, latent_height, latent_width = samples.shape + width = latent_width * width_scale_factor + height = latent_height * height_scale_factor + + if image.shape[1] != height or image.shape[2] != width: + pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + else: + pixels = image + encode_pixels = pixels[:, :, :, :3] + t = vae.encode(encode_pixels) + + samples[:, :, :t.shape[2]] = t + + conditioning_latent_frames_mask = torch.ones( + (batch, 1, latent_frames, 1, 1), + dtype=torch.float32, + device=samples.device, + ) + conditioning_latent_frames_mask[:, :, :t.shape[2]] = 1.0 - strength + + return io.NodeOutput({"samples": samples, "noise_mask": conditioning_latent_frames_mask}) + + generate = execute # TODO: remove + + def conditioning_get_any_value(conditioning, key, default=None): for t in conditioning: if key in t[1]: @@ -106,12 +159,12 @@ def get_keyframe_idxs(cond): keyframe_idxs = conditioning_get_any_value(cond, "keyframe_idxs", None) if keyframe_idxs is None: return None, 0 - num_keyframes = torch.unique(keyframe_idxs[:, 0]).shape[0] + # keyframe_idxs contains start/end positions (last dimension), checking for unqiue values only for start + num_keyframes = torch.unique(keyframe_idxs[:, 0, :, 0]).shape[0] return keyframe_idxs, num_keyframes class LTXVAddGuide(io.ComfyNode): - NUM_PREFIX_FRAMES = 2 - PATCHIFIER = SymmetricPatchifier(1) + PATCHIFIER = SymmetricPatchifier(1, start_end=True) @classmethod def define_schema(cls): @@ -182,26 +235,35 @@ def add_keyframe_index(cls, cond, frame_idx, guiding_latent, scale_factors): return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs}) @classmethod - def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors): - _, latent_idx = cls.get_latent_index( - cond=positive, - latent_length=latent_image.shape[2], - guide_length=guiding_latent.shape[2], - frame_idx=frame_idx, - scale_factors=scale_factors, - ) - noise_mask[:, :, latent_idx:latent_idx + guiding_latent.shape[2]] = 1.0 + def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors, guide_mask=None, in_channels=128): + if latent_image.shape[1] != in_channels or guiding_latent.shape[1] != in_channels: + raise ValueError("Adding guide to a combined AV latent is not supported.") positive = cls.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors) negative = cls.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors) - mask = torch.full( - (noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]), - 1.0 - strength, - dtype=noise_mask.dtype, - device=noise_mask.device, - ) + if guide_mask is not None: + target_h = max(noise_mask.shape[3], guide_mask.shape[3]) + target_w = max(noise_mask.shape[4], guide_mask.shape[4]) + if noise_mask.shape[3] == 1 or noise_mask.shape[4] == 1: + noise_mask = noise_mask.expand(-1, -1, -1, target_h, target_w) + + if guide_mask.shape[3] == 1 or guide_mask.shape[4] == 1: + guide_mask = guide_mask.expand(-1, -1, -1, target_h, target_w) + mask = guide_mask - strength + else: + mask = torch.full( + (noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]), + 1.0 - strength, + dtype=noise_mask.dtype, + device=noise_mask.device, + ) + # This solves audio video combined latent case where latent_image has audio latent concatenated + # in channel dimension with video latent. The solution is to pad guiding latent accordingly. + if latent_image.shape[1] > guiding_latent.shape[1]: + pad_len = latent_image.shape[1] - guiding_latent.shape[1] + guiding_latent = torch.nn.functional.pad(guiding_latent, pad=(0, 0, 0, 0, 0, 0, 0, pad_len), value=0) latent_image = torch.cat([latent_image, guiding_latent], dim=2) noise_mask = torch.cat([noise_mask, mask], dim=2) return positive, negative, latent_image, noise_mask @@ -238,31 +300,15 @@ def execute(cls, positive, negative, vae, latent, image, frame_idx, strength) -> frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors) assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence." - num_prefix_frames = min(cls.NUM_PREFIX_FRAMES, t.shape[2]) - positive, negative, latent_image, noise_mask = cls.append_keyframe( positive, negative, frame_idx, - latent_image, - noise_mask, - t[:, :, :num_prefix_frames], - strength, - scale_factors, - ) - - latent_idx += num_prefix_frames - - t = t[:, :, num_prefix_frames:] - if t.shape[2] == 0: - return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask}) - - latent_image, noise_mask = cls.replace_latent_frames( latent_image, noise_mask, t, - latent_idx, strength, + scale_factors, ) return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask}) @@ -507,18 +553,90 @@ def execute(cls, image, img_compression) -> io.NodeOutput: preprocess = execute # TODO: remove + +import comfy.nested_tensor +class LTXVConcatAVLatent(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LTXVConcatAVLatent", + category="latent/video/ltxv", + inputs=[ + io.Latent.Input("video_latent"), + io.Latent.Input("audio_latent"), + ], + outputs=[ + io.Latent.Output(display_name="latent"), + ], + ) + + @classmethod + def execute(cls, video_latent, audio_latent) -> io.NodeOutput: + output = {} + output.update(video_latent) + output.update(audio_latent) + video_noise_mask = video_latent.get("noise_mask", None) + audio_noise_mask = audio_latent.get("noise_mask", None) + + if video_noise_mask is not None or audio_noise_mask is not None: + if video_noise_mask is None: + video_noise_mask = torch.ones_like(video_latent["samples"]) + if audio_noise_mask is None: + audio_noise_mask = torch.ones_like(audio_latent["samples"]) + output["noise_mask"] = comfy.nested_tensor.NestedTensor((video_noise_mask, audio_noise_mask)) + + output["samples"] = comfy.nested_tensor.NestedTensor((video_latent["samples"], audio_latent["samples"])) + + return io.NodeOutput(output) + + +class LTXVSeparateAVLatent(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LTXVSeparateAVLatent", + category="latent/video/ltxv", + description="LTXV Separate AV Latent", + inputs=[ + io.Latent.Input("av_latent"), + ], + outputs=[ + io.Latent.Output(display_name="video_latent"), + io.Latent.Output(display_name="audio_latent"), + ], + ) + + @classmethod + def execute(cls, av_latent) -> io.NodeOutput: + latents = av_latent["samples"].unbind() + video_latent = av_latent.copy() + video_latent["samples"] = latents[0] + audio_latent = av_latent.copy() + audio_latent["samples"] = latents[1] + if "noise_mask" in av_latent: + masks = av_latent["noise_mask"] + if masks is not None: + masks = masks.unbind() + video_latent["noise_mask"] = masks[0] + audio_latent["noise_mask"] = masks[1] + return io.NodeOutput(video_latent, audio_latent) + + class LtxvExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ EmptyLTXVLatentVideo, LTXVImgToVideo, + LTXVImgToVideoInplace, ModelSamplingLTXV, LTXVConditioning, LTXVScheduler, LTXVAddGuide, LTXVPreprocess, LTXVCropGuides, + LTXVConcatAVLatent, + LTXVSeparateAVLatent, ] diff --git a/comfy_extras/nodes_lt_audio.py b/comfy_extras/nodes_lt_audio.py new file mode 100644 index 000000000000..26b0160d2d98 --- /dev/null +++ b/comfy_extras/nodes_lt_audio.py @@ -0,0 +1,216 @@ +import folder_paths +import comfy.utils +import comfy.model_management +import torch + +from comfy.ldm.lightricks.vae.audio_vae import AudioVAE +from comfy_api.latest import ComfyExtension, io + + +class LTXVAudioVAELoader(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="LTXVAudioVAELoader", + display_name="LTXV Audio VAE Loader", + category="audio", + inputs=[ + io.Combo.Input( + "ckpt_name", + options=folder_paths.get_filename_list("checkpoints"), + tooltip="Audio VAE checkpoint to load.", + ) + ], + outputs=[io.Vae.Output(display_name="Audio VAE")], + ) + + @classmethod + def execute(cls, ckpt_name: str) -> io.NodeOutput: + ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) + sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True) + return io.NodeOutput(AudioVAE(sd, metadata)) + + +class LTXVAudioVAEEncode(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="LTXVAudioVAEEncode", + display_name="LTXV Audio VAE Encode", + category="audio", + inputs=[ + io.Audio.Input("audio", tooltip="The audio to be encoded."), + io.Vae.Input( + id="audio_vae", + display_name="Audio VAE", + tooltip="The Audio VAE model to use for encoding.", + ), + ], + outputs=[io.Latent.Output(display_name="Audio Latent")], + ) + + @classmethod + def execute(cls, audio, audio_vae: AudioVAE) -> io.NodeOutput: + audio_latents = audio_vae.encode(audio) + return io.NodeOutput( + { + "samples": audio_latents, + "sample_rate": int(audio_vae.sample_rate), + "type": "audio", + } + ) + + +class LTXVAudioVAEDecode(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="LTXVAudioVAEDecode", + display_name="LTXV Audio VAE Decode", + category="audio", + inputs=[ + io.Latent.Input("samples", tooltip="The latent to be decoded."), + io.Vae.Input( + id="audio_vae", + display_name="Audio VAE", + tooltip="The Audio VAE model used for decoding the latent.", + ), + ], + outputs=[io.Audio.Output(display_name="Audio")], + ) + + @classmethod + def execute(cls, samples, audio_vae: AudioVAE) -> io.NodeOutput: + audio_latent = samples["samples"] + if audio_latent.is_nested: + audio_latent = audio_latent.unbind()[-1] + audio = audio_vae.decode(audio_latent).to(audio_latent.device) + output_audio_sample_rate = audio_vae.output_sample_rate + return io.NodeOutput( + { + "waveform": audio, + "sample_rate": int(output_audio_sample_rate), + } + ) + + +class LTXVEmptyLatentAudio(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="LTXVEmptyLatentAudio", + display_name="LTXV Empty Latent Audio", + category="latent/audio", + inputs=[ + io.Int.Input( + "frames_number", + default=97, + min=1, + max=1000, + step=1, + display_mode=io.NumberDisplay.number, + tooltip="Number of frames.", + ), + io.Int.Input( + "frame_rate", + default=25, + min=1, + max=1000, + step=1, + display_mode=io.NumberDisplay.number, + tooltip="Number of frames per second.", + ), + io.Int.Input( + "batch_size", + default=1, + min=1, + max=4096, + display_mode=io.NumberDisplay.number, + tooltip="The number of latent audio samples in the batch.", + ), + io.Vae.Input( + id="audio_vae", + display_name="Audio VAE", + tooltip="The Audio VAE model to get configuration from.", + ), + ], + outputs=[io.Latent.Output(display_name="Latent")], + ) + + @classmethod + def execute( + cls, + frames_number: int, + frame_rate: int, + batch_size: int, + audio_vae: AudioVAE, + ) -> io.NodeOutput: + """Generate empty audio latents matching the reference pipeline structure.""" + + assert audio_vae is not None, "Audio VAE model is required" + + z_channels = audio_vae.latent_channels + audio_freq = audio_vae.latent_frequency_bins + sampling_rate = int(audio_vae.sample_rate) + + num_audio_latents = audio_vae.num_of_latents_from_frames(frames_number, frame_rate) + + audio_latents = torch.zeros( + (batch_size, z_channels, num_audio_latents, audio_freq), + device=comfy.model_management.intermediate_device(), + ) + + return io.NodeOutput( + { + "samples": audio_latents, + "sample_rate": sampling_rate, + "type": "audio", + } + ) + + +class LTXAVTextEncoderLoader(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="LTXAVTextEncoderLoader", + display_name="LTXV Audio Text Encoder Loader", + category="advanced/loaders", + description="[Recipes]\n\nltxav: gemma 3 12B", + inputs=[ + io.Combo.Input( + "text_encoder", + options=folder_paths.get_filename_list("text_encoders"), + ), + io.Combo.Input( + "ckpt_name", + options=folder_paths.get_filename_list("checkpoints"), + ) + ], + outputs=[io.Clip.Output()], + ) + + @classmethod + def execute(cls, text_encoder, ckpt_name, device="default"): + clip_type = comfy.sd.CLIPType.LTXV + + clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", text_encoder) + clip_path2 = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) + + clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type) + return io.NodeOutput(clip) + + +class LTXVAudioExtension(ComfyExtension): + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + LTXVAudioVAELoader, + LTXVAudioVAEEncode, + LTXVAudioVAEDecode, + LTXVEmptyLatentAudio, + LTXAVTextEncoderLoader, + ] + + +async def comfy_entrypoint() -> ComfyExtension: + return LTXVAudioExtension() diff --git a/comfy_extras/nodes_lt_upsampler.py b/comfy_extras/nodes_lt_upsampler.py new file mode 100644 index 000000000000..f99ba13fb552 --- /dev/null +++ b/comfy_extras/nodes_lt_upsampler.py @@ -0,0 +1,75 @@ +from comfy import model_management +import math + +class LTXVLatentUpsampler: + """ + Upsamples a video latent by a factor of 2. + """ + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "samples": ("LATENT",), + "upscale_model": ("LATENT_UPSCALE_MODEL",), + "vae": ("VAE",), + } + } + + RETURN_TYPES = ("LATENT",) + FUNCTION = "upsample_latent" + CATEGORY = "latent/video" + EXPERIMENTAL = True + + def upsample_latent( + self, + samples: dict, + upscale_model, + vae, + ) -> tuple: + """ + Upsample the input latent using the provided model. + + Args: + samples (dict): Input latent samples + upscale_model (LatentUpsampler): Loaded upscale model + vae: VAE model for normalization + auto_tiling (bool): Whether to automatically tile the input for processing + + Returns: + tuple: Tuple containing the upsampled latent + """ + device = model_management.get_torch_device() + memory_required = model_management.module_size(upscale_model) + + model_dtype = next(upscale_model.parameters()).dtype + latents = samples["samples"] + input_dtype = latents.dtype + + memory_required += math.prod(latents.shape) * 3000.0 # TODO: more accurate + model_management.free_memory(memory_required, device) + + try: + upscale_model.to(device) # TODO: use the comfy model management system. + + latents = latents.to(dtype=model_dtype, device=device) + + """Upsample latents without tiling.""" + latents = vae.first_stage_model.per_channel_statistics.un_normalize(latents) + upsampled_latents = upscale_model(latents) + finally: + upscale_model.cpu() + + upsampled_latents = vae.first_stage_model.per_channel_statistics.normalize( + upsampled_latents + ) + upsampled_latents = upsampled_latents.to(dtype=input_dtype, device=model_management.intermediate_device()) + return_dict = samples.copy() + return_dict["samples"] = upsampled_latents + return_dict.pop("noise_mask", None) + return (return_dict,) + + +NODE_CLASS_MAPPINGS = { + "LTXVLatentUpsampler": LTXVLatentUpsampler, +} diff --git a/comfy_extras/nodes_mahiro.py b/comfy_extras/nodes_mahiro.py index 07b3353f42f9..6459ca8c1099 100644 --- a/comfy_extras/nodes_mahiro.py +++ b/comfy_extras/nodes_mahiro.py @@ -10,7 +10,7 @@ class Mahiro(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="Mahiro", - display_name="Mahiro is so cute that she deserves a better guidance function!! (。・ω・。)", + display_name="Mahiro CFG", category="_for_testing", description="Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt.", inputs=[ diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index ca2cdeb50546..01afa13a18ee 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -4,11 +4,15 @@ import torch.nn.functional as F from PIL import Image import math +from enum import Enum +from typing import TypedDict, Literal import comfy.utils import comfy.model_management +from comfy_extras.nodes_latent import reshape_latent_to import node_helpers from comfy_api.latest import ComfyExtension, io +from nodes import MAX_RESOLUTION class Blend(io.ComfyNode): @classmethod @@ -241,6 +245,353 @@ def execute(cls, image, upscale_method, megapixels, resolution_steps) -> io.Node s = s.movedim(1,-1) return io.NodeOutput(s) +class ResizeType(str, Enum): + SCALE_BY = "scale by multiplier" + SCALE_DIMENSIONS = "scale dimensions" + SCALE_LONGER_DIMENSION = "scale longer dimension" + SCALE_SHORTER_DIMENSION = "scale shorter dimension" + SCALE_WIDTH = "scale width" + SCALE_HEIGHT = "scale height" + SCALE_TOTAL_PIXELS = "scale total pixels" + MATCH_SIZE = "match size" + +def is_image(input: torch.Tensor) -> bool: + # images have 4 dimensions: [batch, height, width, channels] + # masks have 3 dimensions: [batch, height, width] + return len(input.shape) == 4 + +def init_image_mask_input(input: torch.Tensor, is_type_image: bool) -> torch.Tensor: + if is_type_image: + input = input.movedim(-1, 1) + else: + input = input.unsqueeze(1) + return input + +def finalize_image_mask_input(input: torch.Tensor, is_type_image: bool) -> torch.Tensor: + if is_type_image: + input = input.movedim(1, -1) + else: + input = input.squeeze(1) + return input + +def scale_by(input: torch.Tensor, multiplier: float, scale_method: str) -> torch.Tensor: + is_type_image = is_image(input) + input = init_image_mask_input(input, is_type_image) + width = round(input.shape[-1] * multiplier) + height = round(input.shape[-2] * multiplier) + + input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled") + input = finalize_image_mask_input(input, is_type_image) + return input + +def scale_dimensions(input: torch.Tensor, width: int, height: int, scale_method: str, crop: str="disabled") -> torch.Tensor: + if width == 0 and height == 0: + return input + is_type_image = is_image(input) + input = init_image_mask_input(input, is_type_image) + + if width == 0: + width = max(1, round(input.shape[-1] * height / input.shape[-2])) + elif height == 0: + height = max(1, round(input.shape[-2] * width / input.shape[-1])) + + input = comfy.utils.common_upscale(input, width, height, scale_method, crop) + input = finalize_image_mask_input(input, is_type_image) + return input + +def scale_longer_dimension(input: torch.Tensor, longer_size: int, scale_method: str) -> torch.Tensor: + is_type_image = is_image(input) + input = init_image_mask_input(input, is_type_image) + width = input.shape[-1] + height = input.shape[-2] + + if height > width: + width = round((width / height) * longer_size) + height = longer_size + elif width > height: + height = round((height / width) * longer_size) + width = longer_size + else: + height = longer_size + width = longer_size + + input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled") + input = finalize_image_mask_input(input, is_type_image) + return input + +def scale_shorter_dimension(input: torch.Tensor, shorter_size: int, scale_method: str) -> torch.Tensor: + is_type_image = is_image(input) + input = init_image_mask_input(input, is_type_image) + width = input.shape[-1] + height = input.shape[-2] + + if height < width: + width = round((width / height) * shorter_size) + height = shorter_size + elif width > height: + height = round((height / width) * shorter_size) + width = shorter_size + else: + height = shorter_size + width = shorter_size + + input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled") + input = finalize_image_mask_input(input, is_type_image) + return input + +def scale_total_pixels(input: torch.Tensor, megapixels: float, scale_method: str) -> torch.Tensor: + is_type_image = is_image(input) + input = init_image_mask_input(input, is_type_image) + total = int(megapixels * 1024 * 1024) + + scale_by = math.sqrt(total / (input.shape[-1] * input.shape[-2])) + width = round(input.shape[-1] * scale_by) + height = round(input.shape[-2] * scale_by) + + input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled") + input = finalize_image_mask_input(input, is_type_image) + return input + +def scale_match_size(input: torch.Tensor, match: torch.Tensor, scale_method: str, crop: str) -> torch.Tensor: + is_type_image = is_image(input) + input = init_image_mask_input(input, is_type_image) + match = init_image_mask_input(match, is_image(match)) + + width = match.shape[-1] + height = match.shape[-2] + input = comfy.utils.common_upscale(input, width, height, scale_method, crop) + input = finalize_image_mask_input(input, is_type_image) + return input + +class ResizeImageMaskNode(io.ComfyNode): + + scale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"] + crop_methods = ["disabled", "center"] + + class ResizeTypedDict(TypedDict): + resize_type: ResizeType + scale_method: Literal["nearest-exact", "bilinear", "area", "bicubic", "lanczos"] + crop: Literal["disabled", "center"] + multiplier: float + width: int + height: int + longer_size: int + shorter_size: int + megapixels: float + + @classmethod + def define_schema(cls): + template = io.MatchType.Template("input_type", [io.Image, io.Mask]) + crop_combo = io.Combo.Input("crop", options=cls.crop_methods, default="center") + return io.Schema( + node_id="ResizeImageMaskNode", + display_name="Resize Image/Mask", + category="transform", + inputs=[ + io.MatchType.Input("input", template=template), + io.DynamicCombo.Input("resize_type", options=[ + io.DynamicCombo.Option(ResizeType.SCALE_BY, [ + io.Float.Input("multiplier", default=1.00, min=0.01, max=8.0, step=0.01), + ]), + io.DynamicCombo.Option(ResizeType.SCALE_DIMENSIONS, [ + io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1), + io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1), + crop_combo, + ]), + io.DynamicCombo.Option(ResizeType.SCALE_LONGER_DIMENSION, [ + io.Int.Input("longer_size", default=512, min=0, max=MAX_RESOLUTION, step=1), + ]), + io.DynamicCombo.Option(ResizeType.SCALE_SHORTER_DIMENSION, [ + io.Int.Input("shorter_size", default=512, min=0, max=MAX_RESOLUTION, step=1), + ]), + io.DynamicCombo.Option(ResizeType.SCALE_WIDTH, [ + io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1), + ]), + io.DynamicCombo.Option(ResizeType.SCALE_HEIGHT, [ + io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1), + ]), + io.DynamicCombo.Option(ResizeType.SCALE_TOTAL_PIXELS, [ + io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01), + ]), + io.DynamicCombo.Option(ResizeType.MATCH_SIZE, [ + io.MultiType.Input("match", [io.Image, io.Mask]), + crop_combo, + ]), + ]), + io.Combo.Input("scale_method", options=cls.scale_methods, default="area"), + ], + outputs=[io.MatchType.Output(template=template, display_name="resized")] + ) + + @classmethod + def execute(cls, input: io.Image.Type | io.Mask.Type, scale_method: io.Combo.Type, resize_type: ResizeTypedDict) -> io.NodeOutput: + selected_type = resize_type["resize_type"] + if selected_type == ResizeType.SCALE_BY: + return io.NodeOutput(scale_by(input, resize_type["multiplier"], scale_method)) + elif selected_type == ResizeType.SCALE_DIMENSIONS: + return io.NodeOutput(scale_dimensions(input, resize_type["width"], resize_type["height"], scale_method, resize_type["crop"])) + elif selected_type == ResizeType.SCALE_LONGER_DIMENSION: + return io.NodeOutput(scale_longer_dimension(input, resize_type["longer_size"], scale_method)) + elif selected_type == ResizeType.SCALE_SHORTER_DIMENSION: + return io.NodeOutput(scale_shorter_dimension(input, resize_type["shorter_size"], scale_method)) + elif selected_type == ResizeType.SCALE_WIDTH: + return io.NodeOutput(scale_dimensions(input, resize_type["width"], 0, scale_method)) + elif selected_type == ResizeType.SCALE_HEIGHT: + return io.NodeOutput(scale_dimensions(input, 0, resize_type["height"], scale_method)) + elif selected_type == ResizeType.SCALE_TOTAL_PIXELS: + return io.NodeOutput(scale_total_pixels(input, resize_type["megapixels"], scale_method)) + elif selected_type == ResizeType.MATCH_SIZE: + return io.NodeOutput(scale_match_size(input, resize_type["match"], scale_method, resize_type["crop"])) + raise ValueError(f"Unsupported resize type: {selected_type}") + +def batch_images(images: list[torch.Tensor]) -> torch.Tensor | None: + if len(images) == 0: + return None + # first, get the max channels count + max_channels = max(image.shape[-1] for image in images) + # then, pad all images to have the same channels count + padded_images: list[torch.Tensor] = [] + for image in images: + if image.shape[-1] < max_channels: + padded_images.append(torch.nn.functional.pad(image, (0,1), mode='constant', value=1.0)) + else: + padded_images.append(image) + # resize all images to be the same size as the first image + resized_images: list[torch.Tensor] = [] + first_image_shape = padded_images[0].shape + for image in padded_images: + if image.shape[1:] != first_image_shape[1:]: + resized_images.append(comfy.utils.common_upscale(image.movedim(-1,1), first_image_shape[2], first_image_shape[1], "bilinear", "center").movedim(1,-1)) + else: + resized_images.append(image) + # batch the images in the format [b, h, w, c] + return torch.cat(resized_images, dim=0) + +def batch_masks(masks: list[torch.Tensor]) -> torch.Tensor | None: + if len(masks) == 0: + return None + # resize all masks to be the same size as the first mask + resized_masks: list[torch.Tensor] = [] + first_mask_shape = masks[0].shape + for mask in masks: + if mask.shape[1:] != first_mask_shape[1:]: + mask = init_image_mask_input(mask, is_type_image=False) + mask = comfy.utils.common_upscale(mask, first_mask_shape[2], first_mask_shape[1], "bilinear", "center") + resized_masks.append(finalize_image_mask_input(mask, is_type_image=False)) + else: + resized_masks.append(mask) + # batch the masks in the format [b, h, w] + return torch.cat(resized_masks, dim=0) + +def batch_latents(latents: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor] | None: + if len(latents) == 0: + return None + samples_out = latents[0].copy() + samples_out["batch_index"] = [] + first_samples = latents[0]["samples"] + tensors: list[torch.Tensor] = [] + for latent in latents: + # first, deal with latent tensors + tensors.append(reshape_latent_to(first_samples.shape, latent["samples"], repeat_batch=False)) + # next, deal with batch_index + samples_out["batch_index"].extend(latent.get("batch_index", [x for x in range(0, latent["samples"].shape[0])])) + samples_out["samples"] = torch.cat(tensors, dim=0) + return samples_out + +class BatchImagesNode(io.ComfyNode): + @classmethod + def define_schema(cls): + autogrow_template = io.Autogrow.TemplatePrefix(io.Image.Input("image"), prefix="image", min=2, max=50) + return io.Schema( + node_id="BatchImagesNode", + display_name="Batch Images", + category="image", + inputs=[ + io.Autogrow.Input("images", template=autogrow_template) + ], + outputs=[ + io.Image.Output() + ] + ) + + @classmethod + def execute(cls, images: io.Autogrow.Type) -> io.NodeOutput: + return io.NodeOutput(batch_images(list(images.values()))) + +class BatchMasksNode(io.ComfyNode): + @classmethod + def define_schema(cls): + autogrow_template = io.Autogrow.TemplatePrefix(io.Mask.Input("mask"), prefix="mask", min=2, max=50) + return io.Schema( + node_id="BatchMasksNode", + display_name="Batch Masks", + category="mask", + inputs=[ + io.Autogrow.Input("masks", template=autogrow_template) + ], + outputs=[ + io.Mask.Output() + ] + ) + + @classmethod + def execute(cls, masks: io.Autogrow.Type) -> io.NodeOutput: + return io.NodeOutput(batch_masks(list(masks.values()))) + +class BatchLatentsNode(io.ComfyNode): + @classmethod + def define_schema(cls): + autogrow_template = io.Autogrow.TemplatePrefix(io.Latent.Input("latent"), prefix="latent", min=2, max=50) + return io.Schema( + node_id="BatchLatentsNode", + display_name="Batch Latents", + category="latent", + inputs=[ + io.Autogrow.Input("latents", template=autogrow_template) + ], + outputs=[ + io.Latent.Output() + ] + ) + + @classmethod + def execute(cls, latents: io.Autogrow.Type) -> io.NodeOutput: + return io.NodeOutput(batch_latents(list(latents.values()))) + +class BatchImagesMasksLatentsNode(io.ComfyNode): + @classmethod + def define_schema(cls): + matchtype_template = io.MatchType.Template("input", allowed_types=[io.Image, io.Mask, io.Latent]) + autogrow_template = io.Autogrow.TemplatePrefix( + io.MatchType.Input("input", matchtype_template), + prefix="input", min=1, max=50) + return io.Schema( + node_id="BatchImagesMasksLatentsNode", + display_name="Batch Images/Masks/Latents", + category="util", + inputs=[ + io.Autogrow.Input("inputs", template=autogrow_template) + ], + outputs=[ + io.MatchType.Output(id=None, template=matchtype_template) + ] + ) + + @classmethod + def execute(cls, inputs: io.Autogrow.Type) -> io.NodeOutput: + batched = None + values = list(inputs.values()) + # latents + if isinstance(values[0], dict): + batched = batch_latents(values) + # images + elif is_image(values[0]): + batched = batch_images(values) + # masks + else: + batched = batch_masks(values) + return io.NodeOutput(batched) + class PostProcessingExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: @@ -250,6 +601,11 @@ async def get_node_list(self) -> list[type[io.ComfyNode]]: Quantize, Sharpen, ImageScaleToTotalPixels, + ResizeImageMaskNode, + BatchImagesNode, + BatchMasksNode, + BatchLatentsNode, + # BatchImagesMasksLatentsNode, ] async def comfy_entrypoint() -> PostProcessingExtension: diff --git a/comfy_extras/nodes_primitive.py b/comfy_extras/nodes_primitive.py index 5a1aeba80077..937321800249 100644 --- a/comfy_extras/nodes_primitive.py +++ b/comfy_extras/nodes_primitive.py @@ -66,7 +66,7 @@ def define_schema(cls): display_name="Float", category="utils/primitive", inputs=[ - io.Float.Input("value", min=-sys.maxsize, max=sys.maxsize), + io.Float.Input("value", min=-sys.maxsize, max=sys.maxsize, step=0.1), ], outputs=[io.Float.Output()], ) diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index 4d62b87be209..ed587851c521 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -78,18 +78,20 @@ def execute(cls, upscale_model, image) -> io.NodeOutput: overlap = 32 oom = True - while oom: - try: - steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap) - pbar = comfy.utils.ProgressBar(steps) - s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar) - oom = False - except model_management.OOM_EXCEPTION as e: - tile //= 2 - if tile < 128: - raise e - - upscale_model.to("cpu") + try: + while oom: + try: + steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap) + pbar = comfy.utils.ProgressBar(steps) + s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar) + oom = False + except model_management.OOM_EXCEPTION as e: + tile //= 2 + if tile < 128: + raise e + finally: + upscale_model.to("cpu") + s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0) return io.NodeOutput(s) diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index b0bd471bfb42..d32aad98e024 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -817,7 +817,7 @@ def get_sample_indices(original_fps, if required_duration > total_frames / original_fps: raise ValueError("required_duration must be less than video length") - if not fixed_start is None and fixed_start >= 0: + if fixed_start is not None and fixed_start >= 0: start_frame = fixed_start else: max_start = total_frames - required_origin_frames diff --git a/comfyui_version.py b/comfyui_version.py index b4530919846e..1ed60fe5c007 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.5.1" +__version__ = "0.7.0" diff --git a/execution.py b/execution.py index 0c239efd7a2f..648f204ecd73 100644 --- a/execution.py +++ b/execution.py @@ -79,7 +79,7 @@ async def get(self, node_id): # Intentionally do not use cached outputs here. We only want constants in IS_CHANGED input_data_all, _, v3_data = get_input_data(node["inputs"], class_def, node_id, None) try: - is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name) + is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name, v3_data=v3_data) is_changed = await resolve_map_node_over_list_results(is_changed) node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed] except Exception as e: @@ -148,13 +148,12 @@ def recursive_debug_dump(self): def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}): is_v3 = issubclass(class_def, _ComfyNodeInternal) v3_data: io.V3Data = {} + hidden_inputs_v3 = {} + valid_inputs = class_def.INPUT_TYPES() if is_v3: - valid_inputs, schema, v3_data = class_def.INPUT_TYPES(include_hidden=False, return_schema=True, live_inputs=inputs) - else: - valid_inputs = class_def.INPUT_TYPES() + valid_inputs, hidden, v3_data = _io.get_finalized_class_inputs(valid_inputs, inputs) input_data_all = {} missing_keys = {} - hidden_inputs_v3 = {} for x in inputs: input_data = inputs[x] _, input_category, input_info = get_input_info(class_def, x, valid_inputs) @@ -180,18 +179,18 @@ def mark_missing(): input_data_all[x] = [input_data] if is_v3: - if schema.hidden: - if io.Hidden.prompt in schema.hidden: + if hidden is not None: + if io.Hidden.prompt.name in hidden: hidden_inputs_v3[io.Hidden.prompt] = dynprompt.get_original_prompt() if dynprompt is not None else {} - if io.Hidden.dynprompt in schema.hidden: + if io.Hidden.dynprompt.name in hidden: hidden_inputs_v3[io.Hidden.dynprompt] = dynprompt - if io.Hidden.extra_pnginfo in schema.hidden: + if io.Hidden.extra_pnginfo.name in hidden: hidden_inputs_v3[io.Hidden.extra_pnginfo] = extra_data.get('extra_pnginfo', None) - if io.Hidden.unique_id in schema.hidden: + if io.Hidden.unique_id.name in hidden: hidden_inputs_v3[io.Hidden.unique_id] = unique_id - if io.Hidden.auth_token_comfy_org in schema.hidden: + if io.Hidden.auth_token_comfy_org.name in hidden: hidden_inputs_v3[io.Hidden.auth_token_comfy_org] = extra_data.get("auth_token_comfy_org", None) - if io.Hidden.api_key_comfy_org in schema.hidden: + if io.Hidden.api_key_comfy_org.name in hidden: hidden_inputs_v3[io.Hidden.api_key_comfy_org] = extra_data.get("api_key_comfy_org", None) else: if "hidden" in valid_inputs: @@ -258,7 +257,7 @@ async def process_inputs(inputs, index=None, input_is_list=False): pre_execute_cb(index) # V3 if isinstance(obj, _ComfyNodeInternal) or (is_class(obj) and issubclass(obj, _ComfyNodeInternal)): - # if is just a class, then assign no resources or state, just create clone + # if is just a class, then assign no state, just create clone if is_class(obj): type_obj = obj obj.VALIDATE_CLASS() @@ -481,7 +480,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, else: lazy_status_present = getattr(obj, "check_lazy_status", None) is not None if lazy_status_present: - required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, v3_data=v3_data) + # for check_lazy_status, the returned data should include the original key of the input + v3_data_lazy = v3_data.copy() + v3_data_lazy["create_dynamic_tuple"] = True + required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, v3_data=v3_data_lazy) required_inputs = await resolve_map_node_over_list_results(required_inputs) required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], [])) required_inputs = [x for x in required_inputs if isinstance(x,str) and ( @@ -599,6 +601,7 @@ async def await_completion(): if isinstance(ex, comfy.model_management.OOM_EXCEPTION): tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number." + logging.info("Memory summary: {}".format(comfy.model_management.debug_memory_summary())) logging.error("Got an OOM, unloading all loaded models.") comfy.model_management.unload_all_models() @@ -756,10 +759,13 @@ async def validate_inputs(prompt_id, prompt, item, validated): errors = [] valid = True + v3_data = None validate_function_inputs = [] validate_has_kwargs = False if issubclass(obj_class, _ComfyNodeInternal): - class_inputs, _, _ = obj_class.INPUT_TYPES(include_hidden=False, return_schema=True, live_inputs=inputs) + obj_class: _io._ComfyNodeBaseInternal + class_inputs = obj_class.INPUT_TYPES() + class_inputs, _, v3_data = _io.get_finalized_class_inputs(class_inputs, inputs) validate_function_name = "validate_inputs" validate_function = first_real_override(obj_class, validate_function_name) else: @@ -779,10 +785,11 @@ async def validate_inputs(prompt_id, prompt, item, validated): assert extra_info is not None if x not in inputs: if input_category == "required": + details = f"{x}" if not v3_data else x.split(".")[-1] error = { "type": "required_input_missing", "message": "Required input is missing", - "details": f"{x}", + "details": details, "extra_info": { "input_name": x } @@ -916,8 +923,11 @@ async def validate_inputs(prompt_id, prompt, item, validated): errors.append(error) continue - if isinstance(input_type, list): - combo_options = input_type + if isinstance(input_type, list) or input_type == io.Combo.io_type: + if input_type == io.Combo.io_type: + combo_options = extra_info.get("options", []) + else: + combo_options = input_type if val not in combo_options: input_config = info list_info = "" diff --git a/manager_requirements.txt b/manager_requirements.txt index 2300f0c70083..6585b0c193e4 100644 --- a/manager_requirements.txt +++ b/manager_requirements.txt @@ -1 +1 @@ -comfyui_manager==4.0.3b7 +comfyui_manager==4.0.4 diff --git a/nodes.py b/nodes.py index 7d83ecb21db8..56b74ebe3130 100644 --- a/nodes.py +++ b/nodes.py @@ -295,7 +295,11 @@ def INPUT_TYPES(s): DESCRIPTION = "Decodes latent images back into pixel space images." def decode(self, vae, samples): - images = vae.decode(samples["samples"]) + latent = samples["samples"] + if latent.is_nested: + latent = latent.unbind()[0] + + images = vae.decode(latent) if len(images.shape) == 5: #Combine batches images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1]) return (images, ) @@ -970,7 +974,7 @@ class DualCLIPLoader: def INPUT_TYPES(s): return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image", "newbie"], ), + "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image", "ltxv", "newbie"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), @@ -1663,8 +1667,6 @@ def load_image(self, image): output_masks = [] w, h = None, None - excluded_formats = ['MPO'] - for i in ImageSequence.Iterator(img): i = node_helpers.pillow(ImageOps.exif_transpose, i) @@ -1692,7 +1694,10 @@ def load_image(self, image): output_images.append(image) output_masks.append(mask.unsqueeze(0)) - if len(output_images) > 1 and img.format not in excluded_formats: + if img.format == "MPO": + break # ignore all frames except the first one for MPO format + + if len(output_images) > 1: output_image = torch.cat(output_images, dim=0) output_mask = torch.cat(output_masks, dim=0) else: @@ -1863,6 +1868,7 @@ def INPUT_TYPES(s): FUNCTION = "batch" CATEGORY = "image" + DEPRECATED = True def batch(self, image1, image2): if image1.shape[-1] != image2.shape[-1]: @@ -2241,8 +2247,10 @@ async def init_external_custom_nodes(): for possible_module in possible_modules: module_path = os.path.join(custom_node_path, possible_module) - if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue - if module_path.endswith(".disabled"): continue + if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": + continue + if module_path.endswith(".disabled"): + continue if args.disable_all_custom_nodes and possible_module not in args.whitelist_custom_nodes: logging.info(f"Skipping {possible_module} due to disable_all_custom_nodes and whitelist_custom_nodes") continue @@ -2327,6 +2335,8 @@ async def init_builtin_extra_nodes(): "nodes_mochi.py", "nodes_slg.py", "nodes_mahiro.py", + "nodes_lt_upsampler.py", + "nodes_lt_audio.py", "nodes_lt.py", "nodes_hooks.py", "nodes_load_3d.py", diff --git a/pyproject.toml b/pyproject.toml index 3a6960811275..a7d159be962e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,9 +1,9 @@ [project] name = "ComfyUI" -version = "0.5.1" +version = "0.7.0" readme = "README.md" license = { file = "LICENSE" } -requires-python = ">=3.9" +requires-python = ">=3.10" [project.urls] homepage = "https://www.comfy.org/" @@ -15,12 +15,16 @@ lint.select = [ "N805", # invalid-first-argument-name-for-method "S307", # suspicious-eval-usage "S102", # exec + "E", "T", # print-usage "W", # The "F" series in Ruff stands for "Pyflakes" rules, which catch various Python syntax errors and undefined names. # See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f "F", ] + +lint.ignore = ["E501", "E722", "E731", "E712", "E402", "E741"] + exclude = ["*.ipynb", "**/generated/*.pyi"] [tool.pylint] diff --git a/requirements.txt b/requirements.txt index 54696395f9b8..3a05799eba8d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -comfyui-frontend-package==1.34.9 -comfyui-workflow-templates==0.7.60 +comfyui-frontend-package==1.35.9 +comfyui-workflow-templates==0.7.65 comfyui-embedded-docs==0.3.1 torch torchsde diff --git a/server.py b/server.py index c27f8be7deda..70c8b5e3be5d 100644 --- a/server.py +++ b/server.py @@ -324,7 +324,7 @@ def list_model_types(request): @routes.get("/models/{folder}") async def get_models(request): folder = request.match_info.get("folder", None) - if not folder in folder_paths.folder_names_and_paths: + if folder not in folder_paths.folder_names_and_paths: return web.Response(status=404) files = folder_paths.get_filename_list(folder) return web.json_response(files) @@ -579,7 +579,7 @@ async def view_metadata(request): folder_name = request.match_info.get("folder_name", None) if folder_name is None: return web.Response(status=404) - if not "filename" in request.rel_url.query: + if "filename" not in request.rel_url.query: return web.Response(status=404) filename = request.rel_url.query["filename"] @@ -593,7 +593,7 @@ async def view_metadata(request): if out is None: return web.Response(status=404) dt = json.loads(out) - if not "__metadata__" in dt: + if "__metadata__" not in dt: return web.Response(status=404) return web.json_response(dt["__metadata__"]) diff --git a/tests-unit/comfy_extras_test/image_stitch_test.py b/tests-unit/comfy_extras_test/image_stitch_test.py index b5a0f022cd0c..5c6a15ac44c7 100644 --- a/tests-unit/comfy_extras_test/image_stitch_test.py +++ b/tests-unit/comfy_extras_test/image_stitch_test.py @@ -25,7 +25,7 @@ def test_no_image2_passthrough(self): result = node.stitch(image1, "right", True, 0, "white", image2=None) - assert len(result) == 1 + assert len(result.result) == 1 assert torch.equal(result[0], image1) def test_basic_horizontal_stitch_right(self): diff --git a/tests-unit/comfy_quant/test_quant_registry.py b/tests-unit/comfy_quant/test_quant_registry.py index 9cb54ede8026..51d27dd26d0d 100644 --- a/tests-unit/comfy_quant/test_quant_registry.py +++ b/tests-unit/comfy_quant/test_quant_registry.py @@ -47,6 +47,29 @@ def test_dequantize(self): self.assertEqual(dequantized.dtype, torch.float32) self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1)) + def test_save_load(self): + """Test creating a QuantizedTensor with TensorCoreFP8Layout""" + fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(2.0) + layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16} + + qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) + + self.assertIsInstance(qt, QuantizedTensor) + self.assertEqual(qt.shape, (256, 128)) + self.assertEqual(qt.dtype, torch.float8_e4m3fn) + self.assertEqual(qt._layout_params['scale'], scale) + self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16) + self.assertEqual(qt._layout_type, "TensorCoreFP8Layout") + + torch.save(qt, "test.pt") + loaded_qt = torch.load("test.pt", weights_only=False) + # loaded_qt = torch.load("test.pt", map_location='cpu', mmap=True, weights_only=False) + + self.assertEqual(loaded_qt._layout_type, "TensorCoreFP8Layout") + self.assertEqual(loaded_qt._layout_params['scale'], scale) + self.assertEqual(loaded_qt._layout_params['orig_dtype'], torch.bfloat16) + def test_from_float(self): """Test creating QuantizedTensor from float tensor""" float_tensor = torch.randn(64, 32, dtype=torch.float32) diff --git a/tests/inference/test_model_mmap.py b/tests/inference/test_model_mmap.py new file mode 100644 index 000000000000..a7bff3bfc5cb --- /dev/null +++ b/tests/inference/test_model_mmap.py @@ -0,0 +1,287 @@ +import pytest +import torch +import torch.nn as nn +import psutil +import os +import gc +import tempfile +import sys + +# Ensure the project root is on the Python path (so `import comfy` works when running tests from this folder) +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))) + +from comfy.model_patcher import model_to_mmap, to_mmap + + +class LargeModel(nn.Module): + """A simple model with large parameters for testing memory mapping""" + + def __init__(self, size_gb=10): + super().__init__() + # Calculate number of float32 elements needed for target size + # 1 GB = 1024^3 bytes, float32 = 4 bytes + bytes_per_gb = 1024 * 1024 * 1024 + elements_per_gb = bytes_per_gb // 4 # float32 is 4 bytes + total_elements = int(size_gb * elements_per_gb) + + # Create a large linear layer + # Split into multiple layers to avoid single tensor size limits + self.layers = nn.ModuleList() + elements_per_layer = 500 * 1024 * 1024 # 500M elements per layer (~2GB) + num_layers = (total_elements + elements_per_layer - 1) // elements_per_layer + + for i in range(num_layers): + if i == num_layers - 1: + # Last layer gets the remaining elements + remaining = total_elements - (i * elements_per_layer) + in_features = int(remaining ** 0.5) + out_features = (remaining + in_features - 1) // in_features + else: + in_features = int(elements_per_layer ** 0.5) + out_features = (elements_per_layer + in_features - 1) // in_features + + # Create layer without bias to control size precisely + self.layers.append(nn.Linear(in_features, out_features, bias=False)) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x + + +def get_process_memory_gb(): + """Get current process memory usage in GB""" + process = psutil.Process(os.getpid()) + mem_info = process.memory_info() + return mem_info.rss / (1024 ** 3) # Convert to GB + + +def get_model_size_gb(model): + """Calculate model size in GB""" + total_size = 0 + for param in model.parameters(): + total_size += param.nelement() * param.element_size() + for buffer in model.buffers(): + total_size += buffer.nelement() * buffer.element_size() + return total_size / (1024 ** 3) + + +def test_model_to_mmap_memory_efficiency(): + """Test that model_to_mmap reduces memory usage for a 10GB model to less than 1GB + + The typical use case is: + 1. Load a large model on CUDA + 2. Convert to mmap to offload from GPU to disk-backed memory + 3. This frees GPU memory and reduces CPU RAM usage + """ + + # Check if CUDA is available + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available, skipping test") + + # Force garbage collection before starting + gc.collect() + torch.cuda.empty_cache() + + # Record initial memory + initial_cpu_memory = get_process_memory_gb() + initial_gpu_memory = torch.cuda.memory_allocated() / (1024 ** 3) + print(f"\nInitial CPU memory: {initial_cpu_memory:.2f} GB") + print(f"Initial GPU memory: {initial_gpu_memory:.2f} GB") + + # Create a 10GB model + print("Creating 10GB model...") + model = LargeModel(size_gb=10) + + # Verify model size + model_size = get_model_size_gb(model) + print(f"Model size: {model_size:.2f} GB") + assert model_size >= 9.5, f"Model size {model_size:.2f} GB is less than expected 10 GB" + + # Move model to CUDA + print("Moving model to CUDA...") + model = model.cuda() + torch.cuda.synchronize() + + # Memory after moving to CUDA + cpu_after_cuda = get_process_memory_gb() + gpu_after_cuda = torch.cuda.memory_allocated() / (1024 ** 3) + print(f"CPU memory after moving to CUDA: {cpu_after_cuda:.2f} GB") + print(f"GPU memory after moving to CUDA: {gpu_after_cuda:.2f} GB") + + # Convert to mmap (this should move model from GPU to disk-backed memory) + # Note: model_to_mmap modifies the model in-place via _apply() + # so model and model_mmap will be the same object + print("Converting model to mmap...") + model_mmap = model_to_mmap(model) + + # Verify that model and model_mmap are the same object (in-place modification) + assert model is model_mmap, "model_to_mmap should modify the model in-place" + + # Force garbage collection and clear CUDA cache + # The original CUDA tensors should be automatically freed when replaced + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + + # Memory after mmap conversion + cpu_after_mmap = get_process_memory_gb() + gpu_after_mmap = torch.cuda.memory_allocated() / (1024 ** 3) + print(f"CPU memory after mmap: {cpu_after_mmap:.2f} GB") + print(f"GPU memory after mmap: {gpu_after_mmap:.2f} GB") + + # Calculate memory changes from CUDA state (the baseline we're converting from) + cpu_increase = cpu_after_mmap - cpu_after_cuda + gpu_decrease = gpu_after_cuda - gpu_after_mmap # Should be positive (freed) + print(f"\nCPU memory increase from CUDA: {cpu_increase:.2f} GB") + print(f"GPU memory freed: {gpu_decrease:.2f} GB") + + # Verify that CPU memory usage increase is less than 1GB + # The mmap should use disk-backed storage, keeping CPU RAM usage low + # We use 1.5 GB threshold to account for overhead + assert cpu_increase < 1.5, ( + f"CPU memory increase after mmap ({cpu_increase:.2f} GB) should be less than 1.5 GB. " + f"CUDA state: {cpu_after_cuda:.2f} GB, After mmap: {cpu_after_mmap:.2f} GB" + ) + + # Verify that GPU memory has been freed + # We expect at least 9 GB to be freed (original 10GB model with some tolerance) + assert gpu_decrease > 9.0, ( + f"GPU memory should be freed after mmap. " + f"Freed: {gpu_decrease:.2f} GB (from {gpu_after_cuda:.2f} to {gpu_after_mmap:.2f} GB), expected > 9 GB" + ) + + # Verify the model is still functional (basic sanity check) + assert model_mmap is not None + assert len(list(model_mmap.parameters())) > 0 + + print(f"\n✓ Test passed!") + print(f" CPU memory increase: {cpu_increase:.2f} GB < 1.5 GB") + print(f" GPU memory freed: {gpu_decrease:.2f} GB > 9.0 GB") + print(f" Model successfully offloaded from GPU to disk-backed memory") + + # Cleanup (model and model_mmap are the same object) + del model, model_mmap + gc.collect() + torch.cuda.empty_cache() + + +def test_to_mmap_cuda_cycle(): + """Test CUDA -> mmap -> CUDA cycle + + This test verifies: + 1. CUDA tensor can be converted to mmap tensor + 2. CPU memory increase is minimal when using mmap (< 0.1 GB) + 3. GPU memory is freed when converting to mmap + 4. mmap tensor can be moved back to CUDA + 5. Data remains consistent throughout the cycle + 6. mmap file is automatically cleaned up via garbage collection + """ + + # Check if CUDA is available + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available, skipping test") + + # Force garbage collection + gc.collect() + torch.cuda.empty_cache() + + print("\nTest: CUDA -> mmap -> CUDA cycle") + + # Record initial CPU memory + initial_cpu_memory = get_process_memory_gb() + print(f"Initial CPU memory: {initial_cpu_memory:.2f} GB") + + # Step 1: Create a CUDA tensor + print("\n1. Creating CUDA tensor...") + original_data = torch.randn(5000, 5000).cuda() + original_sum = original_data.sum().item() + print(f" Shape: {original_data.shape}") + print(f" Device: {original_data.device}") + print(f" Sum: {original_sum:.2f}") + + # Record GPU and CPU memory after CUDA allocation + cpu_after_cuda = get_process_memory_gb() + gpu_before_mmap = torch.cuda.memory_allocated() / (1024 ** 3) + print(f" GPU memory: {gpu_before_mmap:.2f} GB") + print(f" CPU memory: {cpu_after_cuda:.2f} GB") + + # Step 2: Convert to mmap tensor + print("\n2. Converting to mmap tensor...") + mmap_tensor = to_mmap(original_data) + del original_data + gc.collect() + torch.cuda.empty_cache() + + print(f" Device: {mmap_tensor.device}") + print(f" Sum: {mmap_tensor.sum().item():.2f}") + + # Verify GPU memory is freed + gpu_after_mmap = torch.cuda.memory_allocated() / (1024 ** 3) + cpu_after_mmap = get_process_memory_gb() + print(f" GPU memory freed: {gpu_before_mmap - gpu_after_mmap:.2f} GB") + print(f" CPU memory: {cpu_after_mmap:.2f} GB") + + # Verify GPU memory is freed + assert gpu_after_mmap < 0.1, f"GPU memory should be freed, but {gpu_after_mmap:.2f} GB still allocated" + + # Verify CPU memory increase is minimal (should be close to 0 due to mmap) + cpu_increase = cpu_after_mmap - cpu_after_cuda + print(f" CPU memory increase: {cpu_increase:.2f} GB") + assert cpu_increase < 0.1, f"CPU memory should increase minimally, but increased by {cpu_increase:.2f} GB" + + # Get the temp file path (we'll check if it gets cleaned up) + # The file should exist at this point + temp_files_before = len([f for f in os.listdir(tempfile.gettempdir()) if f.startswith('comfy_mmap_')]) + print(f" Temp mmap files exist: {temp_files_before}") + + # Step 3: Move back to CUDA + print("\n3. Moving back to CUDA...") + cuda_tensor = mmap_tensor.to('cuda') + torch.cuda.synchronize() + + print(f" Device: {cuda_tensor.device}") + final_sum = cuda_tensor.sum().item() + print(f" Sum: {final_sum:.2f}") + + # Verify GPU memory is used again + gpu_after_cuda = torch.cuda.memory_allocated() / (1024 ** 3) + print(f" GPU memory: {gpu_after_cuda:.2f} GB") + + # Step 4: Verify data consistency + print("\n4. Verifying data consistency...") + sum_diff = abs(original_sum - final_sum) + print(f" Original sum: {original_sum:.2f}") + print(f" Final sum: {final_sum:.2f}") + print(f" Difference: {sum_diff:.6f}") + assert sum_diff < 0.01, f"Data should be consistent, but difference is {sum_diff:.6f}" + + # Step 5: Verify file cleanup (delayed until garbage collection) + print("\n5. Verifying file cleanup...") + # Delete the mmap tensor reference to trigger garbage collection + del mmap_tensor + gc.collect() + import time + time.sleep(0.1) # Give OS time to clean up + temp_files_after = len([f for f in os.listdir(tempfile.gettempdir()) if f.startswith('comfy_mmap_')]) + print(f" Temp mmap files after GC: {temp_files_after}") + # File should be cleaned up after garbage collection + assert temp_files_after <= temp_files_before, "mmap file should be cleaned up after garbage collection" + + print("\n✓ Test passed!") + print(" CUDA -> mmap -> CUDA cycle works correctly") + print(f" CPU memory increase: {cpu_increase:.2f} GB < 0.1 GB (mmap efficiency)") + print(" Data consistency maintained") + print(" File cleanup successful (via garbage collection)") + + # Cleanup + del cuda_tensor # mmap_tensor already deleted in Step 5 + gc.collect() + torch.cuda.empty_cache() + + +if __name__ == "__main__": + # Run the tests directly + test_model_to_mmap_memory_efficiency() + test_to_mmap_cuda_cycle() +