Skip to content

Commit 1c23135

Browse files
committed
Optimize by removing unnecessary copies
1 parent 485c797 commit 1c23135

4 files changed

Lines changed: 80 additions & 61 deletions

File tree

packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,22 @@ void SpeechToText::unload() noexcept { transcriber_->unload(); }
3131

3232
std::shared_ptr<OwningArrayBuffer>
3333
SpeechToText::encode(std::span<float> waveform) const {
34-
std::vector<float> encoderOutput = transcriber_->encode(waveform);
35-
return std::make_shared<OwningArrayBuffer>(encoderOutput);
34+
executorch::aten::Tensor encoderOutputTensor = transcriber_->encode(waveform);
35+
36+
return std::make_shared<OwningArrayBuffer>(
37+
encoderOutputTensor.const_data_ptr(),
38+
sizeof(float) * encoderOutputTensor.numel());
3639
}
3740

3841
std::shared_ptr<OwningArrayBuffer>
3942
SpeechToText::decode(std::span<uint64_t> tokens,
4043
std::span<float> encoderOutput) const {
41-
std::vector<float> decoderOutput =
44+
executorch::aten::Tensor decoderOutputTensor =
4245
transcriber_->decode(tokens, encoderOutput);
43-
return std::make_shared<OwningArrayBuffer>(decoderOutput);
46+
47+
return std::make_shared<OwningArrayBuffer>(
48+
decoderOutputTensor.const_data_ptr(),
49+
sizeof(float) * decoderOutputTensor.numel());
4450
}
4551

