Skip to content

Commit f2b8463

Browse files
michalharakalclaude
andcommitted
Support per-layer xIELU arrays and lazy transpose for memory efficiency
- Fix xIELU metadata extraction: params are per-layer FLOAT32 arrays (32 values each), not global scalars - Add context length limiting (-c flag) to avoid KV cache overflow - Remove pre-transpose in ApertusRuntime to halve peak memory (transpose per-layer during forward pass instead) - Add preTransposed flag to weight loader (for future use) - First successful end-to-end inference with Apertus-8B GGUF Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 2b4c1f4 commit f2b8463

5 files changed

Lines changed: 166 additions & 75 deletions

File tree

llm-inference/apertus/src/commonMain/kotlin/sk/ainet/models/apertus/ApertusRuntime.kt

Lines changed: 9 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -37,27 +37,6 @@ public class ApertusRuntime<T : DType>(
3737
random: Random = Random.Default
3838
) : DecoderRuntime<T>(random) {
3939

40-
private class TransposedLayerWeights<T : DType>(
41-
val wqT: Tensor<T, Float>,
42-
val wkT: Tensor<T, Float>,
43-
val wvT: Tensor<T, Float>,
44-
val woT: Tensor<T, Float>,
45-
val ffnDownT: Tensor<T, Float>,
46-
val ffnUpT: Tensor<T, Float>,
47-
)
48-
49-
private val transposedLayers: List<TransposedLayerWeights<T>> = weights.layers.map { layer ->
50-
TransposedLayerWeights(
51-
wqT = layer.wq.t(),
52-
wkT = layer.wk.t(),
53-
wvT = layer.wv.t(),
54-
woT = layer.wo.t(),
55-
ffnDownT = layer.ffnDown.t(),
56-
ffnUpT = layer.ffnUp.t(),
57-
)
58-
}
59-
private val outputWeightT: Tensor<T, Float> = weights.outputWeight.t()
60-
6140
// ---- DecoderRuntime abstract properties ----
6241
override val dim: Int = weights.metadata.embeddingLength
6342
override val seqLen: Int = weights.metadata.contextLength
@@ -102,22 +81,23 @@ public class ApertusRuntime<T : DType>(
10281
)
10382
}
10483

84+
private val outputWeightT: Tensor<T, Float> = weights.outputWeight.t()
85+
10586
// ---- DecoderRuntime template methods ----
10687

10788
override fun embedToken(tokenId: Int): Tensor<T, Float> =
10889
embedding.forward(intArrayOf(tokenId), ctx)
10990

