Skip to content

Commit dd7b32f

Browse files
feat: add fallback to opening book for initial analysis
1 parent 5018803 commit dd7b32f

2 files changed

Lines changed: 109 additions & 15 deletions

File tree

src/api/play/play.ts

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,32 @@ export const getGameMove = async (
102102
return res.json()
103103
}
104104

105+
export const getBookMoves = async (fen: string) => {
106+
const res = await fetch(buildUrl(`play/get_book_moves?fen=${fen}`), {
107+
method: 'POST',
108+
headers: {
109+
Accept: 'application/json',
110+
'Content-Type': 'application/json',
111+
},
112+
body: JSON.stringify({
113+
moves: [],
114+
maia_names: [
115+
'maia_kdd_1100',
116+
'maia_kdd_1200',
117+
'maia_kdd_1300',
118+
'maia_kdd_1400',
119+
'maia_kdd_1500',
120+
'maia_kdd_1600',
121+
'maia_kdd_1700',
122+
'maia_kdd_1800',
123+
'maia_kdd_1900',
124+
],
125+
}),
126+
})
127+
128+
return res.json()
129+
}
130+
105131
export const submitGameMove = async (
106132
gameId: string,
107133
moves: string[],

src/hooks/useAnalysisController/useAnalysisController.ts

Lines changed: 83 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import { Chess } from 'chess.ts'
2-
import React, { useEffect, useMemo, useState, useCallback } from 'react'
2+
import { useEffect, useMemo, useState } from 'react'
33

4+
import { getBookMoves } from 'src/api'
45
import {
56
GameTree,
67
AnalyzedGame,
@@ -44,6 +45,7 @@ export const useAnalysisController = (game: AnalyzedGame) => {
4445
)
4546

4647
const [analysisState, setAnalysisState] = useState(0)
48+
const inProgressAnalyses = useMemo(() => new Set<string>(), [])
4749

4850
const {
4951
maia,
@@ -70,29 +72,95 @@ export const useAnalysisController = (game: AnalyzedGame) => {
7072
if (!controller.currentNode) return
7173

7274
const board = new Chess(controller.currentNode.fen)
75+
const nodeFen = controller.currentNode.fen
76+
7377
;(async () => {
7478
if (
7579
maiaStatus !== 'ready' ||
7680
!controller.currentNode ||
77-
controller.currentNode.analysis.maia
81+
controller.currentNode.analysis.maia ||
82+
inProgressAnalyses.has(nodeFen)
7883
)
7984
return
8085

81-
const { result } = await maia.batchEvaluate(
82-
Array(9).fill(board.fen()),
83-
[1100, 1200, 1300, 1400, 1500, 1600, 1700, 1800, 1900],
84-
[1100, 1200, 1300, 1400, 1500, 1600, 1700, 1800, 1900],
85-
)
86+
inProgressAnalyses.add(nodeFen)
8687

87-
const maiaEval: { [key: string]: MaiaEvaluation } = {}
88-
MAIA_MODELS.forEach((model, index) => {
89-
maiaEval[model] = result[index]
90-
})
88+
try {
89+
// When the ply is within the first 10 ply, analyze via the server API rather than the client-side ONNX Maia2 model
90+
if (controller.currentNode.moveNumber <= 5) {
91+
const bookMoves = await getBookMoves(board.fen())
92+
93+
const targetFormat: { [key: string]: MaiaEvaluation } = {}
94+
const missingModels: string[] = []
95+
96+
MAIA_MODELS.forEach((model, index) => {
97+
const sortedMoves = Object.entries(bookMoves[model] || {})
98+
.sort(
99+
([, valueA], [, valueB]) =>
100+
(valueB as number) - (valueA as number),
101+
)
102+
.reduce((acc, [key, value]) => ({ ...acc, [key]: value }), {})
103+
104+
if (Object.keys(sortedMoves).length === 0) {
105+
missingModels.push(model)
106+
}
107+
108+
targetFormat[model] = {
109+
value: 0,
110+
policy: sortedMoves,
111+
}
112+
})
113+
114+
// If we have some missing ratings, use the client-side ONNX model to analyze them
115+
if (missingModels.length > 0) {
116+
console.log(
117+
'Falling back to client-side ONNX model for:',
118+
missingModels,
119+
)
120+
const ratingLevels = missingModels.map((name) =>
121+
parseInt(name.slice(-4)),
122+
)
123+
124+
const { result } = await maia.batchEvaluate(
125+
Array(missingModels.length).fill(board.fen()),
126+
ratingLevels,
127+
ratingLevels,
128+
)
129+
130+
missingModels.forEach((model, index) => {
131+
targetFormat[model] = result[index]
132+
})
133+
}
91134

92-
controller.currentNode.addMaiaAnalysis(maiaEval, currentMaiaModel)
93-
setAnalysisState((state) => state + 1)
135+
controller.currentNode.addMaiaAnalysis(targetFormat, currentMaiaModel)
136+
setAnalysisState((state) => state + 1)
137+
return
138+
}
139+
140+
const { result } = await maia.batchEvaluate(
141+
Array(9).fill(board.fen()),
142+
[1100, 1200, 1300, 1400, 1500, 1600, 1700, 1800, 1900],
143+
[1100, 1200, 1300, 1400, 1500, 1600, 1700, 1800, 1900],
144+
)
145+
146+
const maiaEval: { [key: string]: MaiaEvaluation } = {}
147+
MAIA_MODELS.forEach((model, index) => {
148+
maiaEval[model] = result[index]
149+
})
150+
151+
controller.currentNode.addMaiaAnalysis(maiaEval, currentMaiaModel)
152+
setAnalysisState((state) => state + 1)
153+
} finally {
154+
inProgressAnalyses.delete(nodeFen)
155+
}
94156
})()
95-
}, [maiaStatus, controller.currentNode, analysisState, currentMaiaModel])
157+
}, [
158+
maiaStatus,
159+
controller.currentNode,
160+
analysisState,
161+
currentMaiaModel,
162+
inProgressAnalyses,
163+
])
96164

97165
useEffect(() => {
98166
if (!controller.currentNode) return
@@ -405,7 +473,6 @@ export const useAnalysisController = (game: AnalyzedGame) => {
405473
if (!controller.currentNode) return
406474
const maia = controller.currentNode.analysis.maia
407475
const stockfish = moveEvaluation?.stockfish
408-
409476
const candidates: string[][] = []
410477

411478
if (!maia) return
@@ -551,6 +618,7 @@ export const useAnalysisController = (game: AnalyzedGame) => {
551618
const topMaiaMove = Object.entries(maia.policy).sort(
552619
(a, b) => b[1] - a[1],
553620
)[0]
621+
554622
const topStockfishMoves = Object.entries(stockfish.cp_vec)
555623
.sort((a, b) => (isBlackTurn ? a[1] - b[1] : b[1] - a[1]))
556624
.slice(0, 3)

0 commit comments

Comments
 (0)