@@ -545,10 +545,11 @@ internal object JvmQuantizedVectorKernels {
545545 val subBlockSize = 32
546546 val bytesPerBlock = 144L
547547 val blocksPerRow = (inputDim + blockSize - 1 ) / blockSize
548- val codeBuf = FloatArray (subBlockSize)
549548 val scaleIdxBuf = IntArray (8 )
550549 val minIdxBuf = IntArray (8 )
551- val sumsBuf = FloatArray (2 )
550+
551+ val floatStep = floatSpecies.length()
552+ val byteLoadLen = byteSpeciesForFloat.length()
552553
553554 for (o in 0 until outputDim) {
554555 var acc = 0f
@@ -582,31 +583,68 @@ internal object JvmQuantizedVectorKernels {
582583 minIdxBuf[sb] = low4M or (high2M shl 4 )
583584 }
584585
586+ // 4 strided qs groups; each carries sbLo (lo nibbles) and sbHi (hi nibbles).
587+ // Single ByteVector load per chunk feeds both nibble accumulators —
588+ // mirrors the SIMD pipeline in PanamaVectorQ4KMatmulKernel for the
589+ // ByteArray-backed path; this kernel reads from MemorySegment via
590+ // ByteVector.fromMemorySegment for mmap'd weight buffers.
585591 for (groupJ in 0 until 4 ) {
586592 val qsRegion = codesOff + groupJ * 32L
587-
588593 val sbLo = 2 * groupJ
594+ val sbHi = sbLo + 1
589595 val inputStartLo = blockIdx * blockSize + sbLo * subBlockSize
590- if (inputStartLo < inputDim) {
591- dotQ4_KHalfNibbleSubBlockMemSeg(
592- input, inputStartLo, weightSeg, qsRegion,
593- hiNibble = false , codeBuf, sumsBuf
596+ val inputStartHi = inputStartLo + subBlockSize
597+
598+ var codeAccLo = FloatVector .zero(floatSpecies)
599+ var inputAccLo = FloatVector .zero(floatSpecies)
600+ var codeAccHi = FloatVector .zero(floatSpecies)
601+ var inputAccHi = FloatVector .zero(floatSpecies)
602+ var idx = 0
603+
604+ while (idx + floatStep <= subBlockSize) {
605+ val byteVec = ByteVector .fromMemorySegment(
606+ byteSpeciesForFloat, weightSeg, qsRegion + idx, ByteOrder .LITTLE_ENDIAN ,
594607 )
595- val scale = d * scaleIdxBuf[sbLo]
596- val offset = dMin * minIdxBuf[sbLo]
597- acc + = sumsBuf[0 ] * scale - sumsBuf[1 ] * offset
608+ val loBytes = byteVec.and (0x0F .toByte())
609+ val hiBytes = byteVec.lanewise(VectorOperators .LSHR , 4 .toByte())
610+ val codeVecLo = loBytes.castShape(floatSpecies, 0 ) as FloatVector
611+ val codeVecHi = hiBytes.castShape(floatSpecies, 0 ) as FloatVector
612+ val inVecLo = FloatVector .fromArray(floatSpecies, input, inputStartLo + idx)
613+ val inVecHi = FloatVector .fromArray(floatSpecies, input, inputStartHi + idx)
614+ codeAccLo = inVecLo.fma(codeVecLo, codeAccLo)
615+ inputAccLo = inVecLo.add(inputAccLo)
616+ codeAccHi = inVecHi.fma(codeVecHi, codeAccHi)
617+ inputAccHi = inVecHi.add(inputAccHi)
618+ idx + = floatStep
598619 }
599620
600- val sbHi = 2 * groupJ + 1
601- val inputStartHi = inputStartLo + subBlockSize
621+ var codeSumLo = codeAccLo.reduceLanes(VectorOperators .ADD )
622+ var inputSumLo = inputAccLo.reduceLanes(VectorOperators .ADD )
623+ var codeSumHi = codeAccHi.reduceLanes(VectorOperators .ADD )
624+ var inputSumHi = inputAccHi.reduceLanes(VectorOperators .ADD )
625+
626+ while (idx < subBlockSize) {
627+ val b = weightSeg.get(JAVA_BYTE_LE , qsRegion + idx).toInt() and 0xFF
628+ val codeLo = (b and 0x0F ).toFloat()
629+ val codeHi = (b ushr 4 ).toFloat()
630+ val vLo = input[inputStartLo + idx]
631+ val vHi = input[inputStartHi + idx]
632+ codeSumLo + = vLo * codeLo
633+ inputSumLo + = vLo
634+ codeSumHi + = vHi * codeHi
635+ inputSumHi + = vHi
636+ idx++
637+ }
638+
639+ val scaleLo = d * scaleIdxBuf[sbLo]
640+ val offsetLo = dMin * minIdxBuf[sbLo]
641+ val scaleHi = d * scaleIdxBuf[sbHi]
642+ val offsetHi = dMin * minIdxBuf[sbHi]
643+ if (inputStartLo < inputDim) {
644+ acc + = codeSumLo * scaleLo - inputSumLo * offsetLo
645+ }
602646 if (inputStartHi < inputDim) {
603- dotQ4_KHalfNibbleSubBlockMemSeg(
604- input, inputStartHi, weightSeg, qsRegion,
605- hiNibble = true , codeBuf, sumsBuf
606- )
607- val scale = d * scaleIdxBuf[sbHi]
608- val offset = dMin * minIdxBuf[sbHi]
609- acc + = sumsBuf[0 ] * scale - sumsBuf[1 ] * offset
647+ acc + = codeSumHi * scaleHi - inputSumHi * offsetHi
610648 }
611649 }
612650 }
@@ -615,58 +653,6 @@ internal object JvmQuantizedVectorKernels {
615653 }
616654 }
617655
618- /* *
619- * MemSeg-reading counterpart to `dotQ4_KHalfNibbleSubBlock`. Same
620- * canonical strided-nibble layout; reads the 32-byte qs group through
621- * `MemorySegment.get`.
622- */
623- private fun dotQ4_KHalfNibbleSubBlockMemSeg (
624- input : FloatArray ,
625- inputOffset : Int ,
626- weightSeg : MemorySegment ,
627- qsOffset : Long ,
628- hiNibble : Boolean ,
629- codeBuf : FloatArray ,
630- sumsOut : FloatArray ,
631- ) {
632- if (hiNibble) {
633- for (i in 0 until SUB_BLOCK_SIZE ) {
634- val b = weightSeg.get(JAVA_BYTE_LE , qsOffset + i.toLong()).toInt() and 0xFF
635- codeBuf[i] = (b ushr 4 ).toFloat()
636- }
637- } else {
638- for (i in 0 until SUB_BLOCK_SIZE ) {
639- val b = weightSeg.get(JAVA_BYTE_LE , qsOffset + i.toLong()).toInt() and 0xFF
640- codeBuf[i] = (b and 0x0F ).toFloat()
641- }
642- }
643-
644- val step = floatSpecies.length()
645- var codeAcc = FloatVector .zero(floatSpecies)
646- var inputAcc = FloatVector .zero(floatSpecies)
647- var idx = 0
648- val loopBound = floatSpecies.loopBound(SUB_BLOCK_SIZE )
649- while (idx < loopBound) {
650- val iv = FloatVector .fromArray(floatSpecies, input, inputOffset + idx)
651- val cv = FloatVector .fromArray(floatSpecies, codeBuf, idx)
652- codeAcc = iv.fma(cv, codeAcc)
653- inputAcc = iv.add(inputAcc)
654- idx + = step
655- }
656- var codeSum = codeAcc.reduceLanes(VectorOperators .ADD )
657- var inputSum = inputAcc.reduceLanes(VectorOperators .ADD )
658-
659- while (idx < SUB_BLOCK_SIZE ) {
660- val v = input[inputOffset + idx]
661- codeSum + = v * codeBuf[idx]
662- inputSum + = v
663- idx++
664- }
665-
666- sumsOut[0 ] = codeSum
667- sumsOut[1 ] = inputSum
668- }
669-
670656 /* *
671657 * Byte species matching the float species lane count — used for loading
672658 * exactly `floatSpecies.length()` bytes from a MemorySegment so that
0 commit comments