|
| 1 | +package sk.ainet.models.llama |
| 2 | + |
| 3 | +import sk.ainet.context.ExecutionContext |
| 4 | +import sk.ainet.io.gguf.GGMLQuantizationType |
| 5 | +import sk.ainet.io.gguf.dequant.DequantOps |
| 6 | +import sk.ainet.lang.tensor.Shape |
| 7 | +import sk.ainet.lang.tensor.Tensor |
| 8 | +import sk.ainet.lang.tensor.data.DenseFloatArrayTensorData |
| 9 | +import sk.ainet.lang.tensor.data.IntArrayTensorData |
| 10 | +import sk.ainet.lang.tensor.data.TensorData |
| 11 | +import sk.ainet.lang.types.DType |
| 12 | +import sk.ainet.lang.types.FP32 |
| 13 | + |
| 14 | +/** |
| 15 | + * commonMain (Kotlin/Native-capable) converter for `NATIVE_OPTIMIZED` Llama weights — the Llama |
| 16 | + * analogue of `convertGemmaWeightsPacked`. Turns the raw-byte quantized tensors a NATIVE_OPTIMIZED |
| 17 | + * load produces into the forms the DSL matmul path consumes: |
| 18 | + * |
| 19 | + * - **Q4_K / Q5_K / Q6_K / Q8_0 matmul weights** → heap-packed `Q*BlockTensorData` (keep the |
| 20 | + * GGUF footprint; run the in-kernel dequant matmul, NEON on the board). |
| 21 | + * - **token_embd** → FP32 dequant in `[vocab, embed]` order (gathered, not matmul'd; no transpose). |
| 22 | + * - **everything else quantized without a packed kernel** → FP32 dequant transposed to `[out, in]`. |
| 23 | + * |
| 24 | + * No `java.lang.foreign` — runs on the board (Kotlin/Native) and JVM alike. |
| 25 | + */ |
| 26 | +public fun convertLlamaWeightsPacked( |
| 27 | + weights: DecoderGgufWeights<*, *>, |
| 28 | + ctx: ExecutionContext, |
| 29 | +): DecoderGgufWeights<*, *> { |
| 30 | + @Suppress("UNCHECKED_CAST") |
| 31 | + val typed = weights as DecoderGgufWeights<DType, Any> |
| 32 | + val quantTypes = typed.quantTypes |
| 33 | + if (quantTypes.isEmpty()) return weights |
| 34 | + |
| 35 | + val newTensors = linkedMapOf<String, Tensor<DType, Any>>() |
| 36 | + for ((name, tensor) in typed.tensors) { |
| 37 | + val qt = quantTypes[name] |
| 38 | + newTensors[name] = when { |
| 39 | + qt == null -> tensor // not quantized (norms, f32) |
| 40 | + else -> { |
| 41 | + val shape = logicalShapeFor(name, typed.metadata) |
| 42 | + if (shape == null) { |
| 43 | + tensor // unknown 2-D layout — leave as-is |
| 44 | + } else { |
| 45 | + val bytes = extractRawBytes(tensor.data) |
| 46 | + // token_embd is gathered (row lookup) → must be FP32. Other matrices (incl. |
| 47 | + // output/lm_head) stay packed and run the in-kernel matmul. |
| 48 | + val isEmbed = name == LlamaTensorNames.TOKEN_EMBEDDINGS |
| 49 | + val packed = if (!isEmbed) packLlamaKQuant<FP32>(bytes, qt, shape) else null |
| 50 | + when { |
| 51 | + packed != null -> { |
| 52 | + @Suppress("UNCHECKED_CAST") |
| 53 | + ctx.fromData(packed as TensorData<FP32, Float>, FP32::class) as Tensor<DType, Any> |
| 54 | + } |
| 55 | + isEmbed -> dequantNoTranspose(bytes, qt, shape, ctx) |
| 56 | + else -> dequantTransposed(bytes, qt, shape, ctx) |
| 57 | + } |
| 58 | + } |
| 59 | + } |
| 60 | + } |
| 61 | + } |
| 62 | + @Suppress("UNCHECKED_CAST") |
| 63 | + return DecoderGgufWeights(typed.metadata, newTensors, typed.quantTypes) as DecoderGgufWeights<*, *> |
| 64 | +} |
| 65 | + |
| 66 | +/** Dequant to FP32 in natural `[rows, cols]` order (embeddings — gathered, not matmul'd). */ |
| 67 | +@Suppress("UNCHECKED_CAST") |
| 68 | +private fun dequantNoTranspose( |
| 69 | + bytes: ByteArray, |
| 70 | + qt: GGMLQuantizationType, |
| 71 | + shape: Shape, |
| 72 | + ctx: ExecutionContext, |
| 73 | +): Tensor<DType, Any> { |
| 74 | + val floats = DequantOps.dequantFromBytes(bytes, qt, shape.volume) |
| 75 | + return ctx.fromData(DenseFloatArrayTensorData<FP32>(shape, floats), FP32::class) as Tensor<DType, Any> |
| 76 | +} |
| 77 | + |
| 78 | +/** Dequant to canonical FP32 `[out, in]` row-major (GGUF is column-major within a row). */ |
| 79 | +@Suppress("UNCHECKED_CAST") |
| 80 | +private fun dequantTransposed( |
| 81 | + bytes: ByteArray, |
| 82 | + qt: GGMLQuantizationType, |
| 83 | + shape: Shape, |
| 84 | + ctx: ExecutionContext, |
| 85 | +): Tensor<DType, Any> { |
| 86 | + val floats = DequantOps.dequantFromBytes(bytes, qt, shape.volume) |
| 87 | + val out = shape[0] |
| 88 | + val inDim = shape[1] |
| 89 | + val rowMajor = DequantOps.transposeColumnMajorToRowMajor(floats, inDim, out) |
| 90 | + return ctx.fromFloatArray<FP32, Float>(shape, FP32::class, rowMajor) as Tensor<DType, Any> |
| 91 | +} |
| 92 | + |
| 93 | +/** Read raw packed bytes back from a NATIVE_OPTIMIZED quant tensor (JVM IntArray / Native Byte). */ |
| 94 | +internal fun extractRawBytes(data: TensorData<*, *>): ByteArray { |
| 95 | + if (data is IntArrayTensorData<*>) { |
| 96 | + val buf = data.buffer |
| 97 | + return ByteArray(buf.size) { buf[it].toByte() } |
| 98 | + } |
| 99 | + val n = data.shape.volume |
| 100 | + @Suppress("UNCHECKED_CAST") |
| 101 | + val d = data as TensorData<*, Any?> |
| 102 | + return ByteArray(n) { |
| 103 | + when (val v = d[it]) { |
| 104 | + is Byte -> v |
| 105 | + is Int -> v.toByte() |
| 106 | + else -> error("convertLlamaWeightsPacked: cannot read bytes from ${data::class.simpleName}") |
| 107 | + } |
| 108 | + } |
| 109 | +} |
0 commit comments