Skip to content

Commit f70bfb2

Browse files
committed
fix test
1 parent c284a28 commit f70bfb2

1 file changed

Lines changed: 26 additions & 10 deletions

File tree

tests/e2e/test_ernie_21b_mtp_ngram.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,24 @@ def _build_speculate_metrics_baseline(
7777
)
7878

7979

80+
def _assert_speculate_metrics_match(actual, baseline, label):
81+
"""Per-field comparison against a tolerance-based baseline.
82+
83+
Avoids whole-dict ``==`` so that an AssertionError isn't masked by a
84+
TypeError when json.dumps tries to serialize pytest.approx wrappers in
85+
the failure message.
86+
"""
87+
missing = set(baseline) - set(actual)
88+
extra = set(actual) - set(baseline)
89+
assert not missing and not extra, (
90+
f"[{label}] speculate_metrics keys mismatch: missing={missing}, extra={extra}, "
91+
f"got_keys={sorted(actual.keys())}"
92+
)
93+
for key, expected in baseline.items():
94+
got = actual[key]
95+
assert got == expected, f"[{label}] field '{key}' mismatch:\n" f" got: {got}\n" f" expected: {expected}"
96+
97+
8098
@pytest.fixture(scope="session", autouse=True)
8199
def setup_and_run_server():
82100
"""
@@ -276,12 +294,10 @@ def test_mtp_ngram_speculate_metrics(api_url):
276294
f"sum(accepted_tokens_per_head) ({sum(accepted_per_head)})"
277295
)
278296

279-
# Baseline comparison — exact match against the values captured in the reference environment.
297+
# Baseline comparison (tolerance-based) against values captured in the reference environment.
280298
if BASELINE_SPECULATE_METRICS is not None:
281-
assert speculate_metrics == BASELINE_SPECULATE_METRICS, (
282-
f"speculate_metrics mismatch\n"
283-
f"got: {json.dumps(speculate_metrics, indent=2)}\n"
284-
f"baseline: {json.dumps(BASELINE_SPECULATE_METRICS, indent=2)}"
299+
_assert_speculate_metrics_match(
300+
speculate_metrics, BASELINE_SPECULATE_METRICS, label="test_mtp_ngram_speculate_metrics"
285301
)
286302

287303

@@ -336,10 +352,10 @@ def test_mtp_ngram_speculate_metrics_with_logprobs(api_url):
336352
assert len(accepted_per_head) == 6
337353
assert speculate_metrics["accepted_tokens"] == sum(accepted_per_head)
338354

339-
# Baseline comparison — exact match against the values captured in the reference environment.
355+
# Baseline comparison (tolerance-based) against values captured in the reference environment.
340356
if BASELINE_SPECULATE_METRICS_WITH_LOGPROBS is not None:
341-
assert speculate_metrics == BASELINE_SPECULATE_METRICS_WITH_LOGPROBS, (
342-
f"speculate_metrics mismatch\n"
343-
f"got: {json.dumps(speculate_metrics, indent=2)}\n"
344-
f"baseline: {json.dumps(BASELINE_SPECULATE_METRICS_WITH_LOGPROBS, indent=2)}"
357+
_assert_speculate_metrics_match(
358+
speculate_metrics,
359+
BASELINE_SPECULATE_METRICS_WITH_LOGPROBS,
360+
label="test_mtp_ngram_speculate_metrics_with_logprobs",
345361
)

0 commit comments

Comments
 (0)