Skip to content

Commit 69de4a6

Browse files
authored
[None][feat] NIXL support for hybrid model cache transfer (#11608)
Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
1 parent 3b82b6c commit 69de4a6

23 files changed

Lines changed: 412 additions & 148 deletions

File tree

cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ class CacheTransceiver : public BaseCacheTransceiver
288288
std::unique_ptr<executor::kv_cache::ConnectionManager> mManager;
289289
std::optional<executor::CacheTransceiverConfig> mCacheTransceiverConfig;
290290
std::vector<std::unique_ptr<kv_cache_manager::CacheTransBufferManager>> mCacheTransBufferManagers;
291-
std::vector<kv_cache_manager::CacheTransBufferManager*> mCacheTransBufferManagerPtrs;
291+
std::vector<BaseTransBufferManager*> mCacheTransBufferManagerPtrs;
292292

293293
rnn_state_manager::RnnStateManager* mRnnStateManager{nullptr};
294294
// TODO(shreyasm): update this to use same container as kv by using base trans buffers instead

cpp/include/tensorrt_llm/executor/cacheCommunicator.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
#include "tensorrt_llm/executor/serialization.h"
2020
#include <atomic>
21+
#include <cstdint>
22+
#include <optional>
2123
#include <vector>
2224

2325
namespace tensorrt_llm::executor::kv_cache
@@ -63,6 +65,13 @@ class Connection
6365
{
6466
return false;
6567
}
68+
69+
virtual void activateBuffer(uint8_t /*kind*/) const {}
70+
71+
[[nodiscard]] virtual std::optional<size_t> getPreAssignedBufferId(uint8_t /*kind*/) const
72+
{
73+
return std::nullopt;
74+
}
6675
};
6776

6877
class ConnectionManager

cpp/tensorrt_llm/batch_manager/baseTransBuffer.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include <atomic>
2424
#include <condition_variable>
2525
#include <cstddef>
26+
#include <cstdint>
2627
#include <memory>
2728
#include <mutex>
2829
#include <optional>
@@ -38,6 +39,13 @@ class FabricMemory;
3839
namespace tensorrt_llm::batch_manager
3940
{
4041

42+
enum class BufferKind : uint8_t
43+
{
44+
kKV = 0,
45+
kKV_INDEXER = 1,
46+
kRNN = 2
47+
};
48+
4149
/// @brief Base class for cache transfer buffer management.
4250
/// Handles buffer pool allocation, index assignment, and slicing.
4351
/// Derived classes provide cache-specific size calculations.
@@ -46,6 +54,8 @@ class BaseTransBufferManager
4654
public:
4755
virtual ~BaseTransBufferManager() = default;
4856

57+
[[nodiscard]] virtual BufferKind getBufferKind() const = 0;
58+
4959
/// @brief Assign a buffer index for sending.
5060
/// @return Assigned buffer index, or nullopt if using dynamic buffers.
5161
std::optional<int> assignBufferIndexForSend();

cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -539,9 +539,9 @@ void CacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& sessio
539539
"bufferCoverTargetNum:%d pickUpConnections.size():%ld",
540540
bufferTargetNum, targetNum, peerDuplicateHeadFactor, targetInfo.mDupHeadFactor, bufferCoverTargetNum,
541541
pickUpConnections.size());
542-
auto* agentConnnecion
542+
auto const* agentConnection
543543
= dynamic_cast<executor::kv_cache::AgentConnection const*>(connections[pickUpConnections[0]]);
544-
if (agentConnnecion != nullptr)
544+
if (agentConnection != nullptr)
545545
{
546546
TLLM_CHECK_WITH_INFO(bufferCoverTargetNum == bufferTargetNum, "Agent need all buffer pre-allocated");
547547
TLLM_CHECK(onlyUseDynamicBuffer == false);
@@ -792,12 +792,11 @@ void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& sess
792792

793793
TLLM_CHECK(blockNum > 0);
794794

795-
auto* agentConnnecion
796-
= dynamic_cast<executor::kv_cache::AgentConnection const*>(connections[pickUpConnections[0]]);
797-
if (agentConnnecion != nullptr)
795+
auto preAssignedKvId
796+
= connections[pickUpConnections[0]]->getPreAssignedBufferId(static_cast<uint8_t>(BufferKind::kKV));
797+
if (preAssignedKvId.has_value())
798798
{
799-
cacheBufferId = agentConnnecion->getCacheBufferId();
800-
TLLM_CHECK(cacheBufferId.has_value());
799+
cacheBufferId = static_cast<int>(*preAssignedKvId);
801800
}
802801
else
803802
{
@@ -811,7 +810,7 @@ void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& sess
811810
bufferCoverTargetNum = bufferCoverTargetNumtmp;
812811
remainNoCoverTargetNum = targetNum > bufferCoverTargetNum ? targetNum - bufferCoverTargetNum : 0;
813812

814-
if (agentConnnecion != nullptr)
813+
if (preAssignedKvId.has_value())
815814
{
816815
TLLM_CHECK_WITH_INFO(bufferCoverTargetNum == targetNum, "Agent need buffer pre-allocated");
817816
TLLM_CHECK(onlyUseDynamicBuffer == false);

cpp/tensorrt_llm/batch_manager/cacheTransBuffer.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ CacheTransBufferManager::CacheTransBufferManager(
249249
: cacheManager->getPrimaryPool(0)->getDataType(),
250250
maxNumTokens)
251251
, mCacheManager{cacheManager}
252+
, mTransferIndexerKCache{transferIndexerKCache}
252253
{
253254
// TODO: FP4 dataSize
254255
TLLM_CHECK(mCacheManager);

cpp/tensorrt_llm/batch_manager/cacheTransBuffer.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,18 @@ class CacheTransBufferManager : public BaseTransBufferManager
7474
return mCacheManager;
7575
}
7676

77+
[[nodiscard]] BufferKind getBufferKind() const override
78+
{
79+
return mTransferIndexerKCache ? BufferKind::kKV_INDEXER : BufferKind::kKV;
80+
}
81+
7782
private:
7883
/// @brief Compute transfer buffer size from KV cache configuration.
7984
static size_t computeTransferBufferSize(KVCacheManager::BaseKVCacheManager* cacheManager,
8085
std::optional<size_t> maxNumTokens, bool transferIndexerKCache);
8186

8287
KVCacheManager::BaseKVCacheManager* mCacheManager;
88+
bool mTransferIndexerKCache;
8389
};
8490

8591
} // namespace tensorrt_llm::batch_manager::kv_cache_manager

cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -185,26 +185,13 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
185185
mCacheTransBufferManagers.push_back(
186186
std::make_unique<kv_cache_manager::CacheTransBufferManager>(cacheManager, maxNumTokens, true));
187187
}
188-
mCacheTransBufferManagerPtrs.clear();
189-
mCacheTransBufferManagerPtrs.reserve(mCacheTransBufferManagers.size());
190-
for (auto& manager : mCacheTransBufferManagers)
191-
{
192-
mCacheTransBufferManagerPtrs.push_back(manager.get());
193-
}
194188

