feat: stricter layer offloading for flux2-klein#666
feat: stricter layer offloading for flux2-klein#666pokemans wants to merge 9 commits intoostris:mainfrom
Conversation
|
ported to chroma, in testing https://github.com/pokemans/ai-toolkit/tree/chroma-low-vram |
|
Hi @pokemans. Even with 100% CPU offload, it looks like it still gets dequantized on the GPU for encoding the prompt, which is why you had to offload 60% of the transformer as well. The attached diff allows the transformer to be 0% offloaded on 16GB VRAM in qfloat8 (It's all on GPU in my setup with a 7800XT), and runs the text encoder inference on the CPU itself in the special case where it's 100% offloaded. I would still recommend setting This is still rough around the edges and not in a state to merge imo, so please try it out and let me know if this works for you! I am getting 11s/it with gradient accumulation set to 2, using the RoCm fork. diff --git a/extensions_built_in/diffusion_models/flux2/flux2_klein_model.py b/extensions_built_in/diffusion_models/flux2/flux2_klein_model.py
index 86fdafa..2bcb87e 100644
--- a/extensions_built_in/diffusion_models/flux2/flux2_klein_model.py
+++ b/extensions_built_in/diffusion_models/flux2/flux2_klein_model.py
@@ -44,19 +44,11 @@ class Flux2KleinModel(Flux2Model):
self.flux2_klein_te_path,
torch_dtype=dtype,
)
- text_encoder.to(self.device_torch, dtype=dtype)
-
- flush()
-
- if self.model_config.quantize_te:
- self.print_and_status_update("Quantizing Qwen3")
- quantize(text_encoder, weights=get_qtype(self.model_config.qtype))
- freeze(text_encoder)
- flush()
if (
self.model_config.layer_offloading
and self.model_config.layer_offloading_text_encoder_percent > 0
+ and not self.te_offload_to_cpu
):
MemoryManager.attach(
text_encoder,
@@ -64,6 +56,17 @@ class Flux2KleinModel(Flux2Model):
offload_percent=self.model_config.layer_offloading_text_encoder_percent,
)
+ if not self.te_offload_to_cpu:
+ text_encoder.to(self.device_torch, dtype=dtype)
+
+ flush()
+
+ if self.model_config.quantize_te:
+ self.print_and_status_update("Quantizing Qwen3")
+ quantize(text_encoder, weights=get_qtype(self.model_config.qtype_te))
+ freeze(text_encoder)
+ flush()
+
tokenizer = Qwen2Tokenizer.from_pretrained(self.flux2_klein_te_path)
return text_encoder, tokenizer
diff --git a/extensions_built_in/diffusion_models/flux2/flux2_model.py b/extensions_built_in/diffusion_models/flux2/flux2_model.py
index 726d0e4..6932959 100644
--- a/extensions_built_in/diffusion_models/flux2/flux2_model.py
+++ b/extensions_built_in/diffusion_models/flux2/flux2_model.py
@@ -85,6 +85,13 @@ class Flux2Model(BaseModel):
def get_train_scheduler():
return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config)
+ @property
+ def te_offload_to_cpu(self):
+ return (
+ self.model_config.layer_offloading
+ and self.model_config.layer_offloading_text_encoder_percent >= 1.0
+ )
+
def get_bucket_divisibility(self):
return 16
@@ -101,19 +108,11 @@ class Flux2Model(BaseModel):
torch_dtype=dtype,
)
)
- text_encoder.to(self.device_torch, dtype=dtype)
-
- flush()
-
- if self.model_config.quantize_te:
- self.print_and_status_update("Quantizing Mistral")
- quantize(text_encoder, weights=get_qtype(self.model_config.qtype))
- freeze(text_encoder)
- flush()
if (
self.model_config.layer_offloading
and self.model_config.layer_offloading_text_encoder_percent > 0
+ and not self.te_offload_to_cpu
):
MemoryManager.attach(
text_encoder,
@@ -121,6 +120,17 @@ class Flux2Model(BaseModel):
offload_percent=self.model_config.layer_offloading_text_encoder_percent,
)
+ if not self.te_offload_to_cpu:
+ text_encoder.to(self.device_torch, dtype=dtype)
+
+ flush()
+
+ if self.model_config.quantize_te:
+ self.print_and_status_update("Quantizing Mistral")
+ quantize(text_encoder, weights=get_qtype(self.model_config.qtype_te))
+ freeze(text_encoder)
+ flush()
+
tokenizer = AutoProcessor.from_pretrained(MISTRAL_PATH)
return text_encoder, tokenizer
@@ -155,7 +165,8 @@ class Flux2Model(BaseModel):
transformer.load_state_dict(transformer_state_dict, assign=True)
- transformer.to(self.quantize_device, dtype=dtype)
+ if not self.model_config.low_vram:
+ transformer.to(self.quantize_device, dtype=dtype)
if self.model_config.quantize:
# patch the state dict method
@@ -163,9 +174,12 @@ class Flux2Model(BaseModel):
self.print_and_status_update("Quantizing Transformer")
quantize_model(self, transformer)
flush()
+
+ if self.model_config.layer_offloading:
+ self.print_and_status_update("Moving transformer to CPU")
+ transformer.to("cpu")
else:
transformer.to(self.device_torch, dtype=dtype)
- flush()
if (
self.model_config.layer_offloading
@@ -177,9 +191,7 @@ class Flux2Model(BaseModel):
offload_percent=self.model_config.layer_offloading_transformer_percent,
)
- if self.model_config.low_vram:
- self.print_and_status_update("Moving transformer to CPU")
- transformer.to("cpu")
+ flush()
text_encoder, tokenizer = self.load_te()
@@ -233,8 +245,8 @@ class Flux2Model(BaseModel):
tokenizer = [pipe.tokenizer]
flush()
- # just to make sure everything is on the right device and dtype
- text_encoder[0].to(self.device_torch)
+ if not self.te_offload_to_cpu:
+ text_encoder[0].to(self.device_torch)
text_encoder[0].requires_grad_(False)
text_encoder[0].eval()
pipe.transformer = pipe.transformer.to(self.device_torch)
@@ -444,15 +456,28 @@ class Flux2Model(BaseModel):
return noise_pred
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
- if self.pipeline.text_encoder.device != self.device_torch:
- self.pipeline.text_encoder.to(self.device_torch)
+ if self.te_offload_to_cpu:
+ encode_device = self.pipeline.text_encoder.device
+ else:
+ if self.pipeline.text_encoder.device != self.device_torch:
+ self.pipeline.text_encoder.to(self.device_torch)
+ encode_device = self.device_torch
prompt_embeds, prompt_embeds_mask = self.pipeline.encode_prompt(
- prompt, device=self.device_torch
+ prompt, device=encode_device
)
- pe = PromptEmbeds(prompt_embeds)
+ pe = PromptEmbeds(prompt_embeds.to(self.device_torch))
return pe
+ def set_device_state(self, state):
+ if self.te_offload_to_cpu:
+ if isinstance(state.get('text_encoder'), list):
+ for te_state in state['text_encoder']:
+ te_state['device'] = 'cpu'
+ elif isinstance(state.get('text_encoder'), dict):
+ state['text_encoder']['device'] = 'cpu'
+ super().set_device_state(state)
+
def get_model_has_grad(self):
return False
|
Summary
Changes to flux2-klein and shared flux2 code that more aggressively offloads models to CPU memory when the low_vram and layer_offloading job config options are set. I needed these changes to train a klein-9b (fp8 quantized) LoRA on 16gb of vram on windows. I haven't been able to test on other OS/hardware, but the changes are small and should only affect jobs that are using low_vram and/or layer_offloading.
Testing
Windows 11 + 5070ti , klein 9b lora was able to be trained @ 60% transformer offload and 100% TE offload without oom, cache latents and blank prompt preservation were the only other options enabled from the defaults.
I will try to do more testing as I can but don't have access to much hardware. The code changes improved things a lot for me so I wanted to share.