Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion accelerator/real_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
except ImportError as e:
dsa2 = None

SUPPORTED_ACCELERATOR_LIST = ['cuda', 'cpu', 'xpu', 'npu', 'mps', 'hpu', 'mlu', 'sdaa']
SUPPORTED_ACCELERATOR_LIST = ['cuda', 'cpu', 'xpu', 'npu', 'mps', 'hpu', 'mlu', 'sdaa', 'xla']

ds_accelerator = None

Expand Down Expand Up @@ -98,6 +98,12 @@ def get_accelerator():
import torch_mlu # noqa: F401
except ImportError as e:
raise ValueError("MLU_Accelerator requires torch_mlu, which is not installed on this system.")
elif accelerator_name in ["xla", "tpu"]:
accelerator_name = "xla"
try:
import torch_xla # noqa: F401
except ImportError as e:
raise ValueError("XLA_Accelerator requires torch_xla, which is not installed on this system.")
elif accelerator_name not in SUPPORTED_ACCELERATOR_LIST:
raise ValueError(f'DS_ACCELERATOR must be one of {SUPPORTED_ACCELERATOR_LIST}. '
f'Value "{accelerator_name}" is not supported')
Expand Down Expand Up @@ -125,6 +131,14 @@ def get_accelerator():
accelerator_name = "xpu"
except ImportError as e:
pass
if accelerator_name is None:
try:
import torch_xla.core.xla_model as xm

if len(xm.get_xla_supported_devices(devkind='TPU')) > 0:
accelerator_name = "xla"
except (ImportError, RuntimeError):
pass
if accelerator_name is None:
try:
import torch_npu # noqa: F401,F811 # type: ignore
Expand Down Expand Up @@ -220,6 +234,10 @@ def get_accelerator():
from .mlu_accelerator import MLU_Accelerator

ds_accelerator = MLU_Accelerator()
elif accelerator_name == 'xla':
from .xla_accelerator import XLA_Accelerator

ds_accelerator = XLA_Accelerator()
_validate_accelerator(ds_accelerator)
if accel_logger is not None:
accel_logger.info(f"Setting ds_accelerator to {ds_accelerator._name} ({ds_set_method})")
Expand Down
294 changes: 294 additions & 0 deletions accelerator/xla_accelerator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,294 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import os
import functools

import torch

from .abstract_accelerator import DeepSpeedAccelerator

try:
import torch_xla
import torch_xla.core.xla_model as xm
except ImportError as e:
torch_xla = None
xm = None


class XLA_Accelerator(DeepSpeedAccelerator):

def __init__(self):
self._name = 'xla'
self._communication_backend_name = 'xla'
self._compile_backend = None
if xm is None:
raise ValueError("XLA_Accelerator requires torch_xla, which is not installed on this system.")

def _require_xm(self):
if xm is None:
raise RuntimeError("torch_xla is required to use the XLA_Accelerator")
return xm

def _tensor_factory(self, dtype):
return functools.partial(torch.tensor, dtype=dtype, device=self.current_device_name())

def _addressable_devices(self):
if torch_xla is not None and hasattr(torch_xla, 'devices'):
return list(torch_xla.devices())

xm_module = self._require_xm()
return [torch.device(device) for device in xm_module.get_xla_supported_devices(devkind='TPU')]

def _normalize_device_index(self, device_index=None):
devices = self._addressable_devices()
if not devices:
raise RuntimeError("No addressable XLA devices are available in the current process.")
if device_index is None:
return 0
if isinstance(device_index, torch.device):
device_index = device_index.index if device_index.index is not None else 0
elif isinstance(device_index, str):
device_index = int(device_index) if device_index.isdigit() else int(device_index.split(':')[-1])
return min(device_index, len(devices) - 1)
Copy link
Copy Markdown
Collaborator

@tohtana tohtana Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assumes device_index is int, but can actually be

  • string (device_name: code)
  • torch.device (set_device: code)

