Skip to content

Commit da1f067

Browse files
committed
Add legacy checkpoint support for inference service
1 parent 44ca9e5 commit da1f067

3 files changed

Lines changed: 193 additions & 3 deletions

File tree

src/inference/legacy_model.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
from __future__ import annotations
2+
3+
import torch
4+
import torch.nn as nn
5+
6+
from game.actions import ACTION_SPACE
7+
from game.constants import BOARD_SIZE
8+
9+
10+
class LegacyAtaxxTransformerNet(nn.Module):
11+
"""Transformer legacy (3 canales + policy flatten) para checkpoints historicos."""
12+
13+
def __init__(
14+
self,
15+
d_model: int = 128,
16+
nhead: int = 8,
17+
num_layers: int = 6,
18+
dim_feedforward: int = 512,
19+
dropout: float = 0.1,
20+
) -> None:
21+
super().__init__()
22+
self.board_size = BOARD_SIZE
23+
self.num_cells = self.board_size * self.board_size
24+
self.num_actions = ACTION_SPACE.num_actions
25+
self.num_input_channels = 3
26+
27+
self.input_proj = nn.Linear(self.num_input_channels, d_model)
28+
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_cells + 1, d_model))
29+
self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
30+
31+
encoder_layer = nn.TransformerEncoderLayer(
32+
d_model=d_model,
33+
nhead=nhead,
34+
dim_feedforward=dim_feedforward,
35+
dropout=dropout,
36+
activation="gelu",
37+
batch_first=True,
38+
norm_first=False,
39+
)
40+
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
41+
42+
self.policy_head = nn.Sequential(
43+
nn.LayerNorm(d_model * self.num_cells),
44+
nn.Linear(d_model * self.num_cells, self.num_actions),
45+
)
46+
self.value_head = nn.Sequential(
47+
nn.LayerNorm(d_model),
48+
nn.Linear(d_model, d_model),
49+
nn.GELU(),
50+
nn.Dropout(dropout),
51+
nn.Linear(d_model, 1),
52+
nn.Tanh(),
53+
)
54+
55+
def forward(
56+
self,
57+
x: torch.Tensor,
58+
action_mask: torch.Tensor | None = None,
59+
) -> tuple[torch.Tensor, torch.Tensor]:
60+
batch_size = x.size(0)
61+
x = x.permute(0, 2, 3, 1).contiguous().view(
62+
batch_size,
63+
self.num_cells,
64+
self.num_input_channels,
65+
)
66+
x = self.input_proj(x)
67+
68+
cls = self.cls_token.expand(batch_size, -1, -1)
69+
tokens = torch.cat([cls, x], dim=1) + self.pos_embed
70+
encoded = self.encoder(tokens)
71+
72+
cls_out = encoded[:, 0]
73+
board_out = encoded[:, 1:].contiguous().view(batch_size, -1)
74+
policy_logits = self.policy_head(board_out)
75+
if action_mask is not None:
76+
min_value = torch.finfo(policy_logits.dtype).min
77+
policy_logits = policy_logits.masked_fill(action_mask <= 0, min_value)
78+
79+
value = self.value_head(cls_out)
80+
return policy_logits, value
81+
82+
83+
class LegacyAtaxxSystem(nn.Module):
84+
"""Wrapper compatible con state_dicts `model.*` de checkpoints legacy."""
85+
86+
def __init__(
87+
self,
88+
d_model: int = 128,
89+
nhead: int = 8,
90+
num_layers: int = 6,
91+
dim_feedforward: int = 512,
92+
dropout: float = 0.1,
93+
) -> None:
94+
super().__init__()
95+
self.model = LegacyAtaxxTransformerNet(
96+
d_model=d_model,
97+
nhead=nhead,
98+
num_layers=num_layers,
99+
dim_feedforward=dim_feedforward,
100+
dropout=dropout,
101+
)
102+

src/inference/service.py

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
if TYPE_CHECKING:
1717
from engine.mcts import MCTS
18-
from model.system import AtaxxZero
1918

2019
InferenceMode = Literal["fast", "strong"]
2120

@@ -57,6 +56,19 @@ def run(self, output_names: list[str] | None, input_feed: dict[str, Any]) -> lis
5756
...
5857

5958

59+
class _SystemLike(Protocol):
60+
model: Any
61+
62+
def eval(self) -> _SystemLike:
63+
...
64+
65+
def to(self, device: str) -> _SystemLike:
66+
...
67+
68+
def load_state_dict(self, state_dict: dict[str, object]) -> object:
69+
...
70+
71+
6072
@lru_cache(maxsize=1)
6173
def _get_torch_module() -> ModuleType | None:
6274
"""Import torch lazily so API startup does not hard-fail in lightweight runtimes."""
@@ -95,11 +107,15 @@ def __init__(
95107
self.c_puct = float(c_puct)
96108
self.model_kwargs: ModelInitKwargs = model_kwargs or {}
97109

98-
self.system: AtaxxZero | None = None
110+
self.system: _SystemLike | None = None
111+
self._model_input_channels = 4
99112
if self.checkpoint_path.exists():
100113
self.system = self._load_system()
101114
self.system.eval()
102115
self.system.to(self.device)
116+
self._model_input_channels = int(
117+
getattr(self.system.model, "num_input_channels", 4)
118+
)
103119

104120
self._onnx_session: _OnnxSessionLike | None = None
105121
self._onnx_last_error: str | None = None
@@ -132,7 +148,29 @@ def _require_torch() -> ModuleType:
132148
)
133149
return torch_module
134150

