-
Notifications
You must be signed in to change notification settings - Fork 75
feat: port LLMs to C++ #415
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 30 commits
e299c3f
45938f3
5cda73f
0383fcc
6a217bf
d985c25
189440a
6e6703d
6fdd91b
bc83f01
158265f
b00c5f0
6fdf271
9264242
3740b5b
2c5bd57
23d61ff
ec40ff8
4f2810e
64785cc
c3a7d17
fcda895
50b19cc
d80e855
5117e65
c3b1a84
afb1912
2cf6c6a
2acd171
ffe6387
3826a29
cf72d6a
ec82b0e
f23587b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,55 @@ | ||
| #include "LLM.h" | ||
|
|
||
| #include <executorch/extension/tensor/tensor.h> | ||
| #include <filesystem> | ||
| #include <rnexecutorch/Log.h> | ||
|
|
||
| namespace rnexecutorch { | ||
| using namespace facebook; | ||
| using executorch::extension::TensorPtr; | ||
| using executorch::runtime::Error; | ||
|
|
||
| LLM::LLM(const std::string &modelSource, const std::string &tokenizerSource, | ||
|
chmjkb marked this conversation as resolved.
|
||
| std::shared_ptr<react::CallInvoker> callInvoker) | ||
| : runner(std::make_unique<example::Runner>(modelSource, tokenizerSource)), | ||
| callInvoker(callInvoker) { | ||
| auto loadResult = runner->load(); | ||
| if (loadResult != Error::Ok) { | ||
| throw std::runtime_error("Failed to load LLM runner"); | ||
| } | ||
|
chmjkb marked this conversation as resolved.
|
||
| memorySizeLowerBound = | ||
| std::filesystem::file_size(std::filesystem::path(modelSource)) + | ||
| std::filesystem::file_size(std::filesystem::path(tokenizerSource)); | ||
| } | ||
|
|
||
| void LLM::generate(std::string input, std::shared_ptr<jsi::Function> callback) { | ||
| if (!runner || !runner->is_loaded()) { | ||
| throw std::runtime_error("Runner is not loaded"); | ||
| } | ||
|
|
||
| // Create a native callback that will invoke the JS callback on the JS thread | ||
| auto nativeCallback = [this, callback](const std::string &token) { | ||
| callInvoker->invokeAsync([callback, token](jsi::Runtime &runtime) { | ||
| callback->call(runtime, jsi::String::createFromUtf8(runtime, token)); | ||
| }); | ||
| }; | ||
|
|
||
| auto error = runner->generate(input, nativeCallback, {}, false); | ||
| if (error != executorch::runtime::Error::Ok) { | ||
| throw std::runtime_error("Failed to generate text, error: " + | ||
| std::to_string(static_cast<int>(error))); | ||
|
chmjkb marked this conversation as resolved.
Outdated
|
||
| } | ||
| } | ||
|
|
||
| void LLM::interrupt() { | ||
| if (!runner || !runner->is_loaded()) { | ||
| throw std::runtime_error("Can't interrupt a model that's not loaded!"); | ||
| } | ||
| runner->stop(); | ||
| } | ||
|
|
||
| std::size_t LLM::getMemoryLowerBound() { return memorySizeLowerBound; } | ||
|
|
||
| void LLM::unload() { runner.reset(nullptr); } | ||
|
|
||
| } // namespace rnexecutorch | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| #pragma once | ||
|
|
||
| #include <memory> | ||
| #include <string> | ||
|
|
||
| #include <ReactCommon/CallInvoker.h> | ||
| #include <jsi/jsi.h> | ||
| #include <runner/runner.h> | ||
|
|
||
| namespace rnexecutorch { | ||
| using namespace facebook; | ||
|
|
||
| class LLM { | ||
| public: | ||
| LLM(const std::string &modelSource, const std::string &tokenizerSource, | ||
| std::shared_ptr<react::CallInvoker> callInvoker); | ||
|
|
||
| void generate(std::string input, std::shared_ptr<jsi::Function> callback); | ||
| void interrupt(); | ||
| void unload(); | ||
| std::size_t getMemoryLowerBound(); | ||
|
|
||
| private: | ||
| size_t memorySizeLowerBound; | ||
| std::unique_ptr<example::Runner> runner; | ||
| std::shared_ptr<react::CallInvoker> callInvoker; | ||
| }; | ||
| } // namespace rnexecutorch |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,6 +32,7 @@ | |
| * SOFTWARE. | ||
| */ | ||
|
|
||
| // #include <executorch/extension/llm/sampler/sampler.h> | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not exactly sure why that line was here in the first place. Below you can see a line: #include "sampler.h"Which literally includes the same thing. Also, we can't access |
||
| #include "sampler.h" | ||
| #include <algorithm> | ||
|
|
||
|
|
@@ -184,9 +185,10 @@ template <typename T> int32_t Sampler::sample(T *logits) { | |
| } | ||
|
|
||
| template int32_t Sampler::sample<float>(float *logits); | ||
| template int32_t Sampler::sample<exec_aten::Half>(exec_aten::Half *logits); | ||
| template int32_t | ||
| Sampler::sample<exec_aten::BFloat16>(exec_aten::BFloat16 *logits); | ||
| Sampler::sample<executorch::aten::Half>(executorch::aten::Half *logits); | ||
| template int32_t | ||
| Sampler::sample<executorch::aten::BFloat16>(executorch::aten::BFloat16 *logits); | ||
|
|
||
| } // namespace llm | ||
| } // namespace extension | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.