Conversation
Signed-off-by: PKUWZP <zhipeng.rainbowserie@gmail.com>
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 8acc916e4e
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
accelerator/xla_accelerator.py
Outdated
| def device_name(self, device_index=None): | ||
| if device_index is None: | ||
| return 'xla' | ||
| return f'xla:{device_index}' |
There was a problem hiding this comment.
Map XLA device names to addressable device indices
device_name(device_index) is used by DeepSpeedEngine._set_distributed_vars() (deepspeed/runtime/engine.py:1287-1290) to build self.device from LOCAL_RANK, and the module is then moved there at line 1427. Under PJRT, LOCAL_RANK is not the per-process XLA device index: each worker only sees its own addressable XLA devices, which is often just xla:0. Returning xla:{LOCAL_RANK} here means nonzero ranks on a multi-chip TPU will try to place the model on xla:1, xla:2, etc., which are not addressable in that process and causes multi-rank launches to fail or target the wrong chip.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Addressed in 7f82c20. device_name(device_index) now resolves through the process addressable XLA device list instead of treating LOCAL_RANK as a raw XLA ordinal, so single-device-per-process workers map nonzero local ranks back to their local xla:0 device correctly.
| def set_device(self, device_index): | ||
| os.environ['LOCAL_RANK'] = str(device_index) |
There was a problem hiding this comment.
Actually bind each TPU worker in
set_device()
DeepSpeed's launcher gives every local process the same TPU visibility mask (deepspeed/launcher/launch.py:182-183) and relies on get_accelerator().set_device(local_rank) from DeepSpeedEngine._set_distributed_vars() to pin each worker to its chip. This implementation only rewrites LOCAL_RANK; it never calls a torch_xla/PJRT device-selection API or sets the PJRT process-rank env that torch_xla uses to derive local ordinals. On a host with multiple TPU chips, multiple ranks can therefore attach to the same default XLA device, which breaks distributed initialization and ZeRO synchronization.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Addressed in 7f82c20. set_device() now calls into xm.xla_device() to select the XLA default device for the current process before DeepSpeed moves the model, and it preserves the launcher-provided rank information in the environment.
Signed-off-by: PKUWZP <zhipeng.rainbowserie@gmail.com>
Signed-off-by: PKUWZP <zhipeng.rainbowserie@gmail.com>
Signed-off-by: PKUWZP <zhipeng.rainbowserie@gmail.com>
Signed-off-by: PKUWZP <zhipeng.rainbowserie@gmail.com>
Signed-off-by: PKUWZP <zhipeng.rainbowserie@gmail.com>
Signed-off-by: PKUWZP <zhipeng.rainbowserie@gmail.com>
| raise RuntimeError("No addressable XLA devices are available in the current process.") | ||
| if device_index is None: | ||
| return 0 | ||
| return min(device_index, len(devices) - 1) |
There was a problem hiding this comment.
@tohtana good catch! I made some changes, it now handles str (e.g., "0" or "xla:0" from os.environ["LOCAL_RANK"]) and torch.device (from torch.device(get_accelerator().device_name(...))) in addition to int.
There was a problem hiding this comment.
Thank you, I see you added handling in _normalize_device_index. But I wonder if set_device still has an issue with the path from partition_parameters.py.
In set_device, we write device_index back to the env vars (code). If the device_index is torch.device, LOCAL_RANK and PJRT_LOCAL_PROCESS_RANK will include device type (e.g. xla:0), not only a number.
| accelerator = get_accelerator() | ||
| dtype_order = (torch.float16, torch.float32, torch.float64, torch.bfloat16) | ||
| for dtype in dtype_order: | ||
| bucket = [tensor for tensor in tensors if tensor.dtype == dtype and accelerator.on_accelerator(tensor)] |
There was a problem hiding this comment.
Why does xla require special handling given the prior code worked for other accelerators?
There was a problem hiding this comment.
Another issue is that it seems on_accelerator is only defined for xla and so other accelerators will break here.
There was a problem hiding this comment.
@sfc-gh-truwase Thanks for the comments:) I removed the accelerator.on_accelerator(tensor) filter that would have changed behavior for all backends. I also kept the dtype-based comparison (replacing the old string-based type names like torch.cuda.HalfTensor) since that's the actual fix needed for XLA compatibility without breaking other accelerators.
| utils.logger.info(f'cdb={cdb}') | ||
| if cdb is None and torch.distributed.is_initialized(): | ||
| # The user initialized torch.dist themselves, create cdb and short-circuit | ||
| if dist_backend is None: |
There was a problem hiding this comment.
Why do we need this behavior? Is it specific to xla?
There was a problem hiding this comment.
@sfc-gh-truwase Thanks for the call out! I added a comment clarifying that it's a general fix (not XLA-specific) — it prevents passing None to TorchBackend when the user pre-initialized torch.distributed without specifying dist_backend.
| selected_device["index"] = n | ||
| return FakeDevice(selected_device["index"]) | ||
|
|
||
| torch_xla.devices = lambda: [FakeDevice(idx) for idx in range(device_count)] |
There was a problem hiding this comment.
I don't think using a fake TPU device is useful since computation cannot be tested. Rather, I think we should condition these tests to only run when TPU is available. Another possibility is to setup CI tests specifically for TPU on the cloud credits.
There was a problem hiding this comment.
@sfc-gh-truwase Great suggestions. I initially thought adding the fake TPU device is useful for testing, now I removed all three fake TPU device tests and the _install_fake_torch_xla helper. I agree with you that these should be conditioned on real TPU availability instead of using faked devices.
- split_half_float_double: use dtype comparison instead of string-based type names, without adding on_accelerator filtering that would change behavior for all backends - comm.py: clarify that dist_backend fallback is not XLA-specific - Remove fake TPU device tests per reviewer guidance; XLA accelerator tests should run on real TPU hardware Signed-off-by: PKUWZP <zhipeng.rainbowserie@gmail.com>
- _normalize_device_index: handle str and torch.device types in addition to int, since callers pass LOCAL_RANK strings and torch.device objects - real_accelerator: catch RuntimeError from get_xla_supported_devices() when torch_xla is installed but no TPU/PJRT runtime is available Signed-off-by: PKUWZP <zhipeng.rainbowserie@gmail.com>
Summary
Testing
Notes