Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 30 additions & 8 deletions src/lib/engine/stockfish.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { Chess } from 'chess.ts'
import { cpToWinrate } from 'src/lib'
import StockfishWeb from 'lila-stockfish-web'
import { StockfishEvaluation } from 'src/types'
import { StockfishModelStorage } from './stockfishStorage'

class Engine {
private fen: string
Expand Down Expand Up @@ -339,6 +340,27 @@ const sharedWasmMemory = (lo: number, hi = 32767): WebAssembly.Memory => {
}
}

const loadNnueModel = async (
modelUrl: string,
storage: StockfishModelStorage,
): Promise<ArrayBuffer> => {
const cachedModel = await storage.getModel(modelUrl)
if (cachedModel) {
return cachedModel
}

const response = await fetch(modelUrl)
if (!response.ok) {
throw new Error(
`Failed to fetch Stockfish NNUE model (${response.status}) from ${modelUrl}`,
)
}

const buffer = await response.arrayBuffer()
await storage.storeModel(modelUrl, buffer)
return buffer
}

const setupStockfish = (): Promise<StockfishWeb> => {
return new Promise<StockfishWeb>((resolve, reject) => {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
Expand All @@ -355,17 +377,17 @@ const setupStockfish = (): Promise<StockfishWeb> => {
const nnueBaseUrl =
process.env.NEXT_PUBLIC_STOCKFISH_NNUE_BASE_URL ??
'https://raw.githubusercontent.com/CSSLab/maia-platform-frontend/e23a50e/public/stockfish'
const storage = new StockfishModelStorage()
await storage.requestPersistentStorage()

const nnue0Url = `${nnueBaseUrl}/${instance.getRecommendedNnue(0)}`
const nnue1Url = `${nnueBaseUrl}/${instance.getRecommendedNnue(1)}`

// Load NNUE models before resolving
Promise.all([
fetch(`${nnueBaseUrl}/${instance.getRecommendedNnue(0)}`),
fetch(`${nnueBaseUrl}/${instance.getRecommendedNnue(1)}`),
loadNnueModel(nnue0Url, storage),
loadNnueModel(nnue1Url, storage),
])
.then((responses) => {
return Promise.all([
responses[0].arrayBuffer(),
responses[1].arrayBuffer(),
])
})
.then((buffers) => {
instance.setNnueBuffer(new Uint8Array(buffers[0]), 0)
instance.setNnueBuffer(new Uint8Array(buffers[1]), 1)
Expand Down
134 changes: 134 additions & 0 deletions src/lib/engine/stockfishStorage.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
interface NnueStorage {
id: string
url: string
data: Blob
timestamp: number
size: number
}

export class StockfishModelStorage {
private dbName = 'StockfishModels'
private storeName = 'models'
private version = 1
private db: IDBDatabase | null = null

async openDB(): Promise<IDBDatabase | null> {
if (typeof indexedDB === 'undefined') {
return null
}

if (this.db) return this.db

return new Promise((resolve, reject) => {
const request = indexedDB.open(this.dbName, this.version)

request.onerror = () => reject(request.error)
request.onsuccess = () => {
this.db = request.result
resolve(request.result)
}

request.onupgradeneeded = (event) => {
const db = (event.target as IDBOpenDBRequest).result
if (!db.objectStoreNames.contains(this.storeName)) {
const store = db.createObjectStore(this.storeName, { keyPath: 'id' })
store.createIndex('timestamp', 'timestamp', { unique: false })
}
}
})
}

async getModel(modelUrl: string): Promise<ArrayBuffer | null> {
try {
const db = await this.openDB()
if (!db) return null

const transaction = db.transaction([this.storeName], 'readonly')
const store = transaction.objectStore(this.storeName)

const modelData = await new Promise<NnueStorage | null>(
(resolve, reject) => {
const request = store.get(modelUrl)
request.onsuccess = () => resolve(request.result || null)
request.onerror = () => reject(request.error)
},
)

if (!modelData) {
return null
}

if (modelData.url !== modelUrl) {
await this.deleteModel(modelUrl)
return null
}

return modelData.data.arrayBuffer()
} catch (error) {
console.warn('Stockfish cache read failed:', error)
return null
}
}

async storeModel(modelUrl: string, buffer: ArrayBuffer): Promise<void> {
try {
const db = await this.openDB()
if (!db) return

const transaction = db.transaction([this.storeName], 'readwrite')
const store = transaction.objectStore(this.storeName)

const modelData: NnueStorage = {
id: modelUrl,
url: modelUrl,
data: new Blob([buffer]),
timestamp: Date.now(),
size: buffer.byteLength,
}

await new Promise<void>((resolve, reject) => {
const request = store.put(modelData)
request.onsuccess = () => resolve()
request.onerror = () => reject(request.error)
})
} catch (error) {
console.warn('Stockfish cache write failed:', error)
}
}

async deleteModel(modelUrl: string): Promise<void> {
try {
const db = await this.openDB()
if (!db) return

const transaction = db.transaction([this.storeName], 'readwrite')
const store = transaction.objectStore(this.storeName)

await new Promise<void>((resolve, reject) => {
const request = store.delete(modelUrl)
request.onsuccess = () => resolve()
request.onerror = () => reject(request.error)
})
} catch (error) {
console.warn('Stockfish cache delete failed:', error)
}
}

async requestPersistentStorage(): Promise<boolean> {
try {
if (
typeof navigator !== 'undefined' &&
'storage' in navigator &&
'persist' in navigator.storage
) {
return navigator.storage.persist()
}
return false
} catch (error) {
console.warn('Failed to request persistent storage:', error)
return false
}
}
}

export default StockfishModelStorage
Loading