Skip to content

Commit 511301d

Browse files
feich-msclaude
andcommitted
[WebGPU] Optimize GQA kv_empty path: Q-only extraction and rotary
Eliminate unnecessary K/V allocations in the kv_empty (shared KV layer) path: - Add ExtractQFromPackedQKV to extract only Q from packed QKV without allocating K/V split tensors - Add RunRotaryEmbeddingQOnly that applies rotary to Q only by setting k_global_shape[2]=0, making K branches in the fused shader a no-op - Remove kDummy placeholder and redundant K/V tensor allocations Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
1 parent eee6926 commit 511301d

5 files changed

Lines changed: 141 additions & 37 deletions

File tree

onnxruntime/contrib_ops/webgpu/bert/attention.cc

Lines changed: 51 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -75,25 +75,38 @@ Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_h
7575
Status SplitPackedQKVProgram::GenerateShaderCode(ShaderHelper& sh) const {
7676
// Inputs: packed_qkv [B, S, D], outputs: Q, K, V [B, S, D]
7777
const auto& packed_qkv = sh.AddInput("packed_qkv", ShaderUsage::UseOffsetToIndices | ShaderUsage::UseUniform);
78-
const auto& query = sh.AddOutput("query", ShaderUsage::UseSetByIndices | ShaderUsage::UseUniform);
79-
const auto& key = sh.AddOutput("key", ShaderUsage::UseSetByIndices | ShaderUsage::UseUniform);
80-
const auto& value = sh.AddOutput("val", ShaderUsage::UseSetByIndices | ShaderUsage::UseUniform);
81-
sh.MainFunctionBody()
82-
<< sh.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.input_size")
83-
<< " let packed_qkv_indices = " << packed_qkv.OffsetToIndices("global_idx") << ";\n"
84-
<< " let batch = packed_qkv_indices[0];\n"
85-
<< " let seq = packed_qkv_indices[1];\n"
86-
<< " let d = packed_qkv_indices[2];\n"
87-
<< " let input_data = " << packed_qkv.GetByOffset("global_idx") << ";\n"
88-
<< " if (d < uniforms.hidden_size) {\n"
89-
<< " " << query.SetByIndices("vec3<u32>(batch, seq, d)", "input_data") << ";\n"
90-
<< " } else if (d < (uniforms.hidden_size + uniforms.kv_hidden_size)) {\n"
91-
<< " let kd = d - uniforms.hidden_size;\n"
92-
<< " " << key.SetByIndices("vec3<u32>(batch, seq, kd)", "input_data") << ";\n"
93-
<< " } else {\n"
94-
<< " let vd = d - uniforms.hidden_size - uniforms.kv_hidden_size;\n"
95-
<< " " << value.SetByIndices("vec3<u32>(batch, seq, vd)", "input_data") << ";\n"
96-
<< " }\n";
78+
if (q_only_) {
79+
const auto& query = sh.AddOutput("query", ShaderUsage::UseSetByIndices | ShaderUsage::UseUniform);
80+
sh.MainFunctionBody()
81+
<< sh.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.input_size")
82+
<< " let packed_qkv_indices = " << packed_qkv.OffsetToIndices("global_idx") << ";\n"
83+
<< " let batch = packed_qkv_indices[0];\n"
84+
<< " let seq = packed_qkv_indices[1];\n"
85+
<< " let d = packed_qkv_indices[2];\n"
86+
<< " if (d < uniforms.hidden_size) {\n"
87+
<< " " << query.SetByIndices("vec3<u32>(batch, seq, d)", packed_qkv.GetByOffset("global_idx")) << ";\n"
88+
<< " }\n";
89+
} else {
90+
const auto& query = sh.AddOutput("query", ShaderUsage::UseSetByIndices | ShaderUsage::UseUniform);
91+
const auto& key = sh.AddOutput("key", ShaderUsage::UseSetByIndices | ShaderUsage::UseUniform);
92+
const auto& value = sh.AddOutput("val", ShaderUsage::UseSetByIndices | ShaderUsage::UseUniform);
93+
sh.MainFunctionBody()
94+
<< sh.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.input_size")
95+
<< " let packed_qkv_indices = " << packed_qkv.OffsetToIndices("global_idx") << ";\n"
96+
<< " let batch = packed_qkv_indices[0];\n"
97+
<< " let seq = packed_qkv_indices[1];\n"
98+
<< " let d = packed_qkv_indices[2];\n"
99+
<< " let input_data = " << packed_qkv.GetByOffset("global_idx") << ";\n"
100+
<< " if (d < uniforms.hidden_size) {\n"
101+
<< " " << query.SetByIndices("vec3<u32>(batch, seq, d)", "input_data") << ";\n"
102+
<< " } else if (d < (uniforms.hidden_size + uniforms.kv_hidden_size)) {\n"
103+
<< " let kd = d - uniforms.hidden_size;\n"
104+
<< " " << key.SetByIndices("vec3<u32>(batch, seq, kd)", "input_data") << ";\n"
105+
<< " } else {\n"
106+
<< " let vd = d - uniforms.hidden_size - uniforms.kv_hidden_size;\n"
107+
<< " " << value.SetByIndices("vec3<u32>(batch, seq, vd)", "input_data") << ";\n"
108+
<< " }\n";
109+
}
97110
return Status::OK();
98111
}
99112

