-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Expand file tree
/
Copy pathcapacityScheduler.h
More file actions
210 lines (181 loc) · 9.13 KB
/
Copy pathcapacityScheduler.h
File metadata and controls
210 lines (181 loc) · 9.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
/*
* Copyright (c) 2023-2026, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "common.h"
#include "tensorrt_llm/batch_manager/llmRequest.h"
#include "tensorrt_llm/batch_manager/reorderPolicy.h"
#include "tensorrt_llm/common/algorithm.h"
#include "tensorrt_llm/common/optionalRef.h"
#include "tensorrt_llm/runtime/common.h"
#include <memory>
#include <variant>
namespace tensorrt_llm::batch_manager
{
namespace kv_cache_manager
{
class BaseKVCacheManager;
}
class BasePeftCacheManager;
} // namespace tensorrt_llm::batch_manager
namespace tensorrt_llm::batch_manager
{
using tensorrt_llm::runtime::SizeType32;
using common::OptionalRef;
/// @brief This scheduler takes into account the given request capacity and the KV cache capacity.
/// Depending on the CapacitySchedulerPolicy it will schedule already started and new requests,
/// or even pause previously started requests.
class BaseCapacityScheduler
{
public:
explicit BaseCapacityScheduler(LlmRequestState noScheduleUntilState, LlmRequestState noScheduleAfterState)
: mNoScheduleUntilState(noScheduleUntilState)
, mNoScheduleAfterState(noScheduleAfterState)
{
}
[[nodiscard]] LlmRequestState constexpr getNoScheduleUntilState() const noexcept
{
return mNoScheduleUntilState;
}
[[nodiscard]] LlmRequestState constexpr getNoScheduleAfterState() const noexcept
{
return mNoScheduleAfterState;
}
private:
/// The state until/after which the scheduler should not schedule requests
LlmRequestState mNoScheduleUntilState;
LlmRequestState mNoScheduleAfterState;
};
/// @brief Schedule up to maxNumRequests requests
class MaxRequestsScheduler : public BaseCapacityScheduler
{
public:
explicit MaxRequestsScheduler(SizeType32 maxNumRequests,
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE);
/// @brief Takes as input a sorted list of requests and outputs a sorted lists of requests
/// to update for this current iteration, and a map of requests to pause
[[nodiscard]] std::tuple<RequestVector, RequestVector> operator()(RequestList const& activeRequests) const;
private:
SizeType32 mMaxNumRequests;
};
/// @brief Schedule requests using the MAX_UTILIZATION policy
/// @details Try reserving resources to advance requests by one step,
/// may pause previously started requests. When a
/// ``crossKvCacheManager`` is supplied, requests in the
/// ``ENCODER_INIT`` state may be admitted for encoder compute
/// without consuming self- or cross-KV blocks; the later
/// ``CONTEXT_INIT`` decoder admission owns cross-pool budgeting.
class MaxUtilizationScheduler : public BaseCapacityScheduler
{
public:
MaxUtilizationScheduler(SizeType32 maxNumRequests, bool twoStepsLookAhead,
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE,
bool enablePrefixAwareScheduling = true);
[[nodiscard]] std::tuple<RequestVector, RequestVector> operator()(
kv_cache_manager::BaseKVCacheManager& kvCacheManager,
OptionalRef<kv_cache_manager::BaseKVCacheManager> crossKvCacheManager,
OptionalRef<BasePeftCacheManager const> peftCacheManager, RequestList const& activeRequests) const;
private:
SizeType32 mMaxNumRequests;
/// @brief Boolean that indicates if two step lookahead is enabled
bool mTwoStepsLookAhead;
/// @brief Whether to use KV prefix-reuse estimates in scheduling decisions.
bool mEnablePrefixAwareScheduling;
};
/// @brief Schedule requests using the GUARANTEED_NO_EVICT policy
/// @details When a ``crossKvCacheManager`` is supplied, requests in the
/// ``ENCODER_INIT`` state may be admitted for encoder compute
/// without consuming self- or cross-KV blocks. The later
/// ``CONTEXT_INIT`` decoder admission owns cross-pool budgeting.
class GuaranteedNoEvictScheduler : public BaseCapacityScheduler
{
public:
GuaranteedNoEvictScheduler(SizeType32 maxNumRequests,
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE,
bool enablePrefixAwareScheduling = true);
[[nodiscard]] std::tuple<RequestVector, RequestVector> operator()(
kv_cache_manager::BaseKVCacheManager const& kvCacheManager,
OptionalRef<kv_cache_manager::BaseKVCacheManager const> crossKvCacheManager,
OptionalRef<BasePeftCacheManager const> peftCacheManager, RequestList const& activeRequests) const;
protected:
template <bool StaticBatchScheduling>
[[nodiscard]] std::tuple<RequestVector, RequestVector> impl(
kv_cache_manager::BaseKVCacheManager const& kvCacheManager,
OptionalRef<kv_cache_manager::BaseKVCacheManager const> crossKvCacheManager,
OptionalRef<BasePeftCacheManager const> peftCacheManager, RequestList const& activeRequests) const;
private:
SizeType32 mMaxNumRequests;
/// @brief Whether to use KV prefix-reuse estimates in scheduling decisions.
bool mEnablePrefixAwareScheduling;
};
/// @brief Schedule requests using the STATIC_BATCH policy
class StaticBatchScheduler : public GuaranteedNoEvictScheduler
{
public:
StaticBatchScheduler(SizeType32 maxNumRequests,
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE,
bool enablePrefixAwareScheduling = true);
[[nodiscard]] std::tuple<RequestVector, RequestVector> operator()(
kv_cache_manager::BaseKVCacheManager const& kvCacheManager,
OptionalRef<kv_cache_manager::BaseKVCacheManager const> crossKvCacheManager,
OptionalRef<BasePeftCacheManager const> peftCacheManager, RequestList const& activeRequests) const;
};
class CapacityScheduler : public Algorithm
{
public:
constexpr static auto name{"CapacityScheduler"};
explicit CapacityScheduler(SizeType32 maxNumRequests, executor::CapacitySchedulerPolicy capacitySchedulerPolicy,
bool hasKvCacheManager, bool twoStepsLookAhead = false,
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE,
bool enablePrefixAwareScheduling = true);
/**
* @brief Schedules requests following the selected policy.
*
* @param kvCacheManager Required in MaxUtilizationScheduler (as a ref) and in GuaranteedNoEvictScheduler and
* StaticBatchScheduler (as a const ref).
* @param crossKvCacheManager Optional cross-attention KV cache manager. Used by
* MaxUtilizationScheduler (mutates: ``startScheduling`` / ``schedulingRemoveSequence``)
* and GuaranteedNoEvictScheduler / StaticBatchScheduler (read-only). Required for
* encoder-decoder admission. Encoder-init requests only require this pool
* to be configured; decoder context admission budgets blocks from it.
* @param peftCacheManager Optional used in MaxUtilizationScheduler, GuaranteedNoEvictScheduler and
* StaticBatchScheduler.
* @param activeRequests
* @return std::tuple<RequestVector, RequestVector, RequestVector>, fittingRequests, fittingDisaggInitRequests and
* pausedRequests respectively.
*/
[[nodiscard]] std::tuple<RequestVector, RequestVector, RequestVector> operator()(RequestList const& activeRequests,
OptionalRef<kv_cache_manager::BaseKVCacheManager> kvCacheManager = std::nullopt,
OptionalRef<BasePeftCacheManager const> peftCacheManager = std::nullopt,
OptionalRef<kv_cache_manager::BaseKVCacheManager> crossKvCacheManager = std::nullopt) const;
/// @brief Sets the reorder policy to use AgentTreePolicy with the given configuration.
/// @param agentPercentage The ratio of agent requests to schedule (0.0-1.0, -1.0 for random).
/// @param agentTypes The list of agent types to schedule.
/// @param agentInflightSeqNum The maximum number of inflight sequences for agent requests.
void setAgentTreeReorderPolicy(
float agentPercentage, std::optional<std::vector<std::string>> agentTypes, SizeType32 agentInflightSeqNum);
private:
std::variant<std::monostate, MaxRequestsScheduler, MaxUtilizationScheduler, GuaranteedNoEvictScheduler,
StaticBatchScheduler>
mScheduler;
/// Optional reorder policy for reordering requests before scheduling.
std::unique_ptr<ReorderPolicy> mReorderPolicy;
};
} // namespace tensorrt_llm::batch_manager