@@ -359,9 +359,18 @@ internal object JvmQuantizedVectorKernels {
359359 * Dequantize one 256-element Q6_K block into [scratch] starting at
360360 * [scratchOffset]. Mirrors
361361 * `DequantOps.dequantQ6KFromBytes` line-for-line — see that method for
362- * the authoritative spec. Scalar implementation; the hot loop is the
363- * SIMD dot-product that follows, so per-block dequant cost is
364- * amortized by the outputDim-wide loop.
362+ * the authoritative spec.
363+ *
364+ * SIMD-fused via `ByteVector`: per `floatStep`-wide chunk of `l`,
365+ * loads one slice of `ql[qlBase+l]`, one of `ql[qlBase+l+32]`, and
366+ * one of `qh[qhBase+l]`, then assembles q1..q4 = `(qlNibble) |
367+ * ((qhSlice) << 4) − 32` per code lane via byte AND/LSHR/OR ops,
368+ * widens to FloatVector, multiplies by per-sub-block `d·scale`, and
369+ * stores to four 32-element output regions in `scratch`. Replaces
370+ * the prior scalar loop's 32 iterations × 4 codes/iteration of
371+ * scalar shifts and multiplies with one ByteVector load + 4 FMA
372+ * stores per chunk. Scalar tail fires only when `floatStep` doesn't
373+ * divide 16 (rare).
365374 */
366375 private fun dequantQ6_KBlock (
367376 packedWeights : ByteArray ,
@@ -378,32 +387,79 @@ internal object JvmQuantizedVectorKernels {
378387 (packedWeights[dOffset].toInt() and 0xFF )
379388 val d = halfToFloat(dBits)
380389
390+ val floatStep = floatSpecies.length()
391+ val byteLoadLen = byteSpeciesForFloat.length()
392+
381393 for (half in 0 .. 1 ) {
382394 val qlBase = qlBase0 + half * 64
383395 val qhBase = qhBase0 + half * 32
384396 val scBase = scBase0 + half * 8
385397 val outBase = scratchOffset + half * 128
386- for (l in 0 until 32 ) {
387- val isIdx = l / 16
388-
389- val ql0 = packedWeights[qlBase + l].toInt() and 0xFF
390- val ql32 = packedWeights[qlBase + l + 32 ].toInt() and 0xFF
391- val qhL = packedWeights[qhBase + l].toInt() and 0xFF
392-
393- val q1 = ((ql0 and 0x0F ) or ((qhL and 0x03 ) shl 4 )) - 32
394- val q2 = ((ql32 and 0x0F ) or (((qhL ushr 2 ) and 0x03 ) shl 4 )) - 32
395- val q3 = ((ql0 ushr 4 ) or (((qhL ushr 4 ) and 0x03 ) shl 4 )) - 32
396- val q4 = ((ql32 ushr 4 ) or (((qhL ushr 6 ) and 0x03 ) shl 4 )) - 32
397-
398- val sc1 = packedWeights[scBase + isIdx + 0 ].toInt() // signed
399- val sc2 = packedWeights[scBase + isIdx + 2 ].toInt()
400- val sc3 = packedWeights[scBase + isIdx + 4 ].toInt()
401- val sc4 = packedWeights[scBase + isIdx + 6 ].toInt()
402-
403- scratch[outBase + l + 0 ] = d * sc1 * q1
404- scratch[outBase + l + 32 ] = d * sc2 * q2
405- scratch[outBase + l + 64 ] = d * sc3 * q3
406- scratch[outBase + l + 96 ] = d * sc4 * q4
398+
399+ for (isIdx in 0 .. 1 ) {
400+ val sc1 = d * packedWeights[scBase + isIdx + 0 ].toInt()
401+ val sc2 = d * packedWeights[scBase + isIdx + 2 ].toInt()
402+ val sc3 = d * packedWeights[scBase + isIdx + 4 ].toInt()
403+ val sc4 = d * packedWeights[scBase + isIdx + 6 ].toInt()
404+ val sc1Vec = FloatVector .broadcast(floatSpecies, sc1)
405+ val sc2Vec = FloatVector .broadcast(floatSpecies, sc2)
406+ val sc3Vec = FloatVector .broadcast(floatSpecies, sc3)
407+ val sc4Vec = FloatVector .broadcast(floatSpecies, sc4)
408+ val negThirtyTwo = FloatVector .broadcast(floatSpecies, - 32f )
409+
410+ val lStart = isIdx * 16
411+ val lEnd = lStart + 16
412+ var l = lStart
413+ while (l + floatStep <= lEnd &&
414+ qlBase + l + byteLoadLen <= packedWeights.size
415+ ) {
416+ val ql0Vec = ByteVector .fromArray(byteSpeciesForFloat, packedWeights, qlBase + l)
417+ val ql32Vec = ByteVector .fromArray(byteSpeciesForFloat, packedWeights, qlBase + l + 32 )
418+ val qhVec = ByteVector .fromArray(byteSpeciesForFloat, packedWeights, qhBase + l)
419+
420+ val ql0Lo = ql0Vec.and (0x0F .toByte())
421+ val ql0Hi = ql0Vec.lanewise(VectorOperators .LSHR , 4 .toByte())
422+ val ql32Lo = ql32Vec.and (0x0F .toByte())
423+ val ql32Hi = ql32Vec.lanewise(VectorOperators .LSHR , 4 .toByte())
424+
425+ val qh1 = qhVec.and (0x03 .toByte())
426+ val qh2 = qhVec.lanewise(VectorOperators .LSHR , 2 .toByte()).and (0x03 .toByte())
427+ val qh3 = qhVec.lanewise(VectorOperators .LSHR , 4 .toByte()).and (0x03 .toByte())
428+ val qh4 = qhVec.lanewise(VectorOperators .LSHR , 6 .toByte())
429+
430+ val q1Bytes = ql0Lo.or (qh1.lanewise(VectorOperators .LSHL , 4 .toByte()))
431+ val q2Bytes = ql32Lo.or (qh2.lanewise(VectorOperators .LSHL , 4 .toByte()))
432+ val q3Bytes = ql0Hi.or (qh3.lanewise(VectorOperators .LSHL , 4 .toByte()))
433+ val q4Bytes = ql32Hi.or (qh4.lanewise(VectorOperators .LSHL , 4 .toByte()))
434+
435+ val q1F = (q1Bytes.castShape(floatSpecies, 0 ) as FloatVector ).add(negThirtyTwo)
436+ val q2F = (q2Bytes.castShape(floatSpecies, 0 ) as FloatVector ).add(negThirtyTwo)
437+ val q3F = (q3Bytes.castShape(floatSpecies, 0 ) as FloatVector ).add(negThirtyTwo)
438+ val q4F = (q4Bytes.castShape(floatSpecies, 0 ) as FloatVector ).add(negThirtyTwo)
439+
440+ q1F.mul(sc1Vec).intoArray(scratch, outBase + l + 0 )
441+ q2F.mul(sc2Vec).intoArray(scratch, outBase + l + 32 )
442+ q3F.mul(sc3Vec).intoArray(scratch, outBase + l + 64 )
443+ q4F.mul(sc4Vec).intoArray(scratch, outBase + l + 96 )
444+
445+ l + = floatStep
446+ }
447+
448+ // Scalar tail (only fires if floatStep doesn't divide 16).
449+ while (l < lEnd) {
450+ val ql0 = packedWeights[qlBase + l].toInt() and 0xFF
451+ val ql32 = packedWeights[qlBase + l + 32 ].toInt() and 0xFF
452+ val qhL = packedWeights[qhBase + l].toInt() and 0xFF
453+ val q1 = ((ql0 and 0x0F ) or ((qhL and 0x03 ) shl 4 )) - 32
454+ val q2 = ((ql32 and 0x0F ) or (((qhL ushr 2 ) and 0x03 ) shl 4 )) - 32
455+ val q3 = ((ql0 ushr 4 ) or (((qhL ushr 4 ) and 0x03 ) shl 4 )) - 32
456+ val q4 = ((ql32 ushr 4 ) or (((qhL ushr 6 ) and 0x03 ) shl 4 )) - 32
457+ scratch[outBase + l + 0 ] = sc1 * q1
458+ scratch[outBase + l + 32 ] = sc2 * q2
459+ scratch[outBase + l + 64 ] = sc3 * q3
460+ scratch[outBase + l + 96 ] = sc4 * q4
461+ l++
462+ }
407463 }
408464 }
409465 }
0 commit comments