diff --git a/public/maia-worker.js b/public/maia-worker.js index 7378b0f1..a048ebb4 100644 --- a/public/maia-worker.js +++ b/public/maia-worker.js @@ -116,38 +116,50 @@ self.onmessage = async (e) => { } case 'download': { + postMessage({ type: 'status', status: 'downloading' }) + postMessage({ type: 'progress', progress: 0 }) const response = await fetch(modelUrl) if (!response.ok) throw new Error('Failed to fetch model') - const reader = response.body.getReader() - const contentLength = +(response.headers.get('Content-Length') || 0) - const chunks = [] - let receivedLength = 0 - let lastReportedProgress = 0 - - while (true) { - const { done, value } = await reader.read() - if (done) break - chunks.push(value) - receivedLength += value.length - const currentProgress = Math.floor( - (receivedLength / contentLength) * 100, - ) - if (currentProgress >= lastReportedProgress + 10) { - postMessage({ type: 'progress', progress: currentProgress }) - lastReportedProgress = currentProgress + let buffer + + if (response.body && typeof response.body.getReader === 'function') { + const reader = response.body.getReader() + const contentLength = +(response.headers.get('Content-Length') || 0) + const chunks = [] + let receivedLength = 0 + let lastReportedProgress = 0 + + while (true) { + const { done, value } = await reader.read() + if (done) break + chunks.push(value) + receivedLength += value.length + + if (contentLength > 0) { + const currentProgress = Math.floor( + (receivedLength / contentLength) * 100, + ) + if (currentProgress >= lastReportedProgress + 10) { + postMessage({ type: 'progress', progress: currentProgress }) + lastReportedProgress = currentProgress + } + } } - } - const buffer = new Uint8Array(receivedLength) - let position = 0 - for (const chunk of chunks) { - buffer.set(chunk, position) - position += chunk.length + buffer = new Uint8Array(receivedLength) + let position = 0 + for (const chunk of chunks) { + buffer.set(chunk, position) + position += chunk.length + } + } else { + buffer = new Uint8Array(await response.arrayBuffer()) } await storeModel(modelUrl, modelVersion, buffer.buffer) await initSession(buffer.buffer) + postMessage({ type: 'progress', progress: 100 }) postMessage({ type: 'status', status: 'ready' }) break } diff --git a/src/components/Common/DownloadModelModal.tsx b/src/components/Common/DownloadModelModal.tsx index 1c473147..819f39c0 100644 --- a/src/components/Common/DownloadModelModal.tsx +++ b/src/components/Common/DownloadModelModal.tsx @@ -126,14 +126,18 @@ export const DownloadModelModal: React.FC = ({
- {progress ? ( + {isDownloading || progress > 0 ? (

- {Math.round(progress)}% + {progress > 0 + ? `${Math.round(progress)}%` + : 'Starting download...'}

0 ? progress : 12}%` }} />
) : null} diff --git a/src/lib/engine/maia.ts b/src/lib/engine/maia.ts index 8d7695dc..0520ece9 100644 --- a/src/lib/engine/maia.ts +++ b/src/lib/engine/maia.ts @@ -24,11 +24,18 @@ interface PendingInference { reject: (error: Error) => void } +interface PendingDownload { + resolve: () => void + reject: (error: Error) => void +} + class Maia { private worker: Worker | null = null private options: MaiaOptions private storage: MaiaModelStorage private pendingInferences: Map = new Map() + private pendingDownload: PendingDownload | null = null + private downloadPromise: Promise | null = null private nextRequestId = 0 constructor(options: MaiaOptions) { @@ -50,6 +57,12 @@ class Maia { switch (msg.type) { case 'status': this.options.setStatus(msg.status) + if (msg.status === 'ready') { + this.options.setProgress(100) + this.pendingDownload?.resolve() + this.pendingDownload = null + this.downloadPromise = null + } break case 'progress': @@ -66,6 +79,9 @@ class Maia { } else { this.options.setError(msg.message) this.options.setStatus('error') + this.pendingDownload?.reject(new Error(msg.message)) + this.pendingDownload = null + this.downloadPromise = null } break } @@ -86,8 +102,12 @@ class Maia { this.worker.onerror = (err) => { console.error('Maia worker error:', err) - this.options.setError(err.message || 'Worker crashed') + const error = new Error(err.message || 'Worker crashed') + this.options.setError(error.message) this.options.setStatus('error') + this.pendingDownload?.reject(error) + this.pendingDownload = null + this.downloadPromise = null } this.worker.postMessage({ type: 'init', modelUrl, modelVersion }) @@ -95,7 +115,18 @@ class Maia { public async downloadModel() { if (!this.worker) throw new Error('Worker not initialized') - this.worker.postMessage({ type: 'download' }) + if (this.downloadPromise) { + return this.downloadPromise + } + + this.options.setProgress(0) + + this.downloadPromise = new Promise((resolve, reject) => { + this.pendingDownload = { resolve, reject } + this.worker!.postMessage({ type: 'download' }) + }) + + return this.downloadPromise } public async getStorageInfo() { diff --git a/src/lib/engine/stockfish.ts b/src/lib/engine/stockfish.ts index c0074c74..fdbeb869 100644 --- a/src/lib/engine/stockfish.ts +++ b/src/lib/engine/stockfish.ts @@ -11,6 +11,9 @@ import { import { StockfishModelStorage } from './stockfishStorage' const DEFAULT_NNUE_FETCH_TIMEOUT_MS = 30000 +const DEFAULT_STOCKFISH_MODULE_INIT_TIMEOUT_MS = 15000 +const STOCKFISH_CACHE_LOOKUP_TIMEOUT_MS = 1500 +const STOCKFISH_CACHE_WRITE_TIMEOUT_MS = 5000 type StockfishInitPhase = | 'idle' | 'loading-module' @@ -1784,6 +1787,39 @@ const fetchWithTimeout = async ( } } +const withTimeout = async ( + promise: Promise, + timeoutMs: number, + message: string, +): Promise => { + if (timeoutMs <= 0) { + return promise + } + + let timeoutId: ReturnType | null = null + + try { + return await Promise.race([ + promise, + new Promise((_, reject) => { + timeoutId = setTimeout(() => reject(new Error(message)), timeoutMs) + }), + ]) + } finally { + if (timeoutId) { + clearTimeout(timeoutId) + } + } +} + +const getStockfishModuleInitTimeoutMs = (): number => { + const raw = process.env.NEXT_PUBLIC_STOCKFISH_MODULE_INIT_TIMEOUT_MS + const parsed = raw ? parseInt(raw, 10) : NaN + return Number.isFinite(parsed) && parsed > 0 + ? parsed + : DEFAULT_STOCKFISH_MODULE_INIT_TIMEOUT_MS +} + const loadNnueModel = async ( modelUrl: string, storage: StockfishModelStorage, @@ -1792,7 +1828,14 @@ const loadNnueModel = async ( forceRefresh = false, ): Promise => { if (!forceRefresh) { - const cachedModel = await storage.getModel(modelUrl) + const cachedModel = await withTimeout( + storage.getModel(modelUrl), + STOCKFISH_CACHE_LOOKUP_TIMEOUT_MS, + `Timed out while checking cached Stockfish model: ${modelUrl}`, + ).catch((error) => { + console.warn(error) + return null + }) if (cachedModel) { return cachedModel } @@ -1807,7 +1850,13 @@ const loadNnueModel = async ( } const buffer = await response.arrayBuffer() - await storage.storeModel(modelUrl, buffer) + void withTimeout( + storage.storeModel(modelUrl, buffer), + STOCKFISH_CACHE_WRITE_TIMEOUT_MS, + `Timed out while caching Stockfish model: ${modelUrl}`, + ).catch((error) => { + console.warn(error) + }) return buffer } @@ -1839,16 +1888,24 @@ const setupStockfish = async ( process.env.NEXT_PUBLIC_STOCKFISH_NNUE_BASE_URL ?? 'https://raw.githubusercontent.com/CSSLab/maia-platform-frontend/e23a50e/public/stockfish' const storage = new StockfishModelStorage() - await storage.requestPersistentStorage() const timeoutMs = getNnueFetchTimeoutMs() + const moduleInitTimeoutMs = getStockfishModuleInitTimeoutMs() let nnueUrls: [string, string] | null = null + void storage.requestPersistentStorage().catch((error) => { + console.warn('Failed to request persistent storage for Stockfish:', error) + }) + const createInstance = async (): Promise => { onPhaseChange?.('loading-module') - return makeModule.default({ - wasmMemory: sharedWasmMemory(2560), - locateFile: (name: string) => `/stockfish/${name}`, - }) + return withTimeout( + makeModule.default({ + wasmMemory: sharedWasmMemory(2560), + locateFile: (name: string) => `/stockfish/${name}`, + }), + moduleInitTimeoutMs, + `Stockfish module initialization timed out after ${moduleInitTimeoutMs}ms`, + ) } const loadWeightsIntoInstance = async ( diff --git a/src/types/engine.ts b/src/types/engine.ts index 57bb9fbe..447431b1 100644 --- a/src/types/engine.ts +++ b/src/types/engine.ts @@ -26,7 +26,7 @@ export interface MaiaEngine { maia?: Maia status: MaiaStatus progress: number - downloadModel: () => void + downloadModel: () => Promise } export interface StockfishEngine {