-
Notifications
You must be signed in to change notification settings - Fork 41
Add optimize_last_step_only to coupled LossContributionsConfig
#868
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
328c978
eb19d3d
8533d85
81f7329
2584ab1
e5131f4
604d22c
7c9e35c
1f39fb5
106c8fa
827dc04
1ceb8d6
138e82e
4f9585e
cec0022
dcb0cbd
b380c8f
8743338
76c13e1
c97c3cf
30ffd90
3722c55
d7e12a5
c6d7f35
d70af7d
f14e2e3
82fa43f
d5c00da
710e0ce
2576d7c
5fc4bed
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,7 +3,7 @@ | |
| import datetime | ||
| import logging | ||
| import pathlib | ||
| from collections.abc import Callable, Generator, Iterable | ||
| from collections.abc import Generator, Iterable | ||
| from typing import Any, Literal | ||
|
|
||
| import dacite | ||
|
|
@@ -1022,7 +1022,6 @@ def get_prediction_generator( | |
| initial_condition: CoupledPrognosticState, | ||
| forcing_data: CoupledBatchData, | ||
| optimizer: OptimizationABC, | ||
| step_is_optimized: Callable[[str, int], bool] = lambda n, c: True, | ||
| ) -> Generator[ComponentStepPrediction, None, None]: | ||
| if ( | ||
| initial_condition.atmosphere_data.as_batch_data().n_timesteps | ||
|
|
@@ -1082,10 +1081,7 @@ def get_prediction_generator( | |
| # predict and yield atmosphere steps | ||
| for i_inner in range(self.n_inner_steps): | ||
| atmos_step_num = i_outer * self.n_inner_steps + i_inner | ||
| optimized = step_is_optimized("atmosphere", atmos_step_num) | ||
| context = contextlib.nullcontext() if optimized else torch.no_grad() | ||
| with context: | ||
| atmos_step = next(atmos_generator) | ||
| atmos_step = next(atmos_generator) | ||
| yield ComponentStepPrediction( | ||
| realm="atmosphere", | ||
| data=atmos_step, | ||
|
|
@@ -1116,21 +1112,16 @@ def get_prediction_generator( | |
| labels=ocean_window.labels, | ||
| ) | ||
| # predict and yield a single ocean step | ||
| ocean_optimized = step_is_optimized("ocean", i_outer) | ||
| ocean_context = ( | ||
| contextlib.nullcontext() if ocean_optimized else torch.no_grad() | ||
| ) | ||
| with ocean_context: | ||
| ocean_step = next( | ||
| iter( | ||
| self.ocean.get_prediction_generator( | ||
| ocean_ic_state, | ||
| ocean_forcings, | ||
| n_forward_steps=1, | ||
| optimizer=optimizer, | ||
| ) | ||
| ocean_step = next( | ||
| iter( | ||
| self.ocean.get_prediction_generator( | ||
| ocean_ic_state, | ||
| ocean_forcings, | ||
| n_forward_steps=1, | ||
| optimizer=optimizer, | ||
| ) | ||
| ) | ||
| ) | ||
| yield ComponentStepPrediction( | ||
| realm="ocean", | ||
| data=ocean_step, | ||
|
|
@@ -1390,7 +1381,11 @@ def effective_loss_scaling(self) -> CoupledTensorMapping: | |
| atmosphere=self._loss_objs["atmosphere"].effective_loss_scaling, | ||
| ) | ||
|
|
||
| def step_is_optimized(self, realm: str, step: int) -> bool: | ||
| def step_is_optimized( | ||
| self, | ||
| realm: Literal["ocean", "atmosphere"], | ||
| step: int, | ||
| ) -> bool: | ||
| return self._loss_objs[realm].step_is_optimized(step) | ||
|
|
||
| def __call__( | ||
|
|
@@ -1420,6 +1415,7 @@ class CoupledTrainStepperConfig: | |
| """Configuration for training-specific aspects of a coupled stepper. | ||
|
|
||
| Parameters: | ||
| n_coupled_steps: Number of forward coupled steps in the optimization. | ||
| ocean: The configuration for the ocean component. | ||
| atmosphere: The configuration for the atmosphere component. | ||
| n_ensemble: The number of ensemble members evaluated for each training | ||
|
|
@@ -1430,6 +1426,7 @@ class CoupledTrainStepperConfig: | |
| fine-tuning a previously-trained coupled stepper. | ||
| """ | ||
|
|
||
| n_coupled_steps: int | ||
| ocean: ComponentTrainingConfig | ||
| atmosphere: ComponentTrainingConfig | ||
| n_ensemble: int = -1 # sentinel value to avoid None typing of attribute | ||
|
|
@@ -1463,14 +1460,18 @@ def __post_init__(self): | |
| else: | ||
| self.n_ensemble = 1 | ||
|
|
||
| def _build_loss(self, stepper: CoupledStepper) -> CoupledStepperTrainLoss: | ||
| def _build_loss( | ||
| self, stepper: CoupledStepper, n_coupled_steps: int | ||
| ) -> CoupledStepperTrainLoss: | ||
| ocean_step_loss = stepper.ocean.build_loss(self.ocean.loss) | ||
| atmos_step_loss = stepper.atmosphere.build_loss(self.atmosphere.loss) | ||
| max_n_steps_ocean = n_coupled_steps | ||
| max_n_steps_atmos = n_coupled_steps * stepper.n_inner_steps | ||
| ocean_loss = self.ocean.loss_contributions.build( | ||
| ocean_step_loss, stepper.ocean.TIME_DIM | ||
| ocean_step_loss, stepper.ocean.TIME_DIM, max_n_steps=max_n_steps_ocean | ||
| ) | ||
| atmos_loss = self.atmosphere.loss_contributions.build( | ||
| atmos_step_loss, stepper.atmosphere.TIME_DIM | ||
| atmos_step_loss, stepper.atmosphere.TIME_DIM, max_n_steps=max_n_steps_atmos | ||
| ) | ||
| return CoupledStepperTrainLoss(ocean_loss, atmos_loss) | ||
|
|
||
|
|
@@ -1552,7 +1553,7 @@ def __init__( | |
| """ | ||
| self._stepper = stepper | ||
| self._config = config | ||
| self._loss = self._config._build_loss(stepper) | ||
| self._loss = self._config._build_loss(stepper, config.n_coupled_steps) | ||
|
|
||
| @property | ||
| def ocean(self) -> Stepper: | ||
|
|
@@ -1604,6 +1605,36 @@ def load_state(self, state: dict[str, Any]): | |
| def update_training_history(self, training_job: TrainingJob) -> None: | ||
| self._stepper.update_training_history(training_job) | ||
|
|
||
| def _accumulate_step_loss( | ||
| self, | ||
| gen_step: ComponentStepPrediction, | ||
| forward_data: TensorMapping, | ||
| time_dim: int, | ||
| n_ensemble: int, | ||
| optimization: OptimizationABC, | ||
| metrics: ComponentStepMetrics, | ||
| output_list: list[ComponentEnsembleStepPrediction], | ||
| ) -> None: | ||
| target_step = { | ||
| k: v.select(time_dim, gen_step.step) for k, v in forward_data.items() | ||
| } | ||
| ensemble_step = ComponentEnsembleStepPrediction( | ||
| realm=gen_step.realm, | ||
| data=unfold_ensemble_dim(gen_step.data, n_ensemble), | ||
| step=gen_step.step, | ||
| ) | ||
| target_step_ensemble = add_ensemble_dim(target_step) | ||
| step_loss = self._loss(ensemble_step, target_step_ensemble) | ||
| if step_loss is not None: | ||
| label = f"loss/{gen_step.realm}_step_{gen_step.step}" | ||
| metrics.add_metric(label, step_loss.detach(), gen_step.realm) | ||
| optimization.accumulate_loss(step_loss) | ||
| output_list.append( | ||
| ensemble_step.detach_if_using_gradient_accumulation( | ||
| optimization | ||
| ) # eagerly detach | ||
| ) | ||
|
|
||
| def _accumulate_loss( | ||
| self, | ||
| data: CoupledBatchData, | ||
|
|
@@ -1630,36 +1661,48 @@ def _accumulate_loss( | |
| input_data, | ||
| data_ensemble, | ||
| optimization, | ||
| step_is_optimized=self._loss.step_is_optimized, | ||
| ) | ||
| output_iterator = iter(output_generator) | ||
| output_list: list[ComponentEnsembleStepPrediction] = [] | ||
| for gen_step in output_generator: | ||
| if gen_step.realm == "ocean": | ||
| target_step = { | ||
| k: v.select(self.ocean.TIME_DIM, gen_step.step) | ||
| for k, v in ocean_forward_data.data.items() | ||
| } | ||
| else: | ||
| target_step = { | ||
| k: v.select(self.atmosphere.TIME_DIM, gen_step.step) | ||
| for k, v in atmos_forward_data.data.items() | ||
| } | ||
| ensemble_step = ComponentEnsembleStepPrediction( | ||
| realm=gen_step.realm, | ||
| data=unfold_ensemble_dim(gen_step.data, n_ensemble), | ||
| step=gen_step.step, | ||
| ) | ||
| target_step_ensemble = add_ensemble_dim(target_step) | ||
| step_loss = self._loss(ensemble_step, target_step_ensemble) | ||
| if step_loss is not None: | ||
| label = f"loss/{gen_step.realm}_step_{gen_step.step}" | ||
| metrics.add_metric(label, step_loss.detach(), gen_step.realm) | ||
| optimization.accumulate_loss(step_loss) | ||
| output_list.append( | ||
| ensemble_step.detach_if_using_gradient_accumulation( | ||
| optimization | ||
| ) # eagerly detach | ||
| ) | ||
| n_outer_steps = data.ocean_data.n_timesteps - self.n_ic_timesteps | ||
| for i_outer in range(n_outer_steps): | ||
| for i_inner in range(self.n_inner_steps): | ||
| global_atmos_step = i_outer * self.n_inner_steps + i_inner | ||
| optimize = self._loss.step_is_optimized( | ||
| "atmosphere", | ||
| global_atmos_step, | ||
| ) | ||
| grad_context = contextlib.nullcontext() if optimize else torch.no_grad() | ||
| with grad_context: | ||
| gen_step = next(output_iterator) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider adding a check like: assert gen_step.realm == "atmosphere" and gen_step.step == global_atmos_stepsince we now assume the generator always does
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added |
||
| assert ( | ||
| gen_step.realm == "atmosphere" | ||
| and gen_step.step == global_atmos_step | ||
| ) | ||
| self._accumulate_step_loss( | ||
| gen_step=gen_step, | ||
| forward_data=atmos_forward_data.data, | ||
| time_dim=self.atmosphere.TIME_DIM, | ||
| n_ensemble=n_ensemble, | ||
| optimization=optimization, | ||
| metrics=metrics, | ||
| output_list=output_list, | ||
| ) | ||
| optimize = self._loss.step_is_optimized("ocean", i_outer) | ||
| grad_context = contextlib.nullcontext() if optimize else torch.no_grad() | ||
| with grad_context: | ||
| gen_step = next(output_iterator) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as the inner loop, also check here: assert gen_step.realm == "ocean" and gen_step.step == i_outer
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added |
||
| assert gen_step.realm == "ocean" and gen_step.step == i_outer | ||
| self._accumulate_step_loss( | ||
| gen_step=gen_step, | ||
| forward_data=ocean_forward_data.data, | ||
| time_dim=self.ocean.TIME_DIM, | ||
| n_ensemble=n_ensemble, | ||
| optimization=optimization, | ||
| metrics=metrics, | ||
| output_list=output_list, | ||
| ) | ||
|
|
||
| return output_list | ||
|
|
||
| def train_on_batch( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was previously handled in
CoupledStepper.get_prediction_generator()but refactored to be handled here inCoupledTrainStepper._accumulate_loss().