@@ -120,12 +120,13 @@ public class ApertusWeightLoader private constructor(
120120 requiredTensorNames(metadata).forEach { name ->
121121 val rt = tensorByName[name]
122122 ? : error(" Missing required tensor in GGUF payload: $name " )
123- byName[name] = readerTensorToTensor (ctx, dtype, reader, rt)
123+ byName[name] = loadReaderTensor (ctx, dtype, reader, rt, name )
124124 }
125125
126126 // Load optional rope_freqs tensor
127127 tensorByName[ApertusTensorNames .ROPE_FREQS ]?.let { rt ->
128- byName[ApertusTensorNames .ROPE_FREQS ] = readerTensorToTensor(ctx, dtype, reader, rt)
128+ byName[ApertusTensorNames .ROPE_FREQS ] =
129+ loadReaderTensor(ctx, dtype, reader, rt, ApertusTensorNames .ROPE_FREQS )
129130 }
130131
131132 // Extract xIELU params: try metadata fields first, then per-layer tensors
@@ -162,12 +163,13 @@ public class ApertusWeightLoader private constructor(
162163 requiredTensorNames(metadata).forEach { name ->
163164 val st = tensorByName[name]
164165 ? : error(" Missing required tensor in GGUF payload: $name " )
165- byName[name] = streamingTensorToTensor (ctx, dtype, reader, st)
166+ byName[name] = loadStreamingTensor (ctx, dtype, reader, st, name )
166167 }
167168
168169 // Load optional rope_freqs tensor
169170 tensorByName[ApertusTensorNames .ROPE_FREQS ]?.let { st ->
170- byName[ApertusTensorNames .ROPE_FREQS ] = streamingTensorToTensor(ctx, dtype, reader, st)
171+ byName[ApertusTensorNames .ROPE_FREQS ] =
172+ loadStreamingTensor(ctx, dtype, reader, st, ApertusTensorNames .ROPE_FREQS )
171173 }
172174
173175 // Extract xIELU params: try metadata fields first, then per-layer tensors
@@ -560,6 +562,58 @@ public class ApertusWeightLoader private constructor(
560562
561563 // ============== Tensor conversion ==============
562564
565+ /* *
566+ * NATIVE_OPTIMIZED stores quantized tensors as byte-level rank-1 buffers so the
567+ * native FFM kernels can address the raw block layout directly. That works for
568+ * matmul (the kernel knows the logical shape from metadata) but breaks the
569+ * token embedding, where `Embedding.gather()` requires the logical rank-2
570+ * `[vocab, dim]` shape. Force `token_embd.weight` through the dequant path so
571+ * the embedding lookup gets a real `[vocab, dim]` FP32/FP16 tensor regardless
572+ * of the policy chosen for the rest of the model.
573+ */
574+ private fun <T : DType , V > loadStreamingTensor (
575+ ctx : ExecutionContext ,
576+ dtype : KClass <T >,
577+ reader : StreamingGGUFReader ,
578+ st : StreamingTensorInfo ,
579+ name : String
580+ ): Tensor <T , V > {
581+ if (name == ApertusTensorNames .TOKEN_EMBEDDINGS &&
582+ quantPolicy == QuantPolicy .NATIVE_OPTIMIZED &&
583+ st.tensorType != GGMLQuantizationType .F32 &&
584+ st.tensorType != GGMLQuantizationType .F16 &&
585+ st.tensorType != GGMLQuantizationType .BF16
586+ ) {
587+ val shape = Shape (* st.shape.map { it.toInt() }.toIntArray())
588+ val bytes = reader.loadTensorData(st)
589+ val floats = DequantOps .dequantFromBytes(bytes, st.tensorType, st.nElements.toInt())
590+ return createTensor(ctx, dtype, shape, floats)
591+ }
592+ return streamingTensorToTensor(ctx, dtype, reader, st)
593+ }
594+
595+ private fun <T : DType , V > loadReaderTensor (
596+ ctx : ExecutionContext ,
597+ dtype : KClass <T >,
598+ reader : GGUFReader ,
599+ rt : ReaderTensor ,
600+ name : String
601+ ): Tensor <T , V > {
602+ if (name == ApertusTensorNames .TOKEN_EMBEDDINGS &&
603+ quantPolicy == QuantPolicy .NATIVE_OPTIMIZED &&
604+ rt.tensorType != GGMLQuantizationType .F32 &&
605+ rt.tensorType != GGMLQuantizationType .F16 &&
606+ rt.tensorType != GGMLQuantizationType .BF16
607+ ) {
608+ val shape = Shape (* rt.shape.map { it.toInt() }.toIntArray())
609+ val raw = if (rt.data.isEmpty()) reader.materialize(rt) else rt.data
610+ val bytes: ByteArray = DequantOps .toByteArray(raw, rt.name)
611+ val floats = DequantOps .dequantFromBytes(bytes, rt.tensorType, rt.nElements)
612+ return createTensor(ctx, dtype, shape, floats)
613+ }
614+ return readerTensorToTensor(ctx, dtype, reader, rt)
615+ }
616+
563617 @Suppress(" UNCHECKED_CAST" )
564618 private fun <T : DType , V > readerTensorToTensor (
565619 ctx : ExecutionContext ,
@@ -631,7 +685,7 @@ public class ApertusWeightLoader private constructor(
631685 }
632686
633687 @Suppress(" UNCHECKED_CAST" )
634- private fun <T : DType , V > streamingTensorToTensor (
688+ internal fun <T : DType , V > streamingTensorToTensor (
635689 ctx : ExecutionContext ,
636690 dtype : KClass <T >,
637691 reader : StreamingGGUFReader ,
@@ -676,9 +730,19 @@ public class ApertusWeightLoader private constructor(
676730 GGMLQuantizationType .IQ4_NL , GGMLQuantizationType .IQ4_XS ,
677731 GGMLQuantizationType .TQ1_0 , GGMLQuantizationType .TQ2_0 -> {
678732 when (quantPolicy) {
679- QuantPolicy .RAW_BYTES , QuantPolicy .NATIVE_OPTIMIZED -> {
733+ QuantPolicy .RAW_BYTES -> {
734+ require(dtype == Int8 ::class ) {
735+ " Quantized tensor ${st.name} requires dtype Int8 with quantPolicy=RAW_BYTES"
736+ }
680737 ctx.fromByteArray<Int8 , Byte >(shape, Int8 ::class , bytes) as Tensor <T , V >
681738 }
739+ QuantPolicy .NATIVE_OPTIMIZED -> {
740+ // Store raw quantized bytes; dtype can be FP32 (mixed mode).
741+ // Streaming reader preserves logical shape, so use byte-level shape.
742+ val byteShape = Shape (bytes.size)
743+ @Suppress(" UNCHECKED_CAST" )
744+ ctx.fromByteArray<Int8 , Byte >(byteShape, Int8 ::class , bytes) as Tensor <T , V >
745+ }
682746 QuantPolicy .DEQUANTIZE_TO_FP32 -> {
683747 val floats = DequantOps .dequantFromBytes(bytes, st.tensorType, st.nElements.toInt())
684748 createTensor(ctx, dtype, shape, floats)
0 commit comments