Skip to content

Allow stochastic CoupledStepper training#750

Merged
jpdunc23 merged 14 commits intomainfrom
coupled-ft-with-stochastic-atmos
Apr 9, 2026
Merged

Allow stochastic CoupledStepper training#750
jpdunc23 merged 14 commits intomainfrom
coupled-ft-with-stochastic-atmos

Conversation

@jpdunc23
Copy link
Copy Markdown
Member

@jpdunc23 jpdunc23 commented Jan 21, 2026

Adds n_ensemble to CoupledTrainStepperConfig and handling for the ensemble dimension in CoupledTrainStepper.

Changes:

  • CoupledTrainStepper init now has config: CoupledTrainStepperConfig as an arg rather than loss: CoupledStepperTrainLoss.

  • Tests added

@jpdunc23 jpdunc23 changed the title Allow stochastic atmos training in CoupledStepper Allow stochastic CoupledStepper training Apr 2, 2026
@jpdunc23 jpdunc23 marked this pull request as ready for review April 6, 2026 21:11
gen_data=ocean_gen_data,
target_data=add_ensemble_dim(dict(ocean_forward_data.data)),
time=gen_data.ocean_data.time,
time=ocean_forward_data.time,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this equivalent to gen_data.ocean_data.time when no ensemble is used?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ya, the target data is ocean_forward_data so this is just a different but equivalent source for the time dim.

"""
atmos_gen_data = process_ensemble_prediction_generator_list(
[
EnsembleTensorDict(x.data) # FIXME: fix output_list typing
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe just cast x.data to EnsembleTensorDict?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactored things a bit to get the typing in line with fme/ace/stepper/single_module.py. CoupledTrainStepper now deals with ComponentEnsembleStepPrediction, while CoupledStepper uses ComponentStepPrediction.

gen_data=atmos_gen_data,
target_data=add_ensemble_dim(dict(atmos_forward_data.data)),
time=gen_data.atmosphere_data.time,
time=atmos_forward_data.time, # Use original time (not broadcasted)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So is this the reason for the switch to forward data instead of gen for ocean too?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main reason is that there is no more gen_data and ocean_gen_data and atmos_gen_data here are no longer BatchData objects, and getting the time attribute from the ocean_forward_data or atmos_forward_data is equivalent to what we had before (note that this is happening before stepped = stepped.prepend_initial_condition(ic).

self._stepper = stepper
self._loss = loss
self._config = config
self._loss = self._config._build_loss(stepper)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] This would make CoupledTrainStepper harder to build in tests since you now need the full config, you can still just build the loss in get_train_stepper right?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes to the tests are luckily minor since we were already building the CoupledTrainStepperConfig and using its get_train_stepper() method to build the CoupledTrainStepper. I can see some benefit to letting the loss be built outside of the CoupledTrainStepper init, but on the other hand this makes the init args identical to the fme.ace.TrainStepper init and avoids having to add an n_ensemble: int arg, both of which are nice. Don't feel super strongly one way or another.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, let's keep as it is to be consistent with fme.ace.TrainStepper

# Ensemble support: broadcast atmosphere data for ensemble training
n_ensemble = self._config.n_ensemble
atmos_data_ensemble = data.atmosphere_data.broadcast_ensemble(n_ensemble)
ocean_data_ensemble = data.ocean_data.broadcast_ensemble(n_ensemble)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that you broadcast ocean data as well even though we currently only support training stochastic atmosphere, and the stochastic losses are propagated to ocean via the surface forcing variables. Can you add a short description in n_ensemble?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added some info about stochastic training assumptions to the CoupledTrainStepper docstring. Lmk if this is more clear now.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, the docs are good.

Comment on lines +1658 to +1667
else:
gen_data = self._stepper._process_prediction_generator_list(
output_list, data_ensemble
)
ocean_gen_data = unfold_ensemble_dim(
dict(gen_data.ocean_data.data), n_ensemble=1
)
atmos_gen_data = unfold_ensemble_dim(
dict(gen_data.atmosphere_data.data), n_ensemble=1
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be better if _process_ensemble_prediction_generator_list also accepts n_ensemble=1 so you don't need to have if/else here.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was able to remove these if n_ensemble > 1: blocks.

Comment on lines -758 to -764
def detach(self, optimizer: OptimizationABC) -> "ComponentStepPrediction":
"""Detach the data tensor map from the computation graph."""
return ComponentStepPrediction(
realm=self.realm,
data=optimizer.detach_if_using_gradient_accumulation(self.data),
step=self.step,
)
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ComponentStepPrediction is not returned by the CoupledTrainStepper methods now, so it doesn't need this detach() method.

Comment on lines -767 to -796
class CoupledStepperTrainLoss:
def __init__(
self,
ocean_loss: StepLossABC,
atmosphere_loss: StepLossABC,
):
self._loss_objs = {
"ocean": ocean_loss,
"atmosphere": atmosphere_loss,
}

@property
def effective_loss_scaling(self) -> CoupledTensorMapping:
return CoupledTensorMapping(
ocean=self._loss_objs["ocean"].effective_loss_scaling,
atmosphere=self._loss_objs["atmosphere"].effective_loss_scaling,
)

def step_is_optimized(self, realm: str, step: int) -> bool:
return self._loss_objs[realm].step_is_optimized(step)

def __call__(
self,
prediction: ComponentStepPrediction,
target_data: TensorMapping,
) -> torch.Tensor | None:
loss_obj = self._loss_objs[prediction.realm]
if loss_obj.step_is_optimized(prediction.step):
return loss_obj(prediction, target_data)
return None
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved below for a bit better organization. The only change is to the typing of the prediction arg to __call__, which is now prediction: ComponentEnsembleStepPrediction.

def step(self) -> int:
return self._step

def detach_if_using_gradient_accumulation(
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This replaces ComponentStepPrediction.detach()

@jpdunc23 jpdunc23 requested a review from elynnwu April 9, 2026 15:58
@jpdunc23 jpdunc23 enabled auto-merge (squash) April 9, 2026 16:49
@jpdunc23 jpdunc23 merged commit 366f117 into main Apr 9, 2026
7 checks passed
@jpdunc23 jpdunc23 deleted the coupled-ft-with-stochastic-atmos branch April 9, 2026 16:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants