Skip to content

Commit 6662c35

Browse files
Merge pull request #130 from SKaiNET-developers/feat/native-kllama-dsl-swap
feat(kllama-native): swap CLI to DSL path; drop GPU stubs
2 parents c4d5f61 + 35aac6b commit 6662c35

1 file changed

Lines changed: 49 additions & 44 deletions

File tree

  • llm-runtime/kllama/src/nativeMain/kotlin/sk/ainet/apps/kllama/cli

llm-runtime/kllama/src/nativeMain/kotlin/sk/ainet/apps/kllama/cli/Main.kt

Lines changed: 49 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -4,56 +4,51 @@ import kotlinx.coroutines.runBlocking
44
import kotlinx.io.buffered
55
import kotlinx.io.files.Path
66
import kotlinx.io.files.SystemFileSystem
7+
import kotlin.reflect.KClass
78
import kotlin.time.measureTime
89
import sk.ainet.apps.kllama.CpuAttentionBackend
910
import sk.ainet.apps.kllama.GGUFTokenizer
10-
import sk.ainet.apps.kllama.LlamaIngestion
11-
import sk.ainet.apps.kllama.LlamaLoadConfig
11+
import sk.ainet.apps.kllama.Llama2DotCWeightLoader
12+
import sk.ainet.apps.kllama.TokenizerUtils
13+
import sk.ainet.apps.llm.InferenceRuntime
14+
import sk.ainet.apps.llm.OptimizedLLMMode
15+
import sk.ainet.apps.llm.OptimizedLLMRuntime
1216
import sk.ainet.apps.llm.Tokenizer
13-
import sk.ainet.apps.kllama.GpuAttentionBackend
1417
import sk.ainet.apps.llm.backend.BackendRegistry
1518
import sk.ainet.apps.llm.backend.availableNames
1619
import sk.ainet.apps.llm.backend.bestAvailable
1720
import sk.ainet.apps.llm.backend.find
18-
import sk.ainet.models.llama.LlamaRuntime
19-
import sk.ainet.models.llama.LlamaRuntimeInterface
20-
import sk.ainet.apps.kllama.Llama2DotCWeightLoader
21-
import sk.ainet.apps.kllama.TokenizerUtils
22-
import sk.ainet.models.llama.LlamaRuntimeWeights
23-
import sk.ainet.io.model.QuantPolicy
21+
import sk.ainet.apps.llm.generate
2422
import sk.ainet.context.ExecutionContext
23+
import sk.ainet.io.model.QuantPolicy
2524
import sk.ainet.lang.types.DType
2625
import sk.ainet.lang.types.FP16
2726
import sk.ainet.lang.types.FP32
28-
import kotlin.reflect.KClass
27+
import sk.ainet.models.llama.DecoderGgufWeightLoader
28+
import sk.ainet.models.llama.LlamaNetworkLoader
29+
import sk.ainet.models.llama.LlamaRuntime
30+
import sk.ainet.models.llama.LlamaRuntimeWeights
2931

3032
private fun usage(): Nothing {
31-
println("Usage: kllama <model> [tokenizer] <prompt> [steps=64] [temperature=0.8] [--backend=cpu] [--gpu-opt] [--dtype=fp16|fp32]")
33+
println("Usage: kllama <model> [tokenizer] <prompt> [steps=64] [temperature=0.8] [--backend=cpu] [--dtype=fp16|fp32]")
3234
println(" <model> Path to .gguf or .bin model")
3335
println(" <tokenizer> Path to tokenizer.bin (required for .bin, optional for .gguf)")
3436
println(" <prompt> Text prompt")
3537
println(" --backend=NAME Execution backend (default: ${BackendRegistry.bestAvailable().name})")
36-
println(" --gpu-opt Use GPU-optimized runtime (reduces CPU roundtrips)")
37-
println(" --graph Use MPSGraph compiled execution (Metal backend only)")
3838
println(" --dtype=TYPE Tensor dtype: fp16 or fp32 (default: fp32)")
3939
println(" --list-backends List available backends and exit")
4040
println("Available backends: ${BackendRegistry.availableNames().joinToString(", ")}")
4141
throw IllegalArgumentException("Invalid arguments")
4242
}
4343

4444
fun main(args: Array<String>) = runBlocking {
45-
// Register platform-specific backends
4645
registerPlatformBackends()
4746

4847
var backendName: String? = null
49-
var useGpuOpt = false
50-
var useGraph = false
5148
var dtypeStr = "fp32"
5249
val filteredArgs = args.filter { arg ->
5350
when {
5451
arg.startsWith("--backend=") -> { backendName = arg.substringAfter("="); false }
55-
arg == "--gpu-opt" -> { useGpuOpt = true; false }
56-
arg == "--graph" -> { useGraph = true; useGpuOpt = true; false }
5752
arg.startsWith("--dtype=") -> { dtypeStr = arg.substringAfter("=").lowercase(); false }
5853
arg == "--list-backends" -> {
5954
val providers = BackendRegistry.providers()
@@ -107,52 +102,62 @@ fun main(args: Array<String>) = runBlocking {
107102
val ctx = provider.createContext()
108103

109104
when (dtypeStr) {
110-
"fp16" -> runInference<FP16>(ctx, FP16::class, isGguf, modelPathStr, modelPath, useGpuOpt, useGraph, tokenizerPathStr, prompt, steps, temperature)
111-
"fp32" -> runInference<FP32>(ctx, FP32::class, isGguf, modelPathStr, modelPath, useGpuOpt, useGraph, tokenizerPathStr, prompt, steps, temperature)
105+
"fp16" -> runInference<FP16>(ctx, FP16::class, isGguf, modelPathStr, modelPath, tokenizerPathStr, prompt, steps, temperature)
106+
"fp32" -> runInference<FP32>(ctx, FP32::class, isGguf, modelPathStr, modelPath, tokenizerPathStr, prompt, steps, temperature)
112107
else -> error("Unsupported dtype: $dtypeStr. Use fp16 or fp32.")
113108
}
114109
}
115110

116-
private suspend fun <T : DType> runInference(
111+
// Reified so we can call `LlamaNetworkLoader.fromWeights<T, V>` and
112+
// `DecoderGgufWeightLoader.loadToMap<T, V>` (both `inline reified T`).
113+
// The legacy `LlamaRuntime<T>` ctor doesn't need reification — only the
114+
// DSL path does.
115+
@Suppress("DuplicatedCode")
116+
private suspend inline fun <reified T : DType> runInference(
117117
ctx: ExecutionContext,
118118
dtype: KClass<T>,
119119
isGguf: Boolean,
120120
modelPathStr: String,
121121
modelPath: Path,
122-
useGpuOpt: Boolean,
123-
useGraph: Boolean,
124122
tokenizerPathStr: String?,
125123
prompt: String,
126124
steps: Int,
127-
temperature: Float
125+
temperature: Float,
128126
) {
129-
val runtimeWeights = if (isGguf) {
130-
val ingestion = LlamaIngestion<T>(
127+
val runtime: InferenceRuntime<T>
128+
val vocabSize: Int
129+
130+
if (isGguf) {
131+
// DSL path. Native has no MemorySegment, so QuantPolicy.DEQUANTIZE_TO_FP32
132+
// is the only viable choice.
133+
println("Loading GGUF model from $modelPathStr (Llama, DSL streaming, dtype=${dtype.simpleName})...")
134+
val weights = DecoderGgufWeightLoader(
135+
sourceProvider = { SystemFileSystem.source(modelPath).buffered() },
136+
quantPolicy = QuantPolicy.DEQUANTIZE_TO_FP32,
137+
).loadToMap<T, Float>(ctx)
138+
val model = LlamaNetworkLoader.fromWeights(weights)
139+
runtime = OptimizedLLMRuntime(
140+
model = model,
131141
ctx = ctx,
142+
mode = OptimizedLLMMode.DIRECT,
132143
dtype = dtype,
133-
config = LlamaLoadConfig(
134-
quantPolicy = QuantPolicy.DEQUANTIZE_TO_FP32,
135-
allowQuantized = false
136-
)
144+
bos = weights.metadata.bosTokenId,
137145
)
138-
println("Loading GGUF model from $modelPathStr (dtype=${dtype.simpleName})...")
139-
ingestion.load {
140-
SystemFileSystem.source(modelPath).buffered()
141-
}
146+
vocabSize = weights.metadata.vocabSize
142147
} else {
148+
// BIN (Karpathy llama2.c format) — kept on legacy LlamaRuntime; the
149+
// .bin loader returns LlamaRuntimeWeights directly. Migrating .bin
150+
// to the DSL path requires a converter and isn't in scope here.
143151
println("Loading Karpathy .bin model from $modelPathStr...")
144152
@Suppress("UNCHECKED_CAST")
145-
Llama2DotCWeightLoader.load(ctx, SystemFileSystem.source(modelPath).buffered()) as LlamaRuntimeWeights<T>
153+
val runtimeWeights = Llama2DotCWeightLoader.load(ctx, SystemFileSystem.source(modelPath).buffered())
154+
as LlamaRuntimeWeights<T>
155+
val cpuBackend = CpuAttentionBackend<T>(ctx, runtimeWeights, dtype)
156+
@Suppress("DEPRECATION")
157+
runtime = LlamaRuntime<T>(ctx, runtimeWeights, cpuBackend, dtype)
158+
vocabSize = runtimeWeights.metadata.vocabSize
146159
}
147160

148-
val graphAccelerator = if (useGraph) {
149-
println("Compiling MPSGraph layer graphs...")
150-
createGraphAccelerator(ctx, runtimeWeights, dtype, 1e-5f)
151-
} else null
152-
153-
val cpuBackend = CpuAttentionBackend<T>(ctx, runtimeWeights, dtype)
154-
val runtime = LlamaRuntime<T>(ctx, runtimeWeights, cpuBackend, dtype, graphAccelerator = graphAccelerator)
155-
156161
val tokenizer: Tokenizer = if (isGguf && tokenizerPathStr == null) {
157162
println("Loading embedded GGUF tokenizer...")
158163
GGUFTokenizer.fromSource(SystemFileSystem.source(modelPath).buffered())
@@ -161,7 +166,7 @@ private suspend fun <T : DType> runInference(
161166
val tPath = Path(tPathStr)
162167
if (!SystemFileSystem.exists(tPath)) error("Tokenizer not found: $tPathStr")
163168
println("Loading tokenizer from $tPathStr...")
164-
TokenizerUtils.buildTokenizer(SystemFileSystem.source(tPath).buffered(), runtimeWeights.metadata.vocabSize)
169+
TokenizerUtils.buildTokenizer(SystemFileSystem.source(tPath).buffered(), vocabSize)
165170
}
166171

167172
val promptTokens = tokenizer.encode(prompt)

0 commit comments

Comments
 (0)