Skip to content

Commit 7633baf

Browse files
Merge pull request #270 from bernardladenthin/claude/laughing-albattani-bl5h9h
Harden TTS WAV validation and upstream enum pinning
2 parents 2375a40 + 02192f3 commit 7633baf

5 files changed

Lines changed: 123 additions & 20 deletions

File tree

CLAUDE.md

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -410,11 +410,14 @@ Instead the helpers are **DERIVED mechanically at configure time** from the pinn
410410
generated definitions. The in-memory WAV writer (`tts_wav.hpp`) is ours, not extracted.
411411

412412
**Fail-loud on drift (same contract as `patches/`):** the generator asserts every anchor — the
413-
`int main(` split point, each `static <signature>` it de-statics, and both speaker literals. If an
414-
upgrade renames a helper or moves a literal, the **configure step aborts** with a pointer to the
415-
generator; if upstream changes a *type*, `tts_upstream.h` stops matching and the **link fails**.
416-
Either way a silent divergence is impossible. On a llama.cpp bump, re-verify the generator the same
417-
way you re-verify `patches/`.
413+
`int main(` split point, each `static <signature>` it de-statics, the `outetts_version` enum
414+
(enumerators + order, kept ODR-identical to the hand-written copy in `tts_upstream.h`), both
415+
`prompt_add` overloads the header declares (the bare `void prompt_add(` prefix de-statics all three
416+
upstream overloads, so the two the header relies on are pinned individually), and both speaker
417+
literals. If an upgrade renames a helper, reorders the enum, or moves a literal, the **configure step
418+
aborts** with a pointer to the generator; if upstream changes a *type*, `tts_upstream.h` stops
419+
matching and the **link fails**. Either way a silent divergence is impossible. On a llama.cpp bump,
420+
re-verify the generator the same way you re-verify `patches/`.
418421

419422
## Upgrading/Downgrading llama.cpp Version
420423

cmake/generate-tts-upstream.cmake

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,37 @@ foreach(sig IN LISTS JLLAMA_TTS_DESTATIC)
5656
string(REPLACE "static ${sig}" "${sig}" PREMAIN "${PREMAIN}")
5757
endforeach()
5858

