Skip to content

Commit 2e97c5f

Browse files
authored
backend sampling: support returning post-sampling probs (ggml-org#22622)
* server: Never return 0.0 post-sampling probabilities * backend sampling: support returning post-sampling probs
1 parent 5d5d2e1 commit 2e97c5f

4 files changed

Lines changed: 80 additions & 16 deletions

File tree

common/sampling.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,8 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
547547
auto & chain = gsmpl->chain;
548548
auto & cur_p = gsmpl->cur_p; // initialized by set_logits
549549

550+
gsmpl->set_logits(ctx, idx);
551+
550552
// Check if a backend sampler has already sampled a token in which case we
551553
// return that token id directly.
552554
{
@@ -558,17 +560,17 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
558560
GGML_ASSERT(!gsmpl->grmr && "using grammar in combination with backend sampling is not supported");
559561
GGML_ASSERT(!gsmpl->rbudget && "using reasoning budget in combination with backend sampling is not supported");
560562

561-
// TODO: simplify
562-
gsmpl->cur.resize(1);
563-
gsmpl->cur[0] = { id, 0.0f, 1.0f };
564-
cur_p = { gsmpl->cur.data(), gsmpl->cur.size(), 0, true };
563+
for (size_t i = 0; i < cur_p.size; ++i) {
564+
if (cur_p.data[i].id == id) {
565+
cur_p.selected = i;
566+
break;
567+
}
568+
}
565569

566570
return id;
567571
}
568572
}
569573

570-
gsmpl->set_logits(ctx, idx);
571-
572574
// apply reasoning budget first
573575
llama_sampler_apply(rbudget, &cur_p);
574576

tools/server/server-context.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1317,7 +1317,7 @@ struct server_context_impl {
13171317
return false;
13181318
}
13191319

1320-
const bool need_logits = task.params.sampling.n_probs > 0;
1320+
const bool need_pre_sample_logits = task.params.sampling.n_probs > 0 && !task.params.post_sampling_probs;
13211321

13221322
bool backend_sampling = true;
13231323

