Skip to content

Commit 3ea9b5f

Browse files
Merge pull request #563 from SKaiNET-developers/feature/jvm-q4k-memseg-simd
perf(q4_k): SIMD-fy matmulF32Q4_KMemSeg via ByteVector.fromMemorySegment
2 parents 9cc73aa + 2b88cef commit 3ea9b5f

1 file changed

Lines changed: 57 additions & 71 deletions

File tree

skainet-backends/skainet-backend-cpu/src/jvmMain/kotlin/sk/ainet/exec/tensor/ops/JvmQuantizedVectorKernels.kt

Lines changed: 57 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -545,10 +545,11 @@ internal object JvmQuantizedVectorKernels {
545545
val subBlockSize = 32
546546
val bytesPerBlock = 144L
547547
val blocksPerRow = (inputDim + blockSize - 1) / blockSize
548-
val codeBuf = FloatArray(subBlockSize)
549548
val scaleIdxBuf = IntArray(8)
550549
val minIdxBuf = IntArray(8)
551-
val sumsBuf = FloatArray(2)
550+
551+
val floatStep = floatSpecies.length()
552+
val byteLoadLen = byteSpeciesForFloat.length()
552553

553554
for (o in 0 until outputDim) {
554555
var acc = 0f
@@ -582,31 +583,68 @@ internal object JvmQuantizedVectorKernels {
582583
minIdxBuf[sb] = low4M or (high2M shl 4)
583584
}
584585

586+
// 4 strided qs groups; each carries sbLo (lo nibbles) and sbHi (hi nibbles).
587+
// Single ByteVector load per chunk feeds both nibble accumulators —
588+
// mirrors the SIMD pipeline in PanamaVectorQ4KMatmulKernel for the
589+
// ByteArray-backed path; this kernel reads from MemorySegment via
590+
// ByteVector.fromMemorySegment for mmap'd weight buffers.
585591
for (groupJ in 0 until 4) {
586592
val qsRegion = codesOff + groupJ * 32L
587-
588593
val sbLo = 2 * groupJ
594+
val sbHi = sbLo + 1
589595
val inputStartLo = blockIdx * blockSize + sbLo * subBlockSize
590-
if (inputStartLo < inputDim) {
591-
dotQ4_KHalfNibbleSubBlockMemSeg(
592-
input, inputStartLo, weightSeg, qsRegion,
593-
hiNibble = false, codeBuf, sumsBuf
596+
val inputStartHi = inputStartLo + subBlockSize
597+
598+
var codeAccLo = FloatVector.zero(floatSpecies)
599+
var inputAccLo = FloatVector.zero(floatSpecies)
600+
var codeAccHi = FloatVector.zero(floatSpecies)
601+
var inputAccHi = FloatVector.zero(floatSpecies)
602+
var idx = 0
603+
604+
while (idx + floatStep <= subBlockSize) {
605+
val byteVec = ByteVector.fromMemorySegment(
606+
byteSpeciesForFloat, weightSeg, qsRegion + idx, ByteOrder.LITTLE_ENDIAN,
594607
)
595-
val scale = d * scaleIdxBuf[sbLo]
596-
val offset = dMin * minIdxBuf[sbLo]
597-
acc += sumsBuf[0] * scale - sumsBuf[1] * offset
608+
val loBytes = byteVec.and(0x0F.toByte())
609+
val hiBytes = byteVec.lanewise(VectorOperators.LSHR, 4.toByte())
610+
val codeVecLo = loBytes.castShape(floatSpecies, 0) as FloatVector
611+
val codeVecHi = hiBytes.castShape(floatSpecies, 0) as FloatVector
612+
val inVecLo = FloatVector.fromArray(floatSpecies, input, inputStartLo + idx)
613+
val inVecHi = FloatVector.fromArray(floatSpecies, input, inputStartHi + idx)
614+
codeAccLo = inVecLo.fma(codeVecLo, codeAccLo)
615+
inputAccLo = inVecLo.add(inputAccLo)
616+
codeAccHi = inVecHi.fma(codeVecHi, codeAccHi)
617+
inputAccHi = inVecHi.add(inputAccHi)
618+
idx += floatStep
598619
}
599620

600-
val sbHi = 2 * groupJ + 1
601-
val inputStartHi = inputStartLo + subBlockSize
621+
var codeSumLo = codeAccLo.reduceLanes(VectorOperators.ADD)
622+
var inputSumLo = inputAccLo.reduceLanes(VectorOperators.ADD)
623+
var codeSumHi = codeAccHi.reduceLanes(VectorOperators.ADD)
624+
var inputSumHi = inputAccHi.reduceLanes(VectorOperators.ADD)
625+
626+
while (idx < subBlockSize) {
627+
val b = weightSeg.get(JAVA_BYTE_LE, qsRegion + idx).toInt() and 0xFF
628+
val codeLo = (b and 0x0F).toFloat()
629+
val codeHi = (b ushr 4).toFloat()
630+
val vLo = input[inputStartLo + idx]
631+
val vHi = input[inputStartHi + idx]
632+
codeSumLo += vLo * codeLo
633+
inputSumLo += vLo
634+
codeSumHi += vHi * codeHi
635+
inputSumHi += vHi
636+
idx++
637+
}
638+
639+
val scaleLo = d * scaleIdxBuf[sbLo]
640+
val offsetLo = dMin * minIdxBuf[sbLo]
641+
val scaleHi = d * scaleIdxBuf[sbHi]
642+
val offsetHi = dMin * minIdxBuf[sbHi]
643+
if (inputStartLo < inputDim) {
644+
acc += codeSumLo * scaleLo - inputSumLo * offsetLo
645+
}
602646
if (inputStartHi < inputDim) {
603-
dotQ4_KHalfNibbleSubBlockMemSeg(
604-
input, inputStartHi, weightSeg, qsRegion,
605-
hiNibble = true, codeBuf, sumsBuf
606-
)
607-
val scale = d * scaleIdxBuf[sbHi]
608-
val offset = dMin * minIdxBuf[sbHi]
609-
acc += sumsBuf[0] * scale - sumsBuf[1] * offset
647+
acc += codeSumHi * scaleHi - inputSumHi * offsetHi
610648
}
611649
}
612650
}
@@ -615,58 +653,6 @@ internal object JvmQuantizedVectorKernels {
615653
}
616654
}
617655

618-
/**
619-
* MemSeg-reading counterpart to `dotQ4_KHalfNibbleSubBlock`. Same
620-
* canonical strided-nibble layout; reads the 32-byte qs group through
621-
* `MemorySegment.get`.
622-
*/
623-
private fun dotQ4_KHalfNibbleSubBlockMemSeg(
624-
input: FloatArray,
625-
inputOffset: Int,
626-
weightSeg: MemorySegment,
627-
qsOffset: Long,
628-
hiNibble: Boolean,
629-
codeBuf: FloatArray,
630-
sumsOut: FloatArray,
631-
) {
632-
if (hiNibble) {
633-
for (i in 0 until SUB_BLOCK_SIZE) {
634-
val b = weightSeg.get(JAVA_BYTE_LE, qsOffset + i.toLong()).toInt() and 0xFF
635-
codeBuf[i] = (b ushr 4).toFloat()
636-
}
637-
} else {
638-
for (i in 0 until SUB_BLOCK_SIZE) {
639-
val b = weightSeg.get(JAVA_BYTE_LE, qsOffset + i.toLong()).toInt() and 0xFF
640-
codeBuf[i] = (b and 0x0F).toFloat()
641-
}
642-
}
643-
644-
val step = floatSpecies.length()
645-
var codeAcc = FloatVector.zero(floatSpecies)
646-
var inputAcc = FloatVector.zero(floatSpecies)
647-
var idx = 0
648-
val loopBound = floatSpecies.loopBound(SUB_BLOCK_SIZE)
649-
while (idx < loopBound) {
650-
val iv = FloatVector.fromArray(floatSpecies, input, inputOffset + idx)
651-
val cv = FloatVector.fromArray(floatSpecies, codeBuf, idx)
652-
codeAcc = iv.fma(cv, codeAcc)
653-
inputAcc = iv.add(inputAcc)
654-
idx += step
655-
}
656-
var codeSum = codeAcc.reduceLanes(VectorOperators.ADD)
657-
var inputSum = inputAcc.reduceLanes(VectorOperators.ADD)
658-
659-
while (idx < SUB_BLOCK_SIZE) {
660-
val v = input[inputOffset + idx]
661-
codeSum += v * codeBuf[idx]
662-
inputSum += v
663-
idx++
664-
}
665-
666-
sumsOut[0] = codeSum
667-
sumsOut[1] = inputSum
668-
}
669-
670656
/**
671657
* Byte species matching the float species lane count — used for loading
672658
* exactly `floatSpecies.length()` bytes from a MemorySegment so that

0 commit comments

Comments
 (0)