Skip to content

Commit faa1bc2

Browse files
authored
sampling : delegate input allocation to the scheduler (ggml-org#19266)
* sampling : delegate input allocation to the scheduler * graph : compute backend samplers only if needed
1 parent 32b17ab commit faa1bc2

3 files changed

Lines changed: 33 additions & 73 deletions

File tree

src/llama-context.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,11 +1027,7 @@ bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) {
10271027
llama_sampler_chain_n(sampler) > 0;
10281028

10291029
if (sampler && can_offload) {
1030-
ggml_backend_buffer_type_t buft = ggml_backend_dev_buffer_type(model.dev_output());
1031-
auto * host_buft = ggml_backend_dev_host_buffer_type(model.dev_output());
1032-
if (host_buft) {
1033-
buft = host_buft;
1034-
}
1030+
auto * buft = ggml_backend_dev_buffer_type(model.dev_output());
10351031

10361032
sampler->iface->backend_init(sampler, buft);
10371033

src/llama-graph.cpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2419,6 +2419,9 @@ void llm_graph_context::build_sampling() const {
24192419
return;
24202420
}
24212421

2422+
std::array<ggml_tensor *, 2> outs;
2423+
outs[0] = res->t_logits;
2424+
24222425
auto inp_sampling = std::make_unique<llm_graph_input_sampling>(samplers);
24232426
res->add_input(std::move(inp_sampling));
24242427

@@ -2439,14 +2442,14 @@ void llm_graph_context::build_sampling() const {
24392442
// add a dummy row of logits
24402443
// this trick makes the graph static, regardless of which samplers are activated
24412444
// this is important in order to minimize graph reallocations
2442-
// TODO: use `ggml_build_forward_select()` when available (https://github.com/ggml-org/llama.cpp/pull/18550)
24432445
ggml_tensor * logits_t = ggml_pad(ctx0, res->t_logits, 0, 1, 0, 0);
24442446

24452447
for (const auto & [seq_id, sampler] : samplers) {
24462448
const auto it = seq_to_logit_row.find(seq_id);
24472449

24482450
// inactive samplers always work on the first row
2449-
const auto row_idx = seq_to_logit_row.find(seq_id) != seq_to_logit_row.end() ? it->second : 0;
2451+
const auto row_idx = it != seq_to_logit_row.end() ? it->second : 0;
2452+
const int i_out = it != seq_to_logit_row.end() ? 1 : 0;
24502453

24512454
ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, logits_t->ne[0], row_idx * logits_t->nb[1]);
24522455
ggml_format_name(logits_seq, "logits_seq_%d", seq_id);
@@ -2463,22 +2466,26 @@ void llm_graph_context::build_sampling() const {
24632466

24642467
if (data.sampled != nullptr) {
24652468
res->t_sampled[seq_id] = data.sampled;
2466-
ggml_build_forward_expand(gf, data.sampled);
2469+
outs[1] = data.sampled;
2470+
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
24672471
}
24682472

24692473
if (data.probs != nullptr) {
24702474
res->t_sampled_probs[seq_id] = data.probs;
2471-
ggml_build_forward_expand(gf, data.probs);
2475+
outs[1] = data.probs;
2476+
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
24722477
}
24732478

24742479
if (data.logits != nullptr) {
24752480
res->t_sampled_logits[seq_id] = data.logits;
2476-
ggml_build_forward_expand(gf, data.logits);
2481+
outs[1] = data.logits;
2482+
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
24772483
}
24782484

24792485
if (data.candidates != nullptr) {
24802486
res->t_candidates[seq_id] = data.candidates;
2481-
ggml_build_forward_expand(gf, data.candidates);
2487+
outs[1] = data.candidates;
2488+
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
24822489
}
24832490
}
24842491

src/llama-sampling.cpp

Lines changed: 19 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,11 +1025,7 @@ struct llama_sampler_dist : public llama_sampler_backend {
10251025

10261026
std::mt19937 rng;
10271027

1028-
// backend input
1029-
struct ggml_tensor * inp_uniform;
1030-
1031-
ggml_context_ptr inp_ctx;
1032-
ggml_backend_buffer_ptr inp_buf;
1028+
ggml_tensor * inp_uniform;
10331029
};
10341030

10351031
static const char * llama_sampler_dist_name(const struct llama_sampler * smpl) {
@@ -1138,37 +1134,10 @@ static bool llama_sampler_dist_backend_init(
11381134
ggml_backend_buffer_type_t buft) {
11391135
auto * sctx = (llama_sampler_dist *) smpl->ctx;
11401136

1141-
// allocate inputs
1142-
{
1143-
ggml_init_params params = {
1144-
/*.mem_size =*/ ggml_tensor_overhead(),
1145-
/*.mem_buffer =*/ nullptr,
1146-
/*.no_alloc =*/ true,
1147-
};
1148-
1149-
sctx->inp_ctx.reset(ggml_init(params));
1150-
1151-
// Create the uniform random scalar input tensor. This will be set by
1152-
// llama_sampler_dist_backend_set_input after this graph is built.
1153-
sctx->inp_uniform = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1);
1154-
ggml_set_name (sctx->inp_uniform, "uniform");
1155-
ggml_set_input(sctx->inp_uniform);
1156-
1157-
// Allocate all tensors from our context to the backend
1158-
sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft));
1159-
1160-
ggml_backend_buffer_clear(sctx->inp_buf.get(), 0);
1161-
}
1162-
11631137
const bool res = llama_sampler_backend_support(smpl, buft);
11641138

11651139
sctx->init(res);
11661140

1167-
if (!res) {
1168-
sctx->inp_ctx.reset(nullptr);
1169-
sctx->inp_buf.reset(nullptr);
1170-
}
1171-
11721141
return res;
11731142
}
11741143

@@ -1178,8 +1147,13 @@ static void llama_sampler_dist_backend_apply(
11781147
struct ggml_cgraph * gf,
11791148
struct llama_sampler_data * data) {
11801149
GGML_UNUSED(gf);
1150+
11811151
auto * sctx = (llama_sampler_dist *) smpl->ctx;
11821152

1153+
sctx->inp_uniform = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
1154+
ggml_set_name (sctx->inp_uniform, "uniform");
1155+
ggml_set_input(sctx->inp_uniform);
1156+
11831157
struct ggml_tensor * probs = ggml_soft_max(ctx, data->logits);
11841158
ggml_set_name(probs, "dist_probs");
11851159

@@ -1226,6 +1200,7 @@ static void llama_sampler_dist_backend_apply(
12261200

12271201
static void llama_sampler_dist_backend_set_input(struct llama_sampler * smpl) {
12281202
auto * sctx = (llama_sampler_dist *) smpl->ctx;
1203+
12291204
GGML_ASSERT(sctx->inp_uniform != nullptr);
12301205

12311206
// We sample in double precision and cast to float to match rnd numbers of
@@ -1262,8 +1237,6 @@ struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
12621237
/* .seed_cur = */ seed_cur,
12631238
/* .rng = */ std::mt19937(seed_cur),
12641239
/* .inp_uniform = */ nullptr,
1265-
/* .inp_ctx = */ nullptr,
1266-
/* .inp_buf = */ nullptr,
12671240
}
12681241
);
12691242
}
@@ -3461,9 +3434,6 @@ struct llama_sampler_logit_bias : public llama_sampler_backend {
34613434

34623435
struct ggml_tensor * inp_logit_bias;
34633436
struct ggml_tensor * inp_logit_idxs;
3464-
3465-
ggml_context_ptr inp_ctx;
3466-
ggml_backend_buffer_ptr inp_buf;
34673437
};
34683438

34693439
static const char * llama_sampler_logit_bias_name(const struct llama_sampler * smpl) {
@@ -3526,6 +3496,16 @@ static void llama_sampler_logit_bias_backend_apply(
35263496
return;
35273497
}
35283498

3499+
const size_t n = sctx->logit_bias.size();
3500+
3501+
sctx->inp_logit_bias = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n);
3502+
ggml_set_name(sctx->inp_logit_bias, "logit_bias");
3503+
ggml_set_input(sctx->inp_logit_bias);
3504+
3505+
sctx->inp_logit_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n);
3506+
ggml_set_name(sctx->inp_logit_idxs, "logit_idxs");
3507+
ggml_set_input(sctx->inp_logit_idxs);
3508+
35293509
ggml_tensor * cur = ggml_fill(ctx, data->logits, 0.0f);
35303510

