forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmultimodal_runner.cpp
More file actions
246 lines (212 loc) · 7.57 KB
/
multimodal_runner.cpp
File metadata and controls
246 lines (212 loc) · 7.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
// Implementation of MultimodalRunner for multimodal input and text output LLMs
#include <executorch/extension/llm/runner/constants.h>
#include <executorch/extension/llm/runner/multimodal_runner.h>
#include <executorch/extension/llm/runner/util.h>
#include <executorch/runtime/platform/runtime.h>
#include <pytorch/tokenizers/hf_tokenizer.h>
#include <pytorch/tokenizers/sentencepiece.h>
#ifdef CUDA_AVAILABLE
#include <executorch/backends/cuda/runtime/memory_tracker.h>
#endif
namespace executorch::extension::llm {
using ::executorch::extension::Module;
using ::executorch::runtime::Error;
using ::executorch::runtime::Result;
MultimodalRunner::MultimodalRunner(
std::unordered_map<std::string, int64_t> metadata,
std::unique_ptr<::tokenizers::Tokenizer> tokenizer,
std::unique_ptr<Module> module,
std::unique_ptr<MultimodalDecoderRunner> text_decoder_runner,
std::unique_ptr<MultimodalPrefiller> multimodal_prefiller,
std::unique_ptr<IOManager> io_manager,
std::unique_ptr<TextTokenGenerator> text_token_generator,
std::unique_ptr<Stats> stats)
: metadata_(std::move(metadata)),
tokenizer_(std::move(tokenizer)),
module_(std::move(module)),
text_decoder_runner_(std::move(text_decoder_runner)),
multimodal_prefiller_(std::move(multimodal_prefiller)),
io_manager_(std::move(io_manager)),
text_token_generator_(std::move(text_token_generator)),
stats_(std::move(stats)),
pos_(0) {
#ifdef CUDA_AVAILABLE
cuda_memory_tracker_ =
std::make_unique<::executorch::backends::cuda::CudaMemoryTracker>();
// Probe immediately after creating the tracker to capture GPU state before
// any model loading happens.
stats_->gpu_total_bytes = cuda_memory_tracker_->total_bytes();
stats_->gpu_free_before_load_bytes = cuda_memory_tracker_->last_free_bytes();
#endif
}
bool MultimodalRunner::is_loaded() {
return multimodal_prefiller_->is_method_loaded() &&
text_token_generator_->is_loaded();
}
Error MultimodalRunner::load() {
if (is_loaded()) {
return Error::Ok;
}
stats_->model_load_start_ms = time_in_ms();
ET_CHECK_OK_OR_RETURN_ERROR(multimodal_prefiller_->load());
ET_CHECK_OK_OR_RETURN_ERROR(text_token_generator_->load());
stats_->model_load_end_ms = time_in_ms();
#ifdef CUDA_AVAILABLE
cuda_memory_tracker_->log_sample("after_load");
stats_->gpu_total_bytes = cuda_memory_tracker_->total_bytes();
stats_->gpu_free_after_load_bytes = cuda_memory_tracker_->last_free_bytes();
stats_->gpu_peak_usage_mb = cuda_memory_tracker_->peak_usage_mb();
#endif
return Error::Ok;
}
// Don't print with the same priority during warmup
#define RUNNER_ET_LOG(warmup, format, ...) \
if (warmup) { \
ET_LOG(Debug, format, __VA_ARGS__); \
} else { \
ET_LOG(Info, format, __VA_ARGS__); \
}
Error MultimodalRunner::prefill(const std::vector<MultimodalInput>& inputs) {
if (!is_loaded()) {
ET_CHECK_OK_OR_RETURN_ERROR(load());
}
for (auto& input : inputs) {
auto prefill_result = multimodal_prefiller_->prefill(input, pos_);
if (!prefill_result.ok()) {
return prefill_result.error();
}
}
return Error::Ok;
}
Error MultimodalRunner::generate(
const std::vector<MultimodalInput>& inputs,
const GenerationConfig& config,
std::function<void(const std::string&)> token_callback,
std::function<void(const Stats&)> stats_callback) {
if (inputs.empty()) {
ET_LOG(Error, "MultimodalInput vector cannot be empty");
return Error::InvalidArgument;
}
if (!is_loaded()) {
ET_CHECK_OK_OR_RETURN_ERROR(load());
}
if (config.warming) {
ET_LOG(Info, "Doing a warmup run...");
}
RUNNER_ET_LOG(
config.warming,
"RSS after loading model: %f MiB (0 if unsupported)",
get_rss_bytes() / 1024.0 / 1024.0);
// Wrap the token_callback with print function
std::function<void(const std::string&)> wrapped_callback =
[token_callback, config](const std::string& piece) {
if (!config.warming) {
safe_printf(piece.c_str());
fflush(stdout);
}
if (token_callback) {
token_callback(piece);
}
};
// Reset internal state and start inference
stats_->inference_start_ms = time_in_ms();
uint64_t prefill_next_token = 0;
// Process multimodal inputs in order
for (size_t i = 0; i < inputs.size(); ++i) {
const MultimodalInput& input = inputs[i];
ET_LOG(
Info,
"Prefilling input %zu/%zu, type: %s",
i,
inputs.size(),
input.type_name());
if (config.echo && i == inputs.size() - 1 && input.is_text()) {
wrapped_callback(input.get_text());
}
auto prefill_result = multimodal_prefiller_->prefill(input, pos_);
if (!prefill_result.ok()) {
return prefill_result.error();
}
prefill_next_token = prefill_result.get();
}
stats_->first_token_ms = time_in_ms();
stats_->prompt_eval_end_ms = time_in_ms();
stats_->num_prompt_tokens = pos_;
auto decode_result =
tokenizer_->decode(prefill_next_token, prefill_next_token);
if (!decode_result.ok()) {
ET_LOG(
Error,
"Tokenizers error code %d",
static_cast<uint32_t>(decode_result.error()));
return Error::InvalidArgument;
}
wrapped_callback(std::move(*decode_result));
RUNNER_ET_LOG(
config.warming,
"RSS after multimodal input processing: %f MiB (0 if unsupported)",
get_rss_bytes() / 1024.0 / 1024.0);
// Resolve max_new_tokens based on config
int64_t max_context_len =
metadata_.at(kMaxContextLen) - 0; // No start_pos offset
int32_t max_new_tokens = config.resolve_max_new_tokens(max_context_len, pos_);
ET_LOG(
Info,
"Max new tokens resolved: %d, pos_ %" PRId64 ", max_context_len %" PRId64,
max_new_tokens,
pos_,
max_context_len);
ET_CHECK_OR_RETURN_ERROR(
max_new_tokens > 0,
InvalidArgument,
"Max new tokens %d is less than or equal to 0",
max_new_tokens);
// Set ignore_eos based on config
text_token_generator_->set_ignore_eos(config.ignore_eos);
// Generate tokens using the text token generator
std::vector<uint64_t> prompt_tokens = {prefill_next_token};
auto generate_result = text_token_generator_->generate(
/*tokens=*/prompt_tokens,
/*start_pos=*/pos_,
/*max_new_tokens=*/max_new_tokens -
1, // Subtract 1 because prefill already generated 1 token
/*temperature=*/config.temperature,
/*token_callback=*/wrapped_callback);
if (!generate_result.ok()) {
return generate_result.error();
}
int64_t num_generated_tokens = generate_result.get();
pos_ += num_generated_tokens;
// Update stats
stats_->num_generated_tokens = num_generated_tokens;
// Finalize stats and call callback
stats_->inference_end_ms = time_in_ms();
#ifdef CUDA_AVAILABLE
cuda_memory_tracker_->log_sample("after_generate");
stats_->gpu_free_after_generate_bytes =
cuda_memory_tracker_->last_free_bytes();
// update peak in case it changed after generation
stats_->gpu_peak_usage_mb = cuda_memory_tracker_->peak_usage_mb();
#endif
if (!config.warming) {
printf("\n");
}
if (config.warming) {
ET_LOG(Info, "Warmup run finished!");
} else {
// Do not print report during warmup
print_report(*stats_);
}
if (stats_callback) {
stats_callback(*stats_);
}
return Error::Ok;
}
} // namespace executorch::extension::llm