@@ -56,6 +56,13 @@ def __init__(
5656 self ._cache_hits = 0
5757 self ._cache_misses = 0
5858
59+ @staticmethod
60+ def _sample_tied_index (candidate_indices : np .ndarray ) -> int :
61+ if candidate_indices .size == 1 :
62+ return int (candidate_indices [0 ])
63+ picked = int (np .random .randint (0 , candidate_indices .size ))
64+ return int (candidate_indices [picked ])
65+
5966 def run (
6067 self ,
6168 board : AtaxxBoard ,
@@ -232,24 +239,31 @@ def _add_dirichlet_noise(self, node: MCTSNode, alpha: float, frac: float) -> Non
232239 child .prior = (1.0 - frac ) * child .prior + frac * float (noise [idx ])
233240
234241 def _select_child (self , node : MCTSNode ) -> tuple [int , MCTSNode ]:
235- best_action = - 1
236- best_child : MCTSNode | None = None
237242 best_score = - float ("inf" )
243+ tied_actions : list [int ] = []
244+ tied_children : list [MCTSNode ] = []
238245 sqrt_parent = math .sqrt (node .visit_count + 1 )
239246
240247 for action_idx , child in node .children .items ():
241248 # child.value() is from child-player perspective; negate for parent.
242249 q_value = - child .value ()
243250 u_value = self .c_puct * child .prior * sqrt_parent / (1 + child .visit_count )
244251 score = q_value + u_value
245- if score > best_score :
252+ # Early training often produces flat priors/value estimates. If we always
253+ # keep the first child on exact ties, search collapses into one opening.
254+ if score > (best_score + 1e-12 ):
246255 best_score = score
247- best_action = action_idx
248- best_child = child
256+ tied_actions = [action_idx ]
257+ tied_children = [child ]
258+ continue
259+ if math .isclose (score , best_score , rel_tol = 0.0 , abs_tol = 1e-12 ):
260+ tied_actions .append (action_idx )
261+ tied_children .append (child )
249262
250- if best_child is None :
263+ if len ( tied_children ) == 0 :
251264 raise RuntimeError ("No child selected from a non-empty node." )
252- return best_action , best_child
265+ picked = self ._sample_tied_index (np .arange (len (tied_children ), dtype = np .int64 ))
266+ return tied_actions [picked ], tied_children [picked ]
253267
254268 def _expand (self , node : MCTSNode , board : AtaxxBoard ) -> float :
255269 """
@@ -281,8 +295,10 @@ def _get_action_probs(self, root: MCTSNode, temperature: float) -> np.ndarray:
281295 )
282296
283297 if temperature <= 0.0 :
284- best_idx = int (np .argmax (visit_counts ))
285- probs [int (actions [best_idx ])] = 1.0
298+ max_visits = float (np .max (visit_counts ))
299+ best_indices = np .flatnonzero (visit_counts == max_visits )
300+ chosen = self ._sample_tied_index (best_indices )
301+ probs [int (actions [chosen ])] = 1.0
286302 return probs
287303
288304 adjusted = np .power (visit_counts , 1.0 / temperature )
0 commit comments