11#include < thread>
22
33#include " SpeechToText.h"
4+ #include " common/types/TranscriptionResult.h"
5+ #include " whisper/ASR.h"
6+ #include " whisper/OnlineASR.h"
47#include < rnexecutorch/Error.h>
58#include < rnexecutorch/ErrorCodes.h>
6- #include < rnexecutorch/models/speech_to_text/types/TranscriptionResult.h>
79
810namespace rnexecutorch ::models::speech_to_text {
911
10- using namespace ::executorch::extension;
11- using namespace asr ;
12- using namespace types ;
13- using namespace stream ;
14-
15- SpeechToText::SpeechToText (const std::string &encoderSource,
16- const std::string &decoderSource,
12+ SpeechToText::SpeechToText (const std::string &modelName,
13+ const std::string &modelSource,
1714 const std::string &tokenizerSource,
1815 std::shared_ptr<react::CallInvoker> callInvoker)
19- : callInvoker (std::move(callInvoker)),
20- encoder (std::make_unique<BaseModel>(encoderSource, this ->callInvoker)),
21- decoder(std::make_unique<BaseModel>(decoderSource, this ->callInvoker)),
22- tokenizer(std::make_unique<TokenizerModule>(tokenizerSource,
23- this ->callInvoker)) ,
24- asr(std::make_unique<ASR>( this ->encoder.get(), this->decoder.get(),
25- this->tokenizer.get())),
26- processor(std::make_unique<OnlineASRProcessor>( this ->asr .get())),
27- isStreaming( false ), readyToProcess( false ) {}
28-
29- void SpeechToText::unload () noexcept {
30- this -> encoder -> unload ( );
31- this -> decoder -> unload ();
16+ : callInvoker_ (std::move(callInvoker)), isStreaming_( false ),
17+ readyToProcess_ ( false ) {
18+ // Switch between the ASR implementations based on model name
19+ if (modelName == " whisper " ) {
20+ transcriber_ = std::make_unique<whisper::ASR>(modelSource, tokenizerSource ,
21+ callInvoker_);
22+ streamer_ = std::make_unique<whisper::stream::OnlineASR>(
23+ static_cast < const whisper::ASR *>(transcriber_ .get ()));
24+ } else {
25+ throw rnexecutorch::RnExecutorchError (
26+ rnexecutorch::RnExecutorchErrorCode::InvalidConfig,
27+ " [SpeechToText]: Invalid model name: " + modelName );
28+ }
3229}
3330
31+ void SpeechToText::unload () noexcept { transcriber_->unload (); }
32+
3433std::shared_ptr<OwningArrayBuffer>
3534SpeechToText::encode (std::span<float > waveform) const {
36- std::vector<float > encoderOutput = this -> asr ->encode (waveform);
35+ std::vector<float > encoderOutput = transcriber_ ->encode (waveform);
3736 return std::make_shared<OwningArrayBuffer>(encoderOutput);
3837}
3938
4039std::shared_ptr<OwningArrayBuffer>
4140SpeechToText::decode (std::span<uint64_t > tokens,
4241 std::span<float > encoderOutput) const {
4342 std::vector<float > decoderOutput =
44- this -> asr -> decode (tokens, 0 , encoderOutput);
43+ transcriber_-> decode (tokens, encoderOutput);
4544 return std::make_shared<OwningArrayBuffer>(decoderOutput);
4645}
4746
4847TranscriptionResult SpeechToText::transcribe (std::span<float > waveform,
4948 std::string languageOption,
5049 bool verbose) const {
5150 DecodingOptions options (languageOption, verbose);
52- std::vector<Segment> segments = this -> asr ->transcribe (waveform, options);
51+ std::vector<Segment> segments = transcriber_ ->transcribe (waveform, options);
5352
5453 std::string fullText;
5554 for (const auto &segment : segments) {
@@ -71,8 +70,7 @@ TranscriptionResult SpeechToText::transcribe(std::span<float> waveform,
7170}
7271
7372size_t SpeechToText::getMemoryLowerBound () const noexcept {
74- return this ->encoder ->getMemoryLowerBound () +
75- this ->decoder ->getMemoryLowerBound ();
73+ return transcriber_->getMemoryLowerBound ();
7674}
7775
7876namespace {
@@ -106,7 +104,7 @@ TranscriptionResult wordsToResult(const std::vector<Word> &words,
106104
107105void SpeechToText::stream (std::shared_ptr<jsi::Function> callback,
108106 std::string languageOption, bool verbose) {
109- if (this -> isStreaming ) {
107+ if (isStreaming_ ) {
110108 throw RnExecutorchError (RnExecutorchErrorCode::StreamingInProgress,
111109 " Streaming is already in progress!" );
112110 }
@@ -116,7 +114,7 @@ void SpeechToText::stream(std::shared_ptr<jsi::Function> callback,
116114 const TranscriptionResult &nonCommitted,
117115 bool isDone) {
118116 // This moves execution to the JS thread
119- this -> callInvoker ->invokeAsync (
117+ callInvoker_ ->invokeAsync (
120118 [callback, committed, nonCommitted, isDone, verbose](jsi::Runtime &rt) {
121119 jsi::Value jsiCommitted =
122120 rnexecutorch::jsi_conversion::getJsiValue (committed, rt);
@@ -128,46 +126,45 @@ void SpeechToText::stream(std::shared_ptr<jsi::Function> callback,
128126 });
129127 };
130128
131- this -> isStreaming = true ;
129+ isStreaming_ = true ;
132130 DecodingOptions options (languageOption, verbose);
133131
134- while (this ->isStreaming ) {
135- if (!this ->readyToProcess ||
136- this ->processor ->audioBuffer .size () < SpeechToText::kMinAudioSamples ) {
132+ while (isStreaming_) {
133+ if (!readyToProcess_ || !streamer_->ready ()) {
137134 std::this_thread::sleep_for (std::chrono::milliseconds (100 ));
138135 continue ;
139136 }
140137
141- ProcessResult res = this -> processor -> processIter (options);
138+ ProcessResult res = streamer_-> process (options);
142139
143140 TranscriptionResult cRes =
144141 wordsToResult (res.committed , languageOption, verbose);
145142 TranscriptionResult ncRes =
146143 wordsToResult (res.nonCommitted , languageOption, verbose);
147144
148145 nativeCallback (cRes, ncRes, false );
149- this -> readyToProcess = false ;
146+ readyToProcess_ = false ;
150147 }
151148
152- std::vector<Word> finalWords = this -> processor ->finish ();
149+ std::vector<Word> finalWords = streamer_ ->finish ();
153150 TranscriptionResult finalRes =
154151 wordsToResult (finalWords, languageOption, verbose);
155152
156153 nativeCallback (finalRes, {}, true );
157- this -> resetStreamState ();
154+ resetStreamState ();
158155}
159156
160- void SpeechToText::streamStop () { this -> isStreaming = false ; }
157+ void SpeechToText::streamStop () { isStreaming_ = false ; }
161158
162159void SpeechToText::streamInsert (std::span<float > waveform) {
163- this -> processor ->insertAudioChunk (waveform);
164- this -> readyToProcess = true ;
160+ streamer_ ->insertAudioChunk (waveform);
161+ readyToProcess_ = true ;
165162}
166163
167164void SpeechToText::resetStreamState () {
168- this -> isStreaming = false ;
169- this -> readyToProcess = false ;
170- this -> processor = std::make_unique<OnlineASRProcessor>( this -> asr . get () );
165+ isStreaming_ = false ;
166+ readyToProcess_ = false ;
167+ streamer_-> reset ( );
171168}
172169
173170} // namespace rnexecutorch::models::speech_to_text
0 commit comments