Skip to content

Commit ea943e4

Browse files
committed
Refactor STT native implementation
1 parent bab2ffb commit ea943e4

File tree

33 files changed

+1092
-712
lines changed

33 files changed

+1092
-712
lines changed

packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818
#include <rnexecutorch/models/object_detection/Constants.h>
1919
#include <rnexecutorch/models/object_detection/Types.h>
2020
#include <rnexecutorch/models/ocr/Types.h>
21-
#include <rnexecutorch/models/speech_to_text/types/Segment.h>
22-
#include <rnexecutorch/models/speech_to_text/types/TranscriptionResult.h>
21+
#include <rnexecutorch/models/speech_to_text/common/types/Segment.h>
22+
#include <rnexecutorch/models/speech_to_text/common/types/TranscriptionResult.h>
2323
#include <rnexecutorch/models/voice_activity_detection/Types.h>
2424

25-
using namespace rnexecutorch::models::speech_to_text::types;
25+
using namespace rnexecutorch::models::speech_to_text;
2626

2727
namespace rnexecutorch::jsi_conversion {
2828

packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
#include <string>
44
#include <vector>
55

6-
#include "rnexecutorch/metaprogramming/ConstructorHelpers.h"
76
#include <ReactCommon/CallInvoker.h>
87
#include <executorch/extension/module/module.h>
98
#include <jsi/jsi.h>
109
#include <rnexecutorch/host_objects/JSTensorViewIn.h>
1110
#include <rnexecutorch/host_objects/JSTensorViewOut.h>
11+
#include <rnexecutorch/metaprogramming/ConstructorHelpers.h>
1212

1313
namespace rnexecutorch {
1414
namespace models {

packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp

Lines changed: 39 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,54 @@
11
#include <thread>
22

33
#include "SpeechToText.h"
4+
#include "common/types/TranscriptionResult.h"
5+
#include "whisper/ASR.h"
6+
#include "whisper/OnlineASR.h"
47
#include <rnexecutorch/Error.h>
58
#include <rnexecutorch/ErrorCodes.h>
6-
#include <rnexecutorch/models/speech_to_text/types/TranscriptionResult.h>
79

810
namespace rnexecutorch::models::speech_to_text {
911

10-
using namespace ::executorch::extension;
11-
using namespace asr;
12-
using namespace types;
13-
using namespace stream;
14-
15-
SpeechToText::SpeechToText(const std::string &encoderSource,
16-
const std::string &decoderSource,
12+
SpeechToText::SpeechToText(const std::string &modelName,
13+
const std::string &modelSource,
1714
const std::string &tokenizerSource,
1815
std::shared_ptr<react::CallInvoker> callInvoker)
19-
: callInvoker(std::move(callInvoker)),
20-
encoder(std::make_unique<BaseModel>(encoderSource, this->callInvoker)),
21-
decoder(std::make_unique<BaseModel>(decoderSource, this->callInvoker)),
22-
tokenizer(std::make_unique<TokenizerModule>(tokenizerSource,
23-
this->callInvoker)),
24-
asr(std::make_unique<ASR>(this->encoder.get(), this->decoder.get(),
25-
this->tokenizer.get())),
26-
processor(std::make_unique<OnlineASRProcessor>(this->asr.get())),
27-
isStreaming(false), readyToProcess(false) {}
28-
29-
void SpeechToText::unload() noexcept {
30-
this->encoder->unload();
31-
this->decoder->unload();
16+
: callInvoker_(std::move(callInvoker)), isStreaming_(false),
17+
readyToProcess_(false) {
18+
// Switch between the ASR implementations based on model name
19+
if (modelName == "whisper") {
20+
transcriber_ = std::make_unique<whisper::ASR>(modelSource, tokenizerSource,
21+
callInvoker_);
22+
streamer_ = std::make_unique<whisper::stream::OnlineASR>(
23+
static_cast<const whisper::ASR *>(transcriber_.get()));
24+
} else {
25+
throw rnexecutorch::RnExecutorchError(
26+
rnexecutorch::RnExecutorchErrorCode::InvalidConfig,
27+
"[SpeechToText]: Invalid model name: " + modelName);
28+
}
3229
}
3330

31+
void SpeechToText::unload() noexcept { transcriber_->unload(); }
32+
3433
std::shared_ptr<OwningArrayBuffer>
3534
SpeechToText::encode(std::span<float> waveform) const {
36-
std::vector<float> encoderOutput = this->asr->encode(waveform);
35+
std::vector<float> encoderOutput = transcriber_->encode(waveform);
3736
return std::make_shared<OwningArrayBuffer>(encoderOutput);
3837
}
3938

4039
std::shared_ptr<OwningArrayBuffer>
4140
SpeechToText::decode(std::span<uint64_t> tokens,
4241
std::span<float> encoderOutput) const {
4342
std::vector<float> decoderOutput =
44-
this->asr->decode(tokens, 0, encoderOutput);
43+
transcriber_->decode(tokens, encoderOutput);
4544
return std::make_shared<OwningArrayBuffer>(decoderOutput);
4645
}
4746

4847
TranscriptionResult SpeechToText::transcribe(std::span<float> waveform,
4948
std::string languageOption,
5049
bool verbose) const {
5150
DecodingOptions options(languageOption, verbose);
52-
std::vector<Segment> segments = this->asr->transcribe(waveform, options);
51+
std::vector<Segment> segments = transcriber_->transcribe(waveform, options);
5352

5453
std::string fullText;
5554
for (const auto &segment : segments) {
@@ -71,8 +70,7 @@ TranscriptionResult SpeechToText::transcribe(std::span<float> waveform,
7170
}
7271

7372
size_t SpeechToText::getMemoryLowerBound() const noexcept {
74-
return this->encoder->getMemoryLowerBound() +
75-
this->decoder->getMemoryLowerBound();
73+
return transcriber_->getMemoryLowerBound();
7674
}
7775

7876
namespace {
@@ -106,7 +104,7 @@ TranscriptionResult wordsToResult(const std::vector<Word> &words,
106104

107105
void SpeechToText::stream(std::shared_ptr<jsi::Function> callback,
108106
std::string languageOption, bool verbose) {
109-
if (this->isStreaming) {
107+
if (isStreaming_) {
110108
throw RnExecutorchError(RnExecutorchErrorCode::StreamingInProgress,
111109
"Streaming is already in progress!");
112110
}
@@ -116,7 +114,7 @@ void SpeechToText::stream(std::shared_ptr<jsi::Function> callback,
116114
const TranscriptionResult &nonCommitted,
117115
bool isDone) {
118116
// This moves execution to the JS thread
119-
this->callInvoker->invokeAsync(
117+
callInvoker_->invokeAsync(
120118
[callback, committed, nonCommitted, isDone, verbose](jsi::Runtime &rt) {
121119
jsi::Value jsiCommitted =
122120
rnexecutorch::jsi_conversion::getJsiValue(committed, rt);
@@ -128,46 +126,45 @@ void SpeechToText::stream(std::shared_ptr<jsi::Function> callback,
128126
});
129127
};
130128

131-
this->isStreaming = true;
129+
isStreaming_ = true;
132130
DecodingOptions options(languageOption, verbose);
133131

134-
while (this->isStreaming) {
135-
if (!this->readyToProcess ||
136-
this->processor->audioBuffer.size() < SpeechToText::kMinAudioSamples) {
132+
while (isStreaming_) {
133+
if (!readyToProcess_ || !streamer_->ready()) {
137134
std::this_thread::sleep_for(std::chrono::milliseconds(100));
138135
continue;
139136
}
140137

141-
ProcessResult res = this->processor->processIter(options);
138+
ProcessResult res = streamer_->process(options);
142139

143140
TranscriptionResult cRes =
144141
wordsToResult(res.committed, languageOption, verbose);
145142
TranscriptionResult ncRes =
146143
wordsToResult(res.nonCommitted, languageOption, verbose);
147144

148145
nativeCallback(cRes, ncRes, false);
149-
this->readyToProcess = false;
146+
readyToProcess_ = false;
150147
}
151148

152-
std::vector<Word> finalWords = this->processor->finish();
149+
std::vector<Word> finalWords = streamer_->finish();
153150
TranscriptionResult finalRes =
154151
wordsToResult(finalWords, languageOption, verbose);
155152

156153
nativeCallback(finalRes, {}, true);
157-
this->resetStreamState();
154+
resetStreamState();
158155
}
159156

160-
void SpeechToText::streamStop() { this->isStreaming = false; }
157+
void SpeechToText::streamStop() { isStreaming_ = false; }
161158

162159
void SpeechToText::streamInsert(std::span<float> waveform) {
163-
this->processor->insertAudioChunk(waveform);
164-
this->readyToProcess = true;
160+
streamer_->insertAudioChunk(waveform);
161+
readyToProcess_ = true;
165162
}
166163

167164
void SpeechToText::resetStreamState() {
168-
this->isStreaming = false;
169-
this->readyToProcess = false;
170-
this->processor = std::make_unique<OnlineASRProcessor>(this->asr.get());
165+
isStreaming_ = false;
166+
readyToProcess_ = false;
167+
streamer_->reset();
171168
}
172169

173170
} // namespace rnexecutorch::models::speech_to_text

packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.h

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,21 @@
11
#pragma once
22

3-
#include "rnexecutorch/models/speech_to_text/stream/OnlineASRProcessor.h"
4-
#include <rnexecutorch/models/speech_to_text/types/TranscriptionResult.h>
53
#include <span>
64
#include <string>
75
#include <vector>
86

7+
#include "common/schema/ASR.h"
8+
#include "common/schema/OnlineASR.h"
9+
#include "common/types/TranscriptionResult.h"
10+
911
namespace rnexecutorch {
1012

1113
namespace models::speech_to_text {
1214

1315
class SpeechToText {
1416
public:
15-
explicit SpeechToText(const std::string &encoderSource,
16-
const std::string &decoderSource,
17+
explicit SpeechToText(const std::string &modelName,
18+
const std::string &modelSource,
1719
const std::string &tokenizerSource,
1820
std::shared_ptr<react::CallInvoker> callInvoker);
1921

@@ -25,9 +27,9 @@ class SpeechToText {
2527
"Registered non-void function")]] std::shared_ptr<OwningArrayBuffer>
2628
decode(std::span<uint64_t> tokens, std::span<float> encoderOutput) const;
2729
[[nodiscard("Registered non-void function")]]
28-
types::TranscriptionResult transcribe(std::span<float> waveform,
29-
std::string languageOption,
30-
bool verbose) const;
30+
TranscriptionResult transcribe(std::span<float> waveform,
31+
std::string languageOption,
32+
bool verbose) const;
3133

3234
[[nodiscard("Registered non-void function")]]
3335
std::vector<char> transcribeStringOnly(std::span<float> waveform,
@@ -42,20 +44,18 @@ class SpeechToText {
4244
void streamInsert(std::span<float> waveform);
4345

4446
private:
45-
std::shared_ptr<react::CallInvoker> callInvoker;
46-
std::unique_ptr<BaseModel> encoder;
47-
std::unique_ptr<BaseModel> decoder;
48-
std::unique_ptr<TokenizerModule> tokenizer;
49-
std::unique_ptr<asr::ASR> asr;
47+
// Helper functions
48+
void resetStreamState();
5049

51-
// Stream
52-
std::unique_ptr<stream::OnlineASRProcessor> processor;
53-
bool isStreaming;
54-
bool readyToProcess;
50+
std::shared_ptr<react::CallInvoker> callInvoker_;
5551

56-
constexpr static int32_t kMinAudioSamples = 16000; // 1 second
52+
// ASR-like module (both static transcription & streaming)
53+
std::unique_ptr<schema::ASR> transcriber_ = nullptr;
5754

58-
void resetStreamState();
55+
// Online ASR-like module (streaming only)
56+
std::unique_ptr<schema::OnlineASR> streamer_ = nullptr;
57+
bool isStreaming_ = false;
58+
bool readyToProcess_ = true;
5959
};
6060

6161
} // namespace models::speech_to_text

packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.h

Lines changed: 0 additions & 65 deletions
This file was deleted.
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#pragma once
2+
3+
#include <cinttypes>
4+
#include <span>
5+
#include <vector>
6+
7+
#include "../types/DecodingOptions.h"
8+
#include "../types/Segment.h"
9+
#include <rnexecutorch/models/BaseModel.h>
10+
11+
namespace rnexecutorch::models::speech_to_text::schema {
12+
13+
/**
14+
* @brief Abstract base class for Automatic Speech Recognition (ASR) models.
15+
*
16+
* Provides a unified interface for speech-to-text models like Whisper, allowing
17+
* for transcription of raw audio waveforms into text segments, as well as
18+
* access to lower-level model components like encoding and decoding.
19+
*/
20+
class ASR {
21+
public:
22+
virtual ~ASR() = default;
23+
24+
std::vector<Segment> virtual transcribe(
25+
std::span<float> waveform, const DecodingOptions &options) const = 0;
26+
27+
virtual std::vector<float> encode(std::span<float> waveform) const = 0;
28+
29+
virtual std::vector<float> decode(std::span<uint64_t> tokens,
30+
std::span<float> encoderOutput,
31+
uint64_t startPos = 0) const = 0;
32+
33+
// Standard ExecuTorch model methods for compatibility with the rest of the
34+
// API.
35+
virtual void unload() noexcept = 0;
36+
virtual std::size_t getMemoryLowerBound() const noexcept = 0;
37+
};
38+
39+
} // namespace rnexecutorch::models::speech_to_text::schema

0 commit comments

Comments
 (0)