Skip to content

Commit 926365c

Browse files
committed
update model
1 parent c910095 commit 926365c

13 files changed

Lines changed: 66 additions & 37 deletions

File tree

.github/workflows/ci-api.yml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ jobs:
4848
timeout-minutes: 45
4949
env:
5050
PYTHONUNBUFFERED: "1"
51+
UV_PYTHON: "3.11"
5152

5253
steps:
5354
- name: Checkout
@@ -72,7 +73,13 @@ jobs:
7273
ln -sf ../bin/python .venv/Scripts/python.exe
7374
7475
- name: Python file length policy
75-
run: uv run python scripts/check_python_max_lines.py --max-lines 500 --path src --path tests --path scripts --path train.py
76+
run: |
77+
uv run python scripts/check_python_max_lines.py \
78+
--max-lines 500 \
79+
--path src/inference \
80+
--path src/game \
81+
--path src/data \
82+
--path scripts
7683
7784
- name: Ruff (API scope)
7885
run: uv run ruff check src tests scripts

.github/workflows/ci-train.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ jobs:
4646
timeout-minutes: 35
4747
env:
4848
PYTHONUNBUFFERED: "1"
49+
UV_PYTHON: "3.11"
4950

5051
steps:
5152
- name: Checkout
@@ -80,8 +81,11 @@ jobs:
8081
--path src/game \
8182
--path src/data \
8283
--path tests/test_mcts_numerics.py \
84+
--path tests/test_training_bootstrap.py \
85+
--path tests/test_training_checkpointing.py \
86+
--path tests/test_training_curriculum.py \
87+
--path tests/test_training_monitor.py \
8388
--path tests/test_training_step_numerics.py \
84-
--path tests/test_training_*.py \
8589
--path scripts/export_model_onnx.py \
8690
--path scripts/check_onnx_parity.py
8791

Dockerfile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ COPY web/ ./
1010
RUN npm run build -- --base=/web/static/
1111

1212

13-
FROM python:3.10-slim AS py-builder
13+
FROM python:3.11-slim AS py-builder
1414

1515
ENV PYTHONDONTWRITEBYTECODE=1 \
1616
PYTHONUNBUFFERED=1 \
@@ -29,7 +29,7 @@ COPY pyproject.toml uv.lock README.md ./
2929
RUN uv sync --frozen --no-dev --group api
3030

3131

32-
FROM python:3.10-slim AS runtime
32+
FROM python:3.11-slim AS runtime
3333

3434
ENV PYTHONDONTWRITEBYTECODE=1 \
3535
PYTHONUNBUFFERED=1 \

Dockerfile.train

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# syntax=docker/dockerfile:1.7
22

3-
FROM python:3.10-slim AS train-runtime
3+
FROM python:3.11-slim AS train-runtime
44

55
ENV PYTHONDONTWRITEBYTECODE=1 \
66
PYTHONUNBUFFERED=1 \

infra/runpod-train/__main__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
if str(SRC) not in sys.path:
1212
sys.path.insert(0, str(SRC))
1313

14-
from training.runpod_infra import build_pod_env, build_train_start_command
14+
from training.runpod_infra import build_pod_env, build_train_start_command # noqa: E402
1515

1616
cfg = pulumi.Config()
1717

main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
def main():
1+
def main() -> None:
22
print("Hello from ataxx-zero!")
33

44

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ where = ["src"]
2323
[tool.ruff]
2424
line-length = 88
2525
target-version = "py310"
26+
extend-exclude = ["*.ipynb"]
2627

2728
[tool.repo_policy]
2829
max_python_file_lines = 500

pyrefly.toml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,20 @@
11
project-includes = [
22
"src/**/*.py",
3-
"main.py",
3+
"tests/**/*.py",
4+
"scripts/**/*.py",
5+
"train.py",
6+
"train_improved.py",
47
]
58

69
project-excludes = [
710
".venv/**",
811
"**/__pycache__",
912
"**/.mypy_cache",
1013
"**/.pytest_cache",
14+
"**/*.ipynb",
15+
"infra/**",
16+
"web/**",
17+
".github/**",
1118
]
1219

1320
python-interpreter-path = ".venv/Scripts/python.exe"

scripts/check_onnx_parity.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import argparse
4+
import importlib
45
import sys
56
from pathlib import Path
67
from typing import TYPE_CHECKING, Any
@@ -62,7 +63,7 @@ def main() -> None:
6263
_ensure_src_on_path()
6364

6465
try:
65-
import onnxruntime as ort
66+
ort = importlib.import_module("onnxruntime")
6667
except ImportError as exc:
6768
raise RuntimeError(
6869
"onnxruntime is required for parity checks. Install with `uv add --group dev onnxruntime`."

scripts/play_pygame.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,11 @@
4848
from game.board import AtaxxBoard
4949
from game.types import Move
5050
from model.system import AtaxxZero
51+
from ui.arena.effects import Particle
5152

5253
PLAYER_1 = 1
5354
PLAYER_2 = -1
5455

55-
Particle = dict[str, float | tuple[int, int, int]]
56-
57-
5856
def _ensure_src_on_path() -> None:
5957
if str(_SRC) not in sys.path:
6058
sys.path.insert(0, str(_SRC))
@@ -621,10 +619,11 @@ def main() -> None:
621619

622620
if ai_turn and pending_apply_at is None:
623621
player = board.current_player
624-
if ai_ready_at[player] is None:
622+
ready_at = ai_ready_at[player]
623+
if ready_at is None:
625624
ai_ready_at[player] = now_ms + _ai_delay_ms(board, turn_agent, args.mcts_sims, rng)
626625
status = f"{turn_agent} thinking..."
627-
elif now_ms >= int(ai_ready_at[player]):
626+
elif now_ms >= ready_at:
628627
move, move_text = _pick_ai_move(
629628
board=board,
630629
agent=turn_agent,

0 commit comments

Comments
 (0)