Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,11 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa

CacheTransceiver::~CacheTransceiver()
{
// Stop sender/receiver workers while the connection manager and transfer
// plugin are still alive. The workers can access both during termination.
mCacheSender.reset();
mCacheReceiver.reset();

if (mWrapperLibHandle)
{
std::lock_guard<std::mutex> lock(mDllMutex);
Expand Down
25 changes: 16 additions & 9 deletions cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -33,6 +33,7 @@
#include <future>
#include <map>
#include <memory>
#include <optional>
#include <unordered_map>

namespace tensorrt_llm::batch_manager
Expand Down Expand Up @@ -351,7 +352,7 @@ class CacheSender::Impl
mRequestToSession.erase(it);
}

[[nodiscard]] RequestInfo recvRequestInfo()
[[nodiscard]] std::optional<RequestInfo> recvRequestInfo()
{
auto* agentConnectionManager = dynamic_cast<executor::kv_cache::AgentConnectionManager*>(mManager);
bool isAgent = agentConnectionManager != nullptr;
Expand All @@ -361,10 +362,10 @@ class CacheSender::Impl
auto const* connection = isAgent
? agentConnectionManager->recvConnectionAndRequestInfo(info, mTerminate)
: mManager->recvConnect(DataContext{TransceiverTag::kID_TAG, mTerminate}, &id, sizeof(id));
if (connection == nullptr && !mManager->isRunning())
if (connection == nullptr)
{
TLLM_LOG_WARNING(" recvRequestInfo connection is nullptr, maybe the server is terminating");
return info;
TLLM_LOG_WARNING("recvRequestInfo connection is nullptr, maybe the server is terminating");
return std::nullopt;
}

if (!isAgent)
Expand Down Expand Up @@ -639,12 +640,12 @@ class CacheSender::Impl
}
if (!mReadyResponses.empty())
{
auto const& requestInfo = recvRequestInfo();
if (mTerminate || !mManager->isRunning())
auto requestInfo = recvRequestInfo();
if (!requestInfo.has_value() || mTerminate || !mManager->isRunning())
{
return;
}
auto reqId = requestInfo.getRequestId();
auto reqId = requestInfo->getRequestId();

{
std::scoped_lock lk(mSenderMutex);
Expand Down Expand Up @@ -674,6 +675,10 @@ class CacheSender::Impl
}
it = getCurrentResponse();
}
if (mTerminate || it == mReadyResponses.end())
{
break;
}
sendResponse(it);
}
}
Expand Down Expand Up @@ -1288,7 +1293,9 @@ void CacheSender::sendSync(LlmRequest const& llmRequest)

RequestInfo CacheSender::recvRequestInfo()
{
return mImpl->recvRequestInfo();
auto requestInfo = mImpl->recvRequestInfo();
TLLM_CHECK(requestInfo.has_value());
return *requestInfo;
}

bool CacheSender::cancelRequest(LlmRequest const& llmRequest)
Expand Down
Loading