Skip to content

Commit 523b04e

Browse files
committed
Improve training pipeline and local policy checks
1 parent 5728edf commit 523b04e

28 files changed

Lines changed: 1242 additions & 332 deletions

README.md

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ Main entrypoint is now:
5656
uv run python train.py
5757
```
5858

59-
`train_improved.py` is kept as compatibility wrapper.
59+
`train_improved.py` is kept as notebook compatibility wrapper and re-exports
60+
`CONFIG`, `parse_args`, `apply_cli_overrides`, `validate_config`, and `main`.
6061

6162
## Entrenamiento remoto (RunPod + Pulumi)
6263

@@ -204,7 +205,7 @@ uv run python train.py --no-onnx --quiet --keep-local-ckpts 2 --keep-log-version
204205
Kaggle 2x T4 (use both GPUs):
205206

206207
```bash
207-
uv run python train.py --no-onnx --quiet --devices 2 --strategy ddp --keep-local-ckpts 2 --keep-log-versions 1 --hf --hf-repo-id your_user/ataxx-zero --iterations 40 --episodes 70 --sims 420 --epochs 5 --batch-size 96 --lr 9e-4 --weight-decay 1e-4 --save-every 3
208+
uv run python train.py --no-onnx --quiet --devices 2 --strategy ddp_spawn --precision 16-mixed --num-workers 2 --persistent-workers --keep-local-ckpts 256 --keep-log-versions 2 --hf --hf-repo-id your_user/ataxx-zero --hf-run-id policy_spatial_v3 --hf-bootstrap-run-id policy_spatial_v2 --hf-reset-iteration --iterations 220 --episodes 20 --sims 160 --epochs 2 --batch-size 224 --lr 3e-4 --weight-decay 1e-4 --save-every 1 --warmup-games 240 --warmup-epochs 3 --warmup-heuristic-levels hard,apex,sentinel --eval-every 6 --eval-games 12 --eval-sims 160 --eval-heuristic-levels hard,apex,sentinel --restore-best-on-regression --eval-regression-delta 0.03 --eval-regression-patience 2 --allow-selfplay-fallback --max-pending-hf-uploads 6 --hf-upload-timeout-s 900
208209
```
209210

210211
Kaggle estable con `opponent pool` (recomendado):
@@ -216,9 +217,12 @@ uv run python train.py --no-onnx --quiet --devices 1 --strategy auto --keep-loca
216217
Kaggle estable + evaluacion automatica + best checkpoint:
217218

218219
```bash
219-
uv run python train.py --no-onnx --quiet --devices 1 --strategy auto --num-workers 3 --persistent-workers --keep-local-ckpts 2 --keep-log-versions 1 --hf --hf-repo-id your_user/ataxx-zero --iterations 40 --episodes 70 --sims 420 --epochs 5 --batch-size 96 --lr 9e-4 --weight-decay 1e-4 --save-every 3 --strict-probs --eval-every 3 --eval-games 12 --eval-sims 220 --eval-heuristic-level hard --opp-self 0.85 --opp-heuristic 0.12 --opp-random 0.03 --opp-heu-easy 0.05 --opp-heu-normal 0.20 --opp-heu-hard 0.75 --model-swap-prob 0.5
220+
uv run python train.py --no-onnx --quiet --devices 1 --strategy auto --num-workers 2 --persistent-workers --keep-local-ckpts 256 --keep-log-versions 2 --hf --hf-repo-id your_user/ataxx-zero --hf-run-id policy_spatial_v3 --hf-bootstrap-run-id policy_spatial_v2 --hf-reset-iteration --iterations 160 --episodes 16 --sims 128 --epochs 2 --batch-size 192 --lr 3e-4 --weight-decay 1e-4 --save-every 1 --warmup-games 180 --warmup-epochs 2 --warmup-heuristic-levels hard,apex,sentinel --eval-every 6 --eval-games 12 --eval-sims 128 --eval-heuristic-levels hard,apex,sentinel --restore-best-on-regression --eval-regression-delta 0.03 --eval-regression-patience 2 --allow-selfplay-fallback --max-pending-hf-uploads 6 --hf-upload-timeout-s 900
220221
```
221222

223+
Si Kaggle te asigna una `P100`, el trainer cae automaticamente a CPU para evitar el
224+
crash de compatibilidad `sm_60`; para mixed precision y DDP real necesitas `T4 x2`.
225+
222226
If your environment is missing ONNX tooling, use:
223227

224228
```bash

scripts/bootstrap_model_bot.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ async def _ensure_model_bot(args: argparse.Namespace) -> None:
3030
from api.db.enums import AgentType, BotKind
3131
from api.db.models import BotProfile, ModelVersion, User
3232
from api.db.session import get_engine, get_sessionmaker
33+
from api.modules.model_versions.repository import ModelVersionRepository
34+
from api.modules.model_versions.service import ModelVersionService
3335
from api.modules.ranking.repository import RankingRepository
3436
from api.modules.ranking.service import RankingService
3537

@@ -64,18 +66,12 @@ async def _ensure_model_bot(args: argparse.Namespace) -> None:
6466
await session.refresh(version)
6567

6668
if args.activate_version:
67-
await session.execute(
68-
# Keep one global active version when requested explicitly.
69-
ModelVersion.__table__.update()
70-
.where(col(ModelVersion.id) != version.id)
71-
.values(is_active=False)
69+
# Reuse the repository/service flow so activation semantics stay
70+
# consistent with the API and type-check cleanly.
71+
version_service = ModelVersionService(
72+
repository=ModelVersionRepository(session=session)
7273
)
73-
await session.execute(
74-
ModelVersion.__table__.update()
75-
.where(col(ModelVersion.id) == version.id)
76-
.values(is_active=True)
77-
)
78-
await session.commit()
74+
version = await version_service.activate_model_version(version.id)
7975

8076
user_stmt = select(User).where(col(User.username) == args.username)
8177
user = (await session.execute(user_stmt)).scalars().first()

src/data/dataset.py

Lines changed: 68 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
from collections import deque
4+
from typing import TYPE_CHECKING
45

56
import numpy as np
67
import torch
@@ -14,6 +15,9 @@
1415
_N_TRANSFORMS = 8
1516
_POLICY_INDEX_MAPS: np.ndarray | None = None
1617

18+
if TYPE_CHECKING:
19+
from data.replay_buffer import TrainingExample
20+
1721

1822
def _rotate_coord_ccw(r: int, c: int, k: int, size: int) -> tuple[int, int]:
1923
rr, cc = r, c
@@ -98,25 +102,65 @@ def _augment_policy(policy: np.ndarray, transform_id: int) -> np.ndarray:
98102
return pi_aug
99103

100104

105+
def split_train_val_examples(
106+
*,
107+
all_examples: list[TrainingExample],
108+
val_split: float,
109+
shuffle: bool,
110+
seed: int,
111+
) -> tuple[list[TrainingExample], list[TrainingExample]]:
112+
"""Split examples into disjoint train/val sets with optional seeded shuffling."""
113+
n_total = len(all_examples)
114+
if n_total == 0:
115+
return [], []
116+
n_val = int(n_total * val_split)
117+
n_val = min(max(0, n_val), n_total)
118+
n_train = n_total - n_val
119+
if n_val == 0:
120+
return list(all_examples), []
121+
if not shuffle:
122+
return list(all_examples[:n_train]), list(all_examples[n_train:])
123+
124+
rng = np.random.default_rng(seed=seed)
125+
val_indices = np.sort(rng.choice(n_total, size=n_val, replace=False))
126+
val_set = {int(i) for i in val_indices.tolist()}
127+
# Keep train set in chronological order so "recent" remains meaningful.
128+
train_indices = [idx for idx in range(n_total) if idx not in val_set]
129+
train_examples = [all_examples[idx] for idx in train_indices]
130+
val_examples = [all_examples[int(idx)] for idx in val_indices]
131+
return train_examples, val_examples
132+
133+
101134
class AtaxxDataset(Dataset[tuple[torch.Tensor, torch.Tensor, torch.Tensor]]):
102135
"""Dataset wrapper from replay buffer examples."""
103136

104137
def __init__(
105138
self,
106-
buffer: ReplayBuffer,
139+
buffer: ReplayBuffer | None = None,
107140
augment: bool = True,
108141
reference_buffer: bool = False,
109142
val_split: float = 0.1,
143+
examples: list[TrainingExample] | None = None,
110144
) -> None:
111145
self.augment = augment
112146
self.examples: list[tuple[np.ndarray, np.ndarray, float]] | deque[
113147
tuple[np.ndarray, np.ndarray, float]
114148
]
149+
if examples is not None:
150+
self.examples = list(examples)
151+
return
152+
if buffer is None:
153+
self.examples = []
154+
return
155+
115156
raw_examples = list(buffer.buffer) if reference_buffer else buffer.get_all()
116-
n_val = int(len(raw_examples) * val_split)
117-
n_train = len(raw_examples) - n_val
118-
# Keep train/validation disjoint so val loss is a true hold-out metric.
119-
self.examples = raw_examples[:n_train]
157+
train_examples, _ = split_train_val_examples(
158+
all_examples=raw_examples,
159+
val_split=val_split,
160+
shuffle=False,
161+
seed=0,
162+
)
163+
self.examples = train_examples
120164

121165
def __len__(self) -> int:
122166
return len(self.examples)
@@ -140,11 +184,26 @@ def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor, torch.Ten
140184
class ValidationDataset(Dataset[tuple[torch.Tensor, torch.Tensor, torch.Tensor]]):
141185
"""Hold-out validation split from replay buffer."""
142186

143-
def __init__(self, buffer: ReplayBuffer, split: float = 0.1) -> None:
187+
def __init__(
188+
self,
189+
buffer: ReplayBuffer | None = None,
190+
split: float = 0.1,
191+
examples: list[TrainingExample] | None = None,
192+
) -> None:
193+
if examples is not None:
194+
self.examples = list(examples)
195+
return
196+
if buffer is None:
197+
self.examples = []
198+
return
144199
all_examples = buffer.get_all()
145-
n_val = int(len(all_examples) * split)
146-
n_train = len(all_examples) - n_val
147-
self.examples = all_examples[n_train:] if n_val > 0 else []
200+
_, val_examples = split_train_val_examples(
201+
all_examples=all_examples,
202+
val_split=split,
203+
shuffle=False,
204+
seed=0,
205+
)
206+
self.examples = val_examples
148207

149208
def __len__(self) -> int:
150209
return len(self.examples)

src/data/replay_buffer.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,47 @@
1111
TrainingExample = tuple[Observation, PolicyTarget, float]
1212

1313

14+
def sample_recent_mix(
15+
examples: list[TrainingExample],
16+
*,
17+
recent_fraction: float,
18+
recent_window_fraction: float,
19+
seed: int | None = None,
20+
sample_size: int | None = None,
21+
) -> list[TrainingExample]:
22+
"""
23+
Build a training set biased toward recent samples while keeping global coverage.
24+
25+
The default behavior keeps dataset size unchanged and mixes:
26+
- `recent_fraction` from the most recent `recent_window_fraction` of samples,
27+
- the rest from the full training pool.
28+
"""
29+
if len(examples) == 0:
30+
return []
31+
32+
total = len(examples)
33+
sample_n = total if sample_size is None else max(1, min(int(sample_size), total))
34+
35+
recent_window_size = max(1, round(total * recent_window_fraction))
36+
recent_window = examples[-recent_window_size:]
37+
recent_n = round(sample_n * recent_fraction)
38+
recent_n = min(sample_n, max(0, recent_n))
39+
global_n = sample_n - recent_n
40+
41+
rng = np.random.default_rng(seed=seed)
42+
picked: list[TrainingExample] = []
43+
if recent_n > 0:
44+
recent_idx = rng.integers(0, len(recent_window), size=recent_n, endpoint=False)
45+
picked.extend(recent_window[int(i)] for i in recent_idx)
46+
if global_n > 0:
47+
global_idx = rng.integers(0, total, size=global_n, endpoint=False)
48+
picked.extend(examples[int(i)] for i in global_idx)
49+
if len(picked) > 1:
50+
order = rng.permutation(len(picked))
51+
picked = [picked[int(i)] for i in order]
52+
return picked
53+
54+
1455
class ReplayBuffer:
1556
"""FIFO replay buffer for self-play training examples."""
1657

src/training/bootstrap.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44

55
import numpy as np
66

7-
from agents.heuristic import heuristic_move
7+
from agents.heuristic import heuristic_move, is_supported_heuristic_level
88
from data.replay_buffer import TrainingExample
99
from game.actions import ACTION_SPACE
1010
from game.board import AtaxxBoard
1111

12-
HeuristicLevel = Literal["easy", "normal", "hard"]
12+
HeuristicLevel = Literal["easy", "normal", "hard", "apex", "gambit", "sentinel"]
1313
HistoryEntry = tuple[np.ndarray, np.ndarray, int]
1414

1515

@@ -51,6 +51,8 @@ def generate_imitation_data(
5151
"""
5252
if n_games <= 0:
5353
return []
54+
if not is_supported_heuristic_level(heuristic_level):
55+
raise ValueError(f"Unsupported heuristic level for warmup: {heuristic_level}")
5456

5557
rng = np.random.default_rng(seed=seed)
5658
all_examples: list[TrainingExample] = []

0 commit comments

Comments
 (0)