|
| 1 | +/* |
| 2 | + * SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 3 | + * SPDX-License-Identifier: Apache-2.0 |
| 4 | + * |
| 5 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 6 | + * you may not use this file except in compliance with the License. |
| 7 | + * You may obtain a copy of the License at |
| 8 | + * |
| 9 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | + * |
| 11 | + * Unless required by applicable law or agreed to in writing, software |
| 12 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | + * See the License for the specific language governing permissions and |
| 15 | + * limitations under the License. |
| 16 | + */ |
| 17 | + |
| 18 | +#include "tensorrt_llm/batch_manager/kvCacheManager.h" |
| 19 | +#include "tensorrt_llm/batch_manager/llmRequest.h" |
| 20 | +#include "tensorrt_llm/batch_manager/runtimeBuffers.h" |
| 21 | +#include "tensorrt_llm/batch_manager/utils/inflightBatchingUtils.h" |
| 22 | +#include "tensorrt_llm/common/memoryUtils.h" |
| 23 | +#include "tensorrt_llm/kernels/kvCacheIndex.h" |
| 24 | +#include "tensorrt_llm/runtime/bufferManager.h" |
| 25 | +#include "tensorrt_llm/runtime/cudaStream.h" |
| 26 | +#include "tensorrt_llm/runtime/iTensor.h" |
| 27 | +#include "tensorrt_llm/runtime/samplingConfig.h" |
| 28 | +#include "gtest/gtest.h" |
| 29 | +#include <memory> |
| 30 | + |
| 31 | +using namespace tensorrt_llm::batch_manager; |
| 32 | +using namespace tensorrt_llm::batch_manager::kv_cache_manager; |
| 33 | +namespace tr = tensorrt_llm::runtime; |
| 34 | +namespace tc = tensorrt_llm::common; |
| 35 | +namespace tk = tensorrt_llm::kernels; |
| 36 | +using SizeType32 = tr::SizeType32; |
| 37 | + |
| 38 | +// Verify that copyGenerationLogits correctly assembles the host logits buffer |
| 39 | +// using the real kernel merge path, and that two back-to-back calls (simulating |
| 40 | +// two requests flushing in the same batch) use distinct fragmentPointerDevice |
| 41 | +// slots so their pointer arrays do not clobber each other. |
| 42 | +TEST(CopyGenerationLogitsTest, KernelMergePathProducesCorrectHostLayoutAndSlotsAreIsolated) |
| 43 | +{ |
| 44 | + SizeType32 constexpr beamWidth = 2; |
| 45 | + SizeType32 constexpr numSteps = RuntimeBuffers::GenerationLogitsCache::kCACHE_LENGTH; // full flush |
| 46 | + SizeType32 constexpr vocabSize = 8; |
| 47 | + SizeType32 constexpr promptLen = 1; |
| 48 | + SizeType32 constexpr maxBatchSize = 4; // must be >= 2 to test slot isolation |
| 49 | + |
| 50 | + auto stream = std::make_shared<tr::CudaStream>(); |
| 51 | + tr::BufferManager bufferMgr{stream}; |
| 52 | + |
| 53 | + // Build a real GenerationLogitsCache so that transposedLogits, |
| 54 | + // fragmentPointerDevice and fragmentPointerHost are all properly allocated. |
| 55 | + // cache.logits uses pinned memory so the test can fill it from the CPU while |
| 56 | + // the GPU kernel can still read from it via DMA. |
| 57 | + RuntimeBuffers::GenerationLogitsCache cache; |
| 58 | + cache.logits = tr::BufferManager::pinnedPool( |
| 59 | + tr::ITensor::makeShape({numSteps, maxBatchSize * beamWidth, vocabSize}), nvinfer1::DataType::kFLOAT); |
| 60 | + cache.transposedLogits |
| 61 | + = bufferMgr.gpu(tr::ITensor::makeShape({beamWidth, numSteps, vocabSize}), nvinfer1::DataType::kFLOAT); |
| 62 | + cache.fragmentPointerDevice |
| 63 | + = bufferMgr.gpu(tr::ITensor::makeShape({maxBatchSize, numSteps}), nvinfer1::DataType::kINT64); |
| 64 | + cache.fragmentPointerHost |
| 65 | + = tr::BufferManager::pinnedPool(tr::ITensor::makeShape({maxBatchSize, numSteps}), nvinfer1::DataType::kINT64); |
| 66 | + |
| 67 | + // Helper: build one LlmRequest that has numSteps fragments pointing into |
| 68 | + // cache.logits[0..numSteps-1][logitsIndex:logitsIndex+beamWidth]. |
| 69 | + // Each fragment is filled with sentinel value (step*100 + beam + reqOffset). |
| 70 | + auto makeRequest = [&](RequestIdType reqId, SizeType32 logitsIndex, float reqOffset) -> std::shared_ptr<LlmRequest> |
| 71 | + { |
| 72 | + auto tokens = std::make_shared<VecTokens>(promptLen, 0); |
| 73 | + tr::SamplingConfig sc{beamWidth}; |
| 74 | + auto req = std::make_shared<LlmRequest>(reqId, numSteps, tokens, sc, false); |
| 75 | + |
| 76 | + LlmRequest::BeamTokens gen(beamWidth, VecTokens(numSteps, 1)); |
| 77 | + req->setGeneratedTokens(gen); |
| 78 | + req->allocGenerationLogitsHost(vocabSize, nvinfer1::DataType::kFLOAT); |
| 79 | + |
| 80 | + // Write known values into the logits cache slots for this request and |
| 81 | + // create matching fragment slice views. |
| 82 | + for (SizeType32 step = 0; step < numSteps; ++step) |
| 83 | + { |
| 84 | + // cache.logits shape: [numSteps, maxBatchSize*beamWidth, vocabSize] |
| 85 | + // Slice to [1, maxBS*bw, vocab], squeeze to [maxBS*bw, vocab]. |
| 86 | + tr::ITensor::SharedPtr slot = tr::ITensor::slice(cache.logits, step, 1); |
| 87 | + slot->squeeze(0); // [maxBS*bw, vocab] |
| 88 | + auto* slotPtr = tr::bufferCast<float>(*slot); |
| 89 | + for (SizeType32 beam = 0; beam < beamWidth; ++beam) |
| 90 | + { |
| 91 | + float const val = reqOffset + static_cast<float>(step * 100 + beam); |
| 92 | + for (SizeType32 v = 0; v < vocabSize; ++v) |
| 93 | + { |
| 94 | + slotPtr[(logitsIndex + beam) * vocabSize + v] = val; |
| 95 | + } |
| 96 | + } |
| 97 | + |
| 98 | + // Fragment matches HandleGenerationLogits: slice [logitsIndex:logitsIndex+beamWidth] |
| 99 | + // from the step slot, then unsqueeze(0) → [1, beamWidth, vocab]. |
| 100 | + tr::ITensor::SharedPtr fragView = tr::ITensor::slice(slot, logitsIndex, beamWidth); |
| 101 | + fragView->unsqueeze(0); // [1, beamWidth, vocab] |
| 102 | + req->addGenerationLogitsFragment(fragView); |
| 103 | + } |
| 104 | + return req; |
| 105 | + }; |
| 106 | + |
| 107 | + // Request 0 occupies logitsIndex=0 in the batch slot. |
| 108 | + auto req0 = makeRequest(1, /*logitsIndex=*/0, /*reqOffset=*/0.0f); |
| 109 | + // Request 1 occupies logitsIndex=beamWidth in the batch slot. |
| 110 | + auto req1 = makeRequest(2, /*logitsIndex=*/beamWidth, /*reqOffset=*/1000.0f); |
| 111 | + |
| 112 | + // Flush request 0 — uses workIdx=0. |
| 113 | + utils::copyGenerationLogits(cache, bufferMgr, *req0, /*beforeDecoder=*/false, {}); |
| 114 | + // Flush request 1 — uses workIdx=1 (different slot → no pointer clobbering). |
| 115 | + utils::copyGenerationLogits(cache, bufferMgr, *req1, /*beforeDecoder=*/false, {}); |
| 116 | + |
| 117 | + ASSERT_EQ(cudaStreamSynchronize(stream->get()), cudaSuccess); |
| 118 | + |
| 119 | + // Verify req0 host buffer: host[beam, step, v] == step*100 + beam |
| 120 | + auto const* host0 = tr::bufferCast<float>(*req0->getGenerationLogitsHost()); |
| 121 | + for (SizeType32 beam = 0; beam < beamWidth; ++beam) |
| 122 | + { |
| 123 | + for (SizeType32 step = 0; step < numSteps; ++step) |
| 124 | + { |
| 125 | + float const expected = static_cast<float>(step * 100 + beam); |
| 126 | + for (SizeType32 v = 0; v < vocabSize; ++v) |
| 127 | + { |
| 128 | + SizeType32 const idx = (beam * numSteps + step) * vocabSize + v; |
| 129 | + EXPECT_FLOAT_EQ(host0[idx], expected) << "req0 host[beam=" << beam << ",step=" << step << ",v=" << v |
| 130 | + << "]=" << host0[idx] << " expected " << expected; |
| 131 | + } |
| 132 | + } |
| 133 | + } |
| 134 | + |
| 135 | + // Verify req1 host buffer: host[beam, step, v] == 1000 + step*100 + beam |
| 136 | + auto const* host1 = tr::bufferCast<float>(*req1->getGenerationLogitsHost()); |
| 137 | + for (SizeType32 beam = 0; beam < beamWidth; ++beam) |
| 138 | + { |
| 139 | + for (SizeType32 step = 0; step < numSteps; ++step) |
| 140 | + { |
| 141 | + float const expected = 1000.0f + static_cast<float>(step * 100 + beam); |
| 142 | + for (SizeType32 v = 0; v < vocabSize; ++v) |
| 143 | + { |
| 144 | + SizeType32 const idx = (beam * numSteps + step) * vocabSize + v; |
| 145 | + EXPECT_FLOAT_EQ(host1[idx], expected) << "req1 host[beam=" << beam << ",step=" << step << ",v=" << v |
| 146 | + << "]=" << host1[idx] << " expected " << expected; |
| 147 | + } |
| 148 | + } |
| 149 | + } |
| 150 | + |
| 151 | + // Both requests must have had their fragments cleared. |
| 152 | + EXPECT_EQ(req0->getGenerationLogitsFragmentsSize(), 0); |
| 153 | + EXPECT_EQ(req1->getGenerationLogitsFragmentsSize(), 0); |
| 154 | +} |
0 commit comments