@@ -183,22 +183,56 @@ Status RunFusedQKRotaryEmbedding(onnxruntime::webgpu::ComputeContext& context,
183183 return context.RunProgram (program);
184184}
185185
186- // Apply rotary embedding to Q only. Reuses RunFusedQKRotaryEmbedding with a 1-element
187- // dummy K output and query_in as dummy K input. The shader skips K because k_global_dims[2]=0.
188- Status RunRotaryEmbeddingQOnly (onnxruntime::webgpu::ComputeContext& context,
189- const WebgpuAttentionParameters& params,
190- const Tensor* query_in,
191- const Tensor* seqlen_k,
192- const Tensor* cos_cache,
193- const Tensor* sin_cache,
194- Tensor* query_out) {
195- Tensor k_dummy_out = context.CreateGPUTensor (query_in->DataType (), TensorShape ({1 }));
196- // Temporarily patch kv_num_heads to 0 so RunFusedQKRotaryEmbedding builds k_global_shape[2]=0.
197- WebgpuAttentionParameters params_q_only = params;
198- params_q_only.kv_num_heads_ = 0 ;
199- params_q_only.kv_hidden_size_ = 0 ;
200- return RunFusedQKRotaryEmbedding (context, params_q_only, query_in, query_in,
201- seqlen_k, cos_cache, sin_cache, query_out, &k_dummy_out);
186+ // Apply rotary embedding to a single tensor using RotaryEmbeddingWithOffsetProgram.
187+ // Position for each token = past_sequence_length + sequence_index.
188+ Status RunRotaryEmbedding (onnxruntime::webgpu::ComputeContext& context,
189+ const Tensor* input,
190+ const Tensor* cos_cache,
191+ const Tensor* sin_cache,
192+ Tensor* output,
193+ int batch_size,
194+ int sequence_length,
195+ int hidden_size,
196+ int head_size,
197+ int past_sequence_length,
198+ float scale,
199+ bool rotary_interleaved) {
200+ const auto half_rotary_embedding_dim = gsl::narrow_cast<uint32_t >(cos_cache->Shape ()[1 ]);
201+ const auto num_heads = hidden_size / head_size;
202+
203+ const TensorShape global_shape ({static_cast <int64_t >(batch_size),
204+ static_cast <int64_t >(sequence_length),
205+ static_cast <int64_t >(num_heads),
206+ static_cast <int64_t >(head_size - half_rotary_embedding_dim)});
207+ const auto rank = global_shape.NumDimensions ();
208+ std::vector<uint32_t > global_dims (rank);
209+ std::vector<uint32_t > global_strides (rank);
210+ for (size_t j = 0 ; j < rank; ++j) {
211+ global_dims[j] = gsl::narrow_cast<uint32_t >(global_shape[j]);
212+ global_strides[j] = gsl::narrow_cast<uint32_t >(global_shape.SizeFromDimension (j + 1 ));
213+ }
214+
215+ const auto output_size = gsl::narrow_cast<uint32_t >(global_shape.Size ());
216+ const auto input_output_strides = std::vector<uint32_t >({
217+ gsl::narrow_cast<uint32_t >(input->Shape ().SizeFromDimension (1 )),
218+ gsl::narrow_cast<uint32_t >(hidden_size),
219+ gsl::narrow_cast<uint32_t >(head_size),
220+ 1u });
221+
222+ RotaryEmbeddingWithOffsetProgram program (rotary_interleaved);
223+ program
224+ .CacheHint (rotary_interleaved)
225+ .AddInputs ({{input, ProgramTensorMetadataDependency::TypeAndRank},
226+ {cos_cache, ProgramTensorMetadataDependency::Rank},
227+ {sin_cache, ProgramTensorMetadataDependency::Rank}})
228+ .AddOutput ({output, ProgramTensorMetadataDependency::None})
229+ .SetDispatchGroupSize ((output_size + WORKGROUP_SIZE - 1 ) / WORKGROUP_SIZE )
230+ .AddUniformVariables ({{scale},
231+ {gsl::make_span (global_dims)},
232+ {gsl::make_span (global_strides)},
233+ {gsl::make_span (input_output_strides)},
234+ {static_cast <uint32_t >(past_sequence_length)}});
235+ return context.RunProgram (program);
202236}
203237
204238Status GroupQueryAttention::ComputeInternal (onnxruntime::webgpu::ComputeContext& context) const {
@@ -299,10 +333,12 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
299333 if (do_rotary_) {
300334 // Apply RoPE to Q only — K doesn't need rotation since we reuse another layer's already-rotated KV cache.
301335 qRotary = context.CreateGPUTensor (query->DataType (), query->Shape ());
302- ORT_RETURN_IF_ERROR (RunRotaryEmbeddingQOnly (context, parameters,
303- query, seqlen_k,
304- cos_cache, sin_cache,
305- &qRotary));
336+ ORT_RETURN_IF_ERROR (RunRotaryEmbedding (context,
337+ query, cos_cache, sin_cache, &qRotary,
338+ parameters.batch_size_ , parameters.sequence_length_ ,
339+ parameters.hidden_size_ , parameters.head_size_ ,
340+ parameters.past_sequence_length_ ,
341+ parameters.scale_ , parameters.rotary_interleaved_ ));
306342 query = &qRotary;
307343 }
308344 } else if (parameters.is_packed_qkv_ && do_rotary_) {
0 commit comments