-
Notifications
You must be signed in to change notification settings - Fork 3.9k
[Web] Improve large tensor loading in wasm runtime #19771
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -172,6 +172,7 @@ class RuntimeContext implements Disposable { | |||||||||||||||||||
| tensorCacheRemove: PackedFunc; | ||||||||||||||||||||
| tensorCacheClear: PackedFunc; | ||||||||||||||||||||
| arrayDecodeStorage: PackedFunc; | ||||||||||||||||||||
| storageSizeBytes: PackedFunc; | ||||||||||||||||||||
| paramModuleFromCache: PackedFunc; | ||||||||||||||||||||
| paramModuleFromCacheByName: PackedFunc; | ||||||||||||||||||||
| makeShapeTuple: PackedFunc; | ||||||||||||||||||||
|
|
@@ -207,6 +208,7 @@ class RuntimeContext implements Disposable { | |||||||||||||||||||
| this.tensorCacheUpdate = getGlobalFunc("vm.builtin.tensor_cache.update"); | ||||||||||||||||||||
| this.tensorCacheClear = getGlobalFunc("vm.builtin.tensor_cache.clear"); | ||||||||||||||||||||
| this.arrayDecodeStorage = getGlobalFunc("tvmjs.array.decode_storage"); | ||||||||||||||||||||
| this.storageSizeBytes = getGlobalFunc("tvmjs.runtime.StorageSizeBytes"); | ||||||||||||||||||||
| this.paramModuleFromCache = getGlobalFunc("vm.builtin.param_module_from_cache"); | ||||||||||||||||||||
| this.paramModuleFromCacheByName = getGlobalFunc("vm.builtin.param_module_from_cache_by_name"); | ||||||||||||||||||||
| this.makeShapeTuple = getGlobalFunc("ffi.Shape"); | ||||||||||||||||||||
|
|
@@ -230,6 +232,7 @@ class RuntimeContext implements Disposable { | |||||||||||||||||||
| this.tensorCacheRemove.dispose(); | ||||||||||||||||||||
| this.tensorCacheUpdate.dispose(); | ||||||||||||||||||||
| this.arrayDecodeStorage.dispose(); | ||||||||||||||||||||
| this.storageSizeBytes.dispose(); | ||||||||||||||||||||
| this.paramModuleFromCache.dispose(); | ||||||||||||||||||||
| this.paramModuleFromCacheByName.dispose(); | ||||||||||||||||||||
| this.makeShapeTuple.dispose(); | ||||||||||||||||||||
|
|
@@ -1010,9 +1013,11 @@ export class Instance implements Disposable { | |||||||||||||||||||
| */ | ||||||||||||||||||||
| withNewScope<T>(action: () => T): T { | ||||||||||||||||||||
| this.beginScope(); | ||||||||||||||||||||
| const val = action(); | ||||||||||||||||||||
| this.endScope(); | ||||||||||||||||||||
| return val; | ||||||||||||||||||||
| try { | ||||||||||||||||||||
| return action(); | ||||||||||||||||||||
| } finally { | ||||||||||||||||||||
| this.endScope(); | ||||||||||||||||||||
| } | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
|
||||||||||||||||||||
| /** | ||||||||||||||||||||
|
|
@@ -1323,6 +1328,19 @@ export class Instance implements Disposable { | |||||||||||||||||||
| artifactCache: ArtifactCacheTemplate, | ||||||||||||||||||||
| signal?: AbortSignal, | ||||||||||||||||||||
| ) { | ||||||||||||||||||||
| // Avoid a single JS-to-wasm byte-array call for multi-hundred-MiB | ||||||||||||||||||||
| // tensor-cache records. The cap is a conservative per-call staging size, | ||||||||||||||||||||
| // independent of the final tensor allocation size. Smaller records keep | ||||||||||||||||||||
| // the existing full-record path. | ||||||||||||||||||||
| const maxChunkBytes = 128 * 1024 * 1024; | ||||||||||||||||||||
| const storageSizeBytes = (numElements: number, dtype: string): number | undefined => { | ||||||||||||||||||||
| try { | ||||||||||||||||||||
| return this.ctx.storageSizeBytes(new Scalar(numElements, "int"), dtype) as number; | ||||||||||||||||||||
| } catch { | ||||||||||||||||||||
| // Unknown dtypes can still use the original full-record loading path. | ||||||||||||||||||||
| return undefined; | ||||||||||||||||||||
| } | ||||||||||||||||||||
| }; | ||||||||||||||||||||
| const perf = compact.getPerformance(); | ||||||||||||||||||||
| const tstart = perf.now(); | ||||||||||||||||||||
| let totalBytes = 0; | ||||||||||||||||||||
|
|
@@ -1421,9 +1439,68 @@ export class Instance implements Disposable { | |||||||||||||||||||
| this.empty(rec.shape, rec.dtype, this.cpu()) | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| }); | ||||||||||||||||||||
| const recSource = buffer.slice(rec.byteOffset, rec.byteOffset + rec.nbytes); | ||||||||||||||||||||
| const shardBytes = buffer instanceof Uint8Array ? buffer : new Uint8Array(buffer); | ||||||||||||||||||||
| const recSource = | ||||||||||||||||||||
| rec.byteOffset === 0 && rec.nbytes === shardBytes.byteLength | ||||||||||||||||||||
| ? shardBytes | ||||||||||||||||||||
| : shardBytes.subarray(rec.byteOffset, rec.byteOffset + rec.nbytes); | ||||||||||||||||||||
| let canChunkRecord = | ||||||||||||||||||||
| rec.nbytes > maxChunkBytes && | ||||||||||||||||||||
| rec.shape.length >= 1 && | ||||||||||||||||||||
| Number.isInteger(rec.shape[0]) && | ||||||||||||||||||||
| rec.shape[0] > 0 && | ||||||||||||||||||||
| rec.nbytes % rec.shape[0] === 0; | ||||||||||||||||||||
| const outerDim = canChunkRecord ? rec.shape[0] : 1; | ||||||||||||||||||||
| const sourceStrideBytes = canChunkRecord ? rec.nbytes / outerDim : rec.nbytes; | ||||||||||||||||||||
| let targetStrideBytes = 0; | ||||||||||||||||||||
| if (canChunkRecord) { | ||||||||||||||||||||
| const numElements = rec.shape.reduce((acc, value) => acc * value, 1); | ||||||||||||||||||||
| const targetBytes = storageSizeBytes(numElements, rec.dtype); | ||||||||||||||||||||
| canChunkRecord = | ||||||||||||||||||||
| sourceStrideBytes <= maxChunkBytes && | ||||||||||||||||||||
| targetBytes !== undefined && | ||||||||||||||||||||
| targetBytes % outerDim === 0; | ||||||||||||||||||||
| if (canChunkRecord) { | ||||||||||||||||||||
| targetStrideBytes = targetBytes / outerDim; | ||||||||||||||||||||
| } | ||||||||||||||||||||
| } | ||||||||||||||||||||
| const copyRecordToTensor = (targetTensor: Tensor, sourceBytes: Uint8Array) => { | ||||||||||||||||||||
| if (!canChunkRecord) { | ||||||||||||||||||||
| this.ctx.arrayDecodeStorage(targetTensor, sourceBytes, rec.format, rec.dtype); | ||||||||||||||||||||
| return; | ||||||||||||||||||||
| } | ||||||||||||||||||||
| const chunkOuterDim = Math.max(1, Math.floor(maxChunkBytes / sourceStrideBytes)); | ||||||||||||||||||||
| for (let outerOffset = 0; outerOffset < outerDim; outerOffset += chunkOuterDim) { | ||||||||||||||||||||
| const outerCount = Math.min(chunkOuterDim, outerDim - outerOffset); | ||||||||||||||||||||
| const sourceByteOffset = outerOffset * sourceStrideBytes; | ||||||||||||||||||||
| const targetByteOffset = outerOffset * targetStrideBytes; | ||||||||||||||||||||
| const chunkBytes = outerCount * sourceStrideBytes; | ||||||||||||||||||||
| const chunkShape = rec.shape.slice(); | ||||||||||||||||||||
| chunkShape[0] = outerCount; | ||||||||||||||||||||
| const chunkView = this.withNewScope(() => { | ||||||||||||||||||||
| const chunkShapeTuple = this.makeShapeTuple(chunkShape); | ||||||||||||||||||||
| return this.detachFromCurrentScope( | ||||||||||||||||||||
| this.ctx.tensorCreateView( | ||||||||||||||||||||
| targetTensor, | ||||||||||||||||||||
| chunkShapeTuple, | ||||||||||||||||||||
| rec.dtype, | ||||||||||||||||||||
| new Scalar(targetByteOffset, "int"), | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| ); | ||||||||||||||||||||
| }); | ||||||||||||||||||||
| const chunkSource = sourceBytes.subarray( | ||||||||||||||||||||
| sourceByteOffset, | ||||||||||||||||||||
| sourceByteOffset + chunkBytes, | ||||||||||||||||||||
| ); | ||||||||||||||||||||
| try { | ||||||||||||||||||||
| this.ctx.arrayDecodeStorage(chunkView, chunkSource, rec.format, rec.dtype); | ||||||||||||||||||||
| } finally { | ||||||||||||||||||||
| chunkView.dispose(); | ||||||||||||||||||||
| } | ||||||||||||||||||||
| } | ||||||||||||||||||||
| }; | ||||||||||||||||||||
| // first sync copy to cpu. | ||||||||||||||||||||
| this.ctx.arrayDecodeStorage(cpu_arr, new Uint8Array(recSource), rec.format, rec.dtype); | ||||||||||||||||||||
| copyRecordToTensor(cpu_arr, recSource); | ||||||||||||||||||||
| // then async stream into GPU if needed | ||||||||||||||||||||
| if (device.deviceType === DeviceStrToEnum.cpu) { | ||||||||||||||||||||
| this.tensorCacheUpdate(rec.name, cpu_arr, false); | ||||||||||||||||||||
|
|
@@ -1435,7 +1512,42 @@ export class Instance implements Disposable { | |||||||||||||||||||
| this.empty(rec.shape, rec.dtype, device) | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| }); | ||||||||||||||||||||
| gpu_arr.copyFrom(cpu_arr); | ||||||||||||||||||||
| if (!canChunkRecord) { | ||||||||||||||||||||
| gpu_arr.copyFrom(cpu_arr); | ||||||||||||||||||||
| } else { | ||||||||||||||||||||
| const chunkOuterDim = Math.max(1, Math.floor(maxChunkBytes / sourceStrideBytes)); | ||||||||||||||||||||
| for (let outerOffset = 0; outerOffset < outerDim; outerOffset += chunkOuterDim) { | ||||||||||||||||||||
| const outerCount = Math.min(chunkOuterDim, outerDim - outerOffset); | ||||||||||||||||||||
| const targetByteOffset = outerOffset * targetStrideBytes; | ||||||||||||||||||||
| const chunkShape = rec.shape.slice(); | ||||||||||||||||||||
| chunkShape[0] = outerCount; | ||||||||||||||||||||
| const [cpuView, gpuView] = this.withNewScope(() => { | ||||||||||||||||||||
| const chunkShapeTuple = this.makeShapeTuple(chunkShape); | ||||||||||||||||||||
| const cView = this.ctx.tensorCreateView( | ||||||||||||||||||||
| cpu_arr, | ||||||||||||||||||||
| chunkShapeTuple, | ||||||||||||||||||||
| rec.dtype, | ||||||||||||||||||||
| new Scalar(targetByteOffset, "int"), | ||||||||||||||||||||
| ); | ||||||||||||||||||||
| const gView = this.ctx.tensorCreateView( | ||||||||||||||||||||
| gpu_arr, | ||||||||||||||||||||
| chunkShapeTuple, | ||||||||||||||||||||
| rec.dtype, | ||||||||||||||||||||
| new Scalar(targetByteOffset, "int"), | ||||||||||||||||||||
| ); | ||||||||||||||||||||
| return [ | ||||||||||||||||||||
| this.detachFromCurrentScope(cView), | ||||||||||||||||||||
| this.detachFromCurrentScope(gView), | ||||||||||||||||||||
| ]; | ||||||||||||||||||||
| }); | ||||||||||||||||||||
|
Comment on lines
+1524
to
+1542
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the GPU copy path, const [cpuView, gpuView] = this.withNewScope(() => {
const cView = this.ctx.tensorCreateView(
cpu_arr,
chunkShapeTuple,
rec.dtype,
new Scalar(targetByteOffset, "int"),
);
const gView = this.ctx.tensorCreateView(
gpu_arr,
chunkShapeTuple,
rec.dtype,
new Scalar(targetByteOffset, "int"),
);
return [
this.detachFromCurrentScope(cView),
this.detachFromCurrentScope(gView),
];
}); |
||||||||||||||||||||
| try { | ||||||||||||||||||||
| gpuView.copyFrom(cpuView); | ||||||||||||||||||||
| } finally { | ||||||||||||||||||||
| cpuView.dispose(); | ||||||||||||||||||||
| gpuView.dispose(); | ||||||||||||||||||||
| } | ||||||||||||||||||||
| } | ||||||||||||||||||||
| } | ||||||||||||||||||||
| await device.sync(); | ||||||||||||||||||||
| this.tensorCacheUpdate(rec.name, gpu_arr, false); | ||||||||||||||||||||
| cpu_arr.dispose(); | ||||||||||||||||||||
|
|
@@ -2258,6 +2370,28 @@ export class Instance implements Disposable { | |||||||||||||||||||
| case TypeIndex.kTVMFFIOpaquePtr: { | ||||||||||||||||||||
| return this.memory.loadPointer(valuePtr); | ||||||||||||||||||||
| } | ||||||||||||||||||||
| case TypeIndex.kTVMFFIShape: { | ||||||||||||||||||||
| const shapeObjPtr = this.memory.loadPointer(valuePtr); | ||||||||||||||||||||
| if (shapeObjPtr === 0) { | ||||||||||||||||||||
| return null; | ||||||||||||||||||||
| } | ||||||||||||||||||||
| if (callbackArg) { | ||||||||||||||||||||
|
Comment on lines
+2373
to
+2378
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a defensive null check for
Suggested change
|
||||||||||||||||||||
| const shapeCellPtr = shapeObjPtr + SizeOf.ObjectHeader; | ||||||||||||||||||||
| const shapeDataPtr = this.memory.loadPointer(shapeCellPtr); | ||||||||||||||||||||
| const shapeLen = this.memory.loadUSize(shapeCellPtr + this.memory.sizeofPtr()); | ||||||||||||||||||||
| const result = new Array<number>(shapeLen); | ||||||||||||||||||||
| for (let i = 0; i < shapeLen; ++i) { | ||||||||||||||||||||
| result[i] = this.memory.loadI64(shapeDataPtr + i * SizeOf.I64); | ||||||||||||||||||||
| } | ||||||||||||||||||||
| this.lib.checkCall( | ||||||||||||||||||||
| (this.lib.exports.TVMFFIObjectDecRef as ctypes.FTVMFFIObjectDecRef)(shapeObjPtr) | ||||||||||||||||||||
| ); | ||||||||||||||||||||
| return result; | ||||||||||||||||||||
| } | ||||||||||||||||||||
| return this.ctx.attachToCurrentScope( | ||||||||||||||||||||
| new TVMObject(shapeObjPtr, this.lib, this.ctx) | ||||||||||||||||||||
| ); | ||||||||||||||||||||
| } | ||||||||||||||||||||
| case TypeIndex.kTVMFFITensor: { | ||||||||||||||||||||
| return this.ctx.attachToCurrentScope( | ||||||||||||||||||||
| new Tensor(this.memory.loadPointer(valuePtr), this.lib, this.ctx, false) | ||||||||||||||||||||
|
|
||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how is the number determined, would be good to have a sense of what WebGPU runtime supports, the motivation here is not as clear