Skip to content

sine2pi/Maxfactor

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

158 Commits
 
 
 
 

Repository files navigation

class MaxFactor(torch.optim.Optimizer):
    def __init__(self, named_params, lr=0.00025, b_decay=-0.8, eps=(1e-8, 1e-8), d=1.0, w_decay=0.025, gamma=0.99, max=False, clip=False, cap=0.1):

        named_params = list(named_params)
        total = len(named_params)
        params = [p for n, p in named_params]

        defaults = dict(lr=lr, b_decay=b_decay, eps=eps, d=d, w_decay=w_decay, 
                        gamma=gamma, max=max, clip=clip, cap=cap)
        super().__init__(params, defaults)

        for i, (name, p) in enumerate(named_params):
            depth = i / total
            state = self.state[p]
            
            if depth < 0.2:
                state['role'] = 'robust'
            elif depth < 0.7:
                state['role'] = 'balanced'
            else:
                state['role'] = 'aggressive'

    @staticmethod
    def _rms(tensor):
        return tensor.norm() / (tensor.numel() ** 0.5)

    @torch.no_grad()
    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            p_grad, grads, row_vars, col_vars, v, state_steps = [], [], [], [], [], []
            eps1, eps2 = group["eps"]
            for p in group["params"]:
                if p.grad is None:
                    continue
                grad = p.grad
                if grad.dtype in {torch.float16, torch.bfloat16}:
                    grad = grad.float()

                state = self.state[p]
                if "step" not in state:
                    state["step"] = torch.tensor(0.0, dtype=torch.float32)
                    if p.grad.dim() > 1:
                        row_shape, col_shape = list(p.grad.shape), list(p.grad.shape)
                        row_shape[-1], col_shape[-2] = 1, 1
                        state["row_var"], state["col_var"] = p.grad.new_zeros(row_shape), p.grad.new_zeros(col_shape)
                    state["v"] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    state["RMS"] = self._rms(p).item()

                row_vars.append(state.get("row_var", None))
                col_vars.append(state.get("col_var", None))
                v.append(state["v"])
                state_steps.append(state["step"])
                p_grad.append(p)
                grads.append(grad)

            for i, param in enumerate(p_grad):
                grad = grads[i]

                if group["max"]:
                    grad = -grad
                step_t, row_var, col_var, vi = state_steps[i], row_vars[i], col_vars[i], v[i]

                if eps1 is None:
                    eps1 = torch.finfo(param.dtype).eps
                    
                step_t += 1
                step_float = step_t.item()
                
                beta_t = min(0.999, max(0.001, step_float ** group["b_decay"]))
                state["RMS"] = self._rms(param).item()
                
                rho_t = min(group["lr"], 1 / (step_float ** 0.5))
                alpha = max(eps2, param.norm(2).item() / (param.numel() ** 0.5)) * rho_t

                if group["w_decay"] != 0:
                    param.mul_(1 - group["lr"] * group["w_decay"])

                if grad.dim() > 1:
                    row_mean = torch.norm(grad, dim=-1, keepdim=True).square_().div_(grad.size(-1) + 1e-8)
                    row_var.lerp_(row_mean, beta_t)
                    col_mean = torch.norm(grad, dim=-2, keepdim=True).square_().div_(grad.size(-2) + 1e-8)
                    col_var.lerp_(col_mean, beta_t)
                    var_est = row_var @ col_var
                    max_row = row_var.max(dim=-2, keepdim=True)[0]  
                    var_est.div_(max_row.clamp_(min=eps1))
                else:
                    vi.mul_(group["gamma"]).add_(grad ** 2, alpha=1 - group["gamma"])
                    var_est = vi

                update = var_est.clamp_(min=eps1 * eps1).rsqrt_().mul_(grad)
                update = update.div_(torch.norm(update, float('inf')).clamp_(min=eps1))
                
                denom = max(1.0, update.norm(2).item() / ((update.numel() ** 0.5) * group["d"]))
                step_size = alpha / denom

                role = state.get('role', 'balanced')
                if role == 'robust':
                    scale = torch.median(update.abs(), dim=-1, keepdim=True)[0]
                elif role == 'balanced':
                    scale = torch.sqrt(torch.mean(update**2, dim=-1, keepdim=True))
                else: 
                    scale = update.abs().max(dim=-1, keepdim=True)[0]

                if param.dim() < 3:
                    scale = update.sign() * update.abs().max(dim=-1, keepdim=True)[0]
                else:
                    scale = update.sign() * torch.median(update.abs(), dim=-1, keepdim=True)[0]

                impulse = update.sign() * scale
                if group["clip"]:
                    param_rms = torch.norm(param) / (param.numel() ** 0.5)
                    max_allowed_step = param_rms * group["cap"]
                    update_rms = (torch.norm(impulse * step_size) / (impulse.numel() ** 0.5))
                    if update_rms > max_allowed_step:
                        step_size = step_size * (max_allowed_step / (update_rms + 1e-8))
               
                param.add_(impulse, alpha=-step_size)

        return loss    

About

An optimizer that is memory efficient

Topics

Resources

Stars

Watchers

Forks

Contributors

Languages