@@ -183,6 +183,80 @@ Status RunFusedQKRotaryEmbedding(onnxruntime::webgpu::ComputeContext& context,
183183 return context.RunProgram (program);
184184}
185185
186+ // Apply rotary embedding to Q only. Reuses FusedQKRotaryEmbeddingProgram with k_global_shape[2]=0
187+ // so that K branches are never executed in the shader.
188+ Status RunRotaryEmbeddingQOnly (onnxruntime::webgpu::ComputeContext& context,
189+ const WebgpuAttentionParameters& params,
190+ const Tensor* query_in,
191+ const Tensor* seqlen_k,
192+ const Tensor* cos_cache,
193+ const Tensor* sin_cache,
194+ Tensor* query_out) {
195+ const auto half_rotary_embedding_dim = gsl::narrow_cast<uint32_t >(cos_cache->Shape ()[1 ]);
196+ const auto head_size = params.head_size_ ;
197+
198+ // Build Q domain
199+ const auto hidden_size_q = params.hidden_size_ ;
200+ const TensorShape q_global_shape ({params.batch_size_ , params.sequence_length_ ,
201+ hidden_size_q / head_size,
202+ static_cast <int64_t >(head_size - half_rotary_embedding_dim)});
203+ const auto rank = q_global_shape.NumDimensions ();
204+ std::vector<uint32_t > q_global_dims (rank);
205+ std::vector<uint32_t > q_global_strides (rank);
206+ for (size_t j = 0 ; j < rank; ++j) {
207+ q_global_dims[j] = gsl::narrow_cast<uint32_t >(q_global_shape[j]);
208+ q_global_strides[j] = gsl::narrow_cast<uint32_t >(q_global_shape.SizeFromDimension (j + 1 ));
209+ }
210+
211+ // K domain with 0 heads — shader condition `bsnh[2] < k_global_shape[2]` is never true.
212+ std::vector<uint32_t > k_global_dims = {gsl::narrow_cast<uint32_t >(params.batch_size_ ),
213+ gsl::narrow_cast<uint32_t >(params.sequence_length_ ),
214+ 0u ,
215+ gsl::narrow_cast<uint32_t >(head_size - half_rotary_embedding_dim)};
216+
217+ const auto q_domain_size = gsl::narrow_cast<uint32_t >(q_global_shape.Size ());
218+
219+ const auto q_input_output_strides = std::vector<uint32_t >(
220+ {gsl::narrow_cast<uint32_t >(query_in->Shape ().SizeFromDimension (1 )),
221+ gsl::narrow_cast<uint32_t >(hidden_size_q),
222+ gsl::narrow_cast<uint32_t >(head_size),
223+ 1u });
224+
225+ // K strides are unused but must be provided for uniform layout. Use Q strides as placeholder.
226+ const auto k_input_output_strides = q_input_output_strides;
227+
228+ // WebGPU requires valid buffer bindings even for unused K. Use query_in as k_input (never read)
229+ // and a minimal 1-element tensor as k_output (never written).
230+ Tensor k_dummy_out = context.CreateGPUTensor (query_in->DataType (), TensorShape ({1 }));
231+
232+ FusedQKRotaryEmbeddingProgram program (params.rotary_interleaved_ );
233+ program
234+ .CacheHint (params.rotary_interleaved_ , " q_only" )
235+ .AddInputs ({
236+ {query_in, ProgramTensorMetadataDependency::TypeAndRank},
237+ {query_in, ProgramTensorMetadataDependency::Rank}, // k_input placeholder (never read)
238+ {seqlen_k, ProgramTensorMetadataDependency::TypeAndRank},
239+ {cos_cache, ProgramTensorMetadataDependency::Rank},
240+ {sin_cache, ProgramTensorMetadataDependency::Rank},
241+ })
242+ .AddOutputs ({
243+ {query_out, ProgramTensorMetadataDependency::None},
244+ {&k_dummy_out, ProgramTensorMetadataDependency::None},
245+ })
246+ .SetDispatchGroupSize ((q_domain_size + WORKGROUP_SIZE - 1 ) / WORKGROUP_SIZE )
247+ .AddUniformVariables ({
248+ {params.scale_ },
249+ {gsl::make_span (q_global_dims)},
250+ {gsl::make_span (q_global_strides)},
251+ {gsl::make_span (q_input_output_strides)},
252+ {gsl::make_span (k_global_dims)},
253+ {gsl::make_span (k_input_output_strides)},
254+ {q_domain_size},
255+ });
256+
257+ return context.RunProgram (program);
258+ }
259+
186260Status GroupQueryAttention::ComputeInternal (onnxruntime::webgpu::ComputeContext& context) const {
187261 const Tensor* query = context.Input <Tensor>(0 );
188262 const Tensor* key = context.Input <Tensor>(1 );
@@ -256,7 +330,6 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
256330
257331 Tensor qRotary;
258332 Tensor kRotary ;
259- Tensor kDummy ; // Placeholder for rotary when kv_empty and key has zero sequence length
260333
261334 // kv_sequence_length==0 fast path: K/V inputs are empty (shared KV layer).
262335 // Skip all K/V processing; only apply RoPE to Q if needed.
@@ -275,30 +348,21 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
275348
276349 if (kv_empty) {
277350 // KV inputs are empty - shared KV layer. Only need to extract Q and optionally apply RoPE to Q.
278- // Avoid creating zero-sized K/V tensors as WebGPU may reject zero-element storage buffers.
279351 if (parameters.is_packed_qkv_ ) {
280- // Extract Q from packed QKV. Create non-zero K/V with sequence_length=1 to satisfy SplitPackedQKV,
281- // but they won't be used for attention (past_key/past_value provide the KV context).
352+ // Extract only Q from packed QKV — no need to allocate K/V split tensors.
282353 qSplit = context.CreateGPUTensor (query->DataType (), TensorShape ({parameters.batch_size_ , parameters.sequence_length_ , parameters.hidden_size_ }));
283- kSplit = context.CreateGPUTensor (query->DataType (), TensorShape ({parameters.batch_size_ , parameters.sequence_length_ , parameters.kv_hidden_size_ }));
284- vSplit = context.CreateGPUTensor (query->DataType (), TensorShape ({parameters.batch_size_ , parameters.sequence_length_ , parameters.kv_hidden_size_ }));
285- ORT_RETURN_IF_ERROR (SplitPackedQKV (context, parameters, query, &qSplit, &kSplit , &vSplit, parameters.kv_hidden_size_ ));
354+ ORT_RETURN_IF_ERROR (ExtractQFromPackedQKV (context, parameters, query, &qSplit));
286355 parameters.is_packed_qkv_ = false ;
287356 parameters.qkv_format_ = Q_K_V_BSNH ;
288357 query = &qSplit;
289- // K/V from split are discarded — attention will use past_key/past_value instead.
290358 }
291359 if (do_rotary_) {
292- // Apply RoPE to Q only. Use the fused kernel with a dummy K to avoid zero-element buffers.
293- // K output is discarded since attention uses past_key/past_value.
360+ // Apply RoPE to Q only — K doesn't need rotation since we reuse another layer's already-rotated KV cache.
294361 qRotary = context.CreateGPUTensor (query->DataType (), query->Shape ());
295- kDummy = context.CreateGPUTensor (query->DataType (), TensorShape ({parameters.batch_size_ , parameters.sequence_length_ , parameters.kv_hidden_size_ }));
296- kRotary = context.CreateGPUTensor (query->DataType (), TensorShape ({parameters.batch_size_ , parameters.sequence_length_ , parameters.kv_hidden_size_ }));
297- ORT_RETURN_IF_ERROR (RunFusedQKRotaryEmbedding (context, parameters,
298- query, &kDummy ,
299- seqlen_k,
300- cos_cache, sin_cache,
301- &qRotary, &kRotary ));
362+ ORT_RETURN_IF_ERROR (RunRotaryEmbeddingQOnly (context, parameters,
363+ query, seqlen_k,
364+ cos_cache, sin_cache,
365+ &qRotary));
302366 query = &qRotary;
303367 }
304368 } else if (parameters.is_packed_qkv_ && do_rotary_) {
0 commit comments