diff --git a/src/models/decoder_only_pipeline.cpp b/src/models/decoder_only_pipeline.cpp index 921a5ba7d7..9118ee59db 100644 --- a/src/models/decoder_only_pipeline.cpp +++ b/src/models/decoder_only_pipeline.cpp @@ -88,8 +88,19 @@ static NameToLayerIdxMap GeneratePastKeyNameToLayerIdxMap(const Config& config) return m; } -static std::vector GetLayerIndicesSetFromPastKeyNameInputs( - const NameToLayerIdxMap& past_key_name_to_layer_idx, std::span inputs) { +static NameToLayerIdxMap GeneratePresentKeyNameToLayerIdxMap(const Config& config) { + const size_t num_layers = config.model.decoder.num_hidden_layers; + const std::string& present_key_name_template = config.model.decoder.outputs.present_key_names; + NameToLayerIdxMap m{}; + for (size_t i = 0; i < num_layers; ++i) { + m.emplace(ComposeKeyValueName(present_key_name_template, static_cast(i)), i); + } + return m; +} + +static std::vector GetLayerIndicesSetFromPastAndPresentKeyNames( + const NameToLayerIdxMap& past_key_name_to_layer_idx, const NameToLayerIdxMap& present_key_name_to_layer_idx, + std::span inputs, std::span outputs) { std::vector layer_indices{}; for (const auto& input_name : inputs) { const auto it = past_key_name_to_layer_idx.find(input_name); @@ -97,6 +108,12 @@ static std::vector GetLayerIndicesSetFromPastKeyNameInputs( layer_indices.push_back(it->second); } } + for (const auto& output_name : outputs) { + const auto it = present_key_name_to_layer_idx.find(output_name); + if (it != present_key_name_to_layer_idx.end()) { + layer_indices.push_back(it->second); + } + } // sort and remove duplicates std::sort(layer_indices.begin(), layer_indices.end()); layer_indices.erase(std::unique(layer_indices.begin(), layer_indices.end()), @@ -104,6 +121,17 @@ static std::vector GetLayerIndicesSetFromPastKeyNameInputs( return layer_indices; } +static bool ContainsPresentKeyNameOutputs(const NameToLayerIdxMap& present_key_name_to_layer_idx, + std::span outputs) { + for (const auto& output_name : outputs) { + const auto it = present_key_name_to_layer_idx.find(output_name); + if (it != present_key_name_to_layer_idx.end()) { + return true; + } + } + return false; +} + DecoderOnlyPipelineState::DecoderOnlyPipelineState(const DecoderOnlyPipelineModel& model, DeviceSpan sequence_lengths, const GeneratorParams& params) @@ -129,6 +157,7 @@ DecoderOnlyPipelineState::DecoderOnlyPipelineState(const DecoderOnlyPipelineMode if (do_key_value_cache_partial_update_) { const auto past_key_name_to_layer_idx = GeneratePastKeyNameToLayerIdxMap(*model_.config_); + const auto present_key_name_to_layer_idx = GeneratePresentKeyNameToLayerIdxMap(*model_.config_); std::map, size_t> layer_indices_to_update_record_idx{}; std::unordered_set layer_indices_encountered{}; @@ -136,8 +165,10 @@ DecoderOnlyPipelineState::DecoderOnlyPipelineState(const DecoderOnlyPipelineMode for (size_t i = 0; i < config_pipeline.size(); ++i) { const auto& pipeline_model = config_pipeline[i]; - const auto layer_indices = GetLayerIndicesSetFromPastKeyNameInputs(past_key_name_to_layer_idx, - pipeline_model.inputs); + const auto layer_indices = GetLayerIndicesSetFromPastAndPresentKeyNames(past_key_name_to_layer_idx, present_key_name_to_layer_idx, + pipeline_model.inputs, pipeline_model.outputs); + + pipeline_states_[i]->constains_kv_cache_output_ = ContainsPresentKeyNameOutputs(present_key_name_to_layer_idx, pipeline_model.outputs); if (layer_indices.empty()) { continue; @@ -308,8 +339,8 @@ void DecoderOnlyPipelineState::RunPipeline(int total_length, DeviceSpan // Run the intermediate pipeline state pipeline_state->Run(total_length, next_tokens, next_indices); - // If there is any partial KV cache update to start, enqueue it. - if (partial_kv_cache_update_record) { + // If there is any partial KV cache update to start and if KV cache is present in the output, enqueue it. + if (partial_kv_cache_update_record && pipeline_state->constains_kv_cache_output_) { assert(key_value_cache_update_worker_thread_.has_value()); auto update_fn = [&key_value_cache = *key_value_cache_.get(), layer_indices = partial_kv_cache_update_record->layer_indices, diff --git a/src/models/decoder_only_pipeline.h b/src/models/decoder_only_pipeline.h index 167bec9c56..2968a7b7dd 100644 --- a/src/models/decoder_only_pipeline.h +++ b/src/models/decoder_only_pipeline.h @@ -47,6 +47,7 @@ struct IntermediatePipelineState : State { bool SupportsPrimaryDevice() const; size_t id_; + bool constains_kv_cache_output_ = false; //Specifies if the intermediate pipeline state output conatins present kv cache private: const DecoderOnlyPipelineModel& model_; diff --git a/src/models/embeddings.cpp b/src/models/embeddings.cpp index b452734e96..69a6ddade1 100644 --- a/src/models/embeddings.cpp +++ b/src/models/embeddings.cpp @@ -103,6 +103,8 @@ void WindowedEmbeddings::Update(Embeddings& embeddings) { const uint16_t* full_data = full_embeddings->GetTensorData(); + const int mem_copy_factor = (type_ == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) ? 2 : 1; + if (window_index_ == 0) { num_windows_ = (sequence_length + window_size_ - 1) / window_size_; shape_ = { @@ -114,7 +116,7 @@ void WindowedEmbeddings::Update(Embeddings& embeddings) { embeddings_ = OrtValue::CreateTensor(model_.p_device_inputs_->GetAllocator(), shape_, type_); std::copy_n( full_data, - window_size_ * hidden_size * 2, + window_size_ * hidden_size * mem_copy_factor, embeddings_->GetTensorMutableData()); } else if (window_index_ < num_windows_) { @@ -125,8 +127,8 @@ void WindowedEmbeddings::Update(Embeddings& embeddings) { }; embeddings_ = OrtValue::CreateTensor(model_.p_device_inputs_->GetAllocator(), shape_, type_); std::copy_n( - full_data + window_index_ * window_size_ * hidden_size * 2, - window_size_ * hidden_size * 2, + full_data + window_index_ * window_size_ * hidden_size * mem_copy_factor, + window_size_ * hidden_size * mem_copy_factor, embeddings_->GetTensorMutableData()); } else { @@ -134,8 +136,8 @@ void WindowedEmbeddings::Update(Embeddings& embeddings) { shape_ = {1, 1, hidden_size}; embeddings_ = OrtValue::CreateTensor(model_.p_device_inputs_->GetAllocator(), shape_, type_); std::copy_n( - full_data + (sequence_length - 1) * hidden_size * 2, - hidden_size * 2, + full_data + (sequence_length - 1) * hidden_size * mem_copy_factor, + hidden_size * mem_copy_factor, embeddings_->GetTensorMutableData()); } diff --git a/src/models/multi_decoder_pipeline_modal.cpp b/src/models/multi_decoder_pipeline_modal.cpp index c6b7e44ad2..e89e9275af 100644 --- a/src/models/multi_decoder_pipeline_modal.cpp +++ b/src/models/multi_decoder_pipeline_modal.cpp @@ -44,9 +44,10 @@ int64_t GetNumAudioTokens(const std::vector& extra_inputs, return 0; } -int64_t GetImageFeatureBatchSize(const std::vector& extra_inputs) { +int64_t GetImageFeatureBatchSize(const std::vector& extra_inputs, + const std::string& pixel_values_name) { for (size_t i = 0; i < extra_inputs.size(); ++i) { - if (extra_inputs[i].name == Config::Defaults::PixelValuesName) { + if (extra_inputs[i].name == pixel_values_name) { assert(extra_inputs[i].tensor->ort_tensor_); const auto num_dims = extra_inputs[i].tensor->ort_tensor_->GetTensorTypeAndShapeInfo()->GetShape().size(); if (num_dims < 3) { @@ -119,7 +120,7 @@ void VisionPipelineState::SetExtraInputs(const std::vector& extra_in model_.config_->model.vision.outputs.image_features, num_images_, num_image_tokens_); for (const auto& ei : extra_inputs) { - if (ei.name == "pixel_values") { + if (ei.name == model_.config_->model.vision.inputs.pixel_values) { pixel_values_tensor_ = ei.tensor; break; } @@ -192,7 +193,7 @@ DeviceSpan VisionPipelineState::Run(int current_length, for (int64_t i = 0; i < total_images; ++i) { auto pixel_values_i = MakeSingleImagePixelValues(pixel_values_tensor_, i, model_.p_device_); - extra_inputs_.Replace("pixel_values", pixel_values_i); + extra_inputs_.Replace(model_.config_->model.vision.inputs.pixel_values, pixel_values_i); State::Run(*model_.vision_session_); @@ -338,8 +339,19 @@ static NameToLayerIdxMap GeneratePastKeyNameToLayerIdxMap(const Config& config) return m; } -static std::vector GetLayerIndicesSetFromPastKeyNameInputs( - const NameToLayerIdxMap& past_key_name_to_layer_idx, std::span inputs) { +static NameToLayerIdxMap GeneratePresentKeyNameToLayerIdxMap(const Config& config) { + const size_t num_layers = config.model.decoder.num_hidden_layers; + const std::string& present_key_name_template = config.model.decoder.outputs.present_key_names; + NameToLayerIdxMap m{}; + for (size_t i = 0; i < num_layers; ++i) { + m.emplace(ComposeKeyValueName(present_key_name_template, static_cast(i)), i); + } + return m; +} + +static std::vector GetLayerIndicesSetFromPastAndPresentKeyNames( + const NameToLayerIdxMap& past_key_name_to_layer_idx, const NameToLayerIdxMap& present_key_name_to_layer_idx, + std::span inputs, std::span outputs) { std::vector layer_indices{}; for (const auto& input_name : inputs) { const auto it = past_key_name_to_layer_idx.find(input_name); @@ -347,6 +359,12 @@ static std::vector GetLayerIndicesSetFromPastKeyNameInputs( layer_indices.push_back(it->second); } } + for (const auto& output_name : outputs) { + const auto it = present_key_name_to_layer_idx.find(output_name); + if (it != present_key_name_to_layer_idx.end()) { + layer_indices.push_back(it->second); + } + } // sort and remove duplicates std::sort(layer_indices.begin(), layer_indices.end()); layer_indices.erase(std::unique(layer_indices.begin(), layer_indices.end()), @@ -354,6 +372,17 @@ static std::vector GetLayerIndicesSetFromPastKeyNameInputs( return layer_indices; } +static bool ContainsPresentKeyNameOutputs(const NameToLayerIdxMap& present_key_name_to_layer_idx, + std::span outputs) { + for (const auto& output_name : outputs) { + const auto it = present_key_name_to_layer_idx.find(output_name); + if (it != present_key_name_to_layer_idx.end()) { + return true; + } + } + return false; +} + DecoderPipelineState::DecoderPipelineState(const MultiModalPipelineLanguageModel& model, DeviceSpan sequence_lengths, const GeneratorParams& params) @@ -384,6 +413,7 @@ DecoderPipelineState::DecoderPipelineState(const MultiModalPipelineLanguageModel if (do_key_value_cache_partial_update_) { const auto past_key_name_to_layer_idx = GeneratePastKeyNameToLayerIdxMap(*model_.config_); + const auto present_key_name_to_layer_idx = GeneratePresentKeyNameToLayerIdxMap(*model_.config_); std::map, size_t> layer_indices_to_update_record_idx{}; std::unordered_set layer_indices_encountered{}; @@ -391,8 +421,10 @@ DecoderPipelineState::DecoderPipelineState(const MultiModalPipelineLanguageModel for (size_t i = 0; i < config_pipeline.size(); ++i) { const auto& pipeline_model = config_pipeline[i]; - const auto layer_indices = GetLayerIndicesSetFromPastKeyNameInputs(past_key_name_to_layer_idx, - pipeline_model.inputs); + const auto layer_indices = GetLayerIndicesSetFromPastAndPresentKeyNames(past_key_name_to_layer_idx, present_key_name_to_layer_idx, + pipeline_model.inputs, pipeline_model.outputs); + + pipeline_states_[i]->constains_kv_cache_output_ = ContainsPresentKeyNameOutputs(present_key_name_to_layer_idx, pipeline_model.outputs); if (layer_indices.empty()) { continue; @@ -560,8 +592,8 @@ void DecoderPipelineState::RunPipeline(int total_length, DeviceSpan& ne // Run the intermediate pipeline state pipeline_state->Run(total_length, next_tokens, next_indices); - // If there is any partial KV cache update to start, enqueue it. - if (partial_kv_cache_update_record) { + // If there is any partial KV cache update to start and if KV cache is present in the output, enqueue it. + if (partial_kv_cache_update_record && pipeline_state->constains_kv_cache_output_) { assert(key_value_cache_update_worker_thread_.has_value()); auto update_fn = [&key_value_cache = *key_value_cache_.get(), layer_indices = partial_kv_cache_update_record->layer_indices, @@ -702,7 +734,7 @@ MultiModalDecoderPipelineState::MultiModalDecoderPipelineState(const MultiModalP void MultiModalDecoderPipelineState::SetExtraInputs(const std::vector& extra_inputs) { num_image_tokens_ = GetNumImageTokens(extra_inputs); num_audio_tokens_ = GetNumAudioTokens(extra_inputs, model_.config_->model.speech.inputs.audio_sizes); - num_images_ = GetImageFeatureBatchSize(extra_inputs); + num_images_ = GetImageFeatureBatchSize(extra_inputs, model_.config_->model.vision.inputs.pixel_values); if (model_.vision_session_) { vision_state_->SetExtraInputs(extra_inputs, num_images_, num_image_tokens_); diff --git a/src/models/multi_decoder_pipeline_modal.h b/src/models/multi_decoder_pipeline_modal.h index 5793eaf248..eb2543fe17 100644 --- a/src/models/multi_decoder_pipeline_modal.h +++ b/src/models/multi_decoder_pipeline_modal.h @@ -105,6 +105,7 @@ struct IntermediateDecoderPipelineState : State { bool SupportsPrimaryDevice() const; size_t id_; + bool constains_kv_cache_output_ = false; //Specifies if the intermediate pipeline state output conatins present kv cache private: const MultiModalPipelineLanguageModel& model_; diff --git a/src/models/windowed_kv_cache.cpp b/src/models/windowed_kv_cache.cpp index 24e0114fed..d7eccc1a06 100644 --- a/src/models/windowed_kv_cache.cpp +++ b/src/models/windowed_kv_cache.cpp @@ -140,6 +140,7 @@ void WindowedKeyValueCache::SlideLayer(size_t layer_idx) { const auto& layer_state = per_layer_states_[layer_idx]; const auto window_size = layer_state.window_size; + const auto seq_len = layer_state.window_index * layer_state.window_size; const auto& key_cache_shape_in = layer_state.key_cache_shape_in; const auto& key_cache_shape_out = layer_state.key_cache_shape_out; const auto& value_cache_shape_in = layer_state.value_cache_shape_in; @@ -151,10 +152,10 @@ void WindowedKeyValueCache::SlideLayer(size_t layer_idx) { int64_t num_key_cache_chunks = key_cache_shape_in[0] * key_cache_shape_in[2]; for (int64_t j = 0; j < num_key_cache_chunks; ++j) { { - cpu_span key_cache_dst(key_cache_in_data + j * key_cache_shape_in[3], - key_cache_shape_in[3] - window_size); - cpu_span key_cache_src(key_cache_in_data + j * key_cache_shape_in[3] + window_size, - key_cache_shape_in[3] - window_size); + cpu_span key_cache_dst(key_cache_in_data + j * key_cache_shape_in[3] + key_cache_shape_in[3] - seq_len - window_size, + seq_len); + cpu_span key_cache_src(key_cache_in_data + j * key_cache_shape_in[3] + key_cache_shape_in[3] - seq_len, + seq_len); std::copy(key_cache_src.begin(), key_cache_src.end(), key_cache_dst.begin()); } { @@ -171,11 +172,12 @@ void WindowedKeyValueCache::SlideLayer(size_t layer_idx) { for (int64_t j = 0; j < value_cache_shape_in[0]; ++j) { { - cpu_span value_cache_dst(value_cache_in_data + (j * value_cache_shape_in[2] * value_cache_shape_in[3]), - (value_cache_shape_in[2] - window_size) * value_cache_shape_in[3]); + cpu_span value_cache_dst(value_cache_in_data + (j * value_cache_shape_in[2] * value_cache_shape_in[3]) + + ((value_cache_shape_in[2] - seq_len - window_size) * value_cache_shape_in[3]), + seq_len * value_cache_shape_in[3]); cpu_span value_cache_src(value_cache_in_data + (j * value_cache_shape_in[2] * value_cache_shape_in[3]) + - (window_size * value_cache_shape_in[3]), - (value_cache_shape_in[2] - window_size) * value_cache_shape_in[3]); + ((value_cache_shape_in[2] - seq_len) * value_cache_shape_in[3]), + seq_len * value_cache_shape_in[3]); std::copy(value_cache_src.begin(), value_cache_src.end(), value_cache_dst.begin()); } { @@ -287,6 +289,7 @@ void WindowedKeyValueCache::TransitionLayerToTokenGeneration(size_t layer_idx) { value_caches_out_[layer_idx] = OrtValue::CreateTensor(Allocator(), updated_value_cache_shape_out, type_); // update values in per-layer state + layer_state.window_index = layer_state.window_index * layer_state.window_size / updated_window_size; layer_state.window_size = updated_window_size; layer_state.key_cache_shape_in = updated_key_cache_shape_in; layer_state.value_cache_shape_in = updated_value_cache_shape_in;