Skip to content

Commit 9022637

Browse files
author
Philippe Vaillancourt
committed
fix(select): Fix bug in the way the UCB1 was calculated on state brought in from the Transposition t
When calculating the UCB1 for a node, the current algorithm uses the parent node's visit count as a proxy for the total of the siblings' visit count. This creates a problem when an existing state is brought in from the transposition table. At that moment, the parent's visit count is no longer reprensentatiove of the sum of its children visit count. Fix the algorithm so that it uses the correct number.
1 parent 3388c05 commit 9022637

5 files changed

Lines changed: 11 additions & 14 deletions

File tree

src/controller.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import {
99
import { MCTSFacade, DefaultMCTSFacade } from './mcts/mcts'
1010
import { DataStore } from './data-store'
1111
import { DefaultSelect } from './mcts/select/select'
12-
import { DefaultExpand } from './mcts/expand/expand'
12+
import { DefaultExpand, Expand } from './mcts/expand/expand'
1313
import { UCB1, DefaultUCB1, DefaultBestChild } from './mcts/select/best-child/best-child'
1414
import {
1515
DefaultSimulate,

src/mcts/expand/expand.ts

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,6 @@ export interface Expand<State, Action> {
2020
*
2121
* @hidden
2222
* @internal
23-
* @export
24-
* @class DefaultExpand
25-
* @implements {Expand<State, Action>}
2623
* @template State
2724
* @template Action
2825
*/

src/mcts/select/best-child/best-child.ts

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,10 @@ export class DefaultBestChild<State extends Playerwise, Action>
4444
if (!node.children.length) {
4545
return undefined
4646
}
47-
47+
const sumChildVisits = node.children.reduce((p, c) => p + c.mctsState.visits, 0)
4848
const selectedNode = node.children.reduce((p, c) => {
49-
return this.UCB1_.run(node.mctsState, p.mctsState, exploit) >
50-
this.UCB1_.run(node.mctsState, c.mctsState, exploit)
49+
return this.UCB1_.run(sumChildVisits, p.mctsState, exploit) >
50+
this.UCB1_.run(sumChildVisits, c.mctsState, exploit)
5151
? p
5252
: c
5353
})
@@ -66,7 +66,7 @@ export class DefaultBestChild<State extends Playerwise, Action>
6666
* @template Action
6767
*/
6868
export interface UCB1<State, Action> {
69-
run(parent: MCTSState<State, Action>, child: MCTSState<State, Action>, exploit?: boolean): number
69+
run(sumChildVisits: number, child: MCTSState<State, Action>, exploit?: boolean): number
7070
}
7171

7272
/**
@@ -90,10 +90,10 @@ export class DefaultUCB1<State, Action> implements UCB1<State, Action> {
9090
* @returns {number}
9191
* @memberof DefaultUCB1
9292
*/
93-
run(parent: MCTSState<State, Action>, child: MCTSState<State, Action>, exploit = false): number {
93+
run(sumChildVisits: number, child: MCTSState<State, Action>, exploit = false): number {
9494
if (exploit) this.explorationParam_ = 0
9595
const exploitationTerm = child.reward / child.visits
96-
const explorationTerm = Math.sqrt(Math.log(parent.visits) / child.visits)
96+
const explorationTerm = Math.sqrt(Math.log(sumChildVisits) / child.visits)
9797
return exploitationTerm + this.explorationParam_ * explorationTerm
9898
}
9999
}

src/mcts/select/select.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ export class DefaultSelect<State extends Playerwise, Action> implements Select<S
4141
const child = this.bestChild_.run(node)
4242
if (!child) return this.expand_.run(node)
4343
if (node.isNotFullyExpanded()) {
44-
const ucb1 = this.ucb1_.run(node.mctsState, child.mctsState)
44+
const sumChildVisits = node.children.reduce((p, c) => p + c.mctsState.visits, 0)
45+
const ucb1 = this.ucb1_.run(sumChildVisits, child.mctsState)
4546
if (ucb1 < this.fpuParam_) {
4647
return this.expand_.run(node)
4748
}

test/mcts.test.ts

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,12 +98,11 @@ describe('The DefaultUCB1 function', () => {
9898
board: ticTacToeBoard,
9999
player: 1
100100
}
101-
const parent = new MCTSState(state)
101+
102102
const child = new MCTSState(state)
103-
parent.visits = 300
104103
child.visits = 100
105104
child.reward = 50
106-
expect(ucb1.run(parent, child)).toBeCloseTo(0.8377)
105+
expect(ucb1.run(300, child)).toBeCloseTo(0.8377)
107106
})
108107
})
109108
})

0 commit comments

Comments
 (0)