diff --git a/cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h b/cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h index 80a06ea5ddf7..dcd828ee8933 100644 --- a/cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h +++ b/cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h @@ -285,7 +285,6 @@ class CacheTransceiver : public BaseCacheTransceiver // request while a C++ status check still dereferences it. std::vector, std::future>> mSenderFutures; std::vector, std::future>> mRequesterFutures; - // Dedup sets so observe-only timeout WARN logs fire at most once per stuck request. std::unordered_set mTimedOutSenderIds; std::unordered_set mTimedOutRequesterIds; std::unordered_set mCompletedSenderRequestIds; @@ -294,6 +293,10 @@ class CacheTransceiver : public BaseCacheTransceiver std::unordered_set mCompletedRequesterRequestIds; std::unordered_set mFailedRequesterRequestIds; std::unordered_map> mRequesterRequestsAwaitingConsensus; + // checkGenTransferStatus bounded-poll metrics, logged on budget exhaustion. + size_t mGenPollWaitedCalls{0}; + size_t mGenPollWaitedIterationsTotal{0}; + size_t mGenPollBudgetExhaustedCount{0}; mpi::MpiComm const* mMpiWorldComm{nullptr}; std::shared_ptr mGroupComm; diff --git a/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp b/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp index 18c2b19e9eeb..e05b43bdbfd4 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp +++ b/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp @@ -53,8 +53,10 @@ #include "tensorrt_llm/runtime/utils/mpiUtils.h" #include "tensorrt_llm/runtime/utils/pgUtils.h" #include +#include #include #include +#include #include #include @@ -508,6 +510,11 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa CacheTransceiver::~CacheTransceiver() { + // Join sender/receiver background threads while mManager (declared later, + // destroyed earlier) is still alive; otherwise the sender's response + // thread can dereference the destroyed manager and segfault on teardown. + mCacheSender.reset(); + mCacheReceiver.reset(); if (mWrapperLibHandle) { std::lock_guard lock(mDllMutex); @@ -866,123 +873,74 @@ RequestStatuses CacheTransceiver::checkContextTransferStatus( void CacheTransceiver::checkGenTransferStatus(std::optional const& atLeastRequestNum) { - bool blockAll = !atLeastRequestNum.has_value(); - std::vector genTransferReadyRequestIds; - for (auto&& [request, future] : mRequesterFutures) - { - if (future.wait_for(std::chrono::milliseconds(0)) == std::future_status::ready) - { - genTransferReadyRequestIds.push_back(request->mRequestId); - } - } - std::unordered_map frequencyMap; + // Bounded poll instead of unbounded future::get(): a stalled transfer must + // not wedge this rank while sibling ranks advance to the next collective. + // Cross-rank agreement on the committed outcome is handled by + // reduceTransferStates below, so the poll itself is purely local -- only + // futures that are already ready are completed; the rest are left in flight + // for later calls or kv-transfer-timeout handling. + static constexpr int kMaxPollIterations = 32; + static constexpr auto kPollInterval = std::chrono::milliseconds(2); - std::vector toBlockRequestIds; auto syncComm = mCacheState->getParallelConfig().mEnableAttentionDP ? mGroupDataComm : mGroupComm; - if ((syncComm) && syncComm->getSize() > 1) - { - auto gatherRequestIdVec = gatherRequestIds(syncComm, genTransferReadyRequestIds); - for (auto&& requestId : gatherRequestIdVec) - { - frequencyMap[requestId]++; - } - } - else - { - for (auto&& requestId : genTransferReadyRequestIds) - { - frequencyMap[requestId]++; - } - } - - std::vector> freqVec(frequencyMap.begin(), frequencyMap.end()); - - std::sort(freqVec.begin(), freqVec.end(), - [](std::pair const& left, - std::pair const& right) { return left.second > right.second; }); - std::unordered_set toCompleteIdSet; - size_t idx = 0; - while (atLeastRequestNum.value_or(0) > static_cast(toCompleteIdSet.size())) - { - if (idx >= freqVec.size()) - { - break; - } - toCompleteIdSet.insert(freqVec.at(idx).first); - if (useMPI()) - { - TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), - " checkGenTransferStatus at least from freqVec requestId: %zu ", freqVec.at(idx).first); - } - else - { - TLLM_LOG_DEBUG(tensorrt_llm::pg_utils::get_world_pg()->getRank(), - " checkGenTransferStatus at least from freqVec requestId: %zu ", freqVec.at(idx).first); - } - idx++; - } - idx = 0; + int const dbgRank = useMPI() ? mpi::MpiComm::world().getRank() : tensorrt_llm::pg_utils::get_world_pg()->getRank(); - // insert order - while (atLeastRequestNum.value_or(0) > static_cast(toCompleteIdSet.size())) + auto countReadyFutures = [this]() { - if (idx >= mRequesterFutures.size()) + int ready = 0; + for (auto&& [request, future] : mRequesterFutures) { - break; - } - if (toCompleteIdSet.find(mRequesterFutures.at(idx).first->mRequestId) == toCompleteIdSet.end()) - { - toCompleteIdSet.insert(mRequesterFutures.at(idx).first->mRequestId); - if (useMPI()) - { - TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), - " checkGenTransferStatus at least from RequesterFuture requestId: %zu atLeastRequestNum:%d", - mRequesterFutures.at(idx).first->mRequestId, atLeastRequestNum.value_or(0)); - } - else + if (future.wait_for(std::chrono::milliseconds(0)) == std::future_status::ready) { - TLLM_LOG_DEBUG(tensorrt_llm::pg_utils::get_world_pg()->getRank(), - " checkGenTransferStatus at least from RequesterFuture requestId: %zu atLeastRequestNum:%d", - mRequesterFutures.at(idx).first->mRequestId, atLeastRequestNum.value_or(0)); + ++ready; } } - idx++; - } - for (auto&& [requestId, freq] : freqVec) + return ready; + }; + + // atLeastRequestNum asks the caller to make progress on at least that many + // transfers; wait for them with a capped budget instead of blocking. A + // missing value keeps the legacy "drain everything" intent (bounded by the + // poll budget) so callers polling until checkGenTransferComplete() converge. + int const target = std::min(atLeastRequestNum.value_or(static_cast(mRequesterFutures.size())), + static_cast(mRequesterFutures.size())); + int pollIterations = 0; + int readyCount = countReadyFutures(); + while (readyCount < target && pollIterations < kMaxPollIterations - 1) { - if (freq == ((syncComm != nullptr) ? syncComm->getSize() : 1)) - { - toCompleteIdSet.insert(requestId); - } - if (useMPI()) - { - TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), " checkGenTransferStatus freqVec requestId: %zu,freq:%d ", - requestId, freq); - } - else - { - TLLM_LOG_DEBUG(tensorrt_llm::pg_utils::get_world_pg()->getRank(), - " checkGenTransferStatus freqVec requestId: %zu,freq:%d ", requestId, freq); - } + ++pollIterations; + std::this_thread::sleep_for(kPollInterval); + readyCount = countReadyFutures(); } - if (useMPI()) + + if (pollIterations > 0) { - TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), - " checkGenTransferStatus toCompleteIdSet size: %zu, atLeastRequestNum: %d ", toCompleteIdSet.size(), - atLeastRequestNum.value_or(0)); + ++mGenPollWaitedCalls; + mGenPollWaitedIterationsTotal += pollIterations; } - else + if (readyCount < target) { - TLLM_LOG_DEBUG(tensorrt_llm::pg_utils::get_world_pg()->getRank(), - " checkGenTransferStatus toCompleteIdSet size: %zu, atLeastRequestNum: %d ", toCompleteIdSet.size(), - atLeastRequestNum.value_or(0)); + ++mGenPollBudgetExhaustedCount; + TLLM_LOG_WARNING( + "checkGenTransferStatus: poll budget exhausted (%d iterations of %ld ms): %d/%d transfers ready, " + "%zu in flight; leaving the rest for later checks or kv-transfer-timeout handling " + "(budget exhausted %zu times, waited in %zu calls, %zu poll iterations in total).", + kMaxPollIterations, static_cast(kPollInterval.count()), readyCount, target, mRequesterFutures.size(), + mGenPollBudgetExhaustedCount, mGenPollWaitedCalls, mGenPollWaitedIterationsTotal); } + TLLM_LOG_DEBUG(dbgRank, " checkGenTransferStatus readyCount: %d, atLeastRequestNum: %d, pollIterations: %d ", + readyCount, atLeastRequestNum.value_or(0), pollIterations); + // Observe-only: gen-side mirror of the context-side timeout WARN. std::optional kvTransferTimeoutMs = std::nullopt; if (mCacheTransceiverConfig.has_value()) { kvTransferTimeoutMs = mCacheTransceiverConfig->getKvTransferTimeoutMs(); } + // Complete only the futures that are already ready (get() returns + // immediately); record the local outcome so reduceTransferStates can commit + // it once every rank agrees. A not-yet-ready future is never blocked on and + // is revisited on the next call. for (auto it = mRequesterFutures.begin(); it != mRequesterFutures.end();) { auto& request = it->first; @@ -992,53 +950,41 @@ void CacheTransceiver::checkGenTransferStatus(std::optional const& atLeastR auto elapsed = std::chrono::duration_cast( LlmRequest::getSteadyClockNow() - request->getKvCacheTransferStart()); auto elapsedMs = static_cast(elapsed.count()); - if (elapsedMs > kvTransferTimeoutMs.value() && mTimedOutRequesterIds.insert(request->mRequestId).second) + if (elapsedMs > kvTransferTimeoutMs.value() && mTimedOutRequesterIds.insert(requestId).second) { TLLM_LOG_WARNING( "Generation KV cache transfer for request %ld exceeded configured timeout: " "elapsed %ld ms > limit %d ms (observe-only).", - request->mRequestId, elapsedMs, kvTransferTimeoutMs.value()); + requestId, elapsedMs, kvTransferTimeoutMs.value()); } } - if (blockAll || toCompleteIdSet.find(requestId) != toCompleteIdSet.end()) + if (it->second.wait_for(std::chrono::milliseconds(0)) != std::future_status::ready) { - try - { - it->second.get(); - bool const failed = request->getState() == LlmRequestState::kDISAGG_TRANS_ERROR; - if (failed) - { - // The receiver uses the error state as a local transfer-failed signal. - // Keep that signal local until the consensus outcome commits it globally. - request->setState(LlmRequestState::kDISAGG_GENERATION_TRANS_IN_PROGRESS); - } - recordLocalTransferOutcome(requestId, request, failed, mCompletedRequesterRequestIds, - mFailedRequesterRequestIds, mRequesterRequestsAwaitingConsensus); - } - catch (std::exception const& e) - { - TLLM_LOG_ERROR("Error occurred during generation transfer for request %ld: %s", requestId, e.what()); - recordLocalTransferOutcome(requestId, request, /*failed=*/true, mCompletedRequesterRequestIds, - mFailedRequesterRequestIds, mRequesterRequestsAwaitingConsensus); - } - if (useMPI()) - { - TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), - "**** it->first->mRequestId: %ld, context request ID: %ld ******** get feature ***", requestId, - request->getContextPhaseParams().value().getReqId()); - } - else + ++it; + continue; + } + try + { + it->second.get(); + bool const failed = request->getState() == LlmRequestState::kDISAGG_TRANS_ERROR; + if (failed) { - TLLM_LOG_DEBUG(tensorrt_llm::pg_utils::get_world_pg()->getRank(), - "**** it->first->mRequestId: %ld, context request ID: %ld ******** get feature ***", requestId, - request->getContextPhaseParams().value().getReqId()); + // The receiver uses the error state as a local transfer-failed signal. + // Keep that signal local until the consensus outcome commits it globally. + request->setState(LlmRequestState::kDISAGG_GENERATION_TRANS_IN_PROGRESS); } - it = mRequesterFutures.erase(it); + recordLocalTransferOutcome(requestId, request, failed, mCompletedRequesterRequestIds, + mFailedRequesterRequestIds, mRequesterRequestsAwaitingConsensus); } - else + catch (std::exception const& e) { - ++it; + TLLM_LOG_ERROR("Error occurred during generation transfer for request %ld: %s", requestId, e.what()); + recordLocalTransferOutcome(requestId, request, /*failed=*/true, mCompletedRequesterRequestIds, + mFailedRequesterRequestIds, mRequesterRequestsAwaitingConsensus); } + TLLM_LOG_DEBUG(dbgRank, "**** it->first->mRequestId: %ld, context request ID: %ld ******** get feature ***", + requestId, request->getContextPhaseParams().value().getReqId()); + it = mRequesterFutures.erase(it); } auto const consensusOutcome diff --git a/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp b/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp index 853866687d53..a73f4cefabf0 100644 --- a/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp +++ b/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp @@ -361,7 +361,9 @@ class CacheSender::Impl auto const* connection = isAgent ? agentConnectionManager->recvConnectionAndRequestInfo(info, mTerminate) : mManager->recvConnect(DataContext{TransceiverTag::kID_TAG, mTerminate}, &id, sizeof(id)); - if (connection == nullptr && !mManager->isRunning()) + // A null connection only happens on shutdown paths (terminate flag or + // manager stopping); bail out before touching the empty RequestInfo. + if (connection == nullptr) { TLLM_LOG_WARNING(" recvRequestInfo connection is nullptr, maybe the server is terminating"); return info; @@ -674,6 +676,12 @@ class CacheSender::Impl } it = getCurrentResponse(); } + // Terminating while waiting leaves it == end(); bail out + // instead of dereferencing it inside sendResponse. + if (mTerminate || it == mReadyResponses.end()) + { + break; + } sendResponse(it); } } diff --git a/jenkins/scripts/perf/local/submit.py b/jenkins/scripts/perf/local/submit.py index 97527eb9a8c4..5de370f68484 100755 --- a/jenkins/scripts/perf/local/submit.py +++ b/jenkins/scripts/perf/local/submit.py @@ -820,7 +820,7 @@ def main(): if is_b200: ucx_tls_cmd = "export UCX_TLS=^ib &&" else: - ucx_tls_cmd = "unset UCX_TLS &&" + ucx_tls_cmd = "unset UCX_TLS UCX_NET_DEVICES &&" script_prefix_lines.extend( [ f'export CTX_WORKER_ENV_VARS="{ctx_worker_env_vars}"', diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index 3d2c0b8980b0..639746d5c98d 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -703,6 +703,7 @@ def __init__( self.is_cuda_graph_dummy = False self.py_kv_transfer_start_time = None self.py_kv_transfer_timed_out = False + self.py_kv_transfer_cancelled = False # Performance timing info (step metrics, GPU events, context GPU timing) # Lazily created only when return_perf_metrics is enabled to avoid diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 1b624f72d88b..a007e78c8bd0 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -4044,8 +4044,98 @@ def _check_disagg_gen_transfer_status(self): at_least_num = 1 if need_check_one else 0 self._check_disagg_gen_cache_transfer_status(at_least_num) + self._cancel_timed_out_gen_transfers() + return + @nvtx_range("_cancel_timed_out_gen_transfers") + def _cancel_timed_out_gen_transfers(self) -> None: + """Cancel generation-side transfers that exceeded the transfer timeout. + + Runs on every rank each iteration. The cancel decision and the + resulting request-state change must land on the same iteration on + every rank, otherwise per-rank KV-block release / scheduling diverges + and the next collective deadlocks. Timeout flags are refreshed here so + the path also works in loops (e.g. PP) that do not call + _check_kv_transfer_timeout for generation requests. + """ + timeout_ms = self.kv_cache_transceiver.kv_transfer_timeout_ms + if timeout_ms is None: + return + in_progress = { + req.py_request_id: req + for req in self.active_requests + if req.is_disagg_generation_transmission_in_progress + } + # Refresh generation-side timeout flags (mirrors the generation branch + # of _check_kv_transfer_timeout) so cancellation does not depend on the + # caller having flagged them first. + current_time = time.time() + for req in in_progress.values(): + if req.py_kv_transfer_start_time is None: + continue + elapsed_ms = (current_time - req.py_kv_transfer_start_time) * 1000 + if elapsed_ms > timeout_ms and not req.py_kv_transfer_timed_out: + logger.warning( + f"Terminating generation request {req.py_request_id} due to KV cache transfer timeout" + ) + req.py_kv_transfer_timed_out = True + user_canceled = set(self.canceled_req_ids) + local_timed_out = sorted( + rid for rid, req in in_progress.items() + if req.py_kv_transfer_timed_out and rid not in user_canceled) + + if self.dist.tp_size > 1: + any_timed_out = self.dist.tp_allreduce(int(bool(local_timed_out)), + op=ReduceOp.MAX) + else: + any_timed_out = int(bool(local_timed_out)) + if not any_timed_out: + return + + if self.dist.tp_size > 1: + gathered = self.dist.tp_allgather(local_timed_out) + global_timed_out = sorted(set().union(*gathered)) + else: + global_timed_out = local_timed_out + + local_ok = [] + for rid in global_timed_out: + req = in_progress.get(rid) + if req is None: + # Not tracked on this rank + local_ok.append(True) + continue + if not req.py_kv_transfer_cancelled and \ + self.kv_cache_transceiver.cancel_request(req): + req.py_kv_transfer_cancelled = True + local_ok.append(req.py_kv_transfer_cancelled) + + if self.dist.tp_size > 1: + all_ok = self.dist.tp_allgather(local_ok) + global_ok = [ + all(ok[i] for ok in all_ok) + for i in range(len(global_timed_out)) + ] + else: + global_ok = local_ok + + for rid, ok in zip(global_timed_out, global_ok, strict=True): + req = in_progress.get(rid) + if not ok or req is None: + # Cancellation incomplete somewhere (e.g. transfer mid-flight, + # which the transceiver cannot abort yet); retry next iteration. + continue + logger.warning( + f"Cancelled generation request {rid} after KV cache transfer timeout" + ) + req.py_kv_transfer_start_time = None + req.state = LlmRequestState.DISAGG_TRANS_ERROR + self._handle_errors( + "Error in kv cache transfer for generation requests", + requests=self._get_disagg_reqs_in_error_state(), + charge_budget=False) + @nvtx_range("_check_kv_transfer_timeout") def _check_kv_transfer_timeout(self): if not self.kv_cache_transceiver: @@ -5007,16 +5097,6 @@ def _handle_responses(self, emit_first_iter: bool = True): requests_to_terminate.append(request) continue - # Check if generation request needs cleanup due to KV cache transfer timeout - if request.py_kv_transfer_timed_out: - is_cancelled = self.kv_cache_transceiver.cancel_request(request) - if is_cancelled: - self._handle_errors( - error_msg=f"Request {request.py_request_id} timed out", - requests=[request], - charge_budget=False) - continue - if request.is_generation_only_request() and not request.is_finished: # If request is in transmission, so we don't need to emit a response # Also, for the first iteration with overlap, we should skip since first