@@ -1326,8 +1326,8 @@ struct server_context_impl {
13261326
// TODO: speculative decoding requires multiple samples per batch - not supported yet
13271327
backend_sampling &= !(slot.can_speculate() && common_speculative_n_max(slot.spec.get(), task.params.speculative) > 0);
13281328

1329-
// TODO: getting post/pre sampling logits is not yet supported with backend sampling
1330-
backend_sampling &= !need_logits;
1329+
// TODO: getting pre sampling logits is not yet supported with backend sampling
1330+
backend_sampling &= !need_pre_sample_logits;
13311331

13321332
// TODO: tmp until backend sampling is fully implemented
13331333
if (backend_sampling) {
@@ -1504,6 +1504,12 @@ struct server_context_impl {
15041504
// set probability for top n_probs tokens
15051505
result.probs.reserve(n_probs);
15061506
for (size_t i = 0; i < n_probs; i++) {
1507+
// Some samplers do return 0.0 probabilities, others don't.
1508+
// Filter 0.0 probailities, to ensure the behavior is consistent.
1509+
if (cur_p->data[i].p == 0.0) {
1510+
break;
1511+
}
1512+
15071513
result.probs.push_back({
15081514
cur_p->data[i].id,
15091515
common_token_to_piece(ctx, cur_p->data[i].id, special),

tools/server/tests/unit/test_completion.py

Lines changed: 60 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -491,29 +491,82 @@ def test_n_probs_post_sampling():
491491
global server
492492
server.start()
493493
res = server.make_request("POST", "/completion", data={
494-
"prompt": "I believe the meaning of life is",
494+
"prompt": "Today was the day. Today I would finally become a",
495495
"n_probs": 10,
496-
"temperature": 0.0,
496+
"temperature": 1.0,
497497
"n_predict": 5,
498498
"post_sampling_probs": True,
499499
})
500500
assert res.status_code == 200
501501
assert "completion_probabilities" in res.body
502502
assert len(res.body["completion_probabilities"]) == 5
503-
for tok in res.body["completion_probabilities"]:
503+
for (i, tok) in enumerate(res.body["completion_probabilities"]):
504504
assert "id" in tok and tok["id"] > 0
505505
assert "token" in tok and type(tok["token"]) == str
506506
assert "prob" in tok and 0.0 < tok["prob"] <= 1.0
507507
assert "bytes" in tok and type(tok["bytes"]) == list
508-
assert len(tok["top_probs"]) == 10
508+
assert "top_probs" in tok and type(tok["top_probs"]) == list
509+
509510
for prob in tok["top_probs"]:
510511
assert "id" in prob and prob["id"] > 0
511512
assert "token" in prob and type(prob["token"]) == str
512-
assert "prob" in prob and 0.0 <= prob["prob"] <= 1.0
513+
# 0.0 probability tokens should never be returned by the server
514+
assert "prob" in prob and 0.0 < prob["prob"] <= 1.0
513515
assert "bytes" in prob and type(prob["bytes"]) == list
514-
# because the test model usually output token with either 100% or 0% probability, we need to check all the top_probs
515-
assert any(prob["prob"] == 1.0 for prob in tok["top_probs"])
516516

517+
if i == 0:
518+
# The prompt is vague enough that we should get at least 10 possibilities
519+
# for the first token.
520+
assert len(tok["top_probs"]) == 10
521+
522+
if len(tok["top_probs"]) < 10:
523+
# Getting less than the requested number of probabilities should only happen
524+
# if the ones we did get already sum to 1.0.
525+
assert sum(p["prob"] for p in tok["top_probs"]) == pytest.approx(1.0)
526+
527+
def test_n_probs_post_backend_sampling():
528+
"""Verify that the same probabilities are returned with and without backend sampling."""
529+
global server
530+
server.backend_sampling = True
531+
server.start()
532+
533+
def make_request(backend_sampling):
534+
n_predict = 20
535+
536+
res = server.make_request("POST", "/completion", data={
537+
"prompt": "The countries of Europe, in random order, are:",
538+
"n_probs": 10,
539+
"n_predict": n_predict,
540+
"post_sampling_probs": True,
541+
"seed": 4242,
542+
"backend_sampling": backend_sampling,
543+
})
544+
assert res.status_code == 200
545+
546+
total_probs = 0
547+
completions = res.body["completion_probabilities"]
548+
assert len(completions) == n_predict
549+
for tok in completions:
550+
# Handling of 0.0 probabilities differs between samplers and backend sampling. Filter them to normalize the
551+
# data.
552+
tok["top_probs"] = [x for x in tok["top_probs"] if x["prob"] > 0.0]
553+
total_probs += len(tok["top_probs"])
554+
# Verify that we got at least two top probs on average, to ensure the effectiveness of the test.
555+
assert total_probs >= 2 * n_predict
556+
return completions
557+
558+
def verify_token(a, b):
559+
assert a["id"] == b["id"]
560+
assert a["token"] == b["token"]
561+
assert a["bytes"] == b["bytes"]
562+
assert a["prob"] == pytest.approx(b["prob"], abs=0.01)
563+
564+
for (a, b) in zip(make_request(True), make_request(False)):
565+
verify_token(a, b)
566+
assert len(a["top_probs"]) == len(b["top_probs"])
567+
568+
for (aa, bb) in zip(a["top_probs"], b["top_probs"]):
569+
verify_token(aa, bb)
517570

518571
@pytest.mark.parametrize("tokenize,openai_style", [(False, False), (False, True), (True, False), (True, True)])
519572
def test_logit_bias(tokenize, openai_style):

tools/server/tests/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ class ServerProcess:
108108
no_cache_idle_slots: bool = False
109109
log_path: str | None = None
110110
webui_mcp_proxy: bool = False
111+
backend_sampling: bool = False
111112
gcp_compat: bool = False
112113

113114
# session variables
@@ -252,6 +253,8 @@ def start(self, timeout_seconds: int = DEFAULT_HTTP_TIMEOUT) -> None:
252253
server_args.append("--no-cache-idle-slots")
253254
if self.webui_mcp_proxy:
254255
server_args.append("--webui-mcp-proxy")
256+
if self.backend_sampling:
257+
server_args.append("--backend_sampling")
255258
if self.gcp_compat:
256259
env["AIP_MODE"] = "PREDICTION"
257260

0 commit comments

Comments
 (0)