forked from microsoft/onnxruntime-genai
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdecoder_only_pipeline.cpp
More file actions
454 lines (390 loc) · 20.8 KB
/
decoder_only_pipeline.cpp
File metadata and controls
454 lines (390 loc) · 20.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "../generators.h"
#include "../logging.h"
#include "../tracing.h"
#include "decoder_only_pipeline.h"
#include "windowed_kv_cache.h"
namespace Generators {
DecoderOnlyPipelineModel::DecoderOnlyPipelineModel(std::unique_ptr<Config> config, OrtEnv& ort_env)
: Model{std::move(config)}, ort_env_{ort_env} {
for (const auto& model : config_->model.decoder.pipeline) {
sessions_.emplace_back(CreateSession(ort_env, model.filename, GetSessionOptions(model.model_id)));
}
for (auto& session : sessions_) {
session_info_.Add(*session);
}
}
std::unique_ptr<State> DecoderOnlyPipelineModel::CreateState(DeviceSpan<int32_t> sequence_lengths,
const GeneratorParams& params) const {
return std::make_unique<DecoderOnlyPipelineState>(*this, sequence_lengths, params);
}
IntermediatePipelineState::IntermediatePipelineState(const DecoderOnlyPipelineModel& model, const GeneratorParams& params,
size_t pipeline_state_index)
: State{params, model},
id_{pipeline_state_index},
model_{model} {}
bool IntermediatePipelineState::HasInput(std::string_view name) const {
return std::any_of(model_.config_->model.decoder.pipeline[id_].inputs.begin(),
model_.config_->model.decoder.pipeline[id_].inputs.end(),
[&name](const std::string& elem) { return elem == name; });
}
bool IntermediatePipelineState::HasOutput(std::string_view name) const {
return std::any_of(model_.config_->model.decoder.pipeline[id_].outputs.begin(),
model_.config_->model.decoder.pipeline[id_].outputs.end(),
[&name](const std::string& elem) { return elem == name; });
}
bool IntermediatePipelineState::SupportsPrimaryDevice() const {
if (model_.p_device_->GetType() == DeviceType::CPU || model_.p_device_->GetType() == DeviceType::QNN) {
return true;
} else if (model_.p_device_->GetType() == DeviceType::CUDA) {
if (!model_.config_->model.decoder.pipeline[id_].session_options.has_value()) {
// No session options, so this session uses the default session options.
// Default session options supports the cuda device type.
return true;
} else if (auto& provider_options = (*model_.config_->model.decoder.pipeline[id_].session_options).provider_options;
std::any_of(provider_options.begin(), provider_options.end(),
[](const Config::ProviderOptions& elem) { return elem.name == "cuda"; })) {
// cuda is listed as one of the providers. This session supports the cuda device type.
return true;
} else {
// cuda is not listed as one of the providers. This session does not support the cuda device type.
return false;
}
}
return false;
}
DeviceSpan<float> IntermediatePipelineState::Run(int total_length, DeviceSpan<int32_t>& next_tokens,
DeviceSpan<int32_t> next_indices) {
if (!model_.sessions_[id_]) {
const_cast<DecoderOnlyPipelineModel*>(&model_)->sessions_[id_] =
OrtSession::Create(model_.ort_env_, (model_.config_->config_path / fs::path(model_.config_->model.decoder.pipeline[id_].filename)).c_str(),
model_.GetSessionOptions(model_.config_->model.decoder.pipeline[id_].model_id));
}
State::Run(*model_.sessions_[id_]);
return {};
}
using NameToLayerIdxMap = std::unordered_map<std::string, size_t>;
static NameToLayerIdxMap GeneratePastKeyNameToLayerIdxMap(const Config& config) {
const size_t num_layers = config.model.decoder.num_hidden_layers;
const std::string& past_key_name_template = config.model.decoder.inputs.past_key_names;
NameToLayerIdxMap m{};
for (size_t i = 0; i < num_layers; ++i) {
m.emplace(ComposeKeyValueName(past_key_name_template, static_cast<int>(i)), i);
}
return m;
}
static NameToLayerIdxMap GeneratePresentKeyNameToLayerIdxMap(const Config& config) {
const size_t num_layers = config.model.decoder.num_hidden_layers;
const std::string& present_key_name_template = config.model.decoder.outputs.present_key_names;
NameToLayerIdxMap m{};
for (size_t i = 0; i < num_layers; ++i) {
m.emplace(ComposeKeyValueName(present_key_name_template, static_cast<int>(i)), i);
}
return m;
}
static std::vector<size_t> GetLayerIndicesSetFromPastAndPresentKeyNames(
const NameToLayerIdxMap& past_key_name_to_layer_idx, const NameToLayerIdxMap& present_key_name_to_layer_idx,
std::span<const std::string> inputs, std::span<const std::string> outputs) {
std::vector<size_t> layer_indices{};
for (const auto& input_name : inputs) {
const auto it = past_key_name_to_layer_idx.find(input_name);
if (it != past_key_name_to_layer_idx.end()) {
layer_indices.push_back(it->second);
}
}
for (const auto& output_name : outputs) {
const auto it = present_key_name_to_layer_idx.find(output_name);
if (it != present_key_name_to_layer_idx.end()) {
layer_indices.push_back(it->second);
}
}
// sort and remove duplicates
std::sort(layer_indices.begin(), layer_indices.end());
layer_indices.erase(std::unique(layer_indices.begin(), layer_indices.end()),
layer_indices.end());
return layer_indices;
}
static bool ContainsPresentKeyNameOutputs(const NameToLayerIdxMap& present_key_name_to_layer_idx,
std::span<const std::string> outputs) {
for (const auto& output_name : outputs) {
const auto it = present_key_name_to_layer_idx.find(output_name);
if (it != present_key_name_to_layer_idx.end()) {
return true;
}
}
return false;
}
DecoderOnlyPipelineState::DecoderOnlyPipelineState(const DecoderOnlyPipelineModel& model,
DeviceSpan<int32_t> sequence_lengths,
const GeneratorParams& params)
: State{params, model},
model_{model},
input_ids_{CreateInputIDs(*this)},
key_value_cache_{CreateKeyValueCache(*this)},
do_key_value_cache_partial_update_{key_value_cache_ && key_value_cache_->IsPartialUpdateSupported()},
position_inputs_{CreatePositionInputs(*this, sequence_lengths, model_.config_->model.decoder.inputs.attention_mask)} {
input_ids_->Add();
position_inputs_->Add();
logits_.Add();
if (key_value_cache_) {
key_value_cache_->Add();
}
const auto& config_pipeline = model_.config_->model.decoder.pipeline;
for (size_t i = 0; i < config_pipeline.size(); ++i) {
auto pipeline_model_state = std::make_unique<IntermediatePipelineState>(model_, params, pipeline_states_.size());
pipeline_states_.emplace_back(std::move(pipeline_model_state));
}
if (do_key_value_cache_partial_update_) {
const auto past_key_name_to_layer_idx = GeneratePastKeyNameToLayerIdxMap(*model_.config_);
const auto present_key_name_to_layer_idx = GeneratePresentKeyNameToLayerIdxMap(*model_.config_);
std::map<std::vector<size_t>, size_t> layer_indices_to_update_record_idx{};
std::unordered_set<size_t> layer_indices_encountered{};
for (size_t i = 0; i < config_pipeline.size(); ++i) {
const auto& pipeline_model = config_pipeline[i];
const auto layer_indices = GetLayerIndicesSetFromPastAndPresentKeyNames(past_key_name_to_layer_idx, present_key_name_to_layer_idx,
pipeline_model.inputs, pipeline_model.outputs);
pipeline_states_[i]->constains_kv_cache_output_ = ContainsPresentKeyNameOutputs(present_key_name_to_layer_idx, pipeline_model.outputs);
if (layer_indices.empty()) {
continue;
}
size_t record_idx{};
if (auto layer_indices_to_update_record_it = layer_indices_to_update_record_idx.find(layer_indices);
layer_indices_to_update_record_it != layer_indices_to_update_record_idx.end()) {
// we have seen this exact set of layer indices before. reuse the existing record.
record_idx = layer_indices_to_update_record_it->second;
} else {
// verify that the new set of layer indices is valid.
// i.e., it is disjoint with the set of all layer indices we've seen so far.
const bool layer_indices_valid =
std::all_of(layer_indices.begin(), layer_indices.end(),
[&layer_indices_encountered](size_t layer_idx) {
return layer_indices_encountered.find(layer_idx) == layer_indices_encountered.end();
});
if (!layer_indices_valid) {
throw std::runtime_error(
"Invalid layer indices. Layer index sets for partial key value cache update must be either an exact "
"match with another set or disjoint with all other sets.");
}
// add a new record
auto record = PartialKeyValueCacheUpdateRecord{};
record.layer_indices = layer_indices;
partial_kv_cache_update_records_.emplace_back(std::move(record));
record_idx = partial_kv_cache_update_records_.size() - 1;
// add layer_indices to what we've seen so far
layer_indices_encountered.insert(layer_indices.begin(), layer_indices.end());
layer_indices_to_update_record_idx.emplace(layer_indices, record_idx);
}
pipeline_state_id_to_partial_kv_cache_update_record_idx_.emplace(i, record_idx);
}
if (!partial_kv_cache_update_records_.empty()) {
key_value_cache_update_worker_thread_.emplace();
}
}
}
void DecoderOnlyPipelineState::SetExtraInputs(const std::vector<ExtraInput>& extra_inputs) {
for (auto& session : model_.sessions_) {
extra_inputs_.Add(extra_inputs, session->GetInputNames());
}
}
void DecoderOnlyPipelineState::RunPipeline(int total_length, DeviceSpan<int32_t>& next_tokens,
DeviceSpan<int32_t> next_indices) {
for (auto& pipeline_state : pipeline_states_) {
if (first_run_ && !model_.config_->model.decoder.pipeline[pipeline_state->id_].run_on_prompt) {
continue;
} else if (!first_run_ && !model_.config_->model.decoder.pipeline[pipeline_state->id_].run_on_token_gen) {
continue;
}
DurationTrace trace{MakeString("DecoderOnlyPipelineState::RunPipeline[", pipeline_state->id_, "]")};
if (model_.config_->model.decoder.pipeline[pipeline_state->id_].reset_session_idx > -1) {
if (model_.config_->model.decoder.pipeline[pipeline_state->id_].reset_session_idx >=
static_cast<int>(model_.sessions_.size())) {
throw std::runtime_error(
MakeString("Invalid reset_session_idx ", model_.config_->model.decoder.pipeline[pipeline_state->id_].reset_session_idx,
" for pipeline model ", model_.config_->model.decoder.pipeline[pipeline_state->id_].model_id));
}
(const_cast<DecoderOnlyPipelineModel*>(&model_))->sessions_[model_.config_->model.decoder.pipeline[pipeline_state->id_].reset_session_idx].reset();
}
auto* const partial_kv_cache_update_record = [&]() -> PartialKeyValueCacheUpdateRecord* {
auto it = pipeline_state_id_to_partial_kv_cache_update_record_idx_.find(pipeline_state->id_);
if (it != pipeline_state_id_to_partial_kv_cache_update_record_idx_.end()) {
return &partial_kv_cache_update_records_[it->second];
}
return nullptr;
}();
// If there is any outstanding partial KV cache update, wait for it to finish.
// It is important to synchronize at this point, before setting input/output tensors for this pipeline state run,
// because a KV cache update may replace the KV cache input/output tensors.
if (partial_kv_cache_update_record) {
if (partial_kv_cache_update_record->outstanding_update.valid()) {
partial_kv_cache_update_record->outstanding_update.get();
}
}
// Clear the intermediate pipeline state outputs from the previous runs.
// These outputs will be replaced by the outputs from the current run.
for (const auto& output_name : pipeline_state->output_names_) {
if (auto iter = ortvalue_store_.find(output_name); iter != ortvalue_store_.end()) {
ortvalue_store_.erase(iter);
}
}
pipeline_state->ClearIO();
// Managed inputs and outputs are those inputs and outputs that the
// Model knows how to create and update from one run to the next.
// Add all the managed inputs to the intermediate pipeline state
for (const auto& input_name : input_names_) {
if (pipeline_state->HasInput(input_name)) {
if (!pipeline_state->SupportsPrimaryDevice()) {
throw std::runtime_error(
MakeString("Managed input ", input_name, " resides on the primary device type (",
static_cast<int>(model_.p_device_->GetType()), "). But the pipeline model ",
model_.config_->model.decoder.pipeline[pipeline_state->id_].model_id,
" is expecting it to reside elsewhere."));
}
pipeline_state->input_names_.push_back(input_name);
pipeline_state->inputs_.push_back(State::GetInput(input_name));
}
}
// Add outputs from the previous pipeline states to the current pipeline state
for (auto& [name, ortvalue] : ortvalue_store_) {
if (pipeline_state->HasInput(name)) {
pipeline_state->input_names_.push_back(name.c_str());
pipeline_state->inputs_.push_back(ortvalue.get());
}
}
// Add all the managed outputs to the intermediate pipeline state
for (const auto& output_name : output_names_) {
if (pipeline_state->HasOutput(output_name)) {
if (!pipeline_state->SupportsPrimaryDevice()) {
throw std::runtime_error(
MakeString("Managed output ", output_name, " resides on the primary device type (",
static_cast<int>(model_.p_device_->GetType()), "). But the pipeline model ",
model_.config_->model.decoder.pipeline[pipeline_state->id_].model_id,
" is expecting it to reside elsewhere."));
}
pipeline_state->output_names_.push_back(output_name);
pipeline_state->outputs_.push_back(State::GetOutput(output_name));
}
}
// Output of pipeline models could also be managed inputs.
// For example, the output of a pipeline model could be the key-value cache.
// In such cases, use the managed output buffers and register them with the pipeline model as outputs.
for (const auto& input_name : input_names_) {
if (pipeline_state->HasOutput(input_name)) {
if (!pipeline_state->SupportsPrimaryDevice()) {
throw std::runtime_error(
MakeString("Managed input ", input_name, " resides on the primary device type (",
static_cast<int>(model_.p_device_->GetType()), "). But the pipeline model ",
model_.config_->model.decoder.pipeline[pipeline_state->id_].model_id,
" is expecting it to reside elsewhere."));
}
pipeline_state->output_names_.push_back(input_name);
pipeline_state->outputs_.push_back(State::GetInput(input_name));
}
}
// Add all the remaining outputs for the intermediate pipeline state
for (const auto& output_name : model_.config_->model.decoder.pipeline[pipeline_state->id_].outputs) {
if (std::none_of(pipeline_state->output_names_.begin(), pipeline_state->output_names_.end(),
[&](const std::string& elem) { return elem == output_name; })) {
pipeline_state->output_names_.push_back(output_name.c_str());
pipeline_state->outputs_.push_back(nullptr);
}
}
// Run the intermediate pipeline state
pipeline_state->Run(total_length, next_tokens, next_indices);
// If there is any partial KV cache update to start and if KV cache is present in the output, enqueue it.
if (partial_kv_cache_update_record && pipeline_state->constains_kv_cache_output_) {
assert(key_value_cache_update_worker_thread_.has_value());
auto update_fn = [&key_value_cache = *key_value_cache_.get(),
layer_indices = partial_kv_cache_update_record->layer_indices,
next_indices, total_length]() {
key_value_cache.PartialUpdate(next_indices, total_length, layer_indices);
};
partial_kv_cache_update_record->outstanding_update = key_value_cache_update_worker_thread_->Enqueue(update_fn);
}
// Transfer ownership of all the non-managed outputs from the current pipeline state to the ortvalue store.
// All non managed outputs are assumed to be on CPU
for (size_t i = 0; i < pipeline_state->output_names_.size(); ++i) {
if (std::none_of(output_names_.begin(), output_names_.end(),
[&](const std::string& elem) { return elem == pipeline_state->output_names_[i]; }) &&
std::none_of(input_names_.begin(), input_names_.end(),
[&](const std::string& elem) { return elem == pipeline_state->output_names_[i]; })) {
auto forwarded_output = model_.config_->model.decoder.pipeline[pipeline_state->id_].output_names_forwarder.find(pipeline_state->output_names_[i]);
if (forwarded_output != model_.config_->model.decoder.pipeline[pipeline_state->id_].output_names_forwarder.end()) {
ortvalue_store_[forwarded_output->second] = std::unique_ptr<OrtValue>(pipeline_state->outputs_[i]);
} else {
ortvalue_store_[pipeline_state->output_names_[i]] = std::unique_ptr<OrtValue>(pipeline_state->outputs_[i]);
}
}
}
}
}
DeviceSpan<float> DecoderOnlyPipelineState::Run(int total_length, DeviceSpan<int32_t>& next_tokens,
DeviceSpan<int32_t> next_indices) {
DurationTrace trace{"DecoderOnlyPipelineState::Run"};
UpdateInputsOutputs(next_tokens, next_indices, total_length);
size_t num_chunks{1};
if (first_run_ && model_.config_->model.decoder.sliding_window.has_value()) {
int window_size = model_.config_->model.decoder.sliding_window->window_size;
num_chunks = (next_tokens.size() + window_size - 1) / window_size;
}
for (size_t i = 0; i < num_chunks; ++i) {
RunPipeline(total_length, next_tokens, next_indices);
if (model_.config_->model.decoder.sliding_window.has_value() && i < num_chunks - 1) {
// Sliding the window over the input_ids, key_cache, and value_cache, position_ids, and attention_mask
input_ids_->Update(next_tokens);
UpdateKeyValueCache(next_indices, total_length);
position_inputs_->Update(next_tokens, total_length, static_cast<int>(input_ids_->GetShape()[1]));
logits_.Update(WrapTensor<int32_t>(*model_.p_device_inputs_, *input_ids_->Get()),
static_cast<int>(input_ids_->GetShape()[1]));
}
}
// Clear the outputs of the pipeline models that are only run on prompt since this cannot happen earlier.
if (!first_run_) {
for (auto& pipeline_state : pipeline_states_) {
if (!model_.config_->model.decoder.pipeline[pipeline_state->id_].run_on_token_gen) {
for (const auto& output_name : pipeline_state->output_names_) {
if (auto iter = ortvalue_store_.find(output_name); iter != ortvalue_store_.end()) {
ortvalue_store_.erase(iter);
}
}
}
}
}
first_run_ = false;
return logits_.Get();
}
void DecoderOnlyPipelineState::UpdateKeyValueCache(DeviceSpan<int32_t> beam_indices, int total_length) {
if (key_value_cache_) {
const bool outstanding_key_value_cache_partial_update =
do_key_value_cache_partial_update_ &&
std::any_of(partial_kv_cache_update_records_.rbegin(),
partial_kv_cache_update_records_.rend(),
[](const PartialKeyValueCacheUpdateRecord& record) {
return record.outstanding_update.valid();
});
if (outstanding_key_value_cache_partial_update) {
// If there is any outstanding partial KV cache update, don't update the KV cache here.
} else {
key_value_cache_->Update(beam_indices, total_length);
}
}
}
void DecoderOnlyPipelineState::UpdateInputsOutputs(DeviceSpan<int32_t>& next_tokens,
DeviceSpan<int32_t> beam_indices, int total_length) {
input_ids_->Update(next_tokens);
size_t new_length = input_ids_->GetShape()[1];
position_inputs_->Update(next_tokens, total_length, static_cast<int>(new_length));
UpdateKeyValueCache(beam_indices, total_length);
auto next_windowed_tokens = WrapTensor<int32_t>(*model_.p_device_inputs_, *input_ids_->Get());
logits_.Update(next_windowed_tokens, new_length);
}
OrtValue* DecoderOnlyPipelineState::GetOutput(const char* name) {
// Check the ortvalue store to search if name is one of the non-managed output.
auto it = ortvalue_store_.find(name);
if (it != ortvalue_store_.end()) {
return it->second.get();
}
// Search managed outputs saved in this State.
return State::GetOutput(name);
}
} // namespace Generators