59+
# --- 2a. pin the outetts_version enum against the hand-written copy in tts_upstream.h ---
60+
# src/main/cpp/tts_upstream.h re-declares `enum outetts_version { OUTETTS_V0_2, OUTETTS_V0_3 }` because
61+
# it cannot include the generated TU. The two definitions live in different translation units and must
62+
# stay token-identical: if upstream reorders/renames/extends the enum, the generated TU and the header
63+
# would bind the same name to different integer values (a silent miscompile). Capture the upstream enum
64+
# body and compare its enumerator list so a drift fails the configure with a pointer to update the header.
65+
string(REGEX MATCH "enum[ \t\r\n]+outetts_version[ \t\r\n]*{([^}]*)}" _enum_match "${PREMAIN}")
66+
if(_enum_match STREQUAL "")
67+
message(FATAL_ERROR "generate-tts-upstream: 'enum outetts_version' not found in tts.cpp — upstream changed; update cmake/generate-tts-upstream.cmake and src/main/cpp/tts_upstream.h")
68+
endif()
69+
set(_enum_body "${CMAKE_MATCH_1}")
70+
string(REGEX REPLACE "//[^\n]*" "" _enum_body "${_enum_body}") # strip any line comments
71+
string(REGEX REPLACE "[ \t\r\n]+" "" _enum_body "${_enum_body}") # strip all whitespace
72+
string(REGEX REPLACE ",+$" "" _enum_body "${_enum_body}") # strip a trailing comma
73+
if(NOT _enum_body STREQUAL "OUTETTS_V0_2,OUTETTS_V0_3")
74+
message(FATAL_ERROR "generate-tts-upstream: upstream 'enum outetts_version' enumerators are now '${_enum_body}' (expected 'OUTETTS_V0_2,OUTETTS_V0_3'). Update the matching enum in src/main/cpp/tts_upstream.h to keep the two definitions ODR-identical, then update this assertion in cmake/generate-tts-upstream.cmake")
75+
endif()
76+
77+
# --- 2b. verify BOTH prompt_add overloads that tts_upstream.h declares are present ---
78+
# `void prompt_add(` is shared by three upstream overloads; the de-static REPLACE above (correctly) gives
79+
# all of them external linkage, but the single string(FIND) only proves >=1 exists. tts_upstream.h
80+
# declares exactly two — (llama_tokens&, const llama_tokens&) and the (vocab, txt, add_special,
81+
# parse_special) builder — and tts_engine.cpp links against them. Pin both here (whitespace-tolerant) so
82+
# dropping or renaming either fails the configure with a clear pointer instead of a cryptic link error.
83+
if(NOT PREMAIN MATCHES "void[ \t]+prompt_add[ \t]*\\([^)]*const[ \t]+llama_tokens[ \t]*&[ \t]*tokens[ \t]*\\)")
84+
message(FATAL_ERROR "generate-tts-upstream: the prompt_add(llama_tokens&, const llama_tokens&) overload declared in src/main/cpp/tts_upstream.h was not found in tts.cpp — upstream changed; update the de-static list and src/main/cpp/tts_upstream.h")
85+
endif()
86+
if(NOT PREMAIN MATCHES "void[ \t]+prompt_add[ \t]*\\([^)]*vocab[^)]*add_special[^)]*parse_special[^)]*\\)")
87+
message(FATAL_ERROR "generate-tts-upstream: the prompt_add(llama_tokens&, const llama_vocab*, const std::string&, bool, bool) overload declared in src/main/cpp/tts_upstream.h was not found in tts.cpp — upstream changed; update the de-static list and src/main/cpp/tts_upstream.h")
88+
endif()
89+
5990
# --- 3. extract the two default-speaker literals from inside main() ---
6091
# audio_text: a single-line std::string audio_text = "<|text_start|>the<|text_sep|>...";
6192
# The leading "<|text_start|>the<|text_sep|>" disambiguates it from the empty-seed literal

src/main/cpp/tts_engine.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@
2121
#include "llama.h"
2222
#include "sampling.h"
2323

24+
// Full json definition: tts_upstream.h only forward-declares nlohmann::ordered_json (keeping the heavy
25+
// header out of the shared interface), but this TU constructs the empty-object speaker argument for
26+
// get_tts_version(), which needs the complete type.
27+
#include <nlohmann/json.hpp>
28+
2429
#include <algorithm>
2530
#include <cstdint>
2631
#include <regex>
@@ -67,7 +72,9 @@ tts_engine *engine_init(const std::string &ttc_model_path, const std::string &ct
6772
return nullptr;
6873
}
6974
engine->vocab = llama_model_get_vocab(engine->model_ttc);
70-
engine->tts_version = get_tts_version(engine->model_ttc);
75+
// Explicit empty-object speaker: tts_upstream.h declares no default (it forward-declares json), so
76+
// the default lives only in the generated TU. We always use the built-in default speaker profile.
77+
engine->tts_version = get_tts_version(engine->model_ttc, nlohmann::ordered_json::object());
7178

7279
// Codes-to-speech (CTS) vocoder, loaded in embedding mode.
7380
params.model.path = cts_model_path;
@@ -202,13 +209,19 @@ bool engine_synthesize(tts_engine *engine, const std::string &text, int n_predic
202209
}
203210
llama_synchronize(engine->ctx_cts);
204211

212+
// llama_model_n_embd_out (not llama_model_n_embd): read the vocoder's OUTPUT embedding width, which
213+
// is what llama_get_embeddings returns here. This matches upstream tts.cpp, which also queries
214+
// llama_model_n_embd_out at this step.
205215
const int n_embd = llama_model_n_embd_out(engine->model_cts);
206216
const float *embd = llama_get_embeddings(engine->ctx_cts);
207217
std::vector<float> audio = embd_to_audio(embd, n_codes, n_embd, engine->n_threads);
208218
llama_batch_free(cts_batch);
209219

210-
// Zero the first 0.25 s (suppresses a leading click).
220+
// 24 kHz mono — the OuteTTS / WavTokenizer output rate.
211221
const int n_sr = 24000;
222+
// Zero the first 0.25 s, mirroring upstream tts.cpp's post-vocoder cleanup (it suppresses a leading
223+
// click). The `&& i < audio.size()` guard is ours: it keeps the loop in-bounds for clips shorter
224+
// than 0.25 s, where upstream's fixed 24000/4 bound would read past the buffer.
212225
for (int i = 0; i < n_sr / 4 && i < (int)audio.size(); ++i) {
213226
audio[i] = 0.0f;
214227
}

src/main/cpp/tts_upstream.h

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,21 @@
1515
#include <string>
1616
#include <vector>
1717

18-
#include <nlohmann/json.hpp>
18+
// Forward declarations only. This shared interface header names nlohmann::ordered_json once (the
19+
// get_tts_version() speaker parameter) but never instantiates it, so it must not pull the full
20+
// ~25k-line <nlohmann/json.hpp> into every translation unit that includes it. The single caller that
21+
// constructs the empty-object default (tts_engine.cpp) includes the full <nlohmann/json.hpp> itself.
22+
#include <nlohmann/json_fwd.hpp>
1923

2024
#include "common.h" // llama_tokens
2125
#include "llama.h" // llama_model, llama_vocab, llama_token
2226

23-
// Mirrors the upstream enum (identical definition; ODR-compatible across translation units).
27+
// Mirrors the upstream enum (identical definition; ODR-compatible across translation units). The
28+
// generated TU carries upstream's own copy, so these enumerators and their order MUST stay
29+
// token-identical to upstream — otherwise the two definitions assign different integer values to the
30+
// same name (a silent miscompile). cmake/generate-tts-upstream.cmake asserts the upstream enum still
31+
// reads `{ OUTETTS_V0_2, OUTETTS_V0_3 }` at configure time and fails loud (pointing here) if a
32+
// llama.cpp bump changes it.
2433
enum outetts_version { OUTETTS_V0_2, OUTETTS_V0_3 };
2534

2635
// --- derived from upstream tts.cpp (defined in the generated translation unit) ---
@@ -40,7 +49,11 @@ void prompt_init(llama_tokens &prompt, const llama_vocab *vocab);
4049
std::vector<llama_token> prepare_guide_tokens(const llama_vocab *vocab, const std::string &str,
4150
outetts_version tts_version);
4251

43-
outetts_version get_tts_version(llama_model *model, nlohmann::ordered_json speaker = nlohmann::ordered_json::object());
52+
// No default argument here on purpose: constructing nlohmann::ordered_json::object() needs the full
53+
// json definition, which this header deliberately does not include (see the json_fwd note above). The
54+
// sole caller (tts_engine.cpp) passes an explicit empty object; the generated TU keeps upstream's own
55+
// default, so its internal calls are unaffected.
56+
outetts_version get_tts_version(llama_model *model, nlohmann::ordered_json speaker);
4457

4558
// Default OuteTTS speaker profile, extracted from upstream main() into the generated TU.
4659
extern const std::string jllama_tts_default_audio_text;

src/test/java/net/ladenthin/llama/TtsIntegrationTest.java

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
import static org.junit.jupiter.api.Assertions.assertTrue;
1010

1111
import java.io.File;
12+
import java.nio.ByteBuffer;
13+
import java.nio.ByteOrder;
14+
import java.nio.charset.StandardCharsets;
1215
import java.util.concurrent.TimeUnit;
1316
import org.junit.jupiter.api.Assumptions;
1417
import org.junit.jupiter.api.DisplayName;
@@ -26,8 +29,11 @@
2629
*/
2730
public class TtsIntegrationTest {
2831

32+
/** Canonical RIFF/WAVE header size in bytes (16-bit PCM, no extra chunks). */
33+
private static final int WAV_HEADER_BYTES = 44;
34+
2935
@Test
30-
@DisplayName("synthesize() returns a well-formed 16-bit WAV byte stream")
36+
@DisplayName("synthesize() returns a well-formed, non-silent 24 kHz mono 16-bit WAV")
3137
@Timeout(value = 300_000, unit = TimeUnit.MILLISECONDS)
3238
public void synthesizesWellFormedWav() {
3339
String ttc = System.getProperty(TestConstants.PROP_TTS_TTC_MODEL);
@@ -45,15 +51,52 @@ public void synthesizesWellFormedWav() {
4551
byte[] wav = tts.synthesize("hello from llama");
4652

4753
assertNotNull(wav, "WAV bytes must not be null");
48-
assertTrue(wav.length > 44, "WAV must carry a header plus samples; got " + wav.length + " bytes");
49-
assertEquals('R', (char) wav[0]);
50-
assertEquals('I', (char) wav[1]);
51-
assertEquals('F', (char) wav[2]);
52-
assertEquals('F', (char) wav[3]);
53-
assertEquals('W', (char) wav[8]);
54-
assertEquals('A', (char) wav[9]);
55-
assertEquals('V', (char) wav[10]);
56-
assertEquals('E', (char) wav[11]);
54+
// A bare 44-byte header with no payload is not a valid clip: require real samples beyond it.
55+
assertTrue(
56+
wav.length > WAV_HEADER_BYTES,
57+
"WAV must carry a header plus samples; got " + wav.length + " bytes");
58+
59+
// RIFF/WAVE container magic.
60+
assertEquals("RIFF", tag(wav, 0), "RIFF magic");
61+
assertEquals("WAVE", tag(wav, 8), "WAVE magic");
62+
assertEquals("fmt ", tag(wav, 12), "fmt subchunk tag");
63+
assertEquals("data", tag(wav, 36), "data subchunk tag");
64+
65+
// fmt fields must match the documented output format: 24 kHz mono 16-bit PCM. A mis-loaded
66+
// model that still framed a header would not silently pass with the wrong rate/channels.
67+
ByteBuffer header = ByteBuffer.wrap(wav).order(ByteOrder.LITTLE_ENDIAN);
68+
assertEquals(1, header.getShort(20) & 0xFFFF, "audio format must be PCM (1)");
69+
assertEquals(1, header.getShort(22) & 0xFFFF, "must be mono (1 channel)");
70+
assertEquals(24_000, header.getInt(24), "sample rate must be 24 kHz");
71+
assertEquals(16, header.getShort(34) & 0xFFFF, "must be 16-bit samples");
72+
73+
// Declared chunk sizes must be self-consistent with the actual byte-array length.
74+
assertEquals(wav.length - 8, header.getInt(4), "RIFF chunk size must equal fileLength - 8");
75+
int dataSize = header.getInt(40);
76+
assertEquals(wav.length - WAV_HEADER_BYTES, dataSize, "data chunk size must equal fileLength - 44");
77+
assertEquals(0, dataSize % 2, "16-bit PCM data size must be even");
78+
79+
// The clip must contain real audio, not just the zeroed 0.25 s lead-in (or the all-silent
80+
// buffer a mis-configured model could still frame inside an otherwise valid header). The
81+
// original `length > 44` check passed on a single padding byte; scan the PCM payload instead.
82+
assertTrue(
83+
hasNonZeroSample(wav, WAV_HEADER_BYTES),
84+
"synthesized PCM must contain audible (non-zero) samples, not pure silence");
85+
}
86+
}
87+
88+
/** Reads the 4-byte ASCII chunk tag at {@code offset}. */
89+
private static String tag(byte[] wav, int offset) {
90+
return new String(wav, offset, 4, StandardCharsets.US_ASCII);
91+
}
92+
93+
/** Returns {@code true} if any byte of the PCM payload at or after {@code from} is non-zero. */
94+
private static boolean hasNonZeroSample(byte[] wav, int from) {
95+
for (int i = from; i < wav.length; i++) {
96+
if (wav[i] != 0) {
97+
return true;
98+
}
5799
}
100+
return false;
58101
}
59102
}

0 commit comments

Comments
 (0)