Skip to content

fix(trainer): add mixed_precision_dtype parameter to DefaultTrainer#718

Open
chuyaowang wants to merge 1 commit into
constantinpape:mainfrom
chuyaowang:fix/mixed-precision-dtype
Open

fix(trainer): add mixed_precision_dtype parameter to DefaultTrainer#718
chuyaowang wants to merge 1 commit into
constantinpape:mainfrom
chuyaowang:fix/mixed-precision-dtype

Conversation

@chuyaowang

Copy link
Copy Markdown

The PRs mentioned in the issue

Allows callers to select the autocast dtype for mixed precision training instead of the hardcoded torch.float16. Defaults to torch.float16 to preserve existing behaviour.

  • Add mixed_precision_dtype: torch.dtype = torch.float16 parameter to DefaultTrainer.__init__ with corresponding docstring entry
  • Store self.mixed_precision_dtype as an instance attribute so the Serializer can round-trip it through checkpoints
  • Create GradScaler only when dtype is float16; bfloat16 has fp32 range and does not require gradient scaling
  • Pass dtype=self.mixed_precision_dtype to the autocast context in _train_epoch_mixed and _validate_mixed
  • Select _backprop (no scaler) vs _backprop_mixed (with scaler) based on whether self.scaler is None, driven by the chosen dtype

Allows callers to select the autocast dtype for mixed precision training
instead of the hardcoded torch.float16. Defaults to torch.float16 to
preserve existing behaviour.

- Add `mixed_precision_dtype: torch.dtype = torch.float16` parameter to
  `DefaultTrainer.__init__` with corresponding docstring entry
- Store `self.mixed_precision_dtype` as an instance attribute so the
  Serializer can round-trip it through checkpoints
- Create GradScaler only when dtype is float16; bfloat16 has fp32 range
  and does not require gradient scaling
- Pass `dtype=self.mixed_precision_dtype` to the autocast context in
  `_train_epoch_mixed` and `_validate_mixed`
- Select `_backprop` (no scaler) vs `_backprop_mixed` (with scaler)
  based on whether `self.scaler` is None, driven by the chosen dtype
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant