55
66from typing import Callable , Iterable , TYPE_CHECKING
77
8+ import numpy as np
89import torch
910
1011if TYPE_CHECKING :
1112 from torch import Tensor
1213
13- from .base import ModelBase , TextModel , gguf
14+ from .base import ModelBase , TextModel , gguf , logger
1415
1516
1617@ModelBase .register (
2122 "VLlama3ForCausalLM" ,
2223 "LlavaForConditionalGeneration" ,
2324 "VoxtralForConditionalGeneration" ,
25+ "LlamaForCausalLMEagle3" ,
26+ "Eagle3Speculator" ,
27+ "Eagle3DraftModel" ,
2428 "IQuestCoderForCausalLM" ,
2529 "LlamaModel" )
2630class LlamaModel (TextModel ):
@@ -39,7 +43,61 @@ def __init__(self, *args, **kwargs):
3943 hparams = ModelBase .load_hparams (self .dir_model , is_mistral_format = False )
4044 self .origin_hf_arch = hparams .get ('architectures' , [None ])[0 ]
4145
46+ # Detect eagle3 draft checkpoint by hparams (some models don't use a distinct HF arch name)
47+ if "draft_vocab_size" in self .hparams and self .hparams ["num_hidden_layers" ] == 1 :
48+ self .is_eagle3 = True
49+ self .model_arch = gguf .MODEL_ARCH .EAGLE3
50+ logger .info ("Detected EAGLE-3 draft model, switching to EAGLE3 architecture" )
51+ # Re-initialize tensor_map with eagle3 architecture
52+ self .tensor_map = gguf .get_tensor_name_map (self .model_arch , self .block_count )
53+ # Update gguf_writer architecture
54+ self .gguf_writer .arch = gguf .MODEL_ARCH_NAMES [self .model_arch ]
55+ self .gguf_writer .add_architecture ()
56+ if self .target_model_dir is None :
57+ raise ValueError (
58+ "EAGLE-3 model requires --target-model-dir to be specified. "
59+ "Please provide the path to the target model directory to read config.json"
60+ )
61+ # Read both eagle3 raw config and target model config
62+ with open (self .dir_model / "config.json" , 'r' , encoding = 'utf-8' ) as f :
63+ eagle3_raw_config = json .load (f )
64+ with open (self .target_model_dir / "config.json" , 'r' , encoding = 'utf-8' ) as f :
65+ target_config = json .load (f )
66+
67+ if "text_config" in target_config :
68+ target_config = {** target_config , ** target_config ["text_config" ]}
69+ self .target_vocab_size = target_config ["vocab_size" ]
70+
71+ # target_layers: derived from target model layer count (low/mid/high)
72+ target_num_layers = target_config ["num_hidden_layers" ]
73+ target_layers = [2 , target_num_layers // 2 , target_num_layers - 3 ]
74+ logger .info (f"EAGLE-3: target_layers = { target_layers } (target model has { target_num_layers } layers)" )
75+ self .gguf_writer .add_array (f"{ self .gguf_writer .arch } .target_layers" , target_layers )
76+
77+ # target_hidden_size: prefer eagle3 config, fallback to target config
78+ if eagle3_raw_config .get ("target_hidden_size" ) is not None :
79+ target_hidden_size = eagle3_raw_config ["target_hidden_size" ]
80+ src = "EAGLE-3 config"
81+ else :
82+ target_hidden_size = target_config ["hidden_size" ]
83+ src = "target model config"
84+ logger .info (f"EAGLE-3: target_hidden_size = { target_hidden_size } (from { src } )" )
85+ self .gguf_writer .add_uint32 (f"{ self .gguf_writer .arch } .target_hidden_size" , target_hidden_size )
86+
87+ # norm_before_residual (RedHat-style eagle3 specific)
88+ norm_before_residual = eagle3_raw_config .get ("norm_before_residual" , False )
89+ logger .info (f"EAGLE-3: norm_before_residual = { norm_before_residual } " )
90+ self .gguf_writer .add_bool (f"{ self .gguf_writer .arch } .norm_before_residual" , norm_before_residual )
91+
4292 def set_vocab (self ):
93+ # eagle3: use tokenizer from target model if provided
94+ original_dir_model = None
95+ if getattr (self , 'is_eagle3' , False ):
96+ assert self .target_model_dir is not None
97+ logger .info (f"EAGLE-3: Using tokenizer from target model: { self .target_model_dir } " )
98+ original_dir_model = self .dir_model
99+ self .dir_model = self .target_model_dir
100+
43101 if self .origin_hf_arch == "GlmasrModel" :
44102 return self ._set_vocab_glmedge ()
45103
@@ -85,6 +143,10 @@ def set_vocab(self):
85143 if self .hparams .get ("vocab_size" , 32000 ) == 49152 :
86144 self .gguf_writer .add_add_bos_token (False )
87145
146+ # eagle3: Restore original dir_model
147+ if original_dir_model is not None :
148+ self .dir_model = original_dir_model
149+
88150 def set_gguf_parameters (self ):
89151 super ().set_gguf_parameters ()
90152 hparams = self .hparams
@@ -129,7 +191,49 @@ def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Ca
129191
130192 return super ().filter_tensors ((name , gen ))
131193
194+ def index_tensors (self , remote_hf_model_id : str | None = None ) -> dict [str , Callable [[], Tensor ]]:
195+ tensors = super ().index_tensors (remote_hf_model_id )
196+
197+ # Handle Eagle3Speculator nested config
198+ if "transformer_layer_config" in self .hparams :
199+ self .hparams = {** self .hparams , ** self .hparams ["transformer_layer_config" ]}
200+
201+ # eagle3 detection
202+ if "draft_vocab_size" in self .hparams and self .hparams ["num_hidden_layers" ] == 1 :
203+ logger .info ("EAGLE-3: renaming midlayer.* / layers.0.* to model.layers.0.*" )
204+ new_tensors = {}
205+ for name , gen in tensors .items ():
206+ if name .startswith ("midlayer." ):
207+ new_name = "model.layers.0." + name [len ("midlayer." ):]
208+ new_tensors [new_name ] = gen
209+ elif name .startswith ("layers.0." ): # Eagle3Speculator format
210+ new_name = "model." + name
211+ new_tensors [new_name ] = gen
212+ else :
213+ new_tensors [name ] = gen
214+ return new_tensors
215+
216+ return tensors
217+
132218 def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
219+ # eagle3: special tensors that bypass standard llama mapping
220+ if getattr (self , 'is_eagle3' , False ):
221+ if name == "fc.weight" :
222+ yield (name , data_torch )
223+ return
224+ if name == "d2t" :
225+ # store for manual int64 handling in prepare_tensors (avoid F32 conversion)
226+ if not hasattr (self , '_eagle3_int_tensors' ):
227+ self ._eagle3_int_tensors = {}
228+ self ._eagle3_int_tensors [name ] = data_torch
229+ return
230+ if name == "t2d" :
231+ # not used at runtime, skip
232+ return
233+ if name .endswith (".hidden_norm.weight" ):
234+ yield (self .format_tensor_name (gguf .MODEL_TENSOR .ATTN_NORM_2 , bid ), data_torch )
235+ return
236+
133237 n_head = self .find_hparam (["n_heads" , "num_attention_heads" ])
134238 n_kv_head = self .find_hparam (["n_kv_heads" , "num_key_value_heads" ])
135239
@@ -205,8 +309,33 @@ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
205309 yield (self .format_tensor_name (gguf .MODEL_TENSOR .ROPE_FREQS ), torch .tensor (rope_factors , dtype = torch .float32 ))
206310
207311 def prepare_tensors (self ):
312+ # eagle3: collect d2t original dtype before parent converts tensors to F32
313+ eagle3_original_dtypes = {}
314+ if getattr (self , 'is_eagle3' , False ):
315+ for name , data_torch in self .get_tensors ():
316+ if name == "d2t" :
317+ eagle3_original_dtypes [name ] = data_torch .dtype
318+
208319 super ().prepare_tensors ()
209320
321+ # eagle3: write d2t as absolute target token ids
322+ if getattr (self , 'is_eagle3' , False ) and hasattr (self , '_eagle3_int_tensors' ):
323+ for name , data_torch in self ._eagle3_int_tensors .items ():
324+ old_dtype = eagle3_original_dtypes .get (name , data_torch .dtype )
325+ data = data_torch .to (torch .int64 ).cpu ().numpy ()
326+ if name == "d2t" :
327+ data = data .reshape (- 1 )
328+ data = data + np .arange (data .size , dtype = np .int64 )
329+ if np .any ((data < 0 ) | (data >= self .target_vocab_size )):
330+ raise ValueError (f"EAGLE-3 d2t target ids out of range for target vocab size { self .target_vocab_size } " )
331+ if np .unique (data ).size != data .size :
332+ raise ValueError ("EAGLE-3 d2t contains duplicate target ids" )
333+ data_qtype = gguf .GGMLQuantizationType .I64
334+
335+ shape_str = f"{{{ ', ' .join (str (n ) for n in reversed (data .shape ))} }}"
336+ logger .info (f"{ name + ',' :<30} { old_dtype } --> { data_qtype .name } , shape = { shape_str } " )
337+ self .gguf_writer .add_tensor (name , data , raw_dtype = data_qtype )
338+
210339 if self ._experts is not None :
211340 # flatten `list[dict[str, Tensor]]` into `list[str]`
212341 experts = [k for d in self ._experts for k in d .keys ()]
0 commit comments