For these, the function should throw TypeError. I suggest handling these types in this function.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tohtana good catch! I made some changes, it now handles str (e.g., "0" or "xla:0" from os.environ["LOCAL_RANK"]) and torch.device (from torch.device(get_accelerator().device_name(...))) in addition to int.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, I see you added handling in _normalize_device_index. But I wonder if set_device still has an issue with the path from partition_parameters.py.

In set_device, we write device_index back to the env vars (code). If the device_index is torch.device, LOCAL_RANK and PJRT_LOCAL_PROCESS_RANK will include device type (e.g. xla:0), not only a number.


def is_synchronized_device(self):
return True

def use_host_timers(self):
return True

def resolves_data_dependency(self):
return True

def handles_memory_backpressure(self):
return True

# Device APIs
def device_name(self, device_index=None):
if device_index is None:
return 'xla'
return str(self._addressable_devices()[self._normalize_device_index(device_index)])

def device(self, device_index=None):
xm_module = self._require_xm()
if device_index is None:
return xm_module.xla_device(devkind='TPU')
return xm_module.xla_device(n=self._normalize_device_index(device_index), devkind='TPU')

def set_device(self, device_index):
# XLA uses the default device selected for the current process.
self.device(device_index)
os.environ['LOCAL_RANK'] = str(device_index)
os.environ.setdefault('PJRT_LOCAL_PROCESS_RANK', str(device_index))

def current_device(self):
current_device = self.device()
device_index = getattr(current_device, 'index', None)
if device_index is not None:
return device_index
return self._normalize_device_index()

def current_device_name(self):
return str(self.device())

def device_count(self):
xm_module = self._require_xm()
return len(xm_module.get_xla_supported_devices(devkind='TPU'))

def synchronize(self, device_index=None):
xm_module = self._require_xm()
xm_module.mark_step()
return xm_module.wait_device_ops()

# RNG APIs
def random(self):
return torch.random

def set_rng_state(self, new_state, device_index=None):
return torch.set_rng_state(new_state)

def get_rng_state(self, device_index=None):
return torch.get_rng_state()

def manual_seed(self, seed):
return torch.manual_seed(seed)

def manual_seed_all(self, seed):
return torch.manual_seed(seed)

def initial_seed(self):
return torch.initial_seed()

def default_generator(self, device_index):
return torch.default_generator

# Streams/Events
@property
def Stream(self):
return None

def stream(self, stream):
from deepspeed.runtime.utils import noop_context
return noop_context()

def current_stream(self, device_index=None):
return None

def default_stream(self, device_index=None):
return None

@property
def Event(self):
return None

# Memory management
def empty_cache(self):
return

def memory_allocated(self, device_index=None):
return 0

def max_memory_allocated(self, device_index=None):
return 0

def reset_max_memory_allocated(self, device_index=None):
return

def memory_cached(self, device_index=None):
return 0

def max_memory_cached(self, device_index=None):
return 0

def reset_max_memory_cached(self, device_index=None):
return

def memory_stats(self, device_index=None):
return {}

def reset_peak_memory_stats(self, device_index=None):
return

def memory_reserved(self, device_index=None):
return 0

def max_memory_reserved(self, device_index=None):
return 0

def total_memory(self, device_index=None):
return 0

def available_memory(self, device_index=None):
return 0

# Data types
def is_bf16_supported(self):
return True

def is_fp16_supported(self):
return False

def supported_dtypes(self):
return [torch.float32, torch.bfloat16]

# Misc
def is_available(self):
return self.device_count() > 0

def range_push(self, msg):
return

def range_pop(self):
return

def lazy_call(self, callback):
return callback()

def communication_backend_name(self):
return self._communication_backend_name

def is_triton_supported(self):
return False

# Graph operations
def create_graph(self):
return None

def capture_to_graph(self, graph, pool=None, stream=None):
from deepspeed.runtime.utils import noop_context
return noop_context()

def replay_graph(self, graph):
return

# Tensor operations
@property
def BFloat16Tensor(self):
return self._tensor_factory(torch.bfloat16)

