Skip to content

Commit ce5a39a

Browse files
committed
Various STT streaming fixes
1 parent b54e469 commit ce5a39a

6 files changed

Lines changed: 304 additions & 39 deletions

File tree

packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/ASR.cpp

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,14 @@
44

55
#include "ASR.h"
66
#include "Constants.h"
7+
#include "Params.h"
78
#include <executorch/extension/tensor/tensor_ptr.h>
89
#include <rnexecutorch/Error.h>
910
#include <rnexecutorch/data_processing/Numerical.h>
1011
#include <rnexecutorch/data_processing/gzip.h>
1112

13+
#include <rnexecutorch/Log.h>
14+
1215
namespace rnexecutorch::models::speech_to_text::whisper {
1316

1417
using executorch::runtime::etensor::ScalarType;
@@ -30,18 +33,24 @@ ASR::ASR(const std::string &modelSource, const std::string &tokenizerSource,
3033
*/
3134
std::vector<Segment> ASR::transcribe(std::span<float> waveform,
3235
const DecodingOptions &options) const {
33-
int32_t seek = 0;
36+
// Use floats to prevent downcasting and timestamp mismatches
37+
float seek = 0.f;
3438
std::vector<Segment> results;
3539

40+
const float waveformSize = static_cast<float>(waveform.size());
41+
const float waveformSkipBoundary =
42+
static_cast<float>((constants::kChunkSize - params::kChunkBreakBuffer) *
43+
constants::kSamplingRate);
44+
3645
// We loop through the input audio waveform and process it in 30s chunks.
3746
// This is determined by Whisper models strict 30s audio length requirement.
38-
while (std::cmp_less(seek * constants::kSamplingRate, waveform.size())) {
47+
while (seek * constants::kSamplingRate < waveformSize) {
3948
// Calculate chunk bounds and extract the chunk.
40-
int32_t start = seek * constants::kSamplingRate;
49+
float start = seek * constants::kSamplingRate;
4150
const auto end =
42-
std::min<int32_t>(static_cast<int32_t>((seek + constants::kChunkSize) *
43-
constants::kSamplingRate),
44-
static_cast<int32_t>(waveform.size()));
51+
std::min<float>(static_cast<float>((seek + constants::kChunkSize) *
52+
constants::kSamplingRate),
53+
waveformSize);
4554
auto chunk = waveform.subspan(start, end - start);
4655

4756
if (std::cmp_less(chunk.size(), constants::kMinChunkSamples)) {
@@ -71,7 +80,12 @@ std::vector<Segment> ASR::transcribe(std::span<float> waveform,
7180
}
7281

7382
if (!segments.empty() && !segments.back().words.empty()) {
74-
seek = static_cast<int32_t>(segments.back().words.back().end);
83+
// This prevents additional segments to appear, unless the audio length is
84+
// very close to the max chunk size, that is there could be some words
85+
// spoken near the breakpoint.
86+
seek = waveformSize < waveformSkipBoundary
87+
? seek + constants::kChunkSize
88+
: segments.back().words.back().end;
7589
}
7690
results.insert(results.end(), std::make_move_iterator(segments.begin()),
7791
std::make_move_iterator(segments.end()));
@@ -226,6 +240,12 @@ std::vector<Segment> ASR::generate(std::span<float> waveform,
226240
}
227241
}
228242

243+
rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info,
244+
"[ASR] Raw transcription results (tokens): ", bestTokens);
245+
rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info,
246+
"[ASR] Raw transcription results (text): ",
247+
tokenizer_->decode(bestTokens, true));
248+
229249
return this->calculateWordLevelTimestamps(bestTokens, waveform,
230250
bestAvgLogProb, bestTemperature,
231251
bestCompressionRatio);
@@ -323,7 +343,8 @@ std::vector<Segment> ASR::calculateWordLevelTimestamps(
323343
if (words.size()) {
324344
Segment seg;
325345
seg.words = std::move(words);
326-
seg.tokens = {};
346+
// seg.tokens = {}; // WTF ?
347+
seg.tokens = tokens;
327348
seg.avgLogprob = avgLogProb;
328349
seg.temperature = temperature;
329350
seg.compressionRatio = compressionRatio;
@@ -382,6 +403,7 @@ ASR::estimateWordLevelTimestampsLinear(std::span<const uint64_t> tokens,
382403
uint64_t start, uint64_t end) const {
383404
const std::vector<uint64_t> tokensVec(tokens.begin(), tokens.end());
384405
const std::string segmentText = tokenizer_->decode(tokensVec, true);
406+
385407
std::istringstream iss(segmentText);
386408
std::vector<std::string> wordsStr;
387409
std::string word;

packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Constants.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,17 @@ constexpr static int32_t kNumFrames = 1500;
2121

2222
// Sampling rate expected by Whisper and the model's audio pipeline (16 kHz)
2323
constexpr static int32_t kSamplingRate = 16000;
24+
constexpr static int32_t kSamplesPerMilisecond = kSamplingRate / 1000;
2425

2526
// Time precision used by Whisper timestamps: each token spans 0.02 seconds
2627
constexpr static float kTimePrecision = 0.02f;
2728

2829
// Special token constants
2930
namespace tokens {
30-
inline const std::string kStartOfTranscript = "<|startoftranscript|>";
31-
inline const std::string kEndOfTranscript = "<|endoftext|>";
32-
inline const std::string kBeginTimestamp = "<|0.00|>";
31+
static const std::string kStartOfTranscript = "<|startoftranscript|>";
32+
static const std::string kEndOfTranscript = "<|endoftext|>";
33+
static const std::string kBeginTimestamp = "<|0.00|>";
34+
static const std::string kBlankAudio = "[BLANK_AUDIO]";
3335
} // namespace tokens
3436

3537
} // namespace rnexecutorch::models::speech_to_text::whisper::constants
Lines changed: 58 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,56 @@
11
#include "HypothesisBuffer.h"
2+
#include "Params.h"
3+
#include "Utils.h"
4+
#include <cmath>
5+
#include <rnexecutorch/Log.h>
26

37
namespace rnexecutorch::models::speech_to_text::whisper::stream {
48

59
void HypothesisBuffer::insert(std::span<const Word> newWords, float offset) {
10+
rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info,
11+
"[HypothesisBuffer] Inserting " +
12+
std::to_string(newWords.size()) +
13+
" words with offset " + std::to_string(offset) + "s.");
14+
615
fresh_.clear();
716
for (const auto &word : newWords) {
817
const float newStart = word.start + offset;
9-
if (newStart > lastCommittedTime_ - 0.5f) {
18+
// Only accept words that start after or near the last committed time to
19+
// avoid stale data
20+
if (newStart > lastCommittedTime_ - params::kStreamFreshThreshold) {
1021
fresh_.emplace_back(word.content, newStart, word.end + offset);
1122
}
1223
}
24+
rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info,
25+
"[HypothesisBuffer] Filtered " +
26+
std::to_string(fresh_.size()) +
27+
" words into 'fresh' buffer.");
1328

1429
if (!fresh_.empty() && !committedInBuffer_.empty()) {
1530
const float a = fresh_.front().start;
16-
if (std::fabs(a - lastCommittedTime_) < 1.0f) {
31+
// Check for overlap with already committed history to avoid duplicates in
32+
// the stream
33+
if (std::fabs(a - lastCommittedTime_) < 2.0f) {
1734
const size_t cn = committedInBuffer_.size();
1835
const size_t nn = fresh_.size();
19-
const std::size_t maxCheck = std::min<std::size_t>({cn, nn, 5});
20-
for (size_t i = 1; i <= maxCheck; i++) {
21-
std::string c;
22-
for (auto it = committedInBuffer_.cend() - i;
23-
it != committedInBuffer_.cend(); ++it) {
24-
if (!c.empty()) {
25-
c += ' ';
26-
}
27-
c += it->content;
28-
}
29-
30-
std::string tail;
31-
auto it = fresh_.cbegin();
32-
for (size_t k = 0; k < i; k++, it++) {
33-
if (!tail.empty()) {
34-
tail += ' ';
35-
}
36-
tail += it->content;
37-
}
38-
39-
if (c == tail) {
40-
fresh_.erase(fresh_.begin(), fresh_.begin() + i);
41-
break;
42-
}
36+
37+
rnexecutorch::log(
38+
rnexecutorch::LOG_LEVEL::Info,
39+
"[HypothesisBuffer] Checking for overlap. cn=" + std::to_string(cn) +
40+
", nn=" + std::to_string(nn) +
41+
", maxCheck=" + std::to_string(params::kStreamMaxOverlapSize));
42+
43+
size_t overlapSize = utils::findLargestOverlapingFragment(
44+
committedInBuffer_, fresh_, params::kStreamMaxOverlapSize,
45+
params::kStreamMaxOverlapTimestampDiff);
46+
47+
if (overlapSize > 0) {
48+
rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info,
49+
"[HypothesisBuffer] Detected overlap of " +
50+
std::to_string(overlapSize) +
51+
" words with committed history. Erasing "
52+
"duplicates from 'fresh'.");
53+
fresh_.erase(fresh_.begin(), fresh_.begin() + overlapSize);
4354
}
4455
}
4556
}
@@ -48,6 +59,8 @@ void HypothesisBuffer::insert(std::span<const Word> newWords, float offset) {
4859
std::deque<Word> HypothesisBuffer::flush() {
4960
std::deque<Word> commit;
5061

62+
// Find stable prefix: words that haven't changed between last and current
63+
// iteration
5164
while (!fresh_.empty() && !buffer_.empty()) {
5265
if (fresh_.front().content != buffer_.front().content) {
5366
break;
@@ -59,19 +72,36 @@ std::deque<Word> HypothesisBuffer::flush() {
5972

6073
if (!commit.empty()) {
6174
lastCommittedTime_ = commit.back().end;
75+
rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info,
76+
"[HypothesisBuffer] Found stable prefix. Committing " +
77+
std::to_string(commit.size()) +
78+
" words. New lastCommittedTime: " +
79+
std::to_string(lastCommittedTime_) + "s.");
6280
}
6381

82+
// Current 'fresh' (remaining) becomes the new 'buffer' for next iteration
83+
// comparison
6484
buffer_ = std::move(fresh_);
6585
fresh_.clear();
86+
6687
committedInBuffer_.insert(committedInBuffer_.end(), commit.begin(),
6788
commit.end());
89+
6890
return commit;
6991
}
7092

7193
void HypothesisBuffer::popCommitted(float time) {
94+
size_t count = 0;
7295
while (!committedInBuffer_.empty() &&
7396
committedInBuffer_.front().end <= time) {
7497
committedInBuffer_.pop_front();
98+
count++;
99+
}
100+
if (count > 0) {
101+
rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info,
102+
"[HypothesisBuffer] Popped " + std::to_string(count) +
103+
" old words from committed history up to " +
104+
std::to_string(time) + "s.");
75105
}
76106
}
77107

@@ -81,6 +111,8 @@ void HypothesisBuffer::reset() {
81111
buffer_.clear();
82112
fresh_.clear();
83113
committedInBuffer_.clear();
114+
115+
lastCommittedTime_ = 0.f;
84116
}
85117

86118
} // namespace rnexecutorch::models::speech_to_text::whisper::stream

0 commit comments

Comments
 (0)