Skip to content

Commit 3d3c6ff

Browse files
michalharakalclaude
andcommitted
fix(apertus): force-dequant token_embd under NATIVE_OPTIMIZED
ApertusWeightLoader.streamingTensorToTensor / readerTensorToTensor wrap quantized weights with byte-level rank-1 shape under QuantPolicy.NATIVE_OPTIMIZED so the native FFM kernels can address the block layout directly. That works for matmul (the kernel knows the logical shape from metadata) but breaks Embedding.gather, which requires the logical rank-2 [vocab, dim] shape — a rank-1 weight tensor errors with "gather: unsupported input rank 1". Surfaced by ApertusNetworkLoader.fromGguf().load() on real unsloth/Apertus-8B-Instruct-2509 Q4_K_S: token_embd is stored as Q4_K in the GGUF and gets the byte-level shape, so the very first forward pass through the embedding layer dies before any logit math. Add loadStreamingTensor / loadReaderTensor wrappers around the existing *ToTensor helpers. They route token_embd.weight through the dequant path (DequantOps.dequantFromBytes → createTensor with the logical [vocab, dim] shape) when quantPolicy is NATIVE_OPTIMIZED and the tensor is a quantized type. Other tensors keep their NATIVE_OPTIMIZED byte-level layout for kernel dispatch. The integration test class kdoc documents the next blocker that prevents end-to-end inference (linearProject in MultiHeadAttention calls ops.transpose on byte-shape weights for Q/K/V/O and FFN projections, which Gemma solves via Q4_KBlockTensorData but Apertus doesn't yet implement). Tracked as #100. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 583ebbc commit 3d3c6ff

2 files changed

Lines changed: 86 additions & 4 deletions

File tree

llm-inference/apertus/src/commonMain/kotlin/sk/ainet/models/apertus/ApertusWeightLoader.kt

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,12 +120,13 @@ public class ApertusWeightLoader private constructor(
120120
requiredTensorNames(metadata).forEach { name ->
121121
val rt = tensorByName[name]
122122
?: error("Missing required tensor in GGUF payload: $name")
123-
byName[name] = readerTensorToTensor(ctx, dtype, reader, rt)
123+
byName[name] = loadReaderTensor(ctx, dtype, reader, rt, name)
124124
}
125125

126126
// Load optional rope_freqs tensor
127127
tensorByName[ApertusTensorNames.ROPE_FREQS]?.let { rt ->
128-
byName[ApertusTensorNames.ROPE_FREQS] = readerTensorToTensor(ctx, dtype, reader, rt)
128+
byName[ApertusTensorNames.ROPE_FREQS] =
129+
loadReaderTensor(ctx, dtype, reader, rt, ApertusTensorNames.ROPE_FREQS)
129130
}
130131

131132
// Extract xIELU params: try metadata fields first, then per-layer tensors
@@ -162,12 +163,13 @@ public class ApertusWeightLoader private constructor(
162163
requiredTensorNames(metadata).forEach { name ->
163164
val st = tensorByName[name]
164165
?: error("Missing required tensor in GGUF payload: $name")
165-
byName[name] = streamingTensorToTensor(ctx, dtype, reader, st)
166+
byName[name] = loadStreamingTensor(ctx, dtype, reader, st, name)
166167
}
167168

168169
// Load optional rope_freqs tensor
169170
tensorByName[ApertusTensorNames.ROPE_FREQS]?.let { st ->
170-
byName[ApertusTensorNames.ROPE_FREQS] = streamingTensorToTensor(ctx, dtype, reader, st)
171+
byName[ApertusTensorNames.ROPE_FREQS] =
172+
loadStreamingTensor(ctx, dtype, reader, st, ApertusTensorNames.ROPE_FREQS)
171173
}
172174

173175
// Extract xIELU params: try metadata fields first, then per-layer tensors
@@ -560,6 +562,58 @@ public class ApertusWeightLoader private constructor(
560562

561563
// ============== Tensor conversion ==============
562564

565+
/**
566+
* NATIVE_OPTIMIZED stores quantized tensors as byte-level rank-1 buffers so the
567+
* native FFM kernels can address the raw block layout directly. That works for
568+
* matmul (the kernel knows the logical shape from metadata) but breaks the
569+
* token embedding, where `Embedding.gather()` requires the logical rank-2
570+
* `[vocab, dim]` shape. Force `token_embd.weight` through the dequant path so
571+
* the embedding lookup gets a real `[vocab, dim]` FP32/FP16 tensor regardless
572+
* of the policy chosen for the rest of the model.
573+
*/
574+
private fun <T : DType, V> loadStreamingTensor(
575+
ctx: ExecutionContext,
576+
dtype: KClass<T>,
577+
reader: StreamingGGUFReader,
578+
st: StreamingTensorInfo,
579+
name: String
580+
): Tensor<T, V> {
581+
if (name == ApertusTensorNames.TOKEN_EMBEDDINGS &&
582+
quantPolicy == QuantPolicy.NATIVE_OPTIMIZED &&
583+
st.tensorType != GGMLQuantizationType.F32 &&
584+
st.tensorType != GGMLQuantizationType.F16 &&
585+
st.tensorType != GGMLQuantizationType.BF16
586+
) {
587+
val shape = Shape(*st.shape.map { it.toInt() }.toIntArray())
588+
val bytes = reader.loadTensorData(st)
589+
val floats = DequantOps.dequantFromBytes(bytes, st.tensorType, st.nElements.toInt())
590+
return createTensor(ctx, dtype, shape, floats)
591+
}
592+
return streamingTensorToTensor(ctx, dtype, reader, st)
593+
}
594+
595+
private fun <T : DType, V> loadReaderTensor(
596+
ctx: ExecutionContext,
597+
dtype: KClass<T>,
598+
reader: GGUFReader,
599+
rt: ReaderTensor,
600+
name: String
601+
): Tensor<T, V> {
602+
if (name == ApertusTensorNames.TOKEN_EMBEDDINGS &&
603+
quantPolicy == QuantPolicy.NATIVE_OPTIMIZED &&
604+
rt.tensorType != GGMLQuantizationType.F32 &&
605+
rt.tensorType != GGMLQuantizationType.F16 &&
606+
rt.tensorType != GGMLQuantizationType.BF16
607+
) {
608+
val shape = Shape(*rt.shape.map { it.toInt() }.toIntArray())
609+
val raw = if (rt.data.isEmpty()) reader.materialize(rt) else rt.data
610+
val bytes: ByteArray = DequantOps.toByteArray(raw, rt.name)
611+
val floats = DequantOps.dequantFromBytes(bytes, rt.tensorType, rt.nElements)
612+
return createTensor(ctx, dtype, shape, floats)
613+
}
614+
return readerTensorToTensor(ctx, dtype, reader, rt)
615+
}
616+
563617
@Suppress("UNCHECKED_CAST")
564618
private fun <T : DType, V> readerTensorToTensor(
565619
ctx: ExecutionContext,

llm-inference/apertus/src/jvmTest/kotlin/sk/ainet/models/apertus/ApertusRealGgufLoadingTest.kt

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,34 @@ class ApertusRealGgufLoadingTest {
180180
println("[real-load fromGguf NATIVE_OPTIMIZED] top-modules=${topNames.size}")
181181
}
182182

183+
/**
184+
* End-to-end inference (forward / generate / tool calling) is intentionally
185+
* NOT covered here.
186+
*
187+
* `ApertusNetworkLoader.fromGguf().load()` succeeds end-to-end (verified by
188+
* the test above), and the embedding lookup works after the
189+
* `loadStreamingTensor` token-embd dequant special case. But the rest of
190+
* the forward pass — Q/K/V/O projections, FFN matmuls — relies on the
191+
* standard `linearProject(ops, input, weight) = ops.matmul(input, ops.transpose(weight))`
192+
* helper, which assumes a logical rank-2 weight. Under
193+
* `QuantPolicy.NATIVE_OPTIMIZED` the loader stores quantized weights as
194+
* raw byte-level rank-1 `Int8` tensors so the native FFM kernels can
195+
* address the block layout directly — but `ops.transpose(byteShape)` then
196+
* fails.
197+
*
198+
* Gemma's Q4_K end-to-end test works because Gemma's loader uses
199+
* `Q4_KBlockTensorData(logicalShape, blockMajorBytes)` with a lazy
200+
* `transpose` override and a quant-aware `matmul` dispatch (see
201+
* `GemmaDslQ4KTest`, `relayoutQ4_KRowMajorToBlockMajor`). Apertus's
202+
* loader stores raw Int8 bytes instead, so `linearProject` blows up at
203+
* the first attention projection.
204+
*
205+
* Tracking issue: see the upstream / transformers follow-up — the
206+
* Apertus loader needs per-quant-type tensor-data wrappers
207+
* (`Q4_KBlockTensorData` / `Q5_KBlockTensorData` / `Q6_KBlockTensorData`)
208+
* with row-major → block-major relayout, mirroring Gemma's path.
209+
*/
210+
183211
private fun locateModel(): File? {
184212
System.getenv("APERTUS_GGUF_PATH")?.let { p ->
185213
val f = File(p)

0 commit comments

Comments
 (0)