Skip to content

Commit e4a0799

Browse files
Merge pull request #195 from SKaiNET-developers/fix/llama-gguf-orientation
feat(llama): NATIVE_OPTIMIZED packed weight path (mirror Gemma)
2 parents 80e3cc1 + ccbd87e commit e4a0799

3 files changed

Lines changed: 227 additions & 1 deletion

File tree

llm-inference/llama/src/commonMain/kotlin/sk/ainet/models/llama/LlamaNetworkLoader.kt

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,18 @@ public class LlamaNetworkLoader @PublishedApi internal constructor(
156156
}
157157
}
158158

159-
return applyWeightsToNetwork(weights)
159+
// NATIVE_OPTIMIZED keeps quantized tensors as raw 1-D bytes; convert them to the packed /
160+
// FP32 forms the DSL matmul + gather paths consume (mirrors the Gemma packed path).
161+
val ggufPolicy = (weightsProvider as? WeightsProvider.GgufSource)?.quantPolicy
162+
?: (weightsProvider as? WeightsProvider.GgufRandomAccess)?.quantPolicy
163+
val finalWeights: DecoderGgufWeights<T, V> = if (ggufPolicy == QuantPolicy.NATIVE_OPTIMIZED) {
164+
@Suppress("UNCHECKED_CAST")
165+
convertLlamaWeightsPacked(weights, ctx) as DecoderGgufWeights<T, V>
166+
} else {
167+
weights
168+
}
169+
170+
return applyWeightsToNetwork(finalWeights)
160171
}
161172

