diff --git a/engine/misc/dist_utils.py b/engine/misc/dist_utils.py index 368d435..7b409e3 100644 --- a/engine/misc/dist_utils.py +++ b/engine/misc/dist_utils.py @@ -187,10 +187,23 @@ def reduce_dict(data, avg=True): return data with torch.no_grad(): - keys, values = [], [] - for k in sorted(data.keys()): - keys.append(k) - values.append(data[k]) + # Synchronize keys across all ranks to handle cases where some ranks + # have different loss keys (e.g., due to empty batches) + local_keys = set(data.keys()) + all_keys_list = [None] * world_size + torch.distributed.all_gather_object(all_keys_list, local_keys) + all_keys = sorted(set().union(*all_keys_list)) + + # Get device from data values + device = next(iter(data.values())).device if data else torch.device("cuda") + + # Build values tensor with zeros for missing keys + values = [] + for k in all_keys: + if k in data: + values.append(data[k]) + else: + values.append(torch.tensor(0.0, device=device)) values = torch.stack(values, dim=0) torch.distributed.all_reduce(values) @@ -198,7 +211,7 @@ def reduce_dict(data, avg=True): if avg is True: values /= world_size - return {k: v for k, v in zip(keys, values)} + return {k: v for k, v in zip(all_keys, values)} def all_gather(data):