diff --git a/extensions_built_in/diffusion_models/z_image/z_image.py b/extensions_built_in/diffusion_models/z_image/z_image.py index 368ae9e7c..25834a2e2 100644 --- a/extensions_built_in/diffusion_models/z_image/z_image.py +++ b/extensions_built_in/diffusion_models/z_image/z_image.py @@ -24,6 +24,15 @@ try: from diffusers import ZImagePipeline from diffusers.models.transformers import ZImageTransformer2DModel + # Try to import config - may be in different locations depending on diffusers version + try: + from diffusers.models.transformers.transformer_2d import ZImageTransformer2DModelConfig + except ImportError: + try: + from diffusers import ZImageTransformer2DModelConfig + except ImportError: + # If config class not available, we'll create config dict instead + ZImageTransformer2DModelConfig = None except ImportError: raise ImportError( "Diffusers is out of date. Update diffusers to the latest version by doing pip uninstall diffusers and then pip install -r requirements.txt" @@ -147,6 +156,122 @@ def load_training_adapter(self, transformer: ZImageTransformer2DModel): # tell the model to invert assistant on inference since we want remove lora effects self.invert_assistant_lora = True + @staticmethod + def _convert_zimage_safetensors_to_diffusers(state_dict): + """ + Convert Z-Image safetensors keys from original format to diffusers format. + This allows loading safetensors files directly without requiring diffusers format conversion. + """ + out = {} + + def rename_key(k: str) -> str: + # strip prefix + if k.startswith("model.diffusion_model."): + k = k.replace("model.diffusion_model.", "") + + # rename embedder + final layer to match diffusers ZImage + k = k.replace("x_embedder", "all_x_embedder.2-1") + k = k.replace("final_layer", "all_final_layer.2-1") + + # attention renames + k = k.replace("attention.out", "attention.to_out.0") + k = k.replace("attention.q_norm", "attention.norm_q") + k = k.replace("attention.k_norm", "attention.norm_k") + + return k + + for key, tensor in state_dict.items(): + # handle qkv split (both weight and bias if present) + if ".attention.qkv." in key: + base = key.replace("model.diffusion_model.", "") + # Remove .attention.qkv.weight or .attention.qkv.bias + if base.endswith(".attention.qkv.weight"): + base = base.replace(".attention.qkv.weight", "") + # ZImage uses [3 * hidden, hidden] layout for weights + q, k, v = torch.chunk(tensor, 3, dim=0) + out[f"{base}.attention.to_q.weight"] = q + out[f"{base}.attention.to_k.weight"] = k + out[f"{base}.attention.to_v.weight"] = v + elif base.endswith(".attention.qkv.bias"): + base = base.replace(".attention.qkv.bias", "") + # ZImage uses [3 * hidden] layout for bias + q, k, v = torch.chunk(tensor, 3, dim=0) + out[f"{base}.attention.to_q.bias"] = q + out[f"{base}.attention.to_k.bias"] = k + out[f"{base}.attention.to_v.bias"] = v + continue + + new_key = rename_key(key) + out[new_key] = tensor + + return out + + @staticmethod + def _infer_config_from_state_dict(state_dict): + """ + Infer ZImageTransformer2DModel config from state dict keys and shapes. + Similar to make_configs.py but integrated for runtime use. + """ + # Infer hidden size from q projection + q_keys = [k for k in state_dict if k.endswith(".attention.to_q.weight")] + if not q_keys: + raise RuntimeError("No attention.to_q.weight keys found in state dict") + + sample_q = state_dict[q_keys[0]] + hidden_size = sample_q.shape[0] + + # Infer attention heads - Z-Image uses square Q projection: [hidden, hidden] + # head_dim is typically 64; verify divisibility + for head_dim in (64, 128, 32): + if hidden_size % head_dim == 0: + num_heads = hidden_size // head_dim + break + else: + raise RuntimeError(f"Cannot infer head_dim from hidden_size={hidden_size}") + + # Infer number of layers + layer_ids = set() + for k in state_dict: + if k.startswith("layers."): + layer_ids.add(int(k.split(".")[1])) + + num_layers = max(layer_ids) + 1 if layer_ids else 32 # default to 32 if not found + + # Infer in_channels from x_embedder weight shape + # x_embedder has shape [hidden_size, in_channels * patch_size * patch_size] + # patch_size is 2, so in_channels = x_embedder.shape[1] / 4 + x_keys = [k for k in state_dict if "all_x_embedder" in k and k.endswith(".weight")] + if not x_keys: + # Fallback: use default + in_channels = 4 # default for standard VAE + else: + x_w = state_dict[x_keys[0]] + patch_size = 2 # ZImage uses patch_size=2 + in_channels = x_w.shape[1] // (patch_size * patch_size) + + # Create config dict or config object depending on availability + config_dict = { + "num_layers": num_layers, + "num_attention_heads": num_heads, + "attention_head_dim": head_dim, + "hidden_size": hidden_size, + "in_channels": in_channels, + "norm_type": "ada_norm_single", + "norm_eps": 1e-05, + "use_bias": True, + } + + # Note: cross_attention_dim is not typically needed in config as it's inferred + # from text encoder, but if needed it can be added here + + if ZImageTransformer2DModelConfig is not None: + config = ZImageTransformer2DModelConfig(**config_dict) + else: + # Fallback: use dict and let from_config handle it + config = config_dict + + return config + def load_model(self): dtype = self.torch_dtype self.print_and_status_update("Loading ZImage model") @@ -155,20 +280,65 @@ def load_model(self): self.print_and_status_update("Loading transformer") - transformer_path = model_path - transformer_subfolder = "transformer" - if os.path.exists(transformer_path): - transformer_subfolder = None - transformer_path = os.path.join(transformer_path, "transformer") - # check if the path is a full checkpoint. - te_folder_path = os.path.join(model_path, "text_encoder") - # if we have the te, this folder is a full checkpoint, use it as the base - if os.path.exists(te_folder_path): - base_model_path = model_path - - transformer = ZImageTransformer2DModel.from_pretrained( - transformer_path, subfolder=transformer_subfolder, torch_dtype=dtype - ) + # Check if model_path is a safetensors file (direct checkpoint loading) + # Note: When loading from safetensors, only the transformer is loaded from the file. + # Text encoder and VAE are still loaded from base_model_path (extras_name_or_path), + # which must be set to a diffusers format path (e.g., "Tongyi-MAI/Z-Image-Turbo") + if model_path.endswith(".safetensors") and os.path.exists(model_path): + if not base_model_path: + raise ValueError( + "When loading transformer from safetensors file, extras_name_or_path must be set " + "to provide the text encoder and VAE (e.g., 'Tongyi-MAI/Z-Image-Turbo')" + ) + + self.print_and_status_update("Loading from safetensors file (converting keys to diffusers format)") + # Load and convert the safetensors file + state_dict = load_file(model_path, device='cpu') + converted_state_dict = self._convert_zimage_safetensors_to_diffusers(state_dict) + + # Infer config from the converted state dict + config = self._infer_config_from_state_dict(converted_state_dict) + + # Create model from config + if isinstance(config, dict): + transformer = ZImageTransformer2DModel.from_config(config) + else: + transformer = ZImageTransformer2DModel(config) + # Load the converted state dict + transformer.load_state_dict(converted_state_dict, strict=False) + transformer = transformer.to(dtype) + + # Extract values for logging (handle both dict and object) + if isinstance(config, dict): + hidden_size = config["hidden_size"] + num_layers = config["num_layers"] + num_heads = config["num_attention_heads"] + else: + hidden_size = config.hidden_size + num_layers = config.num_layers + num_heads = config.num_attention_heads + + self.print_and_status_update(f"Inferred config: hidden_size={hidden_size}, " + f"num_layers={num_layers}, " + f"num_attention_heads={num_heads}") + self.print_and_status_update(f"Text encoder and VAE will be loaded from: {base_model_path}") + self.print_and_status_update("Note: Only text_encoder, tokenizer, and vae subfolders will be downloaded, NOT the transformer") + else: + # Original diffusers format loading + transformer_path = model_path + transformer_subfolder = "transformer" + if os.path.exists(transformer_path): + transformer_subfolder = None + transformer_path = os.path.join(transformer_path, "transformer") + # check if the path is a full checkpoint. + te_folder_path = os.path.join(model_path, "text_encoder") + # if we have the te, this folder is a full checkpoint, use it as the base + if os.path.exists(te_folder_path): + base_model_path = model_path + + transformer = ZImageTransformer2DModel.from_pretrained( + transformer_path, subfolder=transformer_subfolder, torch_dtype=dtype + ) # load assistant lora if specified if self.model_config.assistant_lora_path is not None: diff --git a/ui/src/app/jobs/new/options.ts b/ui/src/app/jobs/new/options.ts index 64fa7cfa0..361fbdc9c 100644 --- a/ui/src/app/jobs/new/options.ts +++ b/ui/src/app/jobs/new/options.ts @@ -565,6 +565,7 @@ export const modelArchs: ModelArch[] = [ defaults: { // default updates when [selected, unselected] in the UI 'config.process[0].model.name_or_path': ['Tongyi-MAI/Z-Image-Turbo', defaultNameOrPath], + 'config.process[0].model.extras_name_or_path': ['Tongyi-MAI/Z-Image-Turbo', undefined], 'config.process[0].model.quantize': [true, false], 'config.process[0].model.quantize_te': [true, false], 'config.process[0].model.low_vram': [true, false], diff --git a/ui/src/docs.tsx b/ui/src/docs.tsx index c82d800bd..d7bf23e5d 100644 --- a/ui/src/docs.tsx +++ b/ui/src/docs.tsx @@ -47,8 +47,8 @@ const docs: { [key: string]: ConfigDoc } = { description: ( <> The name of a diffusers repo on Huggingface or the local path to the base model you want to train from. The - folder needs to be in diffusers format for most models. For some models, such as SDXL and SD1, you can put the - path to an all in one safetensors checkpoint here. + folder needs to be in diffusers format for most models. For some models, such as Z-Image Turbo, SDXL and SD1, you can put the + path to a safetensors checkpoint here. ), },