Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 201 additions & 0 deletions public/maia-worker.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
/**
* Maia3 Web Worker — runs ONNX inference off the main thread.
*
* Messages FROM main thread:
* { type: 'init', modelUrl, modelVersion }
* { type: 'download' }
* { type: 'inference', id, tokens, eloSelfs, eloOppos, batchSize }
*
* Messages TO main thread:
* { type: 'status', status }
* { type: 'progress', progress }
* { type: 'error', message, id? }
* { type: 'inference-result', id, logitsMove, logitsValue }
*/

importScripts('/ort/ort.wasm.min.js')

const ORT = ort
ORT.env.wasm.wasmPaths = '/ort/'

// ── IndexedDB storage (mirrors MaiaModelStorage) ─────────────────────────────

const DB_NAME = 'MaiaModels'
const STORE_NAME = 'models'
const MODEL_KEY = 'maia-rapid-model'

function openDB() {
return new Promise((resolve, reject) => {
const request = indexedDB.open(DB_NAME, 1)
request.onerror = () => reject(request.error)
request.onsuccess = () => resolve(request.result)
request.onupgradeneeded = (event) => {
const db = event.target.result
if (!db.objectStoreNames.contains(STORE_NAME)) {
db.createObjectStore(STORE_NAME, { keyPath: 'id' })
}
}
})
}

async function getCachedModel(modelUrl, modelVersion) {
const db = await openDB()
const tx = db.transaction([STORE_NAME], 'readonly')
const store = tx.objectStore(STORE_NAME)

const data = await new Promise((resolve, reject) => {
const req = store.get(MODEL_KEY)
req.onsuccess = () => resolve(req.result || null)
req.onerror = () => reject(req.error)
})

if (!data) return null

if (data.version && data.version !== modelVersion) {
const rwTx = db.transaction([STORE_NAME], 'readwrite')
rwTx.objectStore(STORE_NAME).delete(MODEL_KEY)
return null
}

return await data.data.arrayBuffer()
}

async function storeModel(modelUrl, modelVersion, buffer) {
const db = await openDB()
const tx = db.transaction([STORE_NAME], 'readwrite')
const store = tx.objectStore(STORE_NAME)

await new Promise((resolve, reject) => {
const req = store.put({
id: MODEL_KEY,
url: modelUrl,
version: modelVersion,
data: new Blob([buffer]),
timestamp: Date.now(),
size: buffer.byteLength,
})
req.onsuccess = () => resolve()
req.onerror = () => reject(req.error)
})
}

// ── Worker state ─────────────────────────────────────────────────────────────

let session = null
let modelUrl = null
let modelVersion = null

async function initSession(buffer) {
session = await ORT.InferenceSession.create(buffer)
}

// ── Message handler ──────────────────────────────────────────────────────────

self.onmessage = async (e) => {
const msg = e.data

try {
switch (msg.type) {
case 'init': {
modelUrl = msg.modelUrl
modelVersion = msg.modelVersion
postMessage({ type: 'status', status: 'loading' })

const buffer = await getCachedModel(modelUrl, modelVersion)
if (buffer) {
await initSession(buffer)
postMessage({ type: 'status', status: 'ready' })
} else {
postMessage({ type: 'status', status: 'no-cache' })
}
break
}

case 'download': {
const response = await fetch(modelUrl)
if (!response.ok) throw new Error('Failed to fetch model')

const reader = response.body.getReader()
const contentLength = +(response.headers.get('Content-Length') || 0)
const chunks = []
let receivedLength = 0
let lastReportedProgress = 0

while (true) {
const { done, value } = await reader.read()
if (done) break
chunks.push(value)
receivedLength += value.length
const currentProgress = Math.floor(
(receivedLength / contentLength) * 100,
)
if (currentProgress >= lastReportedProgress + 10) {
postMessage({ type: 'progress', progress: currentProgress })
lastReportedProgress = currentProgress
}
}

const buffer = new Uint8Array(receivedLength)
let position = 0
for (const chunk of chunks) {
buffer.set(chunk, position)
position += chunk.length
}

await storeModel(modelUrl, modelVersion, buffer.buffer)
await initSession(buffer.buffer)
postMessage({ type: 'status', status: 'ready' })
break
}

case 'inference': {
if (!session) {
postMessage({
type: 'error',
message: 'Model not initialized',
id: msg.id,
})
return
}

const { id, tokens, eloSelfs, eloOppos, batchSize } = msg

const feeds = {
tokens: new ORT.Tensor('float32', new Float32Array(tokens), [
batchSize,
64,
12,
]),
elo_self: new ORT.Tensor('float32', new Float32Array(eloSelfs), [
batchSize,
]),
elo_oppo: new ORT.Tensor('float32', new Float32Array(eloOppos), [
batchSize,
]),
}

const result = await session.run(feeds)

const logitsMove = new Float32Array(result.logits_move.data)
const logitsValue = new Float32Array(result.logits_value.data)

postMessage(
{
type: 'inference-result',
id,
logitsMove: logitsMove.buffer,
logitsValue: logitsValue.buffer,
},
[logitsMove.buffer, logitsValue.buffer],
)
break
}
}
} catch (err) {
postMessage({
type: 'error',
message: err.message || 'Unknown worker error',
id: msg.id,
})
}
}
Loading
Loading