Skip to content

Commit aad8486

Browse files
committed
feat(models): add INT8 quantization option for AI model builds
Add --int8 flag to allow choosing between INT8 and INT4 quantization levels. - INT8: Better compatibility, ~50% size reduction (using dynamic quantization) - INT4: Maximum compression, ~75% size reduction (using MatMulNBitsQuantizer, default) Implementation: - Parse --int8 flag, default to INT4 for backward compatibility - Update quantizeModel to support both quantization methods - Update file naming to include quantization level (e.g., minilm-l6-int8.onnx) - Update checkpoint keys to separate INT8 and INT4 builds - Update copyToDist to use dynamic file names based on quantization level Usage: node packages/models/scripts/build.mjs # INT4 (default) node packages/models/scripts/build.mjs --int8 # INT8 node packages/models/scripts/build.mjs --int4 # INT4 (explicit)
1 parent 216ad50 commit aad8486

File tree

1 file changed

+79
-46
lines changed

1 file changed

+79
-46
lines changed

packages/models/scripts/build.mjs

Lines changed: 79 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,22 @@
33
/**
44
* Build script for @socketsecurity/models.
55
*
6-
* Downloads AI models from Hugging Face, converts to ONNX, and applies INT4 quantization.
6+
* Downloads AI models from Hugging Face, converts to ONNX, and applies quantization.
77
*
88
* Workflow:
99
* 1. Download models from Hugging Face (with fallbacks)
1010
* 2. Convert to ONNX if needed
11-
* 3. Apply INT4 quantization for maximum compression (99.8% size reduction)
11+
* 3. Apply quantization (INT4 or INT8) for compression
1212
* 4. Output quantized ONNX models
13+
*
14+
* Options:
15+
* --int8 Use INT8 quantization (better compatibility, ~50% size reduction)
16+
* --int4 Use INT4 quantization (maximum compression, ~75% size reduction, default)
17+
* --minilm Build MiniLM-L6 model only
18+
* --codet5 Build CodeT5 model only
19+
* --all Build all models
20+
* --force Force rebuild even if checkpoints exist
21+
* --clean Clean all checkpoints before building
1322
*/
1423

1524
import { existsSync } from 'node:fs'
@@ -48,6 +57,9 @@ const NO_SELF_UPDATE = args.includes('--no-self-update')
4857
const BUILD_MINILM = args.includes('--minilm') || (!args.includes('--codet5') && !args.includes('--all'))
4958
const BUILD_CODET5 = args.includes('--codet5') || args.includes('--all')
5059

60+
// Quantization level (default: INT4 for maximum compression).
61+
const QUANT_LEVEL = args.includes('--int8') ? 'INT8' : 'INT4'
62+
5163
const __filename = fileURLToPath(import.meta.url)
5264
const __dirname = dirname(__filename)
5365
const ROOT = join(__dirname, '..')
@@ -182,40 +194,44 @@ async function convertToOnnx(modelKey) {
182194
}
183195

