@@ -105,6 +105,7 @@ def initialize_tree(input_ids, inputs_embeds, model, past_key_values, logits_pro
105105 token = token [None , None ]
106106 input_ids = torch .cat ((input_ids , token .to (input_ids .device )), dim = 1 )
107107 # add embedding
108+ add_inputs_embeds = None
108109 if inputs_embeds is not None :
109110 add_inputs_embeds = torch .cat (
110111 [inputs_embeds , model .eagle_layer .embed_tokens (token )], dim = 1
@@ -322,16 +323,18 @@ def update_inference_inputs(
322323 ]
323324
324325 # add embedding
326+ tmp_inputs_embeds = None
325327 if inputs_embeds is not None :
326328 add_inputs_embeds = model .eagle_layer .embed_tokens .weight [
327329 sample_token .squeeze (0 ).tolist ()
328330 ].unsqueeze (0 )
331+ tmp_inputs_embeds = torch .cat ([inputs_embeds , add_inputs_embeds ], dim = 1 )
329332
330333 draft_tokens , retrieve_indices , tree_mask , tree_position_ids , early_stop_signal = (
331334 model .eagle_layer .topK_genrate (
332335 accept_hidden_state_new ,
333336 input_ids = torch .cat ((input_ids , sample_token .to (input_ids .device )), dim = 1 ),
334- inputs_embeds = torch . cat ([ inputs_embeds , add_inputs_embeds ], dim = 1 ) ,
337+ inputs_embeds = tmp_inputs_embeds ,
335338 logits_processor = logits_processor ,
336339 )
337340 )
0 commit comments