@@ -44,9 +44,10 @@ int64_t GetNumAudioTokens(const std::vector<ExtraInput>& extra_inputs,
4444 return 0 ;
4545}
4646
47- int64_t GetImageFeatureBatchSize (const std::vector<ExtraInput>& extra_inputs) {
47+ int64_t GetImageFeatureBatchSize (const std::vector<ExtraInput>& extra_inputs,
48+ const std::string& pixel_values_name) {
4849 for (size_t i = 0 ; i < extra_inputs.size (); ++i) {
49- if (extra_inputs[i].name == Config::Defaults::PixelValuesName ) {
50+ if (extra_inputs[i].name == pixel_values_name ) {
5051 assert (extra_inputs[i].tensor ->ort_tensor_ );
5152 const auto num_dims = extra_inputs[i].tensor ->ort_tensor_ ->GetTensorTypeAndShapeInfo ()->GetShape ().size ();
5253 if (num_dims < 3 ) {
@@ -119,7 +120,7 @@ void VisionPipelineState::SetExtraInputs(const std::vector<ExtraInput>& extra_in
119120 model_.config_ ->model .vision .outputs .image_features ,
120121 num_images_, num_image_tokens_);
121122 for (const auto & ei : extra_inputs) {
122- if (ei.name == " pixel_values" ) {
123+ if (ei.name == model_. config_ -> model . vision . inputs . pixel_values ) {
123124 pixel_values_tensor_ = ei.tensor ;
124125 break ;
125126 }
@@ -192,7 +193,7 @@ DeviceSpan<float> VisionPipelineState::Run(int current_length,
192193
193194 for (int64_t i = 0 ; i < total_images; ++i) {
194195 auto pixel_values_i = MakeSingleImagePixelValues (pixel_values_tensor_, i, model_.p_device_ );
195- extra_inputs_.Replace (" pixel_values" , pixel_values_i);
196+ extra_inputs_.Replace (model_. config_ -> model . vision . inputs . pixel_values , pixel_values_i);
196197
197198 State::Run (*model_.vision_session_ );
198199
@@ -338,22 +339,50 @@ static NameToLayerIdxMap GeneratePastKeyNameToLayerIdxMap(const Config& config)
338339 return m;
339340}
340341
341- static std::vector<size_t > GetLayerIndicesSetFromPastKeyNameInputs (
342- const NameToLayerIdxMap& past_key_name_to_layer_idx, std::span<const std::string> inputs) {
342+ static NameToLayerIdxMap GeneratePresentKeyNameToLayerIdxMap (const Config& config) {
343+ const size_t num_layers = config.model .decoder .num_hidden_layers ;
344+ const std::string& present_key_name_template = config.model .decoder .outputs .present_key_names ;
345+ NameToLayerIdxMap m{};
346+ for (size_t i = 0 ; i < num_layers; ++i) {
347+ m.emplace (ComposeKeyValueName (present_key_name_template, static_cast <int >(i)), i);
348+ }
349+ return m;
350+ }
351+
352+ static std::vector<size_t > GetLayerIndicesSetFromPastAndPresentKeyNames (
353+ const NameToLayerIdxMap& past_key_name_to_layer_idx, const NameToLayerIdxMap& present_key_name_to_layer_idx,
354+ std::span<const std::string> inputs, std::span<const std::string> outputs) {
343355 std::vector<size_t > layer_indices{};
344356 for (const auto & input_name : inputs) {
345357 const auto it = past_key_name_to_layer_idx.find (input_name);
346358 if (it != past_key_name_to_layer_idx.end ()) {
347359 layer_indices.push_back (it->second );
348360 }
349361 }
362+ for (const auto & output_name : outputs) {
363+ const auto it = present_key_name_to_layer_idx.find (output_name);
364+ if (it != present_key_name_to_layer_idx.end ()) {
365+ layer_indices.push_back (it->second );
366+ }
367+ }
350368 // sort and remove duplicates
351369 std::sort (layer_indices.begin (), layer_indices.end ());
352370 layer_indices.erase (std::unique (layer_indices.begin (), layer_indices.end ()),
353371 layer_indices.end ());
354372 return layer_indices;
355373}
356374
375+ static bool ContainsPresentKeyNameOutputs (const NameToLayerIdxMap& present_key_name_to_layer_idx,
376+ std::span<const std::string> outputs) {
377+ for (const auto & output_name : outputs) {
378+ const auto it = present_key_name_to_layer_idx.find (output_name);
379+ if (it != present_key_name_to_layer_idx.end ()) {
380+ return true ;
381+ }
382+ }
383+ return false ;
384+ }
385+
357386DecoderPipelineState::DecoderPipelineState (const MultiModalPipelineLanguageModel& model,
358387 DeviceSpan<int32_t > sequence_lengths,
359388 const GeneratorParams& params)
@@ -384,15 +413,18 @@ DecoderPipelineState::DecoderPipelineState(const MultiModalPipelineLanguageModel
384413
385414 if (do_key_value_cache_partial_update_) {
386415 const auto past_key_name_to_layer_idx = GeneratePastKeyNameToLayerIdxMap (*model_.config_ );
416+ const auto present_key_name_to_layer_idx = GeneratePresentKeyNameToLayerIdxMap (*model_.config_ );
387417
388418 std::map<std::vector<size_t >, size_t > layer_indices_to_update_record_idx{};
389419 std::unordered_set<size_t > layer_indices_encountered{};
390420
391421 for (size_t i = 0 ; i < config_pipeline.size (); ++i) {
392422 const auto & pipeline_model = config_pipeline[i];
393423
394- const auto layer_indices = GetLayerIndicesSetFromPastKeyNameInputs (past_key_name_to_layer_idx,
395- pipeline_model.inputs );
424+ const auto layer_indices = GetLayerIndicesSetFromPastAndPresentKeyNames (past_key_name_to_layer_idx, present_key_name_to_layer_idx,
425+ pipeline_model.inputs , pipeline_model.outputs );
426+
427+ pipeline_states_[i]->constains_kv_cache_output_ = ContainsPresentKeyNameOutputs (present_key_name_to_layer_idx, pipeline_model.outputs );
396428
397429 if (layer_indices.empty ()) {
398430 continue ;
@@ -560,8 +592,8 @@ void DecoderPipelineState::RunPipeline(int total_length, DeviceSpan<int32_t>& ne
560592 // Run the intermediate pipeline state
561593 pipeline_state->Run (total_length, next_tokens, next_indices);
562594
563- // If there is any partial KV cache update to start, enqueue it.
564- if (partial_kv_cache_update_record) {
595+ // If there is any partial KV cache update to start and if KV cache is present in the output , enqueue it.
596+ if (partial_kv_cache_update_record && pipeline_state-> constains_kv_cache_output_ ) {
565597 assert (key_value_cache_update_worker_thread_.has_value ());
566598 auto update_fn = [&key_value_cache = *key_value_cache_.get (),
567599 layer_indices = partial_kv_cache_update_record->layer_indices ,
@@ -702,7 +734,7 @@ MultiModalDecoderPipelineState::MultiModalDecoderPipelineState(const MultiModalP
702734void MultiModalDecoderPipelineState::SetExtraInputs (const std::vector<ExtraInput>& extra_inputs) {
703735 num_image_tokens_ = GetNumImageTokens (extra_inputs);
704736 num_audio_tokens_ = GetNumAudioTokens (extra_inputs, model_.config_ ->model .speech .inputs .audio_sizes );
705- num_images_ = GetImageFeatureBatchSize (extra_inputs);
737+ num_images_ = GetImageFeatureBatchSize (extra_inputs, model_. config_ -> model . vision . inputs . pixel_values );
706738
707739 if (model_.vision_session_ ) {
708740 vision_state_->SetExtraInputs (extra_inputs, num_images_, num_image_tokens_);
0 commit comments