Skip to content

Commit 4e6a606

Browse files
authored
fix: initialize add_inputs_embeds to avoid UnboundLocalError in eagle3 (#221)
1 parent 5c9955e commit 4e6a606

1 file changed

Lines changed: 4 additions & 1 deletion

File tree

  • angelslim/compressor/speculative/utils

angelslim/compressor/speculative/utils/util.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def initialize_tree(input_ids, inputs_embeds, model, past_key_values, logits_pro
105105
token = token[None, None]
106106
input_ids = torch.cat((input_ids, token.to(input_ids.device)), dim=1)
107107
# add embedding
108+
add_inputs_embeds = None
108109
if inputs_embeds is not None:
109110
add_inputs_embeds = torch.cat(
110111
[inputs_embeds, model.eagle_layer.embed_tokens(token)], dim=1
@@ -322,16 +323,18 @@ def update_inference_inputs(
322323
]
323324

324325
# add embedding
326+
tmp_inputs_embeds = None
325327
if inputs_embeds is not None:
326328
add_inputs_embeds = model.eagle_layer.embed_tokens.weight[
327329
sample_token.squeeze(0).tolist()
328330
].unsqueeze(0)
331+
tmp_inputs_embeds = torch.cat([inputs_embeds, add_inputs_embeds], dim=1)
329332

330333
draft_tokens, retrieve_indices, tree_mask, tree_position_ids, early_stop_signal = (
331334
model.eagle_layer.topK_genrate(
332335
accept_hidden_state_new,
333336
input_ids=torch.cat((input_ids, sample_token.to(input_ids.device)), dim=1),
334-
inputs_embeds=torch.cat([inputs_embeds, add_inputs_embeds], dim=1),
337+
inputs_embeds=tmp_inputs_embeds,
335338
logits_processor=logits_processor,
336339
)
337340
)

0 commit comments

Comments
 (0)