Skip to content

Commit a453045

Browse files
Merge pull request #556 from SKaiNET-developers/feature/ISSUE-555-q4k-q5k-arena-leak
Q4_K/Q5_K canonical ggml layout + FP32 MemSeg arena leak fix
2 parents 9786960 + 063bdb3 commit a453045

15 files changed

Lines changed: 1112 additions & 354 deletions

File tree

skainet-backends/skainet-backend-cpu/build.gradle.kts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,11 @@ kotlin {
5858
implementation(project(":skainet-lang:skainet-lang-models"))
5959
}
6060

61-
val jvmMain by getting
61+
val jvmMain by getting {
62+
dependencies {
63+
implementation(libs.kotlinx.coroutines)
64+
}
65+
}
6266
val jvmTest by getting {
6367
dependencies {
6468
implementation(libs.kotlin.test)

skainet-backends/skainet-backend-cpu/src/jvmMain/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOpsJvm.kt

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -108,19 +108,31 @@ internal class DefaultCpuOpsJvm(
108108
@Suppress("UNCHECKED_CAST")
109109
return newTensor(transposed as TensorData<T, V>, tensor.dtype, tensor)
110110
}
111-
// MemorySegment FP32 fast path: physical transpose via SIMD
111+
// MemorySegment FP32 fast path: physical transpose via SIMD.
112+
// Uses Arena.ofAuto() so the result segment is reclaimed by GC
113+
// when the wrapping Tensor is no longer reachable. Earlier
114+
// ofConfined() builds leaked an arena per call, blowing 32+ GiB
115+
// of direct memory in inference loops (every layer × every
116+
// forward pass).
112117
if (data is MemorySegmentBackedData) {
113-
val arena = Arena.ofConfined()
118+
val arena = Arena.ofAuto()
114119
val result = MemorySegmentTensorData<T>(Shape(cols, rows), arena)
115120
val src = data as MemorySegmentBackedData
116-
val srcOff = src.segmentByteOffset
117-
val dstOff = result.segmentByteOffset
121+
val floatLayout = java.lang.foreign.ValueLayout.JAVA_FLOAT
122+
// Bulk-load source into FloatArray, transpose via tight scalar
123+
// loop (JIT auto-vectorizes), bulk-write destination. Replaces
124+
// O(rows*cols) per-element VarHandle.get/set which dominated
125+
// attention-path transposes.
126+
val srcArr = FloatArray(rows * cols)
127+
java.lang.foreign.MemorySegment.copy(src.segment, floatLayout, src.segmentByteOffset, srcArr, 0, rows * cols)
128+
val dstArr = FloatArray(rows * cols)
118129
for (r in 0 until rows) {
130+
val rowBase = r * cols
119131
for (c in 0 until cols) {
120-
val v = src.segment.get(java.lang.foreign.ValueLayout.JAVA_FLOAT, srcOff + (r.toLong() * cols + c) * 4)
121-
result.segment.set(java.lang.foreign.ValueLayout.JAVA_FLOAT, dstOff + (c.toLong() * rows + r) * 4, v)
132+
dstArr[c * rows + r] = srcArr[rowBase + c]
122133
}
123134
}
135+
java.lang.foreign.MemorySegment.copy(dstArr, 0, result.segment, floatLayout, result.segmentByteOffset, rows * cols)
124136
@Suppress("UNCHECKED_CAST")
125137
return newTensor(result as TensorData<T, V>, tensor.dtype, tensor)
126138
}
@@ -750,7 +762,11 @@ internal class DefaultCpuOpsJvm(
750762
val aMemSeg = a.data as? MemorySegmentBackedData
751763
val bMemSeg = b.data as? MemorySegmentBackedData
752764
if (aMemSeg != null && bMemSeg != null) {
753-
val arena = Arena.ofConfined()
765+
// Same fix as the transpose path above: use Arena.ofAuto so the
766+
// matmul output segment is GC-reclaimable. Per-call ofConfined()
767+
// leaks ~tens of MB per matmul, which over a 35-layer Gemma 4
768+
// forward pass exhausts the JVM direct-memory cap.
769+
val arena = Arena.ofAuto()
754770
val result = MemorySegmentTensorData<T>(Shape(m, n), arena)
755771
val blockedThresholdMS = 16 * 16
756772
if (m >= blockedThresholdMS || n >= blockedThresholdMS || k >= blockedThresholdMS) {

0 commit comments

Comments
 (0)