Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions apps/speech/screens/SpeechToTextScreen.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ export const SpeechToTextScreen = ({ onBack }: { onBack: () => void }) => {
const [liveTranscribing, setLiveTranscribing] = useState(false);
const scrollViewRef = useRef<ScrollView>(null);

const recorder = new AudioRecorder();
const recorder = useRef(new AudioRecorder());

useEffect(() => {
AudioManager.setAudioSessionOptions({
Expand Down Expand Up @@ -115,7 +115,7 @@ export const SpeechToTextScreen = ({ onBack }: { onBack: () => void }) => {

const sampleRate = 16000;

recorder.onAudioReady(
recorder.current.onAudioReady(
{
sampleRate,
bufferLength: 0.1 * sampleRate,
Expand All @@ -131,7 +131,7 @@ export const SpeechToTextScreen = ({ onBack }: { onBack: () => void }) => {
if (!success) {
console.warn('Cannot start audio session correctly');
}
const result = recorder.start();
const result = recorder.current.start();
if (result.status === 'error') {
console.warn('Recording problems: ', result.message);
}
Expand Down Expand Up @@ -177,7 +177,7 @@ export const SpeechToTextScreen = ({ onBack }: { onBack: () => void }) => {
const handleStopTranscribeFromMicrophone = () => {
isRecordingRef.current = false;

recorder.stop();
recorder.current.stop();
model.streamStop();
console.log('Live transcription stopped');
setLiveTranscribing(false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ Since speech-to-text models can only process audio segments up to 30 seconds lon

`useSpeechToText` takes [`SpeechToTextProps`](../../06-api-reference/interfaces/SpeechToTextProps.md) that consists of:

- `model` of type [`SpeechToTextConfig`](../../06-api-reference/interfaces/SpeechToTextModelConfig.md), containing the [`isMultilingual` flag](../../06-api-reference/interfaces/SpeechToTextModelConfig.md#ismultilingual), [tokenizer source](../../06-api-reference/interfaces/SpeechToTextModelConfig.md#tokenizersource), [encoder source](../../06-api-reference/interfaces/SpeechToTextModelConfig.md#encodersource), and [decoder source](../../06-api-reference/interfaces/SpeechToTextModelConfig.md#decodersource).
- `model` of type [`SpeechToTextConfig`](../../06-api-reference/interfaces/SpeechToTextModelConfig.md), containing the [`isMultilingual` flag](../../06-api-reference/interfaces/SpeechToTextModelConfig.md#ismultilingual), [tokenizer source](../../06-api-reference/interfaces/SpeechToTextModelConfig.md#tokenizersource) and [model source](../../06-api-reference/interfaces/SpeechToTextModelConfig.md#modelsource).
- An optional flag [`preventLoad`](../../06-api-reference/interfaces/SpeechToTextProps.md#preventload) which prevents auto-loading of the model.

You need more details? Check the following resources:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,7 @@ Create an instance of [`SpeechToTextModule`](../../06-api-reference/classes/Spee
- [`model`](../../06-api-reference/classes/SpeechToTextModule.md#model) - Object containing:
- [`isMultilingual`](../../06-api-reference/interfaces/SpeechToTextModelConfig.md#ismultilingual) - Flag indicating if model is multilingual.

- [`encoderSource`](../../06-api-reference/interfaces/SpeechToTextModelConfig.md#encodersource) - The location of the used encoder.

- [`decoderSource`](../../06-api-reference/interfaces/SpeechToTextModelConfig.md#decodersource) - The location of the used decoder.
- [`modelSource`](../../06-api-reference/interfaces/SpeechToTextModelConfig.md#modelsource) - The location of the used model (bundled encoder + decoder functionality).

- [`tokenizerSource`](../../06-api-reference/interfaces/SpeechToTextModelConfig.md#tokenizersource) - The location of the used tokenizer.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
#include <rnexecutorch/metaprogramming/TypeConcepts.h>
#include <rnexecutorch/models/object_detection/Types.h>
#include <rnexecutorch/models/ocr/Types.h>
#include <rnexecutorch/models/speech_to_text/types/Segment.h>
#include <rnexecutorch/models/speech_to_text/types/TranscriptionResult.h>
#include <rnexecutorch/models/speech_to_text/common/types/Segment.h>
#include <rnexecutorch/models/speech_to_text/common/types/TranscriptionResult.h>
#include <rnexecutorch/models/voice_activity_detection/Types.h>

using namespace rnexecutorch::models::speech_to_text::types;
using namespace rnexecutorch::models::speech_to_text;

namespace rnexecutorch::jsi_conversion {

Expand Down Expand Up @@ -513,7 +513,8 @@ inline jsi::Value getJsiValue(const Segment &seg, jsi::Runtime &runtime) {
jsi::Object wordObj(runtime);
wordObj.setProperty(
runtime, "word",
jsi::String::createFromUtf8(runtime, seg.words[i].content));
jsi::String::createFromUtf8(runtime, seg.words[i].content +
seg.words[i].punctations));
wordObj.setProperty(runtime, "start",
static_cast<double>(seg.words[i].start));
wordObj.setProperty(runtime, "end", static_cast<double>(seg.words[i].end));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
#include <string>
#include <vector>

#include "rnexecutorch/metaprogramming/ConstructorHelpers.h"
#include <ReactCommon/CallInvoker.h>
#include <executorch/extension/module/module.h>
#include <jsi/jsi.h>
#include <rnexecutorch/host_objects/JSTensorViewIn.h>
#include <rnexecutorch/host_objects/JSTensorViewOut.h>
#include <rnexecutorch/metaprogramming/ConstructorHelpers.h>
Comment thread
msluszniak marked this conversation as resolved.

namespace rnexecutorch {
namespace models {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,54 +1,66 @@
#include <thread>

#include "SpeechToText.h"
#include "common/types/TranscriptionResult.h"
#include "whisper/ASR.h"
#include "whisper/OnlineASR.h"
#include <rnexecutorch/Error.h>
#include <rnexecutorch/ErrorCodes.h>
#include <rnexecutorch/models/speech_to_text/types/TranscriptionResult.h>

namespace rnexecutorch::models::speech_to_text {

using namespace ::executorch::extension;
using namespace asr;
using namespace types;
using namespace stream;

SpeechToText::SpeechToText(const std::string &encoderSource,
const std::string &decoderSource,
SpeechToText::SpeechToText(const std::string &modelName,
const std::string &modelSource,
const std::string &tokenizerSource,
std::shared_ptr<react::CallInvoker> callInvoker)
: callInvoker(std::move(callInvoker)),
encoder(std::make_unique<BaseModel>(encoderSource, this->callInvoker)),
decoder(std::make_unique<BaseModel>(decoderSource, this->callInvoker)),
tokenizer(std::make_unique<TokenizerModule>(tokenizerSource,
this->callInvoker)),
asr(std::make_unique<ASR>(this->encoder.get(), this->decoder.get(),
this->tokenizer.get())),
processor(std::make_unique<OnlineASRProcessor>(this->asr.get())),
isStreaming(false), readyToProcess(false) {}

void SpeechToText::unload() noexcept {
this->encoder->unload();
this->decoder->unload();
: callInvoker_(std::move(callInvoker)) {
// Switch between the ASR implementations based on model name
if (modelName == "whisper") {
Comment thread
IgorSwat marked this conversation as resolved.
transcriber_ = std::make_unique<whisper::ASR>(modelSource, tokenizerSource,
callInvoker_);
streamer_ = std::make_unique<whisper::stream::OnlineASR>(
static_cast<const whisper::ASR *>(transcriber_.get()));
} else {
throw rnexecutorch::RnExecutorchError(
rnexecutorch::RnExecutorchErrorCode::InvalidConfig,
"[SpeechToText]: Invalid model name: " + modelName);
}
}

SpeechToText::SpeechToText(SpeechToText &&other) noexcept
: callInvoker_(std::move(other.callInvoker_)),
transcriber_(std::move(other.transcriber_)),
streamer_(std::move(other.streamer_)),
isStreaming_(other.isStreaming_.load()),
readyToProcess_(other.readyToProcess_.load()) {}

void SpeechToText::unload() noexcept { transcriber_->unload(); }

std::shared_ptr<OwningArrayBuffer>
SpeechToText::encode(std::span<float> waveform) const {
std::vector<float> encoderOutput = this->asr->encode(waveform);
return std::make_shared<OwningArrayBuffer>(encoderOutput);
executorch::aten::Tensor encoderOutputTensor = transcriber_->encode(waveform);

return std::make_shared<OwningArrayBuffer>(
encoderOutputTensor.const_data_ptr<float>(),
sizeof(float) * encoderOutputTensor.numel());
}

std::shared_ptr<OwningArrayBuffer>
SpeechToText::decode(std::span<uint64_t> tokens,
std::span<float> encoderOutput) const {
std::vector<float> decoderOutput = this->asr->decode(tokens, encoderOutput);
return std::make_shared<OwningArrayBuffer>(decoderOutput);
executorch::aten::Tensor decoderOutputTensor =
transcriber_->decode(tokens, encoderOutput);

return std::make_shared<OwningArrayBuffer>(
decoderOutputTensor.const_data_ptr<float>(),
sizeof(float) * decoderOutputTensor.numel());
}

TranscriptionResult SpeechToText::transcribe(std::span<float> waveform,
std::string languageOption,
bool verbose) const {
DecodingOptions options(languageOption, verbose);
std::vector<Segment> segments = this->asr->transcribe(waveform, options);
std::vector<Segment> segments = transcriber_->transcribe(waveform, options);

std::string fullText;
for (const auto &segment : segments) {
Expand All @@ -70,8 +82,7 @@ TranscriptionResult SpeechToText::transcribe(std::span<float> waveform,
}

size_t SpeechToText::getMemoryLowerBound() const noexcept {
return this->encoder->getMemoryLowerBound() +
this->decoder->getMemoryLowerBound();
return transcriber_->getMemoryLowerBound();
}

namespace {
Expand All @@ -83,7 +94,7 @@ TranscriptionResult wordsToResult(const std::vector<Word> &words,

std::string fullText;
for (const auto &w : words) {
fullText += w.content;
fullText += w.content + w.punctations;
}
res.text = fullText;

Expand All @@ -105,68 +116,70 @@ TranscriptionResult wordsToResult(const std::vector<Word> &words,

void SpeechToText::stream(std::shared_ptr<jsi::Function> callback,
std::string languageOption, bool verbose) {
if (this->isStreaming) {
if (isStreaming_) {
throw RnExecutorchError(RnExecutorchErrorCode::StreamingInProgress,
"Streaming is already in progress!");
}

auto nativeCallback = [this, callback,
verbose](const TranscriptionResult &committed,
const TranscriptionResult &nonCommitted,
bool isDone) {
// This moves execution to the JS thread
this->callInvoker->invokeAsync(
[callback, committed, nonCommitted, isDone, verbose](jsi::Runtime &rt) {
jsi::Value jsiCommitted =
rnexecutorch::jsi_conversion::getJsiValue(committed, rt);
jsi::Value jsiNonCommitted =
rnexecutorch::jsi_conversion::getJsiValue(nonCommitted, rt);

callback->call(rt, std::move(jsiCommitted),
std::move(jsiNonCommitted), jsi::Value(isDone));
});
};

this->isStreaming = true;
auto nativeCallback =
[this, callback](const TranscriptionResult &committed,
const TranscriptionResult &nonCommitted, bool isDone) {
// This moves execution to the JS thread
callInvoker_->invokeAsync(
[callback, committed, nonCommitted, isDone](jsi::Runtime &rt) {
jsi::Value jsiCommitted =
rnexecutorch::jsi_conversion::getJsiValue(committed, rt);
jsi::Value jsiNonCommitted =
rnexecutorch::jsi_conversion::getJsiValue(nonCommitted, rt);

callback->call(rt, std::move(jsiCommitted),
std::move(jsiNonCommitted), jsi::Value(isDone));
});
};

isStreaming_ = true;
DecodingOptions options(languageOption, verbose);

while (this->isStreaming) {
if (!this->readyToProcess ||
this->processor->audioBuffer.size() < SpeechToText::kMinAudioSamples) {
std::this_thread::sleep_for(std::chrono::milliseconds(100));
continue;
}
while (isStreaming_) {
if (readyToProcess_ && streamer_->isReady()) {
ProcessResult res = streamer_->process(options);

ProcessResult res = this->processor->processIter(options);
TranscriptionResult committedRes =
wordsToResult(res.committed, languageOption, verbose);
TranscriptionResult nonCommittedRes =
wordsToResult(res.nonCommitted, languageOption, verbose);

TranscriptionResult cRes =
wordsToResult(res.committed, languageOption, verbose);
TranscriptionResult ncRes =
wordsToResult(res.nonCommitted, languageOption, verbose);
nativeCallback(committedRes, nonCommittedRes, false);
readyToProcess_ = false;
}

nativeCallback(cRes, ncRes, false);
this->readyToProcess = false;
// Add a minimal pause between transcriptions.
// The reasoning is very simple: with the current liberal threshold values,
// running transcriptions too rapidly (before the audio buffer is filled
// with significant amount of new data) can cause streamer to commit wrong
// phrases.
std::this_thread::sleep_for(std::chrono::milliseconds(100));
}

std::vector<Word> finalWords = this->processor->finish();
std::vector<Word> finalWords = streamer_->finish();
TranscriptionResult finalRes =
wordsToResult(finalWords, languageOption, verbose);

nativeCallback(finalRes, {}, true);
this->resetStreamState();
resetStreamState();
}

void SpeechToText::streamStop() { this->isStreaming = false; }
void SpeechToText::streamStop() { isStreaming_ = false; }

void SpeechToText::streamInsert(std::span<float> waveform) {
this->processor->insertAudioChunk(waveform);
this->readyToProcess = true;
streamer_->insertAudioChunk(waveform);
readyToProcess_ = true;
}

void SpeechToText::resetStreamState() {
this->isStreaming = false;
this->readyToProcess = false;
this->processor = std::make_unique<OnlineASRProcessor>(this->asr.get());
isStreaming_ = false;
readyToProcess_ = false;
streamer_->reset();
}

} // namespace rnexecutorch::models::speech_to_text
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@
#pragma once

#include "rnexecutorch/models/speech_to_text/stream/OnlineASRProcessor.h"
#include <rnexecutorch/models/speech_to_text/types/TranscriptionResult.h>
#include <atomic>
#include <span>
#include <string>
#include <vector>

#include "common/schema/ASR.h"
#include "common/schema/OnlineASR.h"
#include "common/types/TranscriptionResult.h"

namespace rnexecutorch {

namespace models::speech_to_text {

class SpeechToText {
public:
explicit SpeechToText(const std::string &encoderSource,
const std::string &decoderSource,
const std::string &tokenizerSource,
std::shared_ptr<react::CallInvoker> callInvoker);
SpeechToText(const std::string &modelName, const std::string &modelSource,
const std::string &tokenizerSource,
std::shared_ptr<react::CallInvoker> callInvoker);

// Required because of std::atomic usage
SpeechToText(SpeechToText &&other) noexcept;

void unload() noexcept;
[[nodiscard(
Expand All @@ -25,9 +30,9 @@ class SpeechToText {
"Registered non-void function")]] std::shared_ptr<OwningArrayBuffer>
decode(std::span<uint64_t> tokens, std::span<float> encoderOutput) const;
[[nodiscard("Registered non-void function")]]
types::TranscriptionResult transcribe(std::span<float> waveform,
std::string languageOption,
bool verbose) const;
TranscriptionResult transcribe(std::span<float> waveform,
std::string languageOption,
bool verbose) const;

[[nodiscard("Registered non-void function")]]
std::vector<char> transcribeStringOnly(std::span<float> waveform,
Expand All @@ -42,20 +47,17 @@ class SpeechToText {
void streamInsert(std::span<float> waveform);

private:
std::shared_ptr<react::CallInvoker> callInvoker;
std::unique_ptr<BaseModel> encoder;
std::unique_ptr<BaseModel> decoder;
std::unique_ptr<TokenizerModule> tokenizer;
std::unique_ptr<asr::ASR> asr;
void resetStreamState();

// Stream
std::unique_ptr<stream::OnlineASRProcessor> processor;
bool isStreaming;
bool readyToProcess;
std::shared_ptr<react::CallInvoker> callInvoker_;

constexpr static int32_t kMinAudioSamples = 16000; // 1 second
// ASR-like module (both static transcription & streaming)
std::unique_ptr<schema::ASR> transcriber_ = nullptr;

void resetStreamState();
// Online ASR-like module (streaming only)
std::unique_ptr<schema::OnlineASR> streamer_ = nullptr;
std::atomic<bool> isStreaming_ = false;
std::atomic<bool> readyToProcess_ = false;
};

} // namespace models::speech_to_text
Expand Down
Loading
Loading