Skip to content

Commit 8eeab85

Browse files
chore: developing the package
1 parent b8f9f16 commit 8eeab85

6 files changed

Lines changed: 188 additions & 38 deletions

File tree

embedder/index.js

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,14 @@
11
import { mkdir, readFile, writeFile } from 'node:fs/promises'
22
import { dirname } from 'node:path'
3+
import { toCompressed, toBase64UrlString } from '@z-base/bytecodec'
34

45
const js = String.raw
56

67
const modelPath = './models/quantized_models/eng/model.int4.onnx'
78
const modelDataPath = './models/quantized_models/eng/model.int4.onnx.data'
89
const tokenizerModelPath = './models/quantized_models/eng/tokenizer.model'
910

10-
const outPath = './src/Model/class.ts'
11-
12-
function toUint8ArraySource(bytes) {
13-
return `new Uint8Array([${Array.from(bytes).join(',')}])`
14-
}
11+
const outPath = './src/models/index.ts'
1512

1613
const [model, modelData, tokenizerModel] = await Promise.all([
1714
readFile(modelPath),
@@ -22,21 +19,40 @@ const [model, modelData, tokenizerModel] = await Promise.all([
2219
const ts = js`
2320
import * as ort from 'onnxruntime-web'
2421
import { SentencePieceProcessor } from '@agnai/sentencepiece-js'
22+
import { fromCompressed, fromBase64UrlString } from '@z-base/bytecodec'
2523
2624
export async function createInferenceSession(): Promise<ort.InferenceSession> {
27-
return ort.InferenceSession.create(${toUint8ArraySource(model)}, {
28-
externalData: [
29-
{
30-
path: 'model.int4.onnx.data',
31-
data: ${toUint8ArraySource(modelData)},
32-
},
33-
],
34-
})
25+
return ort.InferenceSession.create(
26+
await fromCompressed(
27+
fromBase64UrlString(${JSON.stringify(toBase64UrlString(await toCompressed(model)))})
28+
),
29+
{
30+
externalData: [
31+
{
32+
path: 'model.int4.onnx.data',
33+
data: await fromCompressed(
34+
fromBase64UrlString(${JSON.stringify(toBase64UrlString(await toCompressed(modelData)))})
35+
),
36+
},
37+
],
38+
}
39+
)
3540
}
3641
37-
export async function createTokenProcessor():Promise<SentencePieceProcessor> {
42+
export async function createTokenProcessor(): Promise<SentencePieceProcessor> {
3843
const tokenProcessor = new SentencePieceProcessor()
39-
await tokenProcessor.load(${toUint8ArraySource(tokenizerModel)})
44+
45+
const tokenizerModelBytes = await fromCompressed(
46+
fromBase64UrlString(${JSON.stringify(toBase64UrlString(await toCompressed(tokenizerModel)))})
47+
)
48+
49+
const tokenizerModelBlobBytes = Uint8Array.from(tokenizerModelBytes)
50+
51+
const tokenizerModelUrl = URL.createObjectURL(
52+
new Blob([tokenizerModelBlobBytes], { type: 'application/octet-stream' })
53+
)
54+
55+
await tokenProcessor.load(tokenizerModelUrl)
4056
return tokenProcessor
4157
}
4258
`.trimStart()

src/.types/index.d.ts

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,17 @@ declare module '@agnai/sentencepiece-js' {
2020

2121
export interface SentencePieceProcessorBinding {
2222
Load(model: SentencePieceStringViewHandle): SentencePieceStatus
23-
EncodeAsIds(text: SentencePieceStringViewHandle): SentencePieceVector<number>
24-
EncodeAsPieces(text: SentencePieceStringViewHandle): SentencePieceVector<string>
23+
EncodeAsIds(
24+
text: SentencePieceStringViewHandle
25+
): SentencePieceVector<number>
26+
EncodeAsPieces(
27+
text: SentencePieceStringViewHandle
28+
): SentencePieceVector<string>
2529
DecodeIds(ids: SentencePieceVector<number>): string
26-
LoadVocabulary(vocab: SentencePieceStringViewHandle, threshold: number): SentencePieceStatus
30+
LoadVocabulary(
31+
vocab: SentencePieceStringViewHandle,
32+
threshold: number
33+
): SentencePieceStatus
2734
}
2835

2936
export interface SentencePieceModule {

src/Model/class.ts

Lines changed: 0 additions & 19 deletions
This file was deleted.

src/index.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
export { createInferenceSession, createTokenProcessor } from './Model/class.js'
1+
export { createInferenceSession, createTokenProcessor } from './models/index.js'

src/models/index.ts

Lines changed: 38 additions & 0 deletions
Large diffs are not rendered by default.

test/index.js

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
import * as ort from 'onnxruntime-web'
2+
import { cleanText } from '@agnai/sentencepiece-js'
3+
import { createInferenceSession, createTokenProcessor } from '../dist/index.js'
4+
5+
const MAX_INPUT_LENGTH = 256
6+
const MAX_GENERATION_LENGTH = 513
7+
const BOS_TOKEN_ID = 1
8+
const EOS_TOKEN_ID = 2
9+
10+
const input = cleanText(`- 3
11+
y TRADER JOE'S
12+
2001 Greenville Ave
13+
Dallas TX 75206
14+
Store #403 - (469) 334-0614
15+
OPEN 8:00AM TO 9:00PM DAILY
16+
R-CARROTS SHREDDED 10 0Z 1.29
17+
R-CUCUMBERS PERSIAN 1 LB 1.99
18+
TOMATOES CRUSHED NO SALT 1.59
19+
TOMATOES WHOLE NO SALT W/BASIL 1.59
20+
ORGANIC OLD_FASHIONED OATMEAL ~~ 2.69
21+
MINI-PEARL TOMATOES. . 2.49
22+
PKG SHREDDED MOZZARELLA LITET 3.9
23+
EGGS 1 DOZ ORGANIC BROWN. 3.79
24+
BEANS GARBANZO 0.89
25+
SPROUTED CA STYLE Zea
26+
A-AVOCADOS HASS BAG ACT 2:39
27+
A-APPLE BAG JAZZ 2 |B gr
28+
A-PEPPER BELL EACH XL RED 0.99
29+
GROCERY NON TAXABLE 0.98
30+
260.49
31+
BANANAS ORGANIC 0.87
32+
3kA 6 0.29/EA
33+
CREAMY SALTED PEANUT BUT TER 2.49
34+
WHL WHT PITA BREAD 1.69
35+
GROCERY NON TAXABLE 1.38
36+
260.69
37+
SUBTOTAL $38.68
38+
TOTAL $38.68
39+
CASH $40.00
40+
CHANGE $1.32
41+
ITEMS 22 Higgins, Ryan
42+
06-28-2014 12:34PM 0403 04 1346 4683
43+
THANK YOU FOR SHOPPING AT
44+
TRADER JOE'S
45+
www. trader joes .com
46+
`)
47+
48+
function toInt64Tensor(values, dims) {
49+
return new ort.Tensor('int64', BigInt64Array.from(values, BigInt), dims)
50+
}
51+
52+
function argmax(values) {
53+
let bestIndex = 0
54+
let bestValue = Number.NEGATIVE_INFINITY
55+
56+
for (let index = 0; index < values.length; index += 1) {
57+
if (values[index] > bestValue) {
58+
bestValue = values[index]
59+
bestIndex = index
60+
}
61+
}
62+
63+
return bestIndex
64+
}
65+
66+
function getNextTokenId(logits) {
67+
const [, targetLength, vocabSize] = logits.dims
68+
const offset = (targetLength - 1) * vocabSize
69+
const stepLogits = logits.data.subarray(offset, offset + vocabSize)
70+
71+
return argmax(stepLogits)
72+
}
73+
74+
const tokenizer = await createTokenProcessor()
75+
const session = await createInferenceSession()
76+
77+
const tokenIds = tokenizer.encodeIds(input).slice(0, MAX_INPUT_LENGTH)
78+
const attentionMask = tokenIds.map(() => 1)
79+
const decoderTokenIds = [BOS_TOKEN_ID]
80+
81+
for (let step = 0; step < MAX_GENERATION_LENGTH; step += 1) {
82+
const outputs = await session.run({
83+
input_ids: toInt64Tensor(tokenIds, [1, tokenIds.length]),
84+
attention_mask: toInt64Tensor(attentionMask, [1, attentionMask.length]),
85+
decoder_input_ids: toInt64Tensor(decoderTokenIds, [
86+
1,
87+
decoderTokenIds.length,
88+
]),
89+
})
90+
91+
const nextTokenId = getNextTokenId(outputs.logits)
92+
if (nextTokenId === EOS_TOKEN_ID) {
93+
break
94+
}
95+
96+
decoderTokenIds.push(nextTokenId)
97+
}
98+
99+
const outputTokenIds = decoderTokenIds.slice(1)
100+
const outputText = tokenizer.decodeIds(outputTokenIds)
101+
102+
console.log({
103+
inputLength: input.length,
104+
tokenCount: tokenIds.length,
105+
outputTokenCount: outputTokenIds.length,
106+
outputTokenIds,
107+
outputText,
108+
})

0 commit comments

Comments
 (0)