Skip to content

Training, Optimizers, Device Handling & Z-Image Improvements#713

Draft
daanforever wants to merge 17 commits intoostris:mainfrom
daanforever:pr-main
Draft

Training, Optimizers, Device Handling & Z-Image Improvements#713
daanforever wants to merge 17 commits intoostris:mainfrom
daanforever:pr-main

Conversation

@daanforever
Copy link

@daanforever daanforever commented Feb 13, 2026

Training, Optimizers, Device Handling & Z-Image Improvements

Overview

This PR bundles improvements across the training pipeline, optimizers, device handling, and Z-Image inference: timestep sampling and schedulers, Adafactor extensions, low-VRAM device handling, and a separate sampling model with quantization and UI fixes. It also adds fixes and refactors for Adafactor LR clamping, Z-Image loading/dtype, train checkpoint memory and optimizer state, PEFT LoRA loading, and text-embedding cache behaviour (shuffle_tokens, per-epoch shuffle, trigger_word).

Training & Schedulers

  • blank_prompt_probability — Control how often Blank Prompt Preservation (BPP) runs to reduce step time while limiting degradation.
  • Differential guidance — Separate metric and loss-graph category; LR graph with metric filtering.
  • Timestep sampling — Gaussian/content-style distribution, configurable gaussian_std curriculum, timestep debug logging and mapping fixes in BaseSDTrainProcess.
  • Warmup for cosine LR — Warmup support for cosine schedulers; fixes for SequentialLR.step() and T_0/T_max duplication with warmup; example config and guide (WARMUP_SCHEDULER_GUIDE.md).
  • Fixed cycle training — Config and UI for fixed cycle; sigma parameter fix for weight calculation.
  • Networks & configrank_dropout/module_dropout, alpha for Peft LoRA, SNR compatibility for flow matching; validation for image/video scaling.
  • UI — SimpleJob controls, docs, and types for new training options; loss graph and LR visualization.
  • Train robustness — Free memory before checkpoint save to reduce OOM risk; validate optimizer state load and document save behaviour.

Optimizers (Adafactor)

  • min_lr/max_lr and get_avg_learning_rate for finer control and logging.
  • RMS tracking — Methods and weight-update RMS logging.
  • Fixeslr=0 with relative_step, uninitialized state, clamping to max_lr (including max_lr 1e-2→1e-4 and streamlined clamping logic); truncated normal sampling and spelling fix.

Device Handling

  • safe_module_to_device in toolkit/util/device.py for consistent, low-VRAM-aware transfers.
  • Conditional model transfer in ltx2, qwen_image, and wan22 based on low_vram (and related) flags instead of raw model.to(device).

Z-Image & Inference

  • Separate sampling model — Own transformer with device handling (e.g. CPU when not generating) and optional quantization.
  • Sampling model unload and LoRA on sampling transformer during generation.
  • Quantization — Custom attribute for accuracy recovery adapter; assignment fix in quantize_model.
  • Robustness — File cleanup with retries and logging; ZImageModel device transfer fixes.
  • UI — Stop job API correctly marks job stopped and handles dead process; sample image viewer keyboard navigation fix.
  • Loading & dtype — Load sampling transformer first and extract load helpers; normalize model paths (HF identifier warning); use dtype instead of deprecated torch_dtype for transformers; streamline embedding conversion and device management; optional debug_zimage_load to trace safetensors load/mmap.

Text Embedding Cache & Shuffle

  • shuffle_tokens + cache_text_embeddings — Apply caption shuffle only when not caching; skip shuffling cached prompt embeds when shuffle_tokens is false in set_epoch_num and load_prompt_embedding; respect shuffle_tokens for cached embeddings so cache paths stay stable.
  • Per-epoch shuffle — When cache_text_embeddings is true, shuffle token order along the sequence at the start of each new epoch; PromptEmbeds.shuffle_sequence(), dataset set_epoch_num / clear_cached_embeddings_memory, wired at epoch boundary in BaseSDTrainProcess; epoch_num synced at train start for correct resume.
  • First segment/token fixed — In get_caption, when shuffle_tokens is on, keep the first comma-separated segment in place and shuffle the rest; in shuffle_sequence, keep index 0 fixed and permute only 1..seq_len-1.
  • Memory & UX — Clear cached embeddings when not using cache; retain prompt embeddings when cached; logging when cached tokens are shuffled each epoch.
  • SD trainer — Prefer trigger_word over blank when batch.prompt_embeds is None; require trigger_word when cache_text_embeddings is enabled and batch has no cached embeds (ValueError otherwise); reg batches keep using blank.

