Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions cpp/include/tensorrt_llm/batch_manager/capacityScheduler.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2023-2026, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -97,7 +97,8 @@ class MaxUtilizationScheduler : public BaseCapacityScheduler
public:
MaxUtilizationScheduler(SizeType32 maxNumRequests, bool twoStepsLookAhead,
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE);
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE,
bool enablePrefixAwareScheduling = true);

[[nodiscard]] std::tuple<RequestVector, RequestVector> operator()(
kv_cache_manager::BaseKVCacheManager& kvCacheManager,
Expand All @@ -108,6 +109,8 @@ class MaxUtilizationScheduler : public BaseCapacityScheduler
SizeType32 mMaxNumRequests;
/// @brief Boolean that indicates if two step lookahead is enabled
bool mTwoStepsLookAhead;
/// @brief Whether to use KV prefix-reuse estimates in scheduling decisions.
bool mEnablePrefixAwareScheduling;
};

/// @brief Schedule requests using the GUARANTEED_NO_EVICT policy
Expand All @@ -120,7 +123,8 @@ class GuaranteedNoEvictScheduler : public BaseCapacityScheduler
public:
GuaranteedNoEvictScheduler(SizeType32 maxNumRequests,
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE);
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE,
bool enablePrefixAwareScheduling = true);

[[nodiscard]] std::tuple<RequestVector, RequestVector> operator()(
kv_cache_manager::BaseKVCacheManager const& kvCacheManager,
Expand All @@ -136,6 +140,8 @@ class GuaranteedNoEvictScheduler : public BaseCapacityScheduler

private:
SizeType32 mMaxNumRequests;
/// @brief Whether to use KV prefix-reuse estimates in scheduling decisions.
bool mEnablePrefixAwareScheduling;
};

/// @brief Schedule requests using the STATIC_BATCH policy
Expand All @@ -144,7 +150,8 @@ class StaticBatchScheduler : public GuaranteedNoEvictScheduler
public:
StaticBatchScheduler(SizeType32 maxNumRequests,
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE);
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE,
bool enablePrefixAwareScheduling = true);

[[nodiscard]] std::tuple<RequestVector, RequestVector> operator()(
kv_cache_manager::BaseKVCacheManager const& kvCacheManager,
Expand All @@ -160,7 +167,8 @@ class CapacityScheduler : public Algorithm
explicit CapacityScheduler(SizeType32 maxNumRequests, executor::CapacitySchedulerPolicy capacitySchedulerPolicy,
bool hasKvCacheManager, bool twoStepsLookAhead = false,
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE);
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE,
bool enablePrefixAwareScheduling = true);

