Skip to content

fix(training): auto-select bfloat16 on GPUs without Tensor Cores#1234

Open
chuyaowang wants to merge 1 commit into
computational-cell-analytics:mainfrom
chuyaowang:fix/mixed-precision-dtype
Open

fix(training): auto-select bfloat16 on GPUs without Tensor Cores#1234
chuyaowang wants to merge 1 commit into
computational-cell-analytics:mainfrom
chuyaowang:fix/mixed-precision-dtype

Conversation

@chuyaowang

Copy link
Copy Markdown

The PRs mentioned in the issue #1214. Work with this PR for torch-em

fp16 mixed precision causes NaN in Conv2d(512, 512, 3x3) inside the UNETR decoder on GPUs that lack Tensor Cores (e.g. GTX 1660 Ti). Without Tensor Cores, cuDNN accumulates the 4608-element dot product in fp16, overflowing its ceiling of ~65504. bfloat16 avoids this by sharing fp32's exponent range. Confirmed by forward hook diagnostic and deterministic mode test ruling out Winograd as the specific cause.

  • Add _get_mixed_precision_dtype(device) helper that returns torch.float16 for GPUs with Tensor Cores and torch.bfloat16 otherwise
  • Detect Tensor Core presence via compute capability and device name: Ampere+ (sm_8.x+) always fp16; Volta/Turing (sm_7.x) fp16 except GTX 16xx series (TU116/TU117); Pascal and below bfloat16
  • Avoid torch.cuda.is_bf16_supported() as it reflects software-level dtype support and returns True on sm_7.5 with CUDA 12.x regardless of Tensor Core presence
  • Wire dtype through train_sam (JointSamTrainer and SamTrainer) and train_instance_segmentation (DefaultTrainer) via the new mixed_precision_dtype parameter added to torch_em DefaultTrainer

fp16 mixed precision causes NaN in Conv2d(512, 512, 3x3) inside the
UNETR decoder on GPUs that lack Tensor Cores (e.g. GTX 1660 Ti). Without
Tensor Cores, cuDNN accumulates the 4608-element dot product in fp16,
overflowing its ceiling of ~65504. bfloat16 avoids this by sharing fp32's
exponent range. Confirmed by forward hook diagnostic and deterministic
mode test ruling out Winograd as the specific cause.

- Add `_get_mixed_precision_dtype(device)` helper that returns
  torch.float16 for GPUs with Tensor Cores and torch.bfloat16 otherwise
- Detect Tensor Core presence via compute capability and device name:
  Ampere+ (sm_8.x+) always fp16; Volta/Turing (sm_7.x) fp16 except
  GTX 16xx series (TU116/TU117); Pascal and below bfloat16
- Avoid torch.cuda.is_bf16_supported() as it reflects software-level
  dtype support and returns True on sm_7.5 with CUDA 12.x regardless of
  Tensor Core presence
- Wire dtype through train_sam (JointSamTrainer and SamTrainer) and
  train_instance_segmentation (DefaultTrainer) via the new
  mixed_precision_dtype parameter added to torch_em DefaultTrainer
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant