Skip to content

Commit 1877fe7

Browse files
committed
fix(kokoro): voice loading, Synth method selection, async model init
- Voice loading: read all rows from voice file instead of truncating to kMaxInputTokens (128). Voice files have 510 rows; upstream discards 382. Changed voice_ from fixed std::array to std::vector sized from file. - Synthesizer method selection: discover forward_N methods at construction (same pattern DurationPredictor already uses). Falls back to "forward" for older models. Uses execute() instead of forward() for named methods. - voiceID bounds: use three-way min(phonemes-1, dpTokens-1, voice_.size()-1) to prevent OOB access. Upstream had a latent OOB bug with voiceID=noTokens on a 128-element array. - Async model construction: move heavy model init off the JS thread via Promise + GlobalThreadPool. JSI object creation stays on JS thread. JS side adds await to loadTextToSpeechKokoro() and loadLLM(). - Pad indices to inputDurationLimit before Synthesizer to avoid XNNPACK shape mismatch on repeated calls with varying duration predictions. - Perf: skip durPadded copy when DP/Synth use same token count, use resize() for silence padding, move-capture in streaming callback.
1 parent f293cb2 commit 1877fe7

File tree

9 files changed

+172
-145
lines changed

9 files changed

+172
-145
lines changed

packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/DurationPredictor.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22

3+
#include <algorithm>
34
#include <memory>
45
#include <span>
56
#include <string>
@@ -44,6 +45,14 @@ class DurationPredictor : public BaseModel {
4445
// Returns maximum supported amount of input tokens.
4546
size_t getTokensLimit() const;
4647

48+
// Returns the token count of the forward method that would be selected
49+
// for a given input size. E.g., input 37 -> returns 64 (forward_64).
50+
size_t getMethodTokenCount(size_t inputSize) const {
51+
auto it = std::ranges::find_if(forwardMethods_,
52+
[inputSize](const auto &e) { return e.second >= inputSize; });
53+
return (it != forwardMethods_.end()) ? it->second : forwardMethods_.back().second;
54+
}
55+
4756
private:
4857
// Helper function - duration scalling
4958
// Performs integer scaling on the durations tensor to ensure the sum of

packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Kokoro.cpp

Lines changed: 67 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -38,38 +38,39 @@ Kokoro::Kokoro(const std::string &lang, const std::string &taggerDataSource,
3838
}
3939

4040
void Kokoro::loadVoice(const std::string &voiceSource) {
41-
constexpr size_t rows = static_cast<size_t>(constants::kMaxInputTokens);
42-
constexpr size_t cols = static_cast<size_t>(constants::kVoiceRefSize); // 256
43-
const size_t expectedCount = rows * cols;
44-
const std::streamsize expectedBytes =
45-
static_cast<std::streamsize>(expectedCount * sizeof(float));
41+
constexpr size_t cols = static_cast<size_t>(constants::kVoiceRefSize);
42+
constexpr size_t bytesPerRow = cols * sizeof(float);
4643

4744
std::ifstream in(voiceSource, std::ios::binary);
4845
if (!in) {
4946
throw RnExecutorchError(RnExecutorchErrorCode::FileReadFailed,
50-
"[Kokoro::loadSingleVoice]: cannot open file: " +
47+
"[Kokoro::loadVoice]: cannot open file: " +
5148
voiceSource);
5249
}
5350

54-
// Check the file size
51+
// Determine number of rows from file size
5552
in.seekg(0, std::ios::end);
56-
const std::streamsize fileSize = in.tellg();
53+
const auto fileSize = static_cast<size_t>(in.tellg());
5754
in.seekg(0, std::ios::beg);
58-
if (fileSize < expectedBytes) {
55+
56+
if (fileSize < bytesPerRow) {
5957
throw RnExecutorchError(
6058
RnExecutorchErrorCode::FileReadFailed,
61-
"[Kokoro::loadSingleVoice]: file too small: expected at least " +
62-
std::to_string(expectedBytes) + " bytes, got " +
59+
"[Kokoro::loadVoice]: file too small: need at least " +
60+
std::to_string(bytesPerRow) + " bytes for one row, got " +
6361
std::to_string(fileSize));
6462
}
6563

66-
// Read [rows, 1, cols] as contiguous floats directly into voice_
67-
// ([rows][cols])
68-
if (!in.read(reinterpret_cast<char *>(voice_.data()->data()),
69-
expectedBytes)) {
64+
const size_t rows = fileSize / bytesPerRow;
65+
const auto readBytes = static_cast<std::streamsize>(rows * bytesPerRow);
66+
67+
// Resize voice vector to hold all rows from the file
68+
voice_.resize(rows);
69+
70+
if (!in.read(reinterpret_cast<char *>(voice_.data()->data()), readBytes)) {
7071
throw RnExecutorchError(
7172
RnExecutorchErrorCode::FileReadFailed,
72-
"[Kokoro::loadSingleVoice]: failed to read voice weights");
73+
"[Kokoro::loadVoice]: failed to read voice weights");
7374
}
7475
}
7576

@@ -98,13 +99,10 @@ std::vector<float> Kokoro::generate(std::string text, float speed) {
9899
size_t pauseMs = params::kPauseValues.contains(lastPhoneme)
99100
? params::kPauseValues.at(lastPhoneme)
100101
: params::kDefaultPause;
101-
std::vector<float> pause(pauseMs * constants::kSamplesPerMilisecond, 0.F);
102102

103-
// Add audio part and pause to the main audio vector
104-
audio.insert(audio.end(), std::make_move_iterator(audioPart.begin()),
105-
std::make_move_iterator(audioPart.end()));
106-
audio.insert(audio.end(), std::make_move_iterator(pause.begin()),
107-
std::make_move_iterator(pause.end()));
103+
// Add audio part and silence pause to the main audio vector
104+
audio.insert(audio.end(), audioPart.begin(), audioPart.end());
105+
audio.resize(audio.size() + pauseMs * constants::kSamplesPerMilisecond, 0.F);
108106
}
109107

110108
return audio;
@@ -118,12 +116,13 @@ void Kokoro::stream(std::string text, float speed,
118116
}
119117

120118
// Build a full callback function
121-
auto nativeCallback = [this, callback](const std::vector<float> &audioVec) {
119+
auto nativeCallback = [this, callback](std::vector<float> audioVec) {
122120
if (this->isStreaming_) {
123-
this->callInvoker_->invokeAsync([callback, audioVec](jsi::Runtime &rt) {
124-
callback->call(rt,
125-
rnexecutorch::jsi_conversion::getJsiValue(audioVec, rt));
126-
});
121+
this->callInvoker_->invokeAsync(
122+
[callback, audioVec = std::move(audioVec)](jsi::Runtime &rt) {
123+
callback->call(
124+
rt, rnexecutorch::jsi_conversion::getJsiValue(audioVec, rt));
125+
});
127126
}
128127
};
129128

@@ -166,14 +165,12 @@ void Kokoro::stream(std::string text, float speed,
166165
size_t pauseMs = params::kPauseValues.contains(lastPhoneme)
167166
? params::kPauseValues.at(lastPhoneme)
168167
: params::kDefaultPause;
169-
std::vector<float> pause(pauseMs * constants::kSamplesPerMilisecond, 0.F);
170168

171-
// Add pause to the audio vector
172-
audioPart.insert(audioPart.end(), std::make_move_iterator(pause.begin()),
173-
std::make_move_iterator(pause.end()));
169+
// Append silence pause directly
170+
audioPart.resize(audioPart.size() + pauseMs * constants::kSamplesPerMilisecond, 0.F);
174171

175172
// Push the audio right away to the JS side
176-
nativeCallback(audioPart);
173+
nativeCallback(std::move(audioPart));
177174
}
178175

179176
// Mark the end of the streaming process
@@ -188,41 +185,62 @@ std::vector<float> Kokoro::synthesize(const std::u32string &phonemes,
188185
return {};
189186
}
190187

191-
// Clamp the input to not go beyond number of input token limits
192-
// Note that 2 tokens are always reserved for pre- and post-fix padding,
193-
// so we effectively take at most (maxNoInputTokens_ - 2) tokens.
194-
size_t noTokens = std::clamp(phonemes.size() + 2, constants::kMinInputTokens,
188+
// Clamp token count: phonemes + 2 padding tokens (leading + trailing zero)
189+
size_t dpTokens = std::clamp(phonemes.size() + 2,
190+
constants::kMinInputTokens,
195191
context_.inputTokensLimit);
196192

197-
// Map phonemes to tokens
198-
const auto tokens = utils::tokenize(phonemes, {noTokens});
193+
// Map phonemes to tokens, padded to dpTokens
194+
auto tokens = utils::tokenize(phonemes, {dpTokens});
199195

200196
// Select the appropriate voice vector
201-
size_t voiceID = std::min(phonemes.size() - 1, noTokens);
197+
size_t voiceID = std::min({phonemes.size() - 1, dpTokens - 1,
198+
voice_.size() - 1});
202199
auto &voice = voice_[voiceID];
203200

204-
// Initialize text mask
205-
// Exclude all the paddings apart from first and last one.
206-
size_t realInputLength = std::min(phonemes.size() + 2, noTokens);
207-
std::vector<uint8_t> textMask(noTokens, false);
201+
// Initialize text mask for DP
202+
size_t realInputLength = std::min(phonemes.size() + 2, dpTokens);
203+
std::vector<uint8_t> textMask(dpTokens, false);
208204
std::fill(textMask.begin(), textMask.begin() + realInputLength, true);
209205

210206
// Inference 1 - DurationPredictor
211-
// The resulting duration vector is already scalled at this point
212207
auto [d, indices, effectiveDuration] = durationPredictor_.generate(
213208
std::span(tokens),
214209
std::span(reinterpret_cast<bool *>(textMask.data()), textMask.size()),
215210
std::span(voice).last(constants::kVoiceRefHalfSize), speed);
216211

212+
// --- Synthesizer phase ---
213+
// The Synthesizer may have different method sizes than the DP.
214+
// Pad all inputs to the Synthesizer's selected method size.
215+
size_t synthTokens = synthesizer_.getMethodTokenCount(dpTokens);
216+
size_t dCols = d.sizes().back(); // 640
217+
218+
// Pad tokens and textMask to synthTokens (no-op when synthTokens == dpTokens)
219+
tokens.resize(synthTokens, 0);
220+
textMask.resize(synthTokens, false);
221+
222+
// Pad indices to the maximum duration limit
223+
indices.resize(context_.inputDurationLimit, 0);
224+
225+
// Prepare duration data for Synthesizer.
226+
// When sizes match, pass the DP tensor directly to avoid a 320KB copy.
227+
size_t durSize = synthTokens * dCols;
228+
std::vector<float> durPadded;
229+
float *durPtr;
230+
if (synthTokens == dpTokens) {
231+
durPtr = d.mutable_data_ptr<float>();
232+
} else {
233+
durPadded.resize(durSize, 0.0f);
234+
std::copy_n(d.const_data_ptr<float>(), dpTokens * dCols, durPadded.data());
235+
durPtr = durPadded.data();
236+
}
237+
217238
// Inference 2 - Synthesizer
218239
auto decoding = synthesizer_.generate(
219240
std::span(tokens),
220241
std::span(reinterpret_cast<bool *>(textMask.data()), textMask.size()),
221242
std::span(indices),
222-
// Note that we reduce the size of d tensor to match the initial number of
223-
// input tokens
224-
std::span<float>(d.mutable_data_ptr<float>(),
225-
noTokens * d.sizes().back()),
243+
std::span<float>(durPtr, durSize),
226244
std::span(voice));
227245
auto audioTensor = decoding->at(0).toTensor();
228246

@@ -233,9 +251,7 @@ std::vector<float> Kokoro::synthesize(const std::u32string &phonemes,
233251
auto croppedAudio =
234252
utils::stripAudio(audio, paddingMs * constants::kSamplesPerMilisecond);
235253

236-
std::vector<float> result(croppedAudio.begin(), croppedAudio.end());
237-
238-
return result;
254+
return {croppedAudio.begin(), croppedAudio.end()};
239255
}
240256

241257
std::size_t Kokoro::getMemoryLowerBound() const noexcept {

packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Kokoro.h

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,9 @@ class Kokoro {
5858
DurationPredictor durationPredictor_;
5959
Synthesizer synthesizer_;
6060

61-
// Voice array
62-
// There is a separate voice vector for each of the possible numbers of input
63-
// tokens.
64-
std::array<std::array<float, constants::kVoiceRefSize>,
65-
constants::kMaxInputTokens>
66-
voice_;
61+
// Voice array — dynamically sized to match the voice file.
62+
// Each row is a style vector for a given input token count.
63+
std::vector<std::array<float, constants::kVoiceRefSize>> voice_;
6764

6865
// Extra control variables
6966
bool isStreaming_ = false;

packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Synthesizer.cpp

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,34 @@ Synthesizer::Synthesizer(const std::string &modelSource,
1313
const Context &modelContext,
1414
std::shared_ptr<react::CallInvoker> callInvoker)
1515
: BaseModel(modelSource, callInvoker), context_(modelContext) {
16-
const auto inputTensors = getAllInputShapes("forward");
16+
// Discover all forward methods (forward, forward_8, forward_32, etc.)
17+
auto availableMethods = module_->method_names();
18+
if (availableMethods.ok()) {
19+
const auto &names = *availableMethods;
20+
for (const auto &name : names) {
21+
if (name.rfind("forward", 0) == 0) {
22+
const auto inputTensors = getAllInputShapes(name);
23+
CHECK_SIZE(inputTensors, 5);
24+
CHECK_SIZE(inputTensors[0], 2);
25+
CHECK_SIZE(inputTensors[1], 2);
26+
CHECK_SIZE(inputTensors[2], 1);
27+
size_t inputSize = inputTensors[0][1];
28+
forwardMethods_.emplace_back(name, inputSize);
29+
}
30+
}
31+
std::stable_sort(forwardMethods_.begin(), forwardMethods_.end(),
32+
[](const auto &a, const auto &b) { return a.second < b.second; });
33+
}
1734

18-
// Perform checks to validate model's compatibility with native code
19-
CHECK_SIZE(inputTensors, 5);
20-
CHECK_SIZE(
21-
inputTensors[0],
22-
2); // input tokens must be of shape {1, T}, where T is number of tokens
23-
CHECK_SIZE(
24-
inputTensors[1],
25-
2); // text mask must be of shape {1, T}, where T is number of tokens
26-
CHECK_SIZE(inputTensors[2],
27-
1); // indices must be of shape {D}, where D is a maximum duration
35+
// Fallback: if no methods discovered, validate "forward" directly
36+
if (forwardMethods_.empty()) {
37+
const auto inputTensors = getAllInputShapes("forward");
38+
CHECK_SIZE(inputTensors, 5);
39+
CHECK_SIZE(inputTensors[0], 2);
40+
CHECK_SIZE(inputTensors[1], 2);
41+
CHECK_SIZE(inputTensors[2], 1);
42+
forwardMethods_.emplace_back("forward", inputTensors[0][1]);
43+
}
2844
}
2945

3046
Result<std::vector<EValue>> Synthesizer::generate(std::span<const Token> tokens,
@@ -54,14 +70,19 @@ Result<std::vector<EValue>> Synthesizer::generate(std::span<const Token> tokens,
5470
auto voiceRefTensor = make_tensor_ptr({1, constants::kVoiceRefSize},
5571
ref_s.data(), ScalarType::Float);
5672

57-
// Execute the appropriate "forward_xyz" method, based on given method name
58-
auto results = forward(
73+
// Select appropriate forward method based on token count
74+
auto it = std::find_if(forwardMethods_.begin(), forwardMethods_.end(),
75+
[noTokens](const auto &entry) { return static_cast<int32_t>(entry.second) >= noTokens; });
76+
std::string selectedMethod = (it != forwardMethods_.end()) ? it->first : forwardMethods_.back().first;
77+
78+
// Execute the selected forward method
79+
auto results = execute(selectedMethod,
5980
{tokensTensor, textMaskTensor, indicesTensor, durTensor, voiceRefTensor});
6081

6182
if (!results.ok()) {
6283
throw RnExecutorchError(
6384
RnExecutorchErrorCode::InvalidModelOutput,
64-
"[Kokoro::Synthesizer] Failed to execute method forward"
85+
"[Kokoro::Synthesizer] Failed to execute method " + selectedMethod +
6586
", error: " +
6687
std::to_string(static_cast<uint32_t>(results.error())));
6788
}
@@ -72,13 +93,12 @@ Result<std::vector<EValue>> Synthesizer::generate(std::span<const Token> tokens,
7293
}
7394

7495
size_t Synthesizer::getTokensLimit() const {
75-
// Returns tokens input (shape {1, T}) second dim
76-
return getInputShape("forward", 0)[1];
96+
return forwardMethods_.empty() ? 0 : forwardMethods_.back().second;
7797
}
7898

7999
size_t Synthesizer::getDurationLimit() const {
80-
// Returns indices vector first dim (shape {D})
81-
return getInputShape("forward", 2)[0];
100+
if (forwardMethods_.empty()) return 0;
101+
return getInputShape(forwardMethods_.back().first, 2)[0];
82102
}
83103

84104
} // namespace rnexecutorch::models::text_to_speech::kokoro

packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Synthesizer.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22

3+
#include <algorithm>
34
#include <memory>
45
#include <span>
56
#include <string>
@@ -49,7 +50,17 @@ class Synthesizer : public BaseModel {
4950
size_t getTokensLimit() const;
5051
size_t getDurationLimit() const;
5152

53+
// Returns the token count of the forward method that would be selected
54+
// for a given input size. E.g., input 37 -> returns 64 (forward_64).
55+
size_t getMethodTokenCount(size_t inputSize) const {
56+
auto it = std::ranges::find_if(forwardMethods_,
57+
[inputSize](const auto &e) { return e.second >= inputSize; });
58+
return (it != forwardMethods_.end()) ? it->second : forwardMethods_.back().second;
59+
}
60+
5261
private:
62+
// Forward methods discovered at construction (e.g. forward_8, forward_64, forward_128)
63+
std::vector<std::pair<std::string, size_t>> forwardMethods_;
5364
// Shared model context
5465
// A const reference to singleton in Kokoro.
5566
const Context &context_;

0 commit comments

Comments
 (0)