@@ -425,16 +425,23 @@ def _base_model_lm_head(self):
425425 @property
426426 def _base_llm_config (self ):
427427 """Return the llm config for the base model, from LLM or VLM."""
428- return self .config .llm_config if hasattr (self .config , "llm_config" ) else self .config
428+ # return self.config.llm_config if hasattr(self.config, "llm_config") else self.config
429+ return self .config .text_config
429430
430431 def _find_base_model_parts (self ):
431432 """Find model parts from different models and set base_{part}_path attributes."""
432433 base_model_parts_mapping = {
433- "base_model_path" : ["model" , "backbone" , "language_model.backbone" ],
434+ "base_model_path" : [
435+ "model.language_model" ,
436+ "model" ,
437+ "backbone" ,
438+ "language_model.backbone" ,
439+ ],
434440 "base_model_embeddings_path" : [
435441 "model.embed_tokens" ,
436442 "backbone.embeddings" ,
437443 "language_model.backbone.embeddings" ,
444+ "model.language_model.embed_tokens" ,
438445 ],
439446 "base_model_lm_head_path" : ["lm_head" , "language_model.lm_head" ],
440447 }
@@ -747,7 +754,8 @@ def _llm_or_vlm_embedding(self, input_ids, kwargs):
747754 del vit_embeds
748755 return tok_embeds .reshape (bs , seq_len , hid_size )
749756 else :
750- raise ValueError (f"VLM model type { self .config .model_type } not supported" )
757+ breakpoint ()
758+ # raise ValueError(f"VLM model type {self.config.model_type} not supported")
751759
752760 def _base_model_forward (
753761 self ,
@@ -769,6 +777,7 @@ def _base_model_forward(
769777 ** kwargs ,
770778 )
771779 past_key_values = getattr (outputs , "past_key_values" , None )
780+ input_embeds = outputs .hidden_states [0 ]
772781 base_model_hidden_states = outputs .hidden_states [- 1 ]
773782 base_model_logits = outputs .logits
774783
@@ -780,7 +789,13 @@ def _base_model_forward(
780789 labels = labels .view (- 1 )
781790 base_model_loss = loss_fct (loss_logits , labels )
782791
783- return base_model_hidden_states , base_model_logits , base_model_loss , past_key_values
792+ return (
793+ input_embeds ,
794+ base_model_hidden_states ,
795+ base_model_logits ,
796+ base_model_loss ,
797+ past_key_values ,
798+ )
784799
785800 def _map_logits_to_draft_vocab (self , full_logits ):
786801 reverse_mapping = (
@@ -872,16 +887,20 @@ def forward(
872887 base_model_logits = self .lm_head (base_model_hidden_states )
873888 base_model_loss , past_key_values = None , None
874889 else :
875- base_model_hidden_states , base_model_logits , base_model_loss , past_key_values = (
876- self ._base_model_forward (
877- input_ids ,
878- attention_mask ,
879- position_ids ,
880- past_key_values ,
881- self .eagle_freeze_base_model ,
882- labels ,
883- ** kwargs ,
884- )
890+ (
891+ base_input_embeds ,
892+ base_model_hidden_states ,
893+ base_model_logits ,
894+ base_model_loss ,
895+ past_key_values ,
896+ ) = self ._base_model_forward (
897+ input_ids ,
898+ attention_mask ,
899+ position_ids ,
900+ past_key_values ,
901+ self .eagle_freeze_base_model ,
902+ labels ,
903+ ** kwargs ,
885904 )
886905
887906 if not isinstance (past_key_values , Cache ):
@@ -912,7 +931,8 @@ def forward(
912931 eagle_cache ,
913932 )
914933 with torch .no_grad ():
915- inputs_embeds = self ._llm_or_vlm_embedding (eagle_input_ids , kwargs )
934+ # inputs_embeds = self._llm_or_vlm_embedding(eagle_input_ids, kwargs)
935+ inputs_embeds = base_input_embeds .roll (- 1 , 1 )
916936
917937 past_key_values .eagle_cache = eagle_cache
918938
0 commit comments