Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
328c978
Allow stochastic atmos training in CoupledStepper
jpdunc23 Jan 21, 2026
eb19d3d
Merge branch 'main' of github.com:ai2cm/ace into coupled-ft-with-stoc…
jpdunc23 Jan 21, 2026
8533d85
Merge branch 'main' of github.com:ai2cm/ace into coupled-ft-with-stoc…
jpdunc23 Feb 21, 2026
81f7329
Fix test
jpdunc23 Feb 21, 2026
2584ab1
Add `optimize_last_step_only` to `LossContributionsConfig`
jpdunc23 Feb 22, 2026
e5131f4
Merge branch 'main' of github.com:ai2cm/ace into coupled-ft-with-stoc…
jpdunc23 Mar 25, 2026
604d22c
Merge branch 'main' of github.com:ai2cm/ace into coupled-ft-with-stoc…
jpdunc23 Apr 2, 2026
7c9e35c
Merge branch 'coupled-ft-with-stochastic-atmos' of github.com:ai2cm/a…
jpdunc23 Apr 2, 2026
1f39fb5
Allow for ensemble ocean training
jpdunc23 Apr 2, 2026
106c8fa
Merge branch 'coupled-ft-with-stochastic-atmos' of github.com:ai2cm/a…
jpdunc23 Apr 2, 2026
827dc04
De-duplicate gen step processing in train_on_batch
jpdunc23 Apr 2, 2026
1ceb8d6
Merge branch 'main' into coupled-ft-with-stochastic-atmos
jpdunc23 Apr 6, 2026
138e82e
Merge branch 'coupled-ft-with-stochastic-atmos' of github.com:ai2cm/a…
jpdunc23 Apr 6, 2026
4f9585e
Add `CoupledTrainStepper._accumulate_loss()`
jpdunc23 Apr 7, 2026
cec0022
Merge branch 'main' of github.com:ai2cm/ace into coupled-ft-with-stoc…
jpdunc23 Apr 8, 2026
dcb0cbd
Merge branch 'coupled-ft-with-stochastic-atmos' of github.com:ai2cm/a…
jpdunc23 Apr 8, 2026
b380c8f
Update tests
jpdunc23 Apr 8, 2026
8743338
Address additional review comments
jpdunc23 Apr 8, 2026
76c13e1
Fix unhelpful docstring
jpdunc23 Apr 8, 2026
c97c3cf
Merge branch 'coupled-ft-with-stochastic-atmos' of github.com:ai2cm/a…
jpdunc23 Apr 8, 2026
30ffd90
Merge branch 'main' into coupled-ft-with-stochastic-atmos
jpdunc23 Apr 9, 2026
3722c55
Merge branch 'coupled-ft-with-stochastic-atmos' into feature/coupled-…
jpdunc23 Apr 9, 2026
d7e12a5
Merge branch 'main' into coupled-ft-with-stochastic-atmos
jpdunc23 Apr 9, 2026
c6d7f35
Merge branch 'coupled-ft-with-stochastic-atmos' into feature/coupled-…
jpdunc23 Apr 9, 2026
d70af7d
Merge branch 'main' of github.com:ai2cm/ace into feature/coupled-opti…
jpdunc23 Apr 9, 2026
f14e2e3
Merge branch 'main' into feature/coupled-optimize-last-step-only
jpdunc23 Apr 13, 2026
82fa43f
Merge branch 'main' of github.com:ai2cm/ace into feature/coupled-opti…
jpdunc23 Apr 15, 2026
d5c00da
Add test with `use_gradient_accumulation` and `optimize_last_step_only`
jpdunc23 Apr 15, 2026
710e0ce
Assert gen_step realm and step
jpdunc23 Apr 15, 2026
2576d7c
Move `n_coupled_steps` to `CoupledTrainStepperConfig`
jpdunc23 Apr 15, 2026
5fc4bed
Update baseline configs
jpdunc23 Apr 15, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ validate_using_ema: true
ema:
decay: 0.999
max_epochs: 20
n_coupled_steps: 4
inference:
n_coupled_steps: 1456
coupled_steps_in_memory: 8
Expand Down Expand Up @@ -122,6 +121,7 @@ optimization:
scheduler:
type: CosineAnnealingLR
stepper_training:
n_coupled_steps: 4
parameter_init:
checkpoint_path: /ckpt.tar
ocean:
Expand Down
2 changes: 1 addition & 1 deletion configs/baselines/cm4-piControl/finetune-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ validate_using_ema: true
ema:
decay: 0.999
max_epochs: 20
n_coupled_steps: 4
inference:
n_coupled_steps: 1456
coupled_steps_in_memory: 8
Expand Down Expand Up @@ -122,6 +121,7 @@ optimization:
scheduler:
type: CosineAnnealingLR
stepper_training:
n_coupled_steps: 4
parameter_init:
checkpoint_path: /ckpt.tar
ocean:
Expand Down
2 changes: 1 addition & 1 deletion configs/baselines/cm4-piControl/train-config-template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ validate_using_ema: true
ema:
decay: 0.999
max_epochs: 20
n_coupled_steps: 4
inference:
n_coupled_steps: 1456
coupled_steps_in_memory: 8
Expand Down Expand Up @@ -120,6 +119,7 @@ optimization:
weight_decay: 0.01
use_gradient_accumulation: true
stepper_training:
n_coupled_steps: 4
ocean:
parameter_init:
weights_path: /ocean_ckpt.tar
Expand Down
2 changes: 1 addition & 1 deletion configs/baselines/cm4-piControl/train-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ validate_using_ema: true
ema:
decay: 0.999
max_epochs: 20
n_coupled_steps: 4
inference:
n_coupled_steps: 1456
coupled_steps_in_memory: 8
Expand Down Expand Up @@ -120,6 +119,7 @@ optimization:
weight_decay: 0.01
use_gradient_accumulation: true
stepper_training:
n_coupled_steps: 4
ocean:
parameter_init:
weights_path: /ocean_ckpt.tar
Expand Down
33 changes: 27 additions & 6 deletions fme/coupled/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@ def effective_loss_scaling(self) -> TensorDict: ...

