From c0f86e11812b35a244001bc33f68b2c102126c23 Mon Sep 17 00:00:00 2001 From: Dave Cranwell Date: Fri, 30 Jan 2026 21:25:47 +0000 Subject: [PATCH 1/8] First attempt at converting zit safetensor files to diffuers format --- .../wan22/wan22_14b_i2v_model.py | 11 + .../diffusion_models/z_image/z_image.py | 201 ++++++++++++++++-- 2 files changed, 198 insertions(+), 14 deletions(-) diff --git a/extensions_built_in/diffusion_models/wan22/wan22_14b_i2v_model.py b/extensions_built_in/diffusion_models/wan22/wan22_14b_i2v_model.py index 32eb11e8a..bafcbec2d 100644 --- a/extensions_built_in/diffusion_models/wan22/wan22_14b_i2v_model.py +++ b/extensions_built_in/diffusion_models/wan22/wan22_14b_i2v_model.py @@ -15,6 +15,17 @@ class Wan2214bI2VModel(Wan2214bModel): arch = "wan22_14b_i2v" + def get_model_has_grad(self): + # Check if the transformers have gradients enabled + # We need to check both transformers since this is a dual model + transformer_1_has_grad = self.model.transformer_1.proj_out.weight.requires_grad + transformer_2_has_grad = self.model.transformer_2.proj_out.weight.requires_grad + return transformer_1_has_grad or transformer_2_has_grad + + def get_te_has_grad(self): + # Check if the text encoder has gradients enabled + return self.text_encoder.encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad + def generate_single_image( self, diff --git a/extensions_built_in/diffusion_models/z_image/z_image.py b/extensions_built_in/diffusion_models/z_image/z_image.py index 368ae9e7c..9d407f7a1 100644 --- a/extensions_built_in/diffusion_models/z_image/z_image.py +++ b/extensions_built_in/diffusion_models/z_image/z_image.py @@ -24,6 +24,15 @@ try: from diffusers import ZImagePipeline from diffusers.models.transformers import ZImageTransformer2DModel + # Try to import config - may be in different locations depending on diffusers version + try: + from diffusers.models.transformers.transformer_2d import ZImageTransformer2DModelConfig + except ImportError: + try: + from diffusers import ZImageTransformer2DModelConfig + except ImportError: + # If config class not available, we'll create config dict instead + ZImageTransformer2DModelConfig = None except ImportError: raise ImportError( "Diffusers is out of date. Update diffusers to the latest version by doing pip uninstall diffusers and then pip install -r requirements.txt" @@ -147,6 +156,117 @@ def load_training_adapter(self, transformer: ZImageTransformer2DModel): # tell the model to invert assistant on inference since we want remove lora effects self.invert_assistant_lora = True + @staticmethod + def _convert_zimage_safetensors_to_diffusers(state_dict): + """ + Convert Z-Image safetensors keys from original format to diffusers format. + This allows loading safetensors files directly without requiring diffusers format conversion. + """ + out = {} + + def rename_key(k: str) -> str: + # strip prefix + if k.startswith("model.diffusion_model."): + k = k.replace("model.diffusion_model.", "") + + # rename embedder + final layer to match diffusers ZImage + k = k.replace("x_embedder", "all_x_embedder.2-1") + k = k.replace("final_layer", "all_final_layer.2-1") + + # attention renames + k = k.replace("attention.out", "attention.to_out.0") + k = k.replace("attention.q_norm", "attention.norm_q") + k = k.replace("attention.k_norm", "attention.norm_k") + + return k + + for key, tensor in state_dict.items(): + # handle qkv split (both weight and bias if present) + if ".attention.qkv." in key: + base = key.replace("model.diffusion_model.", "") + # Remove .attention.qkv.weight or .attention.qkv.bias + if base.endswith(".attention.qkv.weight"): + base = base.replace(".attention.qkv.weight", "") + # ZImage uses [3 * hidden, hidden] layout for weights + q, k, v = torch.chunk(tensor, 3, dim=0) + out[f"{base}.attention.to_q.weight"] = q + out[f"{base}.attention.to_k.weight"] = k + out[f"{base}.attention.to_v.weight"] = v + elif base.endswith(".attention.qkv.bias"): + base = base.replace(".attention.qkv.bias", "") + # ZImage uses [3 * hidden] layout for bias + q, k, v = torch.chunk(tensor, 3, dim=0) + out[f"{base}.attention.to_q.bias"] = q + out[f"{base}.attention.to_k.bias"] = k + out[f"{base}.attention.to_v.bias"] = v + continue + + new_key = rename_key(key) + out[new_key] = tensor + + return out + + @staticmethod + def _infer_config_from_state_dict(state_dict): + """ + Infer ZImageTransformer2DModel config from state dict keys and shapes. + Similar to make_configs.py but integrated for runtime use. + """ + # Infer hidden size from q projection + q_keys = [k for k in state_dict if k.endswith(".attention.to_q.weight")] + if not q_keys: + raise RuntimeError("No attention.to_q.weight keys found in state dict") + + sample_q = state_dict[q_keys[0]] + hidden_size = sample_q.shape[0] + + # Infer attention heads - Z-Image uses square Q projection: [hidden, hidden] + # head_dim is typically 64; verify divisibility + for head_dim in (64, 128, 32): + if hidden_size % head_dim == 0: + num_heads = hidden_size // head_dim + break + else: + raise RuntimeError(f"Cannot infer head_dim from hidden_size={hidden_size}") + + # Infer number of layers + layer_ids = set() + for k in state_dict: + if k.startswith("layers."): + layer_ids.add(int(k.split(".")[1])) + + num_layers = max(layer_ids) + 1 if layer_ids else 32 # default to 32 if not found + + # Infer cross-attention dim from x_embedder + x_keys = [k for k in state_dict if "all_x_embedder" in k and k.endswith(".weight")] + if not x_keys: + # Fallback: try to infer from other keys or use default + cross_attention_dim = 2048 # default for Z-Image-Turbo + else: + x_w = state_dict[x_keys[0]] + cross_attention_dim = x_w.shape[1] + + # Create config dict or config object depending on availability + config_dict = { + "num_layers": num_layers, + "num_attention_heads": num_heads, + "attention_head_dim": head_dim, + "hidden_size": hidden_size, + "in_channels": 4, + "cross_attention_dim": cross_attention_dim, + "norm_type": "ada_norm_single", + "norm_eps": 1e-05, + "use_bias": True, + } + + if ZImageTransformer2DModelConfig is not None: + config = ZImageTransformer2DModelConfig(**config_dict) + else: + # Fallback: use dict and let from_config handle it + config = config_dict + + return config + def load_model(self): dtype = self.torch_dtype self.print_and_status_update("Loading ZImage model") @@ -155,20 +275,65 @@ def load_model(self): self.print_and_status_update("Loading transformer") - transformer_path = model_path - transformer_subfolder = "transformer" - if os.path.exists(transformer_path): - transformer_subfolder = None - transformer_path = os.path.join(transformer_path, "transformer") - # check if the path is a full checkpoint. - te_folder_path = os.path.join(model_path, "text_encoder") - # if we have the te, this folder is a full checkpoint, use it as the base - if os.path.exists(te_folder_path): - base_model_path = model_path - - transformer = ZImageTransformer2DModel.from_pretrained( - transformer_path, subfolder=transformer_subfolder, torch_dtype=dtype - ) + # Check if model_path is a safetensors file (direct checkpoint loading) + # Note: When loading from safetensors, only the transformer is loaded from the file. + # Text encoder and VAE are still loaded from base_model_path (extras_name_or_path), + # which must be set to a diffusers format path (e.g., "Tongyi-MAI/Z-Image-Turbo") + if model_path.endswith(".safetensors") and os.path.exists(model_path): + if not base_model_path: + raise ValueError( + "When loading transformer from safetensors file, extras_name_or_path must be set " + "to provide the text encoder and VAE (e.g., 'Tongyi-MAI/Z-Image-Turbo')" + ) + + self.print_and_status_update("Loading from safetensors file (converting keys to diffusers format)") + # Load and convert the safetensors file + state_dict = load_file(model_path, device='cpu') + converted_state_dict = self._convert_zimage_safetensors_to_diffusers(state_dict) + + # Infer config from the converted state dict + config = self._infer_config_from_state_dict(converted_state_dict) + + # Create model from config + if isinstance(config, dict): + transformer = ZImageTransformer2DModel.from_config(config) + else: + transformer = ZImageTransformer2DModel(config) + # Load the converted state dict + transformer.load_state_dict(converted_state_dict, strict=False) + transformer = transformer.to(dtype) + + # Extract values for logging (handle both dict and object) + if isinstance(config, dict): + hidden_size = config["hidden_size"] + num_layers = config["num_layers"] + num_heads = config["num_attention_heads"] + else: + hidden_size = config.hidden_size + num_layers = config.num_layers + num_heads = config.num_attention_heads + + self.print_and_status_update(f"Inferred config: hidden_size={hidden_size}, " + f"num_layers={num_layers}, " + f"num_attention_heads={num_heads}") + self.print_and_status_update(f"Text encoder and VAE will be loaded from: {base_model_path}") + self.print_and_status_update("Note: Only text_encoder, tokenizer, and vae subfolders will be downloaded, NOT the transformer") + else: + # Original diffusers format loading + transformer_path = model_path + transformer_subfolder = "transformer" + if os.path.exists(transformer_path): + transformer_subfolder = None + transformer_path = os.path.join(transformer_path, "transformer") + # check if the path is a full checkpoint. + te_folder_path = os.path.join(model_path, "text_encoder") + # if we have the te, this folder is a full checkpoint, use it as the base + if os.path.exists(te_folder_path): + base_model_path = model_path + + transformer = ZImageTransformer2DModel.from_pretrained( + transformer_path, subfolder=transformer_subfolder, torch_dtype=dtype + ) # load assistant lora if specified if self.model_config.assistant_lora_path is not None: @@ -202,7 +367,11 @@ def load_model(self): flush() + # Load text encoder and tokenizer from base_model_path + # Using subfolder parameter ensures ONLY the text_encoder and tokenizer subfolders + # are downloaded, NOT the transformer folder (which is loaded from safetensors file) self.print_and_status_update("Text Encoder") + self.print_and_status_update(f"Downloading text encoder and tokenizer from: {base_model_path}") tokenizer = AutoTokenizer.from_pretrained( base_model_path, subfolder="tokenizer", torch_dtype=dtype ) @@ -229,7 +398,11 @@ def load_model(self): freeze(text_encoder) flush() + # Load VAE from base_model_path + # Using subfolder parameter ensures ONLY the vae subfolder is downloaded, + # NOT the transformer folder (which is loaded from safetensors file) self.print_and_status_update("Loading VAE") + self.print_and_status_update(f"Downloading VAE from: {base_model_path}") vae = AutoencoderKL.from_pretrained( base_model_path, subfolder="vae", torch_dtype=dtype ) From eea696c0cd3cb1d543f8d89242e8fa00ed71d7c4 Mon Sep 17 00:00:00 2001 From: Dave Cranwell Date: Sun, 1 Feb 2026 14:30:29 +0000 Subject: [PATCH 2/8] Wrong size channels --- .../diffusion_models/z_image/z_image.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/extensions_built_in/diffusion_models/z_image/z_image.py b/extensions_built_in/diffusion_models/z_image/z_image.py index 9d407f7a1..51f44901e 100644 --- a/extensions_built_in/diffusion_models/z_image/z_image.py +++ b/extensions_built_in/diffusion_models/z_image/z_image.py @@ -237,14 +237,17 @@ def _infer_config_from_state_dict(state_dict): num_layers = max(layer_ids) + 1 if layer_ids else 32 # default to 32 if not found - # Infer cross-attention dim from x_embedder + # Infer in_channels from x_embedder weight shape + # x_embedder has shape [hidden_size, in_channels * patch_size * patch_size] + # patch_size is 2, so in_channels = x_embedder.shape[1] / 4 x_keys = [k for k in state_dict if "all_x_embedder" in k and k.endswith(".weight")] if not x_keys: - # Fallback: try to infer from other keys or use default - cross_attention_dim = 2048 # default for Z-Image-Turbo + # Fallback: use default + in_channels = 4 # default for standard VAE else: x_w = state_dict[x_keys[0]] - cross_attention_dim = x_w.shape[1] + patch_size = 2 # ZImage uses patch_size=2 + in_channels = x_w.shape[1] // (patch_size * patch_size) # Create config dict or config object depending on availability config_dict = { @@ -252,13 +255,15 @@ def _infer_config_from_state_dict(state_dict): "num_attention_heads": num_heads, "attention_head_dim": head_dim, "hidden_size": hidden_size, - "in_channels": 4, - "cross_attention_dim": cross_attention_dim, + "in_channels": in_channels, "norm_type": "ada_norm_single", "norm_eps": 1e-05, "use_bias": True, } + # Note: cross_attention_dim is not typically needed in config as it's inferred + # from text encoder, but if needed it can be added here + if ZImageTransformer2DModelConfig is not None: config = ZImageTransformer2DModelConfig(**config_dict) else: From 262baf521417413b58f4f5ec5135022440078517 Mon Sep 17 00:00:00 2001 From: Dave Cranwell Date: Sun, 1 Feb 2026 15:34:32 +0000 Subject: [PATCH 3/8] Updates docs/options --- ui/src/app/jobs/new/options.ts | 1 + ui/src/docs.tsx | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/ui/src/app/jobs/new/options.ts b/ui/src/app/jobs/new/options.ts index 64fa7cfa0..361fbdc9c 100644 --- a/ui/src/app/jobs/new/options.ts +++ b/ui/src/app/jobs/new/options.ts @@ -565,6 +565,7 @@ export const modelArchs: ModelArch[] = [ defaults: { // default updates when [selected, unselected] in the UI 'config.process[0].model.name_or_path': ['Tongyi-MAI/Z-Image-Turbo', defaultNameOrPath], + 'config.process[0].model.extras_name_or_path': ['Tongyi-MAI/Z-Image-Turbo', undefined], 'config.process[0].model.quantize': [true, false], 'config.process[0].model.quantize_te': [true, false], 'config.process[0].model.low_vram': [true, false], diff --git a/ui/src/docs.tsx b/ui/src/docs.tsx index c82d800bd..d7bf23e5d 100644 --- a/ui/src/docs.tsx +++ b/ui/src/docs.tsx @@ -47,8 +47,8 @@ const docs: { [key: string]: ConfigDoc } = { description: ( <> The name of a diffusers repo on Huggingface or the local path to the base model you want to train from. The - folder needs to be in diffusers format for most models. For some models, such as SDXL and SD1, you can put the - path to an all in one safetensors checkpoint here. + folder needs to be in diffusers format for most models. For some models, such as Z-Image Turbo, SDXL and SD1, you can put the + path to a safetensors checkpoint here. ), }, From 02074b3b548d583a49083ae2e77bde3a8803b4d3 Mon Sep 17 00:00:00 2001 From: Dave Cranwell Date: Sun, 1 Feb 2026 15:38:05 +0000 Subject: [PATCH 4/8] Unnecessary changes removed --- extensions_built_in/diffusion_models/z_image/z_image.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/extensions_built_in/diffusion_models/z_image/z_image.py b/extensions_built_in/diffusion_models/z_image/z_image.py index 51f44901e..25834a2e2 100644 --- a/extensions_built_in/diffusion_models/z_image/z_image.py +++ b/extensions_built_in/diffusion_models/z_image/z_image.py @@ -372,11 +372,7 @@ def load_model(self): flush() - # Load text encoder and tokenizer from base_model_path - # Using subfolder parameter ensures ONLY the text_encoder and tokenizer subfolders - # are downloaded, NOT the transformer folder (which is loaded from safetensors file) self.print_and_status_update("Text Encoder") - self.print_and_status_update(f"Downloading text encoder and tokenizer from: {base_model_path}") tokenizer = AutoTokenizer.from_pretrained( base_model_path, subfolder="tokenizer", torch_dtype=dtype ) @@ -403,11 +399,7 @@ def load_model(self): freeze(text_encoder) flush() - # Load VAE from base_model_path - # Using subfolder parameter ensures ONLY the vae subfolder is downloaded, - # NOT the transformer folder (which is loaded from safetensors file) self.print_and_status_update("Loading VAE") - self.print_and_status_update(f"Downloading VAE from: {base_model_path}") vae = AutoencoderKL.from_pretrained( base_model_path, subfolder="vae", torch_dtype=dtype ) From 2d6460cc6eae71f670b53aeb1fddf5644801848b Mon Sep 17 00:00:00 2001 From: Dave Cranwell Date: Sun, 1 Feb 2026 15:41:02 +0000 Subject: [PATCH 5/8] Undoes unwanted change --- .../diffusion_models/wan22/wan22_14b_i2v_model.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/extensions_built_in/diffusion_models/wan22/wan22_14b_i2v_model.py b/extensions_built_in/diffusion_models/wan22/wan22_14b_i2v_model.py index bafcbec2d..6a1c567da 100644 --- a/extensions_built_in/diffusion_models/wan22/wan22_14b_i2v_model.py +++ b/extensions_built_in/diffusion_models/wan22/wan22_14b_i2v_model.py @@ -15,18 +15,7 @@ class Wan2214bI2VModel(Wan2214bModel): arch = "wan22_14b_i2v" - def get_model_has_grad(self): - # Check if the transformers have gradients enabled - # We need to check both transformers since this is a dual model - transformer_1_has_grad = self.model.transformer_1.proj_out.weight.requires_grad - transformer_2_has_grad = self.model.transformer_2.proj_out.weight.requires_grad - return transformer_1_has_grad or transformer_2_has_grad - def get_te_has_grad(self): - # Check if the text encoder has gradients enabled - return self.text_encoder.encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad - - def generate_single_image( self, pipeline: Wan22Pipeline, From 660581a79bc2cc3052ee0bbd22cd0dd013f8d257 Mon Sep 17 00:00:00 2001 From: Dave Cranwell Date: Sun, 1 Feb 2026 15:41:33 +0000 Subject: [PATCH 6/8] Tweak --- .../diffusion_models/wan22/wan22_14b_i2v_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extensions_built_in/diffusion_models/wan22/wan22_14b_i2v_model.py b/extensions_built_in/diffusion_models/wan22/wan22_14b_i2v_model.py index 6a1c567da..e7afdf2e1 100644 --- a/extensions_built_in/diffusion_models/wan22/wan22_14b_i2v_model.py +++ b/extensions_built_in/diffusion_models/wan22/wan22_14b_i2v_model.py @@ -14,8 +14,8 @@ class Wan2214bI2VModel(Wan2214bModel): arch = "wan22_14b_i2v" - + def generate_single_image( self, pipeline: Wan22Pipeline, From 2ad2339ece15cbb43139e90654a1f953f080c0dc Mon Sep 17 00:00:00 2001 From: Dave Cranwell Date: Sun, 1 Feb 2026 15:41:49 +0000 Subject: [PATCH 7/8] Tweak --- .../diffusion_models/wan22/wan22_14b_i2v_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extensions_built_in/diffusion_models/wan22/wan22_14b_i2v_model.py b/extensions_built_in/diffusion_models/wan22/wan22_14b_i2v_model.py index e7afdf2e1..6a1c567da 100644 --- a/extensions_built_in/diffusion_models/wan22/wan22_14b_i2v_model.py +++ b/extensions_built_in/diffusion_models/wan22/wan22_14b_i2v_model.py @@ -14,8 +14,8 @@ class Wan2214bI2VModel(Wan2214bModel): arch = "wan22_14b_i2v" - + def generate_single_image( self, pipeline: Wan22Pipeline, From 13d513be7d379c758bec9a5b07c8abfcad92f9a8 Mon Sep 17 00:00:00 2001 From: Dave Cranwell Date: Sun, 1 Feb 2026 15:45:41 +0000 Subject: [PATCH 8/8] Revert --- .../diffusion_models/wan22/wan22_14b_i2v_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extensions_built_in/diffusion_models/wan22/wan22_14b_i2v_model.py b/extensions_built_in/diffusion_models/wan22/wan22_14b_i2v_model.py index 6a1c567da..32eb11e8a 100644 --- a/extensions_built_in/diffusion_models/wan22/wan22_14b_i2v_model.py +++ b/extensions_built_in/diffusion_models/wan22/wan22_14b_i2v_model.py @@ -15,7 +15,7 @@ class Wan2214bI2VModel(Wan2214bModel): arch = "wan22_14b_i2v" - + def generate_single_image( self, pipeline: Wan22Pipeline,