Skip to content

Commit 410f5a8

Browse files
authored
+rotemb, +rmsnorm, reshape->opset-25, transpose->opset-24 (#27752)
for webgpu ep: + onnx rotary-embedding op + onnx rmsnorm + reshape-> opset-25 + transpose -> opset-24
1 parent f427e3e commit 410f5a8

7 files changed

Lines changed: 369 additions & 6 deletions

File tree

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "core/providers/webgpu/webgpu_supported_types.h"
5+
#include "core/providers/webgpu/llm/rotary_embedding.h"
6+
#include "contrib_ops/webgpu/bert/rotary_embedding.h"
7+
#include "core/providers/webgpu/generator/range.h"
8+
9+
namespace onnxruntime {
10+
namespace webgpu {
11+
12+
ONNX_OPERATOR_KERNEL_EX(
13+
RotaryEmbedding,
14+
kOnnxDomain,
15+
23,
16+
kWebGpuExecutionProvider,
17+
(*KernelDefBuilder::Create())
18+
.TypeConstraint("T", WebGpuSupportedFloatTypes())
19+
.TypeConstraint("M", DataTypeImpl::GetTensorType<int64_t>()),
20+
RotaryEmbedding);
21+
22+
RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : WebGpuKernel(info) {
23+
rotary_embedding_dim_ = static_cast<int>(info.GetAttrOrDefault<int64_t>("rotary_embedding_dim", 0));
24+
num_heads_ = static_cast<int>(info.GetAttrOrDefault<int64_t>("num_heads", 0));
25+
interleaved_ = (info.GetAttrOrDefault<int64_t>("interleaved", 0) == 1);
26+
}
27+
28+
Status RotaryEmbedding::ComputeInternal(ComputeContext& context) const {
29+
// ONNX inputs: X(0), cos_cache(1), sin_cache(2), position_ids(3, optional)
30+
const auto* input = context.Input<Tensor>(0);
31+
const auto* cos_cache = context.Input<Tensor>(1);
32+
const auto* sin_cache = context.Input<Tensor>(2);
33+
const auto* position_ids = context.Input<Tensor>(3); // optional
34+
35+
const auto input_shape = input->Shape();
36+
auto* output = context.Output(0, input_shape);
37+
38+
const auto batch_size = onnxruntime::narrow<uint32_t>(input_shape[0]);
39+
const auto batch_stride = onnxruntime::narrow<uint32_t>(input_shape.SizeFromDimension(1));
40+
const auto sequence_length = onnxruntime::narrow<uint32_t>(input_shape[input_shape.NumDimensions() - 2]);
41+
const auto hidden_size = batch_stride / sequence_length;
42+
const auto half_rotary_embedding_dim = onnxruntime::narrow<uint32_t>(cos_cache->Shape()[cos_cache->Shape().NumDimensions() - 1]);
43+
44+
// Compute head_size: when rotary_embedding_dim is not set, head_size = rotary_dim (= 2 * half).
45+
// When rotary_embedding_dim is set, derive head_size from the 4D input shape or num_heads attribute.
46+
uint32_t head_size;
47+
if (rotary_embedding_dim_ == 0) {
48+
head_size = half_rotary_embedding_dim * 2;
49+
} else if (input_shape.NumDimensions() == 4) {
50+
// 4D input: [batch, num_heads, seq, head_size]
51+
head_size = onnxruntime::narrow<uint32_t>(input_shape[3]);
52+
} else {
53+
ORT_ENFORCE(num_heads_ > 0,
54+
"Attribute 'num_heads' must be provided when 'rotary_embedding_dim' is specified "
55+
"and input is not rank-4 (batch, num_heads, sequence, head).");
56+
head_size = hidden_size / num_heads_;
57+
}
58+
59+
const TensorShape global_shape({batch_size,
60+
sequence_length,
61+
hidden_size / head_size,
62+
head_size - half_rotary_embedding_dim});
63+
64+
const auto rank = global_shape.NumDimensions();
65+
std::vector<uint32_t> global_dims(rank);
66+
std::vector<uint32_t> global_strides(rank);
67+
for (size_t j = 0; j < rank; ++j) {
68+
global_dims[j] = onnxruntime::narrow<uint32_t>(global_shape[j]);
69+
global_strides[j] = onnxruntime::narrow<uint32_t>(global_shape.SizeFromDimension(j + 1));
70+
}
71+
72+
const auto output_size = onnxruntime::narrow<const uint32_t>(global_shape.Size());
73+
const auto input_output_strides =
74+
input_shape.NumDimensions() == 3
75+
? std::vector<uint32_t>({batch_stride, hidden_size, head_size, 1})
76+
: (input_shape.NumDimensions() == 4
77+
? std::vector<uint32_t>({batch_stride, head_size, sequence_length * head_size, 1})
78+
: std::vector<uint32_t>({}));
79+
80+
// The contrib RotaryEmbeddingProgram expects inputs in order:
81+
// input(0), position_ids(1), cos_cache(2), sin_cache(3)
82+
// The ONNX op has: X(0), cos_cache(1), sin_cache(2), position_ids(3, optional)
83+
84+
if (position_ids != nullptr) {
85+
// position_ids provided: cos/sin cache is 2D (max_pos, D/2)
86+
contrib::webgpu::RotaryEmbeddingProgram program{interleaved_};
87+
program
88+
.CacheHint(interleaved_)
89+
.AddInputs({{input, ProgramTensorMetadataDependency::TypeAndRank},
90+
{position_ids, ProgramTensorMetadataDependency::Rank},
91+
{cos_cache, ProgramTensorMetadataDependency::Rank},
92+
{sin_cache, ProgramTensorMetadataDependency::Rank}})
93+
.AddOutput({output, ProgramTensorMetadataDependency::None})
94+
.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
95+
.AddUniformVariables({{1.0f},
96+
{gsl::make_span(global_dims)},
97+
{gsl::make_span(global_strides)},
98+
{gsl::make_span(input_output_strides)}})
99+
.AddIndices(TensorShape{1, 1});
100+
return context.RunProgram(program);
101+
}
102+
103+
// position_ids NOT provided: cos/sin cache is 3D (B, S, D/2)
104+
// Reshape to 2D (B*S, D/2) and generate sequential position_ids.
105+
const auto total_seq = batch_size * sequence_length;
106+
const TensorShape cache_2d_shape({static_cast<int64_t>(total_seq),
107+
static_cast<int64_t>(half_rotary_embedding_dim)});
108+
109+
// Generate position_ids [0, 1, ..., B*S-1] reshaped as (B, S) on GPU using RangeProgram
110+
const TensorShape pos_ids_shape({static_cast<int64_t>(batch_size),
111+
static_cast<int64_t>(sequence_length)});
112+
Tensor pos_ids_tensor = context.CreateGPUTensor(DataTypeImpl::GetType<int64_t>(), pos_ids_shape);
113+
{
114+
RangeProgram range_program{ONNX_NAMESPACE::TensorProto_DataType_INT64};
115+
int32_t start_i32 = 0;
116+
int32_t delta_i32 = 1;
117+
range_program
118+
.AddOutput({&pos_ids_tensor, ProgramTensorMetadataDependency::Type})
119+
.SetDispatchGroupSize((total_seq + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
120+
.AddUniformVariables({
121+
total_seq,
122+
std::bit_cast<uint32_t>(start_i32),
123+
std::bit_cast<uint32_t>(delta_i32),
124+
});
125+
ORT_RETURN_IF_ERROR(context.RunProgram(range_program));
126+
}
127+
128+
contrib::webgpu::RotaryEmbeddingProgram program{interleaved_};
129+
program
130+
.CacheHint(interleaved_)
131+
.AddInputs({{input, ProgramTensorMetadataDependency::TypeAndRank},
132+
{&pos_ids_tensor, ProgramTensorMetadataDependency::Rank},
133+
{cos_cache, ProgramTensorMetadataDependency::Rank, cache_2d_shape, 1},
134+
{sin_cache, ProgramTensorMetadataDependency::Rank, cache_2d_shape, 1}})
135+
.AddOutput({output, ProgramTensorMetadataDependency::None})
136+
.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
137+
.AddUniformVariables({{1.0f},
138+
{gsl::make_span(global_dims)},
139+
{gsl::make_span(global_strides)},
140+
{gsl::make_span(input_output_strides)}})
141+
.AddIndices(TensorShape{1, 1});
142+
return context.RunProgram(program);
143+
}
144+
145+
} // namespace webgpu
146+
} // namespace onnxruntime
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include "core/providers/webgpu/webgpu_kernel.h"
7+
8+
namespace onnxruntime {
9+
namespace webgpu {
10+
11+
class RotaryEmbedding final : public WebGpuKernel {
12+
public:
13+
RotaryEmbedding(const OpKernelInfo& info);
14+
Status ComputeInternal(ComputeContext& context) const override;
15+
16+
private:
17+
int num_heads_;
18+
int rotary_embedding_dim_;
19+
bool interleaved_;
20+
};
21+
22+
} // namespace webgpu
23+
} // namespace onnxruntime
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "core/providers/webgpu/shader_helper.h"
5+
#include "core/providers/webgpu/webgpu_supported_types.h"
6+
#include "core/providers/webgpu/webgpu_utils.h"
7+
#include "core/providers/webgpu/nn/rms_norm.h"
8+
#include "core/providers/webgpu/nn/layer_norm.h"
9+
10+
namespace onnxruntime {
11+
namespace webgpu {
12+
13+
static size_t NormalizeAxis(int64_t axis, size_t tensor_rank) {
14+
int64_t rank = static_cast<int64_t>(tensor_rank);
15+
if (axis < -rank && axis >= rank) {
16+
ORT_THROW("invalid axis: ", axis);
17+
}
18+
return onnxruntime::narrow<size_t>(axis < 0 ? axis + rank : axis);
19+
}
20+
21+
static TensorShape GetOverrideShape(const TensorShape& shape, int components) {
22+
TensorShape override_shape{shape.Size() / components};
23+
return override_shape;
24+
}
25+
26+
Status RMSNorm::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const {
27+
const auto* x = context.Input(0);
28+
const auto* scale = context.Input(1);
29+
30+
const auto x_shape = x->Shape();
31+
32+
const size_t axis = NormalizeAxis(axis_, x_shape.NumDimensions());
33+
const uint32_t norm_count = onnxruntime::narrow<uint32_t>(x_shape.SizeToDimension(axis));
34+
const int64_t norm_size = x_shape.SizeFromDimension(axis);
35+
const int components = GetMaxComponents(norm_size);
36+
const uint32_t norm_size_vectorized = onnxruntime::narrow<uint32_t>((norm_size + components - 1) / components);
37+
38+
const auto& scale_shape = scale->Shape();
39+
const auto scale_size = scale_shape.Size();
40+
if (scale_shape.NumDimensions() > x_shape.NumDimensions()) {
41+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
42+
"Scale and (optional) bias must match X.shape[axis:] or be NumPy-broadcastable to it."
43+
" Scale/Bias rank cannot exceed Input rank.");
44+
}
45+
if (scale_size != norm_size) {
46+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
47+
"Size of X.shape()[axis:] == ", norm_size,
48+
". Size of scale must match this. Got scale size of ",
49+
scale_size);
50+
}
51+
52+
// RMSNormalization outputs: Y (index 0), InvStdDev (index 1, optional)
53+
auto* y = context.Output(0, x_shape);
54+
55+
TensorShapeVector inv_std_dev_dim;
56+
for (size_t i = 0; i < x_shape.NumDimensions(); ++i) {
57+
if (i < axis) {
58+
inv_std_dev_dim.push_back(x_shape[i]);
59+
} else {
60+
inv_std_dev_dim.push_back(1);
61+
}
62+
}
63+
TensorShape inv_std_dev_shape(inv_std_dev_dim);
64+
auto* inv_std_dev = context.Output(1, inv_std_dev_shape);
65+
66+
if (x_shape.Size() == 0) {
67+
return Status::OK();
68+
}
69+
70+
// Check if we should use split norm dimension optimization
71+
const bool split_norm_dim = norm_size % 512 == 0 && norm_count == 1;
72+
73+
// Reuse LayerNormProgram with simplified=true, has_bias=false, no mean output
74+
LayerNormProgram program{/*has_bias=*/false, /*simplified=*/true, /*has_mean_output=*/false,
75+
/*has_inv_std_dev_output=*/inv_std_dev != nullptr, split_norm_dim};
76+
77+
program.CacheHint(components, /*simplified=*/true, split_norm_dim)
78+
.AddInputs({{x, ProgramTensorMetadataDependency::Type, GetOverrideShape(x->Shape(), components), components}})
79+
.AddInputs(
80+
{{scale, ProgramTensorMetadataDependency::Type, GetOverrideShape(scale->Shape(), components), components}})
81+
.AddOutputs({{y, ProgramTensorMetadataDependency::None, GetOverrideShape(y->Shape(), components), components}})
82+
.AddUniformVariables({
83+
{static_cast<uint32_t>(components)},
84+
})
85+
.AddUniformVariables({
86+
{static_cast<uint32_t>(norm_count)},
87+
})
88+
.AddUniformVariables({
89+
{static_cast<uint32_t>(norm_size)},
90+
})
91+
.AddUniformVariables({
92+
{static_cast<uint32_t>(norm_size_vectorized)},
93+
})
94+
.AddUniformVariables({
95+
{static_cast<float>(epsilon_)},
96+
});
97+
98+
if (split_norm_dim) {
99+
const uint32_t workgroup_size_x = 128;
100+
const uint32_t dispatch_size_x = onnxruntime::narrow<uint32_t>(norm_size / (workgroup_size_x * components));
101+
program.SetDispatchGroupSize(dispatch_size_x, 1, 1)
102+
.SetWorkgroupSize(workgroup_size_x);
103+
} else {
104+
program.SetDispatchGroupSize(norm_count);
105+
}
106+
107+
if (inv_std_dev != nullptr) {
108+
program.AddOutputs({{inv_std_dev, ProgramTensorMetadataDependency::None}});
109+
}
110+
111+
return context.RunProgram(program);
112+
}
113+
114+
ONNX_OPERATOR_KERNEL_EX(RMSNormalization, kOnnxDomain, 23, kWebGpuExecutionProvider,
115+
(*KernelDefBuilder::Create())
116+
.TypeConstraint("T", WebGpuSupportedFloatTypes())
117+
.TypeConstraint("V", WebGpuSupportedFloatTypes()),
118+
RMSNorm);
119+
120+
} // namespace webgpu
121+
} // namespace onnxruntime
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include "core/providers/webgpu/webgpu_kernel.h"
7+
8+
namespace onnxruntime {
9+
namespace webgpu {
10+
11+
class RMSNorm final : public WebGpuKernel {
12+
public:
13+
RMSNorm(const OpKernelInfo& info) : WebGpuKernel(info) {
14+
info.GetAttrOrDefault<int64_t>("axis", &axis_, -1);
15+
info.GetAttrOrDefault<float>("epsilon", &epsilon_, 1e-05f);
16+
}
17+
18+
Status ComputeInternal(ComputeContext& context) const override;
19+
20+
private:
21+
int64_t axis_;
22+
float epsilon_;
23+
};
24+
25+
} // namespace webgpu
26+
} // namespace onnxruntime

onnxruntime/core/providers/webgpu/tensor/reshape.cc

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,31 @@ namespace webgpu {
1111
ONNX_OPERATOR_KERNEL_EX(
1212
Reshape,
1313
kOnnxDomain,
14-
21,
14+
25,
15+
kWebGpuExecutionProvider,
16+
(*KernelDefBuilder::Create())
17+
.TypeConstraint("T", WebGpuSupportedNumberTypes())
18+
.TypeConstraint("shape", DataTypeImpl::GetTensorType<int64_t>())
19+
.Alias(0, 0)
20+
.InputMemoryType(OrtMemTypeCPU, 1),
21+
Reshape);
22+
23+
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
24+
Reshape,
25+
kOnnxDomain,
26+
23, 24,
27+
kWebGpuExecutionProvider,
28+
(*KernelDefBuilder::Create())
29+
.TypeConstraint("T", WebGpuSupportedNumberTypes())
30+
.TypeConstraint("shape", DataTypeImpl::GetTensorType<int64_t>())
31+
.Alias(0, 0)
32+
.InputMemoryType(OrtMemTypeCPU, 1),
33+
Reshape);
34+
35+
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
36+
Reshape,
37+
kOnnxDomain,
38+
21, 22,
1539
kWebGpuExecutionProvider,
1640
(*KernelDefBuilder::Create())
1741
.TypeConstraint("T", WebGpuSupportedNumberTypes())

onnxruntime/core/providers/webgpu/tensor/transpose.cc

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,19 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
6363
.TypeConstraint("T", WebGpuSupportedNumberTypes()),
6464
Transpose);
6565

66+
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
67+
Transpose,
68+
kOnnxDomain,
69+
23, 23,
70+
kWebGpuExecutionProvider,
71+
(*KernelDefBuilder::Create())
72+
.TypeConstraint("T", WebGpuSupportedNumberTypes()),
73+
Transpose);
74+
6675
ONNX_OPERATOR_KERNEL_EX(
6776
Transpose,
6877
kOnnxDomain,
69-
23,
78+
24,
7079
kWebGpuExecutionProvider,
7180
(*KernelDefBuilder::Create())
7281
.TypeConstraint("T", WebGpuSupportedNumberTypes()),

0 commit comments

Comments
 (0)