diff --git a/ignite/distributed/comp_models/native.py b/ignite/distributed/comp_models/native.py index 912eb31704c8..b0a30e874abd 100644 --- a/ignite/distributed/comp_models/native.py +++ b/ignite/distributed/comp_models/native.py @@ -74,6 +74,9 @@ def _create_from_backend(self, backend, timeout=None, **kwargs): if timeout is not None: init_pg_kwargs["timeout"] = timeout + if backend == dist.Backend.NCCL and not torch.cuda.is_available(): + raise RuntimeError("Nccl backend is required but no cuda capable devices") + dist.init_process_group(backend, init_method="env://", **init_pg_kwargs) # https://github.com/facebookresearch/maskrcnn-benchmark/issues/172 dist.barrier() diff --git a/tests/ignite/conftest.py b/tests/ignite/conftest.py index 62b907b48e3e..5857a113c597 100644 --- a/tests/ignite/conftest.py +++ b/tests/ignite/conftest.py @@ -216,3 +216,12 @@ def _xla_execute(fn, args, nprocs): @pytest.fixture() def xmp_executor(): yield _xla_execute + + +@pytest.fixture() +def mock_gpu_is_not_available(): + from unittest.mock import patch + + with patch("torch.cuda") as mock_cuda: + mock_cuda.is_available.return_value = False + yield mock_cuda diff --git a/tests/ignite/distributed/comp_models/test_native.py b/tests/ignite/distributed/comp_models/test_native.py index 650c5fb7bcb2..fe1d3b9e8b70 100644 --- a/tests/ignite/distributed/comp_models/test_native.py +++ b/tests/ignite/distributed/comp_models/test_native.py @@ -32,6 +32,20 @@ def test__native_dist_model(): assert "mpi" not in available_backends +@pytest.mark.distributed +@pytest.mark.skipif(not dist.is_nccl_available(), reason="Skip if nccl not available") +def test__native_nccl_but_no_gpu(mock_gpu_is_not_available): + + env_backup = os.environ + + with pytest.raises(RuntimeError, match=r"Nccl backend is required but no cuda capable devices"): + _NativeDistModel(backend="nccl") + + # environ could be corrupted by _NativeDistModel + os.environ.clear() + os.environ.update(env_backup) + + @pytest.mark.distributed @pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc") def test__native_dist_model_create_from_backend_bad_config():