Skip to content

Commit e212240

Browse files
michalharakalclaude
andcommitted
fix: support tied embeddings in LlamaWeightLoader
Small models like Qwen2.5-0.5B/1.5B tie their input and output embeddings (output.weight = token_embd.weight) to save parameters. The GGUF file omits output.weight in this case, causing "Missing required tensor" errors during load. Detect missing output.weight when token_embd.weight is present and alias the lookup to reuse the embedding tensor as the LM head. Logs "Tied word embeddings" when this path is taken. Verified with Qwen2.5-0.5B-Instruct-Q8_0 which now loads correctly and reaches the tool calling demo. (Output quality still limited by a separate byte-level BPE tokenizer issue.) Refs: #49 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent e595a02 commit e212240

1 file changed

Lines changed: 29 additions & 2 deletions

File tree

llm-inference/llama/src/commonMain/kotlin/sk/ainet/models/llama/LlamaWeightLoader.kt

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,20 @@ public class LlamaWeightLoader private constructor(
243243
val required = requiredTensorNames(metadata)
244244
val tensorByName = reader.tensors.associateBy { it.name }
245245

246+
// Tied embeddings (Qwen2.5-0.5B/1.5B, Gemma, etc.): reuse token_embd.weight as output.weight
247+
val tiedEmbeddings = tensorByName[LlamaTensorNames.OUTPUT_WEIGHT] == null &&
248+
tensorByName[LlamaTensorNames.TOKEN_EMBEDDINGS] != null
249+
if (tiedEmbeddings) {
250+
println("Tied word embeddings: output.weight = token_embd.weight")
251+
}
252+
246253
required.forEach { name ->
247-
val rt = tensorByName[name]
254+
val lookupName = if (name == LlamaTensorNames.OUTPUT_WEIGHT && tiedEmbeddings) {
255+
LlamaTensorNames.TOKEN_EMBEDDINGS
256+
} else {
257+
name
258+
}
259+
val rt = tensorByName[lookupName]
248260
?: error("Missing required tensor in GGUF payload: $name")
249261
validateTensorShape(name, rt, metadata)
250262
val tensor: Tensor<T, V> = readerTensorToTensor(ctx, dtype, reader, rt)
@@ -299,9 +311,24 @@ public class LlamaWeightLoader private constructor(
299311
val required = requiredTensorNames(metadata)
300312
val tensorByName = reader.tensors.associateBy { it.name }
301313

314+
// Tied embeddings: small models (Qwen2.5-0.5B/1.5B, etc.) omit output.weight
315+
// and reuse token_embd.weight as the LM head. Detect and alias.
316+
val tiedEmbeddings = tensorByName[LlamaTensorNames.OUTPUT_WEIGHT] == null &&
317+
tensorByName[LlamaTensorNames.TOKEN_EMBEDDINGS] != null
318+
if (tiedEmbeddings) {
319+
println("Tied word embeddings: output.weight = token_embd.weight")
320+
}
321+
302322
required.forEach { name ->
303-
val st = tensorByName[name]
323+
val lookupName = if (name == LlamaTensorNames.OUTPUT_WEIGHT && tiedEmbeddings) {
324+
LlamaTensorNames.TOKEN_EMBEDDINGS
325+
} else {
326+
name
327+
}
328+
val st = tensorByName[lookupName]
304329
?: error("Missing required tensor in GGUF payload: $name")
330+
// Shape validation uses the logical name (e.g., OUTPUT_WEIGHT) even when
331+
// the physical tensor is TOKEN_EMBEDDINGS — both must have [vocab, dim] shape.
305332
validateStreamingTensorShape(name, st, metadata)
306333
val tensor: Tensor<T, V> = streamingTensorToTensor(ctx, dtype, reader, st)
307334
onTensorLoaded(name, tensor)

0 commit comments

Comments
 (0)