Skip to content

Commit ae7691b

Browse files
michalharakalclaude
andcommitted
fix(cpu-ops): lazy transpose for Q4_0 too; cover all packed matmul dtypes
Follow-up to #736 (Q8_0). The transpose lazy-rewrap `when` was still missing Q4_0 — a packed type chooseQuantizedMatmulHeap dispatches — so a packed Q4_0 matmul weight through linearProject (matmul(x, transpose(W))) hit the generic FP32 path and threw `Byte cannot be cast to Float`. Add the Q4_0 case so the `when` now covers EVERY packed type that can be a matmul weight (Q4_K/Q5_K/Q6_K/Q5_0/Q5_1/Q8_0/Q4_0). Adds `transpose_preserves_every_packed_quant_type` to PackedMatmulDispatchTest: transposes a 2-D tensor of each of the 7 packed types and asserts the shape flips and the packed encoding is preserved (no FP32 fallback / no crash). Content-agnostic, runs on every platform (jvm + linuxX64). See SKaiNET-transformers#178. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent cd2bfd2 commit ae7691b

2 files changed

Lines changed: 49 additions & 4 deletions

File tree

skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOps.kt

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import sk.ainet.lang.tensor.data.Q5_1TensorData
2424
import sk.ainet.lang.tensor.data.Q5_1BlockTensorData
2525
import sk.ainet.lang.tensor.data.Q5_0TensorData
2626
import sk.ainet.lang.tensor.data.Q5_0BlockTensorData
27+
import sk.ainet.lang.tensor.data.Q4_0BlockTensorData
2728
import sk.ainet.lang.tensor.data.Q8_0BlockTensorData
2829
import sk.ainet.lang.tensor.data.TensorData
2930
import sk.ainet.lang.tensor.data.TensorDataFactory
@@ -607,12 +608,15 @@ public open class DefaultCpuOpsBase(protected val dataFactory: TensorDataFactory
607608
is Q6_KTensorData -> return newTensor(Q6_KBlockTensorData(Shape(cols, rows), d.packedData) as TensorData<T, V>, tensor.dtype, tensor)
608609
is Q5_1TensorData -> return newTensor(Q5_1BlockTensorData(Shape(cols, rows), d.packedData) as TensorData<T, V>, tensor.dtype, tensor)
609610
is Q5_0TensorData -> return newTensor(Q5_0BlockTensorData(Shape(cols, rows), d.packedData) as TensorData<T, V>, tensor.dtype, tensor)
610-
// Q8_0 lazy transpose: rewrap the same input-block-major bytes with
611-
// flipped shape (bytes are layout-agnostic to the [out,in] kernel
612-
// convention) so a packed Q8_0 weight (e.g. gemma's tied lm_head)
611+
// Q8_0 / Q4_0 lazy transpose: rewrap the same input-block-major bytes
612+
// with flipped shape (bytes are layout-agnostic to the [out,in] kernel
613+
// convention) so a packed weight (e.g. gemma's tied Q8_0 lm_head)
613614
// survives linearProject's transpose instead of hitting the generic
614-
// FP32 path (Byte→Float ClassCastException). See transformers #178.
615+
// FP32 path (Byte→Float ClassCastException). This `when` now covers
616+
// every quant type chooseQuantizedMatmulHeap dispatches — i.e. every
617+
// packed type that can be a matmul weight. See transformers #178.
615618
is Q8_0TensorData -> return newTensor(Q8_0BlockTensorData(Shape(cols, rows), d.packedData) as TensorData<T, V>, tensor.dtype, tensor)
619+
is Q4_0TensorData -> return newTensor(Q4_0BlockTensorData(Shape(cols, rows), d.packedData) as TensorData<T, V>, tensor.dtype, tensor)
616620
else -> {}
617621
}
618622
}