@abc.abstractmethod
def step_is_optimized(self, step: int) -> bool:
"""Returns True if the step is less than to the number of
steps contributing to the loss.
"""Returns True if the given step should contribute to the loss.

Args:
step: The step index to check.
"""
...

Expand All @@ -53,24 +55,32 @@ class LossContributionsConfig:
starting from the first.
weight: (optional) Weight applied to each step loss for the given realm.
Each step contributes equally to the total loss.
optimize_last_step_only: If True, only the last step within the training
horizon defined by ``n_steps`` is optimized (i.e. contributes to the
loss and has gradients enabled). The optimized step index is
``min(n_steps, n_total_steps) - 1``.

"""

n_steps: float = float("inf")
weight: float = 1.0
optimize_last_step_only: bool = False

def build(
self,
loss_obj: StepLoss,
time_dim: int,
max_n_steps: int,
) -> StepLossABC:
if self.n_steps == 0 or self.weight == 0.0:
return NullLossContributions(loss_obj)
return LossContributions(
n_steps=self.n_steps,
weight=self.weight,
optimize_last_step_only=self.optimize_last_step_only,
loss_obj=loss_obj,
time_dim=time_dim,
max_n_steps=max_n_steps,
)


Expand Down Expand Up @@ -104,28 +114,39 @@ def __init__(
self,
n_steps: float,
weight: float,
optimize_last_step_only: bool,
loss_obj: StepLoss,
time_dim: int,
max_n_steps: int,
):
self._loss = loss_obj
self._n_steps = n_steps
self._weight = weight
self._optimize_last_step_only = optimize_last_step_only
self._time_dim = time_dim
self._max_n_steps = max_n_steps

@property
def effective_loss_scaling(self) -> TensorDict:
return self._loss.effective_loss_scaling

def step_is_optimized(self, step: int) -> bool:
"""Returns True if the step is less than to the number of steps and
weight is != 0. The first step number is assumed to be 0.
"""Returns True if the step should contribute to the loss.