195189
// RNN specific setup
196190
if (mRnnStateManager != nullptr)
197191
{
198192
TLLM_LOG_DEBUG("Setting up RNN cache transfer components.");
199193
TLLM_CHECK(!rnnLayerNumPerPP.empty());
200194

201-
if (backendType.value() == executor::CacheTransceiverConfig::BackendType::NIXL
202-
|| backendType.value() == executor::CacheTransceiverConfig::BackendType::MOONCAKE)
203-
{
204-
TLLM_LOG_ERROR("RNN cache transfer is not supported for NIXL and MOONCAKE yet");
205-
return;
206-
}
207-
208195
mRnnCacheTransBufferManager
209196
= std::make_unique<rnn_state_manager::RnnCacheTransBufferManager>(mRnnStateManager, maxNumTokens);
210197

@@ -218,6 +205,17 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
218205
TLLM_LOG_INFO("RNN cache transfer components initialized.");
219206
}
220207

208+
mCacheTransBufferManagerPtrs.clear();
209+
mCacheTransBufferManagerPtrs.reserve(mCacheTransBufferManagers.size() + (mRnnCacheTransBufferManager ? 1 : 0));
210+
for (auto& manager : mCacheTransBufferManagers)
211+
{
212+
mCacheTransBufferManagerPtrs.push_back(manager.get());
213+
}
214+
if (mRnnCacheTransBufferManager)
215+
{
216+
mCacheTransBufferManagerPtrs.push_back(mRnnCacheTransBufferManager.get());
217+
}
218+
221219
if (backendType.value() == executor::CacheTransceiverConfig::BackendType::UCX)
222220
{
223221
std::lock_guard<std::mutex> lock(mDllMutex);
@@ -239,14 +237,18 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
239237
}
240238
else if (backendType.value() == executor::CacheTransceiverConfig::BackendType::NIXL)
241239
{
240+
auto rnnState
241+
= mCacheState->hasRnnConfig() ? std::make_optional(mCacheState->getRnnCacheState()) : std::nullopt;
242242
mManager = std::make_unique<tensorrt_llm::executor::kv_cache::AgentConnectionManager>(
243-
mCacheTransBufferManagerPtrs, *mCacheState, "nixl");
243+
mCacheTransBufferManagerPtrs, *mCacheState, "nixl", rnnState);
244244
TLLM_LOG_INFO("NIXL Connection Manager created");
245245
}
246246
else if (backendType.value() == executor::CacheTransceiverConfig::BackendType::MOONCAKE)
247247
{
248+
auto rnnState
249+
= mCacheState->hasRnnConfig() ? std::make_optional(mCacheState->getRnnCacheState()) : std::nullopt;
248250
mManager = std::make_unique<tensorrt_llm::executor::kv_cache::AgentConnectionManager>(
249-
mCacheTransBufferManagerPtrs, *mCacheState, "mooncake");
251+
mCacheTransBufferManagerPtrs, *mCacheState, "mooncake", rnnState);
250252
TLLM_LOG_INFO("MOONCAKE Connection Manager created");
251253
}
252254
else if (backendType.value() == executor::CacheTransceiverConfig::BackendType::MPI)
@@ -261,7 +263,15 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
261263
}
262264