skainet-backends/skainet-backend-cpu/src/commonTest/kotlin/sk/ainet/exec/tensor/ops/PackedMatmulDispatchTest.kt

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,17 @@ package sk.ainet.exec.tensor.ops
33
import kotlin.math.abs
44
import kotlin.random.Random
55
import kotlin.test.Test
6+
import kotlin.test.assertEquals
67
import kotlin.test.assertTrue
78
import sk.ainet.context.DirectCpuExecutionContext
89
import sk.ainet.lang.tensor.Shape
10+
import sk.ainet.lang.tensor.data.Q4_0BlockTensorData
911
import sk.ainet.lang.tensor.data.Q4_KBlockTensorData
12+
import sk.ainet.lang.tensor.data.Q5_0BlockTensorData
1013
import sk.ainet.lang.tensor.data.Q5_1BlockTensorData
14+
import sk.ainet.lang.tensor.data.Q5_KBlockTensorData
1115
import sk.ainet.lang.tensor.data.Q6_KBlockTensorData
16+
import sk.ainet.lang.tensor.data.Q8_0BlockTensorData
1217
import sk.ainet.lang.tensor.data.TensorData
1318
import sk.ainet.lang.types.FP32
1419

@@ -129,4 +134,40 @@ class PackedMatmulDispatchTest {
129134
@Test fun q5_1_through_ops_matmul_transpose() = run("Q5_1", inDim = 128, outDim = 16, seed = 7)
130135
@Test fun q4_k_through_ops_matmul_transpose() = run("Q4_K", inDim = 256, outDim = 12, seed = 8)
131136
@Test fun q6_k_through_ops_matmul_transpose() = run("Q6_K", inDim = 512, outDim = 8, seed = 9)
137+
138+
/**
139+
* `ops.transpose` must lazily rewrap EVERY packed quant type that can be a
140+
* matmul weight (the full `chooseQuantizedMatmulHeap` set) — flipping the
141+
* shape while keeping the same packed bytes — instead of falling into the
142+
* generic FP32 path, which casts the Byte-backed buffer to Float and throws
143+
* `ClassCastException`. Regression guard for transformers #178 (Q8_0/Q4_0
144+
* were the gaps). Content-agnostic: zero bytes, sized per block geometry.
145+
*/
146+
@Test
147+
fun transpose_preserves_every_packed_quant_type() {
148+
val outDim = 8
149+
// name -> (blockElems, bytesPerBlock, builder)
150+
val cases: List<Triple<String, Pair<Int, Int>, (Shape, ByteArray) -> TensorData<FP32, Float>>> = listOf(
151+
Triple("Q4_K", 256 to 144) { s, b -> Q4_KBlockTensorData(s, b) as TensorData<FP32, Float> },
152+
Triple("Q5_K", 256 to 176) { s, b -> Q5_KBlockTensorData(s, b) as TensorData<FP32, Float> },
153+
Triple("Q6_K", 256 to 210) { s, b -> Q6_KBlockTensorData(s, b) as TensorData<FP32, Float> },
154+
Triple("Q8_0", 32 to 34) { s, b -> Q8_0BlockTensorData(s, b) as TensorData<FP32, Float> },
155+
Triple("Q4_0", 32 to 18) { s, b -> Q4_0BlockTensorData(s, b) as TensorData<FP32, Float> },
156+
Triple("Q5_0", 32 to 22) { s, b -> Q5_0BlockTensorData(s, b) as TensorData<FP32, Float> },
157+
Triple("Q5_1", 32 to 24) { s, b -> Q5_1BlockTensorData(s, b) as TensorData<FP32, Float> },
158+
)
159+
for ((name, geom, build) in cases) {
160+
val (blockElems, bpb) = geom
161+
val inDim = blockElems // one block per row
162+
val bytes = ByteArray(outDim * (inDim / blockElems) * bpb)
163+
val w = ctx.fromData(build(Shape(outDim, inDim), bytes), FP32::class)
164+
// The bug threw here for unhandled packed types.
165+
val t = ctx.ops.transpose(w)
166+
assertEquals(Shape(inDim, outDim), t.shape, "$name: transpose did not flip shape")
167+
assertTrue(
168+
t.data::class.simpleName?.contains("Block") == true,
169+
"$name: transpose dropped the packed encoding (got ${t.data::class.simpleName})",
170+
)
171+
}
172+
}
132173
}

0 commit comments

Comments
 (0)