Skip to content

Commit ef01e21

Browse files
Merge pull request #169 from SKaiNET-developers/fix/gemma-memseg-dequant-kernelless-quants
fix(gemma): dequant kernel-less quant types in NATIVE_OPTIMIZED instead of leaving raw bytes
2 parents b41e978 + bb1025f commit ef01e21

1 file changed

Lines changed: 30 additions & 7 deletions

File tree

llm-inference/gemma/src/jvmMain/kotlin/sk/ainet/models/gemma/GemmaMemSegConverter.kt

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -197,19 +197,42 @@ private fun <T : DType, V> convertOne(
197197
ctx.fromData(data as TensorData<FP32, Float>, advertisedDtype) as Tensor<T, V>
198198
}
199199
GGMLQuantizationType.Q5_K -> {
200-
// No native matmul kernel yet for Q5_K (not needed for Gemma 4
201-
// E2B Q4_K_M — it has no Q5_K tensors). Fall back to dequant.
202-
val elemCount = shape.volume
203-
val floats = DequantOps.dequantFromBytes(bytes, qt, elemCount)
204-
ctx.fromFloatArray<FP32, Float>(shape, advertisedDtype, floats) as Tensor<T, V>
200+
// No native matmul kernel yet for Q5_K. Fall back to a correct FP32 dequant.
201+
dequantPackedToFp32<T, V>(bytes, qt, shape, ctx)
205202
}
206203
else -> {
207-
println("WARNING: GemmaMemSegConverter: unsupported quant type $qt for '$name'; leaving as-is")
208-
tensor
204+
// Any other quant type without a packed SIMD kernel (Q5_0/Q5_1/Q4_1/Q2_K/…)
205+
// would otherwise be left as raw 1-D bytes, which `linearProject` then can't
206+
// transpose ("Transpose requires at least 2 dimensions"). Dequantize to a
207+
// correct FP32 `[out, in]` weight so the DSL path runs; the supported packed
208+
// types (Q4_0/Q8_0/Q4_K/Q6_K) above keep their fast SIMD form. This trades
209+
// those tensors' memory savings for correctness until a packed kernel exists.
210+
dequantPackedToFp32<T, V>(bytes, qt, shape, ctx)
209211
}
210212
}
211213
}
212214

215+
/**
216+
* Dequantize raw GGUF quant `bytes` of logical shape `[out, in]` to a canonical FP32
217+
* `[out, in]` row-major weight — the same layout `Gemma4WeightLoader.createTensor` produces
218+
* on the `DEQUANTIZE_TO_FP32` path. GGUF stores K/legacy-quant blocks column-major within a
219+
* row, so the dequantized floats are transposed column-major → row-major (rows = `in`,
220+
* cols = `out`) to match what `linearProject` (`x @ W.t()`) expects.
221+
*/
222+
@Suppress("UNCHECKED_CAST")
223+
private fun <T : DType, V> dequantPackedToFp32(
224+
bytes: ByteArray,
225+
qt: GGMLQuantizationType,
226+
shape: Shape,
227+
ctx: ExecutionContext,
228+
): Tensor<T, V> {
229+
val floats = DequantOps.dequantFromBytes(bytes, qt, shape.volume)
230+
val out = shape[0]
231+
val inDim = shape[1]
232+
val rowMajor = DequantOps.transposeColumnMajorToRowMajor(floats, inDim, out)
233+
return ctx.fromFloatArray<FP32, Float>(shape, FP32::class, rowMajor) as Tensor<T, V>
234+
}
235+
213236
/**
214237
* Wrap the raw Q-series bytes of `per_layer_token_embd.weight` in a
215238
* [GemmaPerLayerTokenEmbedTensorData] that dequants one row at a time.

0 commit comments

Comments
 (0)