@@ -491,29 +491,82 @@ def test_n_probs_post_sampling():
491491 global server
492492 server .start ()
493493 res = server .make_request ("POST" , "/completion" , data = {
494- "prompt" : "I believe the meaning of life is " ,
494+ "prompt" : "Today was the day. Today I would finally become a " ,
495495 "n_probs" : 10 ,
496- "temperature" : 0 .0 ,
496+ "temperature" : 1 .0 ,
497497 "n_predict" : 5 ,
498498 "post_sampling_probs" : True ,
499499 })
500500 assert res .status_code == 200
501501 assert "completion_probabilities" in res .body
502502 assert len (res .body ["completion_probabilities" ]) == 5
503- for tok in res .body ["completion_probabilities" ]:
503+ for ( i , tok ) in enumerate ( res .body ["completion_probabilities" ]) :
504504 assert "id" in tok and tok ["id" ] > 0
505505 assert "token" in tok and type (tok ["token" ]) == str
506506 assert "prob" in tok and 0.0 < tok ["prob" ] <= 1.0
507507 assert "bytes" in tok and type (tok ["bytes" ]) == list
508- assert len (tok ["top_probs" ]) == 10
508+ assert "top_probs" in tok and type (tok ["top_probs" ]) == list
509+
509510 for prob in tok ["top_probs" ]:
510511 assert "id" in prob and prob ["id" ] > 0
511512 assert "token" in prob and type (prob ["token" ]) == str
512- assert "prob" in prob and 0.0 <= prob ["prob" ] <= 1.0
513+ # 0.0 probability tokens should never be returned by the server
514+ assert "prob" in prob and 0.0 < prob ["prob" ] <= 1.0
513515 assert "bytes" in prob and type (prob ["bytes" ]) == list
514- # because the test model usually output token with either 100% or 0% probability, we need to check all the top_probs
515- assert any (prob ["prob" ] == 1.0 for prob in tok ["top_probs" ])
516516
517+ if i == 0 :
518+ # The prompt is vague enough that we should get at least 10 possibilities
519+ # for the first token.
520+ assert len (tok ["top_probs" ]) == 10
521+
522+ if len (tok ["top_probs" ]) < 10 :
523+ # Getting less than the requested number of probabilities should only happen
524+ # if the ones we did get already sum to 1.0.
525+ assert sum (p ["prob" ] for p in tok ["top_probs" ]) == pytest .approx (1.0 )
526+
527+ def test_n_probs_post_backend_sampling ():
528+ """Verify that the same probabilities are returned with and without backend sampling."""
529+ global server
530+ server .backend_sampling = True
531+ server .start ()
532+
533+ def make_request (backend_sampling ):
534+ n_predict = 20
535+
536+ res = server .make_request ("POST" , "/completion" , data = {
537+ "prompt" : "The countries of Europe, in random order, are:" ,
538+ "n_probs" : 10 ,
539+ "n_predict" : n_predict ,
540+ "post_sampling_probs" : True ,
541+ "seed" : 4242 ,
542+ "backend_sampling" : backend_sampling ,
543+ })
544+ assert res .status_code == 200
545+
546+ total_probs = 0
547+ completions = res .body ["completion_probabilities" ]
548+ assert len (completions ) == n_predict
549+ for tok in completions :
550+ # Handling of 0.0 probabilities differs between samplers and backend sampling. Filter them to normalize the
551+ # data.
552+ tok ["top_probs" ] = [x for x in tok ["top_probs" ] if x ["prob" ] > 0.0 ]
553+ total_probs += len (tok ["top_probs" ])
554+ # Verify that we got at least two top probs on average, to ensure the effectiveness of the test.
555+ assert total_probs >= 2 * n_predict
556+ return completions
557+
558+ def verify_token (a , b ):
559+ assert a ["id" ] == b ["id" ]
560+ assert a ["token" ] == b ["token" ]
561+ assert a ["bytes" ] == b ["bytes" ]
562+ assert a ["prob" ] == pytest .approx (b ["prob" ], abs = 0.01 )
563+
564+ for (a , b ) in zip (make_request (True ), make_request (False )):
565+ verify_token (a , b )
566+ assert len (a ["top_probs" ]) == len (b ["top_probs" ])
567+
568+ for (aa , bb ) in zip (a ["top_probs" ], b ["top_probs" ]):
569+ verify_token (aa , bb )
517570
518571@pytest .mark .parametrize ("tokenize,openai_style" , [(False , False ), (False , True ), (True , False ), (True , True )])
519572def test_logit_bias (tokenize , openai_style ):
0 commit comments