From b6571415053998b0ca84b9fffc32a1d6435703f9 Mon Sep 17 00:00:00 2001 From: Shreyas Karnik Date: Sat, 25 Apr 2026 14:17:23 -0700 Subject: [PATCH] gate set_detect_anomaly behind STYLETTS2_DETECT_ANOMALY env var MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Currently train_second.py turns torch.autograd.set_detect_anomaly(True) on unconditionally at module import time. That's a debugging tool, not a production setting: * 5–10× slower backward (anomaly detection traces every op). * Materially more memory held (causes OOMs that wouldn't happen otherwise). * On single-A100 + bf16, can deadlock the first backward pass entirely (no error, just stuck) — this bit us on a Marathi fine-tune Stage 2 run until we manually patched the file. Make it opt-in via an env var. Default off; set STYLETTS2_DETECT_ANOMALY=1 when diagnosing a NaN/inf to get the same behaviour as before. No behaviour change for users who explicitly want anomaly detection. Diagnosing-by-default is the regression we're fixing. --- train_second.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/train_second.py b/train_second.py index ba3b989a..bbeff4e1 100644 --- a/train_second.py +++ b/train_second.py @@ -1,5 +1,6 @@ # load packages import copy +import os import random import yaml import time @@ -7,7 +8,13 @@ import numpy as np import torch -torch.autograd.set_detect_anomaly(True) +# Anomaly detection is a debugging tool — it traces every op in the autograd +# graph to surface NaN/inf sources, but it makes backward 5–10× slower, holds +# extra memory (causes OOMs that wouldn't otherwise happen), and on some +# single-GPU + bf16 setups it deadlocks the first backward pass entirely. +# Off by default; opt in via STYLETTS2_DETECT_ANOMALY=1 when diagnosing. +if os.environ.get("STYLETTS2_DETECT_ANOMALY") == "1": + torch.autograd.set_detect_anomaly(True) if getattr(torch, "_original_load", None) is None: torch._original_load = torch.load