Skip to content

Commit e299c3f

Browse files
committed
chore: add runner to includes
1 parent dea16ea commit e299c3f

12 files changed

Lines changed: 1390 additions & 0 deletions

File tree

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
// An interface for LLM runners. Developers can create their own runner that
10+
// implements their own load and generation logic to run the model.
11+
12+
#pragma once
13+
14+
#include <functional>
15+
#include <string>
16+
17+
#include "stats.h"
18+
#include <executorch/extension/module/module.h>
19+
20+
namespace executorch {
21+
namespace extension {
22+
namespace llm {
23+
24+
class ET_EXPERIMENTAL IRunner {
25+
public:
26+
virtual ~IRunner() = default;
27+
28+
// Checks if the model is loaded.
29+
virtual bool is_loaded() const = 0;
30+
31+
// Load the model and tokenizer.
32+
virtual ::executorch::runtime::Error load() = 0;
33+
34+
// Generate the output tokens.
35+
virtual ::executorch::runtime::Error
36+
generate(const std::string &prompt,
37+
std::function<void(const std::string &)> token_callback = {},
38+
std::function<void(const ::executorch::extension::llm::Stats &)>
39+
stats_callback = {},
40+
bool echo = true, bool warming = false) = 0;
41+
42+
// Stop the generation.
43+
virtual void stop() = 0;
44+
};
45+
46+
} // namespace llm
47+
} // namespace extension
48+
} // namespace executorch
Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
// A simple llama2 runner that includes preprocessing and post processing logic.
10+
// The module takes in a string as input and emits a string as output.
11+
12+
#include "runner.h"
13+
#include "util.h"
14+
#include <ctime>
15+
#include <fstream>
16+
#include <iostream>
17+
18+
namespace example {
19+
20+
using ::executorch::extension::Module;
21+
using ::executorch::runtime::Error;
22+
using ::executorch::runtime::Result;
23+
24+
namespace llm = ::executorch::extension::llm;
25+
26+
std::string loadBytesFromFile(const std::string &path) {
27+
std::ifstream fs(path, std::ios::in | std::ios::binary);
28+
if (fs.fail()) {
29+
throw std::runtime_error("Failed to open tokenizer file");
30+
}
31+
std::string data;
32+
fs.seekg(0, std::ios::end);
33+
size_t size = static_cast<size_t>(fs.tellg());
34+
fs.seekg(0, std::ios::beg);
35+
data.resize(size);
36+
fs.read(data.data(), size);
37+
return data;
38+
}
39+
40+
namespace {
41+
static constexpr auto kEnableDynamicShape = "enable_dynamic_shape";
42+
static constexpr auto kBosId = "get_bos_id";
43+
static constexpr auto kEosIds = "get_eos_ids";
44+
static constexpr auto kMaxSeqLen = "get_max_seq_len";
45+
static constexpr auto kMaxContextLen = "get_max_context_len";
46+
static constexpr auto kVocabSize = "get_vocab_size";
47+
static constexpr auto kUseKVCache = "use_kv_cache";
48+
static constexpr auto kUseSDPAWithKVCache = "use_sdpa_with_kv_cache";
49+
} // namespace
50+
51+
Runner::Runner(const std::string &model_path, const std::string &tokenizer_path,
52+
const float temperature,
53+
std::optional<const std::string> data_path)
54+
// NOTE: we observed ~2x loading performance increase on iPhone 15
55+
// and a ~5% improvement on Galaxy S22 by switching to
56+
// FileDataLoader instead of MmapDataLoader + UseMlockIgnoreErrors.
57+
: temperature_(temperature), tokenizer_path_(tokenizer_path),
58+
metadata_({
59+
{kEnableDynamicShape, false},
60+
{kMaxSeqLen, 128},
61+
{kMaxContextLen, 128},
62+
{kUseKVCache, true},
63+
{kUseSDPAWithKVCache, false},
64+
}) {
65+
if (data_path.has_value()) {
66+
module_ = std::make_unique<Module>(model_path, data_path.value(),
67+
Module::LoadMode::File);
68+
} else {
69+
module_ = std::make_unique<Module>(model_path, Module::LoadMode::File);
70+
}
71+
ET_LOG(Info, "Creating LLaMa runner: model_path=%s, tokenizer_path=%s",
72+
model_path.c_str(), tokenizer_path.c_str());
73+
}
74+
75+
bool Runner::is_loaded() const {
76+
return module_->is_loaded() && tokenizer_ && text_decoder_runner_ &&
77+
text_prefiller_ && text_token_generator_;
78+
}
79+
80+
Error Runner::load() {
81+
if (is_loaded()) {
82+
return Error::Ok;
83+
}
84+
ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method("forward"));
85+
// load tokenizer.
86+
87+
auto blob = loadBytesFromFile(tokenizer_path_);
88+
tokenizer_ = tokenizers::Tokenizer::FromBlobJSON(blob);
89+
90+
ET_LOG(Info, "Reading metadata from model");
91+
92+
auto eos_ids = std::make_unique<std::unordered_set<uint64_t>>();
93+
metadata_[kVocabSize] = tokenizer_->GetVocabSize();
94+
95+
const auto method_names =
96+
ET_UNWRAP(module_->method_names(), "Failed reading method names");
97+
98+
for (auto &pair : metadata_) {
99+
const auto &method_name = pair.first;
100+
auto &value = pair.second;
101+
if (method_names.count(method_name)) {
102+
value = ET_UNWRAP(module_->get(method_name))
103+
.toScalar()
104+
.to<decltype(metadata_)::mapped_type>();
105+
} else {
106+
ET_LOG(Info, "Methond %s not found, using the default value %" PRId64,
107+
method_name.c_str(), value);
108+
}
109+
ET_LOG(Info, "Metadata: %s = %" PRId64, method_name.c_str(), value);
110+
}
111+
if (method_names.count(kEosIds)) {
112+
eos_ids->clear();
113+
for (const auto &eos_id : ET_UNWRAP(module_->execute(kEosIds))) {
114+
auto value = eos_id.toScalar().to<int64_t>();
115+
eos_ids->emplace(value);
116+
ET_LOG(Info, "eos_id = %" PRId64, value);
117+
}
118+
}
119+
text_decoder_runner_ = std::make_unique<llm::TextDecoderRunner>(
120+
module_.get(), metadata_.at(kUseKVCache), metadata_.at(kVocabSize),
121+
temperature_);
122+
text_prefiller_ = std::make_unique<llm::TextPrefiller>(
123+
text_decoder_runner_.get(), metadata_.at(kUseKVCache),
124+
metadata_.at(kEnableDynamicShape));
125+
126+
text_token_generator_ = std::make_unique<llm::TextTokenGenerator>(
127+
tokenizer_.get(), text_decoder_runner_.get(), metadata_.at(kUseKVCache),
128+
std::move(eos_ids), &stats_);
129+
130+
return Error::Ok;
131+
}
132+
133+
// Don't print with the same priority during warmup
134+
#define RUNNER_ET_LOG(warmup, format, ...) \
135+
if (warmup) { \
136+
ET_LOG(Debug, format, __VA_ARGS__); \
137+
} else { \
138+
ET_LOG(Info, format, __VA_ARGS__); \
139+
}
140+
141+
Error Runner::generate(const std::string &prompt,
142+
std::function<void(const std::string &)> token_callback,
143+
std::function<void(const llm::Stats &)> stats_callback,
144+
bool echo, bool warmup) {
145+
// Prepare the inputs.
146+
// Use ones-initialized inputs.
147+
ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null");
148+
if (!is_loaded()) {
149+
stats_.model_load_start_ms = llm::time_in_ms();
150+
ET_CHECK_OK_OR_RETURN_ERROR(load());
151+
stats_.model_load_end_ms = llm::time_in_ms();
152+
}
153+
154+
if (warmup) {
155+
ET_LOG(Info, "Doing a warmup run...");
156+
}
157+
158+
RUNNER_ET_LOG(warmup, "RSS after loading model: %f MiB (0 if unsupported)",
159+
llm::get_rss_bytes() / 1024.0 / 1024.0);
160+
161+
// Wrap the token_callback with print function
162+
std::function<void(const std::string &)> wrapped_callback =
163+
[token_callback, warmup](const std::string &piece) {
164+
if (!warmup) {
165+
llm::safe_printf(piece.c_str());
166+
fflush(stdout);
167+
}
168+
if (token_callback) {
169+
token_callback(piece);
170+
}
171+
};
172+
// First token time only measures the time it takes to encode the prompt and
173+
// return a response token.
174+
175+
stats_.inference_start_ms = llm::time_in_ms();
176+
shouldStop_ = false;
177+
178+
// Set the sequence length to the max seq length if not provided
179+
int32_t seq_len = (seq_len > 0 && seq_len <= metadata_.at(kMaxSeqLen))
180+
? seq_len
181+
: metadata_.at(kMaxSeqLen);
182+
183+
std::vector<int32_t> prompt_tokens = tokenizer_->Encode(prompt);
184+
std::vector<uint64_t> prompt_tokens_uint64(prompt_tokens.begin(),
185+
prompt_tokens.end());
186+
187+
// encode the (string) prompt into tokens sequence
188+
int num_prompt_tokens = prompt_tokens.size();
189+
190+
if (num_prompt_tokens < 1) {
191+
ET_LOG(Error,
192+
"num_prompt_tokens %d < 1, expected at least 1 token to be passed "
193+
"to generate()!",
194+
num_prompt_tokens);
195+
return Error::InvalidArgument;
196+
} else if (num_prompt_tokens >= seq_len) {
197+
ET_LOG(Error,
198+
"num_prompt_tokens %d >= seq_len %d, Sequence length exceeded - "
199+
"please increase the seq_len value passed to generate()!",
200+
num_prompt_tokens, seq_len);
201+
return Error::InvalidArgument;
202+
}
203+
204+
// Prefill first
205+
// Here feed all tokens to the model and get the next predicted token
206+
// after the prompt. After that we will enter generate loop.
207+
208+
// print prompts
209+
if (echo) {
210+
wrapped_callback(prompt);
211+
}
212+
int64_t pos = 0;
213+
auto prefill_res = text_prefiller_->prefill(prompt_tokens_uint64, pos);
214+
stats_.first_token_ms = llm::time_in_ms();
215+
stats_.prompt_eval_end_ms = llm::time_in_ms();
216+
ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error());
217+
uint64_t cur_token = prefill_res.get();
218+
219+
// print the first token from prefill. No prev_token so use cur_token for it.
220+
wrapped_callback(tokenizer_->Decode(
221+
std::vector<int32_t>{static_cast<int32_t>(cur_token)}));
222+
RUNNER_ET_LOG(warmup, "RSS after prompt prefill: %f MiB (0 if unsupported)",
223+
llm::get_rss_bytes() / 1024.0 / 1024.0);
224+
225+
// start the main loop
226+
prompt_tokens_uint64.push_back(cur_token);
227+
int64_t num_generated_tokens = ET_UNWRAP(text_token_generator_->generate(
228+
prompt_tokens_uint64, num_prompt_tokens, seq_len, wrapped_callback));
229+
230+
stats_.inference_end_ms = llm::time_in_ms();
231+
if (!warmup) {
232+
printf("\n");
233+
}
234+
RUNNER_ET_LOG(
235+
warmup, "RSS after finishing text generation: %f MiB (0 if unsupported)",
236+
llm::get_rss_bytes() / 1024.0 / 1024.0);
237+
238+
if (num_prompt_tokens + num_generated_tokens == seq_len) {
239+
RUNNER_ET_LOG(warmup, "Sequence length (%i tokens) reached!", seq_len);
240+
}
241+
242+
stats_.num_prompt_tokens = num_prompt_tokens;
243+
stats_.num_generated_tokens = num_generated_tokens;
244+
245+
if (warmup) {
246+
ET_LOG(Info, "Warmup run finished!");
247+
} else {
248+
// Do not print report during warmup
249+
::executorch::llm::print_report(stats_);
250+
}
251+
if (stats_callback) {
252+
stats_callback(stats_);
253+
}
254+
255+
return Error::Ok;
256+
}
257+
258+
Error Runner::warmup(const std::string &prompt) {
259+
Error err = generate(prompt,
260+
/*token_callback=*/nullptr,
261+
/*stats_callbak=*/nullptr,
262+
/*echo=*/false,
263+
/*warmup=*/true);
264+
stats_.reset();
265+
return err;
266+
}
267+
268+
void Runner::stop() {
269+
if (is_loaded()) {
270+
text_token_generator_->stop();
271+
} else {
272+
ET_LOG(Error, "Token generator is not loaded, cannot stop");
273+
}
274+
}
275+
} // namespace example
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
// A simple llama2 runner that includes preprocessing and post processing logic.
10+
// The module takes in a string as input and emits a string as output.
11+
12+
#pragma once
13+
14+
#include "irunner.h"
15+
#include "stats.h"
16+
#include "text_decoder_runner.h"
17+
#include "text_prefiller.h"
18+
#include "text_token_generator.h"
19+
#include <cstdint>
20+
#include <executorch/extension/module/module.h>
21+
#include <functional>
22+
#include <memory>
23+
#include <optional>
24+
#include <string>
25+
#include <tokenizers-cpp/tokenizers_cpp.h>
26+
#include <unordered_map>
27+
28+
namespace example {
29+
30+
class ET_EXPERIMENTAL Runner : public executorch::extension::llm::IRunner {
31+
public:
32+
explicit Runner(const std::string &model_path,
33+
const std::string &tokenizer_path,
34+
const float temperature = 0.8f,
35+
std::optional<const std::string> data_path = std::nullopt);
36+
37+
bool is_loaded() const;
38+
::executorch::runtime::Error load();
39+
::executorch::runtime::Error
40+
generate(const std::string &prompt,
41+
std::function<void(const std::string &)> token_callback = {},
42+
std::function<void(const ::executorch::extension::llm::Stats &)>
43+
stats_callback = {},
44+
bool echo = true, bool warming = false);
45+
::executorch::runtime::Error warmup(const std::string &prompt);
46+
void stop();
47+
48+
private:
49+
float temperature_;
50+
bool shouldStop_{false};
51+
52+
// model
53+
std::unique_ptr<::executorch::extension::Module> module_;
54+
std::string tokenizer_path_;
55+
std::unique_ptr<tokenizers::Tokenizer> tokenizer_;
56+
std::unordered_map<std::string, int64_t> metadata_;
57+
std::unique_ptr<::executorch::extension::llm::TextDecoderRunner>
58+
text_decoder_runner_;
59+
std::unique_ptr<::executorch::extension::llm::TextPrefiller> text_prefiller_;
60+
std::unique_ptr<::executorch::extension::llm::TextTokenGenerator>
61+
text_token_generator_;
62+
63+
// stats
64+
::executorch::extension::llm::Stats stats_;
65+
};
66+
67+
} // namespace example

0 commit comments

Comments
 (0)