Skip to content

Commit 9d05fc4

Browse files
Merge pull request #575 from SKaiNET-developers/feature/native-fp32-matmul
feat(native-cpu): native FFM FP32 SGEMM kernel (PR 5 of 5)
2 parents 87a5730 + 2af135d commit 9d05fc4

8 files changed

Lines changed: 410 additions & 6 deletions

File tree

skainet-backends/skainet-backend-native-cpu/native/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ endif()
1212
add_library(skainet_kernels SHARED
1313
src/skainet_smoke.c
1414
src/q4k_matmul.c
15+
src/fp32_matmul.c
1516
)
1617

1718
target_include_directories(skainet_kernels PUBLIC

skainet-backends/skainet-backend-native-cpu/native/include/skainet_kernels.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,22 @@ SKAINET_API void skainet_q4k_matmul(
6060
int32_t output_offset
6161
);
6262

63+
/*
64+
* Row-major FP32 SGEMM: C(m, n) = A(m, k) * B(k, n).
65+
*
66+
* Strides are in floats (not bytes). For a contiguous parent matrix
67+
* `a_stride == k`, `b_stride == n`, `c_stride == n`. The kernel zeros
68+
* the m×n output block before accumulating, so callers always get
69+
* `C = A·B` (not `C += A·B`). `k == 0` zeros the block; `m == 0`
70+
* or `n == 0` is a no-op.
71+
*/
72+
SKAINET_API void skainet_fp32_matmul(
73+
const float* a, int32_t a_offset, int32_t a_stride,
74+
const float* b, int32_t b_offset, int32_t b_stride,
75+
float* c, int32_t c_offset, int32_t c_stride,
76+
int32_t m, int32_t n, int32_t k
77+
);
78+
6379
#ifdef __cplusplus
6480
}
6581
#endif
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#include "skainet_kernels.h"
2+
3+
#include <stddef.h>
4+
#include <stdint.h>
5+
6+
/*
7+
* Native row-major SGEMM matching the
8+
* sk.ainet.backend.api.kernel.Fp32MatmulKernel SPI:
9+
*
10+
* C(m, n) = A(m, k) * B(k, n)
11+
*
12+
* Strides are in floats (not bytes); for a contiguous parent matrix
13+
* `aStride == k`, `bStride == n`, `cStride == n`. Sub-block scenarios
14+
* pass larger strides and the corresponding offsets.
15+
*
16+
* Iteration order is i-p-j (outer product into rows of C). The inner
17+
* loop is `c[j] += a_ip * b[j]` over a contiguous run of `n` floats
18+
* for both b's row and c's row — auto-vectorizes cleanly under
19+
* -O3 -ffast-math into vfmadd231ps / fmla.
20+
*
21+
* Caller contract:
22+
* - C is FULLY OVERWRITTEN in the m×n block (zero-then-accumulate).
23+
* - k == 0 zeros the m×n block.
24+
* - m == 0 || n == 0 is a no-op.
25+
* - Negative m / n / k are caller errors; the Kotlin wrapper rejects
26+
* them. The C kernel still treats negatives as no-op (via the
27+
* `<=` loop bounds) defensively.
28+
*/
29+
SKAINET_API void skainet_fp32_matmul(
30+
const float* SKAINET_RESTRICT a, int32_t a_offset, int32_t a_stride,
31+
const float* SKAINET_RESTRICT b, int32_t b_offset, int32_t b_stride,
32+
float* SKAINET_RESTRICT c, int32_t c_offset, int32_t c_stride,
33+
int32_t m, int32_t n, int32_t k
34+
) {
35+
if (m <= 0 || n <= 0) return;
36+
37+
/* Zero the output block. Required by the SPI contract for k == 0
38+
* AND prerequisite for the i-p-j accumulator pattern below. */
39+
for (int32_t i = 0; i < m; ++i) {
40+
float* SKAINET_RESTRICT c_row = c + c_offset + (size_t) i * c_stride;
41+
for (int32_t j = 0; j < n; ++j) {
42+
c_row[j] = 0.0f;
43+
}
44+
}
45+
if (k <= 0) return;
46+
47+
/* Outer-product accumulator: streams two contiguous rows on the
48+
* inner loop (b's row and c's row), broadcasts a single A scalar.
49+
* The compiler emits vfmadd231ps with a vbroadcastss for a_ip. */
50+
for (int32_t i = 0; i < m; ++i) {
51+
const float* SKAINET_RESTRICT a_row = a + a_offset + (size_t) i * a_stride;
52+
float* SKAINET_RESTRICT c_row = c + c_offset + (size_t) i * c_stride;
53+
for (int32_t p = 0; p < k; ++p) {
54+
const float a_ip = a_row[p];
55+
const float* SKAINET_RESTRICT b_row = b + b_offset + (size_t) p * b_stride;
56+
for (int32_t j = 0; j < n; ++j) {
57+
c_row[j] += a_ip * b_row[j];
58+
}
59+
}
60+
}
61+
}
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
package sk.ainet.exec.kernel
2+
3+
import java.lang.foreign.Arena
4+
import java.lang.foreign.FunctionDescriptor
5+
import java.lang.foreign.Linker
6+
import java.lang.foreign.MemorySegment
7+
import java.lang.foreign.ValueLayout
8+
import java.lang.invoke.MethodHandle
9+
import sk.ainet.backend.api.kernel.Fp32MatmulKernel
10+
11+
/**
12+
* Native (FFM) implementation of [Fp32MatmulKernel].
13+
*
14+
* Wraps the bundled C symbol
15+
*
16+
* void skainet_fp32_matmul(
17+
* const float* a, int32_t a_offset, int32_t a_stride,
18+
* const float* b, int32_t b_offset, int32_t b_stride,
19+
* float* c, int32_t c_offset, int32_t c_stride,
20+
* int32_t m, int32_t n, int32_t k);
21+
*
22+
* The C kernel is a tight i-p-j outer-product accumulator over rows
23+
* of C; the inner `c[j] += a*b[j]` loop streams two contiguous arrays
24+
* and auto-vectorizes into FMA under -O3 -ffast-math (vfmadd231ps on
25+
* x86_64, fmla on AArch64).
26+
*
27+
* Numerical parity vs [PanamaVectorMatmulKernel] is asserted by
28+
* [NativeFp32MatmulKernelParityTest] within FMA + reordered-reduction
29+
* tolerance (the same `1e-5 * k` bar Panama uses against the scalar
30+
* reference).
31+
*
32+
* PR 5 of the staged native-FFM rollout — wraps the rollout per the
33+
* `native-ffm-plan` asciidoc. Single-threaded, no cache blocking;
34+
* future work could add parallelChunks-style row blocking and B-tile
35+
* packing, but the scalar C path already lands well within the SPI
36+
* contract on host-arch CPUs.
37+
*/
38+
internal object NativeFp32MatmulKernel : Fp32MatmulKernel {
39+
40+
fun isAvailable(): Boolean = handle != null
41+
42+
override fun matmul(
43+
a: FloatArray, aOffset: Int, aStride: Int,
44+
b: FloatArray, bOffset: Int, bStride: Int,
45+
out: FloatArray, outOffset: Int, outStride: Int,
46+
m: Int, n: Int, k: Int,
47+
) {
48+
require(m >= 0 && n >= 0 && k >= 0) {
49+
"NativeFp32MatmulKernel: m, n, k must be non-negative; got m=$m n=$n k=$k"
50+
}
51+
if (m == 0 || n == 0) return
52+
53+
val mh = handle
54+
?: error("NativeFp32MatmulKernel.matmul invoked while native library unavailable")
55+
56+
// Sizes for the off-heap copies. Each of A, B, C uses the
57+
// bytes the kernel actually reaches — for non-contiguous
58+
// strides this can be larger than the matrix's element count
59+
// because the strides skip past unused floats. Allocating to
60+
// the full reach (offset + last-row reach) keeps the kernel
61+
// pointer arithmetic simple and matches Kotlin's bounds.
62+
val aReachFloats = if (m == 0 || k == 0) 0 else aOffset + (m - 1) * aStride + k
63+
val bReachFloats = if (k == 0 || n == 0) 0 else bOffset + (k - 1) * bStride + n
64+
val cReachFloats = outOffset + (m - 1) * outStride + n
65+
66+
Arena.ofConfined().use { arena ->
67+
val aBytes = aReachFloats.toLong() * java.lang.Float.BYTES
68+
val bBytes = bReachFloats.toLong() * java.lang.Float.BYTES
69+
val cBytes = cReachFloats.toLong() * java.lang.Float.BYTES
70+
val align = ValueLayout.JAVA_FLOAT.byteAlignment()
71+
72+
val aSeg: MemorySegment = if (aBytes > 0) arena.allocate(aBytes, align) else MemorySegment.NULL
73+
val bSeg: MemorySegment = if (bBytes > 0) arena.allocate(bBytes, align) else MemorySegment.NULL
74+
val cSeg: MemorySegment = arena.allocate(cBytes, align)
75+
76+
if (aReachFloats > 0) {
77+
MemorySegment.copy(a, 0, aSeg, ValueLayout.JAVA_FLOAT, 0L, aReachFloats)
78+
}
79+
if (bReachFloats > 0) {
80+
MemorySegment.copy(b, 0, bSeg, ValueLayout.JAVA_FLOAT, 0L, bReachFloats)
81+
}
82+
83+
mh.invoke(
84+
aSeg, aOffset, aStride,
85+
bSeg, bOffset, bStride,
86+
cSeg, outOffset, outStride,
87+
m, n, k,
88+
)
89+
90+
MemorySegment.copy(cSeg, ValueLayout.JAVA_FLOAT, 0L, out, 0, cReachFloats)
91+
}
92+
}
93+
94+
private val handle: MethodHandle? by lazy {
95+
val lookup = NativeLibraryLoader.lookup() ?: return@lazy null
96+
val symbol = lookup.find("skainet_fp32_matmul").orElse(null) ?: return@lazy null
97+
val descriptor = FunctionDescriptor.ofVoid(
98+
ValueLayout.ADDRESS, // a
99+
ValueLayout.JAVA_INT, // a_offset
100+
ValueLayout.JAVA_INT, // a_stride
101+
ValueLayout.ADDRESS, // b
102+
ValueLayout.JAVA_INT, // b_offset
103+
ValueLayout.JAVA_INT, // b_stride
104+
ValueLayout.ADDRESS, // c
105+
ValueLayout.JAVA_INT, // c_offset
106+
ValueLayout.JAVA_INT, // c_stride
107+
ValueLayout.JAVA_INT, // m
108+
ValueLayout.JAVA_INT, // n
109+
ValueLayout.JAVA_INT, // k
110+
)
111+
runCatching { Linker.nativeLinker().downcallHandle(symbol, descriptor) }.getOrNull()
112+
}
113+
}