When ``optimize_last_step_only`` is False (default), returns True for
steps ``0`` through ``n_steps - 1``. When True, returns True only for
the step at index ``min(n_steps, n_total_steps) - 1``.
"""
return step < self._n_steps and self._weight != 0.0
if self._weight == 0.0:
return False
if self._optimize_last_step_only:
last_optimized_step = min(self._n_steps, self._max_n_steps) - 1
return step == last_optimized_step
return step < self._n_steps

def __call__(
self, prediction: StepPredictionABC, target_data: TensorMapping
) -> torch.Tensor:
) -> torch.Tensor | None:
if self.step_is_optimized(prediction.step):
return self._weight * self._loss(
prediction.data, target_data, prediction.step
Expand Down
147 changes: 95 additions & 52 deletions fme/coupled/stepper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import datetime
import logging
import pathlib
from collections.abc import Callable, Generator, Iterable
from collections.abc import Generator, Iterable
from typing import Any, Literal

import dacite
Expand Down Expand Up @@ -1022,7 +1022,6 @@ def get_prediction_generator(
initial_condition: CoupledPrognosticState,
forcing_data: CoupledBatchData,
optimizer: OptimizationABC,
step_is_optimized: Callable[[str, int], bool] = lambda n, c: True,
) -> Generator[ComponentStepPrediction, None, None]:
if (
initial_condition.atmosphere_data.as_batch_data().n_timesteps
Expand Down Expand Up @@ -1082,10 +1081,7 @@ def get_prediction_generator(
# predict and yield atmosphere steps
for i_inner in range(self.n_inner_steps):
atmos_step_num = i_outer * self.n_inner_steps + i_inner
optimized = step_is_optimized("atmosphere", atmos_step_num)
context = contextlib.nullcontext() if optimized else torch.no_grad()
with context:
atmos_step = next(atmos_generator)
atmos_step = next(atmos_generator)
yield ComponentStepPrediction(
realm="atmosphere",
data=atmos_step,
Expand Down Expand Up @@ -1116,21 +1112,16 @@ def get_prediction_generator(
labels=ocean_window.labels,
)
# predict and yield a single ocean step
ocean_optimized = step_is_optimized("ocean", i_outer)
ocean_context = (
contextlib.nullcontext() if ocean_optimized else torch.no_grad()
)
with ocean_context:
ocean_step = next(
iter(
self.ocean.get_prediction_generator(
ocean_ic_state,
ocean_forcings,
n_forward_steps=1,
optimizer=optimizer,
)
ocean_step = next(
iter(
self.ocean.get_prediction_generator(
ocean_ic_state,
ocean_forcings,
n_forward_steps=1,
optimizer=optimizer,
)
)
)
yield ComponentStepPrediction(
realm="ocean",
data=ocean_step,
Expand Down Expand Up @@ -1390,7 +1381,11 @@ def effective_loss_scaling(self) -> CoupledTensorMapping:
atmosphere=self._loss_objs["atmosphere"].effective_loss_scaling,
)

def step_is_optimized(self, realm: str, step: int) -> bool:
def step_is_optimized(
self,
realm: Literal["ocean", "atmosphere"],
step: int,
) -> bool:
return self._loss_objs[realm].step_is_optimized(step)

def __call__(
Expand Down Expand Up @@ -1420,6 +1415,7 @@ class CoupledTrainStepperConfig:
"""Configuration for training-specific aspects of a coupled stepper.

Parameters:
n_coupled_steps: Number of forward coupled steps in the optimization.
ocean: The configuration for the ocean component.
atmosphere: The configuration for the atmosphere component.
n_ensemble: The number of ensemble members evaluated for each training
Expand All @@ -1430,6 +1426,7 @@ class CoupledTrainStepperConfig:
fine-tuning a previously-trained coupled stepper.
"""

n_coupled_steps: int
ocean: ComponentTrainingConfig
atmosphere: ComponentTrainingConfig
n_ensemble: int = -1 # sentinel value to avoid None typing of attribute
Expand Down Expand Up @@ -1463,14 +1460,18 @@ def __post_init__(self):
else:
self.n_ensemble = 1

def _build_loss(self, stepper: CoupledStepper) -> CoupledStepperTrainLoss:
def _build_loss(
self, stepper: CoupledStepper, n_coupled_steps: int
) -> CoupledStepperTrainLoss:
ocean_step_loss = stepper.ocean.build_loss(self.ocean.loss)
atmos_step_loss = stepper.atmosphere.build_loss(self.atmosphere.loss)
max_n_steps_ocean = n_coupled_steps
max_n_steps_atmos = n_coupled_steps * stepper.n_inner_steps
ocean_loss = self.ocean.loss_contributions.build(
ocean_step_loss, stepper.ocean.TIME_DIM
ocean_step_loss, stepper.ocean.TIME_DIM, max_n_steps=max_n_steps_ocean
)
atmos_loss = self.atmosphere.loss_contributions.build(
atmos_step_loss, stepper.atmosphere.TIME_DIM
atmos_step_loss, stepper.atmosphere.TIME_DIM, max_n_steps=max_n_steps_atmos
)
return CoupledStepperTrainLoss(ocean_loss, atmos_loss)

Expand Down Expand Up @@ -1552,7 +1553,7 @@ def __init__(
"""
self._stepper = stepper
self._config = config
self._loss = self._config._build_loss(stepper)
self._loss = self._config._build_loss(stepper, config.n_coupled_steps)

