Skip to content

Commit 081aea0

Browse files
committed
Apply review suggestions
1 parent 9baccc5 commit 081aea0

File tree

10 files changed

+124
-77
lines changed

10 files changed

+124
-77
lines changed

packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,21 @@ SpeechToText::SpeechToText(const std::string &modelName,
2727
}
2828
}
2929

30+
SpeechToText::SpeechToText(SpeechToText &&other) noexcept
31+
: callInvoker_(std::move(other.callInvoker_)),
32+
transcriber_(std::move(other.transcriber_)),
33+
streamer_(std::move(other.streamer_)),
34+
isStreaming_(other.isStreaming_.load()),
35+
readyToProcess_(other.readyToProcess_.load()) {}
36+
3037
void SpeechToText::unload() noexcept { transcriber_->unload(); }
3138

3239
std::shared_ptr<OwningArrayBuffer>
3340
SpeechToText::encode(std::span<float> waveform) const {
3441
executorch::aten::Tensor encoderOutputTensor = transcriber_->encode(waveform);
3542

3643
return std::make_shared<OwningArrayBuffer>(
37-
encoderOutputTensor.const_data_ptr(),
44+
encoderOutputTensor.const_data_ptr<float>(),
3845
sizeof(float) * encoderOutputTensor.numel());
3946
}
4047

@@ -45,7 +52,7 @@ SpeechToText::decode(std::span<uint64_t> tokens,
4552
transcriber_->decode(tokens, encoderOutput);
4653

4754
return std::make_shared<OwningArrayBuffer>(
48-
decoderOutputTensor.const_data_ptr(),
55+
decoderOutputTensor.const_data_ptr<float>(),
4956
sizeof(float) * decoderOutputTensor.numel());
5057
}
5158

