-
Notifications
You must be signed in to change notification settings - Fork 4.8k
Add torch_xla TPU support for ZeRO-1/2 #7917
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
8acc916
7f82c20
dbfd0a9
41d0f7e
7debfbb
623df9e
65cf60f
7f6ea95
b14e5b4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,294 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| # DeepSpeed Team | ||
|
|
||
| import os | ||
| import functools | ||
|
|
||
| import torch | ||
|
|
||
| from .abstract_accelerator import DeepSpeedAccelerator | ||
|
|
||
| try: | ||
| import torch_xla | ||
| import torch_xla.core.xla_model as xm | ||
| except ImportError as e: | ||
| torch_xla = None | ||
| xm = None | ||
|
|
||
|
|
||
| class XLA_Accelerator(DeepSpeedAccelerator): | ||
|
|
||
| def __init__(self): | ||
| self._name = 'xla' | ||
| self._communication_backend_name = 'xla' | ||
| self._compile_backend = None | ||
| if xm is None: | ||
| raise ValueError("XLA_Accelerator requires torch_xla, which is not installed on this system.") | ||
|
|
||
| def _require_xm(self): | ||
| if xm is None: | ||
| raise RuntimeError("torch_xla is required to use the XLA_Accelerator") | ||
| return xm | ||
|
|
||
| def _tensor_factory(self, dtype): | ||
| return functools.partial(torch.tensor, dtype=dtype, device=self.current_device_name()) | ||
|
|
||
| def _addressable_devices(self): | ||
| if torch_xla is not None and hasattr(torch_xla, 'devices'): | ||
| return list(torch_xla.devices()) | ||
|
|
||
| xm_module = self._require_xm() | ||
| return [torch.device(device) for device in xm_module.get_xla_supported_devices(devkind='TPU')] | ||
|
|
||
| def _normalize_device_index(self, device_index=None): | ||
| devices = self._addressable_devices() | ||
| if not devices: | ||
| raise RuntimeError("No addressable XLA devices are available in the current process.") | ||
| if device_index is None: | ||
| return 0 | ||
| if isinstance(device_index, torch.device): | ||
| device_index = device_index.index if device_index.index is not None else 0 | ||
| elif isinstance(device_index, str): | ||
| device_index = int(device_index) if device_index.isdigit() else int(device_index.split(':')[-1]) | ||
| return min(device_index, len(devices) - 1) | ||
|
|
||
| def is_synchronized_device(self): | ||
| return True | ||
|
|
||
| def use_host_timers(self): | ||
| return True | ||
|
|
||
| def resolves_data_dependency(self): | ||
| return True | ||
|
|
||
| def handles_memory_backpressure(self): | ||
| return True | ||
|
|
||
| # Device APIs | ||
| def device_name(self, device_index=None): | ||
| if device_index is None: | ||
| return 'xla' | ||
| return str(self._addressable_devices()[self._normalize_device_index(device_index)]) | ||
|
|
||
| def device(self, device_index=None): | ||
| xm_module = self._require_xm() | ||
| if device_index is None: | ||
| return xm_module.xla_device(devkind='TPU') | ||
| return xm_module.xla_device(n=self._normalize_device_index(device_index), devkind='TPU') | ||
|
|
||
| def set_device(self, device_index): | ||
| # XLA uses the default device selected for the current process. | ||
| self.device(device_index) | ||
| os.environ['LOCAL_RANK'] = str(device_index) | ||
| os.environ.setdefault('PJRT_LOCAL_PROCESS_RANK', str(device_index)) | ||
|
|
||
| def current_device(self): | ||
| current_device = self.device() | ||
| device_index = getattr(current_device, 'index', None) | ||
| if device_index is not None: | ||
| return device_index | ||
| return self._normalize_device_index() | ||
|
|
||
| def current_device_name(self): | ||
| return str(self.device()) | ||
|
|
||
| def device_count(self): | ||
| xm_module = self._require_xm() | ||
| return len(xm_module.get_xla_supported_devices(devkind='TPU')) | ||
|
|
||
| def synchronize(self, device_index=None): | ||
| xm_module = self._require_xm() | ||
| xm_module.mark_step() | ||
| return xm_module.wait_device_ops() | ||
|
|
||
| # RNG APIs | ||
| def random(self): | ||
| return torch.random | ||
|
|
||
| def set_rng_state(self, new_state, device_index=None): | ||
| return torch.set_rng_state(new_state) | ||
|
|
||
| def get_rng_state(self, device_index=None): | ||
| return torch.get_rng_state() | ||
|
|
||
| def manual_seed(self, seed): | ||
| return torch.manual_seed(seed) | ||
|
|
||
| def manual_seed_all(self, seed): | ||
| return torch.manual_seed(seed) | ||
|
|
||
| def initial_seed(self): | ||
| return torch.initial_seed() | ||
|
|
||
| def default_generator(self, device_index): | ||
| return torch.default_generator | ||
|
|
||
| # Streams/Events | ||
| @property | ||
| def Stream(self): | ||
| return None | ||
|
|
||
| def stream(self, stream): | ||
| from deepspeed.runtime.utils import noop_context | ||
| return noop_context() | ||
|
|
||
| def current_stream(self, device_index=None): | ||
| return None | ||
|
|
||
| def default_stream(self, device_index=None): | ||
| return None | ||
|
|
||
| @property | ||
| def Event(self): | ||
| return None | ||
|
|
||
| # Memory management | ||
| def empty_cache(self): | ||
| return | ||
|
|
||
| def memory_allocated(self, device_index=None): | ||
| return 0 | ||
|
|
||
| def max_memory_allocated(self, device_index=None): | ||
| return 0 | ||
|
|
||
| def reset_max_memory_allocated(self, device_index=None): | ||
| return | ||
|
|
||
| def memory_cached(self, device_index=None): | ||
| return 0 | ||
|
|
||
| def max_memory_cached(self, device_index=None): | ||
| return 0 | ||
|
|
||
| def reset_max_memory_cached(self, device_index=None): | ||
| return | ||
|
|
||
| def memory_stats(self, device_index=None): | ||
| return {} | ||
|
|
||
| def reset_peak_memory_stats(self, device_index=None): | ||
| return | ||
|
|
||
| def memory_reserved(self, device_index=None): | ||
| return 0 | ||
|
|
||
| def max_memory_reserved(self, device_index=None): | ||
| return 0 | ||
|
|
||
| def total_memory(self, device_index=None): | ||
| return 0 | ||
|
|
||
| def available_memory(self, device_index=None): | ||
| return 0 | ||
|
|
||
| # Data types | ||
| def is_bf16_supported(self): | ||
| return True | ||
|
|
||
| def is_fp16_supported(self): | ||
| return False | ||
|
|
||
| def supported_dtypes(self): | ||
| return [torch.float32, torch.bfloat16] | ||
|
|
||
| # Misc | ||
| def is_available(self): | ||
| return self.device_count() > 0 | ||
|
|
||
| def range_push(self, msg): | ||
| return | ||
|
|
||
| def range_pop(self): | ||
| return | ||
|
|
||
| def lazy_call(self, callback): | ||
| return callback() | ||
|
|
||
| def communication_backend_name(self): | ||
| return self._communication_backend_name | ||
|
|
||
| def is_triton_supported(self): | ||
| return False | ||
|
|
||
| # Graph operations | ||
| def create_graph(self): | ||
| return None | ||
|
|
||
| def capture_to_graph(self, graph, pool=None, stream=None): | ||
| from deepspeed.runtime.utils import noop_context | ||
| return noop_context() | ||
|
|
||
| def replay_graph(self, graph): | ||
| return | ||
|
|
||
| # Tensor operations | ||
| @property | ||
| def BFloat16Tensor(self): | ||
| return self._tensor_factory(torch.bfloat16) | ||
|
|
||
| @property | ||
| def ByteTensor(self): | ||
| return self._tensor_factory(torch.uint8) | ||
|
|
||
| @property | ||
| def DoubleTensor(self): | ||
| return self._tensor_factory(torch.float64) | ||
|
|
||
| @property | ||
| def FloatTensor(self): | ||
| return self._tensor_factory(torch.float32) | ||
|
|
||
| @property | ||
| def HalfTensor(self): | ||
| return self._tensor_factory(torch.float16) | ||
|
|
||
| @property | ||
| def IntTensor(self): | ||
| return self._tensor_factory(torch.int32) | ||
|
|
||
| @property | ||
| def LongTensor(self): | ||
| return self._tensor_factory(torch.int64) | ||
|
|
||
| def pin_memory(self, tensor, align_bytes=1): | ||
| return tensor | ||
|
|
||
| def is_pinned(self, tensor): | ||
| return False | ||
|
|
||
| def on_accelerator(self, tensor): | ||
| return getattr(tensor.device, 'type', None) == 'xla' | ||
|
|
||
| def op_builder_dir(self): | ||
| return "deepspeed.ops.op_builder.cpu" | ||
|
|
||
| def create_op_builder(self, op_name): | ||
| return None | ||
|
|
||
| def get_op_builder(self, class_name): | ||
| return None | ||
|
|
||
| def build_extension(self): | ||
| from torch.utils.cpp_extension import BuildExtension | ||
| return BuildExtension | ||
|
|
||
| def export_envs(self): | ||
| return ['PJRT_DEVICE', 'TPU_VISIBLE_CHIPS'] | ||
|
|
||
| def visible_devices_envs(self): | ||
| return ['TPU_VISIBLE_CHIPS'] | ||
|
|
||
| def set_visible_devices_envs(self, current_env, local_accelerator_ids): | ||
| for env in self.visible_devices_envs(): | ||
| current_env[env] = ",".join(map(str, local_accelerator_ids)) | ||
| current_env.setdefault('PJRT_DEVICE', 'TPU') | ||
|
|
||
| def get_compile_backend(self): | ||
| return self._compile_backend | ||
|
|
||
| def set_compile_backend(self, backend): | ||
| if backend is not None: | ||
| raise ValueError(f"{backend} not supported by {self.device_name()}. Supported Backends are [None]") | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -162,6 +162,8 @@ def init_deepspeed_backend(ds_backend, timeout, init_method): | |
| utils.logger.info(f"Initialize {ds_backend} backend") | ||
| elif ds_backend == HCCL_BACKEND: | ||
| utils.logger.debug("HCCL backend in DeepSpeed not yet implemented") | ||
| elif ds_backend == XLA_BACKEND: | ||
| utils.logger.debug("XLA backend in DeepSpeed is provided via torch.distributed") | ||
| else: | ||
| utils.logger.debug(f"DeepSpeed does not support {ds_backend} backend") | ||
|
|
||
|
|
@@ -820,7 +822,10 @@ def init_distributed(dist_backend=None, | |
| set_backend() | ||
| utils.logger.info(f'cdb={cdb}') | ||
| if cdb is None and torch.distributed.is_initialized(): | ||
| # The user initialized torch.dist themselves, create cdb and short-circuit | ||
| # The user initialized torch.dist themselves, create cdb and short-circuit. | ||
| # Resolve dist_backend so TorchBackend always receives a concrete name. | ||
| if dist_backend is None: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need this behavior? Is it specific to
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @sfc-gh-truwase Thanks for the call out! I added a comment clarifying that it's a general fix (not XLA-specific) — it prevents passing |
||
| dist_backend = get_accelerator().communication_backend_name() | ||
| cdb = TorchBackend(dist_backend, timeout, init_method) | ||
| return | ||
|
|
||
|
|
@@ -831,7 +836,8 @@ def init_distributed(dist_backend=None, | |
| else: | ||
| # Initialize torch distributed if needed | ||
| required_env = ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] | ||
| if auto_mpi_discovery and not all(map(lambda v: v in os.environ, required_env)): | ||
| xla_backend = dist_backend == XLA_BACKEND or get_accelerator().communication_backend_name() == XLA_BACKEND | ||
| if auto_mpi_discovery and not xla_backend and not all(map(lambda v: v in os.environ, required_env)): | ||
| if verbose: | ||
| utils.logger.info("Not using the DeepSpeed or dist launchers, attempting to detect MPI environment...") | ||
| if in_aml() and not in_dlts(): | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This assumes
device_indexis int, but can actually bedevice_name: code)torch.device(set_device: code)For these, the function should throw
TypeError. I suggest handling these types in this function.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tohtana good catch! I made some changes, it now handles
str(e.g., "0" or "xla:0" fromos.environ["LOCAL_RANK"]) andtorch.device(from torch.device(get_accelerator().device_name(...))) in addition to int.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, I see you added handling in
_normalize_device_index. But I wonder ifset_devicestill has an issue with the path frompartition_parameters.py.In
set_device, we write device_index back to the env vars (code). If thedevice_indexistorch.device,LOCAL_RANKandPJRT_LOCAL_PROCESS_RANKwill include device type (e.g.xla:0), not only a number.