Skip to content

Commit 8519d05

Browse files
committed
Improve training signal and checkpoint comparison tooling
1 parent 523b04e commit 8519d05

15 files changed

Lines changed: 894 additions & 131 deletions

scripts/compare_checkpoints.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
from __future__ import annotations
2+
3+
import argparse
4+
import json
5+
import sys
6+
from pathlib import Path
7+
from typing import TYPE_CHECKING
8+
9+
import numpy as np
10+
import torch
11+
12+
if TYPE_CHECKING:
13+
from engine.mcts import MCTS
14+
from game.board import AtaxxBoard
15+
16+
17+
def _ensure_src_on_path() -> None:
18+
root = Path(__file__).resolve().parents[1]
19+
src = root / "src"
20+
if str(src) not in sys.path:
21+
sys.path.insert(0, str(src))
22+
23+
24+
def _parse_args() -> argparse.Namespace:
25+
parser = argparse.ArgumentParser(
26+
description="Run a short automated duel between two Ataxx checkpoints.",
27+
)
28+
parser.add_argument("--checkpoint-a", required=True, help="Path to checkpoint A (.pt/.ckpt).")
29+
parser.add_argument("--checkpoint-b", required=True, help="Path to checkpoint B (.pt/.ckpt).")
30+
parser.add_argument("--games", type=int, default=8, help="Number of games to play.")
31+
parser.add_argument("--device", default="auto", choices=["auto", "cpu", "cuda"])
32+
parser.add_argument("--mcts-sims", "--sims", type=int, default=96)
33+
parser.add_argument("--c-puct", type=float, default=1.5)
34+
parser.add_argument("--seed", type=int, default=42)
35+
parser.add_argument("--json", action="store_true", help="Print machine-readable JSON summary.")
36+
return parser.parse_args()
37+
38+
39+
def _resolve_device(device: str) -> str:
40+
if device == "auto":
41+
return "cuda" if torch.cuda.is_available() else "cpu"
42+
if device == "cuda" and not torch.cuda.is_available():
43+
print("CUDA requested but not available; falling back to CPU.")
44+
return "cpu"
45+
return device
46+
47+
48+
def _pick_model_action_idx(board: AtaxxBoard, mcts: MCTS) -> int:
49+
probs = mcts.run(board=board, add_dirichlet_noise=False, temperature=0.0)
50+
return int(np.argmax(probs))
51+
52+
53+
def main() -> None:
54+
args = _parse_args()
55+
_ensure_src_on_path()
56+
57+
from engine.mcts import MCTS
58+
from game.actions import ACTION_SPACE
59+
from game.board import AtaxxBoard
60+
from inference.checkpoint_duel_runtime import (
61+
build_match_schedule,
62+
load_system_from_checkpoint,
63+
summarize_match_results,
64+
)
65+
66+
checkpoint_a = Path(args.checkpoint_a)
67+
checkpoint_b = Path(args.checkpoint_b)
68+
if not checkpoint_a.exists():
69+
raise FileNotFoundError(f"Checkpoint A not found: {checkpoint_a}")
70+
if not checkpoint_b.exists():
71+
raise FileNotFoundError(f"Checkpoint B not found: {checkpoint_b}")
72+
73+
device = _resolve_device(args.device)
74+
system_a = load_system_from_checkpoint(checkpoint_a, device=device)
75+
system_b = load_system_from_checkpoint(checkpoint_b, device=device)
76+
mcts_a = MCTS(model=system_a.model, c_puct=args.c_puct, n_simulations=args.mcts_sims, device=device)
77+
mcts_b = MCTS(model=system_b.model, c_puct=args.c_puct, n_simulations=args.mcts_sims, device=device)
78+
79+
schedule = build_match_schedule(games=max(1, int(args.games)))
80+
rng = np.random.default_rng(seed=int(args.seed))
81+
results: list[dict[str, int]] = []
82+
83+
for idx, (checkpoint_a_player, checkpoint_b_player) in enumerate(schedule, start=1):
84+
board = AtaxxBoard()
85+
turn_seed = int(rng.integers(0, 2**31 - 1))
86+
torch.manual_seed(turn_seed)
87+
np.random.seed(turn_seed)
88+
turns = 0
89+
while not board.is_game_over():
90+
turns += 1
91+
if board.current_player == checkpoint_a_player:
92+
action_idx = _pick_model_action_idx(board, mcts_a)
93+
elif board.current_player == checkpoint_b_player:
94+
action_idx = _pick_model_action_idx(board, mcts_b)
95+
else:
96+
raise RuntimeError("Unexpected player assignment while comparing checkpoints.")
97+
board.step(ACTION_SPACE.decode(action_idx))
98+
99+
winner = board.get_result()
100+
results.append(
101+
{
102+
"winner": int(winner),
103+
"turns": turns,
104+
"checkpoint_a_player": checkpoint_a_player,
105+
},
106+
)
107+
color_a = "p1" if checkpoint_a_player == 1 else "p2"
108+
print(
109+
f"[{idx}/{len(schedule)}] "
110+
f"checkpoint_a={color_a} winner={winner} turns={turns}",
111+
)
112+
113+
summary = summarize_match_results(results=results)
114+
output: dict[str, float | int | str] = {
115+
**summary,
116+
"checkpoint_a": str(checkpoint_a),
117+
"checkpoint_b": str(checkpoint_b),
118+
"device": device,
119+
"mcts_sims": int(args.mcts_sims),
120+
}
121+
122+
if args.json:
123+
print(json.dumps(output, indent=2))
124+
return
125+
126+
print("")
127+
print("Summary")
128+
print(f" checkpoint_a: {checkpoint_a}")
129+
print(f" checkpoint_b: {checkpoint_b}")
130+
print(f" games: {summary['games']}")
131+
print(f" checkpoint_a_wins: {summary['checkpoint_a_wins']}")
132+
print(f" checkpoint_b_wins: {summary['checkpoint_b_wins']}")
133+
print(f" draws: {summary['draws']}")
134+
print(f" checkpoint_a_score: {float(summary['checkpoint_a_score']):.3f}")
135+
print(f" avg_turns: {float(summary['avg_turns']):.1f}")
136+
137+
138+
if __name__ == "__main__":
139+
main()

