11/*
2- * Copyright (c) 2022-2024 , NVIDIA CORPORATION. All rights reserved.
2+ * Copyright (c) 2022-2026 , NVIDIA CORPORATION. All rights reserved.
33 *
44 * Licensed under the Apache License, Version 2.0 (the "License");
55 * you may not use this file except in compliance with the License.
2323#include < nccl.h>
2424#endif // ENABLE_MULTI_DEVICE
2525
26+ #include < array>
27+ #include < chrono>
28+ #include < cstdlib>
29+ #include < string>
30+ #include < thread>
31+
2632using namespace tensorrt_llm ::runtime;
2733
2834namespace
2935{
3036#if ENABLE_MULTI_DEVICE
37+ constexpr int kDefaultNcclCommInitTimeoutMs = 60'000 ;
38+ constexpr int kNcclCommInitPollIntervalMs = 20 ;
39+ constexpr char const * kNcclCommInitTimeoutEnv = " TLLM_NCCL_COMM_INIT_TIMEOUT_MS" ;
40+ constexpr char const * kNcclNvlsEnableEnv = " NCCL_NVLS_ENABLE" ;
41+
42+ struct NcclInitResult
43+ {
44+ ncclComm_t comm{nullptr };
45+ ncclResult_t result{ncclSuccess};
46+ bool timedOut{false };
47+
48+ [[nodiscard]] bool isSuccess () const
49+ {
50+ return result == ncclSuccess;
51+ }
52+ };
53+
54+ struct NcclInitStatus
55+ {
56+ bool failed{false };
57+ bool timedOut{false };
58+ };
59+
60+ int getNcclCommInitTimeoutMs ()
61+ {
62+ auto const * env = std::getenv (kNcclCommInitTimeoutEnv );
63+ int const timeoutMs = env == nullptr ? 0 : std::atoi (env);
64+ return timeoutMs > 0 ? timeoutMs : kDefaultNcclCommInitTimeoutMs ;
65+ }
66+
67+ bool canSuggestNvlsDisable ()
68+ {
69+ auto const * nvlsEnable = std::getenv (kNcclNvlsEnableEnv );
70+ return nvlsEnable == nullptr || std::string{nvlsEnable} == " 2" ;
71+ }
72+
73+ void setRuntimeConnectIfUnset ()
74+ {
75+ // Need static connection initialization for accurate KV cache size estimation.
76+ #if defined(_WIN32)
77+ if (getenv (" NCCL_RUNTIME_CONNECT" ) == nullptr )
78+ {
79+ _putenv_s (" NCCL_RUNTIME_CONNECT" , " 0" );
80+ }
81+ #else
82+ setenv (" NCCL_RUNTIME_CONNECT" , " 0" , 0 );
83+ #endif // _WIN32
84+ }
85+
86+ void abortNcclComm (ncclComm_t comm)
87+ {
88+ if (comm == nullptr )
89+ {
90+ return ;
91+ }
92+
93+ auto const result = ncclCommAbort (comm);
94+ if (result != ncclSuccess)
95+ {
96+ TLLM_LOG_WARNING (" Failed to abort NCCL communicator: %s." , ncclGetErrorString (result));
97+ }
98+ }
99+
100+ NcclInitResult initNcclCommWithTimeout (ncclUniqueId const & id, int worldSize, int rank, int timeoutMs)
101+ {
102+ NcclInitResult initResult;
103+ ncclConfig_t config = NCCL_CONFIG_INITIALIZER ;
104+ config.blocking = 0 ;
105+
106+ auto result = ncclCommInitRankConfig (&initResult.comm , worldSize, id, rank, &config);
107+ if (result != ncclSuccess && result != ncclInProgress)
108+ {
109+ initResult.result = result;
110+ return initResult;
111+ }
112+ if (result == ncclSuccess)
113+ {
114+ initResult.result = ncclSuccess;
115+ return initResult;
116+ }
117+ if (initResult.comm == nullptr )
118+ {
119+ initResult.result = result;
120+ return initResult;
121+ }
122+
123+ auto const deadline = std::chrono::steady_clock::now () + std::chrono::milliseconds{timeoutMs};
124+ while (true )
125+ {
126+ ncclResult_t asyncResult = ncclSuccess;
127+ result = ncclCommGetAsyncError (initResult.comm , &asyncResult);
128+ if (result != ncclSuccess)
129+ {
130+ initResult.result = result;
131+ return initResult;
132+ }
133+ if (asyncResult != ncclInProgress)
134+ {
135+ initResult.result = asyncResult;
136+ return initResult;
137+ }
138+ if (std::chrono::steady_clock::now () >= deadline)
139+ {
140+ initResult.result = ncclInProgress;
141+ initResult.timedOut = true ;
142+ return initResult;
143+ }
144+ std::this_thread::sleep_for (std::chrono::milliseconds{kNcclCommInitPollIntervalMs });
145+ }
146+ }
147+
148+ NcclInitStatus getNcclInitStatus (NcclInitResult const & result, tensorrt_llm::mpi::MpiComm const & mpiComm)
149+ {
150+ std::array<int , 2 > localStatus{result.isSuccess () ? 0 : 1 , result.timedOut ? 1 : 0 };
151+ std::array<int , 2 > globalStatus{};
152+ mpiComm.allreduce (
153+ localStatus.data (), globalStatus.data (), 2 , tensorrt_llm::mpi::MpiType::kINT32 , tensorrt_llm::mpi::MpiOp::MAX );
154+ return {globalStatus[0 ] != 0 , globalStatus[1 ] != 0 };
155+ }
156+
157+ bool allRanksCanUseNvlsDisableWorkaround (tensorrt_llm::mpi::MpiComm const & mpiComm)
158+ {
159+ int const localCanDisable = canSuggestNvlsDisable () ? 1 : 0 ;
160+ int globalCanDisable = 0 ;
161+ mpiComm.allreduce (
162+ &localCanDisable, &globalCanDisable, 1 , tensorrt_llm::mpi::MpiType::kINT32 , tensorrt_llm::mpi::MpiOp::MIN );
163+
164+ return globalCanDisable != 0 ;
165+ }
166+
167+ void checkNcclResult (ncclComm_t comm, ncclResult_t result, char const * operation)
168+ {
169+ if (result == ncclSuccess)
170+ {
171+ return ;
172+ }
173+ if (result != ncclInProgress)
174+ {
175+ TLLM_NCCL_CHECK (result);
176+ }
177+
178+ while (true )
179+ {
180+ ncclResult_t asyncResult = ncclSuccess;
181+ result = ncclCommGetAsyncError (comm, &asyncResult);
182+ if (result != ncclSuccess)
183+ {
184+ TLLM_THROW (" NCCL %s failed while polling communicator status: %s." , operation, ncclGetErrorString (result));
185+ }
186+ if (asyncResult == ncclSuccess)
187+ {
188+ return ;
189+ }
190+ if (asyncResult != ncclInProgress)
191+ {
192+ TLLM_THROW (" NCCL %s failed asynchronously: %s." , operation, ncclGetErrorString (asyncResult));
193+ }
194+ std::this_thread::sleep_for (std::chrono::milliseconds{kNcclCommInitPollIntervalMs });
195+ }
196+ }
197+
198+ ncclUniqueId createAndBroadcastNcclId (int rank, tensorrt_llm::mpi::MpiComm const & mpiComm)
199+ {
200+ ncclUniqueId id;
201+ if (rank == 0 )
202+ {
203+ TLLM_NCCL_CHECK (ncclGetUniqueId (&id));
204+ }
205+ mpiComm.bcastValue (id, 0 );
206+ return id;
207+ }
31208
32209ncclDataType_t toNcclType (nvinfer1::DataType dataType)
33210{
@@ -53,7 +230,7 @@ void NcclCommunicator::send(
53230 void const * sendbuff, size_t count, nvinfer1::DataType dataType, int peer, CudaStream const & stream) const
54231{
55232#if ENABLE_MULTI_DEVICE
56- TLLM_NCCL_CHECK ( ncclSend (sendbuff, count, toNcclType (dataType), peer, mComm , stream.get ()));
233+ checkNcclResult ( mComm , ncclSend (sendbuff, count, toNcclType (dataType), peer, mComm , stream.get ()), " send " );
57234#else
58235 TLLM_THROW (" Multi device support is disabled." );
59236#endif // ENABLE_MULTI_DEVICE
@@ -63,7 +240,7 @@ void NcclCommunicator::receive(
63240 void * sendbuff, size_t count, nvinfer1::DataType dataType, int peer, CudaStream const & stream) const
64241{
65242#if ENABLE_MULTI_DEVICE
66- TLLM_NCCL_CHECK ( ncclRecv (sendbuff, count, toNcclType (dataType), peer, mComm , stream.get ()));
243+ checkNcclResult ( mComm , ncclRecv (sendbuff, count, toNcclType (dataType), peer, mComm , stream.get ()), " receive " );
67244#else
68245 TLLM_THROW (" Multi device support is disabled." );
69246#endif // ENABLE_MULTI_DEVICE
@@ -73,22 +250,45 @@ ncclComm_t NcclCommunicator::createComm(int worldSize, int rank, mpi::MpiComm co
73250{
74251#if ENABLE_MULTI_DEVICE
75252
76- ncclUniqueId id;
77- if (rank == 0 )
253+ setRuntimeConnectIfUnset ();
254+ auto const timeoutMs = getNcclCommInitTimeoutMs ();
255+
256+ auto id = createAndBroadcastNcclId (rank, mpiComm);
257+ auto initResult = initNcclCommWithTimeout (id, worldSize, rank, timeoutMs);
258+ auto const initStatus = getNcclInitStatus (initResult, mpiComm);
259+
260+ if (initStatus.failed )
78261 {
79- ncclGetUniqueId (&id);
262+ if (initStatus.timedOut )
263+ {
264+ if (allRanksCanUseNvlsDisableWorkaround (mpiComm))
265+ {
266+ TLLM_THROW (
267+ " NCCL communicator initialization timed out after %d ms on at least one rank. This may indicate "
268+ " an NVLS multicast resource setup failure in Fabric Manager. Set NCCL_NVLS_ENABLE=0 before "
269+ " process startup and retry. TensorRT-LLM does not retry in-process because NCCL may not recover "
270+ " after a timed-out NVLS initialization." ,
271+ timeoutMs);
272+ }
273+ TLLM_THROW (
274+ " NCCL communicator initialization timed out after %d ms on at least one rank. NCCL_NVLS_ENABLE is "
275+ " explicitly set, so TensorRT-LLM will not override it." ,
276+ timeoutMs);
277+ }
278+ if (!initResult.isSuccess ())
279+ {
280+ abortNcclComm (initResult.comm );
281+ }
282+ mpiComm.barrier ();
283+ if (!initResult.isSuccess ())
284+ {
285+ TLLM_THROW (
286+ " NCCL communicator initialization failed on rank %d: %s." , rank, ncclGetErrorString (initResult.result ));
287+ }
288+ TLLM_THROW (" NCCL communicator initialization failed on at least one peer rank." );
80289 }
81- mpiComm.bcastValue (id, 0 );
82- ncclComm_t comm;
83- // Need static connection initialization for accurate KV cache size estimation
84- #if defined(_WIN32)
85- if (getenv (" NCCL_RUNTIME_CONNECT" ) == nullptr )
86- _putenv_s (" NCCL_RUNTIME_CONNECT" , " 0" );
87- #else
88- setenv (" NCCL_RUNTIME_CONNECT" , " 0" , 0 );
89- #endif // _WIN32
90- TLLM_NCCL_CHECK (ncclCommInitRank (&comm, worldSize, id, rank));
91- return comm;
290+
291+ return initResult.comm ;
92292#else
93293 // Python runtime requires instantiation of a communicator even though it may never be used to enable
94294 // pipeline parallel code-path. To enable this, have an empty communicator with uninitialized state.
0 commit comments