2727from datasets import Dataset
2828
2929
30+ LOGIT_TIE_EPSILON = 0.05
31+
32+
3033def _to_plain_list (value ):
3134 """convert a list of tensors to a list of plain values"""
3235 new_value = []
@@ -37,6 +40,83 @@ def _to_plain_list(value):
3740 return new_value
3841
3942
43+ def to_plain_data (value ):
44+ """Convert nested tensor-like values to plain Python data."""
45+ if hasattr (value , "tolist" ):
46+ value = value .tolist ()
47+
48+ if isinstance (value , list ):
49+ return [to_plain_data (item ) for item in value ]
50+
51+ return value
52+
53+
54+ def first_generated_token_id (model_response : dict ) -> int | None :
55+ """Return the first generated token id for the first sequence."""
56+ output_tokens = to_plain_data (model_response .get ("output_tokens" ) or [])
57+ if not output_tokens or not output_tokens [0 ]:
58+ return None
59+
60+ first_sequence = output_tokens [0 ]
61+ if isinstance (first_sequence , list ):
62+ return first_sequence [0 ] if first_sequence else None
63+
64+ return first_sequence
65+
66+
67+ def first_step_logits (model_response : dict ) -> list [float ] | None :
68+ """Return the logits for the first generated token, if available."""
69+ logits = to_plain_data (model_response .get ("logits" ))
70+ if not logits :
71+ return None
72+
73+ first_step = logits [0 ]
74+ if isinstance (first_step , list ):
75+ return first_step
76+
77+ return logits
78+
79+
80+ def is_within_logit_tie_margin (logits : list [float ], token_id : int , epsilon : float = LOGIT_TIE_EPSILON ) -> bool :
81+ """Check whether a token is within epsilon of the maximum logit."""
82+ if token_id < 0 or token_id >= len (logits ):
83+ return False
84+
85+ return max (logits ) - logits [token_id ] <= epsilon
86+
87+
88+ def is_tied_choice_prediction (current : dict , reference : dict , epsilon : float = LOGIT_TIE_EPSILON ) -> bool :
89+ """Return True when two different MCQ predictions are both within the tie margin."""
90+ current_choices = current .get ("doc" , {}).get ("choices" )
91+ reference_choices = reference .get ("doc" , {}).get ("choices" )
92+ if not current_choices or current_choices != reference_choices or len (current_choices ) < 2 :
93+ return False
94+
95+ current_response = current .get ("model_response" , {})
96+ reference_response = reference .get ("model_response" , {})
97+
98+ current_token = first_generated_token_id (current_response )
99+ reference_token = first_generated_token_id (reference_response )
100+ if current_token is None or reference_token is None or current_token == reference_token :
101+ return False
102+
103+ reference_logits = first_step_logits (reference_response )
104+ if reference_logits is None :
105+ return False
106+
107+ for token_id in (current_token , reference_token ):
108+ if not is_within_logit_tie_margin (reference_logits , token_id , epsilon ):
109+ return False
110+
111+ current_logits = first_step_logits (current_response )
112+ if current_logits is not None :
113+ for token_id in (current_token , reference_token ):
114+ if not is_within_logit_tie_margin (current_logits , token_id , epsilon ):
115+ return False
116+
117+ return True
118+
119+
40120def load_sample_details (details_dir : str ):
41121 """Load sample-level details from parquet files in the details directory."""
42122 details = {}
@@ -140,6 +220,10 @@ def _compare_single_sample(current, reference, sample_index):
140220 if "doc" in current and "doc" in reference :
141221 sample_diff .update (_compare_doc_info (current , reference ))
142222
223+ if sample_diff and set (sample_diff ).issubset ({"output_tokens_difference" , "metric_differences" }):
224+ if is_tied_choice_prediction (current , reference ):
225+ return {}
226+
143227 if sample_diff :
144228 sample_diff ["sample_index" ] = sample_index
145229
0 commit comments