Skip to content

feat: stricter layer offloading for flux2-klein#666

Open
pokemans wants to merge 9 commits intoostris:mainfrom
pokemans:flux2k-lowvram
Open

feat: stricter layer offloading for flux2-klein#666
pokemans wants to merge 9 commits intoostris:mainfrom
pokemans:flux2k-lowvram

Conversation

@pokemans
Copy link

@pokemans pokemans commented Jan 23, 2026

Summary

Changes to flux2-klein and shared flux2 code that more aggressively offloads models to CPU memory when the low_vram and layer_offloading job config options are set. I needed these changes to train a klein-9b (fp8 quantized) LoRA on 16gb of vram on windows. I haven't been able to test on other OS/hardware, but the changes are small and should only affect jobs that are using low_vram and/or layer_offloading.

Testing

Windows 11 + 5070ti , klein 9b lora was able to be trained @ 60% transformer offload and 100% TE offload without oom, cache latents and blank prompt preservation were the only other options enabled from the defaults.

I will try to do more testing as I can but don't have access to much hardware. The code changes improved things a lot for me so I wanted to share.

@pokemans pokemans changed the title [wip] feat: stricter layer offloading for flux2-klein feat: stricter layer offloading for flux2-klein Jan 23, 2026
@pokemans
Copy link
Author

ported to chroma, in testing https://github.com/pokemans/ai-toolkit/tree/chroma-low-vram
fp8 base is working on 16gb vram @ 10% transformer offload

@EQuioMaX
Copy link

EQuioMaX commented Feb 19, 2026

Hi @pokemans. Even with 100% CPU offload, it looks like it still gets dequantized on the GPU for encoding the prompt, which is why you had to offload 60% of the transformer as well.

The attached diff allows the transformer to be 0% offloaded on 16GB VRAM in qfloat8 (It's all on GPU in my setup with a 7800XT), and runs the text encoder inference on the CPU itself in the special case where it's 100% offloaded. I would still recommend setting cache_text_embeddings to true to reduce dependency on the slow CPU inference (which also removes the need for the custom te_offload_to_cpu property).

This is still rough around the edges and not in a state to merge imo, so please try it out and let me know if this works for you! I am getting 11s/it with gradient accumulation set to 2, using the RoCm fork.

lora_v1.2:   7%|██▏                             | 333/5000 [1:06:59<14:11:18, 10.94s/it, lr: 7.6e-05 loss: 2.870e-01]
diff --git a/extensions_built_in/diffusion_models/flux2/flux2_klein_model.py b/extensions_built_in/diffusion_models/flux2/flux2_klein_model.py
index 86fdafa..2bcb87e 100644
--- a/extensions_built_in/diffusion_models/flux2/flux2_klein_model.py
+++ b/extensions_built_in/diffusion_models/flux2/flux2_klein_model.py
@@ -44,19 +44,11 @@ class Flux2KleinModel(Flux2Model):
             self.flux2_klein_te_path,
             torch_dtype=dtype,
         )
-        text_encoder.to(self.device_torch, dtype=dtype)
-
-        flush()
-
-        if self.model_config.quantize_te:
-            self.print_and_status_update("Quantizing Qwen3")
-            quantize(text_encoder, weights=get_qtype(self.model_config.qtype))
-            freeze(text_encoder)
-            flush()
 
         if (
             self.model_config.layer_offloading
             and self.model_config.layer_offloading_text_encoder_percent > 0
+            and not self.te_offload_to_cpu
         ):
             MemoryManager.attach(
                 text_encoder,
@@ -64,6 +56,17 @@ class Flux2KleinModel(Flux2Model):
                 offload_percent=self.model_config.layer_offloading_text_encoder_percent,
             )
 
+        if not self.te_offload_to_cpu:
+            text_encoder.to(self.device_torch, dtype=dtype)
+
+        flush()
+
+        if self.model_config.quantize_te:
+            self.print_and_status_update("Quantizing Qwen3")
+            quantize(text_encoder, weights=get_qtype(self.model_config.qtype_te))
+            freeze(text_encoder)
+            flush()
+
         tokenizer = Qwen2Tokenizer.from_pretrained(self.flux2_klein_te_path)
         return text_encoder, tokenizer
 
diff --git a/extensions_built_in/diffusion_models/flux2/flux2_model.py b/extensions_built_in/diffusion_models/flux2/flux2_model.py
index 726d0e4..6932959 100644
--- a/extensions_built_in/diffusion_models/flux2/flux2_model.py
+++ b/extensions_built_in/diffusion_models/flux2/flux2_model.py
@@ -85,6 +85,13 @@ class Flux2Model(BaseModel):
     def get_train_scheduler():
         return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config)
 
