2828import itertools
2929import multiprocessing
3030import os
31+ import uuid
32+ from etils import epath
3133
3234# Module-level persistent pool state.
3335_POOL = None
@@ -71,7 +73,7 @@ def silent_worker_init():
7173 pass
7274
7375
74- def verify_math_worker (golds , predictions ):
76+ def verify_math_worker (idx , golds , predictions , parse_log_path , verify_log_path ):
7577 """Worker-side math_verify grader."""
7678 try :
7779 from math_verify import parse , verify
@@ -80,6 +82,16 @@ def verify_math_worker(golds, predictions):
8082 gold_targets = (ExprExtractionConfig (), LatexExtractionConfig ())
8183 pred_targets = (ExprExtractionConfig (), LatexExtractionConfig ())
8284
85+ log_file = parse_log_path / f"{ idx } .txt"
86+ log_content = (
87+ "START ============================\n "
88+ f"Index: { idx } \n "
89+ f"Golds: { golds } \n "
90+ f"Predictions: { predictions } \n "
91+ "END ==============================\n "
92+ )
93+ log_file .write_text (log_content )
94+
8395 extracted_predictions = list (
8496 itertools .chain .from_iterable (parse (pred , pred_targets , parsing_timeout = None ) for pred in predictions )
8597 )
@@ -89,7 +101,17 @@ def verify_math_worker(golds, predictions):
89101 if not extracted_predictions or not extracted_golds :
90102 return 0.0
91103
92- print ("Calling verify for extracted_golds: " , extracted_golds , " and predictions: " , extracted_predictions , " golds: " , golds , " predictions: " , predictions )
104+ log_file = verify_log_path / f"{ idx } .txt"
105+ log_content = (
106+ "START ============================\n "
107+ f"Index: { idx } \n "
108+ f"Golds: { golds } \n "
109+ f"Extracted Golds: { extracted_golds } \n "
110+ f"Predictions: { predictions } \n "
111+ f"Extracted Predictions: { extracted_predictions } \n "
112+ "END ==============================\n "
113+ )
114+ log_file .write_text (log_content )
93115 return max (
94116 (1.0 if any (verify (gold , pred , timeout_seconds = None ) for gold in extracted_golds ) else 0.0 )
95117 for pred in extracted_predictions
@@ -150,6 +172,31 @@ def math_verify_pool(items, timeout=15, num_procs=None, log_fn=None):
150172
151173 print (f"math_verify_pool called for { len (items )} items" )
152174
175+ call_id = uuid .uuid4 ().hex
176+ print ("Call id: " , call_id )
177+
178+ logs_dir = "gs://nicogrande-maxtext-logs/debug/logs/run/rl/math_verify"
179+ pool_log_path = epath .Path (tmvp_config .base_output_directory ) / "pool"
180+ pool_log_path .mkdir (parents = True , exist_ok = True )
181+
182+ parse_log_path = epath .Path (tmvp_config .base_output_directory ) / "parse" / f"{ call_id } "
183+ parse_log_path .mkdir (parents = True , exist_ok = True )
184+
185+ verify_log_path = epath .Path (tmvp_config .base_output_directory ) / "verify" / f"{ call_id } "
186+ verify_log_path .mkdir (parents = True , exist_ok = True )
187+
188+ result_log_path = epath .Path (tmvp_config .base_output_directory ) / "worker" / f"{ call_id } "
189+ result_log_path .mkdir (parents = True , exist_ok = True )
190+
191+ # log pool
192+ log_file = pool_log_path / f"{ call_id } .txt"
193+ log_content = (
194+ "START ============================\n "
195+ f"Items: { items } \n "
196+ "END ==============================\n "
197+ )
198+ log_file .write_text (log_content )
199+
153200 cpu_count = multiprocessing .cpu_count ()
154201 if num_procs is None :
155202 num_procs = min (_DEFAULT_MAX_PROCS , len (items ), cpu_count )
@@ -161,10 +208,18 @@ def math_verify_pool(items, timeout=15, num_procs=None, log_fn=None):
161208 pool = _get_pool (num_procs )
162209 saw_timeout = False
163210 try :
164- jobs = [pool .apply_async (verify_math_worker , (golds , predictions )) for (_ , golds , predictions ) in items ]
211+ jobs = [pool .apply_async (verify_math_worker , (idx , golds , predictions , parse_log_path , verify_log_path )) for (idx , golds , predictions ) in items ]
165212 for i , job in enumerate (jobs ):
166213 try :
167214 results [i ] = float (job .get (timeout = timeout ))
215+ log_file = result_log_path / f"{ i } .txt"
216+ log_content = (
217+ "START ============================\n "
218+ f"Idx: { i } \n "
219+ f"Score: { results [i ]} \n "
220+ "END ==============================\n "
221+ )
222+ log_file .write_text (log_content )
168223 cnt += 1
169224 except multiprocessing .TimeoutError :
170225 if log_fn is not None :
0 commit comments