Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 35 additions & 23 deletions public/maia-worker.js
Original file line number Diff line number Diff line change
Expand Up @@ -116,38 +116,50 @@ self.onmessage = async (e) => {
}

case 'download': {
postMessage({ type: 'status', status: 'downloading' })
postMessage({ type: 'progress', progress: 0 })
const response = await fetch(modelUrl)
if (!response.ok) throw new Error('Failed to fetch model')

const reader = response.body.getReader()
const contentLength = +(response.headers.get('Content-Length') || 0)
const chunks = []
let receivedLength = 0
let lastReportedProgress = 0

while (true) {
const { done, value } = await reader.read()
if (done) break
chunks.push(value)
receivedLength += value.length
const currentProgress = Math.floor(
(receivedLength / contentLength) * 100,
)
if (currentProgress >= lastReportedProgress + 10) {
postMessage({ type: 'progress', progress: currentProgress })
lastReportedProgress = currentProgress
let buffer

if (response.body && typeof response.body.getReader === 'function') {
const reader = response.body.getReader()
const contentLength = +(response.headers.get('Content-Length') || 0)
const chunks = []
let receivedLength = 0
let lastReportedProgress = 0

while (true) {
const { done, value } = await reader.read()
if (done) break
chunks.push(value)
receivedLength += value.length

if (contentLength > 0) {
const currentProgress = Math.floor(
(receivedLength / contentLength) * 100,
)
if (currentProgress >= lastReportedProgress + 10) {
postMessage({ type: 'progress', progress: currentProgress })
lastReportedProgress = currentProgress
}
}
}
}

const buffer = new Uint8Array(receivedLength)
let position = 0
for (const chunk of chunks) {
buffer.set(chunk, position)
position += chunk.length
buffer = new Uint8Array(receivedLength)
let position = 0
for (const chunk of chunks) {
buffer.set(chunk, position)
position += chunk.length
}
} else {
buffer = new Uint8Array(await response.arrayBuffer())
}

await storeModel(modelUrl, modelVersion, buffer.buffer)
await initSession(buffer.buffer)
postMessage({ type: 'progress', progress: 100 })
postMessage({ type: 'status', status: 'ready' })
break
}
Expand Down
12 changes: 8 additions & 4 deletions src/components/Common/DownloadModelModal.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,18 @@ export const DownloadModelModal: React.FC<Props> = ({
</div>

<div className="mt-4 flex w-full flex-col items-end justify-end gap-2 md:mt-6 md:flex-row">
{progress ? (
{isDownloading || progress > 0 ? (
<div className="relative order-2 flex h-8 w-full items-center overflow-hidden rounded-md border border-glass-border bg-glass px-3 md:order-1 md:h-10 md:flex-1">
<p className="z-10 text-xs text-white/90 md:text-sm">
{Math.round(progress)}%
{progress > 0
? `${Math.round(progress)}%`
: 'Starting download...'}
</p>
<div
className="absolute left-0 top-0 z-0 h-full rounded-l-md bg-human-4 transition-all duration-500 ease-out"
style={{ width: `${progress}%` }}
className={`absolute left-0 top-0 z-0 h-full rounded-l-md bg-human-4 transition-all duration-500 ease-out ${
progress === 0 ? 'animate-pulse' : ''
}`}
style={{ width: `${progress > 0 ? progress : 12}%` }}
/>
</div>
) : null}
Expand Down
35 changes: 33 additions & 2 deletions src/lib/engine/maia.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,18 @@ interface PendingInference {
reject: (error: Error) => void
}

interface PendingDownload {
resolve: () => void
reject: (error: Error) => void
}

class Maia {
private worker: Worker | null = null
private options: MaiaOptions
private storage: MaiaModelStorage
private pendingInferences: Map<number, PendingInference> = new Map()
private pendingDownload: PendingDownload | null = null
private downloadPromise: Promise<void> | null = null
private nextRequestId = 0

constructor(options: MaiaOptions) {
Expand All @@ -50,6 +57,12 @@ class Maia {
switch (msg.type) {
case 'status':
this.options.setStatus(msg.status)
if (msg.status === 'ready') {
this.options.setProgress(100)
this.pendingDownload?.resolve()
this.pendingDownload = null
this.downloadPromise = null
}
break

case 'progress':
Expand All @@ -66,6 +79,9 @@ class Maia {
} else {
this.options.setError(msg.message)
this.options.setStatus('error')
this.pendingDownload?.reject(new Error(msg.message))
this.pendingDownload = null
this.downloadPromise = null
}
break
}
Expand All @@ -86,16 +102,31 @@ class Maia {

this.worker.onerror = (err) => {
console.error('Maia worker error:', err)
this.options.setError(err.message || 'Worker crashed')
const error = new Error(err.message || 'Worker crashed')
this.options.setError(error.message)
this.options.setStatus('error')
this.pendingDownload?.reject(error)
this.pendingDownload = null
this.downloadPromise = null
}

this.worker.postMessage({ type: 'init', modelUrl, modelVersion })
}

public async downloadModel() {
if (!this.worker) throw new Error('Worker not initialized')
this.worker.postMessage({ type: 'download' })
if (this.downloadPromise) {
return this.downloadPromise
}

this.options.setProgress(0)

this.downloadPromise = new Promise<void>((resolve, reject) => {
this.pendingDownload = { resolve, reject }
this.worker!.postMessage({ type: 'download' })
})

return this.downloadPromise
}

public async getStorageInfo() {
Expand Down
71 changes: 64 additions & 7 deletions src/lib/engine/stockfish.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ import {
import { StockfishModelStorage } from './stockfishStorage'

const DEFAULT_NNUE_FETCH_TIMEOUT_MS = 30000
const DEFAULT_STOCKFISH_MODULE_INIT_TIMEOUT_MS = 15000
const STOCKFISH_CACHE_LOOKUP_TIMEOUT_MS = 1500
const STOCKFISH_CACHE_WRITE_TIMEOUT_MS = 5000
type StockfishInitPhase =
| 'idle'
| 'loading-module'
Expand Down Expand Up @@ -1784,6 +1787,39 @@ const fetchWithTimeout = async (
}
}

const withTimeout = async <T>(
promise: Promise<T>,
timeoutMs: number,
message: string,
): Promise<T> => {
if (timeoutMs <= 0) {
return promise
}

let timeoutId: ReturnType<typeof setTimeout> | null = null

try {
return await Promise.race([
promise,
new Promise<T>((_, reject) => {
timeoutId = setTimeout(() => reject(new Error(message)), timeoutMs)
}),
])
} finally {
if (timeoutId) {
clearTimeout(timeoutId)
}
}
}

const getStockfishModuleInitTimeoutMs = (): number => {
const raw = process.env.NEXT_PUBLIC_STOCKFISH_MODULE_INIT_TIMEOUT_MS
const parsed = raw ? parseInt(raw, 10) : NaN
return Number.isFinite(parsed) && parsed > 0
? parsed
: DEFAULT_STOCKFISH_MODULE_INIT_TIMEOUT_MS
}

const loadNnueModel = async (
modelUrl: string,
storage: StockfishModelStorage,
Expand All @@ -1792,7 +1828,14 @@ const loadNnueModel = async (
forceRefresh = false,
): Promise<ArrayBuffer> => {
if (!forceRefresh) {
const cachedModel = await storage.getModel(modelUrl)
const cachedModel = await withTimeout(
storage.getModel(modelUrl),
STOCKFISH_CACHE_LOOKUP_TIMEOUT_MS,
`Timed out while checking cached Stockfish model: ${modelUrl}`,
).catch((error) => {
console.warn(error)
return null
})
if (cachedModel) {
return cachedModel
}
Expand All @@ -1807,7 +1850,13 @@ const loadNnueModel = async (
}

const buffer = await response.arrayBuffer()
await storage.storeModel(modelUrl, buffer)
void withTimeout(
storage.storeModel(modelUrl, buffer),
STOCKFISH_CACHE_WRITE_TIMEOUT_MS,
`Timed out while caching Stockfish model: ${modelUrl}`,
).catch((error) => {
console.warn(error)
})
return buffer
}

Expand Down Expand Up @@ -1839,16 +1888,24 @@ const setupStockfish = async (
process.env.NEXT_PUBLIC_STOCKFISH_NNUE_BASE_URL ??
'https://raw.githubusercontent.com/CSSLab/maia-platform-frontend/e23a50e/public/stockfish'
const storage = new StockfishModelStorage()
await storage.requestPersistentStorage()
const timeoutMs = getNnueFetchTimeoutMs()
const moduleInitTimeoutMs = getStockfishModuleInitTimeoutMs()
let nnueUrls: [string, string] | null = null

void storage.requestPersistentStorage().catch((error) => {
console.warn('Failed to request persistent storage for Stockfish:', error)
})

const createInstance = async (): Promise<StockfishWeb> => {
onPhaseChange?.('loading-module')
return makeModule.default({
wasmMemory: sharedWasmMemory(2560),
locateFile: (name: string) => `/stockfish/${name}`,
})
return withTimeout(
makeModule.default({
wasmMemory: sharedWasmMemory(2560),
locateFile: (name: string) => `/stockfish/${name}`,
}),
moduleInitTimeoutMs,
`Stockfish module initialization timed out after ${moduleInitTimeoutMs}ms`,
)
}

const loadWeightsIntoInstance = async (
Expand Down
2 changes: 1 addition & 1 deletion src/types/engine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ export interface MaiaEngine {
maia?: Maia
status: MaiaStatus
progress: number
downloadModel: () => void
downloadModel: () => Promise<void>
}

export interface StockfishEngine {
Expand Down
Loading