|
| 1 | +package sk.ainet.exec.kernel |
| 2 | + |
| 3 | +import jdk.incubator.vector.FloatVector |
| 4 | +import jdk.incubator.vector.VectorOperators |
| 5 | +import jdk.incubator.vector.VectorSpecies |
| 6 | +import sk.ainet.backend.api.kernel.Q5_1MatmulKernel |
| 7 | + |
| 8 | +/** |
| 9 | + * SIMD-vectorized FP32 × Q5_1 matmul on the JDK Vector API. Per 32-element block: |
| 10 | + * decode `d`/`m`/`qh`, dequant the 32 codes (`d*(code + (highBit shl 4)) + m`, split |
| 11 | + * nibble layout) into a reusable scratch buffer, then SIMD-FMA against the matching |
| 12 | + * input window. Numerically equivalent to [ScalarQ5_1MatmulKernel] within FMA + |
| 13 | + * reordered-reduction tolerance. Block-major weight layout `(blockIdx*outputDim+o)*24`. |
| 14 | + */ |
| 15 | +public object PanamaVectorQ5_1MatmulKernel : Q5_1MatmulKernel { |
| 16 | + |
| 17 | + private const val BLOCK_SIZE = 32 |
| 18 | + private const val BYTES_PER_BLOCK = 24 |
| 19 | + private val floatSpecies: VectorSpecies<Float> = FloatVector.SPECIES_PREFERRED |
| 20 | + |
| 21 | + override fun matmul( |
| 22 | + input: FloatArray, inputOffset: Int, |
| 23 | + weight: ByteArray, weightByteOffset: Int, |
| 24 | + inputDim: Int, outputDim: Int, |
| 25 | + output: FloatArray, outputOffset: Int, |
| 26 | + ) { |
| 27 | + require(inputDim % BLOCK_SIZE == 0) { |
| 28 | + "PanamaVectorQ5_1MatmulKernel: inputDim must be a multiple of $BLOCK_SIZE; got $inputDim" |
| 29 | + } |
| 30 | + if (outputDim == 0) return |
| 31 | + if (inputDim == 0) { for (o in 0 until outputDim) output[outputOffset + o] = 0f; return } |
| 32 | + val blocksPerInputDim = inputDim / BLOCK_SIZE |
| 33 | + val step = floatSpecies.length() |
| 34 | + val loopBound = floatSpecies.loopBound(BLOCK_SIZE) |
| 35 | + val codeBuf = FloatArray(BLOCK_SIZE) |
| 36 | + |
| 37 | + for (o in 0 until outputDim) { |
| 38 | + var acc = 0f |
| 39 | + for (blockIdx in 0 until blocksPerInputDim) { |
| 40 | + val base = weightByteOffset + (blockIdx * outputDim + o) * BYTES_PER_BLOCK |
| 41 | + val d = halfToFloat(((weight[base + 1].toInt() and 0xFF) shl 8) or (weight[base].toInt() and 0xFF)) |
| 42 | + val m = halfToFloat(((weight[base + 3].toInt() and 0xFF) shl 8) or (weight[base + 2].toInt() and 0xFF)) |
| 43 | + val qh0 = weight[base + 4].toInt() and 0xFF |
| 44 | + val qh1 = weight[base + 5].toInt() and 0xFF |
| 45 | + val qh2 = weight[base + 6].toInt() and 0xFF |
| 46 | + val qh3 = weight[base + 7].toInt() and 0xFF |
| 47 | + val qsBase = base + 8 |
| 48 | + for (j in 0 until 16) { |
| 49 | + val q = weight[qsBase + j].toInt() and 0xFF |
| 50 | + val bitLo = ((if (j < 8) qh0 else qh1) ushr (j and 7)) and 1 |
| 51 | + val bitHi = ((if (j < 8) qh2 else qh3) ushr (j and 7)) and 1 |
| 52 | + codeBuf[j] = d * ((q and 0x0F) + (bitLo shl 4)) + m |
| 53 | + codeBuf[16 + j] = d * ((q ushr 4) + (bitHi shl 4)) + m |
| 54 | + } |
| 55 | + val inputBase = inputOffset + blockIdx * BLOCK_SIZE |
| 56 | + var accVec = FloatVector.zero(floatSpecies) |
| 57 | + var k = 0 |
| 58 | + while (k < loopBound) { |
| 59 | + accVec = FloatVector.fromArray(floatSpecies, input, inputBase + k) |
| 60 | + .fma(FloatVector.fromArray(floatSpecies, codeBuf, k), accVec) |
| 61 | + k += step |
| 62 | + } |
| 63 | + acc += accVec.reduceLanes(VectorOperators.ADD) |
| 64 | + while (k < BLOCK_SIZE) { acc += input[inputBase + k] * codeBuf[k]; k++ } |
| 65 | + } |
| 66 | + output[outputOffset + o] = acc |
| 67 | + } |
| 68 | + } |
| 69 | + |
| 70 | + /** Same FP16 → FP32 conversion as [ScalarQ5_1MatmulKernel]. */ |
| 71 | + private fun halfToFloat(hbits: Int): Float { |
| 72 | + val sign = (hbits and 0x8000) shl 16 |
| 73 | + val exp = (hbits and 0x7C00) shr 10 |
| 74 | + val mant = hbits and 0x03FF |
| 75 | + return when (exp) { |
| 76 | + 0 -> if (mant == 0) Float.fromBits(sign) else { |
| 77 | + var m = mant; var e = -14 |
| 78 | + while ((m and 0x400) == 0) { m = m shl 1; e-- } |
| 79 | + Float.fromBits(sign or ((e + 127) shl 23) or ((m and 0x3FF) shl 13)) |
| 80 | + } |
| 81 | + 31 -> Float.fromBits(sign or (0xFF shl 23) or (mant shl 13)) |
| 82 | + else -> Float.fromBits(sign or ((exp - 15 + 127) shl 23) or (mant shl 13)) |
| 83 | + } |
| 84 | + } |
| 85 | +} |
0 commit comments