Skip to content

Commit 302e3db

Browse files
committed
Handle NCCL NVLS init hangs in unwaived tests
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
1 parent d09c338 commit 302e3db

2 files changed

Lines changed: 235 additions & 20 deletions

File tree

cpp/tensorrt_llm/runtime/ncclCommunicator.cpp

Lines changed: 217 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
2+
* Copyright (c) 2022-2026, NVIDIA CORPORATION. All rights reserved.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -23,11 +23,188 @@
2323
#include <nccl.h>
2424
#endif // ENABLE_MULTI_DEVICE
2525

26+
#include <array>
27+
#include <chrono>
28+
#include <cstdlib>
29+
#include <string>
30+
#include <thread>
31+
2632
using namespace tensorrt_llm::runtime;
2733

2834
namespace
2935
{
3036
#if ENABLE_MULTI_DEVICE
37+
constexpr int kDefaultNcclCommInitTimeoutMs = 60'000;
38+
constexpr int kNcclCommInitPollIntervalMs = 20;
39+
constexpr char const* kNcclCommInitTimeoutEnv = "TLLM_NCCL_COMM_INIT_TIMEOUT_MS";
40+
constexpr char const* kNcclNvlsEnableEnv = "NCCL_NVLS_ENABLE";
41+
42+
struct NcclInitResult
43+
{
44+
ncclComm_t comm{nullptr};
45+
ncclResult_t result{ncclSuccess};
46+
bool timedOut{false};
47+
48+
[[nodiscard]] bool isSuccess() const
49+
{
50+
return result == ncclSuccess;
51+
}
52+
};
53+
54+
struct NcclInitStatus
55+
{
56+
bool failed{false};
57+
bool timedOut{false};
58+
};
59+
60+
int getNcclCommInitTimeoutMs()
61+
{
62+
auto const* env = std::getenv(kNcclCommInitTimeoutEnv);
63+
int const timeoutMs = env == nullptr ? 0 : std::atoi(env);
64+
return timeoutMs > 0 ? timeoutMs : kDefaultNcclCommInitTimeoutMs;
65+
}
66+
67+
bool canSuggestNvlsDisable()
68+
{
69+
auto const* nvlsEnable = std::getenv(kNcclNvlsEnableEnv);
70+
return nvlsEnable == nullptr || std::string{nvlsEnable} == "2";
71+
}
72+
73+
void setRuntimeConnectIfUnset()
74+
{
75+
// Need static connection initialization for accurate KV cache size estimation.
76+
#if defined(_WIN32)
77+
if (getenv("NCCL_RUNTIME_CONNECT") == nullptr)
78+
{
79+
_putenv_s("NCCL_RUNTIME_CONNECT", "0");
80+
}
81+
#else
82+
setenv("NCCL_RUNTIME_CONNECT", "0", 0);
83+
#endif // _WIN32
84+
}
85+
86+
void abortNcclComm(ncclComm_t comm)
87+
{
88+
if (comm == nullptr)
89+
{
90+
return;
91+
}
92+
93+
auto const result = ncclCommAbort(comm);
94+
if (result != ncclSuccess)
95+
{
96+
TLLM_LOG_WARNING("Failed to abort NCCL communicator: %s.", ncclGetErrorString(result));
97+
}
98+
}
99+
100+
NcclInitResult initNcclCommWithTimeout(ncclUniqueId const& id, int worldSize, int rank, int timeoutMs)
101+
{
102+
NcclInitResult initResult;
103+
ncclConfig_t config = NCCL_CONFIG_INITIALIZER;
104+
config.blocking = 0;
105+
106+
auto result = ncclCommInitRankConfig(&initResult.comm, worldSize, id, rank, &config);
107+
if (result != ncclSuccess && result != ncclInProgress)
108+
{
109+
initResult.result = result;
110+
return initResult;
111+
}
112+
if (result == ncclSuccess)
113+
{
114+
initResult.result = ncclSuccess;
115+
return initResult;
116+
}
117+
if (initResult.comm == nullptr)
118+
{
119+
initResult.result = result;
120+
return initResult;
121+
}
122+
123+
auto const deadline = std::chrono::steady_clock::now() + std::chrono::milliseconds{timeoutMs};
124+
while (true)
125+
{
126+
ncclResult_t asyncResult = ncclSuccess;
127+
result = ncclCommGetAsyncError(initResult.comm, &asyncResult);
128+
if (result != ncclSuccess)
129+
{
130+
initResult.result = result;
131+
return initResult;
132+
}
133+
if (asyncResult != ncclInProgress)
134+
{
135+
initResult.result = asyncResult;
136+
return initResult;
137+
}
138+
if (std::chrono::steady_clock::now() >= deadline)
139+
{
140+
initResult.result = ncclInProgress;
141+
initResult.timedOut = true;
142+
return initResult;
143+
}
144+
std::this_thread::sleep_for(std::chrono::milliseconds{kNcclCommInitPollIntervalMs});
145+
}
146+
}
147+
148+
NcclInitStatus getNcclInitStatus(NcclInitResult const& result, tensorrt_llm::mpi::MpiComm const& mpiComm)
149+
{
150+
std::array<int, 2> localStatus{result.isSuccess() ? 0 : 1, result.timedOut ? 1 : 0};
151+
std::array<int, 2> globalStatus{};
152+
mpiComm.allreduce(
153+
localStatus.data(), globalStatus.data(), 2, tensorrt_llm::mpi::MpiType::kINT32, tensorrt_llm::mpi::MpiOp::MAX);
154+
return {globalStatus[0] != 0, globalStatus[1] != 0};
155+
}
156+
157+
bool allRanksCanUseNvlsDisableWorkaround(tensorrt_llm::mpi::MpiComm const& mpiComm)
158+
{
159+
int const localCanDisable = canSuggestNvlsDisable() ? 1 : 0;
160+
int globalCanDisable = 0;
161+
mpiComm.allreduce(
162+
&localCanDisable, &globalCanDisable, 1, tensorrt_llm::mpi::MpiType::kINT32, tensorrt_llm::mpi::MpiOp::MIN);
163+
164+
return globalCanDisable != 0;
165+
}
166+
167+
void checkNcclResult(ncclComm_t comm, ncclResult_t result, char const* operation)
168+
{
169+
if (result == ncclSuccess)
170+
{
171+
return;
172+
}
173+
if (result != ncclInProgress)
174+
{
175+
TLLM_NCCL_CHECK(result);
176+
}
177+
178+
while (true)
179+
{
180+
ncclResult_t asyncResult = ncclSuccess;
181+
result = ncclCommGetAsyncError(comm, &asyncResult);
182+
if (result != ncclSuccess)
183+
{
184+
TLLM_THROW("NCCL %s failed while polling communicator status: %s.", operation, ncclGetErrorString(result));
185+
}
186+
if (asyncResult == ncclSuccess)
187+
{
188+
return;
189+
}
190+
if (asyncResult != ncclInProgress)
191+
{
192+
TLLM_THROW("NCCL %s failed asynchronously: %s.", operation, ncclGetErrorString(asyncResult));
193+
}
194+
std::this_thread::sleep_for(std::chrono::milliseconds{kNcclCommInitPollIntervalMs});
195+
}
196+
}
197+
198+
ncclUniqueId createAndBroadcastNcclId(int rank, tensorrt_llm::mpi::MpiComm const& mpiComm)
199+
{
200+
ncclUniqueId id;
201+
if (rank == 0)
202+
{
203+
TLLM_NCCL_CHECK(ncclGetUniqueId(&id));
204+
}
205+
mpiComm.bcastValue(id, 0);
206+
return id;
207+
}
31208

32209
ncclDataType_t toNcclType(nvinfer1::DataType dataType)
33210
{
@@ -53,7 +230,7 @@ void NcclCommunicator::send(
53230
void const* sendbuff, size_t count, nvinfer1::DataType dataType, int peer, CudaStream const& stream) const
54231
{
55232
#if ENABLE_MULTI_DEVICE
56-
TLLM_NCCL_CHECK(ncclSend(sendbuff, count, toNcclType(dataType), peer, mComm, stream.get()));
233+
checkNcclResult(mComm, ncclSend(sendbuff, count, toNcclType(dataType), peer, mComm, stream.get()), "send");
57234
#else
58235
TLLM_THROW("Multi device support is disabled.");
59236
#endif // ENABLE_MULTI_DEVICE
@@ -63,7 +240,7 @@ void NcclCommunicator::receive(
63240
void* sendbuff, size_t count, nvinfer1::DataType dataType, int peer, CudaStream const& stream) const
64241
{
65242
#if ENABLE_MULTI_DEVICE
66-
TLLM_NCCL_CHECK(ncclRecv(sendbuff, count, toNcclType(dataType), peer, mComm, stream.get()));
243+
checkNcclResult(mComm, ncclRecv(sendbuff, count, toNcclType(dataType), peer, mComm, stream.get()), "receive");
67244
#else
68245
TLLM_THROW("Multi device support is disabled.");
69246
#endif // ENABLE_MULTI_DEVICE
@@ -73,22 +250,45 @@ ncclComm_t NcclCommunicator::createComm(int worldSize, int rank, mpi::MpiComm co
73250
{
74251
#if ENABLE_MULTI_DEVICE
75252

76-
ncclUniqueId id;
77-
if (rank == 0)
253+
setRuntimeConnectIfUnset();
254+
auto const timeoutMs = getNcclCommInitTimeoutMs();
255+
256+
auto id = createAndBroadcastNcclId(rank, mpiComm);
257+
auto initResult = initNcclCommWithTimeout(id, worldSize, rank, timeoutMs);
258+
auto const initStatus = getNcclInitStatus(initResult, mpiComm);
259+
260+
if (initStatus.failed)
78261
{
79-
ncclGetUniqueId(&id);
262+
if (initStatus.timedOut)
263+
{
264+
if (allRanksCanUseNvlsDisableWorkaround(mpiComm))
265+
{
266+
TLLM_THROW(
267+
"NCCL communicator initialization timed out after %d ms on at least one rank. This may indicate "
268+
"an NVLS multicast resource setup failure in Fabric Manager. Set NCCL_NVLS_ENABLE=0 before "
269+
"process startup and retry. TensorRT-LLM does not retry in-process because NCCL may not recover "
270+
"after a timed-out NVLS initialization.",
271+
timeoutMs);
272+
}
273+
TLLM_THROW(
274+
"NCCL communicator initialization timed out after %d ms on at least one rank. NCCL_NVLS_ENABLE is "
275+
"explicitly set, so TensorRT-LLM will not override it.",
276+
timeoutMs);
277+
}
278+
if (!initResult.isSuccess())
279+
{
280+
abortNcclComm(initResult.comm);
281+
}
282+
mpiComm.barrier();
283+
if (!initResult.isSuccess())
284+
{
285+
TLLM_THROW(
286+
"NCCL communicator initialization failed on rank %d: %s.", rank, ncclGetErrorString(initResult.result));
287+
}
288+
TLLM_THROW("NCCL communicator initialization failed on at least one peer rank.");
80289
}
81-
mpiComm.bcastValue(id, 0);
82-
ncclComm_t comm;
83-
// Need static connection initialization for accurate KV cache size estimation
84-
#if defined(_WIN32)
85-
if (getenv("NCCL_RUNTIME_CONNECT") == nullptr)
86-
_putenv_s("NCCL_RUNTIME_CONNECT", "0");
87-
#else
88-
setenv("NCCL_RUNTIME_CONNECT", "0", 0);
89-
#endif // _WIN32
90-
TLLM_NCCL_CHECK(ncclCommInitRank(&comm, worldSize, id, rank));
91-
return comm;
290+
291+
return initResult.comm;
92292
#else
93293
// Python runtime requires instantiation of a communicator even though it may never be used to enable
94294
// pipeline parallel code-path. To enable this, have an empty communicator with uninitialized state.

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@
4848
JsonModeEval, LlmapiAccuracyTestHarness,
4949
LongBenchV1, LongBenchV2)
5050

51+
_NCCL_NVLS_DISABLED_ENV = {"NCCL_NVLS_ENABLE": "0"}
52+
5153

5254
# Keep helper definitions below imports so new imports do not need E402
5355
# suppressions in this legacy test file.
@@ -75,6 +77,11 @@ def patched_start_mpi_pool(self):
7577
patched_start_mpi_pool)
7678

7779

80+
def disable_nccl_nvls_for_test(mocker):
81+
mocker.patch.dict(os.environ, _NCCL_NVLS_DISABLED_ENV)
82+
patch_mpi_pool_session_for_env(mocker, _NCCL_NVLS_DISABLED_ENV)
83+
84+
7885
def _get_default_torch_compile_config(torch_compile):
7986
return TorchCompileConfig(enable_fullgraph=True,
8087
enable_piecewise_cuda_graph=True,
@@ -1698,7 +1705,10 @@ def test_bfloat16_4gpus_kv_cache_aware_routing(self, mtp_nextn):
16981705
ids=["tp4", "ep4", "tp2pp2", "pp4"])
16991706
def test_bfloat16_4gpus(self, tp_size, pp_size, ep_size, mtp_nextn,
17001707
attention_dp, cuda_graph, overlap_scheduler,
1701-
torch_compile):
1708+
torch_compile, mocker):
1709+
if pp_size > 1:
1710+
disable_nccl_nvls_for_test(mocker)
1711+
17021712
if pp_size > 1 and mtp_nextn > 0:
17031713
num_hidden_layers = 30
17041714
pp_partition = [num_hidden_layers // pp_size + 1] * pp_size
@@ -1954,7 +1964,10 @@ def test_fp8_block_scales_cuda_graph_padding_4gpus(self, mtp_nextn,
19541964
def test_fp8_block_scales_4gpus(self, tp_size, pp_size, ep_size, mtp_nextn,
19551965
fp8kv, attention_dp, cuda_graph,
19561966
overlap_scheduler, torch_compile,
1957-
sampler_async_worker):
1967+
sampler_async_worker, mocker):
1968+
if pp_size > 1:
1969+
disable_nccl_nvls_for_test(mocker)
1970+
19581971
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75)
19591972
torch_compile_config = _get_default_torch_compile_config(torch_compile)
19601973
pytorch_config = dict(
@@ -2390,12 +2403,14 @@ def test_nvfp4_batch_waiting(self, torch_compile, fp8kv, cuda_graph,
23902403
def test_nvfp4_4gpus(self, fp8kv, attention_dp, cuda_graph,
23912404
overlap_scheduler, low_precision_combine, tp_size,
23922405
pp_size, ep_size, torch_compile, mtp_nextn,
2393-
moe_backend):
2406+
moe_backend, mocker):
23942407
sm_version = get_sm_version()
23952408
if moe_backend == "TRTLLM" and sm_version in (120, 121):
23962409
pytest.skip(f"{moe_backend} backend does not support SM 120 or 121")
23972410
if moe_backend == "CUTEDSL" and sm_version not in (100, 103):
23982411
pytest.skip(f"{moe_backend} backend supports SM 100 and 103 only")
2412+
if pp_size > 1:
2413+
disable_nccl_nvls_for_test(mocker)
23992414

24002415
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75)
24012416
# Picewise Cuda Graph cannot be enabled for nvfp4 attention dp.

0 commit comments

Comments
 (0)