184196
/**
185-
* Apply INT4 quantization for maximum compression.
197+
* Apply quantization for compression.
198+
*
199+
* Supports two quantization levels:
200+
* - INT4: MatMulNBitsQuantizer with RTN weight-only quantization (maximum compression).
201+
* - INT8: Dynamic quantization (better compatibility, moderate compression).
186202
*
187-
* Uses MatMulNBitsQuantizer with RTN (Round To Nearest) weight-only quantization:
188-
* - Converts MatMul operators to MatMulNBits (INT4).
189-
* - Results in significant size reduction (e.g., 86MB → ~20MB for MiniLM).
190-
* - Model remains fully functional with minimal accuracy loss.
203+
* Results in significant size reduction with minimal accuracy loss.
191204
*/
192-
async function quantizeModel(modelKey) {
193-
if (!(await shouldRun(PACKAGE_NAME, `quantized-${modelKey}`, FORCE_BUILD))) {
205+
async function quantizeModel(modelKey, quantLevel) {
206+
const suffix = quantLevel.toLowerCase()
207+
const checkpointKey = `quantized-${modelKey}-${suffix}`
208+
209+
if (!(await shouldRun(PACKAGE_NAME, checkpointKey, FORCE_BUILD))) {
194210
// Return existing quantized paths.
195211
const modelDir = join(MODELS, modelKey)
196212
if (modelKey === 'codet5') {
197213
return [
198-
join(modelDir, 'encoder_model.int4.onnx'),
199-
join(modelDir, 'decoder_model.int4.onnx')
214+
join(modelDir, `encoder_model.${suffix}.onnx`),
215+
join(modelDir, `decoder_model.${suffix}.onnx`)
200216
]
201217
}
202-
return [join(modelDir, 'model.int4.onnx')]
218+
return [join(modelDir, `model.${suffix}.onnx`)]
203219
}
204220

205-
logger.step(`Applying INT4 quantization to ${modelKey}`)
221+
logger.step(`Applying ${quantLevel} quantization to ${modelKey}`)
206222

207223
const modelDir = join(MODELS, modelKey)
208224

209225
// Different files for codet5 (encoder/decoder) vs minilm (single model).
210226
const models = modelKey === 'codet5'
211227
? [
212-
{ input: 'encoder_model.onnx', output: 'encoder_model.int4.onnx' },
213-
{ input: 'decoder_model.onnx', output: 'decoder_model.int4.onnx' }
228+
{ input: 'encoder_model.onnx', output: `encoder_model.${suffix}.onnx` },
229+
{ input: 'decoder_model.onnx', output: `decoder_model.${suffix}.onnx` }
214230
]
215-
: [{ input: 'model.onnx', output: 'model.int4.onnx' }]
231+
: [{ input: 'model.onnx', output: `model.${suffix}.onnx` }]
216232

217233
const quantizedPaths = []
218-
let method = 'INT4'
234+
let method = quantLevel
219235

220236
for (const { input, output } of models) {
221237
const onnxPath = join(modelDir, input)
@@ -230,19 +246,31 @@ async function quantizeModel(modelKey) {
230246
let quantSize
231247

232248
try {
233-
await execAsync(
234-
`python3 -c "` +
235-
`from onnxruntime.quantization.matmul_nbits_quantizer import MatMulNBitsQuantizer, RTNWeightOnlyQuantConfig; ` +
236-
`from onnxruntime.quantization import quant_utils; ` +
237-
`from pathlib import Path; ` +
238-
`quant_config = RTNWeightOnlyQuantConfig(); ` +
239-
`model = quant_utils.load_model_with_shape_infer(Path('${onnxPath}')); ` +
240-
`quant = MatMulNBitsQuantizer(model, algo_config=quant_config); ` +
241-
`quant.process(); ` +
242-
`quant.model.save_model_to_file('${quantPath}', True)` +
243-
`"`,
244-
{ stdio: 'inherit' }
245-
)
249+
if (quantLevel === 'INT8') {
250+
// INT8: Use dynamic quantization (simpler, more compatible).
251+
await execAsync(
252+
`python3 -c "` +
253+
`from onnxruntime.quantization import quantize_dynamic, QuantType; ` +
254+
`quantize_dynamic('${onnxPath}', '${quantPath}', weight_type=QuantType.QUInt8)` +
255+
`"`,
256+
{ stdio: 'inherit' }
257+
)
258+
} else {
259+
// INT4: Use MatMulNBitsQuantizer (maximum compression).
260+
await execAsync(
261+
`python3 -c "` +
262+
`from onnxruntime.quantization.matmul_nbits_quantizer import MatMulNBitsQuantizer, RTNWeightOnlyQuantConfig; ` +
263+
`from onnxruntime.quantization import quant_utils; ` +
264+
`from pathlib import Path; ` +
265+
`quant_config = RTNWeightOnlyQuantConfig(); ` +
266+
`model = quant_utils.load_model_with_shape_infer(Path('${onnxPath}')); ` +
267+
`quant = MatMulNBitsQuantizer(model, algo_config=quant_config); ` +
268+
`quant.process(); ` +
269+
`quant.model.save_model_to_file('${quantPath}', True)` +
270+
`"`,
271+
{ stdio: 'inherit' }
272+
)
273+
}
246274

247275
// Get sizes.
248276
originalSize = (await readFile(onnxPath)).length
@@ -251,7 +279,7 @@ async function quantizeModel(modelKey) {
251279

252280
logger.substep(`${input}: ${(originalSize / 1024 / 1024).toFixed(2)} MB → ${(quantSize / 1024 / 1024).toFixed(2)} MB (${savings}% savings)`)
253281
} catch (e) {
254-
logger.warn(`INT4 quantization failed for ${input}, using FP32 model: ${e.message}`)
282+
logger.warn(`${quantLevel} quantization failed for ${input}, using FP32 model: ${e.message}`)
255283
// Copy the original ONNX model as the "quantized" version.
256284
await copyFile(onnxPath, quantPath)
257285
method = 'FP32'
@@ -263,9 +291,10 @@ async function quantizeModel(modelKey) {
263291
}
264292

265293
logger.success(`Quantized to ${method}`)
266-
await createCheckpoint(PACKAGE_NAME, `quantized-${modelKey}`, {
294+
await createCheckpoint(PACKAGE_NAME, checkpointKey, {
267295
modelKey,
268296
method,
297+
quantLevel,
269298
})
270299

271300
return quantizedPaths
@@ -274,26 +303,27 @@ async function quantizeModel(modelKey) {
274303
/**
275304
* Copy quantized models and tokenizers to dist.
276305
*/
277-
async function copyToDist(modelKey, quantizedPaths) {
306+
async function copyToDist(modelKey, quantizedPaths, quantLevel) {
278307
logger.step('Copying models to dist')
279308

280309
await mkdir(DIST, { recursive: true })
281310

282311
const modelDir = join(MODELS, modelKey)
312+
const suffix = quantLevel.toLowerCase()
283313

284314
if (modelKey === 'codet5') {
285315
// CodeT5: encoder, decoder, tokenizer.
286-
await copyFile(quantizedPaths[0], join(DIST, 'codet5-encoder.onnx'))
287-
await copyFile(quantizedPaths[1], join(DIST, 'codet5-decoder.onnx'))
316+
await copyFile(quantizedPaths[0], join(DIST, `codet5-encoder-${suffix}.onnx`))
317+
await copyFile(quantizedPaths[1], join(DIST, `codet5-decoder-${suffix}.onnx`))
288318
await copyFile(join(modelDir, 'tokenizer.json'), join(DIST, 'codet5-tokenizer.json'))
289319

290-
logger.success('Copied codet5 models to dist/')
320+
logger.success(`Copied codet5 models (${quantLevel}) to dist/`)
291321
} else {
292322
// MiniLM: single model + tokenizer.
293-
await copyFile(quantizedPaths[0], join(DIST, 'minilm-l6.onnx'))
323+
await copyFile(quantizedPaths[0], join(DIST, `minilm-l6-${suffix}.onnx`))
294324
await copyFile(join(modelDir, 'tokenizer.json'), join(DIST, 'minilm-l6-tokenizer.json'))
295325

296-
logger.success('Copied minilm-l6 models to dist/')
326+
logger.success(`Copied minilm-l6 model (${quantLevel}) to dist/`)
297327
}
298328
}
299329

@@ -303,12 +333,15 @@ async function copyToDist(modelKey, quantizedPaths) {
303333
async function main() {
304334
logger.info('Building @socketsecurity/models')
305335
logger.info('='.repeat(60))
336+
logger.info(`Quantization: ${QUANT_LEVEL}`)
306337
logger.info('')
307338

308339
const startTime = Date.now()
309340

341+
const suffix = QUANT_LEVEL.toLowerCase()
342+
310343
// Clean checkpoints if requested or if output is missing.
311-
const outputMissing = !existsSync(join(DIST, 'minilm-l6.onnx')) && !existsSync(join(DIST, 'codet5-encoder.onnx'))
344+
const outputMissing = !existsSync(join(DIST, `minilm-l6-${suffix}.onnx`)) && !existsSync(join(DIST, `codet5-encoder-${suffix}.onnx`))
312345

313346
if (CLEAN_BUILD || outputMissing) {
314347
if (outputMissing) {
@@ -330,8 +363,8 @@ async function main() {
330363

331364
await downloadModel('minilm-l6')
332365
await convertToOnnx('minilm-l6')
333-
const quantizedPaths = await quantizeModel('minilm-l6')
334-
await copyToDist('minilm-l6', quantizedPaths)
366+
const quantizedPaths = await quantizeModel('minilm-l6', QUANT_LEVEL)
367+
await copyToDist('minilm-l6', quantizedPaths, QUANT_LEVEL)
335368
}
336369

337370
// Build CodeT5 if requested.
@@ -342,8 +375,8 @@ async function main() {
342375

343376
await downloadModel('codet5')
344377
await convertToOnnx('codet5')
345-
const quantizedPaths = await quantizeModel('codet5')
346-
await copyToDist('codet5', quantizedPaths)
378+
const quantizedPaths = await quantizeModel('codet5', QUANT_LEVEL)
379+
await copyToDist('codet5', quantizedPaths, QUANT_LEVEL)
347380
}
348381

349382
const duration = ((Date.now() - startTime) / 1000).toFixed(1)
@@ -357,12 +390,12 @@ async function main() {
357390
logger.substep(`Output: ${DIST}`)
358391

359392
if (BUILD_MINILM) {
360-
logger.substep(' - minilm-l6.onnx (INT4 quantized)')
393+
logger.substep(` - minilm-l6-${suffix}.onnx (${QUANT_LEVEL} quantized)`)
361394
logger.substep(' - minilm-l6-tokenizer.json')
362395
}
363396
if (BUILD_CODET5) {
364-
logger.substep(' - codet5-encoder.onnx (INT4 quantized)')
365-
logger.substep(' - codet5-decoder.onnx (INT4 quantized)')
397+
logger.substep(` - codet5-encoder-${suffix}.onnx (${QUANT_LEVEL} quantized)`)
398+
logger.substep(` - codet5-decoder-${suffix}.onnx (${QUANT_LEVEL} quantized)`)
366399
logger.substep(' - codet5-tokenizer.json')
367400
}
368401
} catch (error) {

0 commit comments

Comments
 (0)