diff --git a/python/ray/dag/tests/experimental/test_non_compiled_nccl_dag.py b/python/ray/dag/tests/experimental/test_non_compiled_nccl_dag.py new file mode 100644 index 000000000000..eeb9bf755394 --- /dev/null +++ b/python/ray/dag/tests/experimental/test_non_compiled_nccl_dag.py @@ -0,0 +1,64 @@ +# coding: utf-8 +import os +import sys + +import pytest + +from ray.experimental.channel.torch_tensor_type import TorchTensorType +from ray.tests.conftest import * # noqa + + +def test_non_compiled_nccl_falls_back_to_cpu(ray_start_regular_shared): + """ + Test that TorchTensorType(transport='accelerator') in a non-compiled DAG + falls back to CPU/shared-memory transport with a warning. + """ + t = TorchTensorType(transport="accelerator") + + with pytest.warns( + UserWarning, match="Falling back to shared-memory \\(CPU\\) transport" + ): + channel = t.create_channel( + writer=None, + reader_and_node_list=[], + ) + + from ray.experimental.channel.shared_memory_channel import CompositeChannel + + assert isinstance(channel, CompositeChannel) + + +def test_non_compiled_custom_communicator_falls_back_to_cpu(ray_start_regular_shared): + """ + Test that TorchTensorType(transport=custom_comm) in a non-compiled DAG + falls back to CPU/shared-memory transport with a warning because communicator_id is None. + """ + from ray.experimental.channel.communicator import Communicator + + # Register as virtual subclass of Communicator to satisfy isinstance check + # without running into abstract method instantiation TypeErrors. + @Communicator.register + class DummyCommunicator: + def get_transport_name(self) -> str: + return "accelerator" + + t = TorchTensorType(transport=DummyCommunicator()) + + with pytest.warns( + UserWarning, match="Falling back to shared-memory \\(CPU\\) transport" + ): + channel = t.create_channel( + writer=None, + reader_and_node_list=[], + ) + + from ray.experimental.channel.shared_memory_channel import CompositeChannel + + assert isinstance(channel, CompositeChannel) + + +if __name__ == "__main__": + if os.environ.get("PARALLEL_CI"): + sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__])) + else: + sys.exit(pytest.main(["-sv", __file__])) diff --git a/python/ray/experimental/channel/torch_tensor_type.py b/python/ray/experimental/channel/torch_tensor_type.py index 21220acf9959..10a6c7013d24 100644 --- a/python/ray/experimental/channel/torch_tensor_type.py +++ b/python/ray/experimental/channel/torch_tensor_type.py @@ -1,4 +1,5 @@ import logging +import warnings from typing import TYPE_CHECKING, List, Optional, Tuple, Union import ray @@ -142,20 +143,36 @@ def create_channel( _tensor_metadata_channel: Optional["Channel"] = None, ) -> type: if self.requires_accelerator(): - from ray.experimental.channel.torch_tensor_accelerator_channel import ( - TorchTensorAcceleratorChannel, - ) - - return TorchTensorAcceleratorChannel( - writer, - reader_and_node_list, - self, - driver_actor_id, - _tensor_metadata_channel, - _cpu_data_channel, - ) - - # Data does not require accelerator. Transfer via host memory using a + # Check if this type hint has been set up by the Compiled Graph + # compiler (i.e., communicator_id or communicator is set). If not, + # we are in a non-compiled graph context and fall back to the + # shared memory channel for debugging purposes. + if self._communicator_id is None: + warnings.warn( + "TorchTensorType(transport='accelerator') used outside of a " + "Compiled Graph. Falling back to shared-memory (CPU) " + "transport for debugging. Performance will be " + "significantly worse than compiled NCCL.", + UserWarning, + stacklevel=2, + ) + # Fall through to the shared memory path below. + else: + from ray.experimental.channel.torch_tensor_accelerator_channel import ( + TorchTensorAcceleratorChannel, + ) + + return TorchTensorAcceleratorChannel( + writer, + reader_and_node_list, + self, + driver_actor_id, + _tensor_metadata_channel, + _cpu_data_channel, + ) + + # Data does not require accelerator, OR we are in a non-compiled graph + # context (debugging path). Transfer via host memory using a # shared-memory channel. # TODO(swang): Allow the initial max buffer size to be overridden. typ = SharedMemoryType()