263265
auto makeFormatter = [cacheManager, isMLA, this]()
264-
{ return createCacheFormatter(cacheManager, mCacheTransBufferManagerPtrs, isMLA); };
266+
{
267+
std::vector<kv_cache_manager::CacheTransBufferManager*> kvBufferPtrs;
268+
kvBufferPtrs.reserve(mCacheTransBufferManagers.size());
269+
for (auto& mgr : mCacheTransBufferManagers)
270+
{
271+
kvBufferPtrs.push_back(mgr.get());
272+
}
273+
return createCacheFormatter(cacheManager, kvBufferPtrs, isMLA);
274+
};
265275

266276
auto makeRnnFormatter = [this]() -> std::unique_ptr<RnnCacheFormatter>
267277
{

cpp/tensorrt_llm/batch_manager/cacheTransferLayer.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "tensorrt_llm/batch_manager/rnnCacheFormatter.h"
2222
#include "tensorrt_llm/common/assert.h"
2323
#include "tensorrt_llm/common/logger.h"
24+
#include "tensorrt_llm/executor/cache_transmission/agent_utils/connection.h"
2425
#include "tensorrt_llm/executor/cache_transmission/cacheSplitConcat.h"
2526

2627
#include <algorithm>
@@ -95,6 +96,13 @@ void CacheTransferLayer::format(TransferSession& session) const
9596
mKvFormatter->format(session);
9697
if (mRnnFormatter)
9798
{
99+
for (auto const* conn : session.getConnections())
100+
{
101+
if (conn != nullptr)
102+
{
103+
conn->activateBuffer(static_cast<uint8_t>(BufferKind::kRNN));
104+
}
105+
}
98106
mRnnFormatter->format(session);
99107
}
100108
}

cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp

Lines changed: 63 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "tensorrt_llm/common/tllmException.h"
2727
#include "tensorrt_llm/common/utils.h"
2828
#include "tensorrt_llm/executor/cache_transmission/agent_utils/connection.h"
29+
#include "tensorrt_llm/executor/cache_transmission/cacheSplitConcat.h"
2930
#include "tensorrt_llm/runtime/common.h"
3031
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
3132
#include <chrono>
@@ -384,9 +385,10 @@ class CacheSender::Impl
384385
auto allCounterparts = mCacheTransferLayer.computeCounterparts(
385386
mSelfState.getCommState().value().getSelfIdx(), info.getTransState());
386387

387-
auto peerSelfIdx = info.getTransState().getCommState()->getSelfIdx(); // Index of self in peer's comm state
388+
auto peerSelfIdx = info.getTransState().getCommState()->getSelfIdx();
388389
int peerIdx = std::distance(
389390
allCounterparts.begin(), std::find(allCounterparts.begin(), allCounterparts.end(), peerSelfIdx));
391+
390392
TLLM_CHECK_WITH_INFO(peerIdx < static_cast<int>(allCounterparts.size()),
391393
"Peer rank %d not found in expected counterparts", peerSelfIdx);
392394
{
@@ -861,6 +863,19 @@ class CacheReceiver::Impl
861863
auto allCounterparts
862864
= mCacheTransferLayer.computeCounterparts(mSelfState.getCommState().value().getSelfIdx(), contextState);
863865

866+
auto kvCounterParts = mCacheTransferLayer.getKvFormatter()->getCounterparts(
867+
mCacheTransferLayer.getCacheState(), mSelfState.getCommState().value().getSelfIdx(), destCacheState);
868+
869+
bool hasRnn = mCacheTransferLayer.getCacheState().hasRnnConfig() && destCacheState.hasRnnConfig();
870+
871+
std::vector<SizeType32> rnnCounterParts;
872+
if (hasRnn)
873+
{
874+
rnnCounterParts = executor::kv_cache::targetIRanksForRnn(
875+
destCacheState, mCacheTransferLayer.getCacheState(), mSelfState.getCommState().value().getSelfIdx())
876+
.mIRanks;
877+
}
878+
864879
auto connections = mManager->getConnections(commState);
865880
std::vector<executor::kv_cache::Connection const*> allConnections;
866881
for (auto index : allCounterparts)
@@ -869,24 +884,59 @@ class CacheReceiver::Impl
869884
allConnections.emplace_back(connection);
870885
}
871886

