Skip to content

Commit 9ac53b3

Browse files
committed
Address comments from Robin and Tyler.
Signed-off-by: Simeng Liu <simengl@nvidia.com>
1 parent 9733bdb commit 9ac53b3

6 files changed

Lines changed: 36 additions & 50 deletions

File tree

cpp/include/tensorrt_llm/executor/executor.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -989,6 +989,8 @@ class DynamicBatchConfig
989989

990990
[[nodiscard]] std::vector<std::pair<SizeType32, SizeType32>> getBatchSizeTable() const;
991991

992+
bool operator==(DynamicBatchConfig const& other) const;
993+
992994
/// @brief The default value of batch size table
993995
static std::vector<std::pair<SizeType32, SizeType32>> const kDefaultBatchSizeTable;
994996

cpp/tensorrt_llm/batch_manager/capacityScheduler.cpp

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -314,38 +314,39 @@ std::tuple<RequestVector, RequestVector> GuaranteedNoEvictScheduler::impl(
314314
bool const isEncoderInit = req->isEncoderInitState();
315315
std::optional<kv_cache_manager::PrefixReuseSummary> summary;
316316
std::optional<kv_cache_manager::PrefixReuseSummary> crossSummary;
317-
if (isFirstChunkContext)
317+
if (mEnablePrefixAwareScheduling)
318318
{
319-
// analyzePrefixReuse asserts on variable-window managers; skip the walk there
320-
// and let downstream callers fall back to their fresh tree-walk path.
321-
if (!mEnablePrefixAwareScheduling)
319+
if (isFirstChunkContext)
322320
{
323-
summary = kv_cache_manager::PrefixReuseSummary{};
324-
if (crossKvCacheManager)
321+
// analyzePrefixReuse asserts on variable-window managers; skip the walk there
322+
// and let downstream callers fall back to their fresh tree-walk path.
323+
if (kvCacheManager.isEnableBlockReuse() && !kvCacheManager.getBlockManager().isVariableWindow())
325324
{
326-
crossSummary = kv_cache_manager::PrefixReuseSummary{};
325+
auto uniqueTokens = req->getUniqueTokens(0);
326+
summary = kvCacheManager.analyzePrefixReuse(uniqueTokens, *req);
327+
}
328+
if (crossKvCacheManager && crossKvCacheManager->isEnableBlockReuse()
329+
&& !crossKvCacheManager->getBlockManager().isVariableWindow())
330+
{
331+
auto uniqueTokens = *(req->getEncoderUniqueTokens().value());
332+
crossSummary = crossKvCacheManager->analyzePrefixReuse(uniqueTokens, *req);
327333
}
328334
}
329-
else if (kvCacheManager.isEnableBlockReuse()
330-
&& !kvCacheManager.getBlockManager().isVariableWindow())
331-
{
332-
auto uniqueTokens = req->getUniqueTokens(0);
333-
summary = kvCacheManager.analyzePrefixReuse(uniqueTokens, *req);
334-
}
335-
if (mEnablePrefixAwareScheduling && crossKvCacheManager && crossKvCacheManager->isEnableBlockReuse()
335+
else if (isEncoderInit && crossKvCacheManager && crossKvCacheManager->isEnableBlockReuse()
336336
&& !crossKvCacheManager->getBlockManager().isVariableWindow())
337337
{
338+
// Encoder admission only needs the cross summary for reuse ordering.
338339
auto uniqueTokens = *(req->getEncoderUniqueTokens().value());
339340
crossSummary = crossKvCacheManager->analyzePrefixReuse(uniqueTokens, *req);
340341
}
341342
}
342-
else if (mEnablePrefixAwareScheduling && isEncoderInit && crossKvCacheManager
343-
&& crossKvCacheManager->isEnableBlockReuse()
344-
&& !crossKvCacheManager->getBlockManager().isVariableWindow())
343+
else if (isFirstChunkContext)
345344
{
346-
// Encoder admission only needs the cross summary for reuse ordering.
347-
auto uniqueTokens = *(req->getEncoderUniqueTokens().value());
348-
crossSummary = crossKvCacheManager->analyzePrefixReuse(uniqueTokens, *req);
345+
summary = kv_cache_manager::PrefixReuseSummary{};
346+
if (crossKvCacheManager)
347+
{
348+
crossSummary = kv_cache_manager::PrefixReuseSummary{};
349+
}
349350
}
350351
// Beneficial-to-skip check using the cached summary
351352
if (!StaticBatchScheduling && skippingIsRelevant && (isFirstChunkContext || isEncoderInit)

cpp/tensorrt_llm/executor/dynamicBatchConfig.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
* SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
33
* SPDX-License-Identifier: Apache-2.0
44
*
55
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -49,6 +49,14 @@ std::vector<std::pair<SizeType32, SizeType32>> DynamicBatchConfig::getBatchSizeT
4949
return mBatchSizeTable;
5050
}
5151

52+
bool DynamicBatchConfig::operator==(DynamicBatchConfig const& other) const
53+
{
54+
return mEnableBatchSizeTuning == other.mEnableBatchSizeTuning
55+
&& mEnableMaxNumTokensTuning == other.mEnableMaxNumTokensTuning
56+
&& mDynamicBatchMovingAverageWindow == other.mDynamicBatchMovingAverageWindow
57+
&& mBatchSizeTable == other.mBatchSizeTable;
58+
}
59+
5260
std::vector<std::pair<SizeType32, SizeType32>> const DynamicBatchConfig::kDefaultBatchSizeTable{
5361
{144, 128},
5462
{336, 256},

cpp/tensorrt_llm/executor/schedulerConfig.cpp

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -20,29 +20,6 @@
2020
namespace tensorrt_llm::executor
2121
{
2222

23-
namespace
24-
{
25-
26-
bool dynamicBatchConfigsEqual(
27-
std::optional<DynamicBatchConfig> const& lhs, std::optional<DynamicBatchConfig> const& rhs)
28-
{
29-
if (lhs.has_value() != rhs.has_value())
30-
{
31-
return false;
32-
}
33-
if (!lhs.has_value())
34-
{
35-
return true;
36-
}
37-
38-
return lhs->getEnableBatchSizeTuning() == rhs->getEnableBatchSizeTuning()
39-
&& lhs->getEnableMaxNumTokensTuning() == rhs->getEnableMaxNumTokensTuning()
40-
&& lhs->getDynamicBatchMovingAverageWindow() == rhs->getDynamicBatchMovingAverageWindow()
41-
&& lhs->getBatchSizeTable() == rhs->getBatchSizeTable();
42-
}
43-
44-
} // namespace
45-
4623
SchedulerConfig::SchedulerConfig(CapacitySchedulerPolicy capacitySchedulerPolicy,
4724
std::optional<ContextChunkingPolicy> contextChunkingPolicy, std::optional<DynamicBatchConfig> dynamicBatchConfig,
4825
bool enablePrefixAwareScheduling)
@@ -56,8 +33,7 @@ SchedulerConfig::SchedulerConfig(CapacitySchedulerPolicy capacitySchedulerPolicy
5633
bool SchedulerConfig::operator==(SchedulerConfig const& other) const
5734
{
5835
return mCapacitySchedulerPolicy == other.mCapacitySchedulerPolicy
59-
&& mContextChunkingPolicy == other.mContextChunkingPolicy
60-
&& dynamicBatchConfigsEqual(mDynamicBatchConfig, other.mDynamicBatchConfig)
36+
&& mContextChunkingPolicy == other.mContextChunkingPolicy && mDynamicBatchConfig == other.mDynamicBatchConfig
6137
&& mEnablePrefixAwareScheduling == other.mEnablePrefixAwareScheduling;
6238
}
6339

cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,14 +76,13 @@ void initConfigBindings(nb::module_& m)
7676

7777
auto schedulerConfigSetstate = [](tle::SchedulerConfig& self, nb::tuple const& state)
7878
{
79-
if (state.size() != 3 && state.size() != 4)
79+
if (state.size() != 4)
8080
{
8181
throw std::runtime_error("Invalid state!");
8282
}
83-
bool const enablePrefixAwareScheduling = state.size() == 4 ? nb::cast<bool>(state[3]) : true;
8483
new (&self) tle::SchedulerConfig(nb::cast<tle::CapacitySchedulerPolicy>(state[0]),
8584
nb::cast<std::optional<tle::ContextChunkingPolicy>>(state[1]),
86-
nb::cast<std::optional<tle::DynamicBatchConfig>>(state[2]), enablePrefixAwareScheduling);
85+
nb::cast<std::optional<tle::DynamicBatchConfig>>(state[2]), nb::cast<bool>(state[3]));
8786
};
8887
auto schedulerConfigGetstate = [](tle::SchedulerConfig const& self)
8988
{

tests/unittest/bindings/test_executor_bindings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1372,7 +1372,7 @@ def test_dynamic_batch_config_pickle():
13721372
assert config_copy.dynamic_batch_moving_average_window == 128
13731373

13741374

1375-
def test_scheduler_config() -> None:
1375+
def test_scheduler_config():
13761376
capacity_scheduler_policy = trtllm.CapacitySchedulerPolicy.GUARANTEED_NO_EVICT
13771377
config = trtllm.SchedulerConfig()
13781378
assert config.capacity_scheduler_policy == capacity_scheduler_policy

0 commit comments

Comments
 (0)