Skip to content

Commit 1e49140

Browse files
committed
fix(memory): heap-wrap remaining hot-path tensor allocs — stop direct-memory leak
Same root cause as 319c394 (sliceView): ctx.fromFloatArray copies the input FloatArray into a fresh MemorySegment from Arena.ofAuto(). Direct memory doesn't pressure the GC, so per-forward auto-arenas accumulate until -XX:MaxDirectMemorySize is exhausted. Empirically: smoke test went from a 45 GB direct-memory OOM mid-prefill to a 271 MB net direct-memory growth across the full 27 min forward, with the resident JVM staying inside the 32 GB cap. Fixed sites (all on the per-token / per-layer path): - RoPE.applyRoPESplitHalf: cos/sin tables (sliding layers, partial=1.0) - RoPE.applyRoPESplitHalfFull: cos/sin tables (full layers, partial=0.25) - MultiHeadAttention.buildSlidingCausalMask: mask tensor (every block using the sliding path, every forward) - GemmaModel softcap: scale + inv scalar tensors (every forward) - PaddedSharedPositionalKVCache.padHeadDim: padded V (Gemma 4 value-head padding when src/target head_dim differ) Each site now wraps the FloatArray as DenseFloatArrayTensorData and goes through ctx.fromData, which keeps the storage on the heap and lets the GC reclaim it normally. Tool-call format regression on the smoke test prompt is tracked separately; this commit only fixes the runnability OOM.
1 parent 5741d7b commit 1e49140

4 files changed

Lines changed: 54 additions & 13 deletions

File tree

llm-core/src/commonMain/kotlin/sk/ainet/lang/nn/transformer/KVCache.kt

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -473,10 +473,15 @@ public class PaddedSharedPositionalKVCache<T : DType, V>(
473473
// remaining [srcHeadDim, targetHeadDim) stays zero
474474
}
475475
}
476-
return ctx.fromFloatArray<T, V>(
477-
sk.ainet.lang.tensor.Shape(nKV, seq, targetHeadDim),
478-
t.dtype,
479-
out
476+
// Heap-backed wrap — fromFloatArray would copy into a fresh
477+
// Arena.ofAuto MemorySegment per call; padHeadDim runs every
478+
// attention forward when src/target head_dim differ (Gemma 4
479+
// value-head padding), so direct memory accumulates without GC
480+
// pressure. Same root cause as commit 319c394.
481+
val padShape = sk.ainet.lang.tensor.Shape(nKV, seq, targetHeadDim)
482+
return ctx.fromData(
483+
sk.ainet.lang.tensor.data.DenseFloatArrayTensorData<T>(padShape, out) as sk.ainet.lang.tensor.data.TensorData<T, V>,
484+
t.dtype
480485
)
481486
}
482487
}

llm-core/src/commonMain/kotlin/sk/ainet/lang/nn/transformer/MultiHeadAttention.kt

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,15 @@ public class MultiHeadAttention<T : DType, V>(
322322
data[qi * seqKV + ki] = if (allowed) 0f else neg
323323
}
324324
}
325-
return ctx.fromFloatArray(Shape(1, 1, seqQ, seqKV), dtype, data)
325+
// Heap-backed wrap — fromFloatArray would copy into a fresh
326+
// Arena.ofAuto MemorySegment every forward (× layers using the
327+
// sliding-mask path), and direct memory doesn't pressure the GC.
328+
// Same root cause as the sliceView leak (commit 319c394).
329+
val maskShape = Shape(1, 1, seqQ, seqKV)
330+
return ctx.fromData(
331+
sk.ainet.lang.tensor.data.DenseFloatArrayTensorData<T>(maskShape, data) as sk.ainet.lang.tensor.data.TensorData<T, V>,
332+
dtype
333+
)
326334
}
327335

