@@ -77,6 +77,24 @@ def _build_speculate_metrics_baseline(
7777)
7878
7979
80+ def _assert_speculate_metrics_match (actual , baseline , label ):
81+ """Per-field comparison against a tolerance-based baseline.
82+
83+ Avoids whole-dict ``==`` so that an AssertionError isn't masked by a
84+ TypeError when json.dumps tries to serialize pytest.approx wrappers in
85+ the failure message.
86+ """
87+ missing = set (baseline ) - set (actual )
88+ extra = set (actual ) - set (baseline )
89+ assert not missing and not extra , (
90+ f"[{ label } ] speculate_metrics keys mismatch: missing={ missing } , extra={ extra } , "
91+ f"got_keys={ sorted (actual .keys ())} "
92+ )
93+ for key , expected in baseline .items ():
94+ got = actual [key ]
95+ assert got == expected , f"[{ label } ] field '{ key } ' mismatch:\n " f" got: { got } \n " f" expected: { expected } "
96+
97+
8098@pytest .fixture (scope = "session" , autouse = True )
8199def setup_and_run_server ():
82100 """
@@ -276,12 +294,10 @@ def test_mtp_ngram_speculate_metrics(api_url):
276294 f"sum(accepted_tokens_per_head) ({ sum (accepted_per_head )} )"
277295 )
278296
279- # Baseline comparison — exact match against the values captured in the reference environment.
297+ # Baseline comparison (tolerance-based) against values captured in the reference environment.
280298 if BASELINE_SPECULATE_METRICS is not None :
281- assert speculate_metrics == BASELINE_SPECULATE_METRICS , (
282- f"speculate_metrics mismatch\n "
283- f"got: { json .dumps (speculate_metrics , indent = 2 )} \n "
284- f"baseline: { json .dumps (BASELINE_SPECULATE_METRICS , indent = 2 )} "
299+ _assert_speculate_metrics_match (
300+ speculate_metrics , BASELINE_SPECULATE_METRICS , label = "test_mtp_ngram_speculate_metrics"
285301 )
286302
287303
@@ -336,10 +352,10 @@ def test_mtp_ngram_speculate_metrics_with_logprobs(api_url):
336352 assert len (accepted_per_head ) == 6
337353 assert speculate_metrics ["accepted_tokens" ] == sum (accepted_per_head )
338354
339- # Baseline comparison — exact match against the values captured in the reference environment.
355+ # Baseline comparison (tolerance-based) against values captured in the reference environment.
340356 if BASELINE_SPECULATE_METRICS_WITH_LOGPROBS is not None :
341- assert speculate_metrics == BASELINE_SPECULATE_METRICS_WITH_LOGPROBS , (
342- f" speculate_metrics mismatch \n "
343- f"got: { json . dumps ( speculate_metrics , indent = 2 ) } \n "
344- f"baseline: { json . dumps ( BASELINE_SPECULATE_METRICS_WITH_LOGPROBS , indent = 2 ) } "
357+ _assert_speculate_metrics_match (
358+ speculate_metrics ,
359+ BASELINE_SPECULATE_METRICS_WITH_LOGPROBS ,
360+ label = "test_mtp_ngram_speculate_metrics_with_logprobs" ,
345361 )
0 commit comments