diff --git a/train_second.py b/train_second.py index ba3b989a..bbeff4e1 100644 --- a/train_second.py +++ b/train_second.py @@ -1,5 +1,6 @@ # load packages import copy +import os import random import yaml import time @@ -7,7 +8,13 @@ import numpy as np import torch -torch.autograd.set_detect_anomaly(True) +# Anomaly detection is a debugging tool — it traces every op in the autograd +# graph to surface NaN/inf sources, but it makes backward 5–10× slower, holds +# extra memory (causes OOMs that wouldn't otherwise happen), and on some +# single-GPU + bf16 setups it deadlocks the first backward pass entirely. +# Off by default; opt in via STYLETTS2_DETECT_ANOMALY=1 when diagnosing. +if os.environ.get("STYLETTS2_DETECT_ANOMALY") == "1": + torch.autograd.set_detect_anomaly(True) if getattr(torch, "_original_load", None) is None: torch._original_load = torch.load