|
| 1 | +package sk.ainet.exec.kernel |
| 2 | + |
| 3 | +import jdk.incubator.vector.ByteVector |
| 4 | +import jdk.incubator.vector.FloatVector |
| 5 | +import jdk.incubator.vector.VectorOperators |
| 6 | +import jdk.incubator.vector.VectorSpecies |
| 7 | +import sk.ainet.backend.api.kernel.Q4KMatmulKernel |
| 8 | +import sk.ainet.exec.tensor.ops.parallelChunks |
| 9 | + |
| 10 | +/** |
| 11 | + * SIMD-vectorized Q4_K matmul on the JDK Vector API. |
| 12 | + * |
| 13 | + * Pipeline per 32-byte qs slab (which carries two adjacent sub-blocks |
| 14 | + * — sub-block `2j` in lo nibbles, sub-block `2j+1` in hi nibbles): |
| 15 | + * 1. `ByteVector.fromArray(byteSpeciesForFloat, weight, qsRegion+idx)` — single load. |
| 16 | + * 2. `loNibVec = byteVec.and(0x0F.toByte())`, |
| 17 | + * `hiNibVec = byteVec.lanewise(LSHR, 4)` — extract both nibbles. |
| 18 | + * 3. `castShape(floatSpecies, 0)` — widen + I2F. |
| 19 | + * 4. `inputVec.fma(codeFloatVec, codeAcc)` — accumulate `Σ(input·code)` |
| 20 | + * per sub-block; track `inputAcc = Σ(input)` separately for the |
| 21 | + * lazy-`dmin` correction. |
| 22 | + * 5. After all super-blocks for a given output cell, sum across |
| 23 | + * sub-blocks: `acc += scale[s] · codeSum[s] − offset[s] · inputSum[s]` |
| 24 | + * with `scale[s] = d · scaleIdx[s]` and `offset[s] = dMin · minIdx[s]`. |
| 25 | + * |
| 26 | + * Compared to [sk.ainet.exec.tensor.ops.JvmQuantizedVectorKernels.matmulQ4_KVec]: |
| 27 | + * - Replaces the scalar 32-iteration nibble unpack into a scratch |
| 28 | + * `FloatArray` with a single `ByteVector` load + `castShape` per |
| 29 | + * `floatSpecies.length()` elements. |
| 30 | + * - Folds lo + hi nibble passes into a single byte load (existing |
| 31 | + * helper called the byte-load helper twice — once per nibble). |
| 32 | + * |
| 33 | + * Numerical equivalence with the existing partial-vec kernel is |
| 34 | + * within FMA + reordered-reduction tolerance; verified via parity |
| 35 | + * tests at `1e-5 · inputDim`. |
| 36 | + */ |
| 37 | +public object PanamaVectorQ4KMatmulKernel : Q4KMatmulKernel { |
| 38 | + |
| 39 | + private const val BLOCK_SIZE = 256 |
| 40 | + private const val SUB_BLOCK_SIZE = 32 |
| 41 | + private const val SUB_BLOCKS_PER_BLOCK = 8 |
| 42 | + private const val BYTES_PER_BLOCK = 144 |
| 43 | + |
| 44 | + private val floatSpecies: VectorSpecies<Float> = FloatVector.SPECIES_PREFERRED |
| 45 | + |
| 46 | + /** |
| 47 | + * Byte species sized so `castShape(floatSpecies, 0)` consumes |
| 48 | + * exactly `floatSpecies.length()` bytes — same convention as |
| 49 | + * [sk.ainet.exec.tensor.ops.JvmQuantizedVectorKernels.byteSpeciesForFloat]. |
| 50 | + */ |
| 51 | + private val byteSpeciesForFloat: VectorSpecies<Byte> = when (floatSpecies.length()) { |
| 52 | + 16 -> ByteVector.SPECIES_128 |
| 53 | + else -> ByteVector.SPECIES_64 // covers 4-wide (NEON) and 8-wide (AVX2) |
| 54 | + } |
| 55 | + |
| 56 | + override fun matmul( |
| 57 | + input: FloatArray, inputOffset: Int, |
| 58 | + weight: ByteArray, weightByteOffset: Int, |
| 59 | + inputDim: Int, outputDim: Int, |
| 60 | + output: FloatArray, outputOffset: Int, |
| 61 | + ) { |
| 62 | + require(inputDim % BLOCK_SIZE == 0) { |
| 63 | + "PanamaVectorQ4KMatmulKernel: inputDim must be a multiple of $BLOCK_SIZE; got $inputDim" |
| 64 | + } |
| 65 | + if (outputDim == 0 || inputDim == 0) return |
| 66 | + val blocksPerInputDim = inputDim / BLOCK_SIZE |
| 67 | + |
| 68 | + parallelChunks(outputDim) { startO, endO -> |
| 69 | + // Per-task scratch — must not be shared across worker threads. |
| 70 | + val scaleIdx = IntArray(SUB_BLOCKS_PER_BLOCK) |
| 71 | + val minIdx = IntArray(SUB_BLOCKS_PER_BLOCK) |
| 72 | + for (o in startO until endO) { |
| 73 | + var acc = 0f |
| 74 | + for (blockIdx in 0 until blocksPerInputDim) { |
| 75 | + val blockBase = weightByteOffset + (blockIdx * outputDim + o) * BYTES_PER_BLOCK |
| 76 | + |
| 77 | + // d, dMin (FP16 LE). |
| 78 | + val dBits = (weight[blockBase + 1].toInt() and 0xFF shl 8) or |
| 79 | + (weight[blockBase].toInt() and 0xFF) |
| 80 | + val dMinBits = (weight[blockBase + 3].toInt() and 0xFF shl 8) or |
| 81 | + (weight[blockBase + 2].toInt() and 0xFF) |
| 82 | + val d = halfToFloat(dBits) |
| 83 | + val dMin = halfToFloat(dMinBits) |
| 84 | + |
| 85 | + // Sub-scale decode via ggml `get_scale_min_k4`. |
| 86 | + val scalesOffset = blockBase + 4 |
| 87 | + for (sb in 0 until 4) { |
| 88 | + scaleIdx[sb] = weight[scalesOffset + sb].toInt() and 0x3F |
| 89 | + minIdx[sb] = weight[scalesOffset + sb + 4].toInt() and 0x3F |
| 90 | + } |
| 91 | + for (sb in 4 until 8) { |
| 92 | + val low4S = weight[scalesOffset + sb + 4].toInt() and 0x0F |
| 93 | + val high2S = (weight[scalesOffset + sb - 4].toInt() and 0xFF) ushr 6 |
| 94 | + scaleIdx[sb] = low4S or (high2S shl 4) |
| 95 | + val low4M = (weight[scalesOffset + sb + 4].toInt() and 0xFF) ushr 4 |
| 96 | + val high2M = (weight[scalesOffset + sb].toInt() and 0xFF) ushr 6 |
| 97 | + minIdx[sb] = low4M or (high2M shl 4) |
| 98 | + } |
| 99 | + |
| 100 | + // 4 strided qs groups; each carries sbLo (lo nibbles) and sbHi (hi nibbles). |
| 101 | + val codesOffset = blockBase + 16 |
| 102 | + val inputBlockBase = inputOffset + blockIdx * BLOCK_SIZE |
| 103 | + for (groupJ in 0 until 4) { |
| 104 | + val qsRegion = codesOffset + groupJ * 32 |
| 105 | + val sbLo = 2 * groupJ |
| 106 | + val sbHi = sbLo + 1 |
| 107 | + val inputStartLo = inputBlockBase + sbLo * SUB_BLOCK_SIZE |
| 108 | + val inputStartHi = inputStartLo + SUB_BLOCK_SIZE |
| 109 | + |
| 110 | + var codeAccLo = FloatVector.zero(floatSpecies) |
| 111 | + var inputAccLo = FloatVector.zero(floatSpecies) |
| 112 | + var codeAccHi = FloatVector.zero(floatSpecies) |
| 113 | + var inputAccHi = FloatVector.zero(floatSpecies) |
| 114 | + |
| 115 | + val floatStep = floatSpecies.length() |
| 116 | + val byteLoadLen = byteSpeciesForFloat.length() |
| 117 | + var idx = 0 |
| 118 | + |
| 119 | + // SIMD body — single byte load feeds both nibble vectors. |
| 120 | + while (idx + floatStep <= SUB_BLOCK_SIZE && |
| 121 | + qsRegion + idx + byteLoadLen <= weight.size |
| 122 | + ) { |
| 123 | + val inVecLo = FloatVector.fromArray(floatSpecies, input, inputStartLo + idx) |
| 124 | + val inVecHi = FloatVector.fromArray(floatSpecies, input, inputStartHi + idx) |
| 125 | + val byteVec = ByteVector.fromArray(byteSpeciesForFloat, weight, qsRegion + idx) |
| 126 | + val loBytes = byteVec.and(0x0F.toByte()) |
| 127 | + val hiBytes = byteVec.lanewise(VectorOperators.LSHR, 4.toByte()) |
| 128 | + val codeVecLo = loBytes.castShape(floatSpecies, 0) as FloatVector |
| 129 | + val codeVecHi = hiBytes.castShape(floatSpecies, 0) as FloatVector |
| 130 | + codeAccLo = inVecLo.fma(codeVecLo, codeAccLo) |
| 131 | + inputAccLo = inVecLo.add(inputAccLo) |
| 132 | + codeAccHi = inVecHi.fma(codeVecHi, codeAccHi) |
| 133 | + inputAccHi = inVecHi.add(inputAccHi) |
| 134 | + idx += floatStep |
| 135 | + } |
| 136 | + |
| 137 | + var codeSumLo = codeAccLo.reduceLanes(VectorOperators.ADD) |
| 138 | + var inputSumLo = inputAccLo.reduceLanes(VectorOperators.ADD) |
| 139 | + var codeSumHi = codeAccHi.reduceLanes(VectorOperators.ADD) |
| 140 | + var inputSumHi = inputAccHi.reduceLanes(VectorOperators.ADD) |
| 141 | + |
| 142 | + // Scalar tail — only fires if floatSpecies.length() doesn't divide 32 (rare). |
| 143 | + while (idx < SUB_BLOCK_SIZE) { |
| 144 | + val byte = weight[qsRegion + idx].toInt() and 0xFF |
| 145 | + val codeLo = (byte and 0x0F).toFloat() |
| 146 | + val codeHi = (byte ushr 4).toFloat() |
| 147 | + val vLo = input[inputStartLo + idx] |
| 148 | + val vHi = input[inputStartHi + idx] |
| 149 | + codeSumLo += vLo * codeLo |
| 150 | + inputSumLo += vLo |
| 151 | + codeSumHi += vHi * codeHi |
| 152 | + inputSumHi += vHi |
| 153 | + idx++ |
| 154 | + } |
| 155 | + |
| 156 | + val scaleLo = d * scaleIdx[sbLo] |
| 157 | + val offsetLo = dMin * minIdx[sbLo] |
| 158 | + val scaleHi = d * scaleIdx[sbHi] |
| 159 | + val offsetHi = dMin * minIdx[sbHi] |
| 160 | + acc += codeSumLo * scaleLo - inputSumLo * offsetLo |
| 161 | + acc += codeSumHi * scaleHi - inputSumHi * offsetHi |
| 162 | + } |
| 163 | + } |
| 164 | + output[outputOffset + o] = acc |
| 165 | + } |
| 166 | + } |
| 167 | + } |
| 168 | + |
| 169 | + /** |
| 170 | + * IEEE 754 binary16 → binary32 conversion. Mirrors the helper used |
| 171 | + * inside `JvmQuantizedVectorKernels` and `Q4_KTensorData` — kept |
| 172 | + * private to this file rather than depending on either, since both |
| 173 | + * are `internal` in their respective modules. |
| 174 | + */ |
| 175 | + private fun halfToFloat(hbits: Int): Float { |
| 176 | + val sign = (hbits ushr 15) and 0x1 |
| 177 | + val exp = (hbits ushr 10) and 0x1F |
| 178 | + val frac = hbits and 0x3FF |
| 179 | + return when { |
| 180 | + exp == 0 -> { |
| 181 | + if (frac == 0) { |
| 182 | + if (sign == 0) 0.0f else -0.0f |
| 183 | + } else { |
| 184 | + val f = frac / 1024.0f * (1.0f / 16384.0f) |
| 185 | + if (sign == 0) f else -f |
| 186 | + } |
| 187 | + } |
| 188 | + exp == 0x1F -> { |
| 189 | + if (frac == 0) { |
| 190 | + if (sign == 0) Float.POSITIVE_INFINITY else Float.NEGATIVE_INFINITY |
| 191 | + } else { |
| 192 | + Float.NaN |
| 193 | + } |
| 194 | + } |
| 195 | + else -> { |
| 196 | + val bits = (sign shl 31) or ((exp - 15 + 127) shl 23) or (frac shl 13) |
| 197 | + Float.fromBits(bits) |
| 198 | + } |
| 199 | + } |
| 200 | + } |
| 201 | +} |
0 commit comments