328336
private fun repeatKVHeads(t: Tensor<T, V>, repeats: Int, ops: sk.ainet.lang.tensor.ops.TensorOps): Tensor<T, V> {

llm-core/src/commonMain/kotlin/sk/ainet/lang/nn/transformer/RoPE.kt

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,21 @@ public class RoPE<T : DType, V>(
191191
}
192192
}
193193
val cosShape = Shape(seqLen, halfRotary)
194-
val cosTensor: Tensor<T, V> = ctx.fromFloatArray(cosShape, input.dtype, cosData)
195-
val sinTensor: Tensor<T, V> = ctx.fromFloatArray(cosShape, input.dtype, sinData)
194+
// Heap-backed wrap, NOT ctx.fromFloatArray — fromFloatArray would
195+
// copy these transient cos/sin tables into fresh MemorySegments
196+
// from Arena.ofAuto(). RoPE runs twice per MHA (Q, K) × every
197+
// layer × every forward, and direct-memory pressure doesn't trigger
198+
// GC, so the auto-arenas accumulate until -XX:MaxDirectMemorySize
199+
// is exhausted. Same root-cause class as the sliceView leak
200+
// (commit 319c394). Heap arrays follow normal GC.
201+
val cosTensor: Tensor<T, V> = ctx.fromData(
202+
sk.ainet.lang.tensor.data.DenseFloatArrayTensorData<T>(cosShape, cosData) as sk.ainet.lang.tensor.data.TensorData<T, V>,
203+
input.dtype
204+
)
205+
val sinTensor: Tensor<T, V> = ctx.fromData(
206+
sk.ainet.lang.tensor.data.DenseFloatArrayTensorData<T>(cosShape, sinData) as sk.ainet.lang.tensor.data.TensorData<T, V>,
207+
input.dtype
208+
)
196209

197210
// Standard 2D rotation: (a, b) -> (a*cos - b*sin, a*sin + b*cos)
198211
val rotA = ops.subtract(ops.multiply(A, cosTensor), ops.multiply(C, sinTensor))
@@ -219,8 +232,16 @@ public class RoPE<T : DType, V>(
219232
}
220233

221234
val cosShape = Shape(seqLen, halfRotary)
222-
val cosTensor: Tensor<T, V> = ctx.fromFloatArray(cosShape, input.dtype, cosData)
223-
val sinTensor: Tensor<T, V> = ctx.fromFloatArray(cosShape, input.dtype, sinData)
235+
// Heap-backed wrap — see applyRoPESplitHalf for why fromFloatArray
236+
// is poison on the hot path (direct-memory leak via Arena.ofAuto).
237+
val cosTensor: Tensor<T, V> = ctx.fromData(
238+
sk.ainet.lang.tensor.data.DenseFloatArrayTensorData<T>(cosShape, cosData) as sk.ainet.lang.tensor.data.TensorData<T, V>,
239+
input.dtype
240+
)
241+
val sinTensor: Tensor<T, V> = ctx.fromData(
242+
sk.ainet.lang.tensor.data.DenseFloatArrayTensorData<T>(cosShape, sinData) as sk.ainet.lang.tensor.data.TensorData<T, V>,
243+
input.dtype
244+
)
224245

225246
val rotEven = ops.subtract(ops.multiply(even, cosTensor), ops.multiply(odd, sinTensor))
226247
val rotOdd = ops.add(ops.multiply(odd, cosTensor), ops.multiply(even, sinTensor))

llm-inference/gemma/src/commonMain/kotlin/sk/ainet/models/gemma/GemmaModel.kt

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,11 +128,18 @@ public class GemmaModel<T : DType, V>(
128128
// onto degenerate attractor tokens during decode.
129129
if (finalLogitSoftcapping > 0f) {
130130
val ops = ctx.ops
131-
val scale = ctx.fromFloatArray<T, V>(
132-
sk.ainet.lang.tensor.Shape(1), dtype, floatArrayOf(1f / finalLogitSoftcapping)
131+
// Heap-backed scalar wrap — fromFloatArray copies even
132+
// single-float tables into a fresh Arena.ofAuto MemorySegment;
133+
// running per forward step accumulates direct memory the GC
134+
// can't see. Same root cause as commit 319c394.
135+
val scaleShape = sk.ainet.lang.tensor.Shape(1)
136+
val scale: Tensor<T, V> = ctx.fromData(
137+
sk.ainet.lang.tensor.data.DenseFloatArrayTensorData<T>(scaleShape, floatArrayOf(1f / finalLogitSoftcapping)) as sk.ainet.lang.tensor.data.TensorData<T, V>,
138+
dtype
133139
)
134-
val inv = ctx.fromFloatArray<T, V>(
135-
sk.ainet.lang.tensor.Shape(1), dtype, floatArrayOf(finalLogitSoftcapping)
140+
val inv: Tensor<T, V> = ctx.fromData(
141+
sk.ainet.lang.tensor.data.DenseFloatArrayTensorData<T>(scaleShape, floatArrayOf(finalLogitSoftcapping)) as sk.ainet.lang.tensor.data.TensorData<T, V>,
142+
dtype
136143
)
137144
logits = ops.multiply(ops.tanh(ops.multiply(logits, scale)), inv)
138145
}

0 commit comments

Comments
 (0)