Skip to content

Commit 804cd74

Browse files
michalharakalclaude
andcommitted
fix(llama): dequantize Q4_1 (and all non-packed quant types) in DecoderGgufMemSegConverter
DecoderGgufMemSegConverter only handled Q4_0/Q8_0 (packed) and Q4_K/Q5_K/Q6_K (dequant); every other quant type fell into an else branch that logged a warning and passed the raw quant bytes through unchanged. The forward pass then crashed deep inside matmul with a dtype/layout mismatch (e.g. Q4_1 Qwen3 models: 'unsupported quant type Q4_1 for blk.0.ffn_down.weight'). Route the else branch through DequantOps.dequantFromBytes to FP32 — the same memory-for-correctness trade-off already used for K-quants. This covers Q4_1, Q5_0, Q5_1, Q8_1, IQ4_NL/XS, TQ1/2_0, etc. (all already implemented in skainet-io-gguf). DequantOps throws for genuinely unknown types, so an unsupported model now fails explicitly at load time instead of silently passing through and crashing later inside matmul. Adds a regression test that a Q4_1 weight is dequantized to its logical 2D FP32 shape rather than passed through as 1D bytes. Closes #654 Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent 0a2185c commit 804cd74

2 files changed

Lines changed: 76 additions & 16 deletions

File tree

llm-inference/llama/src/jvmMain/kotlin/sk/ainet/models/llama/DecoderGgufMemSegConverter.kt

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,20 @@ import java.lang.foreign.Arena
2727
* [Q8MemorySegmentTensorData] with the **logical** matrix shape derived
2828
* from metadata. Upstream `DefaultCpuOpsJvm.matmul` and `transpose`
2929
* detect the markers and dispatch quant-aware kernels at forward time.
30-
* - **Q4_K / Q5_K / Q6_K** → dequantized to FP32. The packed K-quant kernels
31-
* are MemSeg-only on a hot path the DSL doesn't yet route through, so this
32-
* trades memory for correctness. Same trade-off the legacy converter
33-
* makes for K-quants.
30+
* - **Every other quant type** (Q4_1, Q5_0, Q5_1, Q8_1, the K-quants
31+
* Q4_K / Q5_K / Q6_K, IQ4_NL/XS, TQ1/2_0, ...) → dequantized to FP32. None
32+
* of these has a packed MemSeg kernel on the hot path the DSL routes
33+
* through, so this trades memory for correctness — the same trade-off the
34+
* legacy converter makes for K-quants. [DequantOps.dequantFromBytes] throws
35+
* for genuinely unknown types, so an unsupported model fails explicitly at
36+
* load time instead of silently passing bytes through and crashing later
37+
* inside matmul (see issue #654).
3438
* - **token_embd.weight** → always dequantized to FP32 regardless of quant
3539
* type. The Embedding layer consumes this via `gather`, not matmul, so it
3640
* needs real floats with the logical 2D shape — packed quant bytes would
3741
* be misread as FP32 values, and the loader's intermediate Int8 wrapper
3842
* stores a 1D byte-count shape that `gather` rejects.
3943
* - **FP32 (no entry in `quantTypes`)** → passed through unchanged.
40-
* - **Other quant types** → warning logged, passed through (will fail later
41-
* if the model actually hits them via matmul).
4244
*
4345
* Why logical shape matters here: the loader stores raw quant bytes via
4446
* `ctx.fromByteArray(Shape(bytes.size), Int8, bytes)` — a 1D byte-count
@@ -168,19 +170,17 @@ public object DecoderGgufMemSegConverter {
168170
@Suppress("UNCHECKED_CAST")
169171
ctx.fromData(newData as TensorData<FP32, Float>, FP32::class)
170172
}
171-
GGMLQuantizationType.Q4_K,
172-
GGMLQuantizationType.Q5_K,
173-
GGMLQuantizationType.Q6_K -> {
173+
// Every other GGUF quant type (Q4_1, Q5_0, Q5_1, Q8_1, the
174+
// K-quants, IQ4_NL/XS, TQ1/2_0, ...) has no packed MemSeg kernel
175+
// on the DSL forward path, so dequantize to FP32 here — the same
176+
// memory-for-correctness trade-off the K-quants already made.
177+
// DequantOps throws for genuinely unknown types, which turns what
178+
// used to be a silent pass-through (and a confusing crash deep
179+
// inside matmul) into an explicit failure at load time. See #654.
180+
else -> {
174181
val floats = DequantOps.dequantFromBytes(bytes, quantType, logicalShape.volume)
175182
ctx.fromFloatArray(logicalShape, FP32::class, floats)
176183
}
177-
else -> {
178-
println(
179-
"WARNING: DecoderGgufMemSegConverter: unsupported quant type $quantType for '$name'; " +
180-
"passing through unchanged. Forward pass may fail at matmul.",
181-
)
182-
tensor
183-
}
184184
}
185185
}
186186

