@@ -26,14 +26,9 @@ Status SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram::GenerateShaderCode(Sha
2626 const auto & present_key = sh.AddOutput (" present_key" , ShaderUsage::UseUniform);
2727 const auto & present_value = sh.AddOutput (" present_value" , ShaderUsage::UseUniform);
2828
29- if (prepare_indirect_dispatch_) {
30- sh.AddOutput (" indirect_buffer" , ShaderUsage::None);
31- }
32-
3329 return WGSL_TEMPLATE_APPLY (sh, " bert/split_packed_qkv_with_rotary_embedding_and_copykv.wgsl.template" ,
3430 WGSL_TEMPLATE_PARAMETER (interleaved, interleaved_),
3531 WGSL_TEMPLATE_PARAMETER (multi_rotary_cache_concat_offset, multi_rotary_cache_concat_offset_),
36- WGSL_TEMPLATE_PARAMETER (prepare_indirect_dispatch, prepare_indirect_dispatch_),
3732 WGSL_TEMPLATE_PARAMETER (use_multi_rotary_cache_concat, multi_rotary_cache_concat_offset_ > 0 ),
3833 WGSL_TEMPLATE_VARIABLE (cos_cache, cos_cache),
3934 WGSL_TEMPLATE_VARIABLE (packed_qkv, packed_qkv),
@@ -62,10 +57,6 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const {
6257 if (use_seqlen_k_) {
6358 shader.AddInput (" seqlen_k" , ShaderUsage::None);
6459 }
65- // If prepare_indirect_dispatch is enabled, add seqlen_k input and indirect_buffer output
66- if (prepare_indirect_dispatch_) {
67- shader.AddOutput (" indirect_buffer" , ShaderUsage::None);
68- }
6960
7061 shader.MainFunctionBody () << shader.GuardAgainstOutOfBoundsWorkgroupSizes (" uniforms.copy_size" )
7162 << " let output_indices = " << copy_kv_shape.OffsetToIndices (" global_idx" ) << " ;\n "
@@ -85,18 +76,6 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const {
8576 shader.MainFunctionBody () << " let present_offset = " << present_key.IndicesToOffset (" present_key_indices_t(batch, num_head_id, sequence_id, head_size_id)" ) << " ;\n " ;
8677 }
8778
88- // Add indirect dispatch logic for thread 0
89- if (prepare_indirect_dispatch_) {
90- // TODO: Add NormalizeDispatchGroupSize logic here to avoid exceeding max dispatch size.
91- shader.MainFunctionBody () << " // Prepare indirect dispatch buffer for thread 0\n "
92- << " if (global_idx == 0u) {\n "
93- << " let num_total_seq_length_tile = (total_seq_length + uniforms.tile_size - 1u) / uniforms.tile_size;\n "
94- << " indirect_buffer[0] = num_total_seq_length_tile;\n "
95- << " indirect_buffer[1] = uniforms.num_heads;\n "
96- << " indirect_buffer[2] = 1u;\n "
97- << " }\n\n " ;
98- }
99-
10079 if (has_past_) {
10180 const auto & past_key = shader.AddInput (" past_key" , ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias | ShaderUsage::UseIndicesTypeAlias);
10281 shader.AddInput (" past_value" , ShaderUsage::UseUniform);
@@ -120,11 +99,10 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const {
12099Status CopyKVCache (onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& parameters,
121100 const Tensor* K, const Tensor* past_key, Tensor* present_key,
122101 const Tensor* V, const Tensor* past_value, Tensor* present_value,
123- uint32_t tile_size, const Tensor* seqlen_k, Tensor* indirect_buffer ) {
102+ const Tensor* seqlen_k) {
124103 // CopyKVCache takes past key/value and current key/value and copies them to present key and value.
125104 // This makes it so that FlashAttention only needs to look at present key and value, and saves
126105 // number of input buffers in the shader, which we run out of (<=8) without this optimization.
127- // If indirect_buffer is provided, also prepare indirect dispatch buffer for flash attention.
128106 const int components = parameters.head_size_ % 4 == 0 ? 4 : (parameters.head_size_ % 2 == 0 ? 2 : 1 );
129107 // has_past means non-static kv cache with valid past data
130108 bool has_past = !parameters.past_present_share_buffer_ && past_key != nullptr && past_value != nullptr && past_key->SizeInBytes () > 0 ;
@@ -136,12 +114,9 @@ Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtt
136114 TensorShape copy_kv_shape{parameters.batch_size_ , num_heads, copy_sequence_length, parameters.head_size_ / components};
137115 int64_t copy_size = copy_kv_shape.Size ();
138116
139- // Determine if we need to prepare indirect dispatch
140- bool prepare_indirect_dispatch = (indirect_buffer != nullptr );
141117 bool use_seqlen_k = (seqlen_k != nullptr );
142118 bool kv_BNSH = parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH || parameters.qkv_format_ == Q_K_V_BNSH;
143- CopyKVCacheProgram program{" CopyKVCache" , has_past, kv_BNSH, parameters.past_present_share_buffer_ ,
144- prepare_indirect_dispatch, use_seqlen_k};
119+ CopyKVCacheProgram program{" CopyKVCache" , has_past, kv_BNSH, parameters.past_present_share_buffer_ , use_seqlen_k};
145120 if (kv_BNSH) {
146121 program.AddInputs ({{K, ProgramTensorMetadataDependency::TypeAndRank, components},
147122 {V, ProgramTensorMetadataDependency::TypeAndRank, components}});
@@ -164,19 +139,13 @@ Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtt
164139 program.AddOutputs ({{present_key, ProgramTensorMetadataDependency::Rank, components},
165140 {present_value, ProgramTensorMetadataDependency::Rank, components}});
166141
167- if (prepare_indirect_dispatch) {
168- program.AddOutput ({indirect_buffer, ProgramTensorMetadataDependency::None});
169- }
170-
171142 program.AddIndices (std::move (copy_kv_shape));
172143 program.SetDispatchGroupSize (static_cast <uint32_t >((copy_size + 63 ) / 64 ))
173144 .SetWorkgroupSize (64 )
174- .CacheHint (has_past, parameters.qkv_format_ , parameters.past_present_share_buffer_ , prepare_indirect_dispatch, use_seqlen_k)
145+ .CacheHint (has_past, parameters.qkv_format_ , parameters.past_present_share_buffer_ , use_seqlen_k)
175146 .AddUniformVariables ({{static_cast <uint32_t >(copy_size)},
176147 {static_cast <uint32_t >(parameters.total_sequence_length_ )},
177- {static_cast <uint32_t >(parameters.kv_sequence_length_ )},
178- {tile_size},
179- {static_cast <uint32_t >(parameters.num_heads_ )}});
148+ {static_cast <uint32_t >(parameters.kv_sequence_length_ )}});
180149
181150 return context.RunProgram (program);
182151}
@@ -246,8 +215,7 @@ Status FlashAttentionDecodeQKTProgram::GenerateShaderCode(ShaderHelper& shader)
246215
247216Status ComputeFlashAttentionDecodeQKT (onnxruntime::webgpu::ComputeContext& context, const Tensor* Q,
248217 const Tensor* attention_bias, Tensor* output, Tensor* present_key, Tensor* metadata, const Tensor* seqlen_k,
249- const WebgpuAttentionParameters& parameters, const Tensor* indirect_buffer, uint32_t num_total_seq_length_tile, uint32_t num_present_sequence_length_tile, uint32_t tile_size, bool use_indirect_dispatch, uint32_t present_sequence_length) {
250- ORT_UNUSED_PARAMETER (indirect_buffer);
218+ const WebgpuAttentionParameters& parameters, uint32_t num_total_seq_length_tile, uint32_t num_present_sequence_length_tile, uint32_t tile_size, bool use_indirect_dispatch, uint32_t present_sequence_length) {
251219 const float alpha = parameters.scale_ == 0 .0f ? 1 .f / sqrt (static_cast <float >(parameters.head_size_ ))
252220 : parameters.scale_ ;
253221
@@ -282,6 +250,8 @@ Status ComputeFlashAttentionDecodeQKT(onnxruntime::webgpu::ComputeContext& conte
282250 // vkAllocateDescriptorSets), which dominates decode tps under graph capture.
283251 // num_present_sequence_length_tile is the worst-case tile count for the static
284252 // KV cache; the shader already masks workgroups via seqlens_k.
253+ // Despite the legacy flag name, use_indirect_dispatch selects worst-case tiling
254+ // for a direct dispatch here rather than issuing an indirect dispatch.
285255 const uint32_t qkt_dispatch_tiles = use_indirect_dispatch ? num_present_sequence_length_tile : num_total_seq_length_tile;
286256 program.SetDispatchGroupSize (parameters.batch_size_ * parameters.num_heads_ * qkt_dispatch_tiles);
287257 program.SetWorkgroupSize (64 )
@@ -330,14 +300,12 @@ Status ComputeFlashAttentionDecodeSplitVxScore(onnxruntime::webgpu::ComputeConte
330300 Tensor* present_value,
331301 const Tensor* seqlen_k,
332302 const WebgpuAttentionParameters& parameters,
333- const Tensor* indirect_buffer,
334303 uint32_t num_total_seq_length_tile,
335304 uint32_t num_present_sequence_length_tile,
336305 uint32_t tile_size,
337306 bool use_indirect_dispatch,
338307 uint32_t present_sequence_length,
339308 const Tensor* head_sink) {
340- ORT_UNUSED_PARAMETER (indirect_buffer);
341309 const int components = 4 ;
342310 const bool has_head_sink = head_sink != nullptr ;
343311 int head_size_vec = parameters.v_head_size_ / components;
@@ -356,6 +324,8 @@ Status ComputeFlashAttentionDecodeSplitVxScore(onnxruntime::webgpu::ComputeConte
356324 // See FlashAttentionDecodeQKT above: avoid indirect dispatch to skip Dawn's
357325 // per-call TransformIndirectDispatchBuffer overhead on Vulkan. The shader masks
358326 // out-of-range workgroups via seqlens_k.
327+ // Despite the legacy flag name, use_indirect_dispatch selects worst-case tiling
328+ // for a direct dispatch here rather than issuing an indirect dispatch.
359329 const uint32_t splitvx_dispatch_tiles = use_indirect_dispatch ? num_present_sequence_length_tile : num_total_seq_length_tile;
360330 program.SetDispatchGroupSize (batch_heads * splitvx_dispatch_tiles);
361331 program.CacheHint (tile_size, head_size_vec, use_indirect_dispatch, has_head_sink)
@@ -447,20 +417,13 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
447417 // Declare query_output at function scope to ensure it persists throughout the function
448418 Tensor query_output;
449419
450- // Create indirect dispatch buffer if using indirect dispatch
451- Tensor* indirect_buffer_ptr = nullptr ;
452- Tensor indirect_buffer;
453-
454- // Prepare indirect dispatch buffer for decode path with static KV cache
420+ // Whether the decode path can apply worst-case-tile direct dispatch with seqlens_k
421+ // masking in the shader. Only valid for sequence_length == 1 with static KV cache
422+ // under graph capture, where seqlen_k is non-null.
455423 const bool use_indirect_dispatch = parameters.sequence_length_ == 1 &&
456424 parameters.past_present_share_buffer_ &&
457425 seqlen_k != nullptr &&
458426 context.IsGraphCaptureEnabled ();
459- if (use_indirect_dispatch) {
460- const TensorShape indirect_buffer_shape{3 }; // 3 uint32 values for dispatch dimensions
461- indirect_buffer = context.CreateGPUTensor (DataTypeImpl::GetType<uint32_t >(), indirect_buffer_shape);
462- indirect_buffer_ptr = &indirect_buffer;
463- }
464427
465428 const bool do_rotary = (cos_cache != nullptr && sin_cache != nullptr );
466429
@@ -474,11 +437,10 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
474437 ORT_RETURN_IF_ERROR (RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV (context, parameters,
475438 Q, seqlen_k,
476439 cos_cache, sin_cache,
477- &query_output, present_key, present_value,
478- indirect_buffer_ptr, tile_size));
440+ &query_output, present_key, present_value));
479441 Q = &query_output;
480442 } else {
481- ORT_RETURN_IF_ERROR (CopyKVCache (context, parameters, K, past_key, present_key, V, past_value, present_value, tile_size, use_seqlen_k ? seqlen_k : nullptr , indirect_buffer_ptr ));
443+ ORT_RETURN_IF_ERROR (CopyKVCache (context, parameters, K, past_key, present_key, V, past_value, present_value, use_seqlen_k ? seqlen_k : nullptr ));
482444 }
483445
484446 if (parameters.sequence_length_ > 1 ) {
@@ -555,7 +517,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
555517 const TensorShape metadata_shape (metadata_dims);
556518 Tensor metadata = context.CreateGPUTensor (DataTypeImpl::GetType<float >(), metadata_shape);
557519 ORT_RETURN_IF_ERROR (ComputeFlashAttentionDecodeQKT (context, Q, attention_bias, &qk, present_key, &metadata, seqlen_k,
558- parameters, indirect_buffer_ptr, num_total_seq_length_tile,
520+ parameters, num_total_seq_length_tile,
559521 num_present_sequence_length_tile, tile_size, use_indirect_dispatch,
560522 present_sequence_length));
561523
@@ -564,7 +526,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
564526 const TensorShape out_split_vx_shape (out_split_vx_dims);
565527 Tensor out_split_vx = context.CreateGPUTensor (Q->DataType (), out_split_vx_shape);
566528 ORT_RETURN_IF_ERROR (ComputeFlashAttentionDecodeSplitVxScore (context, &metadata, &qk, &out_split_vx, present_value,
567- seqlen_k, parameters, indirect_buffer_ptr,
529+ seqlen_k, parameters,
568530 num_total_seq_length_tile,
569531 num_present_sequence_length_tile, tile_size,
570532 use_indirect_dispatch, present_sequence_length,
@@ -591,9 +553,7 @@ Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::Comput
591553 const Tensor* sin_cache,
592554 Tensor* query,
593555 Tensor* present_key,
594- Tensor* present_value,
595- Tensor* indirect_buffer,
596- uint32_t tile_size) {
556+ Tensor* present_value) {
597557 const auto half_rotary_embedding_dim = gsl::narrow_cast<uint32_t >(cos_cache->Shape ()[1 ]);
598558 const auto head_size = params.head_size_ ;
599559
@@ -619,12 +579,11 @@ Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::Comput
619579 // Extract present_sequence_length from present_key tensor shape
620580 const uint32_t present_sequence_length = gsl::narrow_cast<uint32_t >(present_key->Shape ()[2 ]);
621581
622- const bool prepare_indirect_dispatch = (indirect_buffer != nullptr );
623582 const uint32_t multi_rotary_cache_concat_offset = context.MultiRotaryCacheConcatOffset ();
624583
625- SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram program (params.rotary_interleaved_ , prepare_indirect_dispatch, multi_rotary_cache_concat_offset);
584+ SplitPackedQKVWithRotaryEmbeddingAndCopyKVProgram program (params.rotary_interleaved_ , multi_rotary_cache_concat_offset);
626585 program
627- .CacheHint (params.rotary_interleaved_ , prepare_indirect_dispatch, multi_rotary_cache_concat_offset)
586+ .CacheHint (params.rotary_interleaved_ , multi_rotary_cache_concat_offset)
628587 .AddInput ({packedQKV, ProgramTensorMetadataDependency::TypeAndRank, components})
629588 .AddInputs ({
630589 {seqlen_k, ProgramTensorMetadataDependency::TypeAndRank},
@@ -635,10 +594,6 @@ Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::Comput
635594 {present_key, ProgramTensorMetadataDependency::None, components},
636595 {present_value, ProgramTensorMetadataDependency::None, components}});
637596
638- if (prepare_indirect_dispatch) {
639- program.AddOutput ({indirect_buffer, ProgramTensorMetadataDependency::None});
640- }
641-
642597 program.AddUniformVariables ({
643598 {static_cast <uint32_t >(params.sequence_length_ )},
644599 {static_cast <uint32_t >(params.hidden_size_ / components)},
@@ -648,7 +603,6 @@ Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::Comput
648603 {static_cast <uint32_t >(head_size_vec)},
649604 {static_cast <uint32_t >(half_rotary_embedding_dim_vec)},
650605 {present_sequence_length},
651- {tile_size},
652606 {static_cast <uint32_t >(dispatch_size)},
653607 });
654608
0 commit comments