Skip to content

Commit 71eb0c7

Browse files
committed
eagle3: fix params bug
1 parent db164a1 commit 71eb0c7

2 files changed

Lines changed: 5 additions & 5 deletions

File tree

common/speculative.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -404,8 +404,8 @@ struct common_speculative_impl_draft_simple : public common_speculative_impl {
404404
//
405405
// Performance is overall good but there is waste in verify cycle:
406406
// process() runs encoder + decoder on the *full* verify batch including rows for
407-
// rejected drafts. The KV at those positions is then dropped.
408-
//
407+
// rejected drafts. The KV at those positions is then dropped.
408+
//
409409
// TODO: Not sure if we need optimization for this waste?
410410
// If so we may need hybrid stash:
411411
// in verify mode, have process() only stash features and let draft() seed run
@@ -486,8 +486,8 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
486486
}
487487

488488
// turn on extraction of the draft model's pre-norm hidden state
489-
// (used both for the encoder output g_embd and the decoder pre-norm output)
490-
llama_set_embeddings_pre_norm(ctx_dft, true);
489+
// (used both for the encoder output g_embd and the decoder pre-norm output).
490+
llama_set_embeddings_pre_norm(ctx_dft, true, /*masked*/ true);
491491

492492
pending_g_last.assign(n_seq, std::vector<float>(n_embd_dec, 0.0f));
493493
pending_pos_last.assign(n_seq, -1);

src/llama-context.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ struct llama_context {
230230
// map the output row index `i` to batch index
231231
int64_t output_resolve_row(int32_t i) const;
232232

233-
// async-copy enabled layer-input tensors (per cparams.output_layer_inp)
233+
// async-copy enabled layer-input tensors (per cparams.output_layer_inp)
234234
// from backend into host-side embd_layer_inp buffers
235235
void extract_layer_inputs(const llm_graph_result * res);
236236

0 commit comments

Comments
 (0)