Skip to content

Commit 96dfb53

Browse files
authored
Merge pull request #1 from CodeLinaro/chilukam/multimodel-decoding-pipeline
Multi Model Decoder Modification
2 parents 5719c97 + 1b73c32 commit 96dfb53

6 files changed

Lines changed: 100 additions & 30 deletions

src/models/decoder_only_pipeline.cpp

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,22 +88,50 @@ static NameToLayerIdxMap GeneratePastKeyNameToLayerIdxMap(const Config& config)
8888
return m;
8989
}
9090

91-
static std::vector<size_t> GetLayerIndicesSetFromPastKeyNameInputs(
92-
const NameToLayerIdxMap& past_key_name_to_layer_idx, std::span<const std::string> inputs) {
91+
static NameToLayerIdxMap GeneratePresentKeyNameToLayerIdxMap(const Config& config) {
92+
const size_t num_layers = config.model.decoder.num_hidden_layers;
93+
const std::string& present_key_name_template = config.model.decoder.outputs.present_key_names;
94+
NameToLayerIdxMap m{};
95+
for (size_t i = 0; i < num_layers; ++i) {
96+
m.emplace(ComposeKeyValueName(present_key_name_template, static_cast<int>(i)), i);
97+
}
98+
return m;
99+
}
100+
101+
static std::vector<size_t> GetLayerIndicesSetFromPastAndPresentKeyNames(
102+
const NameToLayerIdxMap& past_key_name_to_layer_idx, const NameToLayerIdxMap& present_key_name_to_layer_idx,
103+
std::span<const std::string> inputs, std::span<const std::string> outputs) {
93104
std::vector<size_t> layer_indices{};
94105
for (const auto& input_name : inputs) {
95106
const auto it = past_key_name_to_layer_idx.find(input_name);
96107
if (it != past_key_name_to_layer_idx.end()) {
97108
layer_indices.push_back(it->second);
98109
}
99110
}
111+
for (const auto& output_name : outputs) {
112+
const auto it = present_key_name_to_layer_idx.find(output_name);
113+
if (it != present_key_name_to_layer_idx.end()) {
114+
layer_indices.push_back(it->second);
115+
}
116+
}
100117
// sort and remove duplicates
101118
std::sort(layer_indices.begin(), layer_indices.end());
102119
layer_indices.erase(std::unique(layer_indices.begin(), layer_indices.end()),
103120
layer_indices.end());
104121
return layer_indices;
105122
}
106123

124+
static bool ContainsPresentKeyNameOutputs(const NameToLayerIdxMap& present_key_name_to_layer_idx,
125+
std::span<const std::string> outputs) {
126+
for (const auto& output_name : outputs) {
127+
const auto it = present_key_name_to_layer_idx.find(output_name);
128+
if (it != present_key_name_to_layer_idx.end()) {
129+
return true;
130+
}
131+
}
132+
return false;
133+
}
134+
107135
DecoderOnlyPipelineState::DecoderOnlyPipelineState(const DecoderOnlyPipelineModel& model,
108136
DeviceSpan<int32_t> sequence_lengths,
109137
const GeneratorParams& params)
@@ -129,15 +157,18 @@ DecoderOnlyPipelineState::DecoderOnlyPipelineState(const DecoderOnlyPipelineMode
129157

130158
if (do_key_value_cache_partial_update_) {
131159
const auto past_key_name_to_layer_idx = GeneratePastKeyNameToLayerIdxMap(*model_.config_);
160+
const auto present_key_name_to_layer_idx = GeneratePresentKeyNameToLayerIdxMap(*model_.config_);
132161

133162
std::map<std::vector<size_t>, size_t> layer_indices_to_update_record_idx{};
134163
std::unordered_set<size_t> layer_indices_encountered{};
135164

136165
for (size_t i = 0; i < config_pipeline.size(); ++i) {
137166
const auto& pipeline_model = config_pipeline[i];
138167

139-
const auto layer_indices = GetLayerIndicesSetFromPastKeyNameInputs(past_key_name_to_layer_idx,
140-
pipeline_model.inputs);
168+
const auto layer_indices = GetLayerIndicesSetFromPastAndPresentKeyNames(past_key_name_to_layer_idx, present_key_name_to_layer_idx,
169+
pipeline_model.inputs, pipeline_model.outputs);
170+
171+
pipeline_states_[i]->constains_kv_cache_output_ = ContainsPresentKeyNameOutputs(present_key_name_to_layer_idx, pipeline_model.outputs);
141172

142173
if (layer_indices.empty()) {
143174
continue;
@@ -308,8 +339,8 @@ void DecoderOnlyPipelineState::RunPipeline(int total_length, DeviceSpan<int32_t>
308339
// Run the intermediate pipeline state
309340
pipeline_state->Run(total_length, next_tokens, next_indices);
310341

311-
// If there is any partial KV cache update to start, enqueue it.
312-
if (partial_kv_cache_update_record) {
342+
// If there is any partial KV cache update to start and if KV cache is present in the output, enqueue it.
343+
if (partial_kv_cache_update_record && pipeline_state->constains_kv_cache_output_) {
313344
assert(key_value_cache_update_worker_thread_.has_value());
314345
auto update_fn = [&key_value_cache = *key_value_cache_.get(),
315346
layer_indices = partial_kv_cache_update_record->layer_indices,

src/models/decoder_only_pipeline.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ struct IntermediatePipelineState : State {
4747
bool SupportsPrimaryDevice() const;
4848

4949
size_t id_;
50+
bool constains_kv_cache_output_ = false; //Specifies if the intermediate pipeline state output conatins present kv cache
5051

5152
private:
5253
const DecoderOnlyPipelineModel& model_;

src/models/embeddings.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ void WindowedEmbeddings::Update(Embeddings& embeddings) {
103103

104104
const uint16_t* full_data = full_embeddings->GetTensorData<uint16_t>();
105105

106+
const int mem_copy_factor = (type_ == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) ? 2 : 1;
107+
106108
if (window_index_ == 0) {
107109
num_windows_ = (sequence_length + window_size_ - 1) / window_size_;
108110
shape_ = {
@@ -114,7 +116,7 @@ void WindowedEmbeddings::Update(Embeddings& embeddings) {
114116
embeddings_ = OrtValue::CreateTensor(model_.p_device_inputs_->GetAllocator(), shape_, type_);
115117
std::copy_n(
116118
full_data,
117-
window_size_ * hidden_size * 2,
119+
window_size_ * hidden_size * mem_copy_factor,
118120
embeddings_->GetTensorMutableData<uint16_t>());
119121

120122
} else if (window_index_ < num_windows_) {
@@ -125,17 +127,17 @@ void WindowedEmbeddings::Update(Embeddings& embeddings) {
125127
};
126128
embeddings_ = OrtValue::CreateTensor(model_.p_device_inputs_->GetAllocator(), shape_, type_);
127129
std::copy_n(
128-
full_data + window_index_ * window_size_ * hidden_size * 2,
129-
window_size_ * hidden_size * 2,
130+
full_data + window_index_ * window_size_ * hidden_size * mem_copy_factor,
131+
window_size_ * hidden_size * mem_copy_factor,
130132
embeddings_->GetTensorMutableData<uint16_t>());
131133

132134
} else {
133135
// Final token case (e.g., generated token)
134136
shape_ = {1, 1, hidden_size};
135137
embeddings_ = OrtValue::CreateTensor(model_.p_device_inputs_->GetAllocator(), shape_, type_);
136138
std::copy_n(
137-
full_data + (sequence_length - 1) * hidden_size * 2,
138-
hidden_size * 2,
139+
full_data + (sequence_length - 1) * hidden_size * mem_copy_factor,
140+
hidden_size * mem_copy_factor,
139141
embeddings_->GetTensorMutableData<uint16_t>());
140142

141143
}

src/models/multi_decoder_pipeline_modal.cpp

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,10 @@ int64_t GetNumAudioTokens(const std::vector<ExtraInput>& extra_inputs,
4444
return 0;
4545
}
4646

47-
int64_t GetImageFeatureBatchSize(const std::vector<ExtraInput>& extra_inputs) {
47+
int64_t GetImageFeatureBatchSize(const std::vector<ExtraInput>& extra_inputs,
48+
const std::string& pixel_values_name) {
4849
for (size_t i = 0; i < extra_inputs.size(); ++i) {
49-
if (extra_inputs[i].name == Config::Defaults::PixelValuesName) {
50+
if (extra_inputs[i].name == pixel_values_name) {
5051
assert(extra_inputs[i].tensor->ort_tensor_);
5152
const auto num_dims = extra_inputs[i].tensor->ort_tensor_->GetTensorTypeAndShapeInfo()->GetShape().size();
5253
if (num_dims < 3) {
@@ -119,7 +120,7 @@ void VisionPipelineState::SetExtraInputs(const std::vector<ExtraInput>& extra_in
119120
model_.config_->model.vision.outputs.image_features,
120121
num_images_, num_image_tokens_);
121122
for (const auto& ei : extra_inputs) {
122-
if (ei.name == "pixel_values") {
123+
if (ei.name == model_.config_->model.vision.inputs.pixel_values) {
123124
pixel_values_tensor_ = ei.tensor;
124125
break;
125126
}
@@ -192,7 +193,7 @@ DeviceSpan<float> VisionPipelineState::Run(int current_length,
192193

193194
for (int64_t i = 0; i < total_images; ++i) {
194195
auto pixel_values_i = MakeSingleImagePixelValues(pixel_values_tensor_, i, model_.p_device_);
195-
extra_inputs_.Replace("pixel_values", pixel_values_i);
196+
extra_inputs_.Replace(model_.config_->model.vision.inputs.pixel_values, pixel_values_i);
196197

197198
State::Run(*model_.vision_session_);
198199

@@ -338,22 +339,50 @@ static NameToLayerIdxMap GeneratePastKeyNameToLayerIdxMap(const Config& config)
338339
return m;
339340
}
340341

341-
static std::vector<size_t> GetLayerIndicesSetFromPastKeyNameInputs(
342-
const NameToLayerIdxMap& past_key_name_to_layer_idx, std::span<const std::string> inputs) {
342+
static NameToLayerIdxMap GeneratePresentKeyNameToLayerIdxMap(const Config& config) {
343+
const size_t num_layers = config.model.decoder.num_hidden_layers;
344+
const std::string& present_key_name_template = config.model.decoder.outputs.present_key_names;
345+
NameToLayerIdxMap m{};
346+
for (size_t i = 0; i < num_layers; ++i) {
347+
m.emplace(ComposeKeyValueName(present_key_name_template, static_cast<int>(i)), i);
348+
}
349+
return m;
350+
}
351+
352+
static std::vector<size_t> GetLayerIndicesSetFromPastAndPresentKeyNames(
353+
const NameToLayerIdxMap& past_key_name_to_layer_idx, const NameToLayerIdxMap& present_key_name_to_layer_idx,
354+
std::span<const std::string> inputs, std::span<const std::string> outputs) {
343355
std::vector<size_t> layer_indices{};
344356
for (const auto& input_name : inputs) {
345357
const auto it = past_key_name_to_layer_idx.find(input_name);
346358
if (it != past_key_name_to_layer_idx.end()) {
347359
layer_indices.push_back(it->second);
348360
}
349361
}
362+
for (const auto& output_name : outputs) {
363+
const auto it = present_key_name_to_layer_idx.find(output_name);
364+
if (it != present_key_name_to_layer_idx.end()) {
365+
layer_indices.push_back(it->second);
366+
}
367+
}
350368
// sort and remove duplicates
351369
std::sort(layer_indices.begin(), layer_indices.end());
352370
layer_indices.erase(std::unique(layer_indices.begin(), layer_indices.end()),
353371
layer_indices.end());
354372
return layer_indices;
355373
}
356374

375+
static bool ContainsPresentKeyNameOutputs(const NameToLayerIdxMap& present_key_name_to_layer_idx,
376+
std::span<const std::string> outputs) {
377+
for (const auto& output_name : outputs) {
378+
const auto it = present_key_name_to_layer_idx.find(output_name);
379+
if (it != present_key_name_to_layer_idx.end()) {
380+
return true;
381+
}
382+
}
383+
return false;
384+
}
385+
357386
DecoderPipelineState::DecoderPipelineState(const MultiModalPipelineLanguageModel& model,
358387
DeviceSpan<int32_t> sequence_lengths,
359388
const GeneratorParams& params)
@@ -384,15 +413,18 @@ DecoderPipelineState::DecoderPipelineState(const MultiModalPipelineLanguageModel
384413

385414
if (do_key_value_cache_partial_update_) {
386415
const auto past_key_name_to_layer_idx = GeneratePastKeyNameToLayerIdxMap(*model_.config_);
416+
const auto present_key_name_to_layer_idx = GeneratePresentKeyNameToLayerIdxMap(*model_.config_);
387417

388418
std::map<std::vector<size_t>, size_t> layer_indices_to_update_record_idx{};
389419
std::unordered_set<size_t> layer_indices_encountered{};
390420

391421
for (size_t i = 0; i < config_pipeline.size(); ++i) {
392422
const auto& pipeline_model = config_pipeline[i];
393423

394-
const auto layer_indices = GetLayerIndicesSetFromPastKeyNameInputs(past_key_name_to_layer_idx,
395-
pipeline_model.inputs);
424+
const auto layer_indices = GetLayerIndicesSetFromPastAndPresentKeyNames(past_key_name_to_layer_idx, present_key_name_to_layer_idx,
425+
pipeline_model.inputs, pipeline_model.outputs);
426+
427+
pipeline_states_[i]->constains_kv_cache_output_ = ContainsPresentKeyNameOutputs(present_key_name_to_layer_idx, pipeline_model.outputs);
396428

397429
if (layer_indices.empty()) {
398430
continue;
@@ -560,8 +592,8 @@ void DecoderPipelineState::RunPipeline(int total_length, DeviceSpan<int32_t>& ne
560592
// Run the intermediate pipeline state
561593
pipeline_state->Run(total_length, next_tokens, next_indices);
562594

563-
// If there is any partial KV cache update to start, enqueue it.
564-
if (partial_kv_cache_update_record) {
595+
// If there is any partial KV cache update to start and if KV cache is present in the output, enqueue it.
596+
if (partial_kv_cache_update_record && pipeline_state->constains_kv_cache_output_) {
565597
assert(key_value_cache_update_worker_thread_.has_value());
566598
auto update_fn = [&key_value_cache = *key_value_cache_.get(),
567599
layer_indices = partial_kv_cache_update_record->layer_indices,
@@ -702,7 +734,7 @@ MultiModalDecoderPipelineState::MultiModalDecoderPipelineState(const MultiModalP
702734
void MultiModalDecoderPipelineState::SetExtraInputs(const std::vector<ExtraInput>& extra_inputs) {
703735
num_image_tokens_ = GetNumImageTokens(extra_inputs);
704736
num_audio_tokens_ = GetNumAudioTokens(extra_inputs, model_.config_->model.speech.inputs.audio_sizes);
705-
num_images_ = GetImageFeatureBatchSize(extra_inputs);
737+
num_images_ = GetImageFeatureBatchSize(extra_inputs, model_.config_->model.vision.inputs.pixel_values);
706738

707739
if (model_.vision_session_) {
708740
vision_state_->SetExtraInputs(extra_inputs, num_images_, num_image_tokens_);

src/models/multi_decoder_pipeline_modal.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ struct IntermediateDecoderPipelineState : State {
105105
bool SupportsPrimaryDevice() const;
106106

107107
size_t id_;
108+
bool constains_kv_cache_output_ = false; //Specifies if the intermediate pipeline state output conatins present kv cache
108109

109110
private:
110111
const MultiModalPipelineLanguageModel& model_;

src/models/windowed_kv_cache.cpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ void WindowedKeyValueCache::SlideLayer(size_t layer_idx) {
140140
const auto& layer_state = per_layer_states_[layer_idx];
141141

142142
const auto window_size = layer_state.window_size;
143+
const auto seq_len = layer_state.window_index * layer_state.window_size;
143144
const auto& key_cache_shape_in = layer_state.key_cache_shape_in;
144145
const auto& key_cache_shape_out = layer_state.key_cache_shape_out;
145146
const auto& value_cache_shape_in = layer_state.value_cache_shape_in;
@@ -151,10 +152,10 @@ void WindowedKeyValueCache::SlideLayer(size_t layer_idx) {
151152
int64_t num_key_cache_chunks = key_cache_shape_in[0] * key_cache_shape_in[2];
152153
for (int64_t j = 0; j < num_key_cache_chunks; ++j) {
153154
{
154-
cpu_span<uint8_t> key_cache_dst(key_cache_in_data + j * key_cache_shape_in[3],
155-
key_cache_shape_in[3] - window_size);
156-
cpu_span<uint8_t> key_cache_src(key_cache_in_data + j * key_cache_shape_in[3] + window_size,
157-
key_cache_shape_in[3] - window_size);
155+
cpu_span<uint8_t> key_cache_dst(key_cache_in_data + j * key_cache_shape_in[3] + key_cache_shape_in[3] - seq_len - window_size,
156+
seq_len);
157+
cpu_span<uint8_t> key_cache_src(key_cache_in_data + j * key_cache_shape_in[3] + key_cache_shape_in[3] - seq_len,
158+
seq_len);
158159
std::copy(key_cache_src.begin(), key_cache_src.end(), key_cache_dst.begin());
159160
}
160161
{
@@ -171,11 +172,12 @@ void WindowedKeyValueCache::SlideLayer(size_t layer_idx) {
171172

172173
for (int64_t j = 0; j < value_cache_shape_in[0]; ++j) {
173174
{
174-
cpu_span<uint8_t> value_cache_dst(value_cache_in_data + (j * value_cache_shape_in[2] * value_cache_shape_in[3]),
175-
(value_cache_shape_in[2] - window_size) * value_cache_shape_in[3]);
175+
cpu_span<uint8_t> value_cache_dst(value_cache_in_data + (j * value_cache_shape_in[2] * value_cache_shape_in[3]) +
176+
((value_cache_shape_in[2] - seq_len - window_size) * value_cache_shape_in[3]),
177+
seq_len * value_cache_shape_in[3]);
176178
cpu_span<uint8_t> value_cache_src(value_cache_in_data + (j * value_cache_shape_in[2] * value_cache_shape_in[3]) +
177-
(window_size * value_cache_shape_in[3]),
178-
(value_cache_shape_in[2] - window_size) * value_cache_shape_in[3]);
179+
((value_cache_shape_in[2] - seq_len) * value_cache_shape_in[3]),
180+
seq_len * value_cache_shape_in[3]);
179181
std::copy(value_cache_src.begin(), value_cache_src.end(), value_cache_dst.begin());
180182
}
181183
{
@@ -287,6 +289,7 @@ void WindowedKeyValueCache::TransitionLayerToTokenGeneration(size_t layer_idx) {
287289
value_caches_out_[layer_idx] = OrtValue::CreateTensor(Allocator(), updated_value_cache_shape_out, type_);
288290

289291
// update values in per-layer state
292+
layer_state.window_index = layer_state.window_index * layer_state.window_size / updated_window_size;
290293
layer_state.window_size = updated_window_size;
291294
layer_state.key_cache_shape_in = updated_key_cache_shape_in;
292295
layer_state.value_cache_shape_in = updated_value_cache_shape_in;

0 commit comments

Comments
 (0)