@@ -116,6 +129,25 @@ Status SplitPackedQKV(onnxruntime::webgpu::ComputeContext& context, const Webgpu
116129
return context.RunProgram(program);
117130
}
118131

132+
Status ExtractQFromPackedQKV(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& params,
133+
const Tensor* packedQKV, Tensor* query) {
134+
const int components = std::min({GetMaxComponents(params.hidden_size_), GetMaxComponents(params.kv_hidden_size_), GetMaxComponents(params.v_hidden_size_)});
135+
SplitPackedQKVProgram program(/*q_only=*/true);
136+
auto input_size = packedQKV->Shape().Size();
137+
const uint32_t vectorized_input_size = static_cast<uint32_t>(input_size / components);
138+
program
139+
.CacheHint("q_only", components)
140+
.AddInput({packedQKV, ProgramTensorMetadataDependency::TypeAndRank, components})
141+
.AddOutput({query, ProgramTensorMetadataDependency::TypeAndRank, components})
142+
.AddUniformVariables({
143+
{vectorized_input_size},
144+
{static_cast<uint32_t>(params.hidden_size_ / components)},
145+
{static_cast<uint32_t>(params.kv_hidden_size_ / components)},
146+
})
147+
.SetDispatchGroupSize((vectorized_input_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE);
148+
return context.RunProgram(program);
149+
}
150+
119151
void InitVarStub(std::ostringstream& ss, bool has_seqlen_k) {
120152
if (has_seqlen_k) {
121153
ss << "total_sequence_length = u32(seqlen_k[batch_idx]) + 1;\n";

onnxruntime/contrib_ops/webgpu/bert/attention.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,16 @@ class TransferBSDToBNSHProgram final : public Program<TransferBSDToBNSHProgram>
3434

3535
class SplitPackedQKVProgram final : public Program<SplitPackedQKVProgram> {
3636
public:
37-
SplitPackedQKVProgram() : Program{"SplitPackedQKV"} {}
37+
SplitPackedQKVProgram(bool q_only = false) : Program{"SplitPackedQKV"}, q_only_(q_only) {}
3838

3939
Status GenerateShaderCode(ShaderHelper& sh) const override;
4040

4141
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"input_size", ProgramUniformVariableDataType::Uint32},
4242
{"hidden_size", ProgramUniformVariableDataType::Uint32},
4343
{"kv_hidden_size", ProgramUniformVariableDataType::Uint32});
44+
45+
private:
46+
bool q_only_;
4447
};
4548

4649
class AttentionProbsProgram final : public Program<AttentionProbsProgram> {

onnxruntime/contrib_ops/webgpu/bert/attention_common.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,9 @@ Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_h
124124
Status SplitPackedQKV(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& params,
125125
const Tensor* packedQKV, Tensor* query, Tensor* key, Tensor* val, int kv_hidden_size);
126126

127+
Status ExtractQFromPackedQKV(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& params,
128+
const Tensor* packedQKV, Tensor* query);
129+
127130
Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias,
128131
const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value,
129132
Tensor* output_qk, WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context,

onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,8 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
470470
// kv_sequence_length==0: K/V inputs are empty (shared KV layer).
471471
// Skip CopyKVCache and fused split+rotary+copyKV.
472472
// Use past_key/past_value directly as the present buffers for attention.
473+
// Note: do_rotary is always false here because GQA passes cos_cache=nullptr, sin_cache=nullptr
474+
// for kv_empty layers (rotary is applied to Q separately in GQA before calling ApplyFlashAttention).
473475
ORT_ENFORCE(!do_rotary, "Fused SplitPackedQKVWithRotaryEmbeddingAndCopyKV should not be used with kv_sequence_length==0.");
474476
ORT_ENFORCE(past_key != nullptr && past_value != nullptr,
475477
"kv_empty path requires past KV context (KV-shared layers reuse another layer's cache).");

onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc

Lines changed: 81 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,80 @@ Status RunFusedQKRotaryEmbedding(onnxruntime::webgpu::ComputeContext& context,
183183
return context.RunProgram(program);
184184
}
185185

186+
// Apply rotary embedding to Q only. Reuses FusedQKRotaryEmbeddingProgram with k_global_shape[2]=0
187+
// so that K branches are never executed in the shader.
188+
Status RunRotaryEmbeddingQOnly(onnxruntime::webgpu::ComputeContext& context,
189+
const WebgpuAttentionParameters& params,
190+
const Tensor* query_in,
191+
const Tensor* seqlen_k,
192+
const Tensor* cos_cache,
193+
const Tensor* sin_cache,
194+
Tensor* query_out) {
195+
const auto half_rotary_embedding_dim = gsl::narrow_cast<uint32_t>(cos_cache->Shape()[1]);
196+
const auto head_size = params.head_size_;
197+
198+
// Build Q domain
199+
const auto hidden_size_q = params.hidden_size_;
200+
const TensorShape q_global_shape({params.batch_size_, params.sequence_length_,
201+
hidden_size_q / head_size,
202+
static_cast<int64_t>(head_size - half_rotary_embedding_dim)});
203+
const auto rank = q_global_shape.NumDimensions();
204+
std::vector<uint32_t> q_global_dims(rank);
205+
std::vector<uint32_t> q_global_strides(rank);
206+
for (size_t j = 0; j < rank; ++j) {
207+
q_global_dims[j] = gsl::narrow_cast<uint32_t>(q_global_shape[j]);
208+
q_global_strides[j] = gsl::narrow_cast<uint32_t>(q_global_shape.SizeFromDimension(j + 1));
209+
}
210+
211+
// K domain with 0 heads — shader condition `bsnh[2] < k_global_shape[2]` is never true.
212+
std::vector<uint32_t> k_global_dims = {gsl::narrow_cast<uint32_t>(params.batch_size_),
213+
gsl::narrow_cast<uint32_t>(params.sequence_length_),
214+
0u,
215+
gsl::narrow_cast<uint32_t>(head_size - half_rotary_embedding_dim)};
216+
217+
const auto q_domain_size = gsl::narrow_cast<uint32_t>(q_global_shape.Size());
218+
219+
const auto q_input_output_strides = std::vector<uint32_t>(
220+
{gsl::narrow_cast<uint32_t>(query_in->Shape().SizeFromDimension(1)),
221+
gsl::narrow_cast<uint32_t>(hidden_size_q),
222+
gsl::narrow_cast<uint32_t>(head_size),
223+
1u});
224+
225+
// K strides are unused but must be provided for uniform layout. Use Q strides as placeholder.
226+
const auto k_input_output_strides = q_input_output_strides;
227+
228+
// WebGPU requires valid buffer bindings even for unused K. Use query_in as k_input (never read)
229+
// and a minimal 1-element tensor as k_output (never written).
230+
Tensor k_dummy_out = context.CreateGPUTensor(query_in->DataType(), TensorShape({1}));
231+
232+
FusedQKRotaryEmbeddingProgram program(params.rotary_interleaved_);
233+
program
234+
.CacheHint(params.rotary_interleaved_, "q_only")
235+
.AddInputs({
236+
{query_in, ProgramTensorMetadataDependency::TypeAndRank},
237+
{query_in, ProgramTensorMetadataDependency::Rank}, // k_input placeholder (never read)
238+
{seqlen_k, ProgramTensorMetadataDependency::TypeAndRank},
239+
{cos_cache, ProgramTensorMetadataDependency::Rank},
240+
{sin_cache, ProgramTensorMetadataDependency::Rank},
241+
})
242+
.AddOutputs({
243+
{query_out, ProgramTensorMetadataDependency::None},
244+
{&k_dummy_out, ProgramTensorMetadataDependency::None},
245+
})
246+
.SetDispatchGroupSize((q_domain_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
247+
.AddUniformVariables({
248+
{params.scale_},
249+
{gsl::make_span(q_global_dims)},
250+
{gsl::make_span(q_global_strides)},
251+
{gsl::make_span(q_input_output_strides)},
252+
{gsl::make_span(k_global_dims)},
253+
{gsl::make_span(k_input_output_strides)},
254+
{q_domain_size},
255+
});
256+
257+
return context.RunProgram(program);
258+
}
259+
186260
Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const {
187261
const Tensor* query = context.Input<Tensor>(0);
188262
const Tensor* key = context.Input<Tensor>(1);
@@ -256,7 +330,6 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
256330

257331
Tensor qRotary;
258332
Tensor kRotary;
259-
Tensor kDummy; // Placeholder for rotary when kv_empty and key has zero sequence length
260333

261334
// kv_sequence_length==0 fast path: K/V inputs are empty (shared KV layer).
262335
// Skip all K/V processing; only apply RoPE to Q if needed.
@@ -275,30 +348,21 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
275348

276349
if (kv_empty) {
277350
// KV inputs are empty - shared KV layer. Only need to extract Q and optionally apply RoPE to Q.
278-
// Avoid creating zero-sized K/V tensors as WebGPU may reject zero-element storage buffers.
279351
if (parameters.is_packed_qkv_) {
280-
// Extract Q from packed QKV. Create non-zero K/V with sequence_length=1 to satisfy SplitPackedQKV,
281-
// but they won't be used for attention (past_key/past_value provide the KV context).
352+
// Extract only Q from packed QKV — no need to allocate K/V split tensors.
282353
qSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.hidden_size_}));
283-
kSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.kv_hidden_size_}));
284-
vSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.kv_hidden_size_}));
285-
ORT_RETURN_IF_ERROR(SplitPackedQKV(context, parameters, query, &qSplit, &kSplit, &vSplit, parameters.kv_hidden_size_));
354+
ORT_RETURN_IF_ERROR(ExtractQFromPackedQKV(context, parameters, query, &qSplit));
286355
parameters.is_packed_qkv_ = false;
287356
parameters.qkv_format_ = Q_K_V_BSNH;
288357
query = &qSplit;
289-
// K/V from split are discarded — attention will use past_key/past_value instead.
290358
}
291359
if (do_rotary_) {
292-
// Apply RoPE to Q only. Use the fused kernel with a dummy K to avoid zero-element buffers.
293-
// K output is discarded since attention uses past_key/past_value.
360+
// Apply RoPE to Q only — K doesn't need rotation since we reuse another layer's already-rotated KV cache.
294361
qRotary = context.CreateGPUTensor(query->DataType(), query->Shape());
295-
kDummy = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.kv_hidden_size_}));
296-
kRotary = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.kv_hidden_size_}));
297-
ORT_RETURN_IF_ERROR(RunFusedQKRotaryEmbedding(context, parameters,
298-
query, &kDummy,
299-
seqlen_k,
300-
cos_cache, sin_cache,
301-
&qRotary, &kRotary));
362+
ORT_RETURN_IF_ERROR(RunRotaryEmbeddingQOnly(context, parameters,
363+
query, seqlen_k,
364+
cos_cache, sin_cache,
365+
&qRotary));
302366
query = &qRotary;
303367
}
304368
} else if (parameters.is_packed_qkv_ && do_rotary_) {

0 commit comments

Comments
 (0)