Skip to content

Commit e8ddf01

Browse files
ruixiang63Dogacel
andcommitted
eagle3 : fix ubatch handling in embd_layer_inp extraction and encoder
Co-authored-by: Doğaç Eldenk <dogacel@gmail.com>
1 parent acc31b1 commit e8ddf01

4 files changed

Lines changed: 73 additions & 29 deletions

File tree

common/speculative.cpp

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,7 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
432432

433433
// scratch buffer for concatenated target features [n_tokens, n_embd_enc]
434434
std::vector<float> features_buf;
435+
std::vector<float> g_embd_buf;
435436

436437
common_speculative_impl_draft_eagle3(const common_params_speculative & params, uint32_t n_seq)
437438
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, n_seq)
@@ -569,25 +570,39 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
569570
}
570571
}
571572

572-
llama_batch enc_batch = {
573-
/*.n_tokens =*/ n_tokens,
574-
/*.token =*/ nullptr,
575-
/*.embd =*/ features_buf.data(),
576-
/*.pos =*/ nullptr,
577-
/*.n_seq_id =*/ nullptr,
578-
/*.seq_id =*/ nullptr,
579-
/*.logits =*/ nullptr,
580-
};
581-
int rc = llama_encode(ctx_dft, enc_batch);
582-
if (rc != 0) {
583-
LOG_ERR("%s: llama_encode(ctx_dft) failed rc=%d (n_tokens=%d)\n",
584-
__func__, rc, (int) n_tokens);
585-
return false;
573+
g_embd_buf.resize((size_t) n_tokens * n_embd_dec);
574+
575+
// llama_encode() requires the full encoder batch to fit in n_ubatch.
576+
// Allow batch > ubatch: eagle3's per-token encoder can be chunked safely.
577+
const int32_t n_ubatch_dft = (int32_t) llama_n_ubatch(ctx_dft);
578+
for (int32_t i = 0; i < n_tokens; i += n_ubatch_dft) {
579+
const int32_t n_chunk = std::min(n_ubatch_dft, n_tokens - i);
580+
581+
llama_batch enc_batch = {
582+
/*.n_tokens =*/ n_chunk,
583+
/*.token =*/ nullptr,
584+
/*.embd =*/ features_buf.data() + (size_t) i * n_embd_enc,
585+
/*.pos =*/ nullptr,
586+
/*.n_seq_id =*/ nullptr,
587+
/*.seq_id =*/ nullptr,
588+
/*.logits =*/ nullptr,
589+
};
590+
const int32_t rc = llama_encode(ctx_dft, enc_batch);
591+
if (rc != 0) {
592+
LOG_ERR("%s: llama_encode(ctx_dft) failed rc=%d (n_tokens=%d, offset=%d)\n",
593+
__func__, rc, (int) n_chunk, (int) i);
594+
return false;
595+
}
596+
597+
// g_embd has shape [n_chunk, n_embd_dec] in ctx_dft's pre-norm embeddings buffer.
598+
const float * g_embd_chunk = llama_get_embeddings_pre_norm(ctx_dft);
599+
GGML_ASSERT(g_embd_chunk && "EAGLE3 encoder produced no output.");
600+
std::memcpy(g_embd_buf.data() + (size_t) i * n_embd_dec,
601+
g_embd_chunk,
602+
(size_t) n_chunk * n_embd_dec * sizeof(float));
586603
}
587604

588-
// g_embd has shape [n_tokens, n_embd_dec] in ctx_dft's pre-norm embeddings buffer
589-
const float * g_embd = llama_get_embeddings_pre_norm(ctx_dft);
590-
GGML_ASSERT(g_embd && "EAGLE3 encoder produced no output.");
605+
const float * g_embd = g_embd_buf.data();
591606

592607
const size_t row_bytes = (size_t) n_embd_dec * sizeof(float);
593608

@@ -648,7 +663,7 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
648663
}
649664

650665
if (batch.n_tokens > 0) {
651-
rc = llama_decode(ctx_dft, batch);
666+
const int32_t rc = llama_decode(ctx_dft, batch);
652667
if (rc != 0) {
653668
LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (n_tokens=%d, ubatch_pos[0]=%d)\n",
654669
__func__, rc, (int) batch.n_tokens, (int) batch_in.pos[0]);

src/llama-context.cpp

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1260,10 +1260,10 @@ void llama_context::set_output_layer_inp(uint32_t layer_id, bool enable) {
12601260
}
12611261

12621262
float * llama_context::get_output_layer_inp(uint32_t layer_id) {
1263-
if (layer_id >= embd_layer_inp.size() || embd_layer_inp[layer_id].empty()) {
1263+
if (layer_id >= embd_layer_inp.size() || !embd_layer_inp[layer_id].has_data()) {
12641264
return nullptr;
12651265
}
1266-
return embd_layer_inp[layer_id].data();
1266+
return embd_layer_inp[layer_id].data;
12671267
}
12681268

12691269
llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
@@ -1960,7 +1960,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
19601960
}
19611961
}
19621962

1963-
extract_layer_inputs(res);
1963+
extract_layer_inputs(res, n_tokens_prev, ubatch.n_tokens);
19641964

19651965
// extract pre-norm embeddings (hidden state before the final output norm)
19661966
// only meaningful in LLAMA_POOLING_TYPE_NONE (per-token); other pooling modes are ignored.
@@ -2081,6 +2081,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
20812081

20822082
size_t backend_float_count = 0;
20832083
size_t backend_token_count = 0;
2084+
size_t embd_layer_inp_float_count = 0;
20842085

