Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions fme/ace/train/train_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ class TrainConfig:
evaluate_before_training: bool = False
save_best_inference_epoch_checkpoints: bool = False
lr_tuning: LRTuningConfig | None = None
finetune_optimization_checkpoint_path: str | None = None
resume_results: ResumeResultsConfig | None = None

def __post_init__(self):
Expand Down
135 changes: 134 additions & 1 deletion fme/core/generics/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
Trainer,
TrainOutputABC,
TrainStepperABC,
_load_finetune_optimization_state,
count_parameters,
epoch_checkpoint_enabled,
)
Expand Down Expand Up @@ -236,6 +237,7 @@ class Config:
evaluate_before_training: bool = False
save_best_inference_epoch_checkpoints: bool = False
lr_tuning: LRTuningConfig | None = None
finetune_optimization_checkpoint_path: str | None = None

def __post_init__(self):
start_epoch = 0 if self.evaluate_before_training else 1
Expand Down Expand Up @@ -343,6 +345,8 @@ def get_trainer(
n_validation_batches: int = 5,
save_checkpoint: bool = True,
lr_tuning: LRTuningConfig | None = None,
finetune_optimization_checkpoint_path: str | None = None,
lr: float = 0.01,
) -> tuple[TrainConfigProtocol, Trainer]:
if checkpoint_dir is None:
checkpoint_dir = os.path.join(tmp_path, "checkpoints")
Expand Down Expand Up @@ -376,7 +380,7 @@ def build_optimization(modules: torch.nn.ModuleList) -> Optimization:
opt = Optimization(
parameters=itertools.chain(*[module.parameters() for module in modules]),
optimizer_type="Adam",
lr=0.01,
lr=lr,
max_epochs=max_epochs,
scheduler=scheduler_config,
enable_automatic_mixed_precision=False,
Expand All @@ -400,6 +404,17 @@ def step_weights_side_effect(*args, **kwargs):
if stepper_module_values is None:
raise ValueError("stepper_module_values is None")
module.weight.data.fill_(stepper_module_values[i])
for param in module.parameters():
if param not in opt.optimizer.state:
opt.optimizer.state[param] = {
"step": torch.tensor(0.0, device=param.data.device),
"exp_avg": torch.zeros_like(param.data),
"exp_avg_sq": torch.zeros_like(param.data),
}
state = opt.optimizer.state[param]
state["step"] += 1
state["exp_avg"] += 0.1
state["exp_avg_sq"] += 0.01

opt.step_weights = unittest.mock.MagicMock(side_effect=step_weights_side_effect) # type: ignore
return opt
Expand All @@ -419,6 +434,7 @@ def build_ema(modules: torch.nn.ModuleList) -> EMATracker:
save_best_inference_epoch_checkpoints=save_best_inference_epoch_checkpoints,
save_checkpoint=save_checkpoint,
lr_tuning=lr_tuning,
finetune_optimization_checkpoint_path=finetune_optimization_checkpoint_path,
)
aggregator_builder = AggregatorBuilder(
train_losses=train_losses,
Expand Down Expand Up @@ -1388,3 +1404,120 @@ def test_epoch_checkpoint_enabled_includes_final_epoch():
save_epochs = Slice(step=5)
assert epoch_checkpoint_enabled(5, max_epochs, save_epochs)
assert epoch_checkpoint_enabled(10, max_epochs, save_epochs)


def test_load_finetune_optimization_state_missing_key(tmp_path: str):
"""_load_finetune_optimization_state raises ValueError when the checkpoint
does not contain an 'optimization' key (e.g. a best_ckpt.tar)."""
checkpoint_path = os.path.join(tmp_path, "bad_ckpt.tar")
torch.save({"stepper": {}, "ema": {}}, checkpoint_path)

model = torch.nn.Linear(1, 1).to(get_device())
optimization = Optimization(
parameters=model.parameters(),
optimizer_type="Adam",
lr=0.01,
max_epochs=1,
scheduler=SchedulerConfig(),
enable_automatic_mixed_precision=False,
kwargs={},
)

with pytest.raises(ValueError, match="does not contain optimization state"):
_load_finetune_optimization_state(optimization, checkpoint_path)


def test_finetune_optimization_checkpoint_loads_optimizer_state(tmp_path: str):
"""Trainer loads optimizer state from a finetune checkpoint while
keeping counters and scheduler fresh.

Uses a StepLR scheduler on stage 1 so the saved checkpoint contains a
decayed LR and advanced scheduler state, then verifies stage 2 starts
with the fresh configured LR and a fresh scheduler.
"""
configured_lr = 0.01
stage1_scheduler = SchedulerConfig(
type="StepLR", kwargs={"step_size": 1, "gamma": 0.5}
)

stage1_dir = os.path.join(tmp_path, "stage1")
_, stage1_trainer = get_trainer(
stage1_dir,
max_epochs=1,
n_train_batches=4,
stepper_module_values=np.array([1.0]),
scheduler_config=stage1_scheduler,
lr=configured_lr,
)
assert stage1_trainer.optimization.optimizer.state_dict()["state"] == {}

stage1_trainer.train()
assert stage1_trainer.optimization.learning_rate < configured_lr

stage1_trainer._save_restart_checkpoints()
stage1_ckpt_path = stage1_trainer.paths.latest_checkpoint_path

# verify training updated the optimizer state dict
stage1_opt_state = stage1_trainer.optimization.optimizer.state_dict()["state"]
assert stage1_opt_state != {}, "optimizer state should change during training"

stage2_dir = os.path.join(tmp_path, "stage2")
_, stage2_trainer = get_trainer(
stage2_dir,
max_epochs=1,
n_train_batches=4,
stepper_module_values=np.array([2.0]),
finetune_optimization_checkpoint_path=stage1_ckpt_path,
lr=configured_lr,
)

assert stage2_trainer._epochs_trained == 0
assert stage2_trainer._start_epoch == 0
assert stage2_trainer.num_batches_seen == 0

# optimizer state loaded from stage1 ckpt
stage2_opt_state = stage2_trainer.optimization.optimizer.state_dict()["state"]
for param_id in stage1_opt_state:
assert param_id in stage2_opt_state
for key in ("step", "exp_avg", "exp_avg_sq"):
assert key in stage2_opt_state[param_id]
torch.testing.assert_close(
stage2_opt_state[param_id][key],
stage1_opt_state[param_id][key],
)

# lr and scheduler are overwritten by TrainConfig
assert stage2_trainer.optimization.learning_rate == configured_lr
fresh_scheduler = SchedulerConfig().build(
stage2_trainer.optimization.optimizer, max_epochs=1
)
assert (
stage2_trainer.optimization.scheduler.state_dict()
== fresh_scheduler.state_dict()
)


def test_resume_takes_precedence_over_finetune_path(tmp_path: str):
"""When a ckpt.tar exists in the checkpoint dir (preemption resume),
Trainer resumes from it and ignores finetune_optimization_checkpoint_path."""
_, trainer = get_trainer(
tmp_path,
max_epochs=2,
n_train_batches=4,
stepper_module_values=np.array([1.0, 2.0]),
)
trainer.train()

ckpt_path = trainer.paths.latest_checkpoint_path
assert os.path.isfile(ckpt_path)

_, resumed_trainer = get_trainer(
tmp_path,
max_epochs=2,
n_train_batches=4,
stepper_module_values=np.array([1.0, 2.0]),
finetune_optimization_checkpoint_path=ckpt_path,
)

assert resumed_trainer._epochs_trained == 2
assert resumed_trainer.num_batches_seen == 8
47 changes: 47 additions & 0 deletions fme/core/generics/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,9 @@ def save_best_inference_epoch_checkpoints(self) -> bool: ...
@property
def lr_tuning(self) -> LRTuningConfig | None: ...

@property
def finetune_optimization_checkpoint_path(self) -> str | None: ...

def get_inference_epochs(self) -> list[int]: ...


Expand Down Expand Up @@ -262,6 +265,14 @@ def __init__(
if resuming:
logging.info(f"Resuming training from {self.paths.latest_checkpoint_path}")
self.restore_checkpoint(self.paths.latest_checkpoint_path)
elif config.finetune_optimization_checkpoint_path is not None:
logging.info(
"Loading optimizer state for fine-tuning from "
f"{config.finetune_optimization_checkpoint_path}"
)
_load_finetune_optimization_state(
self.optimization, config.finetune_optimization_checkpoint_path
)

wandb = WandB.get_instance()
wandb.watch(self.stepper.modules)
Expand Down Expand Up @@ -789,6 +800,42 @@ def _restore_checkpoint(trainer: Trainer, checkpoint_path):
trainer._ema = EMATracker.from_state(checkpoint["ema"], trainer.stepper.modules)


def _tensors_to_device(obj, device: torch.device):
"""Recursively move all tensors in a nested dict/list to *device*."""
if isinstance(obj, torch.Tensor):
return obj.to(device)
elif isinstance(obj, dict):
return {k: _tensors_to_device(v, device) for k, v in obj.items()}
elif isinstance(obj, list | tuple):
return type(obj)(_tensors_to_device(v, device) for v in obj)
return obj


def _load_finetune_optimization_state(optimization: Optimization, checkpoint_path: str):
"""Load optimizer (and optionally grad scaler) state for fine-tuning.

Only loads the optimizer state dict and grad scaler state from the
checkpoint. Scheduler state and training counters are not restored, so
the current config's schedule starts from scratch. The configured
learning rate is preserved.

The checkpoint is loaded on CPU so that only the optimization state
(not model weights, EMA, etc.) is transferred to the training device.
"""
lr = optimization.learning_rate
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
if "optimization" not in checkpoint:
raise ValueError(
f"Checkpoint at {checkpoint_path} does not contain optimization "
"state. Only checkpoints saved with include_optimization=True "
"(i.e. ckpt.tar) support fine-tune optimization loading."
)
optim_state = checkpoint["optimization"]
del checkpoint
optim_state = _tensors_to_device(optim_state, fme.get_device())
optimization.load_optimizer_state_for_finetuning(optim_state, lr=lr)


def count_parameters(modules: torch.nn.ModuleList) -> int:
parameters = 0
for module in modules:
Expand Down
19 changes: 19 additions & 0 deletions fme/core/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,25 @@ def set_learning_rate(self, lr: float):
for param_group in self.optimizer.param_groups:
param_group["lr"] = lr

def load_optimizer_state_for_finetuning(self, state: dict, lr: float):
"""Load optimizer and grad scaler state from a checkpoint for fine-tuning.

Restores the optimizer state dict (e.g. Adam momentum buffers) and,
if available, the grad scaler state. Does NOT restore the scheduler
state, so the current config's schedule starts from scratch. After
loading, re-applies the given ``lr`` because
``optimizer.load_state_dict`` overwrites param-group learning rates.

Args:
state: The optimization state dict as saved by ``get_state()``,
containing at least ``"optimizer_state_dict"``.
lr: The learning rate to set after loading.
"""
self.optimizer.load_state_dict(state["optimizer_state_dict"])
self.set_learning_rate(lr)
if self.gscaler is not None and state.get("gscaler_state_dict") is not None:
self.gscaler.load_state_dict(state["gscaler_state_dict"])

def get_state(self):
"""
Returns state as a serializable data structure.
Expand Down
84 changes: 84 additions & 0 deletions fme/core/test_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,90 @@ def test_load_state_then_set_learning_rate():
), "Parameters should differ when trained at different learning rates"


def test_load_optimizer_state_for_finetuning():
"""load_optimizer_state_for_finetuning restores momentum buffers, applies a
new LR, and leaves the scheduler in its initial state."""
torch.manual_seed(0)
model = nn.Linear(2, 2).to(fme.get_device())
x = torch.randn(10, 2).to(fme.get_device())

optimization = _build_optimization(model.parameters(), lr=0.001)
assert optimization.optimizer.state_dict()["state"] == {}

for _ in range(3):
loss = model(x).sum()
optimization.accumulate_loss(loss)
optimization.step_weights()
optimization.step_scheduler(is_iteration=False)

assert optimization.optimizer.state_dict()["state"] != {}

saved_state = optimization.get_state()

model2 = copy.deepcopy(model)
new_lr = 0.01
optimization2 = _build_optimization(model2.parameters(), lr=new_lr)
fresh_scheduler_state = optimization2.scheduler.state_dict()

optimization2.load_optimizer_state_for_finetuning(saved_state, lr=new_lr)

assert optimization2.learning_rate == new_lr

orig_opt_state = optimization.optimizer.state_dict()["state"]
loaded_opt_state = optimization2.optimizer.state_dict()["state"]
for param_id in orig_opt_state:
for key in ("exp_avg", "exp_avg_sq"):
torch.testing.assert_close(
loaded_opt_state[param_id][key], orig_opt_state[param_id][key]
)

assert optimization2.scheduler.state_dict() == fresh_scheduler_state


@pytest.mark.skipif(not torch.cuda.is_available(), reason="GradScaler requires CUDA")
def test_load_optimizer_state_for_finetuning_with_gscaler():
"""load_optimizer_state_for_finetuning restores grad scaler state when AMP
is enabled on both the source and target Optimization."""
torch.manual_seed(0)
model = nn.Linear(2, 2).to(fme.get_device())
x = torch.randn(10, 2).to(fme.get_device())

optimization = Optimization(
parameters=model.parameters(),
optimizer_type="Adam",
lr=0.001,
max_epochs=10,
scheduler=SchedulerConfig(),
enable_automatic_mixed_precision=True,
kwargs={},
)

for _ in range(3):
with optimization.autocast():
loss = model(x).sum()
optimization.accumulate_loss(loss)
optimization.step_weights()

saved_state = optimization.get_state()
assert saved_state["gscaler_state_dict"] != {}

model2 = copy.deepcopy(model)
optimization2 = Optimization(
parameters=model2.parameters(),
optimizer_type="Adam",
lr=0.01,
max_epochs=10,
scheduler=SchedulerConfig(),
enable_automatic_mixed_precision=True,
kwargs={},
)

optimization2.load_optimizer_state_for_finetuning(saved_state, lr=0.01)

assert optimization2.gscaler is not None
assert optimization2.gscaler.state_dict() == saved_state["gscaler_state_dict"]


def test_scheduler_step_timing():
"""
Test that schedulers step at the correct timing based on
Expand Down
1 change: 1 addition & 0 deletions fme/coupled/train/train_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ class TrainConfig:
evaluate_before_training: bool = True
save_best_inference_epoch_checkpoints: bool = False
lr_tuning: LRTuningConfig | None = None
finetune_optimization_checkpoint_path: str | None = None
resume_results: ResumeResultsConfig | None = None

def __post_init__(self):
Expand Down
1 change: 1 addition & 0 deletions fme/diffusion/train_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ class TrainConfig:
evaluate_before_training: bool = False
save_best_inference_epoch_checkpoints: bool = False
lr_tuning: LRTuningConfig | None = None
finetune_optimization_checkpoint_path: str | None = None
resume_results: ResumeResultsConfig | None = None

def __post_init__(self):
Expand Down
Loading