Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions modules/module/LoRAModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,12 +290,14 @@ class LoRAModule(PeftBase):
rank: int
alpha: torch.Tensor
dropout: Dropout
use_stiefel: bool

# Note there's a few times in this class where we assert the existence of
# optional members. This is because these members might not exist at
# construction, but definitely exist by the time those methods are called.

def __init__(self, prefix: str, orig_module: nn.Module | None, rank: int, alpha: float):
def __init__(self, prefix: str, orig_module: nn.Module | None, rank: int, alpha: float, use_stiefel: bool = False):
self.use_stiefel = use_stiefel
super().__init__(prefix, orig_module)

self.rank = rank
Expand All @@ -312,8 +314,18 @@ def __init__(self, prefix: str, orig_module: nn.Module | None, rank: int, alpha:
def initialize_weights(self):
self._initialized = True
self.lora_down, self.lora_up = self.create_layer()
nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_up.weight)

if self.use_stiefel:
# Stiefel Initialization
# Paper requires B (lora_up) to be on Stiefel Manifold (Orthonormal).
# To ensure the adapter starts as Identity (0 impact), we must set A (lora_down) to 0.
nn.init.orthogonal_(self.lora_up.weight)
nn.init.zeros_(self.lora_down.weight)
self.lora_down.weight._is_lora_A = True
self.lora_up.weight._is_lora_B = True
else:
nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_up.weight)

def check_initialized(self):
super().check_initialized()
Expand Down Expand Up @@ -519,6 +531,8 @@ def initialize_weights(self):
.to(device=self.orig_module.weight.device)
)

self.dora_scale._is_dora_scale = True

del orig_weight

