Skip to content

Commit 2673a80

Browse files
committed
[Web] Improve large tensor loading in wasm runtime
1 parent 4b7b7de commit 2673a80

2 files changed

Lines changed: 170 additions & 13 deletions

File tree

web/emcc/wasm_runtime.cc

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -130,20 +130,29 @@ void ArrayDecodeStorage(Tensor cpu_arr, TVMFFIByteArray* bytes, const std::strin
130130
const char* byte_data = bytes->data;
131131
const size_t byte_size = bytes->size;
132132
if (format == "f32-to-bf16" && dtype == "float32") {
133-
const uint16_t* bf16 = reinterpret_cast<const uint16_t*>(byte_data);
134-
uint32_t* data = static_cast<uint32_t*>(cpu_arr->data);
135133
TVM_FFI_ICHECK(cpu_arr.IsContiguous());
136134
size_t size = 1;
137135
for (int i = 0; i < cpu_arr->ndim; ++i) {
138136
size *= cpu_arr->shape[i];
139137
}
140-
TVM_FFI_ICHECK_EQ(size, byte_size / 2);
141-
for (size_t i = 0; i < size; ++i) {
142-
data[i] = static_cast<uint32_t>(bf16[i]) << 16;
138+
// The "f32-to-bf16" format encodes a float32 tensor as packed bf16 (2
139+
// bytes per element). When the byte_size matches that expectation, expand
140+
// back to f32. If the byte_size matches the native float32 width
141+
// (4 bytes per element), the payload is already raw float32; fall through
142+
// to the generic byte copy. This makes the loader tolerant of weight
143+
// shards produced by older / alternate quantisation pipelines that retain
144+
// the "f32-to-bf16" tag without performing the bf16 truncation.
145+
if (byte_size == size * sizeof(uint16_t)) {
146+
const uint16_t* bf16 = reinterpret_cast<const uint16_t*>(byte_data);
147+
uint32_t* data =
148+
reinterpret_cast<uint32_t*>(static_cast<char*>(cpu_arr->data) + cpu_arr->byte_offset);
149+
for (size_t i = 0; i < size; ++i) {
150+
data[i] = static_cast<uint32_t>(bf16[i]) << 16;
151+
}
152+
return;
143153
}
144-
} else {
145-
cpu_arr.CopyFromBytes(byte_data, byte_size);
146154
}
155+
cpu_arr.CopyFromBytes(byte_data, byte_size);
147156
}
148157

149158
TVM_FFI_STATIC_INIT_BLOCK() {

web/src/runtime.ts

Lines changed: 154 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1010,9 +1010,11 @@ export class Instance implements Disposable {
10101010
*/
10111011
withNewScope<T>(action: () => T): T {
10121012
this.beginScope();
1013-
const val = action();
1014-
this.endScope();
1015-
return val;
1013+
try {
1014+
return action();
1015+
} finally {
1016+
this.endScope();
1017+
}
10161018
}
10171019

10181020
/**
@@ -1323,6 +1325,45 @@ export class Instance implements Disposable {
13231325
artifactCache: ArtifactCacheTemplate,
13241326
signal?: AbortSignal,
13251327
) {
1328+
// Avoid a single JS-to-wasm byte-array call for multi-hundred-MiB
1329+
// tensor-cache records. The cap is a conservative per-call staging size,
1330+
// independent of the final tensor allocation size. Smaller records keep
1331+
// the existing full-record path.
1332+
const maxChunkBytes = 128 * 1024 * 1024;
1333+
const storageBytes = (dtype: string) => {
1334+
if (dtype === "bool") {
1335+
return 1;
1336+
}
1337+
1338+
if (dtype.startsWith("boolx")) {
1339+
const lanes = Number(dtype.slice("boolx".length));
1340+
if (Number.isInteger(lanes) && lanes > 0) {
1341+
return lanes;
1342+
}
1343+
}
1344+
1345+
for (const prefix of ["bfloat", "complex", "float", "uint", "int"]) {
1346+
if (!dtype.startsWith(prefix)) {
1347+
continue;
1348+
}
1349+
1350+
const widthAndLanes = dtype.slice(prefix.length);
1351+
const vectorLaneSeparator = widthAndLanes.indexOf("x");
1352+
let bitsText = widthAndLanes;
1353+
let lanes = 1;
1354+
if (vectorLaneSeparator !== -1) {
1355+
bitsText = widthAndLanes.slice(0, vectorLaneSeparator);
1356+
lanes = Number(widthAndLanes.slice(vectorLaneSeparator + 1));
1357+
}
1358+
const bits = Number(bitsText);
1359+
if (Number.isInteger(bits) && bits > 0 &&
1360+
Number.isInteger(lanes) && lanes > 0) {
1361+
return (bits * lanes + 7) >> 3;
1362+
}
1363+
}
1364+
1365+
throw new Error("Cannot determine storage width of dtype " + dtype);
1366+
};
13261367
const perf = compact.getPerformance();
13271368
const tstart = perf.now();
13281369
let totalBytes = 0;
@@ -1421,9 +1462,59 @@ export class Instance implements Disposable {
14211462
this.empty(rec.shape, rec.dtype, this.cpu())
14221463
)
14231464
});
1424-
const recSource = buffer.slice(rec.byteOffset, rec.byteOffset + rec.nbytes);
1465+
const shardBytes = buffer instanceof Uint8Array ? buffer : new Uint8Array(buffer);
1466+
const recSource =
1467+
rec.byteOffset === 0 && rec.nbytes === shardBytes.byteLength
1468+
? shardBytes
1469+
: shardBytes.subarray(rec.byteOffset, rec.byteOffset + rec.nbytes);
1470+
const canChunkRecord =
1471+
rec.nbytes > maxChunkBytes &&
1472+
rec.shape.length >= 1 &&
1473+
Number.isInteger(rec.shape[0]) &&
1474+
rec.shape[0] > 0 &&
1475+
rec.nbytes % rec.shape[0] === 0;
1476+
const outerDim = canChunkRecord ? rec.shape[0] : 1;
1477+
const sourceStrideBytes = canChunkRecord ? rec.nbytes / outerDim : rec.nbytes;
1478+
const targetBytes = rec.shape.reduce((acc, value) => acc * value, 1) *
1479+
storageBytes(rec.dtype);
1480+
const targetStrideBytes = canChunkRecord ? targetBytes / outerDim : targetBytes;
1481+
const copyRecordToTensor = (targetTensor: Tensor, sourceBytes: Uint8Array) => {
1482+
if (!canChunkRecord) {
1483+
this.ctx.arrayDecodeStorage(targetTensor, sourceBytes, rec.format, rec.dtype);
1484+
return;
1485+
}
1486+
const chunkOuterDim = Math.max(1, Math.floor(maxChunkBytes / sourceStrideBytes));
1487+
for (let outerOffset = 0; outerOffset < outerDim; outerOffset += chunkOuterDim) {
1488+
const outerCount = Math.min(chunkOuterDim, outerDim - outerOffset);
1489+
const sourceByteOffset = outerOffset * sourceStrideBytes;
1490+
const targetByteOffset = outerOffset * targetStrideBytes;
1491+
const chunkBytes = outerCount * sourceStrideBytes;
1492+
const chunkShape = rec.shape.slice();
1493+
chunkShape[0] = outerCount;
1494+
const chunkView = this.withNewScope(() => {
1495+
const chunkShapeTuple = this.makeShapeTuple(chunkShape);
1496+
return this.detachFromCurrentScope(
1497+
this.ctx.tensorCreateView(
1498+
targetTensor,
1499+
chunkShapeTuple,
1500+
rec.dtype,
1501+
new Scalar(targetByteOffset, "int"),
1502+
)
1503+
);
1504+
});
1505+
const chunkSource = sourceBytes.subarray(
1506+
sourceByteOffset,
1507+
sourceByteOffset + chunkBytes,
1508+
);
1509+
try {
1510+
this.ctx.arrayDecodeStorage(chunkView, chunkSource, rec.format, rec.dtype);
1511+
} finally {
1512+
chunkView.dispose();
1513+
}
1514+
}
1515+
};
14251516
// first sync copy to cpu.
1426-
this.ctx.arrayDecodeStorage(cpu_arr, new Uint8Array(recSource), rec.format, rec.dtype);
1517+
copyRecordToTensor(cpu_arr, recSource);
14271518
// then async stream into GPU if needed
14281519
if (device.deviceType === DeviceStrToEnum.cpu) {
14291520
this.tensorCacheUpdate(rec.name, cpu_arr, false);
@@ -1435,7 +1526,42 @@ export class Instance implements Disposable {
14351526
this.empty(rec.shape, rec.dtype, device)
14361527
)
14371528
});
1438-
gpu_arr.copyFrom(cpu_arr);
1529+
if (!canChunkRecord) {
1530+
gpu_arr.copyFrom(cpu_arr);
1531+
} else {
1532+
const chunkOuterDim = Math.max(1, Math.floor(maxChunkBytes / sourceStrideBytes));
1533+
for (let outerOffset = 0; outerOffset < outerDim; outerOffset += chunkOuterDim) {
1534+
const outerCount = Math.min(chunkOuterDim, outerDim - outerOffset);
1535+
const targetByteOffset = outerOffset * targetStrideBytes;
1536+
const chunkShape = rec.shape.slice();
1537+
chunkShape[0] = outerCount;
1538+
const [cpuView, gpuView] = this.withNewScope(() => {
1539+
const chunkShapeTuple = this.makeShapeTuple(chunkShape);
1540+
const cView = this.ctx.tensorCreateView(
1541+
cpu_arr,
1542+
chunkShapeTuple,
1543+
rec.dtype,
1544+
new Scalar(targetByteOffset, "int"),
1545+
);
1546+
const gView = this.ctx.tensorCreateView(
1547+
gpu_arr,
1548+
chunkShapeTuple,
1549+
rec.dtype,
1550+
new Scalar(targetByteOffset, "int"),
1551+
);
1552+
return [
1553+
this.detachFromCurrentScope(cView),
1554+
this.detachFromCurrentScope(gView),
1555+
];
1556+
});
1557+
try {
1558+
gpuView.copyFrom(cpuView);
1559+
} finally {
1560+
cpuView.dispose();
1561+
gpuView.dispose();
1562+
}
1563+
}
1564+
}
14391565
await device.sync();
14401566
this.tensorCacheUpdate(rec.name, gpu_arr, false);
14411567
cpu_arr.dispose();
@@ -2258,6 +2384,28 @@ export class Instance implements Disposable {
22582384
case TypeIndex.kTVMFFIOpaquePtr: {
22592385
return this.memory.loadPointer(valuePtr);
22602386
}
2387+
case TypeIndex.kTVMFFIShape: {
2388+
const shapeObjPtr = this.memory.loadPointer(valuePtr);
2389+
if (shapeObjPtr === 0) {
2390+
return null;
2391+
}
2392+
if (callbackArg) {
2393+
const shapeCellPtr = shapeObjPtr + SizeOf.ObjectHeader;
2394+
const shapeDataPtr = this.memory.loadPointer(shapeCellPtr);
2395+
const shapeLen = this.memory.loadUSize(shapeCellPtr + this.memory.sizeofPtr());
2396+
const result = new Array<number>(shapeLen);
2397+
for (let i = 0; i < shapeLen; ++i) {
2398+
result[i] = this.memory.loadI64(shapeDataPtr + i * SizeOf.I64);
2399+
}
2400+
this.lib.checkCall(
2401+
(this.lib.exports.TVMFFIObjectDecRef as ctypes.FTVMFFIObjectDecRef)(shapeObjPtr)
2402+
);
2403+
return result;
2404+
}
2405+
return this.ctx.attachToCurrentScope(
2406+
new TVMObject(shapeObjPtr, this.lib, this.ctx)
2407+
);
2408+
}
22612409
case TypeIndex.kTVMFFITensor: {
22622410
return this.ctx.attachToCurrentScope(
22632411
new Tensor(this.memory.loadPointer(valuePtr), this.lib, this.ctx, false)

0 commit comments

Comments
 (0)