@property
def ByteTensor(self):
return self._tensor_factory(torch.uint8)

@property
def DoubleTensor(self):
return self._tensor_factory(torch.float64)

@property
def FloatTensor(self):
return self._tensor_factory(torch.float32)

@property
def HalfTensor(self):
return self._tensor_factory(torch.float16)

@property
def IntTensor(self):
return self._tensor_factory(torch.int32)

@property
def LongTensor(self):
return self._tensor_factory(torch.int64)

def pin_memory(self, tensor, align_bytes=1):
return tensor

def is_pinned(self, tensor):
return False

def on_accelerator(self, tensor):
return getattr(tensor.device, 'type', None) == 'xla'

def op_builder_dir(self):
return "deepspeed.ops.op_builder.cpu"

def create_op_builder(self, op_name):
return None

def get_op_builder(self, class_name):
return None

def build_extension(self):
from torch.utils.cpp_extension import BuildExtension
return BuildExtension

def export_envs(self):
return ['PJRT_DEVICE', 'TPU_VISIBLE_CHIPS']

def visible_devices_envs(self):
return ['TPU_VISIBLE_CHIPS']

def set_visible_devices_envs(self, current_env, local_accelerator_ids):
for env in self.visible_devices_envs():
current_env[env] = ",".join(map(str, local_accelerator_ids))
current_env.setdefault('PJRT_DEVICE', 'TPU')

def get_compile_backend(self):
return self._compile_backend

def set_compile_backend(self, backend):
if backend is not None:
raise ValueError(f"{backend} not supported by {self.device_name()}. Supported Backends are [None]")
10 changes: 8 additions & 2 deletions deepspeed/comm/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ def init_deepspeed_backend(ds_backend, timeout, init_method):
utils.logger.info(f"Initialize {ds_backend} backend")
elif ds_backend == HCCL_BACKEND:
utils.logger.debug("HCCL backend in DeepSpeed not yet implemented")
elif ds_backend == XLA_BACKEND:
utils.logger.debug("XLA backend in DeepSpeed is provided via torch.distributed")
else:
utils.logger.debug(f"DeepSpeed does not support {ds_backend} backend")

Expand Down Expand Up @@ -820,7 +822,10 @@ def init_distributed(dist_backend=None,
set_backend()
utils.logger.info(f'cdb={cdb}')
if cdb is None and torch.distributed.is_initialized():
# The user initialized torch.dist themselves, create cdb and short-circuit
# The user initialized torch.dist themselves, create cdb and short-circuit.
# Resolve dist_backend so TorchBackend always receives a concrete name.
if dist_backend is None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this behavior? Is it specific to xla?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sfc-gh-truwase Thanks for the call out! I added a comment clarifying that it's a general fix (not XLA-specific) — it prevents passing None to TorchBackend when the user pre-initialized torch.distributed without specifying dist_backend.

dist_backend = get_accelerator().communication_backend_name()
cdb = TorchBackend(dist_backend, timeout, init_method)
return

Expand All @@ -831,7 +836,8 @@ def init_distributed(dist_backend=None,
else:
# Initialize torch distributed if needed
required_env = ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
if auto_mpi_discovery and not all(map(lambda v: v in os.environ, required_env)):
xla_backend = dist_backend == XLA_BACKEND or get_accelerator().communication_backend_name() == XLA_BACKEND
if auto_mpi_discovery and not xla_backend and not all(map(lambda v: v in os.environ, required_env)):
if verbose:
utils.logger.info("Not using the DeepSpeed or dist launchers, attempting to detect MPI environment...")
if in_aml() and not in_dlts():
Expand Down
1 change: 1 addition & 0 deletions deepspeed/comm/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
GLOO_BACKEND = 'gloo'
SCCL_BACKEND = 'sccl'
HCCL_BACKEND = 'hccl'
XLA_BACKEND = 'xla'

DEFAULT_AML_MASTER_PORT = "54965"
DEFAULT_AML_NCCL_SOCKET_IFNAME = "^docker0,lo"
Expand Down
Loading
Loading