Skip to content

Commit 9c8a8cd

Browse files
committed
Fix deterministic MCTS tie-breaking
1 parent 8519d05 commit 9c8a8cd

2 files changed

Lines changed: 80 additions & 10 deletions

File tree

src/engine/mcts.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,13 @@ def __init__(
5656
self._cache_hits = 0
5757
self._cache_misses = 0
5858

59+
@staticmethod
60+
def _sample_tied_index(candidate_indices: np.ndarray) -> int:
61+
if candidate_indices.size == 1:
62+
return int(candidate_indices[0])
63+
picked = int(np.random.randint(0, candidate_indices.size))
64+
return int(candidate_indices[picked])
65+
5966
def run(
6067
self,
6168
board: AtaxxBoard,
@@ -232,24 +239,31 @@ def _add_dirichlet_noise(self, node: MCTSNode, alpha: float, frac: float) -> Non
232239
child.prior = (1.0 - frac) * child.prior + frac * float(noise[idx])
233240

234241
def _select_child(self, node: MCTSNode) -> tuple[int, MCTSNode]:
235-
best_action = -1
236-
best_child: MCTSNode | None = None
237242
best_score = -float("inf")
243+
tied_actions: list[int] = []
244+
tied_children: list[MCTSNode] = []
238245
sqrt_parent = math.sqrt(node.visit_count + 1)
239246

240247
for action_idx, child in node.children.items():
241248
# child.value() is from child-player perspective; negate for parent.
242249
q_value = -child.value()
243250
u_value = self.c_puct * child.prior * sqrt_parent / (1 + child.visit_count)
244251
score = q_value + u_value
245-
if score > best_score:
252+
# Early training often produces flat priors/value estimates. If we always
253+
# keep the first child on exact ties, search collapses into one opening.
254+
if score > (best_score + 1e-12):
246255
best_score = score
247-
best_action = action_idx
248-
best_child = child
256+
tied_actions = [action_idx]
257+
tied_children = [child]
258+
continue
259+
if math.isclose(score, best_score, rel_tol=0.0, abs_tol=1e-12):
260+
tied_actions.append(action_idx)
261+
tied_children.append(child)
249262

250-
if best_child is None:
263+
if len(tied_children) == 0:
251264
raise RuntimeError("No child selected from a non-empty node.")
252-
return best_action, best_child
265+
picked = self._sample_tied_index(np.arange(len(tied_children), dtype=np.int64))
266+
return tied_actions[picked], tied_children[picked]
253267

254268
def _expand(self, node: MCTSNode, board: AtaxxBoard) -> float:
255269
"""
@@ -281,8 +295,10 @@ def _get_action_probs(self, root: MCTSNode, temperature: float) -> np.ndarray:
281295
)
282296

283297
if temperature <= 0.0:
284-
best_idx = int(np.argmax(visit_counts))
285-
probs[int(actions[best_idx])] = 1.0
298+
max_visits = float(np.max(visit_counts))
299+
best_indices = np.flatnonzero(visit_counts == max_visits)
300+
chosen = self._sample_tied_index(best_indices)
301+
probs[int(actions[chosen])] = 1.0
286302
return probs
287303

288304
adjusted = np.power(visit_counts, 1.0 / temperature)

tests/test_mcts_numerics.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src"))
1212

13-
from engine.mcts import MCTS
13+
from engine.mcts import MCTS, MCTSNode
1414
from game.actions import ACTION_SPACE
1515
from game.board import AtaxxBoard
1616
from model.transformer import AtaxxTransformerNet
@@ -217,6 +217,60 @@ def forward(
217217
self.assertGreaterEqual(float(stats_second["hit_rate"]), 0.0)
218218
self.assertLessEqual(float(stats_second["hit_rate"]), 1.0)
219219

220+
def test_select_child_does_not_always_pick_first_on_exact_tie(self) -> None:
221+
model = AtaxxTransformerNet(
222+
d_model=64,
223+
nhead=8,
224+
num_layers=2,
225+
dim_feedforward=128,
226+
dropout=0.0,
227+
)
228+
mcts = MCTS(model=model, c_puct=1.5, n_simulations=1, device="cpu")
229+
root = MCTSNode(prior=1.0)
230+
root.visit_count = 4
231+
root.children = {
232+
11: MCTSNode(prior=0.5),
233+
23: MCTSNode(prior=0.5),
234+
}
235+
chosen_actions: set[int] = set()
236+
237+
for seed in range(32):
238+
np.random.seed(seed)
239+
action_idx, _child = mcts._select_child(root)
240+
chosen_actions.add(action_idx)
241+
242+
self.assertEqual(chosen_actions, {11, 23})
243+
244+
def test_temperature_zero_breaks_visit_ties_without_fixed_first_action(self) -> None:
245+
class UniformModel(nn.Module):
246+
def forward(
247+
self,
248+
board_tensor: torch.Tensor,
249+
action_mask: torch.Tensor | None = None,
250+
) -> tuple[torch.Tensor, torch.Tensor]:
251+
batch = board_tensor.shape[0]
252+
logits = torch.zeros((batch, ACTION_SPACE.num_actions), dtype=torch.float32)
253+
value = torch.zeros((batch, 1), dtype=torch.float32)
254+
if action_mask is not None:
255+
logits = logits.masked_fill(action_mask <= 0, -1e9)
256+
return logits, value
257+
258+
board = AtaxxBoard()
259+
chosen_actions: set[int] = set()
260+
for seed in range(32):
261+
np.random.seed(seed)
262+
mcts = MCTS(
263+
model=UniformModel(),
264+
c_puct=1.5,
265+
n_simulations=0,
266+
device="cpu",
267+
cache_size=0,
268+
)
269+
probs = mcts.run(board=board, add_dirichlet_noise=False, temperature=0.0)
270+
chosen_actions.add(int(np.argmax(probs)))
271+
272+
self.assertGreater(len(chosen_actions), 1)
273+
220274

221275
if __name__ == "__main__":
222276
unittest.main()

0 commit comments

Comments
 (0)