From b11b085fcfb0116dc8a9b85e5c2b60fa41eab690 Mon Sep 17 00:00:00 2001 From: bubbliiiing <3323290568@qq.com> Date: Wed, 11 Mar 2026 16:28:17 +0800 Subject: [PATCH 01/10] Reformat S2V models --- examples/wan2.2/predict_s2v.py | 11 +- scripts/longcatvideo/train_avatar.py | 72 ++------ scripts/longcatvideo/train_avatar_lora.py | 75 ++------ scripts/wan2.2/train_s2v.py | 22 ++- scripts/wan2.2/train_s2v_lora.py | 22 ++- videox_fun/data/dataset_video.py | 109 +++++++----- videox_fun/models/__init__.py | 3 +- .../models/fantasytalking_audio_encoder.py | 10 +- .../models/longcatvideo_audio_encoder.py | 166 +++++++++++++++++- videox_fun/models/wan_audio_encoder.py | 63 +++---- .../pipeline/pipeline_fantasy_talking.py | 22 ++- .../pipeline/pipeline_longcatvideo_avatar.py | 88 ++-------- videox_fun/pipeline/pipeline_wan2_2_s2v.py | 14 +- videox_fun/utils/utils.py | 5 +- 14 files changed, 371 insertions(+), 311 deletions(-) diff --git a/examples/wan2.2/predict_s2v.py b/examples/wan2.2/predict_s2v.py index 2c526f2c..2b5dcf13 100644 --- a/examples/wan2.2/predict_s2v.py +++ b/examples/wan2.2/predict_s2v.py @@ -104,7 +104,8 @@ # Other params sample_size = [832, 480] -video_length = 80 +# How many frames to generate per clips. +infer_frames = 80 fps = 16 # Use torch.float16 if GPU does not support torch.bfloat16 @@ -311,8 +312,8 @@ pipeline = merge_lora(pipeline, lora_high_path, lora_high_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") with torch.no_grad(): - video_length = video_length // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio if video_length != 1 else 1 - latent_frames = video_length // vae.config.temporal_compression_ratio + infer_frames = infer_frames // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio if infer_frames != 1 else 1 + latent_frames = infer_frames // vae.config.temporal_compression_ratio if enable_riflex: pipeline.transformer.enable_riflex(k = riflex_k, L_test = latent_frames) @@ -322,11 +323,11 @@ if ref_image is not None: ref_image = get_image_latent(ref_image, sample_size=sample_size) - pose_video, _, _, _ = get_video_to_video_latent(control_video, video_length=video_length, sample_size=sample_size, fps=fps, ref_image=None) + pose_video, _, _, _ = get_video_to_video_latent(control_video, video_length=None, sample_size=sample_size, fps=fps, ref_image=None) sample = pipeline( prompt, - num_frames = video_length, + num_frames = infer_frames, negative_prompt = negative_prompt, height = sample_size[0], width = sample_size[1], diff --git a/scripts/longcatvideo/train_avatar.py b/scripts/longcatvideo/train_avatar.py index 3d54e44b..0525d0c8 100644 --- a/scripts/longcatvideo/train_avatar.py +++ b/scripts/longcatvideo/train_avatar.py @@ -79,9 +79,9 @@ get_random_mask) from videox_fun.data.dataset_video import VideoSpeechDataset from videox_fun.models import (AutoencoderKLLongCatVideo, AutoTokenizer, - CLIPModel, LongCatVideoAvatarTransformer3DModel, - UMT5EncoderModel, Wav2Vec2FeatureExtractor, - Wav2Vec2ModelWrapper) + CLIPModel, LongCatVideoAudioEncoder, + LongCatVideoAvatarTransformer3DModel, + UMT5EncoderModel) from videox_fun.pipeline import LongCatVideoAvatarPipeline from videox_fun.utils.discrete_sampler import DiscreteSampling from videox_fun.utils.utils import (calculate_dimensions, @@ -172,7 +172,7 @@ def resize_mask(mask, latent, process_first_frame_only=True): logger = get_logger(__name__, log_level="INFO") -def log_validation(vae, text_encoder, tokenizer, audio_encoder, wav2vec_feature_extractor, transformer3d, args, accelerator, weight_dtype, global_step): +def log_validation(vae, text_encoder, tokenizer, audio_encoder, transformer3d, args, accelerator, weight_dtype, global_step): try: is_deepspeed = type(transformer3d).__name__ == 'DeepSpeedEngine' if is_deepspeed: @@ -192,7 +192,6 @@ def log_validation(vae, text_encoder, tokenizer, audio_encoder, wav2vec_feature_ transformer=accelerator.unwrap_model(transformer3d) if type(transformer3d).__name__ == 'DistributedDataParallel' else transformer3d, scheduler=scheduler, audio_encoder=audio_encoder, - wav2vec_feature_extractor=wav2vec_feature_extractor, ) pipeline = pipeline.to(accelerator.device) @@ -881,15 +880,10 @@ def deepspeed_zero_init_disabled_context_manager(): vae.eval() # Get Audio encoder (for avatar mode) - audio_encoder = Wav2Vec2ModelWrapper( + audio_encoder = LongCatVideoAudioEncoder( os.path.join(args.pretrained_avatar_model_name_or_path, 'chinese-wav2vec2-base') ) - audio_encoder.feature_extractor._freeze_parameters() - - wav2vec_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( - os.path.join(args.pretrained_avatar_model_name_or_path, 'chinese-wav2vec2-base'), - local_files_only=True - ) + audio_encoder.audio_encoder.feature_extractor._freeze_parameters() # Get Transformer transformer3d = LongCatVideoAvatarTransformer3DModel.from_pretrained( @@ -1670,53 +1664,16 @@ def _batch_encode_vae(pixel_values): inpaint_latents = (inpaint_latents - latents_mean) * latents_std with torch.no_grad(): - def _loudness_norm(audio_array, sr=16000, lufs=-23, threshold=100): - meter = pyln.Meter(sr) - loudness = meter.integrated_loudness(audio_array) - if abs(loudness) > threshold: - return audio_array - normalized_audio = pyln.normalize.loudness(audio_array, loudness, lufs) - return normalized_audio - - def _add_noise_floor(audio, noise_db=-45): - noise_amp = 10 ** (noise_db / 20) - noise = np.random.randn(len(audio)) * noise_amp - return audio + noise - - def _smooth_transients(audio, sr=16000): - b, a = ss.butter(3, 3000 / (sr/2)) - return ss.lfilter(b, a, audio) - audio_stride = 2 + num_frames = pixel_values.size()[1] audio_cond_embs = [] for index, speech_array in enumerate(audio): - # speech preprocess - speech_array = _loudness_norm(speech_array.cpu().numpy(), sample_rate[index]) - speech_array = _add_noise_floor(speech_array) - speech_array = _smooth_transients(speech_array) - - # wav2vec_feature_extractor - audio_feature = np.squeeze( - wav2vec_feature_extractor(speech_array, sampling_rate=sample_rate[index]).input_values - ) - audio_feature = torch.from_numpy(audio_feature).float().to(device=accelerator.device) - audio_feature = audio_feature.unsqueeze(0) - - # audio embedding - embeddings = audio_encoder(audio_feature, seq_len=int(audio_stride * pixel_values.size()[1]), output_hidden_states=True) - - audio_emb = torch.stack(embeddings.hidden_states[1:], dim=1).squeeze(0) - audio_emb = rearrange(audio_emb, "b s d -> s b d").contiguous() # T, 12, 768 - - # Prepare audio embedding with sliding window - indices = torch.arange(2 * 2 + 1) - 2 # [-2, -1, 0, 1, 2] - audio_start_idx = 0 - audio_end_idx = audio_start_idx + audio_stride * pixel_values.size()[1] - - center_indices = torch.arange(audio_start_idx, audio_end_idx, audio_stride).unsqueeze(1) + indices.unsqueeze(0) - center_indices = torch.clamp(center_indices, min=0, max=audio_emb.shape[0] - 1) - audio_emb = audio_emb[center_indices][None, ...].to(accelerator.device) - + audio_emb = audio_encoder.extract_audio_feat_without_file_load( + audio_segment=speech_array.cpu().numpy(), + sample_rate=sample_rate[index], + num_frames=num_frames, + audio_stride=audio_stride + ).to(accelerator.device) audio_cond_embs.append(audio_emb) audio_cond_embs = torch.cat(audio_cond_embs, dim=0) @@ -1911,8 +1868,7 @@ def custom_mse_loss(noise_pred, target, weighting=None, threshold=50): vae, text_encoder, tokenizer, - audio_encoder, - wav2vec_feature_extractor, + audio_encoder, transformer3d, args, accelerator, diff --git a/scripts/longcatvideo/train_avatar_lora.py b/scripts/longcatvideo/train_avatar_lora.py index 654ce4f9..12db804a 100644 --- a/scripts/longcatvideo/train_avatar_lora.py +++ b/scripts/longcatvideo/train_avatar_lora.py @@ -79,9 +79,9 @@ get_random_mask) from videox_fun.data.dataset_video import VideoSpeechDataset from videox_fun.models import (AutoencoderKLLongCatVideo, AutoTokenizer, - CLIPModel, LongCatVideoAvatarTransformer3DModel, - UMT5EncoderModel, Wav2Vec2FeatureExtractor, - Wav2Vec2ModelWrapper) + CLIPModel, LongCatVideoAudioEncoder, + LongCatVideoAvatarTransformer3DModel, + UMT5EncoderModel) from videox_fun.pipeline import LongCatVideoAvatarPipeline from videox_fun.utils.discrete_sampler import DiscreteSampling from videox_fun.utils.lora_utils import (convert_peft_lora_to_kohya_lora, @@ -175,7 +175,7 @@ def resize_mask(mask, latent, process_first_frame_only=True): logger = get_logger(__name__, log_level="INFO") -def log_validation(vae, text_encoder, tokenizer, audio_encoder, wav2vec_feature_extractor, transformer3d, network, args, accelerator, weight_dtype, global_step): +def log_validation(vae, text_encoder, tokenizer, audio_encoder, transformer3d, network, args, accelerator, weight_dtype, global_step): try: is_deepspeed = type(transformer3d).__name__ == 'DeepSpeedEngine' if is_deepspeed: @@ -195,7 +195,6 @@ def log_validation(vae, text_encoder, tokenizer, audio_encoder, wav2vec_feature_ transformer=accelerator.unwrap_model(transformer3d) if type(transformer3d).__name__ == 'DistributedDataParallel' else transformer3d, scheduler=scheduler, audio_encoder=audio_encoder, - wav2vec_feature_extractor=wav2vec_feature_extractor, ) pipeline = pipeline.to(accelerator.device) @@ -881,15 +880,10 @@ def deepspeed_zero_init_disabled_context_manager(): vae.eval() # Get Audio encoder (for avatar mode) - audio_encoder = Wav2Vec2ModelWrapper( + audio_encoder = LongCatVideoAudioEncoder( os.path.join(args.pretrained_avatar_model_name_or_path, 'chinese-wav2vec2-base') ) - audio_encoder.feature_extractor._freeze_parameters() - - wav2vec_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( - os.path.join(args.pretrained_avatar_model_name_or_path, 'chinese-wav2vec2-base'), - local_files_only=True - ) + audio_encoder.audio_encoder.feature_extractor._freeze_parameters() # Get Transformer transformer3d = LongCatVideoAvatarTransformer3DModel.from_pretrained( @@ -1709,53 +1703,16 @@ def _batch_encode_vae(pixel_values): inpaint_latents = (inpaint_latents - latents_mean) * latents_std with torch.no_grad(): - def _loudness_norm(audio_array, sr=16000, lufs=-23, threshold=100): - meter = pyln.Meter(sr) - loudness = meter.integrated_loudness(audio_array) - if abs(loudness) > threshold: - return audio_array - normalized_audio = pyln.normalize.loudness(audio_array, loudness, lufs) - return normalized_audio - - def _add_noise_floor(audio, noise_db=-45): - noise_amp = 10 ** (noise_db / 20) - noise = np.random.randn(len(audio)) * noise_amp - return audio + noise - - def _smooth_transients(audio, sr=16000): - b, a = ss.butter(3, 3000 / (sr/2)) - return ss.lfilter(b, a, audio) - audio_stride = 2 + num_frames = pixel_values.size()[1] audio_cond_embs = [] for index, speech_array in enumerate(audio): - # speech preprocess - speech_array = _loudness_norm(speech_array.cpu().numpy(), sample_rate[index]) - speech_array = _add_noise_floor(speech_array) - speech_array = _smooth_transients(speech_array) - - # wav2vec_feature_extractor - audio_feature = np.squeeze( - wav2vec_feature_extractor(speech_array, sampling_rate=sample_rate[index]).input_values - ) - audio_feature = torch.from_numpy(audio_feature).float().to(device=accelerator.device) - audio_feature = audio_feature.unsqueeze(0) - - # audio embedding - embeddings = audio_encoder(audio_feature, seq_len=int(audio_stride * pixel_values.size()[1]), output_hidden_states=True) - - audio_emb = torch.stack(embeddings.hidden_states[1:], dim=1).squeeze(0) - audio_emb = rearrange(audio_emb, "b s d -> s b d").contiguous() # T, 12, 768 - - # Prepare audio embedding with sliding window - indices = torch.arange(2 * 2 + 1) - 2 # [-2, -1, 0, 1, 2] - audio_start_idx = 0 - audio_end_idx = audio_start_idx + audio_stride * pixel_values.size()[1] - - center_indices = torch.arange(audio_start_idx, audio_end_idx, audio_stride).unsqueeze(1) + indices.unsqueeze(0) - center_indices = torch.clamp(center_indices, min=0, max=audio_emb.shape[0] - 1) - audio_emb = audio_emb[center_indices][None, ...].to(accelerator.device) - + audio_emb = audio_encoder.extract_audio_feat_without_file_load( + audio_segment=speech_array.cpu().numpy(), + sample_rate=sample_rate[index], + num_frames=num_frames, + audio_stride=audio_stride + ).to(accelerator.device) audio_cond_embs.append(audio_emb) audio_cond_embs = torch.cat(audio_cond_embs, dim=0) @@ -1937,8 +1894,7 @@ def custom_mse_loss(noise_pred, target, weighting=None, threshold=50): vae, text_encoder, tokenizer, - audio_encoder, - wav2vec_feature_extractor, + audio_encoder, transformer3d, network, args, @@ -1958,8 +1914,7 @@ def custom_mse_loss(noise_pred, target, weighting=None, threshold=50): vae, text_encoder, tokenizer, - audio_encoder, - wav2vec_feature_extractor, + audio_encoder, transformer3d, network, args, diff --git a/scripts/wan2.2/train_s2v.py b/scripts/wan2.2/train_s2v.py index 1175379c..4cb5b63d 100644 --- a/scripts/wan2.2/train_s2v.py +++ b/scripts/wan2.2/train_s2v.py @@ -1690,6 +1690,8 @@ def _batch_encode_vae(pixel_values): new_pixel_values.append(pixel_values_bs) return torch.cat(new_pixel_values, dim = 0) + # Control pixel values Process Start + # Used in padding if rng is None: zero_tail_frames = np.random.choice([0, 1], p = [0.90, 0.10]) else: @@ -1718,6 +1720,7 @@ def _batch_encode_vae(pixel_values): ref_latents = _batch_encode_vae(ref_pixel_values) # Encode Motion latents + # Determine whether to set motion_pixel_values to all zeros; all zeros means no reference value. if rng is None: zero_motion_pixel_values = np.random.choice([0, 1], p = [0.90, 0.10]) else: @@ -1726,7 +1729,13 @@ def _batch_encode_vae(pixel_values): height, width = control_pixel_values.size()[-2], control_pixel_values.size()[-1] motion_pixel_values = torch.zeros([1, args.motion_frames, 3, height, width], dtype=control_latents.dtype, device=control_latents.device) + # has_motion_pixel_values indicates whether there is a reference value; True means yes, False means no + # If there is reference content, it corresponds to the nth generation (not the first round), so the reference value is not processed. + # If there is no reference content, a reference value (first frame) can be assigned at this time or no operation is performed. has_motion_pixel_values = torch.sum(motion_pixel_values) == 0 + # Check clip_idx to see if ref_latents is the first frame + # If clip_idx is 0, it means ref_latents is the first frame, and a reference value can be assigned at this time + # If clip_idx is not 0, it means ref_latents is not the first frame, and a reference value cannot be assigned at this time if torch.sum(clip_idx) != 0: init_first_frame = False else: @@ -1735,16 +1744,20 @@ def _batch_encode_vae(pixel_values): else: init_first_frame = rng.choice([0, 1], p = [0.50, 0.50]) if init_first_frame or has_motion_pixel_values: + # If has_motion_pixel_values=False but enters the if statement, + # it means clip_idx is 0 and the first frame is used as reference. if not has_motion_pixel_values: motion_pixel_values[:, -6:, :] = ref_pixel_values - + motion_frames_latents_length = int((args.motion_frames - 1) / sample_n_frames_bucket_interval + 1) local_pixel_values = torch.cat([motion_pixel_values, pixel_values], dim = 1) local_latents = _batch_encode_vae(local_pixel_values) + # Separate motion_latents and the inferred latents latents = local_latents[:, :, motion_frames_latents_length:] motion_latents = local_latents[:, :, :motion_frames_latents_length] drop_motion_frames = False else: + # No motion_latents reference value, but has ref_latents; typically the first round of generation. local_pixel_values = torch.cat([ref_pixel_values, pixel_values], dim = 1) latents = _batch_encode_vae(local_pixel_values) latents = latents[:, :, 1:] @@ -1801,13 +1814,14 @@ def _batch_encode_vae(pixel_values): for bs_index in range(audio_wav2vec_fea.size()[0]): if rng is None: - zero_init_control_latents_conv_in = np.random.choice([0, 1], p = [0.90, 0.10]) + zero_init_audio_wav2vec_fea = np.random.choice([0, 1], p = [0.90, 0.10]) else: - zero_init_control_latents_conv_in = rng.choice([0, 1], p = [0.90, 0.10]) + zero_init_audio_wav2vec_fea = rng.choice([0, 1], p = [0.90, 0.10]) - if zero_init_control_latents_conv_in: + if zero_init_audio_wav2vec_fea: audio_wav2vec_fea[bs_index] = torch.ones_like(audio_wav2vec_fea[bs_index]) * 0 + # Used in padding if zero_tail_frames: audio_wav2vec_fea[..., zero_frames_num:] = torch.zeros_like(audio_wav2vec_fea[..., zero_frames_num:]) # audio_wav2vec_fea = audio_wav2vec_fea[..., :control_pixel_values.size()[1]] diff --git a/scripts/wan2.2/train_s2v_lora.py b/scripts/wan2.2/train_s2v_lora.py index a0d23943..078602cd 100644 --- a/scripts/wan2.2/train_s2v_lora.py +++ b/scripts/wan2.2/train_s2v_lora.py @@ -1700,6 +1700,8 @@ def _batch_encode_vae(pixel_values): new_pixel_values.append(pixel_values_bs) return torch.cat(new_pixel_values, dim = 0) + # Control pixel values Process Start + # Used in padding if rng is None: zero_tail_frames = np.random.choice([0, 1], p = [0.90, 0.10]) else: @@ -1728,6 +1730,7 @@ def _batch_encode_vae(pixel_values): ref_latents = _batch_encode_vae(ref_pixel_values) # Encode Motion latents + # Determine whether to set motion_pixel_values to all zeros; all zeros means no reference value. if rng is None: zero_motion_pixel_values = np.random.choice([0, 1], p = [0.90, 0.10]) else: @@ -1736,7 +1739,13 @@ def _batch_encode_vae(pixel_values): height, width = control_pixel_values.size()[-2], control_pixel_values.size()[-1] motion_pixel_values = torch.zeros([1, args.motion_frames, 3, height, width], dtype=control_latents.dtype, device=control_latents.device) + # has_motion_pixel_values indicates whether there is a reference value; True means yes, False means no + # If there is reference content, it corresponds to the nth generation (not the first round), so the reference value is not processed. + # If there is no reference content, a reference value (first frame) can be assigned at this time or no operation is performed. has_motion_pixel_values = torch.sum(motion_pixel_values) == 0 + # Check clip_idx to see if ref_latents is the first frame + # If clip_idx is 0, it means ref_latents is the first frame, and a reference value can be assigned at this time + # If clip_idx is not 0, it means ref_latents is not the first frame, and a reference value cannot be assigned at this time if torch.sum(clip_idx) != 0: init_first_frame = False else: @@ -1745,16 +1754,20 @@ def _batch_encode_vae(pixel_values): else: init_first_frame = rng.choice([0, 1], p = [0.50, 0.50]) if init_first_frame or has_motion_pixel_values: + # If has_motion_pixel_values=False but enters the if statement, + # it means clip_idx is 0 and the first frame is used as reference. if not has_motion_pixel_values: motion_pixel_values[:, -6:, :] = ref_pixel_values - + motion_frames_latents_length = int((args.motion_frames - 1) / sample_n_frames_bucket_interval + 1) local_pixel_values = torch.cat([motion_pixel_values, pixel_values], dim = 1) local_latents = _batch_encode_vae(local_pixel_values) + # Separate motion_latents and the inferred latents latents = local_latents[:, :, motion_frames_latents_length:] motion_latents = local_latents[:, :, :motion_frames_latents_length] drop_motion_frames = False else: + # No motion_latents reference value, but has ref_latents; typically the first round of generation. local_pixel_values = torch.cat([ref_pixel_values, pixel_values], dim = 1) latents = _batch_encode_vae(local_pixel_values) latents = latents[:, :, 1:] @@ -1811,13 +1824,14 @@ def _batch_encode_vae(pixel_values): for bs_index in range(audio_wav2vec_fea.size()[0]): if rng is None: - zero_init_control_latents_conv_in = np.random.choice([0, 1], p = [0.90, 0.10]) + zero_init_audio_wav2vec_fea = np.random.choice([0, 1], p = [0.90, 0.10]) else: - zero_init_control_latents_conv_in = rng.choice([0, 1], p = [0.90, 0.10]) + zero_init_audio_wav2vec_fea = rng.choice([0, 1], p = [0.90, 0.10]) - if zero_init_control_latents_conv_in: + if zero_init_audio_wav2vec_fea: audio_wav2vec_fea[bs_index] = torch.ones_like(audio_wav2vec_fea[bs_index]) * 0 + # Used in padding if zero_tail_frames: audio_wav2vec_fea[..., zero_frames_num:] = torch.zeros_like(audio_wav2vec_fea[..., zero_frames_num:]) # audio_wav2vec_fea = audio_wav2vec_fea[..., :control_pixel_values.size()[1]] diff --git a/videox_fun/data/dataset_video.py b/videox_fun/data/dataset_video.py index 91147603..bf8c9695 100644 --- a/videox_fun/data/dataset_video.py +++ b/videox_fun/data/dataset_video.py @@ -214,7 +214,8 @@ def __init__( self, ann_path, data_root=None, video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16, - enable_bucket=False, enable_inpaint=False, + enable_bucket=False, + enable_inpaint=False, audio_sr=16000, # New: target audio sample rate text_drop_ratio=0.1 # New: text drop probability ): @@ -290,35 +291,35 @@ def get_batch(self, idx): pixel_values = pixel_values / 255. pixel_values = self.pixel_transforms(pixel_values) - # === New: Load and extract the corresponding audio segment === - # Start and end times (in seconds) of the video clip - start_time = start_frame / fps - end_time = (start_frame + (actual_n_frames - 1) * self.video_sample_stride) / fps - duration = end_time - start_time + # === New: Load and extract the corresponding audio segment === + # Start and end times (in seconds) of the video clip + start_time = start_frame / fps + end_time = (start_frame + (actual_n_frames - 1) * self.video_sample_stride) / fps + duration = end_time - start_time - # Use librosa to load the entire audio (librosa.load does not support precise seeking, so load first then slice) - audio_input, sample_rate = librosa.load(audio_path, sr=self.audio_sr) # Resample to target sr + # Use librosa to load the entire audio (librosa.load does not support precise seeking, so load first then slice) + audio_input, sample_rate = librosa.load(audio_path, sr=self.audio_sr) # Resample to target sr - # Convert to sample indices - start_sample = int(start_time * self.audio_sr) - end_sample = int(end_time * self.audio_sr) + # Convert to sample indices + start_sample = int(start_time * self.audio_sr) + end_sample = int(end_time * self.audio_sr) - # Safe slicing - if start_sample >= len(audio_input): - # Audio is too short, pad with zeros or truncate - audio_segment = np.zeros(int(duration * self.audio_sr), dtype=np.float32) - else: - audio_segment = audio_input[start_sample:end_sample] - # If too short, pad with zeros - target_len = int(duration * self.audio_sr) - if len(audio_segment) < target_len: - audio_segment = np.pad(audio_segment, (0, target_len - len(audio_segment)), mode='constant') + # Safe slicing + if start_sample >= len(audio_input): + # Audio is too short, pad with zeros or truncate + raise ValueError(f"Audio file too short: {audio_path}") + else: + audio_segment = audio_input[start_sample:end_sample] + # If too short, pad with zeros + target_len = int(duration * self.audio_sr) + if len(audio_segment) < target_len: + raise ValueError(f"Audio file too short: {audio_path}") - # === Random text dropping === - if random.random() < self.text_drop_ratio: - text = '' + # === Random text dropping === + if random.random() < self.text_drop_ratio: + text = '' - return pixel_values, text, audio_segment, sample_rate + return pixel_values, text, audio_segment, sample_rate def __len__(self): return self.length @@ -356,7 +357,8 @@ def __init__( self, ann_path, data_root=None, video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16, - enable_bucket=False, enable_inpaint=False, + enable_bucket=False, + enable_inpaint=False, audio_sr=16000, text_drop_ratio=0.1, enable_motion_info=False, @@ -415,24 +417,29 @@ def get_batch(self, idx): # Video information with VideoReader_contextmanager(video_path, num_threads=2) as video_reader: total_frames = len(video_reader) - fps = video_reader.get_avg_fps() + fps = video_reader.get_avg_fps() # Get the original video frame rate if fps <= 0: raise ValueError(f"Video has negative fps: {video_path}") + + # Avoid fps > 30 local_video_sample_stride = self.video_sample_stride new_fps = int(fps // local_video_sample_stride) while new_fps > 30: local_video_sample_stride = local_video_sample_stride + 1 new_fps = int(fps // local_video_sample_stride) + # Calculate the actual number of sampled video frames (considering boundaries) max_possible_frames = (total_frames - 1) // local_video_sample_stride + 1 actual_n_frames = min(self.video_sample_n_frames, max_possible_frames) if actual_n_frames <= 0: raise ValueError(f"Video too short: {video_path}") + # Randomly select the starting frame max_start = total_frames - (actual_n_frames - 1) * local_video_sample_stride - 1 start_frame = random.randint(0, max_start) if max_start > 0 else 0 frame_indices = [start_frame + i * local_video_sample_stride for i in range(actual_n_frames)] + # Read video frames try: sample_args = (video_reader, frame_indices) pixel_values = func_timeout( @@ -443,6 +450,7 @@ def get_batch(self, idx): except Exception as e: raise ValueError(f"Failed to extract frames from video. Error is {e}.") + # Motion Video Process for Wan-S2V _, height, width, channel = np.shape(pixel_values) if self.enable_motion_info: motion_pixel_values = np.ones([self.motion_frames, height, width, channel]) * 127.5 @@ -464,28 +472,12 @@ def get_batch(self, idx): else: motion_pixel_values = None + # Video post-processing if not self.enable_bucket: pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous() pixel_values = pixel_values / 255. pixel_values = self.pixel_transforms(pixel_values) - # Audio information - start_time = start_frame / fps - end_time = (start_frame + (actual_n_frames - 1) * local_video_sample_stride) / fps - duration = end_time - start_time - - audio_input, sample_rate = librosa.load(audio_path, sr=self.audio_sr) - start_sample = int(start_time * self.audio_sr) - end_sample = int(end_time * self.audio_sr) - - if start_sample >= len(audio_input): - raise ValueError(f"Audio file too short: {audio_path}") - else: - audio_segment = audio_input[start_sample:end_sample] - target_len = int(duration * self.audio_sr) - if len(audio_segment) < target_len: - raise ValueError(f"Audio file too short: {audio_path}") - # Control information with VideoReader_contextmanager(control_video_id, num_threads=2) as control_video_reader: try: @@ -507,13 +499,34 @@ def get_batch(self, idx): if not self.enable_bucket: control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous() control_pixel_values = control_pixel_values / 255. + control_pixel_values = self.pixel_transforms(control_pixel_values) del control_video_reader - else: - control_pixel_values = control_pixel_values - if not self.enable_bucket: - control_pixel_values = self.video_transforms(control_pixel_values) + # === New: Load and extract the corresponding audio segment === + # Start and end times (in seconds) of the video clip + start_time = start_frame / fps + end_time = (start_frame + (actual_n_frames - 1) * local_video_sample_stride) / fps + duration = end_time - start_time + + # Use librosa to load the entire audio (librosa.load does not support precise seeking, so load first then slice) + audio_input, sample_rate = librosa.load(audio_path, sr=self.audio_sr) # Resample to target sr + + # Convert to sample indices + start_sample = int(start_time * self.audio_sr) + end_sample = int(end_time * self.audio_sr) + + # Safe slicing + if start_sample >= len(audio_input): + # Audio is too short, pad with zeros or truncate + raise ValueError(f"Audio file too short: {audio_path}") + else: + audio_segment = audio_input[start_sample:end_sample] + # If too short, pad with zeros + target_len = int(duration * self.audio_sr) + if len(audio_segment) < target_len: + raise ValueError(f"Audio file too short: {audio_path}") + # === Random text dropping === if random.random() < self.text_drop_ratio: text = '' diff --git a/videox_fun/models/__init__.py b/videox_fun/models/__init__.py index 7f0d9f9b..6b5a14ef 100755 --- a/videox_fun/models/__init__.py +++ b/videox_fun/models/__init__.py @@ -30,7 +30,8 @@ from .flux_transformer2d import FluxTransformer2DModel from .hunyuanvideo_transformer3d import HunyuanVideoTransformer3DModel from .hunyuanvideo_vae import AutoencoderKLHunyuanVideo -from .longcatvideo_audio_encoder import Wav2Vec2ModelWrapper +from .longcatvideo_audio_encoder import (LongCatVideoAudioEncoder, + Wav2Vec2ModelWrapper) from .longcatvideo_transformer3d import LongCatVideoTransformer3DModel from .longcatvideo_transformer3d_avatar import \ LongCatVideoAvatarTransformer3DModel diff --git a/videox_fun/models/fantasytalking_audio_encoder.py b/videox_fun/models/fantasytalking_audio_encoder.py index 6201e1f9..4b32e428 100644 --- a/videox_fun/models/fantasytalking_audio_encoder.py +++ b/videox_fun/models/fantasytalking_audio_encoder.py @@ -20,7 +20,13 @@ def __init__(self, pretrained_model_path="facebook/wav2vec2-base-960h", device=' self.model = Wav2Vec2Model.from_pretrained(pretrained_model_path) self.model = self.model.to(device) - def extract_audio_feat(self, audio_path, num_frames = 81, fps = 16, sr = 16000): + def extract_audio_feat( + self, + audio_path, + num_frames = 81, + fps = 16, + sr = 16000 + ): audio_input, sample_rate = librosa.load(audio_path, sr=sr) start_time = 0 @@ -34,6 +40,7 @@ def extract_audio_feat(self, audio_path, num_frames = 81, fps = 16, sr = 16000): except Exception: audio_segment = audio_input + # INFERENCE input_values = self.processor( audio_segment, sampling_rate=sample_rate, return_tensors="pt" ).input_values.to(self.model.device, self.model.dtype) @@ -47,6 +54,7 @@ def extract_audio_feat_without_file_load(self, audio_segment, sample_rate): audio_segment, sampling_rate=sample_rate, return_tensors="pt" ).input_values.to(self.model.device, self.model.dtype) + # INFERENCE with torch.no_grad(): fea = self.model(input_values).last_hidden_state return fea \ No newline at end of file diff --git a/videox_fun/models/longcatvideo_audio_encoder.py b/videox_fun/models/longcatvideo_audio_encoder.py index a3a41edd..c05433dc 100644 --- a/videox_fun/models/longcatvideo_audio_encoder.py +++ b/videox_fun/models/longcatvideo_audio_encoder.py @@ -1,17 +1,24 @@ # Modified from https://github.com/meituan-longcat/LongCat-Video/blob/main/longcat_video/audio_process/wav2vec2.py import copy import logging +import math import os +import librosa +import numpy as np import torch import torch.nn as nn -from transformers import Wav2Vec2Config +import torch.nn.functional as F +from diffusers.configuration_utils import ConfigMixin +from diffusers.loaders.single_file_model import FromOriginalModelMixin +from diffusers.models.modeling_utils import ModelMixin +from einops import rearrange +from transformers import Wav2Vec2Config, Wav2Vec2FeatureExtractor from transformers import Wav2Vec2Model as Wav2Vec2Model_base from transformers.activations import ACT2FN from transformers.modeling_outputs import BaseModelOutput from transformers.models.wav2vec2.modeling_wav2vec2 import ( Wav2Vec2PositionalConvEmbedding, Wav2Vec2SamePadLayer) -import torch.nn.functional as F def linear_interpolation(features, seq_len): @@ -265,4 +272,157 @@ def encode( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - ) \ No newline at end of file + ) + + +class LongCatVideoAudioEncoder(ModelMixin, ConfigMixin, FromOriginalModelMixin): + """Audio encoder for LongCatVideo Avatar pipeline. + + This class provides a clean interface for audio feature extraction, + similar to FantasyTalkingAudioEncoder but with LongCatVideo-specific + audio preprocessing (loudness normalization, noise floor, transient smoothing). + + Uses existing Wav2Vec2ModelWrapper and Wav2Vec2FeatureExtractor internally. + """ + + def __init__(self, config_path, device='cpu', prefix='wav2vec2.'): + super(LongCatVideoAudioEncoder, self).__init__() + + # Use existing Wav2Vec2ModelWrapper + self.audio_encoder = Wav2Vec2ModelWrapper(config_path, device=device, prefix=prefix) + + # Use existing Wav2Vec2FeatureExtractor + self.wav2vec_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(config_path) + + @property + def dtype(self): + return self.audio_encoder.dtype + + @property + def device(self): + return self.audio_encoder.device + + def _loudness_norm(self, audio_array, sr=16000, lufs=-23, threshold=100): + """Normalize audio loudness to target LUFS.""" + import pyloudnorm as pyln + meter = pyln.Meter(sr) + loudness = meter.integrated_loudness(audio_array) + if abs(loudness) > threshold: + return audio_array + normalized_audio = pyln.normalize.loudness(audio_array, loudness, lufs) + return normalized_audio + + def _add_noise_floor(self, audio, noise_db=-45): + """Add noise floor to audio.""" + noise_amp = 10 ** (noise_db / 20) + noise = np.random.randn(len(audio)) * noise_amp + return audio + noise + + def _smooth_transients(self, audio, sr=16000): + """Smooth audio transients using low-pass filter.""" + import scipy.signal as ss + b, a = ss.butter(3, 3000 / (sr / 2)) + return ss.lfilter(b, a, audio) + + def _preprocess_audio(self, speech_array, sample_rate=16000): + """Apply LongCatVideo-specific audio preprocessing.""" + speech_array = self._loudness_norm(speech_array, sample_rate) + speech_array = self._add_noise_floor(speech_array) + speech_array = self._smooth_transients(speech_array, sample_rate) + return speech_array + + @torch.no_grad() + def _extract_embedding(self, speech_array, sample_rate, num_frames, audio_stride=2): + """Core method to extract audio embedding from preprocessed speech array. + + Args: + speech_array: Preprocessed audio array. + sample_rate: Audio sample rate. + num_frames: Number of video frames. + audio_stride: Audio stride for sliding window. + + Returns: + Audio embeddings tensor of shape [1, num_frames, 5, 12, 768]. + """ + seq_len = int(audio_stride * num_frames) + + # wav2vec_feature_extractor + audio_feature = np.squeeze( + self.wav2vec_feature_extractor(speech_array, sampling_rate=sample_rate).input_values + ) + audio_feature = torch.from_numpy(audio_feature).float().to(device=self.device) + audio_feature = audio_feature.unsqueeze(0) + + # audio embedding using Wav2Vec2ModelWrapper + embeddings = self.audio_encoder(audio_feature, seq_len=seq_len, output_hidden_states=True) + + audio_emb = torch.stack(embeddings.hidden_states[1:], dim=1).squeeze(0) + audio_emb = rearrange(audio_emb, "b s d -> s b d").contiguous() # T, 12, 768 + + # Prepare audio embedding with sliding window + indices = torch.arange(2 * 2 + 1) - 2 # [-2, -1, 0, 1, 2] + audio_start_idx = 0 + audio_end_idx = audio_start_idx + audio_stride * num_frames + + center_indices = torch.arange(audio_start_idx, audio_end_idx, audio_stride).unsqueeze(1) + \ + indices.unsqueeze(0) + center_indices = torch.clamp(center_indices, min=0, max=audio_emb.shape[0] - 1) + audio_emb = audio_emb[center_indices][None, ...] # [1, num_frames, 5, 12, 768] + + return audio_emb + + def extract_audio_feat( + self, + audio_path, + num_frames=49, + fps=16, + sr=16000, + audio_stride=2 + ): + """Extract audio features from audio file. + + Args: + audio_path: Path to audio file. + num_frames: Number of video frames. + fps: Video frames per second. + sr: Audio sample rate. + audio_stride: Audio stride for sliding window. + + Returns: + Audio embeddings tensor of shape [1, num_frames, 5, 12, 768]. + """ + # Load audio + speech_array, sample_rate = librosa.load(audio_path, sr=sr) + + # Pad audio to target length + generate_duration = num_frames / fps + source_duration = len(speech_array) / sample_rate + added_sample_nums = math.ceil((generate_duration - source_duration) * sample_rate) + if added_sample_nums > 0: + speech_array = np.append(speech_array, [0.] * added_sample_nums) + + # Preprocess and extract embedding + speech_array = self._preprocess_audio(speech_array, sample_rate) + return self._extract_embedding(speech_array, sample_rate, num_frames, audio_stride) + + def extract_audio_feat_without_file_load( + self, + audio_segment, + sample_rate, + num_frames=49, + audio_stride=2 + ): + """Extract audio features from audio array without file loading. + + Args: + audio_segment: Audio array (numpy array). + sample_rate: Audio sample rate. + num_frames: Number of video frames. + audio_stride: Audio stride for sliding window. + + Returns: + Audio embeddings tensor of shape [1, num_frames, 5, 12, 768]. + """ + # Preprocess and extract embedding + speech_array = self._preprocess_audio(audio_segment, sample_rate) + return self._extract_embedding(speech_array, sample_rate, num_frames, audio_stride) \ No newline at end of file diff --git a/videox_fun/models/wan_audio_encoder.py b/videox_fun/models/wan_audio_encoder.py index 5dc652c8..011af85f 100644 --- a/videox_fun/models/wan_audio_encoder.py +++ b/videox_fun/models/wan_audio_encoder.py @@ -57,7 +57,6 @@ def linear_interpolation(features, input_fps, output_fps, output_len=None): class WanAudioEncoder(ModelMixin, ConfigMixin, FromOriginalModelMixin): - def __init__(self, pretrained_model_path="facebook/wav2vec2-base-960h", device='cpu'): super(WanAudioEncoder, self).__init__() # load pretrained model @@ -68,49 +67,44 @@ def __init__(self, pretrained_model_path="facebook/wav2vec2-base-960h", device=' self.video_rate = 30 - def extract_audio_feat(self, - audio_path, - return_all_layers=False, - dtype=torch.float32): - audio_input, sample_rate = librosa.load(audio_path, sr=16000) + def extract_audio_feat( + self, + audio_path, + return_all_layers=False, + sr = 16000, + ): + audio_input, sample_rate = librosa.load(audio_path, sr=sr) input_values = self.processor( audio_input, sampling_rate=sample_rate, return_tensors="pt" ).input_values # INFERENCE - - # retrieve logits & take argmax - res = self.model( - input_values.to(self.model.device), output_hidden_states=True) - if return_all_layers: - feat = torch.cat(res.hidden_states) - else: - feat = res.hidden_states[-1] - feat = linear_interpolation( - feat, input_fps=50, output_fps=self.video_rate) - - z = feat.to(dtype) # Encoding for the motion - return z + with torch.no_grad(): + res = self.model( + input_values.to(self.model.device), output_hidden_states=True) + if return_all_layers: + feat = torch.cat(res.hidden_states) + else: + feat = res.hidden_states[-1] + feat = linear_interpolation(feat, input_fps=50, output_fps=self.video_rate) + return feat - def extract_audio_feat_without_file_load(self, audio_input, sample_rate, return_all_layers=False, dtype=torch.float32): + def extract_audio_feat_without_file_load(self, audio_segment, sample_rate, return_all_layers=False): input_values = self.processor( - audio_input, sampling_rate=sample_rate, return_tensors="pt" + audio_segment, sampling_rate=sample_rate, return_tensors="pt" ).input_values # INFERENCE - # retrieve logits & take argmax - res = self.model( - input_values.to(self.model.device), output_hidden_states=True) - if return_all_layers: - feat = torch.cat(res.hidden_states) - else: - feat = res.hidden_states[-1] - feat = linear_interpolation( - feat, input_fps=50, output_fps=self.video_rate) - - z = feat.to(dtype) # Encoding for the motion - return z + with torch.no_grad(): + res = self.model( + input_values.to(self.model.device), output_hidden_states=True) + if return_all_layers: + feat = torch.cat(res.hidden_states) + else: + feat = res.hidden_states[-1] + feat = linear_interpolation(feat, input_fps=50, output_fps=self.video_rate) + return feat def get_audio_embed_bucket(self, audio_embed, @@ -207,7 +201,6 @@ def get_audio_embed_bucket_fps(self, torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \ else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device) batch_audio_eb.append(frame_audio_embed) - batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], - dim=0) + batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], dim=0) return batch_audio_eb, min_batch_num \ No newline at end of file diff --git a/videox_fun/pipeline/pipeline_fantasy_talking.py b/videox_fun/pipeline/pipeline_fantasy_talking.py index 114bd228..8243033a 100644 --- a/videox_fun/pipeline/pipeline_fantasy_talking.py +++ b/videox_fun/pipeline/pipeline_fantasy_talking.py @@ -22,7 +22,7 @@ from transformers import T5Tokenizer from ..models import (AutoencoderKLWan, AutoTokenizer, CLIPModel, - Wan2_2Transformer3DModel_S2V, WanAudioEncoder, + FantasyTalkingTransformer3DModel, FantasyTalkingAudioEncoder, WanT5EncoderModel) from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas) @@ -159,8 +159,8 @@ class FantasyTalkingPipeline(DiffusionPipeline): library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) """ - _optional_components = ["transformer_2", "audio_encoder"] - model_cpu_offload_seq = "text_encoder->transformer_2->transformer->vae" + _optional_components = ["audio_encoder"] + model_cpu_offload_seq = "text_encoder->clip_image_encoder->transformer->vae" _callback_tensor_inputs = [ "latents", @@ -172,18 +172,17 @@ def __init__( self, tokenizer: AutoTokenizer, text_encoder: WanT5EncoderModel, - audio_encoder: WanAudioEncoder, + audio_encoder: FantasyTalkingAudioEncoder, vae: AutoencoderKLWan, - transformer: Wan2_2Transformer3DModel_S2V, + transformer: FantasyTalkingTransformer3DModel, clip_image_encoder: CLIPModel, - transformer_2: Wan2_2Transformer3DModel_S2V = None, scheduler: FlowMatchEulerDiscreteScheduler = None, ): super().__init__() self.register_modules( tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, - transformer_2=transformer_2, scheduler=scheduler, clip_image_encoder=clip_image_encoder, audio_encoder=audio_encoder + clip_image_encoder=clip_image_encoder, audio_encoder=audio_encoder, scheduler=scheduler, ) self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) @@ -316,6 +315,11 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds + def encode_audio_embeddings(self, audio_path, num_frames, fps, weight_dtype, device): + audio_wav2vec_fea = self.audio_encoder.extract_audio_feat(audio_path, num_frames=num_frames, fps=fps) + audio_wav2vec_fea = audio_wav2vec_fea.to(device, weight_dtype) + return audio_wav2vec_fea + def prepare_latents( self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None, num_length_latents=None ): @@ -655,7 +659,9 @@ def __call__( clip_context = torch.zeros_like(clip_context) # Extract audio emb - audio_wav2vec_fea = self.audio_encoder.extract_audio_feat(audio_path, num_frames=num_frames, fps=fps) + audio_wav2vec_fea = self.encode_audio_embeddings( + audio_path, num_frames=num_frames, fps=fps, weight_dtype=weight_dtype, device=device + ) if comfyui_progressbar: pbar.update(1) diff --git a/videox_fun/pipeline/pipeline_longcatvideo_avatar.py b/videox_fun/pipeline/pipeline_longcatvideo_avatar.py index dd4cec8a..dc0d2dc7 100644 --- a/videox_fun/pipeline/pipeline_longcatvideo_avatar.py +++ b/videox_fun/pipeline/pipeline_longcatvideo_avatar.py @@ -5,7 +5,6 @@ from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Tuple, Union -import librosa import numpy as np import torch import torch.nn.functional as F @@ -22,9 +21,9 @@ from PIL import Image from ..models import (AutoencoderKLWan, AutoTokenizer, + LongCatVideoAudioEncoder, LongCatVideoAvatarTransformer3DModel, - LongCatVideoTransformer3DModel, UMT5EncoderModel, - Wav2Vec2FeatureExtractor, Wav2Vec2ModelWrapper) + LongCatVideoTransformer3DModel, UMT5EncoderModel) logger = logging.get_logger(__name__) @@ -141,8 +140,7 @@ def __init__( vae: AutoencoderKLWan, transformer: LongCatVideoAvatarTransformer3DModel, scheduler: FlowMatchEulerDiscreteScheduler, - audio_encoder: Wav2Vec2ModelWrapper, - wav2vec_feature_extractor: Wav2Vec2FeatureExtractor + audio_encoder: LongCatVideoAudioEncoder, ): super().__init__() @@ -153,7 +151,6 @@ def __init__( transformer=transformer, scheduler=scheduler, audio_encoder=audio_encoder, - wav2vec_feature_extractor=wav2vec_feature_extractor ) self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 @@ -421,78 +418,15 @@ def prepare_extra_step_kwargs(self, generator, eta): extra_step_kwargs["generator"] = generator return extra_step_kwargs - def _loudness_norm(self, audio_array, sr=16000, lufs=-23, threshold=100): - import pyloudnorm as pyln - meter = pyln.Meter(sr) - loudness = meter.integrated_loudness(audio_array) - if abs(loudness) > threshold: - return audio_array - normalized_audio = pyln.normalize.loudness(audio_array, loudness, lufs) - return normalized_audio - - def _add_noise_floor(self, audio, noise_db=-45): - noise_amp = 10 ** (noise_db / 20) - noise = np.random.randn(len(audio)) * noise_amp - return audio + noise - - def _smooth_transients(self, audio, sr=16000): - import scipy.signal as ss - b, a = ss.butter(3, 3000 / (sr/2)) - return ss.lfilter(b, a, audio) - - @torch.no_grad() - def get_audio_embedding(self, speech_array, fps=32, device='cpu', sample_rate=16000): - audio_duration = len(speech_array) / sample_rate - video_length = audio_duration * fps - - # speech preprocess - speech_array = self._loudness_norm(speech_array, sample_rate) - speech_array = self._add_noise_floor(speech_array) - speech_array = self._smooth_transients(speech_array) - - # wav2vec_feature_extractor - audio_feature = np.squeeze( - self.wav2vec_feature_extractor(speech_array, sampling_rate=sample_rate).input_values + def encode_audio_embeddings(self, audio_path, num_frames, fps, weight_dtype, device, audio_stride=2): + """Encode audio embeddings using LongCatVideoAudioEncoder.""" + audio_emb = self.audio_encoder.extract_audio_feat( + audio_path, + num_frames=num_frames, + fps=fps, + audio_stride=audio_stride ) - audio_feature = torch.from_numpy(audio_feature).float().to(device=device) - audio_feature = audio_feature.unsqueeze(0) - - # audio embedding - embeddings = self.audio_encoder(audio_feature, seq_len=int(video_length), output_hidden_states=True) - - audio_emb = torch.stack(embeddings.hidden_states[1:], dim=1).squeeze(0) - audio_emb = rearrange(audio_emb, "b s d -> s b d").contiguous() # T, 12, 768 - return audio_emb - - def encode_audio_embeddings(self, audio_path, num_frames, fps, weight_dtype, device, audio_stride = 2): - # Load and pad audio to target length - speech_array, sample_rate = librosa.load(audio_path, sr=16000) - - generate_duration = num_frames / fps - source_duration = len(speech_array) / sample_rate - added_sample_nums = math.ceil((generate_duration - source_duration) * sample_rate) - if added_sample_nums > 0: - speech_array = np.append(speech_array, [0.] * added_sample_nums) - - # Get audio embedding - with torch.no_grad(): - audio_emb = self.get_audio_embedding( - speech_array, - fps=fps * audio_stride, - device=device, - sample_rate=sample_rate - ) - - # Prepare audio embedding with sliding window - indices = torch.arange(2 * 2 + 1) - 2 # [-2, -1, 0, 1, 2] - audio_start_idx = 0 - audio_end_idx = audio_start_idx + audio_stride * num_frames - - center_indices = torch.arange(audio_start_idx, audio_end_idx, audio_stride).unsqueeze(1) + \ - indices.unsqueeze(0) - center_indices = torch.clamp(center_indices, min=0, max=audio_emb.shape[0] - 1) - audio_emb = audio_emb[center_indices][None, ...].to(device, weight_dtype) - return audio_emb + return audio_emb.to(device, weight_dtype) @property def guidance_scale(self): diff --git a/videox_fun/pipeline/pipeline_wan2_2_s2v.py b/videox_fun/pipeline/pipeline_wan2_2_s2v.py index eb421af3..6fbcbaf9 100644 --- a/videox_fun/pipeline/pipeline_wan2_2_s2v.py +++ b/videox_fun/pipeline/pipeline_wan2_2_s2v.py @@ -322,8 +322,7 @@ def encode_audio_embeddings(self, audio_path, num_frames, fps, weight_dtype, dev audio_path, return_all_layers=True) audio_embed_bucket, num_repeat = self.audio_encoder.get_audio_embed_bucket_fps( z, fps=fps, batch_frames=num_frames, m=self.audio_sample_m) - audio_embed_bucket = audio_embed_bucket.to(device, - weight_dtype) + audio_embed_bucket = audio_embed_bucket.to(device, weight_dtype) audio_embed_bucket = audio_embed_bucket.unsqueeze(0) if len(audio_embed_bucket.shape) == 3: audio_embed_bucket = audio_embed_bucket.permute(0, 2, 1) @@ -587,7 +586,9 @@ def __call__( # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 + # lat_motion_frames = 76 / 4 = 19 lat_motion_frames = (self.motion_frames + 3) // 4 + # lat_motion_frames ~= num_frames // 4 lat_target_frames = (num_frames + 3 + self.motion_frames) // 4 - lat_motion_frames # 3. Encode input prompt @@ -610,7 +611,7 @@ def __call__( from comfy.utils import ProgressBar pbar = ProgressBar(num_inference_steps + 2) - # 5. Prepare latents. + # 4. Prepare latents. latent_channels = self.vae.config.latent_channels if comfyui_progressbar: pbar.update(1) @@ -670,14 +671,14 @@ def __call__( if comfyui_progressbar: pbar.update(1) - # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + # 5. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) videos = [] copy_timesteps = copy.deepcopy(timesteps) copy_latents = copy.deepcopy(latents) for r in range(num_repeat): - # Prepare timesteps + # 6. Prepare timesteps if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, copy_timesteps, mu=1) elif isinstance(self.scheduler, FlowUniPCMultistepScheduler): @@ -693,6 +694,7 @@ def __call__( timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, copy_timesteps) self._num_timesteps = len(timesteps) + # 7. Prepare latents again. target_shape = (self.vae.latent_channels, lat_target_frames, width // self.vae.spatial_compression_ratio, height // self.vae.spatial_compression_ratio) seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1]) @@ -708,7 +710,7 @@ def __call__( copy_latents, num_length_latents=target_shape[1] ) - # 7. Denoising loop + # 8. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self.transformer.num_inference_steps = num_inference_steps with self.progress_bar(total=num_inference_steps) as progress_bar: diff --git a/videox_fun/utils/utils.py b/videox_fun/utils/utils.py index d2084586..675c4fff 100755 --- a/videox_fun/utils/utils.py +++ b/videox_fun/utils/utils.py @@ -274,7 +274,10 @@ def get_video_to_video_latent(input_video_path, video_length, sample_size, fps=N else: input_video = input_video_path - input_video = torch.from_numpy(np.array(input_video))[:video_length] + if video_length is not None: + input_video = torch.from_numpy(np.array(input_video))[:video_length] + else: + input_video = torch.from_numpy(np.array(input_video)) input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0) / 255 if validation_video_mask is not None: From 3ced952e4a2f0d9343d180ca63b703a625ba3a83 Mon Sep 17 00:00:00 2001 From: bubbliiiing <3323290568@qq.com> Date: Wed, 11 Mar 2026 16:35:51 +0800 Subject: [PATCH 02/10] Update s2v models training code --- scripts/fantasytalking/train.py | 18 ++++---- scripts/longcatvideo/train_avatar.py | 3 ++ scripts/longcatvideo/train_avatar_lora.py | 3 ++ scripts/wan2.2/train_s2v.py | 55 +++++++++++------------ scripts/wan2.2/train_s2v_lora.py | 55 +++++++++++------------ 5 files changed, 70 insertions(+), 64 deletions(-) diff --git a/scripts/fantasytalking/train.py b/scripts/fantasytalking/train.py index d943a6ff..fdee4277 100644 --- a/scripts/fantasytalking/train.py +++ b/scripts/fantasytalking/train.py @@ -1584,6 +1584,7 @@ def _create_special_list(length): torch.cuda.empty_cache() vae.to(accelerator.device) clip_image_encoder.to(accelerator.device) + audio_encoder.to(accelerator.device) if not args.enable_text_encoder_in_dataloader: text_encoder.to("cpu") @@ -1638,6 +1639,14 @@ def _batch_encode_vae(pixel_values): clip_context.append(_clip_context if not zero_init_clip_in else torch.zeros_like(_clip_context)) clip_context = torch.cat(clip_context) + + with torch.no_grad(): + # Extract audio emb + audio_wav2vec_fea = [] + for index in range(len(audio)): + _audio_wav2vec_fea = audio_encoder.extract_audio_feat_without_file_load(audio[index], sample_rate[index]) + audio_wav2vec_fea.append(_audio_wav2vec_fea) + audio_wav2vec_fea = torch.cat(audio_wav2vec_fea).to(weight_dtype) # wait for latents = vae.encode(pixel_values) to complete if vae_stream_1 is not None: @@ -1646,6 +1655,7 @@ def _batch_encode_vae(pixel_values): if args.low_vram: vae.to('cpu') clip_image_encoder.to('cpu') + audio_encoder.to("cpu") torch.cuda.empty_cache() if not args.enable_text_encoder_in_dataloader: text_encoder.to(accelerator.device) @@ -1669,14 +1679,6 @@ def _batch_encode_vae(pixel_values): prompt_embeds = text_encoder(text_input_ids.to(latents.device), attention_mask=prompt_attention_mask.to(latents.device))[0] prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] - with torch.no_grad(): - # Extract audio emb - audio_wav2vec_fea = [] - for index in range(len(audio)): - _audio_wav2vec_fea = audio_encoder.extract_audio_feat_without_file_load(audio[index], sample_rate[index]) - audio_wav2vec_fea.append(_audio_wav2vec_fea) - audio_wav2vec_fea = torch.cat(audio_wav2vec_fea).to(weight_dtype) - if args.low_vram and not args.enable_text_encoder_in_dataloader: text_encoder.to('cpu') torch.cuda.empty_cache() diff --git a/scripts/longcatvideo/train_avatar.py b/scripts/longcatvideo/train_avatar.py index 0525d0c8..49a72194 100644 --- a/scripts/longcatvideo/train_avatar.py +++ b/scripts/longcatvideo/train_avatar.py @@ -1402,6 +1402,7 @@ def _create_special_list(length): vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) if not args.enable_text_encoder_in_dataloader: text_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) + audio_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -1625,6 +1626,7 @@ def _create_special_list(length): if args.low_vram: torch.cuda.empty_cache() vae.to(accelerator.device) + audio_encoder.to(accelerator.device) if not args.enable_text_encoder_in_dataloader: text_encoder.to("cpu") @@ -1683,6 +1685,7 @@ def _batch_encode_vae(pixel_values): if args.low_vram: vae.to('cpu') + audio_encoder.to("cpu") torch.cuda.empty_cache() if not args.enable_text_encoder_in_dataloader: text_encoder.to(accelerator.device) diff --git a/scripts/longcatvideo/train_avatar_lora.py b/scripts/longcatvideo/train_avatar_lora.py index 12db804a..e0fc214f 100644 --- a/scripts/longcatvideo/train_avatar_lora.py +++ b/scripts/longcatvideo/train_avatar_lora.py @@ -1375,6 +1375,7 @@ def _create_special_list(length): transformer3d.to(accelerator.device, dtype=weight_dtype) if not args.enable_text_encoder_in_dataloader: text_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) + audio_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -1664,6 +1665,7 @@ def _create_special_list(length): if args.low_vram: torch.cuda.empty_cache() vae.to(accelerator.device) + audio_encoder.to(accelerator.device) if not args.enable_text_encoder_in_dataloader: text_encoder.to("cpu") @@ -1722,6 +1724,7 @@ def _batch_encode_vae(pixel_values): if args.low_vram: vae.to('cpu') + audio_encoder.to("cpu") torch.cuda.empty_cache() if not args.enable_text_encoder_in_dataloader: text_encoder.to(accelerator.device) diff --git a/scripts/wan2.2/train_s2v.py b/scripts/wan2.2/train_s2v.py index 4cb5b63d..09d7d642 100644 --- a/scripts/wan2.2/train_s2v.py +++ b/scripts/wan2.2/train_s2v.py @@ -1673,9 +1673,9 @@ def _create_special_list(length): if args.low_vram: torch.cuda.empty_cache() vae.to(accelerator.device) + audio_encoder.to(accelerator.device) if not args.enable_text_encoder_in_dataloader: text_encoder.to("cpu") - audio_encoder.to("cpu") with torch.no_grad(): # This way is quicker when batch grows up @@ -1764,32 +1764,6 @@ def _batch_encode_vae(pixel_values): motion_latents = _batch_encode_vae(motion_pixel_values) drop_motion_frames = True - if args.low_vram: - vae.to('cpu') - torch.cuda.empty_cache() - if not args.enable_text_encoder_in_dataloader: - text_encoder.to(accelerator.device) - audio_encoder.to(accelerator.device) - - if args.enable_text_encoder_in_dataloader: - prompt_embeds = batch['encoder_hidden_states'].to(device=latents.device) - else: - with torch.no_grad(): - prompt_ids = tokenizer( - batch['text'], - padding="max_length", - max_length=args.tokenizer_max_length, - truncation=True, - add_special_tokens=True, - return_tensors="pt" - ) - text_input_ids = prompt_ids.input_ids - prompt_attention_mask = prompt_ids.attention_mask - - seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long() - prompt_embeds = text_encoder(text_input_ids.to(latents.device), attention_mask=prompt_attention_mask.to(latents.device))[0] - prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] - with torch.no_grad(): # Extract audio emb new_audio_wav2vec_fea = [] @@ -1826,9 +1800,34 @@ def _batch_encode_vae(pixel_values): audio_wav2vec_fea[..., zero_frames_num:] = torch.zeros_like(audio_wav2vec_fea[..., zero_frames_num:]) # audio_wav2vec_fea = audio_wav2vec_fea[..., :control_pixel_values.size()[1]] + if args.low_vram: + vae.to('cpu') + audio_encoder.to("cpu") + torch.cuda.empty_cache() + if not args.enable_text_encoder_in_dataloader: + text_encoder.to(accelerator.device) + + if args.enable_text_encoder_in_dataloader: + prompt_embeds = batch['encoder_hidden_states'].to(device=latents.device) + else: + with torch.no_grad(): + prompt_ids = tokenizer( + batch['text'], + padding="max_length", + max_length=args.tokenizer_max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt" + ) + text_input_ids = prompt_ids.input_ids + prompt_attention_mask = prompt_ids.attention_mask + + seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long() + prompt_embeds = text_encoder(text_input_ids.to(latents.device), attention_mask=prompt_attention_mask.to(latents.device))[0] + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + if args.low_vram and not args.enable_text_encoder_in_dataloader: text_encoder.to('cpu') - audio_encoder.to("cpu") torch.cuda.empty_cache() bsz, channel, num_frames, height, width = latents.size() diff --git a/scripts/wan2.2/train_s2v_lora.py b/scripts/wan2.2/train_s2v_lora.py index 078602cd..1d36b1d8 100644 --- a/scripts/wan2.2/train_s2v_lora.py +++ b/scripts/wan2.2/train_s2v_lora.py @@ -1683,9 +1683,9 @@ def _create_special_list(length): if args.low_vram: torch.cuda.empty_cache() vae.to(accelerator.device) + audio_encoder.to(accelerator.device) if not args.enable_text_encoder_in_dataloader: text_encoder.to("cpu") - audio_encoder.to("cpu") with torch.no_grad(): # This way is quicker when batch grows up @@ -1774,32 +1774,6 @@ def _batch_encode_vae(pixel_values): motion_latents = _batch_encode_vae(motion_pixel_values) drop_motion_frames = True - if args.low_vram: - vae.to('cpu') - torch.cuda.empty_cache() - if not args.enable_text_encoder_in_dataloader: - text_encoder.to(accelerator.device) - audio_encoder.to(accelerator.device) - - if args.enable_text_encoder_in_dataloader: - prompt_embeds = batch['encoder_hidden_states'].to(device=latents.device) - else: - with torch.no_grad(): - prompt_ids = tokenizer( - batch['text'], - padding="max_length", - max_length=args.tokenizer_max_length, - truncation=True, - add_special_tokens=True, - return_tensors="pt" - ) - text_input_ids = prompt_ids.input_ids - prompt_attention_mask = prompt_ids.attention_mask - - seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long() - prompt_embeds = text_encoder(text_input_ids.to(latents.device), attention_mask=prompt_attention_mask.to(latents.device))[0] - prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] - with torch.no_grad(): # Extract audio emb new_audio_wav2vec_fea = [] @@ -1836,9 +1810,34 @@ def _batch_encode_vae(pixel_values): audio_wav2vec_fea[..., zero_frames_num:] = torch.zeros_like(audio_wav2vec_fea[..., zero_frames_num:]) # audio_wav2vec_fea = audio_wav2vec_fea[..., :control_pixel_values.size()[1]] + if args.low_vram: + vae.to('cpu') + audio_encoder.to("cpu") + torch.cuda.empty_cache() + if not args.enable_text_encoder_in_dataloader: + text_encoder.to(accelerator.device) + + if args.enable_text_encoder_in_dataloader: + prompt_embeds = batch['encoder_hidden_states'].to(device=latents.device) + else: + with torch.no_grad(): + prompt_ids = tokenizer( + batch['text'], + padding="max_length", + max_length=args.tokenizer_max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt" + ) + text_input_ids = prompt_ids.input_ids + prompt_attention_mask = prompt_ids.attention_mask + + seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long() + prompt_embeds = text_encoder(text_input_ids.to(latents.device), attention_mask=prompt_attention_mask.to(latents.device))[0] + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + if args.low_vram and not args.enable_text_encoder_in_dataloader: text_encoder.to('cpu') - audio_encoder.to("cpu") torch.cuda.empty_cache() bsz, channel, num_frames, height, width = latents.size() From 5f465506bd82214febe5165eae84367db6b5aa44 Mon Sep 17 00:00:00 2001 From: bubbliiiing <3323290568@qq.com> Date: Thu, 12 Mar 2026 16:57:30 +0800 Subject: [PATCH 03/10] Update LTX2, Update pipelines, Reformat S2V codes --- examples/longcatvideo/predict_i2v.py | 1 - examples/longcatvideo/predict_s2v_avatar.py | 17 +- examples/longcatvideo/predict_t2v.py | 1 - examples/ltx2/predict_i2v.py | 331 ++++ examples/ltx2/predict_t2v.py | 327 ++++ examples/wan2.2/predict_animate.py | 11 +- examples/wan2.2/predict_s2v.py | 12 +- scripts/fantasytalking/train.py | 6 +- scripts/wan2.2/train_animate.py | 2 +- scripts/wan2.2/train_animate_lora.py | 2 +- scripts/wan2.2/train_s2v.py | 2 +- scripts/wan2.2/train_s2v_lora.py | 2 +- videox_fun/data/bucket_sampler.py | 28 +- videox_fun/models/__init__.py | 5 + .../models/longcatvideo_audio_encoder.py | 2 +- videox_fun/models/ltx2_connecter.py | 325 ++++ videox_fun/models/ltx2_transformer3d.py | 1352 +++++++++++++++ videox_fun/models/ltx2_vae.py | 1519 +++++++++++++++++ videox_fun/models/ltx2_vae_audio.py | 802 +++++++++ videox_fun/models/ltx2_vocoder.py | 158 ++ videox_fun/models/wan_audio_injector.py | 4 +- videox_fun/pipeline/__init__.py | 2 + videox_fun/pipeline/pipeline_cogvideox_fun.py | 12 +- .../pipeline_cogvideox_fun_control.py | 12 +- .../pipeline_cogvideox_fun_inpaint.py | 12 +- .../pipeline/pipeline_fantasy_talking.py | 13 +- videox_fun/pipeline/pipeline_hunyuanvideo.py | 12 +- .../pipeline/pipeline_hunyuanvideo_i2v.py | 12 +- videox_fun/pipeline/pipeline_longcatvideo.py | 13 +- .../pipeline/pipeline_longcatvideo_avatar.py | 16 +- videox_fun/pipeline/pipeline_ltx2.py | 1259 ++++++++++++++ videox_fun/pipeline/pipeline_ltx2_i2v.py | 1299 ++++++++++++++ videox_fun/pipeline/pipeline_wan.py | 12 +- videox_fun/pipeline/pipeline_wan2_2.py | 12 +- .../pipeline/pipeline_wan2_2_animate.py | 42 +- .../pipeline/pipeline_wan2_2_fun_control.py | 12 +- .../pipeline/pipeline_wan2_2_fun_inpaint.py | 12 +- videox_fun/pipeline/pipeline_wan2_2_s2v.py | 40 +- videox_fun/pipeline/pipeline_wan2_2_ti2v.py | 12 +- .../pipeline/pipeline_wan2_2_vace_fun.py | 12 +- .../pipeline/pipeline_wan_fun_control.py | 12 +- .../pipeline/pipeline_wan_fun_inpaint.py | 12 +- videox_fun/pipeline/pipeline_wan_phantom.py | 12 +- videox_fun/pipeline/pipeline_wan_vace.py | 12 +- 44 files changed, 7565 insertions(+), 208 deletions(-) create mode 100644 examples/ltx2/predict_i2v.py create mode 100644 examples/ltx2/predict_t2v.py create mode 100644 videox_fun/models/ltx2_connecter.py create mode 100644 videox_fun/models/ltx2_transformer3d.py create mode 100644 videox_fun/models/ltx2_vae.py create mode 100644 videox_fun/models/ltx2_vae_audio.py create mode 100644 videox_fun/models/ltx2_vocoder.py create mode 100644 videox_fun/pipeline/pipeline_ltx2.py create mode 100644 videox_fun/pipeline/pipeline_ltx2_i2v.py diff --git a/examples/longcatvideo/predict_i2v.py b/examples/longcatvideo/predict_i2v.py index b9ae53ef..f51ad28d 100644 --- a/examples/longcatvideo/predict_i2v.py +++ b/examples/longcatvideo/predict_i2v.py @@ -149,7 +149,6 @@ if GPU_memory_mode == "sequential_cpu_offload": replace_parameters_by_name(transformer, ["modulation",], device=device) - transformer.freqs = transformer.freqs.to(device=device) pipeline.enable_sequential_cpu_offload(device=device) elif GPU_memory_mode == "model_cpu_offload_and_qfloat8": convert_model_weight_to_float8(transformer, exclude_module_name=["modulation",], device=device) diff --git a/examples/longcatvideo/predict_s2v_avatar.py b/examples/longcatvideo/predict_s2v_avatar.py index 3f57b4c0..0681f647 100644 --- a/examples/longcatvideo/predict_s2v_avatar.py +++ b/examples/longcatvideo/predict_s2v_avatar.py @@ -1,9 +1,7 @@ -import math import os import sys from pathlib import Path -import librosa import numpy as np import torch from audio_separator.separator import Separator @@ -18,9 +16,9 @@ from videox_fun.dist import set_multi_gpus_devices, shard_model from videox_fun.models import (AutoencoderKLLongCatVideo, AutoTokenizer, + LongCatVideoAudioEncoder, LongCatVideoAvatarTransformer3DModel, - UMT5EncoderModel, Wav2Vec2FeatureExtractor, - Wav2Vec2ModelWrapper) + UMT5EncoderModel) from videox_fun.models.cache_utils import get_teacache_coefficients from videox_fun.pipeline import LongCatVideoAvatarPipeline from videox_fun.utils.fm_solvers import FlowDPMSolverMultistepScheduler @@ -136,15 +134,10 @@ ) # Get Audio encoder (for avatar mode) -audio_encoder = Wav2Vec2ModelWrapper( +audio_encoder = LongCatVideoAudioEncoder( os.path.join(model_name_avatar, 'chinese-wav2vec2-base') ) -audio_encoder.feature_extractor._freeze_parameters() - -wav2vec_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( - os.path.join(model_name_avatar, 'chinese-wav2vec2-base'), - local_files_only=True -) +audio_encoder.audio_encoder.feature_extractor._freeze_parameters() # Get Scheduler Chosen_Scheduler = scheduler_dict = { @@ -165,7 +158,6 @@ text_encoder=text_encoder, scheduler=scheduler, audio_encoder=audio_encoder, - wav2vec_feature_extractor=wav2vec_feature_extractor, ) if compile_dit: @@ -175,7 +167,6 @@ if GPU_memory_mode == "sequential_cpu_offload": replace_parameters_by_name(transformer, ["modulation",], device=device) - transformer.freqs = transformer.freqs.to(device=device) pipeline.enable_sequential_cpu_offload(device=device) elif GPU_memory_mode == "model_cpu_offload_and_qfloat8": convert_model_weight_to_float8(transformer, exclude_module_name=["modulation",], device=device) diff --git a/examples/longcatvideo/predict_t2v.py b/examples/longcatvideo/predict_t2v.py index 8172a965..a03f291c 100644 --- a/examples/longcatvideo/predict_t2v.py +++ b/examples/longcatvideo/predict_t2v.py @@ -147,7 +147,6 @@ if GPU_memory_mode == "sequential_cpu_offload": replace_parameters_by_name(transformer, ["modulation",], device=device) - transformer.freqs = transformer.freqs.to(device=device) pipeline.enable_sequential_cpu_offload(device=device) elif GPU_memory_mode == "model_cpu_offload_and_qfloat8": convert_model_weight_to_float8(transformer, exclude_module_name=["modulation",], device=device) diff --git a/examples/ltx2/predict_i2v.py b/examples/ltx2/predict_i2v.py new file mode 100644 index 00000000..78edb216 --- /dev/null +++ b/examples/ltx2/predict_i2v.py @@ -0,0 +1,331 @@ +import os +import sys + +import av +import numpy as np +import torch +from diffusers import FlowMatchEulerDiscreteScheduler +from PIL import Image +from transformers import Gemma3ForConditionalGeneration, GemmaTokenizerFast + +current_file_path = os.path.abspath(__file__) +project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))] +for project_root in project_roots: + sys.path.insert(0, project_root) if project_root not in sys.path else None + +from videox_fun.models import (AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video, + LTX2TextConnectors, LTX2VideoTransformer3DModel, + LTX2Vocoder) +from videox_fun.pipeline import LTX2I2VPipeline +from videox_fun.utils.fm_solvers import FlowDPMSolverMultistepScheduler +from videox_fun.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler +from videox_fun.utils.utils import merge_video_audio, save_videos_grid + +# GPU memory mode, which can be chosen in [model_full_load, model_full_load_and_qfloat8, model_cpu_offload, model_cpu_offload_and_qfloat8, sequential_cpu_offload]. +# model_full_load means that the entire model will be moved to the GPU. +# +# model_full_load_and_qfloat8 means that the entire model will be moved to the GPU, +# and the transformer model has been quantized to float8, which can save more GPU memory. +# +# model_cpu_offload means that the entire model will be moved to the CPU after use, which can save some GPU memory. +# +# model_cpu_offload_and_qfloat8 indicates that the entire model will be moved to the CPU after use, +# and the transformer model has been quantized to float8, which can save more GPU memory. +# +# sequential_cpu_offload means that each layer of the model will be moved to the CPU after use, +# resulting in slower speeds but saving a large amount of GPU memory. +GPU_memory_mode = "model_full_load" +# Compile will give a speedup in fixed resolution and need a little GPU memory. +# The compile_dit is not compatible with sequential_cpu_offload. +compile_dit = False + +# model path +model_name = "models/Diffusion_Transformer/LTX-2" +# Choose the sampler in "Flow", "Flow_Unipc", "Flow_DPM++" +sampler_name = "Flow" + +# Load pretrained model if need +transformer_path = None +vae_path = None +lora_path = None + +# Other params +sample_size = [480, 832] +video_length = 121 +fps = 24 + +# Use torch.float16 if GPU does not support torch.bfloat16 +# ome graphics cards, such as v100, 2080ti, do not support torch.bfloat16 +weight_dtype = torch.bfloat16 +# If you want to generate from text, please set the validation_image_start = None and validation_image_end = None +validation_image_start = "asset/1.png" + +# prompts +prompt = "A brown dog barks on a sofa, sitting on a light-colored couch in a cozy room. Behind the dog, there is a framed painting on a shelf, surrounded by pink flowers. " +negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted, static, low quality, artifacts" +guidance_scale = 6.0 +seed = 43 +num_inference_steps = 50 +lora_weight = 0.55 +save_path = "samples/ltx2-videos-t2v" + +# Audio sample rate will be read from vocoder config +audio_sample_rate = 24000 + +device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + +# Transformer +transformer = LTX2VideoTransformer3DModel.from_pretrained( + model_name, + subfolder="transformer", + low_cpu_mem_usage=True, + torch_dtype=weight_dtype, +) + +if transformer_path is not None: + print(f"From checkpoint: {transformer_path}") + if transformer_path.endswith("safetensors"): + from safetensors.torch import load_file, safe_open + state_dict = load_file(transformer_path) + else: + state_dict = torch.load(transformer_path, map_location="cpu") + state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict + + m, u = transformer.load_state_dict(state_dict, strict=False) + print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") + +# Video VAE +vae = AutoencoderKLLTX2Video.from_pretrained( + model_name, + subfolder="vae", + torch_dtype=weight_dtype, +) + +if vae_path is not None: + print(f"From checkpoint: {vae_path}") + if vae_path.endswith("safetensors"): + from safetensors.torch import load_file, safe_open + state_dict = load_file(vae_path) + else: + state_dict = torch.load(vae_path, map_location="cpu") + state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict + + m, u = vae.load_state_dict(state_dict, strict=False) + print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") + +# Audio VAE +audio_vae = AutoencoderKLLTX2Audio.from_pretrained( + model_name, + subfolder="audio_vae", + torch_dtype=weight_dtype, +) + +# Get Tokenizer +tokenizer = GemmaTokenizerFast.from_pretrained( + model_name, + subfolder="tokenizer", +) + +# Get Text encoder +text_encoder = Gemma3ForConditionalGeneration.from_pretrained( + model_name, + subfolder="text_encoder", + low_cpu_mem_usage=True, + torch_dtype=weight_dtype, +) +text_encoder = text_encoder.eval() + +# Connectors +connectors = LTX2TextConnectors.from_pretrained( + model_name, + subfolder="connectors", + torch_dtype=weight_dtype, +) + +# Vocoder +vocoder = LTX2Vocoder.from_pretrained( + model_name, + subfolder="vocoder", + torch_dtype=weight_dtype, +) + +# Get Scheduler +Chosen_Scheduler = { + "Flow": FlowMatchEulerDiscreteScheduler, + "Flow_Unipc": FlowUniPCMultistepScheduler, + "Flow_DPM++": FlowDPMSolverMultistepScheduler, +}[sampler_name] +scheduler = Chosen_Scheduler.from_pretrained( + model_name, + subfolder="scheduler" +) + +pipeline = LTX2I2VPipeline( + scheduler=scheduler, + vae=vae, + audio_vae=audio_vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + connectors=connectors, + transformer=transformer, + vocoder=vocoder, +) + +if compile_dit: + for i in range(len(pipeline.transformer.transformer_blocks)): + pipeline.transformer.transformer_blocks[i] = torch.compile(pipeline.transformer.transformer_blocks[i]) + print("Add Compile") + +if GPU_memory_mode == "sequential_cpu_offload": + pipeline.enable_sequential_cpu_offload(device=device) +elif GPU_memory_mode == "model_group_offload": + register_auto_device_hook(pipeline.transformer) + safe_enable_group_offload(pipeline, onload_device=device, offload_device="cpu", offload_type="leaf_level", use_stream=True) +elif GPU_memory_mode == "model_cpu_offload_and_qfloat8": + convert_model_weight_to_float8(transformer, exclude_module_name=["img_in", "txt_in", "timestep"], device=device) + convert_weight_dtype_wrapper(transformer, weight_dtype) + pipeline.enable_model_cpu_offload(device=device) +elif GPU_memory_mode == "model_cpu_offload": + pipeline.enable_model_cpu_offload(device=device) +elif GPU_memory_mode == "model_full_load_and_qfloat8": + convert_model_weight_to_float8(transformer, exclude_module_name=["img_in", "txt_in", "timestep"], device=device) + convert_weight_dtype_wrapper(transformer, weight_dtype) + pipeline.to(device=device) +else: + pipeline.to(device=device) + +generator = torch.Generator(device=device).manual_seed(seed) + +if lora_path is not None: + pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) + +with torch.no_grad(): + output = pipeline( + image=Image.open(validation_image_start), + prompt=prompt, + negative_prompt=negative_prompt, + height=sample_size[0], + width=sample_size[1], + num_frames=video_length, + frame_rate=fps, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + generator=generator, + output_type="pt", + ) + +if lora_path is not None: + pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) + +sample = output.videos +audio = output.audio + +def _prepare_audio_stream(container, audio_sample_rate): + from fractions import Fraction + audio_stream = container.add_stream("aac", rate=audio_sample_rate) + audio_stream.codec_context.sample_rate = audio_sample_rate + audio_stream.codec_context.layout = "stereo" + audio_stream.codec_context.time_base = Fraction(1, audio_sample_rate) + return audio_stream + + +def _write_audio(container, audio_stream, samples, audio_sample_rate): + if samples.ndim == 1: + samples = samples[:, None] + if samples.shape[1] != 2 and samples.shape[0] == 2: + samples = samples.T + if samples.shape[1] != 2: + # mono -> duplicate to stereo + samples = samples.expand(-1, 2) + + if samples.dtype != torch.int16: + samples = torch.clip(samples, -1.0, 1.0) + samples = (samples * 32767.0).to(torch.int16) + + frame_in = av.AudioFrame.from_ndarray( + samples.contiguous().reshape(1, -1).cpu().numpy(), + format="s16", + layout="stereo", + ) + frame_in.sample_rate = audio_sample_rate + + cc = audio_stream.codec_context + target_format = cc.format or "fltp" + target_layout = cc.layout or "stereo" + target_rate = cc.sample_rate or frame_in.sample_rate + + resampler = av.audio.resampler.AudioResampler( + format=target_format, + layout=target_layout, + rate=target_rate, + ) + + audio_next_pts = 0 + for rframe in resampler.resample(frame_in): + if rframe.pts is None: + rframe.pts = audio_next_pts + audio_next_pts += rframe.samples + rframe.sample_rate = frame_in.sample_rate + container.mux(audio_stream.encode(rframe)) + + for packet in audio_stream.encode(): + container.mux(packet) + + +def save_results(): + if not os.path.exists(save_path): + os.makedirs(save_path, exist_ok=True) + + index = len([path for path in os.listdir(save_path)]) + 1 + prefix = str(index).zfill(8) + + if video_length == 1: + out_path = os.path.join(save_path, prefix + ".png") + image = sample[0, :, 0] + image = image.transpose(0, 1).transpose(1, 2) + image = (image * 255).numpy().astype(np.uint8) + Image.fromarray(image).save(out_path) + print(f"Saved image to: {out_path}") + return + + import torchvision + from einops import rearrange + + frames_t = rearrange(sample, "b c t h w -> t b c h w") + frame_list = [] + for x in frames_t: + x = torchvision.utils.make_grid(x, nrow=6) + x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) + x = (x * 255).numpy().astype(np.uint8) + frame_list.append(x) + + height, width = frame_list[0].shape[:2] + audio_tensor = audio[0].float().cpu() + if audio_tensor.ndim == 1: + audio_tensor = audio_tensor.unsqueeze(-1) + if audio_tensor.shape[1] != 2 and audio_tensor.shape[0] == 2: + audio_tensor = audio_tensor.T + sr = getattr(pipeline.vocoder.config, "output_sampling_rate", audio_sample_rate) + + out_path = os.path.join(save_path, prefix + "_with_audio.mp4") + container = av.open(out_path, mode="w") + v_stream = container.add_stream("libx264", rate=int(fps)) + v_stream.width = width + v_stream.height = height + v_stream.pix_fmt = "yuv420p" + + a_stream = _prepare_audio_stream(container, sr) + + for frame_np in frame_list: + frame = av.VideoFrame.from_ndarray(frame_np, format="rgb24") + for pkt in v_stream.encode(frame): + container.mux(pkt) + for pkt in v_stream.encode(): + container.mux(pkt) + + _write_audio(container, a_stream, audio_tensor, sr) + + container.close() + print(f"Saved merged video+audio to: {out_path}") + +save_results() \ No newline at end of file diff --git a/examples/ltx2/predict_t2v.py b/examples/ltx2/predict_t2v.py new file mode 100644 index 00000000..3ac1b22b --- /dev/null +++ b/examples/ltx2/predict_t2v.py @@ -0,0 +1,327 @@ +import os +import shutil +import sys + +import av +import numpy as np +import torch +from diffusers import FlowMatchEulerDiscreteScheduler +from PIL import Image +from transformers import Gemma3ForConditionalGeneration, GemmaTokenizerFast + +current_file_path = os.path.abspath(__file__) +project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))] +for project_root in project_roots: + sys.path.insert(0, project_root) if project_root not in sys.path else None + +from videox_fun.models import (AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video, + LTX2TextConnectors, LTX2VideoTransformer3DModel, + LTX2Vocoder) +from videox_fun.pipeline import LTX2Pipeline +from videox_fun.utils.fm_solvers import FlowDPMSolverMultistepScheduler +from videox_fun.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler +from videox_fun.utils.utils import merge_video_audio, save_videos_grid + +# GPU memory mode, which can be chosen in [model_full_load, model_full_load_and_qfloat8, model_cpu_offload, model_cpu_offload_and_qfloat8, sequential_cpu_offload]. +# model_full_load means that the entire model will be moved to the GPU. +# +# model_full_load_and_qfloat8 means that the entire model will be moved to the GPU, +# and the transformer model has been quantized to float8, which can save more GPU memory. +# +# model_cpu_offload means that the entire model will be moved to the CPU after use, which can save some GPU memory. +# +# model_cpu_offload_and_qfloat8 indicates that the entire model will be moved to the CPU after use, +# and the transformer model has been quantized to float8, which can save more GPU memory. +# +# sequential_cpu_offload means that each layer of the model will be moved to the CPU after use, +# resulting in slower speeds but saving a large amount of GPU memory. +GPU_memory_mode = "sequential_cpu_offload" +# Compile will give a speedup in fixed resolution and need a little GPU memory. +# The compile_dit is not compatible with sequential_cpu_offload. +compile_dit = False + +# model path +model_name = "models/Diffusion_Transformer/LTX-2" +# Choose the sampler in "Flow", "Flow_Unipc", "Flow_DPM++" +sampler_name = "Flow" + +# Load pretrained model if need +transformer_path = None +vae_path = None +lora_path = None + +# Other params +sample_size = [512, 768] +video_length = 121 +fps = 24 + +# Use torch.float16 if GPU does not support torch.bfloat16 +# ome graphics cards, such as v100, 2080ti, do not support torch.bfloat16 +weight_dtype = torch.bfloat16 +prompt = "A brown dog barks on a sofa, sitting on a light-colored couch in a cozy room. Behind the dog, there is a framed painting on a shelf, surrounded by pink flowers. " +negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted, static, low quality, artifacts" +guidance_scale = 6.0 +seed = 43 +num_inference_steps = 50 +lora_weight = 0.55 +save_path = "samples/ltx2-videos-t2v" + +# Audio sample rate will be read from vocoder config +audio_sample_rate = 24000 + +device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + +# Transformer +transformer = LTX2VideoTransformer3DModel.from_pretrained( + model_name, + subfolder="transformer", + low_cpu_mem_usage=True, + torch_dtype=weight_dtype, +) + +if transformer_path is not None: + print(f"From checkpoint: {transformer_path}") + if transformer_path.endswith("safetensors"): + from safetensors.torch import load_file, safe_open + state_dict = load_file(transformer_path) + else: + state_dict = torch.load(transformer_path, map_location="cpu") + state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict + + m, u = transformer.load_state_dict(state_dict, strict=False) + print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") + +# Video VAE +vae = AutoencoderKLLTX2Video.from_pretrained( + model_name, + subfolder="vae", + torch_dtype=weight_dtype, +) + +if vae_path is not None: + print(f"From checkpoint: {vae_path}") + if vae_path.endswith("safetensors"): + from safetensors.torch import load_file, safe_open + state_dict = load_file(vae_path) + else: + state_dict = torch.load(vae_path, map_location="cpu") + state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict + + m, u = vae.load_state_dict(state_dict, strict=False) + print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") + +# Audio VAE +audio_vae = AutoencoderKLLTX2Audio.from_pretrained( + model_name, + subfolder="audio_vae", + torch_dtype=weight_dtype, +) + +# Get Tokenizer +tokenizer = GemmaTokenizerFast.from_pretrained( + model_name, + subfolder="tokenizer", +) + +# Get Text encoder +text_encoder = Gemma3ForConditionalGeneration.from_pretrained( + model_name, + subfolder="text_encoder", + low_cpu_mem_usage=True, + torch_dtype=weight_dtype, +) +text_encoder = text_encoder.eval() + +# Connectors +connectors = LTX2TextConnectors.from_pretrained( + model_name, + subfolder="connectors", + torch_dtype=weight_dtype, +) + +# Vocoder +vocoder = LTX2Vocoder.from_pretrained( + model_name, + subfolder="vocoder", + torch_dtype=weight_dtype, +) + +# Get Scheduler +Chosen_Scheduler = { + "Flow": FlowMatchEulerDiscreteScheduler, + "Flow_Unipc": FlowUniPCMultistepScheduler, + "Flow_DPM++": FlowDPMSolverMultistepScheduler, +}[sampler_name] +scheduler = Chosen_Scheduler.from_pretrained( + model_name, + subfolder="scheduler" +) + +pipeline = LTX2Pipeline( + scheduler=scheduler, + vae=vae, + audio_vae=audio_vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + connectors=connectors, + transformer=transformer, + vocoder=vocoder, +) + +if compile_dit: + for i in range(len(pipeline.transformer.transformer_blocks)): + pipeline.transformer.transformer_blocks[i] = torch.compile(pipeline.transformer.transformer_blocks[i]) + print("Add Compile") + +if GPU_memory_mode == "sequential_cpu_offload": + pipeline.enable_sequential_cpu_offload(device=device) +elif GPU_memory_mode == "model_group_offload": + register_auto_device_hook(pipeline.transformer) + safe_enable_group_offload(pipeline, onload_device=device, offload_device="cpu", offload_type="leaf_level", use_stream=True) +elif GPU_memory_mode == "model_cpu_offload_and_qfloat8": + convert_model_weight_to_float8(transformer, exclude_module_name=["img_in", "txt_in", "timestep"], device=device) + convert_weight_dtype_wrapper(transformer, weight_dtype) + pipeline.enable_model_cpu_offload(device=device) +elif GPU_memory_mode == "model_cpu_offload": + pipeline.enable_model_cpu_offload(device=device) +elif GPU_memory_mode == "model_full_load_and_qfloat8": + convert_model_weight_to_float8(transformer, exclude_module_name=["img_in", "txt_in", "timestep"], device=device) + convert_weight_dtype_wrapper(transformer, weight_dtype) + pipeline.to(device=device) +else: + pipeline.to(device=device) + +generator = torch.Generator(device=device).manual_seed(seed) + +if lora_path is not None: + pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) + +with torch.no_grad(): + output = pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + height=sample_size[0], + width=sample_size[1], + num_frames=video_length, + frame_rate=fps, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + generator=generator, + output_type="pt", + ) + +if lora_path is not None: + pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) + +sample = output.videos +audio = output.audio + +def _prepare_audio_stream(container, audio_sample_rate): + from fractions import Fraction + audio_stream = container.add_stream("aac", rate=audio_sample_rate) + audio_stream.codec_context.sample_rate = audio_sample_rate + audio_stream.codec_context.layout = "stereo" + audio_stream.codec_context.time_base = Fraction(1, audio_sample_rate) + return audio_stream + + +def _write_audio(container, audio_stream, samples, audio_sample_rate): + if samples.ndim == 1: + samples = samples[:, None] + if samples.shape[1] != 2 and samples.shape[0] == 2: + samples = samples.T + if samples.shape[1] != 2: + # mono -> duplicate to stereo + samples = samples.expand(-1, 2) + + if samples.dtype != torch.int16: + samples = torch.clip(samples, -1.0, 1.0) + samples = (samples * 32767.0).to(torch.int16) + + frame_in = av.AudioFrame.from_ndarray( + samples.contiguous().reshape(1, -1).cpu().numpy(), + format="s16", + layout="stereo", + ) + frame_in.sample_rate = audio_sample_rate + + cc = audio_stream.codec_context + target_format = cc.format or "fltp" + target_layout = cc.layout or "stereo" + target_rate = cc.sample_rate or frame_in.sample_rate + + resampler = av.audio.resampler.AudioResampler( + format=target_format, + layout=target_layout, + rate=target_rate, + ) + + audio_next_pts = 0 + for rframe in resampler.resample(frame_in): + if rframe.pts is None: + rframe.pts = audio_next_pts + audio_next_pts += rframe.samples + rframe.sample_rate = frame_in.sample_rate + container.mux(audio_stream.encode(rframe)) + + for packet in audio_stream.encode(): + container.mux(packet) + + +def save_results(): + if not os.path.exists(save_path): + os.makedirs(save_path, exist_ok=True) + + index = len([path for path in os.listdir(save_path)]) + 1 + prefix = str(index).zfill(8) + + if video_length == 1: + out_path = os.path.join(save_path, prefix + ".png") + image = sample[0, :, 0] + image = image.transpose(0, 1).transpose(1, 2) + image = (image * 255).numpy().astype(np.uint8) + Image.fromarray(image).save(out_path) + print(f"Saved image to: {out_path}") + return + + import torchvision + from einops import rearrange + + frames_t = rearrange(sample, "b c t h w -> t b c h w") + frame_list = [] + for x in frames_t: + x = torchvision.utils.make_grid(x, nrow=6) + x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) + x = (x * 255).numpy().astype(np.uint8) + frame_list.append(x) + + height, width = frame_list[0].shape[:2] + audio_tensor = audio[0].float().cpu() + if audio_tensor.ndim == 1: + audio_tensor = audio_tensor.unsqueeze(-1) + if audio_tensor.shape[1] != 2 and audio_tensor.shape[0] == 2: + audio_tensor = audio_tensor.T + sr = getattr(pipeline.vocoder.config, "output_sampling_rate", audio_sample_rate) + + out_path = os.path.join(save_path, prefix + "_with_audio.mp4") + container = av.open(out_path, mode="w") + v_stream = container.add_stream("libx264", rate=int(fps)) + v_stream.width = width + v_stream.height = height + v_stream.pix_fmt = "yuv420p" + + a_stream = _prepare_audio_stream(container, sr) + + for frame_np in frame_list: + frame = av.VideoFrame.from_ndarray(frame_np, format="rgb24") + for pkt in v_stream.encode(frame): + container.mux(pkt) + for pkt in v_stream.encode(): + container.mux(pkt) + + _write_audio(container, a_stream, audio_tensor, sr) + + container.close() + print(f"Saved merged video+audio to: {out_path}") + +save_results() \ No newline at end of file diff --git a/examples/wan2.2/predict_animate.py b/examples/wan2.2/predict_animate.py index c209c887..d8544265 100644 --- a/examples/wan2.2/predict_animate.py +++ b/examples/wan2.2/predict_animate.py @@ -106,9 +106,12 @@ src_mask_path = os.path.join(src_root_path, "src_mask.mp4") # Other params -sample_size = [480, 832] -video_length = 81 -fps = 16 +sample_size = [480, 832] +# Total num frames +video_length = 81 +# How many frames to generate per clips. +segment_frame_length = 77 +fps = 16 # Use torch.float16 if GPU does not support torch.bfloat16 # ome graphics cards, such as v100, 2080ti, do not support torch.bfloat16 @@ -331,7 +334,7 @@ sample = pipeline( prompt, - num_frames = video_length, + segment_frame_length = segment_frame_length, negative_prompt = negative_prompt, height = sample_size[0], width = sample_size[1], diff --git a/examples/wan2.2/predict_s2v.py b/examples/wan2.2/predict_s2v.py index 2b5dcf13..97e80a5f 100644 --- a/examples/wan2.2/predict_s2v.py +++ b/examples/wan2.2/predict_s2v.py @@ -103,10 +103,10 @@ lora_high_path = None # Other params -sample_size = [832, 480] +sample_size = [832, 480] # How many frames to generate per clips. -infer_frames = 80 -fps = 16 +segment_frame_length = 80 +fps = 16 # Use torch.float16 if GPU does not support torch.bfloat16 # ome graphics cards, such as v100, 2080ti, do not support torch.bfloat16 @@ -312,8 +312,8 @@ pipeline = merge_lora(pipeline, lora_high_path, lora_high_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") with torch.no_grad(): - infer_frames = infer_frames // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio if infer_frames != 1 else 1 - latent_frames = infer_frames // vae.config.temporal_compression_ratio + segment_frame_length = segment_frame_length // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio if segment_frame_length != 1 else 1 + latent_frames = segment_frame_length // vae.config.temporal_compression_ratio if enable_riflex: pipeline.transformer.enable_riflex(k = riflex_k, L_test = latent_frames) @@ -327,7 +327,7 @@ sample = pipeline( prompt, - num_frames = infer_frames, + segment_frame_length = segment_frame_length, negative_prompt = negative_prompt, height = sample_size[0], width = sample_size[1], diff --git a/scripts/fantasytalking/train.py b/scripts/fantasytalking/train.py index fdee4277..df856352 100644 --- a/scripts/fantasytalking/train.py +++ b/scripts/fantasytalking/train.py @@ -190,7 +190,7 @@ def log_validation(vae, text_encoder, tokenizer, clip_image_encoder, audio_encod sample = pipeline( args.validation_prompts[i], num_frames = video_length, - negative_prompt = "bad detailed", + negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", height = height, width = width, generator = generator, @@ -201,8 +201,8 @@ def log_validation(vae, text_encoder, tokenizer, clip_image_encoder, audio_encod mask_video = input_video_mask, clip_image = clip_image, audio_path = audio_path, - shift = 5, - fps = 16 + shift = 3, + fps = 23 ).videos os.makedirs(os.path.join(args.output_dir, "sample"), exist_ok=True) save_videos_grid( diff --git a/scripts/wan2.2/train_animate.py b/scripts/wan2.2/train_animate.py index 015e0321..4e106e89 100644 --- a/scripts/wan2.2/train_animate.py +++ b/scripts/wan2.2/train_animate.py @@ -242,7 +242,7 @@ def log_validation(vae, text_encoder, tokenizer, clip_image_encoder, transformer sample = pipeline( args.validation_prompts[i], - num_frames = video_length, + segment_frame_length = 77, negative_prompt = "bad detailed", height = height, width = width, diff --git a/scripts/wan2.2/train_animate_lora.py b/scripts/wan2.2/train_animate_lora.py index 30540020..18fb3232 100644 --- a/scripts/wan2.2/train_animate_lora.py +++ b/scripts/wan2.2/train_animate_lora.py @@ -249,7 +249,7 @@ def log_validation(vae, text_encoder, tokenizer, clip_image_encoder, transformer sample = pipeline( args.validation_prompts[i], - num_frames = video_length, + segment_frame_length = 77, negative_prompt = "bad detailed", height = height, width = width, diff --git a/scripts/wan2.2/train_s2v.py b/scripts/wan2.2/train_s2v.py index 09d7d642..a4fcb0a5 100644 --- a/scripts/wan2.2/train_s2v.py +++ b/scripts/wan2.2/train_s2v.py @@ -197,7 +197,7 @@ def log_validation(vae, text_encoder, tokenizer, audio_encoder, transformer3d, a sample = pipeline( args.validation_prompts[i], - num_frames = args.video_sample_n_frames, + segment_frame_length = args.video_sample_n_frames, negative_prompt = "bad detailed", height = height, width = width, diff --git a/scripts/wan2.2/train_s2v_lora.py b/scripts/wan2.2/train_s2v_lora.py index 1d36b1d8..4f930c5e 100644 --- a/scripts/wan2.2/train_s2v_lora.py +++ b/scripts/wan2.2/train_s2v_lora.py @@ -207,7 +207,7 @@ def log_validation(vae, text_encoder, tokenizer, audio_encoder, transformer3d, n sample = pipeline( args.validation_prompts[i], - num_frames = args.video_sample_n_frames, + segment_frame_length = args.video_sample_n_frames, negative_prompt = "bad detailed", height = height, width = width, diff --git a/videox_fun/data/bucket_sampler.py b/videox_fun/data/bucket_sampler.py index 24b4160f..4cfc3239 100755 --- a/videox_fun/data/bucket_sampler.py +++ b/videox_fun/data/bucket_sampler.py @@ -9,6 +9,7 @@ from PIL import Image from torch.utils.data import BatchSampler, Dataset, Sampler + ASPECT_RATIO_512 = { '0.25': [256.0, 1024.0], '0.26': [256.0, 992.0], '0.27': [256.0, 960.0], '0.28': [256.0, 928.0], '0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0], @@ -37,15 +38,18 @@ ] ASPECT_RATIO_RANDOM_CROP_PROB = np.array(ASPECT_RATIO_RANDOM_CROP_PROB) / sum(ASPECT_RATIO_RANDOM_CROP_PROB) + def get_closest_ratio(height: float, width: float, ratios: dict = ASPECT_RATIO_512): aspect_ratio = height / width closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio)) return ratios[closest_ratio], float(closest_ratio) + def get_image_size_without_loading(path): with Image.open(path) as img: return img.size # (width, height) + class RandomSampler(Sampler[int]): r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset. @@ -56,18 +60,22 @@ class RandomSampler(Sampler[int]): replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False`` num_samples (int): number of samples to draw, default=`len(dataset)`. generator (Generator): Generator used in sampling. + k_repeat (int): number of times to repeat each sampled index consecutively, default=1. + When k_repeat > 1, each index is yielded k_repeat times in a row, + so a batch of size B will contain B // k_repeat unique samples. """ data_source: Sized replacement: bool def __init__(self, data_source: Sized, replacement: bool = False, - num_samples: Optional[int] = None, generator=None) -> None: + num_samples: Optional[int] = None, generator=None, k_repeat: int = 1) -> None: self.data_source = data_source self.replacement = replacement self._num_samples = num_samples self.generator = generator self._pos_start = 0 + self.k_repeat = k_repeat if not isinstance(self.replacement, bool): raise TypeError(f"replacement should be a boolean value, but got replacement={self.replacement}") @@ -93,8 +101,12 @@ def __iter__(self) -> Iterator[int]: if self.replacement: for _ in range(self.num_samples // 32): - yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist() - yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist() + for idx in torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist(): + for _ in range(self.k_repeat): + yield idx + for idx in torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist(): + for _ in range(self.k_repeat): + yield idx else: for _ in range(self.num_samples // n): xx = torch.randperm(n, generator=generator).tolist() @@ -102,13 +114,17 @@ def __iter__(self) -> Iterator[int]: self._pos_start = 0 print("xx top 10", xx[:10], self._pos_start) for idx in range(self._pos_start, n): - yield xx[idx] + for _ in range(self.k_repeat): + yield xx[idx] self._pos_start = (self._pos_start + 1) % n self._pos_start = 0 - yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n] + for idx in torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]: + for _ in range(self.k_repeat): + yield idx def __len__(self) -> int: - return self.num_samples + return self.num_samples * self.k_repeat + class AspectRatioBatchImageSampler(BatchSampler): """A sampler wrapper for grouping images with similar aspect ratio into a same batch. diff --git a/videox_fun/models/__init__.py b/videox_fun/models/__init__.py index 6b5a14ef..4c538c13 100755 --- a/videox_fun/models/__init__.py +++ b/videox_fun/models/__init__.py @@ -36,6 +36,11 @@ from .longcatvideo_transformer3d_avatar import \ LongCatVideoAvatarTransformer3DModel from .longcatvideo_vae import AutoencoderKLLongCatVideo +from .ltx2_transformer3d import LTX2VideoTransformer3DModel +from .ltx2_vae import AutoencoderKLLTX2Video +from .ltx2_vae_audio import AutoencoderKLLTX2Audio +from .ltx2_connecter import LTX2TextConnectors +from .ltx2_vocoder import LTX2Vocoder from .qwenimage_transformer2d import QwenImageTransformer2DModel from .qwenimage_transformer2d_control import QwenImageControlTransformer2DModel from .qwenimage_transformer2d_instantx import QwenImageInstantXControlNetModel diff --git a/videox_fun/models/longcatvideo_audio_encoder.py b/videox_fun/models/longcatvideo_audio_encoder.py index c05433dc..920ad463 100644 --- a/videox_fun/models/longcatvideo_audio_encoder.py +++ b/videox_fun/models/longcatvideo_audio_encoder.py @@ -350,7 +350,7 @@ def _extract_embedding(self, speech_array, sample_rate, num_frames, audio_stride audio_feature = np.squeeze( self.wav2vec_feature_extractor(speech_array, sampling_rate=sample_rate).input_values ) - audio_feature = torch.from_numpy(audio_feature).float().to(device=self.device) + audio_feature = torch.from_numpy(audio_feature).float().to(device=self.device, dtype=self.dtype) audio_feature = audio_feature.unsqueeze(0) # audio embedding using Wav2Vec2ModelWrapper diff --git a/videox_fun/models/ltx2_connecter.py b/videox_fun/models/ltx2_connecter.py new file mode 100644 index 00000000..985583f9 --- /dev/null +++ b/videox_fun/models/ltx2_connecter.py @@ -0,0 +1,325 @@ +# Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/ltx2/connectors.py +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import PeftAdapterMixin +from diffusers.models.attention import FeedForward +from diffusers.models.modeling_utils import ModelMixin + +from .ltx2_transformer3d import LTX2Attention, LTX2AudioVideoAttnProcessor + + +class LTX2RotaryPosEmbed1d(nn.Module): + """ + 1D rotary positional embeddings (RoPE) for the LTX 2.0 text encoder connectors. + """ + + def __init__( + self, + dim: int, + base_seq_len: int = 4096, + theta: float = 10000.0, + double_precision: bool = True, + rope_type: str = "interleaved", + num_attention_heads: int = 32, + ): + super().__init__() + if rope_type not in ["interleaved", "split"]: + raise ValueError(f"{rope_type=} not supported. Choose between 'interleaved' and 'split'.") + + self.dim = dim + self.base_seq_len = base_seq_len + self.theta = theta + self.double_precision = double_precision + self.rope_type = rope_type + self.num_attention_heads = num_attention_heads + + def forward( + self, + batch_size: int, + pos: int, + device: str | torch.device, + ) -> tuple[torch.Tensor, torch.Tensor]: + # 1. Get 1D position ids + grid_1d = torch.arange(pos, dtype=torch.float32, device=device) + # Get fractional indices relative to self.base_seq_len + grid_1d = grid_1d / self.base_seq_len + grid = grid_1d.unsqueeze(0).repeat(batch_size, 1) # [batch_size, seq_len] + + # 2. Calculate 1D RoPE frequencies + num_rope_elems = 2 # 1 (because 1D) * 2 (for cos, sin) = 2 + freqs_dtype = torch.float64 if self.double_precision else torch.float32 + pow_indices = torch.pow( + self.theta, + torch.linspace(start=0.0, end=1.0, steps=self.dim // num_rope_elems, dtype=freqs_dtype, device=device), + ) + freqs = (pow_indices * torch.pi / 2.0).to(dtype=torch.float32) + + # 3. Matrix-vector outer product between pos ids of shape (batch_size, seq_len) and freqs vector of shape + # (self.dim // 2,). + freqs = (grid.unsqueeze(-1) * 2 - 1) * freqs # [B, seq_len, self.dim // 2] + + # 4. Get real, interleaved (cos, sin) frequencies, padded to self.dim + if self.rope_type == "interleaved": + cos_freqs = freqs.cos().repeat_interleave(2, dim=-1) + sin_freqs = freqs.sin().repeat_interleave(2, dim=-1) + + if self.dim % num_rope_elems != 0: + cos_padding = torch.ones_like(cos_freqs[:, :, : self.dim % num_rope_elems]) + sin_padding = torch.zeros_like(sin_freqs[:, :, : self.dim % num_rope_elems]) + cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1) + sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1) + + elif self.rope_type == "split": + expected_freqs = self.dim // 2 + current_freqs = freqs.shape[-1] + pad_size = expected_freqs - current_freqs + cos_freq = freqs.cos() + sin_freq = freqs.sin() + + if pad_size != 0: + cos_padding = torch.ones_like(cos_freq[:, :, :pad_size]) + sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size]) + + cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1) + sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1) + + # Reshape freqs to be compatible with multi-head attention + b = cos_freq.shape[0] + t = cos_freq.shape[1] + + cos_freq = cos_freq.reshape(b, t, self.num_attention_heads, -1) + sin_freq = sin_freq.reshape(b, t, self.num_attention_heads, -1) + + cos_freqs = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2) + sin_freqs = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2) + + return cos_freqs, sin_freqs + + +class LTX2TransformerBlock1d(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + activation_fn: str = "gelu-approximate", + eps: float = 1e-6, + rope_type: str = "interleaved", + ): + super().__init__() + + self.norm1 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False) + self.attn1 = LTX2Attention( + query_dim=dim, + heads=num_attention_heads, + kv_heads=num_attention_heads, + dim_head=attention_head_dim, + processor=LTX2AudioVideoAttnProcessor(), + rope_type=rope_type, + ) + + self.norm2 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False) + self.ff = FeedForward(dim, activation_fn=activation_fn) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + rotary_emb: torch.Tensor | None = None, + ) -> torch.Tensor: + norm_hidden_states = self.norm1(hidden_states) + attn_hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, query_rotary_emb=rotary_emb) + hidden_states = hidden_states + attn_hidden_states + + norm_hidden_states = self.norm2(hidden_states) + ff_hidden_states = self.ff(norm_hidden_states) + hidden_states = hidden_states + ff_hidden_states + + return hidden_states + + +class LTX2ConnectorTransformer1d(nn.Module): + """ + A 1D sequence transformer for modalities such as text. + + In LTX 2.0, this is used to process the text encoder hidden states for each of the video and audio streams. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + num_attention_heads: int = 30, + attention_head_dim: int = 128, + num_layers: int = 2, + num_learnable_registers: int | None = 128, + rope_base_seq_len: int = 4096, + rope_theta: float = 10000.0, + rope_double_precision: bool = True, + eps: float = 1e-6, + causal_temporal_positioning: bool = False, + rope_type: str = "interleaved", + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.inner_dim = num_attention_heads * attention_head_dim + self.causal_temporal_positioning = causal_temporal_positioning + + self.num_learnable_registers = num_learnable_registers + self.learnable_registers = None + if num_learnable_registers is not None: + init_registers = torch.rand(num_learnable_registers, self.inner_dim) * 2.0 - 1.0 + self.learnable_registers = torch.nn.Parameter(init_registers) + + self.rope = LTX2RotaryPosEmbed1d( + self.inner_dim, + base_seq_len=rope_base_seq_len, + theta=rope_theta, + double_precision=rope_double_precision, + rope_type=rope_type, + num_attention_heads=num_attention_heads, + ) + + self.transformer_blocks = torch.nn.ModuleList( + [ + LTX2TransformerBlock1d( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + rope_type=rope_type, + ) + for _ in range(num_layers) + ] + ) + + self.norm_out = torch.nn.RMSNorm(self.inner_dim, eps=eps, elementwise_affine=False) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + attn_mask_binarize_threshold: float = -9000.0, + ) -> tuple[torch.Tensor, torch.Tensor]: + # hidden_states shape: [batch_size, seq_len, hidden_dim] + # attention_mask shape: [batch_size, seq_len] or [batch_size, 1, 1, seq_len] + batch_size, seq_len, _ = hidden_states.shape + + # 1. Replace padding with learned registers, if using + if self.learnable_registers is not None: + if seq_len % self.num_learnable_registers != 0: + raise ValueError( + f"The `hidden_states` sequence length {hidden_states.shape[1]} should be divisible by the number" + f" of learnable registers {self.num_learnable_registers}" + ) + + num_register_repeats = seq_len // self.num_learnable_registers + registers = torch.tile(self.learnable_registers, (num_register_repeats, 1)) # [seq_len, inner_dim] + + binary_attn_mask = (attention_mask >= attn_mask_binarize_threshold).int() + if binary_attn_mask.ndim == 4: + binary_attn_mask = binary_attn_mask.squeeze(1).squeeze(1) # [B, 1, 1, L] --> [B, L] + + hidden_states_non_padded = [hidden_states[i, binary_attn_mask[i].bool(), :] for i in range(batch_size)] + valid_seq_lens = [x.shape[0] for x in hidden_states_non_padded] + pad_lengths = [seq_len - valid_seq_len for valid_seq_len in valid_seq_lens] + padded_hidden_states = [ + F.pad(x, pad=(0, 0, 0, p), value=0) for x, p in zip(hidden_states_non_padded, pad_lengths) + ] + padded_hidden_states = torch.cat([x.unsqueeze(0) for x in padded_hidden_states], dim=0) # [B, L, D] + + flipped_mask = torch.flip(binary_attn_mask, dims=[1]).unsqueeze(-1) # [B, L, 1] + hidden_states = flipped_mask * padded_hidden_states + (1 - flipped_mask) * registers + + # Overwrite attention_mask with an all-zeros mask if using registers. + attention_mask = torch.zeros_like(attention_mask) + + # 2. Calculate 1D RoPE positional embeddings + rotary_emb = self.rope(batch_size, seq_len, device=hidden_states.device) + + # 3. Run 1D transformer blocks + for block in self.transformer_blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(block, hidden_states, attention_mask, rotary_emb) + else: + hidden_states = block(hidden_states, attention_mask=attention_mask, rotary_emb=rotary_emb) + + hidden_states = self.norm_out(hidden_states) + + return hidden_states, attention_mask + + +class LTX2TextConnectors(ModelMixin, PeftAdapterMixin, ConfigMixin): + """ + Text connector stack used by LTX 2.0 to process the packed text encoder hidden states for both the video and audio + streams. + """ + + @register_to_config + def __init__( + self, + caption_channels: int, + text_proj_in_factor: int, + video_connector_num_attention_heads: int, + video_connector_attention_head_dim: int, + video_connector_num_layers: int, + video_connector_num_learnable_registers: int | None, + audio_connector_num_attention_heads: int, + audio_connector_attention_head_dim: int, + audio_connector_num_layers: int, + audio_connector_num_learnable_registers: int | None, + connector_rope_base_seq_len: int, + rope_theta: float, + rope_double_precision: bool, + causal_temporal_positioning: bool, + rope_type: str = "interleaved", + ): + super().__init__() + self.text_proj_in = nn.Linear(caption_channels * text_proj_in_factor, caption_channels, bias=False) + self.video_connector = LTX2ConnectorTransformer1d( + num_attention_heads=video_connector_num_attention_heads, + attention_head_dim=video_connector_attention_head_dim, + num_layers=video_connector_num_layers, + num_learnable_registers=video_connector_num_learnable_registers, + rope_base_seq_len=connector_rope_base_seq_len, + rope_theta=rope_theta, + rope_double_precision=rope_double_precision, + causal_temporal_positioning=causal_temporal_positioning, + rope_type=rope_type, + ) + self.audio_connector = LTX2ConnectorTransformer1d( + num_attention_heads=audio_connector_num_attention_heads, + attention_head_dim=audio_connector_attention_head_dim, + num_layers=audio_connector_num_layers, + num_learnable_registers=audio_connector_num_learnable_registers, + rope_base_seq_len=connector_rope_base_seq_len, + rope_theta=rope_theta, + rope_double_precision=rope_double_precision, + causal_temporal_positioning=causal_temporal_positioning, + rope_type=rope_type, + ) + + def forward( + self, text_encoder_hidden_states: torch.Tensor, attention_mask: torch.Tensor, additive_mask: bool = False + ): + # Convert to additive attention mask, if necessary + if not additive_mask: + text_dtype = text_encoder_hidden_states.dtype + attention_mask = (attention_mask - 1).reshape(attention_mask.shape[0], 1, -1, attention_mask.shape[-1]) + attention_mask = attention_mask.to(text_dtype) * torch.finfo(text_dtype).max + + text_encoder_hidden_states = self.text_proj_in(text_encoder_hidden_states) + + video_text_embedding, new_attn_mask = self.video_connector(text_encoder_hidden_states, attention_mask) + + attn_mask = (new_attn_mask < 1e-6).to(torch.int64) + attn_mask = attn_mask.reshape(video_text_embedding.shape[0], video_text_embedding.shape[1], 1) + video_text_embedding = video_text_embedding * attn_mask + new_attn_mask = attn_mask.squeeze(-1) + + audio_text_embedding, _ = self.audio_connector(text_encoder_hidden_states, attention_mask) + + return video_text_embedding, audio_text_embedding, new_attn_mask \ No newline at end of file diff --git a/videox_fun/models/ltx2_transformer3d.py b/videox_fun/models/ltx2_transformer3d.py new file mode 100644 index 00000000..96588f44 --- /dev/null +++ b/videox_fun/models/ltx2_transformer3d.py @@ -0,0 +1,1352 @@ +# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_ltx2.py +# Copyright 2025 The Lightricks team and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Any + +import torch +import torch.nn as nn +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin +from diffusers.loaders.single_file_model import FromOriginalModelMixin +from diffusers.models.attention import FeedForward +from diffusers.models.embeddings import ( + PixArtAlphaCombinedTimestepSizeEmbeddings, PixArtAlphaTextProjection) +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import RMSNorm +from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging +from diffusers.utils.outputs import BaseOutput + +from .attention_utils import attention + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def apply_interleaved_rotary_emb(x: torch.Tensor, freqs: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + cos, sin = freqs + x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, C // 2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2) + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + return out + + +def apply_split_rotary_emb(x: torch.Tensor, freqs: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + cos, sin = freqs + + x_dtype = x.dtype + needs_reshape = False + if x.ndim != 4 and cos.ndim == 4: + # cos is (b, h, t, r) -> reshape x to (b, h, t, dim_per_head) + b, h, t, _ = cos.shape + x = x.reshape(b, t, h, -1).swapaxes(1, 2) + needs_reshape = True + + # Split last dim (2*r) into (d=2, r) + last = x.shape[-1] + if last % 2 != 0: + raise ValueError(f"Expected x.shape[-1] to be even for split rotary, got {last}.") + r = last // 2 + + # (..., 2, r) + split_x = x.reshape(*x.shape[:-1], 2, r).float() # Explicitly upcast to float + first_x = split_x[..., :1, :] # (..., 1, r) + second_x = split_x[..., 1:, :] # (..., 1, r) + + cos_u = cos.unsqueeze(-2) # broadcast to (..., 1, r) against (..., 2, r) + sin_u = sin.unsqueeze(-2) + + out = split_x * cos_u + first_out = out[..., :1, :] + second_out = out[..., 1:, :] + + first_out.addcmul_(-sin_u, second_x) + second_out.addcmul_(sin_u, first_x) + + out = out.reshape(*out.shape[:-2], last) + + if needs_reshape: + out = out.swapaxes(1, 2).reshape(b, t, -1) + + out = out.to(dtype=x_dtype) + return out + + +@dataclass +class AudioVisualModelOutput(BaseOutput): + r""" + Holds the output of an audiovisual model which produces both visual (e.g. video) and audio outputs. + + Args: + sample (`torch.Tensor` of shape `(batch_size, num_channels, num_frames, height, width)`): + The hidden states output conditioned on the `encoder_hidden_states` input, representing the visual output + of the model. This is typically a video (spatiotemporal) output. + audio_sample (`torch.Tensor` of shape `(batch_size, TODO)`): + The audio output of the audiovisual model. + """ + + sample: "torch.Tensor" # noqa: F821 + audio_sample: "torch.Tensor" # noqa: F821 + + +class LTX2AdaLayerNormSingle(nn.Module): + r""" + Norm layer adaptive layer norm single (adaLN-single). + + As proposed in PixArt-Alpha (see: https://huggingface.co/papers/2310.00426; Section 2.3) and adapted by the LTX-2.0 + model. In particular, the number of modulation parameters to be calculated is now configurable. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_mod_params (`int`, *optional*, defaults to `6`): + The number of modulation parameters which will be calculated in the first return argument. The default of 6 + is standard, but sometimes we may want to have a different (usually smaller) number of modulation + parameters. + use_additional_conditions (`bool`, *optional*, defaults to `False`): + Whether to use additional conditions for normalization or not. + """ + + def __init__(self, embedding_dim: int, num_mod_params: int = 6, use_additional_conditions: bool = False): + super().__init__() + self.num_mod_params = num_mod_params + + self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings( + embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions + ) + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, self.num_mod_params * embedding_dim, bias=True) + + def forward( + self, + timestep: torch.Tensor, + added_cond_kwargs: dict[str, torch.Tensor] | None = None, + batch_size: int | None = None, + hidden_dtype: torch.dtype | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # No modulation happening here. + added_cond_kwargs = added_cond_kwargs or {"resolution": None, "aspect_ratio": None} + embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype) + return self.linear(self.silu(embedded_timestep)), embedded_timestep + + +class LTX2AudioVideoAttnProcessor: + r""" + Processor for implementing attention (SDPA is used by default if you're using PyTorch 2.0) for the LTX-2.0 model. + Compared to the LTX-1.0 model, we allow the RoPE embeddings for the queries and keys to be separate so that we can + support audio-to-video (a2v) and video-to-audio (v2a) cross attention. + """ + + _attention_backend = None + _parallel_config = None + + def __init__(self): + if is_torch_version("<", "2.0"): + raise ValueError( + "LTX attention processors require a minimum PyTorch version of 2.0. Please upgrade your PyTorch installation." + ) + + def __call__( + self, + attn: "LTX2Attention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + query_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + key_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> torch.Tensor: + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + if query_rotary_emb is not None: + if attn.rope_type == "interleaved": + query = apply_interleaved_rotary_emb(query, query_rotary_emb) + key = apply_interleaved_rotary_emb( + key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb + ) + elif attn.rope_type == "split": + query = apply_split_rotary_emb(query, query_rotary_emb) + key = apply_split_rotary_emb(key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb) + + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) + + hidden_states = attention( + q=query, + k=key, + v=value, + attn_mask=attention_mask, + dropout_p=0.0, + causal=False, + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +class LTX2Attention(torch.nn.Module): + r""" + Attention class for all LTX-2.0 attention layers. Compared to LTX-1.0, this supports specifying the query and key + RoPE embeddings separately for audio-to-video (a2v) and video-to-audio (v2a) cross-attention. + """ + + _default_processor_cls = LTX2AudioVideoAttnProcessor + _available_processors = [LTX2AudioVideoAttnProcessor] + + def __init__( + self, + query_dim: int, + heads: int = 8, + kv_heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = True, + cross_attention_dim: int | None = None, + out_bias: bool = True, + qk_norm: str = "rms_norm_across_heads", + norm_eps: float = 1e-6, + norm_elementwise_affine: bool = True, + rope_type: str = "interleaved", + processor=None, + ): + super().__init__() + if qk_norm != "rms_norm_across_heads": + raise NotImplementedError("Only 'rms_norm_across_heads' is supported as a valid value for `qk_norm`.") + + self.head_dim = dim_head + self.inner_dim = dim_head * heads + self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads + self.query_dim = query_dim + self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.use_bias = bias + self.dropout = dropout + self.out_dim = query_dim + self.heads = heads + self.rope_type = rope_type + + self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine) + self.norm_k = torch.nn.RMSNorm(dim_head * kv_heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine) + self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_k = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) + self.to_v = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) + self.to_out = torch.nn.ModuleList([]) + self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(torch.nn.Dropout(dropout)) + + if processor is None: + processor = self._default_processor_cls() + self.processor = processor + + def prepare_attention_mask( + self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3 + ) -> torch.Tensor: + """ + Prepare the attention mask for the attention computation. + + Args: + attention_mask (`torch.Tensor`): The attention mask to prepare. + target_length (`int`): The target length of the attention mask. + batch_size (`int`): The batch size for repeating the attention mask. + out_dim (`int`, *optional*, defaults to `3`): Output dimension. + + Returns: + `torch.Tensor`: The prepared attention mask. + """ + head_size = self.heads + if attention_mask is None: + return attention_mask + + current_length: int = attention_mask.shape[-1] + if current_length != target_length: + if attention_mask.device.type == "mps": + # HACK: MPS: Does not support padding by greater than dimension of input tensor. + # Instead, we can manually construct the padding tensor. + padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length) + padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) + attention_mask = torch.cat([attention_mask, padding], dim=2) + else: + # TODO: for pipelines such as stable-diffusion, padding cross-attn mask: + # we want to instead pad by (0, remaining_length), where remaining_length is: + # remaining_length: int = target_length - current_length + # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + + if out_dim == 3: + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = attention_mask.repeat_interleave(head_size, dim=0) + elif out_dim == 4: + attention_mask = attention_mask.unsqueeze(1) + attention_mask = attention_mask.repeat_interleave(head_size, dim=1) + + return attention_mask + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + query_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + key_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs, + ) -> torch.Tensor: + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters] + if len(unused_kwargs) > 0: + logger.warning( + f"attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} + hidden_states = self.processor( + self, hidden_states, encoder_hidden_states, attention_mask, query_rotary_emb, key_rotary_emb, **kwargs + ) + return hidden_states + + +class LTX2VideoTransformerBlock(nn.Module): + r""" + Transformer block used in [LTX-2.0](https://huggingface.co/Lightricks/LTX-Video). + + Args: + dim (`int`): + The number of channels in the input and output. + num_attention_heads (`int`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`): + The number of channels in each head. + qk_norm (`str`, defaults to `"rms_norm"`): + The normalization layer to use. + activation_fn (`str`, defaults to `"gelu-approximate"`): + Activation function to use in feed-forward. + eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + cross_attention_dim: int, + audio_dim: int, + audio_num_attention_heads: int, + audio_attention_head_dim, + audio_cross_attention_dim: int, + qk_norm: str = "rms_norm_across_heads", + activation_fn: str = "gelu-approximate", + attention_bias: bool = True, + attention_out_bias: bool = True, + eps: float = 1e-6, + elementwise_affine: bool = False, + rope_type: str = "interleaved", + ): + super().__init__() + + # 1. Self-Attention (video and audio) + self.norm1 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.attn1 = LTX2Attention( + query_dim=dim, + heads=num_attention_heads, + kv_heads=num_attention_heads, + dim_head=attention_head_dim, + bias=attention_bias, + cross_attention_dim=None, + out_bias=attention_out_bias, + qk_norm=qk_norm, + rope_type=rope_type, + ) + + self.audio_norm1 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine) + self.audio_attn1 = LTX2Attention( + query_dim=audio_dim, + heads=audio_num_attention_heads, + kv_heads=audio_num_attention_heads, + dim_head=audio_attention_head_dim, + bias=attention_bias, + cross_attention_dim=None, + out_bias=attention_out_bias, + qk_norm=qk_norm, + rope_type=rope_type, + ) + + # 2. Prompt Cross-Attention + self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.attn2 = LTX2Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + kv_heads=num_attention_heads, + dim_head=attention_head_dim, + bias=attention_bias, + out_bias=attention_out_bias, + qk_norm=qk_norm, + rope_type=rope_type, + ) + + self.audio_norm2 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine) + self.audio_attn2 = LTX2Attention( + query_dim=audio_dim, + cross_attention_dim=audio_cross_attention_dim, + heads=audio_num_attention_heads, + kv_heads=audio_num_attention_heads, + dim_head=audio_attention_head_dim, + bias=attention_bias, + out_bias=attention_out_bias, + qk_norm=qk_norm, + rope_type=rope_type, + ) + + # 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention + # Audio-to-Video (a2v) Attention --> Q: Video; K,V: Audio + self.audio_to_video_norm = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.audio_to_video_attn = LTX2Attention( + query_dim=dim, + cross_attention_dim=audio_dim, + heads=audio_num_attention_heads, + kv_heads=audio_num_attention_heads, + dim_head=audio_attention_head_dim, + bias=attention_bias, + out_bias=attention_out_bias, + qk_norm=qk_norm, + rope_type=rope_type, + ) + + # Video-to-Audio (v2a) Attention --> Q: Audio; K,V: Video + self.video_to_audio_norm = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine) + self.video_to_audio_attn = LTX2Attention( + query_dim=audio_dim, + cross_attention_dim=dim, + heads=audio_num_attention_heads, + kv_heads=audio_num_attention_heads, + dim_head=audio_attention_head_dim, + bias=attention_bias, + out_bias=attention_out_bias, + qk_norm=qk_norm, + rope_type=rope_type, + ) + + # 4. Feedforward layers + self.norm3 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.ff = FeedForward(dim, activation_fn=activation_fn) + + self.audio_norm3 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine) + self.audio_ff = FeedForward(audio_dim, activation_fn=activation_fn) + + # 5. Per-Layer Modulation Parameters + # Self-Attention / Feedforward AdaLayerNorm-Zero mod params + self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + self.audio_scale_shift_table = nn.Parameter(torch.randn(6, audio_dim) / audio_dim**0.5) + + # Per-layer a2v, v2a Cross-Attention mod params + self.video_a2v_cross_attn_scale_shift_table = nn.Parameter(torch.randn(5, dim)) + self.audio_a2v_cross_attn_scale_shift_table = nn.Parameter(torch.randn(5, audio_dim)) + + def forward( + self, + hidden_states: torch.Tensor, + audio_hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + audio_encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + temb_audio: torch.Tensor, + temb_ca_scale_shift: torch.Tensor, + temb_ca_audio_scale_shift: torch.Tensor, + temb_ca_gate: torch.Tensor, + temb_ca_audio_gate: torch.Tensor, + video_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + audio_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + ca_video_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + ca_audio_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + encoder_attention_mask: torch.Tensor | None = None, + audio_encoder_attention_mask: torch.Tensor | None = None, + a2v_cross_attention_mask: torch.Tensor | None = None, + v2a_cross_attention_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + batch_size = hidden_states.size(0) + + # 1. Video and Audio Self-Attention + norm_hidden_states = self.norm1(hidden_states) + + num_ada_params = self.scale_shift_table.shape[0] + ada_values = self.scale_shift_table[None, None].to(temb.device) + temb.reshape( + batch_size, temb.size(1), num_ada_params, -1 + ) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + + attn_hidden_states = self.attn1( + hidden_states=norm_hidden_states, + encoder_hidden_states=None, + query_rotary_emb=video_rotary_emb, + ) + hidden_states = hidden_states + attn_hidden_states * gate_msa + + norm_audio_hidden_states = self.audio_norm1(audio_hidden_states) + + num_audio_ada_params = self.audio_scale_shift_table.shape[0] + audio_ada_values = self.audio_scale_shift_table[None, None].to(temb_audio.device) + temb_audio.reshape( + batch_size, temb_audio.size(1), num_audio_ada_params, -1 + ) + audio_shift_msa, audio_scale_msa, audio_gate_msa, audio_shift_mlp, audio_scale_mlp, audio_gate_mlp = ( + audio_ada_values.unbind(dim=2) + ) + norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_msa) + audio_shift_msa + + attn_audio_hidden_states = self.audio_attn1( + hidden_states=norm_audio_hidden_states, + encoder_hidden_states=None, + query_rotary_emb=audio_rotary_emb, + ) + audio_hidden_states = audio_hidden_states + attn_audio_hidden_states * audio_gate_msa + + # 2. Video and Audio Cross-Attention with the text embeddings + norm_hidden_states = self.norm2(hidden_states) + attn_hidden_states = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + query_rotary_emb=None, + attention_mask=encoder_attention_mask, + ) + hidden_states = hidden_states + attn_hidden_states + + norm_audio_hidden_states = self.audio_norm2(audio_hidden_states) + attn_audio_hidden_states = self.audio_attn2( + norm_audio_hidden_states, + encoder_hidden_states=audio_encoder_hidden_states, + query_rotary_emb=None, + attention_mask=audio_encoder_attention_mask, + ) + audio_hidden_states = audio_hidden_states + attn_audio_hidden_states + + # 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention + norm_hidden_states = self.audio_to_video_norm(hidden_states) + norm_audio_hidden_states = self.video_to_audio_norm(audio_hidden_states) + + # Combine global and per-layer cross attention modulation parameters + # Video + video_per_layer_ca_scale_shift = self.video_a2v_cross_attn_scale_shift_table[:4, :] + video_per_layer_ca_gate = self.video_a2v_cross_attn_scale_shift_table[4:, :] + + video_ca_scale_shift_table = ( + video_per_layer_ca_scale_shift[:, :, ...].to(temb_ca_scale_shift.dtype) + + temb_ca_scale_shift.reshape(batch_size, temb_ca_scale_shift.shape[1], 4, -1) + ).unbind(dim=2) + video_ca_gate = ( + video_per_layer_ca_gate[:, :, ...].to(temb_ca_gate.dtype) + + temb_ca_gate.reshape(batch_size, temb_ca_gate.shape[1], 1, -1) + ).unbind(dim=2) + + video_a2v_ca_scale, video_a2v_ca_shift, video_v2a_ca_scale, video_v2a_ca_shift = video_ca_scale_shift_table + a2v_gate = video_ca_gate[0].squeeze(2) + + # Audio + audio_per_layer_ca_scale_shift = self.audio_a2v_cross_attn_scale_shift_table[:4, :] + audio_per_layer_ca_gate = self.audio_a2v_cross_attn_scale_shift_table[4:, :] + + audio_ca_scale_shift_table = ( + audio_per_layer_ca_scale_shift[:, :, ...].to(temb_ca_audio_scale_shift.dtype) + + temb_ca_audio_scale_shift.reshape(batch_size, temb_ca_audio_scale_shift.shape[1], 4, -1) + ).unbind(dim=2) + audio_ca_gate = ( + audio_per_layer_ca_gate[:, :, ...].to(temb_ca_audio_gate.dtype) + + temb_ca_audio_gate.reshape(batch_size, temb_ca_audio_gate.shape[1], 1, -1) + ).unbind(dim=2) + + audio_a2v_ca_scale, audio_a2v_ca_shift, audio_v2a_ca_scale, audio_v2a_ca_shift = audio_ca_scale_shift_table + v2a_gate = audio_ca_gate[0].squeeze(2) + + # Audio-to-Video Cross Attention: Q: Video; K,V: Audio + mod_norm_hidden_states = norm_hidden_states * (1 + video_a2v_ca_scale.squeeze(2)) + video_a2v_ca_shift.squeeze( + 2 + ) + mod_norm_audio_hidden_states = norm_audio_hidden_states * ( + 1 + audio_a2v_ca_scale.squeeze(2) + ) + audio_a2v_ca_shift.squeeze(2) + + a2v_attn_hidden_states = self.audio_to_video_attn( + mod_norm_hidden_states, + encoder_hidden_states=mod_norm_audio_hidden_states, + query_rotary_emb=ca_video_rotary_emb, + key_rotary_emb=ca_audio_rotary_emb, + attention_mask=a2v_cross_attention_mask, + ) + + hidden_states = hidden_states + a2v_gate * a2v_attn_hidden_states + + # Video-to-Audio Cross Attention: Q: Audio; K,V: Video + mod_norm_hidden_states = norm_hidden_states * (1 + video_v2a_ca_scale.squeeze(2)) + video_v2a_ca_shift.squeeze( + 2 + ) + mod_norm_audio_hidden_states = norm_audio_hidden_states * ( + 1 + audio_v2a_ca_scale.squeeze(2) + ) + audio_v2a_ca_shift.squeeze(2) + + v2a_attn_hidden_states = self.video_to_audio_attn( + mod_norm_audio_hidden_states, + encoder_hidden_states=mod_norm_hidden_states, + query_rotary_emb=ca_audio_rotary_emb, + key_rotary_emb=ca_video_rotary_emb, + attention_mask=v2a_cross_attention_mask, + ) + + audio_hidden_states = audio_hidden_states + v2a_gate * v2a_attn_hidden_states + + # 4. Feedforward + norm_hidden_states = self.norm3(hidden_states) * (1 + scale_mlp) + shift_mlp + ff_output = self.ff(norm_hidden_states) + hidden_states = hidden_states + ff_output * gate_mlp + + norm_audio_hidden_states = self.audio_norm3(audio_hidden_states) * (1 + audio_scale_mlp) + audio_shift_mlp + audio_ff_output = self.audio_ff(norm_audio_hidden_states) + audio_hidden_states = audio_hidden_states + audio_ff_output * audio_gate_mlp + + return hidden_states, audio_hidden_states + + +class LTX2AudioVideoRotaryPosEmbed(nn.Module): + """ + Video and audio rotary positional embeddings (RoPE) for the LTX-2.0 model. + + Args: + causal_offset (`int`, *optional*, defaults to `1`): + Offset in the temporal axis for causal VAE modeling. This is typically 1 (for causal modeling where the VAE + treats the very first frame differently), but could also be 0 (for non-causal modeling). + """ + + def __init__( + self, + dim: int, + patch_size: int = 1, + patch_size_t: int = 1, + base_num_frames: int = 20, + base_height: int = 2048, + base_width: int = 2048, + sampling_rate: int = 16000, + hop_length: int = 160, + scale_factors: tuple[int, ...] = (8, 32, 32), + theta: float = 10000.0, + causal_offset: int = 1, + modality: str = "video", + double_precision: bool = True, + rope_type: str = "interleaved", + num_attention_heads: int = 32, + ) -> None: + super().__init__() + + self.dim = dim + self.patch_size = patch_size + self.patch_size_t = patch_size_t + + if rope_type not in ["interleaved", "split"]: + raise ValueError(f"{rope_type=} not supported. Choose between 'interleaved' and 'split'.") + self.rope_type = rope_type + + self.base_num_frames = base_num_frames + self.num_attention_heads = num_attention_heads + + # Video-specific + self.base_height = base_height + self.base_width = base_width + + # Audio-specific + self.sampling_rate = sampling_rate + self.hop_length = hop_length + self.audio_latents_per_second = float(sampling_rate) / float(hop_length) / float(scale_factors[0]) + + self.scale_factors = scale_factors + self.theta = theta + self.causal_offset = causal_offset + + self.modality = modality + if self.modality not in ["video", "audio"]: + raise ValueError(f"Modality {modality} is not supported. Supported modalities are `video` and `audio`.") + self.double_precision = double_precision + + def prepare_video_coords( + self, + batch_size: int, + num_frames: int, + height: int, + width: int, + device: torch.device, + fps: float = 24.0, + ) -> torch.Tensor: + """ + Create per-dimension bounds [inclusive start, exclusive end) for each patch with respect to the original pixel + space video grid (num_frames, height, width). This will ultimately have shape (batch_size, 3, num_patches, 2) + where + - axis 1 (size 3) enumerates (frame, height, width) dimensions (e.g. idx 0 corresponds to frames) + - axis 3 (size 2) stores `[start, end)` indices within each dimension + + Args: + batch_size (`int`): + Batch size of the video latents. + num_frames (`int`): + Number of latent frames in the video latents. + height (`int`): + Latent height of the video latents. + width (`int`): + Latent width of the video latents. + device (`torch.device`): + Device on which to create the video grid. + + Returns: + `torch.Tensor`: + Per-dimension patch boundaries tensor of shape [batch_size, 3, num_patches, 2]. + """ + + # 1. Generate grid coordinates for each spatiotemporal dimension (frames, height, width) + # Always compute rope in fp32 + grid_f = torch.arange(start=0, end=num_frames, step=self.patch_size_t, dtype=torch.float32, device=device) + grid_h = torch.arange(start=0, end=height, step=self.patch_size, dtype=torch.float32, device=device) + grid_w = torch.arange(start=0, end=width, step=self.patch_size, dtype=torch.float32, device=device) + # indexing='ij' ensures that the dimensions are kept in order as (frames, height, width) + grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing="ij") + grid = torch.stack(grid, dim=0) # [3, N_F, N_H, N_W], where e.g. N_F is the number of temporal patches + + # 2. Get the patch boundaries with respect to the latent video grid + patch_size = (self.patch_size_t, self.patch_size, self.patch_size) + patch_size_delta = torch.tensor(patch_size, dtype=grid.dtype, device=grid.device) + patch_ends = grid + patch_size_delta.view(3, 1, 1, 1) + + # Combine the start (grid) and end (patch_ends) coordinates along new trailing dimension + latent_coords = torch.stack([grid, patch_ends], dim=-1) # [3, N_F, N_H, N_W, 2] + # Reshape to (batch_size, 3, num_patches, 2) + latent_coords = latent_coords.flatten(1, 3) + latent_coords = latent_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1) + + # 3. Calculate the pixel space patch boundaries from the latent boundaries. + scale_tensor = torch.tensor(self.scale_factors, device=latent_coords.device) + # Broadcast the VAE scale factors such that they are compatible with latent_coords's shape + broadcast_shape = [1] * latent_coords.ndim + broadcast_shape[1] = -1 # This is the (frame, height, width) dim + # Apply per-axis scaling to convert latent coordinates to pixel space coordinates + pixel_coords = latent_coords * scale_tensor.view(*broadcast_shape) + + # As the VAE temporal stride for the first frame is 1 instead of self.vae_scale_factors[0], we need to shift + # and clamp to keep the first-frame timestamps causal and non-negative. + pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + self.causal_offset - self.scale_factors[0]).clamp(min=0) + + # Scale the temporal coordinates by the video FPS + pixel_coords[:, 0, ...] = pixel_coords[:, 0, ...] / fps + + return pixel_coords + + def prepare_audio_coords( + self, + batch_size: int, + num_frames: int, + device: torch.device, + shift: int = 0, + ) -> torch.Tensor: + """ + Create per-dimension bounds [inclusive start, exclusive end) of start and end timestamps for each latent frame. + This will ultimately have shape (batch_size, 3, num_patches, 2) where + - axis 1 (size 1) represents the temporal dimension + - axis 3 (size 2) stores `[start, end)` indices within each dimension + + Args: + batch_size (`int`): + Batch size of the audio latents. + num_frames (`int`): + Number of latent frames in the audio latents. + device (`torch.device`): + Device on which to create the audio grid. + shift (`int`, *optional*, defaults to `0`): + Offset on the latent indices. Different shift values correspond to different overlapping windows with + respect to the same underlying latent grid. + + Returns: + `torch.Tensor`: + Per-dimension patch boundaries tensor of shape [batch_size, 1, num_patches, 2]. + """ + + # 1. Generate coordinates in the frame (time) dimension. + # Always compute rope in fp32 + grid_f = torch.arange( + start=shift, end=num_frames + shift, step=self.patch_size_t, dtype=torch.float32, device=device + ) + + # 2. Calculate start timstamps in seconds with respect to the original spectrogram grid + audio_scale_factor = self.scale_factors[0] + # Scale back to mel spectrogram space + grid_start_mel = grid_f * audio_scale_factor + # Handle first frame causal offset, ensuring non-negative timestamps + grid_start_mel = (grid_start_mel + self.causal_offset - audio_scale_factor).clip(min=0) + # Convert mel bins back into seconds + grid_start_s = grid_start_mel * self.hop_length / self.sampling_rate + + # 3. Calculate start timstamps in seconds with respect to the original spectrogram grid + grid_end_mel = (grid_f + self.patch_size_t) * audio_scale_factor + grid_end_mel = (grid_end_mel + self.causal_offset - audio_scale_factor).clip(min=0) + grid_end_s = grid_end_mel * self.hop_length / self.sampling_rate + + audio_coords = torch.stack([grid_start_s, grid_end_s], dim=-1) # [num_patches, 2] + audio_coords = audio_coords.unsqueeze(0).expand(batch_size, -1, -1) # [batch_size, num_patches, 2] + audio_coords = audio_coords.unsqueeze(1) # [batch_size, 1, num_patches, 2] + return audio_coords + + def prepare_coords(self, *args, **kwargs): + if self.modality == "video": + return self.prepare_video_coords(*args, **kwargs) + elif self.modality == "audio": + return self.prepare_audio_coords(*args, **kwargs) + + def forward( + self, coords: torch.Tensor, device: str | torch.device | None = None + ) -> tuple[torch.Tensor, torch.Tensor]: + device = device or coords.device + + # Number of spatiotemporal dimensions (3 for video, 1 (temporal) for audio and cross attn) + num_pos_dims = coords.shape[1] + + # 1. If the coords are patch boundaries [start, end), use the midpoint of these boundaries as the patch + # position index + if coords.ndim == 4: + coords_start, coords_end = coords.chunk(2, dim=-1) + coords = (coords_start + coords_end) / 2.0 + coords = coords.squeeze(-1) # [B, num_pos_dims, num_patches] + + # 2. Get coordinates as a fraction of the base data shape + if self.modality == "video": + max_positions = (self.base_num_frames, self.base_height, self.base_width) + elif self.modality == "audio": + max_positions = (self.base_num_frames,) + # [B, num_pos_dims, num_patches] --> [B, num_patches, num_pos_dims] + grid = torch.stack([coords[:, i] / max_positions[i] for i in range(num_pos_dims)], dim=-1).to(device) + # Number of spatiotemporal dimensions (3 for video, 1 for audio and cross attn) times 2 for cos, sin + num_rope_elems = num_pos_dims * 2 + + # 3. Create a 1D grid of frequencies for RoPE + freqs_dtype = torch.float64 if self.double_precision else torch.float32 + pow_indices = torch.pow( + self.theta, + torch.linspace(start=0.0, end=1.0, steps=self.dim // num_rope_elems, dtype=freqs_dtype, device=device), + ) + freqs = (pow_indices * torch.pi / 2.0).to(dtype=torch.float32) + + # 4. Tensor-vector outer product between pos ids tensor of shape (B, 3, num_patches) and freqs vector of shape + # (self.dim // num_elems,) + freqs = (grid.unsqueeze(-1) * 2 - 1) * freqs # [B, num_patches, num_pos_dims, self.dim // num_elems] + freqs = freqs.transpose(-1, -2).flatten(2) # [B, num_patches, self.dim // 2] + + # 5. Get real, interleaved (cos, sin) frequencies, padded to self.dim + # TODO: consider implementing this as a utility and reuse in `connectors.py`. + # src/diffusers/pipelines/ltx2/connectors.py + if self.rope_type == "interleaved": + cos_freqs = freqs.cos().repeat_interleave(2, dim=-1) + sin_freqs = freqs.sin().repeat_interleave(2, dim=-1) + + if self.dim % num_rope_elems != 0: + cos_padding = torch.ones_like(cos_freqs[:, :, : self.dim % num_rope_elems]) + sin_padding = torch.zeros_like(cos_freqs[:, :, : self.dim % num_rope_elems]) + cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1) + sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1) + + elif self.rope_type == "split": + expected_freqs = self.dim // 2 + current_freqs = freqs.shape[-1] + pad_size = expected_freqs - current_freqs + cos_freq = freqs.cos() + sin_freq = freqs.sin() + + if pad_size != 0: + cos_padding = torch.ones_like(cos_freq[:, :, :pad_size]) + sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size]) + + cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1) + sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1) + + # Reshape freqs to be compatible with multi-head attention + b = cos_freq.shape[0] + t = cos_freq.shape[1] + + cos_freq = cos_freq.reshape(b, t, self.num_attention_heads, -1) + sin_freq = sin_freq.reshape(b, t, self.num_attention_heads, -1) + + cos_freqs = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2) + sin_freqs = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2) + + return cos_freqs, sin_freqs + + +class LTX2VideoTransformer3DModel( + ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin +): + r""" + A Transformer model for video-like data used in [LTX](https://huggingface.co/Lightricks/LTX-Video). + + Args: + in_channels (`int`, defaults to `128`): + The number of channels in the input. + out_channels (`int`, defaults to `128`): + The number of channels in the output. + patch_size (`int`, defaults to `1`): + The size of the spatial patches to use in the patch embedding layer. + patch_size_t (`int`, defaults to `1`): + The size of the tmeporal patches to use in the patch embedding layer. + num_attention_heads (`int`, defaults to `32`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`, defaults to `64`): + The number of channels in each head. + cross_attention_dim (`int`, defaults to `2048 `): + The number of channels for cross attention heads. + num_layers (`int`, defaults to `28`): + The number of layers of Transformer blocks to use. + activation_fn (`str`, defaults to `"gelu-approximate"`): + Activation function to use in feed-forward. + qk_norm (`str`, defaults to `"rms_norm_across_heads"`): + The normalization layer to use. + """ + + _supports_gradient_checkpointing = True + _skip_layerwise_casting_patterns = ["norm"] + _repeated_blocks = ["LTX2VideoTransformerBlock"] + + @register_to_config + def __init__( + self, + in_channels: int = 128, # Video Arguments + out_channels: int | None = 128, + patch_size: int = 1, + patch_size_t: int = 1, + num_attention_heads: int = 32, + attention_head_dim: int = 128, + cross_attention_dim: int = 4096, + vae_scale_factors: tuple[int, int, int] = (8, 32, 32), + pos_embed_max_pos: int = 20, + base_height: int = 2048, + base_width: int = 2048, + audio_in_channels: int = 128, # Audio Arguments + audio_out_channels: int | None = 128, + audio_patch_size: int = 1, + audio_patch_size_t: int = 1, + audio_num_attention_heads: int = 32, + audio_attention_head_dim: int = 64, + audio_cross_attention_dim: int = 2048, + audio_scale_factor: int = 4, + audio_pos_embed_max_pos: int = 20, + audio_sampling_rate: int = 16000, + audio_hop_length: int = 160, + num_layers: int = 48, # Shared arguments + activation_fn: str = "gelu-approximate", + qk_norm: str = "rms_norm_across_heads", + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-6, + caption_channels: int = 3840, + attention_bias: bool = True, + attention_out_bias: bool = True, + rope_theta: float = 10000.0, + rope_double_precision: bool = True, + causal_offset: int = 1, + timestep_scale_multiplier: int = 1000, + cross_attn_timestep_scale_multiplier: int = 1000, + rope_type: str = "interleaved", + ) -> None: + super().__init__() + + out_channels = out_channels or in_channels + audio_out_channels = audio_out_channels or audio_in_channels + inner_dim = num_attention_heads * attention_head_dim + audio_inner_dim = audio_num_attention_heads * audio_attention_head_dim + + # 1. Patchification input projections + self.proj_in = nn.Linear(in_channels, inner_dim) + self.audio_proj_in = nn.Linear(audio_in_channels, audio_inner_dim) + + # 2. Prompt embeddings + self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) + self.audio_caption_projection = PixArtAlphaTextProjection( + in_features=caption_channels, hidden_size=audio_inner_dim + ) + + # 3. Timestep Modulation Params and Embedding + # 3.1. Global Timestep Modulation Parameters (except for cross-attention) and timestep + size embedding + # time_embed and audio_time_embed calculate both the timestep embedding and (global) modulation parameters + self.time_embed = LTX2AdaLayerNormSingle(inner_dim, num_mod_params=6, use_additional_conditions=False) + self.audio_time_embed = LTX2AdaLayerNormSingle( + audio_inner_dim, num_mod_params=6, use_additional_conditions=False + ) + + # 3.2. Global Cross Attention Modulation Parameters + # Used in the audio-to-video and video-to-audio cross attention layers as a global set of modulation params, + # which are then further modified by per-block modulaton params in each transformer block. + # There are 2 sets of scale/shift parameters for each modality, 1 each for audio-to-video (a2v) and + # video-to-audio (v2a) cross attention + self.av_cross_attn_video_scale_shift = LTX2AdaLayerNormSingle( + inner_dim, num_mod_params=4, use_additional_conditions=False + ) + self.av_cross_attn_audio_scale_shift = LTX2AdaLayerNormSingle( + audio_inner_dim, num_mod_params=4, use_additional_conditions=False + ) + # Gate param for audio-to-video (a2v) cross attn (where the video is the queries (Q) and the audio is the keys + # and values (KV)) + self.av_cross_attn_video_a2v_gate = LTX2AdaLayerNormSingle( + inner_dim, num_mod_params=1, use_additional_conditions=False + ) + # Gate param for video-to-audio (v2a) cross attn (where the audio is the queries (Q) and the video is the keys + # and values (KV)) + self.av_cross_attn_audio_v2a_gate = LTX2AdaLayerNormSingle( + audio_inner_dim, num_mod_params=1, use_additional_conditions=False + ) + + # 3.3. Output Layer Scale/Shift Modulation parameters + self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) + self.audio_scale_shift_table = nn.Parameter(torch.randn(2, audio_inner_dim) / audio_inner_dim**0.5) + + # 4. Rotary Positional Embeddings (RoPE) + # Self-Attention + self.rope = LTX2AudioVideoRotaryPosEmbed( + dim=inner_dim, + patch_size=patch_size, + patch_size_t=patch_size_t, + base_num_frames=pos_embed_max_pos, + base_height=base_height, + base_width=base_width, + scale_factors=vae_scale_factors, + theta=rope_theta, + causal_offset=causal_offset, + modality="video", + double_precision=rope_double_precision, + rope_type=rope_type, + num_attention_heads=num_attention_heads, + ) + self.audio_rope = LTX2AudioVideoRotaryPosEmbed( + dim=audio_inner_dim, + patch_size=audio_patch_size, + patch_size_t=audio_patch_size_t, + base_num_frames=audio_pos_embed_max_pos, + sampling_rate=audio_sampling_rate, + hop_length=audio_hop_length, + scale_factors=[audio_scale_factor], + theta=rope_theta, + causal_offset=causal_offset, + modality="audio", + double_precision=rope_double_precision, + rope_type=rope_type, + num_attention_heads=audio_num_attention_heads, + ) + + # Audio-to-Video, Video-to-Audio Cross-Attention + cross_attn_pos_embed_max_pos = max(pos_embed_max_pos, audio_pos_embed_max_pos) + self.cross_attn_rope = LTX2AudioVideoRotaryPosEmbed( + dim=audio_cross_attention_dim, + patch_size=patch_size, + patch_size_t=patch_size_t, + base_num_frames=cross_attn_pos_embed_max_pos, + base_height=base_height, + base_width=base_width, + theta=rope_theta, + causal_offset=causal_offset, + modality="video", + double_precision=rope_double_precision, + rope_type=rope_type, + num_attention_heads=num_attention_heads, + ) + self.cross_attn_audio_rope = LTX2AudioVideoRotaryPosEmbed( + dim=audio_cross_attention_dim, + patch_size=audio_patch_size, + patch_size_t=audio_patch_size_t, + base_num_frames=cross_attn_pos_embed_max_pos, + sampling_rate=audio_sampling_rate, + hop_length=audio_hop_length, + theta=rope_theta, + causal_offset=causal_offset, + modality="audio", + double_precision=rope_double_precision, + rope_type=rope_type, + num_attention_heads=audio_num_attention_heads, + ) + + # 5. Transformer Blocks + self.transformer_blocks = nn.ModuleList( + [ + LTX2VideoTransformerBlock( + dim=inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + cross_attention_dim=cross_attention_dim, + audio_dim=audio_inner_dim, + audio_num_attention_heads=audio_num_attention_heads, + audio_attention_head_dim=audio_attention_head_dim, + audio_cross_attention_dim=audio_cross_attention_dim, + qk_norm=qk_norm, + activation_fn=activation_fn, + attention_bias=attention_bias, + attention_out_bias=attention_out_bias, + eps=norm_eps, + elementwise_affine=norm_elementwise_affine, + rope_type=rope_type, + ) + for _ in range(num_layers) + ] + ) + + # 6. Output layers + self.norm_out = nn.LayerNorm(inner_dim, eps=1e-6, elementwise_affine=False) + self.proj_out = nn.Linear(inner_dim, out_channels) + + self.audio_norm_out = nn.LayerNorm(audio_inner_dim, eps=1e-6, elementwise_affine=False) + self.audio_proj_out = nn.Linear(audio_inner_dim, audio_out_channels) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + audio_hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + audio_encoder_hidden_states: torch.Tensor, + timestep: torch.LongTensor, + audio_timestep: torch.LongTensor | None = None, + encoder_attention_mask: torch.Tensor | None = None, + audio_encoder_attention_mask: torch.Tensor | None = None, + num_frames: int | None = None, + height: int | None = None, + width: int | None = None, + fps: float = 24.0, + audio_num_frames: int | None = None, + video_coords: torch.Tensor | None = None, + audio_coords: torch.Tensor | None = None, + attention_kwargs: dict[str, Any] | None = None, + return_dict: bool = True, + ) -> torch.Tensor: + """ + Forward pass for LTX-2.0 audiovisual video transformer. + + Args: + hidden_states (`torch.Tensor`): + Input patchified video latents of shape `(batch_size, num_video_tokens, in_channels)`. + audio_hidden_states (`torch.Tensor`): + Input patchified audio latents of shape `(batch_size, num_audio_tokens, audio_in_channels)`. + encoder_hidden_states (`torch.Tensor`): + Input video text embeddings of shape `(batch_size, text_seq_len, self.config.caption_channels)`. + audio_encoder_hidden_states (`torch.Tensor`): + Input audio text embeddings of shape `(batch_size, text_seq_len, self.config.caption_channels)`. + timestep (`torch.Tensor`): + Input timestep of shape `(batch_size, num_video_tokens)`. These should already be scaled by + `self.config.timestep_scale_multiplier`. + audio_timestep (`torch.Tensor`, *optional*): + Input timestep of shape `(batch_size,)` or `(batch_size, num_audio_tokens)` for audio modulation + params. This is only used by certain pipelines such as the I2V pipeline. + encoder_attention_mask (`torch.Tensor`, *optional*): + Optional multiplicative text attention mask of shape `(batch_size, text_seq_len)`. + audio_encoder_attention_mask (`torch.Tensor`, *optional*): + Optional multiplicative text attention mask of shape `(batch_size, text_seq_len)` for audio modeling. + num_frames (`int`, *optional*): + The number of latent video frames. Used if calculating the video coordinates for RoPE. + height (`int`, *optional*): + The latent video height. Used if calculating the video coordinates for RoPE. + width (`int`, *optional*): + The latent video width. Used if calculating the video coordinates for RoPE. + fps: (`float`, *optional*, defaults to `24.0`): + The desired frames per second of the generated video. Used if calculating the video coordinates for + RoPE. + audio_num_frames: (`int`, *optional*): + The number of latent audio frames. Used if calculating the audio coordinates for RoPE. + video_coords (`torch.Tensor`, *optional*): + The video coordinates to be used when calculating the rotary positional embeddings (RoPE) of shape + `(batch_size, 3, num_video_tokens, 2)`. If not supplied, this will be calculated inside `forward`. + audio_coords (`torch.Tensor`, *optional*): + The audio coordinates to be used when calculating the rotary positional embeddings (RoPE) of shape + `(batch_size, 1, num_audio_tokens, 2)`. If not supplied, this will be calculated inside `forward`. + attention_kwargs (`dict[str, Any]`, *optional*): + Optional dict of keyword args to be passed to the attention processor. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a dict-like structured output of type `AudioVisualModelOutput` or a tuple. + + Returns: + `AudioVisualModelOutput` or `tuple`: + If `return_dict` is `True`, returns a structured output of type `AudioVisualModelOutput`, otherwise a + `tuple` is returned where the first element is the denoised video latent patch sequence and the second + element is the denoised audio latent patch sequence. + """ + # Determine timestep for audio. + audio_timestep = audio_timestep if audio_timestep is not None else timestep + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + if audio_encoder_attention_mask is not None and audio_encoder_attention_mask.ndim == 2: + audio_encoder_attention_mask = (1 - audio_encoder_attention_mask.to(audio_hidden_states.dtype)) * -10000.0 + audio_encoder_attention_mask = audio_encoder_attention_mask.unsqueeze(1) + + batch_size = hidden_states.size(0) + + # 1. Prepare RoPE positional embeddings + if video_coords is None: + video_coords = self.rope.prepare_video_coords( + batch_size, num_frames, height, width, hidden_states.device, fps=fps + ) + if audio_coords is None: + audio_coords = self.audio_rope.prepare_audio_coords( + batch_size, audio_num_frames, audio_hidden_states.device + ) + + video_rotary_emb = self.rope(video_coords, device=hidden_states.device) + audio_rotary_emb = self.audio_rope(audio_coords, device=audio_hidden_states.device) + + video_cross_attn_rotary_emb = self.cross_attn_rope(video_coords[:, 0:1, :], device=hidden_states.device) + audio_cross_attn_rotary_emb = self.cross_attn_audio_rope( + audio_coords[:, 0:1, :], device=audio_hidden_states.device + ) + + # 2. Patchify input projections + hidden_states = self.proj_in(hidden_states) + audio_hidden_states = self.audio_proj_in(audio_hidden_states) + + # 3. Prepare timestep embeddings and modulation parameters + timestep_cross_attn_gate_scale_factor = ( + self.config.cross_attn_timestep_scale_multiplier / self.config.timestep_scale_multiplier + ) + + # 3.1. Prepare global modality (video and audio) timestep embedding and modulation parameters + # temb is used in the transformer blocks (as expected), while embedded_timestep is used for the output layer + # modulation with scale_shift_table (and similarly for audio) + temb, embedded_timestep = self.time_embed( + timestep.flatten(), + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + temb = temb.view(batch_size, -1, temb.size(-1)) + embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1)) + + temb_audio, audio_embedded_timestep = self.audio_time_embed( + audio_timestep.flatten(), + batch_size=batch_size, + hidden_dtype=audio_hidden_states.dtype, + ) + temb_audio = temb_audio.view(batch_size, -1, temb_audio.size(-1)) + audio_embedded_timestep = audio_embedded_timestep.view(batch_size, -1, audio_embedded_timestep.size(-1)) + + # 3.2. Prepare global modality cross attention modulation parameters + video_cross_attn_scale_shift, _ = self.av_cross_attn_video_scale_shift( + timestep.flatten(), + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + video_cross_attn_a2v_gate, _ = self.av_cross_attn_video_a2v_gate( + timestep.flatten() * timestep_cross_attn_gate_scale_factor, + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + video_cross_attn_scale_shift = video_cross_attn_scale_shift.view( + batch_size, -1, video_cross_attn_scale_shift.shape[-1] + ) + video_cross_attn_a2v_gate = video_cross_attn_a2v_gate.view(batch_size, -1, video_cross_attn_a2v_gate.shape[-1]) + + audio_cross_attn_scale_shift, _ = self.av_cross_attn_audio_scale_shift( + audio_timestep.flatten(), + batch_size=batch_size, + hidden_dtype=audio_hidden_states.dtype, + ) + audio_cross_attn_v2a_gate, _ = self.av_cross_attn_audio_v2a_gate( + audio_timestep.flatten() * timestep_cross_attn_gate_scale_factor, + batch_size=batch_size, + hidden_dtype=audio_hidden_states.dtype, + ) + audio_cross_attn_scale_shift = audio_cross_attn_scale_shift.view( + batch_size, -1, audio_cross_attn_scale_shift.shape[-1] + ) + audio_cross_attn_v2a_gate = audio_cross_attn_v2a_gate.view(batch_size, -1, audio_cross_attn_v2a_gate.shape[-1]) + + # 4. Prepare prompt embeddings + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1)) + + audio_encoder_hidden_states = self.audio_caption_projection(audio_encoder_hidden_states) + audio_encoder_hidden_states = audio_encoder_hidden_states.view(batch_size, -1, audio_hidden_states.size(-1)) + + # 5. Run transformer blocks + for block in self.transformer_blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states, audio_hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + audio_hidden_states, + encoder_hidden_states, + audio_encoder_hidden_states, + temb, + temb_audio, + video_cross_attn_scale_shift, + audio_cross_attn_scale_shift, + video_cross_attn_a2v_gate, + audio_cross_attn_v2a_gate, + video_rotary_emb, + audio_rotary_emb, + video_cross_attn_rotary_emb, + audio_cross_attn_rotary_emb, + encoder_attention_mask, + audio_encoder_attention_mask, + ) + else: + hidden_states, audio_hidden_states = block( + hidden_states=hidden_states, + audio_hidden_states=audio_hidden_states, + encoder_hidden_states=encoder_hidden_states, + audio_encoder_hidden_states=audio_encoder_hidden_states, + temb=temb, + temb_audio=temb_audio, + temb_ca_scale_shift=video_cross_attn_scale_shift, + temb_ca_audio_scale_shift=audio_cross_attn_scale_shift, + temb_ca_gate=video_cross_attn_a2v_gate, + temb_ca_audio_gate=audio_cross_attn_v2a_gate, + video_rotary_emb=video_rotary_emb, + audio_rotary_emb=audio_rotary_emb, + ca_video_rotary_emb=video_cross_attn_rotary_emb, + ca_audio_rotary_emb=audio_cross_attn_rotary_emb, + encoder_attention_mask=encoder_attention_mask, + audio_encoder_attention_mask=audio_encoder_attention_mask, + ) + + # 6. Output layers (including unpatchification) + scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None] + shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + + hidden_states = self.norm_out(hidden_states) + hidden_states = hidden_states * (1 + scale) + shift + output = self.proj_out(hidden_states) + + audio_scale_shift_values = self.audio_scale_shift_table[None, None] + audio_embedded_timestep[:, :, None] + audio_shift, audio_scale = audio_scale_shift_values[:, :, 0], audio_scale_shift_values[:, :, 1] + + audio_hidden_states = self.audio_norm_out(audio_hidden_states) + audio_hidden_states = audio_hidden_states * (1 + audio_scale) + audio_shift + audio_output = self.audio_proj_out(audio_hidden_states) + + if not return_dict: + return (output, audio_output) + return AudioVisualModelOutput(sample=output, audio_sample=audio_output) \ No newline at end of file diff --git a/videox_fun/models/ltx2_vae.py b/videox_fun/models/ltx2_vae.py new file mode 100644 index 00000000..74ede24b --- /dev/null +++ b/videox_fun/models/ltx2_vae.py @@ -0,0 +1,1519 @@ +# Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py +# Copyright 2025 The Lightricks team and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch.nn as nn +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.activations import get_activation +from diffusers.models.autoencoders.vae import (DecoderOutput, + DiagonalGaussianDistribution) +from diffusers.models.embeddings import \ + PixArtAlphaCombinedTimestepSizeEmbeddings +from diffusers.models.modeling_outputs import AutoencoderKLOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils.accelerate_utils import apply_forward_hook + + +class PerChannelRMSNorm(nn.Module): + """ + Per-pixel (per-location) RMS normalization layer. + + For each element along the chosen dimension, this layer normalizes the tensor by the root-mean-square of its values + across that dimension: + + y = x / sqrt(mean(x^2, dim=dim, keepdim=True) + eps) + """ + + def __init__(self, channel_dim: int = 1, eps: float = 1e-8) -> None: + """ + Args: + dim: Dimension along which to compute the RMS (typically channels). + eps: Small constant added for numerical stability. + """ + super().__init__() + self.channel_dim = channel_dim + self.eps = eps + + def forward(self, x: torch.Tensor, channel_dim: int | None = None) -> torch.Tensor: + """ + Apply RMS normalization along the configured dimension. + """ + channel_dim = channel_dim or self.channel_dim + # Compute mean of squared values along `dim`, keep dimensions for broadcasting. + mean_sq = torch.mean(x**2, dim=self.channel_dim, keepdim=True) + # Normalize by the root-mean-square (RMS). + rms = torch.sqrt(mean_sq + self.eps) + return x / rms + + +# Like LTXCausalConv3d, but whether causal inference is performed can be specified at runtime +class LTX2VideoCausalConv3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int | tuple[int, int, int] = 3, + stride: int | tuple[int, int, int] = 1, + dilation: int | tuple[int, int, int] = 1, + groups: int = 1, + spatial_padding_mode: str = "zeros", + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size, kernel_size) + + dilation = dilation if isinstance(dilation, tuple) else (dilation, 1, 1) + stride = stride if isinstance(stride, tuple) else (stride, stride, stride) + height_pad = self.kernel_size[1] // 2 + width_pad = self.kernel_size[2] // 2 + padding = (0, height_pad, width_pad) + + self.conv = nn.Conv3d( + in_channels, + out_channels, + self.kernel_size, + stride=stride, + dilation=dilation, + groups=groups, + padding=padding, + padding_mode=spatial_padding_mode, + ) + + def forward(self, hidden_states: torch.Tensor, causal: bool = True) -> torch.Tensor: + time_kernel_size = self.kernel_size[0] + + if causal: + pad_left = hidden_states[:, :, :1, :, :].repeat((1, 1, time_kernel_size - 1, 1, 1)) + hidden_states = torch.concatenate([pad_left, hidden_states], dim=2) + else: + pad_left = hidden_states[:, :, :1, :, :].repeat((1, 1, (time_kernel_size - 1) // 2, 1, 1)) + pad_right = hidden_states[:, :, -1:, :, :].repeat((1, 1, (time_kernel_size - 1) // 2, 1, 1)) + hidden_states = torch.concatenate([pad_left, hidden_states, pad_right], dim=2) + + hidden_states = self.conv(hidden_states) + return hidden_states + + +# Like LTXVideoResnetBlock3d, but uses new causal Conv3d, normal Conv3d for the conv_shortcut, and the spatial padding +# mode is configurable +class LTX2VideoResnetBlock3d(nn.Module): + r""" + A 3D ResNet block used in the LTX 2.0 audiovisual model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + dropout (`float`, defaults to `0.0`): + Dropout rate. + eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + elementwise_affine (`bool`, defaults to `False`): + Whether to enable elementwise affinity in the normalization layers. + non_linearity (`str`, defaults to `"swish"`): + Activation function to use. + conv_shortcut (bool, defaults to `False`): + Whether or not to use a convolution shortcut. + """ + + def __init__( + self, + in_channels: int, + out_channels: int | None = None, + dropout: float = 0.0, + eps: float = 1e-6, + elementwise_affine: bool = False, + non_linearity: str = "swish", + inject_noise: bool = False, + timestep_conditioning: bool = False, + spatial_padding_mode: str = "zeros", + ) -> None: + super().__init__() + + out_channels = out_channels or in_channels + + self.nonlinearity = get_activation(non_linearity) + + self.norm1 = PerChannelRMSNorm() + self.conv1 = LTX2VideoCausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + spatial_padding_mode=spatial_padding_mode, + ) + + self.norm2 = PerChannelRMSNorm() + self.dropout = nn.Dropout(dropout) + self.conv2 = LTX2VideoCausalConv3d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + spatial_padding_mode=spatial_padding_mode, + ) + + self.norm3 = None + self.conv_shortcut = None + if in_channels != out_channels: + self.norm3 = nn.LayerNorm(in_channels, eps=eps, elementwise_affine=True, bias=True) + # LTX 2.0 uses a normal nn.Conv3d here rather than LTXVideoCausalConv3d + self.conv_shortcut = nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1) + + self.per_channel_scale1 = None + self.per_channel_scale2 = None + if inject_noise: + self.per_channel_scale1 = nn.Parameter(torch.zeros(in_channels, 1, 1)) + self.per_channel_scale2 = nn.Parameter(torch.zeros(in_channels, 1, 1)) + + self.scale_shift_table = None + if timestep_conditioning: + self.scale_shift_table = nn.Parameter(torch.randn(4, in_channels) / in_channels**0.5) + + def forward( + self, + inputs: torch.Tensor, + temb: torch.Tensor | None = None, + generator: torch.Generator | None = None, + causal: bool = True, + ) -> torch.Tensor: + hidden_states = inputs + + hidden_states = self.norm1(hidden_states) + + if self.scale_shift_table is not None: + temb = temb.unflatten(1, (4, -1)) + self.scale_shift_table[None, ..., None, None, None] + shift_1, scale_1, shift_2, scale_2 = temb.unbind(dim=1) + hidden_states = hidden_states * (1 + scale_1) + shift_1 + + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv1(hidden_states, causal=causal) + + if self.per_channel_scale1 is not None: + spatial_shape = hidden_states.shape[-2:] + spatial_noise = torch.randn( + spatial_shape, generator=generator, device=hidden_states.device, dtype=hidden_states.dtype + )[None] + hidden_states = hidden_states + (spatial_noise * self.per_channel_scale1)[None, :, None, ...] + + hidden_states = self.norm2(hidden_states) + + if self.scale_shift_table is not None: + hidden_states = hidden_states * (1 + scale_2) + shift_2 + + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states, causal=causal) + + if self.per_channel_scale2 is not None: + spatial_shape = hidden_states.shape[-2:] + spatial_noise = torch.randn( + spatial_shape, generator=generator, device=hidden_states.device, dtype=hidden_states.dtype + )[None] + hidden_states = hidden_states + (spatial_noise * self.per_channel_scale2)[None, :, None, ...] + + if self.norm3 is not None: + inputs = self.norm3(inputs.movedim(1, -1)).movedim(-1, 1) + + if self.conv_shortcut is not None: + inputs = self.conv_shortcut(inputs) + + hidden_states = hidden_states + inputs + return hidden_states + + +# Like LTX 1.0 LTXVideoDownsampler3d, but uses new causal Conv3d +class LTXVideoDownsampler3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + stride: int | tuple[int, int, int] = 1, + spatial_padding_mode: str = "zeros", + ) -> None: + super().__init__() + + self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride) + self.group_size = (in_channels * stride[0] * stride[1] * stride[2]) // out_channels + + out_channels = out_channels // (self.stride[0] * self.stride[1] * self.stride[2]) + + self.conv = LTX2VideoCausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=1, + spatial_padding_mode=spatial_padding_mode, + ) + + def forward(self, hidden_states: torch.Tensor, causal: bool = True) -> torch.Tensor: + hidden_states = torch.cat([hidden_states[:, :, : self.stride[0] - 1], hidden_states], dim=2) + + residual = ( + hidden_states.unflatten(4, (-1, self.stride[2])) + .unflatten(3, (-1, self.stride[1])) + .unflatten(2, (-1, self.stride[0])) + ) + residual = residual.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4) + residual = residual.unflatten(1, (-1, self.group_size)) + residual = residual.mean(dim=2) + + hidden_states = self.conv(hidden_states, causal=causal) + hidden_states = ( + hidden_states.unflatten(4, (-1, self.stride[2])) + .unflatten(3, (-1, self.stride[1])) + .unflatten(2, (-1, self.stride[0])) + ) + hidden_states = hidden_states.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4) + hidden_states = hidden_states + residual + + return hidden_states + + +# Like LTX 1.0 LTXVideoUpsampler3d, but uses new causal Conv3d +class LTXVideoUpsampler3d(nn.Module): + def __init__( + self, + in_channels: int, + stride: int | tuple[int, int, int] = 1, + residual: bool = False, + upscale_factor: int = 1, + spatial_padding_mode: str = "zeros", + ) -> None: + super().__init__() + + self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride) + self.residual = residual + self.upscale_factor = upscale_factor + + out_channels = (in_channels * stride[0] * stride[1] * stride[2]) // upscale_factor + + self.conv = LTX2VideoCausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=1, + spatial_padding_mode=spatial_padding_mode, + ) + + def forward(self, hidden_states: torch.Tensor, causal: bool = True) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + + if self.residual: + residual = hidden_states.reshape( + batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width + ) + residual = residual.permute(0, 1, 5, 2, 6, 3, 7, 4).flatten(6, 7).flatten(4, 5).flatten(2, 3) + repeats = (self.stride[0] * self.stride[1] * self.stride[2]) // self.upscale_factor + residual = residual.repeat(1, repeats, 1, 1, 1) + residual = residual[:, :, self.stride[0] - 1 :] + + hidden_states = self.conv(hidden_states, causal=causal) + hidden_states = hidden_states.reshape( + batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width + ) + hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 3, 7, 4).flatten(6, 7).flatten(4, 5).flatten(2, 3) + hidden_states = hidden_states[:, :, self.stride[0] - 1 :] + + if self.residual: + hidden_states = hidden_states + residual + + return hidden_states + + +# Like LTX 1.0 LTXVideo095DownBlock3D, but with the updated LTX2VideoResnetBlock3d +class LTX2VideoDownBlock3D(nn.Module): + r""" + Down block used in the LTXVideo model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + num_layers (`int`, defaults to `1`): + Number of resnet layers. + dropout (`float`, defaults to `0.0`): + Dropout rate. + resnet_eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + resnet_act_fn (`str`, defaults to `"swish"`): + Activation function to use. + spatio_temporal_scale (`bool`, defaults to `True`): + Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension. + Whether or not to downsample across temporal dimension. + is_causal (`bool`, defaults to `True`): + Whether this layer behaves causally (future frames depend only on past frames) or not. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int, + out_channels: int | None = None, + num_layers: int = 1, + dropout: float = 0.0, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + spatio_temporal_scale: bool = True, + downsample_type: str = "conv", + spatial_padding_mode: str = "zeros", + ): + super().__init__() + + out_channels = out_channels or in_channels + + resnets = [] + for _ in range(num_layers): + resnets.append( + LTX2VideoResnetBlock3d( + in_channels=in_channels, + out_channels=in_channels, + dropout=dropout, + eps=resnet_eps, + non_linearity=resnet_act_fn, + spatial_padding_mode=spatial_padding_mode, + ) + ) + self.resnets = nn.ModuleList(resnets) + + self.downsamplers = None + if spatio_temporal_scale: + self.downsamplers = nn.ModuleList() + + if downsample_type == "conv": + self.downsamplers.append( + LTX2VideoCausalConv3d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + stride=(2, 2, 2), + spatial_padding_mode=spatial_padding_mode, + ) + ) + elif downsample_type == "spatial": + self.downsamplers.append( + LTXVideoDownsampler3d( + in_channels=in_channels, + out_channels=out_channels, + stride=(1, 2, 2), + spatial_padding_mode=spatial_padding_mode, + ) + ) + elif downsample_type == "temporal": + self.downsamplers.append( + LTXVideoDownsampler3d( + in_channels=in_channels, + out_channels=out_channels, + stride=(2, 1, 1), + spatial_padding_mode=spatial_padding_mode, + ) + ) + elif downsample_type == "spatiotemporal": + self.downsamplers.append( + LTXVideoDownsampler3d( + in_channels=in_channels, + out_channels=out_channels, + stride=(2, 2, 2), + spatial_padding_mode=spatial_padding_mode, + ) + ) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor | None = None, + generator: torch.Generator | None = None, + causal: bool = True, + ) -> torch.Tensor: + r"""Forward method of the `LTXDownBlock3D` class.""" + + for i, resnet in enumerate(self.resnets): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator, causal) + else: + hidden_states = resnet(hidden_states, temb, generator, causal=causal) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, causal=causal) + + return hidden_states + + +# Adapted from diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoMidBlock3d +# Like LTX 1.0 LTXVideoMidBlock3d, but with the updated LTX2VideoResnetBlock3d +class LTX2VideoMidBlock3d(nn.Module): + r""" + A middle block used in the LTXVideo model. + + Args: + in_channels (`int`): + Number of input channels. + num_layers (`int`, defaults to `1`): + Number of resnet layers. + dropout (`float`, defaults to `0.0`): + Dropout rate. + resnet_eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + resnet_act_fn (`str`, defaults to `"swish"`): + Activation function to use. + is_causal (`bool`, defaults to `True`): + Whether this layer behaves causally (future frames depend only on past frames) or not. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int, + num_layers: int = 1, + dropout: float = 0.0, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + inject_noise: bool = False, + timestep_conditioning: bool = False, + spatial_padding_mode: str = "zeros", + ) -> None: + super().__init__() + + self.time_embedder = None + if timestep_conditioning: + self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(in_channels * 4, 0) + + resnets = [] + for _ in range(num_layers): + resnets.append( + LTX2VideoResnetBlock3d( + in_channels=in_channels, + out_channels=in_channels, + dropout=dropout, + eps=resnet_eps, + non_linearity=resnet_act_fn, + inject_noise=inject_noise, + timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, + ) + ) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor | None = None, + generator: torch.Generator | None = None, + causal: bool = True, + ) -> torch.Tensor: + r"""Forward method of the `LTXMidBlock3D` class.""" + + if self.time_embedder is not None: + temb = self.time_embedder( + timestep=temb.flatten(), + resolution=None, + aspect_ratio=None, + batch_size=hidden_states.size(0), + hidden_dtype=hidden_states.dtype, + ) + temb = temb.view(hidden_states.size(0), -1, 1, 1, 1) + + for i, resnet in enumerate(self.resnets): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator, causal) + else: + hidden_states = resnet(hidden_states, temb, generator, causal=causal) + + return hidden_states + + +# Like LTXVideoUpBlock3d but with no conv_in and the updated LTX2VideoResnetBlock3d +class LTX2VideoUpBlock3d(nn.Module): + r""" + Up block used in the LTXVideo model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + num_layers (`int`, defaults to `1`): + Number of resnet layers. + dropout (`float`, defaults to `0.0`): + Dropout rate. + resnet_eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + resnet_act_fn (`str`, defaults to `"swish"`): + Activation function to use. + spatio_temporal_scale (`bool`, defaults to `True`): + Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension. + Whether or not to downsample across temporal dimension. + is_causal (`bool`, defaults to `True`): + Whether this layer behaves causally (future frames depend only on past frames) or not. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int, + out_channels: int | None = None, + num_layers: int = 1, + dropout: float = 0.0, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + spatio_temporal_scale: bool = True, + inject_noise: bool = False, + timestep_conditioning: bool = False, + upsample_residual: bool = False, + upscale_factor: int = 1, + spatial_padding_mode: str = "zeros", + ): + super().__init__() + + out_channels = out_channels or in_channels + + self.time_embedder = None + if timestep_conditioning: + self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(in_channels * 4, 0) + + self.conv_in = None + if in_channels != out_channels: + self.conv_in = LTX2VideoResnetBlock3d( + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + eps=resnet_eps, + non_linearity=resnet_act_fn, + inject_noise=inject_noise, + timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, + ) + + self.upsamplers = None + if spatio_temporal_scale: + self.upsamplers = nn.ModuleList( + [ + LTXVideoUpsampler3d( + out_channels * upscale_factor, + stride=(2, 2, 2), + residual=upsample_residual, + upscale_factor=upscale_factor, + spatial_padding_mode=spatial_padding_mode, + ) + ] + ) + + resnets = [] + for _ in range(num_layers): + resnets.append( + LTX2VideoResnetBlock3d( + in_channels=out_channels, + out_channels=out_channels, + dropout=dropout, + eps=resnet_eps, + non_linearity=resnet_act_fn, + inject_noise=inject_noise, + timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, + ) + ) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor | None = None, + generator: torch.Generator | None = None, + causal: bool = True, + ) -> torch.Tensor: + if self.conv_in is not None: + hidden_states = self.conv_in(hidden_states, temb, generator, causal=causal) + + if self.time_embedder is not None: + temb = self.time_embedder( + timestep=temb.flatten(), + resolution=None, + aspect_ratio=None, + batch_size=hidden_states.size(0), + hidden_dtype=hidden_states.dtype, + ) + temb = temb.view(hidden_states.size(0), -1, 1, 1, 1) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, causal=causal) + + for i, resnet in enumerate(self.resnets): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator, causal) + else: + hidden_states = resnet(hidden_states, temb, generator, causal=causal) + + return hidden_states + + +# Like LTX 1.0 LTXVideoEncoder3d but with different default args - the spatiotemporal downsampling pattern is +# different, as is the layers_per_block (the 2.0 VAE is bigger) +class LTX2VideoEncoder3d(nn.Module): + r""" + The `LTXVideoEncoder3d` layer of a variational autoencoder that encodes input video samples to its latent + representation. + + Args: + in_channels (`int`, defaults to 3): + Number of input channels. + out_channels (`int`, defaults to 128): + Number of latent channels. + block_out_channels (`tuple[int, ...]`, defaults to `(256, 512, 1024, 2048)`): + The number of output channels for each block. + spatio_temporal_scaling (`tuple[bool, ...], defaults to `(True, True, True, True)`: + Whether a block should contain spatio-temporal downscaling layers or not. + layers_per_block (`tuple[int, ...]`, defaults to `(4, 6, 6, 2, 2)`): + The number of layers per block. + downsample_type (`tuple[str, ...]`, defaults to `("spatial", "temporal", "spatiotemporal", "spatiotemporal")`): + The spatiotemporal downsampling pattern per block. Per-layer values can be + - `"spatial"` (downsample spatial dims by 2x) + - `"temporal"` (downsample temporal dim by 2x) + - `"spatiotemporal"` (downsample both spatial and temporal dims by 2x) + patch_size (`int`, defaults to `4`): + The size of spatial patches. + patch_size_t (`int`, defaults to `1`): + The size of temporal patches. + resnet_norm_eps (`float`, defaults to `1e-6`): + Epsilon value for ResNet normalization layers. + is_causal (`bool`, defaults to `True`): + Whether this layer behaves causally (future frames depend only on past frames) or not. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 128, + block_out_channels: tuple[int, ...] = (256, 512, 1024, 2048), + down_block_types: tuple[str, ...] = ( + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + ), + spatio_temporal_scaling: tuple[bool, ...] = (True, True, True, True), + layers_per_block: tuple[int, ...] = (4, 6, 6, 2, 2), + downsample_type: tuple[str, ...] = ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), + patch_size: int = 4, + patch_size_t: int = 1, + resnet_norm_eps: float = 1e-6, + is_causal: bool = True, + spatial_padding_mode: str = "zeros", + ): + super().__init__() + + self.patch_size = patch_size + self.patch_size_t = patch_size_t + self.in_channels = in_channels * patch_size**2 + self.is_causal = is_causal + + output_channel = out_channels + + self.conv_in = LTX2VideoCausalConv3d( + in_channels=self.in_channels, + out_channels=output_channel, + kernel_size=3, + stride=1, + spatial_padding_mode=spatial_padding_mode, + ) + + # down blocks + num_block_out_channels = len(block_out_channels) + self.down_blocks = nn.ModuleList([]) + for i in range(num_block_out_channels): + input_channel = output_channel + output_channel = block_out_channels[i] + + if down_block_types[i] == "LTX2VideoDownBlock3D": + down_block = LTX2VideoDownBlock3D( + in_channels=input_channel, + out_channels=output_channel, + num_layers=layers_per_block[i], + resnet_eps=resnet_norm_eps, + spatio_temporal_scale=spatio_temporal_scaling[i], + downsample_type=downsample_type[i], + spatial_padding_mode=spatial_padding_mode, + ) + else: + raise ValueError(f"Unknown down block type: {down_block_types[i]}") + + self.down_blocks.append(down_block) + + # mid block + self.mid_block = LTX2VideoMidBlock3d( + in_channels=output_channel, + num_layers=layers_per_block[-1], + resnet_eps=resnet_norm_eps, + spatial_padding_mode=spatial_padding_mode, + ) + + # out + self.norm_out = PerChannelRMSNorm() + self.conv_act = nn.SiLU() + self.conv_out = LTX2VideoCausalConv3d( + in_channels=output_channel, + out_channels=out_channels + 1, + kernel_size=3, + stride=1, + spatial_padding_mode=spatial_padding_mode, + ) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor, causal: bool | None = None) -> torch.Tensor: + r"""The forward method of the `LTXVideoEncoder3d` class.""" + + p = self.patch_size + p_t = self.patch_size_t + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p + post_patch_width = width // p + causal = causal or self.is_causal + + hidden_states = hidden_states.reshape( + batch_size, num_channels, post_patch_num_frames, p_t, post_patch_height, p, post_patch_width, p + ) + # Thanks for driving me insane with the weird patching order :( + hidden_states = hidden_states.permute(0, 1, 3, 7, 5, 2, 4, 6).flatten(1, 4) + hidden_states = self.conv_in(hidden_states, causal=causal) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for down_block in self.down_blocks: + hidden_states = self._gradient_checkpointing_func(down_block, hidden_states, None, None, causal) + + hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, None, None, causal) + else: + for down_block in self.down_blocks: + hidden_states = down_block(hidden_states, causal=causal) + + hidden_states = self.mid_block(hidden_states, causal=causal) + + hidden_states = self.norm_out(hidden_states) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states, causal=causal) + + last_channel = hidden_states[:, -1:] + last_channel = last_channel.repeat(1, hidden_states.size(1) - 2, 1, 1, 1) + hidden_states = torch.cat([hidden_states, last_channel], dim=1) + + return hidden_states + + +# Like LTX 1.0 LTXVideoDecoder3d, but has only 3 symmetric up blocks which are causal and residual with upsample_factor 2 +class LTX2VideoDecoder3d(nn.Module): + r""" + The `LTXVideoDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output + sample. + + Args: + in_channels (`int`, defaults to 128): + Number of latent channels. + out_channels (`int`, defaults to 3): + Number of output channels. + block_out_channels (`tuple[int, ...]`, defaults to `(128, 256, 512, 512)`): + The number of output channels for each block. + spatio_temporal_scaling (`tuple[bool, ...], defaults to `(True, True, True, False)`: + Whether a block should contain spatio-temporal upscaling layers or not. + layers_per_block (`tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`): + The number of layers per block. + patch_size (`int`, defaults to `4`): + The size of spatial patches. + patch_size_t (`int`, defaults to `1`): + The size of temporal patches. + resnet_norm_eps (`float`, defaults to `1e-6`): + Epsilon value for ResNet normalization layers. + is_causal (`bool`, defaults to `False`): + Whether this layer behaves causally (future frames depend only on past frames) or not. + timestep_conditioning (`bool`, defaults to `False`): + Whether to condition the model on timesteps. + """ + + def __init__( + self, + in_channels: int = 128, + out_channels: int = 3, + block_out_channels: tuple[int, ...] = (256, 512, 1024), + spatio_temporal_scaling: tuple[bool, ...] = (True, True, True), + layers_per_block: tuple[int, ...] = (5, 5, 5, 5), + patch_size: int = 4, + patch_size_t: int = 1, + resnet_norm_eps: float = 1e-6, + is_causal: bool = False, + inject_noise: tuple[bool, ...] = (False, False, False), + timestep_conditioning: bool = False, + upsample_residual: tuple[bool, ...] = (True, True, True), + upsample_factor: tuple[bool, ...] = (2, 2, 2), + spatial_padding_mode: str = "reflect", + ) -> None: + super().__init__() + + self.patch_size = patch_size + self.patch_size_t = patch_size_t + self.out_channels = out_channels * patch_size**2 + self.is_causal = is_causal + + block_out_channels = tuple(reversed(block_out_channels)) + spatio_temporal_scaling = tuple(reversed(spatio_temporal_scaling)) + layers_per_block = tuple(reversed(layers_per_block)) + inject_noise = tuple(reversed(inject_noise)) + upsample_residual = tuple(reversed(upsample_residual)) + upsample_factor = tuple(reversed(upsample_factor)) + output_channel = block_out_channels[0] + + self.conv_in = LTX2VideoCausalConv3d( + in_channels=in_channels, + out_channels=output_channel, + kernel_size=3, + stride=1, + spatial_padding_mode=spatial_padding_mode, + ) + + self.mid_block = LTX2VideoMidBlock3d( + in_channels=output_channel, + num_layers=layers_per_block[0], + resnet_eps=resnet_norm_eps, + inject_noise=inject_noise[0], + timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, + ) + + # up blocks + num_block_out_channels = len(block_out_channels) + self.up_blocks = nn.ModuleList([]) + for i in range(num_block_out_channels): + input_channel = output_channel // upsample_factor[i] + output_channel = block_out_channels[i] // upsample_factor[i] + + up_block = LTX2VideoUpBlock3d( + in_channels=input_channel, + out_channels=output_channel, + num_layers=layers_per_block[i + 1], + resnet_eps=resnet_norm_eps, + spatio_temporal_scale=spatio_temporal_scaling[i], + inject_noise=inject_noise[i + 1], + timestep_conditioning=timestep_conditioning, + upsample_residual=upsample_residual[i], + upscale_factor=upsample_factor[i], + spatial_padding_mode=spatial_padding_mode, + ) + + self.up_blocks.append(up_block) + + # out + self.norm_out = PerChannelRMSNorm() + self.conv_act = nn.SiLU() + self.conv_out = LTX2VideoCausalConv3d( + in_channels=output_channel, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + spatial_padding_mode=spatial_padding_mode, + ) + + # timestep embedding + self.time_embedder = None + self.scale_shift_table = None + self.timestep_scale_multiplier = None + if timestep_conditioning: + self.timestep_scale_multiplier = nn.Parameter(torch.tensor(1000.0, dtype=torch.float32)) + self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(output_channel * 2, 0) + self.scale_shift_table = nn.Parameter(torch.randn(2, output_channel) / output_channel**0.5) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor | None = None, + causal: bool | None = None, + ) -> torch.Tensor: + causal = causal or self.is_causal + + hidden_states = self.conv_in(hidden_states, causal=causal) + + if self.timestep_scale_multiplier is not None: + temb = temb * self.timestep_scale_multiplier + + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, temb, None, causal) + + for up_block in self.up_blocks: + hidden_states = self._gradient_checkpointing_func(up_block, hidden_states, temb, None, causal) + else: + hidden_states = self.mid_block(hidden_states, temb, causal=causal) + + for up_block in self.up_blocks: + hidden_states = up_block(hidden_states, temb, causal=causal) + + hidden_states = self.norm_out(hidden_states) + + if self.time_embedder is not None: + temb = self.time_embedder( + timestep=temb.flatten(), + resolution=None, + aspect_ratio=None, + batch_size=hidden_states.size(0), + hidden_dtype=hidden_states.dtype, + ) + temb = temb.view(hidden_states.size(0), -1, 1, 1, 1).unflatten(1, (2, -1)) + temb = temb + self.scale_shift_table[None, ..., None, None, None] + shift, scale = temb.unbind(dim=1) + hidden_states = hidden_states * (1 + scale) + shift + + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states, causal=causal) + + p = self.patch_size + p_t = self.patch_size_t + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + hidden_states = hidden_states.reshape(batch_size, -1, p_t, p, p, num_frames, height, width) + hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 4, 7, 3).flatten(6, 7).flatten(4, 5).flatten(2, 3) + + return hidden_states + + +class AutoencoderKLLTX2Video(ModelMixin, ConfigMixin): + r""" + A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in + [LTX-2](https://huggingface.co/Lightricks/LTX-2). + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Args: + in_channels (`int`, defaults to `3`): + Number of input channels. + out_channels (`int`, defaults to `3`): + Number of output channels. + latent_channels (`int`, defaults to `128`): + Number of latent channels. + block_out_channels (`tuple[int, ...]`, defaults to `(128, 256, 512, 512)`): + The number of output channels for each block. + spatio_temporal_scaling (`tuple[bool, ...], defaults to `(True, True, True, False)`: + Whether a block should contain spatio-temporal downscaling or not. + layers_per_block (`tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`): + The number of layers per block. + patch_size (`int`, defaults to `4`): + The size of spatial patches. + patch_size_t (`int`, defaults to `1`): + The size of temporal patches. + resnet_norm_eps (`float`, defaults to `1e-6`): + Epsilon value for ResNet normalization layers. + scaling_factor (`float`, *optional*, defaults to `1.0`): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) paper. + encoder_causal (`bool`, defaults to `True`): + Whether the encoder should behave causally (future frames depend only on past frames) or not. + decoder_causal (`bool`, defaults to `False`): + Whether the decoder should behave causally (future frames depend only on past frames) or not. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + latent_channels: int = 128, + block_out_channels: tuple[int, ...] = (256, 512, 1024, 2048), + down_block_types: tuple[str, ...] = ( + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + "LTX2VideoDownBlock3D", + ), + decoder_block_out_channels: tuple[int, ...] = (256, 512, 1024), + layers_per_block: tuple[int, ...] = (4, 6, 6, 2, 2), + decoder_layers_per_block: tuple[int, ...] = (5, 5, 5, 5), + spatio_temporal_scaling: tuple[bool, ...] = (True, True, True, True), + decoder_spatio_temporal_scaling: tuple[bool, ...] = (True, True, True), + decoder_inject_noise: tuple[bool, ...] = (False, False, False, False), + downsample_type: tuple[str, ...] = ("spatial", "temporal", "spatiotemporal", "spatiotemporal"), + upsample_residual: tuple[bool, ...] = (True, True, True), + upsample_factor: tuple[int, ...] = (2, 2, 2), + timestep_conditioning: bool = False, + patch_size: int = 4, + patch_size_t: int = 1, + resnet_norm_eps: float = 1e-6, + scaling_factor: float = 1.0, + encoder_causal: bool = True, + decoder_causal: bool = True, + encoder_spatial_padding_mode: str = "zeros", + decoder_spatial_padding_mode: str = "reflect", + spatial_compression_ratio: int = None, + temporal_compression_ratio: int = None, + ) -> None: + super().__init__() + + self.encoder = LTX2VideoEncoder3d( + in_channels=in_channels, + out_channels=latent_channels, + block_out_channels=block_out_channels, + down_block_types=down_block_types, + spatio_temporal_scaling=spatio_temporal_scaling, + layers_per_block=layers_per_block, + downsample_type=downsample_type, + patch_size=patch_size, + patch_size_t=patch_size_t, + resnet_norm_eps=resnet_norm_eps, + is_causal=encoder_causal, + spatial_padding_mode=encoder_spatial_padding_mode, + ) + self.decoder = LTX2VideoDecoder3d( + in_channels=latent_channels, + out_channels=out_channels, + block_out_channels=decoder_block_out_channels, + spatio_temporal_scaling=decoder_spatio_temporal_scaling, + layers_per_block=decoder_layers_per_block, + patch_size=patch_size, + patch_size_t=patch_size_t, + resnet_norm_eps=resnet_norm_eps, + is_causal=decoder_causal, + timestep_conditioning=timestep_conditioning, + inject_noise=decoder_inject_noise, + upsample_residual=upsample_residual, + upsample_factor=upsample_factor, + spatial_padding_mode=decoder_spatial_padding_mode, + ) + + latents_mean = torch.zeros((latent_channels,), requires_grad=False) + latents_std = torch.ones((latent_channels,), requires_grad=False) + self.register_buffer("latents_mean", latents_mean, persistent=True) + self.register_buffer("latents_std", latents_std, persistent=True) + + self.spatial_compression_ratio = ( + patch_size * 2 ** sum(spatio_temporal_scaling) + if spatial_compression_ratio is None + else spatial_compression_ratio + ) + self.temporal_compression_ratio = ( + patch_size_t * 2 ** sum(spatio_temporal_scaling) + if temporal_compression_ratio is None + else temporal_compression_ratio + ) + + # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension + # to perform decoding of a single video latent at a time. + self.use_slicing = False + + # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent + # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the + # intermediate tiles together, the memory requirement can be lowered. + self.use_tiling = False + + # When decoding temporally long video latents, the memory requirement is very high. By decoding latent frames + # at a fixed frame batch size (based on `self.num_latent_frames_batch_sizes`), the memory requirement can be lowered. + self.use_framewise_encoding = False + self.use_framewise_decoding = False + + # This can be configured based on the amount of GPU memory available. + # `16` for sample frames and `2` for latent frames are sensible defaults for consumer GPUs. + # Setting it to higher values results in higher memory usage. + self.num_sample_frames_batch_size = 16 + self.num_latent_frames_batch_size = 2 + + # The minimal tile height and width for spatial tiling to be used + self.tile_sample_min_height = 512 + self.tile_sample_min_width = 512 + self.tile_sample_min_num_frames = 16 + + # The minimal distance between two spatial tiles + self.tile_sample_stride_height = 448 + self.tile_sample_stride_width = 448 + self.tile_sample_stride_num_frames = 8 + + def enable_tiling( + self, + tile_sample_min_height: int | None = None, + tile_sample_min_width: int | None = None, + tile_sample_min_num_frames: int | None = None, + tile_sample_stride_height: float | None = None, + tile_sample_stride_width: float | None = None, + tile_sample_stride_num_frames: float | None = None, + ) -> None: + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + + Args: + tile_sample_min_height (`int`, *optional*): + The minimum height required for a sample to be separated into tiles across the height dimension. + tile_sample_min_width (`int`, *optional*): + The minimum width required for a sample to be separated into tiles across the width dimension. + tile_sample_stride_height (`int`, *optional*): + The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are + no tiling artifacts produced across the height dimension. + tile_sample_stride_width (`int`, *optional*): + The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling + artifacts produced across the width dimension. + """ + self.use_tiling = True + self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + self.tile_sample_min_num_frames = tile_sample_min_num_frames or self.tile_sample_min_num_frames + self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height + self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width + self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames + + def _encode(self, x: torch.Tensor, causal: bool | None = None) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = x.shape + + if self.use_framewise_decoding and num_frames > self.tile_sample_min_num_frames: + return self._temporal_tiled_encode(x, causal=causal) + + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + return self.tiled_encode(x, causal=causal) + + enc = self.encoder(x, causal=causal) + + return enc + + @apply_forward_hook + def encode( + self, x: torch.Tensor, causal: bool | None = None, return_dict: bool = True + ) -> AutoencoderKLOutput | tuple[DiagonalGaussianDistribution]: + """ + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded videos. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice, causal=causal) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x, causal=causal) + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode( + self, + z: torch.Tensor, + temb: torch.Tensor | None = None, + causal: bool | None = None, + return_dict: bool = True, + ) -> DecoderOutput | torch.Tensor: + batch_size, num_channels, num_frames, height, width = z.shape + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio + + if self.use_framewise_decoding and num_frames > tile_latent_min_num_frames: + return self._temporal_tiled_decode(z, temb, causal=causal, return_dict=return_dict) + + if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): + return self.tiled_decode(z, temb, causal=causal, return_dict=return_dict) + + dec = self.decoder(z, temb, causal=causal) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + @apply_forward_hook + def decode( + self, + z: torch.Tensor, + temb: torch.Tensor | None = None, + causal: bool | None = None, + return_dict: bool = True, + ) -> DecoderOutput | torch.Tensor: + """ + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + if self.use_slicing and z.shape[0] > 1: + if temb is not None: + decoded_slices = [ + self._decode(z_slice, t_slice, causal=causal).sample + for z_slice, t_slice in (z.split(1), temb.split(1)) + ] + else: + decoded_slices = [self._decode(z_slice, causal=causal).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z, temb, causal=causal).sample + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( + y / blend_extent + ) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[4], b.shape[4], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( + x / blend_extent + ) + return b + + def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-3], b.shape[-3], blend_extent) + for x in range(blend_extent): + b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * ( + x / blend_extent + ) + return b + + def tiled_encode(self, x: torch.Tensor, causal: bool | None = None) -> torch.Tensor: + r"""Encode a batch of images using a tiled encoder. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + batch_size, num_channels, num_frames, height, width = x.shape + latent_height = height // self.spatial_compression_ratio + latent_width = width // self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = tile_latent_min_height - tile_latent_stride_height + blend_width = tile_latent_min_width - tile_latent_stride_width + + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, self.tile_sample_stride_height): + row = [] + for j in range(0, width, self.tile_sample_stride_width): + time = self.encoder( + x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width], + causal=causal, + ) + + row.append(time) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]) + result_rows.append(torch.cat(result_row, dim=4)) + + enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] + return enc + + def tiled_decode( + self, z: torch.Tensor, temb: torch.Tensor | None, causal: bool | None = None, return_dict: bool = True + ) -> DecoderOutput | torch.Tensor: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + + batch_size, num_channels, num_frames, height, width = z.shape + sample_height = height * self.spatial_compression_ratio + sample_width = width * self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = self.tile_sample_min_height - self.tile_sample_stride_height + blend_width = self.tile_sample_min_width - self.tile_sample_stride_width + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, tile_latent_stride_height): + row = [] + for j in range(0, width, tile_latent_stride_width): + time = self.decoder( + z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width], temb, causal=causal + ) + + row.append(time) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width]) + result_rows.append(torch.cat(result_row, dim=4)) + + dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def _temporal_tiled_encode(self, x: torch.Tensor, causal: bool | None = None) -> AutoencoderKLOutput: + batch_size, num_channels, num_frames, height, width = x.shape + latent_num_frames = (num_frames - 1) // self.temporal_compression_ratio + 1 + + tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio + tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio + blend_num_frames = tile_latent_min_num_frames - tile_latent_stride_num_frames + + row = [] + for i in range(0, num_frames, self.tile_sample_stride_num_frames): + tile = x[:, :, i : i + self.tile_sample_min_num_frames + 1, :, :] + if self.use_tiling and (height > self.tile_sample_min_height or width > self.tile_sample_min_width): + tile = self.tiled_encode(tile, causal=causal) + else: + tile = self.encoder(tile, causal=causal) + if i > 0: + tile = tile[:, :, 1:, :, :] + row.append(tile) + + result_row = [] + for i, tile in enumerate(row): + if i > 0: + tile = self.blend_t(row[i - 1], tile, blend_num_frames) + result_row.append(tile[:, :, :tile_latent_stride_num_frames, :, :]) + else: + result_row.append(tile[:, :, : tile_latent_stride_num_frames + 1, :, :]) + + enc = torch.cat(result_row, dim=2)[:, :, :latent_num_frames] + return enc + + def _temporal_tiled_decode( + self, z: torch.Tensor, temb: torch.Tensor | None, causal: bool | None = None, return_dict: bool = True + ) -> DecoderOutput | torch.Tensor: + batch_size, num_channels, num_frames, height, width = z.shape + num_sample_frames = (num_frames - 1) * self.temporal_compression_ratio + 1 + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio + tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio + blend_num_frames = self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames + + row = [] + for i in range(0, num_frames, tile_latent_stride_num_frames): + tile = z[:, :, i : i + tile_latent_min_num_frames + 1, :, :] + if self.use_tiling and (tile.shape[-1] > tile_latent_min_width or tile.shape[-2] > tile_latent_min_height): + decoded = self.tiled_decode(tile, temb, causal=causal, return_dict=True).sample + else: + decoded = self.decoder(tile, temb, causal=causal) + if i > 0: + decoded = decoded[:, :, :-1, :, :] + row.append(decoded) + + result_row = [] + for i, tile in enumerate(row): + if i > 0: + tile = self.blend_t(row[i - 1], tile, blend_num_frames) + tile = tile[:, :, : self.tile_sample_stride_num_frames, :, :] + result_row.append(tile) + else: + result_row.append(tile[:, :, : self.tile_sample_stride_num_frames + 1, :, :]) + + dec = torch.cat(result_row, dim=2)[:, :, :num_sample_frames] + + if not return_dict: + return (dec,) + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.Tensor, + temb: torch.Tensor | None = None, + sample_posterior: bool = False, + encoder_causal: bool | None = None, + decoder_causal: bool | None = None, + return_dict: bool = True, + generator: torch.Generator | None = None, + ) -> torch.Tensor | torch.Tensor: + x = sample + posterior = self.encode(x, causal=encoder_causal).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z, temb, causal=decoder_causal) + if not return_dict: + return (dec.sample,) + return dec diff --git a/videox_fun/models/ltx2_vae_audio.py b/videox_fun/models/ltx2_vae_audio.py new file mode 100644 index 00000000..8300f454 --- /dev/null +++ b/videox_fun/models/ltx2_vae_audio.py @@ -0,0 +1,802 @@ +# Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py +# Copyright 2025 The Lightricks team and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.autoencoders.vae import (DecoderOutput, + DiagonalGaussianDistribution) +from diffusers.models.modeling_outputs import AutoencoderKLOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils.accelerate_utils import apply_forward_hook + + +LATENT_DOWNSAMPLE_FACTOR = 4 + + +class LTX2AudioCausalConv2d(nn.Module): + """ + A causal 2D convolution that pads asymmetrically along the causal axis. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int | tuple[int, int], + stride: int = 1, + dilation: int | tuple[int, int] = 1, + groups: int = 1, + bias: bool = True, + causality_axis: str = "height", + ) -> None: + super().__init__() + + self.causality_axis = causality_axis + kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size + dilation = (dilation, dilation) if isinstance(dilation, int) else dilation + + pad_h = (kernel_size[0] - 1) * dilation[0] + pad_w = (kernel_size[1] - 1) * dilation[1] + + if self.causality_axis == "none": + padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2) + elif self.causality_axis in {"width", "width-compatibility"}: + padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2) + elif self.causality_axis == "height": + padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0) + else: + raise ValueError(f"Invalid causality_axis: {causality_axis}") + + self.padding = padding + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=0, + dilation=dilation, + groups=groups, + bias=bias, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = F.pad(x, self.padding) + return self.conv(x) + + +class LTX2AudioPixelNorm(nn.Module): + """ + Per-pixel (per-location) RMS normalization layer. + """ + + def __init__(self, dim: int = 1, eps: float = 1e-8) -> None: + super().__init__() + self.dim = dim + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + mean_sq = torch.mean(x**2, dim=self.dim, keepdim=True) + rms = torch.sqrt(mean_sq + self.eps) + return x / rms + + +class LTX2AudioAttnBlock(nn.Module): + def __init__( + self, + in_channels: int, + norm_type: str = "group", + ) -> None: + super().__init__() + self.in_channels = in_channels + + if norm_type == "group": + self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + elif norm_type == "pixel": + self.norm = LTX2AudioPixelNorm(dim=1, eps=1e-6) + else: + raise ValueError(f"Invalid normalization type: {norm_type}") + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h_ = self.norm(x) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + batch, channels, height, width = q.shape + q = q.reshape(batch, channels, height * width).permute(0, 2, 1).contiguous() + k = k.reshape(batch, channels, height * width).contiguous() + attn = torch.bmm(q, k) * (int(channels) ** (-0.5)) + attn = torch.nn.functional.softmax(attn, dim=2) + + v = v.reshape(batch, channels, height * width) + attn = attn.permute(0, 2, 1).contiguous() + h_ = torch.bmm(v, attn).reshape(batch, channels, height, width) + + h_ = self.proj_out(h_) + return x + h_ + + +class LTX2AudioResnetBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int | None = None, + conv_shortcut: bool = False, + dropout: float = 0.0, + temb_channels: int = 512, + norm_type: str = "group", + causality_axis: str = "height", + ) -> None: + super().__init__() + self.causality_axis = causality_axis + + if self.causality_axis is not None and self.causality_axis != "none" and norm_type == "group": + raise ValueError("Causal ResnetBlock with GroupNorm is not supported.") + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + if norm_type == "group": + self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + elif norm_type == "pixel": + self.norm1 = LTX2AudioPixelNorm(dim=1, eps=1e-6) + else: + raise ValueError(f"Invalid normalization type: {norm_type}") + self.non_linearity = nn.SiLU() + if causality_axis is not None: + self.conv1 = LTX2AudioCausalConv2d( + in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis + ) + else: + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + if temb_channels > 0: + self.temb_proj = nn.Linear(temb_channels, out_channels) + if norm_type == "group": + self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) + elif norm_type == "pixel": + self.norm2 = LTX2AudioPixelNorm(dim=1, eps=1e-6) + else: + raise ValueError(f"Invalid normalization type: {norm_type}") + self.dropout = nn.Dropout(dropout) + if causality_axis is not None: + self.conv2 = LTX2AudioCausalConv2d( + out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis + ) + else: + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + if causality_axis is not None: + self.conv_shortcut = LTX2AudioCausalConv2d( + in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis + ) + else: + self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + else: + if causality_axis is not None: + self.nin_shortcut = LTX2AudioCausalConv2d( + in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis + ) + else: + self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x: torch.Tensor, temb: torch.Tensor | None = None) -> torch.Tensor: + h = self.norm1(x) + h = self.non_linearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = self.non_linearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.conv_shortcut(x) if self.use_conv_shortcut else self.nin_shortcut(x) + + return x + h + + +class LTX2AudioDownsample(nn.Module): + def __init__(self, in_channels: int, with_conv: bool, causality_axis: str | None = "height") -> None: + super().__init__() + self.with_conv = with_conv + self.causality_axis = causality_axis + + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.with_conv: + # Padding tuple is in the order: (left, right, top, bottom). + if self.causality_axis == "none": + pad = (0, 1, 0, 1) + elif self.causality_axis == "width": + pad = (2, 0, 0, 1) + elif self.causality_axis == "height": + pad = (0, 1, 2, 0) + elif self.causality_axis == "width-compatibility": + pad = (1, 0, 0, 1) + else: + raise ValueError( + f"Invalid `causality_axis` {self.causality_axis}; supported values are `none`, `width`, `height`," + f" and `width-compatibility`." + ) + + x = F.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + # with_conv=False implies that causality_axis is "none" + x = F.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class LTX2AudioUpsample(nn.Module): + def __init__(self, in_channels: int, with_conv: bool, causality_axis: str | None = "height") -> None: + super().__init__() + self.with_conv = with_conv + self.causality_axis = causality_axis + if self.with_conv: + if causality_axis is not None: + self.conv = LTX2AudioCausalConv2d( + in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis + ) + else: + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + if self.causality_axis is None or self.causality_axis == "none": + pass + elif self.causality_axis == "height": + x = x[:, :, 1:, :] + elif self.causality_axis == "width": + x = x[:, :, :, 1:] + elif self.causality_axis == "width-compatibility": + pass + else: + raise ValueError(f"Invalid causality_axis: {self.causality_axis}") + + return x + + +class LTX2AudioAudioPatchifier: + """ + Patchifier for spectrogram/audio latents. + """ + + def __init__( + self, + patch_size: int, + sample_rate: int = 16000, + hop_length: int = 160, + audio_latent_downsample_factor: int = 4, + is_causal: bool = True, + ): + self.hop_length = hop_length + self.sample_rate = sample_rate + self.audio_latent_downsample_factor = audio_latent_downsample_factor + self.is_causal = is_causal + self._patch_size = (1, patch_size, patch_size) + + def patchify(self, audio_latents: torch.Tensor) -> torch.Tensor: + batch, channels, time, freq = audio_latents.shape + return audio_latents.permute(0, 2, 1, 3).reshape(batch, time, channels * freq) + + def unpatchify(self, audio_latents: torch.Tensor, channels: int, mel_bins: int) -> torch.Tensor: + batch, time, _ = audio_latents.shape + return audio_latents.view(batch, time, channels, mel_bins).permute(0, 2, 1, 3) + + @property + def patch_size(self) -> tuple[int, int, int]: + return self._patch_size + + +class LTX2AudioEncoder(nn.Module): + def __init__( + self, + base_channels: int = 128, + output_channels: int = 1, + num_res_blocks: int = 2, + attn_resolutions: tuple[int, ...] | None = None, + in_channels: int = 2, + resolution: int = 256, + latent_channels: int = 8, + ch_mult: tuple[int, ...] = (1, 2, 4), + norm_type: str = "group", + causality_axis: str | None = "width", + dropout: float = 0.0, + mid_block_add_attention: bool = False, + sample_rate: int = 16000, + mel_hop_length: int = 160, + is_causal: bool = True, + mel_bins: int | None = 64, + double_z: bool = True, + ): + super().__init__() + + self.sample_rate = sample_rate + self.mel_hop_length = mel_hop_length + self.is_causal = is_causal + self.mel_bins = mel_bins + + self.base_channels = base_channels + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.out_ch = output_channels + self.give_pre_end = False + self.tanh_out = False + self.norm_type = norm_type + self.latent_channels = latent_channels + self.channel_multipliers = ch_mult + self.attn_resolutions = attn_resolutions + self.causality_axis = causality_axis + + base_block_channels = base_channels + base_resolution = resolution + self.z_shape = (1, latent_channels, base_resolution, base_resolution) + + if self.causality_axis is not None: + self.conv_in = LTX2AudioCausalConv2d( + in_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis + ) + else: + self.conv_in = nn.Conv2d(in_channels, base_block_channels, kernel_size=3, stride=1, padding=1) + + self.down = nn.ModuleList() + block_in = base_block_channels + curr_res = self.resolution + + for level in range(self.num_resolutions): + stage = nn.Module() + stage.block = nn.ModuleList() + stage.attn = nn.ModuleList() + block_out = self.base_channels * self.channel_multipliers[level] + + for _ in range(self.num_res_blocks): + stage.block.append( + LTX2AudioResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + ) + ) + block_in = block_out + if self.attn_resolutions: + if curr_res in self.attn_resolutions: + stage.attn.append(LTX2AudioAttnBlock(block_in, norm_type=self.norm_type)) + + if level != self.num_resolutions - 1: + stage.downsample = LTX2AudioDownsample(block_in, True, causality_axis=self.causality_axis) + curr_res = curr_res // 2 + + self.down.append(stage) + + self.mid = nn.Module() + self.mid.block_1 = LTX2AudioResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + ) + if mid_block_add_attention: + self.mid.attn_1 = LTX2AudioAttnBlock(block_in, norm_type=self.norm_type) + else: + self.mid.attn_1 = nn.Identity() + self.mid.block_2 = LTX2AudioResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + ) + + final_block_channels = block_in + z_channels = 2 * latent_channels if double_z else latent_channels + if self.norm_type == "group": + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=final_block_channels, eps=1e-6, affine=True) + elif self.norm_type == "pixel": + self.norm_out = LTX2AudioPixelNorm(dim=1, eps=1e-6) + else: + raise ValueError(f"Invalid normalization type: {self.norm_type}") + self.non_linearity = nn.SiLU() + + if self.causality_axis is not None: + self.conv_out = LTX2AudioCausalConv2d( + final_block_channels, z_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis + ) + else: + self.conv_out = nn.Conv2d(final_block_channels, z_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # hidden_states expected shape: (batch_size, channels, time, num_mel_bins) + hidden_states = self.conv_in(hidden_states) + + for level in range(self.num_resolutions): + stage = self.down[level] + for block_idx, block in enumerate(stage.block): + hidden_states = block(hidden_states, temb=None) + if stage.attn: + hidden_states = stage.attn[block_idx](hidden_states) + + if level != self.num_resolutions - 1 and hasattr(stage, "downsample"): + hidden_states = stage.downsample(hidden_states) + + hidden_states = self.mid.block_1(hidden_states, temb=None) + hidden_states = self.mid.attn_1(hidden_states) + hidden_states = self.mid.block_2(hidden_states, temb=None) + + hidden_states = self.norm_out(hidden_states) + hidden_states = self.non_linearity(hidden_states) + hidden_states = self.conv_out(hidden_states) + + return hidden_states + + +class LTX2AudioDecoder(nn.Module): + """ + Symmetric decoder that reconstructs audio spectrograms from latent features. + + The decoder mirrors the encoder structure with configurable channel multipliers, attention resolutions, and causal + convolutions. + """ + + def __init__( + self, + base_channels: int = 128, + output_channels: int = 1, + num_res_blocks: int = 2, + attn_resolutions: tuple[int, ...] | None = None, + in_channels: int = 2, + resolution: int = 256, + latent_channels: int = 8, + ch_mult: tuple[int, ...] = (1, 2, 4), + norm_type: str = "group", + causality_axis: str | None = "width", + dropout: float = 0.0, + mid_block_add_attention: bool = False, + sample_rate: int = 16000, + mel_hop_length: int = 160, + is_causal: bool = True, + mel_bins: int | None = 64, + ) -> None: + super().__init__() + + self.sample_rate = sample_rate + self.mel_hop_length = mel_hop_length + self.is_causal = is_causal + self.mel_bins = mel_bins + self.patchifier = LTX2AudioAudioPatchifier( + patch_size=1, + audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR, + sample_rate=sample_rate, + hop_length=mel_hop_length, + is_causal=is_causal, + ) + + self.base_channels = base_channels + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.out_ch = output_channels + self.give_pre_end = False + self.tanh_out = False + self.norm_type = norm_type + self.latent_channels = latent_channels + self.channel_multipliers = ch_mult + self.attn_resolutions = attn_resolutions + self.causality_axis = causality_axis + + base_block_channels = base_channels * self.channel_multipliers[-1] + base_resolution = resolution // (2 ** (self.num_resolutions - 1)) + self.z_shape = (1, latent_channels, base_resolution, base_resolution) + + if self.causality_axis is not None: + self.conv_in = LTX2AudioCausalConv2d( + latent_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis + ) + else: + self.conv_in = nn.Conv2d(latent_channels, base_block_channels, kernel_size=3, stride=1, padding=1) + self.non_linearity = nn.SiLU() + self.mid = nn.Module() + self.mid.block_1 = LTX2AudioResnetBlock( + in_channels=base_block_channels, + out_channels=base_block_channels, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + ) + if mid_block_add_attention: + self.mid.attn_1 = LTX2AudioAttnBlock(base_block_channels, norm_type=self.norm_type) + else: + self.mid.attn_1 = nn.Identity() + self.mid.block_2 = LTX2AudioResnetBlock( + in_channels=base_block_channels, + out_channels=base_block_channels, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + ) + + self.up = nn.ModuleList() + block_in = base_block_channels + curr_res = self.resolution // (2 ** (self.num_resolutions - 1)) + + for level in reversed(range(self.num_resolutions)): + stage = nn.Module() + stage.block = nn.ModuleList() + stage.attn = nn.ModuleList() + block_out = self.base_channels * self.channel_multipliers[level] + + for _ in range(self.num_res_blocks + 1): + stage.block.append( + LTX2AudioResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=self.causality_axis, + ) + ) + block_in = block_out + if self.attn_resolutions: + if curr_res in self.attn_resolutions: + stage.attn.append(LTX2AudioAttnBlock(block_in, norm_type=self.norm_type)) + + if level != 0: + stage.upsample = LTX2AudioUpsample(block_in, True, causality_axis=self.causality_axis) + curr_res *= 2 + + self.up.insert(0, stage) + + final_block_channels = block_in + + if self.norm_type == "group": + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=final_block_channels, eps=1e-6, affine=True) + elif self.norm_type == "pixel": + self.norm_out = LTX2AudioPixelNorm(dim=1, eps=1e-6) + else: + raise ValueError(f"Invalid normalization type: {self.norm_type}") + + if self.causality_axis is not None: + self.conv_out = LTX2AudioCausalConv2d( + final_block_channels, output_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis + ) + else: + self.conv_out = nn.Conv2d(final_block_channels, output_channels, kernel_size=3, stride=1, padding=1) + + def forward( + self, + sample: torch.Tensor, + ) -> torch.Tensor: + _, _, frames, mel_bins = sample.shape + + target_frames = frames * LATENT_DOWNSAMPLE_FACTOR + + if self.causality_axis is not None: + target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1) + + target_channels = self.out_ch + target_mel_bins = self.mel_bins if self.mel_bins is not None else mel_bins + + hidden_features = self.conv_in(sample) + hidden_features = self.mid.block_1(hidden_features, temb=None) + hidden_features = self.mid.attn_1(hidden_features) + hidden_features = self.mid.block_2(hidden_features, temb=None) + + for level in reversed(range(self.num_resolutions)): + stage = self.up[level] + for block_idx, block in enumerate(stage.block): + hidden_features = block(hidden_features, temb=None) + if stage.attn: + hidden_features = stage.attn[block_idx](hidden_features) + + if level != 0 and hasattr(stage, "upsample"): + hidden_features = stage.upsample(hidden_features) + + if self.give_pre_end: + return hidden_features + + hidden = self.norm_out(hidden_features) + hidden = self.non_linearity(hidden) + decoded_output = self.conv_out(hidden) + decoded_output = torch.tanh(decoded_output) if self.tanh_out else decoded_output + + _, _, current_time, current_freq = decoded_output.shape + target_time = target_frames + target_freq = target_mel_bins + + decoded_output = decoded_output[ + :, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq) + ] + + time_padding_needed = target_time - decoded_output.shape[2] + freq_padding_needed = target_freq - decoded_output.shape[3] + + if time_padding_needed > 0 or freq_padding_needed > 0: + padding = ( + 0, + max(freq_padding_needed, 0), + 0, + max(time_padding_needed, 0), + ) + decoded_output = F.pad(decoded_output, padding) + + decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq] + + return decoded_output + + +class AutoencoderKLLTX2Audio(ModelMixin, ConfigMixin): + r""" + LTX2 audio VAE for encoding and decoding audio latent representations. + """ + + _supports_gradient_checkpointing = False + + @register_to_config + def __init__( + self, + base_channels: int = 128, + output_channels: int = 2, + ch_mult: tuple[int, ...] = (1, 2, 4), + num_res_blocks: int = 2, + attn_resolutions: tuple[int, ...] | None = None, + in_channels: int = 2, + resolution: int = 256, + latent_channels: int = 8, + norm_type: str = "pixel", + causality_axis: str | None = "height", + dropout: float = 0.0, + mid_block_add_attention: bool = False, + sample_rate: int = 16000, + mel_hop_length: int = 160, + is_causal: bool = True, + mel_bins: int | None = 64, + double_z: bool = True, + ) -> None: + super().__init__() + + supported_causality_axes = {"none", "width", "height", "width-compatibility"} + if causality_axis not in supported_causality_axes: + raise ValueError(f"{causality_axis=} is not valid. Supported values: {supported_causality_axes}") + + attn_resolution_set = set(attn_resolutions) if attn_resolutions else attn_resolutions + + self.encoder = LTX2AudioEncoder( + base_channels=base_channels, + output_channels=output_channels, + ch_mult=ch_mult, + num_res_blocks=num_res_blocks, + attn_resolutions=attn_resolution_set, + in_channels=in_channels, + resolution=resolution, + latent_channels=latent_channels, + norm_type=norm_type, + causality_axis=causality_axis, + dropout=dropout, + mid_block_add_attention=mid_block_add_attention, + sample_rate=sample_rate, + mel_hop_length=mel_hop_length, + is_causal=is_causal, + mel_bins=mel_bins, + double_z=double_z, + ) + + self.decoder = LTX2AudioDecoder( + base_channels=base_channels, + output_channels=output_channels, + ch_mult=ch_mult, + num_res_blocks=num_res_blocks, + attn_resolutions=attn_resolution_set, + in_channels=in_channels, + resolution=resolution, + latent_channels=latent_channels, + norm_type=norm_type, + causality_axis=causality_axis, + dropout=dropout, + mid_block_add_attention=mid_block_add_attention, + sample_rate=sample_rate, + mel_hop_length=mel_hop_length, + is_causal=is_causal, + mel_bins=mel_bins, + ) + + # Per-channel statistics for normalizing and denormalizing the latent representation. This statics is computed over + # the entire dataset and stored in model's checkpoint under AudioVAE state_dict + latents_std = torch.ones((base_channels,)) + latents_mean = torch.zeros((base_channels,)) + self.register_buffer("latents_mean", latents_mean, persistent=True) + self.register_buffer("latents_std", latents_std, persistent=True) + + # TODO: calculate programmatically instead of hardcoding + self.temporal_compression_ratio = LATENT_DOWNSAMPLE_FACTOR # 4 + # TODO: confirm whether the mel compression ratio below is correct + self.mel_compression_ratio = LATENT_DOWNSAMPLE_FACTOR + self.use_slicing = False + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + return self.encoder(x) + + @apply_forward_hook + def encode(self, x: torch.Tensor, return_dict: bool = True): + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.Tensor) -> torch.Tensor: + return self.decoder(z) + + @apply_forward_hook + def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor: + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z) + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: torch.Generator | None = None, + ) -> DecoderOutput | torch.Tensor: + posterior = self.encode(sample).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z) + if not return_dict: + return (dec.sample,) + return dec diff --git a/videox_fun/models/ltx2_vocoder.py b/videox_fun/models/ltx2_vocoder.py new file mode 100644 index 00000000..7e038827 --- /dev/null +++ b/videox_fun/models/ltx2_vocoder.py @@ -0,0 +1,158 @@ +# Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/ltx2/vocoder.py +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin + + +class ResBlock(nn.Module): + def __init__( + self, + channels: int, + kernel_size: int = 3, + stride: int = 1, + dilations: tuple[int, ...] = (1, 3, 5), + leaky_relu_negative_slope: float = 0.1, + padding_mode: str = "same", + ): + super().__init__() + self.dilations = dilations + self.negative_slope = leaky_relu_negative_slope + + self.convs1 = nn.ModuleList( + [ + nn.Conv1d(channels, channels, kernel_size, stride=stride, dilation=dilation, padding=padding_mode) + for dilation in dilations + ] + ) + + self.convs2 = nn.ModuleList( + [ + nn.Conv1d(channels, channels, kernel_size, stride=stride, dilation=1, padding=padding_mode) + for _ in range(len(dilations)) + ] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for conv1, conv2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, negative_slope=self.negative_slope) + xt = conv1(xt) + xt = F.leaky_relu(xt, negative_slope=self.negative_slope) + xt = conv2(xt) + x = x + xt + return x + + +class LTX2Vocoder(ModelMixin, ConfigMixin): + r""" + LTX 2.0 vocoder for converting generated mel spectrograms back to audio waveforms. + """ + + @register_to_config + def __init__( + self, + in_channels: int = 128, + hidden_channels: int = 1024, + out_channels: int = 2, + upsample_kernel_sizes: list[int] = [16, 15, 8, 4, 4], + upsample_factors: list[int] = [6, 5, 2, 2, 2], + resnet_kernel_sizes: list[int] = [3, 7, 11], + resnet_dilations: list[list[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + leaky_relu_negative_slope: float = 0.1, + output_sampling_rate: int = 24000, + ): + super().__init__() + self.num_upsample_layers = len(upsample_kernel_sizes) + self.resnets_per_upsample = len(resnet_kernel_sizes) + self.out_channels = out_channels + self.total_upsample_factor = math.prod(upsample_factors) + self.negative_slope = leaky_relu_negative_slope + + if self.num_upsample_layers != len(upsample_factors): + raise ValueError( + f"`upsample_kernel_sizes` and `upsample_factors` should be lists of the same length but are length" + f" {self.num_upsample_layers} and {len(upsample_factors)}, respectively." + ) + + if self.resnets_per_upsample != len(resnet_dilations): + raise ValueError( + f"`resnet_kernel_sizes` and `resnet_dilations` should be lists of the same length but are length" + f" {len(self.resnets_per_upsample)} and {len(resnet_dilations)}, respectively." + ) + + self.conv_in = nn.Conv1d(in_channels, hidden_channels, kernel_size=7, stride=1, padding=3) + + self.upsamplers = nn.ModuleList() + self.resnets = nn.ModuleList() + input_channels = hidden_channels + for i, (stride, kernel_size) in enumerate(zip(upsample_factors, upsample_kernel_sizes)): + output_channels = input_channels // 2 + self.upsamplers.append( + nn.ConvTranspose1d( + input_channels, # hidden_channels // (2 ** i) + output_channels, # hidden_channels // (2 ** (i + 1)) + kernel_size, + stride=stride, + padding=(kernel_size - stride) // 2, + ) + ) + + for kernel_size, dilations in zip(resnet_kernel_sizes, resnet_dilations): + self.resnets.append( + ResBlock( + output_channels, + kernel_size, + dilations=dilations, + leaky_relu_negative_slope=leaky_relu_negative_slope, + ) + ) + input_channels = output_channels + + self.conv_out = nn.Conv1d(output_channels, out_channels, 7, stride=1, padding=3) + + def forward(self, hidden_states: torch.Tensor, time_last: bool = False) -> torch.Tensor: + r""" + Forward pass of the vocoder. + + Args: + hidden_states (`torch.Tensor`): + Input Mel spectrogram tensor of shape `(batch_size, num_channels, time, num_mel_bins)` if `time_last` + is `False` (the default) or shape `(batch_size, num_channels, num_mel_bins, time)` if `time_last` is + `True`. + time_last (`bool`, *optional*, defaults to `False`): + Whether the last dimension of the input is the time/frame dimension or the Mel bins dimension. + + Returns: + `torch.Tensor`: + Audio waveform tensor of shape (batch_size, out_channels, audio_length) + """ + + # Ensure that the time/frame dimension is last + if not time_last: + hidden_states = hidden_states.transpose(2, 3) + # Combine channels and frequency (mel bins) dimensions + hidden_states = hidden_states.flatten(1, 2) + + hidden_states = self.conv_in(hidden_states) + + for i in range(self.num_upsample_layers): + hidden_states = F.leaky_relu(hidden_states, negative_slope=self.negative_slope) + hidden_states = self.upsamplers[i](hidden_states) + + # Run all resnets in parallel on hidden_states + start = i * self.resnets_per_upsample + end = (i + 1) * self.resnets_per_upsample + resnet_outputs = torch.stack([self.resnets[j](hidden_states) for j in range(start, end)], dim=0) + + hidden_states = torch.mean(resnet_outputs, dim=0) + + # NOTE: unlike the first leaky ReLU, this leaky ReLU is set to use the default F.leaky_relu negative slope of + # 0.01 (whereas the others usually use a slope of 0.1). Not sure if this is intended + hidden_states = F.leaky_relu(hidden_states, negative_slope=0.01) + hidden_states = self.conv_out(hidden_states) + hidden_states = torch.tanh(hidden_states) + + return hidden_states \ No newline at end of file diff --git a/videox_fun/models/wan_audio_injector.py b/videox_fun/models/wan_audio_injector.py index 35568c8c..7f77e43a 100644 --- a/videox_fun/models/wan_audio_injector.py +++ b/videox_fun/models/wan_audio_injector.py @@ -282,7 +282,7 @@ def forward(self, x): x = self.norm3(x) x = self.act(x) x = rearrange(x, '(b n) t c -> b t n c', b=b) - padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1) + padding = self.padding_tokens.to(x.dtype).repeat(b, x.shape[1], 1, 1) x = torch.cat([x, padding], dim=-2) x_local = x.clone() @@ -332,7 +332,7 @@ def __init__(self, def forward(self, features): with amp.autocast(dtype=torch.float32): # features B * num_layers * dim * video_length - weights = self.act(self.weights) + weights = self.act(self.weights.to(features.dtype)) weights_sum = weights.sum(dim=1, keepdims=True) weighted_feat = ((features * weights) / weights_sum).sum( dim=1) # b dim f diff --git a/videox_fun/pipeline/__init__.py b/videox_fun/pipeline/__init__.py index 731e558f..b29c87f1 100755 --- a/videox_fun/pipeline/__init__.py +++ b/videox_fun/pipeline/__init__.py @@ -9,6 +9,8 @@ from .pipeline_hunyuanvideo_i2v import HunyuanVideoI2VPipeline from .pipeline_longcatvideo import LongCatVideoPipeline from .pipeline_longcatvideo_avatar import LongCatVideoAvatarPipeline +from .pipeline_ltx2_i2v import LTX2ImageToVideoPipeline +from .pipeline_ltx2 import LTX2Pipeline from .pipeline_qwenimage import QwenImagePipeline from .pipeline_qwenimage_control import QwenImageControlPipeline from .pipeline_qwenimage_instantx import QwenImageControlNetPipeline diff --git a/videox_fun/pipeline/pipeline_cogvideox_fun.py b/videox_fun/pipeline/pipeline_cogvideox_fun.py index 68568a60..08f07920 100644 --- a/videox_fun/pipeline/pipeline_cogvideox_fun.py +++ b/videox_fun/pipeline/pipeline_cogvideox_fun.py @@ -607,8 +607,8 @@ def __call__( latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: str = "numpy", - return_dict: bool = False, + output_type: str = "pil", + return_dict: bool = True, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, @@ -845,11 +845,9 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() - if output_type == "numpy": + if output_type == "pil": video = self.decode_latents(latents) - elif not output_type == "latent": - video = self.decode_latents(latents) - video = self.video_processor.postprocess_video(video=video, output_type=output_type) + video = torch.from_numpy(video) else: video = latents @@ -857,6 +855,6 @@ def __call__( self.maybe_free_model_hooks() if not return_dict: - video = torch.from_numpy(video) + return video return CogVideoXFunPipelineOutput(videos=video) diff --git a/videox_fun/pipeline/pipeline_cogvideox_fun_control.py b/videox_fun/pipeline/pipeline_cogvideox_fun_control.py index e91df20d..e5f31dd4 100644 --- a/videox_fun/pipeline/pipeline_cogvideox_fun_control.py +++ b/videox_fun/pipeline/pipeline_cogvideox_fun_control.py @@ -659,8 +659,8 @@ def __call__( latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: str = "numpy", - return_dict: bool = False, + output_type: str = "pil", + return_dict: bool = True, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, @@ -939,11 +939,9 @@ def __call__( if comfyui_progressbar: pbar.update(1) - if output_type == "numpy": + if output_type == "pil": video = self.decode_latents(latents) - elif not output_type == "latent": - video = self.decode_latents(latents) - video = self.video_processor.postprocess_video(video=video, output_type=output_type) + video = torch.from_numpy(video) else: video = latents @@ -951,6 +949,6 @@ def __call__( self.maybe_free_model_hooks() if not return_dict: - video = torch.from_numpy(video) + return video return CogVideoXFunPipelineOutput(videos=video) diff --git a/videox_fun/pipeline/pipeline_cogvideox_fun_inpaint.py b/videox_fun/pipeline/pipeline_cogvideox_fun_inpaint.py index 7044d9d0..c4e72f21 100644 --- a/videox_fun/pipeline/pipeline_cogvideox_fun_inpaint.py +++ b/videox_fun/pipeline/pipeline_cogvideox_fun_inpaint.py @@ -757,8 +757,8 @@ def __call__( latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: str = "numpy", - return_dict: bool = False, + output_type: str = "pil", + return_dict: bool = True, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, @@ -1119,11 +1119,9 @@ def __call__( if comfyui_progressbar: pbar.update(1) - if output_type == "numpy": + if output_type == "pil": video = self.decode_latents(latents) - elif not output_type == "latent": - video = self.decode_latents(latents) - video = self.video_processor.postprocess_video(video=video, output_type=output_type) + video = torch.from_numpy(video) else: video = latents @@ -1131,6 +1129,6 @@ def __call__( self.maybe_free_model_hooks() if not return_dict: - video = torch.from_numpy(video) + return video return CogVideoXFunPipelineOutput(videos=video) diff --git a/videox_fun/pipeline/pipeline_fantasy_talking.py b/videox_fun/pipeline/pipeline_fantasy_talking.py index 8243033a..ccb411a1 100644 --- a/videox_fun/pipeline/pipeline_fantasy_talking.py +++ b/videox_fun/pipeline/pipeline_fantasy_talking.py @@ -160,6 +160,7 @@ class FantasyTalkingPipeline(DiffusionPipeline): """ _optional_components = ["audio_encoder"] + _exclude_from_cpu_offload = ["audio_encoder"] model_cpu_offload_seq = "text_encoder->clip_image_encoder->transformer->vae" _callback_tensor_inputs = [ @@ -496,8 +497,8 @@ def __call__( latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: str = "numpy", - return_dict: bool = False, + output_type: str = "pil", + return_dict: bool = True, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, @@ -743,11 +744,9 @@ def __call__( if comfyui_progressbar: pbar.update(1) - if output_type == "numpy": + if output_type == "pil": video = self.decode_latents(latents) - elif not output_type == "latent": - video = self.decode_latents(latents) - video = self.video_processor.postprocess_video(video=video, output_type=output_type) + video = torch.from_numpy(video) else: video = latents @@ -755,6 +754,6 @@ def __call__( self.maybe_free_model_hooks() if not return_dict: - video = torch.from_numpy(video) + return video return WanPipelineOutput(videos=video) diff --git a/videox_fun/pipeline/pipeline_hunyuanvideo.py b/videox_fun/pipeline/pipeline_hunyuanvideo.py index 9afe5c79..efd19790 100644 --- a/videox_fun/pipeline/pipeline_hunyuanvideo.py +++ b/videox_fun/pipeline/pipeline_hunyuanvideo.py @@ -539,8 +539,8 @@ def __call__( negative_prompt_embeds: Optional[torch.Tensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, - output_type: str = "numpy", - return_dict: bool = False, + output_type: str = "pil", + return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] @@ -788,11 +788,9 @@ def __call__( self._current_timestep = None - if output_type == "numpy": + if output_type == "pil": video = self.decode_latents(latents) - elif not output_type == "latent": - video = self.decode_latents(latents) - video = self.video_processor.postprocess_video(video=video, output_type=output_type) + video = torch.from_numpy(video) else: video = latents @@ -800,6 +798,6 @@ def __call__( self.maybe_free_model_hooks() if not return_dict: - video = torch.from_numpy(video) + return video return HunyuanVideoPipelineOutput(videos=video) diff --git a/videox_fun/pipeline/pipeline_hunyuanvideo_i2v.py b/videox_fun/pipeline/pipeline_hunyuanvideo_i2v.py index e2628ef4..151e3d25 100644 --- a/videox_fun/pipeline/pipeline_hunyuanvideo_i2v.py +++ b/videox_fun/pipeline/pipeline_hunyuanvideo_i2v.py @@ -678,8 +678,8 @@ def __call__( negative_prompt_embeds: Optional[torch.Tensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, - output_type: str = "numpy", - return_dict: bool = False, + output_type: str = "pil", + return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] @@ -955,11 +955,9 @@ def __call__( self._current_timestep = None - if output_type == "numpy": + if output_type == "pil": video = self.decode_latents(latents) - elif not output_type == "latent": - video = self.decode_latents(latents) - video = self.video_processor.postprocess_video(video=video, output_type=output_type) + video = torch.from_numpy(video) else: video = latents @@ -967,6 +965,6 @@ def __call__( self.maybe_free_model_hooks() if not return_dict: - video = torch.from_numpy(video) + return video return HunyuanVideoPipelineOutput(videos=video) diff --git a/videox_fun/pipeline/pipeline_longcatvideo.py b/videox_fun/pipeline/pipeline_longcatvideo.py index 88ad355c..dea78a32 100644 --- a/videox_fun/pipeline/pipeline_longcatvideo.py +++ b/videox_fun/pipeline/pipeline_longcatvideo.py @@ -474,8 +474,8 @@ def __call__( latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: str = "numpy", - return_dict: bool = False, + output_type: str = "pil", + return_dict: bool = True, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, @@ -697,13 +697,10 @@ def __call__( if comfyui_progressbar: pbar.update(1) - if output_type == "numpy": + if output_type == "pil": latents = self.denormalize_latents(latents) video = self.decode_latents(latents) - elif not output_type == "latent": - latents = self.denormalize_latents(latents) - video = self.decode_latents(latents) - video = self.video_processor.postprocess_video(video=video, output_type=output_type) + video = torch.from_numpy(video) else: video = latents @@ -711,6 +708,6 @@ def __call__( self.maybe_free_model_hooks() if not return_dict: - video = torch.from_numpy(video) + return video return LongCatVideoPipelineOutput(videos=video) diff --git a/videox_fun/pipeline/pipeline_longcatvideo_avatar.py b/videox_fun/pipeline/pipeline_longcatvideo_avatar.py index dc0d2dc7..67a478a4 100644 --- a/videox_fun/pipeline/pipeline_longcatvideo_avatar.py +++ b/videox_fun/pipeline/pipeline_longcatvideo_avatar.py @@ -124,7 +124,8 @@ class LongCatVideoAvatarPipeline(DiffusionPipeline): """ - _optional_components = [] + _exclude_from_cpu_offload = ["audio_encoder"] + _optional_components = ["audio_encoder"] model_cpu_offload_seq = "text_encoder->transformer->vae" _callback_tensor_inputs = [ @@ -464,8 +465,8 @@ def __call__( latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: str = "numpy", - return_dict: bool = False, + output_type: str = "pil", + return_dict: bool = True, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, @@ -685,13 +686,10 @@ def __call__( if comfyui_progressbar: pbar.update(1) - if output_type == "numpy": + if output_type == "pil": latents = self.denormalize_latents(latents) video = self.decode_latents(latents) - elif not output_type == "latent": - latents = self.denormalize_latents(latents) - video = self.decode_latents(latents) - video = self.video_processor.postprocess_video(video=video, output_type=output_type) + video = torch.from_numpy(video) else: video = latents @@ -699,6 +697,6 @@ def __call__( self.maybe_free_model_hooks() if not return_dict: - video = torch.from_numpy(video) + return video return LongCatVideoAvatarPipelineOutput(videos=video) diff --git a/videox_fun/pipeline/pipeline_ltx2.py b/videox_fun/pipeline/pipeline_ltx2.py new file mode 100644 index 00000000..8b6772cb --- /dev/null +++ b/videox_fun/pipeline/pipeline_ltx2.py @@ -0,0 +1,1259 @@ +# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import inspect +from dataclasses import dataclass +from typing import Any, Callable + +import numpy as np +import torch +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.image_processor import PipelineImageInput +from diffusers.loaders import FromSingleFileMixin +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import (BaseOutput, is_torch_xla_available, logging, + replace_example_docstring) +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor +from transformers import (Gemma3ForConditionalGeneration, GemmaTokenizer, + GemmaTokenizerFast) + +from ..models import (AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video, + LTX2TextConnectors, LTX2VideoTransformer3DModel, + LTX2Vocoder) + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import LTX2Pipeline + >>> from diffusers.pipelines.ltx2.export_utils import encode_video + + >>> pipe = LTX2Pipeline.from_pretrained("Lightricks/LTX-2", torch_dtype=torch.bfloat16) + >>> pipe.enable_model_cpu_offload() + + >>> prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage" + >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" + + >>> frame_rate = 24.0 + >>> video, audio = pipe( + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... width=768, + ... height=512, + ... num_frames=121, + ... frame_rate=frame_rate, + ... num_inference_steps=40, + ... guidance_scale=4.0, + ... output_type="np", + ... return_dict=False, + ... ) + + >>> encode_video( + ... video[0], + ... fps=frame_rate, + ... audio=audio[0].float().cpu(), + ... audio_sample_rate=pipe.vocoder.config.output_sampling_rate, # should be 24000 + ... output_path="video.mp4", + ... ) + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +@dataclass +class LTX2PipelineOutput(BaseOutput): + r""" + Output class for LTX pipelines. + + Args: + videos (`torch.Tensor`, `np.ndarray`, or list[list[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + audio (`torch.Tensor`, `np.ndarray`): + TODO + """ + + videos: torch.Tensor + audio: torch.Tensor + + +class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin): + r""" + Pipeline for text-to-video generation. + + Reference: https://github.com/Lightricks/LTX-Video + + Args: + transformer ([`LTXVideoTransformer3DModel`]): + Conditional Transformer architecture to denoise the encoded video latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLLTXVideo`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + connectors ([`LTX2TextConnectors`]): + Text connector stack used to adapt text encoder hidden states for the video and audio branches. + """ + + model_cpu_offload_seq = "text_encoder->connectors->transformer->vae->audio_vae->vocoder" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLLTX2Video, + audio_vae: AutoencoderKLLTX2Audio, + text_encoder: Gemma3ForConditionalGeneration, + tokenizer: GemmaTokenizer | GemmaTokenizerFast, + connectors: LTX2TextConnectors, + transformer: LTX2VideoTransformer3DModel, + vocoder: LTX2Vocoder, + ): + super().__init__() + + self.register_modules( + vae=vae, + audio_vae=audio_vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + connectors=connectors, + transformer=transformer, + vocoder=vocoder, + scheduler=scheduler, + ) + + self.vae_spatial_compression_ratio = ( + self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32 + ) + self.vae_temporal_compression_ratio = ( + self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8 + ) + # TODO: check whether the MEL compression ratio logic here is corrct + self.audio_vae_mel_compression_ratio = ( + self.audio_vae.mel_compression_ratio if getattr(self, "audio_vae", None) is not None else 4 + ) + self.audio_vae_temporal_compression_ratio = ( + self.audio_vae.temporal_compression_ratio if getattr(self, "audio_vae", None) is not None else 4 + ) + self.transformer_spatial_patch_size = ( + self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1 + ) + self.transformer_temporal_patch_size = ( + self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1 + ) + + self.audio_sampling_rate = ( + self.audio_vae.config.sample_rate if getattr(self, "audio_vae", None) is not None else 16000 + ) + self.audio_hop_length = ( + self.audio_vae.config.mel_hop_length if getattr(self, "audio_vae", None) is not None else 160 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024 + ) + + @staticmethod + def _pack_text_embeds( + text_hidden_states: torch.Tensor, + sequence_lengths: torch.Tensor, + device: str | torch.device, + padding_side: str = "left", + scale_factor: int = 8, + eps: float = 1e-6, + ) -> torch.Tensor: + """ + Packs and normalizes text encoder hidden states, respecting padding. Normalization is performed per-batch and + per-layer in a masked fashion (only over non-padded positions). + + Args: + text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`): + Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`). + sequence_lengths (`torch.Tensor of shape `(batch_size,)`): + The number of valid (non-padded) tokens for each batch instance. + device: (`str` or `torch.device`, *optional*): + torch device to place the resulting embeddings on + padding_side: (`str`, *optional*, defaults to `"left"`): + Whether the text tokenizer performs padding on the `"left"` or `"right"`. + scale_factor (`int`, *optional*, defaults to `8`): + Scaling factor to multiply the normalized hidden states by. + eps (`float`, *optional*, defaults to `1e-6`): + A small positive value for numerical stability when performing normalization. + + Returns: + `torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`: + Normed and flattened text encoder hidden states. + """ + batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape + original_dtype = text_hidden_states.dtype + + # Create padding mask + token_indices = torch.arange(seq_len, device=device).unsqueeze(0) + if padding_side == "right": + # For right padding, valid tokens are from 0 to sequence_length-1 + mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len] + elif padding_side == "left": + # For left padding, valid tokens are from (T - sequence_length) to T-1 + start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1] + mask = token_indices >= start_indices # [B, T] + else: + raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}") + mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1] + + # Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len) + masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0) + num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1) + masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps) + + # Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len) + x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True) + x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True) + + # Normalization + normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps) + normalized_hidden_states = normalized_hidden_states * scale_factor + + # Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers) + normalized_hidden_states = normalized_hidden_states.flatten(2) + mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers) + normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0) + normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype) + return normalized_hidden_states + + def _get_gemma_prompt_embeds( + self, + prompt: str | list[str], + num_videos_per_prompt: int = 1, + max_sequence_length: int = 1024, + scale_factor: int = 8, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`str` or `torch.device`): + torch device to place the resulting embeddings on + dtype: (`torch.dtype`): + torch dtype to cast the prompt embeds to + max_sequence_length (`int`, defaults to 1024): Maximum sequence length to use for the prompt. + """ + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if getattr(self, "tokenizer", None) is not None: + # Gemma expects left padding for chat-style prompts + self.tokenizer.padding_side = "left" + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + prompt = [p.strip() for p in prompt] + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + text_input_ids = text_input_ids.to(device) + prompt_attention_mask = prompt_attention_mask.to(device) + + text_encoder_outputs = self.text_encoder( + input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True + ) + text_encoder_hidden_states = text_encoder_outputs.hidden_states + text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1) + sequence_lengths = prompt_attention_mask.sum(dim=-1) + + prompt_embeds = self._pack_text_embeds( + text_encoder_hidden_states, + sequence_lengths, + device=device, + padding_side=self.tokenizer.padding_side, + scale_factor=scale_factor, + ) + prompt_embeds = prompt_embeds.to(dtype=dtype) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) + + return prompt_embeds, prompt_attention_mask + + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + max_sequence_length: int = 1024, + scale_factor: int = 8, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + scale_factor=scale_factor, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + scale_factor=scale_factor, + device=device, + dtype=dtype, + ) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + def check_inputs( + self, + prompt, + height, + width, + callback_on_step_end_tensor_inputs=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 32 != 0 or width % 32 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + @staticmethod + def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: + # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p]. + # The patch dimensions are then permuted and collapsed into the channel dimension of shape: + # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor). + # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features + batch_size, num_channels, num_frames, height, width = latents.shape + post_patch_num_frames = num_frames // patch_size_t + post_patch_height = height // patch_size + post_patch_width = width // patch_size + latents = latents.reshape( + batch_size, + -1, + post_patch_num_frames, + patch_size_t, + post_patch_height, + patch_size, + post_patch_width, + patch_size, + ) + latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) + return latents + + @staticmethod + def _unpack_latents( + latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1 + ) -> torch.Tensor: + # Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions) + # are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of + # what happens in the `_pack_latents` method. + batch_size = latents.size(0) + latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size) + latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) + return latents + + @staticmethod + def _normalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Normalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = (latents - latents_mean) * scaling_factor / latents_std + return latents + + @staticmethod + def _denormalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Denormalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = latents * latents_std / scaling_factor + latents_mean + return latents + + @staticmethod + def _create_noised_state( + latents: torch.Tensor, noise_scale: float | torch.Tensor, generator: torch.Generator | None = None + ): + noise = randn_tensor(latents.shape, generator=generator, device=latents.device, dtype=latents.dtype) + noised_latents = noise_scale * noise + (1 - noise_scale) * latents + return noised_latents + + @staticmethod + def _pack_audio_latents( + latents: torch.Tensor, patch_size: int | None = None, patch_size_t: int | None = None + ) -> torch.Tensor: + # Audio latents shape: [B, C, L, M], where L is the latent audio length and M is the number of mel bins + if patch_size is not None and patch_size_t is not None: + # Packs the latents into a patch sequence of shape [B, L // p_t * M // p, C * p_t * p] (a ndim=3 tnesor). + # dim=1 is the effective audio sequence length and dim=2 is the effective audio input feature size. + batch_size, num_channels, latent_length, latent_mel_bins = latents.shape + post_patch_latent_length = latent_length / patch_size_t + post_patch_mel_bins = latent_mel_bins / patch_size + latents = latents.reshape( + batch_size, -1, post_patch_latent_length, patch_size_t, post_patch_mel_bins, patch_size + ) + latents = latents.permute(0, 2, 4, 1, 3, 5).flatten(3, 5).flatten(1, 2) + else: + # Packs the latents into a patch sequence of shape [B, L, C * M]. This implicitly assumes a (mel) + # patch_size of M (all mel bins constitutes a single patch) and a patch_size_t of 1. + latents = latents.transpose(1, 2).flatten(2, 3) # [B, C, L, M] --> [B, L, C * M] + return latents + + @staticmethod + def _unpack_audio_latents( + latents: torch.Tensor, + latent_length: int, + num_mel_bins: int, + patch_size: int | None = None, + patch_size_t: int | None = None, + ) -> torch.Tensor: + # Unpacks an audio patch sequence of shape [B, S, D] into a latent spectrogram tensor of shape [B, C, L, M], + # where L is the latent audio length and M is the number of mel bins. + if patch_size is not None and patch_size_t is not None: + batch_size = latents.size(0) + latents = latents.reshape(batch_size, latent_length, num_mel_bins, -1, patch_size_t, patch_size) + latents = latents.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3) + else: + # Assume [B, S, D] = [B, L, C * M], which implies that patch_size = M and patch_size_t = 1. + latents = latents.unflatten(2, (-1, num_mel_bins)).transpose(1, 2) + return latents + + @staticmethod + def _normalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor): + latents_mean = latents_mean.to(latents.device, latents.dtype) + latents_std = latents_std.to(latents.device, latents.dtype) + return (latents - latents_mean) / latents_std + + @staticmethod + def _denormalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor): + latents_mean = latents_mean.to(latents.device, latents.dtype) + latents_std = latents_std.to(latents.device, latents.dtype) + return (latents * latents_std) + latents_mean + + def prepare_latents( + self, + batch_size: int = 1, + num_channels_latents: int = 128, + height: int = 512, + width: int = 768, + num_frames: int = 121, + noise_scale: float = 0.0, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if latents is not None: + if latents.ndim == 5: + latents = self._normalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + # latents are of shape [B, C, F, H, W], need to be packed + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + if latents.ndim != 3: + raise ValueError( + f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is [batch_size, num_seq, num_features]." + ) + latents = self._create_noised_state(latents, noise_scale, generator) + return latents.to(device=device, dtype=dtype) + + height = height // self.vae_spatial_compression_ratio + width = width // self.vae_spatial_compression_ratio + num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + + shape = (batch_size, num_channels_latents, num_frames, height, width) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + return latents + + def prepare_audio_latents( + self, + batch_size: int = 1, + num_channels_latents: int = 8, + audio_latent_length: int = 1, # 1 is just a dummy value + num_mel_bins: int = 64, + noise_scale: float = 0.0, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if latents is not None: + if latents.ndim == 4: + # latents are of shape [B, C, L, M], need to be packed + latents = self._pack_audio_latents(latents) + if latents.ndim != 3: + raise ValueError( + f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is [batch_size, num_seq, num_features]." + ) + latents = self._normalize_audio_latents(latents, self.audio_vae.latents_mean, self.audio_vae.latents_std) + latents = self._create_noised_state(latents, noise_scale, generator) + return latents.to(device=device, dtype=dtype) + + # TODO: confirm whether this logic is correct + latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio + + shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_audio_latents(latents) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str | list[str] | None = None, + height: int = 512, + width: int = 768, + num_frames: int = 121, + frame_rate: float = 24.0, + num_inference_steps: int = 40, + sigmas: list[float] | None = None, + timesteps: list[int] | None = None, + guidance_scale: float = 4.0, + guidance_rescale: float = 0.0, + noise_scale: float = 0.0, + num_videos_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + audio_latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + decode_timestep: float | list[float] = 0.0, + decode_noise_scale: float | list[float] | None = None, + output_type: str = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 1024, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to `512`): + The height in pixels of the generated image. This is set to 480 by default for the best results. + width (`int`, *optional*, defaults to `768`): + The width in pixels of the generated image. This is set to 848 by default for the best results. + num_frames (`int`, *optional*, defaults to `121`): + The number of video frames to generate + frame_rate (`float`, *optional*, defaults to `24.0`): + The frames per second (FPS) of the generated video. + num_inference_steps (`int`, *optional*, defaults to 40): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to `4.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + noise_scale (`float`, *optional*, defaults to `0.0`): + The interpolation factor between random noise and denoised latents at each timestep. Applying noise to + the `latents` and `audio_latents` before continue denoising. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + audio_latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for audio + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.FloatTensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + decode_timestep (`float`, defaults to `0.0`): + The timestep at which generated video is decoded. + decode_noise_scale (`float`, defaults to `None`): + The interpolation factor between random noise and denoised latents at the decode timestep. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ltx.LTX2PipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, *optional*, defaults to `1024`): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.ltx.LTX2PipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ltx.LTX2PipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._attention_kwargs = attention_kwargs + self._interrupt = False + self._current_timestep = None + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Prepare text embeddings + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + device=device, + ) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + additive_attention_mask = (1 - prompt_attention_mask.to(prompt_embeds.dtype)) * -1000000.0 + connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = self.connectors( + prompt_embeds, additive_attention_mask, additive_mask=True + ) + + # 4. Prepare latent variables + latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + if latents is not None: + if latents.ndim == 5: + logger.info( + "Got latents of shape [batch_size, latent_dim, latent_frames, latent_height, latent_width], `latent_num_frames`, `latent_height`, `latent_width` will be inferred." + ) + _, _, latent_num_frames, latent_height, latent_width = latents.shape # [B, C, F, H, W] + elif latents.ndim == 3: + logger.warning( + f"You have supplied packed `latents` of shape {latents.shape}, so the latent dims cannot be" + f" inferred. Make sure the supplied `height`, `width`, and `num_frames` are correct." + ) + else: + raise ValueError( + f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is either [batch_size, seq_len, num_features] or [batch_size, latent_dim, latent_frames, latent_height, latent_width]." + ) + video_sequence_length = latent_num_frames * latent_height * latent_width + + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + noise_scale, + torch.float32, + device, + generator, + latents, + ) + + duration_s = num_frames / frame_rate + audio_latents_per_second = ( + self.audio_sampling_rate / self.audio_hop_length / float(self.audio_vae_temporal_compression_ratio) + ) + audio_num_frames = round(duration_s * audio_latents_per_second) + if audio_latents is not None: + if audio_latents.ndim == 4: + logger.info( + "Got audio_latents of shape [batch_size, num_channels, audio_length, mel_bins], `audio_num_frames` will be inferred." + ) + _, _, audio_num_frames, _ = audio_latents.shape # [B, C, L, M] + elif audio_latents.ndim == 3: + logger.warning( + f"You have supplied packed `audio_latents` of shape {audio_latents.shape}, so the latent dims" + f" cannot be inferred. Make sure the supplied `num_frames` and `frame_rate` are correct." + ) + else: + raise ValueError( + f"Provided `audio_latents` tensor has shape {audio_latents.shape}, but the expected shape is either [batch_size, seq_len, num_features] or [batch_size, num_channels, audio_length, mel_bins]." + ) + + num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64 + latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio + num_channels_latents_audio = ( + self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8 + ) + audio_latents = self.prepare_audio_latents( + batch_size * num_videos_per_prompt, + num_channels_latents=num_channels_latents_audio, + audio_latent_length=audio_num_frames, + num_mel_bins=num_mel_bins, + noise_scale=noise_scale, + dtype=torch.float32, + device=device, + generator=generator, + latents=audio_latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + mu = calculate_shift( + video_sequence_length, + self.scheduler.config.get("base_image_seq_len", 1024), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.95), + self.scheduler.config.get("max_shift", 2.05), + ) + + # For now, duplicate the scheduler for use with the audio latents + audio_scheduler = copy.deepcopy(self.scheduler) + _, _ = retrieve_timesteps( + audio_scheduler, + num_inference_steps, + device, + timesteps, + sigmas=sigmas, + mu=mu, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Prepare micro-conditions + rope_interpolation_scale = ( + self.vae_temporal_compression_ratio / frame_rate, + self.vae_spatial_compression_ratio, + self.vae_spatial_compression_ratio, + ) + # Pre-compute video and audio positional ids as they will be the same at each step of the denoising loop + video_coords = self.transformer.rope.prepare_video_coords( + latents.shape[0], latent_num_frames, latent_height, latent_width, latents.device, fps=frame_rate + ) + audio_coords = self.transformer.audio_rope.prepare_audio_coords( + audio_latents.shape[0], audio_num_frames, audio_latents.device + ) + # Duplicate the positional ids as well if using CFG + if self.do_classifier_free_guidance: + video_coords = video_coords.repeat((2,) + (1,) * (video_coords.ndim - 1)) # Repeat twice in batch dim + audio_coords = audio_coords.repeat((2,) + (1,) * (audio_coords.ndim - 1)) + + # 7. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = latent_model_input.to(prompt_embeds.dtype) + audio_latent_model_input = ( + torch.cat([audio_latents] * 2) if self.do_classifier_free_guidance else audio_latents + ) + audio_latent_model_input = audio_latent_model_input.to(prompt_embeds.dtype) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + noise_pred_video, noise_pred_audio = self.transformer( + hidden_states=latent_model_input, + audio_hidden_states=audio_latent_model_input, + encoder_hidden_states=connector_prompt_embeds, + audio_encoder_hidden_states=connector_audio_prompt_embeds, + timestep=timestep, + encoder_attention_mask=connector_attention_mask, + audio_encoder_attention_mask=connector_attention_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + fps=frame_rate, + audio_num_frames=audio_num_frames, + video_coords=video_coords, + audio_coords=audio_coords, + # rope_interpolation_scale=rope_interpolation_scale, + attention_kwargs=attention_kwargs, + return_dict=False, + ) + noise_pred_video = noise_pred_video.float() + noise_pred_audio = noise_pred_audio.float() + + if self.do_classifier_free_guidance: + noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2) + noise_pred_video = noise_pred_video_uncond + self.guidance_scale * ( + noise_pred_video_text - noise_pred_video_uncond + ) + + noise_pred_audio_uncond, noise_pred_audio_text = noise_pred_audio.chunk(2) + noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * ( + noise_pred_audio_text - noise_pred_audio_uncond + ) + + if self.guidance_rescale > 0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred_video = rescale_noise_cfg( + noise_pred_video, noise_pred_video_text, guidance_rescale=self.guidance_rescale + ) + noise_pred_audio = rescale_noise_cfg( + noise_pred_audio, noise_pred_audio_text, guidance_rescale=self.guidance_rescale + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred_video, t, latents, return_dict=False)[0] + # NOTE: for now duplicate scheduler for audio latents in case self.scheduler sets internal state in + # the step method (such as _step_index) + audio_latents = audio_scheduler.step(noise_pred_audio, t, audio_latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + latents = self._unpack_latents( + latents, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + + audio_latents = self._denormalize_audio_latents( + audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std + ) + audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins) + + if output_type == "latent": + video = latents + audio = audio_latents + else: + latents = latents.to(prompt_embeds.dtype) + + if not self.vae.config.timestep_conditioning: + timestep = None + else: + noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype) + if not isinstance(decode_timestep, list): + decode_timestep = [decode_timestep] * batch_size + if decode_noise_scale is None: + decode_noise_scale = decode_timestep + elif not isinstance(decode_noise_scale, list): + decode_noise_scale = [decode_noise_scale] * batch_size + + timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype) + decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[ + :, None, None, None, None + ] + latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise + + latents = latents.to(self.vae.dtype) + video = self.vae.decode(latents, timestep, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type).cpu().float().permute(0, 2, 1, 3, 4) + + audio_latents = audio_latents.to(self.audio_vae.dtype) + generated_mel_spectrograms = self.audio_vae.decode(audio_latents, return_dict=False)[0] + audio = self.vocoder(generated_mel_spectrograms).cpu().float() + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video, audio) + + return LTX2PipelineOutput(videos=video, audio=audio) \ No newline at end of file diff --git a/videox_fun/pipeline/pipeline_ltx2_i2v.py b/videox_fun/pipeline/pipeline_ltx2_i2v.py new file mode 100644 index 00000000..d631d01c --- /dev/null +++ b/videox_fun/pipeline/pipeline_ltx2_i2v.py @@ -0,0 +1,1299 @@ +# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py +# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import inspect +from dataclasses import dataclass +from typing import Any, Callable + +import numpy as np +import torch +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.image_processor import PipelineImageInput +from diffusers.loaders import FromSingleFileMixin +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import (BaseOutput, is_torch_xla_available, logging, + replace_example_docstring) +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor +from transformers import (Gemma3ForConditionalGeneration, GemmaTokenizer, + GemmaTokenizerFast) + +from ..models import (AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video, + LTX2TextConnectors, LTX2VideoTransformer3DModel, + LTX2Vocoder) + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +@dataclass +class LTX2PipelineOutput(BaseOutput): + r""" + Output class for LTX pipelines. + + Args: + videos (`torch.Tensor`, `np.ndarray`, or list[list[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + audio (`torch.Tensor`, `np.ndarray`): + TODO + """ + + videos: torch.Tensor + audio: torch.Tensor + + +class LTX2I2VPipeline(DiffusionPipeline, FromSingleFileMixin): + r""" + Pipeline for image-to-video generation. + + Reference: https://github.com/Lightricks/LTX-Video + + TODO + """ + + model_cpu_offload_seq = "text_encoder->connectors->transformer->vae->audio_vae->vocoder" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLLTX2Video, + audio_vae: AutoencoderKLLTX2Audio, + text_encoder: Gemma3ForConditionalGeneration, + tokenizer: GemmaTokenizer | GemmaTokenizerFast, + connectors: LTX2TextConnectors, + transformer: LTX2VideoTransformer3DModel, + vocoder: LTX2Vocoder, + ): + super().__init__() + + self.register_modules( + vae=vae, + audio_vae=audio_vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + connectors=connectors, + transformer=transformer, + vocoder=vocoder, + scheduler=scheduler, + ) + + self.vae_spatial_compression_ratio = ( + self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32 + ) + self.vae_temporal_compression_ratio = ( + self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8 + ) + # TODO: check whether the MEL compression ratio logic here is corrct + self.audio_vae_mel_compression_ratio = ( + self.audio_vae.mel_compression_ratio if getattr(self, "audio_vae", None) is not None else 4 + ) + self.audio_vae_temporal_compression_ratio = ( + self.audio_vae.temporal_compression_ratio if getattr(self, "audio_vae", None) is not None else 4 + ) + self.transformer_spatial_patch_size = ( + self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1 + ) + self.transformer_temporal_patch_size = ( + self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1 + ) + + self.audio_sampling_rate = ( + self.audio_vae.config.sample_rate if getattr(self, "audio_vae", None) is not None else 16000 + ) + self.audio_hop_length = ( + self.audio_vae.config.mel_hop_length if getattr(self, "audio_vae", None) is not None else 160 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio, resample="bilinear") + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024 + ) + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_text_embeds + def _pack_text_embeds( + text_hidden_states: torch.Tensor, + sequence_lengths: torch.Tensor, + device: str | torch.device, + padding_side: str = "left", + scale_factor: int = 8, + eps: float = 1e-6, + ) -> torch.Tensor: + """ + Packs and normalizes text encoder hidden states, respecting padding. Normalization is performed per-batch and + per-layer in a masked fashion (only over non-padded positions). + + Args: + text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`): + Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`). + sequence_lengths (`torch.Tensor of shape `(batch_size,)`): + The number of valid (non-padded) tokens for each batch instance. + device: (`str` or `torch.device`, *optional*): + torch device to place the resulting embeddings on + padding_side: (`str`, *optional*, defaults to `"left"`): + Whether the text tokenizer performs padding on the `"left"` or `"right"`. + scale_factor (`int`, *optional*, defaults to `8`): + Scaling factor to multiply the normalized hidden states by. + eps (`float`, *optional*, defaults to `1e-6`): + A small positive value for numerical stability when performing normalization. + + Returns: + `torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`: + Normed and flattened text encoder hidden states. + """ + batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape + original_dtype = text_hidden_states.dtype + + # Create padding mask + token_indices = torch.arange(seq_len, device=device).unsqueeze(0) + if padding_side == "right": + # For right padding, valid tokens are from 0 to sequence_length-1 + mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len] + elif padding_side == "left": + # For left padding, valid tokens are from (T - sequence_length) to T-1 + start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1] + mask = token_indices >= start_indices # [B, T] + else: + raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}") + mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1] + + # Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len) + masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0) + num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1) + masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps) + + # Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len) + x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True) + x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True) + + # Normalization + normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps) + normalized_hidden_states = normalized_hidden_states * scale_factor + + # Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers) + normalized_hidden_states = normalized_hidden_states.flatten(2) + mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers) + normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0) + normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype) + return normalized_hidden_states + + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._get_gemma_prompt_embeds + def _get_gemma_prompt_embeds( + self, + prompt: str | list[str], + num_videos_per_prompt: int = 1, + max_sequence_length: int = 1024, + scale_factor: int = 8, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + device: (`str` or `torch.device`): + torch device to place the resulting embeddings on + dtype: (`torch.dtype`): + torch dtype to cast the prompt embeds to + max_sequence_length (`int`, defaults to 1024): Maximum sequence length to use for the prompt. + """ + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if getattr(self, "tokenizer", None) is not None: + # Gemma expects left padding for chat-style prompts + self.tokenizer.padding_side = "left" + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + prompt = [p.strip() for p in prompt] + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + text_input_ids = text_input_ids.to(device) + prompt_attention_mask = prompt_attention_mask.to(device) + + text_encoder_outputs = self.text_encoder( + input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True + ) + text_encoder_hidden_states = text_encoder_outputs.hidden_states + text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1) + sequence_lengths = prompt_attention_mask.sum(dim=-1) + + prompt_embeds = self._pack_text_embeds( + text_encoder_hidden_states, + sequence_lengths, + device=device, + padding_side=self.tokenizer.padding_side, + scale_factor=scale_factor, + ) + prompt_embeds = prompt_embeds.to(dtype=dtype) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) + + return prompt_embeds, prompt_attention_mask + + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline.encode_prompt + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + max_sequence_length: int = 1024, + scale_factor: int = 8, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + scale_factor=scale_factor, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + scale_factor=scale_factor, + device=device, + dtype=dtype, + ) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + callback_on_step_end_tensor_inputs=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 32 != 0 or width % 32 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_latents + def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: + # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p]. + # The patch dimensions are then permuted and collapsed into the channel dimension of shape: + # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor). + # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features + batch_size, num_channels, num_frames, height, width = latents.shape + post_patch_num_frames = num_frames // patch_size_t + post_patch_height = height // patch_size + post_patch_width = width // patch_size + latents = latents.reshape( + batch_size, + -1, + post_patch_num_frames, + patch_size_t, + post_patch_height, + patch_size, + post_patch_width, + patch_size, + ) + latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._unpack_latents + def _unpack_latents( + latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1 + ) -> torch.Tensor: + # Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions) + # are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of + # what happens in the `_pack_latents` method. + batch_size = latents.size(0) + latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size) + latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) + return latents + + @staticmethod + def _normalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Normalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = (latents - latents_mean) * scaling_factor / latents_std + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._denormalize_latents + def _denormalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Denormalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = latents * latents_std / scaling_factor + latents_mean + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._create_noised_state + def _create_noised_state( + latents: torch.Tensor, noise_scale: float | torch.Tensor, generator: torch.Generator | None = None + ): + noise = randn_tensor(latents.shape, generator=generator, device=latents.device, dtype=latents.dtype) + noised_latents = noise_scale * noise + (1 - noise_scale) * latents + return noised_latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_audio_latents + def _pack_audio_latents( + latents: torch.Tensor, patch_size: int | None = None, patch_size_t: int | None = None + ) -> torch.Tensor: + # Audio latents shape: [B, C, L, M], where L is the latent audio length and M is the number of mel bins + if patch_size is not None and patch_size_t is not None: + # Packs the latents into a patch sequence of shape [B, L // p_t * M // p, C * p_t * p] (a ndim=3 tnesor). + # dim=1 is the effective audio sequence length and dim=2 is the effective audio input feature size. + batch_size, num_channels, latent_length, latent_mel_bins = latents.shape + post_patch_latent_length = latent_length / patch_size_t + post_patch_mel_bins = latent_mel_bins / patch_size + latents = latents.reshape( + batch_size, -1, post_patch_latent_length, patch_size_t, post_patch_mel_bins, patch_size + ) + latents = latents.permute(0, 2, 4, 1, 3, 5).flatten(3, 5).flatten(1, 2) + else: + # Packs the latents into a patch sequence of shape [B, L, C * M]. This implicitly assumes a (mel) + # patch_size of M (all mel bins constitutes a single patch) and a patch_size_t of 1. + latents = latents.transpose(1, 2).flatten(2, 3) # [B, C, L, M] --> [B, L, C * M] + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._unpack_audio_latents + def _unpack_audio_latents( + latents: torch.Tensor, + latent_length: int, + num_mel_bins: int, + patch_size: int | None = None, + patch_size_t: int | None = None, + ) -> torch.Tensor: + # Unpacks an audio patch sequence of shape [B, S, D] into a latent spectrogram tensor of shape [B, C, L, M], + # where L is the latent audio length and M is the number of mel bins. + if patch_size is not None and patch_size_t is not None: + batch_size = latents.size(0) + latents = latents.reshape(batch_size, latent_length, num_mel_bins, -1, patch_size_t, patch_size) + latents = latents.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3) + else: + # Assume [B, S, D] = [B, L, C * M], which implies that patch_size = M and patch_size_t = 1. + latents = latents.unflatten(2, (-1, num_mel_bins)).transpose(1, 2) + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._normalize_audio_latents + def _normalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor): + latents_mean = latents_mean.to(latents.device, latents.dtype) + latents_std = latents_std.to(latents.device, latents.dtype) + return (latents - latents_mean) / latents_std + + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._denormalize_audio_latents + def _denormalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor): + latents_mean = latents_mean.to(latents.device, latents.dtype) + latents_std = latents_std.to(latents.device, latents.dtype) + return (latents * latents_std) + latents_mean + + def prepare_latents( + self, + image: torch.Tensor | None = None, + batch_size: int = 1, + num_channels_latents: int = 128, + height: int = 512, + width: int = 704, + num_frames: int = 161, + noise_scale: float = 0.0, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + height = height // self.vae_spatial_compression_ratio + width = width // self.vae_spatial_compression_ratio + num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + + shape = (batch_size, num_channels_latents, num_frames, height, width) + mask_shape = (batch_size, 1, num_frames, height, width) + + if latents is not None: + if latents.ndim == 5: + # conditioning_mask needs to the same shape as latents in two stages generation. + batch_size, _, num_frames, height, width = latents.shape + mask_shape = (batch_size, 1, num_frames, height, width) + conditioning_mask = latents.new_zeros(mask_shape) + conditioning_mask[:, :, 0] = 1.0 + + latents = self._normalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + latents = self._create_noised_state(latents, noise_scale * (1 - conditioning_mask), generator) + # latents are of shape [B, C, F, H, W], need to be packed + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + else: + conditioning_mask = latents.new_zeros(mask_shape) + conditioning_mask[:, :, 0] = 1.0 + conditioning_mask = self._pack_latents( + conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ).squeeze(-1) + if latents.ndim != 3 or latents.shape[:2] != conditioning_mask.shape: + raise ValueError( + f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is {conditioning_mask.shape + (num_channels_latents,)}." + ) + return latents.to(device=device, dtype=dtype), conditioning_mask + + if isinstance(generator, list): + if len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + init_latents = [ + retrieve_latents(self.vae.encode(image[i].unsqueeze(0).unsqueeze(2)), generator[i], "argmax") + for i in range(batch_size) + ] + else: + init_latents = [ + retrieve_latents(self.vae.encode(img.unsqueeze(0).unsqueeze(2)), generator, "argmax") for img in image + ] + + init_latents = torch.cat(init_latents, dim=0).to(dtype) + init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std) + init_latents = init_latents.repeat(1, 1, num_frames, 1, 1) + + # First condition is image latents and those should be kept clean. + conditioning_mask = torch.zeros(mask_shape, device=device, dtype=dtype) + conditioning_mask[:, :, 0] = 1.0 + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # Interpolation. + latents = init_latents * conditioning_mask + noise * (1 - conditioning_mask) + + conditioning_mask = self._pack_latents( + conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ).squeeze(-1) + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + + return latents, conditioning_mask + + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline.prepare_audio_latents + def prepare_audio_latents( + self, + batch_size: int = 1, + num_channels_latents: int = 8, + audio_latent_length: int = 1, # 1 is just a dummy value + num_mel_bins: int = 64, + noise_scale: float = 0.0, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if latents is not None: + if latents.ndim == 4: + # latents are of shape [B, C, L, M], need to be packed + latents = self._pack_audio_latents(latents) + if latents.ndim != 3: + raise ValueError( + f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is [batch_size, num_seq, num_features]." + ) + latents = self._normalize_audio_latents(latents, self.audio_vae.latents_mean, self.audio_vae.latents_std) + latents = self._create_noised_state(latents, noise_scale, generator) + return latents.to(device=device, dtype=dtype) + + # TODO: confirm whether this logic is correct + latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio + + shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_audio_latents(latents) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput = None, + prompt: str | list[str] = None, + negative_prompt: str | list[str] | None = None, + height: int = 512, + width: int = 768, + num_frames: int = 121, + frame_rate: float = 24.0, + num_inference_steps: int = 40, + sigmas: list[float] | None = None, + timesteps: list[int] | None = None, + guidance_scale: float = 4.0, + guidance_rescale: float = 0.0, + noise_scale: float = 0.0, + num_videos_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + audio_latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + decode_timestep: float | list[float] = 0.0, + decode_noise_scale: float | list[float] | None = None, + output_type: str = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 1024, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`PipelineImageInput`): + The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`. + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to `512`): + The height in pixels of the generated image. This is set to 480 by default for the best results. + width (`int`, *optional*, defaults to `768`): + The width in pixels of the generated image. This is set to 848 by default for the best results. + num_frames (`int`, *optional*, defaults to `121`): + The number of video frames to generate + frame_rate (`float`, *optional*, defaults to `24.0`): + The frames per second (FPS) of the generated video. + num_inference_steps (`int`, *optional*, defaults to 40): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to `4.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + noise_scale (`float`, *optional*, defaults to `0.0`): + The interpolation factor between random noise and denoised latents at each timestep. Applying noise to + the `latents` and `audio_latents` before continue denoising. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + audio_latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for audio + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.FloatTensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + decode_timestep (`float`, defaults to `0.0`): + The timestep at which generated video is decoded. + decode_noise_scale (`float`, defaults to `None`): + The interpolation factor between random noise and denoised latents at the decode timestep. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ltx.LTX2PipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, *optional*, defaults to `1024`): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.ltx.LTX2PipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ltx.LTX2PipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._attention_kwargs = attention_kwargs + self._interrupt = False + self._current_timestep = None + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Prepare text embeddings + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + device=device, + ) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + additive_attention_mask = (1 - prompt_attention_mask.to(prompt_embeds.dtype)) * -1000000.0 + connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = self.connectors( + prompt_embeds, additive_attention_mask, additive_mask=True + ) + + # 4. Prepare latent variables + latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + if latents is not None: + if latents.ndim == 5: + logger.info( + "Got latents of shape [batch_size, latent_dim, latent_frames, latent_height, latent_width], `latent_num_frames`, `latent_height`, `latent_width` will be inferred." + ) + _, _, latent_num_frames, latent_height, latent_width = latents.shape # [B, C, F, H, W] + elif latents.ndim == 3: + logger.warning( + f"You have supplied packed `latents` of shape {latents.shape}, so the latent dims cannot be" + f" inferred. Make sure the supplied `height`, `width`, and `num_frames` are correct." + ) + else: + raise ValueError( + f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is either [batch_size, seq_len, num_features] or [batch_size, latent_dim, latent_frames, latent_height, latent_width]." + ) + video_sequence_length = latent_num_frames * latent_height * latent_width + + if latents is None: + image = self.video_processor.preprocess(image, height=height, width=width) + image = image.to(device=device, dtype=prompt_embeds.dtype) + + num_channels_latents = self.transformer.config.in_channels + latents, conditioning_mask = self.prepare_latents( + image, + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + noise_scale, + torch.float32, + device, + generator, + latents, + ) + if self.do_classifier_free_guidance: + conditioning_mask = torch.cat([conditioning_mask, conditioning_mask]) + + duration_s = num_frames / frame_rate + audio_latents_per_second = ( + self.audio_sampling_rate / self.audio_hop_length / float(self.audio_vae_temporal_compression_ratio) + ) + audio_num_frames = round(duration_s * audio_latents_per_second) + if audio_latents is not None: + if audio_latents.ndim == 4: + logger.info( + "Got audio_latents of shape [batch_size, num_channels, audio_length, mel_bins], `audio_num_frames` will be inferred." + ) + _, _, audio_num_frames, _ = audio_latents.shape # [B, C, L, M] + elif audio_latents.ndim == 3: + logger.warning( + f"You have supplied packed `audio_latents` of shape {audio_latents.shape}, so the latent dims" + f" cannot be inferred. Make sure the supplied `num_frames` and `frame_rate` are correct." + ) + else: + raise ValueError( + f"Provided `audio_latents` tensor has shape {audio_latents.shape}, but the expected shape is either [batch_size, seq_len, num_features] or [batch_size, num_channels, audio_length, mel_bins]." + ) + + num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64 + latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio + num_channels_latents_audio = ( + self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8 + ) + audio_latents = self.prepare_audio_latents( + batch_size * num_videos_per_prompt, + num_channels_latents=num_channels_latents_audio, + audio_latent_length=audio_num_frames, + num_mel_bins=num_mel_bins, + noise_scale=noise_scale, + dtype=torch.float32, + device=device, + generator=generator, + latents=audio_latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + mu = calculate_shift( + video_sequence_length, + self.scheduler.config.get("base_image_seq_len", 1024), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.95), + self.scheduler.config.get("max_shift", 2.05), + ) + + # For now, duplicate the scheduler for use with the audio latents + audio_scheduler = copy.deepcopy(self.scheduler) + _, _ = retrieve_timesteps( + audio_scheduler, + num_inference_steps, + device, + timesteps, + sigmas=sigmas, + mu=mu, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Prepare micro-conditions + rope_interpolation_scale = ( + self.vae_temporal_compression_ratio / frame_rate, + self.vae_spatial_compression_ratio, + self.vae_spatial_compression_ratio, + ) + # Pre-compute video and audio positional ids as they will be the same at each step of the denoising loop + video_coords = self.transformer.rope.prepare_video_coords( + latents.shape[0], latent_num_frames, latent_height, latent_width, latents.device, fps=frame_rate + ) + audio_coords = self.transformer.audio_rope.prepare_audio_coords( + audio_latents.shape[0], audio_num_frames, audio_latents.device + ) + # Duplicate the positional ids as well if using CFG + if self.do_classifier_free_guidance: + video_coords = video_coords.repeat((2,) + (1,) * (video_coords.ndim - 1)) # Repeat twice in batch dim + audio_coords = audio_coords.repeat((2,) + (1,) * (audio_coords.ndim - 1)) + + # 7. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = latent_model_input.to(prompt_embeds.dtype) + audio_latent_model_input = ( + torch.cat([audio_latents] * 2) if self.do_classifier_free_guidance else audio_latents + ) + audio_latent_model_input = audio_latent_model_input.to(prompt_embeds.dtype) + + timestep = t.expand(latent_model_input.shape[0]) + video_timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask) + + noise_pred_video, noise_pred_audio = self.transformer( + hidden_states=latent_model_input, + audio_hidden_states=audio_latent_model_input, + encoder_hidden_states=connector_prompt_embeds, + audio_encoder_hidden_states=connector_audio_prompt_embeds, + timestep=video_timestep, + audio_timestep=timestep, + encoder_attention_mask=connector_attention_mask, + audio_encoder_attention_mask=connector_attention_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + fps=frame_rate, + audio_num_frames=audio_num_frames, + video_coords=video_coords, + audio_coords=audio_coords, + # rope_interpolation_scale=rope_interpolation_scale, + attention_kwargs=attention_kwargs, + return_dict=False, + ) + noise_pred_video = noise_pred_video.float() + noise_pred_audio = noise_pred_audio.float() + + if self.do_classifier_free_guidance: + noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2) + noise_pred_video = noise_pred_video_uncond + self.guidance_scale * ( + noise_pred_video_text - noise_pred_video_uncond + ) + + noise_pred_audio_uncond, noise_pred_audio_text = noise_pred_audio.chunk(2) + noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * ( + noise_pred_audio_text - noise_pred_audio_uncond + ) + + if self.guidance_rescale > 0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred_video = rescale_noise_cfg( + noise_pred_video, noise_pred_video_text, guidance_rescale=self.guidance_rescale + ) + noise_pred_audio = rescale_noise_cfg( + noise_pred_audio, noise_pred_audio_text, guidance_rescale=self.guidance_rescale + ) + + # compute the previous noisy sample x_t -> x_t-1 + noise_pred_video = self._unpack_latents( + noise_pred_video, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + latents = self._unpack_latents( + latents, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + + noise_pred_video = noise_pred_video[:, :, 1:] + noise_latents = latents[:, :, 1:] + pred_latents = self.scheduler.step(noise_pred_video, t, noise_latents, return_dict=False)[0] + + latents = torch.cat([latents[:, :, :1], pred_latents], dim=2) + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + + # NOTE: for now duplicate scheduler for audio latents in case self.scheduler sets internal state in + # the step method (such as _step_index) + audio_latents = audio_scheduler.step(noise_pred_audio, t, audio_latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + latents = self._unpack_latents( + latents, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + + audio_latents = self._denormalize_audio_latents( + audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std + ) + audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins) + + if output_type == "latent": + video = latents + audio = audio_latents + else: + latents = latents.to(prompt_embeds.dtype) + + if not self.vae.config.timestep_conditioning: + timestep = None + else: + noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype) + if not isinstance(decode_timestep, list): + decode_timestep = [decode_timestep] * batch_size + if decode_noise_scale is None: + decode_noise_scale = decode_timestep + elif not isinstance(decode_noise_scale, list): + decode_noise_scale = [decode_noise_scale] * batch_size + + timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype) + decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[ + :, None, None, None, None + ] + latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise + + latents = latents.to(self.vae.dtype) + video = self.vae.decode(latents, timestep, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type).cpu().float().permute(0, 2, 1, 3, 4) + + audio_latents = audio_latents.to(self.audio_vae.dtype) + generated_mel_spectrograms = self.audio_vae.decode(audio_latents, return_dict=False)[0] + audio = self.vocoder(generated_mel_spectrograms).cpu().float() + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video, audio) + + return LTX2PipelineOutput(videos=video, audio=audio) \ No newline at end of file diff --git a/videox_fun/pipeline/pipeline_wan.py b/videox_fun/pipeline/pipeline_wan.py index f105c9a3..0dc024bf 100755 --- a/videox_fun/pipeline/pipeline_wan.py +++ b/videox_fun/pipeline/pipeline_wan.py @@ -399,8 +399,8 @@ def __call__( latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: str = "numpy", - return_dict: bool = False, + output_type: str = "pil", + return_dict: bool = True, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, @@ -559,11 +559,9 @@ def __call__( if comfyui_progressbar: pbar.update(1) - if output_type == "numpy": + if output_type == "pil": video = self.decode_latents(latents) - elif not output_type == "latent": - video = self.decode_latents(latents) - video = self.video_processor.postprocess_video(video=video, output_type=output_type) + video = torch.from_numpy(video) else: video = latents @@ -571,6 +569,6 @@ def __call__( self.maybe_free_model_hooks() if not return_dict: - video = torch.from_numpy(video) + return video return WanPipelineOutput(videos=video) diff --git a/videox_fun/pipeline/pipeline_wan2_2.py b/videox_fun/pipeline/pipeline_wan2_2.py index e96287a9..b317458c 100755 --- a/videox_fun/pipeline/pipeline_wan2_2.py +++ b/videox_fun/pipeline/pipeline_wan2_2.py @@ -401,8 +401,8 @@ def __call__( latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: str = "numpy", - return_dict: bool = False, + output_type: str = "pil", + return_dict: bool = True, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, @@ -574,11 +574,9 @@ def __call__( if comfyui_progressbar: pbar.update(1) - if output_type == "numpy": + if output_type == "pil": video = self.decode_latents(latents) - elif not output_type == "latent": - video = self.decode_latents(latents) - video = self.video_processor.postprocess_video(video=video, output_type=output_type) + video = torch.from_numpy(video) else: video = latents @@ -586,6 +584,6 @@ def __call__( self.maybe_free_model_hooks() if not return_dict: - video = torch.from_numpy(video) + return video return WanPipelineOutput(videos=video) diff --git a/videox_fun/pipeline/pipeline_wan2_2_animate.py b/videox_fun/pipeline/pipeline_wan2_2_animate.py index 6a0df263..456fcef0 100644 --- a/videox_fun/pipeline/pipeline_wan2_2_animate.py +++ b/videox_fun/pipeline/pipeline_wan2_2_animate.py @@ -384,8 +384,8 @@ def inputs_padding(self, x, target_len): else: raise ValueError(f"Unsupported input dimension: {ndim}. Expected 4D or 5D.") - def get_valid_len(self, real_len, clip_len=81, overlap=1): - real_clip_len = clip_len - overlap + def get_valid_len(self, real_len, segment_frame_length=81, overlap=1): + real_clip_len = segment_frame_length - overlap last_clip_num = (real_len - overlap) % real_clip_len if last_clip_num == 0: extra = 0 @@ -568,8 +568,7 @@ def __call__( negative_prompt: Optional[Union[str, List[str]]] = None, height: int = 480, width: int = 720, - clip_len=77, - num_frames: int = 49, + segment_frame_length = 77, num_inference_steps: int = 50, pose_video = None, face_video = None, @@ -585,8 +584,8 @@ def __call__( latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: str = "numpy", - return_dict: bool = False, + output_type: str = "pil", + return_dict: bool = True, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, @@ -680,7 +679,7 @@ def __call__( face_video = None real_frame_len = pose_video.size()[2] - target_len = self.get_valid_len(real_frame_len, clip_len, overlap=refert_num) + target_len = self.get_valid_len(real_frame_len, segment_frame_length, overlap=refert_num) print('real frames: {} target frames: {}'.format(real_frame_len, target_len)) pose_video = self.inputs_padding(pose_video, target_len).to(device, weight_dtype) face_video = self.inputs_padding(face_video, target_len).to(device, weight_dtype) @@ -704,12 +703,12 @@ def __call__( # 5. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spatial_compression_ratio, height // self.vae.spatial_compression_ratio) + target_shape = (self.vae.latent_channels, (segment_frame_length - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spatial_compression_ratio, height // self.vae.spatial_compression_ratio) seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1]) # 6. Denoising loop start = 0 - end = clip_len + end = segment_frame_length all_out_frames = [] copy_timesteps = copy.deepcopy(timesteps) copy_latents = copy.deepcopy(latents) @@ -738,7 +737,7 @@ def __call__( latents = self.prepare_latents( batch_size * num_videos_per_prompt, latent_channels, - num_frames, + segment_frame_length, height, width, weight_dtype, @@ -794,7 +793,7 @@ def __call__( mask_pixel_values = F.interpolate(mask_pixel_values, size=(target_shape[-1], target_shape[-2]), mode='nearest') mask_pixel_values = rearrange(mask_pixel_values, "(b t) c h w -> b c t h w", b = bs)[:, 0] msk_reft = self.get_i2v_mask( - int((clip_len - 1) // self.vae.temporal_compression_ratio + 1), target_shape[-1], target_shape[-2], mask_reft_len, mask_pixel_values=mask_pixel_values, device=device + int((segment_frame_length - 1) // self.vae.temporal_compression_ratio + 1), target_shape[-1], target_shape[-2], mask_reft_len, mask_pixel_values=mask_pixel_values, device=device ) else: refer_t_pixel_values = rearrange(refer_t_pixel_values[:, :, :mask_reft_len], "b c t h w -> (b t) c h w") @@ -805,12 +804,12 @@ def __call__( torch.concat( [ refer_t_pixel_values, - torch.zeros(bs, 3, clip_len - mask_reft_len, height, width).to(device=device, dtype=weight_dtype), + torch.zeros(bs, 3, segment_frame_length - mask_reft_len, height, width).to(device=device, dtype=weight_dtype), ], dim=2, ).to(device=device, dtype=weight_dtype) )[0].mode() msk_reft = self.get_i2v_mask( - int((clip_len - 1) // self.vae.temporal_compression_ratio + 1), target_shape[-1], target_shape[-2], mask_reft_len, device=device + int((segment_frame_length - 1) // self.vae.temporal_compression_ratio + 1), target_shape[-1], target_shape[-2], mask_reft_len, device=device ) else: if replace_flag: @@ -824,14 +823,14 @@ def __call__( mask_pixel_values = F.interpolate(mask_pixel_values, size=(target_shape[-1], target_shape[-2]), mode='nearest') mask_pixel_values = rearrange(mask_pixel_values, "(b t) c h w -> b c t h w", b = bs)[:, 0] msk_reft = self.get_i2v_mask( - int((clip_len - 1) // self.vae.temporal_compression_ratio + 1), target_shape[-1], target_shape[-2], mask_reft_len, mask_pixel_values=mask_pixel_values, device=device + int((segment_frame_length - 1) // self.vae.temporal_compression_ratio + 1), target_shape[-1], target_shape[-2], mask_reft_len, mask_pixel_values=mask_pixel_values, device=device ) else: y_reft = self.vae.encode( - torch.zeros(1, 3, clip_len - mask_reft_len, height, width).to(device=device, dtype=weight_dtype) + torch.zeros(1, 3, segment_frame_length - mask_reft_len, height, width).to(device=device, dtype=weight_dtype) )[0].mode() msk_reft = self.get_i2v_mask( - int((clip_len - 1) // self.vae.temporal_compression_ratio + 1), target_shape[-1], target_shape[-2], mask_reft_len, device=device + int((segment_frame_length - 1) // self.vae.temporal_compression_ratio + 1), target_shape[-1], target_shape[-2], mask_reft_len, device=device ) y_reft = torch.concat([msk_reft, y_reft], dim=1).to(device=device, dtype=weight_dtype) @@ -918,12 +917,15 @@ def __call__( if start != 0: out_frames = out_frames[:, :, refert_num:] all_out_frames.append(out_frames.cpu()) - start += clip_len - refert_num - end += clip_len - refert_num + start += segment_frame_length - refert_num + end += segment_frame_length - refert_num - videos = torch.cat(all_out_frames, dim=2)[:, :, :real_frame_len] + videos = torch.cat(all_out_frames, dim=2)[:, :, :real_frame_len].float().cpu() # Offload all models self.maybe_free_model_hooks() - return WanPipelineOutput(videos=videos.float().cpu()) + if not return_dict: + return video + + return WanPipelineOutput(videos=videos) diff --git a/videox_fun/pipeline/pipeline_wan2_2_fun_control.py b/videox_fun/pipeline/pipeline_wan2_2_fun_control.py index 5b923191..66423315 100644 --- a/videox_fun/pipeline/pipeline_wan2_2_fun_control.py +++ b/videox_fun/pipeline/pipeline_wan2_2_fun_control.py @@ -524,8 +524,8 @@ def __call__( latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: str = "numpy", - return_dict: bool = False, + output_type: str = "pil", + return_dict: bool = True, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, @@ -886,11 +886,9 @@ def __call__( if comfyui_progressbar: pbar.update(1) - if output_type == "numpy": + if output_type == "pil": video = self.decode_latents(latents) - elif not output_type == "latent": - video = self.decode_latents(latents) - video = self.video_processor.postprocess_video(video=video, output_type=output_type) + video = torch.from_numpy(video) else: video = latents @@ -898,6 +896,6 @@ def __call__( self.maybe_free_model_hooks() if not return_dict: - video = torch.from_numpy(video) + return video return WanPipelineOutput(videos=video) diff --git a/videox_fun/pipeline/pipeline_wan2_2_fun_inpaint.py b/videox_fun/pipeline/pipeline_wan2_2_fun_inpaint.py index 8eaf68e6..ca35821e 100644 --- a/videox_fun/pipeline/pipeline_wan2_2_fun_inpaint.py +++ b/videox_fun/pipeline/pipeline_wan2_2_fun_inpaint.py @@ -487,8 +487,8 @@ def __call__( latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: str = "numpy", - return_dict: bool = False, + output_type: str = "pil", + return_dict: bool = True, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, @@ -735,11 +735,9 @@ def __call__( if comfyui_progressbar: pbar.update(1) - if output_type == "numpy": + if output_type == "pil": video = self.decode_latents(latents) - elif not output_type == "latent": - video = self.decode_latents(latents) - video = self.video_processor.postprocess_video(video=video, output_type=output_type) + video = torch.from_numpy(video) else: video = latents @@ -747,6 +745,6 @@ def __call__( self.maybe_free_model_hooks() if not return_dict: - video = torch.from_numpy(video) + return video return WanPipelineOutput(videos=video) diff --git a/videox_fun/pipeline/pipeline_wan2_2_s2v.py b/videox_fun/pipeline/pipeline_wan2_2_s2v.py index 6fbcbaf9..42e2b84c 100644 --- a/videox_fun/pipeline/pipeline_wan2_2_s2v.py +++ b/videox_fun/pipeline/pipeline_wan2_2_s2v.py @@ -158,6 +158,7 @@ class Wan2_2S2VPipeline(DiffusionPipeline): library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) """ + _exclude_from_cpu_offload = ["audio_encoder"] _optional_components = ["transformer_2", "audio_encoder"] model_cpu_offload_seq = "text_encoder->transformer_2->transformer->vae" @@ -317,11 +318,11 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds - def encode_audio_embeddings(self, audio_path, num_frames, fps, weight_dtype, device): + def encode_audio_embeddings(self, audio_path, segment_frame_length, fps, weight_dtype, device): z = self.audio_encoder.extract_audio_feat( audio_path, return_all_layers=True) audio_embed_bucket, num_repeat = self.audio_encoder.get_audio_embed_bucket_fps( - z, fps=fps, batch_frames=num_frames, m=self.audio_sample_m) + z, fps=fps, batch_frames=segment_frame_length, m=self.audio_sample_m) audio_embed_bucket = audio_embed_bucket.to(device, weight_dtype) audio_embed_bucket = audio_embed_bucket.unsqueeze(0) if len(audio_embed_bucket.shape) == 3: @@ -330,10 +331,10 @@ def encode_audio_embeddings(self, audio_path, num_frames, fps, weight_dtype, dev audio_embed_bucket = audio_embed_bucket.permute(0, 2, 3, 1) return audio_embed_bucket, num_repeat - def encode_pose_latents(self, pose_video, num_repeat, num_frames, size, fps, weight_dtype, device): + def encode_pose_latents(self, pose_video, num_repeat, segment_frame_length, size, fps, weight_dtype, device): height, width = size if not pose_video is None: - padding_frame_num = num_repeat * num_frames - pose_video.shape[2] + padding_frame_num = num_repeat * segment_frame_length - pose_video.shape[2] pose_video = torch.cat( [ pose_video, @@ -344,7 +345,7 @@ def encode_pose_latents(self, pose_video, num_repeat, num_frames, size, fps, wei cond_tensors = torch.chunk(pose_video, num_repeat, dim=2) else: - cond_tensors = [-torch.ones([1, 3, num_frames, height, width])] + cond_tensors = [-torch.ones([1, 3, segment_frame_length, height, width])] pose_latents = [] for r in range(len(cond_tensors)): @@ -519,7 +520,7 @@ def __call__( ref_image: Union[torch.FloatTensor] = None, audio_path = None, pose_video = None, - num_frames: int = 49, + segment_frame_length: int = 49, num_inference_steps: int = 50, timesteps: Optional[List[int]] = None, guidance_scale: float = 6, @@ -529,8 +530,8 @@ def __call__( latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: str = "numpy", - return_dict: bool = False, + output_type: str = "pil", + return_dict: bool = True, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, @@ -588,8 +589,8 @@ def __call__( # lat_motion_frames = 76 / 4 = 19 lat_motion_frames = (self.motion_frames + 3) // 4 - # lat_motion_frames ~= num_frames // 4 - lat_target_frames = (num_frames + 3 + self.motion_frames) // 4 - lat_motion_frames + # lat_motion_frames ~= segment_frame_length // 4 + lat_target_frames = (segment_frame_length + 3 + self.motion_frames) // 4 - lat_motion_frames # 3. Encode input prompt prompt_embeds, negative_prompt_embeds = self.encode_prompt( @@ -636,7 +637,7 @@ def __call__( # Extract audio emb audio_emb, num_repeat = self.encode_audio_embeddings( - audio_path, num_frames=num_frames, fps=fps, weight_dtype=weight_dtype, device=device + audio_path, segment_frame_length=segment_frame_length, fps=fps, weight_dtype=weight_dtype, device=device ) # Encode the motion latents @@ -661,7 +662,7 @@ def __call__( pose_latents = self.encode_pose_latents( pose_video=pose_video, num_repeat=num_repeat, - num_frames=num_frames, + segment_frame_length=segment_frame_length, size=(height, width), fps=fps, weight_dtype=weight_dtype, @@ -701,7 +702,7 @@ def __call__( latents = self.prepare_latents( batch_size * num_videos_per_prompt, latent_channels, - num_frames, + segment_frame_length, height, width, weight_dtype, @@ -725,8 +726,8 @@ def __call__( latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) with torch.no_grad(): - left_idx = r * num_frames - right_idx = r * num_frames + num_frames + left_idx = r * segment_frame_length + right_idx = r * segment_frame_length + segment_frame_length cond_latents = pose_latents[r] if pose_video is not None else pose_latents[0] * 0 cond_latents = cond_latents.to(dtype=weight_dtype, device=device) audio_input = audio_emb[..., left_idx:right_idx] @@ -793,7 +794,7 @@ def __call__( decode_latents = torch.cat([ref_image_latentes, latents], dim=2) image = self.vae.decode(decode_latents).sample - image = image[:, :, -(num_frames):] + image = image[:, :, -(segment_frame_length):] if (drop_first_motion and r == 0): image = image[:, :, 3:] @@ -809,9 +810,12 @@ def __call__( videos.append(image) videos = torch.cat(videos, dim=2) - videos = (videos / 2 + 0.5).clamp(0, 1) + videos = (videos / 2 + 0.5).clamp(0, 1).float().cpu() # Offload all models self.maybe_free_model_hooks() - return WanPipelineOutput(videos=videos.float().cpu()) + if not return_dict: + return video + + return WanPipelineOutput(videos=videos) diff --git a/videox_fun/pipeline/pipeline_wan2_2_ti2v.py b/videox_fun/pipeline/pipeline_wan2_2_ti2v.py index 12b3bd8c..bd5af93f 100644 --- a/videox_fun/pipeline/pipeline_wan2_2_ti2v.py +++ b/videox_fun/pipeline/pipeline_wan2_2_ti2v.py @@ -486,8 +486,8 @@ def __call__( latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: str = "numpy", - return_dict: bool = False, + output_type: str = "pil", + return_dict: bool = True, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, @@ -715,11 +715,9 @@ def __call__( if comfyui_progressbar: pbar.update(1) - if output_type == "numpy": + if output_type == "pil": video = self.decode_latents(latents) - elif not output_type == "latent": - video = self.decode_latents(latents) - video = self.video_processor.postprocess_video(video=video, output_type=output_type) + video = torch.from_numpy(video) else: video = latents @@ -727,6 +725,6 @@ def __call__( self.maybe_free_model_hooks() if not return_dict: - video = torch.from_numpy(video) + return video return WanPipelineOutput(videos=video) diff --git a/videox_fun/pipeline/pipeline_wan2_2_vace_fun.py b/videox_fun/pipeline/pipeline_wan2_2_vace_fun.py index c8b0d993..545a6536 100644 --- a/videox_fun/pipeline/pipeline_wan2_2_vace_fun.py +++ b/videox_fun/pipeline/pipeline_wan2_2_vace_fun.py @@ -555,8 +555,8 @@ def __call__( latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: str = "numpy", - return_dict: bool = False, + output_type: str = "pil", + return_dict: bool = True, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, @@ -784,11 +784,9 @@ def __call__( len_subject_ref_images = len(subject_ref_images[0]) latents = latents[:, :, len_subject_ref_images:, :, :] - if output_type == "numpy": + if output_type == "pil": video = self.decode_latents(latents) - elif not output_type == "latent": - video = self.decode_latents(latents) - video = self.video_processor.postprocess_video(video=video, output_type=output_type) + video = torch.from_numpy(video) else: video = latents @@ -796,6 +794,6 @@ def __call__( self.maybe_free_model_hooks() if not return_dict: - video = torch.from_numpy(video) + return video return WanPipelineOutput(videos=video) diff --git a/videox_fun/pipeline/pipeline_wan_fun_control.py b/videox_fun/pipeline/pipeline_wan_fun_control.py index 4d4ec75e..1f300198 100755 --- a/videox_fun/pipeline/pipeline_wan_fun_control.py +++ b/videox_fun/pipeline/pipeline_wan_fun_control.py @@ -487,8 +487,8 @@ def __call__( latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: str = "numpy", - return_dict: bool = False, + output_type: str = "pil", + return_dict: bool = True, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, @@ -782,11 +782,9 @@ def __call__( if comfyui_progressbar: pbar.update(1) - if output_type == "numpy": + if output_type == "pil": video = self.decode_latents(latents) - elif not output_type == "latent": - video = self.decode_latents(latents) - video = self.video_processor.postprocess_video(video=video, output_type=output_type) + video = torch.from_numpy(video) else: video = latents @@ -794,6 +792,6 @@ def __call__( self.maybe_free_model_hooks() if not return_dict: - video = torch.from_numpy(video) + return video return WanPipelineOutput(videos=video) diff --git a/videox_fun/pipeline/pipeline_wan_fun_inpaint.py b/videox_fun/pipeline/pipeline_wan_fun_inpaint.py index 35f3b96d..aa4c5412 100755 --- a/videox_fun/pipeline/pipeline_wan_fun_inpaint.py +++ b/videox_fun/pipeline/pipeline_wan_fun_inpaint.py @@ -486,8 +486,8 @@ def __call__( latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: str = "numpy", - return_dict: bool = False, + output_type: str = "pil", + return_dict: bool = True, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, @@ -717,11 +717,9 @@ def __call__( if comfyui_progressbar: pbar.update(1) - if output_type == "numpy": + if output_type == "pil": video = self.decode_latents(latents) - elif not output_type == "latent": - video = self.decode_latents(latents) - video = self.video_processor.postprocess_video(video=video, output_type=output_type) + video = torch.from_numpy(video) else: video = latents @@ -729,6 +727,6 @@ def __call__( self.maybe_free_model_hooks() if not return_dict: - video = torch.from_numpy(video) + return video return WanPipelineOutput(videos=video) diff --git a/videox_fun/pipeline/pipeline_wan_phantom.py b/videox_fun/pipeline/pipeline_wan_phantom.py index fd993b00..c1494000 100644 --- a/videox_fun/pipeline/pipeline_wan_phantom.py +++ b/videox_fun/pipeline/pipeline_wan_phantom.py @@ -483,8 +483,8 @@ def __call__( latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: str = "numpy", - return_dict: bool = False, + output_type: str = "pil", + return_dict: bool = True, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, @@ -678,11 +678,9 @@ def __call__( if comfyui_progressbar: pbar.update(1) - if output_type == "numpy": + if output_type == "pil": video = self.decode_latents(latents) - elif not output_type == "latent": - video = self.decode_latents(latents) - video = self.video_processor.postprocess_video(video=video, output_type=output_type) + video = torch.from_numpy(video) else: video = latents @@ -690,6 +688,6 @@ def __call__( self.maybe_free_model_hooks() if not return_dict: - video = torch.from_numpy(video) + return video return WanPipelineOutput(videos=video) diff --git a/videox_fun/pipeline/pipeline_wan_vace.py b/videox_fun/pipeline/pipeline_wan_vace.py index 7c0ded27..4a2b08d4 100644 --- a/videox_fun/pipeline/pipeline_wan_vace.py +++ b/videox_fun/pipeline/pipeline_wan_vace.py @@ -553,8 +553,8 @@ def __call__( latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: str = "numpy", - return_dict: bool = False, + output_type: str = "pil", + return_dict: bool = True, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, @@ -770,11 +770,9 @@ def __call__( len_subject_ref_images = len(subject_ref_images[0]) latents = latents[:, :, len_subject_ref_images:, :, :] - if output_type == "numpy": + if output_type == "pil": video = self.decode_latents(latents) - elif not output_type == "latent": - video = self.decode_latents(latents) - video = self.video_processor.postprocess_video(video=video, output_type=output_type) + video = torch.from_numpy(video) else: video = latents @@ -782,6 +780,6 @@ def __call__( self.maybe_free_model_hooks() if not return_dict: - video = torch.from_numpy(video) + return video return WanPipelineOutput(videos=video) From 2b3b25794800aca4ac18261b1c3379f1bbbad9a0 Mon Sep 17 00:00:00 2001 From: bubbliiiing <3323290568@qq.com> Date: Thu, 12 Mar 2026 17:03:13 +0800 Subject: [PATCH 04/10] Update save_results --- examples/wan2.2/predict_s2v.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/wan2.2/predict_s2v.py b/examples/wan2.2/predict_s2v.py index 97e80a5f..f38296a2 100644 --- a/examples/wan2.2/predict_s2v.py +++ b/examples/wan2.2/predict_s2v.py @@ -355,7 +355,7 @@ def save_results(): index = len([path for path in os.listdir(save_path)]) + 1 prefix = str(index).zfill(8) - if video_length == 1: + if sample.size()[2] == 1: video_path = os.path.join(save_path, prefix + ".png") image = sample[0, :, 0] From f209cb332ecaf3e6113bb09923d32a50283cf770 Mon Sep 17 00:00:00 2001 From: bubbliiiing <3323290568@qq.com> Date: Thu, 12 Mar 2026 17:06:37 +0800 Subject: [PATCH 05/10] Update init --- videox_fun/pipeline/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/videox_fun/pipeline/__init__.py b/videox_fun/pipeline/__init__.py index b29c87f1..8c3edd57 100755 --- a/videox_fun/pipeline/__init__.py +++ b/videox_fun/pipeline/__init__.py @@ -9,7 +9,7 @@ from .pipeline_hunyuanvideo_i2v import HunyuanVideoI2VPipeline from .pipeline_longcatvideo import LongCatVideoPipeline from .pipeline_longcatvideo_avatar import LongCatVideoAvatarPipeline -from .pipeline_ltx2_i2v import LTX2ImageToVideoPipeline +from .pipeline_ltx2_i2v import LTX2I2VPipeline from .pipeline_ltx2 import LTX2Pipeline from .pipeline_qwenimage import QwenImagePipeline from .pipeline_qwenimage_control import QwenImageControlPipeline From 8f0768fcb9db46a09c873d0cb4a34c6249f803fd Mon Sep 17 00:00:00 2001 From: bubbliiiing <3323290568@qq.com> Date: Thu, 12 Mar 2026 22:08:31 +0800 Subject: [PATCH 06/10] Update Wan s2v codes --- scripts/wan2.2/train_s2v.py | 3 +-- scripts/wan2.2/train_s2v_lora.py | 3 +-- scripts/wan2.2/train_s2v_lora.sh | 2 +- videox_fun/pipeline/pipeline_wan2_2_animate.py | 4 ++-- 4 files changed, 5 insertions(+), 7 deletions(-) diff --git a/scripts/wan2.2/train_s2v.py b/scripts/wan2.2/train_s2v.py index a4fcb0a5..8781a81c 100644 --- a/scripts/wan2.2/train_s2v.py +++ b/scripts/wan2.2/train_s2v.py @@ -190,9 +190,8 @@ def log_validation(vae, text_encoder, tokenizer, audio_encoder, transformer3d, a start_image = Image.open(args.validation_image_paths[i]) width, height = start_image.width, start_image.height width, height = calculate_dimensions(args.video_sample_size * args.video_sample_size, width / height) - video_length = int((args.video_sample_n_frames - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if args.video_sample_n_frames != 1 else 1 - pose_video, _, _, _ = get_video_to_video_latent(None, video_length=video_length, sample_size=(height, width), ref_image=None) + pose_video, _, _, _ = get_video_to_video_latent(None, video_length=None, sample_size=(height, width), ref_image=None) ref_image = get_image_latent(args.validation_image_paths[i], sample_size=(height, width)) sample = pipeline( diff --git a/scripts/wan2.2/train_s2v_lora.py b/scripts/wan2.2/train_s2v_lora.py index 4f930c5e..a1056d01 100644 --- a/scripts/wan2.2/train_s2v_lora.py +++ b/scripts/wan2.2/train_s2v_lora.py @@ -200,9 +200,8 @@ def log_validation(vae, text_encoder, tokenizer, audio_encoder, transformer3d, n start_image = Image.open(args.validation_image_paths[i]) width, height = start_image.width, start_image.height width, height = calculate_dimensions(args.video_sample_size * args.video_sample_size, width / height) - video_length = int((args.video_sample_n_frames - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if args.video_sample_n_frames != 1 else 1 - pose_video, _, _, _ = get_video_to_video_latent(None, video_length=video_length, sample_size=(height, width), ref_image=None) + pose_video, _, _, _ = get_video_to_video_latent(None, video_length=None, sample_size=(height, width), ref_image=None) ref_image = get_image_latent(args.validation_image_paths[i], sample_size=(height, width)) sample = pipeline( diff --git a/scripts/wan2.2/train_s2v_lora.sh b/scripts/wan2.2/train_s2v_lora.sh index e9a07b83..79d20a25 100644 --- a/scripts/wan2.2/train_s2v_lora.sh +++ b/scripts/wan2.2/train_s2v_lora.sh @@ -14,7 +14,7 @@ accelerate launch --mixed_precision="bf16" scripts/wan2.2/train_s2v_lora.py \ --video_sample_size=640 \ --token_sample_size=640 \ --video_sample_stride=2 \ - --video_sample_n_frames=81 \ + --video_sample_n_frames=80 \ --train_batch_size=1 \ --video_repeat=1 \ --gradient_accumulation_steps=1 \ diff --git a/videox_fun/pipeline/pipeline_wan2_2_animate.py b/videox_fun/pipeline/pipeline_wan2_2_animate.py index 456fcef0..cb81fcd8 100644 --- a/videox_fun/pipeline/pipeline_wan2_2_animate.py +++ b/videox_fun/pipeline/pipeline_wan2_2_animate.py @@ -703,7 +703,7 @@ def __call__( # 5. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - target_shape = (self.vae.latent_channels, (segment_frame_length - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spatial_compression_ratio, height // self.vae.spatial_compression_ratio) + target_shape = (self.vae.latent_channels, (segment_frame_length + 4 - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spatial_compression_ratio, height // self.vae.spatial_compression_ratio) seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1]) # 6. Denoising loop @@ -737,7 +737,7 @@ def __call__( latents = self.prepare_latents( batch_size * num_videos_per_prompt, latent_channels, - segment_frame_length, + segment_frame_length + 4, height, width, weight_dtype, From e5b7590357ee7a997aeecee39fd18e2ecb10f003 Mon Sep 17 00:00:00 2001 From: bubbliiiing <3323290568@qq.com> Date: Thu, 12 Mar 2026 22:11:11 +0800 Subject: [PATCH 07/10] Update sh --- scripts/wan2.2/train_s2v_lora.sh | 1 - 1 file changed, 1 deletion(-) diff --git a/scripts/wan2.2/train_s2v_lora.sh b/scripts/wan2.2/train_s2v_lora.sh index 79d20a25..b3ef60d4 100644 --- a/scripts/wan2.2/train_s2v_lora.sh +++ b/scripts/wan2.2/train_s2v_lora.sh @@ -16,7 +16,6 @@ accelerate launch --mixed_precision="bf16" scripts/wan2.2/train_s2v_lora.py \ --video_sample_stride=2 \ --video_sample_n_frames=80 \ --train_batch_size=1 \ - --video_repeat=1 \ --gradient_accumulation_steps=1 \ --dataloader_num_workers=8 \ --num_train_epochs=100 \ From 2bd2d2e388a830a0dab01dd0b2e2c0cb1e565d7b Mon Sep 17 00:00:00 2001 From: bubbliiiing <3323290568@qq.com> Date: Thu, 12 Mar 2026 22:48:58 +0800 Subject: [PATCH 08/10] Add gradient checkpointing in s2v --- videox_fun/models/wan_transformer3d_s2v.py | 48 +++++++++++++++++----- 1 file changed, 38 insertions(+), 10 deletions(-) diff --git a/videox_fun/models/wan_transformer3d_s2v.py b/videox_fun/models/wan_transformer3d_s2v.py index e55f86a3..84e1e7f9 100644 --- a/videox_fun/models/wan_transformer3d_s2v.py +++ b/videox_fun/models/wan_transformer3d_s2v.py @@ -434,7 +434,19 @@ def process_motion_transformer_motioner(self, dtype=motion_latents[0].dtype) gride_sizes = [] - zip_motion = self.motioner(motion_latents) + if torch.is_grad_enabled() and self.gradient_checkpointing: + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + zip_motion = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.motioner), + motion_latents, + **ckpt_kwargs + ) + else: + zip_motion = self.motioner(motion_latents) zip_motion = self.zip_motion_out(zip_motion) if drop_motion_frames: zip_motion = zip_motion * 0.0 @@ -629,7 +641,21 @@ def custom_forward(*inputs): self.merged_audio_emb = audio_emb[:, motion_frames_1:, :] # Cond states - cond = [self.cond_encoder(c.unsqueeze(0)) for c in cond_states] + if torch.is_grad_enabled() and self.gradient_checkpointing: + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + cond = [ + torch.utils.checkpoint.checkpoint( + create_custom_forward(self.cond_encoder), + c.unsqueeze(0), + **ckpt_kwargs + ) for c in cond_states + ] + else: + cond = [self.cond_encoder(c.unsqueeze(0)) for c in cond_states] x = [x_ + pose for x_, pose in zip(x, cond)] grid_sizes = torch.stack( @@ -790,14 +816,16 @@ def custom_forward(*inputs): for idx, block in enumerate(self.blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: - def create_custom_forward(module): + def create_custom_forward_with_audio(module, block_idx): def custom_forward(*inputs): - return module(*inputs) + x = module(*inputs) + x = self.after_transformer_block(block_idx, x) + return x return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + create_custom_forward_with_audio(block, idx), x, e0, seq_lens, @@ -809,7 +837,6 @@ def custom_forward(*inputs): t, **ckpt_kwargs, ) - x = self.after_transformer_block(idx, x) else: # arguments kwargs = dict( @@ -833,14 +860,16 @@ def custom_forward(*inputs): for idx, block in enumerate(self.blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: - def create_custom_forward(module): + def create_custom_forward_with_audio(module, block_idx): def custom_forward(*inputs): - return module(*inputs) + x = module(*inputs) + x = self.after_transformer_block(block_idx, x) + return x return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + create_custom_forward_with_audio(block, idx), x, e0, seq_lens, @@ -852,7 +881,6 @@ def custom_forward(*inputs): t, **ckpt_kwargs, ) - x = self.after_transformer_block(idx, x) else: # arguments kwargs = dict( From 8fb48cb481c25f11be321540e358582f88aa65e9 Mon Sep 17 00:00:00 2001 From: bubbliiiing <3323290568@qq.com> Date: Fri, 13 Mar 2026 14:57:20 +0800 Subject: [PATCH 09/10] Add gradient checkpointing in wan animate --- .../models/wan_transformer3d_animate.py | 69 +++++++++++++++---- 1 file changed, 55 insertions(+), 14 deletions(-) diff --git a/videox_fun/models/wan_transformer3d_animate.py b/videox_fun/models/wan_transformer3d_animate.py index 227a3bd5..fb621d51 100644 --- a/videox_fun/models/wan_transformer3d_animate.py +++ b/videox_fun/models/wan_transformer3d_animate.py @@ -2,7 +2,7 @@ import math import types from copy import deepcopy -from typing import List +from typing import Any, Dict, List import numpy as np import torch @@ -86,12 +86,38 @@ def after_patch_embedding(self, x: List[torch.Tensor], pose_latents, face_pixel_ encode_bs = 8 face_pixel_values_tmp = [] for i in range(math.ceil(face_pixel_values.shape[0]/encode_bs)): - face_pixel_values_tmp.append(self.motion_encoder.get_motion(face_pixel_values[i*encode_bs:(i+1)*encode_bs])) + if torch.is_grad_enabled() and self.gradient_checkpointing: + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + face_pixel_values_tmp.append( + torch.utils.checkpoint.checkpoint( + create_custom_forward(self.motion_encoder.get_motion), + face_pixel_values[i*encode_bs:(i+1)*encode_bs], + **ckpt_kwargs + ) + ) + else: + face_pixel_values_tmp.append(self.motion_encoder.get_motion(face_pixel_values[i*encode_bs:(i+1)*encode_bs])) motion_vec = torch.cat(face_pixel_values_tmp) motion_vec = rearrange(motion_vec, "(b t) c -> b t c", t=T) - motion_vec = self.face_encoder(motion_vec) + if torch.is_grad_enabled() and self.gradient_checkpointing: + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + motion_vec = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.face_encoder), + motion_vec, + **ckpt_kwargs + ) + else: + motion_vec = self.face_encoder(motion_vec) B, L, H, C = motion_vec.shape pad_face = torch.zeros(B, 1, H, C).type_as(motion_vec) @@ -207,14 +233,17 @@ def forward( for idx, block in enumerate(self.blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: - def create_custom_forward(module): + def create_custom_forward_with_adapter(module, block_idx, motion_vec_ref): def custom_forward(*inputs): - return module(*inputs) + x = module(*inputs) + x = x.to(inputs[0].dtype) + x = self.after_transformer_block(block_idx, x, motion_vec_ref.to(inputs[0].dtype)) + return x return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + create_custom_forward_with_adapter(block, idx, motion_vec), x, e0, seq_lens, @@ -226,8 +255,6 @@ def custom_forward(*inputs): t, **ckpt_kwargs, ) - x, motion_vec = x.to(dtype), motion_vec.to(dtype) - x = self.after_transformer_block(idx, x, motion_vec) else: # arguments kwargs = dict( @@ -252,14 +279,17 @@ def custom_forward(*inputs): for idx, block in enumerate(self.blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: - def create_custom_forward(module): + def create_custom_forward_with_adapter(module, block_idx, motion_vec_ref): def custom_forward(*inputs): - return module(*inputs) + x = module(*inputs) + x = x.to(inputs[0].dtype) + x = self.after_transformer_block(block_idx, x, motion_vec_ref.to(inputs[0].dtype)) + return x return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} x = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + create_custom_forward_with_adapter(block, idx, motion_vec), x, e0, seq_lens, @@ -271,8 +301,6 @@ def custom_forward(*inputs): t, **ckpt_kwargs, ) - x, motion_vec = x.to(dtype), motion_vec.to(dtype) - x = self.after_transformer_block(idx, x, motion_vec) else: # arguments kwargs = dict( @@ -290,7 +318,20 @@ def custom_forward(*inputs): x = self.after_transformer_block(idx, x, motion_vec) # head - x = self.head(x, e) + if torch.is_grad_enabled() and self.gradient_checkpointing: + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.head), + x, + e, + **ckpt_kwargs + ) + else: + x = self.head(x, e) # Context Parallel if self.sp_world_size > 1: From c7989a60455caaf6ba91726ef7180865ecaaea79 Mon Sep 17 00:00:00 2001 From: bubbliiiing <3323290568@qq.com> Date: Fri, 13 Mar 2026 16:24:06 +0800 Subject: [PATCH 10/10] update attention and s2v validation --- scripts/fantasytalking/train.py | 11 +++- scripts/longcatvideo/train_avatar.py | 11 +++- scripts/longcatvideo/train_avatar_lora.py | 11 +++- scripts/wan2.2/train_animate.sh | 2 +- scripts/wan2.2/train_animate_lora.sh | 4 +- scripts/wan2.2/train_s2v.py | 11 +++- scripts/wan2.2/train_s2v_lora.py | 11 +++- videox_fun/models/attention_utils.py | 63 ++++++++++++++++++++++- 8 files changed, 110 insertions(+), 14 deletions(-) diff --git a/scripts/fantasytalking/train.py b/scripts/fantasytalking/train.py index df856352..9b8fc915 100644 --- a/scripts/fantasytalking/train.py +++ b/scripts/fantasytalking/train.py @@ -81,7 +81,7 @@ from videox_fun.utils.discrete_sampler import DiscreteSampling from videox_fun.utils.utils import (calculate_dimensions, get_image_to_video_latent, - save_videos_grid) + merge_video_audio, save_videos_grid) if is_wandb_available(): import wandb @@ -209,9 +209,16 @@ def log_validation(vae, text_encoder, tokenizer, clip_image_encoder, audio_encod sample, os.path.join( args.output_dir, - f"sample/sample-{global_step}-rank{accelerator.process_index}-image-{i}.gif" + f"sample/sample-{global_step}-rank{accelerator.process_index}-image-{i}.mp4" ) ) + merge_video_audio( + video_path=os.path.join( + args.output_dir, + f"sample/sample-{global_step}-rank{accelerator.process_index}-image-{i}.mp4" + ), + audio_path=args.validation_audio_paths[i] + ) del pipeline gc.collect() diff --git a/scripts/longcatvideo/train_avatar.py b/scripts/longcatvideo/train_avatar.py index 49a72194..98020560 100644 --- a/scripts/longcatvideo/train_avatar.py +++ b/scripts/longcatvideo/train_avatar.py @@ -86,7 +86,7 @@ from videox_fun.utils.discrete_sampler import DiscreteSampling from videox_fun.utils.utils import (calculate_dimensions, get_image_to_video_latent, - save_videos_grid) + merge_video_audio, save_videos_grid) if is_wandb_available(): import wandb @@ -230,9 +230,16 @@ def log_validation(vae, text_encoder, tokenizer, audio_encoder, transformer3d, a sample, os.path.join( args.output_dir, - f"sample/sample-{global_step}-rank{accelerator.process_index}-image-{i}.gif" + f"sample/sample-{global_step}-rank{accelerator.process_index}-image-{i}.mp4" ) ) + merge_video_audio( + video_path=os.path.join( + args.output_dir, + f"sample/sample-{global_step}-rank{accelerator.process_index}-image-{i}.mp4" + ), + audio_path=args.validation_audio_paths[i] + ) del pipeline gc.collect() diff --git a/scripts/longcatvideo/train_avatar_lora.py b/scripts/longcatvideo/train_avatar_lora.py index e0fc214f..1e690bd6 100644 --- a/scripts/longcatvideo/train_avatar_lora.py +++ b/scripts/longcatvideo/train_avatar_lora.py @@ -89,7 +89,7 @@ unmerge_lora) from videox_fun.utils.utils import (calculate_dimensions, get_image_to_video_latent, - save_videos_grid) + merge_video_audio, save_videos_grid) if is_wandb_available(): import wandb @@ -233,9 +233,16 @@ def log_validation(vae, text_encoder, tokenizer, audio_encoder, transformer3d, n sample, os.path.join( args.output_dir, - f"sample/sample-{global_step}-rank{accelerator.process_index}-image-{i}.gif" + f"sample/sample-{global_step}-rank{accelerator.process_index}-image-{i}.mp4" ) ) + merge_video_audio( + video_path=os.path.join( + args.output_dir, + f"sample/sample-{global_step}-rank{accelerator.process_index}-image-{i}.mp4" + ), + audio_path=args.validation_audio_paths[i] + ) del pipeline gc.collect() diff --git a/scripts/wan2.2/train_animate.sh b/scripts/wan2.2/train_animate.sh index 5f3ff8cb..fdefd709 100644 --- a/scripts/wan2.2/train_animate.sh +++ b/scripts/wan2.2/train_animate.sh @@ -14,7 +14,7 @@ accelerate launch --mixed_precision="bf16" scripts/wan2.2/train_animate.py \ --video_sample_size=640 \ --token_sample_size=640 \ --video_sample_stride=2 \ - --video_sample_n_frames=81 \ + --video_sample_n_frames=77 \ --train_batch_size=1 \ --video_repeat=1 \ --gradient_accumulation_steps=1 \ diff --git a/scripts/wan2.2/train_animate_lora.sh b/scripts/wan2.2/train_animate_lora.sh index 19cd3499..c8d23561 100644 --- a/scripts/wan2.2/train_animate_lora.sh +++ b/scripts/wan2.2/train_animate_lora.sh @@ -14,7 +14,7 @@ accelerate launch --mixed_precision="bf16" scripts/wan2.2/train_animate_lora.py --video_sample_size=640 \ --token_sample_size=640 \ --video_sample_stride=2 \ - --video_sample_n_frames=81 \ + --video_sample_n_frames=77 \ --train_batch_size=1 \ --video_repeat=1 \ --gradient_accumulation_steps=1 \ @@ -23,7 +23,7 @@ accelerate launch --mixed_precision="bf16" scripts/wan2.2/train_animate_lora.py --checkpointing_steps=50 \ --learning_rate=1e-04 \ --seed=42 \ - --output_dir="output_dir_animate_lora" \ + --output_dir="output_dir_wan2.2_animate_lora" \ --gradient_checkpointing \ --mixed_precision="bf16" \ --adam_weight_decay=3e-2 \ diff --git a/scripts/wan2.2/train_s2v.py b/scripts/wan2.2/train_s2v.py index 8781a81c..e5abe7c6 100644 --- a/scripts/wan2.2/train_s2v.py +++ b/scripts/wan2.2/train_s2v.py @@ -82,7 +82,7 @@ from videox_fun.utils.utils import (calculate_dimensions, get_image_latent, get_image_to_video_latent, get_video_to_video_latent, - save_videos_grid) + merge_video_audio, save_videos_grid) if is_wandb_available(): import wandb @@ -216,9 +216,16 @@ def log_validation(vae, text_encoder, tokenizer, audio_encoder, transformer3d, a sample, os.path.join( args.output_dir, - f"sample/sample-{global_step}-rank{accelerator.process_index}-image-{i}.gif" + f"sample/sample-{global_step}-rank{accelerator.process_index}-image-{i}.mp4" ) ) + merge_video_audio( + video_path=os.path.join( + args.output_dir, + f"sample/sample-{global_step}-rank{accelerator.process_index}-image-{i}.mp4" + ), + audio_path=args.validation_audio_paths[i] + ) del pipeline gc.collect() diff --git a/scripts/wan2.2/train_s2v_lora.py b/scripts/wan2.2/train_s2v_lora.py index a1056d01..fe68c6be 100644 --- a/scripts/wan2.2/train_s2v_lora.py +++ b/scripts/wan2.2/train_s2v_lora.py @@ -92,7 +92,7 @@ from videox_fun.utils.utils import (calculate_dimensions, get_image_latent, get_image_to_video_latent, get_video_to_video_latent, - save_videos_grid) + merge_video_audio, save_videos_grid) if is_wandb_available(): import wandb @@ -226,9 +226,16 @@ def log_validation(vae, text_encoder, tokenizer, audio_encoder, transformer3d, n sample, os.path.join( args.output_dir, - f"sample/sample-{global_step}-rank{accelerator.process_index}-image-{i}.gif" + f"sample/sample-{global_step}-rank{accelerator.process_index}-image-{i}.mp4" ) ) + merge_video_audio( + video_path=os.path.join( + args.output_dir, + f"sample/sample-{global_step}-rank{accelerator.process_index}-image-{i}.mp4" + ), + audio_path=args.validation_audio_paths[i] + ) del pipeline gc.collect() diff --git a/videox_fun/models/attention_utils.py b/videox_fun/models/attention_utils.py index 6fc2d7ef..12769442 100644 --- a/videox_fun/models/attention_utils.py +++ b/videox_fun/models/attention_utils.py @@ -65,6 +65,56 @@ def convert_qkv_dtype(q, k, v): return q, k, v +def _convert_attn_mask_to_lens(attn_mask): + """ + Convert attention mask to sequence lengths for Flash Attention. + + Args: + attn_mask: Attention mask, can be: + - [B, L] with 1=valid, 0=padding + - [B, 1, L] or [B, 1, 1, L] attention bias with 0=valid, -inf/-10000=padding + - [B, H, Lq, Lk] full attention mask + + Returns: + k_lens: [B] tensor of valid sequence lengths, or None if not a simple padding mask + """ + if attn_mask is None: + return None + + # Squeeze to simplest form + while attn_mask.ndim > 2 and attn_mask.shape[1] == 1: + attn_mask = attn_mask.squeeze(1) + + # Only handle [B, L] case (simple padding mask) + if attn_mask.ndim != 2: + return None + + # Check if it's attention bias format (0 and -inf/-10000) or binary mask (0/1) + unique_vals = torch.unique(attn_mask) + if len(unique_vals) > 2: + return None # Complex mask, can't convert + + # Determine which value means "valid" + max_val = unique_vals.max().item() + + if max_val <= 0: # Attention bias format: 0=valid, negative=padding + valid_mask = (attn_mask >= -1.0) # 0 is valid + else: # Binary format: 1=valid, 0=padding + valid_mask = (attn_mask > 0.5) + + # Check if it's a simple left-padded or right-padded mask + # For right-padding: [1,1,1,0,0] -> valid tokens are contiguous from start + k_lens = valid_mask.sum(dim=-1).to(torch.int32) + + # Verify it's actually a contiguous padding mask by reconstruction + B, L = valid_mask.shape + reconstructed = torch.arange(L, device=valid_mask.device).unsqueeze(0) < k_lens.unsqueeze(1) + if not torch.all(reconstructed == valid_mask): + return None # Not a simple contiguous padding mask + + return k_lens + + def flash_attention_naive( q, k, @@ -232,10 +282,21 @@ def attention( if torch.is_grad_enabled() and attention_type == "SAGE_ATTENTION": attention_type = "FLASH_ATTENTION" + # Convert attn_mask to k_lens for Flash Attention if possible + # Note: flash_attention doesn't support variable-length query, only set k_lens + if attn_mask is not None and k_lens is None and attention_type == "FLASH_ATTENTION": + converted_lens = _convert_attn_mask_to_lens(attn_mask) + if converted_lens is not None: + k_lens = converted_lens + attn_mask = None # Successfully converted, clear the mask + else: + # Conversion failed, fallback to SDPA which supports attn_mask + attention_type = "SDPA" + if attention_type == "SAGE_ATTENTION" and SAGE_ATTENTION_AVAILABLE: if q_lens is not None or k_lens is not None: warnings.warn( - 'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.' + 'Padding mask is disabled when using SAGE_ATTENTION. It can have a significant impact on performance.' ) q, k, v = convert_qkv_dtype(q, k, v)