|
| 1 | +/* |
| 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | + * All rights reserved. |
| 4 | + * |
| 5 | + * This source code is licensed under the BSD-style license found in the |
| 6 | + * LICENSE file in the root directory of this source tree. |
| 7 | + */ |
| 8 | + |
| 9 | +// A simple llama2 runner that includes preprocessing and post processing logic. |
| 10 | +// The module takes in a string as input and emits a string as output. |
| 11 | + |
| 12 | +#include "runner.h" |
| 13 | +#include "util.h" |
| 14 | +#include <ctime> |
| 15 | +#include <fstream> |
| 16 | +#include <iostream> |
| 17 | + |
| 18 | +namespace example { |
| 19 | + |
| 20 | +using ::executorch::extension::Module; |
| 21 | +using ::executorch::runtime::Error; |
| 22 | +using ::executorch::runtime::Result; |
| 23 | + |
| 24 | +namespace llm = ::executorch::extension::llm; |
| 25 | + |
| 26 | +std::string loadBytesFromFile(const std::string &path) { |
| 27 | + std::ifstream fs(path, std::ios::in | std::ios::binary); |
| 28 | + if (fs.fail()) { |
| 29 | + throw std::runtime_error("Failed to open tokenizer file"); |
| 30 | + } |
| 31 | + std::string data; |
| 32 | + fs.seekg(0, std::ios::end); |
| 33 | + size_t size = static_cast<size_t>(fs.tellg()); |
| 34 | + fs.seekg(0, std::ios::beg); |
| 35 | + data.resize(size); |
| 36 | + fs.read(data.data(), size); |
| 37 | + return data; |
| 38 | +} |
| 39 | + |
| 40 | +namespace { |
| 41 | +static constexpr auto kEnableDynamicShape = "enable_dynamic_shape"; |
| 42 | +static constexpr auto kBosId = "get_bos_id"; |
| 43 | +static constexpr auto kEosIds = "get_eos_ids"; |
| 44 | +static constexpr auto kMaxSeqLen = "get_max_seq_len"; |
| 45 | +static constexpr auto kMaxContextLen = "get_max_context_len"; |
| 46 | +static constexpr auto kVocabSize = "get_vocab_size"; |
| 47 | +static constexpr auto kUseKVCache = "use_kv_cache"; |
| 48 | +static constexpr auto kUseSDPAWithKVCache = "use_sdpa_with_kv_cache"; |
| 49 | +} // namespace |
| 50 | + |
| 51 | +Runner::Runner(const std::string &model_path, const std::string &tokenizer_path, |
| 52 | + const float temperature, |
| 53 | + std::optional<const std::string> data_path) |
| 54 | + // NOTE: we observed ~2x loading performance increase on iPhone 15 |
| 55 | + // and a ~5% improvement on Galaxy S22 by switching to |
| 56 | + // FileDataLoader instead of MmapDataLoader + UseMlockIgnoreErrors. |
| 57 | + : temperature_(temperature), tokenizer_path_(tokenizer_path), |
| 58 | + metadata_({ |
| 59 | + {kEnableDynamicShape, false}, |
| 60 | + {kMaxSeqLen, 128}, |
| 61 | + {kMaxContextLen, 128}, |
| 62 | + {kUseKVCache, true}, |
| 63 | + {kUseSDPAWithKVCache, false}, |
| 64 | + }) { |
| 65 | + if (data_path.has_value()) { |
| 66 | + module_ = std::make_unique<Module>(model_path, data_path.value(), |
| 67 | + Module::LoadMode::File); |
| 68 | + } else { |
| 69 | + module_ = std::make_unique<Module>(model_path, Module::LoadMode::File); |
| 70 | + } |
| 71 | + ET_LOG(Info, "Creating LLaMa runner: model_path=%s, tokenizer_path=%s", |
| 72 | + model_path.c_str(), tokenizer_path.c_str()); |
| 73 | +} |
| 74 | + |
| 75 | +bool Runner::is_loaded() const { |
| 76 | + return module_->is_loaded() && tokenizer_ && text_decoder_runner_ && |
| 77 | + text_prefiller_ && text_token_generator_; |
| 78 | +} |
| 79 | + |
| 80 | +Error Runner::load() { |
| 81 | + if (is_loaded()) { |
| 82 | + return Error::Ok; |
| 83 | + } |
| 84 | + ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method("forward")); |
| 85 | + // load tokenizer. |
| 86 | + |
| 87 | + auto blob = loadBytesFromFile(tokenizer_path_); |
| 88 | + tokenizer_ = tokenizers::Tokenizer::FromBlobJSON(blob); |
| 89 | + |
| 90 | + ET_LOG(Info, "Reading metadata from model"); |
| 91 | + |
| 92 | + auto eos_ids = std::make_unique<std::unordered_set<uint64_t>>(); |
| 93 | + metadata_[kVocabSize] = tokenizer_->GetVocabSize(); |
| 94 | + |
| 95 | + const auto method_names = |
| 96 | + ET_UNWRAP(module_->method_names(), "Failed reading method names"); |
| 97 | + |
| 98 | + for (auto &pair : metadata_) { |
| 99 | + const auto &method_name = pair.first; |
| 100 | + auto &value = pair.second; |
| 101 | + if (method_names.count(method_name)) { |
| 102 | + value = ET_UNWRAP(module_->get(method_name)) |
| 103 | + .toScalar() |
| 104 | + .to<decltype(metadata_)::mapped_type>(); |
| 105 | + } else { |
| 106 | + ET_LOG(Info, "Methond %s not found, using the default value %" PRId64, |
| 107 | + method_name.c_str(), value); |
| 108 | + } |
| 109 | + ET_LOG(Info, "Metadata: %s = %" PRId64, method_name.c_str(), value); |
| 110 | + } |
| 111 | + if (method_names.count(kEosIds)) { |
| 112 | + eos_ids->clear(); |
| 113 | + for (const auto &eos_id : ET_UNWRAP(module_->execute(kEosIds))) { |
| 114 | + auto value = eos_id.toScalar().to<int64_t>(); |
| 115 | + eos_ids->emplace(value); |
| 116 | + ET_LOG(Info, "eos_id = %" PRId64, value); |
| 117 | + } |
| 118 | + } |
| 119 | + text_decoder_runner_ = std::make_unique<llm::TextDecoderRunner>( |
| 120 | + module_.get(), metadata_.at(kUseKVCache), metadata_.at(kVocabSize), |
| 121 | + temperature_); |
| 122 | + text_prefiller_ = std::make_unique<llm::TextPrefiller>( |
| 123 | + text_decoder_runner_.get(), metadata_.at(kUseKVCache), |
| 124 | + metadata_.at(kEnableDynamicShape)); |
| 125 | + |
| 126 | + text_token_generator_ = std::make_unique<llm::TextTokenGenerator>( |
| 127 | + tokenizer_.get(), text_decoder_runner_.get(), metadata_.at(kUseKVCache), |
| 128 | + std::move(eos_ids), &stats_); |
| 129 | + |
| 130 | + return Error::Ok; |
| 131 | +} |
| 132 | + |
| 133 | +// Don't print with the same priority during warmup |
| 134 | +#define RUNNER_ET_LOG(warmup, format, ...) \ |
| 135 | + if (warmup) { \ |
| 136 | + ET_LOG(Debug, format, __VA_ARGS__); \ |
| 137 | + } else { \ |
| 138 | + ET_LOG(Info, format, __VA_ARGS__); \ |
| 139 | + } |
| 140 | + |
| 141 | +Error Runner::generate(const std::string &prompt, |
| 142 | + std::function<void(const std::string &)> token_callback, |
| 143 | + std::function<void(const llm::Stats &)> stats_callback, |
| 144 | + bool echo, bool warmup) { |
| 145 | + // Prepare the inputs. |
| 146 | + // Use ones-initialized inputs. |
| 147 | + ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null"); |
| 148 | + if (!is_loaded()) { |
| 149 | + stats_.model_load_start_ms = llm::time_in_ms(); |
| 150 | + ET_CHECK_OK_OR_RETURN_ERROR(load()); |
| 151 | + stats_.model_load_end_ms = llm::time_in_ms(); |
| 152 | + } |
| 153 | + |
| 154 | + if (warmup) { |
| 155 | + ET_LOG(Info, "Doing a warmup run..."); |
| 156 | + } |
| 157 | + |
| 158 | + RUNNER_ET_LOG(warmup, "RSS after loading model: %f MiB (0 if unsupported)", |
| 159 | + llm::get_rss_bytes() / 1024.0 / 1024.0); |
| 160 | + |
| 161 | + // Wrap the token_callback with print function |
| 162 | + std::function<void(const std::string &)> wrapped_callback = |
| 163 | + [token_callback, warmup](const std::string &piece) { |
| 164 | + if (!warmup) { |
| 165 | + llm::safe_printf(piece.c_str()); |
| 166 | + fflush(stdout); |
| 167 | + } |
| 168 | + if (token_callback) { |
| 169 | + token_callback(piece); |
| 170 | + } |
| 171 | + }; |
| 172 | + // First token time only measures the time it takes to encode the prompt and |
| 173 | + // return a response token. |
| 174 | + |
| 175 | + stats_.inference_start_ms = llm::time_in_ms(); |
| 176 | + shouldStop_ = false; |
| 177 | + |
| 178 | + // Set the sequence length to the max seq length if not provided |
| 179 | + int32_t seq_len = (seq_len > 0 && seq_len <= metadata_.at(kMaxSeqLen)) |
| 180 | + ? seq_len |
| 181 | + : metadata_.at(kMaxSeqLen); |
| 182 | + |
| 183 | + std::vector<int32_t> prompt_tokens = tokenizer_->Encode(prompt); |
| 184 | + std::vector<uint64_t> prompt_tokens_uint64(prompt_tokens.begin(), |
| 185 | + prompt_tokens.end()); |
| 186 | + |
| 187 | + // encode the (string) prompt into tokens sequence |
| 188 | + int num_prompt_tokens = prompt_tokens.size(); |
| 189 | + |
| 190 | + if (num_prompt_tokens < 1) { |
| 191 | + ET_LOG(Error, |
| 192 | + "num_prompt_tokens %d < 1, expected at least 1 token to be passed " |
| 193 | + "to generate()!", |
| 194 | + num_prompt_tokens); |
| 195 | + return Error::InvalidArgument; |
| 196 | + } else if (num_prompt_tokens >= seq_len) { |
| 197 | + ET_LOG(Error, |
| 198 | + "num_prompt_tokens %d >= seq_len %d, Sequence length exceeded - " |
| 199 | + "please increase the seq_len value passed to generate()!", |
| 200 | + num_prompt_tokens, seq_len); |
| 201 | + return Error::InvalidArgument; |
| 202 | + } |
| 203 | + |
| 204 | + // Prefill first |
| 205 | + // Here feed all tokens to the model and get the next predicted token |
| 206 | + // after the prompt. After that we will enter generate loop. |
| 207 | + |
| 208 | + // print prompts |
| 209 | + if (echo) { |
| 210 | + wrapped_callback(prompt); |
| 211 | + } |
| 212 | + int64_t pos = 0; |
| 213 | + auto prefill_res = text_prefiller_->prefill(prompt_tokens_uint64, pos); |
| 214 | + stats_.first_token_ms = llm::time_in_ms(); |
| 215 | + stats_.prompt_eval_end_ms = llm::time_in_ms(); |
| 216 | + ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error()); |
| 217 | + uint64_t cur_token = prefill_res.get(); |
| 218 | + |
| 219 | + // print the first token from prefill. No prev_token so use cur_token for it. |
| 220 | + wrapped_callback(tokenizer_->Decode( |
| 221 | + std::vector<int32_t>{static_cast<int32_t>(cur_token)})); |
| 222 | + RUNNER_ET_LOG(warmup, "RSS after prompt prefill: %f MiB (0 if unsupported)", |
| 223 | + llm::get_rss_bytes() / 1024.0 / 1024.0); |
| 224 | + |
| 225 | + // start the main loop |
| 226 | + prompt_tokens_uint64.push_back(cur_token); |
| 227 | + int64_t num_generated_tokens = ET_UNWRAP(text_token_generator_->generate( |
| 228 | + prompt_tokens_uint64, num_prompt_tokens, seq_len, wrapped_callback)); |
| 229 | + |
| 230 | + stats_.inference_end_ms = llm::time_in_ms(); |
| 231 | + if (!warmup) { |
| 232 | + printf("\n"); |
| 233 | + } |
| 234 | + RUNNER_ET_LOG( |
| 235 | + warmup, "RSS after finishing text generation: %f MiB (0 if unsupported)", |
| 236 | + llm::get_rss_bytes() / 1024.0 / 1024.0); |
| 237 | + |
| 238 | + if (num_prompt_tokens + num_generated_tokens == seq_len) { |
| 239 | + RUNNER_ET_LOG(warmup, "Sequence length (%i tokens) reached!", seq_len); |
| 240 | + } |
| 241 | + |
| 242 | + stats_.num_prompt_tokens = num_prompt_tokens; |
| 243 | + stats_.num_generated_tokens = num_generated_tokens; |
| 244 | + |
| 245 | + if (warmup) { |
| 246 | + ET_LOG(Info, "Warmup run finished!"); |
| 247 | + } else { |
| 248 | + // Do not print report during warmup |
| 249 | + ::executorch::llm::print_report(stats_); |
| 250 | + } |
| 251 | + if (stats_callback) { |
| 252 | + stats_callback(stats_); |
| 253 | + } |
| 254 | + |
| 255 | + return Error::Ok; |
| 256 | +} |
| 257 | + |
| 258 | +Error Runner::warmup(const std::string &prompt) { |
| 259 | + Error err = generate(prompt, |
| 260 | + /*token_callback=*/nullptr, |
| 261 | + /*stats_callbak=*/nullptr, |
| 262 | + /*echo=*/false, |
| 263 | + /*warmup=*/true); |
| 264 | + stats_.reset(); |
| 265 | + return err; |
| 266 | +} |
| 267 | + |
| 268 | +void Runner::stop() { |
| 269 | + if (is_loaded()) { |
| 270 | + text_token_generator_->stop(); |
| 271 | + } else { |
| 272 | + ET_LOG(Error, "Token generator is not loaded, cannot stop"); |
| 273 | + } |
| 274 | +} |
| 275 | +} // namespace example |
0 commit comments