Skip to content

Commit 4cd1da9

Browse files
Merge pull request #98 from SKaiNET-developers/fix/apertus-real-loading
fix(apertus): real-model loading — UInt metadata + quantized shape
2 parents 68c2ff1 + 62f22bd commit 4cd1da9

8 files changed

Lines changed: 576 additions & 14 deletions

File tree

llm-core/build.gradle.kts

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,15 @@ kotlin {
5858

5959
val jvmMain by getting
6060

61+
val jvmTest by getting {
62+
dependencies {
63+
implementation(libs.kotlin.test)
64+
implementation(libs.junit)
65+
implementation(libs.skainet.io.gguf)
66+
implementation(libs.skainet.io.core)
67+
}
68+
}
69+
6170
// Shared source set for all non-JVM targets (manual BackendRegistry)
6271
val registryBasedMain by creating {
6372
dependsOn(commonMain.get())

llm-core/src/commonMain/kotlin/sk/ainet/apps/llm/UnifiedModelLoader.kt

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package sk.ainet.apps.llm
22

33
import sk.ainet.io.RandomAccessSource
44
import sk.ainet.io.gguf.StreamingGGUFReader
5+
import sk.ainet.io.gguf.getInt
56
import sk.ainet.lang.types.DType
67

78
/**
@@ -54,11 +55,11 @@ public object UnifiedModelLoader {
5455
GGUFModelInfo(
5556
architecture = arch,
5657
family = family,
57-
contextLength = (fields["${arch}.context_length"] as? Number)?.toInt() ?: 4096,
58-
vocabSize = (fields["${arch}.vocab_size"] as? Number)?.toInt()
58+
contextLength = fields.getInt("${arch}.context_length") ?: 4096,
59+
vocabSize = fields.getInt("${arch}.vocab_size")
5960
?: ((fields["tokenizer.ggml.tokens"] as? List<*>)?.size ?: 0),
60-
blockCount = (fields["${arch}.block_count"] as? Number)?.toInt() ?: 0,
61-
embeddingLength = (fields["${arch}.embedding_length"] as? Number)?.toInt() ?: 0,
61+
blockCount = fields.getInt("${arch}.block_count") ?: 0,
62+
embeddingLength = fields.getInt("${arch}.embedding_length") ?: 0,
6263
fields = fields
6364
)
6465
}
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
package sk.ainet.apps.llm
2+
3+
import sk.ainet.io.JvmRandomAccessSource
4+
import sk.ainet.io.gguf.export.GGUFWriter
5+
import sk.ainet.io.gguf.export.GgufWriteRequest
6+
import java.nio.file.Files
7+
import kotlin.test.Test
8+
import kotlin.test.assertEquals
9+
10+
/**
11+
* Regression tests for [UnifiedModelLoader.peek] handling of GGUF metadata
12+
* fields stored as unsigned integer types.
13+
*
14+
* Before the fix, `(fields[...] as? Number)?.toInt()` silently returned null
15+
* for `UInt`/`ULong` values (they are not subtypes of `Number` in Kotlin),
16+
* causing every modern GGUF — which uses uint32 dimensions — to fall back
17+
* to the defaults: contextLength=4096, blockCount=0, embeddingLength=0.
18+
* A blockCount of 0 yields a model with zero transformer layers.
19+
*/
20+
class UnifiedModelLoaderUIntMetadataTest {
21+
22+
@Test
23+
fun peek_reads_uint32_metadata_fields() {
24+
val bytes = buildGgufBytes(
25+
arch = "apertus",
26+
metadata = mapOf(
27+
"apertus.context_length" to 8192u,
28+
"apertus.block_count" to 32u,
29+
"apertus.embedding_length" to 4096u,
30+
"apertus.vocab_size" to 128256u
31+
)
32+
)
33+
34+
val info = peekFromBytes(bytes)
35+
36+
assertEquals("apertus", info.architecture)
37+
assertEquals(8192, info.contextLength)
38+
assertEquals(32, info.blockCount)
39+
assertEquals(4096, info.embeddingLength)
40+
assertEquals(128256, info.vocabSize)
41+
}
42+
43+
@Test
44+
fun peek_reads_uint64_metadata_fields() {
45+
val bytes = buildGgufBytes(
46+
arch = "apertus",
47+
metadata = mapOf(
48+
"apertus.context_length" to 8192uL,
49+
"apertus.block_count" to 32uL,
50+
"apertus.embedding_length" to 4096uL,
51+
"apertus.vocab_size" to 128256uL
52+
)
53+
)
54+
55+
val info = peekFromBytes(bytes)
56+
57+
assertEquals(8192, info.contextLength)
58+
assertEquals(32, info.blockCount)
59+
assertEquals(4096, info.embeddingLength)
60+
assertEquals(128256, info.vocabSize)
61+
}
62+
63+
@Test
64+
fun peek_reads_int32_metadata_fields() {
65+
val bytes = buildGgufBytes(
66+
arch = "apertus",
67+
metadata = mapOf(
68+
"apertus.context_length" to 8192,
69+
"apertus.block_count" to 32,
70+
"apertus.embedding_length" to 4096,
71+
"apertus.vocab_size" to 128256
72+
)
73+
)
74+
75+
val info = peekFromBytes(bytes)
76+
77+
assertEquals(8192, info.contextLength)
78+
assertEquals(32, info.blockCount)
79+
assertEquals(4096, info.embeddingLength)
80+
assertEquals(128256, info.vocabSize)
81+
}
82+
83+
@Test
84+
fun peek_falls_back_to_defaults_when_fields_missing() {
85+
val bytes = buildGgufBytes(
86+
arch = "apertus",
87+
metadata = emptyMap()
88+
)
89+
90+
val info = peekFromBytes(bytes)
91+
92+
assertEquals("apertus", info.architecture)
93+
assertEquals(4096, info.contextLength) // default
94+
assertEquals(0, info.blockCount)
95+
assertEquals(0, info.embeddingLength)
96+
assertEquals(0, info.vocabSize)
97+
}
98+
99+
private fun buildGgufBytes(arch: String, metadata: Map<String, Any>): ByteArray {
100+
val merged = LinkedHashMap<String, Any>()
101+
merged["general.architecture"] = arch
102+
merged.putAll(metadata)
103+
val request = GgufWriteRequest(
104+
metadata = merged,
105+
tensors = emptyList(),
106+
tensorMap = emptyMap()
107+
)
108+
return GGUFWriter.writeToByteArray(request).second
109+
}
110+
111+
private fun peekFromBytes(bytes: ByteArray): GGUFModelInfo {
112+
val tempFile = Files.createTempFile("uint-meta", ".gguf").toFile()
113+
tempFile.deleteOnExit()
114+
tempFile.writeBytes(bytes)
115+
return UnifiedModelLoader.peek { JvmRandomAccessSource.open(tempFile) }
116+
}
117+
}

llm-inference/apertus/build.gradle.kts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,5 +71,5 @@ kotlin {
7171

7272
tasks.withType<Test>().configureEach {
7373
jvmArgs("--enable-preview", "--add-modules", "jdk.incubator.vector", "-XX:MaxDirectMemorySize=12g")
74-
maxHeapSize = "6g"
74+
maxHeapSize = (findProperty("apertusTestMaxHeap") as? String) ?: "6g"
7575
}

llm-inference/apertus/src/commonMain/kotlin/sk/ainet/models/apertus/ApertusWeightLoader.kt

Lines changed: 70 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -120,12 +120,13 @@ public class ApertusWeightLoader private constructor(
120120
requiredTensorNames(metadata).forEach { name ->
121121
val rt = tensorByName[name]
122122
?: error("Missing required tensor in GGUF payload: $name")
123-
byName[name] = readerTensorToTensor(ctx, dtype, reader, rt)
123+
byName[name] = loadReaderTensor(ctx, dtype, reader, rt, name)
124124
}
125125

126126
// Load optional rope_freqs tensor
127127
tensorByName[ApertusTensorNames.ROPE_FREQS]?.let { rt ->
128-
byName[ApertusTensorNames.ROPE_FREQS] = readerTensorToTensor(ctx, dtype, reader, rt)
128+
byName[ApertusTensorNames.ROPE_FREQS] =
129+
loadReaderTensor(ctx, dtype, reader, rt, ApertusTensorNames.ROPE_FREQS)
129130
}
130131

131132
// Extract xIELU params: try metadata fields first, then per-layer tensors
@@ -162,12 +163,13 @@ public class ApertusWeightLoader private constructor(
162163
requiredTensorNames(metadata).forEach { name ->
163164
val st = tensorByName[name]
164165
?: error("Missing required tensor in GGUF payload: $name")
165-
byName[name] = streamingTensorToTensor(ctx, dtype, reader, st)
166+
byName[name] = loadStreamingTensor(ctx, dtype, reader, st, name)
166167
}
167168

168169
// Load optional rope_freqs tensor
169170
tensorByName[ApertusTensorNames.ROPE_FREQS]?.let { st ->
170-
byName[ApertusTensorNames.ROPE_FREQS] = streamingTensorToTensor(ctx, dtype, reader, st)
171+
byName[ApertusTensorNames.ROPE_FREQS] =
172+
loadStreamingTensor(ctx, dtype, reader, st, ApertusTensorNames.ROPE_FREQS)
171173
}
172174

173175
// Extract xIELU params: try metadata fields first, then per-layer tensors
@@ -560,6 +562,58 @@ public class ApertusWeightLoader private constructor(
560562

561563
// ============== Tensor conversion ==============
562564

565+
/**
566+
* NATIVE_OPTIMIZED stores quantized tensors as byte-level rank-1 buffers so the
567+
* native FFM kernels can address the raw block layout directly. That works for
568+
* matmul (the kernel knows the logical shape from metadata) but breaks the
569+
* token embedding, where `Embedding.gather()` requires the logical rank-2
570+
* `[vocab, dim]` shape. Force `token_embd.weight` through the dequant path so
571+
* the embedding lookup gets a real `[vocab, dim]` FP32/FP16 tensor regardless
572+
* of the policy chosen for the rest of the model.
573+
*/
574+
private fun <T : DType, V> loadStreamingTensor(
575+
ctx: ExecutionContext,
576+
dtype: KClass<T>,
577+
reader: StreamingGGUFReader,
578+
st: StreamingTensorInfo,
579+
name: String
580+
): Tensor<T, V> {
581+
if (name == ApertusTensorNames.TOKEN_EMBEDDINGS &&
582+
quantPolicy == QuantPolicy.NATIVE_OPTIMIZED &&
583+
st.tensorType != GGMLQuantizationType.F32 &&
584+
st.tensorType != GGMLQuantizationType.F16 &&
585+
st.tensorType != GGMLQuantizationType.BF16
586+
) {
587+
val shape = Shape(*st.shape.map { it.toInt() }.toIntArray())
588+
val bytes = reader.loadTensorData(st)
589+
val floats = DequantOps.dequantFromBytes(bytes, st.tensorType, st.nElements.toInt())
590+
return createTensor(ctx, dtype, shape, floats)
591+
}
592+
return streamingTensorToTensor(ctx, dtype, reader, st)
593+
}
594+
595+
private fun <T : DType, V> loadReaderTensor(
596+
ctx: ExecutionContext,
597+
dtype: KClass<T>,
598+
reader: GGUFReader,
599+
rt: ReaderTensor,
600+
name: String
601+
): Tensor<T, V> {
602+
if (name == ApertusTensorNames.TOKEN_EMBEDDINGS &&
603+
quantPolicy == QuantPolicy.NATIVE_OPTIMIZED &&
604+
rt.tensorType != GGMLQuantizationType.F32 &&
605+
rt.tensorType != GGMLQuantizationType.F16 &&
606+
rt.tensorType != GGMLQuantizationType.BF16
607+
) {
608+
val shape = Shape(*rt.shape.map { it.toInt() }.toIntArray())
609+
val raw = if (rt.data.isEmpty()) reader.materialize(rt) else rt.data
610+
val bytes: ByteArray = DequantOps.toByteArray(raw, rt.name)
611+
val floats = DequantOps.dequantFromBytes(bytes, rt.tensorType, rt.nElements)
612+
return createTensor(ctx, dtype, shape, floats)
613+
}
614+
return readerTensorToTensor(ctx, dtype, reader, rt)
615+
}
616+
563617
@Suppress("UNCHECKED_CAST")
564618
private fun <T : DType, V> readerTensorToTensor(
565619
ctx: ExecutionContext,
@@ -631,7 +685,7 @@ public class ApertusWeightLoader private constructor(
631685
}
632686

633687
@Suppress("UNCHECKED_CAST")
634-
private fun <T : DType, V> streamingTensorToTensor(
688+
internal fun <T : DType, V> streamingTensorToTensor(
635689
ctx: ExecutionContext,
636690
dtype: KClass<T>,
637691
reader: StreamingGGUFReader,
@@ -676,9 +730,19 @@ public class ApertusWeightLoader private constructor(
676730
GGMLQuantizationType.IQ4_NL, GGMLQuantizationType.IQ4_XS,
677731
GGMLQuantizationType.TQ1_0, GGMLQuantizationType.TQ2_0 -> {
678732
when (quantPolicy) {
679-
QuantPolicy.RAW_BYTES, QuantPolicy.NATIVE_OPTIMIZED -> {
733+
QuantPolicy.RAW_BYTES -> {
734+
require(dtype == Int8::class) {
735+
"Quantized tensor ${st.name} requires dtype Int8 with quantPolicy=RAW_BYTES"
736+
}
680737
ctx.fromByteArray<Int8, Byte>(shape, Int8::class, bytes) as Tensor<T, V>
681738
}
739+
QuantPolicy.NATIVE_OPTIMIZED -> {
740+
// Store raw quantized bytes; dtype can be FP32 (mixed mode).
741+
// Streaming reader preserves logical shape, so use byte-level shape.
742+
val byteShape = Shape(bytes.size)
743+
@Suppress("UNCHECKED_CAST")
744+
ctx.fromByteArray<Int8, Byte>(byteShape, Int8::class, bytes) as Tensor<T, V>
745+
}
682746
QuantPolicy.DEQUANTIZE_TO_FP32 -> {
683747
val floats = DequantOps.dequantFromBytes(bytes, st.tensorType, st.nElements.toInt())
684748
createTensor(ctx, dtype, shape, floats)

0 commit comments

Comments
 (0)