11/*
2- * Copyright (c) 2022-2026 , NVIDIA CORPORATION. All rights reserved.
2+ * Copyright (c) 2022-2024 , NVIDIA CORPORATION. All rights reserved.
33 *
44 * Licensed under the Apache License, Version 2.0 (the "License");
55 * you may not use this file except in compliance with the License.
2323#include < nccl.h>
2424#endif // ENABLE_MULTI_DEVICE
2525
26- #include < array>
27- #include < chrono>
28- #include < cstdlib>
29- #include < string>
30- #include < thread>
31-
3226using namespace tensorrt_llm ::runtime;
3327
3428namespace
3529{
3630#if ENABLE_MULTI_DEVICE
37- constexpr int kDefaultNcclCommInitTimeoutMs = 60'000 ;
38- constexpr int kNcclCommInitPollIntervalMs = 20 ;
39- constexpr char const * kNcclCommInitTimeoutEnv = " TLLM_NCCL_COMM_INIT_TIMEOUT_MS" ;
40- constexpr char const * kNcclNvlsEnableEnv = " NCCL_NVLS_ENABLE" ;
41-
42- struct NcclInitResult
43- {
44- ncclComm_t comm{nullptr };
45- ncclResult_t result{ncclSuccess};
46- bool timedOut{false };
47-
48- [[nodiscard]] bool isSuccess () const
49- {
50- return result == ncclSuccess;
51- }
52- };
53-
54- struct NcclInitStatus
55- {
56- bool failed{false };
57- bool timedOut{false };
58- };
59-
60- int getNcclCommInitTimeoutMs ()
61- {
62- auto const * env = std::getenv (kNcclCommInitTimeoutEnv );
63- int const timeoutMs = env == nullptr ? 0 : std::atoi (env);
64- return timeoutMs > 0 ? timeoutMs : kDefaultNcclCommInitTimeoutMs ;
65- }
66-
67- bool canSuggestNvlsDisable ()
68- {
69- auto const * nvlsEnable = std::getenv (kNcclNvlsEnableEnv );
70- return nvlsEnable == nullptr || std::string{nvlsEnable} == " 2" ;
71- }
72-
73- void setRuntimeConnectIfUnset ()
74- {
75- // Need static connection initialization for accurate KV cache size estimation.
76- #if defined(_WIN32)
77- if (getenv (" NCCL_RUNTIME_CONNECT" ) == nullptr )
78- {
79- _putenv_s (" NCCL_RUNTIME_CONNECT" , " 0" );
80- }
81- #else
82- setenv (" NCCL_RUNTIME_CONNECT" , " 0" , 0 );
83- #endif // _WIN32
84- }
85-
86- void abortNcclComm (ncclComm_t comm)
87- {
88- if (comm == nullptr )
89- {
90- return ;
91- }
92-
93- auto const result = ncclCommAbort (comm);
94- if (result != ncclSuccess)
95- {
96- TLLM_LOG_WARNING (" Failed to abort NCCL communicator: %s." , ncclGetErrorString (result));
97- }
98- }
99-
100- NcclInitResult initNcclCommWithTimeout (ncclUniqueId const & id, int worldSize, int rank, int timeoutMs)
101- {
102- NcclInitResult initResult;
103- ncclConfig_t config = NCCL_CONFIG_INITIALIZER ;
104- config.blocking = 0 ;
105-
106- auto result = ncclCommInitRankConfig (&initResult.comm , worldSize, id, rank, &config);
107- if (result != ncclSuccess && result != ncclInProgress)
108- {
109- initResult.result = result;
110- return initResult;
111- }
112- if (result == ncclSuccess)
113- {
114- initResult.result = ncclSuccess;
115- return initResult;
116- }
117- if (initResult.comm == nullptr )
118- {
119- initResult.result = result;
120- return initResult;
121- }
122-
123- auto const deadline = std::chrono::steady_clock::now () + std::chrono::milliseconds{timeoutMs};
124- while (true )
125- {
126- ncclResult_t asyncResult = ncclSuccess;
127- result = ncclCommGetAsyncError (initResult.comm , &asyncResult);
128- if (result != ncclSuccess)
129- {
130- initResult.result = result;
131- return initResult;
132- }
133- if (asyncResult != ncclInProgress)
134- {
135- initResult.result = asyncResult;
136- return initResult;
137- }
138- if (std::chrono::steady_clock::now () >= deadline)
139- {
140- initResult.result = ncclInProgress;
141- initResult.timedOut = true ;
142- return initResult;
143- }
144- std::this_thread::sleep_for (std::chrono::milliseconds{kNcclCommInitPollIntervalMs });
145- }
146- }
147-
148- NcclInitStatus getNcclInitStatus (NcclInitResult const & result, tensorrt_llm::mpi::MpiComm const & mpiComm)
149- {
150- std::array<int , 2 > localStatus{result.isSuccess () ? 0 : 1 , result.timedOut ? 1 : 0 };
151- std::array<int , 2 > globalStatus{};
152- mpiComm.allreduce (
153- localStatus.data (), globalStatus.data (), 2 , tensorrt_llm::mpi::MpiType::kINT32 , tensorrt_llm::mpi::MpiOp::MAX );
154- return {globalStatus[0 ] != 0 , globalStatus[1 ] != 0 };
155- }
156-
157- bool allRanksCanUseNvlsDisableWorkaround (tensorrt_llm::mpi::MpiComm const & mpiComm)
158- {
159- int const localCanDisable = canSuggestNvlsDisable () ? 1 : 0 ;
160- int globalCanDisable = 0 ;
161- mpiComm.allreduce (
162- &localCanDisable, &globalCanDisable, 1 , tensorrt_llm::mpi::MpiType::kINT32 , tensorrt_llm::mpi::MpiOp::MIN );
163-
164- return globalCanDisable != 0 ;
165- }
166-
167- void checkNcclResult (ncclComm_t comm, ncclResult_t result, char const * operation)
168- {
169- if (result == ncclSuccess)
170- {
171- return ;
172- }
173- if (result != ncclInProgress)
174- {
175- TLLM_NCCL_CHECK (result);
176- }
177-
178- while (true )
179- {
180- ncclResult_t asyncResult = ncclSuccess;
181- result = ncclCommGetAsyncError (comm, &asyncResult);
182- if (result != ncclSuccess)
183- {
184- TLLM_THROW (" NCCL %s failed while polling communicator status: %s." , operation, ncclGetErrorString (result));
185- }
186- if (asyncResult == ncclSuccess)
187- {
188- return ;
189- }
190- if (asyncResult != ncclInProgress)
191- {
192- TLLM_THROW (" NCCL %s failed asynchronously: %s." , operation, ncclGetErrorString (asyncResult));
193- }
194- std::this_thread::sleep_for (std::chrono::milliseconds{kNcclCommInitPollIntervalMs });
195- }
196- }
197-
198- ncclUniqueId createAndBroadcastNcclId (int rank, tensorrt_llm::mpi::MpiComm const & mpiComm)
199- {
200- ncclUniqueId id;
201- if (rank == 0 )
202- {
203- TLLM_NCCL_CHECK (ncclGetUniqueId (&id));
204- }
205- mpiComm.bcastValue (id, 0 );
206- return id;
207- }
20831
20932ncclDataType_t toNcclType (nvinfer1::DataType dataType)
21033{
@@ -230,7 +53,7 @@ void NcclCommunicator::send(
23053 void const * sendbuff, size_t count, nvinfer1::DataType dataType, int peer, CudaStream const & stream) const
23154{
23255#if ENABLE_MULTI_DEVICE
233- checkNcclResult ( mComm , ncclSend (sendbuff, count, toNcclType (dataType), peer, mComm , stream.get ()), " send " );
56+ TLLM_NCCL_CHECK ( ncclSend (sendbuff, count, toNcclType (dataType), peer, mComm , stream.get ()));
23457#else
23558 TLLM_THROW (" Multi device support is disabled." );
23659#endif // ENABLE_MULTI_DEVICE
@@ -240,7 +63,7 @@ void NcclCommunicator::receive(
24063 void * sendbuff, size_t count, nvinfer1::DataType dataType, int peer, CudaStream const & stream) const
24164{
24265#if ENABLE_MULTI_DEVICE
243- checkNcclResult ( mComm , ncclRecv (sendbuff, count, toNcclType (dataType), peer, mComm , stream.get ()), " receive " );
66+ TLLM_NCCL_CHECK ( ncclRecv (sendbuff, count, toNcclType (dataType), peer, mComm , stream.get ()));
24467#else
24568 TLLM_THROW (" Multi device support is disabled." );
24669#endif // ENABLE_MULTI_DEVICE
@@ -250,45 +73,22 @@ ncclComm_t NcclCommunicator::createComm(int worldSize, int rank, mpi::MpiComm co
25073{
25174#if ENABLE_MULTI_DEVICE
25275
253- setRuntimeConnectIfUnset ();
254- auto const timeoutMs = getNcclCommInitTimeoutMs ();
255-
256- auto id = createAndBroadcastNcclId (rank, mpiComm);
257- auto initResult = initNcclCommWithTimeout (id, worldSize, rank, timeoutMs);
258- auto const initStatus = getNcclInitStatus (initResult, mpiComm);
259-
260- if (initStatus.failed )
76+ ncclUniqueId id;
77+ if (rank == 0 )
26178 {
262- if (initStatus.timedOut )
263- {
264- if (allRanksCanUseNvlsDisableWorkaround (mpiComm))
265- {
266- TLLM_THROW (
267- " NCCL communicator initialization timed out after %d ms on at least one rank. This may indicate "
268- " an NVLS multicast resource setup failure in Fabric Manager. Set NCCL_NVLS_ENABLE=0 before "
269- " process startup and retry. TensorRT-LLM does not retry in-process because NCCL may not recover "
270- " after a timed-out NVLS initialization." ,
271- timeoutMs);
272- }
273- TLLM_THROW (
274- " NCCL communicator initialization timed out after %d ms on at least one rank. NCCL_NVLS_ENABLE is "
275- " explicitly set, so TensorRT-LLM will not override it." ,
276- timeoutMs);
277- }
278- if (!initResult.isSuccess ())
279- {
280- abortNcclComm (initResult.comm );
281- }
282- mpiComm.barrier ();
283- if (!initResult.isSuccess ())
284- {
285- TLLM_THROW (
286- " NCCL communicator initialization failed on rank %d: %s." , rank, ncclGetErrorString (initResult.result ));
287- }
288- TLLM_THROW (" NCCL communicator initialization failed on at least one peer rank." );
79+ ncclGetUniqueId (&id);
28980 }
290-
291- return initResult.comm ;
81+ mpiComm.bcastValue (id, 0 );
82+ ncclComm_t comm;
83+ // Need static connection initialization for accurate KV cache size estimation
84+ #if defined(_WIN32)
85+ if (getenv (" NCCL_RUNTIME_CONNECT" ) == nullptr )
86+ _putenv_s (" NCCL_RUNTIME_CONNECT" , " 0" );
87+ #else
88+ setenv (" NCCL_RUNTIME_CONNECT" , " 0" , 0 );
89+ #endif // _WIN32
90+ TLLM_NCCL_CHECK (ncclCommInitRank (&comm, worldSize, id, rank));
91+ return comm;
29292#else
29393 // Python runtime requires instantiation of a communicator even though it may never be used to enable
29494 // pipeline parallel code-path. To enable this, have an empty communicator with uninitialized state.
0 commit comments