Skip to content

Commit 2ee6d1d

Browse files
committed
Remove special tokens
1 parent f42351b commit 2ee6d1d

File tree

2 files changed

+16
-21
lines changed
  • packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper

2 files changed

+16
-21
lines changed

packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/ASR.cpp

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -241,12 +241,6 @@ std::vector<Segment> ASR::generate(std::span<float> waveform,
241241
}
242242
}
243243

244-
rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info,
245-
"[ASR] Raw transcription results (tokens): ", bestTokens);
246-
rnexecutorch::log(rnexecutorch::LOG_LEVEL::Info,
247-
"[ASR] Raw transcription results (text): ",
248-
tokenizer_->decode(bestTokens, true));
249-
250244
return this->calculateWordLevelTimestamps(bestTokens, waveform,
251245
bestAvgLogProb, bestTemperature,
252246
bestCompressionRatio);
@@ -344,7 +338,6 @@ std::vector<Segment> ASR::calculateWordLevelTimestamps(
344338
if (words.size()) {
345339
Segment seg;
346340
seg.words = std::move(words);
347-
// seg.tokens = {}; // WTF ?
348341
seg.tokens = tokens;
349342
seg.avgLogprob = avgLogProb;
350343
seg.temperature = temperature;
@@ -369,17 +362,19 @@ std::vector<Segment> ASR::calculateWordLevelTimestamps(
369362
const uint64_t end = generatedTokens[generatedTokensSize - 2];
370363
auto words = this->estimateWordLevelTimestampsLinear(tokens, start, end);
371364

365+
if (words.empty()) {
366+
return {};
367+
}
368+
372369
Segment seg;
373370
seg.words = std::move(words);
374371
seg.tokens = tokens;
375372
seg.avgLogprob = avgLogProb;
376373
seg.temperature = temperature;
377374
seg.compressionRatio = compressionRatio;
378375

379-
if (!seg.words.empty()) {
380-
seg.start = seg.words.front().start;
381-
seg.end = seg.words.back().end;
382-
}
376+
seg.start = seg.words.front().start;
377+
seg.end = seg.words.back().end;
383378

384379
segments.push_back(std::move(seg));
385380

@@ -409,8 +404,12 @@ ASR::estimateWordLevelTimestampsLinear(std::span<const uint64_t> tokens,
409404
std::vector<std::string> wordsStr;
410405
std::string word;
411406
while (iss >> word) {
412-
wordsStr.emplace_back(" ");
413-
wordsStr.back().append(word);
407+
// Detect special tokens such as [BLANK_AUDIO] by searching for square
408+
// bracket
409+
if (word.find('[') == std::string::npos) {
410+
wordsStr.emplace_back(" ");
411+
wordsStr.back().append(word);
412+
}
414413
}
415414

416415
size_t numChars = 0;

packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/OnlineASR.cpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,9 @@ ProcessResult OnlineASR::process(const DecodingOptions &options) {
9292
// (assuming some fixed words per second frequency).
9393
const float freshDuration = newEnd - establishedEnd;
9494
const float epsilon = std::max(
95-
0.F, 0.8F * (freshDuration -
96-
static_cast<float>(noNewWords /
97-
params::kStreamWordsPerSecond)));
98-
const float beforeScaleStart = hypothesisBuffer_.fresh_[i].start;
99-
const float beforeScaleEnd = hypothesisBuffer_.fresh_[i].end;
95+
0.F, 0.85F * (freshDuration -
96+
static_cast<float>(noNewWords /
97+
params::kStreamWordsPerSecond)));
10098
float scale = (freshDuration - epsilon) / (newEnd - newBegin);
10199
hypothesisBuffer_.fresh_[i].start =
102100
(hypothesisBuffer_.fresh_[i].start - newEnd) * scale + newEnd;
@@ -134,9 +132,7 @@ ProcessResult OnlineASR::process(const DecodingOptions &options) {
134132
std::vector<Word> OnlineASR::finish() {
135133
// We always push the last remaining hypothesis, even if it's not
136134
// confirmed in second iteration.
137-
auto remaining = hypothesisBuffer_.hypothesis_;
138-
139-
reset();
135+
std::deque<Word> remaining = hypothesisBuffer_.hypothesis_;
140136

141137
return move_to_vector(remaining);
142138
}

0 commit comments

Comments
 (0)