-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathBest_Checkpoint_Finder.py
More file actions
62 lines (52 loc) · 2.37 KB
/
Copy pathBest_Checkpoint_Finder.py
File metadata and controls
62 lines (52 loc) · 2.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# -*- coding: utf-8 -*-
import os, json, glob, re
def get_best_checkpoint(run_dir: str) -> str:
# 1) Prefer the root trainer_state.json if it exists
root_state = os.path.join(run_dir, "trainer_state.json")
if os.path.exists(root_state):
state = json.load(open(root_state, "r", encoding="utf-8"))
ckpt = state.get("best_model_checkpoint")
if ckpt:
if not os.path.exists(ckpt):
ckpt = os.path.join(run_dir, os.path.basename(ckpt))
return ckpt
# 2) Otherwise, read trainer_state.json from the latest checkpoint folder
ckpt_dirs = glob.glob(os.path.join(run_dir, "checkpoint-*"))
def step_num(p: str) -> int:
m = re.search(r"checkpoint-(\d+)$", p.replace("\\", "/"))
return int(m.group(1)) if m else -1
ckpt_dirs = sorted(ckpt_dirs, key=step_num)
for ckpt_dir in reversed(ckpt_dirs):
state_path = os.path.join(ckpt_dir, "trainer_state.json")
if not os.path.exists(state_path):
continue
state = json.load(open(state_path, "r", encoding="utf-8"))
ckpt = state.get("best_model_checkpoint")
if ckpt:
if not os.path.exists(ckpt):
ckpt = os.path.join(run_dir, os.path.basename(ckpt))
return ckpt
# 3) Last resort, compute best by min eval_loss seen in each checkpoint state
best = None
for ckpt_dir in ckpt_dirs:
state_path = os.path.join(ckpt_dir, "trainer_state.json")
if not os.path.exists(state_path):
continue
state = json.load(open(state_path, "r", encoding="utf-8"))
eval_losses = [r["eval_loss"] for r in state.get("log_history", []) if "eval_loss" in r]
if not eval_losses:
continue
score = float(min(eval_losses))
if best is None or score < best[0]:
best = (score, ckpt_dir)
if best is None:
raise FileNotFoundError("Could not find trainer_state.json in run_dir or any checkpoint-* folder.")
return best[1]
# run_dir = r"Step15\model_runs_chunk_pool\BERT"
# run_dir = r"Step14\model_runs_chunk_pool\BERT"
run_dir = r"Step17\model_runs_chunk_pool\BERT"
# run_dir = r"Step17\model_runs_chunk_pool\DeBERTaV3"
# run_dir = r"Step17\model_runs_chunk_pool\DistilBERT"
# run_dir = r"Step17\model_runs_chunk_pool\FinBERT"
best_ckpt = get_best_checkpoint(run_dir)
print("Best checkpoint =", best_ckpt)