Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 113 additions & 29 deletions js/llama_webgpu_bridge.js
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,26 @@ function toUint8Array(value) {
return null;
}

function trimUnstableUtf8Tail(text) {
if (typeof text !== 'string' || text.length === 0) {
return '';
}

let end = text.length;
while (end > 0 && text.charCodeAt(end - 1) === 0xFFFD) {
end -= 1;
}

if (end > 0) {
const tail = text.charCodeAt(end - 1);
if (tail >= 0xD800 && tail <= 0xDBFF) {
end -= 1;
}
}

return end === text.length ? text : text.slice(0, end);
}

function toFloat32Array(value) {
if (!value) {
return null;
Expand Down Expand Up @@ -3831,7 +3851,8 @@ class LlamaWebGpuBridgeRuntime {
const shouldYieldForResponsiveness =
!(typeof WorkerGlobalScope !== 'undefined' && globalThis instanceof WorkerGlobalScope);
const yieldInterval = shouldYieldForResponsiveness ? 4 : 0;
let streamed = shouldEmitCurrentText ? '' : null;
let streamed = '';
let emittedStableText = '';

while (generated < nPredict) {
if (this._abortRequested || options.signal?.aborted) {
Expand Down Expand Up @@ -3888,19 +3909,25 @@ class LlamaWebGpuBridgeRuntime {
}

generated += 1;
const piece = this._core.ccall('llamadart_webgpu_last_piece', 'string', [], []) || '';
if (piece.length === 0) {
const fullText = this._core.ccall('llamadart_webgpu_last_output', 'string', [], []) || '';
streamed = fullText;
const stableText = trimUnstableUtf8Tail(fullText);

if (!stableText.startsWith(emittedStableText)) {
emittedStableText = '';
}

const deltaText = stableText.slice(emittedStableText.length);
if (deltaText.length === 0) {
continue;
}
emittedStableText = stableText;

if (typeof options.onToken === 'function') {
const piecePayload = emitTokenText ? piece : textEncoder.encode(piece);
if (shouldEmitCurrentText) {
streamed += piece;
options.onToken(piecePayload, streamed);
} else {
options.onToken(piecePayload, null);
}
const piecePayload = emitTokenText
? deltaText
: textEncoder.encode(deltaText);
options.onToken(piecePayload, shouldEmitCurrentText ? fullText : null);
}

if (yieldInterval > 0 && (generated % yieldInterval) === 0) {
Expand All @@ -3909,6 +3936,17 @@ class LlamaWebGpuBridgeRuntime {
}

const text = this._core.ccall('llamadart_webgpu_last_output', 'string', [], []) || streamed || '';
if (typeof options.onToken === 'function') {
const tailText = text.startsWith(emittedStableText)
? text.slice(emittedStableText.length)
: '';
if (tailText.length > 0) {
const piecePayload = emitTokenText
? tailText
: textEncoder.encode(tailText);
options.onToken(piecePayload, shouldEmitCurrentText ? text : null);
}
}
return text;
} finally {
if (generationStarted) {
Expand Down Expand Up @@ -4203,6 +4241,40 @@ export class LlamaWebGpuBridge {
return sanitized;
}

_createCpuSafeMultimodalLoadOptions(options = {}) {
const sanitized = this._sanitizeModelLoadOptions(options);
sanitized.nGpuLayers = 0;

if (Number.isFinite(Number(sanitized.nCtx)) && Number(sanitized.nCtx) > 4096) {
sanitized.nCtx = 4096;
}

if (!Number.isFinite(Number(sanitized.nThreads)) || Number(sanitized.nThreads) <= 0) {
sanitized.nThreads = 4;
} else {
sanitized.nThreads = Math.min(4, Math.max(1, Math.trunc(Number(sanitized.nThreads))));
}

sanitized.nThreadsBatch = sanitized.nThreads;

if (!Number.isFinite(Number(sanitized.nBatch)) || Number(sanitized.nBatch) <= 0) {
sanitized.nBatch = 128;
} else {
sanitized.nBatch = Math.min(128, Math.max(32, Math.trunc(Number(sanitized.nBatch))));
}

if (!Number.isFinite(Number(sanitized.nUbatch)) || Number(sanitized.nUbatch) <= 0) {
sanitized.nUbatch = Math.min(64, sanitized.nBatch);
} else {
sanitized.nUbatch = Math.min(
sanitized.nBatch,
Math.min(64, Math.max(1, Math.trunc(Number(sanitized.nUbatch)))),
);
}

return sanitized;
}

_rememberLoadedModel(url, options = {}) {
const normalizedUrl = String(url || '').trim();
if (normalizedUrl.length === 0) {
Expand Down Expand Up @@ -4277,7 +4349,9 @@ export class LlamaWebGpuBridge {
return false;
}

const selectedOptions = this._sanitizeModelLoadOptions(this._loadedModelOptions || {});
const selectedOptions = this._sanitizeModelLoadOptions(
this._loadedModelOptions || {},
);

const applyWorkerSafeMode = async () => {
await this._callWorker('loadModelFromUrl', [this._loadedModelUrl, selectedOptions]);
Expand Down Expand Up @@ -4381,12 +4455,29 @@ export class LlamaWebGpuBridge {
}

const forceReloadRequested = options?._llamadartForceRuntimeReload === true;
const mediaPartsRequested = this._hasMediaParts(options);
const shouldEnsureMultimodalInRuntime =
this._hasMediaParts(options)
mediaPartsRequested
&& typeof this._loadedMmProjUrl === 'string'
&& this._loadedMmProjUrl.length > 0;
const workerTimedOut = this._isWorkerTimeoutError(fallbackError);
const forcedCpuFallback = this._isForcedCpuMultimodalFallbackError(fallbackError);
const dispatchWorkgroupFallback = this._isDispatchWorkgroupLimitError(fallbackError);
const loadedGpuLayers = Number(this._loadedModelOptions?.nGpuLayers);
const metadataGpuLayers = Number(this._metadata?.['llamadart.webgpu.n_gpu_layers']);
const modelLoadedWithGpu = Number.isFinite(loadedGpuLayers)
? loadedGpuLayers !== 0
: (Number.isFinite(metadataGpuLayers) ? metadataGpuLayers !== 0 : true);
const shouldUseCpuMultimodalFallback =
mediaPartsRequested
&& modelLoadedWithGpu
&& (dispatchWorkgroupFallback || forcedCpuFallback || workerTimedOut);

if (Number(this._runtime?._modelBytes) > 0 && !forceReloadRequested) {
if (
Number(this._runtime?._modelBytes) > 0
&& !forceReloadRequested
&& !shouldUseCpuMultimodalFallback
) {
if (shouldEnsureMultimodalInRuntime) {
const runtimeSupportsMedia =
(typeof this._runtime.supportsVision === 'function' && this._runtime.supportsVision())
Expand All @@ -4407,25 +4498,18 @@ export class LlamaWebGpuBridge {
return;
}

const loadOptions = this._sanitizeModelLoadOptions(this._loadedModelOptions || {});
const workerTimedOut = this._isWorkerTimeoutError(fallbackError);
const forcedCpuFallback = this._isForcedCpuMultimodalFallbackError(fallbackError);
const forceCpuMultimodalFallback =
this._hasMediaParts(options)
&& (this._isDispatchWorkgroupLimitError(fallbackError)
|| forcedCpuFallback)
&& Number(loadOptions.nGpuLayers) !== 0;

if (forceCpuMultimodalFallback) {
loadOptions.nGpuLayers = 0;
if (Number.isFinite(loadOptions.nCtx) && Number(loadOptions.nCtx) > 4096) {
loadOptions.nCtx = 4096;
}

const loadOptions = shouldUseCpuMultimodalFallback
? this._createCpuSafeMultimodalLoadOptions(this._loadedModelOptions || {})
: this._sanitizeModelLoadOptions(this._loadedModelOptions || {});
if (shouldUseCpuMultimodalFallback) {
if (forcedCpuFallback) {
this._emitBridgeWarn(
'llamadart: using CPU fallback for multimodal generation stability.',
);
} else if (workerTimedOut) {
this._emitBridgeWarn(
'llamadart: retrying multimodal generation with CPU fallback after worker timeout.',
);
} else {
this._emitBridgeWarn(
'llamadart: retrying multimodal generation with CPU fallback after WebGPU workgroup limit failure.',
Expand All @@ -4448,7 +4532,7 @@ export class LlamaWebGpuBridge {
if (workerTimedOut) {
this._runtime._runtimeNotes.push('worker_fallback_timeout');
}
if (forceCpuMultimodalFallback) {
if (shouldUseCpuMultimodalFallback) {
this._runtime._runtimeNotes.push('worker_fallback_cpu_multimodal');
}
}
Expand Down
8 changes: 8 additions & 0 deletions src/llama_webgpu_core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,14 @@ std::string normalize_media_markers(const std::string & prompt, const size_t med
replace_all_inplace(normalized, "<|image|>", marker);
replace_all_inplace(normalized, "<img>", marker);
replace_all_inplace(normalized, "<|img|>", marker);
replace_all_inplace(
normalized,
"<|vision_start|><|image_pad|><|vision_end|>",
marker);
replace_all_inplace(
normalized,
"<|vision_start|><|video_pad|><|vision_end|>",
marker);
replace_all_inplace(normalized, "<audio>", marker);
replace_all_inplace(normalized, "<|audio|>", marker);

Expand Down
Loading