PEFT

  • LoRA loading — Apply alpha key fix for all PEFT types when loading LoRA.

Commits

  1. feat(training): timestep sampling, schedulers, networks, loss UI — blank_prompt_probability, differential guidance, LR graph; Gaussian/content-style timestep sampling, warmup for cosine LR; BaseSDTrainProcess timestep mapping, fixed cycle; rank_dropout/module_dropout, alpha for Peft LoRA, SNR flow matching; timestep debug logging, config and UI.

  2. feat(optimizers): Adafactor min/max lr, RMS tracking, fixes — min_lr/max_lr, get_avg_learning_rate, RMS tracking; fixes for lr=0 with relative_step, uninitialized state, clamp to max_lr; truncated normal, weight update RMS logging.

  3. refactor(device): safe_module_to_device and conditional model transfer — safe_module_to_device in toolkit/util/device.py; use in ltx2, qwen_image, wan22 for low_vram-aware handling.

  4. feat(Z-Image): sampling model, quantization, stop job API — Separate sampling model, device handling, quantization; unload, LoRA on sampling transformer; accuracy recovery adapter, file cleanup with retries; stop job API and sample viewer fix.

  5. fix(optimizers): adjust max_lr and improve learning rate clamping — Adafactor max_lr 1e-2→1e-4; streamlined clamping between min_lr and max_lr with relative_step.

  6. refactor(z_image): streamline embedding conversion and device management — dtype instead of deprecated torch_dtype; normalize paths; load sampling transformer first, extract helpers; embedding conversion and device handling; debug_zimage_load.

  7. fix(train): free memory before checkpoint save to reduce OOM risk — Free memory before checkpoint save; validate optimizer state load and document save behaviour.

  8. fix(peft): apply alpha key fix for all peft types when loading LoRA — Alpha key fix for all PEFT types on LoRA load.

  9. refactor(dataloader): remove redundant token shuffling logic — Redundant shuffling removed; shuffle_tokens skipped when caching; first segment/token fixed (get_caption + shuffle_sequence); set_epoch_num, clear_cached_embeddings_memory, shuffle at epoch start, logging; respect shuffle_tokens for cached embeds; trigger_word when cache_text_embeddings and no cache.

  10. fix(adafactor): apply new min_lr/max_lr on restart after loading checkpoint — Ensures Adafactor applies updated min_lr/max_lr when resuming from checkpoint.

  11. refactor(paths): rename normalize_model_path to normalize_path and strip whitespace — Centralised path normalisation and whitespace stripping in toolkit/paths.py.

  12. feat(debug): add memory_debug util and use for CUDA load/unload logging — Adds toolkit/util/debug.py with memory_debug() context manager and debug config; wraps text encoder unload and Z-Image load stages for optional [DEBUG …] CUDA memory lines when logging.debug is enabled.

  13. fix(sampling), feat(lora): sampling VRAM, single LoRA, path and debug wiring — Sampling: load sampling transformer on CPU, guarantee unload to CPU after generate_images, do not force unet to GPU. LoRA: single LoRA for training and sampling via shared parameters; free pretrained LoRA memory after load. Uses normalize_path in Z-Image and wires memory_debug in BaseSDTrainProcess LoRA load and Z-Image stages; train.example.yaml, base_model, network_mixins updates.

