Skip to content

Commit c910095

Browse files
committed
add update workflows
1 parent d23661c commit c910095

12 files changed

Lines changed: 143 additions & 107 deletions

File tree

.github/workflows/build-app-image.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ jobs:
7575
with:
7676
context: .
7777
file: Dockerfile
78+
target: runtime
7879
push: ${{ github.event_name != 'pull_request' }}
7980
tags: ${{ steps.meta.outputs.tags }}
8081
labels: ${{ steps.meta.outputs.labels }}

.github/workflows/ci-api.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ jobs:
5656
- name: Setup Python
5757
uses: actions/setup-python@v5
5858
with:
59-
python-version: "3.10"
59+
python-version: "3.11"
6060

6161
- name: Setup uv
6262
uses: astral-sh/setup-uv@v4

.github/workflows/ci-train.yml

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ jobs:
5454
- name: Setup Python
5555
uses: actions/setup-python@v5
5656
with:
57-
python-version: "3.10"
57+
python-version: "3.11"
5858

5959
- name: Setup uv
6060
uses: astral-sh/setup-uv@v4
@@ -70,7 +70,20 @@ jobs:
7070
ln -sf ../bin/python .venv/Scripts/python.exe
7171
7272
- name: Python file length policy
73-
run: uv run python scripts/check_python_max_lines.py --max-lines 500 --path train.py --path src --path tests --path scripts
73+
run: |
74+
uv run python scripts/check_python_max_lines.py \
75+
--max-lines 500 \
76+
--path train.py \
77+
--path src/training \
78+
--path src/engine \
79+
--path src/model \
80+
--path src/game \
81+
--path src/data \
82+
--path tests/test_mcts_numerics.py \
83+
--path tests/test_training_step_numerics.py \
84+
--path tests/test_training_*.py \
85+
--path scripts/export_model_onnx.py \
86+
--path scripts/check_onnx_parity.py
7487
7588
- name: Ruff (train scope)
7689
run: uv run ruff check train.py src/engine src/model src/game src/data tests scripts

Dockerfile.train

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ RUN apt-get update && \
1616
RUN pip install --no-cache-dir uv
1717

1818
COPY pyproject.toml uv.lock README.md ./
19+
COPY src ./src
1920
RUN uv sync --frozen --no-dev --group train --group export
2021

21-
COPY src ./src
2222
COPY train.py ./train.py
2323
COPY train_improved.py ./train_improved.py
2424

src/training/eval_runtime.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
import numpy as np
6+
7+
from game.actions import ACTION_SPACE
8+
from training.config_runtime import cfg_bool, cfg_int
9+
from training.selfplay_runtime import compute_action_probs, heuristic_move
10+
11+
if TYPE_CHECKING:
12+
from engine.mcts import MCTS
13+
from model.system import AtaxxZero
14+
15+
16+
def _play_eval_episode(
17+
mcts: MCTS,
18+
rng: np.random.Generator,
19+
heuristic_level: str,
20+
) -> int:
21+
from game.board import AtaxxBoard
22+
23+
board = AtaxxBoard()
24+
root = None
25+
model_player = 1 if float(rng.random()) >= 0.5 else -1
26+
while not board.is_game_over():
27+
if board.current_player == model_player:
28+
probs, root = compute_action_probs(
29+
board=board,
30+
mcts=mcts,
31+
root=root,
32+
add_noise=False,
33+
temperature=0.0,
34+
)
35+
action_idx = int(np.argmax(probs))
36+
board.step(ACTION_SPACE.decode(action_idx))
37+
root = mcts.advance_root(root, action_idx)
38+
continue
39+
move = heuristic_move(board, rng, heuristic_level)
40+
board.step(move)
41+
root = mcts.advance_root(root, ACTION_SPACE.encode(move))
42+
winner = board.get_result()
43+
if winner == model_player:
44+
return 1
45+
if winner == 0:
46+
return 0
47+
return -1
48+
49+
50+
def evaluate_model(
51+
system: AtaxxZero,
52+
device: str,
53+
games: int,
54+
sims: int,
55+
c_puct: float,
56+
heuristic_level: str,
57+
seed: int,
58+
) -> dict[str, float | int | str]:
59+
from engine.mcts import MCTS
60+
61+
system.eval()
62+
system.to(device)
63+
mcts = MCTS(
64+
model=system.model,
65+
c_puct=c_puct,
66+
n_simulations=sims,
67+
device=device,
68+
use_amp=cfg_bool("mcts_use_amp"),
69+
cache_size=max(0, cfg_int("mcts_cache_size")),
70+
leaf_batch_size=max(1, cfg_int("mcts_leaf_batch_size")),
71+
)
72+
rng = np.random.default_rng(seed=seed)
73+
wins = 0
74+
losses = 0
75+
draws = 0
76+
for _ in range(games):
77+
outcome = _play_eval_episode(mcts, rng, heuristic_level)
78+
if outcome > 0:
79+
wins += 1
80+
elif outcome < 0:
81+
losses += 1
82+
else:
83+
draws += 1
84+
score = (wins + 0.5 * draws) / max(1, games)
85+
return {
86+
"games": games,
87+
"wins": wins,
88+
"losses": losses,
89+
"draws": draws,
90+
"score": score,
91+
"heuristic_level": heuristic_level,
92+
"sims": sims,
93+
}
94+
95+
96+
__all__ = [
97+
"evaluate_model",
98+
]

