Skip to content

Commit 00b80c0

Browse files
Merge pull request #564 from SKaiNET-developers/feature/jvm-q6k-simd-dequant
perf(q6_k): SIMD-fy dequantQ6_KBlock via ByteVector ql + qh extraction
2 parents 3ea9b5f + 697574f commit 00b80c0

1 file changed

Lines changed: 80 additions & 24 deletions

File tree

skainet-backends/skainet-backend-cpu/src/jvmMain/kotlin/sk/ainet/exec/tensor/ops/JvmQuantizedVectorKernels.kt

Lines changed: 80 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -359,9 +359,18 @@ internal object JvmQuantizedVectorKernels {
359359
* Dequantize one 256-element Q6_K block into [scratch] starting at
360360
* [scratchOffset]. Mirrors
361361
* `DequantOps.dequantQ6KFromBytes` line-for-line — see that method for
362-
* the authoritative spec. Scalar implementation; the hot loop is the
363-
* SIMD dot-product that follows, so per-block dequant cost is
364-
* amortized by the outputDim-wide loop.
362+
* the authoritative spec.
363+
*
364+
* SIMD-fused via `ByteVector`: per `floatStep`-wide chunk of `l`,
365+
* loads one slice of `ql[qlBase+l]`, one of `ql[qlBase+l+32]`, and
366+
* one of `qh[qhBase+l]`, then assembles q1..q4 = `(qlNibble) |
367+
* ((qhSlice) << 4) − 32` per code lane via byte AND/LSHR/OR ops,
368+
* widens to FloatVector, multiplies by per-sub-block `d·scale`, and
369+
* stores to four 32-element output regions in `scratch`. Replaces
370+
* the prior scalar loop's 32 iterations × 4 codes/iteration of
371+
* scalar shifts and multiplies with one ByteVector load + 4 FMA
372+
* stores per chunk. Scalar tail fires only when `floatStep` doesn't
373+
* divide 16 (rare).
365374
*/
366375
private fun dequantQ6_KBlock(
367376
packedWeights: ByteArray,
@@ -378,32 +387,79 @@ internal object JvmQuantizedVectorKernels {
378387
(packedWeights[dOffset].toInt() and 0xFF)
379388
val d = halfToFloat(dBits)
380389

390+
val floatStep = floatSpecies.length()
391+
val byteLoadLen = byteSpeciesForFloat.length()
392+
381393
for (half in 0..1) {
382394
val qlBase = qlBase0 + half * 64
383395
val qhBase = qhBase0 + half * 32
384396
val scBase = scBase0 + half * 8
385397
val outBase = scratchOffset + half * 128
386-
for (l in 0 until 32) {
387-
val isIdx = l / 16
388-
389-
val ql0 = packedWeights[qlBase + l].toInt() and 0xFF
390-
val ql32 = packedWeights[qlBase + l + 32].toInt() and 0xFF
391-
val qhL = packedWeights[qhBase + l].toInt() and 0xFF
392-
393-
val q1 = ((ql0 and 0x0F) or ((qhL and 0x03) shl 4)) - 32
394-
val q2 = ((ql32 and 0x0F) or (((qhL ushr 2) and 0x03) shl 4)) - 32
395-
val q3 = ((ql0 ushr 4) or (((qhL ushr 4) and 0x03) shl 4)) - 32
396-
val q4 = ((ql32 ushr 4) or (((qhL ushr 6) and 0x03) shl 4)) - 32
397-
398-
val sc1 = packedWeights[scBase + isIdx + 0].toInt() // signed
399-
val sc2 = packedWeights[scBase + isIdx + 2].toInt()
400-
val sc3 = packedWeights[scBase + isIdx + 4].toInt()
401-
val sc4 = packedWeights[scBase + isIdx + 6].toInt()
402-
403-
scratch[outBase + l + 0] = d * sc1 * q1
404-
scratch[outBase + l + 32] = d * sc2 * q2
405-
scratch[outBase + l + 64] = d * sc3 * q3
406-
scratch[outBase + l + 96] = d * sc4 * q4
398+
399+
for (isIdx in 0..1) {
400+
val sc1 = d * packedWeights[scBase + isIdx + 0].toInt()
401+
val sc2 = d * packedWeights[scBase + isIdx + 2].toInt()
402+
val sc3 = d * packedWeights[scBase + isIdx + 4].toInt()
403+
val sc4 = d * packedWeights[scBase + isIdx + 6].toInt()
404+
val sc1Vec = FloatVector.broadcast(floatSpecies, sc1)
405+
val sc2Vec = FloatVector.broadcast(floatSpecies, sc2)
406+
val sc3Vec = FloatVector.broadcast(floatSpecies, sc3)
407+
val sc4Vec = FloatVector.broadcast(floatSpecies, sc4)
408+
val negThirtyTwo = FloatVector.broadcast(floatSpecies, -32f)
409+
410+
val lStart = isIdx * 16
411+
val lEnd = lStart + 16
412+
var l = lStart
413+
while (l + floatStep <= lEnd &&
414+
qlBase + l + byteLoadLen <= packedWeights.size
415+
) {
416+
val ql0Vec = ByteVector.fromArray(byteSpeciesForFloat, packedWeights, qlBase + l)
417+
val ql32Vec = ByteVector.fromArray(byteSpeciesForFloat, packedWeights, qlBase + l + 32)
418+
val qhVec = ByteVector.fromArray(byteSpeciesForFloat, packedWeights, qhBase + l)
419+
420+
val ql0Lo = ql0Vec.and(0x0F.toByte())
421+
val ql0Hi = ql0Vec.lanewise(VectorOperators.LSHR, 4.toByte())
422+
val ql32Lo = ql32Vec.and(0x0F.toByte())
423+
val ql32Hi = ql32Vec.lanewise(VectorOperators.LSHR, 4.toByte())
424+
425+
val qh1 = qhVec.and(0x03.toByte())
426+
val qh2 = qhVec.lanewise(VectorOperators.LSHR, 2.toByte()).and(0x03.toByte())
427+
val qh3 = qhVec.lanewise(VectorOperators.LSHR, 4.toByte()).and(0x03.toByte())
428+
val qh4 = qhVec.lanewise(VectorOperators.LSHR, 6.toByte())
429+
430+
val q1Bytes = ql0Lo.or(qh1.lanewise(VectorOperators.LSHL, 4.toByte()))
431+
val q2Bytes = ql32Lo.or(qh2.lanewise(VectorOperators.LSHL, 4.toByte()))
432+
val q3Bytes = ql0Hi.or(qh3.lanewise(VectorOperators.LSHL, 4.toByte()))
433+
val q4Bytes = ql32Hi.or(qh4.lanewise(VectorOperators.LSHL, 4.toByte()))
434+
435+
val q1F = (q1Bytes.castShape(floatSpecies, 0) as FloatVector).add(negThirtyTwo)
436+
val q2F = (q2Bytes.castShape(floatSpecies, 0) as FloatVector).add(negThirtyTwo)
437+
val q3F = (q3Bytes.castShape(floatSpecies, 0) as FloatVector).add(negThirtyTwo)
438+
val q4F = (q4Bytes.castShape(floatSpecies, 0) as FloatVector).add(negThirtyTwo)
439+
440+
q1F.mul(sc1Vec).intoArray(scratch, outBase + l + 0)
441+
q2F.mul(sc2Vec).intoArray(scratch, outBase + l + 32)
442+
q3F.mul(sc3Vec).intoArray(scratch, outBase + l + 64)
443+
q4F.mul(sc4Vec).intoArray(scratch, outBase + l + 96)
444+
445+
l += floatStep
446+
}
447+
448+
// Scalar tail (only fires if floatStep doesn't divide 16).
449+
while (l < lEnd) {
450+
val ql0 = packedWeights[qlBase + l].toInt() and 0xFF
451+
val ql32 = packedWeights[qlBase + l + 32].toInt() and 0xFF
452+
val qhL = packedWeights[qhBase + l].toInt() and 0xFF
453+
val q1 = ((ql0 and 0x0F) or ((qhL and 0x03) shl 4)) - 32
454+
val q2 = ((ql32 and 0x0F) or (((qhL ushr 2) and 0x03) shl 4)) - 32
455+
val q3 = ((ql0 ushr 4) or (((qhL ushr 4) and 0x03) shl 4)) - 32
456+
val q4 = ((ql32 ushr 4) or (((qhL ushr 6) and 0x03) shl 4)) - 32
457+
scratch[outBase + l + 0] = sc1 * q1
458+
scratch[outBase + l + 32] = sc2 * q2
459+
scratch[outBase + l + 64] = sc3 * q3
460+
scratch[outBase + l + 96] = sc4 * q4
461+
l++
462+
}
407463
}
408464
}
409465
}

0 commit comments

Comments
 (0)