skainet-backends/skainet-backend-native-cpu/src/jvmMain/kotlin/sk/ainet/exec/kernel/NativeKernelProvider.kt

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,18 @@ import sk.ainet.backend.api.kernel.Q4KMemSegMatmulKernel
2626
*
2727
* Staged rollout cursor (see `native-ffm-plan` asciidoc):
2828
* - PR 2: real Q4_K matmul wired into the heap SPI.
29-
* - PR 3 (this commit): MemSeg-input zero-copy sibling.
30-
* - Later: native `matmulFp32`, `matmulQ6K`, `matmulQ8_0`.
29+
* - PR 3: MemSeg-input zero-copy sibling.
30+
* - PR 5 (this commit): native FP32 matmul wired into [matmulFp32].
31+
* - Later: native `matmulQ6K`, `matmulQ8_0` (need new SPI accessors).
3132
*/
3233
public object NativeKernelProvider : KernelProvider, MemSegKernelProvider {
3334
override val name: String = "native-ffm"
3435
override val priority: Int = 100
3536

3637
override fun isAvailable(): Boolean = NativeQ4KMatmulKernel.isAvailable()
3738

38-
override fun matmulFp32(): Fp32MatmulKernel? = null
39+
override fun matmulFp32(): Fp32MatmulKernel? =
40+
if (NativeFp32MatmulKernel.isAvailable()) NativeFp32MatmulKernel else null
3941

4042
override fun matmulQ4K(): Q4KMatmulKernel? =
4143
if (NativeQ4KMatmulKernel.isAvailable()) NativeQ4KMatmulKernel else null

skainet-backends/skainet-backend-native-cpu/src/jvmTest/kotlin/sk/ainet/exec/kernel/NativeFfmPipelineTest.kt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,16 @@ class NativeFfmPipelineTest {
3232
}
3333

3434
@Test
35-
fun `provider exposes Q4_K kernel when the native lib loads`() {
35+
fun `provider exposes Q4_K and FP32 kernels when the native lib loads`() {
3636
assertEquals("native-ffm", NativeKernelProvider.name)
3737
assertEquals(100, NativeKernelProvider.priority)
3838
assertTrue(
3939
NativeKernelProvider.isAvailable(),
4040
"Native kernel provider reports unavailable on this host — " +
4141
"bundled libskainet_kernels missing or skainet_q4k_matmul unresolved",
4242
)
43-
// FP32 matmul ships in a later PR; Q4_K is wired through PR 2.
44-
assertEquals(null, NativeKernelProvider.matmulFp32())
43+
// PR 5 wires both Q4_K (PR 2) and FP32 (this PR) through the SPI.
44+
assertNotNull(NativeKernelProvider.matmulFp32())
4545
assertNotNull(NativeKernelProvider.matmulQ4K())
4646
}
4747

0 commit comments

Comments
 (0)