20852086
logits.size = has_logits ? n_vocab*n_outputs_max : 0;
20862087
embd.size = has_embd ? n_embd_out*n_outputs_max : 0;
@@ -2092,6 +2093,12 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
20922093
embd_pre_norm.size = (size_t) n_embd * n_batch;
20932094
}
20942095

2096+
for (bool enabled : cparams.output_layer_inp) {
2097+
if (enabled) {
2098+
embd_layer_inp_float_count += (size_t) n_embd * n_batch;
2099+
}
2100+
}
2101+
20952102
// Allocate backend sampling output buffers if there are backend samplers configured.
20962103
const bool has_sampling = !sampling.samplers.empty();
20972104
if (has_sampling) {
@@ -2106,8 +2113,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
21062113

21072114
const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0;
21082115
const size_t new_size =
2109-
(logits.size + embd.size + embd_pre_norm.size + backend_float_count) * sizeof(float) +
2110-
( backend_token_count) * sizeof(llama_token);
2116+
(logits.size + embd.size + embd_pre_norm.size + embd_layer_inp_float_count + backend_float_count) * sizeof(float) +
2117+
( backend_token_count) * sizeof(llama_token);
21112118

21122119
// alloc only when more than the current capacity is required
21132120
// TODO: also consider shrinking the buffer
@@ -2124,6 +2131,9 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
21242131
logits.data = nullptr;
21252132
embd.data = nullptr;
21262133
embd_pre_norm.data = nullptr;
2134+
for (auto & layer_inp : embd_layer_inp) {
2135+
layer_inp = {nullptr, 0};
2136+
}
21272137
}
21282138

21292139
auto * buft = ggml_backend_cpu_buffer_type();
@@ -2155,6 +2165,15 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
21552165
embd_pre_norm = has_embd_pre_norm ? buffer_view<float>{(float *) (base + offset), embd_pre_norm.size} : buffer_view<float>{nullptr, 0};
21562166
offset += embd_pre_norm.size * sizeof(float);
21572167

2168+
for (uint32_t il = 0; il < embd_layer_inp.size(); ++il) {
2169+
if (cparams.output_layer_inp[il]) {
2170+
embd_layer_inp[il] = buffer_view<float>{(float *) (base + offset), (size_t) n_embd * n_batch};
2171+
offset += embd_layer_inp[il].size * sizeof(float);
2172+
} else {
2173+
embd_layer_inp[il] = buffer_view<float>{nullptr, 0};
2174+
}
2175+
}
2176+
21582177
if (has_sampling) {
21592178
sampling.logits = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)};
21602179
offset += sampling.logits.size * sizeof(float);
@@ -2199,20 +2218,30 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
21992218
return n_outputs_max;
22002219
}
22012220

2202-
void llama_context::extract_layer_inputs(const llm_graph_result * res) {
2221+
void llama_context::extract_layer_inputs(const llm_graph_result * res, size_t token_offset, size_t n_tokens) {
22032222
for (uint32_t il = 0; il < cparams.output_layer_inp.size(); ++il) {
22042223
if (!cparams.output_layer_inp[il]) {
22052224
continue;
22062225
}
2226+
if (!embd_layer_inp[il].has_data()) {
2227+
continue;
2228+
}
22072229
ggml_tensor * t = res->get_layer_inp((int) il);
22082230
if (!t) {
22092231
continue;
22102232
}
22112233
const size_t nbytes = ggml_nbytes(t);
2212-
embd_layer_inp[il].resize(nbytes / sizeof(float));
2234+
const size_t nfloats = nbytes / sizeof(float);
2235+
GGML_ASSERT(n_tokens > 0);
2236+
GGML_ASSERT(nfloats % n_tokens == 0);
2237+
2238+
const size_t row_floats = nfloats / n_tokens;
2239+
const size_t dst_offset = token_offset * row_floats;
2240+
GGML_ASSERT(dst_offset + nfloats <= embd_layer_inp[il].size);
2241+
22132242
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched.get(), t);
22142243
GGML_ASSERT(backend != nullptr);
2215-
ggml_backend_tensor_get_async(backend, t, embd_layer_inp[il].data(), 0, nbytes);
2244+
ggml_backend_tensor_get_async(backend, t, embd_layer_inp[il].data + dst_offset, 0, nbytes);
22162245
}
22172246
}
22182247

src/llama-context.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ struct llama_context {
232232

233233
// async-copy enabled layer-input tensors (per cparams.output_layer_inp)
234234
// from backend into host-side embd_layer_inp buffers
235-
void extract_layer_inputs(const llm_graph_result * res);
235+
void extract_layer_inputs(const llm_graph_result * res, size_t token_offset, size_t n_tokens);
236236

237237
//
238238
// graph
@@ -364,7 +364,7 @@ struct llama_context {
364364

365365
// host buffer for output layer input embeddings, per layer
366366
// populated when cparams.output_layer_inp[il] is true
367-
std::vector<std::vector<float>> embd_layer_inp;
367+
std::vector<buffer_view<float>> embd_layer_inp;
368368

369369
// keep copies of the per-sequence memory on the device
370370
std::map<llama_seq_id, llama_memory_buffers> mem_storage;

src/llama-ext.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ LLAMA_API float * llama_get_embeddings_pre_norm_ith(struct llama_context * ctx,
112112
// set if the layer input embeddings should be outputed
113113
LLAMA_API void llama_set_output_layer_inp(struct llama_context * ctx, uint32_t layer_id, bool enable);
114114

115-
// read back the input embeddings of the specified layer for the most recent ubatch
115+
// read back the input embeddings of the specified layer for the most recent decode batch
116116
// the layer must have been enabled via llama_set_output_layer_inp
117117
LLAMA_API float * llama_get_output_layer_inp(struct llama_context * ctx, uint32_t layer_id);
118118

0 commit comments

Comments
 (0)