Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 37 additions & 6 deletions src/models/decoder_only_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,22 +88,50 @@ static NameToLayerIdxMap GeneratePastKeyNameToLayerIdxMap(const Config& config)
return m;
}

static std::vector<size_t> GetLayerIndicesSetFromPastKeyNameInputs(
const NameToLayerIdxMap& past_key_name_to_layer_idx, std::span<const std::string> inputs) {
static NameToLayerIdxMap GeneratePresentKeyNameToLayerIdxMap(const Config& config) {
const size_t num_layers = config.model.decoder.num_hidden_layers;
const std::string& present_key_name_template = config.model.decoder.outputs.present_key_names;
NameToLayerIdxMap m{};
for (size_t i = 0; i < num_layers; ++i) {
m.emplace(ComposeKeyValueName(present_key_name_template, static_cast<int>(i)), i);
}
return m;
}

static std::vector<size_t> GetLayerIndicesSetFromPastAndPresentKeyNames(
const NameToLayerIdxMap& past_key_name_to_layer_idx, const NameToLayerIdxMap& present_key_name_to_layer_idx,
std::span<const std::string> inputs, std::span<const std::string> outputs) {
std::vector<size_t> layer_indices{};
for (const auto& input_name : inputs) {
const auto it = past_key_name_to_layer_idx.find(input_name);
if (it != past_key_name_to_layer_idx.end()) {
layer_indices.push_back(it->second);
}
}
for (const auto& output_name : outputs) {
const auto it = present_key_name_to_layer_idx.find(output_name);
if (it != present_key_name_to_layer_idx.end()) {
layer_indices.push_back(it->second);
}
}
// sort and remove duplicates
std::sort(layer_indices.begin(), layer_indices.end());
layer_indices.erase(std::unique(layer_indices.begin(), layer_indices.end()),
layer_indices.end());
return layer_indices;
}

static bool ContainsPresentKeyNameOutputs(const NameToLayerIdxMap& present_key_name_to_layer_idx,
std::span<const std::string> outputs) {
for (const auto& output_name : outputs) {
const auto it = present_key_name_to_layer_idx.find(output_name);
if (it != present_key_name_to_layer_idx.end()) {
return true;
}
}
return false;
}

DecoderOnlyPipelineState::DecoderOnlyPipelineState(const DecoderOnlyPipelineModel& model,
DeviceSpan<int32_t> sequence_lengths,
const GeneratorParams& params)
Expand All @@ -129,15 +157,18 @@ DecoderOnlyPipelineState::DecoderOnlyPipelineState(const DecoderOnlyPipelineMode

if (do_key_value_cache_partial_update_) {
const auto past_key_name_to_layer_idx = GeneratePastKeyNameToLayerIdxMap(*model_.config_);
const auto present_key_name_to_layer_idx = GeneratePresentKeyNameToLayerIdxMap(*model_.config_);

std::map<std::vector<size_t>, size_t> layer_indices_to_update_record_idx{};
std::unordered_set<size_t> layer_indices_encountered{};

for (size_t i = 0; i < config_pipeline.size(); ++i) {
const auto& pipeline_model = config_pipeline[i];

const auto layer_indices = GetLayerIndicesSetFromPastKeyNameInputs(past_key_name_to_layer_idx,
pipeline_model.inputs);
const auto layer_indices = GetLayerIndicesSetFromPastAndPresentKeyNames(past_key_name_to_layer_idx, present_key_name_to_layer_idx,
pipeline_model.inputs, pipeline_model.outputs);

pipeline_states_[i]->constains_kv_cache_output_ = ContainsPresentKeyNameOutputs(present_key_name_to_layer_idx, pipeline_model.outputs);

if (layer_indices.empty()) {
continue;
Expand Down Expand Up @@ -308,8 +339,8 @@ void DecoderOnlyPipelineState::RunPipeline(int total_length, DeviceSpan<int32_t>
// Run the intermediate pipeline state
pipeline_state->Run(total_length, next_tokens, next_indices);

// If there is any partial KV cache update to start, enqueue it.
if (partial_kv_cache_update_record) {
// If there is any partial KV cache update to start and if KV cache is present in the output, enqueue it.
if (partial_kv_cache_update_record && pipeline_state->constains_kv_cache_output_) {
assert(key_value_cache_update_worker_thread_.has_value());
auto update_fn = [&key_value_cache = *key_value_cache_.get(),
layer_indices = partial_kv_cache_update_record->layer_indices,
Expand Down
1 change: 1 addition & 0 deletions src/models/decoder_only_pipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ struct IntermediatePipelineState : State {
bool SupportsPrimaryDevice() const;

size_t id_;
bool constains_kv_cache_output_ = false; //Specifies if the intermediate pipeline state output conatins present kv cache

private:
const DecoderOnlyPipelineModel& model_;
Expand Down
12 changes: 7 additions & 5 deletions src/models/embeddings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ void WindowedEmbeddings::Update(Embeddings& embeddings) {

const uint16_t* full_data = full_embeddings->GetTensorData<uint16_t>();

const int mem_copy_factor = (type_ == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) ? 2 : 1;

if (window_index_ == 0) {
num_windows_ = (sequence_length + window_size_ - 1) / window_size_;
shape_ = {
Expand All @@ -114,7 +116,7 @@ void WindowedEmbeddings::Update(Embeddings& embeddings) {
embeddings_ = OrtValue::CreateTensor(model_.p_device_inputs_->GetAllocator(), shape_, type_);
std::copy_n(
full_data,
window_size_ * hidden_size * 2,
window_size_ * hidden_size * mem_copy_factor,
embeddings_->GetTensorMutableData<uint16_t>());

} else if (window_index_ < num_windows_) {
Expand All @@ -125,17 +127,17 @@ void WindowedEmbeddings::Update(Embeddings& embeddings) {
};
embeddings_ = OrtValue::CreateTensor(model_.p_device_inputs_->GetAllocator(), shape_, type_);
std::copy_n(
full_data + window_index_ * window_size_ * hidden_size * 2,
window_size_ * hidden_size * 2,
full_data + window_index_ * window_size_ * hidden_size * mem_copy_factor,
window_size_ * hidden_size * mem_copy_factor,
embeddings_->GetTensorMutableData<uint16_t>());

} else {
// Final token case (e.g., generated token)
shape_ = {1, 1, hidden_size};
embeddings_ = OrtValue::CreateTensor(model_.p_device_inputs_->GetAllocator(), shape_, type_);
std::copy_n(
full_data + (sequence_length - 1) * hidden_size * 2,
hidden_size * 2,
full_data + (sequence_length - 1) * hidden_size * mem_copy_factor,
hidden_size * mem_copy_factor,
embeddings_->GetTensorMutableData<uint16_t>());

}
Expand Down
54 changes: 43 additions & 11 deletions src/models/multi_decoder_pipeline_modal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,10 @@ int64_t GetNumAudioTokens(const std::vector<ExtraInput>& extra_inputs,
return 0;
}

int64_t GetImageFeatureBatchSize(const std::vector<ExtraInput>& extra_inputs) {
int64_t GetImageFeatureBatchSize(const std::vector<ExtraInput>& extra_inputs,
const std::string& pixel_values_name) {
for (size_t i = 0; i < extra_inputs.size(); ++i) {
if (extra_inputs[i].name == Config::Defaults::PixelValuesName) {
if (extra_inputs[i].name == pixel_values_name) {
assert(extra_inputs[i].tensor->ort_tensor_);
const auto num_dims = extra_inputs[i].tensor->ort_tensor_->GetTensorTypeAndShapeInfo()->GetShape().size();
if (num_dims < 3) {
Expand Down Expand Up @@ -119,7 +120,7 @@ void VisionPipelineState::SetExtraInputs(const std::vector<ExtraInput>& extra_in
model_.config_->model.vision.outputs.image_features,
num_images_, num_image_tokens_);
for (const auto& ei : extra_inputs) {
if (ei.name == "pixel_values") {
if (ei.name == model_.config_->model.vision.inputs.pixel_values) {
pixel_values_tensor_ = ei.tensor;
break;
}
Expand Down Expand Up @@ -192,7 +193,7 @@ DeviceSpan<float> VisionPipelineState::Run(int current_length,

for (int64_t i = 0; i < total_images; ++i) {
auto pixel_values_i = MakeSingleImagePixelValues(pixel_values_tensor_, i, model_.p_device_);
extra_inputs_.Replace("pixel_values", pixel_values_i);
extra_inputs_.Replace(model_.config_->model.vision.inputs.pixel_values, pixel_values_i);

State::Run(*model_.vision_session_);

Expand Down Expand Up @@ -338,22 +339,50 @@ static NameToLayerIdxMap GeneratePastKeyNameToLayerIdxMap(const Config& config)
return m;
}

static std::vector<size_t> GetLayerIndicesSetFromPastKeyNameInputs(
const NameToLayerIdxMap& past_key_name_to_layer_idx, std::span<const std::string> inputs) {
static NameToLayerIdxMap GeneratePresentKeyNameToLayerIdxMap(const Config& config) {
const size_t num_layers = config.model.decoder.num_hidden_layers;
const std::string& present_key_name_template = config.model.decoder.outputs.present_key_names;
NameToLayerIdxMap m{};
for (size_t i = 0; i < num_layers; ++i) {
m.emplace(ComposeKeyValueName(present_key_name_template, static_cast<int>(i)), i);
}
return m;
}

static std::vector<size_t> GetLayerIndicesSetFromPastAndPresentKeyNames(
const NameToLayerIdxMap& past_key_name_to_layer_idx, const NameToLayerIdxMap& present_key_name_to_layer_idx,
std::span<const std::string> inputs, std::span<const std::string> outputs) {
std::vector<size_t> layer_indices{};
for (const auto& input_name : inputs) {
const auto it = past_key_name_to_layer_idx.find(input_name);
if (it != past_key_name_to_layer_idx.end()) {
layer_indices.push_back(it->second);
}
}
for (const auto& output_name : outputs) {
const auto it = present_key_name_to_layer_idx.find(output_name);
if (it != present_key_name_to_layer_idx.end()) {
layer_indices.push_back(it->second);
}
}
// sort and remove duplicates
std::sort(layer_indices.begin(), layer_indices.end());
layer_indices.erase(std::unique(layer_indices.begin(), layer_indices.end()),
layer_indices.end());
return layer_indices;
}

static bool ContainsPresentKeyNameOutputs(const NameToLayerIdxMap& present_key_name_to_layer_idx,
std::span<const std::string> outputs) {
for (const auto& output_name : outputs) {
const auto it = present_key_name_to_layer_idx.find(output_name);
if (it != present_key_name_to_layer_idx.end()) {
return true;
}
}
return false;
}

DecoderPipelineState::DecoderPipelineState(const MultiModalPipelineLanguageModel& model,
DeviceSpan<int32_t> sequence_lengths,
const GeneratorParams& params)
Expand Down Expand Up @@ -384,15 +413,18 @@ DecoderPipelineState::DecoderPipelineState(const MultiModalPipelineLanguageModel

if (do_key_value_cache_partial_update_) {
const auto past_key_name_to_layer_idx = GeneratePastKeyNameToLayerIdxMap(*model_.config_);
const auto present_key_name_to_layer_idx = GeneratePresentKeyNameToLayerIdxMap(*model_.config_);

std::map<std::vector<size_t>, size_t> layer_indices_to_update_record_idx{};
std::unordered_set<size_t> layer_indices_encountered{};

for (size_t i = 0; i < config_pipeline.size(); ++i) {
const auto& pipeline_model = config_pipeline[i];

const auto layer_indices = GetLayerIndicesSetFromPastKeyNameInputs(past_key_name_to_layer_idx,
pipeline_model.inputs);
const auto layer_indices = GetLayerIndicesSetFromPastAndPresentKeyNames(past_key_name_to_layer_idx, present_key_name_to_layer_idx,
pipeline_model.inputs, pipeline_model.outputs);

pipeline_states_[i]->constains_kv_cache_output_ = ContainsPresentKeyNameOutputs(present_key_name_to_layer_idx, pipeline_model.outputs);

if (layer_indices.empty()) {
continue;
Expand Down Expand Up @@ -560,8 +592,8 @@ void DecoderPipelineState::RunPipeline(int total_length, DeviceSpan<int32_t>& ne
// Run the intermediate pipeline state
pipeline_state->Run(total_length, next_tokens, next_indices);

// If there is any partial KV cache update to start, enqueue it.
if (partial_kv_cache_update_record) {
// If there is any partial KV cache update to start and if KV cache is present in the output, enqueue it.
if (partial_kv_cache_update_record && pipeline_state->constains_kv_cache_output_) {
assert(key_value_cache_update_worker_thread_.has_value());
auto update_fn = [&key_value_cache = *key_value_cache_.get(),
layer_indices = partial_kv_cache_update_record->layer_indices,
Expand Down Expand Up @@ -702,7 +734,7 @@ MultiModalDecoderPipelineState::MultiModalDecoderPipelineState(const MultiModalP
void MultiModalDecoderPipelineState::SetExtraInputs(const std::vector<ExtraInput>& extra_inputs) {
num_image_tokens_ = GetNumImageTokens(extra_inputs);
num_audio_tokens_ = GetNumAudioTokens(extra_inputs, model_.config_->model.speech.inputs.audio_sizes);
num_images_ = GetImageFeatureBatchSize(extra_inputs);
num_images_ = GetImageFeatureBatchSize(extra_inputs, model_.config_->model.vision.inputs.pixel_values);

if (model_.vision_session_) {
vision_state_->SetExtraInputs(extra_inputs, num_images_, num_image_tokens_);
Expand Down
1 change: 1 addition & 0 deletions src/models/multi_decoder_pipeline_modal.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ struct IntermediateDecoderPipelineState : State {
bool SupportsPrimaryDevice() const;

size_t id_;
bool constains_kv_cache_output_ = false; //Specifies if the intermediate pipeline state output conatins present kv cache

private:
const MultiModalPipelineLanguageModel& model_;
Expand Down
19 changes: 11 additions & 8 deletions src/models/windowed_kv_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ void WindowedKeyValueCache::SlideLayer(size_t layer_idx) {
const auto& layer_state = per_layer_states_[layer_idx];

const auto window_size = layer_state.window_size;
const auto seq_len = layer_state.window_index * layer_state.window_size;
const auto& key_cache_shape_in = layer_state.key_cache_shape_in;
const auto& key_cache_shape_out = layer_state.key_cache_shape_out;
const auto& value_cache_shape_in = layer_state.value_cache_shape_in;
Expand All @@ -151,10 +152,10 @@ void WindowedKeyValueCache::SlideLayer(size_t layer_idx) {
int64_t num_key_cache_chunks = key_cache_shape_in[0] * key_cache_shape_in[2];
for (int64_t j = 0; j < num_key_cache_chunks; ++j) {
{
cpu_span<uint8_t> key_cache_dst(key_cache_in_data + j * key_cache_shape_in[3],
key_cache_shape_in[3] - window_size);
cpu_span<uint8_t> key_cache_src(key_cache_in_data + j * key_cache_shape_in[3] + window_size,
key_cache_shape_in[3] - window_size);
cpu_span<uint8_t> key_cache_dst(key_cache_in_data + j * key_cache_shape_in[3] + key_cache_shape_in[3] - seq_len - window_size,
seq_len);
cpu_span<uint8_t> key_cache_src(key_cache_in_data + j * key_cache_shape_in[3] + key_cache_shape_in[3] - seq_len,
seq_len);
std::copy(key_cache_src.begin(), key_cache_src.end(), key_cache_dst.begin());
}
{
Expand All @@ -171,11 +172,12 @@ void WindowedKeyValueCache::SlideLayer(size_t layer_idx) {

for (int64_t j = 0; j < value_cache_shape_in[0]; ++j) {
{
cpu_span<uint8_t> value_cache_dst(value_cache_in_data + (j * value_cache_shape_in[2] * value_cache_shape_in[3]),
(value_cache_shape_in[2] - window_size) * value_cache_shape_in[3]);
cpu_span<uint8_t> value_cache_dst(value_cache_in_data + (j * value_cache_shape_in[2] * value_cache_shape_in[3]) +
((value_cache_shape_in[2] - seq_len - window_size) * value_cache_shape_in[3]),
seq_len * value_cache_shape_in[3]);
cpu_span<uint8_t> value_cache_src(value_cache_in_data + (j * value_cache_shape_in[2] * value_cache_shape_in[3]) +
(window_size * value_cache_shape_in[3]),
(value_cache_shape_in[2] - window_size) * value_cache_shape_in[3]);
((value_cache_shape_in[2] - seq_len) * value_cache_shape_in[3]),
seq_len * value_cache_shape_in[3]);
std::copy(value_cache_src.begin(), value_cache_src.end(), value_cache_dst.begin());
}
{
Expand Down Expand Up @@ -287,6 +289,7 @@ void WindowedKeyValueCache::TransitionLayerToTokenGeneration(size_t layer_idx) {
value_caches_out_[layer_idx] = OrtValue::CreateTensor(Allocator(), updated_value_cache_shape_out, type_);

// update values in per-layer state
layer_state.window_index = layer_state.window_index * layer_state.window_size / updated_window_size;
layer_state.window_size = updated_window_size;
layer_state.key_cache_shape_in = updated_key_cache_shape_in;
layer_state.value_cache_shape_in = updated_value_cache_shape_in;
Expand Down
Loading