Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 184 additions & 14 deletions extensions_built_in/diffusion_models/z_image/z_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,15 @@
try:
from diffusers import ZImagePipeline
from diffusers.models.transformers import ZImageTransformer2DModel
# Try to import config - may be in different locations depending on diffusers version
try:
from diffusers.models.transformers.transformer_2d import ZImageTransformer2DModelConfig
except ImportError:
try:
from diffusers import ZImageTransformer2DModelConfig
except ImportError:
# If config class not available, we'll create config dict instead
ZImageTransformer2DModelConfig = None
except ImportError:
raise ImportError(
"Diffusers is out of date. Update diffusers to the latest version by doing pip uninstall diffusers and then pip install -r requirements.txt"
Expand Down Expand Up @@ -147,6 +156,122 @@ def load_training_adapter(self, transformer: ZImageTransformer2DModel):
# tell the model to invert assistant on inference since we want remove lora effects
self.invert_assistant_lora = True

@staticmethod
def _convert_zimage_safetensors_to_diffusers(state_dict):
"""
Convert Z-Image safetensors keys from original format to diffusers format.
This allows loading safetensors files directly without requiring diffusers format conversion.
"""
out = {}

def rename_key(k: str) -> str:
# strip prefix
if k.startswith("model.diffusion_model."):
k = k.replace("model.diffusion_model.", "")

# rename embedder + final layer to match diffusers ZImage
k = k.replace("x_embedder", "all_x_embedder.2-1")
k = k.replace("final_layer", "all_final_layer.2-1")

# attention renames
k = k.replace("attention.out", "attention.to_out.0")
k = k.replace("attention.q_norm", "attention.norm_q")
k = k.replace("attention.k_norm", "attention.norm_k")

return k

for key, tensor in state_dict.items():
# handle qkv split (both weight and bias if present)
if ".attention.qkv." in key:
base = key.replace("model.diffusion_model.", "")
# Remove .attention.qkv.weight or .attention.qkv.bias
if base.endswith(".attention.qkv.weight"):
base = base.replace(".attention.qkv.weight", "")
# ZImage uses [3 * hidden, hidden] layout for weights
q, k, v = torch.chunk(tensor, 3, dim=0)
out[f"{base}.attention.to_q.weight"] = q
out[f"{base}.attention.to_k.weight"] = k
out[f"{base}.attention.to_v.weight"] = v
elif base.endswith(".attention.qkv.bias"):
base = base.replace(".attention.qkv.bias", "")
# ZImage uses [3 * hidden] layout for bias
q, k, v = torch.chunk(tensor, 3, dim=0)
out[f"{base}.attention.to_q.bias"] = q
out[f"{base}.attention.to_k.bias"] = k
out[f"{base}.attention.to_v.bias"] = v
continue

new_key = rename_key(key)
out[new_key] = tensor

return out

@staticmethod
def _infer_config_from_state_dict(state_dict):
"""
Infer ZImageTransformer2DModel config from state dict keys and shapes.
Similar to make_configs.py but integrated for runtime use.
"""
# Infer hidden size from q projection
q_keys = [k for k in state_dict if k.endswith(".attention.to_q.weight")]
if not q_keys:
raise RuntimeError("No attention.to_q.weight keys found in state dict")

sample_q = state_dict[q_keys[0]]
hidden_size = sample_q.shape[0]

# Infer attention heads - Z-Image uses square Q projection: [hidden, hidden]
# head_dim is typically 64; verify divisibility
for head_dim in (64, 128, 32):
if hidden_size % head_dim == 0:
num_heads = hidden_size // head_dim
break
else:
raise RuntimeError(f"Cannot infer head_dim from hidden_size={hidden_size}")

# Infer number of layers
layer_ids = set()
for k in state_dict:
if k.startswith("layers."):
layer_ids.add(int(k.split(".")[1]))

num_layers = max(layer_ids) + 1 if layer_ids else 32 # default to 32 if not found

# Infer in_channels from x_embedder weight shape
# x_embedder has shape [hidden_size, in_channels * patch_size * patch_size]
# patch_size is 2, so in_channels = x_embedder.shape[1] / 4
x_keys = [k for k in state_dict if "all_x_embedder" in k and k.endswith(".weight")]
if not x_keys:
# Fallback: use default
in_channels = 4 # default for standard VAE
else:
x_w = state_dict[x_keys[0]]
patch_size = 2 # ZImage uses patch_size=2
in_channels = x_w.shape[1] // (patch_size * patch_size)

# Create config dict or config object depending on availability
config_dict = {
"num_layers": num_layers,
"num_attention_heads": num_heads,
"attention_head_dim": head_dim,
"hidden_size": hidden_size,
"in_channels": in_channels,
"norm_type": "ada_norm_single",
"norm_eps": 1e-05,
"use_bias": True,
}

# Note: cross_attention_dim is not typically needed in config as it's inferred
# from text encoder, but if needed it can be added here

if ZImageTransformer2DModelConfig is not None:
config = ZImageTransformer2DModelConfig(**config_dict)
else:
# Fallback: use dict and let from_config handle it
config = config_dict

return config

def load_model(self):
dtype = self.torch_dtype
self.print_and_status_update("Loading ZImage model")
Expand All @@ -155,20 +280,65 @@ def load_model(self):

self.print_and_status_update("Loading transformer")

transformer_path = model_path
transformer_subfolder = "transformer"
if os.path.exists(transformer_path):
transformer_subfolder = None
transformer_path = os.path.join(transformer_path, "transformer")
# check if the path is a full checkpoint.
te_folder_path = os.path.join(model_path, "text_encoder")
# if we have the te, this folder is a full checkpoint, use it as the base
if os.path.exists(te_folder_path):
base_model_path = model_path

transformer = ZImageTransformer2DModel.from_pretrained(
transformer_path, subfolder=transformer_subfolder, torch_dtype=dtype
)
# Check if model_path is a safetensors file (direct checkpoint loading)
# Note: When loading from safetensors, only the transformer is loaded from the file.
# Text encoder and VAE are still loaded from base_model_path (extras_name_or_path),
# which must be set to a diffusers format path (e.g., "Tongyi-MAI/Z-Image-Turbo")
if model_path.endswith(".safetensors") and os.path.exists(model_path):
if not base_model_path:
raise ValueError(
"When loading transformer from safetensors file, extras_name_or_path must be set "
"to provide the text encoder and VAE (e.g., 'Tongyi-MAI/Z-Image-Turbo')"
)

self.print_and_status_update("Loading from safetensors file (converting keys to diffusers format)")
# Load and convert the safetensors file
state_dict = load_file(model_path, device='cpu')
converted_state_dict = self._convert_zimage_safetensors_to_diffusers(state_dict)

# Infer config from the converted state dict
config = self._infer_config_from_state_dict(converted_state_dict)

# Create model from config
if isinstance(config, dict):
transformer = ZImageTransformer2DModel.from_config(config)
else:
transformer = ZImageTransformer2DModel(config)
# Load the converted state dict
transformer.load_state_dict(converted_state_dict, strict=False)
transformer = transformer.to(dtype)

# Extract values for logging (handle both dict and object)
if isinstance(config, dict):
hidden_size = config["hidden_size"]
num_layers = config["num_layers"]
num_heads = config["num_attention_heads"]
else:
hidden_size = config.hidden_size
num_layers = config.num_layers
num_heads = config.num_attention_heads

self.print_and_status_update(f"Inferred config: hidden_size={hidden_size}, "
f"num_layers={num_layers}, "
f"num_attention_heads={num_heads}")
self.print_and_status_update(f"Text encoder and VAE will be loaded from: {base_model_path}")
self.print_and_status_update("Note: Only text_encoder, tokenizer, and vae subfolders will be downloaded, NOT the transformer")
else:
# Original diffusers format loading
transformer_path = model_path
transformer_subfolder = "transformer"
if os.path.exists(transformer_path):
transformer_subfolder = None
transformer_path = os.path.join(transformer_path, "transformer")
# check if the path is a full checkpoint.
te_folder_path = os.path.join(model_path, "text_encoder")
# if we have the te, this folder is a full checkpoint, use it as the base
if os.path.exists(te_folder_path):
base_model_path = model_path

transformer = ZImageTransformer2DModel.from_pretrained(
transformer_path, subfolder=transformer_subfolder, torch_dtype=dtype
)

# load assistant lora if specified
if self.model_config.assistant_lora_path is not None:
Expand Down
1 change: 1 addition & 0 deletions ui/src/app/jobs/new/options.ts
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,7 @@ export const modelArchs: ModelArch[] = [
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.name_or_path': ['Tongyi-MAI/Z-Image-Turbo', defaultNameOrPath],
'config.process[0].model.extras_name_or_path': ['Tongyi-MAI/Z-Image-Turbo', undefined],
'config.process[0].model.quantize': [true, false],
'config.process[0].model.quantize_te': [true, false],
'config.process[0].model.low_vram': [true, false],
Expand Down
4 changes: 2 additions & 2 deletions ui/src/docs.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ const docs: { [key: string]: ConfigDoc } = {
description: (
<>
The name of a diffusers repo on Huggingface or the local path to the base model you want to train from. The
folder needs to be in diffusers format for most models. For some models, such as SDXL and SD1, you can put the
path to an all in one safetensors checkpoint here.
folder needs to be in diffusers format for most models. For some models, such as Z-Image Turbo, SDXL and SD1, you can put the
path to a safetensors checkpoint here.
</>
),
},
Expand Down