@@ -32,26 +32,31 @@ import kotlin.reflect.KClass
3232public class ApertusWeightLoader private constructor(
3333 private val sourceProvider : (() -> Source )? ,
3434 private val randomAccessProvider : (() -> RandomAccessSource )? ,
35- private val quantPolicy : QuantPolicy = QuantPolicy .RAW_BYTES
35+ private val quantPolicy : QuantPolicy = QuantPolicy .RAW_BYTES ,
36+ private val preTransposed : Boolean = false
3637) {
3738
3839 public companion object {
3940 public fun fromSource (
4041 sourceProvider : () -> Source ,
41- quantPolicy : QuantPolicy = QuantPolicy .RAW_BYTES
42+ quantPolicy : QuantPolicy = QuantPolicy .RAW_BYTES ,
43+ preTransposed : Boolean = false
4244 ): ApertusWeightLoader = ApertusWeightLoader (
4345 sourceProvider = sourceProvider,
4446 randomAccessProvider = null ,
45- quantPolicy = quantPolicy
47+ quantPolicy = quantPolicy,
48+ preTransposed = preTransposed
4649 )
4750
4851 public fun fromRandomAccess (
4952 randomAccessProvider : () -> RandomAccessSource ,
50- quantPolicy : QuantPolicy = QuantPolicy .RAW_BYTES
53+ quantPolicy : QuantPolicy = QuantPolicy .RAW_BYTES ,
54+ preTransposed : Boolean = false
5155 ): ApertusWeightLoader = ApertusWeightLoader (
5256 sourceProvider = null ,
5357 randomAccessProvider = randomAccessProvider,
54- quantPolicy = quantPolicy
58+ quantPolicy = quantPolicy,
59+ preTransposed = preTransposed
5560 )
5661 }
5762
@@ -115,7 +120,7 @@ public class ApertusWeightLoader private constructor(
115120 extractXIELUParamsFromReader(reader, tensorByName, metadata.blockCount, xieluParams)
116121 }
117122
118- return ApertusWeights (metadata, byName, xieluParams)
123+ return ApertusWeights (metadata, byName, xieluParams, preTransposed )
119124 }
120125
121126 // ============== Streaming loading ==============
@@ -157,46 +162,65 @@ public class ApertusWeightLoader private constructor(
157162 extractXIELUParamsFromStreaming(reader, tensorByName, metadata.blockCount, xieluParams)
158163 }
159164
160- ApertusWeights (metadata, byName, xieluParams)
165+ ApertusWeights (metadata, byName, xieluParams, preTransposed )
161166 }
162167 }
163168
164169 // ============== xIELU parameter extraction ==============
165170
166171 /* *
167- * Extract xIELU params from GGUF metadata fields (global, same for all layers).
168- * Fields: xielu.alpha_p, xielu.alpha_n, xielu.beta, xielu.eps
172+ * Extract xIELU params from GGUF metadata fields.
173+ *
174+ * Fields are arrays of FLOAT32 with one value per layer:
175+ * xielu.alpha_p, xielu.alpha_n, xielu.beta, xielu.eps
169176 */
170177 private fun extractXIELUParams (
171178 fields : Map <String , ReaderField >,
172179 blockCount : Int ,
173180 out : MutableMap <Int , ApertusXIELUParams >
174181 ) {
175- val alphaP = fields[" xielu.alpha_p" ]?.scalarFloat() ? : return
176- val alphaN = fields[" xielu.alpha_n" ]?.scalarFloat() ? : return
177- val beta = fields[" xielu.beta" ]?.scalarFloat() ? : return
178- val eps = fields[" xielu.eps" ]?.scalarFloat() ? : return
179- val params = ApertusXIELUParams (alphaP, alphaN, beta, eps)
182+ val alphaPField = fields[" xielu.alpha_p" ] ? : return
183+ val alphaNField = fields[" xielu.alpha_n" ] ? : return
184+ val betaField = fields[" xielu.beta" ] ? : return
185+ val epsField = fields[" xielu.eps" ] ? : return
186+
187+ val alphaPArr = alphaPField.floatArray()
188+ val alphaNArr = alphaNField.floatArray()
189+ val betaArr = betaField.floatArray()
190+ val epsArr = epsField.floatArray()
191+
180192 for (layer in 0 until blockCount) {
181- out [layer] = params
193+ out [layer] = ApertusXIELUParams (
194+ alphaP = alphaPArr.getOrElse(layer) { alphaPArr.first() },
195+ alphaN = alphaNArr.getOrElse(layer) { alphaNArr.first() },
196+ beta = betaArr.getOrElse(layer) { betaArr.first() },
197+ eps = epsArr.getOrElse(layer) { epsArr.first() }
198+ )
182199 }
183200 }
184201
185202 /* *
186- * Extract xIELU params from streaming GGUF metadata (global, same for all layers).
203+ * Extract xIELU params from streaming GGUF metadata.
204+ *
205+ * Values are per-layer arrays (one FLOAT32 per layer).
187206 */
188207 private fun extractXIELUParamsFromStreamingMeta (
189208 fields : Map <String , Any ?>,
190209 blockCount : Int ,
191210 out : MutableMap <Int , ApertusXIELUParams >
192211 ) {
193- val alphaP = fields[" xielu.alpha_p" ]?.toFloatValue () ? : return
194- val alphaN = fields[" xielu.alpha_n" ]?.toFloatValue () ? : return
195- val beta = fields[" xielu.beta" ]?.toFloatValue () ? : return
196- val eps = fields[" xielu.eps" ]?.toFloatValue () ? : return
197- val params = ApertusXIELUParams (alphaP, alphaN, beta, eps)
212+ val alphaPArr = fields[" xielu.alpha_p" ]?.asFloatArray () ? : return
213+ val alphaNArr = fields[" xielu.alpha_n" ]?.asFloatArray () ? : return
214+ val betaArr = fields[" xielu.beta" ]?.asFloatArray () ? : return
215+ val epsArr = fields[" xielu.eps" ]?.asFloatArray () ? : return
216+
198217 for (layer in 0 until blockCount) {
199- out [layer] = params
218+ out [layer] = ApertusXIELUParams (
219+ alphaP = alphaPArr.getOrElse(layer) { alphaPArr.first() },
220+ alphaN = alphaNArr.getOrElse(layer) { alphaNArr.first() },
221+ beta = betaArr.getOrElse(layer) { betaArr.first() },
222+ eps = epsArr.getOrElse(layer) { epsArr.first() }
223+ )
200224 }
201225 }
202226
@@ -515,6 +539,15 @@ public class ApertusWeightLoader private constructor(
515539 }
516540 }
517541
542+ /* *
543+ * Create a tensor from dequantized float data.
544+ *
545+ * For 2D tensors from GGUF (stored column-major with shape [out, in]):
546+ * - Normal mode: transposes to row-major [in, out] (requires `.t()` in runtime)
547+ * - Pre-transposed mode: interprets column-major as row-major [in, out] directly,
548+ * skipping the data transpose. The weights can then be used directly in matmul
549+ * without `.t()`, saving ~50% memory.
550+ */
518551 @Suppress(" UNCHECKED_CAST" )
519552 private fun <T : DType , V > createTensor (
520553 ctx : ExecutionContext ,
@@ -525,9 +558,16 @@ public class ApertusWeightLoader private constructor(
525558 return if (originalShape.rank == 2 ) {
526559 val rows = originalShape[0 ]
527560 val cols = originalShape[1 ]
528- val transposed = DequantOps .transposeColumnMajorToRowMajor(data, rows, cols)
529- val newShape = Shape (cols, rows)
530- ctx.fromFloatArray<T , Float >(newShape, dtype, transposed) as Tensor <T , V >
561+ if (preTransposed) {
562+ // Column-major [out, in] is equivalent to row-major [in, out]
563+ // Skip data transpose — weights are already in matmul-ready layout
564+ val newShape = Shape (cols, rows)
565+ ctx.fromFloatArray<T , Float >(newShape, dtype, data) as Tensor <T , V >
566+ } else {
567+ val transposed = DequantOps .transposeColumnMajorToRowMajor(data, rows, cols)
568+ val newShape = Shape (cols, rows)
569+ ctx.fromFloatArray<T , Float >(newShape, dtype, transposed) as Tensor <T , V >
570+ }
531571 } else {
532572 ctx.fromFloatArray<T , Float >(originalShape, dtype, data) as Tensor <T , V >
533573 }
@@ -568,6 +608,25 @@ public class ApertusWeightLoader private constructor(
568608 }
569609 }
570610
611+ /* *
612+ * Extract a float array from a ReaderField (GGUF ARRAY of FLOAT32).
613+ * Each element is stored as a separate part; data indices point to them.
614+ */
615+ private fun ReaderField.floatArray (): FloatArray {
616+ return FloatArray (data.size) { idx ->
617+ val partIdx = data[idx]
618+ val part = parts.getOrNull(partIdx) ? : error(" Missing part $partIdx for field $name " )
619+ val value = (part as List <* >).firstOrNull()
620+ ? : error(" Empty part for field $name at index $idx " )
621+ when (value) {
622+ is Float -> value
623+ is Double -> value.toFloat()
624+ is Number -> value.toFloat()
625+ else -> error(" Unsupported array element type ${value::class } for field $name " )
626+ }
627+ }
628+ }
629+
571630 private fun ReaderField.stringValue (): String {
572631 val idx = data.firstOrNull() ? : 0
573632 val part = parts.getOrNull(idx) ? : error(" Missing data part for field $name " )
@@ -601,6 +660,25 @@ public class ApertusWeightLoader private constructor(
601660 else -> null
602661 }
603662
663+ /* *
664+ * Convert a streaming metadata value (array or scalar) to a FloatArray.
665+ */
666+ @Suppress(" UNCHECKED_CAST" )
667+ private fun Any?.asFloatArray (): FloatArray? = when (this ) {
668+ is FloatArray -> this
669+ is List <* > -> FloatArray (size) { i ->
670+ when (val v = get(i)) {
671+ is Float -> v
672+ is Double -> v.toFloat()
673+ is Number -> v.toFloat()
674+ else -> return null
675+ }
676+ }
677+ is Float -> floatArrayOf(this )
678+ is Double -> floatArrayOf(this .toFloat())
679+ else -> null
680+ }
681+
604682 private fun inferEmbeddingFromTensor (tensors : List <ReaderTensor >): Int {
605683 val token = tensors.firstOrNull { it.name == ApertusTensorNames .TOKEN_EMBEDDINGS }
606684 ? : error(" Cannot infer embedding length without token embeddings tensor" )
@@ -639,11 +717,13 @@ public suspend fun <T : DType> loadApertusRuntimeWeights(
639717 ctx : ExecutionContext ,
640718 sourceProvider : () -> Source ,
641719 dtype : KClass <T >,
642- quantPolicy : QuantPolicy = QuantPolicy .DEQUANTIZE_TO_FP32
720+ quantPolicy : QuantPolicy = QuantPolicy .DEQUANTIZE_TO_FP32 ,
721+ preTransposed : Boolean = false
643722): ApertusRuntimeWeights <T > {
644723 val loader = ApertusWeightLoader .fromSource(
645724 sourceProvider = sourceProvider,
646- quantPolicy = quantPolicy
725+ quantPolicy = quantPolicy,
726+ preTransposed = preTransposed
647727 )
648728 val loaded = loader.loadToMap<T , Float >(ctx, dtype)
649729 return ApertusWeightMapper .map(loaded)
@@ -652,8 +732,9 @@ public suspend fun <T : DType> loadApertusRuntimeWeights(
652732public suspend fun loadApertusRuntimeWeights (
653733 ctx : ExecutionContext ,
654734 sourceProvider : () -> Source ,
655- quantPolicy : QuantPolicy = QuantPolicy .DEQUANTIZE_TO_FP32
656- ): ApertusRuntimeWeights <FP32 > = loadApertusRuntimeWeights(ctx, sourceProvider, FP32 ::class , quantPolicy)
735+ quantPolicy : QuantPolicy = QuantPolicy .DEQUANTIZE_TO_FP32 ,
736+ preTransposed : Boolean = false
737+ ): ApertusRuntimeWeights <FP32 > = loadApertusRuntimeWeights(ctx, sourceProvider, FP32 ::class , quantPolicy, preTransposed)
657738
658739/* *
659740 * Load Apertus runtime weights from a GGUF source (streaming, for large files).
@@ -662,11 +743,13 @@ public suspend fun <T : DType> loadApertusRuntimeWeightsStreaming(
662743 ctx : ExecutionContext ,
663744 randomAccessProvider : () -> RandomAccessSource ,
664745 dtype : KClass <T >,
665- quantPolicy : QuantPolicy = QuantPolicy .DEQUANTIZE_TO_FP32
746+ quantPolicy : QuantPolicy = QuantPolicy .DEQUANTIZE_TO_FP32 ,
747+ preTransposed : Boolean = false
666748): ApertusRuntimeWeights <T > {
667749 val loader = ApertusWeightLoader .fromRandomAccess(
668750 randomAccessProvider = randomAccessProvider,
669- quantPolicy = quantPolicy
751+ quantPolicy = quantPolicy,
752+ preTransposed = preTransposed
670753 )
671754 val loaded = loader.loadToMap<T , Float >(ctx, dtype)
672755 return ApertusWeightMapper .map(loaded)
@@ -675,5 +758,6 @@ public suspend fun <T : DType> loadApertusRuntimeWeightsStreaming(
675758public suspend fun loadApertusRuntimeWeightsStreaming (
676759 ctx : ExecutionContext ,
677760 randomAccessProvider : () -> RandomAccessSource ,
678- quantPolicy : QuantPolicy = QuantPolicy .DEQUANTIZE_TO_FP32
679- ): ApertusRuntimeWeights <FP32 > = loadApertusRuntimeWeightsStreaming(ctx, randomAccessProvider, FP32 ::class , quantPolicy)
761+ quantPolicy : QuantPolicy = QuantPolicy .DEQUANTIZE_TO_FP32 ,
762+ preTransposed : Boolean = false
763+ ): ApertusRuntimeWeights <FP32 > = loadApertusRuntimeWeightsStreaming(ctx, randomAccessProvider, FP32 ::class , quantPolicy, preTransposed)
0 commit comments