src/training/selfplay_runtime.py

Lines changed: 0 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -247,86 +247,6 @@ def history_to_examples(
247247
return examples
248248

249249

250-
def _play_eval_episode(
251-
mcts: MCTS,
252-
rng: np.random.Generator,
253-
heuristic_level: str,
254-
) -> int:
255-
from game.board import AtaxxBoard
256-
257-
board = AtaxxBoard()
258-
root = None
259-
model_player = 1 if float(rng.random()) >= 0.5 else -1
260-
while not board.is_game_over():
261-
if board.current_player == model_player:
262-
probs, root = compute_action_probs(
263-
board=board,
264-
mcts=mcts,
265-
root=root,
266-
add_noise=False,
267-
temperature=0.0,
268-
)
269-
action_idx = int(np.argmax(probs))
270-
board.step(ACTION_SPACE.decode(action_idx))
271-
root = mcts.advance_root(root, action_idx)
272-
continue
273-
move = heuristic_move(board, rng, heuristic_level)
274-
board.step(move)
275-
root = mcts.advance_root(root, ACTION_SPACE.encode(move))
276-
winner = board.get_result()
277-
if winner == model_player:
278-
return 1
279-
if winner == 0:
280-
return 0
281-
return -1
282-
283-
284-
def evaluate_model(
285-
system: AtaxxZero,
286-
device: str,
287-
games: int,
288-
sims: int,
289-
c_puct: float,
290-
heuristic_level: str,
291-
seed: int,
292-
) -> dict[str, float | int | str]:
293-
from engine.mcts import MCTS
294-
295-
system.eval()
296-
system.to(device)
297-
mcts = MCTS(
298-
model=system.model,
299-
c_puct=c_puct,
300-
n_simulations=sims,
301-
device=device,
302-
use_amp=cfg_bool("mcts_use_amp"),
303-
cache_size=max(0, cfg_int("mcts_cache_size")),
304-
leaf_batch_size=max(1, cfg_int("mcts_leaf_batch_size")),
305-
)
306-
rng = np.random.default_rng(seed=seed)
307-
wins = 0
308-
losses = 0
309-
draws = 0
310-
for _ in range(games):
311-
outcome = _play_eval_episode(mcts, rng, heuristic_level)
312-
if outcome > 0:
313-
wins += 1
314-
elif outcome < 0:
315-
losses += 1
316-
else:
317-
draws += 1
318-
score = (wins + 0.5 * draws) / max(1, games)
319-
return {
320-
"games": games,
321-
"wins": wins,
322-
"losses": losses,
323-
"draws": draws,
324-
"score": score,
325-
"heuristic_level": heuristic_level,
326-
"sims": sims,
327-
}
328-
329-
330250
def execute_self_play(
331251
system: AtaxxZero,
332252
buffer: ReplayBuffer,
@@ -501,6 +421,5 @@ def execute_self_play(
501421

502422

503423
__all__ = [
504-
"evaluate_model",
505424
"execute_self_play",
506425
]

train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,10 @@
4242
parse_args,
4343
validate_config,
4444
)
45+
from training.eval_runtime import evaluate_model # noqa: E402
4546
from training.monitor import TrainingMonitor # noqa: E402
4647
from training.progress_callbacks import EpochPulseCallback # noqa: E402
47-
from training.selfplay_runtime import evaluate_model, execute_self_play # noqa: E402
48+
from training.selfplay_runtime import execute_self_play # noqa: E402
4849
from training.trainer_runtime import ( # noqa: E402
4950
build_trainer,
5051
export_onnx,

web/src/pages/match/MatchPage.test.tsx

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import { fireEvent, render, screen, waitFor } from "@testing-library/react";
22
import type { ReactNode } from "react";
33
import { beforeEach, describe, expect, it, vi } from "vitest";
4+
import type { PersistedGameWsEvent } from "@/features/match/persistence";
45
import type { BoardState } from "@/features/match/types";
56
import { MatchPage } from "@/pages/match/MatchPage";
67

@@ -465,26 +466,26 @@ describe("MatchPage queued human vs human", () => {
465466
}),
466467
);
467468

468-
let wsEventHandler: ((event: import("@/features/match/persistence").PersistedGameWsEvent) => void) | null = null;
469-
openPersistedGameSocketMock.mockImplementation(
470-
(_token: string, _gameId: string, onEvent: (event: import("@/features/match/persistence").PersistedGameWsEvent) => void) => {
471-
wsEventHandler = onEvent;
472-
return {
473-
close: vi.fn(),
474-
onclose: null,
475-
onmessage: null,
476-
};
477-
},
478-
);
469+
let wsEventHandler: ((event: PersistedGameWsEvent) => void) | null = null;
470+
openPersistedGameSocketMock.mockImplementation((...args: unknown[]) => {
471+
// Vitest mocks default to unknown/any signatures; cast the event callback explicitly.
472+
wsEventHandler = args[2] as (event: PersistedGameWsEvent) => void;
473+
return {
474+
close: vi.fn(),
475+
onclose: null,
476+
onmessage: null,
477+
};
478+
});
479479

480480
render(<MatchPage />);
481481

482482
await waitFor(() => {
483483
expect(screen.getByText(/humano vs humano/i)).toBeInTheDocument();
484484
});
485-
expect(wsEventHandler).not.toBeNull();
486-
487-
wsEventHandler?.({
485+
if (wsEventHandler === null) {
486+
throw new Error("Expected websocket handler to be initialized.");
487+
}
488+
(wsEventHandler as (event: PersistedGameWsEvent) => void)({
488489
type: "game.closed",
489490
game_id: "game-h2h",
490491
reason: "deleted_by_participant",

web/src/pages/match/MatchPage.tsx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import { useCallback, useEffect, useMemo, useRef, useState, type MouseEvent } from "react";
22
import { Link, useLocation, useNavigate } from "react-router-dom";
3-
import { AnimatePresence, motion } from "framer-motion";
3+
import { AnimatePresence, motion, type Variants } from "framer-motion";
44
import {
55
Activity,
66
ArrowRight,
@@ -83,12 +83,12 @@ const SFX = {
8383
queueDeploy: "/sfx/queue_accept.ogg",
8484
} as const;
8585

86-
const panelSectionVariants = {
86+
const panelSectionVariants: Variants = {
8787
hidden: { opacity: 0, y: 8 },
8888
show: (delay = 0) => ({
8989
opacity: 1,
9090
y: 0,
91-
transition: { duration: 0.32, ease: "easeOut", delay },
91+
transition: { duration: 0.32, ease: "easeOut" as const, delay },
9292
}),
9393
};
9494

web/src/test/render.tsx

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
22
import { render, type RenderOptions } from "@testing-library/react";
3-
import type { ReactElement } from "react";
4-
import { MemoryRouter, type InitialEntry } from "react-router-dom";
3+
import type { ComponentProps, ReactElement } from "react";
4+
import { MemoryRouter } from "react-router-dom";
5+
6+
type MemoryRouterEntry = NonNullable<ComponentProps<typeof MemoryRouter>["initialEntries"]>[number];
57

68
export function renderWithProviders(
79
ui: ReactElement,
8-
{ route = "/", ...options }: { route?: InitialEntry } & Omit<RenderOptions, "wrapper"> = {},
10+
{ route = "/", ...options }: { route?: MemoryRouterEntry } & Omit<RenderOptions, "wrapper"> = {},
911
) {
1012
const queryClient = new QueryClient({
1113
defaultOptions: {

0 commit comments

Comments
 (0)