@@ -4,56 +4,51 @@ import kotlinx.coroutines.runBlocking
44import kotlinx.io.buffered
55import kotlinx.io.files.Path
66import kotlinx.io.files.SystemFileSystem
7+ import kotlin.reflect.KClass
78import kotlin.time.measureTime
89import sk.ainet.apps.kllama.CpuAttentionBackend
910import sk.ainet.apps.kllama.GGUFTokenizer
10- import sk.ainet.apps.kllama.LlamaIngestion
11- import sk.ainet.apps.kllama.LlamaLoadConfig
11+ import sk.ainet.apps.kllama.Llama2DotCWeightLoader
12+ import sk.ainet.apps.kllama.TokenizerUtils
13+ import sk.ainet.apps.llm.InferenceRuntime
14+ import sk.ainet.apps.llm.OptimizedLLMMode
15+ import sk.ainet.apps.llm.OptimizedLLMRuntime
1216import sk.ainet.apps.llm.Tokenizer
13- import sk.ainet.apps.kllama.GpuAttentionBackend
1417import sk.ainet.apps.llm.backend.BackendRegistry
1518import sk.ainet.apps.llm.backend.availableNames
1619import sk.ainet.apps.llm.backend.bestAvailable
1720import sk.ainet.apps.llm.backend.find
18- import sk.ainet.models.llama.LlamaRuntime
19- import sk.ainet.models.llama.LlamaRuntimeInterface
20- import sk.ainet.apps.kllama.Llama2DotCWeightLoader
21- import sk.ainet.apps.kllama.TokenizerUtils
22- import sk.ainet.models.llama.LlamaRuntimeWeights
23- import sk.ainet.io.model.QuantPolicy
21+ import sk.ainet.apps.llm.generate
2422import sk.ainet.context.ExecutionContext
23+ import sk.ainet.io.model.QuantPolicy
2524import sk.ainet.lang.types.DType
2625import sk.ainet.lang.types.FP16
2726import sk.ainet.lang.types.FP32
28- import kotlin.reflect.KClass
27+ import sk.ainet.models.llama.DecoderGgufWeightLoader
28+ import sk.ainet.models.llama.LlamaNetworkLoader
29+ import sk.ainet.models.llama.LlamaRuntime
30+ import sk.ainet.models.llama.LlamaRuntimeWeights
2931
3032private fun usage (): Nothing {
31- println (" Usage: kllama <model> [tokenizer] <prompt> [steps=64] [temperature=0.8] [--backend=cpu] [--gpu-opt] [-- dtype=fp16|fp32]" )
33+ println (" Usage: kllama <model> [tokenizer] <prompt> [steps=64] [temperature=0.8] [--backend=cpu] [--dtype=fp16|fp32]" )
3234 println (" <model> Path to .gguf or .bin model" )
3335 println (" <tokenizer> Path to tokenizer.bin (required for .bin, optional for .gguf)" )
3436 println (" <prompt> Text prompt" )
3537 println (" --backend=NAME Execution backend (default: ${BackendRegistry .bestAvailable().name} )" )
36- println (" --gpu-opt Use GPU-optimized runtime (reduces CPU roundtrips)" )
37- println (" --graph Use MPSGraph compiled execution (Metal backend only)" )
3838 println (" --dtype=TYPE Tensor dtype: fp16 or fp32 (default: fp32)" )
3939 println (" --list-backends List available backends and exit" )
4040 println (" Available backends: ${BackendRegistry .availableNames().joinToString(" , " )} " )
4141 throw IllegalArgumentException (" Invalid arguments" )
4242}
4343
4444fun main (args : Array <String >) = runBlocking {
45- // Register platform-specific backends
4645 registerPlatformBackends()
4746
4847 var backendName: String? = null
49- var useGpuOpt = false
50- var useGraph = false
5148 var dtypeStr = " fp32"
5249 val filteredArgs = args.filter { arg ->
5350 when {
5451 arg.startsWith(" --backend=" ) -> { backendName = arg.substringAfter(" =" ); false }
55- arg == " --gpu-opt" -> { useGpuOpt = true ; false }
56- arg == " --graph" -> { useGraph = true ; useGpuOpt = true ; false }
5752 arg.startsWith(" --dtype=" ) -> { dtypeStr = arg.substringAfter(" =" ).lowercase(); false }
5853 arg == " --list-backends" -> {
5954 val providers = BackendRegistry .providers()
@@ -107,52 +102,62 @@ fun main(args: Array<String>) = runBlocking {
107102 val ctx = provider.createContext()
108103
109104 when (dtypeStr) {
110- " fp16" -> runInference<FP16 >(ctx, FP16 ::class , isGguf, modelPathStr, modelPath, useGpuOpt, useGraph, tokenizerPathStr, prompt, steps, temperature)
111- " fp32" -> runInference<FP32 >(ctx, FP32 ::class , isGguf, modelPathStr, modelPath, useGpuOpt, useGraph, tokenizerPathStr, prompt, steps, temperature)
105+ " fp16" -> runInference<FP16 >(ctx, FP16 ::class , isGguf, modelPathStr, modelPath, tokenizerPathStr, prompt, steps, temperature)
106+ " fp32" -> runInference<FP32 >(ctx, FP32 ::class , isGguf, modelPathStr, modelPath, tokenizerPathStr, prompt, steps, temperature)
112107 else -> error(" Unsupported dtype: $dtypeStr . Use fp16 or fp32." )
113108 }
114109}
115110
116- private suspend fun <T : DType > runInference (
111+ // Reified so we can call `LlamaNetworkLoader.fromWeights<T, V>` and
112+ // `DecoderGgufWeightLoader.loadToMap<T, V>` (both `inline reified T`).
113+ // The legacy `LlamaRuntime<T>` ctor doesn't need reification — only the
114+ // DSL path does.
115+ @Suppress(" DuplicatedCode" )
116+ private suspend inline fun <reified T : DType > runInference (
117117 ctx : ExecutionContext ,
118118 dtype : KClass <T >,
119119 isGguf : Boolean ,
120120 modelPathStr : String ,
121121 modelPath : Path ,
122- useGpuOpt : Boolean ,
123- useGraph : Boolean ,
124122 tokenizerPathStr : String? ,
125123 prompt : String ,
126124 steps : Int ,
127- temperature : Float
125+ temperature : Float ,
128126) {
129- val runtimeWeights = if (isGguf) {
130- val ingestion = LlamaIngestion <T >(
127+ val runtime: InferenceRuntime <T >
128+ val vocabSize: Int
129+
130+ if (isGguf) {
131+ // DSL path. Native has no MemorySegment, so QuantPolicy.DEQUANTIZE_TO_FP32
132+ // is the only viable choice.
133+ println (" Loading GGUF model from $modelPathStr (Llama, DSL streaming, dtype=${dtype.simpleName} )..." )
134+ val weights = DecoderGgufWeightLoader (
135+ sourceProvider = { SystemFileSystem .source(modelPath).buffered() },
136+ quantPolicy = QuantPolicy .DEQUANTIZE_TO_FP32 ,
137+ ).loadToMap<T , Float >(ctx)
138+ val model = LlamaNetworkLoader .fromWeights(weights)
139+ runtime = OptimizedLLMRuntime (
140+ model = model,
131141 ctx = ctx,
142+ mode = OptimizedLLMMode .DIRECT ,
132143 dtype = dtype,
133- config = LlamaLoadConfig (
134- quantPolicy = QuantPolicy .DEQUANTIZE_TO_FP32 ,
135- allowQuantized = false
136- )
144+ bos = weights.metadata.bosTokenId,
137145 )
138- println (" Loading GGUF model from $modelPathStr (dtype=${dtype.simpleName} )..." )
139- ingestion.load {
140- SystemFileSystem .source(modelPath).buffered()
141- }
146+ vocabSize = weights.metadata.vocabSize
142147 } else {
148+ // BIN (Karpathy llama2.c format) — kept on legacy LlamaRuntime; the
149+ // .bin loader returns LlamaRuntimeWeights directly. Migrating .bin
150+ // to the DSL path requires a converter and isn't in scope here.
143151 println (" Loading Karpathy .bin model from $modelPathStr ..." )
144152 @Suppress(" UNCHECKED_CAST" )
145- Llama2DotCWeightLoader .load(ctx, SystemFileSystem .source(modelPath).buffered()) as LlamaRuntimeWeights <T >
153+ val runtimeWeights = Llama2DotCWeightLoader .load(ctx, SystemFileSystem .source(modelPath).buffered())
154+ as LlamaRuntimeWeights <T >
155+ val cpuBackend = CpuAttentionBackend <T >(ctx, runtimeWeights, dtype)
156+ @Suppress(" DEPRECATION" )
157+ runtime = LlamaRuntime <T >(ctx, runtimeWeights, cpuBackend, dtype)
158+ vocabSize = runtimeWeights.metadata.vocabSize
146159 }
147160
148- val graphAccelerator = if (useGraph) {
149- println (" Compiling MPSGraph layer graphs..." )
150- createGraphAccelerator(ctx, runtimeWeights, dtype, 1e- 5f )
151- } else null
152-
153- val cpuBackend = CpuAttentionBackend <T >(ctx, runtimeWeights, dtype)
154- val runtime = LlamaRuntime <T >(ctx, runtimeWeights, cpuBackend, dtype, graphAccelerator = graphAccelerator)
155-
156161 val tokenizer: Tokenizer = if (isGguf && tokenizerPathStr == null ) {
157162 println (" Loading embedded GGUF tokenizer..." )
158163 GGUFTokenizer .fromSource(SystemFileSystem .source(modelPath).buffered())
@@ -161,7 +166,7 @@ private suspend fun <T : DType> runInference(
161166 val tPath = Path (tPathStr)
162167 if (! SystemFileSystem .exists(tPath)) error(" Tokenizer not found: $tPathStr " )
163168 println (" Loading tokenizer from $tPathStr ..." )
164- TokenizerUtils .buildTokenizer(SystemFileSystem .source(tPath).buffered(), runtimeWeights.metadata. vocabSize)
169+ TokenizerUtils .buildTokenizer(SystemFileSystem .source(tPath).buffered(), vocabSize)
165170 }
166171
167172 val promptTokens = tokenizer.encode(prompt)
0 commit comments