@@ -137,12 +144,12 @@ void SpeechToText::stream(std::shared_ptr<jsi::Function> callback,
137144
if (readyToProcess_ && streamer_->isReady()) {
138145
ProcessResult res = streamer_->process(options);
139146

140-
TranscriptionResult cRes =
147+
TranscriptionResult committedRes =
141148
wordsToResult(res.committed, languageOption, verbose);
142-
TranscriptionResult ncRes =
149+
TranscriptionResult nonCommittedRes =
143150
wordsToResult(res.nonCommitted, languageOption, verbose);
144151

145-
nativeCallback(cRes, ncRes, false);
152+
nativeCallback(committedRes, nonCommittedRes, false);
146153
readyToProcess_ = false;
147154
}
148155

@@ -151,7 +158,7 @@ void SpeechToText::stream(std::shared_ptr<jsi::Function> callback,
151158
// running transcriptions too rapidly (before the audio buffer is filled
152159
// with significant amount of new data) can cause streamer to commit wrong
153160
// phrases.
154-
std::this_thread::sleep_for(std::chrono::milliseconds(75));
161+
std::this_thread::sleep_for(std::chrono::milliseconds(100));
155162
}
156163

157164
std::vector<Word> finalWords = streamer_->finish();

packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.h

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22

3+
#include <atomic>
34
#include <span>
45
#include <string>
56
#include <vector>
@@ -14,10 +15,12 @@ namespace models::speech_to_text {
1415

1516
class SpeechToText {
1617
public:
17-
explicit SpeechToText(const std::string &modelName,
18-
const std::string &modelSource,
19-
const std::string &tokenizerSource,
20-
std::shared_ptr<react::CallInvoker> callInvoker);
18+
SpeechToText(const std::string &modelName, const std::string &modelSource,
19+
const std::string &tokenizerSource,
20+
std::shared_ptr<react::CallInvoker> callInvoker);
21+
22+
// Required because of std::atomic usage
23+
SpeechToText(SpeechToText &&other) noexcept;
2124

2225
void unload() noexcept;
2326
[[nodiscard(
@@ -53,8 +56,8 @@ class SpeechToText {
5356

5457
// Online ASR-like module (streaming only)
5558
std::unique_ptr<schema::OnlineASR> streamer_ = nullptr;
56-
bool isStreaming_ = false;
57-
bool readyToProcess_ = false;
59+
std::atomic<bool> isStreaming_ = false;
60+
std::atomic<bool> readyToProcess_ = false;
5861
};
5962

6063
} // namespace models::speech_to_text

packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/Word.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ struct Word {
99
float start;
1010
float end;
1111

12-
std::string punctations =
13-
""; // Trailing punctations which appear after the main content
12+
std::string
13+
punctations; // Trailing punctations which appear after the main content
1414
};
1515

1616
} // namespace rnexecutorch::models::speech_to_text

packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/ASR.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,9 @@ class ASR : public models::BaseModel, public schema::ASR {
161161
std::unique_ptr<TokenizerModule> tokenizer_;
162162

163163
// Tokenization helper definitions
164-
const Token startOfTranscriptionToken_;
165-
const Token endOfTranscriptionToken_;
166-
const Token timestampBeginToken_;
164+
Token startOfTranscriptionToken_;
165+
Token endOfTranscriptionToken_;
166+
Token timestampBeginToken_;
167167
};
168168

169169
} // namespace rnexecutorch::models::speech_to_text::whisper

packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/HypothesisBuffer.cpp

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#include "HypothesisBuffer.h"
22
#include "Params.h"
33
#include "Utils.h"
4+
5+
#include <algorithm>
46
#include <cmath>
57

68
namespace rnexecutorch::models::speech_to_text::whisper::stream {
@@ -17,7 +19,9 @@ void HypothesisBuffer::insert(std::span<const Word> words, float offset) {
1719
size_t firstFreshWordIdx = 0;
1820
if (!committed_.empty()) {
1921
std::optional<size_t> lastMatchingWordIdx =
20-
findCommittedSuffix(words, 5, 6.F, 5);
22+
findCommittedSuffix(words, params::kStreamCommitedSuffixSearchSize,
23+
params::kStreamMaxOverlapTimestampDiff1,
24+
params::kStreamWordsPerErrorRate);
2125
firstFreshWordIdx = lastMatchingWordIdx.value_or(0);
2226
}
2327

@@ -48,7 +52,7 @@ void HypothesisBuffer::insert(std::span<const Word> words, float offset) {
4852
// which were just repeated after some time.
4953
size_t overlapSize = utils::findLargestOverlapingFragment(
5054
committed_, fresh_, params::kStreamMaxOverlapSize,
51-
params::kStreamMaxOverlapTimestampDiff);
55+
params::kStreamMaxOverlapTimestampDiff2);
5256

5357
if (overlapSize > 0) {
5458
fresh_.erase(fresh_.begin(), fresh_.begin() + overlapSize);
@@ -124,24 +128,24 @@ std::optional<size_t> HypothesisBuffer::findCommittedSuffix(
124128

125129
// Iterate backwards through 'words' to find the most recent occurrence of a
126130
// suffix of 'committed_' (or the full 'committed_' sequence).
127-
for (int i = static_cast<int>(words.size()) - 1; i >= 0; --i) {
131+
for (int32_t i = static_cast<int32_t>(words.size()) - 1; i >= 0; --i) {
128132
bool match = true;
129133
size_t matchedCount = 0;
130134
size_t contentMistakeCount = 0;
131135

132136
// Linearly interpolate tolerance if we are at the beginning and can't check
133137
// all committed words.
134138
float effectiveTolerance = timestampDiffTolerance;
135-
if (i < static_cast<int>(committedToMatchSize) - 1) {
139+
if (i < static_cast<int32_t>(committedToMatchSize) - 1) {
136140
effectiveTolerance *=
137141
static_cast<float>(i + 1) / static_cast<float>(committedToMatchSize);
138142
}
139143

140144
// Try to match backwards from words[i] and committed_.back()
141145
for (size_t j = 0; j < committedToMatchSize; ++j) {
142-
int wordsIdx = i - static_cast<int>(j);
143-
int committedIdx =
144-
static_cast<int>(committed_.size()) - 1 - static_cast<int>(j);
146+
int32_t wordsIdx = i - static_cast<int32_t>(j);
147+
int32_t committedIdx =
148+
static_cast<int32_t>(committed_.size()) - 1 - static_cast<int32_t>(j);
145149

146150
if (wordsIdx < 0) {
147151
// We reached the beginning of the words span.
@@ -153,8 +157,8 @@ std::optional<size_t> HypothesisBuffer::findCommittedSuffix(
153157
const Word &w2 = committed_[committedIdx];
154158

155159
// Check timestamps within tolerance
156-
if (std::abs(w1.start - w2.start) > effectiveTolerance ||
157-
std::abs(w1.end - w2.end) > effectiveTolerance) {
160+
if (std::max(std::abs(w1.start - w2.start), std::abs(w1.end - w2.end)) >
161+
effectiveTolerance) {
158162
match = false;
159163
break;
160164
}

packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/OnlineASR.cpp

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ OnlineASR::OnlineASR(const ASR *asr) : asr_(asr) {
2424
}
2525

2626
void OnlineASR::insertAudioChunk(std::span<const float> audio) {
27-
std::lock_guard<std::mutex> lock(audioBufferMutex_);
27+
std::scoped_lock<std::mutex> lock(audioBufferMutex_);
2828
audioBuffer_.insert(audioBuffer_.end(), audio.begin(), audio.end());
2929
}
3030

@@ -33,10 +33,16 @@ bool OnlineASR::isReady() const {
3333
}
3434

3535
ProcessResult OnlineASR::process(const DecodingOptions &options) {
36-
std::unique_lock<std::mutex> lock(audioBufferMutex_);
36+
std::vector<float> audioCopy;
37+
38+
// Copy the audio buffer to avoid keeping the lock during the entire
39+
// transcription process.
40+
{
41+
std::scoped_lock<std::mutex> lock(audioBufferMutex_);
42+
audioCopy = audioBuffer_;
43+
}
3744

3845
std::vector<Segment> transcriptions = asr_->transcribe(audioBuffer_, options);
39-
lock.unlock();
4046

4147
if (transcriptions.empty()) {
4248
return {.committed = {}, .nonCommitted = {}};
@@ -106,30 +112,32 @@ ProcessResult OnlineASR::process(const DecodingOptions &options) {
106112

107113
// Since Whisper does not accept waveforms longer than 30 seconds, we need
108114
// to cut the audio at some safe point.
109-
lock.lock();
110-
const float audioDuration =
111-
static_cast<float>(audioBuffer_.size()) / constants::kSamplingRate;
112-
if (audioDuration > params::kStreamChunkThreshold) {
113-
// Leave some portion of audio in, to improve model behavior
114-
// in future iterations.
115-
const float erasePoint =
116-
hypothesisBuffer_.lastCommittedTime_ == lastSentenceEnd_
117-
? audioDuration
118-
: std::min(lastSentenceEnd_, params::kStreamChunkThreshold);
119-
const float minEraseDuration =
120-
audioDuration - params::kStreamAudioBufferMaxReserve;
121-
const float maxEraseDuration =
122-
audioDuration - params::kStreamAudioBufferMinReserve;
123-
const float eraseDuration = std::clamp(erasePoint - bufferTimeOffset_,
124-
minEraseDuration, maxEraseDuration);
125-
const size_t nSamplesToErase =
126-
static_cast<size_t>(eraseDuration * constants::kSamplingRate);
127-
128-
audioBuffer_.erase(audioBuffer_.begin(),
129-
audioBuffer_.begin() + nSamplesToErase);
130-
bufferTimeOffset_ += eraseDuration;
115+
{
116+
std::scoped_lock<std::mutex> lock(audioBufferMutex_);
117+
118+
const float audioDuration =
119+
static_cast<float>(audioBuffer_.size()) / constants::kSamplingRate;
120+
if (audioDuration > params::kStreamChunkThreshold) {
121+
// Leave some portion of audio in, to improve model behavior
122+
// in future iterations.
123+
const float erasePoint =
124+
hypothesisBuffer_.lastCommittedTime_ == lastSentenceEnd_
125+
? audioDuration
126+
: std::min(lastSentenceEnd_, params::kStreamChunkThreshold);
127+
const float minEraseDuration =
128+
audioDuration - params::kStreamAudioBufferMaxReserve;
129+
const float maxEraseDuration =
130+
audioDuration - params::kStreamAudioBufferMinReserve;
131+
const float eraseDuration = std::clamp(
132+
erasePoint - bufferTimeOffset_, minEraseDuration, maxEraseDuration);
133+
const size_t nSamplesToErase =
134+
static_cast<size_t>(eraseDuration * constants::kSamplingRate);
135+
136+
audioBuffer_.erase(audioBuffer_.begin(),
137+
audioBuffer_.begin() + nSamplesToErase);
138+
bufferTimeOffset_ += eraseDuration;
139+
}
131140
}
132-
lock.unlock();
133141

134142
return {.committed = move_to_vector(committed),
135143
.nonCommitted = move_to_vector(nonCommitted)};
@@ -144,7 +152,7 @@ std::vector<Word> OnlineASR::finish() {
144152
}
145153

146154
void OnlineASR::reset() {
147-
std::lock_guard<std::mutex> lock(audioBufferMutex_);
155+
std::scoped_lock<std::mutex> lock(audioBufferMutex_);
148156

149157
hypothesisBuffer_.reset();
150158
bufferTimeOffset_ = 0.f;

packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/OnlineASR.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ class OnlineASR : public schema::OnlineASR {
6666
// Stores the increasing amounts of streamed audio.
6767
// Cleared from time to time after reaching a threshold size.
6868
std::vector<float> audioBuffer_ = {};
69-
std::mutex audioBufferMutex_;
69+
mutable std::mutex audioBufferMutex_;
7070
float bufferTimeOffset_ = 0.F; // Audio buffer offset
7171

7272
// Helper buffers - hypothesis buffer

packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Params.h

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,17 @@ constexpr static int32_t kChunkBreakBuffer = 2; // [s]
2525
* Determines the maximum timestamp difference available for a word to be
2626
* considered as fresh in streaming algorithm.
2727
*/
28-
constexpr static float kStreamFreshThreshold = 2.F; // [s], originally 0.5
28+
constexpr static float kStreamFreshThreshold = 3.F; // [s], originally 0.5
29+
30+
/**
31+
* The size of the most recent committed suffix searched in
32+
* fresh words string.
33+
*
34+
* For example, if the committed buffer contains ["I", "did" "a" "very" "nasty"
35+
* "thing."], and kStreamCommitedSuffixSearchSize = 3, then we search for
36+
* ["very" "nasty" "thing."] suffix.
37+
*/
38+
constexpr static size_t kStreamCommitedSuffixSearchSize = 5;
2939

3040
/**
3141
* Determines the maximum expected size of overlapping fragments between
@@ -40,8 +50,28 @@ constexpr static size_t kStreamMaxOverlapSize =
4050
/**
4151
* Similar to kMaxStreamOverlapSize, but this one determines
4252
* the maximum allowed timestamp difference between the overlaping fragments.
53+
*
54+
* It's the first, more strict threshold, used when searching for recently
55+
* committed entries.
56+
*/
57+
constexpr static float kStreamMaxOverlapTimestampDiff1 = 6.F; // [s]
58+
59+
/**
60+
* Similar to kMaxStreamOverlapSize, but this one determines
61+
* the maximum allowed timestamp difference between the overlaping fragments.
62+
*
63+
* It's the second, more liberal threshold, used in overlap correction
64+
* algorithm.
65+
*/
66+
constexpr static float kStreamMaxOverlapTimestampDiff2 = 15.F; // [s]
67+
68+
/**
69+
* Number of words per 1 allowed mistake (error correction).
70+
*
71+
* For example, if kStreamWordsPerErrorRate = 4, then we allow maximum 1 mistake
72+
* in a 4 word string.
4373
*/
44-
constexpr static float kStreamMaxOverlapTimestampDiff = 15.F; // [s]
74+
constexpr static size_t kStreamWordsPerErrorRate = 5;
4575

4676
/**
4777
* A threshold which exceeded causes the main streaming audio buffer to be

packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Utils.h

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@ namespace rnexecutorch::models::speech_to_text::whisper::utils {
1010

1111
// Compares two strings without case-sensitivity.
1212
inline bool equalsIgnoreCase(const std::string &a, const std::string &b) {
13-
if (a.size() != b.size())
13+
if (a.size() != b.size()) {
1414
return false;
15+
}
1516
return std::equal(a.begin(), a.end(), b.begin(), [](char c1, char c2) {
1617
return std::tolower(static_cast<unsigned char>(c1)) ==
1718
std::tolower(static_cast<unsigned char>(c2));
@@ -55,13 +56,14 @@ inline size_t findLargestOverlapingFragment(const Container &suffixVec,
5556
if (equalsIgnoreCase(suffixVec[i].content, prefixVec[0].content)) {
5657
size_t calculatedSize = suffixVec.size() - i;
5758

58-
bool isEqual = std::equal(
59-
suffixVec.begin() + i, suffixVec.end(), prefixVec.begin(),
60-
[maxTimestampDiff](const Word &sWord, const Word &pWord) {
61-
return equalsIgnoreCase(sWord.content, pWord.content) &&
62-
std::fabs(sWord.start - pWord.start) <= maxTimestampDiff &&
63-
std::fabs(sWord.end - pWord.end) <= maxTimestampDiff;
64-
});
59+
bool isEqual =
60+
std::equal(suffixVec.begin() + i, suffixVec.end(), prefixVec.begin(),
61+
[maxTimestampDiff](const Word &sWord, const Word &pWord) {
62+
return equalsIgnoreCase(sWord.content, pWord.content) &&
63+
std::max(std::fabs(sWord.start - pWord.start),
64+
std::fabs(sWord.end - pWord.end)) <=
65+
maxTimestampDiff;
66+
});
6567

6668
if (isEqual) {
6769
return calculatedSize;

0 commit comments

Comments
 (0)