/**
* @brief Schedules requests following the selected policy.
Expand Down
9 changes: 8 additions & 1 deletion cpp/include/tensorrt_llm/executor/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -989,6 +989,8 @@ class DynamicBatchConfig

[[nodiscard]] std::vector<std::pair<SizeType32, SizeType32>> getBatchSizeTable() const;

bool operator==(DynamicBatchConfig const& other) const;

/// @brief The default value of batch size table
static std::vector<std::pair<SizeType32, SizeType32>> const kDefaultBatchSizeTable;

Expand Down Expand Up @@ -1019,7 +1021,7 @@ class SchedulerConfig
explicit SchedulerConfig(
CapacitySchedulerPolicy capacitySchedulerPolicy = CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT,
std::optional<ContextChunkingPolicy> contextChunkingPolicy = std::nullopt,
std::optional<DynamicBatchConfig> dynamicBatchConfig = std::nullopt);
std::optional<DynamicBatchConfig> dynamicBatchConfig = std::nullopt, bool enablePrefixAwareScheduling = true);

bool operator==(SchedulerConfig const& other) const;

Expand All @@ -1029,6 +1031,8 @@ class SchedulerConfig

[[nodiscard]] std::optional<DynamicBatchConfig> getDynamicBatchConfig() const;

[[nodiscard]] bool getEnablePrefixAwareScheduling() const;

private:
friend class Serialization;

Expand All @@ -1040,6 +1044,9 @@ class SchedulerConfig

/// @brief The config for tuning batch size dynamically. See DynamicBatchSizeConfig.
std::optional<DynamicBatchConfig> mDynamicBatchConfig;

/// @brief Whether schedulers use KV prefix-reuse estimates for admission and token-budget decisions.
bool mEnablePrefixAwareScheduling;
};

/// @brief Configuration class for the KV cache
Expand Down
81 changes: 53 additions & 28 deletions cpp/tensorrt_llm/batch_manager/capacityScheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,23 +151,26 @@ MaxRequestsScheduler::MaxRequestsScheduler(
}

MaxUtilizationScheduler::MaxUtilizationScheduler(SizeType32 maxNumRequests, bool twoStepsLookAhead,
LlmRequestState noScheduleUntilState, LlmRequestState noScheduleAfterState)
LlmRequestState noScheduleUntilState, LlmRequestState noScheduleAfterState, bool enablePrefixAwareScheduling)
: BaseCapacityScheduler(noScheduleUntilState, noScheduleAfterState)
, mMaxNumRequests(maxNumRequests)
, mTwoStepsLookAhead{twoStepsLookAhead}
, mEnablePrefixAwareScheduling{enablePrefixAwareScheduling}
{
}

GuaranteedNoEvictScheduler::GuaranteedNoEvictScheduler(
SizeType32 maxNumRequests, LlmRequestState noScheduleUntilState, LlmRequestState noScheduleAfterState)
GuaranteedNoEvictScheduler::GuaranteedNoEvictScheduler(SizeType32 maxNumRequests, LlmRequestState noScheduleUntilState,
LlmRequestState noScheduleAfterState, bool enablePrefixAwareScheduling)
: BaseCapacityScheduler(noScheduleUntilState, noScheduleAfterState)
, mMaxNumRequests(maxNumRequests)
, mEnablePrefixAwareScheduling{enablePrefixAwareScheduling}
{
}

StaticBatchScheduler::StaticBatchScheduler(
SizeType32 maxNumRequests, LlmRequestState noScheduleUntilState, LlmRequestState noScheduleAfterState)
: GuaranteedNoEvictScheduler(maxNumRequests, noScheduleUntilState, noScheduleAfterState)
StaticBatchScheduler::StaticBatchScheduler(SizeType32 maxNumRequests, LlmRequestState noScheduleUntilState,
LlmRequestState noScheduleAfterState, bool enablePrefixAwareScheduling)
: GuaranteedNoEvictScheduler(
maxNumRequests, noScheduleUntilState, noScheduleAfterState, enablePrefixAwareScheduling)
{
}

Expand Down Expand Up @@ -226,7 +229,7 @@ std::tuple<RequestVector, RequestVector> GuaranteedNoEvictScheduler::impl(
= peftCacheManager ? peftCacheManager->getMaxDevicePages() : std::numeric_limits<SizeType32>::max();

// The optimization of delaying requests won't work for variable window attention
bool skippingIsRelevant = (!kvCacheManager.getBlockManager().isVariableWindow())
bool skippingIsRelevant = mEnablePrefixAwareScheduling && (!kvCacheManager.getBlockManager().isVariableWindow())
&& (!crossKvCacheManager || !crossKvCacheManager->getBlockManager().isVariableWindow());

// Keep track of blocks contributed by requests in context phase
Expand Down Expand Up @@ -323,28 +326,39 @@ std::tuple<RequestVector, RequestVector> GuaranteedNoEvictScheduler::impl(
bool const isEncoderInit = req->isEncoderInitState();
std::optional<kv_cache_manager::PrefixReuseSummary> summary;
std::optional<kv_cache_manager::PrefixReuseSummary> crossSummary;
if (isFirstChunkContext)
if (mEnablePrefixAwareScheduling)
{
// analyzePrefixReuse asserts on variable-window managers; skip the walk there
// and let downstream callers fall back to their fresh tree-walk path.
if (kvCacheManager.isEnableBlockReuse() && !kvCacheManager.getBlockManager().isVariableWindow())
if (isFirstChunkContext)
{
auto uniqueTokens = req->getUniqueTokens(0);
summary = kvCacheManager.analyzePrefixReuse(uniqueTokens, *req);
// analyzePrefixReuse asserts on variable-window managers; skip the walk there
// and let downstream callers fall back to their fresh tree-walk path.
if (kvCacheManager.isEnableBlockReuse() && !kvCacheManager.getBlockManager().isVariableWindow())
{
auto uniqueTokens = req->getUniqueTokens(0);
summary = kvCacheManager.analyzePrefixReuse(uniqueTokens, *req);
}
if (crossKvCacheManager && crossKvCacheManager->isEnableBlockReuse()
&& !crossKvCacheManager->getBlockManager().isVariableWindow())
{
auto uniqueTokens = *(req->getEncoderUniqueTokens().value());
crossSummary = crossKvCacheManager->analyzePrefixReuse(uniqueTokens, *req);
}
}
if (crossKvCacheManager && crossKvCacheManager->isEnableBlockReuse()
else if (isEncoderInit && crossKvCacheManager && crossKvCacheManager->isEnableBlockReuse()
&& !crossKvCacheManager->getBlockManager().isVariableWindow())
{
// Encoder admission only needs the cross summary for reuse ordering.
auto uniqueTokens = *(req->getEncoderUniqueTokens().value());
crossSummary = crossKvCacheManager->analyzePrefixReuse(uniqueTokens, *req);
}
}
else if (isEncoderInit && crossKvCacheManager && crossKvCacheManager->isEnableBlockReuse()
&& !crossKvCacheManager->getBlockManager().isVariableWindow())
else if (isFirstChunkContext)
{
// Encoder admission only needs the cross summary for reuse ordering.
auto uniqueTokens = *(req->getEncoderUniqueTokens().value());
crossSummary = crossKvCacheManager->analyzePrefixReuse(uniqueTokens, *req);
summary = kv_cache_manager::PrefixReuseSummary{};
if (crossKvCacheManager)
{
crossSummary = kv_cache_manager::PrefixReuseSummary{};
}
}
// Beneficial-to-skip check using the cached summary
if (!StaticBatchScheduling && skippingIsRelevant && (isFirstChunkContext || isEncoderInit)
Expand Down Expand Up @@ -442,7 +456,7 @@ std::tuple<RequestVector, RequestVector> MaxUtilizationScheduler::operator()(
}

// The optimization of delaying requests won't work for variable window attention
bool skippingIsRelevant = !kvCacheManager.getBlockManager().isVariableWindow();
bool skippingIsRelevant = mEnablePrefixAwareScheduling && !kvCacheManager.getBlockManager().isVariableWindow();

// Keep track of number of requests and block needed for the scheduled requests
auto scheduledBlocksManager
Expand All @@ -459,8 +473,13 @@ std::tuple<RequestVector, RequestVector> MaxUtilizationScheduler::operator()(
std::unordered_set<uint64_t> seenTaskIds;

// Keep track of blocks contributed by requests in context phase
auto [newlyContributedContextBlocks, newlyContributedCrossContextBlocks]
= prefillWithChunkedContextsAlreadyExecuting(activeRequests, kvCacheManager);
std::unordered_set<BlockKey, BlockKeyHasher> newlyContributedContextBlocks;
std::unordered_set<BlockKey, BlockKeyHasher> newlyContributedCrossContextBlocks;
if (skippingIsRelevant)
{
std::tie(newlyContributedContextBlocks, newlyContributedCrossContextBlocks)
= prefillWithChunkedContextsAlreadyExecuting(activeRequests, kvCacheManager);
}

// Find last active in case we need to evict. Encoder-init requests are
// intentionally excluded here: they hold no started self- or cross-pool
Expand Down Expand Up @@ -511,7 +530,11 @@ std::tuple<RequestVector, RequestVector> MaxUtilizationScheduler::operator()(
std::optional<kv_cache_manager::PrefixReuseSummary> summary;
// analyzePrefixReuse asserts on variable-window managers; skip the walk there
// and let downstream callers fall back to their fresh tree-walk path.
if (isFirstChunkContext && kvCacheManager.isEnableBlockReuse()
if (isFirstChunkContext && !mEnablePrefixAwareScheduling)
{
summary = kv_cache_manager::PrefixReuseSummary{};
}
else if (isFirstChunkContext && kvCacheManager.isEnableBlockReuse()
&& !kvCacheManager.getBlockManager().isVariableWindow())
{
auto uniqueTokens = req->getUniqueTokens(0);
Expand Down Expand Up @@ -644,24 +667,26 @@ bool trySchedulingRequestMaxUtilization(std::shared_ptr<LlmRequest> const& req,

CapacityScheduler::CapacityScheduler(SizeType32 maxNumRequests,
executor::CapacitySchedulerPolicy capacitySchedulerPolicy, bool hasKvCacheManager, bool twoStepsLookAhead,
LlmRequestState noScheduleUntilState, LlmRequestState noScheduleAfterState)
LlmRequestState noScheduleUntilState, LlmRequestState noScheduleAfterState, bool enablePrefixAwareScheduling)
{
if (!hasKvCacheManager)
{
mScheduler = MaxRequestsScheduler{maxNumRequests, noScheduleUntilState, noScheduleAfterState};
}
else if (capacitySchedulerPolicy == executor::CapacitySchedulerPolicy::kMAX_UTILIZATION)
{
mScheduler
= MaxUtilizationScheduler{maxNumRequests, twoStepsLookAhead, noScheduleUntilState, noScheduleAfterState};
mScheduler = MaxUtilizationScheduler{
maxNumRequests, twoStepsLookAhead, noScheduleUntilState, noScheduleAfterState, enablePrefixAwareScheduling};
}
else if (capacitySchedulerPolicy == executor::CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT)
{
mScheduler = GuaranteedNoEvictScheduler{maxNumRequests, noScheduleUntilState, noScheduleAfterState};
mScheduler = GuaranteedNoEvictScheduler{
maxNumRequests, noScheduleUntilState, noScheduleAfterState, enablePrefixAwareScheduling};
}
else if (capacitySchedulerPolicy == executor::CapacitySchedulerPolicy::kSTATIC_BATCH)
{
mScheduler = StaticBatchScheduler{maxNumRequests, noScheduleUntilState, noScheduleAfterState};
mScheduler = StaticBatchScheduler{
maxNumRequests, noScheduleUntilState, noScheduleAfterState, enablePrefixAwareScheduling};
}
else
{
Expand Down
9 changes: 6 additions & 3 deletions cpp/tensorrt_llm/batch_manager/trtEncoderModel.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -76,8 +76,11 @@ TrtEncoderModel::TrtEncoderModel(runtime::ModelConfig const& modelConfig, WorldC
// handling of maximizing utilization or pause/evict
// TODO: finer control on encoder requests scheduling
mCapacityScheduler = std::make_unique<tensorrt_llm::batch_manager::CapacityScheduler>(
getMaxBatchSize() * mNumMicroBatches, executorConfig.getSchedulerConfig().getCapacitySchedulerPolicy(), false,
false, LlmRequestState::kENCODER_INIT, LlmRequestState::kCONTEXT_INIT);
getMaxBatchSize() * mNumMicroBatches, executorConfig.getSchedulerConfig().getCapacitySchedulerPolicy(),
/*hasKvCacheManager=*/false, /*twoStepsLookAhead=*/false,
/*noScheduleUntilState=*/LlmRequestState::kENCODER_INIT,
/*noScheduleAfterState=*/LlmRequestState::kCONTEXT_INIT,
/*enablePrefixAwareScheduling=*/executorConfig.getSchedulerConfig().getEnablePrefixAwareScheduling());

