Skip to content

Commit d48f172

Browse files
Merge pull request #565 from SKaiNET-developers/feature/jvm-q4_0-simd-dot
perf(q4_0): partial-vec dotQ4_0BlockMemSeg via scratch + SIMD FMA
2 parents 00b80c0 + 7252dc3 commit d48f172

1 file changed

Lines changed: 62 additions & 11 deletions

File tree

skainet-backends/skainet-backend-cpu/src/jvmMain/kotlin/sk/ainet/exec/tensor/ops/JvmQuantizedVectorKernels.kt

Lines changed: 62 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -518,34 +518,83 @@ internal object JvmQuantizedVectorKernels {
518518
* Q4_0 dot-product for a single block of 32 elements stored in a MemorySegment.
519519
*
520520
* Q4_0 block layout: 2 bytes f16 scale + 16 bytes packed nibbles (32 values).
521-
* Each byte packs two 4-bit codes: lo nibble = first, hi nibble = second.
521+
* Each byte packs two 4-bit codes — adjacent elements share a byte:
522+
* `code[2k] = byte[k] & 0x0F`, `code[2k+1] = byte[k] >>> 4`. Subtract
523+
* 8 for sign correction.
522524
*
523-
* Uses the preferred vector species (AVX-256 gives 8-wide, AVX-512 gives 16-wide).
525+
* Two-stage SIMD: a scalar byte-pair unpack writes 32 sign-corrected
526+
* floats into [codeBuf] (16 byte loads from the MemorySegment, two
527+
* nibbles per load — same memory traffic as the prior fully-scalar
528+
* implementation), then a vector FMA loop dot-products [codeBuf]
529+
* with the matching input slice. The nibble-pair-per-byte layout
530+
* makes a fully-fused `ByteVector` pipeline (a la
531+
* [PanamaVectorQ4KMatmulKernel]) awkward without strided gather or
532+
* lane-interleave shuffles, so this kernel keeps the scratch +
533+
* SIMD-dot shape — same approach Q4_K used before its
534+
* fused-pipeline rewrite (PR #562).
535+
*
536+
* @param codeBuf scratch FloatArray of length >= 32, supplied by
537+
* the caller so allocation amortizes across blocks.
524538
*/
525539
fun dotQ4_0BlockMemSeg(
526540
input: FloatArray,
527541
inputOffset: Int,
528542
weightSeg: MemorySegment,
529543
blockByteOffset: Long,
544+
codeBuf: FloatArray,
530545
): Float {
531546
val blockSize = 32
532547
val codesOffset = blockByteOffset + 2
533548

534549
// Read f16 scale
535550
val scale = halfToFloat(read2BytesLE(weightSeg, blockByteOffset))
536551

537-
// Q4_0: 16 packed bytes → 32 nibbles. Unpack all 32 codes to a reusable scratch array.
538-
// This is still scalar unpacking but avoids per-iteration FloatArray allocation.
539-
var sum = 0f
540-
for (idx in 0 until blockSize) {
541-
val packedByte = weightSeg.get(JAVA_BYTE_LE, codesOffset + (idx / 2).toLong()).toInt() and 0xFF
542-
val code = (if (idx % 2 == 0) (packedByte and 0x0F) else (packedByte ushr 4)).toFloat() - 8f
543-
sum += input[inputOffset + idx] * code
552+
// Unpack 16 packed bytes → 32 sign-corrected nibbles. Two
553+
// nibbles per byte load means half the byte traffic of the
554+
// straight scalar dot product.
555+
for (k in 0 until 16) {
556+
val b = weightSeg.get(JAVA_BYTE_LE, codesOffset + k.toLong()).toInt() and 0xFF
557+
codeBuf[2 * k] = (b and 0x0F).toFloat() - 8f
558+
codeBuf[2 * k + 1] = (b ushr 4).toFloat() - 8f
544559
}
545560

546-
return sum * scale
561+
// SIMD FMA dot product.
562+
val step = floatSpecies.length()
563+
var accVec = FloatVector.zero(floatSpecies)
564+
var idx = 0
565+
val loopBound = floatSpecies.loopBound(blockSize)
566+
while (idx < loopBound) {
567+
val iv = FloatVector.fromArray(floatSpecies, input, inputOffset + idx)
568+
val cv = FloatVector.fromArray(floatSpecies, codeBuf, idx)
569+
accVec = iv.fma(cv, accVec)
570+
idx += step
571+
}
572+
var acc = accVec.reduceLanes(VectorOperators.ADD)
573+
// Scalar tail (only fires if floatSpecies.length() doesn't divide 32 — rare).
574+
while (idx < blockSize) {
575+
acc += input[inputOffset + idx] * codeBuf[idx]
576+
idx++
577+
}
578+
579+
return acc * scale
547580
}
548581

582+
/**
583+
* Backwards-compatible overload that allocates its own scratch
584+
* buffer. Existing callers that don't pass one still work; the
585+
* matmul-level [matmulF32Q4_0MemSeg] hoists the allocation out of
586+
* the per-block loop.
587+
*/
588+
fun dotQ4_0BlockMemSeg(
589+
input: FloatArray,
590+
inputOffset: Int,
591+
weightSeg: MemorySegment,
592+
blockByteOffset: Long,
593+
): Float = dotQ4_0BlockMemSeg(
594+
input, inputOffset, weightSeg, blockByteOffset,
595+
codeBuf = FloatArray(32),
596+
)
597+
549598
/**
550599
* F32 x Q4_0 matrix-vector multiply using MemorySegment for packed Q4 weights.
551600
*
@@ -569,14 +618,16 @@ internal object JvmQuantizedVectorKernels {
569618
val blockSize = 32
570619
val bytesPerBlock = 18L // 2 bytes scale + 16 bytes codes
571620
val blocksPerRow = (inputDim + blockSize - 1) / blockSize
621+
// Scratch hoisted out of the per-block loop — see dotQ4_0BlockMemSeg kdoc.
622+
val codeBuf = FloatArray(blockSize)
572623

573624
for (o in 0 until outputDim) {
574625
var acc = 0f
575626
for (blockIdx in 0 until blocksPerRow) {
576627
val blockOff = weightByteOffset +
577628
(o.toLong() * blocksPerRow + blockIdx) * bytesPerBlock
578629
val inputStart = blockIdx * blockSize
579-
acc += dotQ4_0BlockMemSeg(input, inputStart, weightSeg, blockOff)
630+
acc += dotQ4_0BlockMemSeg(input, inputStart, weightSeg, blockOff, codeBuf)
580631
}
581632
output[outputOffset + o] = acc
582633
}

0 commit comments

Comments
 (0)