diff --git a/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp b/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp index b82d24316e..3122867bde 100644 --- a/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp @@ -1,5 +1,6 @@ #include "RnExecutorchInstaller.h" +#include #include #include #include @@ -42,6 +43,11 @@ void RnExecutorchInstaller::injectJSIBindings( *jsiRuntime, "loadExecutorchModule", RnExecutorchInstaller::loadModel(jsiRuntime, jsCallInvoker, "loadExecutorchModule")); + + jsiRuntime->global().setProperty( + *jsiRuntime, "loadTokenizerModule", + RnExecutorchInstaller::loadModel( + jsiRuntime, jsCallInvoker, "loadTokenizerModule")); } } // namespace rnexecutorch \ No newline at end of file diff --git a/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.h b/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.h index 42d0a86bdc..53e0401262 100644 --- a/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.h +++ b/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.h @@ -28,6 +28,8 @@ REGISTER_CONSTRUCTOR(ObjectDetection, std::string, std::shared_ptr); REGISTER_CONSTRUCTOR(BaseModel, std::string, std::shared_ptr); +REGISTER_CONSTRUCTOR(TokenizerModule, std::string, + std::shared_ptr); using namespace facebook; diff --git a/packages/react-native-executorch/common/rnexecutorch/TokenizerModule.cpp b/packages/react-native-executorch/common/rnexecutorch/TokenizerModule.cpp new file mode 100644 index 0000000000..42664b4384 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/TokenizerModule.cpp @@ -0,0 +1,52 @@ +#include "TokenizerModule.h" +#include +#include +#include +#include + +namespace rnexecutorch { +using namespace facebook; + +TokenizerModule::TokenizerModule( + std::string source, std::shared_ptr callInvoker) + : memorySizeLowerBound(std::filesystem::file_size(source)), + tokenizer(tokenizers::Tokenizer::FromBlobJSON( + fileutils::loadBytesFromFile(source))) {} + +void TokenizerModule::ensureTokenizerLoaded( + const std::string &methodName) const { + if (!tokenizer) { + throw std::runtime_error( + methodName + " function was called on an uninitialized tokenizer!"); + } +} + +std::vector TokenizerModule::encode(std::string s) const { + ensureTokenizerLoaded("encode"); + return tokenizer->Encode(s); +} + +std::string TokenizerModule::decode(std::vector vec, + bool skipSpecialTokens) const { + ensureTokenizerLoaded("decode"); + return tokenizer->Decode(vec, skipSpecialTokens); +} + +size_t TokenizerModule::getVocabSize() const { + ensureTokenizerLoaded("getVocabSize"); + return tokenizer->GetVocabSize(); +} + +std::string TokenizerModule::idToToken(int32_t tokenId) const { + ensureTokenizerLoaded("idToToken"); + return tokenizer->IdToToken(tokenId); +} + +int32_t TokenizerModule::tokenToId(std::string token) const { + ensureTokenizerLoaded("tokenToId"); + return tokenizer->TokenToId(token); +} +std::size_t TokenizerModule::getMemoryLowerBound() const noexcept { + return memorySizeLowerBound; +} +} // namespace rnexecutorch \ No newline at end of file diff --git a/packages/react-native-executorch/common/rnexecutorch/TokenizerModule.h b/packages/react-native-executorch/common/rnexecutorch/TokenizerModule.h new file mode 100644 index 0000000000..284ef403d6 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/TokenizerModule.h @@ -0,0 +1,26 @@ +#pragma once + +#include +#include +#include + +namespace rnexecutorch { +using namespace facebook; + +class TokenizerModule { +public: + explicit TokenizerModule(std::string source, + std::shared_ptr callInvoker); + std::vector encode(std::string s) const; + std::string decode(std::vector vec, bool skipSpecialTokens) const; + std::string idToToken(int32_t tokenId) const; + int32_t tokenToId(std::string token) const; + std::size_t getVocabSize() const; + std::size_t getMemoryLowerBound() const noexcept; + +private: + void ensureTokenizerLoaded(const std::string &methodName) const; + std::unique_ptr tokenizer; + const std::size_t memorySizeLowerBound{0}; +}; +} // namespace rnexecutorch \ No newline at end of file diff --git a/packages/react-native-executorch/common/rnexecutorch/data_processing/FileUtils.h b/packages/react-native-executorch/common/rnexecutorch/data_processing/FileUtils.h index 98035b07b9..d8fa83145f 100644 --- a/packages/react-native-executorch/common/rnexecutorch/data_processing/FileUtils.h +++ b/packages/react-native-executorch/common/rnexecutorch/data_processing/FileUtils.h @@ -1,6 +1,8 @@ #pragma once #include +#include +#include #include namespace rnexecutorch::fileutils { @@ -11,4 +13,18 @@ inline std::string getTimeID() { .count()); } +inline std::string loadBytesFromFile(const std::string &path) { + std::ifstream fs(path, std::ios::in | std::ios::binary); + if (fs.fail()) { + throw std::runtime_error("Failed to open tokenizer file"); + } + std::string data; + fs.seekg(0, std::ios::end); + size_t size = static_cast(fs.tellg()); + fs.seekg(0, std::ios::beg); + data.resize(size); + fs.read(data.data(), size); + return data; +}; + } // namespace rnexecutorch::fileutils diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h index f723db5ce9..d04eed18f1 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h @@ -41,6 +41,21 @@ inline std::string getValue(const jsi::Value &val, return val.getString(runtime).utf8(runtime); } +template <> +inline std::vector +getValue>(const jsi::Value &val, jsi::Runtime &runtime) { + jsi::Array array = val.asObject(runtime).asArray(runtime); + size_t length = array.size(runtime); + std::vector result; + result.reserve(length); + + for (size_t i = 0; i < length; ++i) { + jsi::Value element = array.getValueAtIndex(runtime, i); + result.push_back(getValue(element, runtime)); + } + return result; +} + template <> inline JSTensorViewIn getValue(const jsi::Value &val, jsi::Runtime &runtime) { @@ -182,20 +197,18 @@ getJsiValue(const std::vector> &vec, return jsi::Value(runtime, array); } -inline jsi::Value -getJsiValue(const std::vector> &vec, - jsi::Runtime &runtime) { +inline jsi::Value getJsiValue(const std::vector &vec, + jsi::Runtime &runtime) { jsi::Array array(runtime, vec.size()); for (size_t i = 0; i < vec.size(); i++) { jsi::Object tensorObj(runtime); - tensorObj.setProperty(runtime, "sizes", - getJsiValue(vec[i]->sizes, runtime)); + tensorObj.setProperty(runtime, "sizes", getJsiValue(vec[i].sizes, runtime)); tensorObj.setProperty(runtime, "scalarType", - jsi::Value(static_cast(vec[i]->scalarType))); + jsi::Value(static_cast(vec[i].scalarType))); - jsi::ArrayBuffer arrayBuffer(runtime, vec[i]->dataPtr); + jsi::ArrayBuffer arrayBuffer(runtime, vec[i].dataPtr); tensorObj.setProperty(runtime, "dataPtr", arrayBuffer); array.setValueAtIndex(runtime, i, tensorObj); diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h index 95f5e01071..fa3ec3bab7 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h @@ -7,7 +7,7 @@ #include -#include +#include #include #include #include @@ -45,6 +45,31 @@ template class ModelHostObject : public JsiHostObject { promiseHostFunction<&Model::generate>, "generate")); } + + if constexpr (meta::HasEncode) { + addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, + promiseHostFunction<&Model::encode>, + "encode")); + } + + if constexpr (meta::SameAs) { + addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, + promiseHostFunction<&Model::encode>, + "encode")); + + addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, + promiseHostFunction<&Model::decode>, + "decode")); + addFunctions(JSI_EXPORT_FUNCTION( + ModelHostObject, promiseHostFunction<&Model::getVocabSize>, + "getVocabSize")); + addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, + promiseHostFunction<&Model::idToToken>, + "idToToken")); + addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, + promiseHostFunction<&Model::tokenToId>, + "tokenToId")); + } } // A generic host function that resolves a promise with a result of a @@ -76,8 +101,8 @@ template class ModelHostObject : public JsiHostObject { std::thread([this, promise, argsConverted = std::move(argsConverted)]() { try { - auto result = - std::apply(std::bind_front(FnPtr, model), argsConverted); + auto result = std::apply(std::bind_front(FnPtr, model), + std::move(argsConverted)); // The result is copied. It should either be quickly copiable, // or passed with a shared_ptr. callInvoker->invokeAsync([promise, diff --git a/packages/react-native-executorch/common/rnexecutorch/metaprogramming/FunctionHelpers.h b/packages/react-native-executorch/common/rnexecutorch/metaprogramming/FunctionHelpers.h index ff7a5fa53c..622106dd51 100644 --- a/packages/react-native-executorch/common/rnexecutorch/metaprogramming/FunctionHelpers.h +++ b/packages/react-native-executorch/common/rnexecutorch/metaprogramming/FunctionHelpers.h @@ -14,6 +14,11 @@ constexpr std::size_t getArgumentCount(R (Model::*f)(Types...)) { return sizeof...(Types); } +template +constexpr std::size_t getArgumentCount(R (Model::*f)(Types...) const) { + return sizeof...(Types); +} + template std::tuple fillTupleFromArgs(std::index_sequence, const jsi::Value *args, @@ -34,4 +39,12 @@ std::tuple createArgsTupleFromJsi(R (Model::*f)(Types...), return fillTupleFromArgs(std::index_sequence_for{}, args, runtime); } + +template +std::tuple createArgsTupleFromJsi(R (Model::*f)(Types...) const, + const jsi::Value *args, + jsi::Runtime &runtime) { + return fillTupleFromArgs(std::index_sequence_for{}, args, + runtime); +} } // namespace rnexecutorch::meta \ No newline at end of file diff --git a/packages/react-native-executorch/common/rnexecutorch/metaprogramming/TypeConcepts.h b/packages/react-native-executorch/common/rnexecutorch/metaprogramming/TypeConcepts.h index ae5111ba37..253a532388 100644 --- a/packages/react-native-executorch/common/rnexecutorch/metaprogramming/TypeConcepts.h +++ b/packages/react-native-executorch/common/rnexecutorch/metaprogramming/TypeConcepts.h @@ -8,11 +8,24 @@ namespace rnexecutorch::meta { template concept DerivedFromOrSameAs = std::is_base_of_v; +template +concept SameAs = std::is_same_v; + template concept HasGenerate = requires(T t) { { &T::generate }; }; +template +concept HasEncode = requires(T t) { + { &T::encode }; +}; + +template +concept HasDecode = requires(T t) { + { &T::decode }; +}; + template concept IsNumeric = std::is_arithmetic_v; diff --git a/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.cpp b/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.cpp index d4f846518d..64e223ec11 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.cpp @@ -83,7 +83,7 @@ BaseModel::getAllInputShapes(std::string methodName) { return output; } -std::vector> +std::vector BaseModel::forwardJS(const std::vector tensorViewVec) { if (!module) { throw std::runtime_error("Model not loaded: Cannot perform forward pass"); @@ -114,7 +114,7 @@ BaseModel::forwardJS(const std::vector tensorViewVec) { } auto &outputs = result.get(); - std::vector> output; + std::vector output; output.reserve(outputs.size()); // Convert ET outputs to a vector of JSTensorViewOut which are later @@ -125,8 +125,7 @@ BaseModel::forwardJS(const std::vector tensorViewVec) { size_t bufferSize = outputTensor.numel() * outputTensor.element_size(); auto buffer = std::make_shared(bufferSize); std::memcpy(buffer->data(), outputTensor.const_data_ptr(), bufferSize); - auto jsTensor = std::make_shared( - sizes, outputTensor.scalar_type(), buffer); + auto jsTensor = JSTensorViewOut(sizes, outputTensor.scalar_type(), buffer); output.emplace_back(jsTensor); } return output; diff --git a/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h b/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h index 79e426f199..d8b7add02f 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h @@ -22,7 +22,7 @@ class BaseModel { std::vector getInputShape(std::string method_name, int index); std::vector> getAllInputShapes(std::string methodName = "forward"); - std::vector> + std::vector forwardJS(std::vector tensorViewVec); protected: diff --git a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp index 52e3bbece2..a2c23330c4 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp @@ -1,10 +1,7 @@ #include "StyleTransfer.h" -#include #include -#include - #include #include @@ -42,7 +39,7 @@ std::string StyleTransfer::postprocess(const Tensor &tensor, } std::string StyleTransfer::generate(std::string imageSource) { - auto [inputTensor, originalSize] = + auto [inputTensor, originalSize] = imageprocessing::readImageToTensor(imageSource, getAllInputShapes()[0]); auto forwardResult = BaseModel::forward(inputTensor); diff --git a/packages/react-native-executorch/src/controllers/SpeechToTextController.ts b/packages/react-native-executorch/src/controllers/SpeechToTextController.ts index c34d3759dd..2d2217a3d0 100644 --- a/packages/react-native-executorch/src/controllers/SpeechToTextController.ts +++ b/packages/react-native-executorch/src/controllers/SpeechToTextController.ts @@ -22,6 +22,7 @@ export class SpeechToTextController { public isReady = false; public isGenerating = false; + private tokenizerModule: TokenizerModule; private overlapSeconds!: number; private windowSize!: number; private chunks: number[][] = []; @@ -60,6 +61,7 @@ export class SpeechToTextController { windowSize?: number; streamingConfig?: keyof typeof MODES; }) { + this.tokenizerModule = new TokenizerModule(); this.decodedTranscribeCallback = async (seq) => transcribeCallback(await this.tokenIdsToText(seq)); this.modelDownloadProgressCallback = modelDownloadProgressCallback; @@ -97,10 +99,9 @@ export class SpeechToTextController { this.config = MODEL_CONFIGS[modelName]; try { - await TokenizerModule.load( + await this.tokenizerModule.load( tokenizerSource || this.config.tokenizer.source ); - [encoderSource, decoderSource] = await ResourceFetcher.fetchMultipleResources( this.modelDownloadProgressCallback, @@ -196,10 +197,13 @@ export class SpeechToTextController { return [this.config.tokenizer.bos]; } // FIXME: I should use .getTokenId for the BOS as well, should remove it from config - const langTokenId = await TokenizerModule.tokenToId(`<|${audioLanguage}|>`); - const transcribeTokenId = await TokenizerModule.tokenToId('<|transcribe|>'); + const langTokenId = await this.tokenizerModule.tokenToId( + `<|${audioLanguage}|>` + ); + const transcribeTokenId = + await this.tokenizerModule.tokenToId('<|transcribe|>'); const noTimestampsTokenId = - await TokenizerModule.tokenToId('<|notimestamps|>'); + await this.tokenizerModule.tokenToId('<|notimestamps|>'); const startingTokenIds = [ this.config.tokenizer.bos, langTokenId, @@ -294,7 +298,7 @@ export class SpeechToTextController { private async tokenIdsToText(tokenIds: number[]): Promise { try { - return TokenizerModule.decode(tokenIds, true); + return await this.tokenizerModule.decode(tokenIds, true); } catch (e) { this.onErrorCallback( new Error(`An error has occurred when decoding the token ids: ${e}`) diff --git a/packages/react-native-executorch/src/hooks/natural_language_processing/useTokenizer.ts b/packages/react-native-executorch/src/hooks/natural_language_processing/useTokenizer.ts index 29b9734c98..ee66acbc2f 100644 --- a/packages/react-native-executorch/src/hooks/natural_language_processing/useTokenizer.ts +++ b/packages/react-native-executorch/src/hooks/natural_language_processing/useTokenizer.ts @@ -1,4 +1,4 @@ -import { useEffect, useState } from 'react'; +import { useEffect, useRef, useState } from 'react'; import { TokenizerModule } from '../../modules/natural_language_processing/TokenizerModule'; import { ResourceSource } from '../../types/common'; import { ETError, getError } from '../../Error'; @@ -14,13 +14,14 @@ export const useTokenizer = ({ const [isReady, setIsReady] = useState(false); const [isGenerating, setIsGenerating] = useState(false); const [downloadProgress, setDownloadProgress] = useState(0); + const tokenizerModuleRef = useRef(null); useEffect(() => { const loadModule = async () => { try { setIsReady(false); - TokenizerModule.onDownloadProgress(setDownloadProgress); - await TokenizerModule.load(tokenizerSource); + tokenizerModuleRef.current = new TokenizerModule(); + tokenizerModuleRef.current.load(tokenizerSource, setDownloadProgress); setIsReady(true); } catch (err) { setError((err as Error).message); @@ -32,15 +33,14 @@ export const useTokenizer = ({ }, [tokenizerSource, preventLoad]); const stateWrapper = Promise>(fn: T) => { - const boundFn = fn.bind(TokenizerModule); - return async (...args: Parameters): Promise> => { - if (!isReady) throw new Error(getError(ETError.ModuleNotLoaded)); + if (!isReady || !tokenizerModuleRef.current) + throw new Error(getError(ETError.ModuleNotLoaded)); if (isGenerating) throw new Error(getError(ETError.ModelGenerating)); setIsGenerating(true); try { - return await boundFn(...args); + return await fn.apply(tokenizerModuleRef.current, args); } finally { setIsGenerating(false); } @@ -52,10 +52,10 @@ export const useTokenizer = ({ isReady, isGenerating, downloadProgress, - decode: stateWrapper(TokenizerModule.decode), - encode: stateWrapper(TokenizerModule.encode), - getVocabSize: stateWrapper(TokenizerModule.getVocabSize), - idToToken: stateWrapper(TokenizerModule.idToToken), - tokenToId: stateWrapper(TokenizerModule.tokenToId), + decode: stateWrapper(TokenizerModule.prototype.decode), + encode: stateWrapper(TokenizerModule.prototype.encode), + getVocabSize: stateWrapper(TokenizerModule.prototype.getVocabSize), + idToToken: stateWrapper(TokenizerModule.prototype.idToToken), + tokenToId: stateWrapper(TokenizerModule.prototype.tokenToId), }; }; diff --git a/packages/react-native-executorch/src/index.tsx b/packages/react-native-executorch/src/index.tsx index cb6d2ca24b..0af19b0c3e 100644 --- a/packages/react-native-executorch/src/index.tsx +++ b/packages/react-native-executorch/src/index.tsx @@ -9,13 +9,16 @@ declare global { var loadClassification: (source: string) => any; var loadObjectDetection: (source: string) => any; var loadExecutorchModule: (source: string) => any; + var loadTokenizerModule: (source: string) => any; } // eslint-disable no-var if ( global.loadStyleTransfer == null || global.loadImageSegmentation == null || global.loadExecutorchModule == null || - global.loadClassification == null + global.loadClassification == null || + global.loadObjectDetection == null || + global.loadTokenizerModule == null ) { if (!ETInstallerNativeModule) { throw new Error( diff --git a/packages/react-native-executorch/src/modules/natural_language_processing/TokenizerModule.ts b/packages/react-native-executorch/src/modules/natural_language_processing/TokenizerModule.ts index 589a147d78..370b5b556a 100644 --- a/packages/react-native-executorch/src/modules/natural_language_processing/TokenizerModule.ts +++ b/packages/react-native-executorch/src/modules/natural_language_processing/TokenizerModule.ts @@ -1,34 +1,37 @@ -import { TokenizerNativeModule } from '../../native/RnExecutorchModules'; import { ResourceSource } from '../../types/common'; -import { BaseModule } from '../BaseModule'; +import { ResourceFetcher } from '../../utils/ResourceFetcher'; -export class TokenizerModule extends BaseModule { - protected static override nativeModule = TokenizerNativeModule; +export class TokenizerModule { + nativeModule: any; - static override async load(tokenizerSource: ResourceSource) { - await super.load([tokenizerSource]); + async load( + modelSource: ResourceSource, + onDownloadProgressCallback: (_: number) => void = () => {} + ): Promise { + const paths = await ResourceFetcher.fetchMultipleResources( + onDownloadProgressCallback, + modelSource + ); + this.nativeModule = global.loadTokenizerModule(paths[0] || ''); } - static async decode( - input: number[], - skipSpecialTokens = false - ): Promise { - return await this.nativeModule.decode(input, skipSpecialTokens); + async encode(s: string) { + return await this.nativeModule.encode(s); } - static async encode(input: string): Promise { - return await this.nativeModule.encode(input); + async decode(tokens: number[], skipSpecialTokens: boolean = true) { + return await this.nativeModule.decode(tokens, skipSpecialTokens); } - static async getVocabSize(): Promise { + async getVocabSize(): Promise { return await this.nativeModule.getVocabSize(); } - static async idToToken(tokenId: number): Promise { - return await this.nativeModule.idToToken(tokenId); + async idToToken(tokenId: number): Promise { + return this.nativeModule.idToToken(tokenId); } - static async tokenToId(token: string): Promise { + async tokenToId(token: string): Promise { return await this.nativeModule.tokenToId(token); } } diff --git a/packages/react-native-executorch/third-party/android/libs/tokenizers-cpp/libtokenizers_c.a b/packages/react-native-executorch/third-party/android/libs/tokenizers-cpp/libtokenizers_c.a new file mode 100644 index 0000000000..873866fbcc Binary files /dev/null and b/packages/react-native-executorch/third-party/android/libs/tokenizers-cpp/libtokenizers_c.a differ diff --git a/packages/react-native-executorch/third-party/android/libs/tokenizers-cpp/libtokenizers_cpp.a b/packages/react-native-executorch/third-party/android/libs/tokenizers-cpp/libtokenizers_cpp.a new file mode 100644 index 0000000000..eb820ccd48 Binary files /dev/null and b/packages/react-native-executorch/third-party/android/libs/tokenizers-cpp/libtokenizers_cpp.a differ diff --git a/packages/react-native-executorch/third-party/include/tokenizers-cpp/tokenizers_c.h b/packages/react-native-executorch/third-party/include/tokenizers-cpp/tokenizers_c.h new file mode 100644 index 0000000000..42a59e94e5 --- /dev/null +++ b/packages/react-native-executorch/third-party/include/tokenizers-cpp/tokenizers_c.h @@ -0,0 +1,61 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file tokenizers_c.h + * \brief C binding to tokenizers rust library + */ +#ifndef TOKENIZERS_C_H_ +#define TOKENIZERS_C_H_ + +// The C API +#ifdef __cplusplus +extern "C" { +#endif + +#include +#include + +typedef void *TokenizerHandle; + +typedef struct { + int *token_ids; + size_t len; +} TokenizerEncodeResult; + +TokenizerHandle tokenizers_new_from_str(const char *json, size_t len); + +TokenizerHandle byte_level_bpe_tokenizers_new_from_str( + const char *vocab, size_t vocab_len, const char *merges, size_t merges_len, + const char *added_tokens, size_t added_tokens_len); + +void tokenizers_encode(TokenizerHandle handle, const char *data, size_t len, + int add_special_token, TokenizerEncodeResult *result); + +void tokenizers_encode_batch(TokenizerHandle handle, const char **data, + size_t *len, size_t num_seqs, + int add_special_token, + TokenizerEncodeResult *results); + +void tokenizers_free_encode_results(TokenizerEncodeResult *results, + size_t num_seqs); + +void tokenizers_decode(TokenizerHandle handle, const uint32_t *data, size_t len, + int skip_special_token); + +void tokenizers_get_decode_str(TokenizerHandle handle, const char **data, + size_t *len); + +void tokenizers_get_vocab_size(TokenizerHandle handle, size_t *size); + +void tokenizers_id_to_token(TokenizerHandle handle, uint32_t id, + const char **data, size_t *len); + +// tokenizers_token_to_id stores -1 to *id if the token is not in the vocab +void tokenizers_token_to_id(TokenizerHandle handle, const char *token, + size_t len, int32_t *id); + +void tokenizers_free(TokenizerHandle handle); + +#ifdef __cplusplus +} +#endif +#endif // TOKENIZERS_C_H_ diff --git a/packages/react-native-executorch/third-party/include/tokenizers-cpp/tokenizers_cpp.h b/packages/react-native-executorch/third-party/include/tokenizers-cpp/tokenizers_cpp.h new file mode 100644 index 0000000000..72c9c33a8a --- /dev/null +++ b/packages/react-native-executorch/third-party/include/tokenizers-cpp/tokenizers_cpp.h @@ -0,0 +1,118 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file tokenizers_cpp.h + * \brief A C++ binding to common set of tokenizers + */ +#ifndef TOKENIZERS_CPP_H_ +#define TOKENIZERS_CPP_H_ + +#include +#include +#include + +namespace tokenizers { + +/*! + * \brief a universal tokenizer that loads + * either HF's tokenizer or sentence piece, + * depending on the constructor + */ +class Tokenizer { +public: + /*! \brief virtual destructor */ + virtual ~Tokenizer() {} + + /*! + * \brief Encode text into ids. + * \param text The input text. + * \returns The encoded token ids. + */ + virtual std::vector Encode(const std::string &text) = 0; + + /*! + * \brief Encode a batch of texts into ids. + * \param texts The input texts. + * \returns The encoded token ids. + */ + virtual std::vector> + EncodeBatch(const std::vector &texts) { + // Fall back when the derived class does not implement this function. + std::vector> ret; + ret.reserve(texts.size()); + for (const auto &text : texts) { + ret.push_back(Encode(text)); + } + return ret; + } + + /*! + * \brief Decode token ids into text. + * \param text The token ids. + * \returns The decoded text. + */ + virtual std::string Decode(const std::vector &ids) = 0; + + virtual std::string Decode(const std::vector &ids, + bool skip_special_tokens) = 0; + + /*! + * \brief Returns the vocabulary size. Special tokens are considered. + */ + virtual size_t GetVocabSize() = 0; + + /*! + * \brief Convert the given id to its corresponding token if it exists. If + * not, return an empty string. + */ + virtual std::string IdToToken(int32_t token_id) = 0; + + /*! + * \brief Convert the given token to its corresponding id if it exists. If + * not, return -1. + */ + virtual int32_t TokenToId(const std::string &token) = 0; + + //--------------------------------------------------- + // Factory functions from byte-blobs + // These factory function takes in in-memory blobs + // so the library can be independent from filesystem + //--------------------------------------------------- + /*! + * \brief Create HF tokenizer from a single in-memory json blob. + * + * \param json_blob The json blob. + * \return The created tokenzier. + */ + static std::unique_ptr FromBlobJSON(const std::string &json_blob); + /*! + * \brief Create BPE tokenizer + * + * \param vocab_blob The blob that contains vocabs. + * \param merges_blob The blob that contains the merges. + * \param added_tokens The added tokens. + * \return The created tokenizer. + */ + static std::unique_ptr + FromBlobByteLevelBPE(const std::string &vocab_blob, + const std::string &merges_blob, + const std::string &added_tokens = ""); + /*! + * \brief Create SentencePiece. + * + * \param model_blob The blob that contains vocabs. + * \return The created tokenizer. + */ + static std::unique_ptr + FromBlobSentencePiece(const std::string &model_blob); + /*! + * \brief Create RWKVWorldTokenizer. + * + * \param model_blob The blob that contains vocabs. + * \return The created tokenizer. + */ + static std::unique_ptr + FromBlobRWKVWorld(const std::string &model_blob); +}; + +} // namespace tokenizers +#endif // TOKENIZERS_CPP_H_