2626#include " tensorrt_llm/common/tllmException.h"
2727#include " tensorrt_llm/common/utils.h"
2828#include " tensorrt_llm/executor/cache_transmission/agent_utils/connection.h"
29+ #include " tensorrt_llm/executor/cache_transmission/cacheSplitConcat.h"
2930#include " tensorrt_llm/runtime/common.h"
3031#include " tensorrt_llm/runtime/utils/mpiUtils.h"
3132#include < chrono>
@@ -384,9 +385,10 @@ class CacheSender::Impl
384385 auto allCounterparts = mCacheTransferLayer .computeCounterparts (
385386 mSelfState .getCommState ().value ().getSelfIdx (), info.getTransState ());
386387
387- auto peerSelfIdx = info.getTransState ().getCommState ()->getSelfIdx (); // Index of self in peer's comm state
388+ auto peerSelfIdx = info.getTransState ().getCommState ()->getSelfIdx ();
388389 int peerIdx = std::distance (
389390 allCounterparts.begin (), std::find (allCounterparts.begin (), allCounterparts.end (), peerSelfIdx));
391+
390392 TLLM_CHECK_WITH_INFO (peerIdx < static_cast <int >(allCounterparts.size ()),
391393 " Peer rank %d not found in expected counterparts" , peerSelfIdx);
392394 {
@@ -861,6 +863,19 @@ class CacheReceiver::Impl
861863 auto allCounterparts
862864 = mCacheTransferLayer .computeCounterparts (mSelfState .getCommState ().value ().getSelfIdx (), contextState);
863865
866+ auto kvCounterParts = mCacheTransferLayer .getKvFormatter ()->getCounterparts (
867+ mCacheTransferLayer .getCacheState (), mSelfState .getCommState ().value ().getSelfIdx (), destCacheState);
868+
869+ bool hasRnn = mCacheTransferLayer .getCacheState ().hasRnnConfig () && destCacheState.hasRnnConfig ();
870+
871+ std::vector<SizeType32> rnnCounterParts;
872+ if (hasRnn)
873+ {
874+ rnnCounterParts = executor::kv_cache::targetIRanksForRnn (
875+ destCacheState, mCacheTransferLayer .getCacheState (), mSelfState .getCommState ().value ().getSelfIdx ())
876+ .mIRanks ;
877+ }
878+
864879 auto connections = mManager ->getConnections (commState);
865880 std::vector<executor::kv_cache::Connection const *> allConnections;
866881 for (auto index : allCounterparts)
@@ -869,24 +884,59 @@ class CacheReceiver::Impl
869884 allConnections.emplace_back (connection);
870885 }
871886
872- for (size_t i = 0 ; i < allConnections .size (); i ++)
887+ for (size_t ci = 0 ; ci < allCounterparts .size (); ci ++)
873888 {
874- auto const * connection = allConnections[i];
875- // if Manager is agentConnectionManager, then send request info to agent
876- auto * agentConnectionManager = dynamic_cast <executor::kv_cache::AgentConnectionManager*>(mManager );
889+ auto rank = allCounterparts[ci];
890+ auto const * connection = connections.at (rank);
891+
892+ bool isKvCounterpart
893+ = std::find (kvCounterParts.begin (), kvCounterParts.end (), rank) != kvCounterParts.end ();
894+ bool isRnnCounterpart
895+ = hasRnn && std::find (rnnCounterParts.begin (), rnnCounterParts.end (), rank) != rnnCounterParts.end ();
896+
877897 if (agentConnectionManager)
878898 {
879- // TODO: index -> validConnectionIdx conversion
880- // TODO(shreyasm): this will not work for RNN. Will error out in the constructor if used with RNN.
881- auto [pickUpIdx, localRankIdx] = mCacheTransferLayer .getKvFormatter ()->pickRecvConnections (
882- allCounterparts.size (), mSelfState .getCacheState ().value (),
883- mSelfState .getCommState ().value ().getSelfIdx (), destCacheState, allCounterparts);
884- auto validConnectionIdx = std::find (localRankIdx.begin (), localRankIdx.end (), i) - localRankIdx.begin ();
899+ auto idsForRank = cacheBufferIds;
900+ auto const & managers = agentConnectionManager->getCacheTransBufferManagers ();
901+ for (size_t i = 0 ; i < idsForRank.size (); i++)
902+ {
903+ auto kind = managers[i]->getBufferKind ();
904+ bool include = (kind != BufferKind::kRNN ) ? isKvCounterpart : isRnnCounterpart;
905+ if (!include)
906+ {
907+ idsForRank[i] = std::nullopt ;
908+ }
909+ }
910+
911+ int validConnectionIdx = 0 ;
912+ if (isKvCounterpart)
913+ {
914+ auto kvCpIdx
915+ = std::find (kvCounterParts.begin (), kvCounterParts.end (), rank) - kvCounterParts.begin ();
916+ auto [pickUpIdx, localRankIdx] = mCacheTransferLayer .getKvFormatter ()->pickRecvConnections (
917+ allCounterparts.size (), mSelfState .getCacheState ().value (),
918+ mSelfState .getCommState ().value ().getSelfIdx (), destCacheState, allCounterparts);
919+ validConnectionIdx
920+ = std::find (localRankIdx.begin (), localRankIdx.end (), kvCpIdx) - localRankIdx.begin ();
921+ }
922+ else if (isRnnCounterpart)
923+ {
924+ auto rnnTargetInfo = executor::kv_cache::targetIRanksForRnn (destCacheState,
925+ mCacheTransferLayer .getCacheState (), mSelfState .getCommState ().value ().getSelfIdx ());
926+ auto rnnCpIdx
927+ = std::find (rnnCounterParts.begin (), rnnCounterParts.end (), rank) - rnnCounterParts.begin ();
928+ auto [pickUpIdx, localRankIdx] = cache_formatter_utils::pickRecvConnections (rnnCounterParts.size (),
929+ mCacheTransferLayer .getCacheState (), mSelfState .getCommState ().value ().getSelfIdx (),
930+ destCacheState, rnnCounterParts, rnnTargetInfo);
931+ validConnectionIdx
932+ = std::find (localRankIdx.begin (), localRankIdx.end (), rnnCpIdx) - localRankIdx.begin ();
933+ }
934+
885935 auto * agentConnection = dynamic_cast <executor::kv_cache::AgentConnection const *>(connection);
886936 TLLM_CHECK (agentConnection != nullptr );
887- TLLM_CHECK (!cacheBufferIds. empty ());
937+
888938 const_cast <executor::kv_cache::AgentConnection*>(agentConnection)
889- ->sendRequestAndBufferInfo (requestInfo, cacheBufferIds , validConnectionIdx);
939+ ->sendRequestAndBufferInfo (requestInfo, idsForRank , validConnectionIdx);
890940 }
891941 else
892942 {
0 commit comments