File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -3033,6 +3033,8 @@ server_context_meta server_context::get_meta() const {
30333033 /* fim_rep_token */ llama_vocab_fim_rep (impl->vocab ),
30343034 /* fim_sep_token */ llama_vocab_fim_sep (impl->vocab ),
30353035
3036+ /* logit_bias_eog */ impl->params_base .sampling .logit_bias_eog ,
3037+
30363038 /* model_vocab_type */ llama_vocab_type (impl->vocab ),
30373039 /* model_vocab_n_tokens */ llama_vocab_n_tokens (impl->vocab ),
30383040 /* model_n_ctx_train */ llama_model_n_ctx_train (impl->model ),
@@ -3117,6 +3119,7 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
31173119 ctx_server.vocab ,
31183120 params,
31193121 meta->slot_n_ctx ,
3122+ meta->logit_bias_eog ,
31203123 data);
31213124 task.id_slot = json_value (data, " id_slot" , -1 );
31223125
Original file line number Diff line number Diff line change @@ -39,6 +39,9 @@ struct server_context_meta {
3939 llama_token fim_rep_token;
4040 llama_token fim_sep_token;
4141
42+ // sampling
43+ std::vector<llama_logit_bias> logit_bias_eog;
44+
4245 // model meta
4346 enum llama_vocab_type model_vocab_type;
4447 int32_t model_vocab_n_tokens;
Original file line number Diff line number Diff line change @@ -239,6 +239,7 @@ task_params server_task::params_from_json_cmpl(
239239 const llama_vocab * vocab,
240240 const common_params & params_base,
241241 const int n_ctx_slot,
242+ const std::vector<llama_logit_bias> & logit_bias_eog,
242243 const json & data) {
243244 task_params params;
244245
@@ -562,7 +563,7 @@ task_params server_task::params_from_json_cmpl(
562563 if (params.sampling .ignore_eos ) {
563564 params.sampling .logit_bias .insert (
564565 params.sampling .logit_bias .end (),
565- defaults. sampling . logit_bias_eog .begin (), defaults. sampling . logit_bias_eog .end ());
566+ logit_bias_eog.begin (), logit_bias_eog.end ());
566567 }
567568 }
568569
Original file line number Diff line number Diff line change @@ -209,6 +209,7 @@ struct server_task {
209209 const llama_vocab * vocab,
210210 const common_params & params_base,
211211 const int n_ctx_slot,
212+ const std::vector<llama_logit_bias> & logit_bias_eog,
212213 const json & data);
213214
214215 // utility function
Original file line number Diff line number Diff line change 1+ import pytest
2+ from utils import *
3+
4+ server = ServerPreset .tinyllama2 ()
5+
6+
7+ @pytest .fixture (autouse = True )
8+ def create_server ():
9+ global server
10+ server = ServerPreset .tinyllama2 ()
11+
12+
13+ def test_ignore_eos_populates_logit_bias ():
14+ """ignore_eos=true must add EOG logit biases to generation_settings."""
15+ global server
16+ server .start ()
17+ res = server .make_request ("POST" , "/completion" , data = {
18+ "n_predict" : 8 ,
19+ "prompt" : "Once upon a time" ,
20+ "ignore_eos" : True ,
21+ "temperature" : 0.0 ,
22+ })
23+ assert res .status_code == 200
24+ # EOG token biases must be present with -inf bias
25+ logit_bias = res .body ["generation_settings" ]["logit_bias" ]
26+ assert len (logit_bias ) > 0
27+ for entry in logit_bias :
28+ assert entry ["bias" ] is None # null in JSON represents -inf
29+
30+
31+ def test_ignore_eos_false_no_logit_bias ():
32+ """ignore_eos=false (default) must NOT add EOG logit biases."""
33+ global server
34+ server .start ()
35+ res = server .make_request ("POST" , "/completion" , data = {
36+ "n_predict" : 8 ,
37+ "prompt" : "Once upon a time" ,
38+ "ignore_eos" : False ,
39+ "temperature" : 0.0 ,
40+ })
41+ assert res .status_code == 200
42+ logit_bias = res .body ["generation_settings" ]["logit_bias" ]
43+ assert len (logit_bias ) == 0
You can’t perform that action at this time.
0 commit comments