1+ #include < numeric>
12#include < random>
23#include < sstream>
34
@@ -42,11 +43,15 @@ GenerationResult ASR::generate(std::span<float> waveform, float temperature,
4243 std::vector<float > encoderOutput = this ->encode (waveform);
4344
4445 std::vector<uint64_t > sequenceIds = this ->getInitialSequence (options);
46+ std::vector<uint64_t > cachedTokens = sequenceIds;
4547 const size_t initialSequenceLenght = sequenceIds.size ();
4648 std::vector<float > scores;
4749
48- while (std::cmp_less_equal (sequenceIds.size (), ASR::kMaxDecodeLength )) {
49- std::vector<float > logits = this ->decode (sequenceIds, encoderOutput);
50+ uint64_t startPos = 0 ;
51+ while (std::cmp_less_equal (startPos + sequenceIds.size (),
52+ ASR::kMaxDecodeLength )) {
53+ std::vector<float > logits =
54+ this ->decode (sequenceIds, startPos, encoderOutput);
5055
5156 // intentionally comparing float to float
5257 // temperatures are predefined, so this is safe
@@ -74,16 +79,20 @@ GenerationResult ASR::generate(std::span<float> waveform, float temperature,
7479 nextProb = probs[nextId];
7580 }
7681
77- sequenceIds.push_back (nextId);
82+ // Move the startPos pointer by the amount of tokens we processed
83+ startPos += sequenceIds.size ();
84+ sequenceIds = {nextId};
85+ cachedTokens.push_back (nextId);
7886 scores.push_back (nextProb);
7987
8088 if (nextId == this ->endOfTranscriptionToken ) {
8189 break ;
8290 }
8391 }
8492
85- return {.tokens = std::vector<uint64_t >(
86- sequenceIds.cbegin () + initialSequenceLenght, sequenceIds.cend ()),
93+ return {.tokens = std::vector<uint64_t >(cachedTokens.cbegin () +
94+ initialSequenceLenght,
95+ cachedTokens.cend ()),
8796 .scores = scores};
8897}
8998
@@ -318,13 +327,19 @@ std::vector<float> ASR::encode(std::span<float> waveform) const {
318327 return {dataPtr, dataPtr + outputNumel};
319328}
320329
321- std::vector<float > ASR::decode (std::span<const uint64_t > tokens,
330+ std::vector<float > ASR::decode (std::span<uint64_t > tokens, uint64_t startPos ,
322331 std::span<float > encoderOutput) const {
323332 std::vector<int32_t > tokenShape = {1 , static_cast <int32_t >(tokens.size ())};
324- auto tokensLong = std::vector<int64_t > (tokens.begin (), tokens. end ()) ;
333+ std::vector<int32_t > positionShape = { static_cast < int32_t > (tokens.size ())} ;
325334
326335 auto tokenTensor = executorch::extension::make_tensor_ptr (
327- tokenShape, tokensLong.data (), ScalarType::Long);
336+ tokenShape, tokens.data (), ScalarType::Long);
337+
338+ // Populate cache position vector
339+ std::vector<uint64_t > cachePositions (tokens.size ());
340+ std::iota (cachePositions.begin (), cachePositions.end (), startPos);
341+ auto positionTensor = executorch::extension::make_tensor_ptr (
342+ positionShape, cachePositions.data (), ScalarType::Long);
328343
329344 const auto encoderOutputSize = static_cast <int32_t >(encoderOutput.size ());
330345 std::vector<int32_t > encShape = {1 , ASR::kNumFrames ,
@@ -333,7 +348,7 @@ std::vector<float> ASR::decode(std::span<const uint64_t> tokens,
333348 std::move (encShape), encoderOutput.data (), ScalarType::Float);
334349
335350 const auto decoderResult =
336- this ->decoder ->forward ({tokenTensor, encoderTensor});
351+ this ->decoder ->forward ({tokenTensor, positionTensor, encoderTensor});
337352
338353 if (!decoderResult.ok ()) {
339354 throw RnExecutorchError (decoderResult.error (),
0 commit comments