11091
override fun runLayer(layerIdx: Int, x: Tensor<T, Float>): Tensor<T, Float> {
111-
val tl = transposedLayers[layerIdx]
11292
val layer = weights.layers[layerIdx]
11393

11494
// 1. Attention norm
11595
val attnNorm = attnNorms[layerIdx].forward(x, ctx)
11696

117-
// 2. QKV projections
118-
val q = attnNorm.matmul(tl.wqT)
119-
val k = attnNorm.matmul(tl.wkT)
120-
val v = attnNorm.matmul(tl.wvT)
97+
// 2. QKV projections (transpose on the fly to avoid double-memory peak)
98+
val q = attnNorm.matmul(layer.wq.t())
99+
val k = attnNorm.matmul(layer.wk.t())
100+
val v = attnNorm.matmul(layer.wv.t())
121101

122102
// 3. QK-norm: per-head RMSNorm on Q and K
123103
val qNormed = applyPerHeadRMSNorm(q, nHeads, headDim, layer.qNorm)
@@ -127,15 +107,15 @@ public class ApertusRuntime<T : DType>(
127107
val attnOut = attentionBackend.attention(qNormed, kNormed, v, layerIdx, position)
128108

129109
// 5. Output projection + residual
130-
val afterAttn = x + attnOut.matmul(tl.woT)
110+
val afterAttn = x + attnOut.matmul(layer.wo.t())
131111

132112
// 6. FFN norm
133113
val ffnNorm = ffnNorms[layerIdx].forward(afterAttn, ctx)
134114

135115
// 7. Ungated MLP: up → xIELU → down
136-
val up = ffnNorm.matmul(tl.ffnUpT)
116+
val up = ffnNorm.matmul(layer.ffnUp.t())
137117
val activated = applyXIELU(up, layer.xieluParams)
138-
val ffnOut = activated.matmul(tl.ffnDownT)
118+
val ffnOut = activated.matmul(layer.ffnDown.t())
139119

140120
// 8. Residual
141121
return afterAttn + ffnOut

llm-inference/apertus/src/commonMain/kotlin/sk/ainet/models/apertus/ApertusRuntimeWeights.kt

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ public data class ApertusRuntimeWeights<T : DType>(
6969
val layers: List<ApertusLayerWeights<T>>,
7070
val outputNorm: Tensor<T, Float>,
7171
val outputWeight: Tensor<T, Float>,
72-
val ropeFreqs: Tensor<T, Float>? = null
72+
val ropeFreqs: Tensor<T, Float>? = null,
73+
val preTransposed: Boolean = false
7374
)
7475

7576
/**
@@ -99,7 +100,8 @@ public object ApertusTensorNames {
99100
public data class ApertusWeights<T : DType, V>(
100101
val metadata: ApertusModelMetadata,
101102
val tensors: Map<String, Tensor<T, V>>,
102-
val xieluParams: Map<Int, ApertusXIELUParams> = emptyMap()
103+
val xieluParams: Map<Int, ApertusXIELUParams> = emptyMap(),
104+
val preTransposed: Boolean = false
103105
)
104106

105107
/**
@@ -150,7 +152,8 @@ public object ApertusWeightMapper {
150152
layers = layers,
151153
outputNorm = outputNorm,
152154
outputWeight = outputWeight,
153-
ropeFreqs = ropeFreqs
155+
ropeFreqs = ropeFreqs,
156+
preTransposed = weights.preTransposed
154157
)
155158
}
156159
}

llm-inference/apertus/src/commonMain/kotlin/sk/ainet/models/apertus/ApertusWeightLoader.kt

Lines changed: 117 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -32,26 +32,31 @@ import kotlin.reflect.KClass
3232
public class ApertusWeightLoader private constructor(
3333
private val sourceProvider: (() -> Source)?,
3434
private val randomAccessProvider: (() -> RandomAccessSource)?,
35-
private val quantPolicy: QuantPolicy = QuantPolicy.RAW_BYTES
35+
private val quantPolicy: QuantPolicy = QuantPolicy.RAW_BYTES,
36+
private val preTransposed: Boolean = false
3637
) {
3738

3839
public companion object {
3940
public fun fromSource(
4041
sourceProvider: () -> Source,
41-
quantPolicy: QuantPolicy = QuantPolicy.RAW_BYTES
42+
quantPolicy: QuantPolicy = QuantPolicy.RAW_BYTES,
43+
preTransposed: Boolean = false
4244
): ApertusWeightLoader = ApertusWeightLoader(
4345
sourceProvider = sourceProvider,
4446
randomAccessProvider = null,
45-
quantPolicy = quantPolicy
47+
quantPolicy = quantPolicy,
48+
preTransposed = preTransposed
4649
)
4750

4851
public fun fromRandomAccess(
4952
randomAccessProvider: () -> RandomAccessSource,
50-
quantPolicy: QuantPolicy = QuantPolicy.RAW_BYTES
53+
quantPolicy: QuantPolicy = QuantPolicy.RAW_BYTES,
54+
preTransposed: Boolean = false
5155
): ApertusWeightLoader = ApertusWeightLoader(
5256
sourceProvider = null,
5357
randomAccessProvider = randomAccessProvider,
54-
quantPolicy = quantPolicy
58+
quantPolicy = quantPolicy,
59+
preTransposed = preTransposed
5560
)
5661
}
5762

@@ -115,7 +120,7 @@ public class ApertusWeightLoader private constructor(
115120
extractXIELUParamsFromReader(reader, tensorByName, metadata.blockCount, xieluParams)
116121
}
117122

118-
return ApertusWeights(metadata, byName, xieluParams)
123+
return ApertusWeights(metadata, byName, xieluParams, preTransposed)
119124
}
120125

121126
// ============== Streaming loading ==============
@@ -157,46 +162,65 @@ public class ApertusWeightLoader private constructor(
157162
extractXIELUParamsFromStreaming(reader, tensorByName, metadata.blockCount, xieluParams)
158163
}
159164

160-
ApertusWeights(metadata, byName, xieluParams)
165+
ApertusWeights(metadata, byName, xieluParams, preTransposed)
161166
}
162167
}
163168

164169
// ============== xIELU parameter extraction ==============
165170

166171
/**
167-
* Extract xIELU params from GGUF metadata fields (global, same for all layers).
168-
* Fields: xielu.alpha_p, xielu.alpha_n, xielu.beta, xielu.eps
172+
* Extract xIELU params from GGUF metadata fields.
173+
*
174+
* Fields are arrays of FLOAT32 with one value per layer:
175+
* xielu.alpha_p, xielu.alpha_n, xielu.beta, xielu.eps
169176
*/
170177
private fun extractXIELUParams(
171178
fields: Map<String, ReaderField>,
172179
blockCount: Int,
173180
out: MutableMap<Int, ApertusXIELUParams>
174181
) {
175-
val alphaP = fields["xielu.alpha_p"]?.scalarFloat() ?: return
176-
val alphaN = fields["xielu.alpha_n"]?.scalarFloat() ?: return
177-
val beta = fields["xielu.beta"]?.scalarFloat() ?: return
178-
val eps = fields["xielu.eps"]?.scalarFloat() ?: return
179-
val params = ApertusXIELUParams(alphaP, alphaN, beta, eps)
182+
val alphaPField = fields["xielu.alpha_p"] ?: return
183+
val alphaNField = fields["xielu.alpha_n"] ?: return
184+
val betaField = fields["xielu.beta"] ?: return
185+
val epsField = fields["xielu.eps"] ?: return
186+
187+
val alphaPArr = alphaPField.floatArray()
188+
val alphaNArr = alphaNField.floatArray()
189+
val betaArr = betaField.floatArray()
190+
val epsArr = epsField.floatArray()
191+
180192
for (layer in 0 until blockCount) {
181-
out[layer] = params
193+
out[layer] = ApertusXIELUParams(
194+
alphaP = alphaPArr.getOrElse(layer) { alphaPArr.first() },
195+
alphaN = alphaNArr.getOrElse(layer) { alphaNArr.first() },
196+
beta = betaArr.getOrElse(layer) { betaArr.first() },
197+
eps = epsArr.getOrElse(layer) { epsArr.first() }
198+
)
182199
}
183200
}
184201

185202
/**
186-
* Extract xIELU params from streaming GGUF metadata (global, same for all layers).
203+
* Extract xIELU params from streaming GGUF metadata.
204+
*
205+
* Values are per-layer arrays (one FLOAT32 per layer).
187206
*/
188207
private fun extractXIELUParamsFromStreamingMeta(
189208
fields: Map<String, Any?>,
190209
blockCount: Int,
191210
out: MutableMap<Int, ApertusXIELUParams>
192211
) {
193-
val alphaP = fields["xielu.alpha_p"]?.toFloatValue() ?: return
194-
val alphaN = fields["xielu.alpha_n"]?.toFloatValue() ?: return
195-
val beta = fields["xielu.beta"]?.toFloatValue() ?: return
196-
val eps = fields["xielu.eps"]?.toFloatValue() ?: return
197-
val params = ApertusXIELUParams(alphaP, alphaN, beta, eps)
212+
val alphaPArr = fields["xielu.alpha_p"]?.asFloatArray() ?: return
213+
val alphaNArr = fields["xielu.alpha_n"]?.asFloatArray() ?: return
214+
val betaArr = fields["xielu.beta"]?.asFloatArray() ?: return
215+
val epsArr = fields["xielu.eps"]?.asFloatArray() ?: return
216+
198217
for (layer in 0 until blockCount) {
199-
out[layer] = params
218+
out[layer] = ApertusXIELUParams(
219+
alphaP = alphaPArr.getOrElse(layer) { alphaPArr.first() },
220+
alphaN = alphaNArr.getOrElse(layer) { alphaNArr.first() },
221+
beta = betaArr.getOrElse(layer) { betaArr.first() },
222+
eps = epsArr.getOrElse(layer) { epsArr.first() }
223+
)
200224
}
201225
}
202226

@@ -515,6 +539,15 @@ public class ApertusWeightLoader private constructor(
515539
}
516540
}
517541

542+
/**
543+
* Create a tensor from dequantized float data.
544+
*
545+
* For 2D tensors from GGUF (stored column-major with shape [out, in]):
546+
* - Normal mode: transposes to row-major [in, out] (requires `.t()` in runtime)
547+
* - Pre-transposed mode: interprets column-major as row-major [in, out] directly,
548+
* skipping the data transpose. The weights can then be used directly in matmul
549+
* without `.t()`, saving ~50% memory.
550+
*/
518551
@Suppress("UNCHECKED_CAST")
519552
private fun <T : DType, V> createTensor(
520553
ctx: ExecutionContext,
@@ -525,9 +558,16 @@ public class ApertusWeightLoader private constructor(
525558
return if (originalShape.rank == 2) {
526559
val rows = originalShape[0]
527560
val cols = originalShape[1]
528-
val transposed = DequantOps.transposeColumnMajorToRowMajor(data, rows, cols)
529-
val newShape = Shape(cols, rows)
530-
ctx.fromFloatArray<T, Float>(newShape, dtype, transposed) as Tensor<T, V>
561+
if (preTransposed) {
562+
// Column-major [out, in] is equivalent to row-major [in, out]
563+
// Skip data transpose — weights are already in matmul-ready layout
564+
val newShape = Shape(cols, rows)
565+
ctx.fromFloatArray<T, Float>(newShape, dtype, data) as Tensor<T, V>
566+
} else {
567+
val transposed = DequantOps.transposeColumnMajorToRowMajor(data, rows, cols)
568+
val newShape = Shape(cols, rows)
569+
ctx.fromFloatArray<T, Float>(newShape, dtype, transposed) as Tensor<T, V>
570+
}
531571
} else {
532572
ctx.fromFloatArray<T, Float>(originalShape, dtype, data) as Tensor<T, V>
533573
}
@@ -568,6 +608,25 @@ public class ApertusWeightLoader private constructor(
568608
}
569609
}
570610

611+
/**
612+
* Extract a float array from a ReaderField (GGUF ARRAY of FLOAT32).
613+
* Each element is stored as a separate part; data indices point to them.
614+
*/
615+
private fun ReaderField.floatArray(): FloatArray {
616+
return FloatArray(data.size) { idx ->
617+
val partIdx = data[idx]
618+
val part = parts.getOrNull(partIdx) ?: error("Missing part $partIdx for field $name")
619+
val value = (part as List<*>).firstOrNull()
620+
?: error("Empty part for field $name at index $idx")
621+
when (value) {
622+
is Float -> value
623+
is Double -> value.toFloat()
624+
is Number -> value.toFloat()
625+
else -> error("Unsupported array element type ${value::class} for field $name")
626+
}
627+
}
628+
}
629+
571630
private fun ReaderField.stringValue(): String {
572631
val idx = data.firstOrNull() ?: 0
573632
val part = parts.getOrNull(idx) ?: error("Missing data part for field $name")
@@ -601,6 +660,25 @@ public class ApertusWeightLoader private constructor(
601660
else -> null
602661
}
603662

663+
/**
664+
* Convert a streaming metadata value (array or scalar) to a FloatArray.
665+
*/
666+
@Suppress("UNCHECKED_CAST")
667+
private fun Any?.asFloatArray(): FloatArray? = when (this) {
668+
is FloatArray -> this
669+
is List<*> -> FloatArray(size) { i ->
670+
when (val v = get(i)) {
671+
is Float -> v
672+
is Double -> v.toFloat()
673+
is Number -> v.toFloat()
674+
else -> return null
675+
}
676+
}
677+
is Float -> floatArrayOf(this)
678+
is Double -> floatArrayOf(this.toFloat())
679+
else -> null
680+
}
681+
604682
private fun inferEmbeddingFromTensor(tensors: List<ReaderTensor>): Int {
605683
val token = tensors.firstOrNull { it.name == ApertusTensorNames.TOKEN_EMBEDDINGS }
606684
?: error("Cannot infer embedding length without token embeddings tensor")
@@ -639,11 +717,13 @@ public suspend fun <T : DType> loadApertusRuntimeWeights(
639717
ctx: ExecutionContext,
640718
sourceProvider: () -> Source,
641719
dtype: KClass<T>,
642-
quantPolicy: QuantPolicy = QuantPolicy.DEQUANTIZE_TO_FP32
720+
quantPolicy: QuantPolicy = QuantPolicy.DEQUANTIZE_TO_FP32,
721+
preTransposed: Boolean = false
643722
): ApertusRuntimeWeights<T> {
644723
val loader = ApertusWeightLoader.fromSource(
645724
sourceProvider = sourceProvider,
646-
quantPolicy = quantPolicy
725+
quantPolicy = quantPolicy,
726+
preTransposed = preTransposed
647727
)
648728
val loaded = loader.loadToMap<T, Float>(ctx, dtype)
649729
return ApertusWeightMapper.map(loaded)
@@ -652,8 +732,9 @@ public suspend fun <T : DType> loadApertusRuntimeWeights(
652732
public suspend fun loadApertusRuntimeWeights(
653733
ctx: ExecutionContext,
654734
sourceProvider: () -> Source,
655-
quantPolicy: QuantPolicy = QuantPolicy.DEQUANTIZE_TO_FP32
656-
): ApertusRuntimeWeights<FP32> = loadApertusRuntimeWeights(ctx, sourceProvider, FP32::class, quantPolicy)
735+
quantPolicy: QuantPolicy = QuantPolicy.DEQUANTIZE_TO_FP32,
736+
preTransposed: Boolean = false
737+
): ApertusRuntimeWeights<FP32> = loadApertusRuntimeWeights(ctx, sourceProvider, FP32::class, quantPolicy, preTransposed)
657738

658739
/**
659740
* Load Apertus runtime weights from a GGUF source (streaming, for large files).
@@ -662,11 +743,13 @@ public suspend fun <T : DType> loadApertusRuntimeWeightsStreaming(
662743
ctx: ExecutionContext,
663744
randomAccessProvider: () -> RandomAccessSource,
664745
dtype: KClass<T>,
665-
quantPolicy: QuantPolicy = QuantPolicy.DEQUANTIZE_TO_FP32
746+
quantPolicy: QuantPolicy = QuantPolicy.DEQUANTIZE_TO_FP32,
747+
preTransposed: Boolean = false
666748
): ApertusRuntimeWeights<T> {
667749
val loader = ApertusWeightLoader.fromRandomAccess(
668750
randomAccessProvider = randomAccessProvider,
669-
quantPolicy = quantPolicy
751+
quantPolicy = quantPolicy,
752+
preTransposed = preTransposed
670753
)
671754
val loaded = loader.loadToMap<T, Float>(ctx, dtype)
672755
return ApertusWeightMapper.map(loaded)
@@ -675,5 +758,6 @@ public suspend fun <T : DType> loadApertusRuntimeWeightsStreaming(
675758
public suspend fun loadApertusRuntimeWeightsStreaming(
676759
ctx: ExecutionContext,
677760
randomAccessProvider: () -> RandomAccessSource,
678-
quantPolicy: QuantPolicy = QuantPolicy.DEQUANTIZE_TO_FP32
679-
): ApertusRuntimeWeights<FP32> = loadApertusRuntimeWeightsStreaming(ctx, randomAccessProvider, FP32::class, quantPolicy)
761+
quantPolicy: QuantPolicy = QuantPolicy.DEQUANTIZE_TO_FP32,
762+
preTransposed: Boolean = false
763+
): ApertusRuntimeWeights<FP32> = loadApertusRuntimeWeightsStreaming(ctx, randomAccessProvider, FP32::class, quantPolicy, preTransposed)

0 commit comments

Comments
 (0)