|
| 1 | +package sk.ainet.apps.kgemma.cli |
| 2 | + |
| 3 | +import sk.ainet.apps.kgemma.Gemma3nIngestion |
| 4 | +import sk.ainet.apps.kgemma.Gemma3nLoadConfig |
| 5 | +import sk.ainet.apps.kllama.GGUFTokenizer |
| 6 | +import sk.ainet.apps.llm.Tokenizer |
| 7 | +import sk.ainet.context.DirectCpuExecutionContext |
| 8 | +import sk.ainet.io.JvmRandomAccessSource |
| 9 | +import sk.ainet.io.model.QuantPolicy |
| 10 | +import sk.ainet.lang.tensor.data.MemorySegmentTensorDataFactory |
| 11 | +import sk.ainet.lang.types.FP32 |
| 12 | +import java.lang.foreign.Arena |
| 13 | +import java.nio.file.Path |
| 14 | +import kotlinx.coroutines.runBlocking |
| 15 | +import kotlin.io.path.exists |
| 16 | +import kotlin.io.path.extension |
| 17 | +import kotlin.io.path.isDirectory |
| 18 | +import kotlin.io.path.readText |
| 19 | +import kotlin.system.exitProcess |
| 20 | +import kotlin.time.measureTime |
| 21 | + |
| 22 | +private enum class ModelFormat { GGUF, SAFETENSORS } |
| 23 | + |
| 24 | +private data class CliArgs( |
| 25 | + val modelPath: Path, |
| 26 | + val prompt: String, |
| 27 | + val steps: Int, |
| 28 | + val temperature: Float |
| 29 | +) |
| 30 | + |
| 31 | +private fun usage(errorMessage: String? = null): Nothing { |
| 32 | + if (errorMessage != null) { |
| 33 | + System.err.println("Error: $errorMessage") |
| 34 | + System.err.println() |
| 35 | + } |
| 36 | + |
| 37 | + println("Usage: kgemma <model> <prompt> [steps] [temperature]") |
| 38 | + println(" model Path to .gguf model or SafeTensors directory (required)") |
| 39 | + println(" prompt Prompt text (required)") |
| 40 | + println(" steps Generation steps (default: 32)") |
| 41 | + println(" temperature Sampling temperature (default: 0.8)") |
| 42 | + println() |
| 43 | + println("Example:") |
| 44 | + println(" kgemma models/gemma-3-270m-it-Q8_0.gguf \"Hello, how are you?\" 32 0.8") |
| 45 | + exitProcess(if (errorMessage == null) 0 else 1) |
| 46 | +} |
| 47 | + |
| 48 | +private fun parseArgs(args: Array<String>): CliArgs { |
| 49 | + if (args.isEmpty()) usage("Missing arguments.") |
| 50 | + if (args[0] == "-h" || args[0] == "--help") usage() |
| 51 | + |
| 52 | + val modelPath = Path.of(args[0]) |
| 53 | + val prompt = args.getOrElse(1) { usage("Prompt is required.") } |
| 54 | + val steps = args.getOrElse(2) { "32" }.toIntOrNull() ?: usage("Invalid steps value '${args[2]}'.") |
| 55 | + val temperature = args.getOrElse(3) { "0.8" }.toFloatOrNull() ?: usage("Invalid temperature '${args[3]}'.") |
| 56 | + |
| 57 | + return CliArgs(modelPath, prompt, steps, temperature) |
| 58 | +} |
| 59 | + |
| 60 | +private fun detectFormat(path: Path): ModelFormat { |
| 61 | + if (path.isDirectory()) { |
| 62 | + val st = path.resolve("model.safetensors") |
| 63 | + val stIndex = path.resolve("model.safetensors.index.json") |
| 64 | + if (st.exists() || stIndex.exists()) return ModelFormat.SAFETENSORS |
| 65 | + error("Directory $path does not contain model.safetensors or model.safetensors.index.json") |
| 66 | + } |
| 67 | + return when (path.extension.lowercase()) { |
| 68 | + "gguf" -> ModelFormat.GGUF |
| 69 | + "safetensors" -> ModelFormat.SAFETENSORS |
| 70 | + else -> error("Unsupported model format: ${path.extension}. Use .gguf or .safetensors") |
| 71 | + } |
| 72 | +} |
| 73 | + |
| 74 | +fun main(args: Array<String>) { |
| 75 | + runBlocking { |
| 76 | + val cliArgs = parseArgs(args) |
| 77 | + val modelPath = cliArgs.modelPath |
| 78 | + |
| 79 | + if (!modelPath.exists()) error("Model not found: $modelPath") |
| 80 | + |
| 81 | + val format = detectFormat(modelPath) |
| 82 | + |
| 83 | + val memSegFactory = MemorySegmentTensorDataFactory() |
| 84 | + val ctx = DirectCpuExecutionContext(tensorDataFactory = memSegFactory) |
| 85 | + |
| 86 | + Runtime.getRuntime().addShutdownHook(Thread { |
| 87 | + memSegFactory.close() |
| 88 | + }) |
| 89 | + |
| 90 | + val ingestion = Gemma3nIngestion<FP32>( |
| 91 | + ctx = ctx, |
| 92 | + dtype = FP32::class, |
| 93 | + config = Gemma3nLoadConfig( |
| 94 | + quantPolicy = QuantPolicy.DEQUANTIZE_TO_FP32, |
| 95 | + allowQuantized = true |
| 96 | + ) |
| 97 | + ) |
| 98 | + |
| 99 | + val runtime = when (format) { |
| 100 | + ModelFormat.GGUF -> { |
| 101 | + println("Loading Gemma GGUF model from $modelPath (streaming mode)...") |
| 102 | + ingestion.loadRuntimeStreaming { |
| 103 | + JvmRandomAccessSource.open(modelPath.toString()) |
| 104 | + } |
| 105 | + } |
| 106 | + ModelFormat.SAFETENSORS -> { |
| 107 | + val modelDir = if (modelPath.isDirectory()) modelPath else modelPath.parent ?: modelPath |
| 108 | + val indexPath = modelDir.resolve("model.safetensors.index.json") |
| 109 | + val safetensorsPath = if (indexPath.exists()) { |
| 110 | + indexPath.toString() |
| 111 | + } else { |
| 112 | + modelDir.resolve("model.safetensors").toString() |
| 113 | + } |
| 114 | + println("Loading Gemma SafeTensors model from $safetensorsPath...") |
| 115 | + ingestion.loadRuntimeFromSafeTensors(safetensorsPath) |
| 116 | + } |
| 117 | + } |
| 118 | + |
| 119 | + // Load tokenizer from GGUF or from tokenizer.json in model directory |
| 120 | + val tokenizer: Tokenizer = when (format) { |
| 121 | + ModelFormat.GGUF -> { |
| 122 | + println("Loading embedded GGUF tokenizer...") |
| 123 | + JvmRandomAccessSource.open(modelPath.toString()).use { source -> |
| 124 | + GGUFTokenizer.fromRandomAccessSource(source) |
| 125 | + } |
| 126 | + } |
| 127 | + ModelFormat.SAFETENSORS -> { |
| 128 | + val modelDir = if (modelPath.isDirectory()) modelPath else modelPath.parent ?: modelPath |
| 129 | + val tokenizerFile = modelDir.resolve("tokenizer.json") |
| 130 | + if (!tokenizerFile.exists()) error("tokenizer.json not found in $modelDir") |
| 131 | + println("Loading tokenizer from $tokenizerFile...") |
| 132 | + GGUFTokenizer.fromTokenizerJson(tokenizerFile.readText()) |
| 133 | + } |
| 134 | + } |
| 135 | + |
| 136 | + val promptTokens = tokenizer.encode(cliArgs.prompt) |
| 137 | + |
| 138 | + println("Generating ${cliArgs.steps} tokens with temperature=${cliArgs.temperature}...") |
| 139 | + println("---") |
| 140 | + print(cliArgs.prompt) |
| 141 | + |
| 142 | + val elapsed = measureTime { |
| 143 | + runtime.generate(prompt = promptTokens, steps = cliArgs.steps, temperature = cliArgs.temperature) { id -> |
| 144 | + print(tokenizer.decode(id)) |
| 145 | + } |
| 146 | + }.inWholeMilliseconds |
| 147 | + |
| 148 | + val tokPerSec = cliArgs.steps / elapsed.toDouble() * 1000 |
| 149 | + println("\n---") |
| 150 | + println("tok/s: $tokPerSec") |
| 151 | + } |
| 152 | +} |
0 commit comments