hinablue and others added 11 commits February 13, 2026 18:12
- blank_prompt_probability, differential guidance metric, LR graph
- Gaussian/content-style timestep sampling, warmup for cosine LR
- BaseSDTrainProcess timestep mapping, fixed cycle training
- rank_dropout/module_dropout, alpha for Peft LoRA, SNR flow matching
- Timestep debug logging, config and UI updates
- min_lr/max_lr, get_avg_learning_rate, RMS tracking methods
- Fix lr=0 with relative_step, uninitialized state, clamp to max_lr
- Truncated normal sampling, weight update RMS logging
- Add safe_module_to_device in toolkit/util/device.py
- Use in ltx2, qwen_image, wan22 for low_vram-aware device handling
- Separate sampling model with device handling, quantization support
- Sampling model unload, LoRA on sampling transformer
- Accuracy recovery adapter in quantize, file cleanup with retries
- Stop job API marks job stopped and handles dead process
- Sample image viewer keyboard navigation fix
- Change max_lr from 1e-2 to 1e-4 for Adafactor optimizer
- Update comment to clarify learning rate clamping between min_lr and max_lr when using relative_step

fix(optimizers): streamline learning rate clamping in Adafactor

- Simplified the learning rate clamping logic to ensure it consistently respects min_lr and max_lr boundaries.
- Removed the conditional check for relative_step, enhancing code clarity.
- Simplified the embedding conversion process in ZImageModel for better clarity and efficiency.
- Improved device management for the sampling transformer, ensuring proper loading and unloading to optimize resource usage.
- Enhanced code readability by removing redundant checks and restructuring the flow.

fix(z_image): use dtype instead of deprecated torch_dtype in from_pretrained calls

Revert "fix(z_image): use dtype instead of deprecated torch_dtype in from_pretrained calls"

This reverts commit 7acdd14.

refactor(zimage): load sampling transformer first and extract load helpers

- Add _load_sampling_transformer() and _load_transformer(model_path)
- Call sampling transformer load before main transformer to control peak VRAM
- Shrink load_model() by delegating to helpers; preserve base_model_path resolution
- Restore and add comments and docstrings in load paths

feat(zimage): add debug_zimage_load to trace safetensors load/mmap

- Add model.debug_zimage_load in ModelConfig to enable load debugging
- When enabled, patch safetensors.torch.load_file and safetensors.safe_open
  to log path, size, and duration via print_and_status_update
- Log markers before loading sampling vs main transformer to correlate
  slow sampling load with specific file opens
- Patches applied once per process to avoid double-wrap

fix(zimage): normalize model paths to avoid HF identifier warning

Add normalize_model_path() in toolkit/paths.py to strip trailing path
separators. Use it in Z-Image for name_or_path, sampling_name_or_path,
and extras_name_or_path so paths like "e:\...\snapshots\hash\" no longer
trigger "The module name (originally ) is not a valid Python identifier"
from Hugging Face transformers.

fix(zimage): use dtype instead of deprecated torch_dtype for transformers

Pass dtype=dtype to AutoTokenizer.from_pretrained and Qwen3ForCausalLM.from_pretrained to remove the "torch_dtype is deprecated! Use dtype instead!" warnings when loading the text encoder.
Call optimizer.zero_grad(set_to_none=True), torch.cuda.synchronize(),
and flush() before self.save() in BaseSDTrainProcess so gradients and
CUDA cache are released before building state_dict and hashes.

fix(train): validate optimizer state load and document save behavior

- Add param count check before loading optimizer.pt; skip load and warn
  when current and saved param counts differ to avoid wrong state mapping.
- Comment that optimizer is always saved via unwrap to match pre-prepare
  load target; drop unused exception variable in inner except.
Move $$alpha → .alpha key normalization into common peft block so it
runs for all peft (lora, lokr, etc.). Remove duplicate alpha fix from
lokr-only block to avoid missing_keys for alpha in state_dict.
- Eliminated unnecessary shuffling of tokens as it is already handled in the subsequent code.

fix(dataloader): skip shuffle_tokens when caching text embeddings

Apply shuffle_tokens only when not cache_text_embeddings, so cached
embeddings are saved from the original caption order. Aligns with
existing caption_dropout and token_dropout behavior and keeps cache
paths stable.

feat(shuffle): keep first segment/token fixed when shuffling captions and embeddings

- get_caption: when shuffle_tokens is on, keep first comma-separated segment in place and shuffle only the rest
- shuffle_sequence: keep index 0 fixed and permute only positions 1..seq_len-1

feat(train): enhance text embedding caching logic and memory management

- Introduced a mechanism to clear cached text embeddings based on the dataset configuration, improving memory efficiency during training.
- Updated the AiToolkitDataset to retain prompt embeddings if they are cached, ensuring consistency across epochs.
- Added logging for shuffling cached text embedding tokens at the start of each epoch, enhancing visibility into the training process.

