Skip to content

Commit bfec0b1

Browse files
fs-eireCopilot
andauthored
Add webgpu plugin EP pipeline (#27841)
### Description This PR introduces a packaging pipeline for ONNX Runtime WebGPU EP plugin for the following platforms: - win/x64 - win/arm64 - linux/x64 - mac/arm64 Key changes: **CI/CD Pipeline Additions and Improvements:** * Added a new Azure Pipelines YAML pipeline (`plugin-webgpu-pipeline.yml`) to automate building and packaging the WebGPU plugin for Windows, Linux, and macOS, with parameterized builds for architecture, API version, package version, and build type. The pipeline validates parameter combinations and orchestrates platform-specific packaging stages. * Introduced modular pipeline stage templates for Linux (`plugin-linux-webgpu-stage.yml`), macOS (`plugin-mac-webgpu-stage.yml`), and a top-level packaging stage (`plugin-webgpu-packaging-stage.yml`) to manage platform-specific build, artifact staging, and publishing processes. [[1]](diffhunk://#diff-8d9766b9dfb672636229c848b58bd4beb8469d8a2bc0aab7adfa332a04b49c25R1-R96) [[2]](diffhunk://#diff-c97395e205146bf044dce86c089595772fd09e1f36e898fc20fb831583568a39R1-R106) [[3]](diffhunk://#diff-4c2ad2fa235a30f8f589fa28c573e8e0969e997f1eafbdb3305b04743960b538R1-R75) **Plugin Versioning and Build Configuration:** * Updated `onnxruntime_providers_webgpu.cmake` to set the plugin EP version to match `ORT_VERSION` by default, unless explicitly specified, and to pass this version via a preprocessor definition for consistent version reporting. * Changed the plugin's reported version in `Factory::GetVersionImpl` to use the new `ORT_PLUGIN_EP_VERSION` macro, ensuring the runtime-reported version matches the build configuration. **Codebase Maintenance:** * Added a missing `<mutex>` include to `allocator.h` to ensure thread safety and proper compilation. --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent b73b3dd commit bfec0b1

14 files changed

Lines changed: 995 additions & 26 deletions

File tree

cmake/onnxruntime_providers_webgpu.cmake

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,14 @@
8484
add_definitions("-DONNX_NAMESPACE=onnx")
8585
add_definitions("-DONNX_USE_LITE_PROTO=1")
8686

87+
# Default plugin EP version to ORT_VERSION with "-dev" suffix if not explicitly provided.
88+
if(NOT DEFINED onnxruntime_PLUGIN_EP_VERSION)
89+
set(onnxruntime_PLUGIN_EP_VERSION "${ORT_VERSION}-dev")
90+
endif()
91+
92+
# Set preprocessor definition for plugin EP version
93+
target_compile_definitions(onnxruntime_providers_webgpu PRIVATE ORT_PLUGIN_EP_VERSION="${onnxruntime_PLUGIN_EP_VERSION}")
94+
8795
# Set preprocessor definitions used in onnxruntime_providers_webgpu.rc
8896
if(WIN32)
8997
set(WEBGPU_DLL_FILE_DESCRIPTION "ONNX Runtime WebGPU Provider")

include/onnxruntime/ep/adapter/op_kernel.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,11 @@ struct OpKernelContext {
155155
}
156156
bool GetUseDeterministicCompute() const {
157157
// TODO(fs-eire): Implement GetUseDeterministicCompute().
158+
// if (CurrentOrtApiVersion() >= 25) {
159+
// return /* TBD: wait for GetUseDeterministicCompute to be added in ORT API v25 */;
160+
// } else {
158161
return false;
162+
// }
159163
}
160164
void* GetGPUComputeStream() const {
161165
return context_.GetGPUComputeStream();

include/onnxruntime/ep/api.h

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
#pragma once
55

6+
#include <charconv>
7+
#include <cstring>
68
#include <mutex>
79
#include <optional>
810
#include <stdexcept>
@@ -26,8 +28,30 @@ struct ApiPtrs {
2628

2729
namespace detail {
2830
inline std::optional<ApiPtrs> g_api_ptrs;
31+
32+
inline bool TryGetAPIVersionFromVersionString(const char* version_str, uint32_t& api_version) {
33+
// A valid version string should always be in the format of "1.{API_VERSION}.*".
34+
if (!version_str || version_str[0] != '1' || version_str[1] != '.') {
35+
return false;
36+
}
37+
const char* begin = version_str + 2;
38+
const char* end = std::strchr(begin, '.');
39+
if (!end) {
40+
return false;
41+
}
42+
uint32_t version = 0;
43+
auto [ptr, ec] = std::from_chars(begin, end, version);
44+
if (ec != std::errc{} || ptr != end) {
45+
return false;
46+
}
47+
api_version = version;
48+
return true;
2949
}
3050

51+
inline uint32_t g_current_ort_api_version{};
52+
53+
} // namespace detail
54+
3155
/// <summary>
3256
/// Get the global instance of ApiPtrs.
3357
/// </summary>
@@ -45,16 +69,59 @@ inline const ApiPtrs& Api() {
4569
inline void ApiInit(const OrtApiBase* ort_api_base) {
4670
static std::once_flag init_flag;
4771
std::call_once(init_flag, [&]() {
48-
// Manual init for the C++ API
49-
const OrtApi* ort_api = ort_api_base->GetApi(ORT_API_VERSION);
72+
// The following initialization process is composed of 3 steps:
73+
// 1) Get the ORT API version string
74+
// 2) Try to parse the ORT API version from the version string. If parsing fails, we assume the version is 24.
75+
// 3) Get the ORT API for the parsed version and initialize the global API instance with it.
76+
constexpr uint32_t ORT_BASE_API_VERSION = 24;
77+
const char* version_str = ort_api_base->GetVersionString();
78+
if (!version_str) {
79+
version_str = "unknown";
80+
}
81+
uint32_t current_ort_version = 0;
82+
if (!detail::TryGetAPIVersionFromVersionString(version_str, current_ort_version)) {
83+
// If we fail to parse the version string, we can still try to get the API for the base version and hope it works.
84+
current_ort_version = ORT_BASE_API_VERSION;
85+
}
86+
if (current_ort_version < ORT_BASE_API_VERSION) {
87+
throw std::runtime_error("Failed to initialize EP API: the minimum required ORT API version is " + std::to_string(ORT_BASE_API_VERSION) +
88+
", but the current version is \"" + version_str +
89+
"\" (parsed API version: " + std::to_string(current_ort_version) + ").");
90+
}
91+
92+
const OrtApi* ort_api = ort_api_base->GetApi(current_ort_version);
93+
if (!ort_api) {
94+
throw std::runtime_error("Failed to initialize EP API: the current ORT version is \"" + std::string(version_str) +
95+
"\" but it does not support the parsed API version " + std::to_string(current_ort_version) + ".");
96+
}
97+
98+
detail::g_current_ort_api_version = current_ort_version;
99+
50100
const OrtEpApi* ep_api = ort_api->GetEpApi();
51101
const OrtModelEditorApi* model_editor_api = ort_api->GetModelEditorApi();
102+
if (!ep_api || !model_editor_api) {
103+
throw std::runtime_error("Failed to initialize EP API: GetEpApi or GetModelEditorApi returned null.");
104+
}
105+
106+
// Manual init for the C++ API
52107
Ort::InitApi(ort_api);
53108

54109
// Initialize the global API instance
55110
detail::g_api_ptrs.emplace(*ort_api, *ep_api, *model_editor_api);
56111
});
57112
}
58113

114+
/// <summary>
115+
/// Get the current ORT API version that the EP API has been initialized with.
116+
///
117+
/// This function should be called after ApiInit() to get the actual API version.
118+
/// </summary>
119+
inline uint32_t CurrentOrtApiVersion() {
120+
if (!detail::g_api_ptrs.has_value()) {
121+
throw std::logic_error("onnxruntime::ep::CurrentOrtApiVersion() called before ApiInit().");
122+
}
123+
return detail::g_current_ort_api_version;
124+
}
125+
59126
} // namespace ep
60127
} // namespace onnxruntime

onnxruntime/core/providers/webgpu/ep/factory.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ uint32_t ORT_API_CALL Factory::GetVendorIdImpl(const OrtEpFactory* /*this_ptr*/)
6666
}
6767

6868
const char* ORT_API_CALL Factory::GetVersionImpl(const OrtEpFactory* /*this_ptr*/) noexcept {
69-
return "0.1.0";
69+
return ORT_PLUGIN_EP_VERSION;
7070
}
7171

7272
OrtStatus* ORT_API_CALL Factory::GetSupportedDevicesImpl(

onnxruntime/core/session/onnxruntime_c_api.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
#include "core/session/onnxruntime_session_options_config_keys.h"
5656
#include "core/session/ort_apis.h"
5757
#include "core/session/ort_env.h"
58+
#include "core/session/ort_version_check.h"
5859
#include "core/session/utils.h"
5960

6061
#if defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE)
@@ -4830,6 +4831,9 @@ ORT_API(const OrtApi*, OrtApis::GetApi, uint32_t version) {
48304831
}
48314832

48324833
ORT_API(const char*, OrtApis::GetVersionString) {
4834+
static_assert(onnxruntime::version_check::IsOrtVersionValid(ORT_VERSION),
4835+
"ORT_VERSION must be in the format '1.Y.Z' where Y and Z are non-negative integers without leading "
4836+
"zeros, and Y must equal ORT_API_VERSION");
48334837
return ORT_VERSION;
48344838
}
48354839

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include <cstdint>
7+
#include <string_view>
8+
9+
#include "core/session/onnxruntime_c_api.h"
10+
11+
namespace onnxruntime::version_check {
12+
13+
// A simple consteval-friendly result type for ParseUint.
14+
// std::optional triggers an internal compiler error in MSVC 14.44 when used with consteval.
15+
struct ParseUintResult {
16+
uint32_t value;
17+
bool has_value;
18+
19+
consteval bool operator==(uint32_t other) const { return has_value && value == other; }
20+
consteval bool operator!=(uint32_t other) const { return !(*this == other); }
21+
};
22+
23+
inline consteval ParseUintResult ParseUintNone() { return {0, false}; }
24+
25+
// Parse a non-negative integer from a string_view without leading zeros.
26+
// Returns a result with has_value == false on failure (empty, leading zero, non-digit, or overflow).
27+
consteval ParseUintResult ParseUint(std::string_view str) {
28+
if (str.empty()) return ParseUintNone();
29+
// Leading zeros are not allowed (except "0" itself).
30+
if (str.size() > 1 && str[0] == '0') return ParseUintNone();
31+
uint64_t result = 0;
32+
for (char c : str) {
33+
if (c < '0' || c > '9') return ParseUintNone();
34+
result = result * 10 + static_cast<uint64_t>(c - '0');
35+
if (result > UINT32_MAX) return ParseUintNone();
36+
}
37+
return {static_cast<uint32_t>(result), true};
38+
}
39+
40+
// Validates a version string at compile time.
41+
// It must be in the format "1.Y.Z" where:
42+
// - Major version is 1
43+
// - Y and Z are non-negative integers without leading zeros
44+
// - Y (minor version) must equal expected_api_version (defaults to ORT_API_VERSION)
45+
consteval bool IsOrtVersionValid(std::string_view version, uint32_t expected_api_version = ORT_API_VERSION) {
46+
size_t first_dot = version.find('.');
47+
if (first_dot == std::string_view::npos) return false;
48+
size_t second_dot = version.find('.', first_dot + 1);
49+
if (second_dot == std::string_view::npos) return false;
50+
if (version.find('.', second_dot + 1) != std::string_view::npos) return false; // Exactly two dots
51+
std::string_view major = version.substr(0, first_dot);
52+
std::string_view minor = version.substr(first_dot + 1, second_dot - first_dot - 1);
53+
std::string_view patch = version.substr(second_dot + 1);
54+
if (major != "1") {
55+
return false;
56+
}
57+
auto minor_val = ParseUint(minor);
58+
auto patch_val = ParseUint(patch);
59+
if (!minor_val.has_value || !patch_val.has_value) {
60+
return false;
61+
}
62+
if (minor_val.value != expected_api_version) {
63+
return false;
64+
}
65+
return true;
66+
}
67+
68+
} // namespace onnxruntime::version_check
Lines changed: 56 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,64 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
22
// Licensed under the MIT License.
33

4+
#include "onnxruntime_config.h"
45
#include "core/session/onnxruntime_cxx_api.h"
6+
#include "core/session/ort_version_check.h"
57

6-
#include <cstdint>
7-
#include <charconv>
8-
#include <optional>
9-
#include <string>
10-
#include <vector>
11-
12-
#include "absl/strings/str_split.h"
138
#include "gtest/gtest.h"
149

15-
TEST(CApiTest, VersionConsistencyWithApiVersion) {
16-
const auto version_string = Ort::GetVersionString();
17-
const std::vector<std::string> version_string_components = absl::StrSplit(version_string, '.');
18-
ASSERT_EQ(version_string_components.size(), size_t{3});
19-
20-
auto to_uint32_t = [](const std::string& s) -> std::optional<uint32_t> {
21-
uint32_t result{};
22-
if (std::from_chars(s.data(), s.data() + s.size(), result).ec == std::errc{}) {
23-
return result;
24-
}
25-
return std::nullopt;
26-
};
27-
28-
ASSERT_NE(to_uint32_t(version_string_components[0]), std::nullopt);
29-
ASSERT_EQ(to_uint32_t(version_string_components[1]), uint32_t{ORT_API_VERSION});
30-
ASSERT_NE(to_uint32_t(version_string_components[2]), std::nullopt);
10+
using onnxruntime::version_check::IsOrtVersionValid;
11+
using onnxruntime::version_check::ParseUint;
12+
13+
// Compile-time tests for ParseUint
14+
static_assert(ParseUint("0") == 0u);
15+
static_assert(ParseUint("1") == 1u);
16+
static_assert(ParseUint("25") == 25u);
17+
static_assert(ParseUint("123") == 123u);
18+
static_assert(ParseUint("4294967295") == 4294967295u); // UINT32_MAX
19+
static_assert(!(ParseUint("4294967296").has_value)); // UINT32_MAX + 1 overflows
20+
static_assert(!(ParseUint("").has_value)); // empty
21+
static_assert(!(ParseUint("01").has_value)); // leading zero
22+
static_assert(!(ParseUint("00").has_value)); // leading zero
23+
static_assert(!(ParseUint("abc").has_value)); // non-digit
24+
static_assert(!(ParseUint("1a").has_value)); // trailing non-digit
25+
static_assert(!(ParseUint("-1").has_value)); // negative sign
26+
static_assert(!(ParseUint("1.0").has_value)); // contains dot
27+
static_assert(ParseUint("0").has_value);
28+
static_assert(!ParseUint("").has_value);
29+
30+
// Compile-time tests for IsOrtVersionValid (default expected_api_version = ORT_API_VERSION)
31+
static_assert(IsOrtVersionValid(ORT_VERSION)); // current version must be valid
32+
33+
// Invalid formats
34+
static_assert(!IsOrtVersionValid(""));
35+
static_assert(!IsOrtVersionValid("1"));
36+
static_assert(!IsOrtVersionValid("1.0"));
37+
static_assert(!IsOrtVersionValid("1.0.0.0")); // too many dots
38+
static_assert(!IsOrtVersionValid("2.0.0")); // major != 1
39+
static_assert(!IsOrtVersionValid("1.02.0")); // leading zero in minor
40+
static_assert(!IsOrtVersionValid("1.0.01")); // leading zero in patch
41+
static_assert(!IsOrtVersionValid("1..0")); // empty minor
42+
static_assert(!IsOrtVersionValid("1.0.")); // empty patch
43+
static_assert(!IsOrtVersionValid(".1.0")); // empty major
44+
static_assert(!IsOrtVersionValid("abc")); // non-numeric
45+
static_assert(!IsOrtVersionValid("1.abc.0")); // non-numeric minor
46+
static_assert(!IsOrtVersionValid("1.0.abc")); // non-numeric patch
47+
48+
// Compile-time tests for IsOrtVersionValid with explicit expected_api_version
49+
static_assert(IsOrtVersionValid("1.0.0", 0));
50+
static_assert(IsOrtVersionValid("1.1.0", 1));
51+
static_assert(IsOrtVersionValid("1.25.0", 25));
52+
static_assert(IsOrtVersionValid("1.25.3", 25));
53+
static_assert(IsOrtVersionValid("1.100.0", 100));
54+
static_assert(!IsOrtVersionValid("1.25.0", 24)); // minor doesn't match expected
55+
static_assert(!IsOrtVersionValid("1.25.0", 26)); // minor doesn't match expected
56+
static_assert(!IsOrtVersionValid("1.0.0", 1)); // minor 0 != expected 1
57+
static_assert(!IsOrtVersionValid("2.0.0", 0)); // major != 1
58+
static_assert(!IsOrtVersionValid("1.02.0", 2)); // leading zero in minor
59+
static_assert(!IsOrtVersionValid("1.0.01", 0)); // leading zero in patch
60+
61+
TEST(CApiTest, VersionIsValid) {
62+
// Runtime sanity check — the version string returned by the API is the expected one.
63+
EXPECT_STREQ(Ort::GetVersionString().c_str(), ORT_VERSION);
3164
}

0 commit comments

Comments
 (0)