fix(training): auto-select bfloat16 on GPUs without Tensor Cores#1234
Open
chuyaowang wants to merge 1 commit into
Open
fix(training): auto-select bfloat16 on GPUs without Tensor Cores#1234chuyaowang wants to merge 1 commit into
chuyaowang wants to merge 1 commit into
Conversation
fp16 mixed precision causes NaN in Conv2d(512, 512, 3x3) inside the UNETR decoder on GPUs that lack Tensor Cores (e.g. GTX 1660 Ti). Without Tensor Cores, cuDNN accumulates the 4608-element dot product in fp16, overflowing its ceiling of ~65504. bfloat16 avoids this by sharing fp32's exponent range. Confirmed by forward hook diagnostic and deterministic mode test ruling out Winograd as the specific cause. - Add `_get_mixed_precision_dtype(device)` helper that returns torch.float16 for GPUs with Tensor Cores and torch.bfloat16 otherwise - Detect Tensor Core presence via compute capability and device name: Ampere+ (sm_8.x+) always fp16; Volta/Turing (sm_7.x) fp16 except GTX 16xx series (TU116/TU117); Pascal and below bfloat16 - Avoid torch.cuda.is_bf16_supported() as it reflects software-level dtype support and returns True on sm_7.5 with CUDA 12.x regardless of Tensor Core presence - Wire dtype through train_sam (JointSamTrainer and SamTrainer) and train_instance_segmentation (DefaultTrainer) via the new mixed_precision_dtype parameter added to torch_em DefaultTrainer
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
The PRs mentioned in the issue #1214. Work with this PR for torch-em
fp16 mixed precision causes NaN in Conv2d(512, 512, 3x3) inside the UNETR decoder on GPUs that lack Tensor Cores (e.g. GTX 1660 Ti). Without Tensor Cores, cuDNN accumulates the 4608-element dot product in fp16, overflowing its ceiling of ~65504. bfloat16 avoids this by sharing fp32's exponent range. Confirmed by forward hook diagnostic and deterministic mode test ruling out Winograd as the specific cause.
_get_mixed_precision_dtype(device)helper that returns torch.float16 for GPUs with Tensor Cores and torch.bfloat16 otherwise