Skip to content

Commit edc5074

Browse files
committed
Address PR #28581 review comments
- Drop the now-unused indirect dispatch buffer entirely: remove the GPU tensor allocation in ApplyFlashAttention, the indirect_buffer / prepare_indirect_dispatch plumbing from CopyKVCacheProgram and SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram (host + WGSL), and the indirect_buffer parameter from ComputeFlashAttentionDecodeQKT / ComputeFlashAttentionDecodeSplitVxScore. Both decode programs now use direct SetDispatchGroupSize, so the indirect buffer was being allocated, bound, and written by thread 0 of the prologue but never read. - Add a clarifying comment at both decode dispatch sites noting that, despite the legacy flag name, use_indirect_dispatch now selects worst-case tiling for a direct dispatch rather than issuing an indirect dispatch. Pure cleanup. Validated on Qwen3-1.7B (Vulkan + D3D12 1k prompts, graph capture on/off); no perf change vs. baseline.
1 parent 79e0495 commit edc5074

3 files changed

Lines changed: 24 additions & 89 deletions

File tree

onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc

Lines changed: 19 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,9 @@ Status SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram::GenerateShaderCode(Sha
2626
const auto& present_key = sh.AddOutput("present_key", ShaderUsage::UseUniform);
2727
const auto& present_value = sh.AddOutput("present_value", ShaderUsage::UseUniform);
2828

29-
if (prepare_indirect_dispatch_) {
30-
sh.AddOutput("indirect_buffer", ShaderUsage::None);
31-
}
32-
3329
return WGSL_TEMPLATE_APPLY(sh, "bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template",
3430
WGSL_TEMPLATE_PARAMETER(interleaved, interleaved_),
3531
WGSL_TEMPLATE_PARAMETER(multi_rotary_cache_concat_offset, multi_rotary_cache_concat_offset_),
36-
WGSL_TEMPLATE_PARAMETER(prepare_indirect_dispatch, prepare_indirect_dispatch_),
3732
WGSL_TEMPLATE_PARAMETER(use_multi_rotary_cache_concat, multi_rotary_cache_concat_offset_ > 0),
3833
WGSL_TEMPLATE_VARIABLE(cos_cache, cos_cache),
3934
WGSL_TEMPLATE_VARIABLE(packed_qkv, packed_qkv),
@@ -62,10 +57,6 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const {
6257
if (use_seqlen_k_) {
6358
shader.AddInput("seqlen_k", ShaderUsage::None);
6459
}
65-
// If prepare_indirect_dispatch is enabled, add seqlen_k input and indirect_buffer output
66-
if (prepare_indirect_dispatch_) {
67-
shader.AddOutput("indirect_buffer", ShaderUsage::None);
68-
}
6960

7061
shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.copy_size")
7162
<< " let output_indices = " << copy_kv_shape.OffsetToIndices("global_idx") << ";\n"
@@ -85,18 +76,6 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const {
8576
shader.MainFunctionBody() << " let present_offset = " << present_key.IndicesToOffset("present_key_indices_t(batch, num_head_id, sequence_id, head_size_id)") << ";\n";
8677
}
8778

88-
// Add indirect dispatch logic for thread 0
89-
if (prepare_indirect_dispatch_) {
90-
// TODO: Add NormalizeDispatchGroupSize logic here to avoid exceeding max dispatch size.
91-
shader.MainFunctionBody() << " // Prepare indirect dispatch buffer for thread 0\n"
92-
<< " if (global_idx == 0u) {\n"
93-
<< " let num_total_seq_length_tile = (total_seq_length + uniforms.tile_size - 1u) / uniforms.tile_size;\n"
94-
<< " indirect_buffer[0] = num_total_seq_length_tile;\n"
95-
<< " indirect_buffer[1] = uniforms.num_heads;\n"
96-
<< " indirect_buffer[2] = 1u;\n"
97-
<< " }\n\n";
98-
}
99-
10079
if (has_past_) {
10180
const auto& past_key = shader.AddInput("past_key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias | ShaderUsage::UseIndicesTypeAlias);
10281
shader.AddInput("past_value", ShaderUsage::UseUniform);
@@ -120,11 +99,10 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const {
12099
Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& parameters,
121100
const Tensor* K, const Tensor* past_key, Tensor* present_key,
122101
const Tensor* V, const Tensor* past_value, Tensor* present_value,
123-
uint32_t tile_size, const Tensor* seqlen_k, Tensor* indirect_buffer) {
102+
const Tensor* seqlen_k) {
124103
// CopyKVCache takes past key/value and current key/value and copies them to present key and value.
125104
// This makes it so that FlashAttention only needs to look at present key and value, and saves
126105
// number of input buffers in the shader, which we run out of (<=8) without this optimization.
127-
// If indirect_buffer is provided, also prepare indirect dispatch buffer for flash attention.
128106
const int components = parameters.head_size_ % 4 == 0 ? 4 : (parameters.head_size_ % 2 == 0 ? 2 : 1);
129107
// has_past means non-static kv cache with valid past data
130108
bool has_past = !parameters.past_present_share_buffer_ && past_key != nullptr && past_value != nullptr && past_key->SizeInBytes() > 0;
@@ -136,12 +114,9 @@ Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtt
136114
TensorShape copy_kv_shape{parameters.batch_size_, num_heads, copy_sequence_length, parameters.head_size_ / components};
137115
int64_t copy_size = copy_kv_shape.Size();
138116

139-
// Determine if we need to prepare indirect dispatch
140-
bool prepare_indirect_dispatch = (indirect_buffer != nullptr);
141117
bool use_seqlen_k = (seqlen_k != nullptr);
142118
bool kv_BNSH = parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH || parameters.qkv_format_ == Q_K_V_BNSH;
143-
CopyKVCacheProgram program{"CopyKVCache", has_past, kv_BNSH, parameters.past_present_share_buffer_,
144-
prepare_indirect_dispatch, use_seqlen_k};
119+
CopyKVCacheProgram program{"CopyKVCache", has_past, kv_BNSH, parameters.past_present_share_buffer_, use_seqlen_k};
145120
if (kv_BNSH) {
146121
program.AddInputs({{K, ProgramTensorMetadataDependency::TypeAndRank, components},
147122
{V, ProgramTensorMetadataDependency::TypeAndRank, components}});
@@ -164,19 +139,13 @@ Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtt
164139
program.AddOutputs({{present_key, ProgramTensorMetadataDependency::Rank, components},
165140
{present_value, ProgramTensorMetadataDependency::Rank, components}});
166141

167-
if (prepare_indirect_dispatch) {
168-
program.AddOutput({indirect_buffer, ProgramTensorMetadataDependency::None});
169-
}
170-
171142
program.AddIndices(std::move(copy_kv_shape));
172143
program.SetDispatchGroupSize(static_cast<uint32_t>((copy_size + 63) / 64))
173144
.SetWorkgroupSize(64)
174-
.CacheHint(has_past, parameters.qkv_format_, parameters.past_present_share_buffer_, prepare_indirect_dispatch, use_seqlen_k)
145+
.CacheHint(has_past, parameters.qkv_format_, parameters.past_present_share_buffer_, use_seqlen_k)
175146
.AddUniformVariables({{static_cast<uint32_t>(copy_size)},
176147
{static_cast<uint32_t>(parameters.total_sequence_length_)},
177-
{static_cast<uint32_t>(parameters.kv_sequence_length_)},
178-
{tile_size},
179-
{static_cast<uint32_t>(parameters.num_heads_)}});
148+
{static_cast<uint32_t>(parameters.kv_sequence_length_)}});
180149

181150
return context.RunProgram(program);
182151
}
@@ -246,8 +215,7 @@ Status FlashAttentionDecodeQKTProgram::GenerateShaderCode(ShaderHelper& shader)
246215

247216
Status ComputeFlashAttentionDecodeQKT(onnxruntime::webgpu::ComputeContext& context, const Tensor* Q,
248217
const Tensor* attention_bias, Tensor* output, Tensor* present_key, Tensor* metadata, const Tensor* seqlen_k,
249-
const WebgpuAttentionParameters& parameters, const Tensor* indirect_buffer, uint32_t num_total_seq_length_tile, uint32_t num_present_sequence_length_tile, uint32_t tile_size, bool use_indirect_dispatch, uint32_t present_sequence_length) {
250-
ORT_UNUSED_PARAMETER(indirect_buffer);
218+
const WebgpuAttentionParameters& parameters, uint32_t num_total_seq_length_tile, uint32_t num_present_sequence_length_tile, uint32_t tile_size, bool use_indirect_dispatch, uint32_t present_sequence_length) {
251219
const float alpha = parameters.scale_ == 0.0f ? 1.f / sqrt(static_cast<float>(parameters.head_size_))
252220
: parameters.scale_;
253221

@@ -282,6 +250,8 @@ Status ComputeFlashAttentionDecodeQKT(onnxruntime::webgpu::ComputeContext& conte
282250
// vkAllocateDescriptorSets), which dominates decode tps under graph capture.
283251
// num_present_sequence_length_tile is the worst-case tile count for the static
284252
// KV cache; the shader already masks workgroups via seqlens_k.
253+
// Despite the legacy flag name, use_indirect_dispatch selects worst-case tiling
254+
// for a direct dispatch here rather than issuing an indirect dispatch.
285255
const uint32_t qkt_dispatch_tiles = use_indirect_dispatch ? num_present_sequence_length_tile : num_total_seq_length_tile;
286256
program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * qkt_dispatch_tiles);
287257
program.SetWorkgroupSize(64)
@@ -330,14 +300,12 @@ Status ComputeFlashAttentionDecodeSplitVxScore(onnxruntime::webgpu::ComputeConte
330300
Tensor* present_value,
331301
const Tensor* seqlen_k,
332302
const WebgpuAttentionParameters& parameters,
333-
const Tensor* indirect_buffer,
334303
uint32_t num_total_seq_length_tile,
335304
uint32_t num_present_sequence_length_tile,
336305
uint32_t tile_size,
337306
bool use_indirect_dispatch,
338307
uint32_t present_sequence_length,
339308
const Tensor* head_sink) {
340-
ORT_UNUSED_PARAMETER(indirect_buffer);
341309
const int components = 4;
342310
const bool has_head_sink = head_sink != nullptr;
343311
int head_size_vec = parameters.v_head_size_ / components;
@@ -356,6 +324,8 @@ Status ComputeFlashAttentionDecodeSplitVxScore(onnxruntime::webgpu::ComputeConte
356324
// See FlashAttentionDecodeQKT above: avoid indirect dispatch to skip Dawn's
357325
// per-call TransformIndirectDispatchBuffer overhead on Vulkan. The shader masks
358326
// out-of-range workgroups via seqlens_k.
327+
// Despite the legacy flag name, use_indirect_dispatch selects worst-case tiling
328+
// for a direct dispatch here rather than issuing an indirect dispatch.
359329
const uint32_t splitvx_dispatch_tiles = use_indirect_dispatch ? num_present_sequence_length_tile : num_total_seq_length_tile;
360330
program.SetDispatchGroupSize(batch_heads * splitvx_dispatch_tiles);
361331
program.CacheHint(tile_size, head_size_vec, use_indirect_dispatch, has_head_sink)
@@ -447,20 +417,13 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
447417
// Declare query_output at function scope to ensure it persists throughout the function
448418
Tensor query_output;
449419

450-
// Create indirect dispatch buffer if using indirect dispatch
451-
Tensor* indirect_buffer_ptr = nullptr;
452-
Tensor indirect_buffer;
453-
454-
// Prepare indirect dispatch buffer for decode path with static KV cache
420+
// Whether the decode path can apply worst-case-tile direct dispatch with seqlens_k
421+
// masking in the shader. Only valid for sequence_length == 1 with static KV cache
422+
// under graph capture, where seqlen_k is non-null.
455423
const bool use_indirect_dispatch = parameters.sequence_length_ == 1 &&
456424
parameters.past_present_share_buffer_ &&
457425
seqlen_k != nullptr &&
458426
context.IsGraphCaptureEnabled();
459-
if (use_indirect_dispatch) {
460-
const TensorShape indirect_buffer_shape{3}; // 3 uint32 values for dispatch dimensions
461-
indirect_buffer = context.CreateGPUTensor(DataTypeImpl::GetType<uint32_t>(), indirect_buffer_shape);
462-
indirect_buffer_ptr = &indirect_buffer;
463-
}
464427

465428
const bool do_rotary = (cos_cache != nullptr && sin_cache != nullptr);
466429

@@ -474,11 +437,10 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
474437
ORT_RETURN_IF_ERROR(RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(context, parameters,
475438
Q, seqlen_k,
476439
cos_cache, sin_cache,
477-
&query_output, present_key, present_value,
478-
indirect_buffer_ptr, tile_size));
440+
&query_output, present_key, present_value));
479441
Q = &query_output;
480442
} else {
481-
ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_seqlen_k ? seqlen_k : nullptr, indirect_buffer_ptr));
443+
ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, use_seqlen_k ? seqlen_k : nullptr));
482444
}
483445

484446
if (parameters.sequence_length_ > 1) {
@@ -555,7 +517,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
555517
const TensorShape metadata_shape(metadata_dims);
556518
Tensor metadata = context.CreateGPUTensor(DataTypeImpl::GetType<float>(), metadata_shape);
557519
ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeQKT(context, Q, attention_bias, &qk, present_key, &metadata, seqlen_k,
558-
parameters, indirect_buffer_ptr, num_total_seq_length_tile,
520+
parameters, num_total_seq_length_tile,
559521
num_present_sequence_length_tile, tile_size, use_indirect_dispatch,
560522
present_sequence_length));
561523

@@ -564,7 +526,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
564526
const TensorShape out_split_vx_shape(out_split_vx_dims);
565527
Tensor out_split_vx = context.CreateGPUTensor(Q->DataType(), out_split_vx_shape);
566528
ORT_RETURN_IF_ERROR(ComputeFlashAttentionDecodeSplitVxScore(context, &metadata, &qk, &out_split_vx, present_value,
567-
seqlen_k, parameters, indirect_buffer_ptr,
529+
seqlen_k, parameters,
568530
num_total_seq_length_tile,
569531
num_present_sequence_length_tile, tile_size,
570532
use_indirect_dispatch, present_sequence_length,
@@ -591,9 +553,7 @@ Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::Comput
591553
const Tensor* sin_cache,
592554
Tensor* query,
593555
Tensor* present_key,
594-
Tensor* present_value,
595-
Tensor* indirect_buffer,
596-
uint32_t tile_size) {
556+
Tensor* present_value) {
597557
const auto half_rotary_embedding_dim = gsl::narrow_cast<uint32_t>(cos_cache->Shape()[1]);
598558
const auto head_size = params.head_size_;
599559

@@ -619,12 +579,11 @@ Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::Comput
619579
// Extract present_sequence_length from present_key tensor shape
620580
const uint32_t present_sequence_length = gsl::narrow_cast<uint32_t>(present_key->Shape()[2]);
621581

622-
const bool prepare_indirect_dispatch = (indirect_buffer != nullptr);
623582
const uint32_t multi_rotary_cache_concat_offset = context.MultiRotaryCacheConcatOffset();
624583

625-
SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram program(params.rotary_interleaved_, prepare_indirect_dispatch, multi_rotary_cache_concat_offset);
584+
SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram program(params.rotary_interleaved_, multi_rotary_cache_concat_offset);
626585
program
627-
.CacheHint(params.rotary_interleaved_, prepare_indirect_dispatch, multi_rotary_cache_concat_offset)
586+
.CacheHint(params.rotary_interleaved_, multi_rotary_cache_concat_offset)
628587
.AddInput({packedQKV, ProgramTensorMetadataDependency::TypeAndRank, components})
629588
.AddInputs({
630589
{seqlen_k, ProgramTensorMetadataDependency::TypeAndRank},
@@ -635,10 +594,6 @@ Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::Comput
635594
{present_key, ProgramTensorMetadataDependency::None, components},
636595
{present_value, ProgramTensorMetadataDependency::None, components}});
637596

638-
if (prepare_indirect_dispatch) {
639-
program.AddOutput({indirect_buffer, ProgramTensorMetadataDependency::None});
640-
}
641-
642597
program.AddUniformVariables({
643598
{static_cast<uint32_t>(params.sequence_length_)},
644599
{static_cast<uint32_t>(params.hidden_size_ / components)},
@@ -648,7 +603,6 @@ Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::Comput
648603
{static_cast<uint32_t>(head_size_vec)},
649604
{static_cast<uint32_t>(half_rotary_embedding_dim_vec)},
650605
{present_sequence_length},
651-
{tile_size},
652606
{static_cast<uint32_t>(dispatch_size)},
653607
});
654608

0 commit comments

Comments
 (0)