diff --git a/accelerator/real_accelerator.py b/accelerator/real_accelerator.py index 35cac4b94b70..05240e53c96b 100644 --- a/accelerator/real_accelerator.py +++ b/accelerator/real_accelerator.py @@ -20,7 +20,7 @@ except ImportError as e: dsa2 = None -SUPPORTED_ACCELERATOR_LIST = ['cuda', 'cpu', 'xpu', 'npu', 'mps', 'hpu', 'mlu', 'sdaa'] +SUPPORTED_ACCELERATOR_LIST = ['cuda', 'cpu', 'xpu', 'npu', 'mps', 'hpu', 'mlu', 'sdaa', 'xla'] ds_accelerator = None @@ -98,6 +98,12 @@ def get_accelerator(): import torch_mlu # noqa: F401 except ImportError as e: raise ValueError("MLU_Accelerator requires torch_mlu, which is not installed on this system.") + elif accelerator_name in ["xla", "tpu"]: + accelerator_name = "xla" + try: + import torch_xla # noqa: F401 + except ImportError as e: + raise ValueError("XLA_Accelerator requires torch_xla, which is not installed on this system.") elif accelerator_name not in SUPPORTED_ACCELERATOR_LIST: raise ValueError(f'DS_ACCELERATOR must be one of {SUPPORTED_ACCELERATOR_LIST}. ' f'Value "{accelerator_name}" is not supported') @@ -125,6 +131,14 @@ def get_accelerator(): accelerator_name = "xpu" except ImportError as e: pass + if accelerator_name is None: + try: + import torch_xla.core.xla_model as xm + + if len(xm.get_xla_supported_devices(devkind='TPU')) > 0: + accelerator_name = "xla" + except (ImportError, RuntimeError): + pass if accelerator_name is None: try: import torch_npu # noqa: F401,F811 # type: ignore @@ -220,6 +234,10 @@ def get_accelerator(): from .mlu_accelerator import MLU_Accelerator ds_accelerator = MLU_Accelerator() + elif accelerator_name == 'xla': + from .xla_accelerator import XLA_Accelerator + + ds_accelerator = XLA_Accelerator() _validate_accelerator(ds_accelerator) if accel_logger is not None: accel_logger.info(f"Setting ds_accelerator to {ds_accelerator._name} ({ds_set_method})") diff --git a/accelerator/xla_accelerator.py b/accelerator/xla_accelerator.py new file mode 100644 index 000000000000..ca80aa8d879b --- /dev/null +++ b/accelerator/xla_accelerator.py @@ -0,0 +1,294 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import os +import functools + +import torch + +from .abstract_accelerator import DeepSpeedAccelerator + +try: + import torch_xla + import torch_xla.core.xla_model as xm +except ImportError as e: + torch_xla = None + xm = None + + +class XLA_Accelerator(DeepSpeedAccelerator): + + def __init__(self): + self._name = 'xla' + self._communication_backend_name = 'xla' + self._compile_backend = None + if xm is None: + raise ValueError("XLA_Accelerator requires torch_xla, which is not installed on this system.") + + def _require_xm(self): + if xm is None: + raise RuntimeError("torch_xla is required to use the XLA_Accelerator") + return xm + + def _tensor_factory(self, dtype): + return functools.partial(torch.tensor, dtype=dtype, device=self.current_device_name()) + + def _addressable_devices(self): + if torch_xla is not None and hasattr(torch_xla, 'devices'): + return list(torch_xla.devices()) + + xm_module = self._require_xm() + return [torch.device(device) for device in xm_module.get_xla_supported_devices(devkind='TPU')] + + def _normalize_device_index(self, device_index=None): + devices = self._addressable_devices() + if not devices: + raise RuntimeError("No addressable XLA devices are available in the current process.") + if device_index is None: + return 0 + if isinstance(device_index, torch.device): + device_index = device_index.index if device_index.index is not None else 0 + elif isinstance(device_index, str): + device_index = int(device_index) if device_index.isdigit() else int(device_index.split(':')[-1]) + return min(device_index, len(devices) - 1) + + def is_synchronized_device(self): + return True + + def use_host_timers(self): + return True + + def resolves_data_dependency(self): + return True + + def handles_memory_backpressure(self): + return True + + # Device APIs + def device_name(self, device_index=None): + if device_index is None: + return 'xla' + return str(self._addressable_devices()[self._normalize_device_index(device_index)]) + + def device(self, device_index=None): + xm_module = self._require_xm() + if device_index is None: + return xm_module.xla_device(devkind='TPU') + return xm_module.xla_device(n=self._normalize_device_index(device_index), devkind='TPU') + + def set_device(self, device_index): + # XLA uses the default device selected for the current process. + self.device(device_index) + os.environ['LOCAL_RANK'] = str(device_index) + os.environ.setdefault('PJRT_LOCAL_PROCESS_RANK', str(device_index)) + + def current_device(self): + current_device = self.device() + device_index = getattr(current_device, 'index', None) + if device_index is not None: + return device_index + return self._normalize_device_index() + + def current_device_name(self): + return str(self.device()) + + def device_count(self): + xm_module = self._require_xm() + return len(xm_module.get_xla_supported_devices(devkind='TPU')) + + def synchronize(self, device_index=None): + xm_module = self._require_xm() + xm_module.mark_step() + return xm_module.wait_device_ops() + + # RNG APIs + def random(self): + return torch.random + + def set_rng_state(self, new_state, device_index=None): + return torch.set_rng_state(new_state) + + def get_rng_state(self, device_index=None): + return torch.get_rng_state() + + def manual_seed(self, seed): + return torch.manual_seed(seed) + + def manual_seed_all(self, seed): + return torch.manual_seed(seed) + + def initial_seed(self): + return torch.initial_seed() + + def default_generator(self, device_index): + return torch.default_generator + + # Streams/Events + @property + def Stream(self): + return None + + def stream(self, stream): + from deepspeed.runtime.utils import noop_context + return noop_context() + + def current_stream(self, device_index=None): + return None + + def default_stream(self, device_index=None): + return None + + @property + def Event(self): + return None + + # Memory management + def empty_cache(self): + return + + def memory_allocated(self, device_index=None): + return 0 + + def max_memory_allocated(self, device_index=None): + return 0 + + def reset_max_memory_allocated(self, device_index=None): + return + + def memory_cached(self, device_index=None): + return 0 + + def max_memory_cached(self, device_index=None): + return 0 + + def reset_max_memory_cached(self, device_index=None): + return + + def memory_stats(self, device_index=None): + return {} + + def reset_peak_memory_stats(self, device_index=None): + return + + def memory_reserved(self, device_index=None): + return 0 + + def max_memory_reserved(self, device_index=None): + return 0 + + def total_memory(self, device_index=None): + return 0 + + def available_memory(self, device_index=None): + return 0 + + # Data types + def is_bf16_supported(self): + return True + + def is_fp16_supported(self): + return False + + def supported_dtypes(self): + return [torch.float32, torch.bfloat16] + + # Misc + def is_available(self): + return self.device_count() > 0 + + def range_push(self, msg): + return + + def range_pop(self): + return + + def lazy_call(self, callback): + return callback() + + def communication_backend_name(self): + return self._communication_backend_name + + def is_triton_supported(self): + return False + + # Graph operations + def create_graph(self): + return None + + def capture_to_graph(self, graph, pool=None, stream=None): + from deepspeed.runtime.utils import noop_context + return noop_context() + + def replay_graph(self, graph): + return + + # Tensor operations + @property + def BFloat16Tensor(self): + return self._tensor_factory(torch.bfloat16) + + @property + def ByteTensor(self): + return self._tensor_factory(torch.uint8) + + @property + def DoubleTensor(self): + return self._tensor_factory(torch.float64) + + @property + def FloatTensor(self): + return self._tensor_factory(torch.float32) + + @property + def HalfTensor(self): + return self._tensor_factory(torch.float16) + + @property + def IntTensor(self): + return self._tensor_factory(torch.int32) + + @property + def LongTensor(self): + return self._tensor_factory(torch.int64) + + def pin_memory(self, tensor, align_bytes=1): + return tensor + + def is_pinned(self, tensor): + return False + + def on_accelerator(self, tensor): + return getattr(tensor.device, 'type', None) == 'xla' + + def op_builder_dir(self): + return "deepspeed.ops.op_builder.cpu" + + def create_op_builder(self, op_name): + return None + + def get_op_builder(self, class_name): + return None + + def build_extension(self): + from torch.utils.cpp_extension import BuildExtension + return BuildExtension + + def export_envs(self): + return ['PJRT_DEVICE', 'TPU_VISIBLE_CHIPS'] + + def visible_devices_envs(self): + return ['TPU_VISIBLE_CHIPS'] + + def set_visible_devices_envs(self, current_env, local_accelerator_ids): + for env in self.visible_devices_envs(): + current_env[env] = ",".join(map(str, local_accelerator_ids)) + current_env.setdefault('PJRT_DEVICE', 'TPU') + + def get_compile_backend(self): + return self._compile_backend + + def set_compile_backend(self, backend): + if backend is not None: + raise ValueError(f"{backend} not supported by {self.device_name()}. Supported Backends are [None]") diff --git a/deepspeed/comm/comm.py b/deepspeed/comm/comm.py index f9e94f0175e2..ada6a8e8a661 100755 --- a/deepspeed/comm/comm.py +++ b/deepspeed/comm/comm.py @@ -162,6 +162,8 @@ def init_deepspeed_backend(ds_backend, timeout, init_method): utils.logger.info(f"Initialize {ds_backend} backend") elif ds_backend == HCCL_BACKEND: utils.logger.debug("HCCL backend in DeepSpeed not yet implemented") + elif ds_backend == XLA_BACKEND: + utils.logger.debug("XLA backend in DeepSpeed is provided via torch.distributed") else: utils.logger.debug(f"DeepSpeed does not support {ds_backend} backend") @@ -820,7 +822,10 @@ def init_distributed(dist_backend=None, set_backend() utils.logger.info(f'cdb={cdb}') if cdb is None and torch.distributed.is_initialized(): - # The user initialized torch.dist themselves, create cdb and short-circuit + # The user initialized torch.dist themselves, create cdb and short-circuit. + # Resolve dist_backend so TorchBackend always receives a concrete name. + if dist_backend is None: + dist_backend = get_accelerator().communication_backend_name() cdb = TorchBackend(dist_backend, timeout, init_method) return @@ -831,7 +836,8 @@ def init_distributed(dist_backend=None, else: # Initialize torch distributed if needed required_env = ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] - if auto_mpi_discovery and not all(map(lambda v: v in os.environ, required_env)): + xla_backend = dist_backend == XLA_BACKEND or get_accelerator().communication_backend_name() == XLA_BACKEND + if auto_mpi_discovery and not xla_backend and not all(map(lambda v: v in os.environ, required_env)): if verbose: utils.logger.info("Not using the DeepSpeed or dist launchers, attempting to detect MPI environment...") if in_aml() and not in_dlts(): diff --git a/deepspeed/comm/constants.py b/deepspeed/comm/constants.py index 50d234c93fa0..9134a30457df 100644 --- a/deepspeed/comm/constants.py +++ b/deepspeed/comm/constants.py @@ -9,6 +9,7 @@ GLOO_BACKEND = 'gloo' SCCL_BACKEND = 'sccl' HCCL_BACKEND = 'hccl' +XLA_BACKEND = 'xla' DEFAULT_AML_MASTER_PORT = "54965" DEFAULT_AML_NCCL_SOCKET_IFNAME = "^docker0,lo" diff --git a/deepspeed/comm/torch.py b/deepspeed/comm/torch.py index 8e821f2fdd6d..d77ed1ba013e 100755 --- a/deepspeed/comm/torch.py +++ b/deepspeed/comm/torch.py @@ -147,6 +147,10 @@ def has_reduce_scatter_tensor(self): def init_process_group(self, backend, timeout, init_method, rank, world_size): if not torch.distributed.is_initialized(): + if backend == XLA_BACKEND: + import torch_xla.distributed.xla_backend # noqa: F401 + if init_method is None: + init_method = "xla://" kwargs = dict(timeout=timeout, init_method=init_method, rank=rank, world_size=world_size) # 1. device_id arg was added in torch==2.3 @@ -159,6 +163,13 @@ def init_process_group(self, backend, timeout, init_method, rank, world_size): kwargs.update(device_id=get_accelerator().device(local_rank)) torch.distributed.init_process_group(backend, **kwargs) + if backend == XLA_BACKEND: + os.environ.setdefault('RANK', str(torch.distributed.get_rank())) + os.environ.setdefault('WORLD_SIZE', str(torch.distributed.get_world_size())) + if 'LOCAL_RANK' not in os.environ: + import torch_xla.core.xla_model as xm + os.environ['LOCAL_RANK'] = str(xm.get_local_ordinal()) + self.using_mpi = torch.distributed.get_backend() == 'mpi' @disable_compiler_collective diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 49a66ebfbbfe..8581f27767ae 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -59,14 +59,9 @@ def input(msg): def split_half_float_double(tensors): - device_type = get_accelerator().device_name() - dtypes = [ - "torch.{}.HalfTensor".format(device_type), "torch.{}.FloatTensor".format(device_type), - "torch.{}.DoubleTensor".format(device_type), "torch.{}.BFloat16Tensor".format(device_type) - ] buckets = [] - for i, dtype in enumerate(dtypes): - bucket = [t for t in tensors if t.type() == dtype] + for dtype in (torch.float16, torch.float32, torch.float64, torch.bfloat16): + bucket = [t for t in tensors if t.dtype == dtype] if bucket: buckets.append(bucket) return buckets diff --git a/docs/_tutorials/accelerator-setup-guide.md b/docs/_tutorials/accelerator-setup-guide.md index 20e667170eaa..d32b04f8baba 100644 --- a/docs/_tutorials/accelerator-setup-guide.md +++ b/docs/_tutorials/accelerator-setup-guide.md @@ -8,6 +8,7 @@ tags: getting-started training accelerator - [Introduction](#introduction) - [Intel Architecture (IA) CPU](#intel-architecture-ia-cpu) - [Intel XPU](#intel-xpu) +- [Google Cloud TPU](#google-cloud-tpu) - [Huawei Ascend NPU](#huawei-ascend-npu) - [Intel Gaudi](#intel-gaudi) @@ -159,6 +160,31 @@ accelerator: xpu ``` +# Google Cloud TPU +DeepSpeed TPU support targets ZeRO-1 and ZeRO-2 on top of PyTorch/XLA. + +## Installation steps for Google Cloud TPU +1. Install PyTorch and PyTorch/XLA with TPU support following the matching wheel set for your runtime. +2. Install DeepSpeed. + +## How to use DeepSpeed on Google Cloud TPU +DeepSpeed uses the `xla` distributed backend on TPU through `torch.distributed`, and defaults to the `xla://` init method when `torch_xla` is available. + +To validate that TPU support is selected, here is an example: +``` +$ python +>>> import torch_xla +>>> from deepspeed.accelerator import get_accelerator +>>> accelerator = get_accelerator() +>>> print(accelerator._name) +xla +>>> print(accelerator.communication_backend_name()) +xla +``` + +TPU support currently focuses on ZeRO-1 and ZeRO-2 training paths. Use `bf16` training on TPU because fp16 is not supported by the XLA accelerator path. + + # Huawei Ascend NPU DeepSpeed has been verified on the following Huawei Ascend NPU products: diff --git a/tests/accelerator/test_ds_init.py b/tests/accelerator/test_ds_init.py index 9594a6f5ea58..926b47215d15 100644 --- a/tests/accelerator/test_ds_init.py +++ b/tests/accelerator/test_ds_init.py @@ -40,9 +40,9 @@ def test_literal_device(): os.environ['LOCAL_RANK'] = '0' deepspeed.init_distributed(get_accelerator().communication_backend_name()) deepspeed.initialize(model=model, config='ds_config.json') - string = get_accelerator().device_name() #'xpu' or 'cuda' - string0 = get_accelerator().device_name(0) #'xpu:0' or 'cuda:0' - string1 = get_accelerator().device_name(1) #'xpu:1' or 'cuda:1' - assert string == 'xpu' or string == 'cuda' - assert string0 == 'xpu:0' or string0 == 'cuda:0' - assert string1 == 'xpu:1' or string1 == 'cuda:1' + string = get_accelerator().device_name() #'xpu', 'cuda', or 'xla' + string0 = get_accelerator().device_name(0) #'xpu:0', 'cuda:0', or 'xla:0' + string1 = get_accelerator().device_name(1) #'xpu:1', 'cuda:1', or 'xla:1' + assert string in ['xpu', 'cuda', 'xla'] + assert string0 in ['xpu:0', 'cuda:0', 'xla:0'] + assert string1 in ['xpu:1', 'cuda:1', 'xla:1'] diff --git a/tests/unit/comm/test_xla_backend.py b/tests/unit/comm/test_xla_backend.py new file mode 100644 index 000000000000..a10c7db2d16d --- /dev/null +++ b/tests/unit/comm/test_xla_backend.py @@ -0,0 +1,106 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import os +import sys +import types +import importlib +from datetime import timedelta + +import pytest +import torch + +from deepspeed.comm.constants import XLA_BACKEND + + +def _install_fake_torch_xla(monkeypatch, local_ordinal=0): + torch_xla = types.ModuleType("torch_xla") + torch_xla_core = types.ModuleType("torch_xla.core") + torch_xla_xla_model = types.ModuleType("torch_xla.core.xla_model") + torch_xla_distributed = types.ModuleType("torch_xla.distributed") + torch_xla_backend = types.ModuleType("torch_xla.distributed.xla_backend") + + torch_xla_xla_model.get_local_ordinal = lambda: local_ordinal + + monkeypatch.setitem(sys.modules, "torch_xla", torch_xla) + monkeypatch.setitem(sys.modules, "torch_xla.core", torch_xla_core) + monkeypatch.setitem(sys.modules, "torch_xla.core.xla_model", torch_xla_xla_model) + monkeypatch.setitem(sys.modules, "torch_xla.distributed", torch_xla_distributed) + monkeypatch.setitem(sys.modules, "torch_xla.distributed.xla_backend", torch_xla_backend) + + +def test_torch_backend_uses_xla_init_method(monkeypatch): + _install_fake_torch_xla(monkeypatch, local_ordinal=3) + ds_torch = importlib.import_module("deepspeed.comm.torch") + + init_calls = [] + dist_pkg = getattr(torch, 'distributed') + + class FakeAccelerator: + + @staticmethod + def device_name(): + return 'xla' + + monkeypatch.delenv('LOCAL_RANK', raising=False) + monkeypatch.setattr(ds_torch, "build_shm_op", lambda: None) + monkeypatch.setattr(ds_torch, "get_accelerator", lambda: FakeAccelerator()) + monkeypatch.setattr(dist_pkg, "is_initialized", lambda: False) + monkeypatch.setattr( + dist_pkg, + "init_process_group", + lambda backend, **kwargs: init_calls.append((backend, kwargs)), + ) + monkeypatch.setattr(dist_pkg, "get_rank", lambda: 1) + monkeypatch.setattr(dist_pkg, "get_world_size", lambda: 8) + monkeypatch.setattr(dist_pkg, "get_backend", lambda: XLA_BACKEND) + + ds_torch.TorchBackend(XLA_BACKEND, timedelta(seconds=5), None) + + assert init_calls[0][0] == XLA_BACKEND + assert init_calls[0][1]["init_method"] == "xla://" + assert os.environ["LOCAL_RANK"] == "3" + assert os.environ["RANK"] == "1" + assert os.environ["WORLD_SIZE"] == "8" + + +def test_init_distributed_skips_mpi_discovery_for_xla(monkeypatch): + import deepspeed.comm.comm as ds_comm + + calls = [] + + class FakeAccelerator: + + @staticmethod + def communication_backend_name(): + return XLA_BACKEND + + class FakeTorchBackend: + + def __init__(self, backend, timeout, init_method, rank=-1, world_size=-1): + calls.append((backend, timeout, init_method, rank, world_size)) + + @staticmethod + def is_initialized(): + return True + + monkeypatch.setattr(ds_comm, "cdb", None) + monkeypatch.setattr(ds_comm, "configure", lambda deepspeed_config=None: None) + monkeypatch.setattr(ds_comm, "init_deepspeed_backend", lambda *args, **kwargs: None) + monkeypatch.setattr(ds_comm, "set_backend", lambda: None) + monkeypatch.setattr(ds_comm, "get_accelerator", lambda: FakeAccelerator()) + monkeypatch.setattr(ds_comm, "TorchBackend", FakeTorchBackend) + monkeypatch.setattr(ds_comm, "mpi_discovery", + lambda *args, **kwargs: pytest.fail("mpi_discovery should not run for the xla backend")) + + for env_var in ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]: + monkeypatch.delenv(env_var, raising=False) + + ds_comm.init_distributed(dist_backend=XLA_BACKEND, + auto_mpi_discovery=True, + verbose=False, + timeout=timedelta(seconds=5)) + + assert calls == [(XLA_BACKEND, timedelta(seconds=5), None, -1, -1)] diff --git a/tests/unit/common.py b/tests/unit/common.py index f57ee3395973..0fcae11d22a5 100644 --- a/tests/unit/common.py +++ b/tests/unit/common.py @@ -112,6 +112,8 @@ def set_accelerator_visible(): elif get_accelerator().device_name() == 'npu': npu_smi = subprocess.check_output(['npu-smi', 'info', '-l']) num_accelerators = int(npu_smi.decode('utf-8').strip().split('\n')[0].split(':')[1].strip()) + elif get_accelerator().device_name() == 'xla': + num_accelerators = get_accelerator().device_count() else: assert get_accelerator().device_name() == 'cpu' num_accelerators = _get_cpu_socket_count()