@@ -518,34 +518,83 @@ internal object JvmQuantizedVectorKernels {
518518 * Q4_0 dot-product for a single block of 32 elements stored in a MemorySegment.
519519 *
520520 * Q4_0 block layout: 2 bytes f16 scale + 16 bytes packed nibbles (32 values).
521- * Each byte packs two 4-bit codes: lo nibble = first, hi nibble = second.
521+ * Each byte packs two 4-bit codes — adjacent elements share a byte:
522+ * `code[2k] = byte[k] & 0x0F`, `code[2k+1] = byte[k] >>> 4`. Subtract
523+ * 8 for sign correction.
522524 *
523- * Uses the preferred vector species (AVX-256 gives 8-wide, AVX-512 gives 16-wide).
525+ * Two-stage SIMD: a scalar byte-pair unpack writes 32 sign-corrected
526+ * floats into [codeBuf] (16 byte loads from the MemorySegment, two
527+ * nibbles per load — same memory traffic as the prior fully-scalar
528+ * implementation), then a vector FMA loop dot-products [codeBuf]
529+ * with the matching input slice. The nibble-pair-per-byte layout
530+ * makes a fully-fused `ByteVector` pipeline (a la
531+ * [PanamaVectorQ4KMatmulKernel]) awkward without strided gather or
532+ * lane-interleave shuffles, so this kernel keeps the scratch +
533+ * SIMD-dot shape — same approach Q4_K used before its
534+ * fused-pipeline rewrite (PR #562).
535+ *
536+ * @param codeBuf scratch FloatArray of length >= 32, supplied by
537+ * the caller so allocation amortizes across blocks.
524538 */
525539 fun dotQ4_0BlockMemSeg (
526540 input : FloatArray ,
527541 inputOffset : Int ,
528542 weightSeg : MemorySegment ,
529543 blockByteOffset : Long ,
544+ codeBuf : FloatArray ,
530545 ): Float {
531546 val blockSize = 32
532547 val codesOffset = blockByteOffset + 2
533548
534549 // Read f16 scale
535550 val scale = halfToFloat(read2BytesLE(weightSeg, blockByteOffset))
536551
537- // Q4_0: 16 packed bytes → 32 nibbles. Unpack all 32 codes to a reusable scratch array.
538- // This is still scalar unpacking but avoids per-iteration FloatArray allocation.
539- var sum = 0f
540- for (idx in 0 until blockSize ) {
541- val packedByte = weightSeg.get(JAVA_BYTE_LE , codesOffset + (idx / 2 ) .toLong()).toInt() and 0xFF
542- val code = ( if (idx % 2 == 0 ) (packedByte and 0x0F ) else (packedByte ushr 4 ) ).toFloat() - 8f
543- sum + = input[inputOffset + idx] * code
552+ // Unpack 16 packed bytes → 32 sign-corrected nibbles. Two
553+ // nibbles per byte load means half the byte traffic of the
554+ // straight scalar dot product.
555+ for (k in 0 until 16 ) {
556+ val b = weightSeg.get(JAVA_BYTE_LE , codesOffset + k .toLong()).toInt() and 0xFF
557+ codeBuf[ 2 * k] = (b and 0x0F ).toFloat() - 8f
558+ codeBuf[ 2 * k + 1 ] = (b ushr 4 ).toFloat() - 8f
544559 }
545560
546- return sum * scale
561+ // SIMD FMA dot product.
562+ val step = floatSpecies.length()
563+ var accVec = FloatVector .zero(floatSpecies)
564+ var idx = 0
565+ val loopBound = floatSpecies.loopBound(blockSize)
566+ while (idx < loopBound) {
567+ val iv = FloatVector .fromArray(floatSpecies, input, inputOffset + idx)
568+ val cv = FloatVector .fromArray(floatSpecies, codeBuf, idx)
569+ accVec = iv.fma(cv, accVec)
570+ idx + = step
571+ }
572+ var acc = accVec.reduceLanes(VectorOperators .ADD )
573+ // Scalar tail (only fires if floatSpecies.length() doesn't divide 32 — rare).
574+ while (idx < blockSize) {
575+ acc + = input[inputOffset + idx] * codeBuf[idx]
576+ idx++
577+ }
578+
579+ return acc * scale
547580 }
548581
582+ /* *
583+ * Backwards-compatible overload that allocates its own scratch
584+ * buffer. Existing callers that don't pass one still work; the
585+ * matmul-level [matmulF32Q4_0MemSeg] hoists the allocation out of
586+ * the per-block loop.
587+ */
588+ fun dotQ4_0BlockMemSeg (
589+ input : FloatArray ,
590+ inputOffset : Int ,
591+ weightSeg : MemorySegment ,
592+ blockByteOffset : Long ,
593+ ): Float = dotQ4_0BlockMemSeg(
594+ input, inputOffset, weightSeg, blockByteOffset,
595+ codeBuf = FloatArray (32 ),
596+ )
597+
549598 /* *
550599 * F32 x Q4_0 matrix-vector multiply using MemorySegment for packed Q4 weights.
551600 *
@@ -569,14 +618,16 @@ internal object JvmQuantizedVectorKernels {
569618 val blockSize = 32
570619 val bytesPerBlock = 18L // 2 bytes scale + 16 bytes codes
571620 val blocksPerRow = (inputDim + blockSize - 1 ) / blockSize
621+ // Scratch hoisted out of the per-block loop — see dotQ4_0BlockMemSeg kdoc.
622+ val codeBuf = FloatArray (blockSize)
572623
573624 for (o in 0 until outputDim) {
574625 var acc = 0f
575626 for (blockIdx in 0 until blocksPerRow) {
576627 val blockOff = weightByteOffset +
577628 (o.toLong() * blocksPerRow + blockIdx) * bytesPerBlock
578629 val inputStart = blockIdx * blockSize
579- acc + = dotQ4_0BlockMemSeg(input, inputStart, weightSeg, blockOff)
630+ acc + = dotQ4_0BlockMemSeg(input, inputStart, weightSeg, blockOff, codeBuf )
580631 }
581632 output[outputOffset + o] = acc
582633 }
0 commit comments