Skip to content

Commit 4048425

Browse files
committed
Add web bridge native load option parity
1 parent 4a1abc4 commit 4048425

2 files changed

Lines changed: 259 additions & 8 deletions

File tree

js/llama_webgpu_bridge.js

Lines changed: 158 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,44 @@ function parsePositiveInteger(value) {
108108
return Math.trunc(numeric);
109109
}
110110

111+
function parseInteger(value, fallback = 0) {
112+
const numeric = Number(value);
113+
if (!Number.isFinite(numeric)) {
114+
return fallback;
115+
}
116+
return Math.trunc(numeric);
117+
}
118+
119+
function parseBooleanFlag(value, fallback = false) {
120+
if (typeof value === 'boolean') {
121+
return value;
122+
}
123+
if (typeof value === 'number' && Number.isFinite(value)) {
124+
return value !== 0;
125+
}
126+
return fallback;
127+
}
128+
129+
function parseOptionalBooleanFlag(value) {
130+
if (typeof value === 'boolean') {
131+
return value ? 1 : 0;
132+
}
133+
if (typeof value === 'number' && Number.isFinite(value)) {
134+
return value !== 0 ? 1 : 0;
135+
}
136+
return -1;
137+
}
138+
139+
function parseEnumValue(value, allowed, fallback) {
140+
const parsed = parseInteger(value, fallback);
141+
return allowed.includes(parsed) ? parsed : fallback;
142+
}
143+
144+
function parsePositiveNumber(value) {
145+
const numeric = Number(value);
146+
return Number.isFinite(numeric) && numeric > 0 ? numeric : 0;
147+
}
148+
111149
function parseTotalFromContentRangeHeader(contentRangeHeader) {
112150
if (typeof contentRangeHeader !== 'string' || contentRangeHeader.length === 0) {
113151
return 0;
@@ -1309,6 +1347,17 @@ class LlamaWebGpuBridgeRuntime {
13091347
this._nGpuLayers = Number.isFinite(config.nGpuLayers)
13101348
? Number(config.nGpuLayers)
13111349
: -1;
1350+
this._nSeqMax = 0;
1351+
this._useMmap = false;
1352+
this._useMlock = false;
1353+
this._flashAttention = -1;
1354+
this._cacheTypeK = 1;
1355+
this._cacheTypeV = 1;
1356+
this._kvUnified = -1;
1357+
this._ropeFrequencyBase = 0;
1358+
this._ropeFrequencyScale = 0;
1359+
this._splitMode = -1;
1360+
this._mainGpu = -1;
13121361
this._isSafari = isSafariUserAgent(this._config.userAgent ?? globalThis.navigator?.userAgent ?? '');
13131362
this._coreVariant = 'uninitialized';
13141363
this._preferMemory64 = this._config.preferMemory64 !== false;
@@ -1943,6 +1992,70 @@ class LlamaWebGpuBridgeRuntime {
19431992
}
19441993
}
19451994

1995+
_resolveNativeLoadOptions(options = {}) {
1996+
this._nSeqMax = parsePositiveInteger(options.nSeqMax);
1997+
this._useMmap = parseBooleanFlag(options.useMmap, false);
1998+
this._useMlock = parseBooleanFlag(options.useMlock, false);
1999+
this._flashAttention = parseEnumValue(options.flashAttention, [-1, 0, 1], -1);
2000+
this._cacheTypeK = parseEnumValue(options.cacheTypeK, [1, 2, 8], 1);
2001+
this._cacheTypeV = parseEnumValue(options.cacheTypeV, [1, 2, 8], 1);
2002+
this._kvUnified = parseOptionalBooleanFlag(options.kvUnified);
2003+
this._ropeFrequencyBase = parsePositiveNumber(options.ropeFrequencyBase);
2004+
this._ropeFrequencyScale = parsePositiveNumber(options.ropeFrequencyScale);
2005+
this._splitMode = parseEnumValue(options.splitMode, [0, 1, 2, 3], -1);
2006+
this._mainGpu = parseInteger(options.mainGpu, -1);
2007+
if (this._mainGpu < 0) {
2008+
this._mainGpu = -1;
2009+
}
2010+
2011+
const wantsQuantizedKvCache = this._cacheTypeK !== 1 || this._cacheTypeV !== 1;
2012+
if (this._flashAttention === 0 && wantsQuantizedKvCache) {
2013+
throw new Error(
2014+
'Non-F16 KV cache requires flashAttention to be auto or enabled.',
2015+
);
2016+
}
2017+
if (this._flashAttention === -1 && wantsQuantizedKvCache) {
2018+
this._flashAttention = 1;
2019+
this._runtimeNotes.push('flash_attention:auto_enabled_for_kv_cache');
2020+
}
2021+
if (this._kvUnified < 0 && this._nSeqMax > 1) {
2022+
this._kvUnified = 1;
2023+
this._runtimeNotes.push('kv_unified:auto_enabled_for_sequences');
2024+
}
2025+
}
2026+
2027+
_nativeLoadOptionValues() {
2028+
return [
2029+
this._nSeqMax,
2030+
this._useMmap ? 1 : 0,
2031+
this._useMlock ? 1 : 0,
2032+
this._flashAttention,
2033+
this._cacheTypeK,
2034+
this._cacheTypeV,
2035+
this._kvUnified,
2036+
this._ropeFrequencyBase,
2037+
this._ropeFrequencyScale,
2038+
this._splitMode,
2039+
this._mainGpu,
2040+
];
2041+
}
2042+
2043+
_nativeLoadOptionTypes() {
2044+
return [
2045+
'number',
2046+
'number',
2047+
'number',
2048+
'number',
2049+
'number',
2050+
'number',
2051+
'number',
2052+
'number',
2053+
'number',
2054+
'number',
2055+
'number',
2056+
];
2057+
}
2058+
19462059
async _tryLoadModelFromRemoteFetchBackend(core, url, options = {}) {
19472060
if (!this._canUseRemoteFetchBackend(options)) {
19482061
return { loaded: false, sizeBytes: null };
@@ -2011,6 +2124,7 @@ class LlamaWebGpuBridgeRuntime {
20112124
'number',
20122125
'number',
20132126
'number',
2127+
...this._nativeLoadOptionTypes(),
20142128
],
20152129
[
20162130
remoteFetchUrl,
@@ -2021,6 +2135,7 @@ class LlamaWebGpuBridgeRuntime {
20212135
this._nUbatch,
20222136
this._nGpuLayers,
20232137
chunkBytes,
2138+
...this._nativeLoadOptionValues(),
20242139
],
20252140
{ async: true },
20262141
),
@@ -2906,6 +3021,8 @@ class LlamaWebGpuBridgeRuntime {
29063021
this._nUbatch = this._nBatch;
29073022
}
29083023

3024+
this._resolveNativeLoadOptions(options);
3025+
29093026
if (Number.isFinite(this._threadPoolSizeHint) && this._threadPoolSizeHint > 0) {
29103027
this._pushRuntimeNote(`thread_pool_size:${this._threadPoolSizeHint}`);
29113028
}
@@ -2927,6 +3044,9 @@ class LlamaWebGpuBridgeRuntime {
29273044
if (this._nUbatch > 0) {
29283045
this._pushRuntimeNote(`n_ubatch:${this._nUbatch}`);
29293046
}
3047+
if (this._nSeqMax > 0) {
3048+
this._pushRuntimeNote(`n_seq_max:${this._nSeqMax}`);
3049+
}
29303050
if (isCpuModelMode && !Number.isFinite(requestedBatch) && !Number.isFinite(requestedUbatch)) {
29313051
this._runtimeNotes.push('cpu_batch_tuned_default');
29323052
}
@@ -3154,7 +3274,16 @@ class LlamaWebGpuBridgeRuntime {
31543274
await core.ccall(
31553275
'llamadart_webgpu_load_model',
31563276
'number',
3157-
['string', 'number', 'number', 'number', 'number', 'number', 'number'],
3277+
[
3278+
'string',
3279+
'number',
3280+
'number',
3281+
'number',
3282+
'number',
3283+
'number',
3284+
'number',
3285+
...this._nativeLoadOptionTypes(),
3286+
],
31583287
[
31593288
this._modelPath,
31603289
this._nCtx,
@@ -3163,6 +3292,7 @@ class LlamaWebGpuBridgeRuntime {
31633292
this._nBatch,
31643293
this._nUbatch,
31653294
this._nGpuLayers,
3295+
...this._nativeLoadOptionValues(),
31663296
],
31673297
{ async: true },
31683298
),
@@ -3287,6 +3417,7 @@ class LlamaWebGpuBridgeRuntime {
32873417
'number',
32883418
'number',
32893419
'number',
3420+
...this._nativeLoadOptionTypes(),
32903421
],
32913422
[
32923423
reloadUrl,
@@ -3297,6 +3428,7 @@ class LlamaWebGpuBridgeRuntime {
32973428
this._nUbatch,
32983429
candidateLayers,
32993430
remoteFetchReloadChunkBytes,
3431+
...this._nativeLoadOptionValues(),
33003432
],
33013433
{ async: true },
33023434
),
@@ -3306,7 +3438,16 @@ class LlamaWebGpuBridgeRuntime {
33063438
await core.ccall(
33073439
'llamadart_webgpu_load_model',
33083440
'number',
3309-
['string', 'number', 'number', 'number', 'number', 'number', 'number'],
3441+
[
3442+
'string',
3443+
'number',
3444+
'number',
3445+
'number',
3446+
'number',
3447+
'number',
3448+
'number',
3449+
...this._nativeLoadOptionTypes(),
3450+
],
33103451
[
33113452
this._modelPath,
33123453
this._nCtx,
@@ -3315,6 +3456,7 @@ class LlamaWebGpuBridgeRuntime {
33153456
this._nBatch,
33163457
this._nUbatch,
33173458
candidateLayers,
3459+
...this._nativeLoadOptionValues(),
33183460
],
33193461
{ async: true },
33203462
),
@@ -4041,6 +4183,20 @@ class LlamaWebGpuBridgeRuntime {
40414183
'llamadart.webgpu.n_threads_batch': String(this._threadsBatch),
40424184
'llamadart.webgpu.n_batch': this._nBatch > 0 ? String(this._nBatch) : '',
40434185
'llamadart.webgpu.n_ubatch': this._nUbatch > 0 ? String(this._nUbatch) : '',
4186+
'llamadart.webgpu.n_seq_max': this._nSeqMax > 0 ? String(this._nSeqMax) : '',
4187+
'llamadart.webgpu.flash_attention': String(this._flashAttention),
4188+
'llamadart.webgpu.cache_type_k': String(this._cacheTypeK),
4189+
'llamadart.webgpu.cache_type_v': String(this._cacheTypeV),
4190+
'llamadart.webgpu.kv_unified':
4191+
this._kvUnified >= 0 ? String(this._kvUnified) : '',
4192+
'llamadart.webgpu.rope_freq_base':
4193+
this._ropeFrequencyBase > 0 ? String(this._ropeFrequencyBase) : '',
4194+
'llamadart.webgpu.rope_freq_scale':
4195+
this._ropeFrequencyScale > 0 ? String(this._ropeFrequencyScale) : '',
4196+
'llamadart.webgpu.split_mode':
4197+
this._splitMode >= 0 ? String(this._splitMode) : '',
4198+
'llamadart.webgpu.main_gpu':
4199+
this._mainGpu >= 0 ? String(this._mainGpu) : '',
40444200
'llamadart.webgpu.thread_pool_size':
40454201
Number.isFinite(this._threadPoolSizeHint) && this._threadPoolSizeHint > 0
40464202
? String(this._threadPoolSizeHint)

0 commit comments

Comments
 (0)