35313511
cur = ggml_reshape_2d(ctx, cur, 1, ggml_nelements(cur));
@@ -3562,6 +3542,8 @@ static void llama_sampler_logit_bias_backend_set_input(struct llama_sampler * sm
35623542
static bool llama_sampler_logit_bias_backend_init(
35633543
struct llama_sampler * smpl,
35643544
ggml_backend_buffer_type_t buft) {
3545+
GGML_UNUSED(buft);
3546+
35653547
auto * sctx = (llama_sampler_logit_bias *) smpl->ctx;
35663548

35673549
sctx->init(true);
@@ -3570,29 +3552,6 @@ static bool llama_sampler_logit_bias_backend_init(
35703552
return true;
35713553
}
35723554

3573-
ggml_init_params params = {
3574-
/*.mem_size =*/ 2*ggml_tensor_overhead(),
3575-
/*.mem_buffer =*/ nullptr,
3576-
/*.no_alloc =*/ true,
3577-
};
3578-
3579-
sctx->inp_ctx.reset(ggml_init(params));
3580-
3581-
const size_t n = sctx->logit_bias.size();
3582-
3583-
sctx->inp_logit_bias = ggml_new_tensor_2d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1, n);
3584-
ggml_set_name(sctx->inp_logit_bias, "logit_bias");
3585-
ggml_set_input(sctx->inp_logit_bias);
3586-
3587-
sctx->inp_logit_idxs = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_I32, n);
3588-
ggml_set_name(sctx->inp_logit_idxs, "logit_idxs");
3589-
ggml_set_input(sctx->inp_logit_idxs);
3590-
3591-
// Allocate all tensors from our context to the backend
3592-
sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft));
3593-
3594-
ggml_backend_buffer_clear(sctx->inp_buf.get(), 0);
3595-
35963555
return true;
35973556
}
35983557

@@ -3628,8 +3587,6 @@ struct llama_sampler * llama_sampler_init_logit_bias(
36283587
/* .to_search = */ {},
36293588
/* .inp_logit_bias = */ nullptr,
36303589
/* .inp_logit_idxs = */ nullptr,
3631-
/* .inp_ctx = */ nullptr,
3632-
/* .inp_buf = */ nullptr,
36333590
}
36343591
);
36353592
}

0 commit comments

Comments
 (0)