Skip to content

Commit 2b84824

Browse files
Merge pull request #647 from SKaiNET-developers/chore/resync-api-dumps
chore: resync API dumps to current develop source
2 parents e6e6783 + e4b16f8 commit 2b84824

28 files changed

Lines changed: 2999 additions & 34 deletions

File tree

skainet-backends/skainet-backend-api/src/commonMain/kotlin/sk/ainet/backend/api/kernel/KernelProvider.kt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,12 @@ public interface KernelProvider {
6767
*/
6868
public fun matmulQ8_0(): Q8_0MatmulKernel? = null
6969

70+
/**
71+
* F32 × Q4_0 matmul kernel exposed by this provider, or `null` if
72+
* this provider does not specialize Q4_0. Same fall-through pattern.
73+
*/
74+
public fun matmulQ4_0(): Q4_0MatmulKernel? = null
75+
7076
/**
7177
* Capability query: does this provider carry a kernel for
7278
* [opName] with the given [dtypeKeys]?
@@ -100,6 +106,7 @@ public interface KernelProvider {
100106
"BFloat16" -> matmulBf16() != null
101107
"Q4_K" -> matmulQ4K() != null
102108
"Q8_0" -> matmulQ8_0() != null
109+
"Q4_0" -> matmulQ4_0() != null
103110
else -> false
104111
}
105112
}
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package sk.ainet.backend.api.kernel
2+
3+
/**
4+
* F32 input × Q4_0-packed weights matrix-vector multiply, in canonical
5+
* ggml block layout.
6+
*
7+
* output[outputOffset + o] = Σ_j input[inputOffset + j] · dequant(weight[o, j])
8+
* for j ∈ [0, inputDim), o ∈ [0, outputDim)
9+
*
10+
* Block layout (32-element block, 18 bytes/block; see
11+
* [sk.ainet.lang.tensor.data.Q4_0BlockTensorData] kdoc):
12+
* - bytes 0..1 : `d` (block scale, FP16 LE)
13+
* - bytes 2..17 : 16 bytes packing 32 4-bit codes (split layout — low
14+
* nibbles decode elements 0..15, high nibbles decode elements 16..31)
15+
*
16+
* Per element: `dequant = (code - 8) * d` (the `- 8` bias centres the
17+
* unsigned 4-bit code around zero). Q4_0 has no per-block min / offset.
18+
*
19+
* Implementations MUST NOT mutate `input` or `weight`. They MAY assume
20+
* the arrays do not alias each other or `output`. They MUST fully
21+
* write the `outputDim` floats starting at `output[outputOffset]`.
22+
*
23+
* Packed-weight row-major contract: `weight` holds blocks laid out
24+
* `(blockIdx * outputDim + o) * 18` for output row `o` and input block
25+
* index `blockIdx`. This matches `Q4_0BlockTensorData.packedData`.
26+
*
27+
* `inputDim` MUST be a multiple of 32 (the Q4_0 block size).
28+
*/
29+
public interface Q4_0MatmulKernel {
30+
/**
31+
* @param input FP32 input vector (single row).
32+
* @param inputOffset element offset into [input] where the row starts.
33+
* @param weight packed Q4_0 bytes for the full `outputDim × inputDim` weight tensor.
34+
* @param weightByteOffset byte offset into [weight] where block (0, 0) starts.
35+
* @param inputDim contraction dimension (must be a multiple of 32).
36+
* @param outputDim number of output cells.
37+
* @param output FP32 output vector.
38+
* @param outputOffset element offset into [output] where the row starts.
39+
*/
40+
public fun matmul(
41+
input: FloatArray, inputOffset: Int,
42+
weight: ByteArray, weightByteOffset: Int,
43+
inputDim: Int, outputDim: Int,
44+
output: FloatArray, outputOffset: Int,
45+
)
46+
}

skainet-backends/skainet-backend-cpu/api/jvm/skainet-backend-cpu.api

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,20 @@ public final class sk/ainet/context/DirectCpuExecutionContext : sk/ainet/context
1717
public fun getHooks ()Lsk/ainet/lang/nn/hooks/ForwardHooks;
1818
public fun getInTraining ()Z
1919
public fun getMemoryInfo ()Lsk/ainet/context/MemoryInfo;
20+
public fun getMemoryPlanner ()Lsk/ainet/lang/tensor/storage/MemoryPlanner;
21+
public fun getMemoryTracker ()Lsk/ainet/lang/tensor/storage/MemoryTracker;
2022
public fun getObservers ()Lsk/ainet/context/ExecutionObserverRegistry;
2123
public fun getOps ()Lsk/ainet/lang/tensor/ops/TensorOps;
2224
public fun getPhase ()Lsk/ainet/context/Phase;
25+
public fun getScratch ()Lsk/ainet/lang/tensor/scratch/ScratchPool;
2326
public fun getTensorDataFactory ()Lsk/ainet/lang/tensor/data/TensorDataFactory;
2427
public fun ones (Lsk/ainet/lang/tensor/Shape;Lkotlin/reflect/KClass;)Lsk/ainet/lang/tensor/Tensor;
28+
public fun placeholder (Lsk/ainet/lang/tensor/Shape;Lkotlin/reflect/KClass;)Lsk/ainet/lang/tensor/Tensor;
2529
public fun registerObserver (Lsk/ainet/context/ExecutionObserver;)V
2630
public fun unregisterObserver (Lsk/ainet/context/ExecutionObserver;)V
31+
public fun wrapByteArray (Lsk/ainet/lang/tensor/Shape;Lkotlin/reflect/KClass;[B)Lsk/ainet/lang/tensor/Tensor;
32+
public fun wrapFloatArray (Lsk/ainet/lang/tensor/Shape;Lkotlin/reflect/KClass;[F)Lsk/ainet/lang/tensor/Tensor;
33+
public fun wrapIntArray (Lsk/ainet/lang/tensor/Shape;Lkotlin/reflect/KClass;[I)Lsk/ainet/lang/tensor/Tensor;
2734
public fun zeros (Lsk/ainet/lang/tensor/Shape;Lkotlin/reflect/KClass;)Lsk/ainet/lang/tensor/Tensor;
2835
}
2936

@@ -33,6 +40,103 @@ public final class sk/ainet/context/DirectCpuExecutionContext$Companion {
3340
public static synthetic fun create$default (Lsk/ainet/context/DirectCpuExecutionContext$Companion;Lsk/ainet/context/Phase;ILjava/lang/Object;)Lsk/ainet/context/DirectCpuExecutionContext;
3441
}
3542

43+
public final class sk/ainet/exec/kernel/PanamaVectorBf16MatmulKernel : sk/ainet/backend/api/kernel/Bf16MatmulKernel {
44+
public static final field INSTANCE Lsk/ainet/exec/kernel/PanamaVectorBf16MatmulKernel;
45+
public fun matmul ([FII[BII[FIIIII)V
46+
}
47+
48+
public final class sk/ainet/exec/kernel/PanamaVectorKernelProvider : sk/ainet/backend/api/kernel/KernelProvider {
49+
public static final field INSTANCE Lsk/ainet/exec/kernel/PanamaVectorKernelProvider;
50+
public fun getName ()Ljava/lang/String;
51+
public fun getPriority ()I
52+
public fun isAvailable ()Z
53+
public fun matmulBf16 ()Lsk/ainet/backend/api/kernel/Bf16MatmulKernel;
54+
public fun matmulFp32 ()Lsk/ainet/backend/api/kernel/Fp32MatmulKernel;
55+
public fun matmulQ4K ()Lsk/ainet/backend/api/kernel/Q4KMatmulKernel;
56+
public fun matmulQ4_0 ()Lsk/ainet/backend/api/kernel/Q4_0MatmulKernel;
57+
public fun matmulQ8_0 ()Lsk/ainet/backend/api/kernel/Q8_0MatmulKernel;
58+
public fun supports (Ljava/lang/String;Ljava/util/List;)Z
59+
}
60+
61+
public final class sk/ainet/exec/kernel/PanamaVectorKernelProviderFactory : sk/ainet/backend/api/kernel/KernelProvider {
62+
public fun <init> ()V
63+
public fun getName ()Ljava/lang/String;
64+
public fun getPriority ()I
65+
public fun isAvailable ()Z
66+
public fun matmulBf16 ()Lsk/ainet/backend/api/kernel/Bf16MatmulKernel;
67+
public fun matmulFp32 ()Lsk/ainet/backend/api/kernel/Fp32MatmulKernel;
68+
public fun matmulQ4K ()Lsk/ainet/backend/api/kernel/Q4KMatmulKernel;
69+
public fun matmulQ4_0 ()Lsk/ainet/backend/api/kernel/Q4_0MatmulKernel;
70+
public fun matmulQ8_0 ()Lsk/ainet/backend/api/kernel/Q8_0MatmulKernel;
71+
public fun supports (Ljava/lang/String;Ljava/util/List;)Z
72+
}
73+
74+
public final class sk/ainet/exec/kernel/PanamaVectorMatmulKernel : sk/ainet/backend/api/kernel/Fp32MatmulKernel {
75+
public static final field INSTANCE Lsk/ainet/exec/kernel/PanamaVectorMatmulKernel;
76+
public fun matmul ([FII[FII[FIIIII)V
77+
}
78+
79+
public final class sk/ainet/exec/kernel/PanamaVectorQ4KMatmulKernel : sk/ainet/backend/api/kernel/Q4KMatmulKernel {
80+
public static final field INSTANCE Lsk/ainet/exec/kernel/PanamaVectorQ4KMatmulKernel;
81+
public fun matmul ([FI[BIII[FI)V
82+
}
83+
84+
public final class sk/ainet/exec/kernel/PanamaVectorQ4_0MatmulKernel : sk/ainet/backend/api/kernel/Q4_0MatmulKernel {
85+
public static final field INSTANCE Lsk/ainet/exec/kernel/PanamaVectorQ4_0MatmulKernel;
86+
public fun matmul ([FI[BIII[FI)V
87+
}
88+
89+
public final class sk/ainet/exec/kernel/PanamaVectorQ8_0MatmulKernel : sk/ainet/backend/api/kernel/Q8_0MatmulKernel {
90+
public static final field INSTANCE Lsk/ainet/exec/kernel/PanamaVectorQ8_0MatmulKernel;
91+
public fun matmul ([FI[BIII[FI)V
92+
}
93+
94+
public final class sk/ainet/exec/kernel/ScalarBf16MatmulKernel : sk/ainet/backend/api/kernel/Bf16MatmulKernel {
95+
public static final field INSTANCE Lsk/ainet/exec/kernel/ScalarBf16MatmulKernel;
96+
public fun matmul ([FII[BII[FIIIII)V
97+
}
98+
99+
public final class sk/ainet/exec/kernel/ScalarKernelProvider : sk/ainet/backend/api/kernel/KernelProvider {
100+
public static final field INSTANCE Lsk/ainet/exec/kernel/ScalarKernelProvider;
101+
public fun getName ()Ljava/lang/String;
102+
public fun getPriority ()I
103+
public fun isAvailable ()Z
104+
public fun matmulBf16 ()Lsk/ainet/backend/api/kernel/Bf16MatmulKernel;
105+
public fun matmulFp32 ()Lsk/ainet/backend/api/kernel/Fp32MatmulKernel;
106+
public fun matmulQ4K ()Lsk/ainet/backend/api/kernel/Q4KMatmulKernel;
107+
public fun matmulQ4_0 ()Lsk/ainet/backend/api/kernel/Q4_0MatmulKernel;
108+
public fun matmulQ8_0 ()Lsk/ainet/backend/api/kernel/Q8_0MatmulKernel;
109+
public fun supports (Ljava/lang/String;Ljava/util/List;)Z
110+
}
111+
112+
public final class sk/ainet/exec/kernel/ScalarKernelProviderFactory : sk/ainet/backend/api/kernel/KernelProvider {
113+
public fun <init> ()V
114+
public fun getName ()Ljava/lang/String;
115+
public fun getPriority ()I
116+
public fun isAvailable ()Z
117+
public fun matmulBf16 ()Lsk/ainet/backend/api/kernel/Bf16MatmulKernel;
118+
public fun matmulFp32 ()Lsk/ainet/backend/api/kernel/Fp32MatmulKernel;
119+
public fun matmulQ4K ()Lsk/ainet/backend/api/kernel/Q4KMatmulKernel;
120+
public fun matmulQ4_0 ()Lsk/ainet/backend/api/kernel/Q4_0MatmulKernel;
121+
public fun matmulQ8_0 ()Lsk/ainet/backend/api/kernel/Q8_0MatmulKernel;
122+
public fun supports (Ljava/lang/String;Ljava/util/List;)Z
123+
}
124+
125+
public final class sk/ainet/exec/kernel/ScalarMatmulKernel : sk/ainet/backend/api/kernel/Fp32MatmulKernel {
126+
public static final field INSTANCE Lsk/ainet/exec/kernel/ScalarMatmulKernel;
127+
public fun matmul ([FII[FII[FIIIII)V
128+
}
129+
130+
public final class sk/ainet/exec/kernel/ScalarQ4_0MatmulKernel : sk/ainet/backend/api/kernel/Q4_0MatmulKernel {
131+
public static final field INSTANCE Lsk/ainet/exec/kernel/ScalarQ4_0MatmulKernel;
132+
public fun matmul ([FI[BIII[FI)V
133+
}
134+
135+
public final class sk/ainet/exec/kernel/ScalarQ8_0MatmulKernel : sk/ainet/backend/api/kernel/Q8_0MatmulKernel {
136+
public static final field INSTANCE Lsk/ainet/exec/kernel/ScalarQ8_0MatmulKernel;
137+
public fun matmul ([FI[BIII[FI)V
138+
}
139+
36140
public final class sk/ainet/exec/tensor/ops/DefaultCpuOps : sk/ainet/exec/tensor/ops/DefaultCpuOpsBase {
37141
public fun <init> (Lsk/ainet/lang/tensor/data/TensorDataFactory;)V
38142
}
@@ -49,7 +153,9 @@ public class sk/ainet/exec/tensor/ops/DefaultCpuOpsBase : sk/ainet/lang/tensor/o
49153
public fun conv1d (Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;IIII)Lsk/ainet/lang/tensor/Tensor;
50154
public fun conv2d (Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;Lkotlin/Pair;Lkotlin/Pair;Lkotlin/Pair;I)Lsk/ainet/lang/tensor/Tensor;
51155
public fun conv3d (Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;Lkotlin/Triple;Lkotlin/Triple;Lkotlin/Triple;I)Lsk/ainet/lang/tensor/Tensor;
156+
public fun convTranspose1d (Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;IIIII)Lsk/ainet/lang/tensor/Tensor;
52157
public fun convert (Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/types/DType;)Lsk/ainet/lang/tensor/Tensor;
158+
public fun cos (Lsk/ainet/lang/tensor/Tensor;)Lsk/ainet/lang/tensor/Tensor;
53159
public fun divScalar (Lsk/ainet/lang/tensor/Tensor;Ljava/lang/Number;)Lsk/ainet/lang/tensor/Tensor;
54160
public fun divide (Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;)Lsk/ainet/lang/tensor/Tensor;
55161
protected final fun elementwise (Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;Lkotlin/jvm/functions/Function3;)Lsk/ainet/lang/tensor/Tensor;
@@ -64,6 +170,9 @@ public class sk/ainet/exec/tensor/ops/DefaultCpuOpsBase : sk/ainet/lang/tensor/o
64170
protected final fun gradStateFrom ([Lsk/ainet/lang/tensor/Tensor;)Lsk/ainet/lang/tensor/GradState;
65171
public fun indexSelect (Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;I)Lsk/ainet/lang/tensor/Tensor;
66172
public fun leakyRelu (Lsk/ainet/lang/tensor/Tensor;F)Lsk/ainet/lang/tensor/Tensor;
173+
public fun log (Lsk/ainet/lang/tensor/Tensor;)Lsk/ainet/lang/tensor/Tensor;
174+
public fun log10 (Lsk/ainet/lang/tensor/Tensor;)Lsk/ainet/lang/tensor/Tensor;
175+
public fun log2 (Lsk/ainet/lang/tensor/Tensor;)Lsk/ainet/lang/tensor/Tensor;
67176
public fun logSoftmax (Lsk/ainet/lang/tensor/Tensor;I)Lsk/ainet/lang/tensor/Tensor;
68177
public fun lt (Lsk/ainet/lang/tensor/Tensor;F)Lsk/ainet/lang/tensor/Tensor;
69178
protected final fun mapIndex ([ILsk/ainet/lang/tensor/Shape;)[I
@@ -75,6 +184,9 @@ public class sk/ainet/exec/tensor/ops/DefaultCpuOpsBase : sk/ainet/lang/tensor/o
75184
public fun narrow (Lsk/ainet/lang/tensor/Tensor;III)Lsk/ainet/lang/tensor/Tensor;
76185
protected final fun newTensor (Lsk/ainet/lang/tensor/data/TensorData;Lkotlin/reflect/KClass;[Lsk/ainet/lang/tensor/Tensor;)Lsk/ainet/lang/tensor/Tensor;
77186
public fun pad2d (Lsk/ainet/lang/tensor/Tensor;IIII)Lsk/ainet/lang/tensor/Tensor;
187+
public fun permute (Lsk/ainet/lang/tensor/Tensor;[I)Lsk/ainet/lang/tensor/Tensor;
188+
public fun pow (Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;)Lsk/ainet/lang/tensor/Tensor;
189+
public fun powScalar (Lsk/ainet/lang/tensor/Tensor;Ljava/lang/Number;)Lsk/ainet/lang/tensor/Tensor;
78190
public fun rdivScalar (Ljava/lang/Number;Lsk/ainet/lang/tensor/Tensor;)Lsk/ainet/lang/tensor/Tensor;
79191
public fun relu (Lsk/ainet/lang/tensor/Tensor;)Lsk/ainet/lang/tensor/Tensor;
80192
protected final fun requireSameDType (Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;)V
@@ -84,13 +196,15 @@ public class sk/ainet/exec/tensor/ops/DefaultCpuOpsBase : sk/ainet/lang/tensor/o
84196
public fun sigmoid (Lsk/ainet/lang/tensor/Tensor;)Lsk/ainet/lang/tensor/Tensor;
85197
public fun sign (Lsk/ainet/lang/tensor/Tensor;)Lsk/ainet/lang/tensor/Tensor;
86198
public fun silu (Lsk/ainet/lang/tensor/Tensor;)Lsk/ainet/lang/tensor/Tensor;
199+
public fun sin (Lsk/ainet/lang/tensor/Tensor;)Lsk/ainet/lang/tensor/Tensor;
87200
public fun softmax (Lsk/ainet/lang/tensor/Tensor;I)Lsk/ainet/lang/tensor/Tensor;
88201
public fun split (Lsk/ainet/lang/tensor/Tensor;II)Ljava/util/List;
89202
public fun sqrt (Lsk/ainet/lang/tensor/Tensor;)Lsk/ainet/lang/tensor/Tensor;
90203
public fun squeeze (Lsk/ainet/lang/tensor/Tensor;Ljava/lang/Integer;)Lsk/ainet/lang/tensor/Tensor;
91204
public fun subScalar (Lsk/ainet/lang/tensor/Tensor;Ljava/lang/Number;)Lsk/ainet/lang/tensor/Tensor;
92205
public fun subtract (Lsk/ainet/lang/tensor/Tensor;Lsk/ainet/lang/tensor/Tensor;)Lsk/ainet/lang/tensor/Tensor;
93206
public fun sum (Lsk/ainet/lang/tensor/Tensor;Ljava/lang/Integer;)Lsk/ainet/lang/tensor/Tensor;
207+
public fun tanh (Lsk/ainet/lang/tensor/Tensor;)Lsk/ainet/lang/tensor/Tensor;
94208
public fun transpose (Lsk/ainet/lang/tensor/Tensor;)Lsk/ainet/lang/tensor/Tensor;
95209
public fun tril (Lsk/ainet/lang/tensor/Tensor;I)Lsk/ainet/lang/tensor/Tensor;
96210
public fun unfold (Lsk/ainet/lang/tensor/Tensor;III)Lsk/ainet/lang/tensor/Tensor;
@@ -115,6 +229,15 @@ protected final class sk/ainet/exec/tensor/ops/DefaultCpuOpsBase$CpuTensor : sk/
115229
public fun zeroGrad ()V
116230
}
117231

232+
public final class sk/ainet/exec/tensor/ops/JvmTurboQuantKernels {
233+
public static final field INSTANCE Lsk/ainet/exec/tensor/ops/JvmTurboQuantKernels;
234+
public final fun absMax ([FII)F
235+
public final fun dequantize ([B[F[FI)V
236+
public static synthetic fun dequantize$default (Lsk/ainet/exec/tensor/ops/JvmTurboQuantKernels;[B[F[FIILjava/lang/Object;)V
237+
public final fun quantize ([FI)Lsk/ainet/lang/tensor/ops/turboquant/QuantizedVector;
238+
public final fun walshHadamardButterfly ([FII)V
239+
}
240+
118241
public final class sk/ainet/java/SKaiNET {
119242
public static final field INSTANCE Lsk/ainet/java/SKaiNET;
120243
public static final fun context ()Lsk/ainet/context/ExecutionContext;

skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/kernel/ScalarKernelProvider.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package sk.ainet.exec.kernel
33
import sk.ainet.backend.api.kernel.Bf16MatmulKernel
44
import sk.ainet.backend.api.kernel.Fp32MatmulKernel
55
import sk.ainet.backend.api.kernel.KernelProvider
6+
import sk.ainet.backend.api.kernel.Q4_0MatmulKernel
67
import sk.ainet.backend.api.kernel.Q8_0MatmulKernel
78

89
/**
@@ -25,4 +26,5 @@ public object ScalarKernelProvider : KernelProvider {
2526
override fun matmulFp32(): Fp32MatmulKernel = ScalarMatmulKernel
2627
override fun matmulBf16(): Bf16MatmulKernel = ScalarBf16MatmulKernel
2728
override fun matmulQ8_0(): Q8_0MatmulKernel = ScalarQ8_0MatmulKernel
29+
override fun matmulQ4_0(): Q4_0MatmulKernel = ScalarQ4_0MatmulKernel
2830
}
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
package sk.ainet.exec.kernel
2+
3+
import sk.ainet.backend.api.kernel.Q4_0MatmulKernel
4+
5+
/**
6+
* Scalar reference implementation of [Q4_0MatmulKernel] — straight
7+
* per-block dequant + per-element FMA, no SIMD. Always available on
8+
* every KMP target. Used as:
9+
*
10+
* - The correctness reference that accelerated kernels (Panama Vector,
11+
* native FFM) must match within FP order tolerance.
12+
* - A guaranteed fallback when no accelerated provider is registered.
13+
*
14+
* Block layout (32-element block, 18 bytes):
15+
* - bytes 0..1 : FP16 little-endian scale (`d`)
16+
* - bytes 2..17: 16 bytes packing 32 4-bit codes (split layout)
17+
*
18+
* Dequant per element: `(code - 8) * d`. No min / offset.
19+
*
20+
* Performance is intentionally modest; production paths should pick the
21+
* Panama Vector or native variant via the kernel registry.
22+
*/
23+
public object ScalarQ4_0MatmulKernel : Q4_0MatmulKernel {
24+
25+
private const val BLOCK_SIZE = 32
26+
private const val BYTES_PER_BLOCK = 18
27+
28+
override fun matmul(
29+
input: FloatArray, inputOffset: Int,
30+
weight: ByteArray, weightByteOffset: Int,
31+
inputDim: Int, outputDim: Int,
32+
output: FloatArray, outputOffset: Int,
33+
) {
34+
require(inputDim % BLOCK_SIZE == 0) {
35+
"ScalarQ4_0MatmulKernel: inputDim must be a multiple of $BLOCK_SIZE; got $inputDim"
36+
}
37+
if (outputDim == 0 || inputDim == 0) {
38+
if (outputDim > 0) {
39+
for (o in 0 until outputDim) output[outputOffset + o] = 0f
40+
}
41+
return
42+
}
43+
val blocksPerInputDim = inputDim / BLOCK_SIZE
44+
45+
for (o in 0 until outputDim) {
46+
var acc = 0f
47+
for (blockIdx in 0 until blocksPerInputDim) {
48+
val blockBase = weightByteOffset + (blockIdx * outputDim + o) * BYTES_PER_BLOCK
49+
// FP16 scale: two LE bytes.
50+
val dBits = (weight[blockBase].toInt() and 0xFF) or
51+
((weight[blockBase + 1].toInt() and 0xFF) shl 8)
52+
val d = halfToFloat(dBits)
53+
// 32 codes, blockIdx-th window of the input vector. Split
54+
// layout: low nibbles → elements 0..15, high → 16..31.
55+
val inputBase = inputOffset + blockIdx * BLOCK_SIZE
56+
val codesBase = blockBase + 2
57+
for (j in 0 until 16) {
58+
val b = weight[codesBase + j].toInt() and 0xFF
59+
val lo = (b and 0x0F) - 8
60+
val hi = (b ushr 4) - 8
61+
acc += input[inputBase + j] * lo * d
62+
acc += input[inputBase + 16 + j] * hi * d
63+
}
64+
}
65+
output[outputOffset + o] = acc
66+
}
67+
}
68+
69+
/**
70+
* Convert a 16-bit IEEE-754 half-precision value (low 16 bits of
71+
* [hbits]) to FP32. Mirrors [ScalarQ8_0MatmulKernel]'s inlined helper
72+
* — the skainet-lang-core dequant helper is internal to that module.
73+
*/
74+
private fun halfToFloat(hbits: Int): Float {
75+
val sign = (hbits and 0x8000) shl 16
76+
val exp = (hbits and 0x7C00) shr 10
77+
val mant = hbits and 0x03FF
78+
return when (exp) {
79+
0 -> {
80+
if (mant == 0) Float.fromBits(sign)
81+
else {
82+
var m = mant
83+
var e = -14
84+
while ((m and 0x400) == 0) {
85+
m = m shl 1
86+
e--
87+
}
88+
m = m and 0x3FF
89+
Float.fromBits(sign or ((e + 127) shl 23) or (m shl 13))
90+
}
91+
}
92+
31 -> Float.fromBits(sign or (0xFF shl 23) or (mant shl 13))
93+
else -> Float.fromBits(sign or ((exp - 15 + 127) shl 23) or (mant shl 13))
94+
}
95+
}
96+
}

0 commit comments

Comments
 (0)