Skip to content

Commit 9c4334d

Browse files
committed
[Web] Improve large tensor loading in wasm runtime
1 parent 4b7b7de commit 9c4334d

2 files changed

Lines changed: 153 additions & 26 deletions

File tree

web/emcc/wasm_runtime.cc

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,25 @@
3232
#include <tvm/ffi/reflection/registry.h>
3333
#include <tvm/runtime/logging.h>
3434

35+
// FFI core must come before runtime .cc includes in this single translation
36+
// unit. Otherwise, static initialisation can resolve ffi globals before
37+
// object.cc registers them, leading to crashes at module init time.
38+
#include "3rdparty/tvm-ffi/src/ffi/backtrace.cc"
39+
#include "3rdparty/tvm-ffi/src/ffi/container.cc"
40+
#include "3rdparty/tvm-ffi/src/ffi/dtype.cc"
41+
#include "3rdparty/tvm-ffi/src/ffi/error.cc"
42+
#include "3rdparty/tvm-ffi/src/ffi/function.cc"
43+
#include "3rdparty/tvm-ffi/src/ffi/object.cc"
44+
#include "3rdparty/tvm-ffi/src/ffi/tensor.cc"
45+
#include "3rdparty/tvm-ffi/src/ffi/extra/env_c_api.cc"
46+
#include "3rdparty/tvm-ffi/src/ffi/extra/env_context.cc"
47+
#include "3rdparty/tvm-ffi/src/ffi/extra/json_parser.cc"
48+
#include "3rdparty/tvm-ffi/src/ffi/extra/json_writer.cc"
49+
#include "3rdparty/tvm-ffi/src/ffi/extra/library_module.cc"
50+
#include "3rdparty/tvm-ffi/src/ffi/extra/library_module_system_lib.cc"
51+
#include "3rdparty/tvm-ffi/src/ffi/extra/module.cc"
52+
#include "3rdparty/tvm-ffi/src/ffi/testing/testing.cc"
53+
3554
#include "src/runtime/cpu_device_api.cc"
3655
#include "src/runtime/device_api.cc"
3756
#include "src/runtime/extra/contrib/sort/sort.cc"
@@ -46,22 +65,6 @@
4665
#include "src/runtime/tensor.cc"
4766
#include "src/runtime/timer.cc"
4867
#include "src/runtime/workspace_pool.cc"
49-
// relax setup
50-
#include "3rdparty/tvm-ffi/src/ffi/backtrace.cc"
51-
#include "3rdparty/tvm-ffi/src/ffi/container.cc"
52-
#include "3rdparty/tvm-ffi/src/ffi/dtype.cc"
53-
#include "3rdparty/tvm-ffi/src/ffi/error.cc"
54-
#include "3rdparty/tvm-ffi/src/ffi/extra/env_c_api.cc"
55-
#include "3rdparty/tvm-ffi/src/ffi/extra/env_context.cc"
56-
#include "3rdparty/tvm-ffi/src/ffi/extra/json_parser.cc"
57-
#include "3rdparty/tvm-ffi/src/ffi/extra/json_writer.cc"
58-
#include "3rdparty/tvm-ffi/src/ffi/extra/library_module.cc"
59-
#include "3rdparty/tvm-ffi/src/ffi/extra/library_module_system_lib.cc"
60-
#include "3rdparty/tvm-ffi/src/ffi/extra/module.cc"
61-
#include "3rdparty/tvm-ffi/src/ffi/function.cc"
62-
#include "3rdparty/tvm-ffi/src/ffi/object.cc"
63-
#include "3rdparty/tvm-ffi/src/ffi/tensor.cc"
64-
#include "3rdparty/tvm-ffi/src/ffi/testing/testing.cc"
6568
#include "src/runtime/memory/memory_manager.cc"
6669
#include "src/runtime/vm/attn_backend.cc"
6770
#include "src/runtime/vm/builtin.cc"
@@ -130,20 +133,28 @@ void ArrayDecodeStorage(Tensor cpu_arr, TVMFFIByteArray* bytes, const std::strin
130133
const char* byte_data = bytes->data;
131134
const size_t byte_size = bytes->size;
132135
if (format == "f32-to-bf16" && dtype == "float32") {
133-
const uint16_t* bf16 = reinterpret_cast<const uint16_t*>(byte_data);
134-
uint32_t* data = static_cast<uint32_t*>(cpu_arr->data);
135136
TVM_FFI_ICHECK(cpu_arr.IsContiguous());
136137
size_t size = 1;
137138
for (int i = 0; i < cpu_arr->ndim; ++i) {
138139
size *= cpu_arr->shape[i];
139140
}
140-
TVM_FFI_ICHECK_EQ(size, byte_size / 2);
141-
for (size_t i = 0; i < size; ++i) {
142-
data[i] = static_cast<uint32_t>(bf16[i]) << 16;
141+
// The "f32-to-bf16" format encodes a float32 tensor as packed bf16 (2
142+
// bytes per element). When the byte_size matches that expectation, expand
143+
// back to f32. If the byte_size matches the native float32 width
144+
// (4 bytes per element), the payload is already raw float32; fall through
145+
// to the generic byte copy. This makes the loader tolerant of weight
146+
// shards produced by older / alternate quantisation pipelines that retain
147+
// the "f32-to-bf16" tag without performing the bf16 truncation.
148+
if (size == byte_size / 2) {
149+
const uint16_t* bf16 = reinterpret_cast<const uint16_t*>(byte_data);
150+
uint32_t* data = static_cast<uint32_t*>(cpu_arr->data);
151+
for (size_t i = 0; i < size; ++i) {
152+
data[i] = static_cast<uint32_t>(bf16[i]) << 16;
153+
}
154+
return;
143155
}
144-
} else {
145-
cpu_arr.CopyFromBytes(byte_data, byte_size);
146156
}
157+
cpu_arr.CopyFromBytes(byte_data, byte_size);
147158
}
148159

149160
TVM_FFI_STATIC_INIT_BLOCK() {

web/src/runtime.ts

Lines changed: 119 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1323,6 +1323,16 @@ export class Instance implements Disposable {
13231323
artifactCache: ArtifactCacheTemplate,
13241324
signal?: AbortSignal,
13251325
) {
1326+
const maxChunkBytes = 128 * 1024 * 1024;
1327+
const storageBytes = (dtype: string) => {
1328+
const match = dtype.match(/(\d+)(?:x(\d+))?$/);
1329+
if (match === null) {
1330+
throw new Error("Cannot determine storage width of dtype " + dtype);
1331+
}
1332+
const bits = Number(match[1]);
1333+
const lanes = match[2] === undefined ? 1 : Number(match[2]);
1334+
return (bits * lanes + 7) >> 3;
1335+
};
13261336
const perf = compact.getPerformance();
13271337
const tstart = perf.now();
13281338
let totalBytes = 0;
@@ -1421,9 +1431,59 @@ export class Instance implements Disposable {
14211431
this.empty(rec.shape, rec.dtype, this.cpu())
14221432
)
14231433
});
1424-
const recSource = buffer.slice(rec.byteOffset, rec.byteOffset + rec.nbytes);
1434+
const shardBytes = buffer instanceof Uint8Array ? buffer : new Uint8Array(buffer);
1435+
const recSource =
1436+
rec.byteOffset === 0 && rec.nbytes === shardBytes.byteLength
1437+
? shardBytes
1438+
: shardBytes.subarray(rec.byteOffset, rec.byteOffset + rec.nbytes);
1439+
const canChunkRecord =
1440+
rec.nbytes > maxChunkBytes &&
1441+
rec.shape.length >= 1 &&
1442+
Number.isInteger(rec.shape[0]) &&
1443+
rec.shape[0] > 0 &&
1444+
rec.nbytes % rec.shape[0] === 0;
1445+
const outerDim = canChunkRecord ? rec.shape[0] : 1;
1446+
const sourceStrideBytes = canChunkRecord ? rec.nbytes / outerDim : rec.nbytes;
1447+
const targetBytes = rec.shape.reduce((acc, value) => acc * value, 1) *
1448+
storageBytes(rec.dtype);
1449+
const targetStrideBytes = canChunkRecord ? targetBytes / outerDim : targetBytes;
1450+
const copyRecordToTensor = (targetTensor: Tensor, sourceBytes: Uint8Array) => {
1451+
if (!canChunkRecord) {
1452+
this.ctx.arrayDecodeStorage(targetTensor, sourceBytes, rec.format, rec.dtype);
1453+
return;
1454+
}
1455+
const chunkOuterDim = Math.max(1, Math.floor(maxChunkBytes / sourceStrideBytes));
1456+
for (let outerOffset = 0; outerOffset < outerDim; outerOffset += chunkOuterDim) {
1457+
const outerCount = Math.min(chunkOuterDim, outerDim - outerOffset);
1458+
const sourceByteOffset = outerOffset * sourceStrideBytes;
1459+
const targetByteOffset = outerOffset * targetStrideBytes;
1460+
const chunkBytes = outerCount * sourceStrideBytes;
1461+
const chunkShape = rec.shape.slice();
1462+
chunkShape[0] = outerCount;
1463+
const chunkShapeTuple = this.makeShapeTuple(chunkShape);
1464+
const chunkView = this.withNewScope(() => {
1465+
return this.detachFromCurrentScope(
1466+
this.ctx.tensorCreateView(
1467+
targetTensor,
1468+
chunkShapeTuple,
1469+
rec.dtype,
1470+
new Scalar(targetByteOffset, "int"),
1471+
)
1472+
);
1473+
});
1474+
const chunkSource = sourceBytes.subarray(
1475+
sourceByteOffset,
1476+
sourceByteOffset + chunkBytes,
1477+
);
1478+
try {
1479+
this.ctx.arrayDecodeStorage(chunkView, chunkSource, rec.format, rec.dtype);
1480+
} finally {
1481+
chunkView.dispose();
1482+
}
1483+
}
1484+
};
14251485
// first sync copy to cpu.
1426-
this.ctx.arrayDecodeStorage(cpu_arr, new Uint8Array(recSource), rec.format, rec.dtype);
1486+
copyRecordToTensor(cpu_arr, recSource);
14271487
// then async stream into GPU if needed
14281488
if (device.deviceType === DeviceStrToEnum.cpu) {
14291489
this.tensorCacheUpdate(rec.name, cpu_arr, false);
@@ -1435,7 +1495,44 @@ export class Instance implements Disposable {
14351495
this.empty(rec.shape, rec.dtype, device)
14361496
)
14371497
});
1438-
gpu_arr.copyFrom(cpu_arr);
1498+
if (!canChunkRecord) {
1499+
gpu_arr.copyFrom(cpu_arr);
1500+
} else {
1501+
const chunkOuterDim = Math.max(1, Math.floor(maxChunkBytes / sourceStrideBytes));
1502+
for (let outerOffset = 0; outerOffset < outerDim; outerOffset += chunkOuterDim) {
1503+
const outerCount = Math.min(chunkOuterDim, outerDim - outerOffset);
1504+
const targetByteOffset = outerOffset * targetStrideBytes;
1505+
const chunkShape = rec.shape.slice();
1506+
chunkShape[0] = outerCount;
1507+
const chunkShapeTuple = this.makeShapeTuple(chunkShape);
1508+
const [cpuView, gpuView] = this.withNewScope(() => {
1509+
return [
1510+
this.detachFromCurrentScope(
1511+
this.ctx.tensorCreateView(
1512+
cpu_arr,
1513+
chunkShapeTuple,
1514+
rec.dtype,
1515+
new Scalar(targetByteOffset, "int"),
1516+
)
1517+
),
1518+
this.detachFromCurrentScope(
1519+
this.ctx.tensorCreateView(
1520+
gpu_arr,
1521+
chunkShapeTuple,
1522+
rec.dtype,
1523+
new Scalar(targetByteOffset, "int"),
1524+
)
1525+
),
1526+
];
1527+
});
1528+
try {
1529+
gpuView.copyFrom(cpuView);
1530+
} finally {
1531+
cpuView.dispose();
1532+
gpuView.dispose();
1533+
}
1534+
}
1535+
}
14391536
await device.sync();
14401537
this.tensorCacheUpdate(rec.name, gpu_arr, false);
14411538
cpu_arr.dispose();
@@ -2258,6 +2355,25 @@ export class Instance implements Disposable {
22582355
case TypeIndex.kTVMFFIOpaquePtr: {
22592356
return this.memory.loadPointer(valuePtr);
22602357
}
2358+
case TypeIndex.kTVMFFIShape: {
2359+
const shapeObjPtr = this.memory.loadPointer(valuePtr);
2360+
if (callbackArg) {
2361+
const shapeCellPtr = shapeObjPtr + SizeOf.ObjectHeader;
2362+
const shapeDataPtr = this.memory.loadPointer(shapeCellPtr);
2363+
const shapeLen = this.memory.loadUSize(shapeCellPtr + this.memory.sizeofPtr());
2364+
const result = new Array<number>(shapeLen);
2365+
for (let i = 0; i < shapeLen; ++i) {
2366+
result[i] = this.memory.loadI64(shapeDataPtr + i * SizeOf.I64);
2367+
}
2368+
this.lib.checkCall(
2369+
(this.lib.exports.TVMFFIObjectDecRef as ctypes.FTVMFFIObjectDecRef)(shapeObjPtr)
2370+
);
2371+
return result;
2372+
}
2373+
return this.ctx.attachToCurrentScope(
2374+
new TVMObject(shapeObjPtr, this.lib, this.ctx)
2375+
);
2376+
}
22612377
case TypeIndex.kTVMFFITensor: {
22622378
return this.ctx.attachToCurrentScope(
22632379
new Tensor(this.memory.loadPointer(valuePtr), this.lib, this.ctx, false)

0 commit comments

Comments
 (0)