Skip to content

Commit 5b5bcfd

Browse files
committed
Add logs for math verify pool
1 parent 29fe1e9 commit 5b5bcfd

1 file changed

Lines changed: 58 additions & 3 deletions

File tree

src/maxtext/trainers/post_train/rl/math_verify_pool.py

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
import itertools
2929
import multiprocessing
3030
import os
31+
import uuid
32+
from etils import epath
3133

3234
# Module-level persistent pool state.
3335
_POOL = None
@@ -71,7 +73,7 @@ def silent_worker_init():
7173
pass
7274

7375

74-
def verify_math_worker(golds, predictions):
76+
def verify_math_worker(idx, golds, predictions, parse_log_path, verify_log_path):
7577
"""Worker-side math_verify grader."""
7678
try:
7779
from math_verify import parse, verify
@@ -80,6 +82,16 @@ def verify_math_worker(golds, predictions):
8082
gold_targets = (ExprExtractionConfig(), LatexExtractionConfig())
8183
pred_targets = (ExprExtractionConfig(), LatexExtractionConfig())
8284

85+
log_file = parse_log_path / f"{idx}.txt"
86+
log_content = (
87+
"START ============================\n"
88+
f"Index: {idx}\n"
89+
f"Golds: {golds}\n"
90+
f"Predictions: {predictions}\n"
91+
"END ==============================\n"
92+
)
93+
log_file.write_text(log_content)
94+
8395
extracted_predictions = list(
8496
itertools.chain.from_iterable(parse(pred, pred_targets, parsing_timeout=None) for pred in predictions)
8597
)
@@ -89,7 +101,17 @@ def verify_math_worker(golds, predictions):
89101
if not extracted_predictions or not extracted_golds:
90102
return 0.0
91103

92-
print("Calling verify for extracted_golds: ", extracted_golds, " and predictions: ", extracted_predictions, " golds: ", golds, " predictions: ", predictions)
104+
log_file = verify_log_path / f"{idx}.txt"
105+
log_content = (
106+
"START ============================\n"
107+
f"Index: {idx}\n"
108+
f"Golds: {golds}\n"
109+
f"Extracted Golds: {extracted_golds}\n"
110+
f"Predictions: {predictions}\n"
111+
f"Extracted Predictions: {extracted_predictions}\n"
112+
"END ==============================\n"
113+
)
114+
log_file.write_text(log_content)
93115
return max(
94116
(1.0 if any(verify(gold, pred, timeout_seconds=None) for gold in extracted_golds) else 0.0)
95117
for pred in extracted_predictions
@@ -150,6 +172,31 @@ def math_verify_pool(items, timeout=15, num_procs=None, log_fn=None):
150172

151173
print(f"math_verify_pool called for {len(items)} items")
152174

175+
call_id = uuid.uuid4().hex
176+
print("Call id: ", call_id)
177+
178+
logs_dir = "gs://nicogrande-maxtext-logs/debug/logs/run/rl/math_verify"
179+
pool_log_path = epath.Path(logs_dir) / "pool"
180+
pool_log_path.mkdir(parents=True, exist_ok=True)
181+
182+
parse_log_path = epath.Path(logs_dir) / "parse" / f"{call_id}"
183+
parse_log_path.mkdir(parents=True, exist_ok=True)
184+
185+
verify_log_path = epath.Path(logs_dir) / "verify" / f"{call_id}"
186+
verify_log_path.mkdir(parents=True, exist_ok=True)
187+
188+
result_log_path = epath.Path(logs_dir) / "worker" / f"{call_id}"
189+
result_log_path.mkdir(parents=True, exist_ok=True)
190+
191+
# log pool
192+
log_file = pool_log_path / f"{call_id}.txt"
193+
log_content = (
194+
"START ============================\n"
195+
f"Items: {items}\n"
196+
"END ==============================\n"
197+
)
198+
log_file.write_text(log_content)
199+
153200
cpu_count = multiprocessing.cpu_count()
154201
if num_procs is None:
155202
num_procs = min(_DEFAULT_MAX_PROCS, len(items), cpu_count)
@@ -161,10 +208,18 @@ def math_verify_pool(items, timeout=15, num_procs=None, log_fn=None):
161208
pool = _get_pool(num_procs)
162209
saw_timeout = False
163210
try:
164-
jobs = [pool.apply_async(verify_math_worker, (golds, predictions)) for (_, golds, predictions) in items]
211+
jobs = [pool.apply_async(verify_math_worker, (idx, golds, predictions, parse_log_path, verify_log_path)) for (idx, golds, predictions) in items]
165212
for i, job in enumerate(jobs):
166213
try:
167214
results[i] = float(job.get(timeout=timeout))
215+
log_file = result_log_path / f"{i}.txt"
216+
log_content = (
217+
"START ============================\n"
218+
f"Idx: {i}\n"
219+
f"Score: {results[i]}\n"
220+
"END ==============================\n"
221+
)
222+
log_file.write_text(log_content)
168223
cnt += 1
169224
except multiprocessing.TimeoutError:
170225
if log_fn is not None:

0 commit comments

Comments
 (0)