Skip to content

Commit 2c50945

Browse files
BasedestBerryclaude
andcommitted
Move ONNX inference to Web Worker for smooth animations
Run maia3 inference in a dedicated Web Worker so the main thread is never blocked. Board animations, eval bar transitions, and UI interactions stay fully responsive while the ~120ms batch inference runs in the background. Restore all 21 rating levels (step 100). - Add public/maia-worker.js with model loading, caching, and inference - Add public/ort/ with onnxruntime-web WASM runtime files - Refactor maia.ts to communicate via postMessage/onmessage - Simplify useEngineAnalysis back to single batch (no chunking needed) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 731ced8 commit 2c50945

6 files changed

Lines changed: 407 additions & 343 deletions

File tree

public/maia-worker.js

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
/**
2+
* Maia3 Web Worker — runs ONNX inference off the main thread.
3+
*
4+
* Messages FROM main thread:
5+
* { type: 'init', modelUrl, modelVersion }
6+
* { type: 'download' }
7+
* { type: 'inference', id, tokens, eloSelfs, eloOppos, batchSize }
8+
*
9+
* Messages TO main thread:
10+
* { type: 'status', status }
11+
* { type: 'progress', progress }
12+
* { type: 'error', message, id? }
13+
* { type: 'inference-result', id, logitsMove, logitsValue }
14+
*/
15+
16+
importScripts('/ort/ort.wasm.min.js')
17+
18+
const ORT = ort
19+
ORT.env.wasm.wasmPaths = '/ort/'
20+
21+
// ── IndexedDB storage (mirrors MaiaModelStorage) ─────────────────────────────
22+
23+
const DB_NAME = 'MaiaModels'
24+
const STORE_NAME = 'models'
25+
const MODEL_KEY = 'maia-rapid-model'
26+
27+
function openDB() {
28+
return new Promise((resolve, reject) => {
29+
const request = indexedDB.open(DB_NAME, 1)
30+
request.onerror = () => reject(request.error)
31+
request.onsuccess = () => resolve(request.result)
32+
request.onupgradeneeded = (event) => {
33+
const db = event.target.result
34+
if (!db.objectStoreNames.contains(STORE_NAME)) {
35+
db.createObjectStore(STORE_NAME, { keyPath: 'id' })
36+
}
37+
}
38+
})
39+
}
40+
41+
async function getCachedModel(modelUrl, modelVersion) {
42+
const db = await openDB()
43+
const tx = db.transaction([STORE_NAME], 'readonly')
44+
const store = tx.objectStore(STORE_NAME)
45+
46+
const data = await new Promise((resolve, reject) => {
47+
const req = store.get(MODEL_KEY)
48+
req.onsuccess = () => resolve(req.result || null)
49+
req.onerror = () => reject(req.error)
50+
})
51+
52+
if (!data) return null
53+
54+
if (data.version && data.version !== modelVersion) {
55+
const rwTx = db.transaction([STORE_NAME], 'readwrite')
56+
rwTx.objectStore(STORE_NAME).delete(MODEL_KEY)
57+
return null
58+
}
59+
60+
return await data.data.arrayBuffer()
61+
}
62+
63+
async function storeModel(modelUrl, modelVersion, buffer) {
64+
const db = await openDB()
65+
const tx = db.transaction([STORE_NAME], 'readwrite')
66+
const store = tx.objectStore(STORE_NAME)
67+
68+
await new Promise((resolve, reject) => {
69+
const req = store.put({
70+
id: MODEL_KEY,
71+
url: modelUrl,
72+
version: modelVersion,
73+
data: new Blob([buffer]),
74+
timestamp: Date.now(),
75+
size: buffer.byteLength,
76+
})
77+
req.onsuccess = () => resolve()
78+
req.onerror = () => reject(req.error)
79+
})
80+
}
81+
82+
// ── Worker state ─────────────────────────────────────────────────────────────
83+
84+
let session = null
85+
let modelUrl = null
86+
let modelVersion = null
87+
88+
async function initSession(buffer) {
89+
session = await ORT.InferenceSession.create(buffer)
90+
}
91+
92+
// ── Message handler ──────────────────────────────────────────────────────────
93+
94+
self.onmessage = async (e) => {
95+
const msg = e.data
96+
97+
try {
98+
switch (msg.type) {
99+
case 'init': {
100+
modelUrl = msg.modelUrl
101+
modelVersion = msg.modelVersion
102+
postMessage({ type: 'status', status: 'loading' })
103+
104+
const buffer = await getCachedModel(modelUrl, modelVersion)
105+
if (buffer) {
106+
await initSession(buffer)
107+
postMessage({ type: 'status', status: 'ready' })
108+
} else {
109+
postMessage({ type: 'status', status: 'no-cache' })
110+
}
111+
break
112+
}
113+
114+
case 'download': {
115+
const response = await fetch(modelUrl)
116+
if (!response.ok) throw new Error('Failed to fetch model')
117+
118+
const reader = response.body.getReader()
119+
const contentLength = +(response.headers.get('Content-Length') || 0)
120+
const chunks = []
121+
let receivedLength = 0
122+
let lastReportedProgress = 0
123+
124+
while (true) {
125+
const { done, value } = await reader.read()
126+
if (done) break
127+
chunks.push(value)
128+
receivedLength += value.length
129+
const currentProgress = Math.floor(
130+
(receivedLength / contentLength) * 100,
131+
)
132+
if (currentProgress >= lastReportedProgress + 10) {
133+
postMessage({ type: 'progress', progress: currentProgress })
134+
lastReportedProgress = currentProgress
135+
}
136+
}
137+
138+
const buffer = new Uint8Array(receivedLength)
139+
let position = 0
140+
for (const chunk of chunks) {
141+
buffer.set(chunk, position)
142+
position += chunk.length
143+
}
144+
145+
await storeModel(modelUrl, modelVersion, buffer.buffer)
146+
await initSession(buffer.buffer)
147+
postMessage({ type: 'status', status: 'ready' })
148+
break
149+
}
150+
151+
case 'inference': {
152+
if (!session) {
153+
postMessage({
154+
type: 'error',
155+
message: 'Model not initialized',
156+
id: msg.id,
157+
})
158+
return
159+
}
160+
161+
const { id, tokens, eloSelfs, eloOppos, batchSize } = msg
162+
163+
const feeds = {
164+
tokens: new ORT.Tensor('float32', new Float32Array(tokens), [
165+
batchSize,
166+
64,
167+
12,
168+
]),
169+
elo_self: new ORT.Tensor('float32', new Float32Array(eloSelfs), [
170+
batchSize,
171+
]),
172+
elo_oppo: new ORT.Tensor('float32', new Float32Array(eloOppos), [
173+
batchSize,
174+
]),
175+
}
176+
177+
const result = await session.run(feeds)
178+
179+
const logitsMove = new Float32Array(result.logits_move.data)
180+
const logitsValue = new Float32Array(result.logits_value.data)
181+
182+
postMessage(
183+
{
184+
type: 'inference-result',
185+
id,
186+
logitsMove: logitsMove.buffer,
187+
logitsValue: logitsValue.buffer,
188+
},
189+
[logitsMove.buffer, logitsValue.buffer],
190+
)
191+
break
192+
}
193+
}
194+
} catch (err) {
195+
postMessage({
196+
type: 'error',
197+
message: err.message || 'Unknown worker error',
198+
id: msg.id,
199+
})
200+
}
201+
}

0 commit comments

Comments
 (0)