4652
TranscriptionResult SpeechToText::transcribe(std::span<float> waveform,

packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/schema/ASR.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,15 @@ class ASR {
2222
virtual ~ASR() = default;
2323

2424
virtual std::vector<Segment>
25-
transcribe(std::span<float> waveform,
25+
transcribe(std::span<const float> waveform,
2626
const DecodingOptions &options) const = 0;
2727

28-
virtual std::vector<float> encode(std::span<float> waveform) const = 0;
28+
virtual executorch::aten::Tensor
29+
encode(std::span<const float> waveform) const = 0;
2930

30-
virtual std::vector<float> decode(std::span<uint64_t> tokens,
31-
std::span<float> encoderOutput,
32-
uint64_t startPos = 0) const = 0;
31+
virtual executorch::aten::Tensor decode(std::span<uint64_t> tokens,
32+
std::span<const float> encoderOutput,
33+
uint64_t startPos = 0) const = 0;
3334

3435
// Standard ExecuTorch model methods for compatibility with the rest of the
3536
// API.

packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/ASR.cpp

Lines changed: 48 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ ASR::ASR(const std::string &modelSource, const std::string &tokenizerSource,
3030
/**
3131
* Whisper inference - full transcription
3232
*/
33-
std::vector<Segment> ASR::transcribe(std::span<float> waveform,
33+
std::vector<Segment> ASR::transcribe(std::span<const float> waveform,
3434
const DecodingOptions &options) const {
3535
// Use floats to prevent downcasting and timestamp mismatches
3636
float seek = 0.f;
@@ -99,11 +99,12 @@ std::vector<Segment> ASR::transcribe(std::span<float> waveform,
9999
* The input is a standard audio waveform, altough it is implicitly converted
100100
* to a log mel format inside the encoder call.
101101
*/
102-
std::vector<float> ASR::encode(std::span<float> waveform) const {
102+
executorch::aten::Tensor ASR::encode(std::span<const float> waveform) const {
103103
auto inputShape = {static_cast<int32_t>(waveform.size())};
104104

105105
const auto modelInputTensor = executorch::extension::make_tensor_ptr(
106-
std::move(inputShape), waveform.data(), ScalarType::Float);
106+
std::move(inputShape), const_cast<float *>(waveform.data()),
107+
ScalarType::Float);
107108

108109
const auto encoderResult = this->execute("encode", {modelInputTensor});
109110

@@ -113,21 +114,17 @@ std::vector<float> ASR::encode(std::span<float> waveform) const {
113114
"Ensure the model input is correct.");
114115
}
115116

116-
const auto encoderOutputTensor = encoderResult.get().at(0).toTensor();
117-
const auto outputNumel = encoderOutputTensor.numel();
118-
119-
const float *const dataPtr = encoderOutputTensor.const_data_ptr<float>();
120-
return {dataPtr, dataPtr + outputNumel};
117+
return encoderResult.get().at(0).toTensor();
121118
}
122119

123120
/**
124121
* Whisper inference - decoding phase
125122
*
126123
* An autoregressive decoder, called with increasing amount of input tokens.
127124
*/
128-
std::vector<float> ASR::decode(std::span<uint64_t> tokens,
129-
std::span<float> encoderOutput,
130-
uint64_t startPos) const {
125+
executorch::aten::Tensor ASR::decode(std::span<uint64_t> tokens,
126+
std::span<const float> encoderOutput,
127+
uint64_t startPos) const {
131128
std::vector<int32_t> tokenShape = {1, static_cast<int32_t>(tokens.size())};
132129
std::vector<int32_t> positionShape = {static_cast<int32_t>(tokens.size())};
133130

@@ -144,7 +141,8 @@ std::vector<float> ASR::decode(std::span<uint64_t> tokens,
144141
std::vector<int32_t> encShape = {1, constants::kNumFrames,
145142
encoderOutputSize / constants::kNumFrames};
146143
auto encoderTensor = executorch::extension::make_tensor_ptr(
147-
std::move(encShape), encoderOutput.data(), ScalarType::Float);
144+
std::move(encShape), const_cast<float *>(encoderOutput.data()),
145+
ScalarType::Float);
148146

149147
const auto decoderResult =
150148
this->execute("decode", {tokenTensor, positionTensor, encoderTensor});
@@ -155,16 +153,7 @@ std::vector<float> ASR::decode(std::span<uint64_t> tokens,
155153
"Ensure the model inputs are correct.");
156154
}
157155

158-
const auto logitsTensor = decoderResult.get().at(0).toTensor();
159-
const int32_t outputNumel = static_cast<int32_t>(logitsTensor.numel());
160-
161-
const size_t innerDim = logitsTensor.size(1);
162-
const size_t dictSize = logitsTensor.size(2);
163-
164-
const float *const dataPtr =
165-
logitsTensor.const_data_ptr<float>() + (innerDim - 1) * dictSize;
166-
167-
return {dataPtr, dataPtr + outputNumel / innerDim};
156+
return decoderResult.get().at(0).toTensor();
168157
}
169158

170159
void ASR::unload() noexcept { BaseModel::unload(); }
@@ -197,14 +186,18 @@ ASR::createInitialSequence(const DecodingOptions &options) const {
197186
/**
198187
* Helper functions - generation wrapper, with fallback
199188
*/
200-
std::vector<Segment> ASR::generate(std::span<float> waveform,
189+
std::vector<Segment> ASR::generate(std::span<const float> waveform,
201190
const DecodingOptions &options) const {
202191
// A fixed pool of available temperatures
203192
constexpr std::array<float, 6> temperatures = {0.0f, 0.2f, 0.4f,
204193
0.6f, 0.8f, 1.0f};
205194

206195
// Calculate audio features just once to save time.
207-
std::vector<float> encoderOutput = this->encode(waveform);
196+
executorch::aten::Tensor encoderFeaturesTensor = this->encode(waveform);
197+
const float *encoderFeaturesData =
198+
encoderFeaturesTensor.const_data_ptr<float>();
199+
std::span<const float> encoderFeatures(
200+
encoderFeaturesData, encoderFeaturesData + encoderFeaturesTensor.numel());
208201

209202
std::vector<uint64_t> bestTokens;
210203
float bestAvgLogProb = -std::numeric_limits<float>::infinity();
@@ -213,7 +206,7 @@ std::vector<Segment> ASR::generate(std::span<float> waveform,
213206

214207
for (auto t : temperatures) {
215208
auto [tokens, scores] =
216-
this->generate(waveform, options, t, {encoderOutput});
209+
this->generate(waveform, options, t, {encoderFeatures});
217210

218211
const float cumLogProb = std::transform_reduce(
219212
scores.begin(), scores.end(), 0.0f, std::plus<>(),
@@ -248,15 +241,20 @@ std::vector<Segment> ASR::generate(std::span<float> waveform,
248241
* Helper functions - generation wrapper, single-temperature inference
249242
*/
250243
GenerationResult
251-
ASR::generate(std::span<float> waveform, const DecodingOptions &options,
244+
ASR::generate(std::span<const float> waveform, const DecodingOptions &options,
252245
float temperature,
253-
std::optional<std::span<float>> encoderOutput) const {
254-
std::vector<float> encoderOutputData = !encoderOutput.has_value()
255-
? this->encode(waveform)
256-
: std::vector<float>();
257-
std::span<float> encodings = encoderOutput.has_value()
258-
? encoderOutput.value()
259-
: std::span<float>(encoderOutputData);
246+
std::optional<std::span<const float>> encoderOutput) const {
247+
std::span<const float> encoderFeatures;
248+
if (encoderOutput.has_value()) {
249+
encoderFeatures = encoderOutput.value();
250+
} else {
251+
executorch::aten::Tensor encoderFeaturesTensor = this->encode(waveform);
252+
const float *encoderFeaturesData =
253+
encoderFeaturesTensor.const_data_ptr<float>();
254+
encoderFeatures =
255+
std::span(encoderFeaturesData,
256+
encoderFeaturesData + encoderFeaturesTensor.numel());
257+
}
260258

261259
std::vector<uint64_t> sequenceIds = this->createInitialSequence(options);
262260
std::vector<uint64_t> cachedTokens = sequenceIds;
@@ -266,7 +264,17 @@ ASR::generate(std::span<float> waveform, const DecodingOptions &options,
266264
uint64_t startPos = 0;
267265
while (std::cmp_less_equal(startPos + sequenceIds.size(),
268266
constants::kMaxDecodeLength)) {
269-
std::vector<float> logits = this->decode(sequenceIds, encodings, startPos);
267+
executorch::aten::Tensor logitsTensor =
268+
this->decode(sequenceIds, encoderFeatures, startPos);
269+
270+
const size_t logitsInnerDim = logitsTensor.size(1);
271+
const size_t logitsDictSize = logitsTensor.size(2);
272+
const float *logitsData = logitsTensor.const_data_ptr<float>() +
273+
(logitsInnerDim - 1) * logitsDictSize;
274+
// Needs to be float* without const for compatibility with utility functions
275+
std::span<float> logits(const_cast<float *>(logitsData),
276+
const_cast<float *>(logitsData) +
277+
logitsTensor.numel() / logitsInnerDim);
270278

271279
// intentionally comparing float to float
272280
// temperatures are predefined, so this is safe
@@ -276,7 +284,7 @@ ASR::generate(std::span<float> waveform, const DecodingOptions &options,
276284
numerical::softmaxWithTemperature(logits, temperature);
277285
}
278286

279-
const std::vector<float> &probs = logits;
287+
auto probs = logits;
280288

281289
uint64_t nextId;
282290
float nextProb;
@@ -311,9 +319,11 @@ ASR::generate(std::span<float> waveform, const DecodingOptions &options,
311319
.scores = scores};
312320
}
313321

314-
std::vector<Segment> ASR::calculateWordLevelTimestamps(
315-
std::span<const uint64_t> generatedTokens, const std::span<float> waveform,
316-
float avgLogProb, float temperature, float compressionRatio) const {
322+
std::vector<Segment>
323+
ASR::calculateWordLevelTimestamps(std::span<const uint64_t> generatedTokens,
324+
const std::span<const float> waveform,
325+
float avgLogProb, float temperature,
326+
float compressionRatio) const {
317327
const size_t generatedTokensSize = generatedTokens.size();
318328
if (generatedTokensSize < 2 ||
319329
generatedTokens[generatedTokensSize - 1] != endOfTranscriptionToken_ ||

packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/ASR.h

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,19 @@ class ASR : public models::BaseModel, public schema::ASR {
3737
* @param options Control variables for decoding process.
3838
*/
3939
std::vector<Segment> virtual transcribe(
40-
std::span<float> waveform, const DecodingOptions &options) const override;
40+
std::span<const float> waveform,
41+
const DecodingOptions &options) const override;
4142

4243
/**
4344
* Encodes the input audio waveform into mel spectrogram embeddings.
4445
*
4546
* @param waveform Input audio waveform sampled at 16kHz.
46-
* @return Flat vector containing the encoder's output features.
47+
* @return Float tensor containing encoder's output features.
4748
* The output tensor shape: [1, 1500, 384] for Whisper
4849
* models.
4950
*/
50-
std::vector<float> encode(std::span<float> waveform) const override;
51+
executorch::aten::Tensor
52+
encode(std::span<const float> waveform) const override;
5153

5254
/**
5355
* Decodes a sequence of tokens into logits given the encoded audio features.
@@ -58,12 +60,12 @@ class ASR : public models::BaseModel, public schema::ASR {
5860
* embeddings.
5961
* @param startPos The starting position in the sequence (used for KV
6062
* caching).
61-
* @return A vector of floats representing the output logits for
62-
* the next token.
63+
* @return A tensor representing the output logits for the next
64+
* token.
6365
*/
64-
std::vector<float> decode(std::span<Token> tokens,
65-
std::span<float> encoderOutput,
66-
uint64_t startPos = 0) const override;
66+
executorch::aten::Tensor decode(std::span<Token> tokens,
67+
std::span<const float> encoderOutput,
68+
uint64_t startPos = 0) const override;
6769

6870
// Standard ExecuTorch model methods for compatibility with the rest of the
6971
// API.
@@ -95,7 +97,7 @@ class ASR : public models::BaseModel, public schema::ASR {
9597
* encode's input.
9698
* @param options Control variables for decoding process.
9799
*/
98-
std::vector<Segment> generate(std::span<float> waveform,
100+
std::vector<Segment> generate(std::span<const float> waveform,
99101
const DecodingOptions &options) const;
100102

101103
/**
@@ -112,10 +114,10 @@ class ASR : public models::BaseModel, public schema::ASR {
112114
* @param encoderOutput An optional parameter. If provided, the encoding phase
113115
* is skipped and the provided value is used instead.
114116
*/
115-
GenerationResult
116-
generate(std::span<float> waveform, const DecodingOptions &options,
117-
float temperature,
118-
std::optional<std::span<float>> encoderOutput = std::nullopt) const;
117+
GenerationResult generate(
118+
std::span<const float> waveform, const DecodingOptions &options,
119+
float temperature,
120+
std::optional<std::span<const float>> encoderOutput = std::nullopt) const;
119121

120122
/**
121123
* Calculates word-level timestamps for a sequence of generated tokens.
@@ -134,7 +136,7 @@ class ASR : public models::BaseModel, public schema::ASR {
134136
*/
135137
std::vector<Segment>
136138
calculateWordLevelTimestamps(std::span<const uint64_t> generatedTokens,
137-
const std::span<float> waveform,
139+
const std::span<const float> waveform,
138140
float avgLogProb, float temperature,
139141
float compressionRatio) const;
140142

0 commit comments

Comments
 (0)