1010if TYPE_CHECKING :
1111 from torch import Tensor
1212
13- from .base import ModelBase , TextModel , gguf
13+ from .base import ModelBase , TextModel , gguf , logger
1414
1515
1616@ModelBase .register (
2121 "VLlama3ForCausalLM" ,
2222 "LlavaForConditionalGeneration" ,
2323 "VoxtralForConditionalGeneration" ,
24+ "LlamaForCausalLMEagle3" ,
25+ "Eagle3Speculator" ,
26+ "Eagle3DraftModel" ,
2427 "IQuestCoderForCausalLM" ,
2528 "LlamaModel" )
2629class LlamaModel (TextModel ):
@@ -39,7 +42,57 @@ def __init__(self, *args, **kwargs):
3942 hparams = ModelBase .load_hparams (self .dir_model , is_mistral_format = False )
4043 self .origin_hf_arch = hparams .get ('architectures' , [None ])[0 ]
4144
45+ # Detect eagle3 draft checkpoint by hparams (some models don't use a distinct HF arch name)
46+ if "draft_vocab_size" in self .hparams and self .hparams ["num_hidden_layers" ] == 1 :
47+ self .is_eagle3 = True
48+ self .model_arch = gguf .MODEL_ARCH .EAGLE3
49+ logger .info ("Detected EAGLE-3 draft model, switching to EAGLE3 architecture" )
50+ # Re-initialize tensor_map with eagle3 architecture
51+ self .tensor_map = gguf .get_tensor_name_map (self .model_arch , self .block_count )
52+ # Update gguf_writer architecture
53+ self .gguf_writer .arch = gguf .MODEL_ARCH_NAMES [self .model_arch ]
54+ self .gguf_writer .add_architecture ()
55+ if self .target_model_dir is None :
56+ raise ValueError (
57+ "EAGLE-3 model requires --target-model-dir to be specified. "
58+ "Please provide the path to the target model directory to read config.json"
59+ )
60+ # Read both eagle3 raw config and target model config
61+ with open (self .dir_model / "config.json" , 'r' , encoding = 'utf-8' ) as f :
62+ eagle3_raw_config = json .load (f )
63+ with open (self .target_model_dir / "config.json" , 'r' , encoding = 'utf-8' ) as f :
64+ target_config = json .load (f )
65+
66+ # extract_layers: derived from target model layer count (low/mid/high)
67+ target_num_layers = target_config ["num_hidden_layers" ]
68+ extract_layers = [2 , target_num_layers // 2 , target_num_layers - 3 ]
69+ logger .info (f"EAGLE-3: extract_layers = { extract_layers } (target model has { target_num_layers } layers)" )
70+ self .gguf_writer .add_array (f"{ self .gguf_writer .arch } .extract_layers" , extract_layers )
71+
72+ # target_hidden_size: prefer eagle3 config, fallback to target config
73+ if eagle3_raw_config .get ("target_hidden_size" ) is not None :
74+ target_hidden_size = eagle3_raw_config ["target_hidden_size" ]
75+ src = "EAGLE-3 config"
76+ else :
77+ target_hidden_size = target_config ["hidden_size" ]
78+ src = "target model config"
79+ logger .info (f"EAGLE-3: target_hidden_size = { target_hidden_size } (from { src } )" )
80+ self .gguf_writer .add_uint32 (f"{ self .gguf_writer .arch } .target_hidden_size" , target_hidden_size )
81+
82+ # norm_before_residual (RedHat-style eagle3 specific)
83+ norm_before_residual = eagle3_raw_config .get ("norm_before_residual" , False )
84+ logger .info (f"EAGLE-3: norm_before_residual = { norm_before_residual } " )
85+ self .gguf_writer .add_bool (f"{ self .gguf_writer .arch } .norm_before_residual" , norm_before_residual )
86+
4287 def set_vocab (self ):
88+ # eagle3: use tokenizer from target model if provided
89+ original_dir_model = None
90+ if getattr (self , 'is_eagle3' , False ):
91+ assert self .target_model_dir is not None
92+ logger .info (f"EAGLE-3: Using tokenizer from target model: { self .target_model_dir } " )
93+ original_dir_model = self .dir_model
94+ self .dir_model = self .target_model_dir
95+
4396 if self .origin_hf_arch == "GlmasrModel" :
4497 return self ._set_vocab_glmedge ()
4598
@@ -83,6 +136,10 @@ def set_vocab(self):
83136 if self .hparams .get ("vocab_size" , 32000 ) == 49152 :
84137 self .gguf_writer .add_add_bos_token (False )
85138
139+ # eagle3: Restore original dir_model
140+ if original_dir_model is not None :
141+ self .dir_model = original_dir_model
142+
86143 def set_gguf_parameters (self ):
87144 super ().set_gguf_parameters ()
88145 hparams = self .hparams
@@ -127,7 +184,49 @@ def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Ca
127184
128185 return super ().filter_tensors ((name , gen ))
129186
187+ def index_tensors (self , remote_hf_model_id : str | None = None ) -> dict [str , Callable [[], Tensor ]]:
188+ tensors = super ().index_tensors (remote_hf_model_id )
189+
190+ # Handle Eagle3Speculator nested config
191+ if "transformer_layer_config" in self .hparams :
192+ self .hparams = {** self .hparams , ** self .hparams ["transformer_layer_config" ]}
193+
194+ # eagle3 detection
195+ if "draft_vocab_size" in self .hparams and self .hparams ["num_hidden_layers" ] == 1 :
196+ logger .info ("EAGLE-3: renaming midlayer.* / layers.0.* to model.layers.0.*" )
197+ new_tensors = {}
198+ for name , gen in tensors .items ():
199+ if name .startswith ("midlayer." ):
200+ new_name = "model.layers.0." + name [len ("midlayer." ):]
201+ new_tensors [new_name ] = gen
202+ elif name .startswith ("layers.0." ): # Eagle3Speculator format
203+ new_name = "model." + name
204+ new_tensors [new_name ] = gen
205+ else :
206+ new_tensors [name ] = gen
207+ return new_tensors
208+
209+ return tensors
210+
130211 def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
212+ # eagle3: special tensors that bypass standard llama mapping
213+ if getattr (self , 'is_eagle3' , False ):
214+ if name == "fc.weight" :
215+ yield (name , data_torch )
216+ return
217+ if name == "d2t" :
218+ # store for manual int64 handling in prepare_tensors (avoid F32 conversion)
219+ if not hasattr (self , '_eagle3_int_tensors' ):
220+ self ._eagle3_int_tensors = {}
221+ self ._eagle3_int_tensors [name ] = data_torch
222+ return
223+ if name == "t2d" :
224+ # not used at runtime, skip
225+ return
226+ if name == "model.layers.0.hidden_norm.weight" :
227+ yield ("blk.0.hidden_norm.weight" , data_torch )
228+ return
229+
131230 n_head = self .find_hparam (["n_heads" , "num_attention_heads" ])
132231 n_kv_head = self .find_hparam (["n_kv_heads" , "num_key_value_heads" ])
133232
@@ -203,8 +302,26 @@ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
203302 yield (self .format_tensor_name (gguf .MODEL_TENSOR .ROPE_FREQS ), torch .tensor (rope_factors , dtype = torch .float32 ))
204303
205304 def prepare_tensors (self ):
305+ # eagle3: collect d2t original dtype before parent converts tensors to F32
306+ eagle3_original_dtypes = {}
307+ if getattr (self , 'is_eagle3' , False ):
308+ for name , data_torch in self .get_tensors ():
309+ if name == "d2t" :
310+ eagle3_original_dtypes [name ] = data_torch .dtype
311+
206312 super ().prepare_tensors ()
207313
314+ # eagle3: write d2t as int64 directly (not converted to F32)
315+ if getattr (self , 'is_eagle3' , False ) and hasattr (self , '_eagle3_int_tensors' ):
316+ for name , data_torch in self ._eagle3_int_tensors .items ():
317+ old_dtype = eagle3_original_dtypes .get (name , data_torch .dtype )
318+ data = data_torch .to (torch .int64 ).numpy ()
319+ data_qtype = gguf .GGMLQuantizationType .I64
320+
321+ shape_str = f"{{{ ', ' .join (str (n ) for n in reversed (data .shape ))} }}"
322+ logger .info (f"{ name + ',' :<30} { old_dtype } --> { data_qtype .name } , shape = { shape_str } " )
323+ self .gguf_writer .add_tensor (name , data , raw_dtype = data_qtype )
324+
208325 if self ._experts is not None :
209326 # flatten `list[dict[str, Tensor]]` into `list[str]`
210327 experts = [k for d in self ._experts for k in d .keys ()]
0 commit comments