feat(train): log shuffling of cached text embedding tokens per epoch

Added a print statement to log when cached text embedding tokens are shuffled at the start of each epoch, enhancing visibility into the training process.

feat(train): shuffle cached text embedding tokens every epoch

When cache_text_embeddings is true, shuffle token order along the
sequence dimension at the start of each new epoch (after first full
pass). Add PromptEmbeds.shuffle_sequence(), dataset set_epoch_num/
clear_cached_embeddings_memory, wire epoch boundary in BaseSDTrainProcess,
and sync epoch_num to datasets at train start for correct behavior
on resume from checkpoint.

fix(sd_trainer): use trigger_word when cached text embeds are missing

- Prefer trigger over blank when batch.prompt_embeds is None (no cache on disk)
- Require trigger_word when cache_text_embeddings is enabled and batch has no cached embeds; raise ValueError otherwise
- Keep reg batches using blank (no trigger)

fix(dataloader): respect shuffle_tokens for cached text embeddings

Skip shuffling cached prompt embeds when shuffle_tokens is false in
set_epoch_num and load_prompt_embedding.
@daanforever daanforever marked this pull request as ready for review February 17, 2026 09:26
- Eliminated unnecessary shuffling of tokens as it is already handled in the subsequent code.

fix(dataloader): skip shuffle_tokens when caching text embeddings

Apply shuffle_tokens only when not cache_text_embeddings, so cached
embeddings are saved from the original caption order. Aligns with
existing caption_dropout and token_dropout behavior and keeps cache
paths stable.

feat(shuffle): keep first segment/token fixed when shuffling captions and embeddings

- get_caption: when shuffle_tokens is on, keep first comma-separated segment in place and shuffle only the rest
- shuffle_sequence: keep index 0 fixed and permute only positions 1..seq_len-1

feat(train): enhance text embedding caching logic and memory management

- Introduced a mechanism to clear cached text embeddings based on the dataset configuration, improving memory efficiency during training.
- Updated the AiToolkitDataset to retain prompt embeddings if they are cached, ensuring consistency across epochs.
- Added logging for shuffling cached text embedding tokens at the start of each epoch, enhancing visibility into the training process.

feat(train): log shuffling of cached text embedding tokens per epoch

Added a print statement to log when cached text embedding tokens are shuffled at the start of each epoch, enhancing visibility into the training process.

feat(train): shuffle cached text embedding tokens every epoch

When cache_text_embeddings is true, shuffle token order along the
sequence dimension at the start of each new epoch (after first full
pass). Add PromptEmbeds.shuffle_sequence(), dataset set_epoch_num/
clear_cached_embeddings_memory, wire epoch boundary in BaseSDTrainProcess,
and sync epoch_num to datasets at train start for correct behavior
on resume from checkpoint.

fix(sd_trainer): use trigger_word when cached text embeds are missing

- Prefer trigger over blank when batch.prompt_embeds is None (no cache on disk)
- Require trigger_word when cache_text_embeddings is enabled and batch has no cached embeds; raise ValueError otherwise
- Keep reg batches using blank (no trigger)

fix(dataloader): respect shuffle_tokens for cached text embeddings

Skip shuffling cached prompt embeds when shuffle_tokens is false in
set_epoch_num and load_prompt_embedding.
- Add toolkit/util/debug.py with set_debug_config() and memory_debug() context manager
- Register logging config in BaseSDTrainProcess
- Wrap text encoder unload and Z-Image load stages in memory_debug for optional [DEBUG ...] CUDA lines
- Add debug flag and CUDA memory log on text encoder unload
… wiring

- fix(sampling): load sampling transformer on CPU, guarantee unload to CPU after generate_images, do not force unet to GPU
- feat(lora): use single LoRA for training and sampling via shared parameters; free pretrained LoRA memory after load
- Use normalize_path in Z-Image name_or_path; wire memory_debug in BaseSDTrainProcess LoRA load and Z-Image stages
- train.example.yaml, base_model, network_mixins updates
@daanforever daanforever marked this pull request as draft March 6, 2026 16:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants