Skip to content

Commit 2b8353b

Browse files
committed
Maia3 integration: tensor + maia formatting fixes
1 parent 0c27543 commit 2b8353b

2 files changed

Lines changed: 26 additions & 14 deletions

File tree

src/lib/engine/maia.ts

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,11 @@ class Maia {
166166

167167
const outs = await this.model.run(feeds)
168168

169-
const logitsMove = pickOutput(outs, ['logits_move', 'logits_maia', 'policy'])
169+
const logitsMove = pickOutput(outs, [
170+
'logits_move',
171+
'logits_maia',
172+
'policy',
173+
])
170174
const logitsValue = pickOutput(outs, ['logits_value', 'value'])
171175

172176
if (!logitsMove || !logitsValue) {
@@ -185,7 +189,11 @@ class Maia {
185189
return { policy, value }
186190
}
187191

188-
async batchEvaluate(boards: string[], eloSelfs: number[], eloOppos: number[]) {
192+
async batchEvaluate(
193+
boards: string[],
194+
eloSelfs: number[],
195+
eloOppos: number[],
196+
) {
189197
if (!this.model) throw new Error('Maia model not initialized')
190198

191199
const batchSize = boards.length
@@ -226,7 +234,11 @@ class Maia {
226234
const outs = await this.model.run(feeds)
227235
const end = performance.now()
228236

229-
const logitsMove = pickOutput(outs, ['logits_move', 'logits_maia', 'policy'])
237+
const logitsMove = pickOutput(outs, [
238+
'logits_move',
239+
'logits_maia',
240+
'policy',
241+
])
230242
const logitsValue = pickOutput(outs, ['logits_value', 'value'])
231243

232244
if (!logitsMove || !logitsValue) {
@@ -353,13 +365,14 @@ function pickOutput(
353365
const dims = (t as any).dims as number[] | undefined
354366
return dims && dims.length === 2 && dims[1] === 3
355367
})
356-
if (preferredNames.includes('logits_value') || preferredNames.includes('value')) {
368+
if (
369+
preferredNames.includes('logits_value') ||
370+
preferredNames.includes('value')
371+
) {
357372
if (valueCandidate) return valueCandidate
358373
}
359374

360-
const policyCandidate = vals
361-
.slice()
362-
.sort((a, b) => b.size - a.size)[0]
375+
const policyCandidate = vals.slice().sort((a, b) => b.size - a.size)[0]
363376

364377
return policyCandidate ?? null
365378
}
@@ -381,9 +394,11 @@ function makeScalarTensor(
381394
if (t.includes('float')) {
382395
return new Tensor('float32', Float32Array.from([value]), [1])
383396
}
384-
return new Tensor('int64', BigInt64Array.from([BigInt(Math.trunc(value))]), [
385-
1,
386-
])
397+
return new Tensor(
398+
'int64',
399+
BigInt64Array.from([BigInt(Math.trunc(value))]),
400+
[1],
401+
)
387402
}
388403

389404
function makeVectorTensor(

src/lib/engine/tensor.ts

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,7 @@ import { Chess, Move } from 'chess.ts'
1818
Constants
1919
========================================================= */
2020

21-
const PIECE_ORDER = [
22-
'P', 'N', 'B', 'R', 'Q', 'K',
23-
'p', 'n', 'b', 'r', 'q', 'k',
24-
]
21+
const PIECE_ORDER = ['P', 'N', 'B', 'R', 'Q', 'K', 'p', 'n', 'b', 'r', 'q', 'k']
2522

2623
export const POLICY_SIZE = 4352
2724
const PROMO_ORDER = ['q', 'r', 'b', 'n']

0 commit comments

Comments
 (0)