135-
def _load_system(self) -> AtaxxZero:
151+
@staticmethod
152+
def _is_legacy_state_dict(state_dict: dict[str, Any]) -> bool:
153+
has_legacy_policy = "model.policy_head.1.weight" in state_dict
154+
has_spatial_policy = "model.policy_src_proj.weight" in state_dict
155+
input_weight = state_dict.get("model.input_proj.weight")
156+
input_channels = None
157+
if hasattr(input_weight, "shape"):
158+
shape = tuple(input_weight.shape)
159+
if len(shape) == 2:
160+
input_channels = int(shape[1])
161+
return has_legacy_policy and not has_spatial_policy and input_channels == 3
162+
163+
@staticmethod
164+
def _extract_arch_kwargs(raw_kwargs: ModelInitKwargs) -> dict[str, Any]:
165+
allowed = ("d_model", "nhead", "num_layers", "dim_feedforward", "dropout")
166+
return {key: raw_kwargs[key] for key in allowed if key in raw_kwargs}
167+
168+
def _build_legacy_system(self) -> _SystemLike:
169+
from inference.legacy_model import LegacyAtaxxSystem
170+
171+
return LegacyAtaxxSystem(**self._extract_arch_kwargs(self.model_kwargs))
172+
173+
def _load_system(self) -> _SystemLike:
136174
from model.system import AtaxxZero
137175

138176
torch_module = self._require_torch()
@@ -157,6 +195,16 @@ def _load_system(self) -> AtaxxZero:
157195
try:
158196
system.load_state_dict(state_dict_obj)
159197
except RuntimeError as exc:
198+
if self._is_legacy_state_dict(state_dict_obj):
199+
legacy_system = self._build_legacy_system()
200+
try:
201+
legacy_system.load_state_dict(state_dict_obj)
202+
return legacy_system
203+
except RuntimeError as legacy_exc:
204+
raise ValueError(
205+
"Checkpoint incompatible con architecture policy_head espacial; "
206+
"reentrena o usa carga parcial manual (strict=False)."
207+
) from legacy_exc
160208
raise ValueError(
161209
"Checkpoint incompatible con architecture policy_head espacial; "
162210
"reentrena o usa carga parcial manual (strict=False)."
@@ -270,6 +318,8 @@ def _fast_result(self, board: AtaxxBoard) -> InferenceResult:
270318
torch_module = self._require_torch()
271319
mask_np = self._legal_action_mask(board)
272320
obs = board.get_observation()
321+
if obs.shape[0] != self._model_input_channels:
322+
obs = obs[: self._model_input_channels]
273323

274324
obs_tensor = torch_module.from_numpy(obs).unsqueeze(0).to(self.device)
275325
mask_tensor = torch_module.from_numpy(mask_np).unsqueeze(0).to(self.device)
@@ -302,6 +352,10 @@ def _strong_result(self, board: AtaxxBoard) -> InferenceResult:
302352
if self.system is None:
303353
# If no torch model is available, degrade gracefully to fast ONNX/Torch.
304354
return self._fast_result(board)
355+
if self._model_input_channels != 4:
356+
# Legacy checkpoints were trained with 3-channel observations and do
357+
# not support the current MCTS path that batches 4-channel states.
358+
return self._fast_result(board)
305359
torch_module = self._require_torch()
306360
mcts = self._ensure_mcts()
307361
probs = mcts.run(board=board, add_dirichlet_noise=False, temperature=0.0)
@@ -311,6 +365,8 @@ def _strong_result(self, board: AtaxxBoard) -> InferenceResult:
311365
# Value still comes from raw net (current-player perspective), which is stable and cheap.
312366
mask_np = self._legal_action_mask(board)
313367
obs = board.get_observation()
368+
if obs.shape[0] != self._model_input_channels:
369+
obs = obs[: self._model_input_channels]
314370
obs_tensor = torch_module.from_numpy(obs).unsqueeze(0).to(self.device)
315371
mask_tensor = torch_module.from_numpy(mask_np).unsqueeze(0).to(self.device)
316372
with torch_module.no_grad():

tests/test_inference_service.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from game.actions import ACTION_SPACE
1616
from game.board import AtaxxBoard
17+
from inference.legacy_model import LegacyAtaxxSystem
1718
from inference.service import InferenceService
1819
from model.system import AtaxxZero
1920

@@ -121,6 +122,37 @@ def test_rejects_missing_checkpoint(self) -> None:
121122
with self.assertRaises(FileNotFoundError):
122123
InferenceService(checkpoint_path="does/not/exist/model.pt", device="cpu")
123124

125+
def test_loads_legacy_checkpoint_and_predicts(self) -> None:
126+
with tempfile.TemporaryDirectory() as tmp_dir:
127+
legacy = LegacyAtaxxSystem(
128+
d_model=64,
129+
nhead=8,
130+
num_layers=2,
131+
dim_feedforward=128,
132+
dropout=0.0,
133+
)
134+
ckpt_path = Path(tmp_dir) / "legacy.pt"
135+
torch.save({"state_dict": legacy.state_dict()}, ckpt_path)
136+
137+
service = InferenceService(
138+
checkpoint_path=ckpt_path,
139+
device="cpu",
140+
model_kwargs={
141+
"d_model": 64,
142+
"nhead": 8,
143+
"num_layers": 2,
144+
"dim_feedforward": 128,
145+
"dropout": 0.0,
146+
},
147+
)
148+
board = AtaxxBoard()
149+
result = service.predict(board, mode="strong")
150+
151+
legal_moves = board.get_valid_moves()
152+
legal_idxs = {ACTION_SPACE.encode(mv) for mv in legal_moves}
153+
self.assertEqual(result.mode, "fast")
154+
self.assertIn(result.action_idx, legal_idxs)
155+
124156
def test_rejects_invalid_mode(self) -> None:
125157
with tempfile.TemporaryDirectory() as tmp_dir:
126158
system = self._tiny_system()

0 commit comments

Comments
 (0)