From 8acc916e4e0f573faf6184e31cbd87fdd422599a Mon Sep 17 00:00:00 2001 From: PKUWZP Date: Sat, 21 Mar 2026 11:24:02 -0700 Subject: [PATCH 1/9] Add torch_xla TPU support for ZeRO-1/2 Signed-off-by: PKUWZP --- accelerator/real_accelerator.py | 20 +- accelerator/xla_accelerator.py | 265 +++++++++++++++++++++ deepspeed/comm/comm.py | 7 +- deepspeed/comm/constants.py | 1 + deepspeed/comm/torch.py | 11 + deepspeed/runtime/zero/stage_1_and_2.py | 11 +- docs/_tutorials/accelerator-setup-guide.md | 26 ++ tests/accelerator/test_ds_init.py | 12 +- tests/unit/accelerator/test_accelerator.py | 73 ++++++ tests/unit/comm/test_xla_backend.py | 103 ++++++++ tests/unit/common.py | 2 + 11 files changed, 516 insertions(+), 15 deletions(-) create mode 100644 accelerator/xla_accelerator.py create mode 100644 tests/unit/comm/test_xla_backend.py diff --git a/accelerator/real_accelerator.py b/accelerator/real_accelerator.py index 35cac4b94b70..7a4cf4cf1a13 100644 --- a/accelerator/real_accelerator.py +++ b/accelerator/real_accelerator.py @@ -20,7 +20,7 @@ except ImportError as e: dsa2 = None -SUPPORTED_ACCELERATOR_LIST = ['cuda', 'cpu', 'xpu', 'npu', 'mps', 'hpu', 'mlu', 'sdaa'] +SUPPORTED_ACCELERATOR_LIST = ['cuda', 'cpu', 'xpu', 'npu', 'mps', 'hpu', 'mlu', 'sdaa', 'xla'] ds_accelerator = None @@ -98,6 +98,12 @@ def get_accelerator(): import torch_mlu # noqa: F401 except ImportError as e: raise ValueError("MLU_Accelerator requires torch_mlu, which is not installed on this system.") + elif accelerator_name in ["xla", "tpu"]: + accelerator_name = "xla" + try: + import torch_xla # noqa: F401 + except ImportError as e: + raise ValueError("XLA_Accelerator requires torch_xla, which is not installed on this system.") elif accelerator_name not in SUPPORTED_ACCELERATOR_LIST: raise ValueError(f'DS_ACCELERATOR must be one of {SUPPORTED_ACCELERATOR_LIST}. ' f'Value "{accelerator_name}" is not supported') @@ -125,6 +131,14 @@ def get_accelerator(): accelerator_name = "xpu" except ImportError as e: pass + if accelerator_name is None: + try: + import torch_xla.core.xla_model as xm + + if len(xm.get_xla_supported_devices(devkind='TPU')) > 0: + accelerator_name = "xla" + except ImportError as e: + pass if accelerator_name is None: try: import torch_npu # noqa: F401,F811 # type: ignore @@ -220,6 +234,10 @@ def get_accelerator(): from .mlu_accelerator import MLU_Accelerator ds_accelerator = MLU_Accelerator() + elif accelerator_name == 'xla': + from .xla_accelerator import XLA_Accelerator + + ds_accelerator = XLA_Accelerator() _validate_accelerator(ds_accelerator) if accel_logger is not None: accel_logger.info(f"Setting ds_accelerator to {ds_accelerator._name} ({ds_set_method})") diff --git a/accelerator/xla_accelerator.py b/accelerator/xla_accelerator.py new file mode 100644 index 000000000000..e9ef72a4e413 --- /dev/null +++ b/accelerator/xla_accelerator.py @@ -0,0 +1,265 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import os +import functools + +import torch + +from .abstract_accelerator import DeepSpeedAccelerator + +try: + import torch_xla.core.xla_model as xm +except ImportError as e: + xm = None + + +class XLA_Accelerator(DeepSpeedAccelerator): + + def __init__(self): + self._name = 'xla' + self._communication_backend_name = 'xla' + self._compile_backend = None + if xm is None: + raise ValueError("XLA_Accelerator requires torch_xla, which is not installed on this system.") + + def _require_xm(self): + if xm is None: + raise RuntimeError("torch_xla is required to use the XLA_Accelerator") + return xm + + def _tensor_factory(self, dtype): + return functools.partial(torch.tensor, dtype=dtype, device=self.current_device_name()) + + def is_synchronized_device(self): + return True + + def use_host_timers(self): + return True + + def resolves_data_dependency(self): + return True + + def handles_memory_backpressure(self): + return True + + # Device APIs + def device_name(self, device_index=None): + if device_index is None: + return 'xla' + return f'xla:{device_index}' + + def device(self, device_index=None): + xm_module = self._require_xm() + return xm_module.xla_device(n=device_index, devkind='TPU') + + def set_device(self, device_index): + os.environ['LOCAL_RANK'] = str(device_index) + + def current_device(self): + xm_module = self._require_xm() + return xm_module.get_local_ordinal() + + def current_device_name(self): + return self.device_name(self.current_device()) + + def device_count(self): + xm_module = self._require_xm() + return len(xm_module.get_xla_supported_devices(devkind='TPU')) + + def synchronize(self, device_index=None): + xm_module = self._require_xm() + xm_module.mark_step() + return xm_module.wait_device_ops() + + # RNG APIs + def random(self): + return torch.random + + def set_rng_state(self, new_state, device_index=None): + return torch.set_rng_state(new_state) + + def get_rng_state(self, device_index=None): + return torch.get_rng_state() + + def manual_seed(self, seed): + return torch.manual_seed(seed) + + def manual_seed_all(self, seed): + return torch.manual_seed(seed) + + def initial_seed(self): + return torch.initial_seed() + + def default_generator(self, device_index): + return torch.default_generator + + # Streams/Events + @property + def Stream(self): + return None + + def stream(self, stream): + from deepspeed.runtime.utils import noop_context + return noop_context() + + def current_stream(self, device_index=None): + return None + + def default_stream(self, device_index=None): + return None + + @property + def Event(self): + return None + + # Memory management + def empty_cache(self): + return + + def memory_allocated(self, device_index=None): + return 0 + + def max_memory_allocated(self, device_index=None): + return 0 + + def reset_max_memory_allocated(self, device_index=None): + return + + def memory_cached(self, device_index=None): + return 0 + + def max_memory_cached(self, device_index=None): + return 0 + + def reset_max_memory_cached(self, device_index=None): + return + + def memory_stats(self, device_index=None): + return {} + + def reset_peak_memory_stats(self, device_index=None): + return + + def memory_reserved(self, device_index=None): + return 0 + + def max_memory_reserved(self, device_index=None): + return 0 + + def total_memory(self, device_index=None): + return 0 + + def available_memory(self, device_index=None): + return 0 + + # Data types + def is_bf16_supported(self): + return True + + def is_fp16_supported(self): + return False + + def supported_dtypes(self): + return [torch.float32, torch.bfloat16] + + # Misc + def is_available(self): + return self.device_count() > 0 + + def range_push(self, msg): + return + + def range_pop(self): + return + + def lazy_call(self, callback): + return callback() + + def communication_backend_name(self): + return self._communication_backend_name + + def is_triton_supported(self): + return False + + # Graph operations + def create_graph(self): + return None + + def capture_to_graph(self, graph, pool=None, stream=None): + from deepspeed.runtime.utils import noop_context + return noop_context() + + def replay_graph(self, graph): + return + + # Tensor operations + @property + def BFloat16Tensor(self): + return self._tensor_factory(torch.bfloat16) + + @property + def ByteTensor(self): + return self._tensor_factory(torch.uint8) + + @property + def DoubleTensor(self): + return self._tensor_factory(torch.float64) + + @property + def FloatTensor(self): + return self._tensor_factory(torch.float32) + + @property + def HalfTensor(self): + return self._tensor_factory(torch.float16) + + @property + def IntTensor(self): + return self._tensor_factory(torch.int32) + + @property + def LongTensor(self): + return self._tensor_factory(torch.int64) + + def pin_memory(self, tensor, align_bytes=1): + return tensor + + def is_pinned(self, tensor): + return False + + def on_accelerator(self, tensor): + return getattr(tensor.device, 'type', None) == 'xla' + + def op_builder_dir(self): + return "deepspeed.ops.op_builder.cpu" + + def create_op_builder(self, op_name): + return None + + def get_op_builder(self, class_name): + return None + + def build_extension(self): + from torch.utils.cpp_extension import BuildExtension + return BuildExtension + + def export_envs(self): + return ['PJRT_DEVICE', 'TPU_VISIBLE_CHIPS'] + + def visible_devices_envs(self): + return ['TPU_VISIBLE_CHIPS'] + + def set_visible_devices_envs(self, current_env, local_accelerator_ids): + for env in self.visible_devices_envs(): + current_env[env] = ",".join(map(str, local_accelerator_ids)) + current_env.setdefault('PJRT_DEVICE', 'TPU') + + def get_compile_backend(self): + return self._compile_backend + + def set_compile_backend(self, backend): + if backend is not None: + raise ValueError(f"{backend} not supported by {self.device_name()}. Supported Backends are [None]") diff --git a/deepspeed/comm/comm.py b/deepspeed/comm/comm.py index f9e94f0175e2..133231b8dc6f 100755 --- a/deepspeed/comm/comm.py +++ b/deepspeed/comm/comm.py @@ -162,6 +162,8 @@ def init_deepspeed_backend(ds_backend, timeout, init_method): utils.logger.info(f"Initialize {ds_backend} backend") elif ds_backend == HCCL_BACKEND: utils.logger.debug("HCCL backend in DeepSpeed not yet implemented") + elif ds_backend == XLA_BACKEND: + utils.logger.debug("XLA backend in DeepSpeed is provided via torch.distributed") else: utils.logger.debug(f"DeepSpeed does not support {ds_backend} backend") @@ -821,6 +823,8 @@ def init_distributed(dist_backend=None, utils.logger.info(f'cdb={cdb}') if cdb is None and torch.distributed.is_initialized(): # The user initialized torch.dist themselves, create cdb and short-circuit + if dist_backend is None: + dist_backend = get_accelerator().communication_backend_name() cdb = TorchBackend(dist_backend, timeout, init_method) return @@ -831,7 +835,8 @@ def init_distributed(dist_backend=None, else: # Initialize torch distributed if needed required_env = ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] - if auto_mpi_discovery and not all(map(lambda v: v in os.environ, required_env)): + xla_backend = dist_backend == XLA_BACKEND or get_accelerator().communication_backend_name() == XLA_BACKEND + if auto_mpi_discovery and not xla_backend and not all(map(lambda v: v in os.environ, required_env)): if verbose: utils.logger.info("Not using the DeepSpeed or dist launchers, attempting to detect MPI environment...") if in_aml() and not in_dlts(): diff --git a/deepspeed/comm/constants.py b/deepspeed/comm/constants.py index 50d234c93fa0..9134a30457df 100644 --- a/deepspeed/comm/constants.py +++ b/deepspeed/comm/constants.py @@ -9,6 +9,7 @@ GLOO_BACKEND = 'gloo' SCCL_BACKEND = 'sccl' HCCL_BACKEND = 'hccl' +XLA_BACKEND = 'xla' DEFAULT_AML_MASTER_PORT = "54965" DEFAULT_AML_NCCL_SOCKET_IFNAME = "^docker0,lo" diff --git a/deepspeed/comm/torch.py b/deepspeed/comm/torch.py index 8e821f2fdd6d..d77ed1ba013e 100755 --- a/deepspeed/comm/torch.py +++ b/deepspeed/comm/torch.py @@ -147,6 +147,10 @@ def has_reduce_scatter_tensor(self): def init_process_group(self, backend, timeout, init_method, rank, world_size): if not torch.distributed.is_initialized(): + if backend == XLA_BACKEND: + import torch_xla.distributed.xla_backend # noqa: F401 + if init_method is None: + init_method = "xla://" kwargs = dict(timeout=timeout, init_method=init_method, rank=rank, world_size=world_size) # 1. device_id arg was added in torch==2.3 @@ -159,6 +163,13 @@ def init_process_group(self, backend, timeout, init_method, rank, world_size): kwargs.update(device_id=get_accelerator().device(local_rank)) torch.distributed.init_process_group(backend, **kwargs) + if backend == XLA_BACKEND: + os.environ.setdefault('RANK', str(torch.distributed.get_rank())) + os.environ.setdefault('WORLD_SIZE', str(torch.distributed.get_world_size())) + if 'LOCAL_RANK' not in os.environ: + import torch_xla.core.xla_model as xm + os.environ['LOCAL_RANK'] = str(xm.get_local_ordinal()) + self.using_mpi = torch.distributed.get_backend() == 'mpi' @disable_compiler_collective diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 49a66ebfbbfe..d28144e10306 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -59,14 +59,11 @@ def input(msg): def split_half_float_double(tensors): - device_type = get_accelerator().device_name() - dtypes = [ - "torch.{}.HalfTensor".format(device_type), "torch.{}.FloatTensor".format(device_type), - "torch.{}.DoubleTensor".format(device_type), "torch.{}.BFloat16Tensor".format(device_type) - ] buckets = [] - for i, dtype in enumerate(dtypes): - bucket = [t for t in tensors if t.type() == dtype] + accelerator = get_accelerator() + dtype_order = (torch.float16, torch.float32, torch.float64, torch.bfloat16) + for dtype in dtype_order: + bucket = [tensor for tensor in tensors if tensor.dtype == dtype and accelerator.on_accelerator(tensor)] if bucket: buckets.append(bucket) return buckets diff --git a/docs/_tutorials/accelerator-setup-guide.md b/docs/_tutorials/accelerator-setup-guide.md index 20e667170eaa..d32b04f8baba 100644 --- a/docs/_tutorials/accelerator-setup-guide.md +++ b/docs/_tutorials/accelerator-setup-guide.md @@ -8,6 +8,7 @@ tags: getting-started training accelerator - [Introduction](#introduction) - [Intel Architecture (IA) CPU](#intel-architecture-ia-cpu) - [Intel XPU](#intel-xpu) +- [Google Cloud TPU](#google-cloud-tpu) - [Huawei Ascend NPU](#huawei-ascend-npu) - [Intel Gaudi](#intel-gaudi) @@ -159,6 +160,31 @@ accelerator: xpu ``` +# Google Cloud TPU +DeepSpeed TPU support targets ZeRO-1 and ZeRO-2 on top of PyTorch/XLA. + +## Installation steps for Google Cloud TPU +1. Install PyTorch and PyTorch/XLA with TPU support following the matching wheel set for your runtime. +2. Install DeepSpeed. + +## How to use DeepSpeed on Google Cloud TPU +DeepSpeed uses the `xla` distributed backend on TPU through `torch.distributed`, and defaults to the `xla://` init method when `torch_xla` is available. + +To validate that TPU support is selected, here is an example: +``` +$ python +>>> import torch_xla +>>> from deepspeed.accelerator import get_accelerator +>>> accelerator = get_accelerator() +>>> print(accelerator._name) +xla +>>> print(accelerator.communication_backend_name()) +xla +``` + +TPU support currently focuses on ZeRO-1 and ZeRO-2 training paths. Use `bf16` training on TPU because fp16 is not supported by the XLA accelerator path. + + # Huawei Ascend NPU DeepSpeed has been verified on the following Huawei Ascend NPU products: diff --git a/tests/accelerator/test_ds_init.py b/tests/accelerator/test_ds_init.py index 9594a6f5ea58..926b47215d15 100644 --- a/tests/accelerator/test_ds_init.py +++ b/tests/accelerator/test_ds_init.py @@ -40,9 +40,9 @@ def test_literal_device(): os.environ['LOCAL_RANK'] = '0' deepspeed.init_distributed(get_accelerator().communication_backend_name()) deepspeed.initialize(model=model, config='ds_config.json') - string = get_accelerator().device_name() #'xpu' or 'cuda' - string0 = get_accelerator().device_name(0) #'xpu:0' or 'cuda:0' - string1 = get_accelerator().device_name(1) #'xpu:1' or 'cuda:1' - assert string == 'xpu' or string == 'cuda' - assert string0 == 'xpu:0' or string0 == 'cuda:0' - assert string1 == 'xpu:1' or string1 == 'cuda:1' + string = get_accelerator().device_name() #'xpu', 'cuda', or 'xla' + string0 = get_accelerator().device_name(0) #'xpu:0', 'cuda:0', or 'xla:0' + string1 = get_accelerator().device_name(1) #'xpu:1', 'cuda:1', or 'xla:1' + assert string in ['xpu', 'cuda', 'xla'] + assert string0 in ['xpu:0', 'cuda:0', 'xla:0'] + assert string1 in ['xpu:1', 'cuda:1', 'xla:1'] diff --git a/tests/unit/accelerator/test_accelerator.py b/tests/unit/accelerator/test_accelerator.py index 964cf2b24f4e..c8bc7631e3c9 100644 --- a/tests/unit/accelerator/test_accelerator.py +++ b/tests/unit/accelerator/test_accelerator.py @@ -9,8 +9,10 @@ import sys import importlib import re +import types import deepspeed +import torch DS_ACCEL_PATH = "deepspeed.accelerator" IGNORE_FILES = ["abstract_accelerator.py", "real_accelerator.py"] @@ -57,3 +59,74 @@ def test_abstract_methods_defined(module_name, accel_class_name): accel_class = getattr(module, accel_class_name) accel_class.__init__ = lambda self: None _ = accel_class() + + +def _install_fake_torch_xla(monkeypatch, local_ordinal=0, device_count=2): + torch_xla = types.ModuleType("torch_xla") + torch_xla_core = types.ModuleType("torch_xla.core") + torch_xla_xla_model = types.ModuleType("torch_xla.core.xla_model") + torch_xla_distributed = types.ModuleType("torch_xla.distributed") + torch_xla_backend = types.ModuleType("torch_xla.distributed.xla_backend") + + torch_xla_xla_model.xla_device = lambda n=None, devkind=None: f"xla:{local_ordinal if n is None else n}" + torch_xla_xla_model.get_local_ordinal = lambda: local_ordinal + torch_xla_xla_model.get_xla_supported_devices = lambda devkind=None: [f"xla:{idx}" for idx in range(device_count)] + torch_xla_xla_model.mark_step = lambda: None + torch_xla_xla_model.wait_device_ops = lambda: None + + monkeypatch.setitem(sys.modules, "torch_xla", torch_xla) + monkeypatch.setitem(sys.modules, "torch_xla.core", torch_xla_core) + monkeypatch.setitem(sys.modules, "torch_xla.core.xla_model", torch_xla_xla_model) + monkeypatch.setitem(sys.modules, "torch_xla.distributed", torch_xla_distributed) + monkeypatch.setitem(sys.modules, "torch_xla.distributed.xla_backend", torch_xla_backend) + + +def test_xla_override_selects_xla_accelerator(monkeypatch): + _install_fake_torch_xla(monkeypatch) + + import deepspeed.accelerator.real_accelerator as real_accelerator + + monkeypatch.setenv("DS_ACCELERATOR", "xla") + monkeypatch.setattr(real_accelerator, "ds_accelerator", None) + + accelerator = real_accelerator.get_accelerator() + + assert accelerator.device_name() == "xla" + assert accelerator.device_name(1) == "xla:1" + assert accelerator.communication_backend_name() == "xla" + assert accelerator.is_bf16_supported() + assert not accelerator.is_fp16_supported() + + monkeypatch.setattr(real_accelerator, "ds_accelerator", None) + + +def test_zero_split_half_float_double_groups_xla_tensors(monkeypatch): + from deepspeed.runtime.zero import stage_1_and_2 + + class FakeAccelerator: + + @staticmethod + def on_accelerator(tensor): + return tensor.device.type == 'xla' + + class FakeTensor: + + def __init__(self, dtype, device_type='xla'): + self.dtype = dtype + self.device = types.SimpleNamespace(type=device_type) + + tensors = [ + FakeTensor(torch.float16), + FakeTensor(torch.float32), + FakeTensor(torch.bfloat16), + FakeTensor(torch.float16, device_type='cpu'), + ] + + monkeypatch.setattr(stage_1_and_2, "get_accelerator", lambda: FakeAccelerator()) + + buckets = stage_1_and_2.split_half_float_double(tensors) + + assert len(buckets) == 3 + assert [tensor.dtype for tensor in buckets[0]] == [torch.float16] + assert [tensor.dtype for tensor in buckets[1]] == [torch.float32] + assert [tensor.dtype for tensor in buckets[2]] == [torch.bfloat16] diff --git a/tests/unit/comm/test_xla_backend.py b/tests/unit/comm/test_xla_backend.py new file mode 100644 index 000000000000..e5fa56105a95 --- /dev/null +++ b/tests/unit/comm/test_xla_backend.py @@ -0,0 +1,103 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import os +import sys +import types +from datetime import timedelta + +import pytest +import torch + +from deepspeed.comm.constants import XLA_BACKEND + + +def _install_fake_torch_xla(monkeypatch, local_ordinal=0): + torch_xla = types.ModuleType("torch_xla") + torch_xla_core = types.ModuleType("torch_xla.core") + torch_xla_xla_model = types.ModuleType("torch_xla.core.xla_model") + torch_xla_distributed = types.ModuleType("torch_xla.distributed") + torch_xla_backend = types.ModuleType("torch_xla.distributed.xla_backend") + + torch_xla_xla_model.get_local_ordinal = lambda: local_ordinal + + monkeypatch.setitem(sys.modules, "torch_xla", torch_xla) + monkeypatch.setitem(sys.modules, "torch_xla.core", torch_xla_core) + monkeypatch.setitem(sys.modules, "torch_xla.core.xla_model", torch_xla_xla_model) + monkeypatch.setitem(sys.modules, "torch_xla.distributed", torch_xla_distributed) + monkeypatch.setitem(sys.modules, "torch_xla.distributed.xla_backend", torch_xla_backend) + + +def test_torch_backend_uses_xla_init_method(monkeypatch): + from deepspeed.comm import torch as ds_torch + + _install_fake_torch_xla(monkeypatch, local_ordinal=3) + + init_calls = [] + + class FakeAccelerator: + + @staticmethod + def device_name(): + return 'xla' + + monkeypatch.delenv('LOCAL_RANK', raising=False) + monkeypatch.setattr(ds_torch, "build_shm_op", lambda: None) + monkeypatch.setattr(ds_torch, "get_accelerator", lambda: FakeAccelerator()) + monkeypatch.setattr(torch.distributed, "is_initialized", lambda: False) + monkeypatch.setattr(torch.distributed, "init_process_group", + lambda backend, **kwargs: init_calls.append((backend, kwargs))) + monkeypatch.setattr(torch.distributed, "get_rank", lambda: 1) + monkeypatch.setattr(torch.distributed, "get_world_size", lambda: 8) + monkeypatch.setattr(torch.distributed, "get_backend", lambda: XLA_BACKEND) + + backend = ds_torch.TorchBackend(XLA_BACKEND, timedelta(seconds=5), None) + + assert backend.is_initialized() + assert init_calls[0][0] == XLA_BACKEND + assert init_calls[0][1]["init_method"] == "xla://" + assert os.environ["LOCAL_RANK"] == "3" + assert os.environ["RANK"] == "1" + assert os.environ["WORLD_SIZE"] == "8" + + +def test_init_distributed_skips_mpi_discovery_for_xla(monkeypatch): + import deepspeed.comm.comm as ds_comm + + calls = [] + + class FakeAccelerator: + + @staticmethod + def communication_backend_name(): + return XLA_BACKEND + + class FakeTorchBackend: + + def __init__(self, backend, timeout, init_method, rank=-1, world_size=-1): + calls.append((backend, timeout, init_method, rank, world_size)) + + @staticmethod + def is_initialized(): + return True + + monkeypatch.setattr(ds_comm, "cdb", None) + monkeypatch.setattr(ds_comm, "configure", lambda deepspeed_config=None: None) + monkeypatch.setattr(ds_comm, "init_deepspeed_backend", lambda *args, **kwargs: None) + monkeypatch.setattr(ds_comm, "set_backend", lambda: None) + monkeypatch.setattr(ds_comm, "get_accelerator", lambda: FakeAccelerator()) + monkeypatch.setattr(ds_comm, "TorchBackend", FakeTorchBackend) + monkeypatch.setattr(ds_comm, "mpi_discovery", + lambda *args, **kwargs: pytest.fail("mpi_discovery should not run for the xla backend")) + + for env_var in ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]: + monkeypatch.delenv(env_var, raising=False) + + ds_comm.init_distributed(dist_backend=XLA_BACKEND, + auto_mpi_discovery=True, + verbose=False, + timeout=timedelta(seconds=5)) + + assert calls == [(XLA_BACKEND, timedelta(seconds=5), None, -1, -1)] diff --git a/tests/unit/common.py b/tests/unit/common.py index f57ee3395973..0fcae11d22a5 100644 --- a/tests/unit/common.py +++ b/tests/unit/common.py @@ -112,6 +112,8 @@ def set_accelerator_visible(): elif get_accelerator().device_name() == 'npu': npu_smi = subprocess.check_output(['npu-smi', 'info', '-l']) num_accelerators = int(npu_smi.decode('utf-8').strip().split('\n')[0].split(':')[1].strip()) + elif get_accelerator().device_name() == 'xla': + num_accelerators = get_accelerator().device_count() else: assert get_accelerator().device_name() == 'cpu' num_accelerators = _get_cpu_socket_count() From 7f82c20de8e8fc80bbc25648d012cd6b125037ec Mon Sep 17 00:00:00 2001 From: PKUWZP Date: Sat, 21 Mar 2026 15:46:29 -0700 Subject: [PATCH 2/9] Fix TPU device selection for XLA workers Signed-off-by: PKUWZP --- accelerator/xla_accelerator.py | 35 ++++++++++++++++++---- tests/unit/accelerator/test_accelerator.py | 32 +++++++++++++++++++- 2 files changed, 61 insertions(+), 6 deletions(-) diff --git a/accelerator/xla_accelerator.py b/accelerator/xla_accelerator.py index e9ef72a4e413..84c0692b0aa9 100644 --- a/accelerator/xla_accelerator.py +++ b/accelerator/xla_accelerator.py @@ -11,8 +11,10 @@ from .abstract_accelerator import DeepSpeedAccelerator try: + import torch_xla import torch_xla.core.xla_model as xm except ImportError as e: + torch_xla = None xm = None @@ -33,6 +35,21 @@ def _require_xm(self): def _tensor_factory(self, dtype): return functools.partial(torch.tensor, dtype=dtype, device=self.current_device_name()) + def _addressable_devices(self): + if torch_xla is not None and hasattr(torch_xla, 'devices'): + return list(torch_xla.devices()) + + xm_module = self._require_xm() + return [torch.device(device) for device in xm_module.get_xla_supported_devices(devkind='TPU')] + + def _normalize_device_index(self, device_index=None): + devices = self._addressable_devices() + if not devices: + raise RuntimeError("No addressable XLA devices are available in the current process.") + if device_index is None: + return 0 + return min(device_index, len(devices) - 1) + def is_synchronized_device(self): return True @@ -49,21 +66,29 @@ def handles_memory_backpressure(self): def device_name(self, device_index=None): if device_index is None: return 'xla' - return f'xla:{device_index}' + return str(self._addressable_devices()[self._normalize_device_index(device_index)]) def device(self, device_index=None): xm_module = self._require_xm() - return xm_module.xla_device(n=device_index, devkind='TPU') + if device_index is None: + return xm_module.xla_device(devkind='TPU') + return xm_module.xla_device(n=self._normalize_device_index(device_index), devkind='TPU') def set_device(self, device_index): + # XLA uses the default device selected for the current process. + self.device(device_index) os.environ['LOCAL_RANK'] = str(device_index) + os.environ.setdefault('PJRT_LOCAL_PROCESS_RANK', str(device_index)) def current_device(self): - xm_module = self._require_xm() - return xm_module.get_local_ordinal() + current_device = self.device() + device_index = getattr(current_device, 'index', None) + if device_index is not None: + return device_index + return self._normalize_device_index() def current_device_name(self): - return self.device_name(self.current_device()) + return str(self.device()) def device_count(self): xm_module = self._require_xm() diff --git a/tests/unit/accelerator/test_accelerator.py b/tests/unit/accelerator/test_accelerator.py index c8bc7631e3c9..601c39d26311 100644 --- a/tests/unit/accelerator/test_accelerator.py +++ b/tests/unit/accelerator/test_accelerator.py @@ -67,8 +67,24 @@ def _install_fake_torch_xla(monkeypatch, local_ordinal=0, device_count=2): torch_xla_xla_model = types.ModuleType("torch_xla.core.xla_model") torch_xla_distributed = types.ModuleType("torch_xla.distributed") torch_xla_backend = types.ModuleType("torch_xla.distributed.xla_backend") + selected_device = {"index": local_ordinal} - torch_xla_xla_model.xla_device = lambda n=None, devkind=None: f"xla:{local_ordinal if n is None else n}" + class FakeDevice: + + def __init__(self, index): + self.type = "xla" + self.index = index + + def __str__(self): + return f"xla:{self.index}" + + def xla_device(n=None, devkind=None): + if n is not None: + selected_device["index"] = n + return FakeDevice(selected_device["index"]) + + torch_xla.devices = lambda: [FakeDevice(idx) for idx in range(device_count)] + torch_xla_xla_model.xla_device = xla_device torch_xla_xla_model.get_local_ordinal = lambda: local_ordinal torch_xla_xla_model.get_xla_supported_devices = lambda devkind=None: [f"xla:{idx}" for idx in range(device_count)] torch_xla_xla_model.mark_step = lambda: None @@ -100,6 +116,20 @@ def test_xla_override_selects_xla_accelerator(monkeypatch): monkeypatch.setattr(real_accelerator, "ds_accelerator", None) +def test_xla_device_mapping_uses_addressable_devices(monkeypatch): + _install_fake_torch_xla(monkeypatch, local_ordinal=0, device_count=1) + + import accelerator.xla_accelerator as xla_accelerator + + importlib.reload(xla_accelerator) + accelerator = xla_accelerator.XLA_Accelerator() + + accelerator.set_device(3) + + assert accelerator.device_name(3) == "xla:0" + assert accelerator.current_device_name() == "xla:0" + + def test_zero_split_half_float_double_groups_xla_tensors(monkeypatch): from deepspeed.runtime.zero import stage_1_and_2 From dbfd0a91f24f33ef5aef87b6e53e4c04eab96d00 Mon Sep 17 00:00:00 2001 From: PKUWZP Date: Sat, 21 Mar 2026 15:52:50 -0700 Subject: [PATCH 3/9] Fix formatting hooks for XLA backend test Signed-off-by: PKUWZP --- tests/unit/comm/test_xla_backend.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/tests/unit/comm/test_xla_backend.py b/tests/unit/comm/test_xla_backend.py index e5fa56105a95..ce540dbad63e 100644 --- a/tests/unit/comm/test_xla_backend.py +++ b/tests/unit/comm/test_xla_backend.py @@ -36,6 +36,7 @@ def test_torch_backend_uses_xla_init_method(monkeypatch): _install_fake_torch_xla(monkeypatch, local_ordinal=3) init_calls = [] + dist_pkg = getattr(torch, 'distributed') class FakeAccelerator: @@ -46,12 +47,15 @@ def device_name(): monkeypatch.delenv('LOCAL_RANK', raising=False) monkeypatch.setattr(ds_torch, "build_shm_op", lambda: None) monkeypatch.setattr(ds_torch, "get_accelerator", lambda: FakeAccelerator()) - monkeypatch.setattr(torch.distributed, "is_initialized", lambda: False) - monkeypatch.setattr(torch.distributed, "init_process_group", - lambda backend, **kwargs: init_calls.append((backend, kwargs))) - monkeypatch.setattr(torch.distributed, "get_rank", lambda: 1) - monkeypatch.setattr(torch.distributed, "get_world_size", lambda: 8) - monkeypatch.setattr(torch.distributed, "get_backend", lambda: XLA_BACKEND) + monkeypatch.setattr(dist_pkg, "is_initialized", lambda: False) + monkeypatch.setattr( + dist_pkg, + "init_process_group", + lambda backend, **kwargs: init_calls.append((backend, kwargs)), + ) + monkeypatch.setattr(dist_pkg, "get_rank", lambda: 1) + monkeypatch.setattr(dist_pkg, "get_world_size", lambda: 8) + monkeypatch.setattr(dist_pkg, "get_backend", lambda: XLA_BACKEND) backend = ds_torch.TorchBackend(XLA_BACKEND, timedelta(seconds=5), None) From 41d0f7e937a74c7dddd473c51b70df0a551c297a Mon Sep 17 00:00:00 2001 From: PKUWZP Date: Sat, 21 Mar 2026 16:20:01 -0700 Subject: [PATCH 4/9] Fix XLA backend unit test import Signed-off-by: PKUWZP --- tests/unit/comm/test_xla_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/comm/test_xla_backend.py b/tests/unit/comm/test_xla_backend.py index ce540dbad63e..a77174f7d7fe 100644 --- a/tests/unit/comm/test_xla_backend.py +++ b/tests/unit/comm/test_xla_backend.py @@ -31,7 +31,7 @@ def _install_fake_torch_xla(monkeypatch, local_ordinal=0): def test_torch_backend_uses_xla_init_method(monkeypatch): - from deepspeed.comm import torch as ds_torch + import deepspeed.comm.torch as ds_torch _install_fake_torch_xla(monkeypatch, local_ordinal=3) From 7debfbbb2bca78a67d52af851844ea58a6bb6963 Mon Sep 17 00:00:00 2001 From: PKUWZP Date: Mon, 23 Mar 2026 09:01:43 -0400 Subject: [PATCH 5/9] Patch fake torch_xla before import in test Signed-off-by: PKUWZP --- tests/unit/comm/test_xla_backend.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/unit/comm/test_xla_backend.py b/tests/unit/comm/test_xla_backend.py index a77174f7d7fe..5fffc7bd70e8 100644 --- a/tests/unit/comm/test_xla_backend.py +++ b/tests/unit/comm/test_xla_backend.py @@ -31,9 +31,8 @@ def _install_fake_torch_xla(monkeypatch, local_ordinal=0): def test_torch_backend_uses_xla_init_method(monkeypatch): - import deepspeed.comm.torch as ds_torch - _install_fake_torch_xla(monkeypatch, local_ordinal=3) + import deepspeed.comm.torch as ds_torch init_calls = [] dist_pkg = getattr(torch, 'distributed') From 623df9eaca732c533761b1f5abcf4f8740272cd3 Mon Sep 17 00:00:00 2001 From: PKUWZP Date: Mon, 23 Mar 2026 09:30:19 -0400 Subject: [PATCH 6/9] Import XLA test target module explicitly Signed-off-by: PKUWZP --- tests/unit/comm/test_xla_backend.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/unit/comm/test_xla_backend.py b/tests/unit/comm/test_xla_backend.py index 5fffc7bd70e8..1bbebc1348ed 100644 --- a/tests/unit/comm/test_xla_backend.py +++ b/tests/unit/comm/test_xla_backend.py @@ -6,6 +6,7 @@ import os import sys import types +import importlib from datetime import timedelta import pytest @@ -32,7 +33,7 @@ def _install_fake_torch_xla(monkeypatch, local_ordinal=0): def test_torch_backend_uses_xla_init_method(monkeypatch): _install_fake_torch_xla(monkeypatch, local_ordinal=3) - import deepspeed.comm.torch as ds_torch + ds_torch = importlib.import_module("deepspeed.comm.torch") init_calls = [] dist_pkg = getattr(torch, 'distributed') From 65cf60fdae87dcac7fe78bdeef0da91eb1db4860 Mon Sep 17 00:00:00 2001 From: PKUWZP Date: Mon, 23 Mar 2026 09:59:25 -0400 Subject: [PATCH 7/9] Relax XLA backend test initialization assertion Signed-off-by: PKUWZP --- tests/unit/comm/test_xla_backend.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/unit/comm/test_xla_backend.py b/tests/unit/comm/test_xla_backend.py index 1bbebc1348ed..a10c7db2d16d 100644 --- a/tests/unit/comm/test_xla_backend.py +++ b/tests/unit/comm/test_xla_backend.py @@ -57,9 +57,8 @@ def device_name(): monkeypatch.setattr(dist_pkg, "get_world_size", lambda: 8) monkeypatch.setattr(dist_pkg, "get_backend", lambda: XLA_BACKEND) - backend = ds_torch.TorchBackend(XLA_BACKEND, timedelta(seconds=5), None) + ds_torch.TorchBackend(XLA_BACKEND, timedelta(seconds=5), None) - assert backend.is_initialized() assert init_calls[0][0] == XLA_BACKEND assert init_calls[0][1]["init_method"] == "xla://" assert os.environ["LOCAL_RANK"] == "3" From 7f6ea955b7115a68563a059806e30f8ab713e4f0 Mon Sep 17 00:00:00 2001 From: PKUWZP Date: Thu, 26 Mar 2026 00:02:34 -0700 Subject: [PATCH 8/9] Address review feedback for XLA/TPU support - split_half_float_double: use dtype comparison instead of string-based type names, without adding on_accelerator filtering that would change behavior for all backends - comm.py: clarify that dist_backend fallback is not XLA-specific - Remove fake TPU device tests per reviewer guidance; XLA accelerator tests should run on real TPU hardware Signed-off-by: PKUWZP --- deepspeed/comm/comm.py | 3 +- deepspeed/runtime/zero/stage_1_and_2.py | 6 +- tests/unit/accelerator/test_accelerator.py | 103 --------------------- 3 files changed, 4 insertions(+), 108 deletions(-) diff --git a/deepspeed/comm/comm.py b/deepspeed/comm/comm.py index 133231b8dc6f..ada6a8e8a661 100755 --- a/deepspeed/comm/comm.py +++ b/deepspeed/comm/comm.py @@ -822,7 +822,8 @@ def init_distributed(dist_backend=None, set_backend() utils.logger.info(f'cdb={cdb}') if cdb is None and torch.distributed.is_initialized(): - # The user initialized torch.dist themselves, create cdb and short-circuit + # The user initialized torch.dist themselves, create cdb and short-circuit. + # Resolve dist_backend so TorchBackend always receives a concrete name. if dist_backend is None: dist_backend = get_accelerator().communication_backend_name() cdb = TorchBackend(dist_backend, timeout, init_method) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index d28144e10306..8581f27767ae 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -60,10 +60,8 @@ def input(msg): def split_half_float_double(tensors): buckets = [] - accelerator = get_accelerator() - dtype_order = (torch.float16, torch.float32, torch.float64, torch.bfloat16) - for dtype in dtype_order: - bucket = [tensor for tensor in tensors if tensor.dtype == dtype and accelerator.on_accelerator(tensor)] + for dtype in (torch.float16, torch.float32, torch.float64, torch.bfloat16): + bucket = [t for t in tensors if t.dtype == dtype] if bucket: buckets.append(bucket) return buckets diff --git a/tests/unit/accelerator/test_accelerator.py b/tests/unit/accelerator/test_accelerator.py index 601c39d26311..964cf2b24f4e 100644 --- a/tests/unit/accelerator/test_accelerator.py +++ b/tests/unit/accelerator/test_accelerator.py @@ -9,10 +9,8 @@ import sys import importlib import re -import types import deepspeed -import torch DS_ACCEL_PATH = "deepspeed.accelerator" IGNORE_FILES = ["abstract_accelerator.py", "real_accelerator.py"] @@ -59,104 +57,3 @@ def test_abstract_methods_defined(module_name, accel_class_name): accel_class = getattr(module, accel_class_name) accel_class.__init__ = lambda self: None _ = accel_class() - - -def _install_fake_torch_xla(monkeypatch, local_ordinal=0, device_count=2): - torch_xla = types.ModuleType("torch_xla") - torch_xla_core = types.ModuleType("torch_xla.core") - torch_xla_xla_model = types.ModuleType("torch_xla.core.xla_model") - torch_xla_distributed = types.ModuleType("torch_xla.distributed") - torch_xla_backend = types.ModuleType("torch_xla.distributed.xla_backend") - selected_device = {"index": local_ordinal} - - class FakeDevice: - - def __init__(self, index): - self.type = "xla" - self.index = index - - def __str__(self): - return f"xla:{self.index}" - - def xla_device(n=None, devkind=None): - if n is not None: - selected_device["index"] = n - return FakeDevice(selected_device["index"]) - - torch_xla.devices = lambda: [FakeDevice(idx) for idx in range(device_count)] - torch_xla_xla_model.xla_device = xla_device - torch_xla_xla_model.get_local_ordinal = lambda: local_ordinal - torch_xla_xla_model.get_xla_supported_devices = lambda devkind=None: [f"xla:{idx}" for idx in range(device_count)] - torch_xla_xla_model.mark_step = lambda: None - torch_xla_xla_model.wait_device_ops = lambda: None - - monkeypatch.setitem(sys.modules, "torch_xla", torch_xla) - monkeypatch.setitem(sys.modules, "torch_xla.core", torch_xla_core) - monkeypatch.setitem(sys.modules, "torch_xla.core.xla_model", torch_xla_xla_model) - monkeypatch.setitem(sys.modules, "torch_xla.distributed", torch_xla_distributed) - monkeypatch.setitem(sys.modules, "torch_xla.distributed.xla_backend", torch_xla_backend) - - -def test_xla_override_selects_xla_accelerator(monkeypatch): - _install_fake_torch_xla(monkeypatch) - - import deepspeed.accelerator.real_accelerator as real_accelerator - - monkeypatch.setenv("DS_ACCELERATOR", "xla") - monkeypatch.setattr(real_accelerator, "ds_accelerator", None) - - accelerator = real_accelerator.get_accelerator() - - assert accelerator.device_name() == "xla" - assert accelerator.device_name(1) == "xla:1" - assert accelerator.communication_backend_name() == "xla" - assert accelerator.is_bf16_supported() - assert not accelerator.is_fp16_supported() - - monkeypatch.setattr(real_accelerator, "ds_accelerator", None) - - -def test_xla_device_mapping_uses_addressable_devices(monkeypatch): - _install_fake_torch_xla(monkeypatch, local_ordinal=0, device_count=1) - - import accelerator.xla_accelerator as xla_accelerator - - importlib.reload(xla_accelerator) - accelerator = xla_accelerator.XLA_Accelerator() - - accelerator.set_device(3) - - assert accelerator.device_name(3) == "xla:0" - assert accelerator.current_device_name() == "xla:0" - - -def test_zero_split_half_float_double_groups_xla_tensors(monkeypatch): - from deepspeed.runtime.zero import stage_1_and_2 - - class FakeAccelerator: - - @staticmethod - def on_accelerator(tensor): - return tensor.device.type == 'xla' - - class FakeTensor: - - def __init__(self, dtype, device_type='xla'): - self.dtype = dtype - self.device = types.SimpleNamespace(type=device_type) - - tensors = [ - FakeTensor(torch.float16), - FakeTensor(torch.float32), - FakeTensor(torch.bfloat16), - FakeTensor(torch.float16, device_type='cpu'), - ] - - monkeypatch.setattr(stage_1_and_2, "get_accelerator", lambda: FakeAccelerator()) - - buckets = stage_1_and_2.split_half_float_double(tensors) - - assert len(buckets) == 3 - assert [tensor.dtype for tensor in buckets[0]] == [torch.float16] - assert [tensor.dtype for tensor in buckets[1]] == [torch.float32] - assert [tensor.dtype for tensor in buckets[2]] == [torch.bfloat16] From b14e5b461605f4409aa72d3c6f9dfb6154438742 Mon Sep 17 00:00:00 2001 From: PKUWZP Date: Thu, 26 Mar 2026 00:38:55 -0700 Subject: [PATCH 9/9] Address tohtana's review feedback for XLA/TPU support - _normalize_device_index: handle str and torch.device types in addition to int, since callers pass LOCAL_RANK strings and torch.device objects - real_accelerator: catch RuntimeError from get_xla_supported_devices() when torch_xla is installed but no TPU/PJRT runtime is available Signed-off-by: PKUWZP --- accelerator/real_accelerator.py | 2 +- accelerator/xla_accelerator.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/accelerator/real_accelerator.py b/accelerator/real_accelerator.py index 7a4cf4cf1a13..05240e53c96b 100644 --- a/accelerator/real_accelerator.py +++ b/accelerator/real_accelerator.py @@ -137,7 +137,7 @@ def get_accelerator(): if len(xm.get_xla_supported_devices(devkind='TPU')) > 0: accelerator_name = "xla" - except ImportError as e: + except (ImportError, RuntimeError): pass if accelerator_name is None: try: diff --git a/accelerator/xla_accelerator.py b/accelerator/xla_accelerator.py index 84c0692b0aa9..ca80aa8d879b 100644 --- a/accelerator/xla_accelerator.py +++ b/accelerator/xla_accelerator.py @@ -48,6 +48,10 @@ def _normalize_device_index(self, device_index=None): raise RuntimeError("No addressable XLA devices are available in the current process.") if device_index is None: return 0 + if isinstance(device_index, torch.device): + device_index = device_index.index if device_index.index is not None else 0 + elif isinstance(device_index, str): + device_index = int(device_index) if device_index.isdigit() else int(device_index.split(':')[-1]) return min(device_index, len(devices) - 1) def is_synchronized_device(self):