Skip to content

Commit 06f6c71

Browse files
feich-msclaude
andcommitted
Replace RunRotaryEmbeddingQOnly with standalone RotaryEmbeddingWithOffsetProgram
Introduce RotaryEmbeddingWithOffsetProgram as a separate class that computes position from a uniform offset (position_offset + sequence_index) instead of requiring a position_ids tensor input. This avoids the need for RangeProgram dispatch and keeps the original RotaryEmbeddingProgram untouched. Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
1 parent ef6547e commit 06f6c71

3 files changed

Lines changed: 105 additions & 20 deletions

File tree

onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc

Lines changed: 56 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -183,22 +183,56 @@ Status RunFusedQKRotaryEmbedding(onnxruntime::webgpu::ComputeContext& context,
183183
return context.RunProgram(program);
184184
}
185185

186-
// Apply rotary embedding to Q only. Reuses RunFusedQKRotaryEmbedding with a 1-element
187-
// dummy K output and query_in as dummy K input. The shader skips K because k_global_dims[2]=0.
188-
Status RunRotaryEmbeddingQOnly(onnxruntime::webgpu::ComputeContext& context,
189-
const WebgpuAttentionParameters& params,
190-
const Tensor* query_in,
191-
const Tensor* seqlen_k,
192-
const Tensor* cos_cache,
193-
const Tensor* sin_cache,
194-
Tensor* query_out) {
195-
Tensor k_dummy_out = context.CreateGPUTensor(query_in->DataType(), TensorShape({1}));
196-
// Temporarily patch kv_num_heads to 0 so RunFusedQKRotaryEmbedding builds k_global_shape[2]=0.
197-
WebgpuAttentionParameters params_q_only = params;
198-
params_q_only.kv_num_heads_ = 0;
199-
params_q_only.kv_hidden_size_ = 0;
200-
return RunFusedQKRotaryEmbedding(context, params_q_only, query_in, query_in,
201-
seqlen_k, cos_cache, sin_cache, query_out, &k_dummy_out);
186+
// Apply rotary embedding to a single tensor using RotaryEmbeddingWithOffsetProgram.
187+
// Position for each token = past_sequence_length + sequence_index.
188+
Status RunRotaryEmbedding(onnxruntime::webgpu::ComputeContext& context,
189+
const Tensor* input,
190+
const Tensor* cos_cache,
191+
const Tensor* sin_cache,
192+
Tensor* output,
193+
int batch_size,
194+
int sequence_length,
195+
int hidden_size,
196+
int head_size,
197+
int past_sequence_length,
198+
float scale,
199+
bool rotary_interleaved) {
200+
const auto half_rotary_embedding_dim = gsl::narrow_cast<uint32_t>(cos_cache->Shape()[1]);
201+
const auto num_heads = hidden_size / head_size;
202+
203+
const TensorShape global_shape({static_cast<int64_t>(batch_size),
204+
static_cast<int64_t>(sequence_length),
205+
static_cast<int64_t>(num_heads),
206+
static_cast<int64_t>(head_size - half_rotary_embedding_dim)});
207+
const auto rank = global_shape.NumDimensions();
208+
std::vector<uint32_t> global_dims(rank);
209+
std::vector<uint32_t> global_strides(rank);
210+
for (size_t j = 0; j < rank; ++j) {
211+
global_dims[j] = gsl::narrow_cast<uint32_t>(global_shape[j]);
212+
global_strides[j] = gsl::narrow_cast<uint32_t>(global_shape.SizeFromDimension(j + 1));
213+
}
214+
215+
const auto output_size = gsl::narrow_cast<uint32_t>(global_shape.Size());
216+
const auto input_output_strides = std::vector<uint32_t>({
217+
gsl::narrow_cast<uint32_t>(input->Shape().SizeFromDimension(1)),
218+
gsl::narrow_cast<uint32_t>(hidden_size),
219+
gsl::narrow_cast<uint32_t>(head_size),
220+
1u});
221+
222+
RotaryEmbeddingWithOffsetProgram program(rotary_interleaved);
223+
program
224+
.CacheHint(rotary_interleaved)
225+
.AddInputs({{input, ProgramTensorMetadataDependency::TypeAndRank},
226+
{cos_cache, ProgramTensorMetadataDependency::Rank},
227+
{sin_cache, ProgramTensorMetadataDependency::Rank}})
228+
.AddOutput({output, ProgramTensorMetadataDependency::None})
229+
.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
230+
.AddUniformVariables({{scale},
231+
{gsl::make_span(global_dims)},
232+
{gsl::make_span(global_strides)},
233+
{gsl::make_span(input_output_strides)},
234+
{static_cast<uint32_t>(past_sequence_length)}});
235+
return context.RunProgram(program);
202236
}
203237

204238
Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const {
@@ -299,10 +333,12 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
299333
if (do_rotary_) {
300334
// Apply RoPE to Q only — K doesn't need rotation since we reuse another layer's already-rotated KV cache.
301335
qRotary = context.CreateGPUTensor(query->DataType(), query->Shape());
302-
ORT_RETURN_IF_ERROR(RunRotaryEmbeddingQOnly(context, parameters,
303-
query, seqlen_k,
304-
cos_cache, sin_cache,
305-
&qRotary));
336+
ORT_RETURN_IF_ERROR(RunRotaryEmbedding(context,
337+
query, cos_cache, sin_cache, &qRotary,
338+
parameters.batch_size_, parameters.sequence_length_,
339+
parameters.hidden_size_, parameters.head_size_,
340+
parameters.past_sequence_length_,
341+
parameters.scale_, parameters.rotary_interleaved_));
306342
query = &qRotary;
307343
}
308344
} else if (parameters.is_packed_qkv_ && do_rotary_) {

onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,38 @@ Status FusedQKRotaryEmbeddingProgram::GenerateShaderCode(ShaderHelper& shader) c
134134
return Status::OK();
135135
}
136136

137+
Status RotaryEmbeddingWithOffsetProgram::GenerateShaderCode(ShaderHelper& shader) const {
138+
const auto& input = shader.AddInput("input", ShaderUsage::UseUniform);
139+
const auto& cos_cache = shader.AddInput("cos_cache", ShaderUsage::UseUniform);
140+
const auto& sin_cache = shader.AddInput("sin_cache", ShaderUsage::UseUniform);
141+
const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform);
142+
const auto interleaved_str = interleaved_ ? "true" : "false";
143+
shader.MainFunctionBody() << " let half_rotary_emb_dim = uniforms.cos_cache_shape[1];\n"
144+
" let bsnh = global_idx / uniforms.global_stride % uniforms.global_shape;\n"
145+
" let size = uniforms.global_shape[0] * uniforms.global_stride[0];\n"
146+
" if (global_idx >= size) { return; }\n"
147+
" if (bsnh[3] < half_rotary_emb_dim) {\n"
148+
" let position_id = uniforms.position_offset + bsnh[1];\n"
149+
<< " let i = dot(bsnh, uniforms.input_output_stride) + select(0, bsnh[3], " << interleaved_str << ");\n"
150+
<< " let j = i + select(half_rotary_emb_dim, 1, " << interleaved_str << ");\n"
151+
" let max_position = uniforms.cos_cache_shape[0];\n"
152+
" if (position_id >= max_position) {\n"
153+
<< " " << output.SetByOffset("i", input.GetByOffset("i")) << "\n"
154+
<< " " << output.SetByOffset("j", input.GetByOffset("j")) << "\n"
155+
" } else {\n"
156+
<< " let re = " << input.GetByOffset("i") << " * " << cos_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") << " - " << input.GetByOffset("j") << " * " << sin_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") << ";\n"
157+
<< " " << output.SetByOffset("i", "re") << "\n"
158+
<< " let im = " << input.GetByOffset("i") << " * " << sin_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") << " + " << input.GetByOffset("j") << " * " << cos_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") << ";\n"
159+
<< " " << output.SetByOffset("j", "im") << "\n"
160+
" }\n"
161+
<< " } else { \n"
162+
" let k = dot(bsnh, uniforms.input_output_stride) + half_rotary_emb_dim;\n"
163+
<< " " << output.SetByOffset("k", input.GetByOffset("k")) << "\n"
164+
<< " }";
165+
166+
return Status::OK();
167+
}
168+
137169
RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : WebGpuKernel(info) {
138170
scale_ = info.GetAttrOrDefault<float>("scale", 1.0);
139171
rotary_embedding_dim_ = static_cast<int>(info.GetAttrOrDefault<int64_t>("rotary_embedding_dim", 0));

onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,23 @@ class RotaryEmbeddingProgram final : public Program<RotaryEmbeddingProgram> {
2929
const bool interleaved_;
3030
};
3131

32+
class RotaryEmbeddingWithOffsetProgram final : public Program<RotaryEmbeddingWithOffsetProgram> {
33+
public:
34+
RotaryEmbeddingWithOffsetProgram(bool interleaved)
35+
: Program{"RotaryEmbeddingWithOffset"}, interleaved_{interleaved} {}
36+
37+
Status GenerateShaderCode(ShaderHelper& sh) const override;
38+
39+
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"scale", ProgramUniformVariableDataType::Float32},
40+
{"global_shape", ProgramUniformVariableDataType::Uint32},
41+
{"global_stride", ProgramUniformVariableDataType::Uint32},
42+
{"input_output_stride", ProgramUniformVariableDataType::Uint32},
43+
{"position_offset", ProgramUniformVariableDataType::Uint32});
44+
45+
private:
46+
const bool interleaved_;
47+
};
48+
3249
class FusedQKRotaryEmbeddingProgram final : public Program<FusedQKRotaryEmbeddingProgram> {
3350
public:
3451
FusedQKRotaryEmbeddingProgram(bool interleaved) : Program{"FusedQKRotaryEmbedding"}, interleaved_{interleaved} {}

0 commit comments

Comments
 (0)