mMicroBatchScheduler = std::make_unique<tensorrt_llm::batch_manager::MicroBatchScheduler>(
std::nullopt, mModelConfig.getMaxInputLen(), LlmRequestState::kENCODER_INIT, LlmRequestState::kCONTEXT_INIT);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,10 @@ TrtGptModelInflightBatching::TrtGptModelInflightBatching(std::shared_ptr<nvinfer

mCapacityScheduler = std::make_unique<CapacityScheduler>(getMaxNumSequences(),
executorConfig.getSchedulerConfig().getCapacitySchedulerPolicy(), mKvCacheManager != nullptr,
mWorldConfig.isPipelineParallel());
/*twoStepsLookAhead=*/mWorldConfig.isPipelineParallel(),
/*noScheduleUntilState=*/LlmRequestState::kCONTEXT_INIT,
/*noScheduleAfterState=*/LlmRequestState::kGENERATION_COMPLETE,
/*enablePrefixAwareScheduling=*/executorConfig.getSchedulerConfig().getEnablePrefixAwareScheduling());

mMicroBatchScheduler = std::make_unique<MicroBatchScheduler>(ctxChunkConfig, maxContextLength);

Expand Down
10 changes: 9 additions & 1 deletion cpp/tensorrt_llm/executor/dynamicBatchConfig.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -49,6 +49,14 @@ std::vector<std::pair<SizeType32, SizeType32>> DynamicBatchConfig::getBatchSizeT
return mBatchSizeTable;
}

