Allow stochastic CoupledStepper training#750
Conversation
CoupledStepperCoupledStepper training
| gen_data=ocean_gen_data, | ||
| target_data=add_ensemble_dim(dict(ocean_forward_data.data)), | ||
| time=gen_data.ocean_data.time, | ||
| time=ocean_forward_data.time, |
There was a problem hiding this comment.
Is this equivalent to gen_data.ocean_data.time when no ensemble is used?
There was a problem hiding this comment.
Ya, the target data is ocean_forward_data so this is just a different but equivalent source for the time dim.
fme/coupled/stepper.py
Outdated
| """ | ||
| atmos_gen_data = process_ensemble_prediction_generator_list( | ||
| [ | ||
| EnsembleTensorDict(x.data) # FIXME: fix output_list typing |
There was a problem hiding this comment.
Maybe just cast x.data to EnsembleTensorDict?
There was a problem hiding this comment.
Refactored things a bit to get the typing in line with fme/ace/stepper/single_module.py. CoupledTrainStepper now deals with ComponentEnsembleStepPrediction, while CoupledStepper uses ComponentStepPrediction.
fme/coupled/stepper.py
Outdated
| gen_data=atmos_gen_data, | ||
| target_data=add_ensemble_dim(dict(atmos_forward_data.data)), | ||
| time=gen_data.atmosphere_data.time, | ||
| time=atmos_forward_data.time, # Use original time (not broadcasted) |
There was a problem hiding this comment.
So is this the reason for the switch to forward data instead of gen for ocean too?
There was a problem hiding this comment.
The main reason is that there is no more gen_data and ocean_gen_data and atmos_gen_data here are no longer BatchData objects, and getting the time attribute from the ocean_forward_data or atmos_forward_data is equivalent to what we had before (note that this is happening before stepped = stepped.prepend_initial_condition(ic).
| self._stepper = stepper | ||
| self._loss = loss | ||
| self._config = config | ||
| self._loss = self._config._build_loss(stepper) |
There was a problem hiding this comment.
[nit] This would make CoupledTrainStepper harder to build in tests since you now need the full config, you can still just build the loss in get_train_stepper right?
There was a problem hiding this comment.
The changes to the tests are luckily minor since we were already building the CoupledTrainStepperConfig and using its get_train_stepper() method to build the CoupledTrainStepper. I can see some benefit to letting the loss be built outside of the CoupledTrainStepper init, but on the other hand this makes the init args identical to the fme.ace.TrainStepper init and avoids having to add an n_ensemble: int arg, both of which are nice. Don't feel super strongly one way or another.
There was a problem hiding this comment.
Okay, let's keep as it is to be consistent with fme.ace.TrainStepper
fme/coupled/stepper.py
Outdated
| # Ensemble support: broadcast atmosphere data for ensemble training | ||
| n_ensemble = self._config.n_ensemble | ||
| atmos_data_ensemble = data.atmosphere_data.broadcast_ensemble(n_ensemble) | ||
| ocean_data_ensemble = data.ocean_data.broadcast_ensemble(n_ensemble) |
There was a problem hiding this comment.
My understanding is that you broadcast ocean data as well even though we currently only support training stochastic atmosphere, and the stochastic losses are propagated to ocean via the surface forcing variables. Can you add a short description in n_ensemble?
There was a problem hiding this comment.
I added some info about stochastic training assumptions to the CoupledTrainStepper docstring. Lmk if this is more clear now.
fme/coupled/stepper.py
Outdated
| else: | ||
| gen_data = self._stepper._process_prediction_generator_list( | ||
| output_list, data_ensemble | ||
| ) | ||
| ocean_gen_data = unfold_ensemble_dim( | ||
| dict(gen_data.ocean_data.data), n_ensemble=1 | ||
| ) | ||
| atmos_gen_data = unfold_ensemble_dim( | ||
| dict(gen_data.atmosphere_data.data), n_ensemble=1 | ||
| ) |
There was a problem hiding this comment.
Would be better if _process_ensemble_prediction_generator_list also accepts n_ensemble=1 so you don't need to have if/else here.
There was a problem hiding this comment.
Was able to remove these if n_ensemble > 1: blocks.
| def detach(self, optimizer: OptimizationABC) -> "ComponentStepPrediction": | ||
| """Detach the data tensor map from the computation graph.""" | ||
| return ComponentStepPrediction( | ||
| realm=self.realm, | ||
| data=optimizer.detach_if_using_gradient_accumulation(self.data), | ||
| step=self.step, | ||
| ) |
There was a problem hiding this comment.
ComponentStepPrediction is not returned by the CoupledTrainStepper methods now, so it doesn't need this detach() method.
| class CoupledStepperTrainLoss: | ||
| def __init__( | ||
| self, | ||
| ocean_loss: StepLossABC, | ||
| atmosphere_loss: StepLossABC, | ||
| ): | ||
| self._loss_objs = { | ||
| "ocean": ocean_loss, | ||
| "atmosphere": atmosphere_loss, | ||
| } | ||
|
|
||
| @property | ||
| def effective_loss_scaling(self) -> CoupledTensorMapping: | ||
| return CoupledTensorMapping( | ||
| ocean=self._loss_objs["ocean"].effective_loss_scaling, | ||
| atmosphere=self._loss_objs["atmosphere"].effective_loss_scaling, | ||
| ) | ||
|
|
||
| def step_is_optimized(self, realm: str, step: int) -> bool: | ||
| return self._loss_objs[realm].step_is_optimized(step) | ||
|
|
||
| def __call__( | ||
| self, | ||
| prediction: ComponentStepPrediction, | ||
| target_data: TensorMapping, | ||
| ) -> torch.Tensor | None: | ||
| loss_obj = self._loss_objs[prediction.realm] | ||
| if loss_obj.step_is_optimized(prediction.step): | ||
| return loss_obj(prediction, target_data) | ||
| return None |
There was a problem hiding this comment.
Moved below for a bit better organization. The only change is to the typing of the prediction arg to __call__, which is now prediction: ComponentEnsembleStepPrediction.
| def step(self) -> int: | ||
| return self._step | ||
|
|
||
| def detach_if_using_gradient_accumulation( |
There was a problem hiding this comment.
This replaces ComponentStepPrediction.detach()
Adds
n_ensembletoCoupledTrainStepperConfigand handling for the ensemble dimension inCoupledTrainStepper.Changes:
CoupledTrainStepperinit now hasconfig: CoupledTrainStepperConfigas an arg rather thanloss: CoupledStepperTrainLoss.Tests added