diff --git a/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp b/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp index 18c2b19e9eeb..f8058d5e036e 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp +++ b/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp @@ -508,6 +508,11 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa CacheTransceiver::~CacheTransceiver() { + // Stop sender/receiver workers while the connection manager and transfer + // plugin are still alive. The workers can access both during termination. + mCacheSender.reset(); + mCacheReceiver.reset(); + if (mWrapperLibHandle) { std::lock_guard lock(mDllMutex); diff --git a/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp b/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp index 853866687d53..e24ab62b4ff1 100644 --- a/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp +++ b/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -33,6 +33,7 @@ #include #include #include +#include #include namespace tensorrt_llm::batch_manager @@ -351,7 +352,7 @@ class CacheSender::Impl mRequestToSession.erase(it); } - [[nodiscard]] RequestInfo recvRequestInfo() + [[nodiscard]] std::optional recvRequestInfo() { auto* agentConnectionManager = dynamic_cast(mManager); bool isAgent = agentConnectionManager != nullptr; @@ -361,10 +362,10 @@ class CacheSender::Impl auto const* connection = isAgent ? agentConnectionManager->recvConnectionAndRequestInfo(info, mTerminate) : mManager->recvConnect(DataContext{TransceiverTag::kID_TAG, mTerminate}, &id, sizeof(id)); - if (connection == nullptr && !mManager->isRunning()) + if (connection == nullptr) { - TLLM_LOG_WARNING(" recvRequestInfo connection is nullptr, maybe the server is terminating"); - return info; + TLLM_LOG_WARNING("recvRequestInfo connection is nullptr, maybe the server is terminating"); + return std::nullopt; } if (!isAgent) @@ -639,12 +640,12 @@ class CacheSender::Impl } if (!mReadyResponses.empty()) { - auto const& requestInfo = recvRequestInfo(); - if (mTerminate || !mManager->isRunning()) + auto requestInfo = recvRequestInfo(); + if (!requestInfo.has_value() || mTerminate || !mManager->isRunning()) { return; } - auto reqId = requestInfo.getRequestId(); + auto reqId = requestInfo->getRequestId(); { std::scoped_lock lk(mSenderMutex); @@ -674,6 +675,10 @@ class CacheSender::Impl } it = getCurrentResponse(); } + if (mTerminate || it == mReadyResponses.end()) + { + break; + } sendResponse(it); } } @@ -1288,7 +1293,9 @@ void CacheSender::sendSync(LlmRequest const& llmRequest) RequestInfo CacheSender::recvRequestInfo() { - return mImpl->recvRequestInfo(); + auto requestInfo = mImpl->recvRequestInfo(); + TLLM_CHECK(requestInfo.has_value()); + return *requestInfo; } bool CacheSender::cancelRequest(LlmRequest const& llmRequest)