Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -606,8 +606,17 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
return Status::OK();
}

bool CanApplyFlashAttention(const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) {
return !parameters.is_packed_qkv_ &&
bool CanApplyFlashAttention(const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k) {
const bool kv_empty = parameters.kv_sequence_length_ == 0;
// FlashAttention here does not implement right-padded per-batch prefill, so the
// first disjunction restricts it to inputs where padding cannot occur:
// - batch_size_ == 1: single sequence, no padding possible.
// - seqlen_k == nullptr: no per-batch lengths, padding inexpressible.
// - kv_empty (shared-KV layer): FA is mandatory; that path takes a different shader.
// The remaining conjuncts exclude packed-QKV (handled by a separate rotary kernel),
// mismatched head/value sizes, and head_size alignments unsupported by the kernel.
return (parameters.batch_size_ == 1 || seqlen_k == nullptr || kv_empty) &&
!parameters.is_packed_qkv_ &&
parameters.head_size_ == parameters.v_head_size_ &&
((context.AdapterInfo().vendor == std::string_view{"qualcomm"} && parameters.head_size_ % 8 == 0) || parameters.head_size_ % 4 == 0);
}
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/webgpu/bert/flash_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k = nullptr,
const Tensor* cos_cache = nullptr, const Tensor* sin_cache = nullptr, const Tensor* head_sink = nullptr);

bool CanApplyFlashAttention(const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context);
bool CanApplyFlashAttention(const WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k = nullptr);

// Split packed QKV with Q/K rotary embedding and copy KV cache fusion
Status RunSplitPackedQKVWithRotaryEmbeddingAndCopyKV(onnxruntime::webgpu::ComputeContext& context,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
// Create a temporary parameters copy with is_packed_qkv_ set to false to check if flash attention can be applied after unpacking
WebgpuAttentionParameters temp_params = parameters;
temp_params.is_packed_qkv_ = false;
will_use_flash_attention = CanApplyFlashAttention(temp_params, context);
will_use_flash_attention = CanApplyFlashAttention(temp_params, context, seqlen_k);
}

if (kv_empty) {
Expand Down
6 changes: 4 additions & 2 deletions onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ Status RotaryEmbeddingProgram::GenerateShaderCode(ShaderHelper& shader) const {
<< " let seqlen_i = " << position_ids_or_seqlens.GetByOffset("batch_idx") << ";\n"
<< " let seqlen = u32(seqlen_i);\n"
" let total_seqlen = seqlen + 1u;\n"
" let past_seqlen = total_seqlen - uniforms.global_shape[1];\n"
" // Right-padded batches with prompt shorter than global_shape[1] would underflow u32; clamp to 0.\n"
" let past_seqlen = select(total_seqlen - uniforms.global_shape[1], 0u, total_seqlen <= uniforms.global_shape[1]);\n"
" let position_id = past_seqlen + bsnh[1];\n"
<< " let i = dot(bsnh, uniforms.input_output_stride) + select(0u, bsnh[3], " << interleaved_str << ");\n"
<< " let j = i + select(half_rotary_emb_dim, 1u, " << interleaved_str << ");\n"
Expand Down Expand Up @@ -200,7 +201,8 @@ Status FusedQKRotaryEmbeddingProgram::GenerateShaderCode(ShaderHelper& shader) c
<< " let seqlen_i = " << seqlens.GetByOffset("batch_idx") << ";\n"
<< " let seqlen = u32(seqlen_i);\n"
<< " let total_seqlen = seqlen + 1u;\n"
<< " let past_seqlen = total_seqlen - uniforms.q_global_shape[1];\n"
<< " // Right-padded batches with prompt shorter than q_global_shape[1] would underflow u32; clamp to 0.\n"
<< " let past_seqlen = select(total_seqlen - uniforms.q_global_shape[1], 0u, total_seqlen <= uniforms.q_global_shape[1]);\n"
<< " let position_id = past_seqlen + sequence_idx;\n"
<< " let qi = dot(bsnh, uniforms.q_input_output_stride) + select(0u, bsnh[3], " << interleaved_str << ");\n"
<< " let qj = qi + select(half_rotary_dim, 1u, " << interleaved_str << ");\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ $MAIN {
let seqlen_i = seqlens.getByOffset(batch_idx);
let seqlen = u32(seqlen_i);
let total_seqlen = seqlen + 1u;
let past_seqlen = total_seqlen - uniforms.sequence_length;
// Right-padded batches with prompt shorter than sequence_length would underflow u32; clamp to 0.
let past_seqlen = select(total_seqlen - uniforms.sequence_length, 0u, total_seqlen <= uniforms.sequence_length);
let position_id = past_seqlen + seq_idx;
#if use_multi_rotary_cache_concat
let base_position = select(0u, multi_rotary_cache_concat_offset, total_seqlen > multi_rotary_cache_concat_offset);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ $MAIN {
let seqlen = u32(seqlen_i);
let total_seqlen = seqlen + 1u;

let past_seqlen = total_seqlen - uniforms.sequence_length;
// Right-padded batches with prompt shorter than sequence_length would underflow u32; clamp to 0.
let past_seqlen = select(total_seqlen - uniforms.sequence_length, 0u, total_seqlen <= uniforms.sequence_length);
// `position_id` is used to get cos/sin cache and also as the time step index in present_key/present_value
let position_id = past_seqlen + seq_idx;
#if use_multi_rotary_cache_concat
Expand Down
Loading
Loading