@property
def ocean(self) -> Stepper:
Expand Down Expand Up @@ -1604,6 +1605,36 @@ def load_state(self, state: dict[str, Any]):
def update_training_history(self, training_job: TrainingJob) -> None:
self._stepper.update_training_history(training_job)

def _accumulate_step_loss(
self,
gen_step: ComponentStepPrediction,
forward_data: TensorMapping,
time_dim: int,
n_ensemble: int,
optimization: OptimizationABC,
metrics: ComponentStepMetrics,
output_list: list[ComponentEnsembleStepPrediction],
) -> None:
target_step = {
k: v.select(time_dim, gen_step.step) for k, v in forward_data.items()
}
ensemble_step = ComponentEnsembleStepPrediction(
realm=gen_step.realm,
data=unfold_ensemble_dim(gen_step.data, n_ensemble),
step=gen_step.step,
)
target_step_ensemble = add_ensemble_dim(target_step)
step_loss = self._loss(ensemble_step, target_step_ensemble)
if step_loss is not None:
label = f"loss/{gen_step.realm}_step_{gen_step.step}"
metrics.add_metric(label, step_loss.detach(), gen_step.realm)
optimization.accumulate_loss(step_loss)
output_list.append(
ensemble_step.detach_if_using_gradient_accumulation(
optimization
) # eagerly detach
)

def _accumulate_loss(
self,
data: CoupledBatchData,
Expand All @@ -1630,36 +1661,48 @@ def _accumulate_loss(
input_data,
data_ensemble,
optimization,
step_is_optimized=self._loss.step_is_optimized,
)
output_iterator = iter(output_generator)
output_list: list[ComponentEnsembleStepPrediction] = []
for gen_step in output_generator:
if gen_step.realm == "ocean":
target_step = {
k: v.select(self.ocean.TIME_DIM, gen_step.step)
for k, v in ocean_forward_data.data.items()
}
else:
target_step = {
k: v.select(self.atmosphere.TIME_DIM, gen_step.step)
for k, v in atmos_forward_data.data.items()
}
ensemble_step = ComponentEnsembleStepPrediction(
realm=gen_step.realm,
data=unfold_ensemble_dim(gen_step.data, n_ensemble),
step=gen_step.step,
)
target_step_ensemble = add_ensemble_dim(target_step)
step_loss = self._loss(ensemble_step, target_step_ensemble)
if step_loss is not None:
label = f"loss/{gen_step.realm}_step_{gen_step.step}"
metrics.add_metric(label, step_loss.detach(), gen_step.realm)
optimization.accumulate_loss(step_loss)
output_list.append(
ensemble_step.detach_if_using_gradient_accumulation(
optimization
) # eagerly detach
)
n_outer_steps = data.ocean_data.n_timesteps - self.n_ic_timesteps
for i_outer in range(n_outer_steps):
for i_inner in range(self.n_inner_steps):
global_atmos_step = i_outer * self.n_inner_steps + i_inner
optimize = self._loss.step_is_optimized(
"atmosphere",
global_atmos_step,
)
Comment on lines +1671 to +1674
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was previously handled in CoupledStepper.get_prediction_generator() but refactored to be handled here in CoupledTrainStepper._accumulate_loss().

grad_context = contextlib.nullcontext() if optimize else torch.no_grad()
with grad_context:
gen_step = next(output_iterator)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding a check like:

assert gen_step.realm == "atmosphere" and gen_step.step == global_atmos_step

since we now assume the generator always does n_inner_steps of atmosphere then 1 ocean step. This matches get_prediction_generator, but since they are in two different spots, better to add a check to make sure it's not mismatched.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

assert (
gen_step.realm == "atmosphere"
and gen_step.step == global_atmos_step
)
self._accumulate_step_loss(
gen_step=gen_step,
forward_data=atmos_forward_data.data,
time_dim=self.atmosphere.TIME_DIM,
n_ensemble=n_ensemble,
optimization=optimization,
metrics=metrics,
output_list=output_list,
)
optimize = self._loss.step_is_optimized("ocean", i_outer)
grad_context = contextlib.nullcontext() if optimize else torch.no_grad()
with grad_context:
gen_step = next(output_iterator)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as the inner loop, also check here:

assert gen_step.realm == "ocean" and gen_step.step == i_outer

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

assert gen_step.realm == "ocean" and gen_step.step == i_outer
self._accumulate_step_loss(
gen_step=gen_step,
forward_data=ocean_forward_data.data,
time_dim=self.ocean.TIME_DIM,
n_ensemble=n_ensemble,
optimization=optimization,
metrics=metrics,
output_list=output_list,
)

return output_list

def train_on_batch(
Expand Down
Loading
Loading