Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions micro_sam/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,37 @@ def _get_trainer_fit_params(n_epochs, n_iterations, save_every_kth_epoch, pbar_s
return trainer_fit_params


def _get_mixed_precision_dtype(device: torch.device) -> torch.dtype:
"""Return the appropriate dtype for mixed precision training on the given device.

Uses fp16 on GPUs with Tensor Cores (which accumulate fp16 dot products in fp32, preventing
overflow). Falls back to bfloat16 on GPUs without Tensor Cores to avoid fp16 accumulation
overflow in large convolutions such as Conv2d(512, 512, 3x3) in the UNETR decoder.

Tensor Core presence by architecture:
- Ampere (sm_8.x) and newer: always present.
- Volta (sm_7.0): always present (V100).
- Turing (sm_7.5): present on T4, RTX 20xx, Quadro RTX; absent on GTX 16xx (TU116/TU117).
- Pascal (sm_6.x) and older: never present.

Note: torch.cuda.is_bf16_supported() is not used here because it reflects software-level
dtype support (which can return True on sm_7.5 with CUDA 12.x), not hardware Tensor Core
presence. The compute capability and device name are the correct discriminators.
"""
if not torch.cuda.is_available() or device.type != "cuda":
return torch.float16
major, _ = torch.cuda.get_device_capability(device)
if major >= 8:
return torch.float16 # Ampere and newer: Tensor Cores always present
if major < 7:
return torch.bfloat16 # Pascal and older: no Tensor Cores
# Volta (7.0) and Turing (7.5): Tensor Cores present on all cards except GTX 16xx series
name = torch.cuda.get_device_properties(device).name.lower()
if "gtx 16" in name:
return torch.bfloat16 # TU116/TU117: Turing without Tensor Cores
return torch.float16 # V100, T4, RTX 20xx, Quadro RTX: Tensor Cores present


def _get_optimizer_and_scheduler(model_params, lr, optimizer_class, scheduler_class, scheduler_kwargs):
optimizer = optimizer_class(model_params, lr=lr)
if scheduler_kwargs is None:
Expand Down Expand Up @@ -318,6 +349,7 @@ def train_sam(
)

# The trainer which performs training and validation.
mixed_precision_dtype = _get_mixed_precision_dtype(device)
if with_segmentation_decoder:
instance_seg_loss = torch_em.loss.DiceBasedDistanceLoss(mask_distances_in_bg=True)
trainer = joint_trainers.JointSamTrainer(
Expand All @@ -332,6 +364,7 @@ def train_sam(
logger=joint_trainers.JointSamLogger,
log_image_interval=100,
mixed_precision=True,
mixed_precision_dtype=mixed_precision_dtype,
convert_inputs=convert_inputs,
n_objects_per_batch=n_objects_per_batch,
n_sub_iteration=n_sub_iteration,
Expand All @@ -354,6 +387,7 @@ def train_sam(
logger=trainers.SamLogger,
log_image_interval=100,
mixed_precision=True,
mixed_precision_dtype=mixed_precision_dtype,
convert_inputs=convert_inputs,
n_objects_per_batch=n_objects_per_batch,
n_sub_iteration=n_sub_iteration,
Expand Down Expand Up @@ -521,6 +555,7 @@ def train_instance_segmentation(
val_loader=val_loader,
device=device,
mixed_precision=True,
mixed_precision_dtype=_get_mixed_precision_dtype(device),
log_image_interval=50,
compile_model=False,
save_root=save_root,
Expand Down