Skip to content

Add torch_xla TPU support for ZeRO-1/2#7917

Open
PKUWZP wants to merge 9 commits intomasterfrom
codex/tpu-xla-zero12
Open

Add torch_xla TPU support for ZeRO-1/2#7917
PKUWZP wants to merge 9 commits intomasterfrom
codex/tpu-xla-zero12

Conversation

@PKUWZP
Copy link
Collaborator

@PKUWZP PKUWZP commented Mar 21, 2026

Summary

  • add an XLA accelerator for TPU execution through torch_xla
  • initialize DeepSpeed distributed with the torch.distributed xla backend and xla:// init flow
  • make ZeRO-1/2 gradient bucketing backend-agnostic and add TPU/XLA-focused tests and docs

Testing

  • git diff --check
  • python3 -m py_compile accelerator/xla_accelerator.py accelerator/real_accelerator.py deepspeed/comm/constants.py deepspeed/comm/comm.py deepspeed/comm/torch.py deepspeed/runtime/zero/stage_1_and_2.py tests/accelerator/test_ds_init.py tests/unit/common.py tests/unit/accelerator/test_accelerator.py tests/unit/comm/test_xla_backend.py

Notes

  • local python environment in this session did not have pytest or pre-commit installed, so those checks could not be executed here
  • scope is intentionally limited to TPU support for ZeRO-1/2 via torch_xla

Signed-off-by: PKUWZP <zhipeng.rainbowserie@gmail.com>
@PKUWZP PKUWZP requested review from delock and removed request for GuanhuaWang March 21, 2026 18:44
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 8acc916e4e

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +49 to +52
def device_name(self, device_index=None):
if device_index is None:
return 'xla'
return f'xla:{device_index}'

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Map XLA device names to addressable device indices

device_name(device_index) is used by DeepSpeedEngine._set_distributed_vars() (deepspeed/runtime/engine.py:1287-1290) to build self.device from LOCAL_RANK, and the module is then moved there at line 1427. Under PJRT, LOCAL_RANK is not the per-process XLA device index: each worker only sees its own addressable XLA devices, which is often just xla:0. Returning xla:{LOCAL_RANK} here means nonzero ranks on a multi-chip TPU will try to place the model on xla:1, xla:2, etc., which are not addressable in that process and causes multi-rank launches to fail or target the wrong chip.

Useful? React with 👍 / 👎.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in 7f82c20. device_name(device_index) now resolves through the process addressable XLA device list instead of treating LOCAL_RANK as a raw XLA ordinal, so single-device-per-process workers map nonzero local ranks back to their local xla:0 device correctly.

Comment on lines +58 to +59
def set_device(self, device_index):
os.environ['LOCAL_RANK'] = str(device_index)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Actually bind each TPU worker in set_device()

DeepSpeed's launcher gives every local process the same TPU visibility mask (deepspeed/launcher/launch.py:182-183) and relies on get_accelerator().set_device(local_rank) from DeepSpeedEngine._set_distributed_vars() to pin each worker to its chip. This implementation only rewrites LOCAL_RANK; it never calls a torch_xla/PJRT device-selection API or sets the PJRT process-rank env that torch_xla uses to derive local ordinals. On a host with multiple TPU chips, multiple ranks can therefore attach to the same default XLA device, which breaks distributed initialization and ZeRO synchronization.

Useful? React with 👍 / 👎.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in 7f82c20. set_device() now calls into xm.xla_device() to select the XLA default device for the current process before DeepSpeed moves the model, and it preserves the launcher-provided rank information in the environment.

PKUWZP added 6 commits March 21, 2026 15:46
Signed-off-by: PKUWZP <zhipeng.rainbowserie@gmail.com>
Signed-off-by: PKUWZP <zhipeng.rainbowserie@gmail.com>
Signed-off-by: PKUWZP <zhipeng.rainbowserie@gmail.com>
Signed-off-by: PKUWZP <zhipeng.rainbowserie@gmail.com>
Signed-off-by: PKUWZP <zhipeng.rainbowserie@gmail.com>
Signed-off-by: PKUWZP <zhipeng.rainbowserie@gmail.com>
Copy link
Collaborator

@tohtana tohtana left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for submitting a great PR, @PKUWZP!
I left two comments, though they are not the core part of this PR.

raise RuntimeError("No addressable XLA devices are available in the current process.")
if device_index is None:
return 0
return min(device_index, len(devices) - 1)
Copy link
Collaborator

@tohtana tohtana Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assumes device_index is int, but can actually be

  • string (device_name: code)
  • torch.device (set_device: code)

For these, the function should throw TypeError. I suggest handling these types in this function.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tohtana good catch! I made some changes, it now handles str (e.g., "0" or "xla:0" from os.environ["LOCAL_RANK"]) and torch.device (from torch.device(get_accelerator().device_name(...))) in addition to int.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, I see you added handling in _normalize_device_index. But I wonder if set_device still has an issue with the path from partition_parameters.py.

In set_device, we write device_index back to the env vars (code). If the device_index is torch.device, LOCAL_RANK and PJRT_LOCAL_PROCESS_RANK will include device type (e.g. xla:0), not only a number.

accelerator = get_accelerator()
dtype_order = (torch.float16, torch.float32, torch.float64, torch.bfloat16)
for dtype in dtype_order:
bucket = [tensor for tensor in tensors if tensor.dtype == dtype and accelerator.on_accelerator(tensor)]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does xla require special handling given the prior code worked for other accelerators?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another issue is that it seems on_accelerator is only defined for xla and so other accelerators will break here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sfc-gh-truwase Thanks for the comments:) I removed the accelerator.on_accelerator(tensor) filter that would have changed behavior for all backends. I also kept the dtype-based comparison (replacing the old string-based type names like torch.cuda.HalfTensor) since that's the actual fix needed for XLA compatibility without breaking other accelerators.

utils.logger.info(f'cdb={cdb}')
if cdb is None and torch.distributed.is_initialized():
# The user initialized torch.dist themselves, create cdb and short-circuit
if dist_backend is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this behavior? Is it specific to xla?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sfc-gh-truwase Thanks for the call out! I added a comment clarifying that it's a general fix (not XLA-specific) — it prevents passing None to TorchBackend when the user pre-initialized torch.distributed without specifying dist_backend.

selected_device["index"] = n
return FakeDevice(selected_device["index"])

torch_xla.devices = lambda: [FakeDevice(idx) for idx in range(device_count)]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think using a fake TPU device is useful since computation cannot be tested. Rather, I think we should condition these tests to only run when TPU is available. Another possibility is to setup CI tests specifically for TPU on the cloud credits.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sfc-gh-truwase Great suggestions. I initially thought adding the fake TPU device is useful for testing, now I removed all three fake TPU device tests and the _install_fake_torch_xla helper. I agree with you that these should be conditioned on real TPU availability instead of using faked devices.

PKUWZP added 2 commits March 26, 2026 00:02
- split_half_float_double: use dtype comparison instead of string-based
  type names, without adding on_accelerator filtering that would change
  behavior for all backends
- comm.py: clarify that dist_backend fallback is not XLA-specific
- Remove fake TPU device tests per reviewer guidance; XLA accelerator
  tests should run on real TPU hardware

Signed-off-by: PKUWZP <zhipeng.rainbowserie@gmail.com>
- _normalize_device_index: handle str and torch.device types in addition
  to int, since callers pass LOCAL_RANK strings and torch.device objects
- real_accelerator: catch RuntimeError from get_xla_supported_devices()
  when torch_xla is installed but no TPU/PJRT runtime is available

Signed-off-by: PKUWZP <zhipeng.rainbowserie@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants