1515
1616if TYPE_CHECKING :
1717 from engine .mcts import MCTS
18- from model .system import AtaxxZero
1918
2019InferenceMode = Literal ["fast" , "strong" ]
2120
@@ -57,6 +56,19 @@ def run(self, output_names: list[str] | None, input_feed: dict[str, Any]) -> lis
5756 ...
5857
5958
59+ class _SystemLike (Protocol ):
60+ model : Any
61+
62+ def eval (self ) -> _SystemLike :
63+ ...
64+
65+ def to (self , device : str ) -> _SystemLike :
66+ ...
67+
68+ def load_state_dict (self , state_dict : dict [str , object ]) -> object :
69+ ...
70+
71+
6072@lru_cache (maxsize = 1 )
6173def _get_torch_module () -> ModuleType | None :
6274 """Import torch lazily so API startup does not hard-fail in lightweight runtimes."""
@@ -95,11 +107,15 @@ def __init__(
95107 self .c_puct = float (c_puct )
96108 self .model_kwargs : ModelInitKwargs = model_kwargs or {}
97109
98- self .system : AtaxxZero | None = None
110+ self .system : _SystemLike | None = None
111+ self ._model_input_channels = 4
99112 if self .checkpoint_path .exists ():
100113 self .system = self ._load_system ()
101114 self .system .eval ()
102115 self .system .to (self .device )
116+ self ._model_input_channels = int (
117+ getattr (self .system .model , "num_input_channels" , 4 )
118+ )
103119
104120 self ._onnx_session : _OnnxSessionLike | None = None
105121 self ._onnx_last_error : str | None = None
@@ -132,7 +148,29 @@ def _require_torch() -> ModuleType:
132148 )
133149 return torch_module
134150
135- def _load_system (self ) -> AtaxxZero :
151+ @staticmethod
152+ def _is_legacy_state_dict (state_dict : dict [str , Any ]) -> bool :
153+ has_legacy_policy = "model.policy_head.1.weight" in state_dict
154+ has_spatial_policy = "model.policy_src_proj.weight" in state_dict
155+ input_weight = state_dict .get ("model.input_proj.weight" )
156+ input_channels = None
157+ if hasattr (input_weight , "shape" ):
158+ shape = tuple (input_weight .shape )
159+ if len (shape ) == 2 :
160+ input_channels = int (shape [1 ])
161+ return has_legacy_policy and not has_spatial_policy and input_channels == 3
162+
163+ @staticmethod
164+ def _extract_arch_kwargs (raw_kwargs : ModelInitKwargs ) -> dict [str , Any ]:
165+ allowed = ("d_model" , "nhead" , "num_layers" , "dim_feedforward" , "dropout" )
166+ return {key : raw_kwargs [key ] for key in allowed if key in raw_kwargs }
167+
168+ def _build_legacy_system (self ) -> _SystemLike :
169+ from inference .legacy_model import LegacyAtaxxSystem
170+
171+ return LegacyAtaxxSystem (** self ._extract_arch_kwargs (self .model_kwargs ))
172+
173+ def _load_system (self ) -> _SystemLike :
136174 from model .system import AtaxxZero
137175
138176 torch_module = self ._require_torch ()
@@ -157,6 +195,16 @@ def _load_system(self) -> AtaxxZero:
157195 try :
158196 system .load_state_dict (state_dict_obj )
159197 except RuntimeError as exc :
198+ if self ._is_legacy_state_dict (state_dict_obj ):
199+ legacy_system = self ._build_legacy_system ()
200+ try :
201+ legacy_system .load_state_dict (state_dict_obj )
202+ return legacy_system
203+ except RuntimeError as legacy_exc :
204+ raise ValueError (
205+ "Checkpoint incompatible con architecture policy_head espacial; "
206+ "reentrena o usa carga parcial manual (strict=False)."
207+ ) from legacy_exc
160208 raise ValueError (
161209 "Checkpoint incompatible con architecture policy_head espacial; "
162210 "reentrena o usa carga parcial manual (strict=False)."
@@ -270,6 +318,8 @@ def _fast_result(self, board: AtaxxBoard) -> InferenceResult:
270318 torch_module = self ._require_torch ()
271319 mask_np = self ._legal_action_mask (board )
272320 obs = board .get_observation ()
321+ if obs .shape [0 ] != self ._model_input_channels :
322+ obs = obs [: self ._model_input_channels ]
273323
274324 obs_tensor = torch_module .from_numpy (obs ).unsqueeze (0 ).to (self .device )
275325 mask_tensor = torch_module .from_numpy (mask_np ).unsqueeze (0 ).to (self .device )
@@ -302,6 +352,10 @@ def _strong_result(self, board: AtaxxBoard) -> InferenceResult:
302352 if self .system is None :
303353 # If no torch model is available, degrade gracefully to fast ONNX/Torch.
304354 return self ._fast_result (board )
355+ if self ._model_input_channels != 4 :
356+ # Legacy checkpoints were trained with 3-channel observations and do
357+ # not support the current MCTS path that batches 4-channel states.
358+ return self ._fast_result (board )
305359 torch_module = self ._require_torch ()
306360 mcts = self ._ensure_mcts ()
307361 probs = mcts .run (board = board , add_dirichlet_noise = False , temperature = 0.0 )
@@ -311,6 +365,8 @@ def _strong_result(self, board: AtaxxBoard) -> InferenceResult:
311365 # Value still comes from raw net (current-player perspective), which is stable and cheap.
312366 mask_np = self ._legal_action_mask (board )
313367 obs = board .get_observation ()
368+ if obs .shape [0 ] != self ._model_input_channels :
369+ obs = obs [: self ._model_input_channels ]
314370 obs_tensor = torch_module .from_numpy (obs ).unsqueeze (0 ).to (self .device )
315371 mask_tensor = torch_module .from_numpy (mask_np ).unsqueeze (0 ).to (self .device )
316372 with torch_module .no_grad ():
0 commit comments