diff --git a/sdxl_train.py b/sdxl_train.py index 60239b69c..b20fd85a2 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -569,14 +569,6 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) - # lr schedulerを用意する - if args.fused_optimizer_groups: - # prepare lr schedulers for each optimizer - lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers] - lr_scheduler = lr_schedulers[0] # avoid error in the following code - else: - lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) - # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする if args.full_fp16: assert ( @@ -657,6 +649,14 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) + # lr schedulerを用意する + if args.fused_optimizer_groups: + # prepare lr schedulers for each optimizer + lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers] + lr_scheduler = lr_schedulers[0] # avoid error in the following code + else: + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + if args.deepspeed: ds_model = deepspeed_utils.prepare_deepspeed_model( args,