Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
328c978
Allow stochastic atmos training in CoupledStepper
jpdunc23 Jan 21, 2026
eb19d3d
Merge branch 'main' of github.com:ai2cm/ace into coupled-ft-with-stoc…
jpdunc23 Jan 21, 2026
8533d85
Merge branch 'main' of github.com:ai2cm/ace into coupled-ft-with-stoc…
jpdunc23 Feb 21, 2026
81f7329
Fix test
jpdunc23 Feb 21, 2026
2584ab1
Add `optimize_last_step_only` to `LossContributionsConfig`
jpdunc23 Feb 22, 2026
b11d5af
Randomly sample coupled `LossContributions.n_steps`
jpdunc23 Feb 23, 2026
e5131f4
Merge branch 'main' of github.com:ai2cm/ace into coupled-ft-with-stoc…
jpdunc23 Mar 25, 2026
604d22c
Merge branch 'main' of github.com:ai2cm/ace into coupled-ft-with-stoc…
jpdunc23 Apr 2, 2026
7c9e35c
Merge branch 'coupled-ft-with-stochastic-atmos' of github.com:ai2cm/a…
jpdunc23 Apr 2, 2026
ed07188
Merge branch 'feature/coupled-optimize-last-step-only' of github.com:…
jpdunc23 Apr 2, 2026
1f39fb5
Allow for ensemble ocean training
jpdunc23 Apr 2, 2026
106c8fa
Merge branch 'coupled-ft-with-stochastic-atmos' of github.com:ai2cm/a…
jpdunc23 Apr 2, 2026
827dc04
De-duplicate gen step processing in train_on_batch
jpdunc23 Apr 2, 2026
2256741
Merge branch 'feature/coupled-optimize-last-step-only' of github.com:…
jpdunc23 Apr 2, 2026
1ceb8d6
Merge branch 'main' into coupled-ft-with-stochastic-atmos
jpdunc23 Apr 6, 2026
138e82e
Merge branch 'coupled-ft-with-stochastic-atmos' of github.com:ai2cm/a…
jpdunc23 Apr 6, 2026
10702a3
Merge branch 'feature/coupled-optimize-last-step-only' of github.com:…
jpdunc23 Apr 6, 2026
4f9585e
Add `CoupledTrainStepper._accumulate_loss()`
jpdunc23 Apr 7, 2026
cec0022
Merge branch 'main' of github.com:ai2cm/ace into coupled-ft-with-stoc…
jpdunc23 Apr 8, 2026
dcb0cbd
Merge branch 'coupled-ft-with-stochastic-atmos' of github.com:ai2cm/a…
jpdunc23 Apr 8, 2026
b380c8f
Update tests
jpdunc23 Apr 8, 2026
570fa6f
Merge branch 'feature/coupled-optimize-last-step-only' of github.com:…
jpdunc23 Apr 8, 2026
8743338
Address additional review comments
jpdunc23 Apr 8, 2026
76c13e1
Fix unhelpful docstring
jpdunc23 Apr 8, 2026
c97c3cf
Merge branch 'coupled-ft-with-stochastic-atmos' of github.com:ai2cm/a…
jpdunc23 Apr 8, 2026
c992681
Merge branch 'feature/coupled-optimize-last-step-only' of github.com:…
jpdunc23 Apr 8, 2026
30ffd90
Merge branch 'main' into coupled-ft-with-stochastic-atmos
jpdunc23 Apr 9, 2026
3722c55
Merge branch 'coupled-ft-with-stochastic-atmos' into feature/coupled-…
jpdunc23 Apr 9, 2026
6e2aa96
Merge branch 'feature/coupled-optimize-last-step-only' into feature/c…
jpdunc23 Apr 9, 2026
d7e12a5
Merge branch 'main' into coupled-ft-with-stochastic-atmos
jpdunc23 Apr 9, 2026
c6d7f35
Merge branch 'coupled-ft-with-stochastic-atmos' into feature/coupled-…
jpdunc23 Apr 9, 2026
ab0cc95
Merge branch 'feature/coupled-optimize-last-step-only' into feature/c…
jpdunc23 Apr 9, 2026
d70af7d
Merge branch 'main' of github.com:ai2cm/ace into feature/coupled-opti…
jpdunc23 Apr 9, 2026
ccc4086
Merge branch 'feature/coupled-optimize-last-step-only' of github.com:…
jpdunc23 Apr 9, 2026
3e676ef
Merge branch 'feature/coupled-loss-contrib-random-n_steps' of github.…
jpdunc23 Apr 9, 2026
f14e2e3
Merge branch 'main' into feature/coupled-optimize-last-step-only
jpdunc23 Apr 13, 2026
475046c
Merge branch 'feature/coupled-optimize-last-step-only' into feature/c…
jpdunc23 Apr 13, 2026
82fa43f
Merge branch 'main' of github.com:ai2cm/ace into feature/coupled-opti…
jpdunc23 Apr 15, 2026
d5c00da
Add test with `use_gradient_accumulation` and `optimize_last_step_only`
jpdunc23 Apr 15, 2026
710e0ce
Assert gen_step realm and step
jpdunc23 Apr 15, 2026
2576d7c
Move `n_coupled_steps` to `CoupledTrainStepperConfig`
jpdunc23 Apr 15, 2026
5fc4bed
Update baseline configs
jpdunc23 Apr 15, 2026
05b4518
Merge branch 'feature/coupled-optimize-last-step-only' of github.com:…
jpdunc23 Apr 15, 2026
c175be3
Merge branch 'feature/coupled-loss-contrib-random-n_steps' of github.…
jpdunc23 Apr 15, 2026
78f7108
Merge branch 'main' of github.com:ai2cm/ace into feature/coupled-loss…
jpdunc23 Apr 16, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 28 additions & 6 deletions fme/coupled/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch

from fme.ace.stepper.time_length_probabilities import TimeLengthProbabilities
from fme.core.device import get_device
from fme.core.loss import StepLoss
from fme.core.typing_ import TensorDict, TensorMapping
Expand All @@ -28,6 +29,14 @@ class StepLossABC(abc.ABC):
@abc.abstractmethod
def effective_loss_scaling(self) -> TensorDict: ...

def sample_n_steps(self) -> None:
"""Sample a new effective n_steps for the current batch.

No-op by default; override in subclasses that support stochastic
n_steps via ``TimeLengthProbabilities``.
"""
pass

@abc.abstractmethod
def step_is_optimized(self, step: int) -> bool:
"""Returns True if the given step should contribute to the loss.
Expand All @@ -51,8 +60,10 @@ class LossContributionsConfig:
Configuration for loss contributions.

Parameters:
n_steps: (optional) The number of consecutive steps contributing to the loss,
starting from the first.
n_steps: The number of consecutive steps contributing to the loss,
starting from the first. Can be a float (including ``inf`` for all
steps) or a ``TimeLengthProbabilities`` for stochastic per-batch
sampling.
weight: (optional) Weight applied to each step loss for the given realm.
Each step contributes equally to the total loss.
optimize_last_step_only: If True, only the last step within the training
Expand All @@ -62,7 +73,7 @@ class LossContributionsConfig:

"""

n_steps: float = float("inf")
n_steps: TimeLengthProbabilities | float = float("inf")
weight: float = 1.0
optimize_last_step_only: bool = False

