Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions toolkit/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,34 @@ def get_lr_scheduler(
return torch.optim.lr_scheduler.StepLR(
optimizer, **kwargs
)
elif name == "polynomial":
# --- 1. Obtain base values ---
lr_start = optimizer.param_groups[0]['lr']
# Defult steps equals to 3000
total_steps = kwargs.get('total_iters', 3000)

# --- 2. Intelligent Processing of Terminal Learning Rate ---
# Prioritize reading lr_end from the config file. If not present, default to 10% of the initial value.
# even if you later change the initial LR, it will automatically recalculate the endpoint proportionally.
default_lr_end = lr_start * (5e-5 / 5e-4)
lr_end = kwargs.pop('lr_end', default_lr_end)
# ---3.Power Calculation: Smooth Deceleration---
power = kwargs.get('power', 0.8)
# ---4.Calculation---
# Preventing logical errors where the start point is less than or equal to the end point
if lr_start > lr_end:
ratio = lr_end / lr_start
# Calculate the total virtual steps required for the curve to precisely land at lr_end after total_steps
# Fomula:T_{max} = frac{TotalSteps}{1 - (LR_{end} / LR_{start})^{1/power}}
t_max = total_steps / (1 - pow(ratio, 1/power))
kwargs['total_iters'] = int(t_max)
else:
# If the configuration is incorrect (the endpoint is higher than the starting point), it will downgrade to Constant mode to prevent crashes
return torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0)

return torch.optim.lr_scheduler.PolynomialLR(
optimizer, **kwargs
)
elif name == "constant":
if 'factor' not in kwargs:
kwargs['factor'] = 1.0
Expand Down