From dcade2b8647a1651564825f9db192e921d3199f4 Mon Sep 17 00:00:00 2001 From: chuyaowang Date: Thu, 11 Jun 2026 15:58:13 +0200 Subject: [PATCH] fix(training): auto-select bfloat16 on GPUs without Tensor Cores fp16 mixed precision causes NaN in Conv2d(512, 512, 3x3) inside the UNETR decoder on GPUs that lack Tensor Cores (e.g. GTX 1660 Ti). Without Tensor Cores, cuDNN accumulates the 4608-element dot product in fp16, overflowing its ceiling of ~65504. bfloat16 avoids this by sharing fp32's exponent range. Confirmed by forward hook diagnostic and deterministic mode test ruling out Winograd as the specific cause. - Add `_get_mixed_precision_dtype(device)` helper that returns torch.float16 for GPUs with Tensor Cores and torch.bfloat16 otherwise - Detect Tensor Core presence via compute capability and device name: Ampere+ (sm_8.x+) always fp16; Volta/Turing (sm_7.x) fp16 except GTX 16xx series (TU116/TU117); Pascal and below bfloat16 - Avoid torch.cuda.is_bf16_supported() as it reflects software-level dtype support and returns True on sm_7.5 with CUDA 12.x regardless of Tensor Core presence - Wire dtype through train_sam (JointSamTrainer and SamTrainer) and train_instance_segmentation (DefaultTrainer) via the new mixed_precision_dtype parameter added to torch_em DefaultTrainer --- micro_sam/training/training.py | 35 ++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/micro_sam/training/training.py b/micro_sam/training/training.py index 62566eee2..e47630e80 100644 --- a/micro_sam/training/training.py +++ b/micro_sam/training/training.py @@ -185,6 +185,37 @@ def _get_trainer_fit_params(n_epochs, n_iterations, save_every_kth_epoch, pbar_s return trainer_fit_params +def _get_mixed_precision_dtype(device: torch.device) -> torch.dtype: + """Return the appropriate dtype for mixed precision training on the given device. + + Uses fp16 on GPUs with Tensor Cores (which accumulate fp16 dot products in fp32, preventing + overflow). Falls back to bfloat16 on GPUs without Tensor Cores to avoid fp16 accumulation + overflow in large convolutions such as Conv2d(512, 512, 3x3) in the UNETR decoder. + + Tensor Core presence by architecture: + - Ampere (sm_8.x) and newer: always present. + - Volta (sm_7.0): always present (V100). + - Turing (sm_7.5): present on T4, RTX 20xx, Quadro RTX; absent on GTX 16xx (TU116/TU117). + - Pascal (sm_6.x) and older: never present. + + Note: torch.cuda.is_bf16_supported() is not used here because it reflects software-level + dtype support (which can return True on sm_7.5 with CUDA 12.x), not hardware Tensor Core + presence. The compute capability and device name are the correct discriminators. + """ + if not torch.cuda.is_available() or device.type != "cuda": + return torch.float16 + major, _ = torch.cuda.get_device_capability(device) + if major >= 8: + return torch.float16 # Ampere and newer: Tensor Cores always present + if major < 7: + return torch.bfloat16 # Pascal and older: no Tensor Cores + # Volta (7.0) and Turing (7.5): Tensor Cores present on all cards except GTX 16xx series + name = torch.cuda.get_device_properties(device).name.lower() + if "gtx 16" in name: + return torch.bfloat16 # TU116/TU117: Turing without Tensor Cores + return torch.float16 # V100, T4, RTX 20xx, Quadro RTX: Tensor Cores present + + def _get_optimizer_and_scheduler(model_params, lr, optimizer_class, scheduler_class, scheduler_kwargs): optimizer = optimizer_class(model_params, lr=lr) if scheduler_kwargs is None: @@ -318,6 +349,7 @@ def train_sam( ) # The trainer which performs training and validation. + mixed_precision_dtype = _get_mixed_precision_dtype(device) if with_segmentation_decoder: instance_seg_loss = torch_em.loss.DiceBasedDistanceLoss(mask_distances_in_bg=True) trainer = joint_trainers.JointSamTrainer( @@ -332,6 +364,7 @@ def train_sam( logger=joint_trainers.JointSamLogger, log_image_interval=100, mixed_precision=True, + mixed_precision_dtype=mixed_precision_dtype, convert_inputs=convert_inputs, n_objects_per_batch=n_objects_per_batch, n_sub_iteration=n_sub_iteration, @@ -354,6 +387,7 @@ def train_sam( logger=trainers.SamLogger, log_image_interval=100, mixed_precision=True, + mixed_precision_dtype=mixed_precision_dtype, convert_inputs=convert_inputs, n_objects_per_batch=n_objects_per_batch, n_sub_iteration=n_sub_iteration, @@ -521,6 +555,7 @@ def train_instance_segmentation( val_loader=val_loader, device=device, mixed_precision=True, + mixed_precision_dtype=_get_mixed_precision_dtype(device), log_image_interval=50, compile_model=False, save_root=save_root,