162173
/**
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
package sk.ainet.models.llama
2+
3+
import sk.ainet.context.ExecutionContext
4+
import sk.ainet.io.gguf.GGMLQuantizationType
5+
import sk.ainet.io.gguf.dequant.DequantOps
6+
import sk.ainet.lang.tensor.Shape
7+
import sk.ainet.lang.tensor.Tensor
8+
import sk.ainet.lang.tensor.data.DenseFloatArrayTensorData
9+
import sk.ainet.lang.tensor.data.IntArrayTensorData
10+
import sk.ainet.lang.tensor.data.TensorData
11+
import sk.ainet.lang.types.DType
12+
import sk.ainet.lang.types.FP32
13+
14+
/**
15+
* commonMain (Kotlin/Native-capable) converter for `NATIVE_OPTIMIZED` Llama weights — the Llama
16+
* analogue of `convertGemmaWeightsPacked`. Turns the raw-byte quantized tensors a NATIVE_OPTIMIZED
17+
* load produces into the forms the DSL matmul path consumes:
18+
*
19+
* - **Q4_K / Q5_K / Q6_K / Q8_0 matmul weights** → heap-packed `Q*BlockTensorData` (keep the
20+
* GGUF footprint; run the in-kernel dequant matmul, NEON on the board).
21+
* - **token_embd** → FP32 dequant in `[vocab, embed]` order (gathered, not matmul'd; no transpose).
22+
* - **everything else quantized without a packed kernel** → FP32 dequant transposed to `[out, in]`.
23+
*
24+
* No `java.lang.foreign` — runs on the board (Kotlin/Native) and JVM alike.
25+
*/
26+
public fun convertLlamaWeightsPacked(
27+
weights: DecoderGgufWeights<*, *>,
28+
ctx: ExecutionContext,
29+
): DecoderGgufWeights<*, *> {
30+
@Suppress("UNCHECKED_CAST")
31+
val typed = weights as DecoderGgufWeights<DType, Any>
32+
val quantTypes = typed.quantTypes
33+
if (quantTypes.isEmpty()) return weights
34+
35+
val newTensors = linkedMapOf<String, Tensor<DType, Any>>()
36+
for ((name, tensor) in typed.tensors) {
37+
val qt = quantTypes[name]
38+
newTensors[name] = when {
39+
qt == null -> tensor // not quantized (norms, f32)
40+
else -> {
41+
val shape = logicalShapeFor(name, typed.metadata)
42+
if (shape == null) {
43+
tensor // unknown 2-D layout — leave as-is
44+
} else {
45+
val bytes = extractRawBytes(tensor.data)
46+
// token_embd is gathered (row lookup) → must be FP32. Other matrices (incl.
47+
// output/lm_head) stay packed and run the in-kernel matmul.
48+
val isEmbed = name == LlamaTensorNames.TOKEN_EMBEDDINGS
49+
val packed = if (!isEmbed) packLlamaKQuant<FP32>(bytes, qt, shape) else null
50+
when {
51+
packed != null -> {
52+
@Suppress("UNCHECKED_CAST")
53+
ctx.fromData(packed as TensorData<FP32, Float>, FP32::class) as Tensor<DType, Any>
54+
}
55+
isEmbed -> dequantNoTranspose(bytes, qt, shape, ctx)
56+
else -> dequantTransposed(bytes, qt, shape, ctx)
57+
}
58+
}
59+
}
60+
}
61+
}
62+
@Suppress("UNCHECKED_CAST")
63+
return DecoderGgufWeights(typed.metadata, newTensors, typed.quantTypes) as DecoderGgufWeights<*, *>
64+
}
65+
66+
/** Dequant to FP32 in natural `[rows, cols]` order (embeddings — gathered, not matmul'd). */
67+
@Suppress("UNCHECKED_CAST")
68+
private fun dequantNoTranspose(
69+
bytes: ByteArray,
70+
qt: GGMLQuantizationType,
71+
shape: Shape,
72+
ctx: ExecutionContext,
73+
): Tensor<DType, Any> {
74+
val floats = DequantOps.dequantFromBytes(bytes, qt, shape.volume)
75+
return ctx.fromData(DenseFloatArrayTensorData<FP32>(shape, floats), FP32::class) as Tensor<DType, Any>
76+
}
77+
78+
/** Dequant to canonical FP32 `[out, in]` row-major (GGUF is column-major within a row). */
79+
@Suppress("UNCHECKED_CAST")
80+
private fun dequantTransposed(
81+
bytes: ByteArray,
82+
qt: GGMLQuantizationType,
83+
shape: Shape,
84+
ctx: ExecutionContext,
85+
): Tensor<DType, Any> {
86+
val floats = DequantOps.dequantFromBytes(bytes, qt, shape.volume)
87+
val out = shape[0]
88+
val inDim = shape[1]
89+
val rowMajor = DequantOps.transposeColumnMajorToRowMajor(floats, inDim, out)
90+
return ctx.fromFloatArray<FP32, Float>(shape, FP32::class, rowMajor) as Tensor<DType, Any>
91+
}
92+
93+
/** Read raw packed bytes back from a NATIVE_OPTIMIZED quant tensor (JVM IntArray / Native Byte). */
94+
internal fun extractRawBytes(data: TensorData<*, *>): ByteArray {
95+
if (data is IntArrayTensorData<*>) {
96+
val buf = data.buffer
97+
return ByteArray(buf.size) { buf[it].toByte() }
98+
}
99+
val n = data.shape.volume
100+
@Suppress("UNCHECKED_CAST")
101+
val d = data as TensorData<*, Any?>
102+
return ByteArray(n) {
103+
when (val v = d[it]) {
104+
is Byte -> v
105+
is Int -> v.toByte()
106+
else -> error("convertLlamaWeightsPacked: cannot read bytes from ${data::class.simpleName}")
107+
}
108+
}
109+
}
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
package sk.ainet.models.llama
2+
3+
import sk.ainet.io.gguf.GGMLQuantizationType
4+
import sk.ainet.lang.tensor.Shape
5+
import sk.ainet.lang.tensor.data.Q4_KBlockTensorData
6+
import sk.ainet.lang.tensor.data.Q5_KBlockTensorData
7+
import sk.ainet.lang.tensor.data.Q6_KBlockTensorData
8+
import sk.ainet.lang.tensor.data.Q8_0BlockTensorData
9+
import sk.ainet.lang.tensor.data.TensorData
10+
import sk.ainet.lang.types.DType
11+
12+
/**
13+
* Platform-neutral (commonMain) layout helpers for Llama quantized weights — the Llama analogue
14+
* of `GemmaQuantLayout`. A `NATIVE_OPTIMIZED` load stores quantized tensors as 1-D byte arrays,
15+
* so the converter needs the original `[out, in]` shape (from metadata) to relayout blocks.
16+
*/
17+
18+
/**
19+
* Recover the logical 2-D `[out, in]` shape of a Llama weight from its GGUF name + metadata.
20+
* Null for tensors without a 2-D matmul layout (norms etc.). Llama has uniform per-layer dims,
21+
* so metadata is authoritative.
22+
*/
23+
internal fun logicalShapeFor(name: String, metadata: LlamaModelMetadata): Shape? {
24+
val embed = metadata.embeddingLength
25+
val vocab = metadata.vocabSize
26+
val headDim = if (metadata.headCount > 0) embed / metadata.headCount else 0
27+
val qDim = metadata.headCount * headDim
28+
val kvDim = metadata.kvHeadCount * headDim
29+
val ffn = metadata.feedForwardLength
30+
return when {
31+
name == LlamaTensorNames.TOKEN_EMBEDDINGS -> Shape(vocab, embed)
32+
name == LlamaTensorNames.OUTPUT_WEIGHT -> Shape(vocab, embed)
33+
name.startsWith("blk.") -> when {
34+
name.endsWith(".attn_q.weight") -> Shape(qDim, embed)
35+
name.endsWith(".attn_k.weight") -> Shape(kvDim, embed)
36+
name.endsWith(".attn_v.weight") -> Shape(kvDim, embed)
37+
name.endsWith(".attn_output.weight") -> Shape(embed, qDim)
38+
name.endsWith(".ffn_gate.weight") -> Shape(ffn, embed)
39+
name.endsWith(".ffn_up.weight") -> Shape(ffn, embed)
40+
name.endsWith(".ffn_down.weight") -> Shape(embed, ffn)
41+
else -> null
42+
}
43+
else -> null
44+
}
45+
}
46+
47+
/**
48+
* Re-layout GGUF K-series bytes from row-major block order to the input-block-major order the
49+
* `matmulQ{K}` kernels expect. For a `[outDim, inDim]` weight with `inDim % 256 == 0` this is a
50+
* block-level 2-D transpose; bytes inside a block are untouched. (Mirror of GemmaQuantLayout.)
51+
*/
52+
internal fun relayoutKSeriesRowMajorToBlockMajor(
53+
bytes: ByteArray,
54+
shape: Shape,
55+
bytesPerBlock: Int,
56+
blockSize: Int = 256,
57+
): ByteArray {
58+
require(shape.rank == 2) { "K-series weight must be 2D, got rank ${shape.rank}" }
59+
val outDim = shape[0]
60+
val inDim = shape[1]
61+
require(inDim % blockSize == 0) { "K-series weight inDim ($inDim) must be a multiple of $blockSize" }
62+
val blocksPerRow = inDim / blockSize
63+
val expected = outDim.toLong() * blocksPerRow.toLong() * bytesPerBlock.toLong()
64+
require(bytes.size.toLong() >= expected) {
65+
"K-series byte buffer ${bytes.size} < expected $expected for [$outDim, $inDim] @ ${bytesPerBlock}B/block"
66+
}
67+
val out = ByteArray(bytes.size)
68+
for (r in 0 until outDim) {
69+
for (b in 0 until blocksPerRow) {
70+
val srcOff = (r * blocksPerRow + b) * bytesPerBlock
71+
val dstOff = (b * outDim + r) * bytesPerBlock
72+
bytes.copyInto(out, dstOff, srcOff, srcOff + bytesPerBlock)
73+
}
74+
}
75+
return out
76+
}
77+
78+
private fun quantBlockLayout(qt: GGMLQuantizationType): Pair<Int, Int>? = when (qt) {
79+
GGMLQuantizationType.Q4_K -> 256 to 144
80+
GGMLQuantizationType.Q5_K -> 256 to 176
81+
GGMLQuantizationType.Q6_K -> 256 to 210
82+
GGMLQuantizationType.Q8_0 -> 32 to 34
83+
else -> null
84+
}
85+
86+
/**
87+
* Pack raw GGUF `bytes` of logical `[out, in]` shape into heap-packed block tensor data the
88+
* matmul kernels read directly (Q4_K / Q5_K / Q6_K / Q8_0), with the row-major → block-major
89+
* relayout. Null for types without a packed kernel (caller dequantizes those to FP32).
90+
*/
91+
internal fun <T : DType> packLlamaKQuant(
92+
bytes: ByteArray,
93+
qt: GGMLQuantizationType,
94+
shape: Shape,
95+
): TensorData<T, *>? {
96+
val (blockElems, bpb) = quantBlockLayout(qt) ?: return null
97+
val relaid = relayoutKSeriesRowMajorToBlockMajor(bytes, shape, bpb, blockElems)
98+
@Suppress("UNCHECKED_CAST")
99+
return when (qt) {
100+
GGMLQuantizationType.Q4_K -> Q4_KBlockTensorData(shape, relaid) as TensorData<T, *>
101+
GGMLQuantizationType.Q5_K -> Q5_KBlockTensorData(shape, relaid) as TensorData<T, *>
102+
GGMLQuantizationType.Q6_K -> Q6_KBlockTensorData(shape, relaid) as TensorData<T, *>
103+
GGMLQuantizationType.Q8_0 -> Q8_0BlockTensorData(shape, relaid) as TensorData<T, *>
104+
else -> null
105+
}
106+
}

0 commit comments

Comments
 (0)