-
Notifications
You must be signed in to change notification settings - Fork 71
Expand file tree
/
Copy pathrunner.cpp
More file actions
384 lines (326 loc) · 13.7 KB
/
runner.cpp
File metadata and controls
384 lines (326 loc) · 13.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
* @lint-ignore-every CLANGTIDY facebook-hte-Deprecated
*/
// A simple llama2 runner that includes preprocessing and post processing logic.
// The module takes in a string as input and emits a string as output.
#include "runner.h"
#include "constants.h"
#include "util.h"
#include <cstdint>
#include <ctime>
#include <rnexecutorch/Error.h>
namespace example {
using namespace executorch::extension::llm;
using ::executorch::extension::Module;
using ::executorch::runtime::Error;
using ::executorch::runtime::Result;
Runner::Runner(Module *module, const std::string &tokenizer_path,
const llm::GenerationConfig &config)
: config_(config), module_(module), tokenizer_path_(tokenizer_path),
tokenizer_(std::make_unique<tokenizers::HFTokenizer>()),
metadata_({
{kEnableDynamicShape, false},
{kMaxSeqLen, 128},
{kMaxContextLen, 128},
{kUseKVCache, true},
{kUseSDPAWithKVCache, false},
}) {}
bool Runner::is_loaded() const {
return module_->is_loaded() && tokenizer_->is_loaded() &&
text_decoder_runner_ && text_prefiller_ && text_token_generator_;
}
Error Runner::load() {
if (is_loaded()) {
return Error::Ok;
}
auto status = tokenizer_->load(tokenizer_path_);
if (status != tokenizers::Error::Ok) {
throw rnexecutorch::RnExecutorchError(
rnexecutorch::RnExecutorchErrorCode::TokenizerError,
"Unexpected issue occured while loading tokenizer");
};
ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method("forward"));
ET_LOG(Info, "Reading metadata from model");
auto eos_ids = std::make_unique<std::unordered_set<uint64_t>>();
metadata_[kVocabSize] = tokenizer_->vocab_size();
// Load model metadata
const auto method_names =
ET_UNWRAP(module_->method_names(), "Failed reading method names");
for (auto &pair : metadata_) {
const auto &method_name = pair.first;
auto &value = pair.second;
if (method_names.count(method_name)) {
value = ET_UNWRAP(module_->get(method_name))
.toScalar()
.to<decltype(metadata_)::mapped_type>();
} else {
ET_LOG(Info, "Method %s not found, using the default value %" PRId64,
method_name.c_str(), value);
}
ET_LOG(Info, "Metadata: %s = %" PRId64, method_name.c_str(), value);
}
// Load EOS token ids
if (method_names.count(kEosIds)) {
eos_ids->clear();
for (const auto &eos_id : ET_UNWRAP(module_->execute(kEosIds))) {
auto value = eos_id.toScalar().to<int64_t>();
eos_ids->emplace(value);
ET_LOG(Info, "eos_id = %" PRId64, value);
}
}
// Determine missing config values
// If user does not directly specify configuration parameters such as
// max_seq_len (i.e. leaves them as default values), they are determined by
// reading the exported model's methods.
if (config_.max_seq_len < 0)
config_.max_seq_len = static_cast<int32_t>(metadata_.at(kMaxSeqLen));
if (config_.max_context_length < 0)
config_.max_context_length =
static_cast<int32_t>(metadata_.at(kMaxContextLen));
if (config_.max_new_tokens < 0)
config_.max_new_tokens =
std::min(config_.max_seq_len, config_.max_context_length);
if (config_.enable_dynamic_shape)
config_.enable_dynamic_shape =
static_cast<bool>(metadata_.at(kEnableDynamicShape));
if (config_.enable_kv_cache)
config_.enable_kv_cache = static_cast<bool>(metadata_.at(kUseKVCache));
io_manager_ = std::make_unique<llm::IOManager>(*module_);
text_decoder_runner_ = std::make_unique<llm::TextDecoderRunner>(
module_, io_manager_.get(), config_.temperature, config_.topp);
text_prefiller_ = std::make_unique<llm::TextPrefiller>(
text_decoder_runner_.get(), config_.enable_kv_cache,
config_.enable_dynamic_shape, config_.max_seq_len);
text_token_generator_ = std::make_unique<llm::TextTokenGenerator>(
tokenizer_.get(), text_decoder_runner_.get(), config_.enable_kv_cache,
std::move(eos_ids), &stats_);
return Error::Ok;
}
// Don't print with the same priority during warmup
#define RUNNER_ET_LOG(warmup, format, ...) \
if (warmup) { \
ET_LOG(Debug, format, __VA_ARGS__); \
} else { \
ET_LOG(Info, format, __VA_ARGS__); \
}
Error Runner::generate(const std::string &prompt,
const llm::GenerationConfig &generation_config,
std::function<void(const std::string &)> token_callback,
std::function<void(const llm::Stats &)> stats_callback) {
// Prepare the inputs.
// Use ones-initialized inputs.
ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null");
if (!is_loaded()) {
stats_.model_load_start_ms = llm::time_in_ms();
ET_CHECK_OK_OR_RETURN_ERROR(load());
stats_.model_load_end_ms = llm::time_in_ms();
}
if (generation_config.warming) {
ET_LOG(Info, "Doing a warmup run...");
}
RUNNER_ET_LOG(generation_config.warming,
"RSS after loading model: %f MiB (0 if unsupported)",
llm::get_rss_bytes() / 1024.0 / 1024.0);
// Wrap the token_callback with print function
std::function<void(const std::string &)> wrapped_callback =
[token_callback, &generation_config](const std::string &piece) {
if (!generation_config.warming) {
llm::safe_printf(piece.c_str());
fflush(stdout);
}
if (token_callback) {
token_callback(piece);
}
};
// First token time only measures the time it takes to encode the prompt and
// return a response token.
stats_.inference_start_ms = llm::time_in_ms();
shouldStop_ = false;
// Override main config fields with given generation config if specified
int32_t max_seq_len = generation_config.max_seq_len >= 0
? generation_config.max_seq_len
: config_.max_seq_len;
int32_t max_context_length = generation_config.max_context_length >= 0
? generation_config.max_context_length
: config_.max_context_length;
int32_t new_tokens_limit = generation_config.max_new_tokens >= 0
? generation_config.max_new_tokens
: config_.max_new_tokens;
float temperature = generation_config.temperature >= 0.F
? generation_config.temperature
: config_.temperature;
float topp =
generation_config.topp >= 0.F ? generation_config.topp : config_.topp;
int64_t context_len_left = static_cast<int64_t>(max_context_length) - pos_;
// If the used tokenizer.json has defined post_processor field,
// setting any of bos or eos arguments to value other than provided constant
// ( which is 0) will result in running the post_processor with
// 'add_special_token' flag
auto encodeResult =
tokenizer_->encode(prompt, numOfAddedBoSTokens, numOfAddedEoSTokens);
if (!encodeResult.ok()) {
throw rnexecutorch::RnExecutorchError(
rnexecutorch::RnExecutorchErrorCode::TokenizerError,
"Unexpected issue occured while encoding: " +
std::to_string(static_cast<int32_t>(encodeResult.error())));
}
std::vector<uint64_t> prompt_tokens = encodeResult.get();
std::vector<uint64_t> prompt_tokens_uint64(prompt_tokens.begin(),
prompt_tokens.end());
// encode the (string) prompt into tokens sequence
int num_prompt_tokens = prompt_tokens.size();
ET_CHECK_OR_RETURN_ERROR(num_prompt_tokens >= 1, InvalidArgument,
"Expected at least 1 prompt token");
ET_CHECK_OR_RETURN_ERROR(num_prompt_tokens < max_seq_len, InvalidArgument,
"num_prompt_tokens %d >= max_context_len %" PRId32
", Max seq length exceeded - please increase max "
"seq len value in your export script",
num_prompt_tokens, max_seq_len);
// Determine max_new_tokens using the GenerationConfig's resolve method,
// then subtract pos_ for max_new_tokens.
int32_t max_new_tokens = resolve_max_new_tokens(
num_prompt_tokens, max_seq_len, static_cast<int32_t>(context_len_left),
new_tokens_limit);
ET_LOG(Info,
"Max new tokens resolved: %d, given pos_ %" PRId64
", num_prompt_tokens %zu, max_context_len %" PRId64,
max_new_tokens, pos_, prompt_tokens.size(),
static_cast<int64_t>(max_context_length));
ET_CHECK_OR_RETURN_ERROR(max_new_tokens > 0, InvalidArgument,
"Max new tokens %d is less than or equal to 0",
max_new_tokens);
// Prefill first
// Here feed all tokens to the model and get the next predicted token
// after the prompt. After that we will enter generate loop.
// print prompts
if (generation_config.echo) {
wrapped_callback(prompt);
}
auto prefill_res = text_prefiller_->prefill(prompt_tokens_uint64, pos_);
stats_.first_token_ms = llm::time_in_ms();
stats_.prompt_eval_end_ms = llm::time_in_ms();
ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error());
uint64_t cur_token = prefill_res.get();
auto decodeResult = tokenizer_->decode({cur_token});
if (!decodeResult.ok()) {
throw rnexecutorch::RnExecutorchError(
rnexecutorch::RnExecutorchErrorCode::TokenizerError,
"Unexpected issue occured while decoding: " +
std::to_string(static_cast<int32_t>(decodeResult.error())));
}
const std::string cur_decoded = decodeResult.get();
RUNNER_ET_LOG(generation_config.warming,
"RSS after prompt prefill: %f MiB (0 if unsupported)",
llm::get_rss_bytes() / 1024.0 / 1024.0);
// start the main loop
prompt_tokens_uint64.push_back(cur_token);
int64_t num_generated_tokens = ET_UNWRAP(text_token_generator_->generate(
prompt_tokens_uint64, pos_, max_new_tokens - 1, temperature, topp,
wrapped_callback));
pos_ += num_generated_tokens;
stats_.inference_end_ms = llm::time_in_ms();
if (!generation_config.warming) {
printf("\n");
}
RUNNER_ET_LOG(
generation_config.warming,
"RSS after finishing text generation: %f MiB (0 if unsupported)",
llm::get_rss_bytes() / 1024.0 / 1024.0);
if (num_generated_tokens == max_new_tokens) {
RUNNER_ET_LOG(generation_config.warming, "Max new tokens %i reached!",
max_new_tokens);
}
stats_.num_prompt_tokens = num_prompt_tokens;
stats_.num_generated_tokens = num_generated_tokens;
if (generation_config.warming) {
ET_LOG(Info, "Warmup run finished!");
} else {
// Do not print report during warmup
#ifndef TEST_BUILD
::executorch::llm::print_report(stats_);
#endif
}
if (stats_callback) {
stats_callback(stats_);
}
return Error::Ok;
}
Error Runner::warmup(const std::string &prompt) {
// Create a GenerationConfig for warmup
llm::GenerationConfig config{.echo = false, .warming = true};
// Call generate with the warmup config
Error err = generate(prompt, config,
/*token_callback=*/nullptr,
/*stats_callbak=*/nullptr);
// Reset stats after warmup
reset();
return err;
}
void Runner::stop() {
if (is_loaded()) {
text_token_generator_->stop();
} else {
ET_LOG(Error, "Token generator is not loaded, cannot stop");
}
}
void Runner::reset() {
stats_.reset();
pos_ = 0;
}
void Runner::set_count_interval(size_t count_interval) {
text_token_generator_->set_count_interval(count_interval);
}
void Runner::set_time_interval(size_t time_interval) {
text_token_generator_->set_time_interval(time_interval);
}
void Runner::set_temperature(float temperature) noexcept {
config_.temperature = temperature;
if (text_decoder_runner_) {
text_decoder_runner_->set_temperature(temperature);
}
}
void Runner::set_topp(float topp) noexcept {
config_.topp = topp;
if (text_decoder_runner_) {
text_decoder_runner_->set_topp(topp);
}
}
int32_t Runner::resolve_max_new_tokens(int32_t num_prompt_tokens,
int32_t max_seq_len,
int32_t max_context_len,
int32_t max_new_tokens) const {
int32_t result;
if (max_seq_len == -1 && max_new_tokens == -1) {
// Both are -1, use max context len minus prompt tokens
result = max_context_len - num_prompt_tokens;
} else if (max_seq_len == -1 && max_new_tokens != -1) {
// Only max_new_tokens is specified
result = std::min(max_new_tokens, max_context_len - num_prompt_tokens);
} else if (max_seq_len != -1 && max_new_tokens == -1) {
// Only seq_len is specified
result = std::min(max_seq_len, max_context_len) - num_prompt_tokens;
} else {
// Both are specified
result =
std::min(std::min(max_seq_len, max_context_len) - num_prompt_tokens,
max_new_tokens);
}
// Ensure result is not negative
return std::max(0, result);
}
int32_t Runner::count_text_tokens(const std::string &text) const {
auto encodeResult =
tokenizer_->encode(text, numOfAddedBoSTokens, numOfAddedEoSTokens);
if (!encodeResult.ok()) {
throw rnexecutorch::RnExecutorchError(
rnexecutorch::RnExecutorchErrorCode::TokenizerError,
"Encoding failed during token count check.");
}
return static_cast<int32_t>(encodeResult.get().size());
}
} // namespace example