Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,6 @@ class CacheTransceiver : public BaseCacheTransceiver
// request while a C++ status check still dereferences it.
std::vector<std::pair<std::shared_ptr<LlmRequest>, std::future<void>>> mSenderFutures;
std::vector<std::pair<std::shared_ptr<LlmRequest>, std::future<void>>> mRequesterFutures;
// Dedup sets so observe-only timeout WARN logs fire at most once per stuck request.
std::unordered_set<LlmRequest::RequestIdType> mTimedOutSenderIds;
std::unordered_set<LlmRequest::RequestIdType> mTimedOutRequesterIds;
std::unordered_set<LlmRequest::RequestIdType> mCompletedSenderRequestIds;
Expand All @@ -294,6 +293,10 @@ class CacheTransceiver : public BaseCacheTransceiver
std::unordered_set<LlmRequest::RequestIdType> mCompletedRequesterRequestIds;
std::unordered_set<LlmRequest::RequestIdType> mFailedRequesterRequestIds;
std::unordered_map<LlmRequest::RequestIdType, std::shared_ptr<LlmRequest>> mRequesterRequestsAwaitingConsensus;
// checkGenTransferStatus bounded-poll metrics, logged on budget exhaustion.
size_t mGenPollWaitedCalls{0};
size_t mGenPollWaitedIterationsTotal{0};
size_t mGenPollBudgetExhaustedCount{0};
mpi::MpiComm const* mMpiWorldComm{nullptr};

std::shared_ptr<CacheTransceiverComm> mGroupComm;
Expand Down
210 changes: 78 additions & 132 deletions cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,10 @@
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
#include "tensorrt_llm/runtime/utils/pgUtils.h"
#include <algorithm>
#include <chrono>
#include <cstddef>
#include <numeric>
#include <thread>
#include <unordered_map>
#include <unordered_set>

Expand Down Expand Up @@ -508,6 +510,11 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa

CacheTransceiver::~CacheTransceiver()
{
// Join sender/receiver background threads while mManager (declared later,
// destroyed earlier) is still alive; otherwise the sender's response
// thread can dereference the destroyed manager and segfault on teardown.
mCacheSender.reset();
mCacheReceiver.reset();
if (mWrapperLibHandle)
{
std::lock_guard<std::mutex> lock(mDllMutex);
Expand Down Expand Up @@ -866,123 +873,74 @@ RequestStatuses CacheTransceiver::checkContextTransferStatus(

void CacheTransceiver::checkGenTransferStatus(std::optional<int> const& atLeastRequestNum)
{
bool blockAll = !atLeastRequestNum.has_value();
std::vector<LlmRequest::RequestIdType> genTransferReadyRequestIds;
for (auto&& [request, future] : mRequesterFutures)
{
if (future.wait_for(std::chrono::milliseconds(0)) == std::future_status::ready)
{
genTransferReadyRequestIds.push_back(request->mRequestId);
}
}
std::unordered_map<LlmRequest::RequestIdType, int> frequencyMap;
// Bounded poll instead of unbounded future::get(): a stalled transfer must
// not wedge this rank while sibling ranks advance to the next collective.
// Cross-rank agreement on the committed outcome is handled by
// reduceTransferStates below, so the poll itself is purely local -- only
// futures that are already ready are completed; the rest are left in flight
// for later calls or kv-transfer-timeout handling.
static constexpr int kMaxPollIterations = 32;
static constexpr auto kPollInterval = std::chrono::milliseconds(2);

std::vector<LlmRequest::RequestIdType> toBlockRequestIds;
auto syncComm = mCacheState->getParallelConfig().mEnableAttentionDP ? mGroupDataComm : mGroupComm;
if ((syncComm) && syncComm->getSize() > 1)
{
auto gatherRequestIdVec = gatherRequestIds(syncComm, genTransferReadyRequestIds);
for (auto&& requestId : gatherRequestIdVec)
{
frequencyMap[requestId]++;
}
}
else
{
for (auto&& requestId : genTransferReadyRequestIds)
{
frequencyMap[requestId]++;
}
}

std::vector<std::pair<LlmRequest::RequestIdType, int>> freqVec(frequencyMap.begin(), frequencyMap.end());

std::sort(freqVec.begin(), freqVec.end(),
[](std::pair<LlmRequest::RequestIdType, int> const& left,
std::pair<LlmRequest::RequestIdType, int> const& right) { return left.second > right.second; });
std::unordered_set<LlmRequest::RequestIdType> toCompleteIdSet;
size_t idx = 0;
while (atLeastRequestNum.value_or(0) > static_cast<int>(toCompleteIdSet.size()))
{
if (idx >= freqVec.size())
{
break;
}
toCompleteIdSet.insert(freqVec.at(idx).first);
if (useMPI())
{
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
" checkGenTransferStatus at least from freqVec requestId: %zu ", freqVec.at(idx).first);
}
else
{
TLLM_LOG_DEBUG(tensorrt_llm::pg_utils::get_world_pg()->getRank(),
" checkGenTransferStatus at least from freqVec requestId: %zu ", freqVec.at(idx).first);
}
idx++;
}
idx = 0;
int const dbgRank = useMPI() ? mpi::MpiComm::world().getRank() : tensorrt_llm::pg_utils::get_world_pg()->getRank();

// insert order
while (atLeastRequestNum.value_or(0) > static_cast<int>(toCompleteIdSet.size()))
auto countReadyFutures = [this]()
{
if (idx >= mRequesterFutures.size())
int ready = 0;
for (auto&& [request, future] : mRequesterFutures)
{
break;
}
if (toCompleteIdSet.find(mRequesterFutures.at(idx).first->mRequestId) == toCompleteIdSet.end())
{
toCompleteIdSet.insert(mRequesterFutures.at(idx).first->mRequestId);
if (useMPI())
{
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
" checkGenTransferStatus at least from RequesterFuture requestId: %zu atLeastRequestNum:%d",
mRequesterFutures.at(idx).first->mRequestId, atLeastRequestNum.value_or(0));
}
else
if (future.wait_for(std::chrono::milliseconds(0)) == std::future_status::ready)
{
TLLM_LOG_DEBUG(tensorrt_llm::pg_utils::get_world_pg()->getRank(),
" checkGenTransferStatus at least from RequesterFuture requestId: %zu atLeastRequestNum:%d",
mRequesterFutures.at(idx).first->mRequestId, atLeastRequestNum.value_or(0));
++ready;
}
}
idx++;
}
for (auto&& [requestId, freq] : freqVec)
return ready;
};

// atLeastRequestNum asks the caller to make progress on at least that many
// transfers; wait for them with a capped budget instead of blocking. A
// missing value keeps the legacy "drain everything" intent (bounded by the
// poll budget) so callers polling until checkGenTransferComplete() converge.
int const target = std::min(atLeastRequestNum.value_or(static_cast<int>(mRequesterFutures.size())),
static_cast<int>(mRequesterFutures.size()));
int pollIterations = 0;
int readyCount = countReadyFutures();
while (readyCount < target && pollIterations < kMaxPollIterations - 1)
{
if (freq == ((syncComm != nullptr) ? syncComm->getSize() : 1))
{
toCompleteIdSet.insert(requestId);
}
if (useMPI())
{
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), " checkGenTransferStatus freqVec requestId: %zu,freq:%d ",
requestId, freq);
}
else
{
TLLM_LOG_DEBUG(tensorrt_llm::pg_utils::get_world_pg()->getRank(),
" checkGenTransferStatus freqVec requestId: %zu,freq:%d ", requestId, freq);
}
++pollIterations;
std::this_thread::sleep_for(kPollInterval);
readyCount = countReadyFutures();
}
if (useMPI())

if (pollIterations > 0)
{
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
" checkGenTransferStatus toCompleteIdSet size: %zu, atLeastRequestNum: %d ", toCompleteIdSet.size(),
atLeastRequestNum.value_or(0));
++mGenPollWaitedCalls;
mGenPollWaitedIterationsTotal += pollIterations;
}
else
if (readyCount < target)
{
TLLM_LOG_DEBUG(tensorrt_llm::pg_utils::get_world_pg()->getRank(),
" checkGenTransferStatus toCompleteIdSet size: %zu, atLeastRequestNum: %d ", toCompleteIdSet.size(),
atLeastRequestNum.value_or(0));
++mGenPollBudgetExhaustedCount;
TLLM_LOG_WARNING(
"checkGenTransferStatus: poll budget exhausted (%d iterations of %ld ms): %d/%d transfers ready, "
"%zu in flight; leaving the rest for later checks or kv-transfer-timeout handling "
"(budget exhausted %zu times, waited in %zu calls, %zu poll iterations in total).",
kMaxPollIterations, static_cast<long>(kPollInterval.count()), readyCount, target, mRequesterFutures.size(),
mGenPollBudgetExhaustedCount, mGenPollWaitedCalls, mGenPollWaitedIterationsTotal);
}
TLLM_LOG_DEBUG(dbgRank, " checkGenTransferStatus readyCount: %d, atLeastRequestNum: %d, pollIterations: %d ",
readyCount, atLeastRequestNum.value_or(0), pollIterations);

// Observe-only: gen-side mirror of the context-side timeout WARN.
std::optional<int> kvTransferTimeoutMs = std::nullopt;
if (mCacheTransceiverConfig.has_value())
{
kvTransferTimeoutMs = mCacheTransceiverConfig->getKvTransferTimeoutMs();
}
// Complete only the futures that are already ready (get() returns
// immediately); record the local outcome so reduceTransferStates can commit
// it once every rank agrees. A not-yet-ready future is never blocked on and
// is revisited on the next call.
for (auto it = mRequesterFutures.begin(); it != mRequesterFutures.end();)
{
auto& request = it->first;
Expand All @@ -992,53 +950,41 @@ void CacheTransceiver::checkGenTransferStatus(std::optional<int> const& atLeastR
auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(
LlmRequest::getSteadyClockNow() - request->getKvCacheTransferStart());
auto elapsedMs = static_cast<long>(elapsed.count());
if (elapsedMs > kvTransferTimeoutMs.value() && mTimedOutRequesterIds.insert(request->mRequestId).second)
if (elapsedMs > kvTransferTimeoutMs.value() && mTimedOutRequesterIds.insert(requestId).second)
{
TLLM_LOG_WARNING(
"Generation KV cache transfer for request %ld exceeded configured timeout: "
"elapsed %ld ms > limit %d ms (observe-only).",
request->mRequestId, elapsedMs, kvTransferTimeoutMs.value());
requestId, elapsedMs, kvTransferTimeoutMs.value());
}
}
if (blockAll || toCompleteIdSet.find(requestId) != toCompleteIdSet.end())
if (it->second.wait_for(std::chrono::milliseconds(0)) != std::future_status::ready)
{
try
{
it->second.get();
bool const failed = request->getState() == LlmRequestState::kDISAGG_TRANS_ERROR;
if (failed)
{
// The receiver uses the error state as a local transfer-failed signal.
// Keep that signal local until the consensus outcome commits it globally.
request->setState(LlmRequestState::kDISAGG_GENERATION_TRANS_IN_PROGRESS);
}
recordLocalTransferOutcome(requestId, request, failed, mCompletedRequesterRequestIds,
mFailedRequesterRequestIds, mRequesterRequestsAwaitingConsensus);
}
catch (std::exception const& e)
{
TLLM_LOG_ERROR("Error occurred during generation transfer for request %ld: %s", requestId, e.what());
recordLocalTransferOutcome(requestId, request, /*failed=*/true, mCompletedRequesterRequestIds,
mFailedRequesterRequestIds, mRequesterRequestsAwaitingConsensus);
}
if (useMPI())
{
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
"**** it->first->mRequestId: %ld, context request ID: %ld ******** get feature ***", requestId,
request->getContextPhaseParams().value().getReqId());
}
else
++it;
continue;
}
try
{
it->second.get();
bool const failed = request->getState() == LlmRequestState::kDISAGG_TRANS_ERROR;
if (failed)
{
TLLM_LOG_DEBUG(tensorrt_llm::pg_utils::get_world_pg()->getRank(),
"**** it->first->mRequestId: %ld, context request ID: %ld ******** get feature ***", requestId,
request->getContextPhaseParams().value().getReqId());
// The receiver uses the error state as a local transfer-failed signal.
// Keep that signal local until the consensus outcome commits it globally.
request->setState(LlmRequestState::kDISAGG_GENERATION_TRANS_IN_PROGRESS);
}
it = mRequesterFutures.erase(it);
recordLocalTransferOutcome(requestId, request, failed, mCompletedRequesterRequestIds,
mFailedRequesterRequestIds, mRequesterRequestsAwaitingConsensus);
}
else
catch (std::exception const& e)
{
++it;
TLLM_LOG_ERROR("Error occurred during generation transfer for request %ld: %s", requestId, e.what());
recordLocalTransferOutcome(requestId, request, /*failed=*/true, mCompletedRequesterRequestIds,
mFailedRequesterRequestIds, mRequesterRequestsAwaitingConsensus);
}
TLLM_LOG_DEBUG(dbgRank, "**** it->first->mRequestId: %ld, context request ID: %ld ******** get feature ***",
requestId, request->getContextPhaseParams().value().getReqId());
it = mRequesterFutures.erase(it);
}

auto const consensusOutcome
Expand Down
10 changes: 9 additions & 1 deletion cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,9 @@ class CacheSender::Impl
auto const* connection = isAgent
? agentConnectionManager->recvConnectionAndRequestInfo(info, mTerminate)
: mManager->recvConnect(DataContext{TransceiverTag::kID_TAG, mTerminate}, &id, sizeof(id));
if (connection == nullptr && !mManager->isRunning())
// A null connection only happens on shutdown paths (terminate flag or
// manager stopping); bail out before touching the empty RequestInfo.
if (connection == nullptr)
{
TLLM_LOG_WARNING(" recvRequestInfo connection is nullptr, maybe the server is terminating");
return info;
Expand Down Expand Up @@ -674,6 +676,12 @@ class CacheSender::Impl
}
it = getCurrentResponse();
}
// Terminating while waiting leaves it == end(); bail out
// instead of dereferencing it inside sendResponse.
if (mTerminate || it == mReadyResponses.end())
{
break;
}
sendResponse(it);
}
}
Expand Down
2 changes: 1 addition & 1 deletion jenkins/scripts/perf/local/submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,7 +820,7 @@ def main():
if is_b200:
ucx_tls_cmd = "export UCX_TLS=^ib &&"
else:
ucx_tls_cmd = "unset UCX_TLS &&"
ucx_tls_cmd = "unset UCX_TLS UCX_NET_DEVICES &&"
script_prefix_lines.extend(
[
f'export CTX_WORKER_ENV_VARS="{ctx_worker_env_vars}"',
Expand Down
1 change: 1 addition & 0 deletions tensorrt_llm/_torch/pyexecutor/llm_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,7 @@ def __init__(
self.is_cuda_graph_dummy = False
self.py_kv_transfer_start_time = None
self.py_kv_transfer_timed_out = False
self.py_kv_transfer_cancelled = False

# Performance timing info (step metrics, GPU events, context GPU timing)
# Lazily created only when return_perf_metrics is enabled to avoid
Expand Down
Loading
Loading