Skip to content

Commit 4488f97

Browse files
committed
fix ci ut
1 parent 9abc1e1 commit 4488f97

6 files changed

Lines changed: 44 additions & 15 deletions

File tree

fastdeploy/entrypoints/llm.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -450,11 +450,7 @@ def _build_prompt_logprobs(
450450
tensors.
451451
"""
452452

453-
token_ids, logprobs, ranks = (
454-
prompt_logprobs_tensors.logprob_token_ids,
455-
prompt_logprobs_tensors.logprobs,
456-
prompt_logprobs_tensors.selected_token_ranks,
457-
)
453+
token_ids, logprobs, ranks = prompt_logprobs_tensors[:3]
458454

459455
# Detokenize non-incrementally.
460456
# Output is flat: [num_tok, num_lps] -> [num_tok * num_lps]

fastdeploy/entrypoints/openai/serving_chat.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -968,11 +968,7 @@ def _build_prompt_logprobs(
968968
tensors.
969969
"""
970970

971-
token_ids, logprobs, ranks = (
972-
prompt_logprobs_tensors.logprob_token_ids,
973-
prompt_logprobs_tensors.logprobs,
974-
prompt_logprobs_tensors.selected_token_ranks,
975-
)
971+
token_ids, logprobs, ranks = prompt_logprobs_tensors[:3]
976972

977973
# Normalize to plain Python lists (support both Tensor and list inputs)
978974
if hasattr(token_ids, "tolist"):

fastdeploy/entrypoints/openai/serving_completion.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -900,11 +900,7 @@ def _build_prompt_logprobs(
900900
tensors.
901901
"""
902902

903-
token_ids, logprobs, ranks = (
904-
prompt_logprobs_tensors.logprob_token_ids,
905-
prompt_logprobs_tensors.logprobs,
906-
prompt_logprobs_tensors.selected_token_ranks,
907-
)
903+
token_ids, logprobs, ranks = prompt_logprobs_tensors[:3]
908904

909905
# Normalize to plain Python lists (support both Tensor and list inputs)
910906
if hasattr(token_ids, "tolist"):

tests/ce/server/test_logprobs.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,17 @@
33
from core import TEMPLATE, URL, build_request_payload, send_request
44

55

6+
def _strip_logits_stats(obj):
7+
"""Recursively remove 'logits_stats' keys from logprobs response."""
8+
if isinstance(obj, dict):
9+
obj.pop("logits_stats", None)
10+
for v in obj.values():
11+
_strip_logits_stats(v)
12+
elif isinstance(obj, list):
13+
for item in obj:
14+
_strip_logits_stats(item)
15+
16+
617
def test_unstream_with_logprobs():
718
"""
819
测试非流式响应开启 logprobs 后,返回的 token 概率信息是否正确。
@@ -21,6 +32,7 @@ def test_unstream_with_logprobs():
2132
response = send_request(URL, payload)
2233
print(json.dumps(response.json(), indent=2, ensure_ascii=False))
2334
resp_json = response.json()
35+
_strip_logits_stats(resp_json)
2436

2537
# 校验返回内容与概率信息
2638
assert resp_json["choices"][0]["message"]["content"] == "牛顿的"
@@ -99,6 +111,7 @@ def test_stream_with_logprobs():
99111
print(json.dumps(result_chunk, indent=2, ensure_ascii=False))
100112
break
101113

114+
_strip_logits_stats(result_chunk)
102115
# 校验概率字段
103116
assert result_chunk["choices"][0]["delta"]["content"] == "牛顿"
104117
assert result_chunk["choices"][0]["logprobs"]["content"][0]["token"] == "牛顿"
@@ -184,6 +197,7 @@ def test_stream_with_temp_scaled_logprobs():
184197
print(json.dumps(result_chunk, indent=2, ensure_ascii=False))
185198
break
186199

200+
_strip_logits_stats(result_chunk)
187201
# 校验概率字段
188202
assert result_chunk["choices"][0]["delta"]["content"] == "牛顿"
189203
assert result_chunk["choices"][0]["logprobs"]["content"][0]["token"] == "牛顿"
@@ -229,6 +243,7 @@ def test_stream_with_top_p_normalized_logprobs():
229243
print(json.dumps(result_chunk, indent=2, ensure_ascii=False))
230244
break
231245

246+
_strip_logits_stats(result_chunk)
232247
# 校验概率字段
233248
assert result_chunk["choices"][0]["delta"]["content"] == "牛顿"
234249
assert result_chunk["choices"][0]["logprobs"]["content"][0]["token"] == "牛顿"

tests/e2e/4cards_cases/test_ernie_21b_tp1_dp4.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,18 @@
2323
import pytest
2424
import requests
2525

26+
27+
def _strip_logits_stats(obj):
28+
"""Recursively remove 'logits_stats' keys from logprobs response."""
29+
if isinstance(obj, dict):
30+
obj.pop("logits_stats", None)
31+
for v in obj.values():
32+
_strip_logits_stats(v)
33+
elif isinstance(obj, list):
34+
for item in obj:
35+
_strip_logits_stats(item)
36+
37+
2638
tests_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
2739
sys.path.insert(0, tests_dir)
2840

@@ -606,6 +618,7 @@ def test_non_stream_with_logprobs(api_url):
606618
resp_json = send_request(url=api_url, payload=payload).json()
607619

608620
logprobs = resp_json["choices"][0]["logprobs"]
621+
_strip_logits_stats(logprobs)
609622

610623
base_path = os.getenv("MODEL_PATH")
611624
if base_path:

tests/e2e/4cards_cases/test_ernie_21b_tp1_dp4_mtp.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,18 @@
2323
import pytest
2424
import requests
2525

26+
27+
def _strip_logits_stats(obj):
28+
"""Recursively remove 'logits_stats' keys from logprobs response."""
29+
if isinstance(obj, dict):
30+
obj.pop("logits_stats", None)
31+
for v in obj.values():
32+
_strip_logits_stats(v)
33+
elif isinstance(obj, list):
34+
for item in obj:
35+
_strip_logits_stats(item)
36+
37+
2638
tests_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
2739
sys.path.insert(0, tests_dir)
2840

@@ -512,6 +524,7 @@ def test_non_stream_with_logprobs(api_url):
512524
resp_json = send_request(url=api_url, payload=payload).json()
513525

514526
logprobs = resp_json["choices"][0]["logprobs"]
527+
_strip_logits_stats(logprobs)
515528

516529
base_path = os.getenv("MODEL_PATH")
517530

0 commit comments

Comments
 (0)