872-
for (size_t i = 0; i < allConnections.size(); i++)
887+
for (size_t ci = 0; ci < allCounterparts.size(); ci++)
873888
{
874-
auto const* connection = allConnections[i];
875-
// if Manager is agentConnectionManager, then send request info to agent
876-
auto* agentConnectionManager = dynamic_cast<executor::kv_cache::AgentConnectionManager*>(mManager);
889+
auto rank = allCounterparts[ci];
890+
auto const* connection = connections.at(rank);
891+
892+
bool isKvCounterpart
893+
= std::find(kvCounterParts.begin(), kvCounterParts.end(), rank) != kvCounterParts.end();
894+
bool isRnnCounterpart
895+
= hasRnn && std::find(rnnCounterParts.begin(), rnnCounterParts.end(), rank) != rnnCounterParts.end();
896+
877897
if (agentConnectionManager)
878898
{
879-
// TODO: index -> validConnectionIdx conversion
880-
// TODO(shreyasm): this will not work for RNN. Will error out in the constructor if used with RNN.
881-
auto [pickUpIdx, localRankIdx] = mCacheTransferLayer.getKvFormatter()->pickRecvConnections(
882-
allCounterparts.size(), mSelfState.getCacheState().value(),
883-
mSelfState.getCommState().value().getSelfIdx(), destCacheState, allCounterparts);
884-
auto validConnectionIdx = std::find(localRankIdx.begin(), localRankIdx.end(), i) - localRankIdx.begin();
899+
auto idsForRank = cacheBufferIds;
900+
auto const& managers = agentConnectionManager->getCacheTransBufferManagers();
901+
for (size_t i = 0; i < idsForRank.size(); i++)
902+
{
903+
auto kind = managers[i]->getBufferKind();
904+
bool include = (kind != BufferKind::kRNN) ? isKvCounterpart : isRnnCounterpart;
905+
if (!include)
906+
{
907+
idsForRank[i] = std::nullopt;
908+
}
909+
}
910+
911+
int validConnectionIdx = 0;
912+
if (isKvCounterpart)
913+
{
914+
auto kvCpIdx
915+
= std::find(kvCounterParts.begin(), kvCounterParts.end(), rank) - kvCounterParts.begin();
916+
auto [pickUpIdx, localRankIdx] = mCacheTransferLayer.getKvFormatter()->pickRecvConnections(
917+
allCounterparts.size(), mSelfState.getCacheState().value(),
918+
mSelfState.getCommState().value().getSelfIdx(), destCacheState, allCounterparts);
919+
validConnectionIdx
920+
= std::find(localRankIdx.begin(), localRankIdx.end(), kvCpIdx) - localRankIdx.begin();
921+
}
922+
else if (isRnnCounterpart)
923+
{
924+
auto rnnTargetInfo = executor::kv_cache::targetIRanksForRnn(destCacheState,
925+
mCacheTransferLayer.getCacheState(), mSelfState.getCommState().value().getSelfIdx());
926+
auto rnnCpIdx
927+
= std::find(rnnCounterParts.begin(), rnnCounterParts.end(), rank) - rnnCounterParts.begin();
928+
auto [pickUpIdx, localRankIdx] = cache_formatter_utils::pickRecvConnections(rnnCounterParts.size(),
929+
mCacheTransferLayer.getCacheState(), mSelfState.getCommState().value().getSelfIdx(),
930+
destCacheState, rnnCounterParts, rnnTargetInfo);
931+
validConnectionIdx
932+
= std::find(localRankIdx.begin(), localRankIdx.end(), rnnCpIdx) - localRankIdx.begin();
933+
}
934+
885935
auto* agentConnection = dynamic_cast<executor::kv_cache::AgentConnection const*>(connection);
886936
TLLM_CHECK(agentConnection != nullptr);
887-
TLLM_CHECK(!cacheBufferIds.empty());
937+
888938
const_cast<executor::kv_cache::AgentConnection*>(agentConnection)
889-
->sendRequestAndBufferInfo(requestInfo, cacheBufferIds, validConnectionIdx);
939+
->sendRequestAndBufferInfo(requestInfo, idsForRank, validConnectionIdx);
890940
}
891941
else
892942
{

0 commit comments

Comments
 (0)