Training, Optimizers, Device Handling & Z-Image Improvements#713
Draft
daanforever wants to merge 17 commits intoostris:mainfrom
Draft
Training, Optimizers, Device Handling & Z-Image Improvements#713daanforever wants to merge 17 commits intoostris:mainfrom
daanforever wants to merge 17 commits intoostris:mainfrom
Conversation
- blank_prompt_probability, differential guidance metric, LR graph - Gaussian/content-style timestep sampling, warmup for cosine LR - BaseSDTrainProcess timestep mapping, fixed cycle training - rank_dropout/module_dropout, alpha for Peft LoRA, SNR flow matching - Timestep debug logging, config and UI updates
- min_lr/max_lr, get_avg_learning_rate, RMS tracking methods - Fix lr=0 with relative_step, uninitialized state, clamp to max_lr - Truncated normal sampling, weight update RMS logging
- Add safe_module_to_device in toolkit/util/device.py - Use in ltx2, qwen_image, wan22 for low_vram-aware device handling
- Separate sampling model with device handling, quantization support - Sampling model unload, LoRA on sampling transformer - Accuracy recovery adapter in quantize, file cleanup with retries - Stop job API marks job stopped and handles dead process - Sample image viewer keyboard navigation fix
- Change max_lr from 1e-2 to 1e-4 for Adafactor optimizer - Update comment to clarify learning rate clamping between min_lr and max_lr when using relative_step fix(optimizers): streamline learning rate clamping in Adafactor - Simplified the learning rate clamping logic to ensure it consistently respects min_lr and max_lr boundaries. - Removed the conditional check for relative_step, enhancing code clarity.
- Simplified the embedding conversion process in ZImageModel for better clarity and efficiency. - Improved device management for the sampling transformer, ensuring proper loading and unloading to optimize resource usage. - Enhanced code readability by removing redundant checks and restructuring the flow. fix(z_image): use dtype instead of deprecated torch_dtype in from_pretrained calls Revert "fix(z_image): use dtype instead of deprecated torch_dtype in from_pretrained calls" This reverts commit 7acdd14. refactor(zimage): load sampling transformer first and extract load helpers - Add _load_sampling_transformer() and _load_transformer(model_path) - Call sampling transformer load before main transformer to control peak VRAM - Shrink load_model() by delegating to helpers; preserve base_model_path resolution - Restore and add comments and docstrings in load paths feat(zimage): add debug_zimage_load to trace safetensors load/mmap - Add model.debug_zimage_load in ModelConfig to enable load debugging - When enabled, patch safetensors.torch.load_file and safetensors.safe_open to log path, size, and duration via print_and_status_update - Log markers before loading sampling vs main transformer to correlate slow sampling load with specific file opens - Patches applied once per process to avoid double-wrap fix(zimage): normalize model paths to avoid HF identifier warning Add normalize_model_path() in toolkit/paths.py to strip trailing path separators. Use it in Z-Image for name_or_path, sampling_name_or_path, and extras_name_or_path so paths like "e:\...\snapshots\hash\" no longer trigger "The module name (originally ) is not a valid Python identifier" from Hugging Face transformers. fix(zimage): use dtype instead of deprecated torch_dtype for transformers Pass dtype=dtype to AutoTokenizer.from_pretrained and Qwen3ForCausalLM.from_pretrained to remove the "torch_dtype is deprecated! Use dtype instead!" warnings when loading the text encoder.
Call optimizer.zero_grad(set_to_none=True), torch.cuda.synchronize(), and flush() before self.save() in BaseSDTrainProcess so gradients and CUDA cache are released before building state_dict and hashes. fix(train): validate optimizer state load and document save behavior - Add param count check before loading optimizer.pt; skip load and warn when current and saved param counts differ to avoid wrong state mapping. - Comment that optimizer is always saved via unwrap to match pre-prepare load target; drop unused exception variable in inner except.
Move $$alpha → .alpha key normalization into common peft block so it runs for all peft (lora, lokr, etc.). Remove duplicate alpha fix from lokr-only block to avoid missing_keys for alpha in state_dict.
- Eliminated unnecessary shuffling of tokens as it is already handled in the subsequent code. fix(dataloader): skip shuffle_tokens when caching text embeddings Apply shuffle_tokens only when not cache_text_embeddings, so cached embeddings are saved from the original caption order. Aligns with existing caption_dropout and token_dropout behavior and keeps cache paths stable. feat(shuffle): keep first segment/token fixed when shuffling captions and embeddings - get_caption: when shuffle_tokens is on, keep first comma-separated segment in place and shuffle only the rest - shuffle_sequence: keep index 0 fixed and permute only positions 1..seq_len-1 feat(train): enhance text embedding caching logic and memory management - Introduced a mechanism to clear cached text embeddings based on the dataset configuration, improving memory efficiency during training. - Updated the AiToolkitDataset to retain prompt embeddings if they are cached, ensuring consistency across epochs. - Added logging for shuffling cached text embedding tokens at the start of each epoch, enhancing visibility into the training process. feat(train): log shuffling of cached text embedding tokens per epoch Added a print statement to log when cached text embedding tokens are shuffled at the start of each epoch, enhancing visibility into the training process. feat(train): shuffle cached text embedding tokens every epoch When cache_text_embeddings is true, shuffle token order along the sequence dimension at the start of each new epoch (after first full pass). Add PromptEmbeds.shuffle_sequence(), dataset set_epoch_num/ clear_cached_embeddings_memory, wire epoch boundary in BaseSDTrainProcess, and sync epoch_num to datasets at train start for correct behavior on resume from checkpoint. fix(sd_trainer): use trigger_word when cached text embeds are missing - Prefer trigger over blank when batch.prompt_embeds is None (no cache on disk) - Require trigger_word when cache_text_embeddings is enabled and batch has no cached embeds; raise ValueError otherwise - Keep reg batches using blank (no trigger) fix(dataloader): respect shuffle_tokens for cached text embeddings Skip shuffling cached prompt embeds when shuffle_tokens is false in set_epoch_num and load_prompt_embedding.
- Eliminated unnecessary shuffling of tokens as it is already handled in the subsequent code. fix(dataloader): skip shuffle_tokens when caching text embeddings Apply shuffle_tokens only when not cache_text_embeddings, so cached embeddings are saved from the original caption order. Aligns with existing caption_dropout and token_dropout behavior and keeps cache paths stable. feat(shuffle): keep first segment/token fixed when shuffling captions and embeddings - get_caption: when shuffle_tokens is on, keep first comma-separated segment in place and shuffle only the rest - shuffle_sequence: keep index 0 fixed and permute only positions 1..seq_len-1 feat(train): enhance text embedding caching logic and memory management - Introduced a mechanism to clear cached text embeddings based on the dataset configuration, improving memory efficiency during training. - Updated the AiToolkitDataset to retain prompt embeddings if they are cached, ensuring consistency across epochs. - Added logging for shuffling cached text embedding tokens at the start of each epoch, enhancing visibility into the training process. feat(train): log shuffling of cached text embedding tokens per epoch Added a print statement to log when cached text embedding tokens are shuffled at the start of each epoch, enhancing visibility into the training process. feat(train): shuffle cached text embedding tokens every epoch When cache_text_embeddings is true, shuffle token order along the sequence dimension at the start of each new epoch (after first full pass). Add PromptEmbeds.shuffle_sequence(), dataset set_epoch_num/ clear_cached_embeddings_memory, wire epoch boundary in BaseSDTrainProcess, and sync epoch_num to datasets at train start for correct behavior on resume from checkpoint. fix(sd_trainer): use trigger_word when cached text embeds are missing - Prefer trigger over blank when batch.prompt_embeds is None (no cache on disk) - Require trigger_word when cache_text_embeddings is enabled and batch has no cached embeds; raise ValueError otherwise - Keep reg batches using blank (no trigger) fix(dataloader): respect shuffle_tokens for cached text embeddings Skip shuffling cached prompt embeds when shuffle_tokens is false in set_epoch_num and load_prompt_embedding.
- Add toolkit/util/debug.py with set_debug_config() and memory_debug() context manager - Register logging config in BaseSDTrainProcess - Wrap text encoder unload and Z-Image load stages in memory_debug for optional [DEBUG ...] CUDA lines - Add debug flag and CUDA memory log on text encoder unload
… wiring - fix(sampling): load sampling transformer on CPU, guarantee unload to CPU after generate_images, do not force unet to GPU - feat(lora): use single LoRA for training and sampling via shared parameters; free pretrained LoRA memory after load - Use normalize_path in Z-Image name_or_path; wire memory_debug in BaseSDTrainProcess LoRA load and Z-Image stages - train.example.yaml, base_model, network_mixins updates
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Training, Optimizers, Device Handling & Z-Image Improvements
Overview
This PR bundles improvements across the training pipeline, optimizers, device handling, and Z-Image inference: timestep sampling and schedulers, Adafactor extensions, low-VRAM device handling, and a separate sampling model with quantization and UI fixes. It also adds fixes and refactors for Adafactor LR clamping, Z-Image loading/dtype, train checkpoint memory and optimizer state, PEFT LoRA loading, and text-embedding cache behaviour (shuffle_tokens, per-epoch shuffle, trigger_word).
Training & Schedulers
gaussian_stdcurriculum, timestep debug logging and mapping fixes inBaseSDTrainProcess.SequentialLR.step()and T_0/T_max duplication with warmup; example config and guide (WARMUP_SCHEDULER_GUIDE.md).rank_dropout/module_dropout, alpha for Peft LoRA, SNR compatibility for flow matching; validation for image/video scaling.Optimizers (Adafactor)
lr=0withrelative_step, uninitialized state, clamping tomax_lr(including max_lr 1e-2→1e-4 and streamlined clamping logic); truncated normal sampling and spelling fix.Device Handling
toolkit/util/device.pyfor consistent, low-VRAM-aware transfers.low_vram(and related) flags instead of rawmodel.to(device).Z-Image & Inference
quantize_model.dtypeinstead of deprecatedtorch_dtypefor transformers; streamline embedding conversion and device management; optionaldebug_zimage_loadto trace safetensors load/mmap.Text Embedding Cache & Shuffle
shuffle_tokensis false inset_epoch_numandload_prompt_embedding; respectshuffle_tokensfor cached embeddings so cache paths stay stable.cache_text_embeddingsis true, shuffle token order along the sequence at the start of each new epoch;PromptEmbeds.shuffle_sequence(), datasetset_epoch_num/clear_cached_embeddings_memory, wired at epoch boundary inBaseSDTrainProcess; epoch_num synced at train start for correct resume.get_caption, whenshuffle_tokensis on, keep the first comma-separated segment in place and shuffle the rest; inshuffle_sequence, keep index 0 fixed and permute only 1..seq_len-1.trigger_wordover blank whenbatch.prompt_embedsis None; requiretrigger_wordwhencache_text_embeddingsis enabled and batch has no cached embeds (ValueError otherwise); reg batches keep using blank.PEFT
Commits
feat(training): timestep sampling, schedulers, networks, loss UI — blank_prompt_probability, differential guidance, LR graph; Gaussian/content-style timestep sampling, warmup for cosine LR; BaseSDTrainProcess timestep mapping, fixed cycle; rank_dropout/module_dropout, alpha for Peft LoRA, SNR flow matching; timestep debug logging, config and UI.
feat(optimizers): Adafactor min/max lr, RMS tracking, fixes — min_lr/max_lr, get_avg_learning_rate, RMS tracking; fixes for lr=0 with relative_step, uninitialized state, clamp to max_lr; truncated normal, weight update RMS logging.
refactor(device): safe_module_to_device and conditional model transfer — safe_module_to_device in toolkit/util/device.py; use in ltx2, qwen_image, wan22 for low_vram-aware handling.
feat(Z-Image): sampling model, quantization, stop job API — Separate sampling model, device handling, quantization; unload, LoRA on sampling transformer; accuracy recovery adapter, file cleanup with retries; stop job API and sample viewer fix.
fix(optimizers): adjust max_lr and improve learning rate clamping — Adafactor max_lr 1e-2→1e-4; streamlined clamping between min_lr and max_lr with relative_step.
refactor(z_image): streamline embedding conversion and device management — dtype instead of deprecated torch_dtype; normalize paths; load sampling transformer first, extract helpers; embedding conversion and device handling; debug_zimage_load.
fix(train): free memory before checkpoint save to reduce OOM risk — Free memory before checkpoint save; validate optimizer state load and document save behaviour.
fix(peft): apply alpha key fix for all peft types when loading LoRA — Alpha key fix for all PEFT types on LoRA load.
refactor(dataloader): remove redundant token shuffling logic — Redundant shuffling removed; shuffle_tokens skipped when caching; first segment/token fixed (get_caption + shuffle_sequence); set_epoch_num, clear_cached_embeddings_memory, shuffle at epoch start, logging; respect shuffle_tokens for cached embeds; trigger_word when cache_text_embeddings and no cache.
fix(adafactor): apply new min_lr/max_lr on restart after loading checkpoint — Ensures Adafactor applies updated min_lr/max_lr when resuming from checkpoint.
refactor(paths): rename normalize_model_path to normalize_path and strip whitespace — Centralised path normalisation and whitespace stripping in toolkit/paths.py.
feat(debug): add memory_debug util and use for CUDA load/unload logging — Adds toolkit/util/debug.py with memory_debug() context manager and debug config; wraps text encoder unload and Z-Image load stages for optional [DEBUG …] CUDA memory lines when logging.debug is enabled.
fix(sampling), feat(lora): sampling VRAM, single LoRA, path and debug wiring — Sampling: load sampling transformer on CPU, guarantee unload to CPU after generate_images, do not force unet to GPU. LoRA: single LoRA for training and sampling via shared parameters; free pretrained LoRA memory after load. Uses normalize_path in Z-Image and wires memory_debug in BaseSDTrainProcess LoRA load and Z-Image stages; train.example.yaml, base_model, network_mixins updates.