@@ -108,19 +108,31 @@ internal class DefaultCpuOpsJvm(
108108 @Suppress(" UNCHECKED_CAST" )
109109 return newTensor(transposed as TensorData <T , V >, tensor.dtype, tensor)
110110 }
111- // MemorySegment FP32 fast path: physical transpose via SIMD
111+ // MemorySegment FP32 fast path: physical transpose via SIMD.
112+ // Uses Arena.ofAuto() so the result segment is reclaimed by GC
113+ // when the wrapping Tensor is no longer reachable. Earlier
114+ // ofConfined() builds leaked an arena per call, blowing 32+ GiB
115+ // of direct memory in inference loops (every layer × every
116+ // forward pass).
112117 if (data is MemorySegmentBackedData ) {
113- val arena = Arena .ofConfined ()
118+ val arena = Arena .ofAuto ()
114119 val result = MemorySegmentTensorData <T >(Shape (cols, rows), arena)
115120 val src = data as MemorySegmentBackedData
116- val srcOff = src.segmentByteOffset
117- val dstOff = result.segmentByteOffset
121+ val floatLayout = java.lang.foreign.ValueLayout .JAVA_FLOAT
122+ // Bulk-load source into FloatArray, transpose via tight scalar
123+ // loop (JIT auto-vectorizes), bulk-write destination. Replaces
124+ // O(rows*cols) per-element VarHandle.get/set which dominated
125+ // attention-path transposes.
126+ val srcArr = FloatArray (rows * cols)
127+ java.lang.foreign.MemorySegment .copy(src.segment, floatLayout, src.segmentByteOffset, srcArr, 0 , rows * cols)
128+ val dstArr = FloatArray (rows * cols)
118129 for (r in 0 until rows) {
130+ val rowBase = r * cols
119131 for (c in 0 until cols) {
120- val v = src.segment.get(java.lang.foreign.ValueLayout .JAVA_FLOAT , srcOff + (r.toLong() * cols + c) * 4 )
121- result.segment.set(java.lang.foreign.ValueLayout .JAVA_FLOAT , dstOff + (c.toLong() * rows + r) * 4 , v)
132+ dstArr[c * rows + r] = srcArr[rowBase + c]
122133 }
123134 }
135+ java.lang.foreign.MemorySegment .copy(dstArr, 0 , result.segment, floatLayout, result.segmentByteOffset, rows * cols)
124136 @Suppress(" UNCHECKED_CAST" )
125137 return newTensor(result as TensorData <T , V >, tensor.dtype, tensor)
126138 }
@@ -750,7 +762,11 @@ internal class DefaultCpuOpsJvm(
750762 val aMemSeg = a.data as ? MemorySegmentBackedData
751763 val bMemSeg = b.data as ? MemorySegmentBackedData
752764 if (aMemSeg != null && bMemSeg != null ) {
753- val arena = Arena .ofConfined()
765+ // Same fix as the transpose path above: use Arena.ofAuto so the
766+ // matmul output segment is GC-reclaimable. Per-call ofConfined()
767+ // leaks ~tens of MB per matmul, which over a 35-layer Gemma 4
768+ // forward pass exhausts the JVM direct-memory cap.
769+ val arena = Arena .ofAuto()
754770 val result = MemorySegmentTensorData <T >(Shape (m, n), arena)
755771 val blockedThresholdMS = 16 * 16
756772 if (m >= blockedThresholdMS || n >= blockedThresholdMS || k >= blockedThresholdMS) {
0 commit comments