Skip to content

Commit 42ed716

Browse files
committed
[Web] Improve large tensor loading in wasm runtime
1 parent 4b7b7de commit 42ed716

2 files changed

Lines changed: 169 additions & 16 deletions

File tree

web/emcc/wasm_runtime.cc

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -130,32 +130,51 @@ void ArrayDecodeStorage(Tensor cpu_arr, TVMFFIByteArray* bytes, const std::strin
130130
const char* byte_data = bytes->data;
131131
const size_t byte_size = bytes->size;
132132
if (format == "f32-to-bf16" && dtype == "float32") {
133-
const uint16_t* bf16 = reinterpret_cast<const uint16_t*>(byte_data);
134-
uint32_t* data = static_cast<uint32_t*>(cpu_arr->data);
135133
TVM_FFI_ICHECK(cpu_arr.IsContiguous());
136134
size_t size = 1;
137135
for (int i = 0; i < cpu_arr->ndim; ++i) {
138136
size *= cpu_arr->shape[i];
139137
}
140-
TVM_FFI_ICHECK_EQ(size, byte_size / 2);
141-
for (size_t i = 0; i < size; ++i) {
142-
data[i] = static_cast<uint32_t>(bf16[i]) << 16;
138+
// The "f32-to-bf16" format encodes a float32 tensor as packed bf16 (2
139+
// bytes per element). When the byte_size matches that expectation, expand
140+
// back to f32. If the byte_size matches the native float32 width
141+
// (4 bytes per element), the payload is already raw float32; fall through
142+
// to the generic byte copy. This makes the loader tolerant of weight
143+
// shards produced by older / alternate quantisation pipelines that retain
144+
// the "f32-to-bf16" tag without performing the bf16 truncation.
145+
if (byte_size == size * sizeof(uint16_t)) {
146+
const uint16_t* bf16 = reinterpret_cast<const uint16_t*>(byte_data);
147+
uint32_t* data =
148+
reinterpret_cast<uint32_t*>(static_cast<char*>(cpu_arr->data) + cpu_arr->byte_offset);
149+
for (size_t i = 0; i < size; ++i) {
150+
data[i] = static_cast<uint32_t>(bf16[i]) << 16;
151+
}
152+
return;
143153
}
144-
} else {
145-
cpu_arr.CopyFromBytes(byte_data, byte_size);
146154
}
155+
cpu_arr.CopyFromBytes(byte_data, byte_size);
156+
}
157+
158+
int64_t StorageSizeBytes(int64_t num_elements, const std::string& dtype) {
159+
TVM_FFI_ICHECK_GE(num_elements, 0);
160+
TVMFFIByteArray dtype_bytes{dtype.data(), dtype.size()};
161+
DLDataType dl_dtype;
162+
TVM_FFI_ICHECK_EQ(TVMFFIDataTypeFromString(&dtype_bytes, &dl_dtype), 0);
163+
return static_cast<int64_t>(
164+
ffi::GetDataSize(static_cast<size_t>(num_elements), dl_dtype));
147165
}
148166

149167
TVM_FFI_STATIC_INIT_BLOCK() {
150168
namespace refl = tvm::ffi::reflection;
151-
refl::GlobalDef().def_packed(
152-
"tvmjs.array.decode_storage", [](ffi::PackedArgs args, ffi::Any* ret) {
169+
refl::GlobalDef()
170+
.def_packed("tvmjs.array.decode_storage", [](ffi::PackedArgs args, ffi::Any* ret) {
153171
Tensor cpu_arr = args[0].cast<Tensor>();
154172
TVMFFIByteArray* bytes = args[1].cast<TVMFFIByteArray*>();
155173
std::string format = args[2].cast<ffi::String>().operator std::string();
156174
std::string dtype = args[3].cast<ffi::String>().operator std::string();
157175
ArrayDecodeStorage(cpu_arr, bytes, format, dtype);
158-
});
176+
})
177+
.def("tvmjs.runtime.StorageSizeBytes", StorageSizeBytes);
159178
}
160179

161180
// Concatenate n TVMArrays

web/src/runtime.ts

Lines changed: 140 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ class RuntimeContext implements Disposable {
172172
tensorCacheRemove: PackedFunc;
173173
tensorCacheClear: PackedFunc;
174174
arrayDecodeStorage: PackedFunc;
175+
storageSizeBytes: PackedFunc;
175176
paramModuleFromCache: PackedFunc;
176177
paramModuleFromCacheByName: PackedFunc;
177178
makeShapeTuple: PackedFunc;
@@ -207,6 +208,7 @@ class RuntimeContext implements Disposable {
207208
this.tensorCacheUpdate = getGlobalFunc("vm.builtin.tensor_cache.update");
208209
this.tensorCacheClear = getGlobalFunc("vm.builtin.tensor_cache.clear");
209210
this.arrayDecodeStorage = getGlobalFunc("tvmjs.array.decode_storage");
211+
this.storageSizeBytes = getGlobalFunc("tvmjs.runtime.StorageSizeBytes");
210212
this.paramModuleFromCache = getGlobalFunc("vm.builtin.param_module_from_cache");
211213
this.paramModuleFromCacheByName = getGlobalFunc("vm.builtin.param_module_from_cache_by_name");
212214
this.makeShapeTuple = getGlobalFunc("ffi.Shape");
@@ -230,6 +232,7 @@ class RuntimeContext implements Disposable {
230232
this.tensorCacheRemove.dispose();
231233
this.tensorCacheUpdate.dispose();
232234
this.arrayDecodeStorage.dispose();
235+
this.storageSizeBytes.dispose();
233236
this.paramModuleFromCache.dispose();
234237
this.paramModuleFromCacheByName.dispose();
235238
this.makeShapeTuple.dispose();
@@ -1010,9 +1013,11 @@ export class Instance implements Disposable {
10101013
*/
10111014
withNewScope<T>(action: () => T): T {
10121015
this.beginScope();
1013-
const val = action();
1014-
this.endScope();
1015-
return val;
1016+
try {
1017+
return action();
1018+
} finally {
1019+
this.endScope();
1020+
}
10161021
}
10171022

10181023
/**
@@ -1323,6 +1328,19 @@ export class Instance implements Disposable {
13231328
artifactCache: ArtifactCacheTemplate,
13241329
signal?: AbortSignal,
13251330
) {
1331+
// Avoid a single JS-to-wasm byte-array call for multi-hundred-MiB
1332+
// tensor-cache records. The cap is a conservative per-call staging size,
1333+
// independent of the final tensor allocation size. Smaller records keep
1334+
// the existing full-record path.
1335+
const maxChunkBytes = 128 * 1024 * 1024;
1336+
const storageSizeBytes = (numElements: number, dtype: string): number | undefined => {
1337+
try {
1338+
return this.ctx.storageSizeBytes(new Scalar(numElements, "int"), dtype) as number;
1339+
} catch {
1340+
// Unknown dtypes can still use the original full-record loading path.
1341+
return undefined;
1342+
}
1343+
};
13261344
const perf = compact.getPerformance();
13271345
const tstart = perf.now();
13281346
let totalBytes = 0;
@@ -1421,9 +1439,68 @@ export class Instance implements Disposable {
14211439
this.empty(rec.shape, rec.dtype, this.cpu())
14221440
)
14231441
});
1424-
const recSource = buffer.slice(rec.byteOffset, rec.byteOffset + rec.nbytes);
1442+
const shardBytes = buffer instanceof Uint8Array ? buffer : new Uint8Array(buffer);
1443+
const recSource =
1444+
rec.byteOffset === 0 && rec.nbytes === shardBytes.byteLength
1445+
? shardBytes
1446+
: shardBytes.subarray(rec.byteOffset, rec.byteOffset + rec.nbytes);
1447+
let canChunkRecord =
1448+
rec.nbytes > maxChunkBytes &&
1449+
rec.shape.length >= 1 &&
1450+
Number.isInteger(rec.shape[0]) &&
1451+
rec.shape[0] > 0 &&
1452+
rec.nbytes % rec.shape[0] === 0;
1453+
const outerDim = canChunkRecord ? rec.shape[0] : 1;
1454+
const sourceStrideBytes = canChunkRecord ? rec.nbytes / outerDim : rec.nbytes;
1455+
let targetStrideBytes = 0;
1456+
if (canChunkRecord) {
1457+
const numElements = rec.shape.reduce((acc, value) => acc * value, 1);
1458+
const targetBytes = storageSizeBytes(numElements, rec.dtype);
1459+
canChunkRecord =
1460+
sourceStrideBytes <= maxChunkBytes &&
1461+
targetBytes !== undefined &&
1462+
targetBytes % outerDim === 0;
1463+
if (canChunkRecord) {
1464+
targetStrideBytes = targetBytes / outerDim;
1465+
}
1466+
}
1467+
const copyRecordToTensor = (targetTensor: Tensor, sourceBytes: Uint8Array) => {
1468+
if (!canChunkRecord) {
1469+
this.ctx.arrayDecodeStorage(targetTensor, sourceBytes, rec.format, rec.dtype);
1470+
return;
1471+
}
1472+
const chunkOuterDim = Math.max(1, Math.floor(maxChunkBytes / sourceStrideBytes));
1473+
for (let outerOffset = 0; outerOffset < outerDim; outerOffset += chunkOuterDim) {
1474+
const outerCount = Math.min(chunkOuterDim, outerDim - outerOffset);
1475+
const sourceByteOffset = outerOffset * sourceStrideBytes;
1476+
const targetByteOffset = outerOffset * targetStrideBytes;
1477+
const chunkBytes = outerCount * sourceStrideBytes;
1478+
const chunkShape = rec.shape.slice();
1479+
chunkShape[0] = outerCount;
1480+
const chunkView = this.withNewScope(() => {
1481+
const chunkShapeTuple = this.makeShapeTuple(chunkShape);
1482+
return this.detachFromCurrentScope(
1483+
this.ctx.tensorCreateView(
1484+
targetTensor,
1485+
chunkShapeTuple,
1486+
rec.dtype,
1487+
new Scalar(targetByteOffset, "int"),
1488+
)
1489+
);
1490+
});
1491+
const chunkSource = sourceBytes.subarray(
1492+
sourceByteOffset,
1493+
sourceByteOffset + chunkBytes,
1494+
);
1495+
try {
1496+
this.ctx.arrayDecodeStorage(chunkView, chunkSource, rec.format, rec.dtype);
1497+
} finally {
1498+
chunkView.dispose();
1499+
}
1500+
}
1501+
};
14251502
// first sync copy to cpu.
1426-
this.ctx.arrayDecodeStorage(cpu_arr, new Uint8Array(recSource), rec.format, rec.dtype);
1503+
copyRecordToTensor(cpu_arr, recSource);
14271504
// then async stream into GPU if needed
14281505
if (device.deviceType === DeviceStrToEnum.cpu) {
14291506
this.tensorCacheUpdate(rec.name, cpu_arr, false);
@@ -1435,7 +1512,42 @@ export class Instance implements Disposable {
14351512
this.empty(rec.shape, rec.dtype, device)
14361513
)
14371514
});
1438-
gpu_arr.copyFrom(cpu_arr);
1515+
if (!canChunkRecord) {
1516+
gpu_arr.copyFrom(cpu_arr);
1517+
} else {
1518+
const chunkOuterDim = Math.max(1, Math.floor(maxChunkBytes / sourceStrideBytes));
1519+
for (let outerOffset = 0; outerOffset < outerDim; outerOffset += chunkOuterDim) {
1520+
const outerCount = Math.min(chunkOuterDim, outerDim - outerOffset);
1521+
const targetByteOffset = outerOffset * targetStrideBytes;
1522+
const chunkShape = rec.shape.slice();
1523+
chunkShape[0] = outerCount;
1524+
const [cpuView, gpuView] = this.withNewScope(() => {
1525+
const chunkShapeTuple = this.makeShapeTuple(chunkShape);
1526+
const cView = this.ctx.tensorCreateView(
1527+
cpu_arr,
1528+
chunkShapeTuple,
1529+
rec.dtype,
1530+
new Scalar(targetByteOffset, "int"),
1531+
);
1532+
const gView = this.ctx.tensorCreateView(
1533+
gpu_arr,
1534+
chunkShapeTuple,
1535+
rec.dtype,
1536+
new Scalar(targetByteOffset, "int"),
1537+
);
1538+
return [
1539+
this.detachFromCurrentScope(cView),
1540+
this.detachFromCurrentScope(gView),
1541+
];
1542+
});
1543+
try {
1544+
gpuView.copyFrom(cpuView);
1545+
} finally {
1546+
cpuView.dispose();
1547+
gpuView.dispose();
1548+
}
1549+
}
1550+
}
14391551
await device.sync();
14401552
this.tensorCacheUpdate(rec.name, gpu_arr, false);
14411553
cpu_arr.dispose();
@@ -2258,6 +2370,28 @@ export class Instance implements Disposable {
22582370
case TypeIndex.kTVMFFIOpaquePtr: {
22592371
return this.memory.loadPointer(valuePtr);
22602372
}
2373+
case TypeIndex.kTVMFFIShape: {
2374+
const shapeObjPtr = this.memory.loadPointer(valuePtr);
2375+
if (shapeObjPtr === 0) {
2376+
return null;
2377+
}
2378+
if (callbackArg) {
2379+
const shapeCellPtr = shapeObjPtr + SizeOf.ObjectHeader;
2380+
const shapeDataPtr = this.memory.loadPointer(shapeCellPtr);
2381+
const shapeLen = this.memory.loadUSize(shapeCellPtr + this.memory.sizeofPtr());
2382+
const result = new Array<number>(shapeLen);
2383+
for (let i = 0; i < shapeLen; ++i) {
2384+
result[i] = this.memory.loadI64(shapeDataPtr + i * SizeOf.I64);
2385+
}
2386+
this.lib.checkCall(
2387+
(this.lib.exports.TVMFFIObjectDecRef as ctypes.FTVMFFIObjectDecRef)(shapeObjPtr)
2388+
);
2389+
return result;
2390+
}
2391+
return this.ctx.attachToCurrentScope(
2392+
new TVMObject(shapeObjPtr, this.lib, this.ctx)
2393+
);
2394+
}
22612395
case TypeIndex.kTVMFFITensor: {
22622396
return this.ctx.attachToCurrentScope(
22632397
new Tensor(this.memory.loadPointer(valuePtr), this.lib, this.ctx, false)

0 commit comments

Comments
 (0)