Skip to content

Commit bab2ffb

Browse files
committed
Add whisper kv-cache & fix demo app permissions
1 parent 1b7363d commit bab2ffb

File tree

4 files changed

+31
-13
lines changed

4 files changed

+31
-13
lines changed

packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ SpeechToText::encode(std::span<float> waveform) const {
4040
std::shared_ptr<OwningArrayBuffer>
4141
SpeechToText::decode(std::span<uint64_t> tokens,
4242
std::span<float> encoderOutput) const {
43-
std::vector<float> decoderOutput = this->asr->decode(tokens, encoderOutput);
43+
std::vector<float> decoderOutput =
44+
this->asr->decode(tokens, 0, encoderOutput);
4445
return std::make_shared<OwningArrayBuffer>(decoderOutput);
4546
}
4647

packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.cpp

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include <numeric>
12
#include <random>
23
#include <sstream>
34

@@ -42,11 +43,15 @@ GenerationResult ASR::generate(std::span<float> waveform, float temperature,
4243
std::vector<float> encoderOutput = this->encode(waveform);
4344

4445
std::vector<uint64_t> sequenceIds = this->getInitialSequence(options);
46+
std::vector<uint64_t> cachedTokens = sequenceIds;
4547
const size_t initialSequenceLenght = sequenceIds.size();
4648
std::vector<float> scores;
4749

48-
while (std::cmp_less_equal(sequenceIds.size(), ASR::kMaxDecodeLength)) {
49-
std::vector<float> logits = this->decode(sequenceIds, encoderOutput);
50+
uint64_t startPos = 0;
51+
while (std::cmp_less_equal(startPos + sequenceIds.size(),
52+
ASR::kMaxDecodeLength)) {
53+
std::vector<float> logits =
54+
this->decode(sequenceIds, startPos, encoderOutput);
5055

5156
// intentionally comparing float to float
5257
// temperatures are predefined, so this is safe
@@ -74,16 +79,20 @@ GenerationResult ASR::generate(std::span<float> waveform, float temperature,
7479
nextProb = probs[nextId];
7580
}
7681

77-
sequenceIds.push_back(nextId);
82+
// Move the startPos pointer by the amount of tokens we processed
83+
startPos += sequenceIds.size();
84+
sequenceIds = {nextId};
85+
cachedTokens.push_back(nextId);
7886
scores.push_back(nextProb);
7987

8088
if (nextId == this->endOfTranscriptionToken) {
8189
break;
8290
}
8391
}
8492

85-
return {.tokens = std::vector<uint64_t>(
86-
sequenceIds.cbegin() + initialSequenceLenght, sequenceIds.cend()),
93+
return {.tokens = std::vector<uint64_t>(cachedTokens.cbegin() +
94+
initialSequenceLenght,
95+
cachedTokens.cend()),
8796
.scores = scores};
8897
}
8998

@@ -318,13 +327,19 @@ std::vector<float> ASR::encode(std::span<float> waveform) const {
318327
return {dataPtr, dataPtr + outputNumel};
319328
}
320329

321-
std::vector<float> ASR::decode(std::span<const uint64_t> tokens,
330+
std::vector<float> ASR::decode(std::span<uint64_t> tokens, uint64_t startPos,
322331
std::span<float> encoderOutput) const {
323332
std::vector<int32_t> tokenShape = {1, static_cast<int32_t>(tokens.size())};
324-
auto tokensLong = std::vector<int64_t>(tokens.begin(), tokens.end());
333+
std::vector<int32_t> positionShape = {static_cast<int32_t>(tokens.size())};
325334

326335
auto tokenTensor = executorch::extension::make_tensor_ptr(
327-
tokenShape, tokensLong.data(), ScalarType::Long);
336+
tokenShape, tokens.data(), ScalarType::Long);
337+
338+
// Populate cache position vector
339+
std::vector<uint64_t> cachePositions(tokens.size());
340+
std::iota(cachePositions.begin(), cachePositions.end(), startPos);
341+
auto positionTensor = executorch::extension::make_tensor_ptr(
342+
positionShape, cachePositions.data(), ScalarType::Long);
328343

329344
const auto encoderOutputSize = static_cast<int32_t>(encoderOutput.size());
330345
std::vector<int32_t> encShape = {1, ASR::kNumFrames,
@@ -333,7 +348,7 @@ std::vector<float> ASR::decode(std::span<const uint64_t> tokens,
333348
std::move(encShape), encoderOutput.data(), ScalarType::Float);
334349

335350
const auto decoderResult =
336-
this->decoder->forward({tokenTensor, encoderTensor});
351+
this->decoder->forward({tokenTensor, positionTensor, encoderTensor});
337352

338353
if (!decoderResult.ok()) {
339354
throw RnExecutorchError(decoderResult.error(),

packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class ASR {
1717
transcribe(std::span<float> waveform,
1818
const types::DecodingOptions &options) const;
1919
std::vector<float> encode(std::span<float> waveform) const;
20-
std::vector<float> decode(std::span<const uint64_t> tokens,
20+
std::vector<float> decode(std::span<uint64_t> tokens, uint64_t startPos,
2121
std::span<float> encoderOutput) const;
2222

2323
private:

packages/react-native-executorch/src/constants/modelUrls.ts

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -418,8 +418,10 @@ export const STYLE_TRANSFER_UDNIE = {
418418

419419
// S2T
420420
const WHISPER_TINY_EN_TOKENIZER = `${URL_PREFIX}-whisper-tiny.en/${VERSION_TAG}/tokenizer.json`;
421-
const WHISPER_TINY_EN_ENCODER = `${URL_PREFIX}-whisper-tiny.en/${VERSION_TAG}/xnnpack/whisper_tiny_en_encoder_xnnpack.pte`;
422-
const WHISPER_TINY_EN_DECODER = `${URL_PREFIX}-whisper-tiny.en/${VERSION_TAG}/xnnpack/whisper_tiny_en_decoder_xnnpack.pte`;
421+
// const WHISPER_TINY_EN_ENCODER = `${URL_PREFIX}-whisper-tiny.en/${VERSION_TAG}/xnnpack/whisper_tiny_en_encoder_xnnpack.pte`;
422+
// const WHISPER_TINY_EN_DECODER = `${URL_PREFIX}-whisper-tiny.en/${VERSION_TAG}/xnnpack/whisper_tiny_en_decoder_xnnpack.pte`;
423+
const WHISPER_TINY_EN_ENCODER = `${URL_PREFIX}-whisper-tiny.en/resolve/kv-cache/xnnpack/whisper_tiny_en_encoder_xnnpack.pte`;
424+
const WHISPER_TINY_EN_DECODER = `${URL_PREFIX}-whisper-tiny.en/resolve/kv-cache/xnnpack/whisper_tiny_en_decoder_xnnpack.pte`;
423425

424426
const WHISPER_TINY_EN_ENCODER_QUANTIZED = `${URL_PREFIX}-whisper-tiny-quantized.en/${VERSION_TAG}/xnnpack/whisper_tiny_quantized_en_encoder_xnnpack.pte`;
425427
const WHISPER_TINY_EN_DECODER_QUANTIZED = `${URL_PREFIX}-whisper-tiny-quantized.en/${VERSION_TAG}/xnnpack/whisper_tiny_quantized_en_decoder_xnnpack.pte`;

0 commit comments

Comments
 (0)