Skip to content

Commit 810a303

Browse files
committed
Remove GCS logging
1 parent a4581d3 commit 810a303

1 file changed

Lines changed: 2 additions & 57 deletions

File tree

src/maxtext/trainers/post_train/rl/math_verify_pool.py

Lines changed: 2 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def are_equal_under_sympy(gold, prediction):
9494
return False
9595

9696

97-
def verify_math_worker(idx, golds, predictions, parse_log_path=None, verify_log_path=None):
97+
def verify_math_worker(idx, golds, predictions):
9898
"""Worker-side math_verify grader."""
9999
try:
100100
from math_verify import parse, verify
@@ -104,17 +104,6 @@ def verify_math_worker(idx, golds, predictions, parse_log_path=None, verify_log_
104104
gold_targets = (ExprExtractionConfig(), LatexExtractionConfig())
105105
pred_targets = (ExprExtractionConfig(), LatexExtractionConfig())
106106

107-
if parse_log_path is not None:
108-
log_file = parse_log_path / f"{idx}.txt"
109-
log_content = (
110-
"START ============================\n"
111-
f"Index: {idx}\n"
112-
f"Golds: {golds}\n"
113-
f"Predictions: {predictions}\n"
114-
"END ==============================\n"
115-
)
116-
log_file.write_text(log_content)
117-
118107
extracted_predictions = list(
119108
itertools.chain.from_iterable(parse(pred, pred_targets, parsing_timeout=None) for pred in predictions)
120109
)
@@ -124,19 +113,6 @@ def verify_math_worker(idx, golds, predictions, parse_log_path=None, verify_log_
124113
if not extracted_predictions or not extracted_golds:
125114
return idx, 0.0
126115

127-
if verify_log_path is not None:
128-
log_file = verify_log_path / f"{idx}.txt"
129-
log_content = (
130-
"START ============================\n"
131-
f"Index: {idx}\n"
132-
f"Golds: {golds}\n"
133-
f"Extracted Golds: {extracted_golds}\n"
134-
f"Predictions: {predictions}\n"
135-
f"Extracted Predictions: {extracted_predictions}\n"
136-
"END ==============================\n"
137-
)
138-
log_file.write_text(log_content)
139-
140116
for gold in extracted_golds:
141117
for pred in extracted_predictions:
142118
if isinstance(gold, (Basic, MatrixBase)) and isinstance(pred, (Basic, MatrixBase)):
@@ -228,29 +204,6 @@ def math_verify_pool(tmvp_config, items, scores, timeout=15, num_procs=None, log
228204
return scores
229205

230206
call_id = uuid.uuid4().hex
231-
logs_dir = tmvp_config.base_output_directory + tmvp_config.run_name + "/math_verify"
232-
pool_log_path = epath.Path(logs_dir) / "pool"
233-
pool_log_path.mkdir(parents=True, exist_ok=True)
234-
235-
parse_log_path = epath.Path(logs_dir) / "parse" / f"{call_id}"
236-
parse_log_path.mkdir(parents=True, exist_ok=True)
237-
238-
verify_log_path = epath.Path(logs_dir) / "verify" / f"{call_id}"
239-
verify_log_path.mkdir(parents=True, exist_ok=True)
240-
241-
result_log_path = epath.Path(logs_dir) / "worker" / f"{call_id}"
242-
result_log_path.mkdir(parents=True, exist_ok=True)
243-
244-
# log pool
245-
log_file = pool_log_path / f"{call_id}.txt"
246-
log_content = (
247-
"START ============================\n"
248-
f"Total Items: {len(items)}\n"
249-
f"Items: {items}\n"
250-
"END ==============================\n"
251-
)
252-
log_file.write_text(log_content)
253-
254207
cpu_count = multiprocessing.cpu_count()
255208
if num_procs is None:
256209
num_procs = min(_DEFAULT_MAX_PROCS, len(items), cpu_count)
@@ -260,7 +213,7 @@ def math_verify_pool(tmvp_config, items, scores, timeout=15, num_procs=None, log
260213
cnt = 0
261214
timout_job_cnt = 0
262215
pool = _get_pool(num_procs)
263-
active_jobs = [(idx, pool.apply_async(verify_math_worker, (idx, golds, predictions, parse_log_path, verify_log_path))) for (idx, golds, predictions) in items]
216+
active_jobs = [(idx, pool.apply_async(verify_math_worker, (idx, golds, predictions))) for (idx, golds, predictions) in items]
264217
start_time = time.time()
265218
global_timeout = 300
266219
while active_jobs and (time.time() - start_time < global_timeout):
@@ -271,14 +224,6 @@ def math_verify_pool(tmvp_config, items, scores, timeout=15, num_procs=None, log
271224
# .get(0) returns immediately since ready() was true
272225
_, score = job.get(0)
273226
scores[idx] = max(scores[idx], tmvp_config.reward_exact_answer)
274-
log_file = result_log_path / f"{idx}.txt"
275-
log_content = (
276-
"START ============================\n"
277-
f"Idx: {idx}\n"
278-
f"Score: {scores[idx]}\n"
279-
"END ==============================\n"
280-
)
281-
log_file.write_text(log_content)
282227
cnt += 1
283228
except Exception as e:
284229
if log_fn is not None:

0 commit comments

Comments
 (0)