Skip to content

Commit 689a283

Browse files
Merge pull request #179 from SKaiNET-developers/fix/gemma-board-embed-nocopy
fix(gemma): keep tied Q8_0 lm_head packed in eager NATIVE_OPTIMIZED path (#178)
2 parents aec06ba + f94ce6c commit 689a283

2 files changed

Lines changed: 38 additions & 15 deletions

File tree

llm-inference/gemma/src/commonMain/kotlin/sk/ainet/models/gemma/GemmaPackedWeights.kt

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import sk.ainet.io.gguf.GGMLQuantizationType
55
import sk.ainet.io.gguf.dequant.DequantOps
66
import sk.ainet.lang.tensor.Shape
77
import sk.ainet.lang.tensor.Tensor
8+
import sk.ainet.lang.tensor.data.DenseFloatArrayTensorData
89
import sk.ainet.lang.tensor.data.IntArrayTensorData
910
import sk.ainet.lang.tensor.data.TensorData
1011
import sk.ainet.lang.types.DType
@@ -48,8 +49,12 @@ public fun convertGemmaWeightsPacked(
4849
tensor // unknown 2-D layout — leave as-is
4950
} else {
5051
val bytes = extractRawBytes(tensor.data)
51-
val isEmbed = name == Gemma4TensorNames.TOKEN_EMBEDDINGS ||
52-
name == Gemma4TensorNames.OUTPUT_WEIGHT
52+
// Only the token-embedding table is gathered (row lookup) and so
53+
// must be FP32 here. `output`/lm_head is a real matmul weight —
54+
// it stays packed (FunctionGemma's tied output is Q8_0 → NEON
55+
// Q8_0 kernel, transposed lazily by ops.transpose) instead of a
56+
// second ~0.67 GB FP32 copy that would OOM the 1.9 GB board.
57+
val isEmbed = name == Gemma4TensorNames.TOKEN_EMBEDDINGS
5358
val packed = if (!isEmbed) packGemmaKQuant<FP32>(bytes, qt, shape) else null
5459
when {
5560
packed != null -> {
@@ -76,7 +81,11 @@ private fun dequantNoTranspose(
7681
ctx: ExecutionContext,
7782
): Tensor<DType, Any> {
7883
val floats = DequantOps.dequantFromBytes(bytes, qt, shape.volume)
79-
return ctx.fromFloatArray<FP32, Float>(shape, FP32::class, floats) as Tensor<DType, Any>
84+
// Wrap the dequant array directly (no-copy) rather than ctx.fromFloatArray,
85+
// which routes through BufferHandleFactory.owned and allocates a second
86+
// full-size buffer — for the 262k×640 FP32 token_embd (~0.67 GB) that
87+
// transient double is itself enough to OOM the 1.9 GB board.
88+
return ctx.fromData(DenseFloatArrayTensorData<FP32>(shape, floats), FP32::class) as Tensor<DType, Any>
8089
}
8190

8291
/**

llm-inference/gemma/src/commonMain/kotlin/sk/ainet/models/gemma/GemmaQuantLayout.kt

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import sk.ainet.lang.tensor.Shape
55
import sk.ainet.lang.tensor.data.Q4_KBlockTensorData
66
import sk.ainet.lang.tensor.data.Q5_KBlockTensorData
77
import sk.ainet.lang.tensor.data.Q6_KBlockTensorData
8+
import sk.ainet.lang.tensor.data.Q8_0BlockTensorData
89
import sk.ainet.lang.tensor.data.TensorData
910
import sk.ainet.lang.types.DType
1011

@@ -66,8 +67,8 @@ internal fun relayoutKSeriesRowMajorToBlockMajor(
6667
bytes: ByteArray,
6768
shape: Shape,
6869
bytesPerBlock: Int,
70+
blockSize: Int = 256,
6971
): ByteArray {
70-
val blockSize = 256
7172
require(shape.rank == 2) { "K-series weight must be 2D, got rank ${shape.rank}" }
7273
val outDim = shape[0]
7374
val inDim = shape[1]
@@ -88,19 +89,31 @@ internal fun relayoutKSeriesRowMajorToBlockMajor(
8889
return out
8990
}
9091

91-
/** Bytes per ggml block for the K-quant types this packer handles. */
92-
private fun kQuantBytesPerBlock(qt: GGMLQuantizationType): Int? = when (qt) {
93-
GGMLQuantizationType.Q4_K -> 144
94-
GGMLQuantizationType.Q5_K -> 176
95-
GGMLQuantizationType.Q6_K -> 210
92+
/**
93+
* Block geometry `(blockElems, bytesPerBlock)` for the quant types this packer
94+
* handles. The K-series are 256-element super-blocks; Q8_0 is a 32-element block
95+
* (f16 scale + 32 int8). All four have a first-class CPU matmul kernel + a lazy
96+
* transpose in `ops.transpose`, so all four can stay packed instead of FP32.
97+
*/
98+
private fun quantBlockLayout(qt: GGMLQuantizationType): Pair<Int, Int>? = when (qt) {
99+
GGMLQuantizationType.Q4_K -> 256 to 144
100+
GGMLQuantizationType.Q5_K -> 256 to 176
101+
GGMLQuantizationType.Q6_K -> 256 to 210
102+
GGMLQuantizationType.Q8_0 -> 32 to 34
96103
else -> null
97104
}
98105

99106
/**
100-
* Pack raw GGUF K-quant `bytes` of logical `[out, in]` shape into the
101-
* heap-packed block tensor data the matmul kernels read directly (Q4_K / Q5_K /
102-
* Q6_K). Performs the row-major → block-major relayout. Returns `null` for
103-
* non-K-quant types (caller dequantizes those to FP32).
107+
* Pack raw GGUF `bytes` of logical `[out, in]` shape into the heap-packed block
108+
* tensor data the matmul kernels read directly (Q4_K / Q5_K / Q6_K / Q8_0).
109+
* Performs the row-major → block-major relayout. Returns `null` for types
110+
* without a packed kernel (caller dequantizes those to FP32).
111+
*
112+
* Q8_0 matters for gemma's tied `output`/lm_head: FunctionGemma's token_embd is
113+
* Q8_0, so keeping the lm_head packed (vs ~0.67 GB FP32) is what lets the eager
114+
* decode fit the 1.9 GB board, and it runs on the NEON Q8_0 kernel. (Requires
115+
* the Q8_0 case in `ops.transpose` — engine — so `linearProject` can transpose
116+
* the packed weight; see transformers #178.)
104117
*
105118
* commonMain → works on JVM and Kotlin/Native alike (no MemSeg / Arena).
106119
*/
@@ -109,13 +122,14 @@ internal fun <T : DType> packGemmaKQuant(
109122
qt: GGMLQuantizationType,
110123
shape: Shape,
111124
): TensorData<T, *>? {
112-
val bpb = kQuantBytesPerBlock(qt) ?: return null
113-
val relaid = relayoutKSeriesRowMajorToBlockMajor(bytes, shape, bpb)
125+
val (blockElems, bpb) = quantBlockLayout(qt) ?: return null
126+
val relaid = relayoutKSeriesRowMajorToBlockMajor(bytes, shape, bpb, blockElems)
114127
@Suppress("UNCHECKED_CAST")
115128
return when (qt) {
116129
GGMLQuantizationType.Q4_K -> Q4_KBlockTensorData(shape, relaid) as TensorData<T, *>
117130
GGMLQuantizationType.Q5_K -> Q5_KBlockTensorData(shape, relaid) as TensorData<T, *>
118131
GGMLQuantizationType.Q6_K -> Q6_KBlockTensorData(shape, relaid) as TensorData<T, *>
132+
GGMLQuantizationType.Q8_0 -> Q8_0BlockTensorData(shape, relaid) as TensorData<T, *>
119133
else -> null
120134
}
121135
}

0 commit comments

Comments
 (0)