Skip to content

Commit beb922f

Browse files
authored
[None][fix] Fix encoder-decoder beam search corruption via per-slot fragmentPointerDevice (#15461)
Signed-off-by: Aurelien Chartier <2567591+achartier@users.noreply.github.com>
1 parent 93feb57 commit beb922f

6 files changed

Lines changed: 185 additions & 11 deletions

File tree

cpp/include/tensorrt_llm/batch_manager/runtimeBuffers.h

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
2+
* Copyright (c) 2023-2026, NVIDIA CORPORATION. All rights reserved.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -208,7 +208,9 @@ class RuntimeBuffers
208208
//! Temporarily store the transposed results of multiple fragment logits, [maxBeamWidth, kCACHE_LENGTH]
209209
TensorPtr transposedLogits;
210210

211-
//! Temporarily store logits buffer address during the transposing, [kCACHE_LENGTH]
211+
//! Temporarily store logits buffer address during the transposing, [maxBatchSize, kCACHE_LENGTH]
212+
//! One row per batch slot (same layout as fragmentPointerHost) so concurrent flushes for
213+
//! different requests in the same batch never clobber each other's pointer arrays.
212214
TensorPtr fragmentPointerDevice;
213215

214216
//! Temporarily store logits buffer address during the transposing, [maxBatchSize, kCACHE_LENGTH]
@@ -222,11 +224,14 @@ class RuntimeBuffers
222224
workIdx = (workIdx + 1) % (fragmentPointerHost->getShape().d[0]);
223225
}
224226

225-
[[nodiscard]] TensorPtr getFragmentPointerHost()
227+
//! Returns matching host and device pointer rows for the current workIdx, then advances
228+
//! workIdx. Always call this instead of the individual getters to avoid ordering bugs.
229+
[[nodiscard]] std::pair<TensorPtr, TensorPtr> getFragmentPointerSlot()
226230
{
227-
TensorPtr slice = runtime::ITensor::slice(fragmentPointerHost, workIdx, 1);
231+
TensorPtr host = runtime::ITensor::slice(fragmentPointerHost, workIdx, 1);
232+
TensorPtr device = runtime::ITensor::slice(fragmentPointerDevice, workIdx, 1);
228233
cycleWorkIdx();
229-
return slice;
234+
return {std::move(host), std::move(device)};
230235
};
231236
};
232237

cpp/tensorrt_llm/batch_manager/runtimeBuffers.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
* SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
33
* SPDX-License-Identifier: Apache-2.0
44
*
55
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -152,8 +152,8 @@ void RuntimeBuffers::create(SizeType32 maxBatchSize, SizeType32 maxBeamWidth,
152152
ITensor::makeShape({GenerationLogitsCache::kCACHE_LENGTH, maxBatchSize * maxBeamWidth, vocabSizePadded}),
153153
logitsType);
154154

155-
generationLogitsCache.fragmentPointerDevice
156-
= manager.gpu(ITensor::makeShape({GenerationLogitsCache::kCACHE_LENGTH}), nvinfer1::DataType::kINT64);
155+
generationLogitsCache.fragmentPointerDevice = manager.gpu(
156+
ITensor::makeShape({maxBatchSize, GenerationLogitsCache::kCACHE_LENGTH}), nvinfer1::DataType::kINT64);
157157
generationLogitsCache.fragmentPointerHost = tensorrt_llm::runtime::BufferManager::pinnedPool(
158158
ITensor::makeShape({maxBatchSize, GenerationLogitsCache::kCACHE_LENGTH}), nvinfer1::DataType::kINT64);
159159
}

cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1216,8 +1216,21 @@ void TrtGptModelInflightBatching::forwardAsync(RequestList const& activeRequests
12161216
{
12171217
for (auto const& llmReq : activeRequests)
12181218
{
1219+
// Remove from mInflightReqIds so changeBeamWidth can proceed on the next iteration.
1220+
// terminateRequest frees seqSlot/KV cache but does not clean up mInflightReqIds.
1221+
mInflightReqIds.erase(llmReq->mRequestId);
12191222
terminateRequest(llmReq);
12201223
}
1224+
// Force buffer/decoder reset to clean up any partial state from the aborted batch
1225+
// (e.g. partially-filled cross-KV block offsets from mid-context-chunk processing).
1226+
// Guard on mInflightReqIds.empty(): in pipeline-parallel multi-micro-batch mode,
1227+
// other micro-batches may still have requests tracked here; changeBeamWidth asserts
1228+
// emptiness so we skip the reset and let the next successful forwardAsync iteration
1229+
// perform it when the set is clear.
1230+
if (mWorldConfig.isLastPipelineParallelRank() && mInflightReqIds.empty())
1231+
{
1232+
changeBeamWidth(mOperatingBeamWidth);
1233+
}
12211234
}
12221235
catch (std::exception const& e)
12231236
{

cpp/tensorrt_llm/batch_manager/utils/inflightBatchingUtils.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,11 @@ void copyGenerationLogits(RuntimeBuffers::GenerationLogitsCache& generationLogit
103103

104104
auto const fragmentSize = llmReq.getGenerationLogitsFragmentsSize();
105105

106-
// Merge logits fragments on device
106+
// Merge logits fragments on device. getFragmentPointerSlot() returns the matching host and
107+
// device rows for the current workIdx and advances the index atomically, so concurrent flushes
108+
// for different requests in the same batch never clobber each other's pointer arrays.
107109
auto const& transposeBufferPtr = generationLogitsCache.transposedLogits;
108-
auto const& cachePointerDevice = generationLogitsCache.fragmentPointerDevice;
109-
auto const& cachePointerHost = generationLogitsCache.getFragmentPointerHost();
110+
auto [cachePointerHost, cachePointerDevice] = generationLogitsCache.getFragmentPointerSlot();
110111
tensorrt_llm::runtime::kernels::mergeLogitsFragments(bufferManager, *transposeBufferPtr,
111112
llmReq.getGenerationLogitsFragments(), *cachePointerDevice, *cachePointerHost, 0, 1, reqBeamWidth,
112113
bufferManager.getStream(), 0);

cpp/tests/unit_tests/batch_manager/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,4 @@ add_gtest(rnnCacheFormatterTest rnnCacheFormatterTest.cpp)
3131
add_gtest(cudaGraphExecutorCacheTest cudaGraphExecutorCacheTest.cpp)
3232
add_gtest(agentTreeTest agentTreeTest.cpp)
3333
add_gtest(truncateBlocksTest truncateBlocksTest.cpp)
34+
add_gtest(encDecBeamSearchTest encDecBeamSearchTest.cpp)
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
/*
2+
* SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
19+
#include "tensorrt_llm/batch_manager/llmRequest.h"
20+
#include "tensorrt_llm/batch_manager/runtimeBuffers.h"
21+
#include "tensorrt_llm/batch_manager/utils/inflightBatchingUtils.h"
22+
#include "tensorrt_llm/common/memoryUtils.h"
23+
#include "tensorrt_llm/kernels/kvCacheIndex.h"
24+
#include "tensorrt_llm/runtime/bufferManager.h"
25+
#include "tensorrt_llm/runtime/cudaStream.h"
26+
#include "tensorrt_llm/runtime/iTensor.h"
27+
#include "tensorrt_llm/runtime/samplingConfig.h"
28+
#include "gtest/gtest.h"
29+
#include <memory>
30+
31+
using namespace tensorrt_llm::batch_manager;
32+
using namespace tensorrt_llm::batch_manager::kv_cache_manager;
33+
namespace tr = tensorrt_llm::runtime;
34+
namespace tc = tensorrt_llm::common;
35+
namespace tk = tensorrt_llm::kernels;
36+
using SizeType32 = tr::SizeType32;
37+
38+
// Verify that copyGenerationLogits correctly assembles the host logits buffer
39+
// using the real kernel merge path, and that two back-to-back calls (simulating
40+
// two requests flushing in the same batch) use distinct fragmentPointerDevice
41+
// slots so their pointer arrays do not clobber each other.
42+
TEST(CopyGenerationLogitsTest, KernelMergePathProducesCorrectHostLayoutAndSlotsAreIsolated)
43+
{
44+
SizeType32 constexpr beamWidth = 2;
45+
SizeType32 constexpr numSteps = RuntimeBuffers::GenerationLogitsCache::kCACHE_LENGTH; // full flush
46+
SizeType32 constexpr vocabSize = 8;
47+
SizeType32 constexpr promptLen = 1;
48+
SizeType32 constexpr maxBatchSize = 4; // must be >= 2 to test slot isolation
49+
50+
auto stream = std::make_shared<tr::CudaStream>();
51+
tr::BufferManager bufferMgr{stream};
52+
53+
// Build a real GenerationLogitsCache so that transposedLogits,
54+
// fragmentPointerDevice and fragmentPointerHost are all properly allocated.
55+
// cache.logits uses pinned memory so the test can fill it from the CPU while
56+
// the GPU kernel can still read from it via DMA.
57+
RuntimeBuffers::GenerationLogitsCache cache;
58+
cache.logits = tr::BufferManager::pinnedPool(
59+
tr::ITensor::makeShape({numSteps, maxBatchSize * beamWidth, vocabSize}), nvinfer1::DataType::kFLOAT);
60+
cache.transposedLogits
61+
= bufferMgr.gpu(tr::ITensor::makeShape({beamWidth, numSteps, vocabSize}), nvinfer1::DataType::kFLOAT);
62+
cache.fragmentPointerDevice
63+
= bufferMgr.gpu(tr::ITensor::makeShape({maxBatchSize, numSteps}), nvinfer1::DataType::kINT64);
64+
cache.fragmentPointerHost
65+
= tr::BufferManager::pinnedPool(tr::ITensor::makeShape({maxBatchSize, numSteps}), nvinfer1::DataType::kINT64);
66+
67+
// Helper: build one LlmRequest that has numSteps fragments pointing into
68+
// cache.logits[0..numSteps-1][logitsIndex:logitsIndex+beamWidth].
69+
// Each fragment is filled with sentinel value (step*100 + beam + reqOffset).
70+
auto makeRequest = [&](RequestIdType reqId, SizeType32 logitsIndex, float reqOffset) -> std::shared_ptr<LlmRequest>
71+
{
72+
auto tokens = std::make_shared<VecTokens>(promptLen, 0);
73+
tr::SamplingConfig sc{beamWidth};
74+
auto req = std::make_shared<LlmRequest>(reqId, numSteps, tokens, sc, false);
75+
76+
LlmRequest::BeamTokens gen(beamWidth, VecTokens(numSteps, 1));
77+
req->setGeneratedTokens(gen);
78+
req->allocGenerationLogitsHost(vocabSize, nvinfer1::DataType::kFLOAT);
79+
80+
// Write known values into the logits cache slots for this request and
81+
// create matching fragment slice views.
82+
for (SizeType32 step = 0; step < numSteps; ++step)
83+
{
84+
// cache.logits shape: [numSteps, maxBatchSize*beamWidth, vocabSize]
85+
// Slice to [1, maxBS*bw, vocab], squeeze to [maxBS*bw, vocab].
86+
tr::ITensor::SharedPtr slot = tr::ITensor::slice(cache.logits, step, 1);
87+
slot->squeeze(0); // [maxBS*bw, vocab]
88+
auto* slotPtr = tr::bufferCast<float>(*slot);
89+
for (SizeType32 beam = 0; beam < beamWidth; ++beam)
90+
{
91+
float const val = reqOffset + static_cast<float>(step * 100 + beam);
92+
for (SizeType32 v = 0; v < vocabSize; ++v)
93+
{
94+
slotPtr[(logitsIndex + beam) * vocabSize + v] = val;
95+
}
96+
}
97+
98+
// Fragment matches HandleGenerationLogits: slice [logitsIndex:logitsIndex+beamWidth]
99+
// from the step slot, then unsqueeze(0) → [1, beamWidth, vocab].
100+
tr::ITensor::SharedPtr fragView = tr::ITensor::slice(slot, logitsIndex, beamWidth);
101+
fragView->unsqueeze(0); // [1, beamWidth, vocab]
102+
req->addGenerationLogitsFragment(fragView);
103+
}
104+
return req;
105+
};
106+
107+
// Request 0 occupies logitsIndex=0 in the batch slot.
108+
auto req0 = makeRequest(1, /*logitsIndex=*/0, /*reqOffset=*/0.0f);
109+
// Request 1 occupies logitsIndex=beamWidth in the batch slot.
110+
auto req1 = makeRequest(2, /*logitsIndex=*/beamWidth, /*reqOffset=*/1000.0f);
111+
112+
// Flush request 0 — uses workIdx=0.
113+
utils::copyGenerationLogits(cache, bufferMgr, *req0, /*beforeDecoder=*/false, {});
114+
// Flush request 1 — uses workIdx=1 (different slot → no pointer clobbering).
115+
utils::copyGenerationLogits(cache, bufferMgr, *req1, /*beforeDecoder=*/false, {});
116+
117+
ASSERT_EQ(cudaStreamSynchronize(stream->get()), cudaSuccess);
118+
119+
// Verify req0 host buffer: host[beam, step, v] == step*100 + beam
120+
auto const* host0 = tr::bufferCast<float>(*req0->getGenerationLogitsHost());
121+
for (SizeType32 beam = 0; beam < beamWidth; ++beam)
122+
{
123+
for (SizeType32 step = 0; step < numSteps; ++step)
124+
{
125+
float const expected = static_cast<float>(step * 100 + beam);
126+
for (SizeType32 v = 0; v < vocabSize; ++v)
127+
{
128+
SizeType32 const idx = (beam * numSteps + step) * vocabSize + v;
129+
EXPECT_FLOAT_EQ(host0[idx], expected) << "req0 host[beam=" << beam << ",step=" << step << ",v=" << v
130+
<< "]=" << host0[idx] << " expected " << expected;
131+
}
132+
}
133+
}
134+
135+
// Verify req1 host buffer: host[beam, step, v] == 1000 + step*100 + beam
136+
auto const* host1 = tr::bufferCast<float>(*req1->getGenerationLogitsHost());
137+
for (SizeType32 beam = 0; beam < beamWidth; ++beam)
138+
{
139+
for (SizeType32 step = 0; step < numSteps; ++step)
140+
{
141+
float const expected = 1000.0f + static_cast<float>(step * 100 + beam);
142+
for (SizeType32 v = 0; v < vocabSize; ++v)
143+
{
144+
SizeType32 const idx = (beam * numSteps + step) * vocabSize + v;
145+
EXPECT_FLOAT_EQ(host1[idx], expected) << "req1 host[beam=" << beam << ",step=" << step << ",v=" << v
146+
<< "]=" << host1[idx] << " expected " << expected;
147+
}
148+
}
149+
}
150+
151+
// Both requests must have had their fragments cleared.
152+
EXPECT_EQ(req0->getGenerationLogitsFragmentsSize(), 0);
153+
EXPECT_EQ(req1->getGenerationLogitsFragmentsSize(), 0);
154+
}

0 commit comments

Comments
 (0)