llm-inference/llama/src/jvmTest/kotlin/sk/ainet/models/llama/DecoderGgufMemSegConverterTest.kt

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,36 @@ class DecoderGgufMemSegConverterTest {
104104
}
105105
}
106106

107+
@Test
108+
fun `Q4_1 tensor is dequantized to FP32 with logical shape`() {
109+
// Regression for #654: Q4_1 used to hit the silent pass-through
110+
// `else` branch and crash later inside matmul. It must now be
111+
// dequantized to a 2D FP32 tensor with the logical matrix shape.
112+
// ffn_down logical shape is (dim, ffn); size the raw fixture to match.
113+
val rawQ4_1 = rawQ4_1Tensor(rows = dim, cols = ffn)
114+
val weights = DecoderGgufWeights<FP32, Float>(
115+
metadata = metadata,
116+
tensors = mapOf("blk.0.ffn_down.weight" to rawQ4_1),
117+
quantTypes = mapOf("blk.0.ffn_down.weight" to GGMLQuantizationType.Q4_1),
118+
)
119+
120+
Arena.ofConfined().use { arena ->
121+
val out = DecoderGgufMemSegConverter.convert(weights, ctx, arena)
122+
val down = out.tensors.getValue("blk.0.ffn_down.weight")
123+
124+
assertEquals(
125+
Shape(dim, ffn),
126+
down.shape,
127+
"Q4_1 weight must be dequantized to its logical 2D shape, not passed through as 1D bytes",
128+
)
129+
assertTrue(
130+
down.data !is Q4MemorySegmentMarker && down.data !is Q8MemorySegmentMarker,
131+
"Q4_1 has no packed MemSeg path; it must be plain dequantized FP32, got ${down.data::class.simpleName}",
132+
)
133+
assertTrue(out.quantTypes.isEmpty(), "quantTypes should be cleared post-convert")
134+
}
135+
}
136+
107137
@Test
108138
fun `tensor count and key set are preserved`() {
109139
val q4 = rawQ4Tensor(dim, dim)
@@ -157,6 +187,36 @@ class DecoderGgufMemSegConverterTest {
157187
return tensor as Tensor<FP32, Float>
158188
}
159189

190+
/** Build a raw-byte tensor that simulates a NATIVE_OPTIMIZED Q4_1 load. */
191+
private fun rawQ4_1Tensor(rows: Int, cols: Int): Tensor<FP32, Float> {
192+
val nElements = rows * cols
193+
val blockSize = 32
194+
val bytesPerBlock = 20 // 2B d (f16) + 2B m (f16) + 16B packed nibbles
195+
val nBlocks = nElements / blockSize
196+
val nBytes = nBlocks * bytesPerBlock
197+
198+
val bytes = ByteArray(nBytes)
199+
for (block in 0 until nBlocks) {
200+
val off = block * bytesPerBlock
201+
// f16 scale d = 0.5
202+
val dBits = floatToHalf(0.5f)
203+
bytes[off] = (dBits and 0xFF).toByte()
204+
bytes[off + 1] = ((dBits shr 8) and 0xFF).toByte()
205+
// f16 min m = 0.25
206+
val mBits = floatToHalf(0.25f)
207+
bytes[off + 2] = (mBits and 0xFF).toByte()
208+
bytes[off + 3] = ((mBits shr 8) and 0xFF).toByte()
209+
// Nibble codes: 8 on both halves for simplicity (w = d*8 + m)
210+
for (i in 0 until 16) {
211+
bytes[off + 4 + i] = 0x88.toByte()
212+
}
213+
}
214+
215+
val tensor = ctx.fromByteArray<Int8, Byte>(Shape(nBytes), Int8::class, bytes)
216+
@Suppress("UNCHECKED_CAST")
217+
return tensor as Tensor<FP32, Float>
218+
}
219+
160220
/** Build a raw-byte tensor that simulates a NATIVE_OPTIMIZED Q8_0 load. */
161221
private fun rawQ8Tensor(rows: Int, cols: Int): Tensor<FP32, Float> {
162222
val nElements = rows * cols

0 commit comments

Comments
 (0)