@@ -30,7 +30,7 @@ ASR::ASR(const std::string &modelSource, const std::string &tokenizerSource,
3030/* *
3131 * Whisper inference - full transcription
3232 */
33- std::vector<Segment> ASR::transcribe (std::span<float > waveform,
33+ std::vector<Segment> ASR::transcribe (std::span<const float > waveform,
3434 const DecodingOptions &options) const {
3535 // Use floats to prevent downcasting and timestamp mismatches
3636 float seek = 0 .f ;
@@ -99,11 +99,12 @@ std::vector<Segment> ASR::transcribe(std::span<float> waveform,
9999 * The input is a standard audio waveform, altough it is implicitly converted
100100 * to a log mel format inside the encoder call.
101101 */
102- std::vector< float > ASR::encode (std::span<float > waveform) const {
102+ executorch::aten::Tensor ASR::encode (std::span<const float > waveform) const {
103103 auto inputShape = {static_cast <int32_t >(waveform.size ())};
104104
105105 const auto modelInputTensor = executorch::extension::make_tensor_ptr (
106- std::move (inputShape), waveform.data (), ScalarType::Float);
106+ std::move (inputShape), const_cast <float *>(waveform.data ()),
107+ ScalarType::Float);
107108
108109 const auto encoderResult = this ->execute (" encode" , {modelInputTensor});
109110
@@ -113,21 +114,17 @@ std::vector<float> ASR::encode(std::span<float> waveform) const {
113114 " Ensure the model input is correct." );
114115 }
115116
116- const auto encoderOutputTensor = encoderResult.get ().at (0 ).toTensor ();
117- const auto outputNumel = encoderOutputTensor.numel ();
118-
119- const float *const dataPtr = encoderOutputTensor.const_data_ptr <float >();
120- return {dataPtr, dataPtr + outputNumel};
117+ return encoderResult.get ().at (0 ).toTensor ();
121118}
122119
123120/* *
124121 * Whisper inference - decoding phase
125122 *
126123 * An autoregressive decoder, called with increasing amount of input tokens.
127124 */
128- std::vector< float > ASR::decode (std::span<uint64_t > tokens,
129- std::span<float > encoderOutput,
130- uint64_t startPos) const {
125+ executorch::aten::Tensor ASR::decode (std::span<uint64_t > tokens,
126+ std::span<const float > encoderOutput,
127+ uint64_t startPos) const {
131128 std::vector<int32_t > tokenShape = {1 , static_cast <int32_t >(tokens.size ())};
132129 std::vector<int32_t > positionShape = {static_cast <int32_t >(tokens.size ())};
133130
@@ -144,7 +141,8 @@ std::vector<float> ASR::decode(std::span<uint64_t> tokens,
144141 std::vector<int32_t > encShape = {1 , constants::kNumFrames ,
145142 encoderOutputSize / constants::kNumFrames };
146143 auto encoderTensor = executorch::extension::make_tensor_ptr (
147- std::move (encShape), encoderOutput.data (), ScalarType::Float);
144+ std::move (encShape), const_cast <float *>(encoderOutput.data ()),
145+ ScalarType::Float);
148146
149147 const auto decoderResult =
150148 this ->execute (" decode" , {tokenTensor, positionTensor, encoderTensor});
@@ -155,16 +153,7 @@ std::vector<float> ASR::decode(std::span<uint64_t> tokens,
155153 " Ensure the model inputs are correct." );
156154 }
157155
158- const auto logitsTensor = decoderResult.get ().at (0 ).toTensor ();
159- const int32_t outputNumel = static_cast <int32_t >(logitsTensor.numel ());
160-
161- const size_t innerDim = logitsTensor.size (1 );
162- const size_t dictSize = logitsTensor.size (2 );
163-
164- const float *const dataPtr =
165- logitsTensor.const_data_ptr <float >() + (innerDim - 1 ) * dictSize;
166-
167- return {dataPtr, dataPtr + outputNumel / innerDim};
156+ return decoderResult.get ().at (0 ).toTensor ();
168157}
169158
170159void ASR::unload () noexcept { BaseModel::unload (); }
@@ -197,14 +186,18 @@ ASR::createInitialSequence(const DecodingOptions &options) const {
197186/* *
198187 * Helper functions - generation wrapper, with fallback
199188 */
200- std::vector<Segment> ASR::generate (std::span<float > waveform,
189+ std::vector<Segment> ASR::generate (std::span<const float > waveform,
201190 const DecodingOptions &options) const {
202191 // A fixed pool of available temperatures
203192 constexpr std::array<float , 6 > temperatures = {0 .0f , 0 .2f , 0 .4f ,
204193 0 .6f , 0 .8f , 1 .0f };
205194
206195 // Calculate audio features just once to save time.
207- std::vector<float > encoderOutput = this ->encode (waveform);
196+ executorch::aten::Tensor encoderFeaturesTensor = this ->encode (waveform);
197+ const float *encoderFeaturesData =
198+ encoderFeaturesTensor.const_data_ptr <float >();
199+ std::span<const float > encoderFeatures (
200+ encoderFeaturesData, encoderFeaturesData + encoderFeaturesTensor.numel ());
208201
209202 std::vector<uint64_t > bestTokens;
210203 float bestAvgLogProb = -std::numeric_limits<float >::infinity ();
@@ -213,7 +206,7 @@ std::vector<Segment> ASR::generate(std::span<float> waveform,
213206
214207 for (auto t : temperatures) {
215208 auto [tokens, scores] =
216- this ->generate (waveform, options, t, {encoderOutput });
209+ this ->generate (waveform, options, t, {encoderFeatures });
217210
218211 const float cumLogProb = std::transform_reduce (
219212 scores.begin (), scores.end (), 0 .0f , std::plus<>(),
@@ -248,15 +241,20 @@ std::vector<Segment> ASR::generate(std::span<float> waveform,
248241 * Helper functions - generation wrapper, single-temperature inference
249242 */
250243GenerationResult
251- ASR::generate (std::span<float > waveform, const DecodingOptions &options,
244+ ASR::generate (std::span<const float > waveform, const DecodingOptions &options,
252245 float temperature,
253- std::optional<std::span<float >> encoderOutput) const {
254- std::vector<float > encoderOutputData = !encoderOutput.has_value ()
255- ? this ->encode (waveform)
256- : std::vector<float >();
257- std::span<float > encodings = encoderOutput.has_value ()
258- ? encoderOutput.value ()
259- : std::span<float >(encoderOutputData);
246+ std::optional<std::span<const float >> encoderOutput) const {
247+ std::span<const float > encoderFeatures;
248+ if (encoderOutput.has_value ()) {
249+ encoderFeatures = encoderOutput.value ();
250+ } else {
251+ executorch::aten::Tensor encoderFeaturesTensor = this ->encode (waveform);
252+ const float *encoderFeaturesData =
253+ encoderFeaturesTensor.const_data_ptr <float >();
254+ encoderFeatures =
255+ std::span (encoderFeaturesData,
256+ encoderFeaturesData + encoderFeaturesTensor.numel ());
257+ }
260258
261259 std::vector<uint64_t > sequenceIds = this ->createInitialSequence (options);
262260 std::vector<uint64_t > cachedTokens = sequenceIds;
@@ -266,7 +264,17 @@ ASR::generate(std::span<float> waveform, const DecodingOptions &options,
266264 uint64_t startPos = 0 ;
267265 while (std::cmp_less_equal (startPos + sequenceIds.size (),
268266 constants::kMaxDecodeLength )) {
269- std::vector<float > logits = this ->decode (sequenceIds, encodings, startPos);
267+ executorch::aten::Tensor logitsTensor =
268+ this ->decode (sequenceIds, encoderFeatures, startPos);
269+
270+ const size_t logitsInnerDim = logitsTensor.size (1 );
271+ const size_t logitsDictSize = logitsTensor.size (2 );
272+ const float *logitsData = logitsTensor.const_data_ptr <float >() +
273+ (logitsInnerDim - 1 ) * logitsDictSize;
274+ // Needs to be float* without const for compatibility with utility functions
275+ std::span<float > logits (const_cast <float *>(logitsData),
276+ const_cast <float *>(logitsData) +
277+ logitsTensor.numel () / logitsInnerDim);
270278
271279 // intentionally comparing float to float
272280 // temperatures are predefined, so this is safe
@@ -276,7 +284,7 @@ ASR::generate(std::span<float> waveform, const DecodingOptions &options,
276284 numerical::softmaxWithTemperature (logits, temperature);
277285 }
278286
279- const std::vector< float > & probs = logits;
287+ auto probs = logits;
280288
281289 uint64_t nextId;
282290 float nextProb;
@@ -311,9 +319,11 @@ ASR::generate(std::span<float> waveform, const DecodingOptions &options,
311319 .scores = scores};
312320}
313321
314- std::vector<Segment> ASR::calculateWordLevelTimestamps (
315- std::span<const uint64_t > generatedTokens, const std::span<float > waveform,
316- float avgLogProb, float temperature, float compressionRatio) const {
322+ std::vector<Segment>
323+ ASR::calculateWordLevelTimestamps (std::span<const uint64_t > generatedTokens,
324+ const std::span<const float > waveform,
325+ float avgLogProb, float temperature,
326+ float compressionRatio) const {
317327 const size_t generatedTokensSize = generatedTokens.size ();
318328 if (generatedTokensSize < 2 ||
319329 generatedTokens[generatedTokensSize - 1 ] != endOfTranscriptionToken_ ||
0 commit comments