-
Notifications
You must be signed in to change notification settings - Fork 330
Expand file tree
/
Copy pathwebgpu_ep_bootstrapper.cc
More file actions
370 lines (305 loc) · 13.1 KB
/
Copy pathwebgpu_ep_bootstrapper.cc
File metadata and controls
370 lines (305 loc) · 13.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "ep_detection/webgpu_ep_bootstrapper.h"
#include "ep_detection/ep_utils.h"
#include "http/http_client.h"
#include "http/http_download.h"
#include "logger.h"
#include "util/file_lock.h"
#include "util/sha256.h"
#include "util/zip_extract.h"
#include <fmt/format.h>
#include <nlohmann/json.hpp>
#include <algorithm>
#include <atomic>
#include <cctype>
#include <filesystem>
#include <fstream>
#include <string>
#include <unordered_map>
#ifdef _WIN32
#define WIN32_LEAN_AND_MEAN
#include <windows.h>
#endif
namespace {
constexpr const char* kPackageFileName = "webgpu-ep.zip";
constexpr const char* kLockFileName = "webgpu-ep.lock";
constexpr const char* kStagingDirName = "webgpu-ep-staging";
constexpr const char* kUserAgent = "FoundryLocal";
constexpr int kMaxInstallAttempts = 5;
// Manifest zip URL — atomically contains manifest.json and manifest.json.sig.
constexpr const char* kManifestZipUrl =
"https://foundrypackages-ffhrdhbxb7gpdreh.b02.azurefd.net/webgpu_manifest_prod.zip";
// RSA-4096 public key used to verify the manifest signature.
// Corresponds to the private key used by official WebGPU Plugin EP Publishing Pipeline.
constexpr const char* kManifestSigningPublicKey = R"PEM(
-----BEGIN PUBLIC KEY-----
MIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEA1YwPWIQ7UJZ0EOVfRIeU
AiI6G9nwmQ+0RGmBKKNPeuTt8To7EUBfs2yjHs1nS159oEbI9wmN+SRhTx72fyo7
EEbQ2kYB/d+/znqrpTinHiyfrn6dEzqJzj5diTfXkVbm5+uueqxoxN6TAUwZqsdO
wveft1DiSU8G0NRx3QPxBACZx199ObiQgqDQycTbc7qaRUy9rkcrMimvXKIaui3z
fmxQtzF6WkRnN4Xf+jkzxgua0xSHkcdYpDu+M39iynqEkSChzv+h0NIE/B05z9/y
+6/EjFETYB2LuSr7N3EOMj1eTff/oFqwBk1gBuLxNxHjTtH1+DxpygIxz9Dy2OY5
jG46Io9Eg8q7UMW4aSm/YS/Sqt8KzqOG59XvLtADDlaS+8+KDV0K9Jwq1WXBbqXd
gXlUjLdIh+UAgF0zv5N8MGoS9BxvBNr932XkUV5VC26JgU3tPqiiiSXfPParBSJt
wt/PSpQDqkcWE9VsRmCe5pAgmv3AQlv+jSLlB8aDdCP8/+/AoI7St4n7STl8QtPl
XXWmO8EJwqEXFpaitcpNyzuol6/7H4mQV6XeNjezjmTWeedvxWcZXi1Pxp/FfOEK
iJxrPNMxlZZA26WvTEhc0vi9hxYxTsZKWuenZoGvgR2/sy2tqbEV3/4JhowQ6K56
MvdOj/vvArK/BIwPJnCYv4kCAwEAAQ==
-----END PUBLIC KEY-----
)PEM";
// Platform key used to look up this platform's package in the manifest.
#if defined(_WIN32) && defined(_M_ARM64)
constexpr const char* kPlatformKey = "win-arm64";
#elif defined(_WIN32)
constexpr const char* kPlatformKey = "win-x64";
#elif defined(__APPLE__)
constexpr const char* kPlatformKey = "macos-arm64";
#else
constexpr const char* kPlatformKey = "linux-x64";
#endif
// Platform-specific EP library filename.
#if defined(_WIN32)
constexpr const char* kWebGpuProviderLib = "onnxruntime_providers_webgpu.dll";
#elif defined(__APPLE__)
constexpr const char* kWebGpuProviderLib = "libonnxruntime_providers_webgpu.dylib";
#else
constexpr const char* kWebGpuProviderLib = "libonnxruntime_providers_webgpu.so";
#endif
constexpr const char* kRegistrationName = "Foundry.WebGPU";
/// Parsed manifest entry for a single platform.
struct ManifestPackageInfo {
std::string url;
std::vector<std::pair<std::string, std::string>> sha256; // filename -> expected SHA256 hash
};
/// Fetch the manifest zip (atomically containing manifest.json and manifest.json.sig) from CDN,
/// extract it, verify the signature, and extract the package info for this platform.
ManifestPackageInfo FetchManifest(fl::ILogger& logger) {
logger.Log(fl::LogLevel::Debug, fmt::format("WebGPU EP: fetching manifest zip from {}", kManifestZipUrl));
// Download manifest zip atomically (contains both manifest.json and manifest.json.sig)
auto zip_path = std::filesystem::temp_directory_path() / "webgpu_manifest_temp.zip";
if (!HttpDownloadFile(kManifestZipUrl, zip_path, kUserAgent, nullptr, nullptr, logger)) {
throw std::runtime_error("WebGPU EP: failed to download manifest zip");
}
// Extract to temporary directory
auto extract_dir = std::filesystem::temp_directory_path() / kStagingDirName;
if (std::filesystem::exists(extract_dir)) {
std::filesystem::remove_all(extract_dir);
}
std::filesystem::create_directories(extract_dir);
if (!ExtractZip(zip_path, extract_dir, logger)) {
std::filesystem::remove_all(extract_dir);
std::filesystem::remove(zip_path);
throw std::runtime_error("WebGPU EP: failed to extract manifest zip");
}
// Read manifest and signature from extracted files
auto manifest_file = extract_dir / "manifest.json";
auto sig_file = extract_dir / "manifest.json.sig";
if (!std::filesystem::exists(manifest_file)) {
std::filesystem::remove_all(extract_dir);
std::filesystem::remove(zip_path);
throw std::runtime_error("WebGPU EP: manifest.json not found in manifest zip");
}
if (!std::filesystem::exists(sig_file)) {
std::filesystem::remove_all(extract_dir);
std::filesystem::remove(zip_path);
throw std::runtime_error("WebGPU EP: manifest.json.sig not found in manifest zip");
}
// Read both files as strings
std::ifstream manifest_stream(manifest_file, std::ios::binary);
std::string body((std::istreambuf_iterator<char>(manifest_stream)), std::istreambuf_iterator<char>());
manifest_stream.close();
std::ifstream sig_stream(sig_file, std::ios::binary);
std::string sig((std::istreambuf_iterator<char>(sig_stream)), std::istreambuf_iterator<char>());
sig_stream.close();
// Trim any trailing whitespace (CDN may append \r\n).
while (!sig.empty() && (sig.back() == '\n' || sig.back() == '\r' || sig.back() == ' ')) {
sig.pop_back();
}
// Verify signature
if (!fl::VerifyRsaSha256Signature(body, sig, kManifestSigningPublicKey, logger)) {
std::filesystem::remove_all(extract_dir);
std::filesystem::remove(zip_path);
throw std::runtime_error(
"WebGPU EP: manifest signature verification failed — refusing to use manifest");
}
logger.Log(fl::LogLevel::Debug, "WebGPU EP: manifest signature verified");
// Clean up temporary files before parsing
std::filesystem::remove_all(extract_dir);
std::filesystem::remove(zip_path);
auto manifest = nlohmann::json::parse(body);
if (!manifest.contains("packages") || !manifest["packages"].is_object()) {
throw std::runtime_error(
fmt::format("WebGPU EP: manifest is invalid — missing 'packages' field. "
"Raw content (first 200 chars): {}",
body.substr(0, 200)));
}
const auto& packages = manifest["packages"];
if (!packages.contains(kPlatformKey)) {
std::string available;
for (auto it = packages.begin(); it != packages.end(); ++it) {
if (!available.empty()) available += ", ";
available += it.key();
}
throw std::runtime_error(
fmt::format("WebGPU EP: manifest does not contain a package for platform '{}'. "
"Available platforms: {}",
kPlatformKey, available));
}
const auto& pkg = packages[kPlatformKey];
ManifestPackageInfo info;
info.url = pkg.at("url").get<std::string>();
for (const auto& [filename, hash] : pkg.at("sha256").items()) {
info.sha256.push_back({filename, hash.get<std::string>()});
}
logger.Log(fl::LogLevel::Information,
fmt::format("WebGPU EP: manifest fetched for platform '{}'", kPlatformKey));
return info;
}
} // anonymous namespace
namespace fl {
WebGpuEpBootstrapper::WebGpuEpBootstrapper(std::string ep_dir, EpRegistrationCallback register_ep)
: ep_dir_(std::move(ep_dir)), register_ep_(std::move(register_ep)) {}
const std::string& WebGpuEpBootstrapper::Name() const {
return name_;
}
bool WebGpuEpBootstrapper::IsRegistered() const {
return registered_;
}
bool WebGpuEpBootstrapper::DownloadAndRegister(bool force,
const ProgressCallback& progress_cb,
ILogger& logger) {
if (registered_ && !force) {
if (progress_cb) {
progress_cb(name_, 100.0f);
}
return true;
}
if (!force && attempts_ >= kMaxInstallAttempts) {
logger.Log(LogLevel::Warning, "WebGPU EP: max install attempts reached");
return false;
}
attempts_++;
auto ep_dir = std::filesystem::path(ep_dir_);
auto parent_dir = ep_dir.parent_path();
try {
// Fetch manifest before acquiring lock (avoid holding lock during network I/O)
auto manifest = FetchManifest(logger);
// Build a verifier for all binaries listed in the manifest (EP binary + Windows DX dlls).
auto verify_package = [&](const std::filesystem::path& dir) -> bool {
// Verify each file individually (VerifyEpPackage takes initializer_list for compile-time constants)
for (const auto& [filename, expected_hash] : manifest.sha256) {
bool verified = VerifyEpPackage(dir, {{filename, expected_hash}}, "WebGPU EP", logger);
logger.Log(LogLevel::Debug,
fmt::format("WebGPU EP: verifying SHA256 of '{}': {}", filename, verified));
if (!verified) {
return false;
}
}
return true;
};
// Check if package already exists and is valid
if (!force && verify_package(ep_dir)) {
logger.Log(LogLevel::Debug, "WebGPU EP: local binaries match manifest, skipping download");
} else {
// Ensure parent directory exists for the lock file
std::filesystem::create_directories(parent_dir);
auto lock_path = parent_dir / kLockFileName;
// Cross-process lock to prevent concurrent installs
FileLock lock(lock_path);
// Re-check after acquiring lock (another process may have completed the update)
if (!force && verify_package(ep_dir)) {
logger.Log(LogLevel::Debug, "WebGPU EP: another process already completed the update");
} else {
// Download and extract to staging directory for atomic swap
auto staging_dir = parent_dir / kStagingDirName;
if (std::filesystem::exists(staging_dir)) {
std::filesystem::remove_all(staging_dir);
}
std::filesystem::create_directories(staging_dir);
auto zip_path = staging_dir / kPackageFileName;
// Download
logger.Log(LogLevel::Information,
fmt::format("WebGPU EP: downloading for {} from CDN", kPlatformKey));
logger.Log(LogLevel::Debug,
fmt::format("WebGPU EP: download URL is {}", manifest.url));
std::atomic<bool> cancel_flag{false};
auto download_progress = [&](float pct) {
if (progress_cb) {
// 0–80% for download phase
if (!progress_cb(name_, pct * 0.8f)) {
cancel_flag.store(true);
}
}
};
if (!HttpDownloadFile(manifest.url, zip_path, kUserAgent,
&cancel_flag, download_progress, logger)) {
logger.Log(LogLevel::Warning, "WebGPU EP: download failed (see prior log for details)");
return false;
}
// Extract
logger.Log(LogLevel::Information,
fmt::format("WebGPU EP: extracting package to {}", staging_dir.string()));
if (!ExtractZip(zip_path, staging_dir, logger)) {
logger.Log(LogLevel::Warning, "WebGPU EP: extraction failed");
std::filesystem::remove_all(staging_dir);
return false;
}
// Clean up zip
std::filesystem::remove(zip_path);
// Verify staging
if (!verify_package(staging_dir)) {
logger.Log(LogLevel::Warning,
fmt::format("WebGPU EP: verification failed after extraction (attempt {})",
attempts_));
std::filesystem::remove_all(staging_dir);
return false;
}
logger.Log(LogLevel::Debug,
fmt::format("WebGPU EP: staging verification succeeded, promoting to {}",
ep_dir.string()));
// Atomic swap: delete old, rename staging to target
if (std::filesystem::exists(ep_dir)) {
std::filesystem::remove_all(ep_dir);
}
std::filesystem::rename(staging_dir, ep_dir);
logger.Log(LogLevel::Information, "WebGPU EP: successfully installed");
}
}
if (progress_cb) {
progress_cb(name_, 90.0f);
}
// Register with ORT
#ifdef _WIN32
// Prepend the EP directory to PATH for the process lifetime.
// WebGPU EP may delay-load additional dependencies from the same directory.
{
DWORD len = GetEnvironmentVariableW(L"PATH", nullptr, 0);
std::wstring prev_path;
if (len > 0) {
prev_path.resize(len);
GetEnvironmentVariableW(L"PATH", prev_path.data(), len);
prev_path.resize(len - 1); // remove trailing null
}
std::wstring new_path = ep_dir.wstring() + L";" + prev_path;
SetEnvironmentVariableW(L"PATH", new_path.c_str());
}
#endif
auto provider_path = ep_dir / kWebGpuProviderLib;
if (!register_ep_(kRegistrationName, provider_path)) {
logger.Log(LogLevel::Warning, "WebGPU EP: ORT registration failed");
return false;
}
registered_ = true;
if (progress_cb) {
progress_cb(name_, 100.0f);
}
logger.Log(LogLevel::Information,
fmt::format("WebGPU EP: ready (install_path={})", ep_dir.string()));
return true;
} catch (const std::exception& e) {
logger.Log(LogLevel::Warning, fmt::format("WebGPU EP: error: {}", e.what()));
return false;
}
}
} // namespace fl