forked from microsoft/onnxruntime-genai
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdecoder_only_pipeline.h
More file actions
104 lines (72 loc) · 3.65 KB
/
decoder_only_pipeline.h
File metadata and controls
104 lines (72 loc) · 3.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <future>
#include <optional>
#include "../worker_thread.h"
#include "model.h"
#include "input_ids.h"
#include "logits.h"
#include "kv_cache.h"
#include "windowed_kv_cache.h"
#include "position_inputs.h"
#include "extra_inputs.h"
namespace Generators {
struct DecoderOnlyPipelineModel : Model {
DecoderOnlyPipelineModel(std::unique_ptr<Config> config, OrtEnv& ort_env);
DecoderOnlyPipelineModel(const DecoderOnlyPipelineModel&) = delete;
DecoderOnlyPipelineModel& operator=(const DecoderOnlyPipelineModel&) = delete;
std::unique_ptr<State> CreateState(DeviceSpan<int32_t> sequence_lengths,
const GeneratorParams& params) const override;
std::vector<std::unique_ptr<OrtSession>> sessions_;
OrtEnv& ort_env_;
};
struct IntermediatePipelineState : State {
IntermediatePipelineState(const DecoderOnlyPipelineModel& model, const GeneratorParams& params,
size_t pipeline_state_index);
IntermediatePipelineState(const IntermediatePipelineState&) = delete;
IntermediatePipelineState& operator=(const IntermediatePipelineState&) = delete;
DeviceSpan<float> Run(int current_length, DeviceSpan<int32_t>& next_tokens,
DeviceSpan<int32_t> next_indices) override;
bool HasInput(std::string_view name) const;
bool HasOutput(std::string_view name) const;
bool SupportsPrimaryDevice() const;
size_t id_;
bool constains_kv_cache_output_ = false; //Specifies if the intermediate pipeline state output conatins present kv cache
private:
const DecoderOnlyPipelineModel& model_;
};
struct DecoderOnlyPipelineState : State {
DecoderOnlyPipelineState(const DecoderOnlyPipelineModel& model, DeviceSpan<int32_t> sequence_lengths,
const GeneratorParams& params);
DecoderOnlyPipelineState(const DecoderOnlyPipelineState&) = delete;
DecoderOnlyPipelineState& operator=(const DecoderOnlyPipelineState&) = delete;
void SetExtraInputs(const std::vector<ExtraInput>& extra_inputs) override;
DeviceSpan<float> Run(int total_length, DeviceSpan<int32_t>& next_tokens,
DeviceSpan<int32_t> next_indices) override;
OrtValue* GetOutput(const char* name) override;
void RunPipeline(int total_length, DeviceSpan<int32_t>& next_tokens,
DeviceSpan<int32_t> next_indices);
private:
void UpdateKeyValueCache(DeviceSpan<int32_t> beam_indices, int total_length);
void UpdateInputsOutputs(DeviceSpan<int32_t>& next_tokens, DeviceSpan<int32_t> next_indices,
int total_length);
const DecoderOnlyPipelineModel& model_;
std::vector<std::unique_ptr<IntermediatePipelineState>> pipeline_states_;
struct PartialKeyValueCacheUpdateRecord {
std::vector<size_t> layer_indices{}; // indicates which layers of the KV cache are to be updated
std::future<void> outstanding_update{}; // future for an outstanding update task
};
std::map<size_t, size_t> pipeline_state_id_to_partial_kv_cache_update_record_idx_;
std::vector<PartialKeyValueCacheUpdateRecord> partial_kv_cache_update_records_;
// Stores all the outputs from the previous pipeline state(s)
std::unordered_map<std::string, std::unique_ptr<OrtValue>> ortvalue_store_;
std::unique_ptr<InputIDs> input_ids_;
Logits logits_{*this};
std::unique_ptr<KeyValueCache> key_value_cache_;
const bool do_key_value_cache_partial_update_;
std::optional<WorkerThread> key_value_cache_update_worker_thread_{};
std::unique_ptr<PositionInputs> position_inputs_;
ExtraInputs extra_inputs_{*this};
};
} // namespace Generators