fix(trainer): add mixed_precision_dtype parameter to DefaultTrainer#718
Open
chuyaowang wants to merge 1 commit into
Open
fix(trainer): add mixed_precision_dtype parameter to DefaultTrainer#718chuyaowang wants to merge 1 commit into
chuyaowang wants to merge 1 commit into
Conversation
Allows callers to select the autocast dtype for mixed precision training instead of the hardcoded torch.float16. Defaults to torch.float16 to preserve existing behaviour. - Add `mixed_precision_dtype: torch.dtype = torch.float16` parameter to `DefaultTrainer.__init__` with corresponding docstring entry - Store `self.mixed_precision_dtype` as an instance attribute so the Serializer can round-trip it through checkpoints - Create GradScaler only when dtype is float16; bfloat16 has fp32 range and does not require gradient scaling - Pass `dtype=self.mixed_precision_dtype` to the autocast context in `_train_epoch_mixed` and `_validate_mixed` - Select `_backprop` (no scaler) vs `_backprop_mixed` (with scaler) based on whether `self.scaler` is None, driven by the chosen dtype
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
The PRs mentioned in the issue
Allows callers to select the autocast dtype for mixed precision training instead of the hardcoded torch.float16. Defaults to torch.float16 to preserve existing behaviour.
mixed_precision_dtype: torch.dtype = torch.float16parameter toDefaultTrainer.__init__with corresponding docstring entryself.mixed_precision_dtypeas an instance attribute so the Serializer can round-trip it through checkpointsdtype=self.mixed_precision_dtypeto the autocast context in_train_epoch_mixedand_validate_mixed_backprop(no scaler) vs_backprop_mixed(with scaler) based on whetherself.scaleris None, driven by the chosen dtype