def check_initialized(self):
Expand Down Expand Up @@ -593,6 +607,7 @@ def __init__(
self.peft_type = config.peft_type
self.rank = config.lora_rank
self.alpha = config.lora_alpha
use_stiefel = config.use_stiefel

self.module_filters = [
ModuleFilter(pattern, use_regex=config.layer_filter_regex)
Expand All @@ -609,12 +624,15 @@ def __init__(
'norm_epsilon': config.lora_decompose_norm_epsilon,
'decompose_output_axis': config.lora_decompose_output_axis,
'train_device': torch.device(config.train_device),
'use_stiefel': use_stiefel,
}
else:
self.klass = LoRAModule
self.dummy_klass = DummyLoRAModule
self.additional_args = [self.rank, self.alpha]
self.additional_kwargs = {}
self.additional_kwargs = {
'use_stiefel': use_stiefel
}
elif self.peft_type == PeftType.LOHA:
self.klass = LoHaModule
self.dummy_klass = DummyLoHaModule
Expand Down
18 changes: 18 additions & 0 deletions modules/ui/LoraTab.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,24 @@ def setup_lora(self, peft_type: PeftType):
tooltip="Apply the weight decomposition on the output axis instead of the input axis.")
components.switch(master, 3, 4, self.ui_state, "lora_decompose_output_axis")

components.label(master, 4, 3, "Use Stiefel LoRA",
tooltip="Implements Stiefel-LoRA, which treats the B-factor as an orthogonal matrix updated via manifold projection and retraction. This addresses the 'rank collapse' issue found in standard LoRA/DoRA training. Includes rank-invariant scaling (LR/Weight Decay). Alpha is MUST set equal to Rank.")

stiefel_frame = ctk.CTkFrame(master, fg_color="transparent")
stiefel_frame.grid(row=4, column=4, sticky="w")
components.switch(stiefel_frame, 0, 0, self.ui_state, "use_stiefel")

def open_stiefel_settings():
from modules.ui.OptimizerParamsWindow import OptimizerParamsWindow
from modules.util.enum.Optimizer import Optimizer
self.ui_state.get_var("use_stiefel").set(True)
self.ui_state.get_var("optimizer.optimizer").set(str(Optimizer.Stiefel_LoRA))
self.train_config.optimizer.optimizer = Optimizer.Stiefel_LoRA
window = OptimizerParamsWindow(self.master, self.train_config, self.ui_state)
self.master.wait_window(window)

ctk.CTkButton(stiefel_frame, text="...", width=30, command=open_stiefel_settings).grid(row=0, column=1, padx=(5, 0))

# LoRA and LoHA shared settings
if peft_type == PeftType.LORA or peft_type == PeftType.LOHA:
# rank
Expand Down
19 changes: 18 additions & 1 deletion modules/ui/TrainingTab.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,25 @@ def __create_base_frame(self, master, row):
# optimizer
components.label(frame, 0, 0, "Optimizer",
tooltip="The type of optimizer")
components.options_adv(frame, 0, 1, [str(x) for x in list(Optimizer)], self.ui_state, "optimizer.optimizer",
opt_var, opt_d = components.options_adv(frame, 0, 1, [str(x) for x in list(Optimizer)], self.ui_state, "optimizer.optimizer",
command=self.__restore_optimizer_config, adv_command=self.__open_optimizer_params_window)
self.optimizer_comp = opt_d['component']

def update_optimizer_options(*args):
try:
if self.ui_state.get_var("use_stiefel").get():
self.optimizer_comp.configure(values=[str(Optimizer.Stiefel_LoRA)])
if self.ui_state.get_var("optimizer.optimizer").get() != str(Optimizer.Stiefel_LoRA):
self.ui_state.get_var("optimizer.optimizer").set(str(Optimizer.Stiefel_LoRA))
self.train_config.optimizer.optimizer = Optimizer.Stiefel_LoRA
self.__restore_optimizer_config()
else:
self.optimizer_comp.configure(values=[str(x) for x in list(Optimizer)])
except Exception:
pass

self.ui_state.get_var("use_stiefel").trace_add("write", update_optimizer_options)
update_optimizer_options()

# learning rate scheduler
# Wackiness will ensue when reloading configs if we don't check and clear this first.
Expand Down
2 changes: 2 additions & 0 deletions modules/util/config/TrainConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,7 @@ class TrainConfig(BaseConfig):
lora_decompose_output_axis: bool
lora_weight_dtype: DataType
bundle_additional_embeddings: bool
use_stiefel: bool

# oft
oft_block_size: int
Expand Down Expand Up @@ -1155,6 +1156,7 @@ def default_values() -> 'TrainConfig':
data.append(("lora_decompose_output_axis", False, bool, False))
data.append(("lora_weight_dtype", DataType.FLOAT_32, DataType, False))
data.append(("bundle_additional_embeddings", True, bool, False))
data.append(("use_stiefel", False, bool, False))

# oft
data.append(("oft_block_size", 32, int, False))
Expand Down
16 changes: 16 additions & 0 deletions modules/util/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,22 @@ def create_optimizer(
alpha_grad=optimizer_config.alpha_grad if optimizer_config.alpha_grad is not None else 100,
)

# Stiefel_LoRA Optimizer
case Optimizer.Stiefel_LoRA:
from adv_optm import Stiefel_LoRA
optimizer = Stiefel_LoRA(
params=parameters,
lr=config.learning_rate,
momentum=optimizer_config.momentum if optimizer_config.momentum is not None else 0,
weight_decay=optimizer_config.weight_decay if optimizer_config.weight_decay is not None else 0.0,
nnmf_factor=optimizer_config.nnmf_factor if optimizer_config.nnmf_factor is not None else False,
cautious_wd=optimizer_config.cautious_wd if optimizer_config.cautious_wd is not None else False,
stochastic_rounding=optimizer_config.stochastic_rounding,
compiled_optimizer=optimizer_config.compile if optimizer_config.compile is not None else False,
Simplified_AdEMAMix=optimizer_config.Simplified_AdEMAMix if optimizer_config.Simplified_AdEMAMix is not None else False,
alpha_grad=optimizer_config.alpha_grad if optimizer_config.alpha_grad is not None else 100,
)

# LION_ADV Optimizer
case Optimizer.LION_ADV:
from adv_optm import Lion_adv
Expand Down
2 changes: 2 additions & 0 deletions modules/util/enum/Optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class Optimizer(Enum):
SGD = 'SGD'
SGD_8BIT = 'SGD_8BIT'
SIGNSGD_ADV = 'SIGNSGD_ADV'
Stiefel_LoRA = 'Stiefel_LoRA'

# Schedule-free optimizers
SCHEDULE_FREE_ADAMW = 'SCHEDULE_FREE_ADAMW'
Expand Down Expand Up @@ -118,6 +119,7 @@ def supports_fused_back_pass(self):
Optimizer.MUON_ADV,
Optimizer.ADAMUON_ADV,
Optimizer.SIGNSGD_ADV,
Optimizer.Stiefel_LoRA,
]

# Small helper for adjusting learning rates to adaptive optimizers.
Expand Down
13 changes: 12 additions & 1 deletion modules/util/optimizer_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ def init_model_parameters(
"k_warmup_steps": None,
},
Optimizer.SIGNSGD_ADV: {
"momentum": 0.99,
"momentum": 0.95,
"cautious_wd": False,
"weight_decay": 0.0,
"nnmf_factor": False,
Expand All @@ -550,6 +550,17 @@ def init_model_parameters(
"Simplified_AdEMAMix": False,
"alpha_grad": 100.0,
},
Optimizer.Stiefel_LoRA: {
"momentum": 0.95,
"cautious_wd": False,
"weight_decay": 0.0,
"nnmf_factor": False,
"stochastic_rounding": True,
"compile": False,
"fused_back_pass": False,
"Simplified_AdEMAMix": False,
"alpha_grad": 100.0,
},
Optimizer.LION_ADV: {
"beta1": 0.9,
"beta2": 0.99,
Expand Down
2 changes: 1 addition & 1 deletion requirements-global.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ prodigyopt==1.1.2 # prodigy optimizer
schedulefree==1.4.1 # schedule-free optimizers
pytorch_optimizer==3.6.0 # pytorch optimizers
prodigy-plus-schedule-free==2.0.1 # Prodigy plus optimizer
adv_optm==2.2.3 # advanced optimizers
adv_optm==2.3.dev3 # advanced optimizers
-e git+https://github.com/KellerJordan/Muon.git@f90a42b#egg=muon-optimizer

# Profiling
Expand Down