diff --git a/README.md b/README.md index 517e9e1..b909902 100644 --- a/README.md +++ b/README.md @@ -258,7 +258,7 @@ The first startup may spend extra time downloading assets if `models/` does not An Android ONNX Runtime smoke example is available under [`examples/android_onnx_runtime`](./examples/android_onnx_runtime). -The example loads the exported MOSS-TTS-Nano ONNX graphs and the MOSS-Audio-Tokenizer-Nano ONNX decoder on device, synthesizes short pre-tokenized demo prompts, and writes a WAV file from Android. It is intentionally minimal and keeps model files outside the APK for local testing. +The example loads the exported MOSS-TTS-Nano ONNX graphs and the MOSS-Audio-Tokenizer-Nano ONNX decoder on device, tokenizes custom text with a small Kotlin tokenizer, and writes a WAV file from Android. It is intentionally minimal and keeps model files outside the APK for local testing. ### Export TTS-only ONNX Weights diff --git a/README_zh.md b/README_zh.md index 8e47698..edf0c31 100644 --- a/README_zh.md +++ b/README_zh.md @@ -253,7 +253,7 @@ python app_onnx.py \ Android ONNX Runtime smoke 示例位于 [`examples/android_onnx_runtime`](./examples/android_onnx_runtime)。 -该示例会在 Android 设备端加载导出的 MOSS-TTS-Nano ONNX 图和 MOSS-Audio-Tokenizer-Nano ONNX 解码器,合成短的预分词 demo prompt,并写出 WAV 文件。示例刻意保持最小化,并将模型文件保留在 APK 外部,便于本地测试。 +该示例会在 Android 设备端加载导出的 MOSS-TTS-Nano ONNX 图和 MOSS-Audio-Tokenizer-Nano ONNX 解码器,通过小型 Kotlin tokenizer 对自定义文本分词,并写出 WAV 文件。示例刻意保持最小化,并将模型文件保留在 APK 外部,便于本地测试。 ### 导出仅 TTS 的 ONNX 权重 diff --git a/examples/android_onnx_runtime/README.md b/examples/android_onnx_runtime/README.md index 56dfcdb..5b44be3 100644 --- a/examples/android_onnx_runtime/README.md +++ b/examples/android_onnx_runtime/README.md @@ -9,7 +9,7 @@ It intentionally stays minimal: - no model files committed to git - no app-specific business logic -The demo synthesizes two pre-tokenized prompts so the ONNX path can be tested without adding a large SentencePiece JNI dependency to the first Android example. +The demo includes a small pure Kotlin tokenizer for `tokenizer.model`, so you can synthesize custom text directly on Android without adding a SentencePiece JNI dependency. ## Model Files @@ -58,7 +58,7 @@ adb push MOSS-Audio-Tokenizer-Nano-ONNX \ Open `examples/android_onnx_runtime` in Android Studio, connect a device, and run the `app` configuration. -Tap either demo button. The app writes a WAV file to its cache directory and prints the output path on screen. +Type custom text and tap `Generate custom text WAV`, or tap either pre-tokenized demo button. The app writes a WAV file to its cache directory and prints the output path on screen. The sample uses: @@ -69,15 +69,23 @@ The sample uses: ## Custom Text -For custom text input, tokenize with `tokenizer.model` using the same SentencePiece model used by the Python ONNX runtime, then pass the resulting token ids into `MossOnnxDemoEngine.synthesize`. +Custom text is handled by `SimpleSentencePieceTokenizer`, which reads the exported `tokenizer.model` and returns the text token ids used by `MossOnnxDemoEngine.synthesize`. -For a production Android app, add one of the following tokenizer paths: +You can also call the engine directly: -- a small SentencePiece JNI wrapper -- a pre-tokenization service or build step -- another Android-compatible SentencePiece implementation +```kotlin +MossOnnxDemoEngine( + modelRoot = modelRoot, + outputDir = cacheDir, +).use { engine -> + engine.synthesizeText( + text = "Hello world!", + outputFile = File(cacheDir, "custom.wav"), + ) +} +``` -The ONNX Runtime code is independent from the tokenizer as long as it receives the correct `IntArray` token ids. +The tokenizer intentionally implements only the inference-time pieces needed by the exported Nano `tokenizer.model`: Java NFKC-style normalization, whitespace escaping, Unigram segmentation, and BPE merge ranking. It does not interpret the full SentencePiece `precompiled_charsmap`, so compare its output against the Python tokenizer first if you replace the tokenizer model or rely on unusual normalization rules. ## Notes @@ -85,3 +93,4 @@ The ONNX Runtime code is independent from the tokenizer as long as it receives t - The demo caps generation to `maxFrames = 160` for faster smoke testing. - The decoded ONNX codec output is stereo; this example averages channels and writes a mono WAV for simplicity. - Keep the model files outside the APK for local testing. Bundling them into app assets is possible but increases APK size substantially. +- Unit tests use a handcrafted tokenizer fixture by default. To compare against a real Nano tokenizer locally, run `MOSS_TOKENIZER_MODEL=/path/to/tokenizer.model ./gradlew :app:testDebugUnitTest --rerun-tasks`. diff --git a/examples/android_onnx_runtime/app/build.gradle.kts b/examples/android_onnx_runtime/app/build.gradle.kts index ce55e1f..5e4a53b 100644 --- a/examples/android_onnx_runtime/app/build.gradle.kts +++ b/examples/android_onnx_runtime/app/build.gradle.kts @@ -1,3 +1,5 @@ +import org.gradle.api.tasks.testing.Test + plugins { id("com.android.application") id("org.jetbrains.kotlin.android") @@ -27,4 +29,10 @@ android { dependencies { implementation("com.microsoft.onnxruntime:onnxruntime-android:1.20.0") + + testImplementation("junit:junit:4.13.2") +} + +tasks.withType().configureEach { + inputs.property("MOSS_TOKENIZER_MODEL", System.getenv("MOSS_TOKENIZER_MODEL").orEmpty()) } diff --git a/examples/android_onnx_runtime/app/src/main/java/com/openmoss/ttsnano/onnxruntime/MainActivity.kt b/examples/android_onnx_runtime/app/src/main/java/com/openmoss/ttsnano/onnxruntime/MainActivity.kt index d1e73cb..48860ad 100644 --- a/examples/android_onnx_runtime/app/src/main/java/com/openmoss/ttsnano/onnxruntime/MainActivity.kt +++ b/examples/android_onnx_runtime/app/src/main/java/com/openmoss/ttsnano/onnxruntime/MainActivity.kt @@ -4,8 +4,10 @@ import android.app.Activity import android.os.Bundle import android.os.Handler import android.os.Looper +import android.text.InputType import android.view.ViewGroup import android.widget.Button +import android.widget.EditText import android.widget.LinearLayout import android.widget.ScrollView import android.widget.TextView @@ -14,6 +16,8 @@ import java.io.File class MainActivity : Activity() { private val mainHandler = Handler(Looper.getMainLooper()) private lateinit var logView: TextView + private lateinit var customTextInput: EditText + private lateinit var generateCustomButton: Button private lateinit var generateEnglishButton: Button private lateinit var generateChineseButton: Button @@ -25,6 +29,18 @@ class MainActivity : Activity() { textSize = 14f setTextIsSelectable(true) } + customTextInput = EditText(this).apply { + setText("Hello world!") + hint = "Custom text" + inputType = InputType.TYPE_CLASS_TEXT or InputType.TYPE_TEXT_FLAG_MULTI_LINE + minLines = 2 + } + generateCustomButton = Button(this).apply { + text = "Generate custom text WAV" + setOnClickListener { + runCustomText(customTextInput.text.toString()) + } + } generateEnglishButton = Button(this).apply { text = "Generate English demo WAV" setOnClickListener { @@ -41,6 +57,8 @@ class MainActivity : Activity() { val content = LinearLayout(this).apply { orientation = LinearLayout.VERTICAL setPadding(32, 32, 32, 32) + addView(customTextInput) + addView(generateCustomButton) addView(generateEnglishButton) addView(generateChineseButton) addView( @@ -53,7 +71,46 @@ class MainActivity : Activity() { } setContentView(ScrollView(this).apply { addView(content) }) appendLog("Place model files under:\n${modelRoot().absolutePath}") - appendLog("Tap a button to synthesize a short pre-tokenized demo prompt.") + appendLog("Enter text or tap a pre-tokenized demo prompt.") + } + + private fun runCustomText(text: String) { + val trimmedText = text.trim() + if (trimmedText.isEmpty()) { + appendLog("[custom] text is empty") + return + } + setButtonsEnabled(false) + appendLog("\n[custom] starting synthesis: $trimmedText") + Thread { + try { + val outputFile = File(cacheDir, "moss_tts_nano_android_custom.wav") + MossOnnxDemoEngine( + modelRoot = modelRoot(), + outputDir = cacheDir, + cpuThreads = 2, + ).use { engine -> + val result = engine.synthesizeText( + text = trimmedText, + outputFile = outputFile, + voice = "Junhao", + maxFrames = 160, + seed = 1234L, + ) + appendLogFromWorker( + "[custom] done: ${result.outputFile.absolutePath}\n" + + "frames=${result.generatedFrames} " + + "sampleRate=${result.sampleRate}Hz " + + "durationMs=${result.durationMs} " + + "elapsedMs=${result.elapsedMs}", + ) + } + } catch (error: Throwable) { + appendLogFromWorker("[custom] failed: ${error.javaClass.simpleName}: ${error.message}") + } finally { + mainHandler.post { setButtonsEnabled(true) } + } + }.start() } private fun runDemo(label: String, textTokenIds: IntArray) { @@ -95,6 +152,7 @@ class MainActivity : Activity() { } private fun setButtonsEnabled(enabled: Boolean) { + generateCustomButton.isEnabled = enabled generateEnglishButton.isEnabled = enabled generateChineseButton.isEnabled = enabled } diff --git a/examples/android_onnx_runtime/app/src/main/java/com/openmoss/ttsnano/onnxruntime/MossOnnxDemoEngine.kt b/examples/android_onnx_runtime/app/src/main/java/com/openmoss/ttsnano/onnxruntime/MossOnnxDemoEngine.kt index 6f1602a..62b4a0c 100644 --- a/examples/android_onnx_runtime/app/src/main/java/com/openmoss/ttsnano/onnxruntime/MossOnnxDemoEngine.kt +++ b/examples/android_onnx_runtime/app/src/main/java/com/openmoss/ttsnano/onnxruntime/MossOnnxDemoEngine.kt @@ -30,6 +30,9 @@ class MossOnnxDemoEngine( private val codecMeta = CodecMeta.fromJson(readJson(codecMetaPath)) private val ttsDir = ttsMetaPath.parentFile ?: manifestDir private val codecDir = codecMetaPath.parentFile ?: manifestDir + private val textTokenizer by lazy { + SimpleSentencePieceTokenizer.fromFile(File(ttsDir, "tokenizer.model")) + } private val sessionOptions = OrtSession.SessionOptions().apply { setOptimizationLevel(OrtSession.SessionOptions.OptLevel.ALL_OPT) setIntraOpNumThreads(cpuThreads.coerceAtLeast(1)) @@ -65,6 +68,22 @@ class MossOnnxDemoEngine( ) } + fun synthesizeText( + text: String, + outputFile: File = File(outputDir, "moss_tts_nano_android_custom.wav"), + voice: String = "Junhao", + maxFrames: Int = 160, + seed: Long = 1234L, + ): SynthesisResult { + return synthesize( + textTokenIds = textTokenizer.encode(text), + outputFile = outputFile, + voice = voice, + maxFrames = maxFrames, + seed = seed, + ) + } + override fun close() { codecDecodeSession.close() localFixedFrameSession.close() diff --git a/examples/android_onnx_runtime/app/src/main/java/com/openmoss/ttsnano/onnxruntime/SimpleSentencePieceTokenizer.kt b/examples/android_onnx_runtime/app/src/main/java/com/openmoss/ttsnano/onnxruntime/SimpleSentencePieceTokenizer.kt new file mode 100644 index 0000000..ac88c1c --- /dev/null +++ b/examples/android_onnx_runtime/app/src/main/java/com/openmoss/ttsnano/onnxruntime/SimpleSentencePieceTokenizer.kt @@ -0,0 +1,320 @@ +package com.openmoss.ttsnano.onnxruntime + +import java.io.File +import java.text.Normalizer + +/** + * Small SentencePiece tokenizer for the exported Nano tokenizer.model. + * + * This keeps the Android example self-contained and avoids native tokenizer + * dependencies. It intentionally implements only the model fields needed for + * inference-time encoding. + */ +class SimpleSentencePieceTokenizer private constructor( + private val pieces: List, + private val normalizerSpec: NormalizerSpec, +) { + private val unknownId = pieces.indexOfFirst { it.type == PieceType.UNKNOWN }.takeIf { it >= 0 } ?: 0 + private val pieceByText = pieces + .withIndex() + .filter { (_, piece) -> piece.type.isTokenizable && piece.text.isNotEmpty() } + .associateBy { (_, piece) -> piece.text } + private val piecesByFirstChar = pieces + .withIndex() + .filter { (_, piece) -> piece.type.isTokenizable && piece.text.isNotEmpty() } + .groupBy { (_, piece) -> piece.text[0] } + + fun encode(text: String): IntArray { + if (text.isBlank()) { + return IntArray(0) + } + val normalized = normalize(text) + if (normalized.isEmpty()) { + return IntArray(0) + } + if (normalizerSpec.modelType == ModelType.BPE) { + return encodeBpe(normalized) + } + + val bestScores = DoubleArray(normalized.length + 1) { Double.NEGATIVE_INFINITY } + val bestNextIndex = IntArray(normalized.length) { -1 } + val bestPieceId = IntArray(normalized.length) { unknownId } + bestScores[normalized.length] = 0.0 + + for (index in normalized.length - 1 downTo 0) { + val candidates = piecesByFirstChar[normalized[index]].orEmpty() + for ((pieceId, piece) in candidates) { + if (!normalized.startsWith(piece.text, index)) { + continue + } + val nextIndex = index + piece.text.length + val score = piece.score + bestScores[nextIndex] + if (score > bestScores[index]) { + bestScores[index] = score + bestNextIndex[index] = nextIndex + bestPieceId[index] = pieceId + } + } + + if (bestNextIndex[index] < 0) { + val nextIndex = normalized.offsetByCodePoints(index, 1) + bestScores[index] = bestScores[nextIndex] + bestNextIndex[index] = nextIndex + bestPieceId[index] = unknownId + } + } + + val ids = ArrayList() + var index = 0 + while (index < normalized.length) { + ids += bestPieceId[index] + index = bestNextIndex[index] + } + return ids.toIntArray() + } + + private fun encodeBpe(normalized: String): IntArray { + val tokens = ArrayList() + var index = 0 + while (index < normalized.length) { + val nextIndex = normalized.offsetByCodePoints(index, 1) + tokens += normalized.substring(index, nextIndex) + index = nextIndex + } + + while (tokens.size > 1) { + var bestIndex = -1 + var bestScore = Float.POSITIVE_INFINITY + for (tokenIndex in 0 until tokens.lastIndex) { + val merged = tokens[tokenIndex] + tokens[tokenIndex + 1] + val candidate = pieceByText[merged]?.value ?: continue + if (candidate.score < bestScore) { + bestScore = candidate.score + bestIndex = tokenIndex + } + } + if (bestIndex < 0) { + break + } + tokens[bestIndex] = tokens[bestIndex] + tokens.removeAt(bestIndex + 1) + } + + return IntArray(tokens.size) { tokenIndex -> + pieceByText[tokens[tokenIndex]]?.index ?: unknownId + } + } + + private fun normalize(text: String): String { + var normalized = if (normalizerSpec.name.contains("nfkc", ignoreCase = true)) { + Normalizer.normalize(text, Normalizer.Form.NFKC) + } else { + text + } + if (normalizerSpec.removeExtraWhitespaces) { + normalized = normalized.trim().replace(Regex("\\s+"), " ") + } + if (normalizerSpec.addDummyPrefix) { + normalized = " $normalized" + } + if (normalizerSpec.escapeWhitespaces) { + normalized = normalized.replace(Regex("\\s"), "▁") + } + return normalized + } + + companion object { + fun fromFile(modelFile: File): SimpleSentencePieceTokenizer { + require(modelFile.isFile) { "Missing tokenizer model: ${modelFile.absolutePath}" } + return fromModelBytes(modelFile.readBytes()) + } + + fun fromModelBytes(bytes: ByteArray): SimpleSentencePieceTokenizer { + return SentencePieceModelParser(bytes).parse() + } + } + + private data class Piece( + val text: String, + val score: Float, + val type: PieceType, + ) + + private enum class PieceType(val isTokenizable: Boolean) { + NORMAL(true), + UNKNOWN(false), + CONTROL(false), + USER_DEFINED(true), + UNUSED(false), + BYTE(true), + } + + private data class NormalizerSpec( + val modelType: ModelType = ModelType.UNIGRAM, + val name: String = "identity", + val addDummyPrefix: Boolean = true, + val removeExtraWhitespaces: Boolean = true, + val escapeWhitespaces: Boolean = true, + ) + + private enum class ModelType { + UNIGRAM, + BPE, + } + + private class SentencePieceModelParser(private val bytes: ByteArray) { + fun parse(): SimpleSentencePieceTokenizer { + val pieces = ArrayList() + var normalizerSpec = NormalizerSpec() + var modelType = ModelType.UNIGRAM + val reader = ProtoReader(bytes) + while (!reader.isAtEnd()) { + when (val tag = reader.readTag()) { + tag(1, WireType.LENGTH_DELIMITED) -> { + pieces += parsePiece(reader.readBytes()) + } + tag(2, WireType.LENGTH_DELIMITED) -> { + modelType = parseTrainerSpec(reader.readBytes()) + } + tag(3, WireType.LENGTH_DELIMITED) -> { + normalizerSpec = parseNormalizerSpec(reader.readBytes()) + } + else -> reader.skip(tag) + } + } + require(pieces.isNotEmpty()) { "No SentencePiece entries found in tokenizer model" } + return SimpleSentencePieceTokenizer(pieces, normalizerSpec.copy(modelType = modelType)) + } + + private fun parsePiece(pieceBytes: ByteArray): Piece { + var text = "" + var score = 0f + var type = PieceType.NORMAL + val reader = ProtoReader(pieceBytes) + while (!reader.isAtEnd()) { + when (val tag = reader.readTag()) { + tag(1, WireType.LENGTH_DELIMITED) -> text = reader.readString() + tag(2, WireType.FIXED32) -> score = reader.readFloat() + tag(3, WireType.VARINT) -> type = when (reader.readVarint().toInt()) { + 2 -> PieceType.UNKNOWN + 3 -> PieceType.CONTROL + 4 -> PieceType.USER_DEFINED + 5 -> PieceType.UNUSED + 6 -> PieceType.BYTE + else -> PieceType.NORMAL + } + else -> reader.skip(tag) + } + } + return Piece(text = text, score = score, type = type) + } + + private fun parseNormalizerSpec(specBytes: ByteArray): NormalizerSpec { + var name = "identity" + var addDummyPrefix = true + var removeExtraWhitespaces = true + var escapeWhitespaces = true + val reader = ProtoReader(specBytes) + while (!reader.isAtEnd()) { + when (val tag = reader.readTag()) { + tag(1, WireType.LENGTH_DELIMITED) -> name = reader.readString() + tag(3, WireType.VARINT) -> addDummyPrefix = reader.readVarint() != 0L + tag(4, WireType.VARINT) -> removeExtraWhitespaces = reader.readVarint() != 0L + tag(5, WireType.VARINT) -> escapeWhitespaces = reader.readVarint() != 0L + else -> reader.skip(tag) + } + } + return NormalizerSpec( + name = name, + addDummyPrefix = addDummyPrefix, + removeExtraWhitespaces = removeExtraWhitespaces, + escapeWhitespaces = escapeWhitespaces, + ) + } + + private fun parseTrainerSpec(specBytes: ByteArray): ModelType { + val reader = ProtoReader(specBytes) + while (!reader.isAtEnd()) { + when (val tag = reader.readTag()) { + tag(3, WireType.VARINT) -> { + return if (reader.readVarint().toInt() == 2) ModelType.BPE else ModelType.UNIGRAM + } + else -> reader.skip(tag) + } + } + return ModelType.UNIGRAM + } + } + + private class ProtoReader(private val bytes: ByteArray) { + private var offset = 0 + + fun isAtEnd(): Boolean = offset >= bytes.size + + fun readTag(): Int = readVarint().toInt() + + fun readVarint(): Long { + var shift = 0 + var result = 0L + while (shift < 64) { + val value = readByte().toInt() and 0xff + result = result or ((value and 0x7f).toLong() shl shift) + if (value and 0x80 == 0) { + return result + } + shift += 7 + } + error("Invalid varint in tokenizer model") + } + + fun readBytes(): ByteArray { + val length = readVarint().toInt() + require(length >= 0 && offset + length <= bytes.size) { "Invalid length-delimited field" } + return bytes.copyOfRange(offset, offset + length).also { + offset += length + } + } + + fun readString(): String = readBytes().toString(Charsets.UTF_8) + + fun readFloat(): Float { + require(offset + 4 <= bytes.size) { "Invalid fixed32 field" } + val bits = (bytes[offset].toInt() and 0xff) or + ((bytes[offset + 1].toInt() and 0xff) shl 8) or + ((bytes[offset + 2].toInt() and 0xff) shl 16) or + ((bytes[offset + 3].toInt() and 0xff) shl 24) + offset += 4 + return java.lang.Float.intBitsToFloat(bits) + } + + fun skip(tag: Int) { + when (tag and 0x7) { + WireType.VARINT.id -> readVarint() + WireType.FIXED64.id -> skipBytes(8) + WireType.LENGTH_DELIMITED.id -> skipBytes(readVarint().toInt()) + WireType.FIXED32.id -> skipBytes(4) + else -> error("Unsupported protobuf wire type: ${tag and 0x7}") + } + } + + private fun readByte(): Byte { + require(offset < bytes.size) { "Unexpected end of tokenizer model" } + return bytes[offset++] + } + + private fun skipBytes(count: Int) { + require(count >= 0 && offset + count <= bytes.size) { "Invalid protobuf field length" } + offset += count + } + } +} + +private enum class WireType(val id: Int) { + VARINT(0), + FIXED64(1), + LENGTH_DELIMITED(2), + FIXED32(5), +} + +private fun tag(fieldNumber: Int, wireType: WireType): Int { + return (fieldNumber shl 3) or wireType.id +} diff --git a/examples/android_onnx_runtime/app/src/test/java/com/openmoss/ttsnano/onnxruntime/SimpleSentencePieceTokenizerTest.kt b/examples/android_onnx_runtime/app/src/test/java/com/openmoss/ttsnano/onnxruntime/SimpleSentencePieceTokenizerTest.kt new file mode 100644 index 0000000..30ac846 --- /dev/null +++ b/examples/android_onnx_runtime/app/src/test/java/com/openmoss/ttsnano/onnxruntime/SimpleSentencePieceTokenizerTest.kt @@ -0,0 +1,163 @@ +package com.openmoss.ttsnano.onnxruntime + +import org.junit.Assert.assertArrayEquals +import org.junit.Assume.assumeTrue +import org.junit.Test +import java.io.File + +class SimpleSentencePieceTokenizerTest { + @Test + fun encodesTextWithWhitespacePrefixAndNfkcNormalization() { + val tokenizer = SimpleSentencePieceTokenizer.fromModelBytes( + buildModel( + piece("", 0f, 2), + piece("▁Hello", -0.1f), + piece("▁world", -0.1f), + piece("!", -0.1f), + piece("▁你", -0.1f), + piece("好", -0.1f), + piece(",", -0.1f), + piece("世界", -0.1f), + ), + ) + + assertArrayEquals(intArrayOf(1, 2, 3), tokenizer.encode("Hello world!")) + assertArrayEquals(intArrayOf(4, 5, 6, 7, 3), tokenizer.encode("你好,世界!")) + } + + @Test + fun blankTextEncodesToNoTokens() { + val tokenizer = SimpleSentencePieceTokenizer.fromModelBytes( + buildModel( + piece("", 0f, 2), + piece("▁", -0.1f), + ), + ) + + assertArrayEquals(IntArray(0), tokenizer.encode("")) + assertArrayEquals(IntArray(0), tokenizer.encode(" \n\t ")) + } + + @Test + fun prefersTheHighestScoredUnigramPath() { + val tokenizer = SimpleSentencePieceTokenizer.fromModelBytes( + buildModel( + piece("", 0f, 2), + piece("▁a", -2.0f), + piece("▁ab", -0.1f), + piece("b", -0.1f), + ), + ) + + assertArrayEquals(intArrayOf(2), tokenizer.encode("ab")) + } + + @Test + fun bpeModelMergesPairsByRank() { + val tokenizer = SimpleSentencePieceTokenizer.fromModelBytes( + buildModel( + 2, + piece("", 0f, 2), + piece("▁", 0f), + piece("H", 0f), + piece("e", 0f), + piece("l", 0f), + piece("o", 0f), + piece("▁H", -1f), + piece("▁He", -2f), + piece("ll", -3f), + piece("▁Hell", -4f), + piece("▁Hello", -5f), + ), + ) + + assertArrayEquals(intArrayOf(10), tokenizer.encode("Hello")) + } + + @Test + fun encodesRealMossTokenizerWhenModelPathIsProvided() { + val modelPath = System.getenv("MOSS_TOKENIZER_MODEL").orEmpty() + assumeTrue("Set MOSS_TOKENIZER_MODEL to run this optional integration test", modelPath.isNotBlank()) + val tokenizer = SimpleSentencePieceTokenizer.fromFile(File(modelPath)) + + assertArrayEquals(intArrayOf(7026, 1177, 11449), tokenizer.encode("Hello world!")) + assertArrayEquals(intArrayOf(3985, 10445, 10364, 1260, 11449), tokenizer.encode("你好,世界!")) + } + + private fun buildModel(vararg pieces: ByteArray): ByteArray { + return buildModel(1, *pieces) + } + + private fun buildModel(modelType: Int, vararg pieces: ByteArray): ByteArray { + return proto { + pieces.forEach { bytes -> + fieldBytes(1, bytes) + } + fieldBytes(2) { + fieldVarint(3, modelType) + } + fieldBytes(3) { + fieldString(1, "nmt_nfkc") + fieldVarint(3, 1) + fieldVarint(4, 1) + fieldVarint(5, 1) + } + } + } + + private fun piece(text: String, score: Float, type: Int = 1): ByteArray { + return proto { + fieldString(1, text) + fieldFloat(2, score) + fieldVarint(3, type) + } + } + + private fun proto(block: ProtoWriter.() -> Unit): ByteArray { + return ProtoWriter().apply(block).toByteArray() + } + +} + +private class ProtoWriter { + private val bytes = ArrayList() + + fun fieldString(fieldNumber: Int, value: String) { + fieldBytes(fieldNumber, value.toByteArray(Charsets.UTF_8)) + } + + fun fieldBytes(fieldNumber: Int, value: ByteArray) { + writeVarint((fieldNumber shl 3) or 2) + writeVarint(value.size) + value.forEach { bytes += it } + } + + fun fieldBytes(fieldNumber: Int, block: ProtoWriter.() -> Unit) { + fieldBytes(fieldNumber, ProtoWriter().apply(block).toByteArray()) + } + + fun fieldVarint(fieldNumber: Int, value: Int) { + writeVarint((fieldNumber shl 3) or 0) + writeVarint(value) + } + + fun fieldFloat(fieldNumber: Int, value: Float) { + writeVarint((fieldNumber shl 3) or 5) + val bits = java.lang.Float.floatToIntBits(value) + bytes += (bits and 0xff).toByte() + bytes += ((bits ushr 8) and 0xff).toByte() + bytes += ((bits ushr 16) and 0xff).toByte() + bytes += ((bits ushr 24) and 0xff).toByte() + } + + fun toByteArray(): ByteArray = bytes.toByteArray() + + private fun writeVarint(value: Int) { + var remaining = value + while (remaining and 0x7f.inv() != 0) { + bytes += ((remaining and 0x7f) or 0x80).toByte() + remaining = remaining ushr 7 + } + bytes += remaining.toByte() + } +}