@@ -38,38 +38,39 @@ Kokoro::Kokoro(const std::string &lang, const std::string &taggerDataSource,
3838}
3939
4040void Kokoro::loadVoice (const std::string &voiceSource) {
41- constexpr size_t rows = static_cast <size_t >(constants::kMaxInputTokens );
42- constexpr size_t cols = static_cast <size_t >(constants::kVoiceRefSize ); // 256
43- const size_t expectedCount = rows * cols;
44- const std::streamsize expectedBytes =
45- static_cast <std::streamsize>(expectedCount * sizeof (float ));
41+ constexpr size_t cols = static_cast <size_t >(constants::kVoiceRefSize );
42+ constexpr size_t bytesPerRow = cols * sizeof (float );
4643
4744 std::ifstream in (voiceSource, std::ios::binary);
4845 if (!in) {
4946 throw RnExecutorchError (RnExecutorchErrorCode::FileReadFailed,
50- " [Kokoro::loadSingleVoice ]: cannot open file: " +
47+ " [Kokoro::loadVoice ]: cannot open file: " +
5148 voiceSource);
5249 }
5350
54- // Check the file size
51+ // Determine number of rows from file size
5552 in.seekg (0 , std::ios::end);
56- const std::streamsize fileSize = in.tellg ();
53+ const auto fileSize = static_cast < size_t >( in.tellg () );
5754 in.seekg (0 , std::ios::beg);
58- if (fileSize < expectedBytes) {
55+
56+ if (fileSize < bytesPerRow) {
5957 throw RnExecutorchError (
6058 RnExecutorchErrorCode::FileReadFailed,
61- " [Kokoro::loadSingleVoice ]: file too small: expected at least " +
62- std::to_string (expectedBytes ) + " bytes, got " +
59+ " [Kokoro::loadVoice ]: file too small: need at least " +
60+ std::to_string (bytesPerRow ) + " bytes for one row , got " +
6361 std::to_string (fileSize));
6462 }
6563
66- // Read [rows, 1, cols] as contiguous floats directly into voice_
67- // ([rows][cols])
68- if (!in.read (reinterpret_cast <char *>(voice_.data ()->data ()),
69- expectedBytes)) {
64+ const size_t rows = fileSize / bytesPerRow;
65+ const auto readBytes = static_cast <std::streamsize>(rows * bytesPerRow);
66+
67+ // Resize voice vector to hold all rows from the file
68+ voice_.resize (rows);
69+
70+ if (!in.read (reinterpret_cast <char *>(voice_.data ()->data ()), readBytes)) {
7071 throw RnExecutorchError (
7172 RnExecutorchErrorCode::FileReadFailed,
72- " [Kokoro::loadSingleVoice ]: failed to read voice weights" );
73+ " [Kokoro::loadVoice ]: failed to read voice weights" );
7374 }
7475}
7576
@@ -98,13 +99,10 @@ std::vector<float> Kokoro::generate(std::string text, float speed) {
9899 size_t pauseMs = params::kPauseValues .contains (lastPhoneme)
99100 ? params::kPauseValues .at (lastPhoneme)
100101 : params::kDefaultPause ;
101- std::vector<float > pause (pauseMs * constants::kSamplesPerMilisecond , 0 .F );
102102
103- // Add audio part and pause to the main audio vector
104- audio.insert (audio.end (), std::make_move_iterator (audioPart.begin ()),
105- std::make_move_iterator (audioPart.end ()));
106- audio.insert (audio.end (), std::make_move_iterator (pause.begin ()),
107- std::make_move_iterator (pause.end ()));
103+ // Add audio part and silence pause to the main audio vector
104+ audio.insert (audio.end (), audioPart.begin (), audioPart.end ());
105+ audio.resize (audio.size () + pauseMs * constants::kSamplesPerMilisecond , 0 .F );
108106 }
109107
110108 return audio;
@@ -118,12 +116,13 @@ void Kokoro::stream(std::string text, float speed,
118116 }
119117
120118 // Build a full callback function
121- auto nativeCallback = [this , callback](const std::vector<float > & audioVec) {
119+ auto nativeCallback = [this , callback](std::vector<float > audioVec) {
122120 if (this ->isStreaming_ ) {
123- this ->callInvoker_ ->invokeAsync ([callback, audioVec](jsi::Runtime &rt) {
124- callback->call (rt,
125- rnexecutorch::jsi_conversion::getJsiValue (audioVec, rt));
126- });
121+ this ->callInvoker_ ->invokeAsync (
122+ [callback, audioVec = std::move (audioVec)](jsi::Runtime &rt) {
123+ callback->call (
124+ rt, rnexecutorch::jsi_conversion::getJsiValue (audioVec, rt));
125+ });
127126 }
128127 };
129128
@@ -166,14 +165,12 @@ void Kokoro::stream(std::string text, float speed,
166165 size_t pauseMs = params::kPauseValues .contains (lastPhoneme)
167166 ? params::kPauseValues .at (lastPhoneme)
168167 : params::kDefaultPause ;
169- std::vector<float > pause (pauseMs * constants::kSamplesPerMilisecond , 0 .F );
170168
171- // Add pause to the audio vector
172- audioPart.insert (audioPart.end (), std::make_move_iterator (pause.begin ()),
173- std::make_move_iterator (pause.end ()));
169+ // Append silence pause directly
170+ audioPart.resize (audioPart.size () + pauseMs * constants::kSamplesPerMilisecond , 0 .F );
174171
175172 // Push the audio right away to the JS side
176- nativeCallback (audioPart);
173+ nativeCallback (std::move ( audioPart) );
177174 }
178175
179176 // Mark the end of the streaming process
@@ -188,41 +185,62 @@ std::vector<float> Kokoro::synthesize(const std::u32string &phonemes,
188185 return {};
189186 }
190187
191- // Clamp the input to not go beyond number of input token limits
192- // Note that 2 tokens are always reserved for pre- and post-fix padding,
193- // so we effectively take at most (maxNoInputTokens_ - 2) tokens.
194- size_t noTokens = std::clamp (phonemes.size () + 2 , constants::kMinInputTokens ,
188+ // Clamp token count: phonemes + 2 padding tokens (leading + trailing zero)
189+ size_t dpTokens = std::clamp (phonemes.size () + 2 ,
190+ constants::kMinInputTokens ,
195191 context_.inputTokensLimit );
196192
197- // Map phonemes to tokens
198- const auto tokens = utils::tokenize (phonemes, {noTokens });
193+ // Map phonemes to tokens, padded to dpTokens
194+ auto tokens = utils::tokenize (phonemes, {dpTokens });
199195
200196 // Select the appropriate voice vector
201- size_t voiceID = std::min (phonemes.size () - 1 , noTokens);
197+ size_t voiceID = std::min ({phonemes.size () - 1 , dpTokens - 1 ,
198+ voice_.size () - 1 });
202199 auto &voice = voice_[voiceID];
203200
204- // Initialize text mask
205- // Exclude all the paddings apart from first and last one.
206- size_t realInputLength = std::min (phonemes.size () + 2 , noTokens);
207- std::vector<uint8_t > textMask (noTokens, false );
201+ // Initialize text mask for DP
202+ size_t realInputLength = std::min (phonemes.size () + 2 , dpTokens);
203+ std::vector<uint8_t > textMask (dpTokens, false );
208204 std::fill (textMask.begin (), textMask.begin () + realInputLength, true );
209205
210206 // Inference 1 - DurationPredictor
211- // The resulting duration vector is already scalled at this point
212207 auto [d, indices, effectiveDuration] = durationPredictor_.generate (
213208 std::span (tokens),
214209 std::span (reinterpret_cast <bool *>(textMask.data ()), textMask.size ()),
215210 std::span (voice).last (constants::kVoiceRefHalfSize ), speed);
216211
212+ // --- Synthesizer phase ---
213+ // The Synthesizer may have different method sizes than the DP.
214+ // Pad all inputs to the Synthesizer's selected method size.
215+ size_t synthTokens = synthesizer_.getMethodTokenCount (dpTokens);
216+ size_t dCols = d.sizes ().back (); // 640
217+
218+ // Pad tokens and textMask to synthTokens (no-op when synthTokens == dpTokens)
219+ tokens.resize (synthTokens, 0 );
220+ textMask.resize (synthTokens, false );
221+
222+ // Pad indices to the maximum duration limit
223+ indices.resize (context_.inputDurationLimit , 0 );
224+
225+ // Prepare duration data for Synthesizer.
226+ // When sizes match, pass the DP tensor directly to avoid a 320KB copy.
227+ size_t durSize = synthTokens * dCols;
228+ std::vector<float > durPadded;
229+ float *durPtr;
230+ if (synthTokens == dpTokens) {
231+ durPtr = d.mutable_data_ptr <float >();
232+ } else {
233+ durPadded.resize (durSize, 0 .0f );
234+ std::copy_n (d.const_data_ptr <float >(), dpTokens * dCols, durPadded.data ());
235+ durPtr = durPadded.data ();
236+ }
237+
217238 // Inference 2 - Synthesizer
218239 auto decoding = synthesizer_.generate (
219240 std::span (tokens),
220241 std::span (reinterpret_cast <bool *>(textMask.data ()), textMask.size ()),
221242 std::span (indices),
222- // Note that we reduce the size of d tensor to match the initial number of
223- // input tokens
224- std::span<float >(d.mutable_data_ptr <float >(),
225- noTokens * d.sizes ().back ()),
243+ std::span<float >(durPtr, durSize),
226244 std::span (voice));
227245 auto audioTensor = decoding->at (0 ).toTensor ();
228246
@@ -233,9 +251,7 @@ std::vector<float> Kokoro::synthesize(const std::u32string &phonemes,
233251 auto croppedAudio =
234252 utils::stripAudio (audio, paddingMs * constants::kSamplesPerMilisecond );
235253
236- std::vector<float > result (croppedAudio.begin (), croppedAudio.end ());
237-
238- return result;
254+ return {croppedAudio.begin (), croppedAudio.end ()};
239255}
240256
241257std::size_t Kokoro::getMemoryLowerBound () const noexcept {
0 commit comments