Skip to content

Commit d67b5d9

Browse files
committed
feat: support different input format in LLM text prefiller (#661)
1 parent f9fcb04 commit d67b5d9

8 files changed

Lines changed: 91 additions & 49 deletions

File tree

packages/react-native-executorch/common/rnexecutorch/models/BaseModel.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,14 @@ namespace rnexecutorch::models {
88

99
using namespace facebook;
1010
using namespace executorch::extension;
11+
using ::executorch::extension::module::Module;
1112
using ::executorch::runtime::Error;
1213

1314
BaseModel::BaseModel(const std::string &modelSource,
14-
std::shared_ptr<react::CallInvoker> callInvoker)
15+
std::shared_ptr<react::CallInvoker> callInvoker,
16+
Module::LoadMode loadMode)
1517
: callInvoker(callInvoker),
16-
module_(std::make_unique<Module>(
17-
modelSource, Module::LoadMode::MmapUseMlockIgnoreErrors)) {
18+
module_(std::make_unique<Module>(modelSource, loadMode)) {
1819
Error loadError = module_->load();
1920
if (loadError != Error::Ok) {
2021
throw std::runtime_error("Failed to load model: Error " +

packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,16 @@
1313
namespace rnexecutorch {
1414
namespace models {
1515
using namespace facebook;
16+
using executorch::extension::module::Module;
1617
using executorch::runtime::EValue;
1718
using executorch::runtime::Result;
19+
1820
class BaseModel {
1921
public:
20-
BaseModel(const std::string &modelSource,
21-
std::shared_ptr<react::CallInvoker> callInvoker);
22+
BaseModel(
23+
const std::string &modelSource,
24+
std::shared_ptr<react::CallInvoker> callInvoker,
25+
Module::LoadMode loadMode = Module::LoadMode::MmapUseMlockIgnoreErrors);
2226
std::size_t getMemoryLowerBound() const noexcept;
2327
void unload() noexcept;
2428
std::vector<int32_t> getInputShape(std::string method_name, int32_t index);
@@ -42,12 +46,13 @@ class BaseModel {
4246
std::shared_ptr<react::CallInvoker> callInvoker;
4347
std::unique_ptr<executorch::extension::Module> module_;
4448

45-
private:
4649
std::size_t memorySizeLowerBound{0};
50+
51+
private:
4752
std::vector<int32_t> getTensorShape(const executorch::aten::Tensor &tensor);
4853
};
4954
} // namespace models
5055

5156
REGISTER_CONSTRUCTOR(models::BaseModel, std::string,
5257
std::shared_ptr<react::CallInvoker>);
53-
} // namespace rnexecutorch
58+
} // namespace rnexecutorch

packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.cpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,34 @@
88
namespace rnexecutorch::models::llm {
99
using namespace facebook;
1010
using executorch::extension::TensorPtr;
11+
using executorch::extension::module::Module;
1112
using executorch::runtime::Error;
1213

1314
LLM::LLM(const std::string &modelSource, const std::string &tokenizerSource,
1415
std::shared_ptr<react::CallInvoker> callInvoker)
15-
: runner(std::make_unique<example::Runner>(modelSource, tokenizerSource)),
16-
callInvoker(callInvoker) {
17-
16+
: BaseModel(modelSource, callInvoker, Module::LoadMode::File),
17+
runner(std::make_unique<example::Runner>(module_.get(), modelSource,
18+
tokenizerSource)) {
1819
auto loadResult = runner->load();
1920
if (loadResult != Error::Ok) {
2021
throw std::runtime_error("Failed to load LLM runner, error code: " +
2122
std::to_string(static_cast<int>(loadResult)));
2223
}
24+
2325
memorySizeLowerBound =
2426
std::filesystem::file_size(std::filesystem::path(modelSource)) +
2527
std::filesystem::file_size(std::filesystem::path(tokenizerSource));
28+
29+
// Determine the input mode
30+
auto tokensTensorShape = getInputShape("forward", 0);
31+
auto positionsTensorShape = getInputShape("forward", 1);
32+
if (tokensTensorShape.size() != 2 || positionsTensorShape.size() != 1) {
33+
throw std::runtime_error("Unsupported LLM input format");
34+
}
35+
if (positionsTensorShape[0] != 1 &&
36+
tokensTensorShape[1] == positionsTensorShape[0]) {
37+
runner->set_extended_input_mode(true);
38+
}
2639
}
2740

2841
void LLM::generate(std::string input, std::shared_ptr<jsi::Function> callback) {

packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.h

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,16 @@
33
#include <memory>
44
#include <string>
55

6-
#include "rnexecutorch/metaprogramming/ConstructorHelpers.h"
76
#include <ReactCommon/CallInvoker.h>
87
#include <jsi/jsi.h>
8+
#include <rnexecutorch/models/BaseModel.h>
99
#include <runner/runner.h>
1010

1111
namespace rnexecutorch {
1212
namespace models::llm {
1313
using namespace facebook;
1414

15-
class LLM {
15+
class LLM : public BaseModel {
1616
public:
1717
explicit LLM(const std::string &modelSource,
1818
const std::string &tokenizerSource,
@@ -27,12 +27,16 @@ class LLM {
2727
void setTimeInterval(size_t timeInterval);
2828

2929
private:
30-
size_t memorySizeLowerBound;
3130
std::unique_ptr<example::Runner> runner;
32-
std::shared_ptr<react::CallInvoker> callInvoker;
31+
32+
// A typical input for parallel processing in exported LLM model consists of 2
33+
// tensors of shapes [1, N] and [1], where N is the number of tokens. Hovewer,
34+
// some exported models require inputs of shapes [1, N] and [N], which needs
35+
// to be marked before using LLM runner.
36+
bool extended_input_mode_ = false;
3337
};
3438
} // namespace models::llm
3539

3640
REGISTER_CONSTRUCTOR(models::llm::LLM, std::string, std::string,
3741
std::shared_ptr<react::CallInvoker>);
38-
} // namespace rnexecutorch
42+
} // namespace rnexecutorch

packages/react-native-executorch/common/runner/runner.cpp

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -47,27 +47,19 @@ static constexpr auto kUseKVCache = "use_kv_cache";
4747
static constexpr auto kUseSDPAWithKVCache = "use_sdpa_with_kv_cache";
4848
} // namespace
4949

50-
Runner::Runner(const std::string &model_path, const std::string &tokenizer_path,
51-
const float temperature,
50+
Runner::Runner(Module *module, const std::string &model_path,
51+
const std::string &tokenizer_path,
52+
const bool extended_input_mode, const float temperature,
5253
std::optional<const std::string> data_path)
53-
// NOTE: we observed ~2x loading performance increase on iPhone 15
54-
// and a ~5% improvement on Galaxy S22 by switching to
55-
// FileDataLoader instead of MmapDataLoader + UseMlockIgnoreErrors.
56-
: temperature_(temperature), tokenizer_path_(tokenizer_path),
57-
metadata_({
58-
{kEnableDynamicShape, false},
59-
{kMaxSeqLen, 128},
60-
{kMaxContextLen, 128},
61-
{kUseKVCache, true},
62-
{kUseSDPAWithKVCache, false},
63-
}) {
64-
if (data_path.has_value()) {
65-
module_ = std::make_unique<Module>(model_path, data_path.value(),
66-
Module::LoadMode::File);
67-
} else {
68-
module_ = std::make_unique<Module>(model_path, Module::LoadMode::File);
69-
}
70-
ET_LOG(Info, "Creating LLaMa runner: model_path=%s, tokenizer_path=%s",
54+
: module_(module), temperature_(temperature),
55+
tokenizer_path_(tokenizer_path), metadata_({
56+
{kEnableDynamicShape, false},
57+
{kMaxSeqLen, 128},
58+
{kMaxContextLen, 128},
59+
{kUseKVCache, true},
60+
{kUseSDPAWithKVCache, false},
61+
}) {
62+
ET_LOG(Info, "Creating LLM runner: model_path=%s, tokenizer_path=%s",
7163
model_path.c_str(), tokenizer_path.c_str());
7264
}
7365

@@ -116,7 +108,7 @@ Error Runner::load() {
116108
}
117109
}
118110
text_decoder_runner_ = std::make_unique<llm::TextDecoderRunner>(
119-
module_.get(), metadata_.at(kUseKVCache), metadata_.at(kVocabSize),
111+
module_, metadata_.at(kUseKVCache), metadata_.at(kVocabSize),
120112
temperature_);
121113
text_prefiller_ = std::make_unique<llm::TextPrefiller>(
122114
text_decoder_runner_.get(), metadata_.at(kUseKVCache),
@@ -206,7 +198,8 @@ Error Runner::generate(const std::string &prompt,
206198
wrapped_callback(prompt);
207199
}
208200
int64_t pos = 0;
209-
auto prefill_res = text_prefiller_->prefill(prompt_tokens_uint64, pos);
201+
auto prefill_res = text_prefiller_->prefill(prompt_tokens_uint64, pos,
202+
extend_position_input_);
210203
stats_.first_token_ms = llm::time_in_ms();
211204
stats_.prompt_eval_end_ms = llm::time_in_ms();
212205
ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error());
@@ -269,6 +262,10 @@ void Runner::stop() {
269262
}
270263
}
271264

265+
void Runner::set_extended_input_mode(bool extend_position_input) {
266+
extend_position_input_ = extend_position_input;
267+
}
268+
272269
void Runner::set_count_interval(size_t count_interval) {
273270
text_token_generator_->set_count_interval(count_interval);
274271
}

packages/react-native-executorch/common/runner/runner.h

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,13 @@ namespace example {
2929

3030
class Runner : public executorch::extension::llm::IRunner {
3131
public:
32-
explicit Runner(const std::string &model_path,
33-
const std::string &tokenizer_path,
34-
const float temperature = 0.8f,
35-
std::optional<const std::string> data_path = std::nullopt);
32+
explicit Runner(
33+
::executorch::extension::Module *module,
34+
const std::string &model_path, // TODO: consider removing this arg since
35+
// it is only used for debug purposes
36+
const std::string &tokenizer_path, const bool extended_input_mode = false,
37+
const float temperature = 0.8f,
38+
std::optional<const std::string> data_path = std::nullopt);
3639

3740
bool is_loaded() const;
3841
::executorch::runtime::Error load();
@@ -43,6 +46,7 @@ class Runner : public executorch::extension::llm::IRunner {
4346
stats_callback = {},
4447
bool echo = true, bool warming = false);
4548
::executorch::runtime::Error warmup(const std::string &prompt);
49+
void set_extended_input_mode(bool extend_position_input);
4650
void set_count_interval(size_t count_interval);
4751
void set_time_interval(size_t time_interval);
4852
void stop();
@@ -51,10 +55,13 @@ class Runner : public executorch::extension::llm::IRunner {
5155

5256
private:
5357
float temperature_;
58+
bool extend_position_input_{false};
5459
bool shouldStop_{false};
5560

56-
// model
57-
std::unique_ptr<::executorch::extension::Module> module_;
61+
// Main model
62+
::executorch::extension::Module *module_;
63+
64+
// Subcomponents
5865
std::string tokenizer_path_;
5966
std::unique_ptr<tokenizers::Tokenizer> tokenizer_;
6067
std::unordered_map<std::string, int64_t> metadata_;
@@ -65,4 +72,4 @@ class Runner : public executorch::extension::llm::IRunner {
6572
text_token_generator_;
6673
};
6774

68-
} // namespace example
75+
} // namespace example

packages/react-native-executorch/common/runner/text_prefiller.cpp

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
// LLM.
1111

1212
#include "text_prefiller.h"
13+
#include <numeric>
1314

1415
namespace executorch {
1516
namespace extension {
@@ -21,8 +22,8 @@ TextPrefiller::TextPrefiller(TextDecoderRunner *text_decoder_runner,
2122
enable_parallel_prefill_(enable_parallel_prefill) {}
2223

2324
::executorch::runtime::Result<uint64_t>
24-
TextPrefiller::prefill(std::vector<uint64_t> &prompt_tokens,
25-
int64_t &start_pos) {
25+
TextPrefiller::prefill(std::vector<uint64_t> &prompt_tokens, int64_t &start_pos,
26+
bool extend_position_input) {
2627
ET_CHECK_MSG(!prompt_tokens.empty(), "Prompt cannot be null");
2728
if (!text_decoder_runner_->is_method_loaded()) {
2829
ET_CHECK_OK_OR_RETURN_ERROR(text_decoder_runner_->load());
@@ -38,8 +39,21 @@ TextPrefiller::prefill(std::vector<uint64_t> &prompt_tokens,
3839
auto tokens = from_blob(prompt_tokens.data(), {1, num_prompt_tokens},
3940
executorch::aten::ScalarType::Long);
4041

41-
auto start_pos_tensor =
42-
from_blob(&start_pos, {1}, executorch::aten::ScalarType::Long);
42+
std::unique_ptr<std::vector<int64_t>> extended_start_pos = nullptr;
43+
if (extend_position_input) {
44+
extended_start_pos =
45+
std::make_unique<std::vector<int64_t>>(num_prompt_tokens);
46+
47+
// Fill the starting positions with values from [start_pos, start_pos +
48+
// num_prompt_tokens)
49+
std::iota(extended_start_pos->begin(), extended_start_pos->end(),
50+
start_pos);
51+
}
52+
53+
auto start_pos_tensor = from_blob(
54+
extend_position_input ? extended_start_pos->data() : &start_pos,
55+
{extend_position_input ? num_prompt_tokens : 1},
56+
executorch::aten::ScalarType::Long);
4357

4458
auto outputs_res = text_decoder_runner_->step(tokens, start_pos_tensor);
4559

packages/react-native-executorch/common/runner/text_prefiller.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ class TextPrefiller {
3030
* @return The next token of the LLM Module after prefill.
3131
*/
3232
::executorch::runtime::Result<uint64_t>
33-
prefill(std::vector<uint64_t> &prompt_tokens, int64_t &start_pos);
33+
prefill(std::vector<uint64_t> &prompt_tokens, int64_t &start_pos,
34+
bool extend_position_input = false);
3435

3536
private:
3637
TextDecoderRunner *text_decoder_runner_;

0 commit comments

Comments
 (0)