src/game/board.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,12 @@ def get_result(self) -> int:
242242
return WIN_P2
243243
return DRAW
244244

245+
def is_forced_draw(self) -> bool:
246+
"""Expose loop/cap draws so training can punish non-terminating play."""
247+
if not self.is_game_over():
248+
return False
249+
return self.half_moves >= 100 or max(self._position_counts.values(), default=0) >= 3
250+
245251
def get_canonical_form(self) -> np.ndarray:
246252
"""
247253
Current-player perspective:
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
from __future__ import annotations
2+
3+
from pathlib import Path
4+
from typing import TYPE_CHECKING, Any
5+
6+
import torch
7+
8+
if TYPE_CHECKING:
9+
from model.system import AtaxxZero
10+
11+
MatchSchedule = list[tuple[int, int]]
12+
MatchResult = dict[str, int]
13+
14+
15+
def build_match_schedule(*, games: int) -> MatchSchedule:
16+
if games <= 0:
17+
return []
18+
schedule: MatchSchedule = []
19+
for idx in range(games):
20+
checkpoint_a_player = 1 if idx % 2 == 0 else -1
21+
checkpoint_b_player = -checkpoint_a_player
22+
schedule.append((checkpoint_a_player, checkpoint_b_player))
23+
return schedule
24+
25+
26+
def summarize_match_results(*, results: list[MatchResult]) -> dict[str, float | int]:
27+
games = len(results)
28+
if games == 0:
29+
return {
30+
"games": 0,
31+
"checkpoint_a_wins": 0,
32+
"checkpoint_b_wins": 0,
33+
"draws": 0,
34+
"checkpoint_a_score": 0.0,
35+
"avg_turns": 0.0,
36+
}
37+
38+
checkpoint_a_wins = 0
39+
checkpoint_b_wins = 0
40+
draws = 0
41+
total_turns = 0
42+
for result in results:
43+
winner = int(result["winner"])
44+
checkpoint_a_player = int(result["checkpoint_a_player"])
45+
total_turns += int(result["turns"])
46+
if winner == 0:
47+
draws += 1
48+
elif winner == checkpoint_a_player:
49+
checkpoint_a_wins += 1
50+
else:
51+
checkpoint_b_wins += 1
52+
53+
checkpoint_a_score = (checkpoint_a_wins + (0.5 * draws)) / float(games)
54+
return {
55+
"games": games,
56+
"checkpoint_a_wins": checkpoint_a_wins,
57+
"checkpoint_b_wins": checkpoint_b_wins,
58+
"draws": draws,
59+
"checkpoint_a_score": checkpoint_a_score,
60+
"avg_turns": total_turns / float(games),
61+
}
62+
63+
64+
def load_system_from_checkpoint(checkpoint_path: Path, *, device: str) -> AtaxxZero:
65+
from model.system import AtaxxZero
66+
67+
if checkpoint_path.suffix == ".ckpt":
68+
return AtaxxZero.load_from_checkpoint(str(checkpoint_path), map_location=device)
69+
70+
payload = torch.load(str(checkpoint_path), map_location=device, weights_only=False)
71+
if not isinstance(payload, dict):
72+
raise ValueError("Invalid checkpoint format: expected dictionary.")
73+
state_dict_obj = payload.get("state_dict")
74+
if not isinstance(state_dict_obj, dict):
75+
raise ValueError("Checkpoint dictionary must contain key 'state_dict'.")
76+
77+
hparams = payload.get("hparams")
78+
kwargs: dict[str, Any] = {}
79+
if isinstance(hparams, dict):
80+
allowed = {"d_model", "nhead", "num_layers", "dim_feedforward", "dropout"}
81+
kwargs = {key: hparams[key] for key in allowed if key in hparams}
82+
83+
system = AtaxxZero(**kwargs)
84+
system.load_state_dict(state_dict_obj)
85+
system.eval()
86+
system.to(device)
87+
return system
88+
89+
90+
__all__ = [
91+
"build_match_schedule",
92+
"load_system_from_checkpoint",
93+
"summarize_match_results",
94+
]

src/training/bootstrap.py

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,15 @@
88
from data.replay_buffer import TrainingExample
99
from game.actions import ACTION_SPACE
1010
from game.board import AtaxxBoard
11+
from training.config_runtime import cfg_bool
12+
from training.reward_runtime import (
13+
HistoryEntry,
14+
compute_state_potential,
15+
compute_transition_shaping_reward,
16+
history_to_examples,
17+
)
1118

1219
HeuristicLevel = Literal["easy", "normal", "hard", "apex", "gambit", "sentinel"]
13-
HistoryEntry = tuple[np.ndarray, np.ndarray, int]
1420

1521

1622
def _one_hot_policy(action_idx: int) -> np.ndarray:
@@ -19,23 +25,6 @@ def _one_hot_policy(action_idx: int) -> np.ndarray:
1925
return policy
2026

2127

22-
def history_to_examples(
23-
game_history: list[HistoryEntry],
24-
winner: int,
25-
) -> list[TrainingExample]:
26-
"""Convert per-turn history into value targets from the acting player's perspective."""
27-
examples: list[TrainingExample] = []
28-
for observation, policy, player_at_turn in game_history:
29-
if winner == 0:
30-
z = 0.0
31-
elif winner == player_at_turn:
32-
z = 1.0
33-
else:
34-
z = -1.0
35-
examples.append((observation, policy, z))
36-
return examples
37-
38-
3928
def generate_imitation_data(
4029
*,
4130
n_games: int,
@@ -60,16 +49,34 @@ def generate_imitation_data(
6049
for _ in range(n_games):
6150
board = AtaxxBoard()
6251
game_history: list[HistoryEntry] = []
52+
shaping_enabled = cfg_bool("reward_shaping_enabled")
6353

6454
while not board.is_game_over():
6555
player_at_turn = int(board.current_player)
56+
observation = board.get_observation()
6657
move = heuristic_move(board=board, rng=rng, level=heuristic_level)
6758
action_idx = ACTION_SPACE.encode(move)
6859
policy = _one_hot_policy(action_idx)
69-
game_history.append((board.get_observation(), policy, player_at_turn))
60+
shaping_reward = 0.0
61+
before_potential = 0.0
62+
if shaping_enabled:
63+
before_potential = compute_state_potential(board, player_at_turn)
7064
board.step(move)
65+
if shaping_enabled:
66+
after_potential = compute_state_potential(board, player_at_turn)
67+
shaping_reward = compute_transition_shaping_reward(
68+
before_potential=before_potential,
69+
after_potential=after_potential,
70+
)
71+
game_history.append((observation, policy, player_at_turn, shaping_reward))
7172

7273
winner = board.get_result()
73-
all_examples.extend(history_to_examples(game_history=game_history, winner=winner))
74+
all_examples.extend(
75+
history_to_examples(
76+
game_history=game_history,
77+
winner=winner,
78+
forced_draw=board.is_forced_draw(),
79+
),
80+
)
7481

7582
return all_examples

0 commit comments

Comments
 (0)