1616"""
1717RL Evaluation Module.
1818"""
19+ import collections
1920import json
2021
2122from tqdm .auto import tqdm
@@ -87,7 +88,7 @@ def generate_responses(
8788 return multiple_call_responses
8889
8990
90- def score_responses (tmvp_config , question , responses , answers ):
91+ def score_responses (tmvp_config , question , responses , answers , eval_mode = "pass" ):
9192 """
9293 Score a set of responses for a single question.
9394
@@ -96,6 +97,7 @@ def score_responses(tmvp_config, question, responses, answers):
9697 question: The evaluation question
9798 responses: List of generated responses for this question
9899 answers: List of acceptable answers for this question
100+ eval_mode: The evaluation mode to use ("pass" for pass@K, "maj" for maj@K)
99101
100102 Returns:
101103 Tuple of (is_correct, is_partially_correct, has_correct_format)
@@ -112,6 +114,35 @@ def score_responses(tmvp_config, question, responses, answers):
112114 is_partially_correct = False
113115 has_correct_format = False
114116
117+ if eval_mode == "maj" :
118+ extracted_answers = []
119+ for response in responses :
120+ match_format = utils_rl .get_match_format_regex (tmvp_config )
121+ if match_format .search (response ) is not None :
122+ has_correct_format = True
123+
124+ extracted_response = utils_rl .extract_answer (response , tmvp_config )
125+ extracted_answers .append (extracted_response )
126+
127+ if not extracted_answers :
128+ return False , False , False
129+
130+ counter = collections .Counter (extracted_answers )
131+ majority_answer = counter .most_common (1 )[0 ][0 ]
132+
133+ try :
134+ is_correct , is_partially_correct = utils_rl .check_correctness (majority_answer , answers , tmvp_config )
135+ if tmvp_config .debug .rl :
136+ max_logging .log (f"Majority Answer: { majority_answer } (Count: { counter [majority_answer ]} )" )
137+ max_logging .log (f"Result is_correct: { is_correct } " )
138+ max_logging .log (f"Result is_partially_correct: { is_partially_correct } " )
139+ except Exception as e :
140+ if tmvp_config .debug .rl :
141+ max_logging .log (f"Evaluation Exception on majority answer: { e } " )
142+ max_logging .log ("SKIPPED" )
143+
144+ return is_correct , is_partially_correct , has_correct_format
145+
115146 for response in responses :
116147 match_format = utils_rl .get_match_format_regex (tmvp_config )
117148 if match_format .search (response ) is not None :
@@ -144,6 +175,7 @@ def evaluate(
144175 num_passes = 1 ,
145176 corr_lst = False ,
146177 make_lst = False ,
178+ eval_mode = None ,
147179):
148180 """
149181 Computes accuracy and percentage of outputs matching the format.
@@ -155,10 +187,14 @@ def evaluate(
155187 num_passes: Number of generation passes
156188 corr_lst: If True, only include correct responses in the list
157189 make_lst: If True, return a list of (question, answer, responses)
190+ eval_mode: Override for the evaluation mode ("pass" or "maj").
158191
159192 Returns:
160193 Tuple of statistics and optionally the response list
161194 """
195+ if eval_mode is None :
196+ eval_mode = getattr (tmvp_config , "eval_mode" , "pass" )
197+
162198 response_lst = []
163199 corr = 0
164200 partially_corr = 0
@@ -187,6 +223,7 @@ def evaluate(
187223 question = question ,
188224 responses = responses ,
189225 answers = answer ,
226+ eval_mode = eval_mode ,
190227 )
191228
192229 # Update counters
0 commit comments