diff --git a/test/distributed/_shard/sharded_optim/test_sharded_optim.py b/test/distributed/_shard/sharded_optim/test_sharded_optim.py index 12ba2a2aed1b6..8fcefd7394321 100644 --- a/test/distributed/_shard/sharded_optim/test_sharded_optim.py +++ b/test/distributed/_shard/sharded_optim/test_sharded_optim.py @@ -1,19 +1,32 @@ # Owner(s): ["oncall: distributed"] +# Adapted from upstream test_sharded_optim.py — made device-agnostic for PrivateUse1 backends. +import sys from copy import deepcopy import torch +import torch.distributed as dist import torch.optim as optim from torch.distributed._shard import shard_parameter, sharded_tensor from torch.distributed._shard.sharded_optim import ShardedOptimizer from torch.distributed._shard.sharding_spec import ChunkShardingSpec -from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu +from torch.testing._internal.common_distributed import ( + requires_accelerator_dist_backend, + skip_if_lt_x_gpu, +) from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._shard.sharded_tensor import ( ShardedTensorTestBase, with_comms, ) +if torch.accelerator.current_accelerator() is None: + print("No accelerator available, skipping tests", file=sys.stderr) + sys.exit(0) + +DEVICE_TYPE = torch.accelerator.current_accelerator().type +BACKEND = dist.get_default_backend_for_device(DEVICE_TYPE) + class MyShardedModel(torch.nn.Module): def __init__(self, spec=None, group=None): @@ -46,30 +59,15 @@ def __init__(self, rank=None): self.linear2 = torch.nn.Linear(12, 29) self.gelu = torch.nn.GELU() - if rank: - self.linear1.cuda(rank) - self.linear2.cuda(rank) + if rank is not None: + device = torch.device(DEVICE_TYPE, rank) + self.linear1.to(device) + self.linear2.to(device) def shard_parameter(self): - rowwise_sharding_spec = ChunkShardingSpec( - dim=0, - placements=[ - "rank:0/cuda:0", - "rank:1/cuda:1", - "rank:2/cuda:2", - "rank:3/cuda:3", - ], - ) - - colwise_sharding_spec = ChunkShardingSpec( - dim=1, - placements=[ - "rank:0/cuda:0", - "rank:1/cuda:1", - "rank:2/cuda:2", - "rank:3/cuda:3", - ], - ) + placements = [f"rank:{i}/{DEVICE_TYPE}:{i}" for i in range(4)] + rowwise_sharding_spec = ChunkShardingSpec(dim=0, placements=placements) + colwise_sharding_spec = ChunkShardingSpec(dim=1, placements=placements) shard_parameter(self.linear1, "weight", rowwise_sharding_spec) shard_parameter(self.linear2, "weight", colwise_sharding_spec) @@ -79,21 +77,15 @@ def forward(self, inp): class TestShardedOptimizer(ShardedTensorTestBase): - @with_comms(init_rpc=False) + @with_comms(init_rpc=False, backend=BACKEND) @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_sharded_optim(self): - rowwise_spec = ChunkShardingSpec( - dim=0, - placements=[ - "rank:0/cuda:0", - "rank:1/cuda:1", - "rank:2/cuda:2", - "rank:3/cuda:3", - ], - ) - local_model = MyShardedModel().cuda() - sharded_model = MyShardedModel(spec=rowwise_spec).cuda() + placements = [f"rank:{i}/{DEVICE_TYPE}:{i}" for i in range(4)] + rowwise_spec = ChunkShardingSpec(dim=0, placements=placements) + device = torch.device(DEVICE_TYPE) + local_model = MyShardedModel().to(device) + sharded_model = MyShardedModel(spec=rowwise_spec).to(device) # copy the parameters from local model sharded_model.sharded_param.local_shards()[0].tensor = ( @@ -109,7 +101,7 @@ def test_sharded_optim(self): before_update = deepcopy(sharded_optim.named_params) - inp = torch.rand([5, 10]).cuda(self.rank).requires_grad_() + inp = torch.rand([5, 10]).to(torch.device(DEVICE_TYPE, self.rank)).requires_grad_() # run forward local_output = local_model(inp) @@ -138,27 +130,21 @@ def test_sharded_optim(self): self.assertNotEqual(val, new_val) self.assertEqual(new_val, local_model.param) - @with_comms(init_rpc=False) + @with_comms(init_rpc=False, backend=BACKEND) @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_named_params_with_sharded_tensor(self): - rowwise_spec = ChunkShardingSpec( - dim=0, - placements=[ - "rank:0/cuda:0", - "rank:1/cuda:1", - "rank:2/cuda:2", - "rank:3/cuda:3", - ], - ) - sharded_model = MyShardedModel(spec=rowwise_spec).cuda() + placements = [f"rank:{i}/{DEVICE_TYPE}:{i}" for i in range(4)] + rowwise_spec = ChunkShardingSpec(dim=0, placements=placements) + device = torch.device(DEVICE_TYPE) + sharded_model = MyShardedModel(spec=rowwise_spec).to(device) sharded_model_params = dict(sharded_model.named_parameters()) param_keys = list(sharded_model_params.keys()) self.assertEqual(len(param_keys), 2) self.assertTrue("param" in param_keys) self.assertTrue("sharded_param" in param_keys) - sharded_linear = MyShardedLinear(rank=self.rank).cuda() + sharded_linear = MyShardedLinear(rank=self.rank).to(device) sharded_linear.shard_parameter() sharded_linear_params = dict(sharded_linear.named_parameters()) param_keys = list(sharded_linear_params.keys()) diff --git a/test/distributed/_shard/sharded_tensor/ops/test_binary_cmp.py b/test/distributed/_shard/sharded_tensor/ops/test_binary_cmp.py index 094bc0f53d938..6b2ba6ad8cc82 100644 --- a/test/distributed/_shard/sharded_tensor/ops/test_binary_cmp.py +++ b/test/distributed/_shard/sharded_tensor/ops/test_binary_cmp.py @@ -1,3 +1,4 @@ +# Adapted from upstream — made device-agnostic for PrivateUse1 backends. # Owner(s): ["oncall: distributed"] import sys @@ -7,7 +8,10 @@ from torch.distributed._shard import sharded_tensor from torch.distributed._shard.sharding_spec import ChunkShardingSpec from torch.distributed.distributed_c10d import _get_default_group -from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu +from torch.testing._internal.common_distributed import ( + requires_accelerator_dist_backend, + skip_if_lt_x_gpu, +) from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN from torch.testing._internal.distributed._shard.sharded_tensor import ( ShardedTensorTestBase, @@ -22,6 +26,13 @@ ) sys.exit(0) +if torch.accelerator.current_accelerator() is None: + print("No accelerator available, skipping tests", file=sys.stderr) + sys.exit(0) + +DEVICE_TYPE = torch.accelerator.current_accelerator().type +BACKEND = dist.get_default_backend_for_device(DEVICE_TYPE) + class TestShardedTensorBinaryOps(ShardedTensorTestBase): """Test base for binary comparison functions such as torch.equal, torch.allclose etc. for ShardedTensor""" @@ -45,20 +56,20 @@ def get_gpu_specs(self): spec = ChunkShardingSpec( dim=0, placements=[ - "rank:0/cuda:0", - "rank:1/cuda:1", - "rank:2/cuda:2", - "rank:3/cuda:3", + f"rank:0/{DEVICE_TYPE}:0", + f"rank:1/{DEVICE_TYPE}:1", + f"rank:2/{DEVICE_TYPE}:2", + f"rank:3/{DEVICE_TYPE}:3", ], ) alt_spec = ChunkShardingSpec( dim=0, placements=[ - "rank:1/cuda:1", - "rank:0/cuda:0", - "rank:3/cuda:3", - "rank:2/cuda:2", + f"rank:1/{DEVICE_TYPE}:1", + f"rank:0/{DEVICE_TYPE}:0", + f"rank:3/{DEVICE_TYPE}:3", + f"rank:2/{DEVICE_TYPE}:2", ], ) return spec, alt_spec @@ -119,13 +130,13 @@ def _test_common_failures(self, cmp_op): @with_comms @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_torch_equal_tensor_specs(self): self._test_common_failures(torch.equal) @with_comms @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_torch_equal(self): """Test torch.equal(ShardedTensor, ShardedTensor)""" @@ -135,13 +146,13 @@ def test_torch_equal(self): @with_comms @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_torch_allclose_tensor_specs(self): self._test_common_failures(torch.allclose) @with_comms @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_torch_allclose(self): """Test torch.allclose(ShardedTensor, ShardedTensor)""" diff --git a/test/distributed/_shard/sharded_tensor/ops/test_embedding.py b/test/distributed/_shard/sharded_tensor/ops/test_embedding.py index 0b4cb6d1f642c..a783e0a818a8b 100644 --- a/test/distributed/_shard/sharded_tensor/ops/test_embedding.py +++ b/test/distributed/_shard/sharded_tensor/ops/test_embedding.py @@ -1,3 +1,4 @@ +# Adapted from upstream — made device-agnostic for PrivateUse1 backends. # Owner(s): ["oncall: distributed"] import sys @@ -5,7 +6,10 @@ import torch import torch.distributed as dist from torch.distributed._shard import shard_parameter -from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu +from torch.testing._internal.common_distributed import ( + requires_accelerator_dist_backend, + skip_if_lt_x_gpu, +) from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN from torch.testing._internal.distributed._shard.sharded_tensor import ( ShardedTensorTestBase, @@ -26,6 +30,13 @@ ) sys.exit(0) +if torch.accelerator.current_accelerator() is None: + print("No accelerator available, skipping tests", file=sys.stderr) + sys.exit(0) + +DEVICE_TYPE = torch.accelerator.current_accelerator().type +BACKEND = dist.get_default_backend_for_device(DEVICE_TYPE) + class TestShardedEmbedding(ShardedTensorTestBase): def _run_sharded_embedding( @@ -46,7 +57,7 @@ def _run_sharded_embedding( max_norm=max_norm, norm_type=norm_type, padding_idx=padding_idx, - ).cuda(self.rank) + ).to(torch.device(DEVICE_TYPE, self.rank)) sharded_embedding = torch.nn.Embedding( num_embeddings, @@ -64,7 +75,7 @@ def _run_sharded_embedding( # Run sharded computation torch.manual_seed(self.rank) # inputs different on each rank - inp = torch.randint(0, num_embeddings, tuple(input_size)).cuda(self.rank) + inp = torch.randint(0, num_embeddings, tuple(input_size)).to(torch.device(DEVICE_TYPE, self.rank)) sharded_output = sharded_embedding(inp) # If max_norm is set, we need to ensure that the renorm has been applied across @@ -112,9 +123,9 @@ def _run_sharded_embedding( self.assertEqual(local_output, sharded_output) - @with_comms(init_rpc=False) + @with_comms(init_rpc=False, backend=BACKEND) @skip_if_lt_x_gpu(TEST_GPU_NUM) - @requires_nccl() + @requires_accelerator_dist_backend() def test_sharded_embedding_colwise(self): for spec in generate_chunk_sharding_specs_for_test(1): self._run_sharded_embedding(spec, [5, 4], 17, 12) @@ -148,9 +159,9 @@ def test_sharded_embedding_colwise(self): ) self._run_sharded_embedding(spec, [30], 15, 14, max_norm=2.0) - @with_comms(init_rpc=False) + @with_comms(init_rpc=False, backend=BACKEND) @skip_if_lt_x_gpu(TEST_GPU_NUM) - @requires_nccl() + @requires_accelerator_dist_backend() def test_sharded_embedding_rowwise(self): for spec in generate_chunk_sharding_specs_for_test(0): # Test even split. diff --git a/test/distributed/_shard/sharded_tensor/ops/test_embedding_bag.py b/test/distributed/_shard/sharded_tensor/ops/test_embedding_bag.py index e1af5bf2b9919..79c26f7849d6e 100644 --- a/test/distributed/_shard/sharded_tensor/ops/test_embedding_bag.py +++ b/test/distributed/_shard/sharded_tensor/ops/test_embedding_bag.py @@ -1,3 +1,4 @@ +# Adapted from upstream — made device-agnostic for PrivateUse1 backends. # Owner(s): ["oncall: distributed"] import sys @@ -5,7 +6,10 @@ import torch import torch.distributed as dist from torch.distributed._shard import shard_parameter -from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu +from torch.testing._internal.common_distributed import ( + requires_accelerator_dist_backend, + skip_if_lt_x_gpu, +) from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN from torch.testing._internal.distributed._shard.sharded_tensor import ( ShardedTensorTestBase, @@ -26,6 +30,13 @@ ) sys.exit(0) +if torch.accelerator.current_accelerator() is None: + print("No accelerator available, skipping tests", file=sys.stderr) + sys.exit(0) + +DEVICE_TYPE = torch.accelerator.current_accelerator().type +BACKEND = dist.get_default_backend_for_device(DEVICE_TYPE) + class TestShardedEmbeddingBag(ShardedTensorTestBase): def _run_sharded_embedding_bag( @@ -51,7 +62,7 @@ def _run_sharded_embedding_bag( norm_type=norm_type, include_last_offset=include_last_offset, padding_idx=padding_idx, - ).cuda(self.rank) + ).to(torch.device(DEVICE_TYPE, self.rank)) sharded_embedding_bag = torch.nn.EmbeddingBag( num_embeddings, @@ -73,10 +84,10 @@ def _run_sharded_embedding_bag( # Run sharded computation torch.manual_seed(self.rank) # inputs different on each rank - inp = torch.randint(0, num_embeddings, tuple(input_size)).cuda(self.rank) + inp = torch.randint(0, num_embeddings, tuple(input_size)).to(torch.device(DEVICE_TYPE, self.rank)) per_sample_weights = None if mode == "sum": - per_sample_weights = torch.rand(*input_size).cuda(self.rank) + per_sample_weights = torch.rand(*input_size).to(torch.device(DEVICE_TYPE, self.rank)) offsets = None if len(input_size) == 1: @@ -91,7 +102,7 @@ def _run_sharded_embedding_bag( if include_last_offset: offsets[-1] = input_size[0] offsets = ( - torch.unique(offsets, sorted=True).contiguous().cuda(self.rank) + torch.unique(offsets, sorted=True).contiguous().to(torch.device(DEVICE_TYPE, self.rank)) ) # If max_norm is set, we need to ensure that the renorm has been applied across @@ -100,7 +111,7 @@ def _run_sharded_embedding_bag( gathered_inputs = [torch.zeros_like(inp) for _ in range(TEST_GPU_NUM)] dist.all_gather(gathered_inputs, inp) unique_inp = torch.unique(torch.cat(gathered_inputs)) - offsets_dummy = torch.tensor([len(unique_inp) // 2]).cuda(self.rank) + offsets_dummy = torch.tensor([len(unique_inp) // 2]).to(torch.device(DEVICE_TYPE, self.rank)) local_embedding_bag(unique_inp, offsets=offsets_dummy) sharded_output = sharded_embedding_bag( @@ -158,16 +169,16 @@ def _run_sharded_embedding_bag( self.assertEqual(local_output, sharded_output) - @with_comms(init_rpc=False) + @with_comms(init_rpc=False, backend=BACKEND) @skip_if_lt_x_gpu(TEST_GPU_NUM) - @requires_nccl() + @requires_accelerator_dist_backend() def test_sharded_embedding_bag_colwise(self): for spec in generate_chunk_sharding_specs_for_test(1): self._test_sharded_embedding_bag_with_test_cases(spec, 1) - @with_comms(init_rpc=False) + @with_comms(init_rpc=False, backend=BACKEND) @skip_if_lt_x_gpu(TEST_GPU_NUM) - @requires_nccl() + @requires_accelerator_dist_backend() def test_sharded_embedding_bag_rowwise(self): for spec in generate_chunk_sharding_specs_for_test(0): self._test_sharded_embedding_bag_with_test_cases(spec, 0) diff --git a/test/distributed/_shard/sharded_tensor/ops/test_init.py b/test/distributed/_shard/sharded_tensor/ops/test_init.py index c33136f33eefa..4cb65ef372a35 100644 --- a/test/distributed/_shard/sharded_tensor/ops/test_init.py +++ b/test/distributed/_shard/sharded_tensor/ops/test_init.py @@ -1,11 +1,16 @@ +# Adapted from upstream — made device-agnostic for PrivateUse1 backends. # Owner(s): ["oncall: distributed"] import sys import torch +import torch.distributed as dist from torch.distributed._shard import sharded_tensor from torch.distributed._shard.sharding_spec import ChunkShardingSpec -from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu +from torch.testing._internal.common_distributed import ( + requires_accelerator_dist_backend, + skip_if_lt_x_gpu, +) from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN from torch.testing._internal.distributed._shard.sharded_tensor import ( ShardedTensorTestBase, @@ -20,23 +25,30 @@ ) sys.exit(0) +if torch.accelerator.current_accelerator() is None: + print("No accelerator available, skipping tests", file=sys.stderr) + sys.exit(0) + +DEVICE_TYPE = torch.accelerator.current_accelerator().type +BACKEND = dist.get_default_backend_for_device(DEVICE_TYPE) + class TestShardedTensorNNInit(ShardedTensorTestBase): """Testing torch.nn.init functions for ShardedTensor""" @with_comms @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_init_sharded_tensor_with_uniform(self): """Test torch.nn.init.uniform_(ShardedTensor, a, b)""" spec = ChunkShardingSpec( dim=0, placements=[ - "rank:0/cuda:0", - "rank:1/cuda:1", - "rank:2/cuda:2", - "rank:3/cuda:3", + f"rank:0/{DEVICE_TYPE}:0", + f"rank:1/{DEVICE_TYPE}:1", + f"rank:2/{DEVICE_TYPE}:2", + f"rank:3/{DEVICE_TYPE}:3", ], ) h, w = 8, 2 @@ -59,17 +71,17 @@ def test_init_sharded_tensor_with_uniform(self): @with_comms @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_init_sharded_tensor_with_normal(self): """Test torch.nn.init.normal_(ShardedTensor, mean, std)""" spec = ChunkShardingSpec( dim=0, placements=[ - "rank:0/cuda:0", - "rank:1/cuda:1", - "rank:2/cuda:2", - "rank:3/cuda:3", + f"rank:0/{DEVICE_TYPE}:0", + f"rank:1/{DEVICE_TYPE}:1", + f"rank:2/{DEVICE_TYPE}:2", + f"rank:3/{DEVICE_TYPE}:3", ], ) h, w = 8, 2 @@ -92,17 +104,17 @@ def test_init_sharded_tensor_with_normal(self): @with_comms @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_init_sharded_tensor_with_kaiming_uniform(self): """Test torch.nn.init.kaiming_uniform_(ShardedTensor, a, mode, nonlinearit)""" spec = ChunkShardingSpec( dim=0, placements=[ - "rank:0/cuda:0", - "rank:1/cuda:1", - "rank:2/cuda:2", - "rank:3/cuda:3", + f"rank:0/{DEVICE_TYPE}:0", + f"rank:1/{DEVICE_TYPE}:1", + f"rank:2/{DEVICE_TYPE}:2", + f"rank:3/{DEVICE_TYPE}:3", ], ) h, w = 8, 2 diff --git a/test/distributed/_shard/sharded_tensor/ops/test_tensor_ops.py b/test/distributed/_shard/sharded_tensor/ops/test_tensor_ops.py index ddf88424b2336..625aab7226d0a 100644 --- a/test/distributed/_shard/sharded_tensor/ops/test_tensor_ops.py +++ b/test/distributed/_shard/sharded_tensor/ops/test_tensor_ops.py @@ -1,11 +1,17 @@ +# Adapted from upstream — made device-agnostic for PrivateUse1 backends. # Owner(s): ["oncall: distributed"] import copy +import sys import torch +import torch.distributed as dist import torch.distributed._shard.sharded_tensor as sharded_tensor from torch.distributed._shard.sharding_spec import ChunkShardingSpec -from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu +from torch.testing._internal.common_distributed import ( + requires_accelerator_dist_backend, + skip_if_lt_x_gpu, +) from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._shard.sharded_tensor import ( ShardedTensorTestBase, @@ -14,18 +20,26 @@ ) +if torch.accelerator.current_accelerator() is None: + print("No accelerator available, skipping tests", file=sys.stderr) + sys.exit(0) + +DEVICE_TYPE = torch.accelerator.current_accelerator().type +BACKEND = dist.get_default_backend_for_device(DEVICE_TYPE) + + class TestTensorOps(ShardedTensorTestBase): - @with_comms(init_rpc=False) + @with_comms(init_rpc=False, backend=BACKEND) @skip_if_lt_x_gpu(TEST_GPU_NUM) - @requires_nccl() + @requires_accelerator_dist_backend() def test_deep_copy(self): spec = ChunkShardingSpec( dim=0, placements=[ - "rank:0/cuda:0", - "rank:1/cuda:1", - "rank:2/cuda:2", - "rank:3/cuda:3", + f"rank:0/{DEVICE_TYPE}:0", + f"rank:1/{DEVICE_TYPE}:1", + f"rank:2/{DEVICE_TYPE}:2", + f"rank:3/{DEVICE_TYPE}:3", ], ) st = sharded_tensor.rand(spec, (12, 5)) @@ -34,17 +48,17 @@ def test_deep_copy(self): self.assertEqual(copied_st.local_tensor(), st.local_tensor()) self.assertFalse(copied_st is st) - @with_comms(init_rpc=False) + @with_comms(init_rpc=False, backend=BACKEND) @skip_if_lt_x_gpu(TEST_GPU_NUM) - @requires_nccl() + @requires_accelerator_dist_backend() def test_inplace_copy(self): spec = ChunkShardingSpec( dim=0, placements=[ - "rank:0/cuda:0", - "rank:1/cuda:1", - "rank:2/cuda:2", - "rank:3/cuda:3", + f"rank:0/{DEVICE_TYPE}:0", + f"rank:1/{DEVICE_TYPE}:1", + f"rank:2/{DEVICE_TYPE}:2", + f"rank:3/{DEVICE_TYPE}:3", ], ) st = sharded_tensor.rand(spec, (12, 5)) @@ -61,17 +75,17 @@ def test_inplace_copy(self): st_with_grad.copy_(ones_st) self.assertEqual(st_with_grad.local_tensor(), ones_st.local_tensor()) - @with_comms(init_rpc=False) + @with_comms(init_rpc=False, backend=BACKEND) @skip_if_lt_x_gpu(TEST_GPU_NUM) - @requires_nccl() + @requires_accelerator_dist_backend() def test_clone(self): spec = ChunkShardingSpec( dim=0, placements=[ - "rank:0/cuda:0", - "rank:1/cuda:1", - "rank:2/cuda:2", - "rank:3/cuda:3", + f"rank:0/{DEVICE_TYPE}:0", + f"rank:1/{DEVICE_TYPE}:1", + f"rank:2/{DEVICE_TYPE}:2", + f"rank:3/{DEVICE_TYPE}:3", ], ) st = sharded_tensor.rand(spec, (12, 5)) @@ -80,17 +94,17 @@ def test_clone(self): self.assertEqual(copied_st.local_tensor(), st.local_tensor()) self.assertFalse(copied_st is st) - @with_comms(init_rpc=False) + @with_comms(init_rpc=False, backend=BACKEND) @skip_if_lt_x_gpu(TEST_GPU_NUM) - @requires_nccl() + @requires_accelerator_dist_backend() def test_detach(self): spec = ChunkShardingSpec( dim=0, placements=[ - "rank:0/cuda:0", - "rank:1/cuda:1", - "rank:2/cuda:2", - "rank:3/cuda:3", + f"rank:0/{DEVICE_TYPE}:0", + f"rank:1/{DEVICE_TYPE}:1", + f"rank:2/{DEVICE_TYPE}:2", + f"rank:3/{DEVICE_TYPE}:3", ], ) st = sharded_tensor.rand(spec, (12, 5), requires_grad=True) @@ -105,17 +119,17 @@ def test_detach(self): for local_shard in detached_st.local_shards(): self.assertFalse(local_shard.tensor.requires_grad) - @with_comms(init_rpc=False) + @with_comms(init_rpc=False, backend=BACKEND) @skip_if_lt_x_gpu(TEST_GPU_NUM) - @requires_nccl() + @requires_accelerator_dist_backend() def test_set_requires_grad(self): spec = ChunkShardingSpec( dim=0, placements=[ - "rank:0/cuda:0", - "rank:1/cuda:1", - "rank:2/cuda:2", - "rank:3/cuda:3", + f"rank:0/{DEVICE_TYPE}:0", + f"rank:1/{DEVICE_TYPE}:1", + f"rank:2/{DEVICE_TYPE}:2", + f"rank:3/{DEVICE_TYPE}:3", ], ) st = sharded_tensor.rand(spec, (12, 5)) diff --git a/test/distributed/_shard/sharded_tensor/test_logger.py b/test/distributed/_shard/sharded_tensor/test_logger.py index fa946819f93b2..372547ae236fd 100644 --- a/test/distributed/_shard/sharded_tensor/test_logger.py +++ b/test/distributed/_shard/sharded_tensor/test_logger.py @@ -1,3 +1,4 @@ +# Adapted from upstream — made device-agnostic for PrivateUse1 backends. # Owner(s): ["oncall: distributed"] import logging diff --git a/test/distributed/_shard/sharded_tensor/test_sharded_tensor.py b/test/distributed/_shard/sharded_tensor/test_sharded_tensor.py index 3d9183bf63238..237c684a5556b 100644 --- a/test/distributed/_shard/sharded_tensor/test_sharded_tensor.py +++ b/test/distributed/_shard/sharded_tensor/test_sharded_tensor.py @@ -1,3 +1,4 @@ +# Adapted from upstream — made device-agnostic for PrivateUse1 backends. # Owner(s): ["oncall: distributed"] import copy @@ -41,6 +42,7 @@ ) from torch.distributed.remote_device import _remote_device from torch.testing._internal.common_distributed import ( + requires_accelerator_dist_backend, requires_nccl, skip_if_lt_x_gpu, spawn_threads_and_init_comms, @@ -63,6 +65,12 @@ MyShardedModel1, ) +if torch.accelerator.current_accelerator() is None: + print("No accelerator available, skipping tests", file=sys.stderr) + sys.exit(0) + +DEVICE_TYPE = torch.accelerator.current_accelerator().type +BACKEND = dist.get_default_backend_for_device(DEVICE_TYPE) if TEST_WITH_DEV_DBG_ASAN: print( @@ -78,22 +86,22 @@ def test_serialize_and_deserialize(self): ShardMetadata( shard_offsets=[0, 0], shard_sizes=[5, 5], - placement="rank:0/cuda:0", + placement=f"rank:0/{DEVICE_TYPE}:0", ), ShardMetadata( shard_offsets=[0, 5], shard_sizes=[5, 5], - placement="rank:1/cuda:1", + placement=f"rank:1/{DEVICE_TYPE}:1", ), ShardMetadata( shard_offsets=[5, 0], shard_sizes=[5, 5], - placement="rank:2/cuda:2", + placement=f"rank:2/{DEVICE_TYPE}:2", ), ShardMetadata( shard_offsets=[5, 5], shard_sizes=[5, 5], - placement="rank:3/cuda:3", + placement=f"rank:3/{DEVICE_TYPE}:3", ), ] @@ -167,21 +175,21 @@ def test_empty(self): class TestShardParameter(ShardedTensorTestBase): - @with_comms(init_rpc=False) + @with_comms(init_rpc=False, backend=BACKEND) @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_shard_parameter(self): spec = ChunkShardingSpec( dim=0, placements=[ - "rank:0/cuda:0", - "rank:1/cuda:1", - "rank:2/cuda:2", - "rank:3/cuda:3", + f"rank:0/{DEVICE_TYPE}:0", + f"rank:1/{DEVICE_TYPE}:1", + f"rank:2/{DEVICE_TYPE}:2", + f"rank:3/{DEVICE_TYPE}:3", ], ) - fc = torch.nn.Linear(12, 12).cuda(self.rank) + fc = torch.nn.Linear(12, 12).to(torch.device(DEVICE_TYPE, self.rank)) weight_og = fc.weight.clone() shard_parameter(fc, "weight", spec) @@ -196,21 +204,21 @@ def test_shard_parameter(self): torch.narrow(weight_og, 0, 3 * self.rank, 3), local_shards[0].tensor ) - @with_comms(init_rpc=False) + @with_comms(init_rpc=False, backend=BACKEND) @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_shard_parameter_errors(self): spec = ChunkShardingSpec( dim=0, placements=[ - "rank:0/cuda:0", - "rank:1/cuda:1", - "rank:2/cuda:2", - "rank:3/cuda:3", + f"rank:0/{DEVICE_TYPE}:0", + f"rank:1/{DEVICE_TYPE}:1", + f"rank:2/{DEVICE_TYPE}:2", + f"rank:3/{DEVICE_TYPE}:3", ], ) - fc = torch.nn.Linear(12, 12).cuda(self.rank) + fc = torch.nn.Linear(12, 12).to(torch.device(DEVICE_TYPE, self.rank)) with self.assertRaisesRegex(ValueError, "does not match with src_rank"): shard_parameter(fc, "weight", spec, src_rank=self.rank) @@ -225,16 +233,16 @@ def test_shard_parameter_errors(self): shard_parameter(fc, "bias", spec) with self.assertRaisesRegex(ValueError, "not a contiguous Tensor"): - fc.bias = torch.rand(10, 10).cuda(self.rank).t() + fc.bias = torch.rand(10, 10).to(torch.device(DEVICE_TYPE, self.rank)).t() shard_parameter(fc, "bias", spec) spec = ChunkShardingSpec( dim=0, placements=[ - f"rank:{self.rank}/cuda:0", - "rank:1/cuda:1", - "rank:2/cuda:2", - "rank:3/cuda:3", + f"rank:{self.rank}/{DEVICE_TYPE}:0", + f"rank:1/{DEVICE_TYPE}:1", + f"rank:2/{DEVICE_TYPE}:2", + f"rank:3/{DEVICE_TYPE}:3", ], ) with self.assertRaisesRegex(ValueError, "does not match with sharding_spec"): @@ -245,12 +253,12 @@ def test_shard_parameter_errors(self): ShardMetadata( shard_offsets=[0, 0], shard_sizes=[5, 5], - placement="rank:0/cuda:0", + placement=f"rank:0/{DEVICE_TYPE}:0", ), ShardMetadata( shard_offsets=[5, 0], shard_sizes=[5, 5], - placement="rank:1/cuda:1", + placement=f"rank:1/{DEVICE_TYPE}:1", ), ] ) @@ -259,20 +267,20 @@ def test_shard_parameter_errors(self): class TestShardTensor(ShardedTensorTestBase): - @with_comms(init_rpc=False) + @with_comms(init_rpc=False, backend=BACKEND) @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_shard_tensor(self): spec = ChunkShardingSpec( dim=0, placements=[ - "rank:0/cuda:0", - "rank:1/cuda:1", - "rank:2/cuda:2", - "rank:3/cuda:3", + f"rank:0/{DEVICE_TYPE}:0", + f"rank:1/{DEVICE_TYPE}:1", + f"rank:2/{DEVICE_TYPE}:2", + f"rank:3/{DEVICE_TYPE}:3", ], ) - tensor = torch.rand(12, 12).cuda(self.rank) + tensor = torch.rand(12, 12).to(torch.device(DEVICE_TYPE, self.rank)) st = _shard_tensor(tensor, spec) # Verify. @@ -282,20 +290,20 @@ def test_shard_tensor(self): self.assertEqual(torch.Size([3, 12]), local_shard.size()) self.assertEqual(torch.narrow(tensor, 0, 3 * self.rank, 3), local_shard) - @with_comms(init_rpc=False) + @with_comms(init_rpc=False, backend=BACKEND) @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_shard_tensor_with_empty_shard(self): spec = ChunkShardingSpec( dim=0, placements=[ - "rank:0/cuda:0", - "rank:1/cuda:1", - "rank:2/cuda:2", - "rank:3/cuda:3", + f"rank:0/{DEVICE_TYPE}:0", + f"rank:1/{DEVICE_TYPE}:1", + f"rank:2/{DEVICE_TYPE}:2", + f"rank:3/{DEVICE_TYPE}:3", ], ) - tensor = torch.rand(9, 12).cuda(self.rank) + tensor = torch.rand(9, 12).to(torch.device(DEVICE_TYPE, self.rank)) st = _shard_tensor(tensor, spec) # Verify. @@ -313,35 +321,35 @@ def test_shard_tensor_with_empty_shard(self): else: self.assertEqual(torch.Size([0, 12]), local_shard.size()) - @with_comms(init_rpc=False) + @with_comms(init_rpc=False, backend=BACKEND) @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_shard_tensor_errors(self): spec = ChunkShardingSpec( dim=0, placements=[ - "rank:0/cuda:0", - "rank:1/cuda:1", - "rank:2/cuda:2", - "rank:3/cuda:3", + f"rank:0/{DEVICE_TYPE}:0", + f"rank:1/{DEVICE_TYPE}:1", + f"rank:2/{DEVICE_TYPE}:2", + f"rank:3/{DEVICE_TYPE}:3", ], ) - tensor = torch.rand(12, 12).cuda(self.rank) + tensor = torch.rand(12, 12).to(torch.device(DEVICE_TYPE, self.rank)) with self.assertRaisesRegex(ValueError, "does not match with src_rank"): _shard_tensor(tensor, spec, src_rank=self.rank) with self.assertRaisesRegex(ValueError, "not a contiguous Tensor"): - tensor_t = torch.rand(12, 12).cuda(self.rank).t() + tensor_t = torch.rand(12, 12).to(torch.device(DEVICE_TYPE, self.rank)).t() _shard_tensor(tensor_t, spec) spec = ChunkShardingSpec( dim=0, placements=[ - f"rank:{self.rank}/cuda:0", - "rank:1/cuda:1", - "rank:2/cuda:2", - "rank:3/cuda:3", + f"rank:{self.rank}/{DEVICE_TYPE}:0", + f"rank:1/{DEVICE_TYPE}:1", + f"rank:2/{DEVICE_TYPE}:2", + f"rank:3/{DEVICE_TYPE}:3", ], ) with self.assertRaisesRegex(ValueError, "does not match with sharding_spec"): @@ -352,12 +360,12 @@ def test_shard_tensor_errors(self): ShardMetadata( shard_offsets=[0, 0], shard_sizes=[5, 5], - placement="rank:0/cuda:0", + placement=f"rank:0/{DEVICE_TYPE}:0", ), ShardMetadata( shard_offsets=[5, 0], shard_sizes=[5, 5], - placement="rank:1/cuda:1", + placement=f"rank:1/{DEVICE_TYPE}:1", ), ] ) @@ -374,9 +382,9 @@ def __init__(self, spec, tensor_size): def forward(self): return self.st - @with_comms(init_rpc=False) + @with_comms(init_rpc=False, backend=BACKEND) @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_reshard_output(self): specs = _chunk_sharding_specs_list_for_test([0, 1], seed=5) spec, reshard_spec = specs[0], specs[1] @@ -399,9 +407,9 @@ def test_reshard_output(self): self.assertEqual(local_shard.size(0), 24) self.assertEqual(local_shard.size(1), 3) - @with_comms(init_rpc=False) + @with_comms(init_rpc=False, backend=BACKEND) @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_collect_local_shard(self): specs = _chunk_sharding_specs_list_for_test([0], seed=5) spec = specs[0] @@ -415,17 +423,17 @@ def test_collect_local_shard(self): class TestLocalTensor(ShardedTensorTestBase): - @with_comms(init_rpc=False) + @with_comms(init_rpc=False, backend=BACKEND) @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_local_tensor(self): spec = ChunkShardingSpec( dim=0, placements=[ - "rank:0/cuda:0", - "rank:1/cuda:1", - "rank:2/cuda:2", - "rank:3/cuda:3", + f"rank:0/{DEVICE_TYPE}:0", + f"rank:1/{DEVICE_TYPE}:1", + f"rank:2/{DEVICE_TYPE}:2", + f"rank:3/{DEVICE_TYPE}:3", ], ) st = sharded_tensor.rand(spec, 24, 12) @@ -433,23 +441,23 @@ def test_local_tensor(self): self.assertEqual(torch.Size([6, 12]), local_shard.size()) self.assertEqual(st.local_tensor(), local_shard) - @with_comms(init_rpc=False) + @with_comms(init_rpc=False, backend=BACKEND) @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_local_tensor_error(self): spec = ChunkShardingSpec( dim=0, placements=[ - "rank:0/cuda:0", - "rank:0/cuda:0", - "rank:1/cuda:1", - "rank:1/cuda:1", - "rank:1/cuda:1", - "rank:2/cuda:2", - "rank:2/cuda:2", - "rank:2/cuda:2", - "rank:3/cuda:3", - "rank:3/cuda:3", + f"rank:0/{DEVICE_TYPE}:0", + f"rank:0/{DEVICE_TYPE}:0", + f"rank:1/{DEVICE_TYPE}:1", + f"rank:1/{DEVICE_TYPE}:1", + f"rank:1/{DEVICE_TYPE}:1", + f"rank:2/{DEVICE_TYPE}:2", + f"rank:2/{DEVICE_TYPE}:2", + f"rank:2/{DEVICE_TYPE}:2", + f"rank:3/{DEVICE_TYPE}:3", + f"rank:3/{DEVICE_TYPE}:3", ], ) st = sharded_tensor.rand(spec, 24, 12) @@ -462,15 +470,15 @@ def test_local_tensor_error(self): class TestShardedTensorChunked(ShardedTensorTestBase): @with_comms @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_sharded_tensor_metadata(self): spec = ChunkShardingSpec( dim=0, placements=[ - "rank:0/cuda:0", - "rank:1/cuda:1", - "rank:2/cuda:2", - "rank:3/cuda:3", + f"rank:0/{DEVICE_TYPE}:0", + f"rank:1/{DEVICE_TYPE}:1", + f"rank:2/{DEVICE_TYPE}:2", + f"rank:3/{DEVICE_TYPE}:3", ], ) @@ -512,16 +520,16 @@ def test_sharded_tensor_metadata(self): @skipIfRocm @with_comms @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_complete_world_size(self): for dim in [0, -2]: spec = ChunkShardingSpec( dim=dim, placements=[ - "rank:0/cuda:0", - "rank:1/cuda:1", - "rank:2/cuda:2", - "rank:3/cuda:3", + f"rank:0/{DEVICE_TYPE}:0", + f"rank:1/{DEVICE_TYPE}:1", + f"rank:2/{DEVICE_TYPE}:2", + f"rank:3/{DEVICE_TYPE}:3", ], ) st = sharded_tensor.empty(spec, 10, 20, init_rrefs=True) @@ -530,7 +538,7 @@ def test_complete_world_size(self): local_shards = st.local_shards() self.assertEqual(1, len(local_shards)) local_shard = local_shards[0].tensor - self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.device) + self.assertEqual(torch.device(DEVICE_TYPE, self.rank), local_shard.device) if self.rank == 3: self.assertEqual((1, 20), local_shard.size()) else: @@ -548,7 +556,7 @@ def test_complete_world_size(self): else: self.assertEqual([3, 20], shard_metadata.shard_sizes) self.assertEqual( - f"rank:{rank}/cuda:{rank}", str(shard_metadata.placement) + f"rank:{rank}/{DEVICE_TYPE}:{rank}", str(shard_metadata.placement) ) # Validate remote shards. @@ -561,7 +569,7 @@ def test_complete_world_size(self): self.assertEqual(rpc_rank, remote_shard.owner().id) shard = remote_shard.to_here() self.assertEqual( - f"rank:{rpc_rank}/cuda:{rpc_rank}", + f"rank:{rpc_rank}/{DEVICE_TYPE}:{rpc_rank}", str(shard.metadata.placement), ) if rpc_rank == 3: @@ -571,17 +579,17 @@ def test_complete_world_size(self): @with_comms @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_create_sharded_tensor_with_ones(self): """Test sharded_tensor.ones(...)""" spec = ChunkShardingSpec( dim=0, placements=[ - "rank:0/cuda:0", - "rank:1/cuda:1", - "rank:2/cuda:2", - "rank:3/cuda:3", + f"rank:0/{DEVICE_TYPE}:0", + f"rank:1/{DEVICE_TYPE}:1", + f"rank:2/{DEVICE_TYPE}:2", + f"rank:3/{DEVICE_TYPE}:3", ], ) h, w = 10, 20 @@ -591,7 +599,7 @@ def test_create_sharded_tensor_with_ones(self): local_shards = st.local_shards() self.assertEqual(1, len(local_shards)) local_shard = local_shards[0].tensor - self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.device) + self.assertEqual(torch.device(DEVICE_TYPE, self.rank), local_shard.device) # The split: for rank!=3 ceil(h/4)=3 for rank=3 1 expected_h = 1 if self.rank == 3 else math.ceil(h / 4) self.assertEqual((expected_h, w), local_shard.size()) @@ -599,17 +607,17 @@ def test_create_sharded_tensor_with_ones(self): @with_comms @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_gather_even(self) -> None: """Test _sharded_tensor.gather(...) with evenly distributed._shards""" spec = ChunkShardingSpec( dim=0, placements=[ - "rank:0/cuda:0", - "rank:1/cuda:1", - "rank:2/cuda:2", - "rank:3/cuda:3", + f"rank:0/{DEVICE_TYPE}:0", + f"rank:1/{DEVICE_TYPE}:1", + f"rank:2/{DEVICE_TYPE}:2", + f"rank:3/{DEVICE_TYPE}:3", ], ) h, w = 10, 20 @@ -621,7 +629,7 @@ def test_gather_even(self) -> None: full_tensor = torch.zeros( h, w, - device=torch.device(f"cuda:{dst}"), + device=torch.device(DEVICE_TYPE, dst), ) st.gather(dst, full_tensor) @@ -632,18 +640,18 @@ def test_gather_even(self) -> None: @with_comms @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_gather_uneven(self) -> None: """Test _sharded_tensor.gather(...) with unevenly distributed._shards""" spec = ChunkShardingSpec( dim=0, placements=[ - "rank:0/cuda:0", - "rank:0/cuda:0", - "rank:1/cuda:1", - "rank:1/cuda:1", - "rank:2/cuda:2", + f"rank:0/{DEVICE_TYPE}:0", + f"rank:0/{DEVICE_TYPE}:0", + f"rank:1/{DEVICE_TYPE}:1", + f"rank:1/{DEVICE_TYPE}:1", + f"rank:2/{DEVICE_TYPE}:2", ], ) h, w = 10, 20 @@ -655,7 +663,7 @@ def test_gather_uneven(self) -> None: full_tensor = torch.zeros( h, w, - device=torch.device(f"cuda:{dst}"), + device=torch.device(DEVICE_TYPE, dst), ) st.gather(dst, full_tensor) @@ -666,17 +674,17 @@ def test_gather_uneven(self) -> None: @with_comms @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_create_sharded_tensor_with_zeros(self): """Test sharded_tensor.zeros(...)""" spec = ChunkShardingSpec( dim=0, placements=[ - "rank:0/cuda:0", - "rank:1/cuda:1", - "rank:2/cuda:2", - "rank:3/cuda:3", + f"rank:0/{DEVICE_TYPE}:0", + f"rank:1/{DEVICE_TYPE}:1", + f"rank:2/{DEVICE_TYPE}:2", + f"rank:3/{DEVICE_TYPE}:3", ], ) h, w = 10, 20 @@ -686,7 +694,7 @@ def test_create_sharded_tensor_with_zeros(self): local_shards = st.local_shards() self.assertEqual(1, len(local_shards)) local_shard = local_shards[0].tensor - self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.device) + self.assertEqual(torch.device(DEVICE_TYPE, self.rank), local_shard.device) # The split: for rank!=3 ceil(h/4)=3 for rank=3 1 expected_h = 1 if self.rank == 3 else math.ceil(h / 4) self.assertEqual((expected_h, w), local_shard.size()) @@ -694,24 +702,24 @@ def test_create_sharded_tensor_with_zeros(self): @with_comms @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_create_sharded_tensor_with_rand(self): """Test sharded_tensor.rand(...)/randn(...)""" spec = ChunkShardingSpec( dim=0, placements=[ - "rank:0/cuda:0", - "rank:1/cuda:1", - "rank:2/cuda:2", - "rank:3/cuda:3", + f"rank:0/{DEVICE_TYPE}:0", + f"rank:1/{DEVICE_TYPE}:1", + f"rank:2/{DEVICE_TYPE}:2", + f"rank:3/{DEVICE_TYPE}:3", ], ) h, w = 8, 2 seed = 1234 expected_h = 2 - expected_device = torch.device(f"cuda:{self.rank}") + expected_device = torch.device(DEVICE_TYPE, self.rank) dtype = torch.double torch.manual_seed(seed) # Test sharded_tensor.rand creation @@ -745,17 +753,17 @@ def test_create_sharded_tensor_with_rand(self): @with_comms @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_create_sharded_tensor_with_full(self): """Test sharded_tensor.full(...)""" spec = ChunkShardingSpec( dim=0, placements=[ - "rank:0/cuda:0", - "rank:1/cuda:1", - "rank:2/cuda:2", - "rank:3/cuda:3", + f"rank:0/{DEVICE_TYPE}:0", + f"rank:1/{DEVICE_TYPE}:1", + f"rank:2/{DEVICE_TYPE}:2", + f"rank:3/{DEVICE_TYPE}:3", ], ) h, w = 10, 20 @@ -768,7 +776,7 @@ def test_create_sharded_tensor_with_full(self): local_shards = st.local_shards() self.assertEqual(1, len(local_shards)) local_shard = local_shards[0].tensor - self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.device) + self.assertEqual(torch.device(DEVICE_TYPE, self.rank), local_shard.device) # The split: for rank!=3 ceil(h/4)=3 for rank=3 1 expected_h = 1 if self.rank == 3 else math.ceil(h / 4) self.assertEqual((expected_h, w), local_shard.size()) @@ -779,24 +787,24 @@ def test_create_sharded_tensor_with_full(self): @with_comms @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_create_sharded_tensor_like(self): """Test tensor like methods, i.e. torch.zeros_like(...), torch.full_like, etc.""" spec = ChunkShardingSpec( dim=0, placements=[ - "rank:0/cuda:0", - "rank:1/cuda:1", - "rank:2/cuda:2", - "rank:3/cuda:3", + f"rank:0/{DEVICE_TYPE}:0", + f"rank:1/{DEVICE_TYPE}:1", + f"rank:2/{DEVICE_TYPE}:2", + f"rank:3/{DEVICE_TYPE}:3", ], ) h, w = 8, 8 expected_h = 2 seed = 1234 dtype = torch.double - expected_device = torch.device(f"cuda:{self.rank}") + expected_device = torch.device(DEVICE_TYPE, self.rank) st = sharded_tensor.rand(spec, (h, w), dtype=dtype) tensor_like_ops = { torch.zeros_like: torch.zeros, @@ -833,13 +841,13 @@ def test_create_sharded_tensor_like(self): @skipIfRocm @with_comms @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_partial_world_size(self): spec = ChunkShardingSpec( dim=0, placements=[ - "rank:2/cuda:2", - "rank:3/cuda:3", + f"rank:2/{DEVICE_TYPE}:2", + f"rank:3/{DEVICE_TYPE}:3", ], ) st = sharded_tensor.empty(spec, 10, 20, init_rrefs=True) @@ -849,7 +857,7 @@ def test_partial_world_size(self): if self.rank >= 2: self.assertEqual(1, len(local_shards)) local_shard = local_shards[0].tensor - self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.device) + self.assertEqual(torch.device(DEVICE_TYPE, self.rank), local_shard.device) self.assertEqual((5, 20), local_shard.size()) else: self.assertEqual(0, len(local_shards)) @@ -863,7 +871,7 @@ def test_partial_world_size(self): self.assertEqual([shard_rank * 5, 0], shard_metadata.shard_offsets) self.assertEqual([5, 20], shard_metadata.shard_sizes) self.assertEqual( - f"rank:{shard_rank + 2}/cuda:{shard_rank + 2}", + f"rank:{shard_rank + 2}/{DEVICE_TYPE}:{shard_rank + 2}", str(shard_metadata.placement), ) @@ -880,20 +888,20 @@ def test_partial_world_size(self): self.assertEqual(rpc_rank, remote_shard.owner().id) shard = remote_shard.to_here() self.assertEqual( - f"rank:{rpc_rank}/cuda:{rpc_rank}", str(shard.metadata.placement) + f"rank:{rpc_rank}/{DEVICE_TYPE}:{rpc_rank}", str(shard.metadata.placement) ) self.assertEqual((5, 20), shard.tensor.size()) @skipIfRocm @with_comms @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_new_group(self): spec = ChunkShardingSpec( dim=0, placements=[ - "rank:2/cuda:2", - "rank:3/cuda:3", + f"rank:2/{DEVICE_TYPE}:2", + f"rank:3/{DEVICE_TYPE}:3", ], ) @@ -905,7 +913,7 @@ def test_new_group(self): if self.rank >= 2: self.assertEqual(1, len(local_shards)) local_shard = local_shards[0].tensor - self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.device) + self.assertEqual(torch.device(DEVICE_TYPE, self.rank), local_shard.device) self.assertEqual((5, 20), local_shard.size()) else: self.assertEqual(0, len(local_shards)) @@ -919,7 +927,7 @@ def test_new_group(self): self.assertEqual([shard_rank * 5, 0], shard_metadata.shard_offsets) self.assertEqual([5, 20], shard_metadata.shard_sizes) self.assertEqual( - f"rank:{shard_rank + 2}/cuda:{shard_rank + 2}", + f"rank:{shard_rank + 2}/{DEVICE_TYPE}:{shard_rank + 2}", str(shard_metadata.placement), ) @@ -936,26 +944,26 @@ def test_new_group(self): shard = remote_shard.to_here() self.assertEqual(rpc_rank, remote_shard.owner().id) self.assertEqual( - f"rank:{rpc_rank}/cuda:{rpc_rank}", str(shard.metadata.placement) + f"rank:{rpc_rank}/{DEVICE_TYPE}:{rpc_rank}", str(shard.metadata.placement) ) self.assertEqual((5, 20), shard.tensor.size()) @skipIfRocm @with_comms @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_multiple_local_shards(self): spec = ChunkShardingSpec( dim=0, placements=[ - "rank:0/cuda:0", - "rank:1/cuda:1", - "rank:2/cuda:2", - "rank:3/cuda:3", - "rank:0/cuda:0", - "rank:1/cuda:1", - "rank:2/cuda:2", - "rank:3/cuda:3", + f"rank:0/{DEVICE_TYPE}:0", + f"rank:1/{DEVICE_TYPE}:1", + f"rank:2/{DEVICE_TYPE}:2", + f"rank:3/{DEVICE_TYPE}:3", + f"rank:0/{DEVICE_TYPE}:0", + f"rank:1/{DEVICE_TYPE}:1", + f"rank:2/{DEVICE_TYPE}:2", + f"rank:3/{DEVICE_TYPE}:3", ], ) st = sharded_tensor.empty(spec, 16, 20, init_rrefs=True) @@ -965,7 +973,7 @@ def test_multiple_local_shards(self): self.assertEqual(2, len(local_shards)) for local_shard in local_shards: self.assertEqual( - torch.device(f"cuda:{self.rank}"), local_shard.tensor.device + torch.device(DEVICE_TYPE, self.rank), local_shard.tensor.device ) self.assertEqual((2, 20), local_shard.tensor.size()) @@ -978,7 +986,7 @@ def test_multiple_local_shards(self): self.assertEqual([shard_idx * 2, 0], shard_metadata.shard_offsets) self.assertEqual([2, 20], shard_metadata.shard_sizes) self.assertEqual( - f"rank:{shard_idx % 4}/cuda:{shard_idx % 4}", + f"rank:{shard_idx % 4}/{DEVICE_TYPE}:{shard_idx % 4}", str(shard_metadata.placement), ) @@ -993,7 +1001,7 @@ def test_multiple_local_shards(self): self.assertEqual(rpc_rank, remote_shard.owner().id) @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_sharding_columns(self): self.init_pg() @@ -1001,10 +1009,10 @@ def test_sharding_columns(self): spec = ChunkShardingSpec( dim=dim, placements=[ - "rank:0/cuda:0", - "rank:1/cuda:1", - "rank:2/cuda:2", - "rank:3/cuda:3", + f"rank:0/{DEVICE_TYPE}:0", + f"rank:1/{DEVICE_TYPE}:1", + f"rank:2/{DEVICE_TYPE}:2", + f"rank:3/{DEVICE_TYPE}:3", ], ) @@ -1014,7 +1022,7 @@ def test_sharding_columns(self): local_shards = st.local_shards() self.assertEqual(1, len(local_shards)) local_shard = local_shards[0].tensor - self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.device) + self.assertEqual(torch.device(DEVICE_TYPE, self.rank), local_shard.device) self.assertEqual((10, 8), local_shard.size()) # Validate global metadata. @@ -1026,32 +1034,32 @@ def test_sharding_columns(self): self.assertEqual([0, rank * 8], shard_metadata.shard_offsets) self.assertEqual([10, 8], shard_metadata.shard_sizes) self.assertEqual( - f"rank:{rank}/cuda:{rank}", str(shard_metadata.placement) + f"rank:{rank}/{DEVICE_TYPE}:{rank}", str(shard_metadata.placement) ) @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_invalid_sharding(self): self.init_pg() with self.assertRaisesRegex( NotImplementedError, "does not support named dimension" ): - spec = ChunkShardingSpec(dim="H", placements=["rank:1/cuda:1"]) + spec = ChunkShardingSpec(dim="H", placements=[f"rank:1/{DEVICE_TYPE}:1"]) sharded_tensor.empty(spec, 10, 20) for dim in [2, 3, 4, -3, -4, -5]: - spec = ChunkShardingSpec(dim=dim, placements=["rank:1/cuda:1"]) + spec = ChunkShardingSpec(dim=dim, placements=[f"rank:1/{DEVICE_TYPE}:1"]) with self.assertRaisesRegex(ValueError, "Invalid sharding dim"): sharded_tensor.empty(spec, 10, 20) - spec = ChunkShardingSpec(dim=0, placements=["rank:5/cuda:1"]) + spec = ChunkShardingSpec(dim=0, placements=[f"rank:5/{DEVICE_TYPE}:1"]) with self.assertRaisesRegex( ValueError, "Global rank 5 does not exist in input process group" ): sharded_tensor.empty(spec, 10, 20) - spec = ChunkShardingSpec(dim=0, placements=["rank:0/cuda:1"]) + spec = ChunkShardingSpec(dim=0, placements=[f"rank:0/{DEVICE_TYPE}:1"]) st = sharded_tensor.empty(spec, 10, 20) tensor = torch.empty(10, 20) with self.assertRaisesRegex( @@ -1059,26 +1067,26 @@ def test_invalid_sharding(self): ): torch.add(st, tensor) - spec = ChunkShardingSpec(dim=0, placements=["rank:0/cuda:1"]) + spec = ChunkShardingSpec(dim=0, placements=[f"rank:0/{DEVICE_TYPE}:1"]) with self.assertRaisesRegex( ValueError, "Only torch.strided layout is currently supported" ): sharded_tensor.empty(spec, 10, 20, layout=torch.sparse_coo) - spec = ChunkShardingSpec(dim=0, placements=["rank:0/cuda:1"]) + spec = ChunkShardingSpec(dim=0, placements=[f"rank:0/{DEVICE_TYPE}:1"]) with self.assertRaisesRegex( ValueError, "Only torch.contiguous_format memory_format is currently supported", ): sharded_tensor.empty(spec, 10, 20, memory_format=torch.channels_last) - spec = ChunkShardingSpec(dim=0, placements=["worker0/cuda:1"]) + spec = ChunkShardingSpec(dim=0, placements=[f"worker0/{DEVICE_TYPE}:1"]) with self.assertRaisesRegex( RuntimeError, "RPC framework needs to be initialized" ): sharded_tensor.empty(spec, 10, 20) - spec = ChunkShardingSpec(dim=0, placements=["rank:0/cuda:1"]) + spec = ChunkShardingSpec(dim=0, placements=[f"rank:0/{DEVICE_TYPE}:1"]) with self.assertRaisesRegex( RuntimeError, "RPC Framework needs to be initialized" ): @@ -1091,12 +1099,12 @@ def test_invalid_sharding(self): st.remote_shards() self.init_rpc() - spec = ChunkShardingSpec(dim=0, placements=["workerfoo/cuda:1"]) + spec = ChunkShardingSpec(dim=0, placements=[f"workerfoo/{DEVICE_TYPE}:1"]) with self.assertRaisesRegex(ValueError, "Invalid worker name"): sharded_tensor.empty(spec, 10, 20, init_rrefs=True) @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_invalid_pg_rpc_ranks(self): self.init_pg() @@ -1113,24 +1121,24 @@ def test_invalid_pg_rpc_ranks(self): rpc_backend_options=rpc_backend_options, ) - spec = ChunkShardingSpec(dim=0, placements=["rank:1/cuda:1"]) + spec = ChunkShardingSpec(dim=0, placements=[f"rank:1/{DEVICE_TYPE}:1"]) with self.assertRaisesRegex( ValueError, "Default ProcessGroup and RPC ranks must be the same" ): sharded_tensor.empty(spec, 10, 20, init_rrefs=True) @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_insufficient_sharding_dims(self): self.init_pg() spec = ChunkShardingSpec( dim=0, placements=[ - "rank:0/cuda:0", - "rank:1/cuda:1", - "rank:2/cuda:2", - "rank:3/cuda:3", + f"rank:0/{DEVICE_TYPE}:0", + f"rank:1/{DEVICE_TYPE}:1", + f"rank:2/{DEVICE_TYPE}:2", + f"rank:3/{DEVICE_TYPE}:3", ], ) st = sharded_tensor.empty(spec, 2, 20) @@ -1140,12 +1148,12 @@ def test_insufficient_sharding_dims(self): if self.rank <= 1: self.assertEqual(1, len(local_shards)) local_shard = local_shards[0].tensor - self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.device) + self.assertEqual(torch.device(DEVICE_TYPE, self.rank), local_shard.device) self.assertEqual((1, 20), local_shard.size()) else: self.assertEqual(1, len(local_shards)) local_shard = local_shards[0].tensor - self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.device) + self.assertEqual(torch.device(DEVICE_TYPE, self.rank), local_shard.device) self.assertEqual(local_shard.numel(), 0) # Validate global metadata. @@ -1156,7 +1164,7 @@ def test_insufficient_sharding_dims(self): for shard_rank, shard_metadata in enumerate(shards_metadata): self.assertEqual([shard_rank, 0], shard_metadata.shard_offsets) self.assertEqual( - f"rank:{shard_rank}/cuda:{shard_rank}", str(shard_metadata.placement) + f"rank:{shard_rank}/{DEVICE_TYPE}:{shard_rank}", str(shard_metadata.placement) ) if shard_rank <= 1: self.assertEqual([1, 20], shard_metadata.shard_sizes) @@ -1165,15 +1173,15 @@ def test_insufficient_sharding_dims(self): @with_comms @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_sharded_tensor_sizes(self): spec = ChunkShardingSpec( dim=0, placements=[ - "rank:0/cuda:0", - "rank:1/cuda:1", - "rank:2/cuda:2", - "rank:3/cuda:3", + f"rank:0/{DEVICE_TYPE}:0", + f"rank:1/{DEVICE_TYPE}:1", + f"rank:2/{DEVICE_TYPE}:2", + f"rank:3/{DEVICE_TYPE}:3", ], ) @@ -1220,15 +1228,15 @@ def test_sharded_tensor_sizes(self): @with_comms @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_state_dict(self): spec = ChunkShardingSpec( dim=0, placements=[ - "rank:0/cuda:0", - "rank:1/cuda:1", - "rank:2/cuda:2", - "rank:3/cuda:3", + f"rank:0/{DEVICE_TYPE}:0", + f"rank:1/{DEVICE_TYPE}:1", + f"rank:2/{DEVICE_TYPE}:2", + f"rank:3/{DEVICE_TYPE}:3", ], ) @@ -1266,15 +1274,15 @@ def test_state_dict(self): @with_comms @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_state_dict_new_group(self): spec = ChunkShardingSpec( dim=0, placements=[ - "rank:2/cuda:0", - "rank:3/cuda:1", - "rank:2/cuda:2", - "rank:3/cuda:3", + f"rank:2/{DEVICE_TYPE}:0", + f"rank:3/{DEVICE_TYPE}:1", + f"rank:2/{DEVICE_TYPE}:2", + f"rank:3/{DEVICE_TYPE}:3", ], ) @@ -1307,7 +1315,7 @@ def test_state_dict_new_group(self): @with_comms @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_state_dict_no_sharded_tensors(self): # Verify hooks don't affect modules with no ShardedTensors. m = torch.nn.Linear(10, 10) @@ -1332,12 +1340,12 @@ def test_state_dict_no_sharded_tensors(self): self.assertEqual(m.bias, module_load.bias) @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_load_state_dict_errors(self): self.init_rpc() dist.init_process_group( - backend="nccl", + backend=BACKEND, world_size=self.world_size, rank=self.rank, init_method=f"file://{self.file_name}", @@ -1346,10 +1354,10 @@ def test_load_state_dict_errors(self): spec = ChunkShardingSpec( dim=0, placements=[ - "rank:0/cuda:0", - "rank:1/cuda:1", - "rank:2/cuda:2", - "rank:3/cuda:3", + f"rank:0/{DEVICE_TYPE}:0", + f"rank:1/{DEVICE_TYPE}:1", + f"rank:2/{DEVICE_TYPE}:2", + f"rank:3/{DEVICE_TYPE}:3", ], ) @@ -1387,16 +1395,16 @@ def test_load_state_dict_errors(self): @with_comms @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_cleanup(self): def create_tensors(): spec = ChunkShardingSpec( dim=0, placements=[ - "rank:0/cuda:0", - "rank:1/cuda:1", - "rank:2/cuda:2", - "rank:3/cuda:3", + f"rank:0/{DEVICE_TYPE}:0", + f"rank:1/{DEVICE_TYPE}:1", + f"rank:2/{DEVICE_TYPE}:2", + f"rank:3/{DEVICE_TYPE}:3", ], ) sharded_tensor.empty(spec, 10, 20, init_rrefs=True) @@ -1409,29 +1417,29 @@ def create_tensors(): class TestShardedTensorEnumerable(ShardedTensorTestBase): @with_comms @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_sharded_tensor_metadata(self): spec = EnumerableShardingSpec( [ ShardMetadata( shard_offsets=[0, 0], shard_sizes=[5, 5], - placement="rank:0/cuda:0", + placement=f"rank:0/{DEVICE_TYPE}:0", ), ShardMetadata( shard_offsets=[0, 5], shard_sizes=[5, 5], - placement="rank:1/cuda:1", + placement=f"rank:1/{DEVICE_TYPE}:1", ), ShardMetadata( shard_offsets=[5, 0], shard_sizes=[5, 5], - placement="rank:2/cuda:2", + placement=f"rank:2/{DEVICE_TYPE}:2", ), ShardMetadata( shard_offsets=[5, 5], shard_sizes=[5, 5], - placement="rank:3/cuda:3", + placement=f"rank:3/{DEVICE_TYPE}:3", ), ] ) @@ -1483,29 +1491,29 @@ def test_sharded_tensor_metadata(self): @skipIfRocm @with_comms @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_grid_sharding(self): spec = EnumerableShardingSpec( [ ShardMetadata( shard_offsets=[0, 0], shard_sizes=[5, 5], - placement="rank:0/cuda:0", + placement=f"rank:0/{DEVICE_TYPE}:0", ), ShardMetadata( shard_offsets=[0, 5], shard_sizes=[5, 5], - placement="rank:1/cuda:1", + placement=f"rank:1/{DEVICE_TYPE}:1", ), ShardMetadata( shard_offsets=[5, 0], shard_sizes=[5, 5], - placement="rank:2/cuda:2", + placement=f"rank:2/{DEVICE_TYPE}:2", ), ShardMetadata( shard_offsets=[5, 5], shard_sizes=[5, 5], - placement="rank:3/cuda:3", + placement=f"rank:3/{DEVICE_TYPE}:3", ), ] ) @@ -1516,7 +1524,7 @@ def test_grid_sharding(self): # Verify local shard. local_shard = st.local_shards()[0] - self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.tensor.device) + self.assertEqual(torch.device(DEVICE_TYPE, self.rank), local_shard.tensor.device) self.assertEqual((5, 5), local_shard.tensor.size()) # Verify local shard metadata. @@ -1526,7 +1534,7 @@ def test_grid_sharding(self): ) self.assertEqual((5, 5), local_shard.metadata.shard_sizes) self.assertEqual( - f"rank:{self.rank}/cuda:{self.rank}", str(local_shard.metadata.placement) + f"rank:{self.rank}/{DEVICE_TYPE}:{self.rank}", str(local_shard.metadata.placement) ) # Verify global metadata. @@ -1538,7 +1546,7 @@ def test_grid_sharding(self): (rank // 2 * 5, (rank % 2) * 5), shard_metadata.shard_offsets ) self.assertEqual((5, 5), shard_metadata.shard_sizes) - self.assertEqual(f"rank:{rank}/cuda:{rank}", str(shard_metadata.placement)) + self.assertEqual(f"rank:{rank}/{DEVICE_TYPE}:{rank}", str(shard_metadata.placement)) # Validate remote shards. remote_shards = st.remote_shards() @@ -1553,7 +1561,7 @@ def test_grid_sharding(self): @with_comms @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_create_sharded_tensor_with_ones(self): """Test sharded_tensor.ones(...)""" @@ -1562,22 +1570,22 @@ def test_create_sharded_tensor_with_ones(self): ShardMetadata( shard_offsets=[0, 0], shard_sizes=[5, 5], - placement="rank:0/cuda:0", + placement=f"rank:0/{DEVICE_TYPE}:0", ), ShardMetadata( shard_offsets=[0, 5], shard_sizes=[5, 5], - placement="rank:1/cuda:1", + placement=f"rank:1/{DEVICE_TYPE}:1", ), ShardMetadata( shard_offsets=[5, 0], shard_sizes=[5, 5], - placement="rank:2/cuda:2", + placement=f"rank:2/{DEVICE_TYPE}:2", ), ShardMetadata( shard_offsets=[5, 5], shard_sizes=[5, 5], - placement="rank:3/cuda:3", + placement=f"rank:3/{DEVICE_TYPE}:3", ), ] ) @@ -1588,13 +1596,13 @@ def test_create_sharded_tensor_with_ones(self): # Verify local shard is initialized with torch.ones local_shard = st.local_shards()[0] - self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.tensor.device) + self.assertEqual(torch.device(DEVICE_TYPE, self.rank), local_shard.tensor.device) self.assertEqual((5, 5), local_shard.tensor.size()) self.assertEqual(local_shard.tensor, torch.ones(5, 5)) @with_comms @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_gather_even(self) -> None: """Test _sharded_tensor.gather(...) with evenly distributed._shards""" @@ -1603,22 +1611,22 @@ def test_gather_even(self) -> None: ShardMetadata( shard_offsets=[0, 0], shard_sizes=[5, 5], - placement="rank:0/cuda:0", + placement=f"rank:0/{DEVICE_TYPE}:0", ), ShardMetadata( shard_offsets=[0, 5], shard_sizes=[5, 5], - placement="rank:1/cuda:1", + placement=f"rank:1/{DEVICE_TYPE}:1", ), ShardMetadata( shard_offsets=[5, 0], shard_sizes=[5, 5], - placement="rank:2/cuda:2", + placement=f"rank:2/{DEVICE_TYPE}:2", ), ShardMetadata( shard_offsets=[5, 5], shard_sizes=[5, 5], - placement="rank:3/cuda:3", + placement=f"rank:3/{DEVICE_TYPE}:3", ), ] ) @@ -1629,7 +1637,7 @@ def test_gather_even(self) -> None: full_tensor = None dst = 0 if self.rank == dst: - full_tensor = torch.zeros(h, w, device=torch.device(f"cuda:{dst}")) + full_tensor = torch.zeros(h, w, device=torch.device(DEVICE_TYPE, dst)) st.gather(dst, full_tensor) if self.rank == dst: @@ -1639,7 +1647,7 @@ def test_gather_even(self) -> None: @with_comms @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_gather_uneven(self) -> None: """Test _sharded_tensor.gather(...) with unevenly distributed._shards""" @@ -1648,22 +1656,22 @@ def test_gather_uneven(self) -> None: ShardMetadata( shard_offsets=[0, 0], shard_sizes=[5, 5], - placement="rank:0/cuda:0", + placement=f"rank:0/{DEVICE_TYPE}:0", ), ShardMetadata( shard_offsets=[0, 5], shard_sizes=[5, 5], - placement="rank:1/cuda:1", + placement=f"rank:1/{DEVICE_TYPE}:1", ), ShardMetadata( shard_offsets=[5, 0], shard_sizes=[5, 5], - placement="rank:0/cuda:0", + placement=f"rank:0/{DEVICE_TYPE}:0", ), ShardMetadata( shard_offsets=[5, 5], shard_sizes=[5, 5], - placement="rank:3/cuda:3", + placement=f"rank:3/{DEVICE_TYPE}:3", ), ] ) @@ -1674,7 +1682,7 @@ def test_gather_uneven(self) -> None: full_tensor = None dst = 0 if self.rank == dst: - full_tensor = torch.zeros(h, w, device=torch.device(f"cuda:{dst}")) + full_tensor = torch.zeros(h, w, device=torch.device(DEVICE_TYPE, dst)) st.gather(dst, full_tensor) if self.rank == dst: @@ -1684,7 +1692,7 @@ def test_gather_uneven(self) -> None: @with_comms @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_sharded_tensor_to_cpu(self): cpu_spec = ChunkShardingSpec( dim=0, @@ -1698,10 +1706,10 @@ def test_sharded_tensor_to_cpu(self): spec = ChunkShardingSpec( dim=0, placements=[ - "rank:0/cuda:0", - "rank:1/cuda:1", - "rank:2/cuda:2", - "rank:3/cuda:3", + f"rank:0/{DEVICE_TYPE}:0", + f"rank:1/{DEVICE_TYPE}:1", + f"rank:2/{DEVICE_TYPE}:2", + f"rank:3/{DEVICE_TYPE}:3", ], ) h, w = 10, 20 @@ -1744,8 +1752,8 @@ def test_sharded_tensor_to_cpu(self): placements=[ "rank:0/cpu", "rank:1/cpu", - "rank:2/cuda:2", - "rank:3/cuda:3", + f"rank:2/{DEVICE_TYPE}:2", + f"rank:3/{DEVICE_TYPE}:3", ], ) @@ -1771,7 +1779,7 @@ def test_sharded_tensor_to_cpu(self): @with_comms @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_sharded_tensor_to_cuda(self): cpu_spec = ChunkShardingSpec( dim=0, @@ -1785,19 +1793,19 @@ def test_sharded_tensor_to_cuda(self): spec = ChunkShardingSpec( dim=0, placements=[ - "rank:0/cuda:0", - "rank:1/cuda:1", - "rank:2/cuda:2", - "rank:3/cuda:3", + f"rank:0/{DEVICE_TYPE}:0", + f"rank:1/{DEVICE_TYPE}:1", + f"rank:2/{DEVICE_TYPE}:2", + f"rank:3/{DEVICE_TYPE}:3", ], ) h, w = 10, 20 - # CUDA sharded tensor should return a new ShardedTensor, but same + # Accelerator sharded tensor should return a new ShardedTensor, but same # local shards(no movements) - st_cuda = sharded_tensor.zeros(spec, h, w) - new_st_cuda = st_cuda.cuda() - self.assertTrue(st_cuda is not new_st_cuda) - self.assertTrue(st_cuda.local_tensor() is new_st_cuda.local_tensor()) + st_dev = sharded_tensor.zeros(spec, h, w) + new_st_dev = st_dev.cuda() + self.assertTrue(st_dev is not new_st_dev) + self.assertTrue(st_dev.local_tensor() is new_st_dev.local_tensor()) gloo_pg = dist.new_group(backend="gloo") @@ -1818,32 +1826,32 @@ def test_sharded_tensor_to_cuda(self): remote_device_before = spec_before_move.placements[i] self.assertEqual(remote_device_before.rank(), remote_device_after.rank()) self.assertEqual(str(remote_device_before.device().type), "cpu") - self.assertEqual(str(remote_device_after.device().type), "cuda") + self.assertEqual(str(remote_device_after.device().type), DEVICE_TYPE) # ensure metadata also get changed to GPU metas = new_st_gpu.metadata().shards_metadata for meta in metas: - self.assertEqual(str(meta.placement.device().type), "cuda") + self.assertEqual(str(meta.placement.device().type), DEVICE_TYPE) @with_comms @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_sharded_tensor_to_test(self): spec = ChunkShardingSpec( dim=0, placements=[ - "rank:0/cuda:0", - "rank:1/cuda:1", - "rank:2/cuda:2", - "rank:3/cuda:3", + f"rank:0/{DEVICE_TYPE}:0", + f"rank:1/{DEVICE_TYPE}:1", + f"rank:2/{DEVICE_TYPE}:2", + f"rank:3/{DEVICE_TYPE}:3", ], ) h, w = 10, 20 - # CUDA sharded tensor should return a new ShardedTensor, but same + # Accelerator sharded tensor should return a new ShardedTensor, but same # local shards(no movements) st = sharded_tensor.zeros(spec, h, w) # test same dtype, device return itself - st_self = st.to(dtype=st.dtype, device="cuda") + st_self = st.to(dtype=st.dtype, device=DEVICE_TYPE) self.assertTrue(st_self is st) # test dtype to @@ -1854,42 +1862,42 @@ def test_sharded_tensor_to_test(self): st_cpu = st.to(device=torch.device("cpu")) self.assertFalse(st_cpu is st) self.assertEqual(st_cpu.local_tensor().device.type, "cpu") - st_cuda = st_cpu.to(device=torch.device("cuda")) - self.assertEqual(st_cuda.local_tensor().device.type, "cuda") + st_dev = st_cpu.to(device=torch.device(DEVICE_TYPE)) + self.assertEqual(st_dev.local_tensor().device.type, DEVICE_TYPE) # non-kwarg device to - st_cuda = st_cpu.to(torch.device("cuda")) - self.assertEqual(st_cuda.local_tensor().device.type, "cuda") - st_cpu = st_cuda.to(torch.device("cpu")) + st_dev = st_cpu.to(torch.device(DEVICE_TYPE)) + self.assertEqual(st_dev.local_tensor().device.type, DEVICE_TYPE) + st_cpu = st_dev.to(torch.device("cpu")) self.assertEqual(st_cpu.local_tensor().device.type, "cpu") # with string like device conversion - st_cpu = st_cuda.to("cpu") + st_cpu = st_dev.to("cpu") self.assertEqual(st_cpu.local_tensor().device.type, "cpu") - st_cuda = st_cpu.to("cuda") - self.assertEqual(st_cuda.local_tensor().device.type, "cuda") + st_dev = st_cpu.to(DEVICE_TYPE) + self.assertEqual(st_dev.local_tensor().device.type, DEVICE_TYPE) # with int like device conversion - st_cpu = st_cuda.to("cpu") + st_cpu = st_dev.to("cpu") self.assertEqual(st_cpu.local_tensor().device.type, "cpu") - st_cuda = st_cpu.to(self.rank) - self.assertEqual(st_cuda.local_tensor().device.type, "cuda") + st_dev = st_cpu.to(self.rank) + self.assertEqual(st_dev.local_tensor().device.type, DEVICE_TYPE) # test tensor to - cuda_tensor = torch.randn(3, 4, dtype=torch.float16, device="cuda") - st_cuda = st.to(cuda_tensor) - self.assertFalse(st_cuda is st) - self.assertEqual(st_cuda.dtype, torch.float16) + dev_tensor = torch.randn(3, 4, dtype=torch.float16, device=DEVICE_TYPE) + st_dev = st.to(dev_tensor) + self.assertFalse(st_dev is st) + self.assertEqual(st_dev.dtype, torch.float16) - cuda_tensor = torch.randn(3, 4, dtype=torch.float16, device="cuda:2") - st_cuda = st.to(cuda_tensor) - self.assertEqual(st_cuda.dtype, torch.float16) + dev_tensor = torch.randn(3, 4, dtype=torch.float16, device=f"{DEVICE_TYPE}:2") + st_dev = st.to(dev_tensor) + self.assertEqual(st_dev.dtype, torch.float16) # test dtype and device together st_cpu_16 = st.to("cpu", torch.float16) self.assertEqual(st_cpu_16.dtype, torch.float16) self.assertEqual(st_cpu_16.local_tensor().device.type, "cpu") - st_cuda_32 = st_cpu_16.to("cuda", torch.float32) - self.assertEqual(st_cuda_32.dtype, torch.float32) - self.assertEqual(st_cuda_32.local_tensor().device.type, "cuda") + st_dev_32 = st_cpu_16.to(DEVICE_TYPE, torch.float32) + self.assertEqual(st_dev_32.dtype, torch.float32) + self.assertEqual(st_dev_32.local_tensor().device.type, DEVICE_TYPE) # test pass additional process group gloo_pg = dist.new_group(backend="gloo") @@ -1900,22 +1908,22 @@ def test_sharded_tensor_to_test(self): @with_comms @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_sharded_tensor_device(self): spec = ChunkShardingSpec( dim=0, placements=[ - "rank:0/cuda:0", - "rank:1/cuda:1", - "rank:2/cuda:2", - "rank:3/cuda:3", + f"rank:0/{DEVICE_TYPE}:0", + f"rank:1/{DEVICE_TYPE}:1", + f"rank:2/{DEVICE_TYPE}:2", + f"rank:3/{DEVICE_TYPE}:3", ], ) h, w = 10, 20 - # CUDA sharded tensor should return a new ShardedTensor, but same + # Accelerator sharded tensor should return a new ShardedTensor, but same # local shards(no movements) st = sharded_tensor.zeros(spec, h, w) - current_device = torch.device(torch.cuda.current_device()) + current_device = torch.device(torch.accelerator.current_accelerator().type, torch.accelerator.current_device_idx()) self.assertEqual(current_device, st.device) # test after to cpu, device get changed @@ -1924,7 +1932,7 @@ def test_sharded_tensor_device(self): self.assertEqual(st_cpu.device, cpu_device) @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_uneven_shards(self): self.init_pg() @@ -1933,22 +1941,22 @@ def test_uneven_shards(self): ShardMetadata( shard_offsets=[0, 0], shard_sizes=[2, 4], - placement="rank:0/cuda:0", + placement=f"rank:0/{DEVICE_TYPE}:0", ), ShardMetadata( shard_offsets=[0, 4], shard_sizes=[4, 2], - placement="rank:1/cuda:1", + placement=f"rank:1/{DEVICE_TYPE}:1", ), ShardMetadata( shard_offsets=[2, 0], shard_sizes=[4, 4], - placement="rank:2/cuda:2", + placement=f"rank:2/{DEVICE_TYPE}:2", ), ShardMetadata( shard_offsets=[4, 4], shard_sizes=[2, 2], - placement="rank:3/cuda:3", + placement=f"rank:3/{DEVICE_TYPE}:3", ), ] ) @@ -1979,14 +1987,14 @@ def verify_offsets(rank, offsets): # Verify local shard. local_shard = st.local_shards()[0] - self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.tensor.device) + self.assertEqual(torch.device(DEVICE_TYPE, self.rank), local_shard.tensor.device) verify_size(self.rank, local_shard.tensor.size()) # Verify local shard metadata. verify_offsets(self.rank, local_shard.metadata.shard_offsets) verify_size(self.rank, local_shard.metadata.shard_sizes) self.assertEqual( - f"rank:{self.rank}/cuda:{self.rank}", str(local_shard.metadata.placement) + f"rank:{self.rank}/{DEVICE_TYPE}:{self.rank}", str(local_shard.metadata.placement) ) # Verify global metadata. @@ -1996,24 +2004,25 @@ def verify_offsets(rank, offsets): for rank, shard_metadata in enumerate(shards_metadata): verify_offsets(rank, shard_metadata.shard_offsets) verify_size(rank, shard_metadata.shard_sizes) - self.assertEqual(f"rank:{rank}/cuda:{rank}", str(shard_metadata.placement)) + self.assertEqual(f"rank:{rank}/{DEVICE_TYPE}:{rank}", str(shard_metadata.placement)) + @skipIfRocm @with_comms @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_partial_world_size(self): spec = EnumerableShardingSpec( [ ShardMetadata( shard_offsets=[0, 0], shard_sizes=[5, 5], - placement="rank:0/cuda:0", + placement=f"rank:0/{DEVICE_TYPE}:0", ), ShardMetadata( shard_offsets=[5, 0], shard_sizes=[5, 5], - placement="rank:1/cuda:1", + placement=f"rank:1/{DEVICE_TYPE}:1", ), ] ) @@ -2029,7 +2038,7 @@ def test_partial_world_size(self): # Verify local shard. local_shard = st.local_shards()[0] self.assertEqual( - torch.device(f"cuda:{self.rank}"), local_shard.tensor.device + torch.device(DEVICE_TYPE, self.rank), local_shard.tensor.device ) self.assertEqual((5, 5), local_shard.tensor.size()) @@ -2037,7 +2046,7 @@ def test_partial_world_size(self): self.assertEqual((self.rank * 5, 0), local_shard.metadata.shard_offsets) self.assertEqual((5, 5), local_shard.metadata.shard_sizes) self.assertEqual( - f"rank:{self.rank}/cuda:{self.rank}", + f"rank:{self.rank}/{DEVICE_TYPE}:{self.rank}", str(local_shard.metadata.placement), ) @@ -2048,7 +2057,7 @@ def test_partial_world_size(self): for rank, shard_metadata in enumerate(shards_metadata): self.assertEqual((rank * 5, 0), shard_metadata.shard_offsets) self.assertEqual((5, 5), shard_metadata.shard_sizes) - self.assertEqual(f"rank:{rank}/cuda:{rank}", str(shard_metadata.placement)) + self.assertEqual(f"rank:{rank}/{DEVICE_TYPE}:{rank}", str(shard_metadata.placement)) # Validate remote shards. remote_shards = st.remote_shards() @@ -2068,19 +2077,19 @@ def test_partial_world_size(self): @skipIfRocm @with_comms @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_new_group(self): spec = EnumerableShardingSpec( [ ShardMetadata( shard_offsets=[0, 0], shard_sizes=[5, 5], - placement="rank:1/cuda:1", + placement=f"rank:1/{DEVICE_TYPE}:1", ), ShardMetadata( shard_offsets=[5, 0], shard_sizes=[5, 5], - placement="rank:3/cuda:3", + placement=f"rank:3/{DEVICE_TYPE}:3", ), ] ) @@ -2093,7 +2102,7 @@ def test_new_group(self): # Verify local shard. local_shard = st.local_shards()[0] self.assertEqual( - torch.device(f"cuda:{self.rank}"), local_shard.tensor.device + torch.device(DEVICE_TYPE, self.rank), local_shard.tensor.device ) self.assertEqual((5, 5), local_shard.tensor.size()) @@ -2103,7 +2112,7 @@ def test_new_group(self): ) self.assertEqual((5, 5), local_shard.metadata.shard_sizes) self.assertEqual( - f"rank:{self.rank}/cuda:{self.rank}", + f"rank:{self.rank}/{DEVICE_TYPE}:{self.rank}", str(local_shard.metadata.placement), ) @@ -2115,7 +2124,7 @@ def test_new_group(self): self.assertEqual((rank * 5, 0), shard_metadata.shard_offsets) self.assertEqual((5, 5), shard_metadata.shard_sizes) self.assertEqual( - f"rank:{rank * 2 + 1}/cuda:{rank * 2 + 1}", + f"rank:{rank * 2 + 1}/{DEVICE_TYPE}:{rank * 2 + 1}", str(shard_metadata.placement), ) @@ -2137,29 +2146,29 @@ def test_new_group(self): @skipIfRocm @with_comms @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_multiple_local_shards(self): spec = EnumerableShardingSpec( [ ShardMetadata( shard_offsets=[0, 0], shard_sizes=[5, 5], - placement="rank:0/cuda:0", + placement=f"rank:0/{DEVICE_TYPE}:0", ), ShardMetadata( shard_offsets=[0, 5], shard_sizes=[5, 5], - placement="rank:1/cuda:1", + placement=f"rank:1/{DEVICE_TYPE}:1", ), ShardMetadata( shard_offsets=[5, 0], shard_sizes=[5, 5], - placement="rank:0/cuda:0", + placement=f"rank:0/{DEVICE_TYPE}:0", ), ShardMetadata( shard_offsets=[5, 5], shard_sizes=[5, 5], - placement="rank:1/cuda:1", + placement=f"rank:1/{DEVICE_TYPE}:1", ), ] ) @@ -2173,7 +2182,7 @@ def test_multiple_local_shards(self): # Verify local shards. for idx, local_shard in enumerate(st.local_shards()): self.assertEqual( - torch.device(f"cuda:{self.rank}"), local_shard.tensor.device + torch.device(DEVICE_TYPE, self.rank), local_shard.tensor.device ) self.assertEqual((5, 5), local_shard.tensor.size()) @@ -2183,7 +2192,7 @@ def test_multiple_local_shards(self): ) self.assertEqual((5, 5), local_shard.metadata.shard_sizes) self.assertEqual( - f"rank:{self.rank}/cuda:{self.rank}", + f"rank:{self.rank}/{DEVICE_TYPE}:{self.rank}", str(local_shard.metadata.placement), ) else: @@ -2200,7 +2209,7 @@ def test_multiple_local_shards(self): ) self.assertEqual((5, 5), shard_metadata.shard_sizes) self.assertEqual( - f"rank:{shard_rank % 2}/cuda:{shard_rank % 2}", + f"rank:{shard_rank % 2}/{DEVICE_TYPE}:{shard_rank % 2}", str(shard_metadata.placement), ) @@ -2221,29 +2230,29 @@ def test_multiple_local_shards(self): @skipIfRocm @with_comms @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_with_rpc_names(self): spec = EnumerableShardingSpec( [ ShardMetadata( shard_offsets=[0, 0], shard_sizes=[5, 5], - placement="worker0/cuda:0", + placement=f"worker0/{DEVICE_TYPE}:0", ), ShardMetadata( shard_offsets=[0, 5], shard_sizes=[5, 5], - placement="worker1/cuda:1", + placement=f"worker1/{DEVICE_TYPE}:1", ), ShardMetadata( shard_offsets=[5, 0], shard_sizes=[5, 5], - placement="worker2/cuda:2", + placement=f"worker2/{DEVICE_TYPE}:2", ), ShardMetadata( shard_offsets=[5, 5], shard_sizes=[5, 5], - placement="worker3/cuda:3", + placement=f"worker3/{DEVICE_TYPE}:3", ), ] ) @@ -2254,7 +2263,7 @@ def test_with_rpc_names(self): # Verify local shard. local_shard = st.local_shards()[0] - self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.tensor.device) + self.assertEqual(torch.device(DEVICE_TYPE, self.rank), local_shard.tensor.device) self.assertEqual((5, 5), local_shard.tensor.size()) # Verify local shard metadata. @@ -2264,7 +2273,7 @@ def test_with_rpc_names(self): ) self.assertEqual((5, 5), local_shard.metadata.shard_sizes) self.assertEqual( - f"worker{self.rank}/cuda:{self.rank}", str(local_shard.metadata.placement) + f"worker{self.rank}/{DEVICE_TYPE}:{self.rank}", str(local_shard.metadata.placement) ) # Verify global metadata. @@ -2276,7 +2285,7 @@ def test_with_rpc_names(self): (rank // 2 * 5, (rank % 2) * 5), shard_metadata.shard_offsets ) self.assertEqual((5, 5), shard_metadata.shard_sizes) - self.assertEqual(f"worker{rank}/cuda:{rank}", str(shard_metadata.placement)) + self.assertEqual(f"worker{rank}/{DEVICE_TYPE}:{rank}", str(shard_metadata.placement)) # Validate remote shards. remote_shards = st.remote_shards() @@ -2304,7 +2313,7 @@ def _generate_st_from_chunk_local_tensor(self, st_size, sharding_spec): ) rank_to_metadata[rank] = shard_metadata if rank == self.rank: - local_tensor = torch.rand(shard_metadata.shard_sizes).cuda(device) + local_tensor = torch.rand(shard_metadata.shard_sizes).to(torch.device(DEVICE_TYPE, device)) local_shard_metadata = shard_metadata # TODO: figure out what the API should behave when some rank have no shard @@ -2323,7 +2332,7 @@ def _generate_st_from_chunk_local_tensor(self, st_size, sharding_spec): # Verify local shard. local_shard = st.local_shards()[0] self.assertEqual(st.local_tensor(), local_tensor) - self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.tensor.device) + self.assertEqual(torch.device(DEVICE_TYPE, self.rank), local_shard.tensor.device) # Verify local shard metadata. self.assertEqual( @@ -2356,7 +2365,7 @@ def _generate_st_from_chunk_local_tensor(self, st_size, sharding_spec): @skipIfRocm @with_comms @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_init_from_local_tensor(self): chunk_specs = _chunk_sharding_specs_list_for_test([0, 1, 1, 0], seed=31) for spec in chunk_specs: @@ -2367,24 +2376,24 @@ def test_init_from_local_tensor(self): @with_comms @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_init_from_local_tensor_errors(self): enumerable_sharding_spec = EnumerableShardingSpec( [ ShardMetadata( shard_offsets=[0, 0], shard_sizes=[5, 5], - placement="rank:0/cuda:0", + placement=f"rank:0/{DEVICE_TYPE}:0", ), ShardMetadata( shard_offsets=[5, 0], shard_sizes=[5, 5], - placement="rank:1/cuda:1", + placement=f"rank:1/{DEVICE_TYPE}:1", ), ] ) st_size = [24, 12] - local_tensor = torch.rand(*st_size).cuda(self.rank) + local_tensor = torch.rand(*st_size).to(torch.device(DEVICE_TYPE, self.rank)) with self.assertRaisesRegex(ValueError, "do not cover the entire tensor"): ShardedTensor._init_from_local_tensor( local_tensor, @@ -2403,18 +2412,18 @@ def test_init_from_local_tensor_errors(self): class TestShardedTensorFromLocalShards(ShardedTensorTestBase): - @with_comms(init_rpc=False) + @with_comms(init_rpc=False, backend=BACKEND) @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_local_shards(self): shard_offsets = [(self.rank // 2) * 5, (self.rank % 2) * 5] local_shard_metadata = ShardMetadata( shard_offsets=shard_offsets, shard_sizes=[5, 5], - placement=f"rank:{self.rank}/cuda:{self.rank}", + placement=f"rank:{self.rank}/{DEVICE_TYPE}:{self.rank}", ) - local_tensor = torch.randn(5, 5, device=f"cuda:{self.rank}") + local_tensor = torch.randn(5, 5, device=f"{DEVICE_TYPE}:{self.rank}") local_shard = sharded_tensor.Shard(local_tensor, local_shard_metadata) local_shard_from_offsets = sharded_tensor.Shard.from_tensor_and_offsets( local_tensor, shard_offsets=shard_offsets, rank=self.rank @@ -2424,7 +2433,7 @@ def test_local_shards(self): wrong_local_shard_metadata = ShardMetadata( shard_offsets=shard_offsets, shard_sizes=[6, 5], - placement=f"rank:{self.rank}/cuda:{self.rank}", + placement=f"rank:{self.rank}/{DEVICE_TYPE}:{self.rank}", ) with self.assertRaisesRegex(ValueError, "Shard tensor size does not match"): sharded_tensor.Shard(local_tensor, metadata=wrong_local_shard_metadata) @@ -2432,17 +2441,17 @@ def test_local_shards(self): @skipIfRocm @with_comms @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_init_from_local_shards(self): local_shard_metadata = ShardMetadata( shard_offsets=[(self.rank // 2) * 5, (self.rank % 2) * 5], shard_sizes=[5, 5], - placement=f"rank:{self.rank}/cuda:{self.rank}", + placement=f"rank:{self.rank}/{DEVICE_TYPE}:{self.rank}", ) local_shards = [ sharded_tensor.Shard( - torch.randn(5, 5, device=f"cuda:{self.rank}"), local_shard_metadata + torch.randn(5, 5, device=f"{DEVICE_TYPE}:{self.rank}"), local_shard_metadata ) ] @@ -2454,7 +2463,7 @@ def test_init_from_local_shards(self): # Verify local shard. local_shard = st.local_shards()[0] - self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.tensor.device) + self.assertEqual(torch.device(DEVICE_TYPE, self.rank), local_shard.tensor.device) self.assertEqual((5, 5), local_shard.tensor.size()) # Verify local shard metadata. @@ -2464,7 +2473,7 @@ def test_init_from_local_shards(self): ) self.assertEqual((5, 5), local_shard.metadata.shard_sizes) self.assertEqual( - f"rank:{self.rank}/cuda:{self.rank}", str(local_shard.metadata.placement) + f"rank:{self.rank}/{DEVICE_TYPE}:{self.rank}", str(local_shard.metadata.placement) ) # Verify global metadata. @@ -2475,7 +2484,7 @@ def test_init_from_local_shards(self): (rank // 2 * 5, (rank % 2) * 5), shard_metadata.shard_offsets ) self.assertEqual((5, 5), shard_metadata.shard_sizes) - self.assertEqual(f"rank:{rank}/cuda:{rank}", str(shard_metadata.placement)) + self.assertEqual(f"rank:{rank}/{DEVICE_TYPE}:{rank}", str(shard_metadata.placement)) # Validate remote shards. remote_shards = st.remote_shards() @@ -2489,21 +2498,21 @@ def test_init_from_local_shards(self): self.assertEqual((5, 5), shard.tensor.size()) @skipIfRocm - @with_comms(init_rpc=False) + @with_comms(init_rpc=False, backend=BACKEND) @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_recalc_for_metadata(self): shard_sizes = [0, 5] # test 2 different shard sizes for shard_size in shard_sizes: local_shard_metadata = ShardMetadata( shard_offsets=[0, 0], shard_sizes=[shard_size, shard_size], - placement=f"rank:{self.rank}/cuda:{self.rank}", + placement=f"rank:{self.rank}/{DEVICE_TYPE}:{self.rank}", ) local_shards = [ sharded_tensor.Shard( - torch.randn(shard_size, shard_size, device=f"cuda:{self.rank}"), + torch.randn(shard_size, shard_size, device=f"{DEVICE_TYPE}:{self.rank}"), local_shard_metadata, ) ] @@ -2515,7 +2524,7 @@ def test_recalc_for_metadata(self): # Verify local shard. local_shard = st.local_shards()[0] self.assertEqual( - torch.device(f"cuda:{self.rank}"), local_shard.tensor.device + torch.device(DEVICE_TYPE, self.rank), local_shard.tensor.device ) self.assertEqual((shard_size, shard_size), local_shard.tensor.size()) @@ -2526,7 +2535,7 @@ def test_recalc_for_metadata(self): ) self.assertEqual((shard_size, shard_size), local_shard.metadata.shard_sizes) self.assertEqual( - f"rank:{self.rank}/cuda:{self.rank}", + f"rank:{self.rank}/{DEVICE_TYPE}:{self.rank}", str(local_shard.metadata.placement), ) @@ -2537,26 +2546,26 @@ def test_recalc_for_metadata(self): self.assertEqual((rank * shard_size, 0), shard_metadata.shard_offsets) self.assertEqual((shard_size, shard_size), shard_metadata.shard_sizes) self.assertEqual( - f"rank:{rank}/cuda:{rank}", str(shard_metadata.placement) + f"rank:{rank}/{DEVICE_TYPE}:{rank}", str(shard_metadata.placement) ) with self.assertRaises(ValueError): st = sharded_tensor.init_from_local_shards(local_shards) @skipIfRocm - @with_comms(init_rpc=False) + @with_comms(init_rpc=False, backend=BACKEND) @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_init_from_local_shards_with_different_glb_size(self): wrong_offset_local_shard_metadata = ShardMetadata( shard_offsets=[0, 0], shard_sizes=[5, 5], - placement=f"rank:{self.rank}/cuda:{self.rank}", + placement=f"rank:{self.rank}/{DEVICE_TYPE}:{self.rank}", ) wrong_offset_local_shards = [ sharded_tensor.Shard( - torch.randn(5, 5, device=f"cuda:{self.rank}"), + torch.randn(5, 5, device=f"{DEVICE_TYPE}:{self.rank}"), wrong_offset_local_shard_metadata, ) ] @@ -2566,31 +2575,31 @@ def test_init_from_local_shards_with_different_glb_size(self): local_shard_metadata = ShardMetadata( shard_offsets=[self.rank * 5, 0], shard_sizes=[5, 5], - placement=f"rank:{self.rank}/cuda:{self.rank}", + placement=f"rank:{self.rank}/{DEVICE_TYPE}:{self.rank}", ) local_shards = [ sharded_tensor.Shard( - torch.randn(5, 5, device=f"cuda:{self.rank}"), local_shard_metadata + torch.randn(5, 5, device=f"{DEVICE_TYPE}:{self.rank}"), local_shard_metadata ) ] with self.assertRaises(ValueError): sharded_tensor.init_from_local_shards(local_shards, 0, 0) @skipIfRocm - @with_comms(init_rpc=False) + @with_comms(init_rpc=False, backend=BACKEND) @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_non_rw_sharded_recalc_for_metadata(self): local_shard_metadata = ShardMetadata( shard_offsets=[(self.rank // 2) * 5, (self.rank % 2) * 5], shard_sizes=[5, 5], - placement=f"rank:{self.rank}/cuda:{self.rank}", + placement=f"rank:{self.rank}/{DEVICE_TYPE}:{self.rank}", ) local_shards = [ sharded_tensor.Shard( - torch.randn(5, 5, device=f"cuda:{self.rank}"), local_shard_metadata + torch.randn(5, 5, device=f"{DEVICE_TYPE}:{self.rank}"), local_shard_metadata ) ] @@ -2621,12 +2630,12 @@ def test_st_base_init_from_local_shards_and_global_metadata(self): local_shard_metadata = ShardMetadata( shard_offsets=[(rank // 2) * 5, (rank % 2) * 5], shard_sizes=[5, 5], - placement=f"rank:{rank}/cuda:{rank}", + placement=f"rank:{rank}/{DEVICE_TYPE}:{rank}", ) shards_metadata.append(local_shard_metadata) shards.append( sharded_tensor.Shard( - torch.randn(5, 5, device=f"cuda:{rank}"), local_shard_metadata + torch.randn(5, 5, device=f"{DEVICE_TYPE}:{rank}"), local_shard_metadata ) ) @@ -2651,7 +2660,7 @@ def test_st_base_init_from_local_shards_and_global_metadata(self): # Verify local shard of st_base local_shard = st_base.local_shards()[0] - self.assertEqual(torch.device("cuda:0"), local_shard.tensor.device) + self.assertEqual(torch.device(f"{DEVICE_TYPE}:0"), local_shard.tensor.device) self.assertEqual((5, 5), local_shard.tensor.size()) # Verify local shard metadata. @@ -2660,7 +2669,7 @@ def test_st_base_init_from_local_shards_and_global_metadata(self): local_shard.metadata.shard_offsets, ) self.assertEqual((5, 5), local_shard.metadata.shard_sizes) - self.assertEqual("rank:0/cuda:0", str(local_shard.metadata.placement)) + self.assertEqual(f"rank:0/{DEVICE_TYPE}:0", str(local_shard.metadata.placement)) # Verify global metadata. shards_metadata = st_base.metadata().shards_metadata @@ -2670,17 +2679,17 @@ def test_st_base_init_from_local_shards_and_global_metadata(self): (rank // 2 * 5, (rank % 2) * 5), shard_metadata.shard_offsets ) self.assertEqual((5, 5), shard_metadata.shard_sizes) - self.assertEqual(f"rank:{rank}/cuda:{rank}", str(shard_metadata.placement)) + self.assertEqual(f"rank:{rank}/{DEVICE_TYPE}:{rank}", str(shard_metadata.placement)) @skipIfRocm - @with_comms(init_rpc=False) + @with_comms(init_rpc=False, backend=BACKEND) @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_init_from_local_shards_and_global_metadata_with_all_zeros(self): local_shard_metadata = ShardMetadata( shard_offsets=[0, 0], shard_sizes=[0, 0], - placement=f"rank:{self.rank}/cuda:{self.rank}", + placement=f"rank:{self.rank}/{DEVICE_TYPE}:{self.rank}", ) shards_metadata = [] @@ -2692,13 +2701,13 @@ def test_init_from_local_shards_and_global_metadata_with_all_zeros(self): ShardMetadata( shard_offsets=[0, 0], shard_sizes=[0, 0], - placement=f"rank:{r}/cuda:{r}", + placement=f"rank:{r}/{DEVICE_TYPE}:{r}", ) ) local_shards = [ sharded_tensor.Shard( - torch.randn(0, 0, device=f"cuda:{self.rank}"), local_shard_metadata + torch.randn(0, 0, device=f"{DEVICE_TYPE}:{self.rank}"), local_shard_metadata ) ] @@ -2726,7 +2735,7 @@ def test_init_from_local_shards_and_global_metadata_with_all_zeros(self): # Verify local shard. local_shard = st.local_shards()[0] - self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.tensor.device) + self.assertEqual(torch.device(DEVICE_TYPE, self.rank), local_shard.tensor.device) self.assertEqual((0, 0), local_shard.tensor.size()) # Verify local shard metadata. @@ -2736,7 +2745,7 @@ def test_init_from_local_shards_and_global_metadata_with_all_zeros(self): ) self.assertEqual((0, 0), local_shard.metadata.shard_sizes) self.assertEqual( - f"rank:{self.rank}/cuda:{self.rank}", str(local_shard.metadata.placement) + f"rank:{self.rank}/{DEVICE_TYPE}:{self.rank}", str(local_shard.metadata.placement) ) # Verify global metadata. @@ -2745,12 +2754,12 @@ def test_init_from_local_shards_and_global_metadata_with_all_zeros(self): for rank, shard_metadata in enumerate(shards_metadata): self.assertEqual((0, 0), shard_metadata.shard_offsets) self.assertEqual((0, 0), shard_metadata.shard_sizes) - self.assertEqual(f"rank:{rank}/cuda:{rank}", str(shard_metadata.placement)) + self.assertEqual(f"rank:{rank}/{DEVICE_TYPE}:{rank}", str(shard_metadata.placement)) @skipIfRocm - @with_comms(init_rpc=False) + @with_comms(init_rpc=False, backend=BACKEND) @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_init_from_local_shards_and_global_metadata_with_local_view(self): # testing cases where we create ST with local view, meaning we initialize other rank's metadata with 0s shard_offsets = [0, 1] # valid, invalid @@ -2758,7 +2767,7 @@ def test_init_from_local_shards_and_global_metadata_with_local_view(self): local_shard_metadata = ShardMetadata( shard_offsets=[shard_offset, 0], shard_sizes=[5, 5], - placement=f"rank:{self.rank}/cuda:{self.rank}", + placement=f"rank:{self.rank}/{DEVICE_TYPE}:{self.rank}", ) shards_metadata = [] @@ -2770,13 +2779,13 @@ def test_init_from_local_shards_and_global_metadata_with_local_view(self): ShardMetadata( shard_offsets=[0 if r < self.rank else 5, 0], shard_sizes=[0, 0], - placement=f"rank:{r}/cuda:{r}", + placement=f"rank:{r}/{DEVICE_TYPE}:{r}", ) ) local_shards = [ sharded_tensor.Shard( - torch.randn(5, 5, device=f"cuda:{self.rank}"), local_shard_metadata + torch.randn(5, 5, device=f"{DEVICE_TYPE}:{self.rank}"), local_shard_metadata ) ] @@ -2814,7 +2823,7 @@ def test_init_from_local_shards_and_global_metadata_with_local_view(self): # Verify local shard. local_shard = st.local_shards()[0] self.assertEqual( - torch.device(f"cuda:{self.rank}"), local_shard.tensor.device + torch.device(DEVICE_TYPE, self.rank), local_shard.tensor.device ) self.assertEqual((5, 5), local_shard.tensor.size()) @@ -2825,7 +2834,7 @@ def test_init_from_local_shards_and_global_metadata_with_local_view(self): ) self.assertEqual((5, 5), local_shard.metadata.shard_sizes) self.assertEqual( - f"rank:{self.rank}/cuda:{self.rank}", + f"rank:{self.rank}/{DEVICE_TYPE}:{self.rank}", str(local_shard.metadata.placement), ) @@ -2841,18 +2850,18 @@ def test_init_from_local_shards_and_global_metadata_with_local_view(self): else: self.assertEqual((0, 0), shard_metadata.shard_sizes) self.assertEqual( - f"rank:{rank}/cuda:{rank}", str(shard_metadata.placement) + f"rank:{rank}/{DEVICE_TYPE}:{rank}", str(shard_metadata.placement) ) @skipIfRocm @with_comms @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_init_from_local_shards_and_global_metadata(self): local_shard_metadata = ShardMetadata( shard_offsets=[(self.rank // 2) * 5, (self.rank % 2) * 5], shard_sizes=[5, 5], - placement=f"rank:{self.rank}/cuda:{self.rank}", + placement=f"rank:{self.rank}/{DEVICE_TYPE}:{self.rank}", ) shards_metadata = [] @@ -2864,13 +2873,13 @@ def test_init_from_local_shards_and_global_metadata(self): ShardMetadata( shard_offsets=[(r // 2) * 5, (r % 2) * 5], shard_sizes=[5, 5], - placement=f"rank:{r}/cuda:{r}", + placement=f"rank:{r}/{DEVICE_TYPE}:{r}", ) ) local_shards = [ sharded_tensor.Shard( - torch.randn(5, 5, device=f"cuda:{self.rank}"), local_shard_metadata + torch.randn(5, 5, device=f"{DEVICE_TYPE}:{self.rank}"), local_shard_metadata ) ] @@ -2898,7 +2907,7 @@ def test_init_from_local_shards_and_global_metadata(self): # Verify local shard. local_shard = st.local_shards()[0] - self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.tensor.device) + self.assertEqual(torch.device(DEVICE_TYPE, self.rank), local_shard.tensor.device) self.assertEqual((5, 5), local_shard.tensor.size()) # Verify local shard metadata. @@ -2908,7 +2917,7 @@ def test_init_from_local_shards_and_global_metadata(self): ) self.assertEqual((5, 5), local_shard.metadata.shard_sizes) self.assertEqual( - f"rank:{self.rank}/cuda:{self.rank}", str(local_shard.metadata.placement) + f"rank:{self.rank}/{DEVICE_TYPE}:{self.rank}", str(local_shard.metadata.placement) ) # Verify global metadata. @@ -2919,7 +2928,7 @@ def test_init_from_local_shards_and_global_metadata(self): (rank // 2 * 5, (rank % 2) * 5), shard_metadata.shard_offsets ) self.assertEqual((5, 5), shard_metadata.shard_sizes) - self.assertEqual(f"rank:{rank}/cuda:{rank}", str(shard_metadata.placement)) + self.assertEqual(f"rank:{rank}/{DEVICE_TYPE}:{rank}", str(shard_metadata.placement)) # Validate remote shards. remote_shards = st.remote_shards() @@ -2934,7 +2943,7 @@ def test_init_from_local_shards_and_global_metadata(self): @with_comms @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_init_from_local_shards_new_group(self): new_pg = dist.new_group(ranks=[1, 2, 3]) @@ -2942,11 +2951,11 @@ def test_init_from_local_shards_new_group(self): local_shard_metadata = ShardMetadata( shard_offsets=[5 * (self.rank - 1), 0], shard_sizes=[5, 5], - placement=f"rank:{self.rank}/cuda:{self.rank}", + placement=f"rank:{self.rank}/{DEVICE_TYPE}:{self.rank}", ) local_shards = [ sharded_tensor.Shard( - torch.randn(5, 5, device=f"cuda:{self.rank}"), local_shard_metadata + torch.randn(5, 5, device=f"{DEVICE_TYPE}:{self.rank}"), local_shard_metadata ) ] @@ -2957,7 +2966,7 @@ def test_init_from_local_shards_new_group(self): # Verify local shard. local_shard = st.local_shards()[0] self.assertEqual( - torch.device(f"cuda:{self.rank}"), local_shard.tensor.device + torch.device(DEVICE_TYPE, self.rank), local_shard.tensor.device ) self.assertEqual((5, 5), local_shard.tensor.size()) @@ -2967,7 +2976,7 @@ def test_init_from_local_shards_new_group(self): ) self.assertEqual((5, 5), local_shard.metadata.shard_sizes) self.assertEqual( - f"rank:{self.rank}/cuda:{self.rank}", + f"rank:{self.rank}/{DEVICE_TYPE}:{self.rank}", str(local_shard.metadata.placement), ) @@ -2979,23 +2988,23 @@ def test_init_from_local_shards_new_group(self): self.assertEqual((rank * 5, 0), shard_metadata.shard_offsets) self.assertEqual((5, 5), shard_metadata.shard_sizes) self.assertEqual( - f"rank:{rank + 1}/cuda:{rank + 1}", str(shard_metadata.placement) + f"rank:{rank + 1}/{DEVICE_TYPE}:{rank + 1}", str(shard_metadata.placement) ) @with_comms @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_init_from_local_shards_invalid_local_shards(self): local_shard_metadata = ShardMetadata( shard_offsets=[(self.rank // 2) * 5, (self.rank % 2) * 5], shard_sizes=[5, 5], - placement=f"rank:{self.rank}/cuda:{self.rank}", + placement=f"rank:{self.rank}/{DEVICE_TYPE}:{self.rank}", ) indices = [[0, 1, 1], [2, 0, 2]] values = [3.2, 4.5, 5.8] sparse_tensor = torch.sparse_coo_tensor( - indices, values, (5, 5), device=f"cuda:{self.rank}" + indices, values, (5, 5), device=f"{DEVICE_TYPE}:{self.rank}" ) empty_local_shards = [] @@ -3016,7 +3025,7 @@ def test_init_from_local_shards_invalid_local_shards(self): wrong_memory_format_shards = [ sharded_tensor.Shard( - torch.randn(5, 5, device=f"cuda:{self.rank}").t(), local_shard_metadata + torch.randn(5, 5, device=f"{DEVICE_TYPE}:{self.rank}").t(), local_shard_metadata ) ] with self.assertRaisesRegex( @@ -3029,7 +3038,7 @@ def test_init_from_local_shards_invalid_local_shards(self): with self.assertRaisesRegex(ValueError, "Shard tensor size does not match"): sharded_tensor.Shard( - torch.randn(2, 3, device=f"cuda:{self.rank}"), local_shard_metadata + torch.randn(2, 3, device=f"{DEVICE_TYPE}:{self.rank}"), local_shard_metadata ) with self.assertRaisesRegex( @@ -3039,17 +3048,17 @@ def test_init_from_local_shards_invalid_local_shards(self): @with_comms @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_init_from_local_shards_invalid_property_cross_ranks(self): local_shard_metadata = ShardMetadata( shard_offsets=[(self.rank // 2) * 5, (self.rank % 2) * 5], shard_sizes=[5, 5], - placement=f"rank:{self.rank}/cuda:{self.rank}", + placement=f"rank:{self.rank}/{DEVICE_TYPE}:{self.rank}", ) tensor_overall_size = [10, 10] if self.rank == 0 else [10, 5] wrong_dtype_shards = [ sharded_tensor.Shard( - torch.ones(5, 5, device=f"cuda:{self.rank}"), local_shard_metadata + torch.ones(5, 5, device=f"{DEVICE_TYPE}:{self.rank}"), local_shard_metadata ) ] with self.assertRaisesRegex( @@ -3063,7 +3072,7 @@ def test_init_from_local_shards_invalid_property_cross_ranks(self): tensor_dtype = torch.int if self.rank == 0 else torch.float32 wrong_dtype_shards = [ sharded_tensor.Shard( - torch.ones(5, 5, device=f"cuda:{self.rank}", dtype=tensor_dtype), + torch.ones(5, 5, device=f"{DEVICE_TYPE}:{self.rank}", dtype=tensor_dtype), local_shard_metadata, ) ] @@ -3079,7 +3088,7 @@ def test_init_from_local_shards_invalid_property_cross_ranks(self): wrong_requires_grad_shards = [ sharded_tensor.Shard( torch.randn( - 5, 5, device=f"cuda:{self.rank}", requires_grad=tensor_requires_grad + 5, 5, device=f"{DEVICE_TYPE}:{self.rank}", requires_grad=tensor_requires_grad ), local_shard_metadata, ) @@ -3138,18 +3147,18 @@ def test_init_from_local_shards_invalid_pin_memory(self): @with_comms @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_init_from_local_shards_invalid_shards_overlap(self): local_shard_size = [5, 5] if self.rank != 0 else [6, 6] local_shard_metadata = ShardMetadata( shard_offsets=[(self.rank // 2) * 5, (self.rank % 2) * 5], shard_sizes=local_shard_size, - placement=f"rank:{self.rank}/cuda:{self.rank}", + placement=f"rank:{self.rank}/{DEVICE_TYPE}:{self.rank}", ) local_shards = [ sharded_tensor.Shard( - torch.randn(local_shard_size, device=f"cuda:{self.rank}"), + torch.randn(local_shard_size, device=f"{DEVICE_TYPE}:{self.rank}"), local_shard_metadata, ) ] @@ -3161,18 +3170,18 @@ def test_init_from_local_shards_invalid_shards_overlap(self): @with_comms @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_init_from_local_shards_invalid_shards_gaps(self): local_shard_size = [5, 5] if self.rank != 0 else [4, 4] local_shard_metadata = ShardMetadata( shard_offsets=[(self.rank // 2) * 5, (self.rank % 2) * 5], shard_sizes=local_shard_size, - placement=f"rank:{self.rank}/cuda:{self.rank}", + placement=f"rank:{self.rank}/{DEVICE_TYPE}:{self.rank}", ) local_shards = [ sharded_tensor.Shard( - torch.randn(local_shard_size, device=f"cuda:{self.rank}"), + torch.randn(local_shard_size, device=f"{DEVICE_TYPE}:{self.rank}"), local_shard_metadata, ) ] @@ -3184,12 +3193,12 @@ def test_init_from_local_shards_invalid_shards_gaps(self): @with_comms @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_init_from_local_shards_and_global_metadata_invalid_shards(self): local_shard_metadata = ShardMetadata( shard_offsets=[(self.rank // 2) * 5, (self.rank % 2) * 5], shard_sizes=[5, 5], - placement=f"rank:{self.rank}/cuda:{self.rank}", + placement=f"rank:{self.rank}/{DEVICE_TYPE}:{self.rank}", ) shards_metadata = [] @@ -3201,7 +3210,7 @@ def test_init_from_local_shards_and_global_metadata_invalid_shards(self): ShardMetadata( shard_offsets=[(r // 2) * 5, (r % 2) * 5], shard_sizes=[5, 5], - placement=f"rank:{r}/cuda:{r}", + placement=f"rank:{r}/{DEVICE_TYPE}:{r}", ) ) @@ -3229,10 +3238,10 @@ def test_init_from_local_shards_and_global_metadata_invalid_shards(self): wrong_num_shards = [ sharded_tensor.Shard( - torch.randn(5, 5, device=f"cuda:{self.rank}"), local_shard_metadata + torch.randn(5, 5, device=f"{DEVICE_TYPE}:{self.rank}"), local_shard_metadata ), sharded_tensor.Shard( - torch.randn(5, 5, device=f"cuda:{self.rank}"), local_shard_metadata + torch.randn(5, 5, device=f"{DEVICE_TYPE}:{self.rank}"), local_shard_metadata ), ] with self.assertRaisesRegex( @@ -3246,7 +3255,7 @@ def test_init_from_local_shards_and_global_metadata_invalid_shards(self): ValueError, "Shard tensor size does not match with metadata.shard_lengths" ): sharded_tensor.Shard( - torch.randn(2, 3, device=f"cuda:{self.rank}"), local_shard_metadata + torch.randn(2, 3, device=f"{DEVICE_TYPE}:{self.rank}"), local_shard_metadata ) with self.assertRaisesRegex( @@ -3257,7 +3266,7 @@ def test_init_from_local_shards_and_global_metadata_invalid_shards(self): wrong_dtype_shards = [ sharded_tensor.Shard( - torch.ones(5, 5, device=f"cuda:{self.rank}", dtype=torch.int), + torch.ones(5, 5, device=f"{DEVICE_TYPE}:{self.rank}", dtype=torch.int), local_shard_metadata, ) ] @@ -3271,7 +3280,7 @@ def test_init_from_local_shards_and_global_metadata_invalid_shards(self): indices = [[0, 1, 1], [2, 0, 2]] values = [3.2, 4.5, 5.8] sparse_tensor = torch.sparse_coo_tensor( - indices, values, (5, 5), device=f"cuda:{self.rank}" + indices, values, (5, 5), device=f"{DEVICE_TYPE}:{self.rank}" ) wrong_layout_shards = [ @@ -3286,7 +3295,7 @@ def test_init_from_local_shards_and_global_metadata_invalid_shards(self): wrong_requires_grad_shards = [ sharded_tensor.Shard( - torch.randn(5, 5, device=f"cuda:{self.rank}", requires_grad=True), + torch.randn(5, 5, device=f"{DEVICE_TYPE}:{self.rank}", requires_grad=True), local_shard_metadata, ) ] @@ -3300,7 +3309,7 @@ def test_init_from_local_shards_and_global_metadata_invalid_shards(self): wrong_memory_format_shards = [ sharded_tensor.Shard( - torch.randn(5, 5, device=f"cuda:{self.rank}").t(), local_shard_metadata + torch.randn(5, 5, device=f"{DEVICE_TYPE}:{self.rank}").t(), local_shard_metadata ) ] with self.assertRaisesRegex( @@ -3328,7 +3337,7 @@ def test_init_from_local_shards_and_global_metadata_invalid_shards(self): class TestShardedTensorCustomOps(ShardedTensorTestBase): @with_comms @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_custom_op(self): @custom_sharded_op_impl(torch.asin) def my_sharded_asin(types, args, kwargs, process_group): @@ -3337,10 +3346,10 @@ def my_sharded_asin(types, args, kwargs, process_group): spec = ChunkShardingSpec( dim=0, placements=[ - "rank:0/cuda:0", - "rank:1/cuda:1", - "rank:2/cuda:2", - "rank:3/cuda:3", + f"rank:0/{DEVICE_TYPE}:0", + f"rank:1/{DEVICE_TYPE}:1", + f"rank:2/{DEVICE_TYPE}:2", + f"rank:3/{DEVICE_TYPE}:3", ], ) @@ -3350,9 +3359,9 @@ def my_sharded_asin(types, args, kwargs, process_group): @with_comms @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_custom_op_override(self): - t = torch.rand(10, 10).cuda(self.rank) + t = torch.rand(10, 10).to(torch.device(DEVICE_TYPE, self.rank)) from torch.distributed._shard.sharding_spec.api import custom_sharding_spec_op @@ -3363,21 +3372,21 @@ def my_sharded_linear(types, args, kwargs, process_group): spec = ChunkShardingSpec( dim=0, placements=[ - "rank:0/cuda:0", - "rank:1/cuda:1", - "rank:2/cuda:2", - "rank:3/cuda:3", + f"rank:0/{DEVICE_TYPE}:0", + f"rank:1/{DEVICE_TYPE}:1", + f"rank:2/{DEVICE_TYPE}:2", + f"rank:3/{DEVICE_TYPE}:3", ], ) - m = torch.nn.Linear(32, 16).cuda(self.rank) + m = torch.nn.Linear(32, 16).to(torch.device(DEVICE_TYPE, self.rank)) shard_parameter(m, "weight", spec) - result = m(torch.rand(15, 32).cuda(self.rank)) + result = m(torch.rand(15, 32).to(torch.device(DEVICE_TYPE, self.rank))) self.assertEqual(t, result) @with_comms @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_custom_op_errors(self): with self.assertRaisesRegex(TypeError, "expects signature"): @@ -3394,7 +3403,7 @@ def my_op2(types): class TestShardMetadata(ShardedTensorTestBase): @with_comms - @requires_nccl() + @requires_accelerator_dist_backend() def test_shard_metadata_init(self): pg = dist.distributed_c10d._get_default_group() @@ -3411,7 +3420,7 @@ def test_shard_metadata_init(self): self.assertEqual(device, torch.device("cpu")) @with_comms - @requires_nccl() + @requires_accelerator_dist_backend() def test_create_shard_with_no_placement(self): md = ShardMetadata([0], [10]) shard = Shard(torch.zeros(10), md) @@ -3465,7 +3474,7 @@ def test_sub_process_group_placement_validation(self): for r in sub_pg_ranks: _parse_and_validate_remote_device( - sub_pg, _remote_device(f"rank:{r}/cuda:{r % sub_group_sz}") + sub_pg, _remote_device(f"rank:{r}/{DEVICE_TYPE}:{r % sub_group_sz}") ) diff --git a/test/distributed/_shard/sharded_tensor/test_sharded_tensor_reshard.py b/test/distributed/_shard/sharded_tensor/test_sharded_tensor_reshard.py index 05502ac168f8d..99ae455146b3c 100644 --- a/test/distributed/_shard/sharded_tensor/test_sharded_tensor_reshard.py +++ b/test/distributed/_shard/sharded_tensor/test_sharded_tensor_reshard.py @@ -1,12 +1,17 @@ # Owner(s): ["oncall: distributed"] +# Adapted from upstream test_sharded_tensor_reshard.py — made device-agnostic. import sys from itertools import product import torch +import torch.distributed as dist from torch.distributed._shard import _shard_tensor, sharded_tensor from torch.distributed._shard.sharding_spec import EnumerableShardingSpec, ShardMetadata -from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu +from torch.testing._internal.common_distributed import ( + requires_accelerator_dist_backend, + skip_if_lt_x_gpu, +) from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN from torch.testing._internal.distributed._shard.sharded_tensor import ( ShardedTensorTestBase, @@ -16,6 +21,12 @@ _chunk_sharding_specs_list_for_test, ) +if torch.accelerator.current_accelerator() is None: + print("No accelerator available, skipping tests", file=sys.stderr) + sys.exit(0) + +DEVICE_TYPE = torch.accelerator.current_accelerator().type +BACKEND = dist.get_default_backend_for_device(DEVICE_TYPE) if TEST_WITH_DEV_DBG_ASAN: print( @@ -28,7 +39,7 @@ class TestReshard(ShardedTensorTestBase): def _run_sharded_tensor_reshard(self, sharding_spec, reshard_spec, input_size): torch.manual_seed(0) - local_tensor = torch.rand(*input_size).cuda(self.rank) + local_tensor = torch.rand(*input_size).to(torch.device(DEVICE_TYPE, self.rank)) st = _shard_tensor(local_tensor, sharding_spec) st_compare = _shard_tensor(local_tensor, reshard_spec) st.reshard(reshard_spec) @@ -43,9 +54,9 @@ def _run_sharded_tensor_reshard(self, sharding_spec, reshard_spec, input_size): st.local_shards()[0].metadata, st_compare.local_shards()[0].metadata ) - @with_comms(init_rpc=False) + @with_comms(init_rpc=False, backend=BACKEND) @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_sharded_tensor_reshard(self): dims = [0, 1] for sharding_dim, reshard_dim in product(dims, dims): @@ -58,9 +69,9 @@ def test_sharded_tensor_reshard(self): self._run_sharded_tensor_reshard(spec, reshard_spec, [15, 26]) self._run_sharded_tensor_reshard(spec, reshard_spec, [12, 24]) - @with_comms(init_rpc=False) + @with_comms(init_rpc=False, backend=BACKEND) @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_sharded_tensor_reshard_errors(self): specs = _chunk_sharding_specs_list_for_test([0, 1], seed=6) spec, reshard_spec = specs[0], specs[1] @@ -69,12 +80,12 @@ def test_sharded_tensor_reshard_errors(self): ShardMetadata( shard_offsets=[0, 0], shard_sizes=[5, 5], - placement="rank:0/cuda:0", + placement=f"rank:0/{DEVICE_TYPE}:0", ), ShardMetadata( shard_offsets=[5, 0], shard_sizes=[5, 5], - placement="rank:1/cuda:1", + placement=f"rank:1/{DEVICE_TYPE}:1", ), ] ) diff --git a/test/distributed/_shard/sharding_plan/test_sharding_plan.py b/test/distributed/_shard/sharding_plan/test_sharding_plan.py index 7310c43bb4a09..80eeaf39c229a 100644 --- a/test/distributed/_shard/sharding_plan/test_sharding_plan.py +++ b/test/distributed/_shard/sharding_plan/test_sharding_plan.py @@ -1,4 +1,6 @@ # Owner(s): ["oncall: distributed"] +# Adapted from upstream test_sharding_plan.py — made device-agnostic for PrivateUse1 backends. + import sys import torch @@ -8,7 +10,10 @@ from torch.distributed._shard.sharded_tensor import ShardedTensor from torch.distributed._shard.sharding_plan import ShardingPlan, ShardingPlanner from torch.distributed._shard.sharding_spec import ChunkShardingSpec -from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu +from torch.testing._internal.common_distributed import ( + requires_accelerator_dist_backend, + skip_if_lt_x_gpu, +) from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN from torch.testing._internal.distributed._shard.sharded_tensor import ( ShardedTensorTestBase, @@ -20,6 +25,12 @@ ) from torch.testing._internal.distributed._shard.test_common import SimpleMegatronLM +if torch.accelerator.current_accelerator() is None: + print("No accelerator available, skipping tests", file=sys.stderr) + sys.exit(0) + +DEVICE_TYPE = torch.accelerator.current_accelerator().type +BACKEND = dist.get_default_backend_for_device(DEVICE_TYPE) if TEST_WITH_DEV_DBG_ASAN: print( @@ -37,7 +48,7 @@ class ChunkAllShardingPlanner(ShardingPlanner): def __init__(self, chunk_dim=0, device_count=0): self.dim = chunk_dim - self.devices = [f"rank:{i}/cuda:{i}" for i in range(device_count)] + self.devices = [f"rank:{i}/{DEVICE_TYPE}:{i}" for i in range(device_count)] def build_plan(self, module: nn.Module) -> ShardingPlan: named_params = module.named_parameters() @@ -49,9 +60,9 @@ def build_plan(self, module: nn.Module) -> ShardingPlan: class TestShardingPlan(ShardedTensorTestBase): - @with_comms(init_rpc=False) + @with_comms(init_rpc=False, backend=BACKEND) @skip_if_lt_x_gpu(TEST_GPU_NUM) - @requires_nccl() + @requires_accelerator_dist_backend() def test_sharding_plan_errors(self): rowwise_sharding_spec = generate_chunk_sharding_specs_for_test(1)[0] sharding_plan_wrong_plan = ShardingPlan( @@ -61,12 +72,13 @@ def test_sharding_plan_errors(self): output_plan={"": rowwise_sharding_spec}, ) - megatron_lm = SimpleMegatronLM([[17, 12], [12, 29]]).cuda(self.rank) + megatron_lm = SimpleMegatronLM([[17, 12], [12, 29]]).to( + torch.device(DEVICE_TYPE, self.rank) + ) with self.assertRaisesRegex( TypeError, "Only `ShardingSpec` and `Sharder` are supported to shard" ): - # shard the module with the provided sharding plan shard_module(megatron_lm, sharding_plan_wrong_plan) sharding_plan_wrong_output_plan = ShardingPlan( @@ -79,7 +91,6 @@ def test_sharding_plan_errors(self): with self.assertRaisesRegex( TypeError, "Only `ShardingSpec` is supported as output_plan" ): - # shard the module with the provided sharding plan shard_module(megatron_lm, sharding_plan_wrong_output_plan) sharding_plan_wrong_module_path = ShardingPlan( @@ -88,7 +99,6 @@ def test_sharding_plan_errors(self): }, ) with self.assertRaisesRegex(AttributeError, "has no attribute"): - # shard the module with the provided sharding plan shard_module(megatron_lm, sharding_plan_wrong_module_path) sharding_plan_wrong_param_path = ShardingPlan( @@ -97,15 +107,14 @@ def test_sharding_plan_errors(self): }, ) with self.assertRaisesRegex(AttributeError, "has no attribute"): - # shard the module with the provided sharding plan shard_module(megatron_lm, sharding_plan_wrong_param_path) - @with_comms(init_rpc=False) + @with_comms(init_rpc=False, backend=BACKEND) @skip_if_lt_x_gpu(TEST_GPU_NUM) - @requires_nccl() + @requires_accelerator_dist_backend() def test_custom_sharding_planner(self): - megatron_lm = SimpleMegatronLM([[17, 12], [12, 29]], rank=self.rank).cuda( - self.rank + megatron_lm = SimpleMegatronLM([[17, 12], [12, 29]], rank=self.rank).to( + torch.device(DEVICE_TYPE, self.rank) ) planner = ChunkAllShardingPlanner(device_count=TEST_GPU_NUM) sharding_plan = planner.build_plan(megatron_lm) @@ -118,23 +127,23 @@ def test_custom_sharding_planner(self): self.assertTrue(isinstance(megatron_lm.fc1.bias, ShardedTensor)) self.assertTrue(isinstance(megatron_lm.fc2.bias, ShardedTensor)) - @with_comms(init_rpc=False) + @with_comms(init_rpc=False, backend=BACKEND) @skip_if_lt_x_gpu(TEST_GPU_NUM) - @requires_nccl() + @requires_accelerator_dist_backend() def test_shard_module_sub_process_group(self): megatron_lm = SimpleMegatronLM([[17, 12], [12, 29]], rank=self.rank) colwise_sharding_spec = ChunkShardingSpec( dim=0, placements=[ - "rank:2/cuda:2", - "rank:3/cuda:3", + f"rank:2/{DEVICE_TYPE}:2", + f"rank:3/{DEVICE_TYPE}:3", ], ) rowwise_sharding_spec = ChunkShardingSpec( dim=1, placements=[ - "rank:2/cuda:2", - "rank:3/cuda:3", + f"rank:2/{DEVICE_TYPE}:2", + f"rank:3/{DEVICE_TYPE}:3", ], ) sharding_plan = ShardingPlan( diff --git a/test/distributed/_shard/sharding_spec/test_sharding_spec.py b/test/distributed/_shard/sharding_spec/test_sharding_spec.py index fe14f815749b1..ae69fccc8df84 100644 --- a/test/distributed/_shard/sharding_spec/test_sharding_spec.py +++ b/test/distributed/_shard/sharding_spec/test_sharding_spec.py @@ -1,8 +1,11 @@ +# Adapted from upstream — made device-agnostic for PrivateUse1 backends. # Owner(s): ["oncall: distributed"] import copy +import sys from dataclasses import dataclass import torch +import torch.distributed as dist from torch.distributed._shard import _shard_tensor, sharded_tensor from torch.distributed._shard.sharded_tensor import ( ShardedTensor, @@ -25,7 +28,10 @@ validate_non_overlapping_shards_metadata, ) from torch.testing._internal.common_cuda import TEST_MULTIGPU -from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu +from torch.testing._internal.common_distributed import ( + requires_accelerator_dist_backend, + skip_if_lt_x_gpu, +) from torch.testing._internal.common_utils import ( run_tests, skip_but_pass_in_sandcastle_if, @@ -39,7 +45,16 @@ _chunk_sharding_specs_list_for_test, ) +if torch.accelerator.current_accelerator() is None: + print("No accelerator available, skipping tests", file=sys.stderr) + sys.exit(0) + +DEVICE_TYPE = torch.accelerator.current_accelerator().type +BACKEND = dist.get_default_backend_for_device(DEVICE_TYPE) + +# NOTE: TestShardingSpec tests the API parser itself with hardcoded "cuda:0" strings. +# Those are kept as-is since they're testing the parser, not running distributed tests. class TestShardingSpec(TestCase): @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "2 CUDA GPUs are needed") def test_device_placement(self): @@ -614,10 +629,10 @@ def shard( class TestCustomShardingSpec(ShardedTensorTestBase): def test_custom_sharding_spec(self): ranks = [ - "rank:0/cuda:0", - "rank:1/cuda:1", - "rank:2/cuda:2", - "rank:3/cuda:3", + f"rank:0/{DEVICE_TYPE}:0", + f"rank:1/{DEVICE_TYPE}:1", + f"rank:2/{DEVICE_TYPE}:2", + f"rank:3/{DEVICE_TYPE}:3", ] grid_spec = GridShardingSpec(grid_size=4, placements=ranks) @@ -635,17 +650,17 @@ def test_custom_sharding_spec(self): @with_comms @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_custom_sharding_spec_tensor_ctor(self): """Test sharded_tensor.ones(...) with the custom grid sharding spec. """ ranks = [ - "rank:0/cuda:0", - "rank:1/cuda:1", - "rank:2/cuda:2", - "rank:3/cuda:3", + f"rank:0/{DEVICE_TYPE}:0", + f"rank:1/{DEVICE_TYPE}:1", + f"rank:2/{DEVICE_TYPE}:2", + f"rank:3/{DEVICE_TYPE}:3", ] grid_spec = GridShardingSpec(grid_size=2, placements=ranks) @@ -656,23 +671,23 @@ def test_custom_sharding_spec_tensor_ctor(self): local_shards = st.local_shards() self.assertEqual(1, len(local_shards)) local_shard = local_shards[0].tensor - self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.device) + self.assertEqual(torch.device(DEVICE_TYPE, self.rank), local_shard.device) self.assertEqual((2, 2), local_shard.size()) self.assertEqual(local_shard, torch.ones(2, 2)) @with_comms @skip_if_lt_x_gpu(4) - @requires_nccl() + @requires_accelerator_dist_backend() def test_custom_sharding_spec_shard_tensor(self): """Test custom spec can be invoked from the _shard_tensor callsite. """ ranks = [ - "rank:0/cuda:0", - "rank:1/cuda:1", - "rank:2/cuda:2", - "rank:3/cuda:3", + f"rank:0/{DEVICE_TYPE}:0", + f"rank:1/{DEVICE_TYPE}:1", + f"rank:2/{DEVICE_TYPE}:2", + f"rank:3/{DEVICE_TYPE}:3", ] grid_spec = GridShardingSpec(grid_size=2, placements=ranks) diff --git a/test/distributed/_shard/test_sharder.py b/test/distributed/_shard/test_sharder.py index 27b79c55406d8..4729cbb62527b 100644 --- a/test/distributed/_shard/test_sharder.py +++ b/test/distributed/_shard/test_sharder.py @@ -1,15 +1,22 @@ # Owner(s): ["oncall: distributed"] +# Adapted from upstream test_sharder.py — made device-agnostic for PrivateUse1 backends. +# Uses torch.accelerator to derive DEVICE_TYPE and BACKEND dynamically. + import copy import sys import torch +import torch.distributed as dist import torch.nn as nn from torch.distributed._shard import shard_module from torch.distributed._shard.sharded_tensor import ShardedTensor from torch.distributed._shard.sharder import Sharder from torch.distributed._shard.sharding_plan import ShardingPlan from torch.distributed._shard.sharding_spec import ChunkShardingSpec -from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu +from torch.testing._internal.common_distributed import ( + requires_accelerator_dist_backend, + skip_if_lt_x_gpu, +) from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN from torch.testing._internal.distributed._shard.sharded_tensor import ( ShardedTensorTestBase, @@ -17,6 +24,12 @@ with_comms, ) +if torch.accelerator.current_accelerator() is None: + print("No accelerator available, skipping tests", file=sys.stderr) + sys.exit(0) + +DEVICE_TYPE = torch.accelerator.current_accelerator().type +BACKEND = dist.get_default_backend_for_device(DEVICE_TYPE) if TEST_WITH_DEV_DBG_ASAN: print( @@ -98,9 +111,9 @@ def shard(self, ebc: nn.Module) -> nn.Module: class TestCustomSharder(ShardedTensorTestBase): - @with_comms(init_rpc=False) + @with_comms(init_rpc=False, backend=BACKEND) @skip_if_lt_x_gpu(TEST_GPU_NUM) - @requires_nccl() + @requires_accelerator_dist_backend() def test_custom_sharder(self): class MyModule(nn.Module): def __init__(self) -> None: @@ -111,7 +124,7 @@ def forward(self, inputs): return self.ebc(inputs) custom_sharder = CustomSharder( - devices=[f"rank:{i}/cuda:{i}" for i in range(TEST_GPU_NUM)], + devices=[f"rank:{i}/{DEVICE_TYPE}:{i}" for i in range(TEST_GPU_NUM)], split_sharding_idx=TEST_GPU_NUM // 2, ) @@ -121,7 +134,7 @@ def forward(self, inputs): } ) - local_model = MyModule().cuda(self.rank) + local_model = MyModule().to(torch.device(DEVICE_TYPE, self.rank)) sharded_model = copy.deepcopy(local_model) # shard the module with the provided sharding plan @@ -142,18 +155,18 @@ def forward(self, inputs): # make sure we can run sharded computation and compare outputs # with the local model version - input = torch.arange(8).reshape((2, 4)).cuda(self.rank) + input = torch.arange(8).reshape((2, 4)).to(torch.device(DEVICE_TYPE, self.rank)) local_output = local_model(input) sharded_output = sharded_model(input) self.assertEqual(local_output, sharded_output) - @with_comms(init_rpc=False) + @with_comms(init_rpc=False, backend=BACKEND) @skip_if_lt_x_gpu(TEST_GPU_NUM) - @requires_nccl() + @requires_accelerator_dist_backend() def test_custom_sharder_errors(self): custom_sharder = CustomSharder( - devices=[f"rank:{i}/cuda:{i}" for i in range(TEST_GPU_NUM)], + devices=[f"rank:{i}/{DEVICE_TYPE}:{i}" for i in range(TEST_GPU_NUM)], split_sharding_idx=TEST_GPU_NUM // 2, ) @@ -163,7 +176,9 @@ def test_custom_sharder_errors(self): } ) - sharded_model = CustomEmbeddingBagCollection(10, 10, 8).cuda(self.rank) + sharded_model = CustomEmbeddingBagCollection(10, 10, 8).to( + torch.device(DEVICE_TYPE, self.rank) + ) with self.assertRaisesRegex( KeyError, "path must not be empty for custom sharder!" @@ -172,7 +187,13 @@ def test_custom_sharder_errors(self): shard_module(sharded_model, sharding_plan) # test conflicted sharding plan - spec = ChunkShardingSpec(dim=0, placements=["rank:0/cuda:0", "rank:1/cuda:1"]) + spec = ChunkShardingSpec( + dim=0, + placements=[ + f"rank:0/{DEVICE_TYPE}:0", + f"rank:1/{DEVICE_TYPE}:1", + ], + ) sharding_plan = ShardingPlan( plan={ "embedding_bags.embedding_bag_0.weight": spec,