Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 36 additions & 50 deletions test/distributed/_shard/sharded_optim/test_sharded_optim.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,32 @@
# Owner(s): ["oncall: distributed"]
# Adapted from upstream test_sharded_optim.py — made device-agnostic for PrivateUse1 backends.

import sys
from copy import deepcopy

import torch
import torch.distributed as dist
import torch.optim as optim
from torch.distributed._shard import shard_parameter, sharded_tensor
from torch.distributed._shard.sharded_optim import ShardedOptimizer
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu
from torch.testing._internal.common_distributed import (
requires_accelerator_dist_backend,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._shard.sharded_tensor import (
ShardedTensorTestBase,
with_comms,
)

if torch.accelerator.current_accelerator() is None:
print("No accelerator available, skipping tests", file=sys.stderr)
sys.exit(0)

DEVICE_TYPE = torch.accelerator.current_accelerator().type
BACKEND = dist.get_default_backend_for_device(DEVICE_TYPE)


class MyShardedModel(torch.nn.Module):
def __init__(self, spec=None, group=None):
Expand Down Expand Up @@ -46,30 +59,15 @@ def __init__(self, rank=None):
self.linear2 = torch.nn.Linear(12, 29)
self.gelu = torch.nn.GELU()

if rank:
self.linear1.cuda(rank)
self.linear2.cuda(rank)
if rank is not None:
device = torch.device(DEVICE_TYPE, rank)
self.linear1.to(device)
self.linear2.to(device)

def shard_parameter(self):
rowwise_sharding_spec = ChunkShardingSpec(
dim=0,
placements=[
"rank:0/cuda:0",
"rank:1/cuda:1",
"rank:2/cuda:2",
"rank:3/cuda:3",
],
)

colwise_sharding_spec = ChunkShardingSpec(
dim=1,
placements=[
"rank:0/cuda:0",
"rank:1/cuda:1",
"rank:2/cuda:2",
"rank:3/cuda:3",
],
)
placements = [f"rank:{i}/{DEVICE_TYPE}:{i}" for i in range(4)]
rowwise_sharding_spec = ChunkShardingSpec(dim=0, placements=placements)
colwise_sharding_spec = ChunkShardingSpec(dim=1, placements=placements)

shard_parameter(self.linear1, "weight", rowwise_sharding_spec)
shard_parameter(self.linear2, "weight", colwise_sharding_spec)
Expand All @@ -79,21 +77,15 @@ def forward(self, inp):


class TestShardedOptimizer(ShardedTensorTestBase):
@with_comms(init_rpc=False)
@with_comms(init_rpc=False, backend=BACKEND)
@skip_if_lt_x_gpu(4)
@requires_nccl()
@requires_accelerator_dist_backend()
def test_sharded_optim(self):
rowwise_spec = ChunkShardingSpec(
dim=0,
placements=[
"rank:0/cuda:0",
"rank:1/cuda:1",
"rank:2/cuda:2",
"rank:3/cuda:3",
],
)
local_model = MyShardedModel().cuda()
sharded_model = MyShardedModel(spec=rowwise_spec).cuda()
placements = [f"rank:{i}/{DEVICE_TYPE}:{i}" for i in range(4)]
rowwise_spec = ChunkShardingSpec(dim=0, placements=placements)
device = torch.device(DEVICE_TYPE)
local_model = MyShardedModel().to(device)
sharded_model = MyShardedModel(spec=rowwise_spec).to(device)

# copy the parameters from local model
sharded_model.sharded_param.local_shards()[0].tensor = (
Expand All @@ -109,7 +101,7 @@ def test_sharded_optim(self):

before_update = deepcopy(sharded_optim.named_params)

inp = torch.rand([5, 10]).cuda(self.rank).requires_grad_()
inp = torch.rand([5, 10]).to(torch.device(DEVICE_TYPE, self.rank)).requires_grad_()

# run forward
local_output = local_model(inp)
Expand Down Expand Up @@ -138,27 +130,21 @@ def test_sharded_optim(self):
self.assertNotEqual(val, new_val)
self.assertEqual(new_val, local_model.param)

@with_comms(init_rpc=False)
@with_comms(init_rpc=False, backend=BACKEND)
@skip_if_lt_x_gpu(4)
@requires_nccl()
@requires_accelerator_dist_backend()
def test_named_params_with_sharded_tensor(self):
rowwise_spec = ChunkShardingSpec(
dim=0,
placements=[
"rank:0/cuda:0",
"rank:1/cuda:1",
"rank:2/cuda:2",
"rank:3/cuda:3",
],
)
sharded_model = MyShardedModel(spec=rowwise_spec).cuda()
placements = [f"rank:{i}/{DEVICE_TYPE}:{i}" for i in range(4)]
rowwise_spec = ChunkShardingSpec(dim=0, placements=placements)
device = torch.device(DEVICE_TYPE)
sharded_model = MyShardedModel(spec=rowwise_spec).to(device)
sharded_model_params = dict(sharded_model.named_parameters())
param_keys = list(sharded_model_params.keys())
self.assertEqual(len(param_keys), 2)
self.assertTrue("param" in param_keys)
self.assertTrue("sharded_param" in param_keys)

sharded_linear = MyShardedLinear(rank=self.rank).cuda()
sharded_linear = MyShardedLinear(rank=self.rank).to(device)
sharded_linear.shard_parameter()
sharded_linear_params = dict(sharded_linear.named_parameters())
param_keys = list(sharded_linear_params.keys())
Expand Down
37 changes: 24 additions & 13 deletions test/distributed/_shard/sharded_tensor/ops/test_binary_cmp.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Adapted from upstream — made device-agnostic for PrivateUse1 backends.
# Owner(s): ["oncall: distributed"]

import sys
Expand All @@ -7,7 +8,10 @@
from torch.distributed._shard import sharded_tensor
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
from torch.distributed.distributed_c10d import _get_default_group
from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu
from torch.testing._internal.common_distributed import (
requires_accelerator_dist_backend,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
from torch.testing._internal.distributed._shard.sharded_tensor import (
ShardedTensorTestBase,
Expand All @@ -22,6 +26,13 @@
)
sys.exit(0)

if torch.accelerator.current_accelerator() is None:
print("No accelerator available, skipping tests", file=sys.stderr)
sys.exit(0)

DEVICE_TYPE = torch.accelerator.current_accelerator().type
BACKEND = dist.get_default_backend_for_device(DEVICE_TYPE)


class TestShardedTensorBinaryOps(ShardedTensorTestBase):
"""Test base for binary comparison functions such as torch.equal, torch.allclose etc. for ShardedTensor"""
Expand All @@ -45,20 +56,20 @@ def get_gpu_specs(self):
spec = ChunkShardingSpec(
dim=0,
placements=[
"rank:0/cuda:0",
"rank:1/cuda:1",
"rank:2/cuda:2",
"rank:3/cuda:3",
f"rank:0/{DEVICE_TYPE}:0",
f"rank:1/{DEVICE_TYPE}:1",
f"rank:2/{DEVICE_TYPE}:2",
f"rank:3/{DEVICE_TYPE}:3",
],
)

alt_spec = ChunkShardingSpec(
dim=0,
placements=[
"rank:1/cuda:1",
"rank:0/cuda:0",
"rank:3/cuda:3",
"rank:2/cuda:2",
f"rank:1/{DEVICE_TYPE}:1",
f"rank:0/{DEVICE_TYPE}:0",
f"rank:3/{DEVICE_TYPE}:3",
f"rank:2/{DEVICE_TYPE}:2",
],
)
return spec, alt_spec
Expand Down Expand Up @@ -119,13 +130,13 @@ def _test_common_failures(self, cmp_op):

@with_comms
@skip_if_lt_x_gpu(4)
@requires_nccl()
@requires_accelerator_dist_backend()
def test_torch_equal_tensor_specs(self):
self._test_common_failures(torch.equal)

@with_comms
@skip_if_lt_x_gpu(4)
@requires_nccl()
@requires_accelerator_dist_backend()
def test_torch_equal(self):
"""Test torch.equal(ShardedTensor, ShardedTensor)"""

Expand All @@ -135,13 +146,13 @@ def test_torch_equal(self):

@with_comms
@skip_if_lt_x_gpu(4)
@requires_nccl()
@requires_accelerator_dist_backend()
def test_torch_allclose_tensor_specs(self):
self._test_common_failures(torch.allclose)

@with_comms
@skip_if_lt_x_gpu(4)
@requires_nccl()
@requires_accelerator_dist_backend()
def test_torch_allclose(self):
"""Test torch.allclose(ShardedTensor, ShardedTensor)"""

Expand Down
25 changes: 18 additions & 7 deletions test/distributed/_shard/sharded_tensor/ops/test_embedding.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
# Adapted from upstream — made device-agnostic for PrivateUse1 backends.
# Owner(s): ["oncall: distributed"]

import sys

import torch
import torch.distributed as dist
from torch.distributed._shard import shard_parameter
from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu
from torch.testing._internal.common_distributed import (
requires_accelerator_dist_backend,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
from torch.testing._internal.distributed._shard.sharded_tensor import (
ShardedTensorTestBase,
Expand All @@ -26,6 +30,13 @@
)
sys.exit(0)

if torch.accelerator.current_accelerator() is None:
print("No accelerator available, skipping tests", file=sys.stderr)
sys.exit(0)

DEVICE_TYPE = torch.accelerator.current_accelerator().type
BACKEND = dist.get_default_backend_for_device(DEVICE_TYPE)


class TestShardedEmbedding(ShardedTensorTestBase):
def _run_sharded_embedding(
Expand All @@ -46,7 +57,7 @@ def _run_sharded_embedding(
max_norm=max_norm,
norm_type=norm_type,
padding_idx=padding_idx,
).cuda(self.rank)
).to(torch.device(DEVICE_TYPE, self.rank))

sharded_embedding = torch.nn.Embedding(
num_embeddings,
Expand All @@ -64,7 +75,7 @@ def _run_sharded_embedding(

# Run sharded computation
torch.manual_seed(self.rank) # inputs different on each rank
inp = torch.randint(0, num_embeddings, tuple(input_size)).cuda(self.rank)
inp = torch.randint(0, num_embeddings, tuple(input_size)).to(torch.device(DEVICE_TYPE, self.rank))
sharded_output = sharded_embedding(inp)

# If max_norm is set, we need to ensure that the renorm has been applied across
Expand Down Expand Up @@ -112,9 +123,9 @@ def _run_sharded_embedding(

self.assertEqual(local_output, sharded_output)

@with_comms(init_rpc=False)
@with_comms(init_rpc=False, backend=BACKEND)
@skip_if_lt_x_gpu(TEST_GPU_NUM)
@requires_nccl()
@requires_accelerator_dist_backend()
def test_sharded_embedding_colwise(self):
for spec in generate_chunk_sharding_specs_for_test(1):
self._run_sharded_embedding(spec, [5, 4], 17, 12)
Expand Down Expand Up @@ -148,9 +159,9 @@ def test_sharded_embedding_colwise(self):
)
self._run_sharded_embedding(spec, [30], 15, 14, max_norm=2.0)

@with_comms(init_rpc=False)
@with_comms(init_rpc=False, backend=BACKEND)
@skip_if_lt_x_gpu(TEST_GPU_NUM)
@requires_nccl()
@requires_accelerator_dist_backend()
def test_sharded_embedding_rowwise(self):
for spec in generate_chunk_sharding_specs_for_test(0):
# Test even split.
Expand Down
Loading
Loading