diff --git a/micro_sam/training/training.py b/micro_sam/training/training.py index 62566eee2..e47630e80 100644 --- a/micro_sam/training/training.py +++ b/micro_sam/training/training.py @@ -185,6 +185,37 @@ def _get_trainer_fit_params(n_epochs, n_iterations, save_every_kth_epoch, pbar_s return trainer_fit_params +def _get_mixed_precision_dtype(device: torch.device) -> torch.dtype: + """Return the appropriate dtype for mixed precision training on the given device. + + Uses fp16 on GPUs with Tensor Cores (which accumulate fp16 dot products in fp32, preventing + overflow). Falls back to bfloat16 on GPUs without Tensor Cores to avoid fp16 accumulation + overflow in large convolutions such as Conv2d(512, 512, 3x3) in the UNETR decoder. + + Tensor Core presence by architecture: + - Ampere (sm_8.x) and newer: always present. + - Volta (sm_7.0): always present (V100). + - Turing (sm_7.5): present on T4, RTX 20xx, Quadro RTX; absent on GTX 16xx (TU116/TU117). + - Pascal (sm_6.x) and older: never present. + + Note: torch.cuda.is_bf16_supported() is not used here because it reflects software-level + dtype support (which can return True on sm_7.5 with CUDA 12.x), not hardware Tensor Core + presence. The compute capability and device name are the correct discriminators. + """ + if not torch.cuda.is_available() or device.type != "cuda": + return torch.float16 + major, _ = torch.cuda.get_device_capability(device) + if major >= 8: + return torch.float16 # Ampere and newer: Tensor Cores always present + if major < 7: + return torch.bfloat16 # Pascal and older: no Tensor Cores + # Volta (7.0) and Turing (7.5): Tensor Cores present on all cards except GTX 16xx series + name = torch.cuda.get_device_properties(device).name.lower() + if "gtx 16" in name: + return torch.bfloat16 # TU116/TU117: Turing without Tensor Cores + return torch.float16 # V100, T4, RTX 20xx, Quadro RTX: Tensor Cores present + + def _get_optimizer_and_scheduler(model_params, lr, optimizer_class, scheduler_class, scheduler_kwargs): optimizer = optimizer_class(model_params, lr=lr) if scheduler_kwargs is None: @@ -318,6 +349,7 @@ def train_sam( ) # The trainer which performs training and validation. + mixed_precision_dtype = _get_mixed_precision_dtype(device) if with_segmentation_decoder: instance_seg_loss = torch_em.loss.DiceBasedDistanceLoss(mask_distances_in_bg=True) trainer = joint_trainers.JointSamTrainer( @@ -332,6 +364,7 @@ def train_sam( logger=joint_trainers.JointSamLogger, log_image_interval=100, mixed_precision=True, + mixed_precision_dtype=mixed_precision_dtype, convert_inputs=convert_inputs, n_objects_per_batch=n_objects_per_batch, n_sub_iteration=n_sub_iteration, @@ -354,6 +387,7 @@ def train_sam( logger=trainers.SamLogger, log_image_interval=100, mixed_precision=True, + mixed_precision_dtype=mixed_precision_dtype, convert_inputs=convert_inputs, n_objects_per_batch=n_objects_per_batch, n_sub_iteration=n_sub_iteration, @@ -521,6 +555,7 @@ def train_instance_segmentation( val_loader=val_loader, device=device, mixed_precision=True, + mixed_precision_dtype=_get_mixed_precision_dtype(device), log_image_interval=50, compile_model=False, save_root=save_root,