|
5 | 5 | from game.board import AtaxxBoard |
6 | 6 | from game.types import Move |
7 | 7 |
|
| 8 | +HEURISTIC_LEVELS: tuple[str, ...] = ( |
| 9 | + "easy", |
| 10 | + "normal", |
| 11 | + "hard", |
| 12 | + "apex", |
| 13 | + "gambit", |
| 14 | + "sentinel", |
| 15 | +) |
| 16 | +HEURISTIC_LEVEL_SET = frozenset(HEURISTIC_LEVELS) |
| 17 | +DEFAULT_HEURISTIC_LEVEL = "normal" |
8 | 18 |
|
9 | | -def _score_move(state: AtaxxBoard, move: Move) -> float: |
| 19 | + |
| 20 | +def is_supported_heuristic_level(level: str) -> bool: |
| 21 | + return level in HEURISTIC_LEVEL_SET |
| 22 | + |
| 23 | + |
| 24 | +def heuristic_mode_from_level(level: str) -> str: |
| 25 | + if not is_supported_heuristic_level(level): |
| 26 | + raise ValueError(f"Unsupported heuristic level: {level}") |
| 27 | + return f"heuristic_{level}" |
| 28 | + |
| 29 | + |
| 30 | +def _chebyshev_distance(move: Move) -> int: |
10 | 31 | r1, c1, r2, c2 = move |
| 32 | + return max(abs(r1 - r2), abs(c1 - c2)) |
| 33 | + |
| 34 | + |
| 35 | +def _count_targets_in_radius( |
| 36 | + board: AtaxxBoard, |
| 37 | + *, |
| 38 | + row: int, |
| 39 | + col: int, |
| 40 | + target: int, |
| 41 | + radius: int, |
| 42 | +) -> int: |
| 43 | + board_size = board.grid.shape[0] |
| 44 | + r_min = max(0, row - radius) |
| 45 | + r_max = min(board_size, row + radius + 1) |
| 46 | + c_min = max(0, col - radius) |
| 47 | + c_max = min(board_size, col + radius + 1) |
| 48 | + window = board.grid[r_min:r_max, c_min:c_max] |
| 49 | + return int(np.sum(window == target)) |
| 50 | + |
| 51 | + |
| 52 | +def _mobility_advantage(after_move: AtaxxBoard) -> float: |
| 53 | + opponent_moves = len(after_move.get_valid_moves(player=after_move.current_player)) |
| 54 | + own_moves = len(after_move.get_valid_moves(player=-after_move.current_player)) |
| 55 | + return float(own_moves - opponent_moves) |
| 56 | + |
| 57 | + |
| 58 | +def _score_move(state: AtaxxBoard, move: Move) -> float: |
| 59 | + _, _, r2, c2 = move |
11 | 60 | me = state.current_player |
12 | 61 | before_me = int(np.sum(state.grid == me)) |
13 | 62 | before_opp = int(np.sum(state.grid == -me)) |
14 | 63 | scratch = state.copy() |
15 | 64 | scratch.step(move) |
16 | 65 | after_me = int(np.sum(scratch.grid == me)) |
17 | 66 | after_opp = int(np.sum(scratch.grid == -me)) |
18 | | - clone_bonus = 0.15 if max(abs(r1 - r2), abs(c1 - c2)) == 1 else 0.0 |
| 67 | + clone_bonus = 0.15 if _chebyshev_distance(move) == 1 else 0.0 |
19 | 68 | center_bonus = 0.05 * (3 - abs(r2 - 3) + 3 - abs(c2 - 3)) |
20 | 69 | return float((after_me - before_me) + (before_opp - after_opp)) + clone_bonus + center_bonus |
21 | 70 |
|
22 | 71 |
|
| 72 | +def _best_reply_penalty(after_move: AtaxxBoard) -> float: |
| 73 | + opp_moves = after_move.get_valid_moves() |
| 74 | + if len(opp_moves) == 0: |
| 75 | + return -2.0 |
| 76 | + return float(max(_score_move(after_move, opp_move) for opp_move in opp_moves)) |
| 77 | + |
| 78 | + |
| 79 | +def _softmax_choice( |
| 80 | + rng: np.random.Generator, |
| 81 | + scored_moves: list[tuple[Move, float]], |
| 82 | + *, |
| 83 | + temperature: float, |
| 84 | +) -> Move: |
| 85 | + scores = np.asarray([score for _, score in scored_moves], dtype=np.float32) |
| 86 | + logits = (scores - float(np.max(scores))) / temperature |
| 87 | + probs = np.exp(logits) |
| 88 | + probs = probs / float(np.sum(probs)) |
| 89 | + pick_idx = int(rng.choice(len(scored_moves), p=probs)) |
| 90 | + return scored_moves[pick_idx][0] |
| 91 | + |
| 92 | + |
| 93 | +def _score_apex(board: AtaxxBoard, move: Move) -> float: |
| 94 | + base = _score_move(board, move) |
| 95 | + after = board.copy() |
| 96 | + after.step(move) |
| 97 | + opp_moves = after.get_valid_moves() |
| 98 | + mobility = _mobility_advantage(after) |
| 99 | + if len(opp_moves) == 0: |
| 100 | + return base + 3.0 + 0.2 * mobility |
| 101 | + |
| 102 | + # Two-ply selective lookahead: punish lines where opponent can spike value |
| 103 | + # and we fail to recover with a strong counter on the next turn. |
| 104 | + opp_candidates = sorted( |
| 105 | + opp_moves, |
| 106 | + key=lambda opp_move: _score_move(after, opp_move), |
| 107 | + reverse=True, |
| 108 | + )[:3] |
| 109 | + worst_line = float("-inf") |
| 110 | + for opp_move in opp_candidates: |
| 111 | + reply_board = after.copy() |
| 112 | + reply_board.step(opp_move) |
| 113 | + reply_moves = reply_board.get_valid_moves() |
| 114 | + reply_best = ( |
| 115 | + max(_score_move(reply_board, reply_move) for reply_move in reply_moves) |
| 116 | + if len(reply_moves) > 0 |
| 117 | + else -2.5 |
| 118 | + ) |
| 119 | + line_value = _score_move(after, opp_move) - 0.55 * float(reply_best) |
| 120 | + worst_line = max(worst_line, float(line_value)) |
| 121 | + |
| 122 | + return base - 0.92 * worst_line + 0.2 * mobility |
| 123 | + |
| 124 | + |
| 125 | +def _score_gambit(board: AtaxxBoard, move: Move) -> float: |
| 126 | + _, _, r2, c2 = move |
| 127 | + base = _score_move(board, move) |
| 128 | + after = board.copy() |
| 129 | + after.step(move) |
| 130 | + enemy = after.current_player |
| 131 | + frontier_risk = _count_targets_in_radius( |
| 132 | + after, |
| 133 | + row=r2, |
| 134 | + col=c2, |
| 135 | + target=enemy, |
| 136 | + radius=1, |
| 137 | + ) |
| 138 | + pressure_ring = _count_targets_in_radius( |
| 139 | + after, |
| 140 | + row=r2, |
| 141 | + col=c2, |
| 142 | + target=enemy, |
| 143 | + radius=2, |
| 144 | + ) |
| 145 | + jump_bonus = 0.55 if _chebyshev_distance(move) == 2 else -0.12 |
| 146 | + flank_bonus = 0.35 if r2 in {0, 6} or c2 in {0, 6} else 0.0 |
| 147 | + hard_guard = _best_reply_penalty(after) |
| 148 | + return ( |
| 149 | + base |
| 150 | + - 0.58 * hard_guard |
| 151 | + + 0.46 * float(pressure_ring) |
| 152 | + + jump_bonus |
| 153 | + + flank_bonus |
| 154 | + - 0.42 * float(frontier_risk) |
| 155 | + ) |
| 156 | + |
| 157 | + |
| 158 | +def _score_sentinel(board: AtaxxBoard, move: Move) -> float: |
| 159 | + _, _, r2, c2 = move |
| 160 | + base = _score_move(board, move) |
| 161 | + after = board.copy() |
| 162 | + after.step(move) |
| 163 | + enemy = after.current_player |
| 164 | + own_piece = -enemy |
| 165 | + frontier_risk = _count_targets_in_radius( |
| 166 | + after, |
| 167 | + row=r2, |
| 168 | + col=c2, |
| 169 | + target=enemy, |
| 170 | + radius=1, |
| 171 | + ) |
| 172 | + local_support = ( |
| 173 | + _count_targets_in_radius( |
| 174 | + after, |
| 175 | + row=r2, |
| 176 | + col=c2, |
| 177 | + target=own_piece, |
| 178 | + radius=1, |
| 179 | + ) |
| 180 | + - 1 |
| 181 | + ) |
| 182 | + mobility = _mobility_advantage(after) |
| 183 | + center_bonus = 0.18 * (3 - abs(r2 - 3) + 3 - abs(c2 - 3)) |
| 184 | + clone_bias = 0.4 if _chebyshev_distance(move) == 1 else -0.06 |
| 185 | + hard_guard = _best_reply_penalty(after) |
| 186 | + return ( |
| 187 | + base |
| 188 | + - 0.56 * hard_guard |
| 189 | + + 0.34 * mobility |
| 190 | + + 0.36 * float(local_support) |
| 191 | + + center_bonus |
| 192 | + + clone_bias |
| 193 | + - 0.5 * float(frontier_risk) |
| 194 | + ) |
| 195 | + |
| 196 | + |
23 | 197 | def heuristic_move( |
24 | 198 | board: AtaxxBoard, |
25 | 199 | rng: np.random.Generator, |
26 | | - level: str = "normal", |
| 200 | + level: str = DEFAULT_HEURISTIC_LEVEL, |
27 | 201 | ) -> Move | None: |
| 202 | + if not is_supported_heuristic_level(level): |
| 203 | + raise ValueError(f"Unsupported heuristic level: {level}") |
| 204 | + |
28 | 205 | valid_moves = board.get_valid_moves() |
29 | 206 | if len(valid_moves) == 0: |
30 | 207 | return None |
31 | 208 |
|
32 | 209 | if level == "easy": |
33 | | - scores = np.asarray([_score_move(board, move) for move in valid_moves], dtype=np.float32) |
34 | | - scores = scores - float(np.min(scores)) + 0.2 |
35 | | - probs = scores / float(np.sum(scores)) |
36 | | - return valid_moves[int(rng.choice(len(valid_moves), p=probs))] |
| 210 | + scored_moves = [(move, _score_move(board, move)) for move in valid_moves] |
| 211 | + # Easy should still punish obvious blunders while keeping variety. |
| 212 | + return _softmax_choice(rng, scored_moves, temperature=0.85) |
37 | 213 |
|
38 | 214 | scored_moves: list[tuple[Move, float]] = [] |
39 | 215 | for move in valid_moves: |
40 | 216 | score = _score_move(board, move) |
41 | 217 | if level == "hard": |
42 | 218 | scratch = board.copy() |
43 | 219 | scratch.step(move) |
44 | | - opp_moves = scratch.get_valid_moves() |
45 | | - if len(opp_moves) > 0: |
46 | | - opp_best = max(_score_move(scratch, opp_move) for opp_move in opp_moves) |
47 | | - score -= 0.65 * opp_best |
| 220 | + score -= 0.65 * _best_reply_penalty(scratch) |
| 221 | + score += 0.12 * _mobility_advantage(scratch) |
| 222 | + elif level == "apex": |
| 223 | + score = _score_apex(board, move) |
| 224 | + elif level == "gambit": |
| 225 | + score = _score_gambit(board, move) |
| 226 | + elif level == "sentinel": |
| 227 | + score = _score_sentinel(board, move) |
48 | 228 | scored_moves.append((move, score)) |
49 | 229 |
|
50 | 230 | if level == "normal": |
51 | 231 | # Normal is deliberately non-greedy to avoid repetitive games. |
52 | | - scores = np.asarray([score for _, score in scored_moves], dtype=np.float32) |
53 | | - temperature = 0.35 |
54 | | - logits = (scores - float(np.max(scores))) / temperature |
55 | | - probs = np.exp(logits) |
56 | | - probs = probs / float(np.sum(probs)) |
57 | | - pick_idx = int(rng.choice(len(scored_moves), p=probs)) |
58 | | - return scored_moves[pick_idx][0] |
| 232 | + return _softmax_choice(rng, scored_moves, temperature=0.35) |
59 | 233 |
|
60 | 234 | best_score = max(score for _, score in scored_moves) |
61 | 235 | best_moves = [move for move, score in scored_moves if score == best_score] |
|
0 commit comments