Skip to content

Commit 9d2fdcf

Browse files
committed
eagle3: support Gemma4 eagle3 from RedHatAI
1 parent 71eb0c7 commit 9d2fdcf

3 files changed

Lines changed: 8 additions & 0 deletions

File tree

conversion/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,9 @@
124124
"LlamaBidirectionalModel": "llama",
125125
"LlamaForCausalLM": "llama",
126126
"LlamaModel": "llama",
127+
"Eagle3DraftModel": "llama",
128+
"Eagle3Speculator": "llama",
129+
"LlamaForCausalLMEagle3": "llama",
127130
"LlavaForConditionalGeneration": "llama",
128131
"LlavaStableLMEpochForCausalLM": "stablelm",
129132
"MPTForCausalLM": "mpt",

conversion/llama.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ def __init__(self, *args, **kwargs):
6363
with open(self.target_model_dir / "config.json", 'r', encoding='utf-8') as f:
6464
target_config = json.load(f)
6565

66+
if "text_config" in target_config:
67+
target_config = {**target_config, **target_config["text_config"]}
68+
6669
# extract_layers: derived from target model layer count (low/mid/high)
6770
target_num_layers = target_config["num_hidden_layers"]
6871
extract_layers = [2, target_num_layers // 2, target_num_layers - 3]

src/models/gemma4.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,8 @@ llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_para
183183
const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
184184
const int n_rot_l = hparams.n_rot(il);
185185

186+
res->t_layer_inp[il] = inpL;
187+
186188
// norm
187189
cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
188190
cb(cur, "attn_norm", il);

0 commit comments

Comments
 (0)