From 2ca9c5fb239dcd6e621f91c01d7d6ab38919b2eb Mon Sep 17 00:00:00 2001 From: Jimmy Mendez <54858901+Jimmy-Mendez@users.noreply.github.com> Date: Thu, 15 Jan 2026 12:06:36 -0500 Subject: [PATCH] Synchronize keys and handle missing values in dist_utils When training with multiple GPUs, batches with no ground truth objects cause some ranks to produce fewer loss keys (e.g., denoising losses are skipped). This results in reduce_dict attempting all_reduce on tensors of different sizes across ranks, causing a deadlock. Fix: Synchronize loss dictionary keys across all ranks before all_reduce, filling missing keys with zeros. --- engine/misc/dist_utils.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/engine/misc/dist_utils.py b/engine/misc/dist_utils.py index 368d435..7b409e3 100644 --- a/engine/misc/dist_utils.py +++ b/engine/misc/dist_utils.py @@ -187,10 +187,23 @@ def reduce_dict(data, avg=True): return data with torch.no_grad(): - keys, values = [], [] - for k in sorted(data.keys()): - keys.append(k) - values.append(data[k]) + # Synchronize keys across all ranks to handle cases where some ranks + # have different loss keys (e.g., due to empty batches) + local_keys = set(data.keys()) + all_keys_list = [None] * world_size + torch.distributed.all_gather_object(all_keys_list, local_keys) + all_keys = sorted(set().union(*all_keys_list)) + + # Get device from data values + device = next(iter(data.values())).device if data else torch.device("cuda") + + # Build values tensor with zeros for missing keys + values = [] + for k in all_keys: + if k in data: + values.append(data[k]) + else: + values.append(torch.tensor(0.0, device=device)) values = torch.stack(values, dim=0) torch.distributed.all_reduce(values) @@ -198,7 +211,7 @@ def reduce_dict(data, avg=True): if avg is True: values /= world_size - return {k: v for k, v in zip(keys, values)} + return {k: v for k, v in zip(all_keys, values)} def all_gather(data):