Expand All @@ -72,7 +83,9 @@ def build(
time_dim: int,
max_n_steps: int,
) -> StepLossABC:
if self.n_steps == 0 or self.weight == 0.0:
if self.weight == 0.0:
return NullLossContributions(loss_obj)
if isinstance(self.n_steps, int | float) and self.n_steps == 0:
return NullLossContributions(loss_obj)
return LossContributions(
n_steps=self.n_steps,
Expand Down Expand Up @@ -112,20 +125,29 @@ def __call__(
class LossContributions(StepLossABC):
def __init__(
self,
n_steps: float,
n_steps: TimeLengthProbabilities | float,
weight: float,
optimize_last_step_only: bool,
loss_obj: StepLoss,
time_dim: int,
max_n_steps: int,
):
self._loss = loss_obj
self._n_steps = n_steps
if isinstance(n_steps, TimeLengthProbabilities):
self._n_steps_sampler: TimeLengthProbabilities | None = n_steps
self._n_steps: float = float(n_steps.max_n_forward_steps)
else:
self._n_steps_sampler = None
self._n_steps = n_steps
self._weight = weight
self._optimize_last_step_only = optimize_last_step_only
self._time_dim = time_dim
self._max_n_steps = max_n_steps

def sample_n_steps(self) -> None:
if self._n_steps_sampler is not None:
self._n_steps = float(self._n_steps_sampler.sample())

@property
def effective_loss_scaling(self) -> TensorDict:
return self._loss.effective_loss_scaling
Expand Down
5 changes: 5 additions & 0 deletions fme/coupled/stepper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1381,6 +1381,10 @@ def effective_loss_scaling(self) -> CoupledTensorMapping:
atmosphere=self._loss_objs["atmosphere"].effective_loss_scaling,
)

def sample_n_steps(self) -> None:
for loss_obj in self._loss_objs.values():
loss_obj.sample_n_steps()

def step_is_optimized(
self,
realm: Literal["ocean", "atmosphere"],
Expand Down Expand Up @@ -1731,6 +1735,7 @@ def train_on_batch(
)

metrics = ComponentStepMetrics()
self._loss.sample_n_steps()
optimization.set_mode(self.modules)
with optimization.autocast():
output_list = self._accumulate_loss(
Expand Down
96 changes: 96 additions & 0 deletions fme/coupled/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,102 @@ def test_step_is_optimized_last_step_only_weight_zero():
assert not loss.step_is_optimized(0)


def test_stochastic_n_steps_sample_changes_step_is_optimized():
from fme.ace.stepper.time_length_probabilities import (
TimeLengthProbabilities,
TimeLengthProbability,
)

sampler = TimeLengthProbabilities(
outcomes=[
TimeLengthProbability(steps=2, probability=1.0),
]
)
config = LossContributionsConfig(n_steps=sampler)
loss = config.build(loss_obj=Mock(spec=StepLoss), time_dim=1, max_n_steps=4)
# before sampling, _n_steps is max_n_forward_steps = 2
assert loss.step_is_optimized(0)
assert loss.step_is_optimized(1)
assert not loss.step_is_optimized(2)

# after sampling (deterministic: always 2), same behavior
loss.sample_n_steps()
assert loss.step_is_optimized(0)
assert loss.step_is_optimized(1)
assert not loss.step_is_optimized(2)


def test_stochastic_n_steps_deterministic_outcome():
from fme.ace.stepper.time_length_probabilities import (
TimeLengthProbabilities,
TimeLengthProbability,
)

sampler = TimeLengthProbabilities(
outcomes=[
TimeLengthProbability(steps=3, probability=1.0),
]
)
config = LossContributionsConfig(n_steps=sampler)
loss = config.build(loss_obj=Mock(spec=StepLoss), time_dim=1, max_n_steps=4)
loss.sample_n_steps()
assert loss.step_is_optimized(0)
assert loss.step_is_optimized(1)
assert loss.step_is_optimized(2)
assert not loss.step_is_optimized(3)


def test_stochastic_n_steps_samples_vary():
"""With multiple outcomes, repeated sampling should eventually produce
different effective n_steps values."""
from fme.ace.stepper.time_length_probabilities import (
TimeLengthProbabilities,
TimeLengthProbability,
)

sampler = TimeLengthProbabilities(
outcomes=[
TimeLengthProbability(steps=1, probability=0.5),
TimeLengthProbability(steps=4, probability=0.5),
]
)
config = LossContributionsConfig(n_steps=sampler)
loss = config.build(loss_obj=Mock(spec=StepLoss), time_dim=1, max_n_steps=5)
seen_optimized_step_3 = False
seen_not_optimized_step_1 = False
for _ in range(100):
loss.sample_n_steps()
if loss.step_is_optimized(3):
seen_optimized_step_3 = True
if not loss.step_is_optimized(1):
seen_not_optimized_step_1 = True
if seen_optimized_step_3 and seen_not_optimized_step_1:
break
assert seen_optimized_step_3, "should sometimes sample n_steps=4"
assert seen_not_optimized_step_1, "should sometimes sample n_steps=1"


def test_sample_n_steps_noop_for_float_config():
config = LossContributionsConfig(n_steps=5.0)
loss = config.build(loss_obj=Mock(spec=StepLoss), time_dim=1, max_n_steps=5)
loss.sample_n_steps()
assert loss.step_is_optimized(4)
assert not loss.step_is_optimized(5)


def test_coupled_stepper_train_loss_sample_n_steps_delegates():
from unittest.mock import MagicMock

ocean_loss = MagicMock(spec=StepLossABC)
atmos_loss = MagicMock(spec=StepLossABC)
coupled_loss = CoupledStepperTrainLoss(
ocean_loss=ocean_loss, atmosphere_loss=atmos_loss
)
coupled_loss.sample_n_steps()
ocean_loss.sample_n_steps.assert_called_once()
atmos_loss.sample_n_steps.assert_called_once()


@pytest.mark.parametrize("ocean_config_kwargs", [{"n_steps": 0}, {"weight": 0.0}])
def test_null_loss_contributions(steps_thru_atmos_7, ocean_config_kwargs):
# test LossContributionsConfig with n_steps = 0
Expand Down
50 changes: 50 additions & 0 deletions fme/coupled/test_stepper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1760,3 +1760,53 @@ def test_train_on_batch_optimize_last_step_only_with_n_steps(
# ocean: step 0 is optimized (n_steps=1)
expected_calls = atmos_n_steps + ocean_n_steps
assert len(optimization.accumulate_loss.call_args_list) == expected_calls


def test_train_on_batch_stochastic_n_steps():
from fme.ace.stepper.time_length_probabilities import (
TimeLengthProbabilities,
TimeLengthProbability,
)

torch.manual_seed(0)
n_forward_times_ocean = 2
n_forward_times_atmosphere = 4

# Deterministic sampler: atmosphere always samples n_steps=2,
# ocean always samples n_steps=1.
atmos_sampler = TimeLengthProbabilities(
outcomes=[TimeLengthProbability(steps=2, probability=1.0)]
)
ocean_sampler = TimeLengthProbabilities(
outcomes=[TimeLengthProbability(steps=1, probability=1.0)]
)
train_stepper_config = CoupledTrainStepperConfig(
n_coupled_steps=1,
ocean=ComponentTrainingConfig(
loss=StepLossConfig(type="MSE"),
loss_contributions=LossContributionsConfig(n_steps=ocean_sampler),
),
atmosphere=ComponentTrainingConfig(
loss=StepLossConfig(type="MSE"),
loss_contributions=LossContributionsConfig(n_steps=atmos_sampler),
),
)
train_stepper, coupled_data, _, _ = get_train_stepper_and_batch(
train_stepper_config=train_stepper_config,
ocean_in_names=["sst", "mask_0"],
ocean_out_names=["sst"],
atmosphere_in_names=["surface_temperature", "ocean_fraction"],
atmosphere_out_names=["surface_temperature"],
n_forward_times_ocean=n_forward_times_ocean,
n_forward_times_atmosphere=n_forward_times_atmosphere,
n_samples=3,
)
optimization = Mock(wraps=NullOptimization())
train_stepper.train_on_batch(
data=coupled_data.data,
optimization=optimization,
)
# atmos: n_steps=2, so steps 0 and 1 are optimized (out of 4 total)
# ocean: n_steps=1, so step 0 is optimized (out of 2 total)
expected_calls = 2 + 1
assert len(optimization.accumulate_loss.call_args_list) == expected_calls
Loading