Skip to content

Commit 6dd8fb6

Browse files
IgorSwatmsluszniak
andauthored
fix!: speech to text live transcription (#816)
## Description Various improvements & adjustments in Speech-to-Text module. The list of changes includes: - Adjusting native implementation to the new format of Whisper models (single file, bundled encode & decode methods) - Refactoring native implementation in order to support multiple STT models in the future - Fixing an impropriate behavior of Whisper streaming ### Introduces a breaking change? - [x] Yes - [ ] No ### Type of change - [x] Bug fix (change which fixes an issue) - [ ] New feature (change which adds functionality) - [ ] Documentation update (improves or adds clarity to existing documentation) - [x] Other (chores, tests, code style improvements etc.) ### Tested on - [x] iOS - [x] Android ### Testing instructions You can run the tests defined for Speech-to-Text module, as well as test it manually with the 'speech' demo app (SpeechToText screen). ### Screenshots <!-- Add screenshots here, if applicable --> ### Related issues <!-- Link related issues here using #issue-number --> ### Checklist - [x] I have performed a self-review of my code - [x] I have commented my code, particularly in hard-to-understand areas - [x] I have updated the documentation accordingly - [x] My changes generate no new warnings ### Additional notes --------- Co-authored-by: Mateusz Słuszniak <mateusz.sluszniak@swmansion.com>
1 parent ce065d2 commit 6dd8fb6

File tree

42 files changed

+1766
-908
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+1766
-908
lines changed

apps/speech/screens/SpeechToTextScreen.tsx

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ export const SpeechToTextScreen = ({ onBack }: { onBack: () => void }) => {
5050
const [liveTranscribing, setLiveTranscribing] = useState(false);
5151
const scrollViewRef = useRef<ScrollView>(null);
5252

53-
const recorder = new AudioRecorder();
53+
const recorder = useRef(new AudioRecorder());
5454

5555
useEffect(() => {
5656
AudioManager.setAudioSessionOptions({
@@ -115,7 +115,7 @@ export const SpeechToTextScreen = ({ onBack }: { onBack: () => void }) => {
115115

116116
const sampleRate = 16000;
117117

118-
recorder.onAudioReady(
118+
recorder.current.onAudioReady(
119119
{
120120
sampleRate,
121121
bufferLength: 0.1 * sampleRate,
@@ -131,7 +131,7 @@ export const SpeechToTextScreen = ({ onBack }: { onBack: () => void }) => {
131131
if (!success) {
132132
console.warn('Cannot start audio session correctly');
133133
}
134-
const result = recorder.start();
134+
const result = recorder.current.start();
135135
if (result.status === 'error') {
136136
console.warn('Recording problems: ', result.message);
137137
}
@@ -177,7 +177,7 @@ export const SpeechToTextScreen = ({ onBack }: { onBack: () => void }) => {
177177
const handleStopTranscribeFromMicrophone = () => {
178178
isRecordingRef.current = false;
179179

180-
recorder.stop();
180+
recorder.current.stop();
181181
model.streamStop();
182182
console.log('Live transcription stopped');
183183
setLiveTranscribing(false);

docs/docs/03-hooks/01-natural-language-processing/useSpeechToText.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ Since speech-to-text models can only process audio segments up to 30 seconds lon
6666

6767
`useSpeechToText` takes [`SpeechToTextProps`](../../06-api-reference/interfaces/SpeechToTextProps.md) that consists of:
6868

69-
- `model` of type [`SpeechToTextConfig`](../../06-api-reference/interfaces/SpeechToTextModelConfig.md), containing the [`isMultilingual` flag](../../06-api-reference/interfaces/SpeechToTextModelConfig.md#ismultilingual), [tokenizer source](../../06-api-reference/interfaces/SpeechToTextModelConfig.md#tokenizersource), [encoder source](../../06-api-reference/interfaces/SpeechToTextModelConfig.md#encodersource), and [decoder source](../../06-api-reference/interfaces/SpeechToTextModelConfig.md#decodersource).
69+
- `model` of type [`SpeechToTextConfig`](../../06-api-reference/interfaces/SpeechToTextModelConfig.md), containing the [`isMultilingual` flag](../../06-api-reference/interfaces/SpeechToTextModelConfig.md#ismultilingual), [tokenizer source](../../06-api-reference/interfaces/SpeechToTextModelConfig.md#tokenizersource) and [model source](../../06-api-reference/interfaces/SpeechToTextModelConfig.md#modelsource).
7070
- An optional flag [`preventLoad`](../../06-api-reference/interfaces/SpeechToTextProps.md#preventload) which prevents auto-loading of the model.
7171

7272
You need more details? Check the following resources:

docs/docs/04-typescript-api/01-natural-language-processing/SpeechToTextModule.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,7 @@ Create an instance of [`SpeechToTextModule`](../../06-api-reference/classes/Spee
4545
- [`model`](../../06-api-reference/classes/SpeechToTextModule.md#model) - Object containing:
4646
- [`isMultilingual`](../../06-api-reference/interfaces/SpeechToTextModelConfig.md#ismultilingual) - Flag indicating if model is multilingual.
4747

48-
- [`encoderSource`](../../06-api-reference/interfaces/SpeechToTextModelConfig.md#encodersource) - The location of the used encoder.
49-
50-
- [`decoderSource`](../../06-api-reference/interfaces/SpeechToTextModelConfig.md#decodersource) - The location of the used decoder.
48+
- [`modelSource`](../../06-api-reference/interfaces/SpeechToTextModelConfig.md#modelsource) - The location of the used model (bundled encoder + decoder functionality).
5149

5250
- [`tokenizerSource`](../../06-api-reference/interfaces/SpeechToTextModelConfig.md#tokenizersource) - The location of the used tokenizer.
5351

packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717
#include <rnexecutorch/metaprogramming/TypeConcepts.h>
1818
#include <rnexecutorch/models/object_detection/Types.h>
1919
#include <rnexecutorch/models/ocr/Types.h>
20-
#include <rnexecutorch/models/speech_to_text/types/Segment.h>
21-
#include <rnexecutorch/models/speech_to_text/types/TranscriptionResult.h>
20+
#include <rnexecutorch/models/speech_to_text/common/types/Segment.h>
21+
#include <rnexecutorch/models/speech_to_text/common/types/TranscriptionResult.h>
2222
#include <rnexecutorch/models/voice_activity_detection/Types.h>
2323

24-
using namespace rnexecutorch::models::speech_to_text::types;
24+
using namespace rnexecutorch::models::speech_to_text;
2525

2626
namespace rnexecutorch::jsi_conversion {
2727

@@ -513,7 +513,8 @@ inline jsi::Value getJsiValue(const Segment &seg, jsi::Runtime &runtime) {
513513
jsi::Object wordObj(runtime);
514514
wordObj.setProperty(
515515
runtime, "word",
516-
jsi::String::createFromUtf8(runtime, seg.words[i].content));
516+
jsi::String::createFromUtf8(runtime, seg.words[i].content +
517+
seg.words[i].punctations));
517518
wordObj.setProperty(runtime, "start",
518519
static_cast<double>(seg.words[i].start));
519520
wordObj.setProperty(runtime, "end", static_cast<double>(seg.words[i].end));

packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
#include <string>
44
#include <vector>
55

6-
#include "rnexecutorch/metaprogramming/ConstructorHelpers.h"
76
#include <ReactCommon/CallInvoker.h>
87
#include <executorch/extension/module/module.h>
98
#include <jsi/jsi.h>
109
#include <rnexecutorch/host_objects/JSTensorViewIn.h>
1110
#include <rnexecutorch/host_objects/JSTensorViewOut.h>
11+
#include <rnexecutorch/metaprogramming/ConstructorHelpers.h>
1212

1313
namespace rnexecutorch {
1414
namespace models {
Lines changed: 82 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,66 @@
11
#include <thread>
22

33
#include "SpeechToText.h"
4+
#include "common/types/TranscriptionResult.h"
5+
#include "whisper/ASR.h"
6+
#include "whisper/OnlineASR.h"
47
#include <rnexecutorch/Error.h>
58
#include <rnexecutorch/ErrorCodes.h>
6-
#include <rnexecutorch/models/speech_to_text/types/TranscriptionResult.h>
79

810
namespace rnexecutorch::models::speech_to_text {
911

10-
using namespace ::executorch::extension;
11-
using namespace asr;
12-
using namespace types;
13-
using namespace stream;
14-
15-
SpeechToText::SpeechToText(const std::string &encoderSource,
16-
const std::string &decoderSource,
12+
SpeechToText::SpeechToText(const std::string &modelName,
13+
const std::string &modelSource,
1714
const std::string &tokenizerSource,
1815
std::shared_ptr<react::CallInvoker> callInvoker)
19-
: callInvoker(std::move(callInvoker)),
20-
encoder(std::make_unique<BaseModel>(encoderSource, this->callInvoker)),
21-
decoder(std::make_unique<BaseModel>(decoderSource, this->callInvoker)),
22-
tokenizer(std::make_unique<TokenizerModule>(tokenizerSource,
23-
this->callInvoker)),
24-
asr(std::make_unique<ASR>(this->encoder.get(), this->decoder.get(),
25-
this->tokenizer.get())),
26-
processor(std::make_unique<OnlineASRProcessor>(this->asr.get())),
27-
isStreaming(false), readyToProcess(false) {}
28-
29-
void SpeechToText::unload() noexcept {
30-
this->encoder->unload();
31-
this->decoder->unload();
16+
: callInvoker_(std::move(callInvoker)) {
17+
// Switch between the ASR implementations based on model name
18+
if (modelName == "whisper") {
19+
transcriber_ = std::make_unique<whisper::ASR>(modelSource, tokenizerSource,
20+
callInvoker_);
21+
streamer_ = std::make_unique<whisper::stream::OnlineASR>(
22+
static_cast<const whisper::ASR *>(transcriber_.get()));
23+
} else {
24+
throw rnexecutorch::RnExecutorchError(
25+
rnexecutorch::RnExecutorchErrorCode::InvalidConfig,
26+
"[SpeechToText]: Invalid model name: " + modelName);
27+
}
3228
}
3329

30+
SpeechToText::SpeechToText(SpeechToText &&other) noexcept
31+
: callInvoker_(std::move(other.callInvoker_)),
32+
transcriber_(std::move(other.transcriber_)),
33+
streamer_(std::move(other.streamer_)),
34+
isStreaming_(other.isStreaming_.load()),
35+
readyToProcess_(other.readyToProcess_.load()) {}
36+
37+
void SpeechToText::unload() noexcept { transcriber_->unload(); }
38+
3439
std::shared_ptr<OwningArrayBuffer>
3540
SpeechToText::encode(std::span<float> waveform) const {
36-
std::vector<float> encoderOutput = this->asr->encode(waveform);
37-
return std::make_shared<OwningArrayBuffer>(encoderOutput);
41+
executorch::aten::Tensor encoderOutputTensor = transcriber_->encode(waveform);
42+
43+
return std::make_shared<OwningArrayBuffer>(
44+
encoderOutputTensor.const_data_ptr<float>(),
45+
sizeof(float) * encoderOutputTensor.numel());
3846
}
3947

4048
std::shared_ptr<OwningArrayBuffer>
4149
SpeechToText::decode(std::span<uint64_t> tokens,
4250
std::span<float> encoderOutput) const {
43-
std::vector<float> decoderOutput = this->asr->decode(tokens, encoderOutput);
44-
return std::make_shared<OwningArrayBuffer>(decoderOutput);
51+
executorch::aten::Tensor decoderOutputTensor =
52+
transcriber_->decode(tokens, encoderOutput);
53+
54+
return std::make_shared<OwningArrayBuffer>(
55+
decoderOutputTensor.const_data_ptr<float>(),
56+
sizeof(float) * decoderOutputTensor.numel());
4557
}
4658

4759
TranscriptionResult SpeechToText::transcribe(std::span<float> waveform,
4860
std::string languageOption,
4961
bool verbose) const {
5062
DecodingOptions options(languageOption, verbose);
51-
std::vector<Segment> segments = this->asr->transcribe(waveform, options);
63+
std::vector<Segment> segments = transcriber_->transcribe(waveform, options);
5264

5365
std::string fullText;
5466
for (const auto &segment : segments) {
@@ -70,8 +82,7 @@ TranscriptionResult SpeechToText::transcribe(std::span<float> waveform,
7082
}
7183

7284
size_t SpeechToText::getMemoryLowerBound() const noexcept {
73-
return this->encoder->getMemoryLowerBound() +
74-
this->decoder->getMemoryLowerBound();
85+
return transcriber_->getMemoryLowerBound();
7586
}
7687

7788
namespace {
@@ -83,7 +94,7 @@ TranscriptionResult wordsToResult(const std::vector<Word> &words,
8394

8495
std::string fullText;
8596
for (const auto &w : words) {
86-
fullText += w.content;
97+
fullText += w.content + w.punctations;
8798
}
8899
res.text = fullText;
89100

@@ -105,68 +116,70 @@ TranscriptionResult wordsToResult(const std::vector<Word> &words,
105116

106117
void SpeechToText::stream(std::shared_ptr<jsi::Function> callback,
107118
std::string languageOption, bool verbose) {
108-
if (this->isStreaming) {
119+
if (isStreaming_) {
109120
throw RnExecutorchError(RnExecutorchErrorCode::StreamingInProgress,
110121
"Streaming is already in progress!");
111122
}
112123

113-
auto nativeCallback = [this, callback,
114-
verbose](const TranscriptionResult &committed,
115-
const TranscriptionResult &nonCommitted,
116-
bool isDone) {
117-
// This moves execution to the JS thread
118-
this->callInvoker->invokeAsync(
119-
[callback, committed, nonCommitted, isDone, verbose](jsi::Runtime &rt) {
120-
jsi::Value jsiCommitted =
121-
rnexecutorch::jsi_conversion::getJsiValue(committed, rt);
122-
jsi::Value jsiNonCommitted =
123-
rnexecutorch::jsi_conversion::getJsiValue(nonCommitted, rt);
124-
125-
callback->call(rt, std::move(jsiCommitted),
126-
std::move(jsiNonCommitted), jsi::Value(isDone));
127-
});
128-
};
129-
130-
this->isStreaming = true;
124+
auto nativeCallback =
125+
[this, callback](const TranscriptionResult &committed,
126+
const TranscriptionResult &nonCommitted, bool isDone) {
127+
// This moves execution to the JS thread
128+
callInvoker_->invokeAsync(
129+
[callback, committed, nonCommitted, isDone](jsi::Runtime &rt) {
130+
jsi::Value jsiCommitted =
131+
rnexecutorch::jsi_conversion::getJsiValue(committed, rt);
132+
jsi::Value jsiNonCommitted =
133+
rnexecutorch::jsi_conversion::getJsiValue(nonCommitted, rt);
134+
135+
callback->call(rt, std::move(jsiCommitted),
136+
std::move(jsiNonCommitted), jsi::Value(isDone));
137+
});
138+
};
139+
140+
isStreaming_ = true;
131141
DecodingOptions options(languageOption, verbose);
132142

133-
while (this->isStreaming) {
134-
if (!this->readyToProcess ||
135-
this->processor->audioBuffer.size() < SpeechToText::kMinAudioSamples) {
136-
std::this_thread::sleep_for(std::chrono::milliseconds(100));
137-
continue;
138-
}
143+
while (isStreaming_) {
144+
if (readyToProcess_ && streamer_->isReady()) {
145+
ProcessResult res = streamer_->process(options);
139146

140-
ProcessResult res = this->processor->processIter(options);
147+
TranscriptionResult committedRes =
148+
wordsToResult(res.committed, languageOption, verbose);
149+
TranscriptionResult nonCommittedRes =
150+
wordsToResult(res.nonCommitted, languageOption, verbose);
141151

142-
TranscriptionResult cRes =
143-
wordsToResult(res.committed, languageOption, verbose);
144-
TranscriptionResult ncRes =
145-
wordsToResult(res.nonCommitted, languageOption, verbose);
152+
nativeCallback(committedRes, nonCommittedRes, false);
153+
readyToProcess_ = false;
154+
}
146155

147-
nativeCallback(cRes, ncRes, false);
148-
this->readyToProcess = false;
156+
// Add a minimal pause between transcriptions.
157+
// The reasoning is very simple: with the current liberal threshold values,
158+
// running transcriptions too rapidly (before the audio buffer is filled
159+
// with significant amount of new data) can cause streamer to commit wrong
160+
// phrases.
161+
std::this_thread::sleep_for(std::chrono::milliseconds(100));
149162
}
150163

151-
std::vector<Word> finalWords = this->processor->finish();
164+
std::vector<Word> finalWords = streamer_->finish();
152165
TranscriptionResult finalRes =
153166
wordsToResult(finalWords, languageOption, verbose);
154167

155168
nativeCallback(finalRes, {}, true);
156-
this->resetStreamState();
169+
resetStreamState();
157170
}
158171

159-
void SpeechToText::streamStop() { this->isStreaming = false; }
172+
void SpeechToText::streamStop() { isStreaming_ = false; }
160173

161174
void SpeechToText::streamInsert(std::span<float> waveform) {
162-
this->processor->insertAudioChunk(waveform);
163-
this->readyToProcess = true;
175+
streamer_->insertAudioChunk(waveform);
176+
readyToProcess_ = true;
164177
}
165178

166179
void SpeechToText::resetStreamState() {
167-
this->isStreaming = false;
168-
this->readyToProcess = false;
169-
this->processor = std::make_unique<OnlineASRProcessor>(this->asr.get());
180+
isStreaming_ = false;
181+
readyToProcess_ = false;
182+
streamer_->reset();
170183
}
171184

172185
} // namespace rnexecutorch::models::speech_to_text

packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.h

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,26 @@
11
#pragma once
22

3-
#include "rnexecutorch/models/speech_to_text/stream/OnlineASRProcessor.h"
4-
#include <rnexecutorch/models/speech_to_text/types/TranscriptionResult.h>
3+
#include <atomic>
54
#include <span>
65
#include <string>
76
#include <vector>
87

8+
#include "common/schema/ASR.h"
9+
#include "common/schema/OnlineASR.h"
10+
#include "common/types/TranscriptionResult.h"
11+
912
namespace rnexecutorch {
1013

1114
namespace models::speech_to_text {
1215

1316
class SpeechToText {
1417
public:
15-
explicit SpeechToText(const std::string &encoderSource,
16-
const std::string &decoderSource,
17-
const std::string &tokenizerSource,
18-
std::shared_ptr<react::CallInvoker> callInvoker);
18+
SpeechToText(const std::string &modelName, const std::string &modelSource,
19+
const std::string &tokenizerSource,
20+
std::shared_ptr<react::CallInvoker> callInvoker);
21+
22+
// Required because of std::atomic usage
23+
SpeechToText(SpeechToText &&other) noexcept;
1924

2025
void unload() noexcept;
2126
[[nodiscard(
@@ -25,9 +30,9 @@ class SpeechToText {
2530
"Registered non-void function")]] std::shared_ptr<OwningArrayBuffer>
2631
decode(std::span<uint64_t> tokens, std::span<float> encoderOutput) const;
2732
[[nodiscard("Registered non-void function")]]
28-
types::TranscriptionResult transcribe(std::span<float> waveform,
29-
std::string languageOption,
30-
bool verbose) const;
33+
TranscriptionResult transcribe(std::span<float> waveform,
34+
std::string languageOption,
35+
bool verbose) const;
3136

3237
[[nodiscard("Registered non-void function")]]
3338
std::vector<char> transcribeStringOnly(std::span<float> waveform,
@@ -42,20 +47,17 @@ class SpeechToText {
4247
void streamInsert(std::span<float> waveform);
4348

4449
private:
45-
std::shared_ptr<react::CallInvoker> callInvoker;
46-
std::unique_ptr<BaseModel> encoder;
47-
std::unique_ptr<BaseModel> decoder;
48-
std::unique_ptr<TokenizerModule> tokenizer;
49-
std::unique_ptr<asr::ASR> asr;
50+
void resetStreamState();
5051

51-
// Stream
52-
std::unique_ptr<stream::OnlineASRProcessor> processor;
53-
bool isStreaming;
54-
bool readyToProcess;
52+
std::shared_ptr<react::CallInvoker> callInvoker_;
5553

56-
constexpr static int32_t kMinAudioSamples = 16000; // 1 second
54+
// ASR-like module (both static transcription & streaming)
55+
std::unique_ptr<schema::ASR> transcriber_ = nullptr;
5756

58-
void resetStreamState();
57+
// Online ASR-like module (streaming only)
58+
std::unique_ptr<schema::OnlineASR> streamer_ = nullptr;
59+
std::atomic<bool> isStreaming_ = false;
60+
std::atomic<bool> readyToProcess_ = false;
5961
};
6062

6163
} // namespace models::speech_to_text

0 commit comments

Comments
 (0)