bool DynamicBatchConfig::operator==(DynamicBatchConfig const& other) const
{
return mEnableBatchSizeTuning == other.mEnableBatchSizeTuning
&& mEnableMaxNumTokensTuning == other.mEnableMaxNumTokensTuning
&& mDynamicBatchMovingAverageWindow == other.mDynamicBatchMovingAverageWindow
&& mBatchSizeTable == other.mBatchSizeTable;
}

std::vector<std::pair<SizeType32, SizeType32>> const DynamicBatchConfig::kDefaultBatchSizeTable{
{144, 128},
{336, 256},
Expand Down
14 changes: 11 additions & 3 deletions cpp/tensorrt_llm/executor/schedulerConfig.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -21,17 +21,20 @@ namespace tensorrt_llm::executor
{

SchedulerConfig::SchedulerConfig(CapacitySchedulerPolicy capacitySchedulerPolicy,
std::optional<ContextChunkingPolicy> contextChunkingPolicy, std::optional<DynamicBatchConfig> dynamicBatchConfig)
std::optional<ContextChunkingPolicy> contextChunkingPolicy, std::optional<DynamicBatchConfig> dynamicBatchConfig,
bool enablePrefixAwareScheduling)
: mCapacitySchedulerPolicy(capacitySchedulerPolicy)
, mContextChunkingPolicy(std::move(contextChunkingPolicy))
, mDynamicBatchConfig(std::move(dynamicBatchConfig))
, mEnablePrefixAwareScheduling(enablePrefixAwareScheduling)
{
}

bool SchedulerConfig::operator==(SchedulerConfig const& other) const
{
return mCapacitySchedulerPolicy == other.mCapacitySchedulerPolicy
&& mContextChunkingPolicy == other.mContextChunkingPolicy;
&& mContextChunkingPolicy == other.mContextChunkingPolicy && mDynamicBatchConfig == other.mDynamicBatchConfig
&& mEnablePrefixAwareScheduling == other.mEnablePrefixAwareScheduling;
Comment thread
coderabbitai[bot] marked this conversation as resolved.
}

[[nodiscard]] CapacitySchedulerPolicy SchedulerConfig::getCapacitySchedulerPolicy() const
Expand All @@ -49,4 +52,9 @@ bool SchedulerConfig::operator==(SchedulerConfig const& other) const
return mDynamicBatchConfig;
}

[[nodiscard]] bool SchedulerConfig::getEnablePrefixAwareScheduling() const
{
return mEnablePrefixAwareScheduling;
}

} // namespace tensorrt_llm::executor
Loading
Loading