+    @property
+    def te_offload_to_cpu(self):
+        return (
+            self.model_config.layer_offloading
+            and self.model_config.layer_offloading_text_encoder_percent >= 1.0
+        )
+
     def get_bucket_divisibility(self):
         return 16
 
@@ -101,19 +108,11 @@ class Flux2Model(BaseModel):
                 torch_dtype=dtype,
             )
         )
-        text_encoder.to(self.device_torch, dtype=dtype)
-
-        flush()
-
-        if self.model_config.quantize_te:
-            self.print_and_status_update("Quantizing Mistral")
-            quantize(text_encoder, weights=get_qtype(self.model_config.qtype))
-            freeze(text_encoder)
-            flush()
 
         if (
             self.model_config.layer_offloading
             and self.model_config.layer_offloading_text_encoder_percent > 0
+            and not self.te_offload_to_cpu
         ):
             MemoryManager.attach(
                 text_encoder,
@@ -121,6 +120,17 @@ class Flux2Model(BaseModel):
                 offload_percent=self.model_config.layer_offloading_text_encoder_percent,
             )
 
+        if not self.te_offload_to_cpu:
+            text_encoder.to(self.device_torch, dtype=dtype)
+
+        flush()
+
+        if self.model_config.quantize_te:
+            self.print_and_status_update("Quantizing Mistral")
+            quantize(text_encoder, weights=get_qtype(self.model_config.qtype_te))
+            freeze(text_encoder)
+            flush()
+
         tokenizer = AutoProcessor.from_pretrained(MISTRAL_PATH)
         return text_encoder, tokenizer
 
@@ -155,7 +165,8 @@ class Flux2Model(BaseModel):
 
         transformer.load_state_dict(transformer_state_dict, assign=True)
 
-        transformer.to(self.quantize_device, dtype=dtype)
+        if not self.model_config.low_vram:
+            transformer.to(self.quantize_device, dtype=dtype)
 
         if self.model_config.quantize:
             # patch the state dict method
@@ -163,9 +174,12 @@ class Flux2Model(BaseModel):
             self.print_and_status_update("Quantizing Transformer")
             quantize_model(self, transformer)
             flush()
+
+        if self.model_config.layer_offloading:
+            self.print_and_status_update("Moving transformer to CPU")
+            transformer.to("cpu")
         else:
             transformer.to(self.device_torch, dtype=dtype)
-        flush()
 
         if (
             self.model_config.layer_offloading
@@ -177,9 +191,7 @@ class Flux2Model(BaseModel):
                 offload_percent=self.model_config.layer_offloading_transformer_percent,
             )
 
-        if self.model_config.low_vram:
-            self.print_and_status_update("Moving transformer to CPU")
-            transformer.to("cpu")
+        flush()
 
         text_encoder, tokenizer = self.load_te()
 
@@ -233,8 +245,8 @@ class Flux2Model(BaseModel):
         tokenizer = [pipe.tokenizer]
 
         flush()
-        # just to make sure everything is on the right device and dtype
-        text_encoder[0].to(self.device_torch)
+        if not self.te_offload_to_cpu:
+            text_encoder[0].to(self.device_torch)
         text_encoder[0].requires_grad_(False)
         text_encoder[0].eval()
         pipe.transformer = pipe.transformer.to(self.device_torch)
@@ -444,15 +456,28 @@ class Flux2Model(BaseModel):
         return noise_pred
 
     def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
-        if self.pipeline.text_encoder.device != self.device_torch:
-            self.pipeline.text_encoder.to(self.device_torch)
+        if self.te_offload_to_cpu:
+            encode_device = self.pipeline.text_encoder.device
+        else:
+            if self.pipeline.text_encoder.device != self.device_torch:
+                self.pipeline.text_encoder.to(self.device_torch)
+            encode_device = self.device_torch
 
         prompt_embeds, prompt_embeds_mask = self.pipeline.encode_prompt(
-            prompt, device=self.device_torch
+            prompt, device=encode_device
         )
-        pe = PromptEmbeds(prompt_embeds)
+        pe = PromptEmbeds(prompt_embeds.to(self.device_torch))
         return pe
 
+    def set_device_state(self, state):
+        if self.te_offload_to_cpu:
+            if isinstance(state.get('text_encoder'), list):
+                for te_state in state['text_encoder']:
+                    te_state['device'] = 'cpu'
+            elif isinstance(state.get('text_encoder'), dict):
+                state['text_encoder']['device'] = 'cpu'
+        super().set_device_state(state)
+
     def get_model_has_grad(self):
         return False
 

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants