Skip to content

Commit e2169ab

Browse files
committed
Improve llms loading smoke test
1 parent 7f2c882 commit e2169ab

5 files changed

Lines changed: 213 additions & 13 deletions

File tree

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
package sk.ainet.apps.kgemma.cli
2+
3+
import sk.ainet.apps.kgemma.Gemma3nIngestion
4+
import sk.ainet.apps.kgemma.Gemma3nLoadConfig
5+
import sk.ainet.apps.kllama.GGUFTokenizer
6+
import sk.ainet.apps.llm.Tokenizer
7+
import sk.ainet.context.DirectCpuExecutionContext
8+
import sk.ainet.io.JvmRandomAccessSource
9+
import sk.ainet.io.model.QuantPolicy
10+
import sk.ainet.lang.tensor.data.MemorySegmentTensorDataFactory
11+
import sk.ainet.lang.types.FP32
12+
import java.lang.foreign.Arena
13+
import java.nio.file.Path
14+
import kotlinx.coroutines.runBlocking
15+
import kotlin.io.path.exists
16+
import kotlin.io.path.extension
17+
import kotlin.io.path.isDirectory
18+
import kotlin.io.path.readText
19+
import kotlin.system.exitProcess
20+
import kotlin.time.measureTime
21+
22+
private enum class ModelFormat { GGUF, SAFETENSORS }
23+
24+
private data class CliArgs(
25+
val modelPath: Path,
26+
val prompt: String,
27+
val steps: Int,
28+
val temperature: Float
29+
)
30+
31+
private fun usage(errorMessage: String? = null): Nothing {
32+
if (errorMessage != null) {
33+
System.err.println("Error: $errorMessage")
34+
System.err.println()
35+
}
36+
37+
println("Usage: kgemma <model> <prompt> [steps] [temperature]")
38+
println(" model Path to .gguf model or SafeTensors directory (required)")
39+
println(" prompt Prompt text (required)")
40+
println(" steps Generation steps (default: 32)")
41+
println(" temperature Sampling temperature (default: 0.8)")
42+
println()
43+
println("Example:")
44+
println(" kgemma models/gemma-3-270m-it-Q8_0.gguf \"Hello, how are you?\" 32 0.8")
45+
exitProcess(if (errorMessage == null) 0 else 1)
46+
}
47+
48+
private fun parseArgs(args: Array<String>): CliArgs {
49+
if (args.isEmpty()) usage("Missing arguments.")
50+
if (args[0] == "-h" || args[0] == "--help") usage()
51+
52+
val modelPath = Path.of(args[0])
53+
val prompt = args.getOrElse(1) { usage("Prompt is required.") }
54+
val steps = args.getOrElse(2) { "32" }.toIntOrNull() ?: usage("Invalid steps value '${args[2]}'.")
55+
val temperature = args.getOrElse(3) { "0.8" }.toFloatOrNull() ?: usage("Invalid temperature '${args[3]}'.")
56+
57+
return CliArgs(modelPath, prompt, steps, temperature)
58+
}
59+
60+
private fun detectFormat(path: Path): ModelFormat {
61+
if (path.isDirectory()) {
62+
val st = path.resolve("model.safetensors")
63+
val stIndex = path.resolve("model.safetensors.index.json")
64+
if (st.exists() || stIndex.exists()) return ModelFormat.SAFETENSORS
65+
error("Directory $path does not contain model.safetensors or model.safetensors.index.json")
66+
}
67+
return when (path.extension.lowercase()) {
68+
"gguf" -> ModelFormat.GGUF
69+
"safetensors" -> ModelFormat.SAFETENSORS
70+
else -> error("Unsupported model format: ${path.extension}. Use .gguf or .safetensors")
71+
}
72+
}
73+
74+
fun main(args: Array<String>) {
75+
runBlocking {
76+
val cliArgs = parseArgs(args)
77+
val modelPath = cliArgs.modelPath
78+
79+
if (!modelPath.exists()) error("Model not found: $modelPath")
80+
81+
val format = detectFormat(modelPath)
82+
83+
val memSegFactory = MemorySegmentTensorDataFactory()
84+
val ctx = DirectCpuExecutionContext(tensorDataFactory = memSegFactory)
85+
86+
Runtime.getRuntime().addShutdownHook(Thread {
87+
memSegFactory.close()
88+
})
89+
90+
val ingestion = Gemma3nIngestion<FP32>(
91+
ctx = ctx,
92+
dtype = FP32::class,
93+
config = Gemma3nLoadConfig(
94+
quantPolicy = QuantPolicy.DEQUANTIZE_TO_FP32,
95+
allowQuantized = true
96+
)
97+
)
98+
99+
val runtime = when (format) {
100+
ModelFormat.GGUF -> {
101+
println("Loading Gemma GGUF model from $modelPath (streaming mode)...")
102+
ingestion.loadRuntimeStreaming {
103+
JvmRandomAccessSource.open(modelPath.toString())
104+
}
105+
}
106+
ModelFormat.SAFETENSORS -> {
107+
val modelDir = if (modelPath.isDirectory()) modelPath else modelPath.parent ?: modelPath
108+
val indexPath = modelDir.resolve("model.safetensors.index.json")
109+
val safetensorsPath = if (indexPath.exists()) {
110+
indexPath.toString()
111+
} else {
112+
modelDir.resolve("model.safetensors").toString()
113+
}
114+
println("Loading Gemma SafeTensors model from $safetensorsPath...")
115+
ingestion.loadRuntimeFromSafeTensors(safetensorsPath)
116+
}
117+
}
118+
119+
// Load tokenizer from GGUF or from tokenizer.json in model directory
120+
val tokenizer: Tokenizer = when (format) {
121+
ModelFormat.GGUF -> {
122+
println("Loading embedded GGUF tokenizer...")
123+
JvmRandomAccessSource.open(modelPath.toString()).use { source ->
124+
GGUFTokenizer.fromRandomAccessSource(source)
125+
}
126+
}
127+
ModelFormat.SAFETENSORS -> {
128+
val modelDir = if (modelPath.isDirectory()) modelPath else modelPath.parent ?: modelPath
129+
val tokenizerFile = modelDir.resolve("tokenizer.json")
130+
if (!tokenizerFile.exists()) error("tokenizer.json not found in $modelDir")
131+
println("Loading tokenizer from $tokenizerFile...")
132+
GGUFTokenizer.fromTokenizerJson(tokenizerFile.readText())
133+
}
134+
}
135+
136+
val promptTokens = tokenizer.encode(cliArgs.prompt)
137+
138+
println("Generating ${cliArgs.steps} tokens with temperature=${cliArgs.temperature}...")
139+
println("---")
140+
print(cliArgs.prompt)
141+
142+
val elapsed = measureTime {
143+
runtime.generate(prompt = promptTokens, steps = cliArgs.steps, temperature = cliArgs.temperature) { id ->
144+
print(tokenizer.decode(id))
145+
}
146+
}.inWholeMilliseconds
147+
148+
val tokPerSec = cliArgs.steps / elapsed.toDouble() * 1000
149+
println("\n---")
150+
println("tok/s: $tokPerSec")
151+
}
152+
}

skainet-apps/skainet-kllama-cli/build.gradle.kts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
plugins {
22
kotlin("jvm")
33
alias(libs.plugins.shadow)
4+
application
5+
}
6+
7+
application {
8+
mainClass.set("sk.ainet.apps.kllama.cli.MainKt")
49
}
510

611
dependencies {

skainet-models/skainet-model-gemma/src/commonMain/kotlin/sk/ainet/models/gemma/Gemma3nWeightLoader.kt

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,22 @@ public class Gemma3nWeightLoader private constructor(
172172
}
173173
}
174174

175+
// Output weight: use dedicated tensor or fall back to weight tying (reuse token embeddings)
176+
val outputRt = tensorByName[Gemma3nTensorNames.OUTPUT_WEIGHT]
177+
if (outputRt != null) {
178+
val tensor: Tensor<T, V> = readerTensorToTensor(ctx, dtype, reader, outputRt, metadata)
179+
onTensorLoaded(Gemma3nTensorNames.OUTPUT_WEIGHT, tensor)
180+
if (quantPolicy == QuantPolicy.RAW_BYTES && outputRt.tensorType != GGMLQuantizationType.F32) {
181+
quantCallback?.invoke(Gemma3nTensorNames.OUTPUT_WEIGHT, outputRt.tensorType)
182+
}
183+
} else {
184+
// Weight tying: reuse token_embd.weight as output.weight (common in Gemma models)
185+
val embedRt = tensorByName[Gemma3nTensorNames.TOKEN_EMBEDDINGS]
186+
?: error("Missing both output.weight and token_embd.weight — cannot resolve LM head")
187+
val tensor: Tensor<T, V> = readerTensorToTensor(ctx, dtype, reader, embedRt, metadata)
188+
onTensorLoaded(Gemma3nTensorNames.OUTPUT_WEIGHT, tensor)
189+
}
190+
175191
// Optional tensors
176192
loadOptionalTensors(ctx, dtype, reader, tensorByName, onTensorLoaded, metadata)
177193

@@ -209,6 +225,22 @@ public class Gemma3nWeightLoader private constructor(
209225
}
210226
}
211227

228+
// Output weight: use dedicated tensor or fall back to weight tying (reuse token embeddings)
229+
val outputSt = tensorByName[Gemma3nTensorNames.OUTPUT_WEIGHT]
230+
if (outputSt != null) {
231+
val tensor: Tensor<T, V> = streamingTensorToTensor(ctx, dtype, reader, outputSt, metadata)
232+
onTensorLoaded(Gemma3nTensorNames.OUTPUT_WEIGHT, tensor)
233+
if (quantPolicy == QuantPolicy.RAW_BYTES && outputSt.tensorType != GGMLQuantizationType.F32) {
234+
quantCallback?.invoke(Gemma3nTensorNames.OUTPUT_WEIGHT, outputSt.tensorType)
235+
}
236+
} else {
237+
// Weight tying: reuse token_embd.weight as output.weight (common in Gemma models)
238+
val embedSt = tensorByName[Gemma3nTensorNames.TOKEN_EMBEDDINGS]
239+
?: error("Missing both output.weight and token_embd.weight — cannot resolve LM head")
240+
val tensor: Tensor<T, V> = streamingTensorToTensor(ctx, dtype, reader, embedSt, metadata)
241+
onTensorLoaded(Gemma3nTensorNames.OUTPUT_WEIGHT, tensor)
242+
}
243+
212244
// Optional tensors
213245
loadOptionalStreamingTensors(ctx, dtype, reader, tensorByName, onTensorLoaded, metadata)
214246

@@ -421,7 +453,8 @@ public class Gemma3nWeightLoader private constructor(
421453
val names = mutableListOf<String>()
422454
names += Gemma3nTensorNames.TOKEN_EMBEDDINGS
423455
names += Gemma3nTensorNames.OUTPUT_NORM
424-
names += Gemma3nTensorNames.OUTPUT_WEIGHT
456+
// OUTPUT_WEIGHT is handled separately — many Gemma models use weight tying
457+
// (no output.weight tensor; the token embedding is reused as the LM head).
425458

426459
repeat(metadata.blockCount) { layer ->
427460
names += Gemma3nTensorNames.inputLayernorm(layer)

smoke-models.json

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,14 @@
88
{
99
"name": "Llama-3.2-1B-Q8",
1010
"runner": "kllama",
11-
"model": "~/.lmstudio/models/llama-3.2-1b/llama-3.2-1b-q8_0.gguf",
11+
"model": "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF/tinyllama-1.1b-chat-v1.0.Q8_0.gguf",
1212
"format": "gguf"
1313
},
1414
{
1515
"name": "Gemma-2B-SafeTensors",
1616
"runner": "kgemma",
17-
"model": "~/.cache/huggingface/models/gemma-2b",
18-
"format": "safetensors",
17+
"model": "unsloth/gemma-3-270m-it-GGUF/gemma-3-270m-it-Q8_0.gguf",
18+
"format": "gguf",
1919
"steps": 16
2020
},
2121
{

smoke-test.sh

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212
# ./smoke-test.sh /path/to/models # scan custom directory (legacy)
1313
# ./smoke-test.sh model1.gguf model2.gguf # run specific files (legacy)
1414
#
15+
# Environment variables:
16+
# MODELS_ROOT Root directory for resolving relative model paths in the
17+
# JSON config. Absolute paths (/ or ~/) are unaffected.
18+
# In legacy mode, used as the default scan directory.
19+
#
1520
set -euo pipefail
1621

1722
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
@@ -32,7 +37,7 @@ separator() {
3237
# Maps runner name → Gradle task
3338
runner_task() {
3439
case "$1" in
35-
kllama) echo ":skainet-apps:skainet-kllama:run" ;;
40+
kllama) echo ":skainet-apps:skainet-kllama-cli:run" ;;
3641
kgemma) echo ":skainet-apps:skainet-kgemma:jvmRun" ;;
3742
kbert) echo ":skainet-apps:skainet-kbert-cli:run" ;;
3843
*) echo "UNKNOWN"; return 1 ;;
@@ -42,9 +47,9 @@ runner_task() {
4247
# Maps runner name → compile task
4348
runner_compile_task() {
4449
case "$1" in
45-
kllama) echo ":skainet-apps:skainet-kllama:jvmMainClasses" ;;
50+
kllama) echo ":skainet-apps:skainet-kllama-cli:classes" ;;
4651
kgemma) echo ":skainet-apps:skainet-kgemma:jvmMainClasses" ;;
47-
kbert) echo ":skainet-apps:skainet-kbert-cli:jvmMainClasses" ;;
52+
kbert) echo ":skainet-apps:skainet-kbert-cli:mainClasses" ;;
4853
*) echo "UNKNOWN"; return 1 ;;
4954
esac
5055
}
@@ -54,7 +59,7 @@ runner_args() {
5459
local runner="$1" model="$2" prompt="$3" steps="$4" temp="$5" doc="${6:-}"
5560

5661
case "$runner" in
57-
kllama) echo "-m ${model} -s ${steps} -k ${temp} ${prompt}" ;;
62+
kllama) echo "-m ${model} -s ${steps} -k ${temp} \"${prompt}\"" ;;
5863
kgemma) echo "${model} \"${prompt}\" ${steps} ${temp}" ;;
5964
kbert)
6065
if [[ -n "$doc" ]]; then
@@ -66,11 +71,15 @@ runner_args() {
6671
esac
6772
}
6873

69-
# Expand ~ to $HOME in a path
74+
# Expand ~ to $HOME in a path; prepend MODELS_ROOT for relative paths
7075
expand_path() {
7176
local p="$1"
7277
if [[ "$p" == "~/"* ]]; then
7378
echo "${HOME}/${p#\~/}"
79+
elif [[ "$p" == /* ]]; then
80+
echo "$p"
81+
elif [[ -n "${MODELS_ROOT:-}" ]]; then
82+
echo "${MODELS_ROOT%/}/${p}"
7483
else
7584
echo "$p"
7685
fi
@@ -129,6 +138,7 @@ print(f'DEF_TEMP={d.get(\"temperature\", 0.0)}')
129138

130139
echo -e "${BOLD}SKaiNET Smoke Test${RESET} (config: $(basename "$CONFIG_FILE"))"
131140
echo -e "Models: ${CYAN}${MODEL_COUNT}${RESET}"
141+
[[ -n "${MODELS_ROOT:-}" ]] && echo -e "Models root: ${MODELS_ROOT}"
132142
echo -e "Default prompt: \"${DEF_PROMPT}\""
133143
echo -e "Default steps: ${DEF_STEPS}"
134144
echo -e "Default temperature: ${DEF_TEMP}"
@@ -253,8 +263,8 @@ fi
253263
PROMPT="${SMOKE_PROMPT:-The capital of France is}"
254264
STEPS="${SMOKE_STEPS:-32}"
255265
TEMP="${SMOKE_TEMP:-0.0}"
256-
MODEL_DIR="${LEGACY_ARGS[0]:-$HOME/.lmstudio/models}"
257-
TASK=":skainet-apps:skainet-kllama:run"
266+
MODEL_DIR="${LEGACY_ARGS[0]:-${MODELS_ROOT:-$HOME/.lmstudio/models}}"
267+
TASK=":skainet-apps:skainet-kllama-cli:run"
258268

259269
models=()
260270

@@ -296,7 +306,7 @@ separator
296306

297307
# ── Ensure project compiles ────────────────────────────────────────────
298308
echo -e "${YELLOW}Compiling kllama (JVM)...${RESET}"
299-
if ! $GRADLE :skainet-apps:skainet-kllama:jvmMainClasses --quiet 2>&1; then
309+
if ! $GRADLE :skainet-apps:skainet-kllama-cli:classes --quiet 2>&1; then
300310
echo -e "${RED}Compilation failed.${RESET}"
301311
exit 1
302312
fi
@@ -320,7 +330,7 @@ for model in "${models[@]}"; do
320330
output_file=$(mktemp)
321331
exit_code=0
322332

323-
$GRADLE "$TASK" --quiet --args="-m ${model} -s ${STEPS} -k ${TEMP} ${PROMPT}" \
333+
$GRADLE "$TASK" --quiet --args="-m ${model} -s ${STEPS} -k ${TEMP} \"${PROMPT}\"" \
324334
> "$output_file" 2>&1 || exit_code=$?
325335

326336
end_ts=$(python3 -c 'import time; print(time.time())')

0 commit comments

Comments
 (0)