Don't add unoptimized steps to computational graph in coupled training#1013
Don't add unoptimized steps to computational graph in coupled training#1013
Conversation
|
Slide here showing the reduction in GPU memory util. |
mcgibbon
left a comment
There was a problem hiding this comment.
Comment: Not something you need to change, but noting the separation of responsibilities is different between the coupled code and ace. In ace, the TrainStepper is responsible for deciding/knowing which steps should be optimized, keeping the loss object a simpler "gets the loss on a particular step" object. Here the loss defines the loss on a series of steps in a window, though the way it's called to compute the loss is still by passing particular steps.
I think this leads to more coupling between the train stepper and the loss, but also, I can see the feeling that because the window of losses is more complicated in the coupled case, it's nice to pull it out into a level other than the stepper.
fme/coupled/stepper.py
Outdated
| initial_condition: CoupledPrognosticState, | ||
| forcing_data: CoupledBatchData, | ||
| optimizer: OptimizationABC, | ||
| step_is_optimized: Callable[[str, int], bool] | None = None, |
There was a problem hiding this comment.
Issue: It took me a while to understand what this was doing, at first I mis-read below and thought that this argument overrides a default implementation that calls self.step_is_optimized, but then I noticed below there's no self..
Suggestion: I think the behavior would be clear and the logic below simpler if you made the default lambda n, c: True or something similar.
There was a problem hiding this comment.
| step_is_optimized: Callable[[str, int], bool] | None = None, | |
| step_is_optimized: Callable[[str, int], bool] = lambda n, c: None, |
mcgibbon
left a comment
There was a problem hiding this comment.
Approving pending the line suggestion or something similar.
Agreed. Will refactor as in the implementation in #868 when I get back to that PR. |
Avoid adding unoptimized steps (i.e., those where
LossContributionsConfigsettings result in 0 loss weight) to the computational graph by computing those steps withtorch.no_grad(). In a production job, this change resulted in ~13% decrease in GPU memory utilization.Changes:
Adds
step_is_optimized()helper method toCoupledStepperTrainLosswhich can be passed toCoupledStepper.get_prediction_generator()via its new argument of the same name.Tests added