|
1 | 1 | /* |
2 | | - * Copyright (c) 2020-2024, NVIDIA CORPORATION. All rights reserved. |
| 2 | + * Copyright (c) 2020-2026, NVIDIA CORPORATION. All rights reserved. |
3 | 3 | * |
4 | 4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | 5 | * you may not use this file except in compliance with the License. |
@@ -728,7 +728,7 @@ namespace tensorrt_llm::runtime::kernels |
728 | 728 | { |
729 | 729 | // Must be similar to [cpp/tensorrt_llm/thop/gatherTreeOp.cpp] gatherTree |
730 | 730 | void gatherTree(DecodingOutput const& decodingOutput, DecodingInput const& decodingInput, |
731 | | - SamplingConfig const& samplingConfig, runtime::CudaStream const& cudaStream) |
| 731 | + SamplingConfig const& samplingConfig, runtime::CudaStream const& cudaStream, runtime::SizeType32 batchSlot) |
732 | 732 | { |
733 | 733 | TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); |
734 | 734 |
|
@@ -781,15 +781,32 @@ void gatherTree(DecodingOutput const& decodingOutput, DecodingInput const& decod |
781 | 781 | lengthPenaltyPtr = manager.copyFrom(lengthPenaltyVec, ITensor::makeShape({batchSize}), runtime::MemoryType::kGPU); |
782 | 782 |
|
783 | 783 | tensorrt_llm::kernels::BeamHypotheses bh; |
784 | | - bh.nMaxBatchSize = batchSize; |
| 784 | + // logProbsTiled has shape [MSL, maxNumSequences, BM] and is passed unsliced. |
| 785 | + // nMaxBatchSize must equal the allocation stride (dim-1), not the per-slot batchSize=1. |
| 786 | + // The pointer is pre-offset by batchSlot*BM so that insertUnfinishedPathKernel, |
| 787 | + // which uses bid=0 / nBatchSize=1, computes: |
| 788 | + // (base + batchSlot*BM)[step * maxBS * BM + 0*BM + beamIdx] |
| 789 | + // = base[step * maxBS * BM + batchSlot * BM + beamIdx] |
| 790 | + // = logProbsTiled[step][batchSlot][beamIdx] ✓ |
| 791 | + auto const logProbsTiledMaxBatchSize = static_cast<SizeType32>(decodingOutput.logProbsTiled->getShape().d[1]); |
| 792 | + auto const logProbsTiledBeamWidth = static_cast<SizeType32>(decodingOutput.logProbsTiled->getShape().d[2]); |
| 793 | + TLLM_CHECK_WITH_INFO(batchSlot < logProbsTiledMaxBatchSize, |
| 794 | + "batchSlot (%d) must be < logProbsTiled maxBatchSize (%d); " |
| 795 | + "logProbsTiled would be accessed out of bounds.", |
| 796 | + batchSlot, logProbsTiledMaxBatchSize); |
| 797 | + TLLM_CHECK_WITH_INFO(beamWidth == logProbsTiledBeamWidth, |
| 798 | + "beamWidth (%d) must equal logProbsTiled BM dimension (%d); " |
| 799 | + "pointer offset batchSlot*beamWidth would be misaligned.", |
| 800 | + beamWidth, logProbsTiledBeamWidth); |
| 801 | + bh.nMaxBatchSize = logProbsTiledMaxBatchSize; |
785 | 802 | bh.nBatchSize = batchSize; |
786 | 803 | bh.nBeamWidth = beamWidth; |
787 | 804 | bh.nMaxSeqLen = maxSeqLength; |
788 | 805 | bh.lengthPenalties = bufferCast<float>(*lengthPenaltyPtr); |
789 | 806 | bh.inputLengths = bufferCast<SizeType32>(*decodingInput.lengths); |
790 | 807 | bh.outputIds = bufferCast<TokenIdType>(finalOutputIds); |
791 | 808 | bh.logProbs = bufferCastOrNull<float>(decodingOutput.logProbs); |
792 | | - bh.logProbsTiled = bufferCast<float>(*decodingOutput.logProbsTiled); |
| 809 | + bh.logProbsTiled = bufferCast<float>(*decodingOutput.logProbsTiled) + batchSlot * beamWidth; |
793 | 810 | bh.sequenceLengths = bufferCast<SizeType32>(*decodingOutput.lengths); |
794 | 811 | bh.cumLogProbs = bufferCast<float>(*decodingOutput.cumLogProbs); |
795 | 812 | bh.outputIdsCBA = bufferCast<TokenIdType>(*decodingOutput.beamHypotheses.outputIdsCBA); |
|
0 commit comments