diff --git a/mup/__init__.py b/mup/__init__.py index f11535b..28f2a99 100644 --- a/mup/__init__.py +++ b/mup/__init__.py @@ -4,4 +4,4 @@ from mup.infshape import * from mup.init import * from mup.layer import * -from mup.optim import * \ No newline at end of file +from mup.optim import MuSGD, MuAdam, MuAdamW, process_param_groups \ No newline at end of file diff --git a/mup/optim.py b/mup/optim.py index a327996..b043720 100644 --- a/mup/optim.py +++ b/mup/optim.py @@ -23,6 +23,8 @@ def MuOptimizer(params, **kwargs): from torch.optim import SGD, Adam, AdamW +import logging +logger = logging.getLogger(__name__) def process_param_groups(params, **kwargs): param_groups = list(params) @@ -51,6 +53,9 @@ def MuAdam(params, impl=Adam, decoupled_wd=False, **kwargs): An instance of `impl` with refined parameter groups, each of which has the correctly scaled learning rate according to mup. ''' + if impl == Adam and kwargs.get('weight_decay', False): + logger.warning('MuAdam does not scale weight decay correctly. Use MuAdamW instead.') + new_param_groups = [] for param_group in process_param_groups(params, **kwargs): # For every existing param group, we split into several new groups