Skip to content

Commit bcd57ce

Browse files
committed
Address comments from Robin and Tyler.
Signed-off-by: Simeng Liu <simengl@nvidia.com>
1 parent 02cd4bf commit bcd57ce

6 files changed

Lines changed: 36 additions & 50 deletions

File tree

cpp/include/tensorrt_llm/executor/executor.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -989,6 +989,8 @@ class DynamicBatchConfig
989989

990990
[[nodiscard]] std::vector<std::pair<SizeType32, SizeType32>> getBatchSizeTable() const;
991991

992+
bool operator==(DynamicBatchConfig const& other) const;
993+
992994
/// @brief The default value of batch size table
993995
static std::vector<std::pair<SizeType32, SizeType32>> const kDefaultBatchSizeTable;
994996

cpp/tensorrt_llm/batch_manager/capacityScheduler.cpp

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -326,38 +326,39 @@ std::tuple<RequestVector, RequestVector> GuaranteedNoEvictScheduler::impl(
326326
bool const isEncoderInit = req->isEncoderInitState();
327327
std::optional<kv_cache_manager::PrefixReuseSummary> summary;
328328
std::optional<kv_cache_manager::PrefixReuseSummary> crossSummary;
329-
if (isFirstChunkContext)
329+
if (mEnablePrefixAwareScheduling)
330330
{
331-
// analyzePrefixReuse asserts on variable-window managers; skip the walk there
332-
// and let downstream callers fall back to their fresh tree-walk path.
333-
if (!mEnablePrefixAwareScheduling)
331+
if (isFirstChunkContext)
334332
{
335-
summary = kv_cache_manager::PrefixReuseSummary{};
336-
if (crossKvCacheManager)
333+
// analyzePrefixReuse asserts on variable-window managers; skip the walk there
334+
// and let downstream callers fall back to their fresh tree-walk path.
335+
if (kvCacheManager.isEnableBlockReuse() && !kvCacheManager.getBlockManager().isVariableWindow())
337336
{
338-
crossSummary = kv_cache_manager::PrefixReuseSummary{};
337+
auto uniqueTokens = req->getUniqueTokens(0);
338+
summary = kvCacheManager.analyzePrefixReuse(uniqueTokens, *req);
339+
}
340+
if (crossKvCacheManager && crossKvCacheManager->isEnableBlockReuse()
341+
&& !crossKvCacheManager->getBlockManager().isVariableWindow())
342+
{
343+
auto uniqueTokens = *(req->getEncoderUniqueTokens().value());
344+
crossSummary = crossKvCacheManager->analyzePrefixReuse(uniqueTokens, *req);
339345
}
340346
}
341-
else if (kvCacheManager.isEnableBlockReuse()
342-
&& !kvCacheManager.getBlockManager().isVariableWindow())
343-
{
344-
auto uniqueTokens = req->getUniqueTokens(0);
345-
summary = kvCacheManager.analyzePrefixReuse(uniqueTokens, *req);
346-
}
347-
if (mEnablePrefixAwareScheduling && crossKvCacheManager && crossKvCacheManager->isEnableBlockReuse()
347+
else if (isEncoderInit && crossKvCacheManager && crossKvCacheManager->isEnableBlockReuse()
348348
&& !crossKvCacheManager->getBlockManager().isVariableWindow())
349349
{
350+
// Encoder admission only needs the cross summary for reuse ordering.
350351
auto uniqueTokens = *(req->getEncoderUniqueTokens().value());
351352
crossSummary = crossKvCacheManager->analyzePrefixReuse(uniqueTokens, *req);
352353
}
353354
}
354-
else if (mEnablePrefixAwareScheduling && isEncoderInit && crossKvCacheManager
355-
&& crossKvCacheManager->isEnableBlockReuse()
356-
&& !crossKvCacheManager->getBlockManager().isVariableWindow())
355+
else if (isFirstChunkContext)
357356
{
358-
// Encoder admission only needs the cross summary for reuse ordering.
359-
auto uniqueTokens = *(req->getEncoderUniqueTokens().value());
360-
crossSummary = crossKvCacheManager->analyzePrefixReuse(uniqueTokens, *req);
357+
summary = kv_cache_manager::PrefixReuseSummary{};
358+
if (crossKvCacheManager)
359+
{
360+
crossSummary = kv_cache_manager::PrefixReuseSummary{};
361+
}
361362
}
362363
// Beneficial-to-skip check using the cached summary
363364
if (!StaticBatchScheduling && skippingIsRelevant && (isFirstChunkContext || isEncoderInit)

cpp/tensorrt_llm/executor/dynamicBatchConfig.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
* SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
33
* SPDX-License-Identifier: Apache-2.0
44
*
55
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -49,6 +49,14 @@ std::vector<std::pair<SizeType32, SizeType32>> DynamicBatchConfig::getBatchSizeT
4949
return mBatchSizeTable;
5050
}
5151

52+
bool DynamicBatchConfig::operator==(DynamicBatchConfig const& other) const
53+
{
54+
return mEnableBatchSizeTuning == other.mEnableBatchSizeTuning
55+
&& mEnableMaxNumTokensTuning == other.mEnableMaxNumTokensTuning
56+
&& mDynamicBatchMovingAverageWindow == other.mDynamicBatchMovingAverageWindow
57+
&& mBatchSizeTable == other.mBatchSizeTable;
58+
}
59+
5260
std::vector<std::pair<SizeType32, SizeType32>> const DynamicBatchConfig::kDefaultBatchSizeTable{
5361
{144, 128},
5462
{336, 256},

cpp/tensorrt_llm/executor/schedulerConfig.cpp

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -20,29 +20,6 @@
2020
namespace tensorrt_llm::executor
2121
{
2222

23-
namespace
24-
{
25-
26-
bool dynamicBatchConfigsEqual(
27-
std::optional<DynamicBatchConfig> const& lhs, std::optional<DynamicBatchConfig> const& rhs)
28-
{
29-
if (lhs.has_value() != rhs.has_value())
30-
{
31-
return false;
32-
}
33-
if (!lhs.has_value())
34-
{
35-
return true;
36-
}
37-
38-
return lhs->getEnableBatchSizeTuning() == rhs->getEnableBatchSizeTuning()
39-
&& lhs->getEnableMaxNumTokensTuning() == rhs->getEnableMaxNumTokensTuning()
40-
&& lhs->getDynamicBatchMovingAverageWindow() == rhs->getDynamicBatchMovingAverageWindow()
41-
&& lhs->getBatchSizeTable() == rhs->getBatchSizeTable();
42-
}
43-
44-
} // namespace
45-
4623
SchedulerConfig::SchedulerConfig(CapacitySchedulerPolicy capacitySchedulerPolicy,
4724
std::optional<ContextChunkingPolicy> contextChunkingPolicy, std::optional<DynamicBatchConfig> dynamicBatchConfig,
4825
bool enablePrefixAwareScheduling)
@@ -56,8 +33,7 @@ SchedulerConfig::SchedulerConfig(CapacitySchedulerPolicy capacitySchedulerPolicy
5633
bool SchedulerConfig::operator==(SchedulerConfig const& other) const
5734
{
5835
return mCapacitySchedulerPolicy == other.mCapacitySchedulerPolicy
59-
&& mContextChunkingPolicy == other.mContextChunkingPolicy
60-
&& dynamicBatchConfigsEqual(mDynamicBatchConfig, other.mDynamicBatchConfig)
36+
&& mContextChunkingPolicy == other.mContextChunkingPolicy && mDynamicBatchConfig == other.mDynamicBatchConfig
6137
&& mEnablePrefixAwareScheduling == other.mEnablePrefixAwareScheduling;
6238
}
6339

cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,14 +76,13 @@ void initConfigBindings(nb::module_& m)
7676

7777
auto schedulerConfigSetstate = [](tle::SchedulerConfig& self, nb::tuple const& state)
7878
{
79-
if (state.size() != 3 && state.size() != 4)
79+
if (state.size() != 4)
8080
{
8181
throw std::runtime_error("Invalid state!");
8282
}
83-
bool const enablePrefixAwareScheduling = state.size() == 4 ? nb::cast<bool>(state[3]) : true;
8483
new (&self) tle::SchedulerConfig(nb::cast<tle::CapacitySchedulerPolicy>(state[0]),
8584
nb::cast<std::optional<tle::ContextChunkingPolicy>>(state[1]),
86-
nb::cast<std::optional<tle::DynamicBatchConfig>>(state[2]), enablePrefixAwareScheduling);
85+
nb::cast<std::optional<tle::DynamicBatchConfig>>(state[2]), nb::cast<bool>(state[3]));
8786
};
8887
auto schedulerConfigGetstate = [](tle::SchedulerConfig const& self)
8988
{

tests/unittest/bindings/test_executor_bindings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1372,7 +1372,7 @@ def test_dynamic_batch_config_pickle():
13721372
assert config_copy.dynamic_batch_moving_average_window == 128
13731373

13741374

1375-
def test_scheduler_config() -> None:
1375+
def test_scheduler_config():
13761376
capacity_scheduler_policy = trtllm.CapacitySchedulerPolicy.GUARANTEED_NO_EVICT
13771377
config = trtllm.SchedulerConfig()
13781378
assert config.capacity_scheduler_policy == capacity_scheduler_policy

0 commit comments

Comments
 (0)