From 96997f7b8650815bb4532ecc7c074194c0b6ea17 Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Tue, 10 Mar 2026 20:38:07 -0700 Subject: [PATCH 01/27] 1st cut at webgpu LinearAttention --- .../webgpu/bert/linear_attention.cc | 714 ++++++++++++++++++ .../webgpu/bert/linear_attention.h | 133 ++++ .../webgpu/webgpu_contrib_kernels.cc | 5 + .../core/graph/contrib_ops/bert_defs.cc | 193 +++++ onnxruntime/core/graph/contrib_ops/ms_opset.h | 4 + .../contrib_ops/linear_attention_op_test.cc | 700 +++++++++++++++++ 6 files changed, 1749 insertions(+) create mode 100644 onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc create mode 100644 onnxruntime/contrib_ops/webgpu/bert/linear_attention.h create mode 100644 onnxruntime/test/contrib_ops/linear_attention_op_test.cc diff --git a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc new file mode 100644 index 0000000000000..1e49b3309417b --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc @@ -0,0 +1,714 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/webgpu/bert/linear_attention.h" + +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" + +using namespace onnxruntime::webgpu; + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +LinearAttentionUpdateRule ParseUpdateRule(const std::string& rule_str) { + if (rule_str == "linear") { + return LinearAttentionUpdateRule::Linear; + } else if (rule_str == "gated") { + return LinearAttentionUpdateRule::Gated; + } else if (rule_str == "delta") { + return LinearAttentionUpdateRule::Delta; + } else if (rule_str == "gated_delta") { + return LinearAttentionUpdateRule::GatedDelta; + } + ORT_THROW("Unknown update rule: ", rule_str); +} + +// ============================================================================= +// LinearAttentionRecurrent Implementation +// ============================================================================= + +ONNX_OPERATOR_KERNEL_EX( + LinearAttentionRecurrent, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedFloatTypes()), + LinearAttentionRecurrent); + +LinearAttentionRecurrent::LinearAttentionRecurrent(const OpKernelInfo& info) + : WebGpuKernel(info) { + std::string update_rule_str = info.GetAttrOrDefault("update_rule", "gated_delta"); + update_rule_ = ParseUpdateRule(update_rule_str); + scale_ = info.GetAttrOrDefault("scale", 0.0f); +} + +Status LinearAttentionRecurrentProgram::GenerateShaderCode(ShaderHelper& shader) const { + // Input tensors - with proper accessor methods and element type alias for scaling + const auto& query = shader.AddInput("query", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias); + const auto& key = shader.AddInput("key", ShaderUsage::UseUniform); + const auto& value = shader.AddInput("value", ShaderUsage::UseUniform); + const auto& past_state = shader.AddInput("past_state", ShaderUsage::UseUniform); + + // Optional inputs based on update rule + const ShaderVariableHelper* decay_ptr = nullptr; + const ShaderVariableHelper* beta_ptr = nullptr; + if (has_decay_) { + decay_ptr = &shader.AddInput("decay", ShaderUsage::UseUniform); + } + if (has_beta_) { + beta_ptr = &shader.AddInput("beta", ShaderUsage::UseUniform); + } + + // Output tensors + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform); + const auto& present_state = shader.AddOutput("present_state", ShaderUsage::UseUniform); + + // Each workgroup handles one (batch, head) pair + // Within the workgroup, we compute the state update and output + shader.MainFunctionBody() << R"SHADER( + let batch_idx = workgroup_id.x; + let head_idx = workgroup_id.y; + let local_k = local_id.x; + let local_v = local_id.y; + + // Bounds check + if (batch_idx >= uniforms.batch_size || head_idx >= uniforms.num_heads) { + return; + } + + let head_dim_k = uniforms.head_dim_k; + let head_dim_v = uniforms.head_dim_v; + // Cast scale factor to element type to match tensor data type + let scale_factor = query_element_t(select(1.0 / sqrt(f32(head_dim_k)), uniforms.scale, uniforms.scale != 0.0)); + + // Compute base offsets + let qkv_base = (batch_idx * uniforms.num_heads + head_idx) * head_dim_k; + let v_base = (batch_idx * uniforms.num_heads + head_idx) * head_dim_v; + let state_base = (batch_idx * uniforms.num_heads + head_idx) * head_dim_k * head_dim_v; + + // Process state update for this (k, v) element + if (local_k < head_dim_k && local_v < head_dim_v) { + let state_idx = state_base + local_k * head_dim_v + local_v; + + // Load current state value +)SHADER"; + + shader.MainFunctionBody() << " var state_val = " << past_state.GetByOffset("state_idx") << ";\n"; + + // Load k and v values + shader.MainFunctionBody() << " let k_val = " << key.GetByOffset("qkv_base + local_k") << ";\n"; + shader.MainFunctionBody() << " let v_val = " << value.GetByOffset("v_base + local_v") << ";\n"; + + // Apply decay if needed (gated or gated_delta) + if (update_rule_ == LinearAttentionUpdateRule::Gated || update_rule_ == LinearAttentionUpdateRule::GatedDelta) { + shader.MainFunctionBody() << " // Load decay and compute exp(decay) - decay is in log space\n"; + shader.MainFunctionBody() << " let decay_val = " << decay_ptr->GetByOffset("qkv_base + local_k") << ";\n"; + shader.MainFunctionBody() << " let exp_decay = exp(decay_val);\n"; + shader.MainFunctionBody() << " state_val = state_val * exp_decay;\n"; + } + + // Compute the update delta based on update rule + if (update_rule_ == LinearAttentionUpdateRule::Linear) { + shader.MainFunctionBody() << R"SHADER( + // Linear update: S += k ⊗ v + let update = k_val * v_val; + state_val = state_val + update; +)SHADER"; + } else if (update_rule_ == LinearAttentionUpdateRule::Gated) { + shader.MainFunctionBody() << R"SHADER( + // Gated update: S = exp(g) * S + k ⊗ v (decay already applied above) + let update = k_val * v_val; + state_val = state_val + update; +)SHADER"; + } else if (update_rule_ == LinearAttentionUpdateRule::Delta) { + // Delta update requires computing retrieved = S^T @ k + shader.MainFunctionBody() << " // Delta update: S += β * k ⊗ (v - S^T k)\n"; + shader.MainFunctionBody() << " var retrieved = " << past_state.GetByOffset("state_base + 0u * head_dim_v + local_v") + << " * " << key.GetByOffset("qkv_base + 0u") << ";\n"; + shader.MainFunctionBody() << " for (var k_i: u32 = 1u; k_i < head_dim_k; k_i = k_i + 1u) {\n"; + shader.MainFunctionBody() << " let s_idx = state_base + k_i * head_dim_v + local_v;\n"; + shader.MainFunctionBody() << " retrieved = retrieved + " << past_state.GetByOffset("s_idx") + << " * " << key.GetByOffset("qkv_base + k_i") << ";\n"; + shader.MainFunctionBody() << " }\n"; + shader.MainFunctionBody() << " let beta_val = " << beta_ptr->GetByOffset("(batch_idx * uniforms.num_heads + head_idx)") << ";\n"; + shader.MainFunctionBody() << " let delta = beta_val * (v_val - retrieved);\n"; + shader.MainFunctionBody() << " let update = k_val * delta;\n"; + shader.MainFunctionBody() << " state_val = state_val + update;\n"; + } else { // GatedDelta + // Gated Delta update + shader.MainFunctionBody() << " // Gated Delta update: S = exp(g) * S + β * k ⊗ (v - exp(g) * S^T k)\n"; + shader.MainFunctionBody() << " var retrieved = " << past_state.GetByOffset("state_base + 0u * head_dim_v + local_v") + << " * exp(" << decay_ptr->GetByOffset("qkv_base + 0u") << ")" + << " * " << key.GetByOffset("qkv_base + 0u") << ";\n"; + shader.MainFunctionBody() << " for (var k_i: u32 = 1u; k_i < head_dim_k; k_i = k_i + 1u) {\n"; + shader.MainFunctionBody() << " let s_idx = state_base + k_i * head_dim_v + local_v;\n"; + shader.MainFunctionBody() << " let decay_k = " << decay_ptr->GetByOffset("qkv_base + k_i") << ";\n"; + shader.MainFunctionBody() << " retrieved = retrieved + " << past_state.GetByOffset("s_idx") + << " * exp(decay_k) * " << key.GetByOffset("qkv_base + k_i") << ";\n"; + shader.MainFunctionBody() << " }\n"; + shader.MainFunctionBody() << " let beta_val = " << beta_ptr->GetByOffset("(batch_idx * uniforms.num_heads + head_idx)") << ";\n"; + shader.MainFunctionBody() << " let delta = beta_val * (v_val - retrieved);\n"; + shader.MainFunctionBody() << " let update = k_val * delta;\n"; + shader.MainFunctionBody() << " state_val = state_val + update;\n"; + } + + // Write updated state and compute output + shader.MainFunctionBody() << " // Write updated state\n"; + shader.MainFunctionBody() << " " << present_state.SetByOffset("state_idx", "state_val") << "\n"; + shader.MainFunctionBody() << " }\n"; + + shader.MainFunctionBody() << R"SHADER( + // Synchronize before computing output + workgroupBarrier(); + + // Compute output: o = scale * q^T @ S + // Each thread computes one element of the output + if (local_k == 0u && local_v < head_dim_v) { +)SHADER"; + + shader.MainFunctionBody() << " var out_val = " << query.GetByOffset("qkv_base + 0u") + << " * " << present_state.GetByOffset("state_base + 0u * head_dim_v + local_v") << ";\n"; + shader.MainFunctionBody() << " for (var k_i: u32 = 1u; k_i < head_dim_k; k_i = k_i + 1u) {\n"; + shader.MainFunctionBody() << " let q_val = " << query.GetByOffset("qkv_base + k_i") << ";\n"; + shader.MainFunctionBody() << " let s_idx = state_base + k_i * head_dim_v + local_v;\n"; + shader.MainFunctionBody() << " out_val = out_val + q_val * " << present_state.GetByOffset("s_idx") << ";\n"; + shader.MainFunctionBody() << " }\n"; + shader.MainFunctionBody() << " " << output.SetByOffset("v_base + local_v", "out_val * scale_factor") << "\n"; + shader.MainFunctionBody() << " }\n"; + + return Status::OK(); +} + +Status LinearAttentionRecurrent::ComputeInternal(ComputeContext& context) const { + const auto* query = context.Input(0); + const auto* key = context.Input(1); + const auto* value = context.Input(2); + const auto* past_state = context.Input(3); + const auto* decay = context.Input(4); // Optional + const auto* beta = context.Input(5); // Optional + + const auto& query_shape = query->Shape(); + ORT_ENFORCE(query_shape.NumDimensions() == 4, "Query must be 4D: (B, H, 1, d_k)"); + + const auto batch_size = static_cast(query_shape[0]); + const auto num_heads = static_cast(query_shape[1]); + const auto head_dim_k = static_cast(query_shape[3]); + const auto head_dim_v = static_cast(value->Shape()[3]); + + // Validate decay and beta based on update rule + bool has_decay = (decay != nullptr); + bool has_beta = (beta != nullptr); + + if (update_rule_ == LinearAttentionUpdateRule::Gated || update_rule_ == LinearAttentionUpdateRule::GatedDelta) { + ORT_ENFORCE(has_decay, "Decay input is required for gated and gated_delta update rules"); + } + if (update_rule_ == LinearAttentionUpdateRule::Delta || update_rule_ == LinearAttentionUpdateRule::GatedDelta) { + ORT_ENFORCE(has_beta, "Beta input is required for delta and gated_delta update rules"); + } + + // Create output tensors + TensorShape output_shape({static_cast(batch_size), static_cast(num_heads), 1, static_cast(head_dim_v)}); + auto* output = context.Output(0, output_shape); + auto* present_state = context.Output(1, past_state->Shape()); + + // Setup and run the program + LinearAttentionRecurrentProgram program{update_rule_, has_decay, has_beta}; + + program.AddInputs({{query, ProgramTensorMetadataDependency::TypeAndRank}, + {key, ProgramTensorMetadataDependency::TypeAndRank}, + {value, ProgramTensorMetadataDependency::TypeAndRank}, + {past_state, ProgramTensorMetadataDependency::TypeAndRank}}); + + if (has_decay) { + program.AddInput({decay, ProgramTensorMetadataDependency::TypeAndRank}); + } + if (has_beta) { + program.AddInput({beta, ProgramTensorMetadataDependency::TypeAndRank}); + } + + program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank}, + {present_state, ProgramTensorMetadataDependency::TypeAndRank}}); + + // Dispatch: one workgroup per (batch, head), with threads for (k, v) elements + // Use a fixed workgroup size that can cover typical head dimensions + const uint32_t workgroup_size_k = std::min(head_dim_k, 16u); + const uint32_t workgroup_size_v = std::min(head_dim_v, 16u); + + program.SetDispatchGroupSize(batch_size, num_heads, 1) + .SetWorkgroupSize(workgroup_size_k, workgroup_size_v, 1) + .AddUniformVariables({{batch_size}, + {num_heads}, + {head_dim_k}, + {head_dim_v}, + {scale_}}); + + return context.RunProgram(program); +} + +// ============================================================================= +// LinearAttentionChunkParallel Implementation +// ============================================================================= + +ONNX_OPERATOR_KERNEL_EX( + LinearAttentionChunkParallel, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedFloatTypes()), + LinearAttentionChunkParallel); + +LinearAttentionChunkParallel::LinearAttentionChunkParallel(const OpKernelInfo& info) + : WebGpuKernel(info) { + std::string update_rule_str = info.GetAttrOrDefault("update_rule", "gated_delta"); + update_rule_ = ParseUpdateRule(update_rule_str); + chunk_size_ = info.GetAttrOrDefault("chunk_size", 64); + scale_ = info.GetAttrOrDefault("scale", 0.0f); +} + +Status LinearAttentionChunkIntraProgram::GenerateShaderCode(ShaderHelper& shader) const { + // Inputs - referenced by name in WGSL shader + const auto& query = shader.AddInput("query", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias); + const auto& key = shader.AddInput("key", ShaderUsage::UseUniform); + const auto& value = shader.AddInput("value", ShaderUsage::UseUniform); + + std::string decay_name; + if (has_decay_) { + shader.AddInput("decay", ShaderUsage::UseUniform); + decay_name = "decay"; + } + if (has_beta_) { + shader.AddInput("beta", ShaderUsage::UseUniform); + } + + // Outputs + const auto& intra_output = shader.AddOutput("intra_output", ShaderUsage::UseUniform); + const auto& chunk_states = shader.AddOutput("chunk_states", ShaderUsage::UseUniform); + + // Compute intra-chunk causal attention + // For each position i in chunk, compute output using positions 0..i + shader.MainFunctionBody() << R"SHADER( + let batch_idx = workgroup_id.x; + let head_idx = workgroup_id.y; + let chunk_idx = workgroup_id.z; + + if (batch_idx >= uniforms.batch_size || head_idx >= uniforms.num_heads || chunk_idx >= uniforms.num_chunks) { + return; + } + + let head_dim_k = uniforms.head_dim_k; + let head_dim_v = uniforms.head_dim_v; + let chunk_size = uniforms.chunk_size; + let seq_len = uniforms.sequence_length; + let scale_factor = query_element_t(select(1.0 / sqrt(f32(head_dim_k)), uniforms.scale, uniforms.scale != 0.0)); + + // Chunk boundaries + let chunk_start = chunk_idx * chunk_size; + let chunk_end = min(chunk_start + chunk_size, seq_len); + let actual_chunk_size = chunk_end - chunk_start; + + // Base offsets + let bh_offset = batch_idx * uniforms.num_heads + head_idx; + + // Local thread handles one position in the chunk + let local_pos = local_id.x; + + if (local_pos < actual_chunk_size) { + let global_pos = chunk_start + local_pos; + + // Initialize local state for causal computation within chunk + // We need to accumulate state from positions 0..local_pos + let q_base = (bh_offset * seq_len + global_pos) * head_dim_k; + let out_base = (bh_offset * seq_len + global_pos) * head_dim_v; + + // Compute output for this position using causal mask within chunk + for (var v_i: u32 = 0u; v_i < head_dim_v; v_i = v_i + 1u) { + var out_val: query_element_t = query_element_t(0.0); + + // Accumulate contributions from positions 0 to local_pos (inclusive) + for (var src_pos: u32 = 0u; src_pos <= local_pos; src_pos = src_pos + 1u) { + let src_global = chunk_start + src_pos; + let k_base = (bh_offset * seq_len + src_global) * head_dim_k; + let v_base = (bh_offset * seq_len + src_global) * head_dim_v; + + // Compute q @ k^T for this position pair + var qk_dot: query_element_t = query_element_t(0.0); + for (var k_i: u32 = 0u; k_i < head_dim_k; k_i = k_i + 1u) { + qk_dot = qk_dot + )SHADER" << query.GetByOffset("q_base + k_i") << " * " << key.GetByOffset("k_base + k_i") << R"SHADER(; + } + + // For linear attention variants, we need to apply the appropriate weighting + let v_val = )SHADER" << value.GetByOffset("v_base + v_i") << R"SHADER(; +)SHADER"; + + // Apply decay-based weighting if needed + if (has_decay_) { + shader.MainFunctionBody() << R"SHADER( + // Compute cumulative decay from src_pos to local_pos + var cum_decay: query_element_t = query_element_t(0.0); + for (var d_pos: u32 = src_pos + 1u; d_pos <= local_pos; d_pos = d_pos + 1u) { + let d_global = chunk_start + d_pos; + // Average decay across k dimensions for simplicity + var avg_decay: query_element_t = query_element_t(0.0); + for (var k_i: u32 = 0u; k_i < head_dim_k; k_i = k_i + 1u) { + avg_decay = avg_decay + decay[(bh_offset * seq_len + d_global) * head_dim_k + k_i]; + } + cum_decay = cum_decay + avg_decay / query_element_t(head_dim_k); + } + let decay_weight = exp(cum_decay); + out_val = out_val + qk_dot * v_val * decay_weight; +)SHADER"; + } else { + shader.MainFunctionBody() << R"SHADER( + out_val = out_val + qk_dot * v_val; +)SHADER"; + } + + shader.MainFunctionBody() << R"SHADER( + } + + )SHADER" << intra_output.SetByOffset("out_base + v_i", "out_val * scale_factor") << R"SHADER(; + } + } + + // Compute accumulated state at the end of this chunk + // Each thread contributes to building the chunk-end state + workgroupBarrier(); + + // Compute chunk-end state: accumulate k ⊗ v for all positions in chunk + let state_base = (bh_offset * uniforms.num_chunks + chunk_idx) * head_dim_k * head_dim_v; + + for (var k_i: u32 = local_id.x; k_i < head_dim_k; k_i = k_i + 64u) { + for (var v_i: u32 = 0u; v_i < head_dim_v; v_i = v_i + 1u) { + var state_val: query_element_t = query_element_t(0.0); + + for (var pos: u32 = 0u; pos < actual_chunk_size; pos = pos + 1u) { + let global_pos = chunk_start + pos; + let k_base = (bh_offset * seq_len + global_pos) * head_dim_k; + let v_base = (bh_offset * seq_len + global_pos) * head_dim_v; + + let k_val = )SHADER" << key.GetByOffset("k_base + k_i") << R"SHADER(; + let v_val = )SHADER" << value.GetByOffset("v_base + v_i") << R"SHADER(; +)SHADER"; + + if (has_decay_) { + shader.MainFunctionBody() << R"SHADER( + // Decay from this position to chunk end + var decay_to_end: query_element_t = query_element_t(0.0); + for (var d_pos: u32 = pos + 1u; d_pos < actual_chunk_size; d_pos = d_pos + 1u) { + let d_global = chunk_start + d_pos; + decay_to_end = decay_to_end + decay[(bh_offset * seq_len + d_global) * head_dim_k + k_i]; + } + state_val = state_val + k_val * v_val * exp(decay_to_end); +)SHADER"; + } else { + shader.MainFunctionBody() << R"SHADER( + state_val = state_val + k_val * v_val; +)SHADER"; + } + + shader.MainFunctionBody() << R"SHADER( + } + + let state_idx = state_base + k_i * head_dim_v + v_i; + )SHADER" << chunk_states.SetByOffset("state_idx", "state_val") << R"SHADER(; + } + } +)SHADER"; + + return Status::OK(); +} + +Status LinearAttentionChunkInterProgram::GenerateShaderCode(ShaderHelper& shader) const { + // Inputs - referenced by name in WGSL shader + shader.AddInput("intra_output", ShaderUsage::UseUniform); + shader.AddInput("chunk_states", ShaderUsage::UseUniform); + shader.AddInput("query", ShaderUsage::UseUniform); + + if (has_initial_state_) { + shader.AddInput("initial_state", ShaderUsage::UseUniform); + } + if (has_decay_) { + shader.AddInput("decay", ShaderUsage::UseUniform); + } + + // Outputs - referenced by name in WGSL shader + shader.AddOutput("output", ShaderUsage::UseUniform); + shader.AddOutput("final_state", ShaderUsage::UseUniform); + + // Propagate state between chunks and compute final output + shader.MainFunctionBody() << R"SHADER( + let batch_idx = workgroup_id.x; + let head_idx = workgroup_id.y; + + if (batch_idx >= uniforms.batch_size || head_idx >= uniforms.num_heads) { + return; + } + + let head_dim_k = uniforms.head_dim_k; + let head_dim_v = uniforms.head_dim_v; + let chunk_size = uniforms.chunk_size; + let num_chunks = uniforms.num_chunks; + let seq_len = uniforms.sequence_length; + let scale = select(1.0 / sqrt(f32(head_dim_k)), uniforms.scale, uniforms.scale != 0.0); + + let bh_offset = batch_idx * uniforms.num_heads + head_idx; + + // Process each sequence position + let pos = local_id.x; + if (pos < seq_len) { + let chunk_idx = pos / chunk_size; + let q_base = (bh_offset * seq_len + pos) * head_dim_k; + let out_base = (bh_offset * seq_len + pos) * head_dim_v; + + // Start with intra-chunk output + for (var v_i: u32 = 0u; v_i < head_dim_v; v_i = v_i + 1u) { + var out_val = intra_output[out_base + v_i]; + + // Add contribution from previous chunks' accumulated state + // This is q^T @ (sum of states from chunks 0 to chunk_idx-1) + for (var prev_chunk: u32 = 0u; prev_chunk < chunk_idx; prev_chunk = prev_chunk + 1u) { + let state_base = (bh_offset * num_chunks + prev_chunk) * head_dim_k * head_dim_v; + + for (var k_i: u32 = 0u; k_i < head_dim_k; k_i = k_i + 1u) { + let q_val = query[q_base + k_i]; + let state_val = chunk_states[state_base + k_i * head_dim_v + v_i]; +)SHADER"; + + if (has_decay_) { + shader.MainFunctionBody() << R"SHADER( + // Compute cumulative decay from end of prev_chunk to current position + var cum_decay: f32 = 0.0; + let prev_chunk_end = (prev_chunk + 1u) * chunk_size; + for (var d_pos: u32 = prev_chunk_end; d_pos <= pos; d_pos = d_pos + 1u) { + cum_decay = cum_decay + decay[(bh_offset * seq_len + d_pos) * head_dim_k + k_i]; + } + out_val = out_val + q_val * state_val * exp(cum_decay) * scale; +)SHADER"; + } else { + shader.MainFunctionBody() << R"SHADER( + out_val = out_val + q_val * state_val * scale; +)SHADER"; + } + + shader.MainFunctionBody() << R"SHADER( + } + } +)SHADER"; + + if (has_initial_state_) { + shader.MainFunctionBody() << R"SHADER( + // Add contribution from initial state + let init_state_base = bh_offset * head_dim_k * head_dim_v; + for (var k_i: u32 = 0u; k_i < head_dim_k; k_i = k_i + 1u) { + let q_val = query[q_base + k_i]; + let state_val = initial_state[init_state_base + k_i * head_dim_v + v_i]; +)SHADER"; + if (has_decay_) { + shader.MainFunctionBody() << R"SHADER( + // Decay from start to current position + var cum_decay: f32 = 0.0; + for (var d_pos: u32 = 0u; d_pos <= pos; d_pos = d_pos + 1u) { + cum_decay = cum_decay + decay[(bh_offset * seq_len + d_pos) * head_dim_k + k_i]; + } + out_val = out_val + q_val * state_val * exp(cum_decay) * scale; +)SHADER"; + } else { + shader.MainFunctionBody() << R"SHADER( + out_val = out_val + q_val * state_val * scale; +)SHADER"; + } + shader.MainFunctionBody() << R"SHADER( + } +)SHADER"; + } + + shader.MainFunctionBody() << R"SHADER( + output[out_base + v_i] = out_val; + } + } + + // Compute final state: sum all chunk states with appropriate decay + workgroupBarrier(); + + let final_state_base = bh_offset * head_dim_k * head_dim_v; + for (var idx: u32 = local_id.x; idx < head_dim_k * head_dim_v; idx = idx + 256u) { + let k_i = idx / head_dim_v; + let v_i = idx % head_dim_v; + + var state_val: f32 = 0.0; +)SHADER"; + + if (has_initial_state_) { + shader.MainFunctionBody() << R"SHADER( + // Start with initial state + let init_state_base = bh_offset * head_dim_k * head_dim_v; + state_val = initial_state[init_state_base + idx]; +)SHADER"; + if (has_decay_) { + shader.MainFunctionBody() << R"SHADER( + // Decay initial state through entire sequence + var total_decay: f32 = 0.0; + for (var d_pos: u32 = 0u; d_pos < seq_len; d_pos = d_pos + 1u) { + total_decay = total_decay + decay[(bh_offset * seq_len + d_pos) * head_dim_k + k_i]; + } + state_val = state_val * exp(total_decay); +)SHADER"; + } + } + + shader.MainFunctionBody() << R"SHADER( + // Accumulate all chunk states + for (var c: u32 = 0u; c < num_chunks; c = c + 1u) { + let chunk_state_base = (bh_offset * num_chunks + c) * head_dim_k * head_dim_v; + var chunk_val = chunk_states[chunk_state_base + idx]; +)SHADER"; + + if (has_decay_) { + shader.MainFunctionBody() << R"SHADER( + // Decay this chunk's state to end of sequence + let chunk_end = min((c + 1u) * chunk_size, seq_len); + var decay_to_end: f32 = 0.0; + for (var d_pos: u32 = chunk_end; d_pos < seq_len; d_pos = d_pos + 1u) { + decay_to_end = decay_to_end + decay[(bh_offset * seq_len + d_pos) * head_dim_k + k_i]; + } + chunk_val = chunk_val * exp(decay_to_end); +)SHADER"; + } + + shader.MainFunctionBody() << R"SHADER( + state_val = state_val + chunk_val; + } + + final_state[final_state_base + idx] = state_val; + } +)SHADER"; + + return Status::OK(); +} + +Status LinearAttentionChunkParallel::ComputeInternal(ComputeContext& context) const { + const auto* query = context.Input(0); + const auto* key = context.Input(1); + const auto* value = context.Input(2); + const auto* initial_state = context.Input(3); // Optional + const auto* decay = context.Input(4); // Optional + const auto* beta = context.Input(5); // Optional + + const auto& query_shape = query->Shape(); + ORT_ENFORCE(query_shape.NumDimensions() == 4, "Query must be 4D: (B, H, L, d_k)"); + + const auto batch_size = static_cast(query_shape[0]); + const auto num_heads = static_cast(query_shape[1]); + const auto seq_length = static_cast(query_shape[2]); + const auto head_dim_k = static_cast(query_shape[3]); + const auto head_dim_v = static_cast(value->Shape()[3]); + + bool has_initial_state = (initial_state != nullptr); + bool has_decay = (decay != nullptr); + bool has_beta = (beta != nullptr); + + // Validate inputs based on update rule + if (update_rule_ == LinearAttentionUpdateRule::Gated || update_rule_ == LinearAttentionUpdateRule::GatedDelta) { + ORT_ENFORCE(has_decay, "Decay input is required for gated and gated_delta update rules"); + } + if (update_rule_ == LinearAttentionUpdateRule::Delta || update_rule_ == LinearAttentionUpdateRule::GatedDelta) { + ORT_ENFORCE(has_beta, "Beta input is required for delta and gated_delta update rules"); + } + + const uint32_t chunk_size = static_cast(chunk_size_); + const uint32_t num_chunks = (seq_length + chunk_size - 1) / chunk_size; + + // Create output tensors + TensorShape output_shape({static_cast(batch_size), static_cast(num_heads), + static_cast(seq_length), static_cast(head_dim_v)}); + TensorShape state_shape({static_cast(batch_size), static_cast(num_heads), + static_cast(head_dim_k), static_cast(head_dim_v)}); + + auto* output = context.Output(0, output_shape); + auto* final_state = context.Output(1, state_shape); + + // Allocate intermediate tensors for chunk computation + TensorShape chunk_states_shape({static_cast(batch_size), static_cast(num_heads), + static_cast(num_chunks), static_cast(head_dim_k), + static_cast(head_dim_v)}); + + // Allocate temporary tensors - need separate intra_output to avoid aliasing + Tensor intra_output_tensor = context.CreateGPUTensor(query->DataType(), output_shape); + Tensor chunk_states_tensor = context.CreateGPUTensor(query->DataType(), chunk_states_shape); + + // Step 1: Compute intra-chunk attention and per-chunk states + { + LinearAttentionChunkIntraProgram intra_program{update_rule_, has_decay, has_beta}; + + intra_program.AddInputs({{query, ProgramTensorMetadataDependency::TypeAndRank}, + {key, ProgramTensorMetadataDependency::TypeAndRank}, + {value, ProgramTensorMetadataDependency::TypeAndRank}}); + + if (has_decay) { + intra_program.AddInput({decay, ProgramTensorMetadataDependency::TypeAndRank}); + } + if (has_beta) { + intra_program.AddInput({beta, ProgramTensorMetadataDependency::TypeAndRank}); + } + + intra_program.AddOutputs({{&intra_output_tensor, ProgramTensorMetadataDependency::TypeAndRank}, + {&chunk_states_tensor, ProgramTensorMetadataDependency::TypeAndRank}}); + + intra_program.SetDispatchGroupSize(batch_size, num_heads, num_chunks) + .SetWorkgroupSize(64, 1, 1) + .AddUniformVariables({{batch_size}, + {num_heads}, + {seq_length}, + {head_dim_k}, + {head_dim_v}, + {chunk_size}, + {num_chunks}, + {scale_}}); + + ORT_RETURN_IF_ERROR(context.RunProgram(intra_program)); + } + + // Step 2: Inter-chunk state propagation and final output computation + { + LinearAttentionChunkInterProgram inter_program{update_rule_, has_decay, has_beta, has_initial_state}; + + // Use separate intra_output_tensor as input (read-only) and output (write-only) to avoid aliasing + inter_program.AddInputs({{&intra_output_tensor, ProgramTensorMetadataDependency::TypeAndRank}, // intra_output + {&chunk_states_tensor, ProgramTensorMetadataDependency::TypeAndRank}, // chunk_states + {query, ProgramTensorMetadataDependency::TypeAndRank}}); + + if (has_initial_state) { + inter_program.AddInput({initial_state, ProgramTensorMetadataDependency::TypeAndRank}); + } + if (has_decay) { + inter_program.AddInput({decay, ProgramTensorMetadataDependency::TypeAndRank}); + } + + inter_program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank}, + {final_state, ProgramTensorMetadataDependency::TypeAndRank}}); + + inter_program.SetDispatchGroupSize(batch_size, num_heads, 1) + .SetWorkgroupSize(256, 1, 1) + .AddUniformVariables({{batch_size}, + {num_heads}, + {seq_length}, + {head_dim_k}, + {head_dim_v}, + {chunk_size}, + {num_chunks}, + {scale_}}); + + ORT_RETURN_IF_ERROR(context.RunProgram(inter_program)); + } + + return Status::OK(); +} + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.h b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.h new file mode 100644 index 0000000000000..c1c023583d95e --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.h @@ -0,0 +1,133 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; +using onnxruntime::webgpu::ComputeContext; + +// Update rule enumeration +enum class LinearAttentionUpdateRule { + Linear, // S_t = S_{t-1} + k ⊗ v + Gated, // S_t = exp(g) * S_{t-1} + k ⊗ v + Delta, // S_t = S_{t-1} + β * k ⊗ (v - S^T k) + GatedDelta, // S_t = exp(g) * S_{t-1} + β * k ⊗ (v - exp(g) * S^T k) +}; + +LinearAttentionUpdateRule ParseUpdateRule(const std::string& rule_str); + +// Program for LinearAttentionRecurrent (single-token decode) +class LinearAttentionRecurrentProgram final : public Program { + public: + LinearAttentionRecurrentProgram(LinearAttentionUpdateRule update_rule, bool has_decay, bool has_beta) + : Program{"LinearAttentionRecurrent"}, + update_rule_(update_rule), + has_decay_(has_decay), + has_beta_(has_beta) {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"batch_size", ProgramUniformVariableDataType::Uint32}, + {"num_heads", ProgramUniformVariableDataType::Uint32}, + {"head_dim_k", ProgramUniformVariableDataType::Uint32}, + {"head_dim_v", ProgramUniformVariableDataType::Uint32}, + {"scale", ProgramUniformVariableDataType::Float32}); + + private: + LinearAttentionUpdateRule update_rule_; + bool has_decay_; + bool has_beta_; +}; + +// Kernel for LinearAttentionRecurrent +class LinearAttentionRecurrent final : public WebGpuKernel { + public: + LinearAttentionRecurrent(const OpKernelInfo& info); + Status ComputeInternal(ComputeContext& context) const override; + + private: + LinearAttentionUpdateRule update_rule_; + float scale_; +}; + +// Program for intra-chunk attention computation +class LinearAttentionChunkIntraProgram final : public Program { + public: + LinearAttentionChunkIntraProgram(LinearAttentionUpdateRule update_rule, bool has_decay, bool has_beta) + : Program{"LinearAttentionChunkIntra"}, + update_rule_(update_rule), + has_decay_(has_decay), + has_beta_(has_beta) {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"batch_size", ProgramUniformVariableDataType::Uint32}, + {"num_heads", ProgramUniformVariableDataType::Uint32}, + {"sequence_length", ProgramUniformVariableDataType::Uint32}, + {"head_dim_k", ProgramUniformVariableDataType::Uint32}, + {"head_dim_v", ProgramUniformVariableDataType::Uint32}, + {"chunk_size", ProgramUniformVariableDataType::Uint32}, + {"num_chunks", ProgramUniformVariableDataType::Uint32}, + {"scale", ProgramUniformVariableDataType::Float32}); + + private: + LinearAttentionUpdateRule update_rule_; + bool has_decay_; + bool has_beta_; +}; + +// Program for inter-chunk state propagation +class LinearAttentionChunkInterProgram final : public Program { + public: + LinearAttentionChunkInterProgram(LinearAttentionUpdateRule update_rule, bool has_decay, bool has_beta, bool has_initial_state) + : Program{"LinearAttentionChunkInter"}, + update_rule_(update_rule), + has_decay_(has_decay), + has_beta_(has_beta), + has_initial_state_(has_initial_state) {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"batch_size", ProgramUniformVariableDataType::Uint32}, + {"num_heads", ProgramUniformVariableDataType::Uint32}, + {"sequence_length", ProgramUniformVariableDataType::Uint32}, + {"head_dim_k", ProgramUniformVariableDataType::Uint32}, + {"head_dim_v", ProgramUniformVariableDataType::Uint32}, + {"chunk_size", ProgramUniformVariableDataType::Uint32}, + {"num_chunks", ProgramUniformVariableDataType::Uint32}, + {"scale", ProgramUniformVariableDataType::Float32}); + + private: + LinearAttentionUpdateRule update_rule_; + bool has_decay_; + bool has_beta_; + bool has_initial_state_; +}; + +// Kernel for LinearAttentionChunkParallel +class LinearAttentionChunkParallel final : public WebGpuKernel { + public: + LinearAttentionChunkParallel(const OpKernelInfo& info); + Status ComputeInternal(ComputeContext& context) const override; + + private: + LinearAttentionUpdateRule update_rule_; + int64_t chunk_size_; + float scale_; +}; + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc index 357eebee714d5..a62cc2457b9ed 100644 --- a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc @@ -3,6 +3,7 @@ #include "contrib_ops/webgpu/webgpu_contrib_kernels.h" #include "contrib_ops/webgpu/bert/group_query_attention.h" +#include "contrib_ops/webgpu/bert/linear_attention.h" #include "core/framework/op_kernel.h" @@ -19,6 +20,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Fu class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, GatherBlockQuantized); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Gelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, GroupQueryAttention); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, LinearAttentionRecurrent); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, LinearAttentionChunkParallel); // LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 16, LayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MatMulNBits); @@ -49,6 +52,8 @@ Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry, bool enable BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 092c05f9e081a..45fa48035885b 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -2217,5 +2217,198 @@ ONNX_MS_OPERATOR_SET_SCHEMA( } })); +constexpr const char* LinearAttentionRecurrent_ver1_doc = R"DOC( +Linear Attention Recurrent operator for single-token decode step. + +This is the core operation for recurrent linear attention mechanisms used in modern +hybrid LLMs (Qwen3.5, Jamba, RWKV-6, etc.). It performs a fused state update and +output computation, keeping the full state matrix in fast memory. + +The update_rule attribute selects the recurrence type: +- "linear": S_t = S_{t-1} + k_t ⊗ v_t; o_t = q_t^T S_t / sqrt(d_k) +- "gated": S_t = exp(g_t) * S_{t-1} + k_t ⊗ v_t; o_t = q_t^T S_t / sqrt(d_k) +- "delta": S_t = S_{t-1} + β_t * k_t ⊗ (v_t - S_{t-1}^T k_t); o_t = q_t^T S_t / sqrt(d_k) +- "gated_delta": S_t = exp(g_t) * S_{t-1} + β_t * k_t ⊗ (v_t - exp(g_t) * S_{t-1}^T k_t); o_t = q_t^T S_t / sqrt(d_k) + +where g_t is the decay (in log-space), β_t is the update rate, and ⊗ denotes outer product. +)DOC"; + +ONNX_MS_OPERATOR_SET_SCHEMA( + LinearAttentionRecurrent, 1, + OpSchema() + .SetDoc(LinearAttentionRecurrent_ver1_doc) + .Attr("update_rule", + "The update rule for the linear attention recurrence. " + "One of: 'linear', 'gated', 'delta', 'gated_delta'. Default is 'gated_delta'.", + AttributeProto::STRING, + std::string("gated_delta")) + .Attr("scale", + "Output scaling factor. When 0.0 (default), uses 1/sqrt(d_k) where d_k is the key dimension.", + AttributeProto::FLOAT, + 0.0f) + .Input(0, + "query", + "Query vector with shape (batch_size, num_heads, 1, head_dim_k)", + "T") + .Input(1, + "key", + "Key vector with shape (batch_size, num_heads, 1, head_dim_k). " + "Should be L2-normalized for delta/gated_delta modes.", + "T") + .Input(2, + "value", + "Value vector with shape (batch_size, num_heads, 1, head_dim_v)", + "T") + .Input(3, + "past_state", + "Recurrent state from previous step with shape (batch_size, num_heads, head_dim_k, head_dim_v)", + "T") + .Input(4, + "decay", + "Exponential decay gate in log-space with shape broadcastable to (batch_size, num_heads, 1, head_dim_k). " + "Required for 'gated' and 'gated_delta' modes.", + "T", + OpSchema::Optional) + .Input(5, + "beta", + "Update rate (sigmoid output) with shape broadcastable to (batch_size, num_heads, 1, 1). " + "Required for 'delta' and 'gated_delta' modes.", + "T", + OpSchema::Optional) + .Output(0, + "output", + "Attention output with shape (batch_size, num_heads, 1, head_dim_v)", + "T") + .Output(1, + "present_state", + "Updated recurrent state with shape (batch_size, num_heads, head_dim_k, head_dim_v)", + "T") + .TypeConstraint("T", + {"tensor(float)", "tensor(float16)"}, + "Constrain input and output types to float tensors.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + propagateElemTypeFromInputToOutput(ctx, 0, 1); + + // Output 0: same shape as query (batch_size, num_heads, 1, head_dim_v) + // but last dim comes from value + if (hasInputShape(ctx, 0) && hasInputShape(ctx, 2)) { + auto& query_shape = getInputShape(ctx, 0); + auto& value_shape = getInputShape(ctx, 2); + TensorShapeProto output_shape; + *output_shape.add_dim() = query_shape.dim(0); + *output_shape.add_dim() = query_shape.dim(1); + *output_shape.add_dim() = query_shape.dim(2); + *output_shape.add_dim() = value_shape.dim(3); + updateOutputShape(ctx, 0, output_shape); + } + + // Output 1: same shape as past_state + if (hasInputShape(ctx, 3)) { + propagateShapeFromInputToOutput(ctx, 3, 1); + } + })); + +constexpr const char* LinearAttentionChunkParallel_ver1_doc = R"DOC( +Linear Attention Chunk-Parallel operator for efficient prefill. + +Processes a long input sequence by splitting it into chunks, computing intra-chunk +attention in parallel, and propagating state between chunks. This is semantically +equivalent to running LinearAttentionRecurrent sequentially for each token, but +implemented using a chunk-parallel algorithm for GPU efficiency. + +The update_rule attribute has the same semantics as LinearAttentionRecurrent. +)DOC"; + +ONNX_MS_OPERATOR_SET_SCHEMA( + LinearAttentionChunkParallel, 1, + OpSchema() + .SetDoc(LinearAttentionChunkParallel_ver1_doc) + .Attr("update_rule", + "The update rule for the linear attention recurrence. " + "One of: 'linear', 'gated', 'delta', 'gated_delta'. Default is 'gated_delta'.", + AttributeProto::STRING, + std::string("gated_delta")) + .Attr("chunk_size", + "Chunk size for parallel computation. Default is 64.", + AttributeProto::INT, + static_cast(64)) + .Attr("scale", + "Output scaling factor. When 0.0 (default), uses 1/sqrt(d_k) where d_k is the key dimension.", + AttributeProto::FLOAT, + 0.0f) + .Input(0, + "query", + "Query vectors with shape (batch_size, num_heads, sequence_length, head_dim_k)", + "T") + .Input(1, + "key", + "Key vectors with shape (batch_size, num_heads, sequence_length, head_dim_k). " + "Should be L2-normalized for delta/gated_delta modes.", + "T") + .Input(2, + "value", + "Value vectors with shape (batch_size, num_heads, sequence_length, head_dim_v)", + "T") + .Input(3, + "initial_state", + "State from previous chunk/context with shape (batch_size, num_heads, head_dim_k, head_dim_v). " + "If not provided, initialized to zeros.", + "T", + OpSchema::Optional) + .Input(4, + "decay", + "Per-token decay gates in log-space with shape broadcastable to " + "(batch_size, num_heads, sequence_length, head_dim_k). " + "Required for 'gated' and 'gated_delta' modes.", + "T", + OpSchema::Optional) + .Input(5, + "beta", + "Per-token update rates with shape broadcastable to " + "(batch_size, num_heads, sequence_length, 1). " + "Required for 'delta' and 'gated_delta' modes.", + "T", + OpSchema::Optional) + .Output(0, + "output", + "Attention output for all positions with shape (batch_size, num_heads, sequence_length, head_dim_v)", + "T") + .Output(1, + "final_state", + "State after processing all tokens with shape (batch_size, num_heads, head_dim_k, head_dim_v)", + "T") + .TypeConstraint("T", + {"tensor(float)", "tensor(float16)"}, + "Constrain input and output types to float tensors.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + propagateElemTypeFromInputToOutput(ctx, 0, 1); + + // Output 0: (batch_size, num_heads, sequence_length, head_dim_v) + if (hasInputShape(ctx, 0) && hasInputShape(ctx, 2)) { + auto& query_shape = getInputShape(ctx, 0); + auto& value_shape = getInputShape(ctx, 2); + TensorShapeProto output_shape; + *output_shape.add_dim() = query_shape.dim(0); + *output_shape.add_dim() = query_shape.dim(1); + *output_shape.add_dim() = query_shape.dim(2); + *output_shape.add_dim() = value_shape.dim(3); + updateOutputShape(ctx, 0, output_shape); + } + + // Output 1: (batch_size, num_heads, head_dim_k, head_dim_v) + if (hasInputShape(ctx, 0) && hasInputShape(ctx, 2)) { + auto& query_shape = getInputShape(ctx, 0); + auto& value_shape = getInputShape(ctx, 2); + TensorShapeProto state_shape; + *state_shape.add_dim() = query_shape.dim(0); + *state_shape.add_dim() = query_shape.dim(1); + *state_shape.add_dim() = query_shape.dim(3); // head_dim_k + *state_shape.add_dim() = value_shape.dim(3); // head_dim_v + updateOutputShape(ctx, 1, state_shape); + } + })); + } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h index 6c20aae94d132..c553bf4a3718d 100644 --- a/onnxruntime/core/graph/contrib_ops/ms_opset.h +++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h @@ -88,6 +88,8 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QMoE); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MultiHeadAttention); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GroupQueryAttention); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, PagedAttention); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, LinearAttentionRecurrent); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, LinearAttentionChunkParallel); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MurmurHash3); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, NGramRepeatBlock); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Pad); @@ -199,6 +201,8 @@ class OpSet_Microsoft_ver1 { fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); diff --git a/onnxruntime/test/contrib_ops/linear_attention_op_test.cc b/onnxruntime/test/contrib_ops/linear_attention_op_test.cc new file mode 100644 index 0000000000000..2b128d8b3268a --- /dev/null +++ b/onnxruntime/test/contrib_ops/linear_attention_op_test.cc @@ -0,0 +1,700 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include "gtest/gtest.h" +#include "core/session/onnxruntime_cxx_api.h" +#include "test/common/tensor_op_test_utils.h" +#include "test/common/cuda_op_test_utils.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime { +namespace test { + +namespace { +enum class TensorType { + kFloat, + kFloat16 +}; + +// Reference implementation for linear attention recurrent update +void LinearAttentionRecurrentReference( + const std::vector& query, + const std::vector& key, + const std::vector& value, + const std::vector& past_state, + const std::vector* decay, + const std::vector* beta, + std::vector& output, + std::vector& present_state, + int batch_size, + int num_heads, + int head_dim_k, + int head_dim_v, + const std::string& update_rule, + float scale) { + if (scale == 0.0f) { + scale = 1.0f / std::sqrt(static_cast(head_dim_k)); + } + + // Copy past_state to present_state first + present_state = past_state; + + output.resize(batch_size * num_heads * head_dim_v); + + for (int b = 0; b < batch_size; ++b) { + for (int h = 0; h < num_heads; ++h) { + int bh = b * num_heads + h; + int state_base = bh * head_dim_k * head_dim_v; + int qkv_base = bh * head_dim_k; + int v_base = bh * head_dim_v; + + // Apply decay if gated or gated_delta + if (update_rule == "gated" || update_rule == "gated_delta") { + for (int k = 0; k < head_dim_k; ++k) { + float g = (*decay)[qkv_base + k]; + float exp_g = std::exp(g); + for (int v = 0; v < head_dim_v; ++v) { + present_state[state_base + k * head_dim_v + v] *= exp_g; + } + } + } + + // Compute update + if (update_rule == "linear" || update_rule == "gated") { + // S += k ⊗ v + for (int k = 0; k < head_dim_k; ++k) { + float k_val = key[qkv_base + k]; + for (int v = 0; v < head_dim_v; ++v) { + float v_val = value[v_base + v]; + present_state[state_base + k * head_dim_v + v] += k_val * v_val; + } + } + } else if (update_rule == "delta" || update_rule == "gated_delta") { + // Compute retrieved = S^T @ k + std::vector retrieved(head_dim_v, 0.0f); + for (int v = 0; v < head_dim_v; ++v) { + for (int k = 0; k < head_dim_k; ++k) { + float k_val = key[qkv_base + k]; + // For gated_delta, retrieval uses decayed state (already applied above) + // For delta, uses original past_state + float s_val = (update_rule == "gated_delta") + ? present_state[state_base + k * head_dim_v + v] + : past_state[state_base + k * head_dim_v + v]; + retrieved[v] += s_val * k_val; + } + } + + // Compute delta and update + float beta_val = (*beta)[bh]; + for (int k = 0; k < head_dim_k; ++k) { + float k_val = key[qkv_base + k]; + for (int v = 0; v < head_dim_v; ++v) { + float v_val = value[v_base + v]; + float delta = beta_val * (v_val - retrieved[v]); + present_state[state_base + k * head_dim_v + v] += k_val * delta; + } + } + } + + // Compute output = scale * q^T @ S + for (int v = 0; v < head_dim_v; ++v) { + float out_val = 0.0f; + for (int k = 0; k < head_dim_k; ++k) { + float q_val = query[qkv_base + k]; + out_val += q_val * present_state[state_base + k * head_dim_v + v]; + } + output[v_base + v] = out_val * scale; + } + } + } +} + +} // anonymous namespace + +static void RunLinearAttentionRecurrentTest( + const std::vector& query_data, + const std::vector& key_data, + const std::vector& value_data, + const std::vector& past_state_data, + const std::vector* decay_data, + const std::vector* beta_data, + const std::vector& expected_output, + const std::vector& expected_state, + int batch_size, + int num_heads, + int head_dim_k, + int head_dim_v, + const std::string& update_rule, + float scale, + TensorType tensor_type) { + std::vector query_shape = {batch_size, num_heads, 1, head_dim_k}; + std::vector key_shape = {batch_size, num_heads, 1, head_dim_k}; + std::vector value_shape = {batch_size, num_heads, 1, head_dim_v}; + std::vector state_shape = {batch_size, num_heads, head_dim_k, head_dim_v}; + std::vector decay_shape = {batch_size, num_heads, 1, head_dim_k}; + std::vector beta_shape = {batch_size, num_heads, 1, 1}; + std::vector output_shape = {batch_size, num_heads, 1, head_dim_v}; + + std::string op_type = "LinearAttentionRecurrent"; + std::vector> execution_providers; + + bool enable_webgpu = nullptr != DefaultWebGpuExecutionProvider().get(); + + if (enable_webgpu) { + execution_providers.push_back(DefaultWebGpuExecutionProvider()); + } + + if (execution_providers.empty()) { + // Skip if no providers available + return; + } + + for (auto& ep : execution_providers) { + OpTester test(op_type.c_str(), 1, onnxruntime::kMSDomain); + test.AddAttribute("update_rule", update_rule); + test.AddAttribute("scale", scale); + + if (tensor_type == TensorType::kFloat) { + test.AddInput("query", query_shape, query_data); + test.AddInput("key", key_shape, key_data); + test.AddInput("value", value_shape, value_data); + test.AddInput("past_state", state_shape, past_state_data); + + if (decay_data != nullptr) { + test.AddInput("decay", decay_shape, *decay_data); + } else { + test.AddOptionalInputEdge(); + } + + if (beta_data != nullptr) { + test.AddInput("beta", beta_shape, *beta_data); + } else { + test.AddOptionalInputEdge(); + } + + test.AddOutput("output", output_shape, expected_output); + test.AddOutput("present_state", state_shape, expected_state); + } else { + test.AddInput("query", query_shape, ToFloat16(query_data)); + test.AddInput("key", key_shape, ToFloat16(key_data)); + test.AddInput("value", value_shape, ToFloat16(value_data)); + test.AddInput("past_state", state_shape, ToFloat16(past_state_data)); + + if (decay_data != nullptr) { + test.AddInput("decay", decay_shape, ToFloat16(*decay_data)); + } else { + test.AddOptionalInputEdge(); + } + + if (beta_data != nullptr) { + test.AddInput("beta", beta_shape, ToFloat16(*beta_data)); + } else { + test.AddOptionalInputEdge(); + } + + test.AddOutput("output", output_shape, ToFloat16(expected_output)); + test.AddOutput("present_state", state_shape, ToFloat16(expected_state)); + } + + test.SetOutputAbsErr("output", 0.01f); + test.SetOutputAbsErr("present_state", 0.01f); + + std::vector> test_execution_providers; + test_execution_providers.push_back(std::move(ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &test_execution_providers); + } +} + +static void RunLinearAttentionRecurrentTests( + const std::vector& query_data, + const std::vector& key_data, + const std::vector& value_data, + const std::vector& past_state_data, + const std::vector* decay_data, + const std::vector* beta_data, + int batch_size, + int num_heads, + int head_dim_k, + int head_dim_v, + const std::string& update_rule, + float scale = 0.0f) { + // Compute expected output using reference implementation + std::vector expected_output; + std::vector expected_state; + LinearAttentionRecurrentReference( + query_data, key_data, value_data, past_state_data, + decay_data, beta_data, + expected_output, expected_state, + batch_size, num_heads, head_dim_k, head_dim_v, + update_rule, scale); + + // FP32 test + RunLinearAttentionRecurrentTest( + query_data, key_data, value_data, past_state_data, + decay_data, beta_data, + expected_output, expected_state, + batch_size, num_heads, head_dim_k, head_dim_v, + update_rule, scale, TensorType::kFloat); + + // FP16 test + RunLinearAttentionRecurrentTest( + query_data, key_data, value_data, past_state_data, + decay_data, beta_data, + expected_output, expected_state, + batch_size, num_heads, head_dim_k, head_dim_v, + update_rule, scale, TensorType::kFloat16); +} + +// ============================================================================= +// LinearAttentionRecurrent Tests +// ============================================================================= + +TEST(ContribOpLinearAttentionTest, LinearAttentionRecurrent_Linear_Basic) { + int batch_size = 1; + int num_heads = 2; + int head_dim_k = 4; + int head_dim_v = 4; + + // Query: (1, 2, 1, 4) + std::vector query_data = { + 0.5f, 0.3f, -0.2f, 0.1f, // head 0 + -0.4f, 0.6f, 0.2f, -0.3f // head 1 + }; + + // Key: (1, 2, 1, 4) + std::vector key_data = { + 0.1f, 0.2f, 0.3f, 0.4f, + 0.2f, -0.1f, 0.3f, 0.1f + }; + + // Value: (1, 2, 1, 4) + std::vector value_data = { + 0.4f, 0.3f, 0.2f, 0.1f, + -0.2f, 0.4f, 0.1f, 0.3f + }; + + // Past state: (1, 2, 4, 4) - initialized to small values + std::vector past_state_data(batch_size * num_heads * head_dim_k * head_dim_v, 0.1f); + + RunLinearAttentionRecurrentTests( + query_data, key_data, value_data, past_state_data, + nullptr, nullptr, + batch_size, num_heads, head_dim_k, head_dim_v, + "linear"); +} + +TEST(ContribOpLinearAttentionTest, LinearAttentionRecurrent_Gated_Basic) { + int batch_size = 1; + int num_heads = 2; + int head_dim_k = 4; + int head_dim_v = 4; + + std::vector query_data = { + 0.5f, 0.3f, -0.2f, 0.1f, + -0.4f, 0.6f, 0.2f, -0.3f + }; + + std::vector key_data = { + 0.1f, 0.2f, 0.3f, 0.4f, + 0.2f, -0.1f, 0.3f, 0.1f + }; + + std::vector value_data = { + 0.4f, 0.3f, 0.2f, 0.1f, + -0.2f, 0.4f, 0.1f, 0.3f + }; + + std::vector past_state_data(batch_size * num_heads * head_dim_k * head_dim_v, 0.1f); + + // Decay: (1, 2, 1, 4) - negative values for decay + std::vector decay_data = { + -0.1f, -0.1f, -0.1f, -0.1f, + -0.2f, -0.2f, -0.2f, -0.2f + }; + + RunLinearAttentionRecurrentTests( + query_data, key_data, value_data, past_state_data, + &decay_data, nullptr, + batch_size, num_heads, head_dim_k, head_dim_v, + "gated"); +} + +TEST(ContribOpLinearAttentionTest, LinearAttentionRecurrent_Delta_Basic) { + int batch_size = 1; + int num_heads = 2; + int head_dim_k = 4; + int head_dim_v = 4; + + std::vector query_data = { + 0.5f, 0.3f, -0.2f, 0.1f, + -0.4f, 0.6f, 0.2f, -0.3f + }; + + // L2-normalized keys for delta rule + std::vector key_data = { + 0.1826f, 0.3651f, 0.5477f, 0.7303f, // normalized + 0.5345f, -0.2673f, 0.8018f, 0.2673f // normalized + }; + + std::vector value_data = { + 0.4f, 0.3f, 0.2f, 0.1f, + -0.2f, 0.4f, 0.1f, 0.3f + }; + + std::vector past_state_data(batch_size * num_heads * head_dim_k * head_dim_v, 0.1f); + + // Beta: (1, 2, 1, 1) + std::vector beta_data = {0.5f, 0.7f}; + + RunLinearAttentionRecurrentTests( + query_data, key_data, value_data, past_state_data, + nullptr, &beta_data, + batch_size, num_heads, head_dim_k, head_dim_v, + "delta"); +} + +TEST(ContribOpLinearAttentionTest, LinearAttentionRecurrent_GatedDelta_Basic) { + int batch_size = 1; + int num_heads = 2; + int head_dim_k = 4; + int head_dim_v = 4; + + std::vector query_data = { + 0.5f, 0.3f, -0.2f, 0.1f, + -0.4f, 0.6f, 0.2f, -0.3f + }; + + // L2-normalized keys + std::vector key_data = { + 0.1826f, 0.3651f, 0.5477f, 0.7303f, + 0.5345f, -0.2673f, 0.8018f, 0.2673f + }; + + std::vector value_data = { + 0.4f, 0.3f, 0.2f, 0.1f, + -0.2f, 0.4f, 0.1f, 0.3f + }; + + std::vector past_state_data(batch_size * num_heads * head_dim_k * head_dim_v, 0.1f); + + // Decay: (1, 2, 1, 4) + std::vector decay_data = { + -0.1f, -0.1f, -0.1f, -0.1f, + -0.2f, -0.2f, -0.2f, -0.2f + }; + + // Beta: (1, 2, 1, 1) + std::vector beta_data = {0.5f, 0.7f}; + + RunLinearAttentionRecurrentTests( + query_data, key_data, value_data, past_state_data, + &decay_data, &beta_data, + batch_size, num_heads, head_dim_k, head_dim_v, + "gated_delta"); +} + +TEST(ContribOpLinearAttentionTest, LinearAttentionRecurrent_LargerBatch) { + int batch_size = 2; + int num_heads = 4; + int head_dim_k = 8; + int head_dim_v = 8; + + int qkv_size = batch_size * num_heads * head_dim_k; + int value_size = batch_size * num_heads * head_dim_v; + int state_size = batch_size * num_heads * head_dim_k * head_dim_v; + + // Generate random-ish data + std::vector query_data(qkv_size); + std::vector key_data(qkv_size); + std::vector value_data(value_size); + std::vector past_state_data(state_size); + std::vector decay_data(qkv_size); + std::vector beta_data(batch_size * num_heads); + + for (int i = 0; i < qkv_size; ++i) { + query_data[i] = 0.1f * (i % 10 - 5); + key_data[i] = 0.1f * ((i + 3) % 10 - 5); + decay_data[i] = -0.1f * ((i % 3) + 1); + } + for (int i = 0; i < value_size; ++i) { + value_data[i] = 0.1f * ((i + 7) % 10 - 5); + } + for (int i = 0; i < state_size; ++i) { + past_state_data[i] = 0.05f * (i % 10 - 5); + } + for (int i = 0; i < batch_size * num_heads; ++i) { + beta_data[i] = 0.3f + 0.1f * (i % 5); + } + + RunLinearAttentionRecurrentTests( + query_data, key_data, value_data, past_state_data, + &decay_data, &beta_data, + batch_size, num_heads, head_dim_k, head_dim_v, + "gated_delta"); +} + +// ============================================================================= +// LinearAttentionChunkParallel Tests +// ============================================================================= + +static void RunLinearAttentionChunkParallelTest( + const std::vector& query_data, + const std::vector& key_data, + const std::vector& value_data, + const std::vector* initial_state_data, + const std::vector* decay_data, + const std::vector* beta_data, + int batch_size, + int num_heads, + int seq_length, + int head_dim_k, + int head_dim_v, + const std::string& update_rule, + int64_t chunk_size, + float scale, + TensorType tensor_type) { + std::vector query_shape = {batch_size, num_heads, seq_length, head_dim_k}; + std::vector key_shape = {batch_size, num_heads, seq_length, head_dim_k}; + std::vector value_shape = {batch_size, num_heads, seq_length, head_dim_v}; + std::vector state_shape = {batch_size, num_heads, head_dim_k, head_dim_v}; + std::vector decay_shape = {batch_size, num_heads, seq_length, head_dim_k}; + std::vector beta_shape = {batch_size, num_heads, seq_length, 1}; + std::vector output_shape = {batch_size, num_heads, seq_length, head_dim_v}; + + std::string op_type = "LinearAttentionChunkParallel"; + std::vector> execution_providers; + + bool enable_webgpu = nullptr != DefaultWebGpuExecutionProvider().get(); + + if (enable_webgpu) { + execution_providers.push_back(DefaultWebGpuExecutionProvider()); + } + + if (execution_providers.empty()) { + return; + } + + for (auto& ep : execution_providers) { + OpTester test(op_type.c_str(), 1, onnxruntime::kMSDomain); + test.AddAttribute("update_rule", update_rule); + test.AddAttribute("chunk_size", chunk_size); + test.AddAttribute("scale", scale); + + if (tensor_type == TensorType::kFloat) { + test.AddInput("query", query_shape, query_data); + test.AddInput("key", key_shape, key_data); + test.AddInput("value", value_shape, value_data); + + if (initial_state_data != nullptr) { + test.AddInput("initial_state", state_shape, *initial_state_data); + } else { + test.AddOptionalInputEdge(); + } + + if (decay_data != nullptr) { + test.AddInput("decay", decay_shape, *decay_data); + } else { + test.AddOptionalInputEdge(); + } + + if (beta_data != nullptr) { + test.AddInput("beta", beta_shape, *beta_data); + } else { + test.AddOptionalInputEdge(); + } + + // We just check that the output has the right shape and doesn't crash + // Full numerical verification is complex due to chunk parallel algorithm + test.AddOutput("output", output_shape, std::vector(batch_size * num_heads * seq_length * head_dim_v, 0.0f), false); + test.AddOutput("final_state", state_shape, std::vector(batch_size * num_heads * head_dim_k * head_dim_v, 0.0f), false); + } else { + test.AddInput("query", query_shape, ToFloat16(query_data)); + test.AddInput("key", key_shape, ToFloat16(key_data)); + test.AddInput("value", value_shape, ToFloat16(value_data)); + + if (initial_state_data != nullptr) { + test.AddInput("initial_state", state_shape, ToFloat16(*initial_state_data)); + } else { + test.AddOptionalInputEdge(); + } + + if (decay_data != nullptr) { + test.AddInput("decay", decay_shape, ToFloat16(*decay_data)); + } else { + test.AddOptionalInputEdge(); + } + + if (beta_data != nullptr) { + test.AddInput("beta", beta_shape, ToFloat16(*beta_data)); + } else { + test.AddOptionalInputEdge(); + } + + test.AddOutput("output", output_shape, std::vector(batch_size * num_heads * seq_length * head_dim_v), false); + test.AddOutput("final_state", state_shape, std::vector(batch_size * num_heads * head_dim_k * head_dim_v), false); + } + + std::vector> test_execution_providers; + test_execution_providers.push_back(std::move(ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &test_execution_providers); + } +} + +TEST(ContribOpLinearAttentionTest, LinearAttentionChunkParallel_Linear_Basic) { + int batch_size = 1; + int num_heads = 2; + int seq_length = 8; + int head_dim_k = 4; + int head_dim_v = 4; + + int qkv_size = batch_size * num_heads * seq_length * head_dim_k; + int value_size = batch_size * num_heads * seq_length * head_dim_v; + + std::vector query_data(qkv_size); + std::vector key_data(qkv_size); + std::vector value_data(value_size); + + for (int i = 0; i < qkv_size; ++i) { + query_data[i] = 0.1f * (i % 10 - 5); + key_data[i] = 0.1f * ((i + 3) % 10 - 5); + } + for (int i = 0; i < value_size; ++i) { + value_data[i] = 0.1f * ((i + 7) % 10 - 5); + } + + RunLinearAttentionChunkParallelTest( + query_data, key_data, value_data, + nullptr, nullptr, nullptr, + batch_size, num_heads, seq_length, head_dim_k, head_dim_v, + "linear", 4, 0.0f, TensorType::kFloat); +} + +TEST(ContribOpLinearAttentionTest, LinearAttentionChunkParallel_Gated_Basic) { + int batch_size = 1; + int num_heads = 2; + int seq_length = 8; + int head_dim_k = 4; + int head_dim_v = 4; + + int qkv_size = batch_size * num_heads * seq_length * head_dim_k; + int value_size = batch_size * num_heads * seq_length * head_dim_v; + int decay_size = batch_size * num_heads * seq_length * head_dim_k; + + std::vector query_data(qkv_size); + std::vector key_data(qkv_size); + std::vector value_data(value_size); + std::vector decay_data(decay_size); + + for (int i = 0; i < qkv_size; ++i) { + query_data[i] = 0.1f * (i % 10 - 5); + key_data[i] = 0.1f * ((i + 3) % 10 - 5); + } + for (int i = 0; i < value_size; ++i) { + value_data[i] = 0.1f * ((i + 7) % 10 - 5); + } + for (int i = 0; i < decay_size; ++i) { + decay_data[i] = -0.1f * ((i % 3) + 1); + } + + RunLinearAttentionChunkParallelTest( + query_data, key_data, value_data, + nullptr, &decay_data, nullptr, + batch_size, num_heads, seq_length, head_dim_k, head_dim_v, + "gated", 4, 0.0f, TensorType::kFloat); +} + +TEST(ContribOpLinearAttentionTest, LinearAttentionChunkParallel_GatedDelta_WithInitialState) { + int batch_size = 1; + int num_heads = 2; + int seq_length = 16; + int head_dim_k = 4; + int head_dim_v = 4; + + int qkv_size = batch_size * num_heads * seq_length * head_dim_k; + int value_size = batch_size * num_heads * seq_length * head_dim_v; + int state_size = batch_size * num_heads * head_dim_k * head_dim_v; + int decay_size = batch_size * num_heads * seq_length * head_dim_k; + int beta_size = batch_size * num_heads * seq_length; + + std::vector query_data(qkv_size); + std::vector key_data(qkv_size); + std::vector value_data(value_size); + std::vector initial_state_data(state_size); + std::vector decay_data(decay_size); + std::vector beta_data(beta_size); + + for (int i = 0; i < qkv_size; ++i) { + query_data[i] = 0.1f * (i % 10 - 5); + key_data[i] = 0.1f * ((i + 3) % 10 - 5); + } + for (int i = 0; i < value_size; ++i) { + value_data[i] = 0.1f * ((i + 7) % 10 - 5); + } + for (int i = 0; i < state_size; ++i) { + initial_state_data[i] = 0.05f; + } + for (int i = 0; i < decay_size; ++i) { + decay_data[i] = -0.1f * ((i % 3) + 1); + } + for (int i = 0; i < beta_size; ++i) { + beta_data[i] = 0.5f; + } + + RunLinearAttentionChunkParallelTest( + query_data, key_data, value_data, + &initial_state_data, &decay_data, &beta_data, + batch_size, num_heads, seq_length, head_dim_k, head_dim_v, + "gated_delta", 8, 0.0f, TensorType::kFloat); +} + +TEST(ContribOpLinearAttentionTest, LinearAttentionChunkParallel_LargerSequence) { + int batch_size = 2; + int num_heads = 4; + int seq_length = 64; + int head_dim_k = 8; + int head_dim_v = 8; + + int qkv_size = batch_size * num_heads * seq_length * head_dim_k; + int value_size = batch_size * num_heads * seq_length * head_dim_v; + int decay_size = batch_size * num_heads * seq_length * head_dim_k; + int beta_size = batch_size * num_heads * seq_length; + + std::vector query_data(qkv_size); + std::vector key_data(qkv_size); + std::vector value_data(value_size); + std::vector decay_data(decay_size); + std::vector beta_data(beta_size); + + for (int i = 0; i < qkv_size; ++i) { + query_data[i] = 0.05f * (i % 20 - 10); + key_data[i] = 0.05f * ((i + 7) % 20 - 10); + } + for (int i = 0; i < value_size; ++i) { + value_data[i] = 0.05f * ((i + 13) % 20 - 10); + } + for (int i = 0; i < decay_size; ++i) { + decay_data[i] = -0.05f * ((i % 5) + 1); + } + for (int i = 0; i < beta_size; ++i) { + beta_data[i] = 0.3f + 0.1f * (i % 5); + } + + RunLinearAttentionChunkParallelTest( + query_data, key_data, value_data, + nullptr, &decay_data, &beta_data, + batch_size, num_heads, seq_length, head_dim_k, head_dim_v, + "gated_delta", 16, 0.0f, TensorType::kFloat); + + // Also test FP16 + RunLinearAttentionChunkParallelTest( + query_data, key_data, value_data, + nullptr, &decay_data, &beta_data, + batch_size, num_heads, seq_length, head_dim_k, head_dim_v, + "gated_delta", 16, 0.0f, TensorType::kFloat16); +} + +} // namespace test +} // namespace onnxruntime From 23e744d95155ba3ac5ddf00a0d6b74ba28042e76 Mon Sep 17 00:00:00 2001 From: gs Date: Wed, 11 Mar 2026 09:38:55 -0700 Subject: [PATCH 02/27] ut is now passing --- .../webgpu/bert/linear_attention.cc | 146 ++++++++++++++++++ .../webgpu/bert/linear_attention.h | 33 +++- .../contrib_ops/linear_attention_op_test.cc | 129 +++++++++++++++- 3 files changed, 300 insertions(+), 8 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc index 1e49b3309417b..0bd3b2842f9c0 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc @@ -591,6 +591,116 @@ Status LinearAttentionChunkInterProgram::GenerateShaderCode(ShaderHelper& shader return Status::OK(); } +Status LinearAttentionFullSequentialProgram::GenerateShaderCode(ShaderHelper& shader) const { + // Full sequential computation for delta/gated_delta update rules. + // These rules have state updates that depend on the current state (S^T k term), + // making chunk-parallel decomposition incorrect. + shader.AddInput("query", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias); + shader.AddInput("key", ShaderUsage::UseUniform); + shader.AddInput("value", ShaderUsage::UseUniform); + + if (has_initial_state_) { + shader.AddInput("initial_state", ShaderUsage::UseUniform); + } + if (has_decay_) { + shader.AddInput("decay", ShaderUsage::UseUniform); + } + if (has_beta_) { + shader.AddInput("beta", ShaderUsage::UseUniform); + } + + shader.AddOutput("output", ShaderUsage::UseUniform); + shader.AddOutput("final_state", ShaderUsage::UseUniform); + + shader.MainFunctionBody() << R"SHADER( + let batch_idx = workgroup_id.x; + let head_idx = workgroup_id.y; + + if (batch_idx >= uniforms.batch_size || head_idx >= uniforms.num_heads) { + return; + } + + let dk = uniforms.head_dim_k; + let dv = uniforms.head_dim_v; + let seq_len = uniforms.sequence_length; + let scale_val = query_element_t(select(1.0 / sqrt(f32(dk)), uniforms.scale, uniforms.scale != 0.0)); + let bh = batch_idx * uniforms.num_heads + head_idx; + let state_size = dk * dv; + + // Initialize state array (supports up to 32x32 head dimensions) + var state: array; + for (var i = 0u; i < state_size; i = i + 1u) { + state[i] = query_element_t(0.0); + } +)SHADER"; + + if (has_initial_state_) { + shader.MainFunctionBody() << R"SHADER( + // Load initial state + let init_base = bh * state_size; + for (var i = 0u; i < state_size; i = i + 1u) { + state[i] = query_element_t(initial_state[init_base + i]); + } +)SHADER"; + } + + shader.MainFunctionBody() << R"SHADER( + // Process each timestep sequentially + for (var t = 0u; t < seq_len; t = t + 1u) { + let qk_base = (bh * seq_len + t) * dk; + let v_base = (bh * seq_len + t) * dv; +)SHADER"; + + if (has_decay_) { + shader.MainFunctionBody() << R"SHADER( + // Apply decay: state *= exp(decay) + for (var ki = 0u; ki < dk; ki = ki + 1u) { + let exp_g = query_element_t(exp(decay[qk_base + ki])); + for (var vi = 0u; vi < dv; vi = vi + 1u) { + state[ki * dv + vi] = state[ki * dv + vi] * exp_g; + } + } +)SHADER"; + } + + shader.MainFunctionBody() << R"SHADER( + // Delta update: S += beta * k \u2297 (v - S^T k) + let beta_val = query_element_t(beta[bh * seq_len + t]); + for (var vi = 0u; vi < dv; vi = vi + 1u) { + // Compute retrieved = S^T @ k for this v dimension + var retrieved = query_element_t(0.0); + for (var ki = 0u; ki < dk; ki = ki + 1u) { + retrieved = retrieved + state[ki * dv + vi] * query_element_t(key[qk_base + ki]); + } + let v_val = query_element_t(value[v_base + vi]); + let delta_val = beta_val * (v_val - retrieved); + + for (var ki = 0u; ki < dk; ki = ki + 1u) { + state[ki * dv + vi] = state[ki * dv + vi] + query_element_t(key[qk_base + ki]) * delta_val; + } + } + + // Compute output: o = scale * q^T @ state + let out_base = (bh * seq_len + t) * dv; + for (var vi = 0u; vi < dv; vi = vi + 1u) { + var out_val = query_element_t(0.0); + for (var ki = 0u; ki < dk; ki = ki + 1u) { + out_val = out_val + query_element_t(query[qk_base + ki]) * state[ki * dv + vi]; + } + output[out_base + vi] = out_val * scale_val; + } + } + + // Write final state + let final_base = bh * state_size; + for (var i = 0u; i < state_size; i = i + 1u) { + final_state[final_base + i] = state[i]; + } +)SHADER"; + + return Status::OK(); +} + Status LinearAttentionChunkParallel::ComputeInternal(ComputeContext& context) const { const auto* query = context.Input(0); const auto* key = context.Input(1); @@ -632,6 +742,42 @@ Status LinearAttentionChunkParallel::ComputeInternal(ComputeContext& context) co auto* output = context.Output(0, output_shape); auto* final_state = context.Output(1, state_shape); + // For delta/gated_delta rules, use sequential computation. + // Chunk-parallel decomposition doesn't work because state updates depend on the + // running state through the S^T k term, making chunks non-independent. + if (update_rule_ == LinearAttentionUpdateRule::Delta || update_rule_ == LinearAttentionUpdateRule::GatedDelta) { + LinearAttentionFullSequentialProgram program{update_rule_, has_decay, has_beta, has_initial_state}; + + program.AddInputs({{query, ProgramTensorMetadataDependency::TypeAndRank}, + {key, ProgramTensorMetadataDependency::TypeAndRank}, + {value, ProgramTensorMetadataDependency::TypeAndRank}}); + + if (has_initial_state) { + program.AddInput({initial_state, ProgramTensorMetadataDependency::TypeAndRank}); + } + if (has_decay) { + program.AddInput({decay, ProgramTensorMetadataDependency::TypeAndRank}); + } + if (has_beta) { + program.AddInput({beta, ProgramTensorMetadataDependency::TypeAndRank}); + } + + program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank}, + {final_state, ProgramTensorMetadataDependency::TypeAndRank}}); + + program.SetDispatchGroupSize(batch_size, num_heads, 1) + .SetWorkgroupSize(1, 1, 1) + .AddUniformVariables({{batch_size}, + {num_heads}, + {seq_length}, + {head_dim_k}, + {head_dim_v}, + {scale_}}); + + return context.RunProgram(program); + } + + // Linear/Gated rules: Use two-phase chunk-parallel approach // Allocate intermediate tensors for chunk computation TensorShape chunk_states_shape({static_cast(batch_size), static_cast(num_heads), static_cast(num_chunks), static_cast(head_dim_k), diff --git a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.h b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.h index c1c023583d95e..34e417d9a76e3 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.h @@ -82,7 +82,7 @@ class LinearAttentionChunkIntraProgram final : public Program { + public: + LinearAttentionFullSequentialProgram(LinearAttentionUpdateRule update_rule, bool has_decay, bool has_beta, bool has_initial_state) + : Program{"LinearAttentionFullSequential"}, + update_rule_(update_rule), + has_decay_(has_decay), + has_beta_(has_beta), + has_initial_state_(has_initial_state) {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"batch_size", ProgramUniformVariableDataType::Uint32}, + {"num_heads", ProgramUniformVariableDataType::Uint32}, + {"sequence_length", ProgramUniformVariableDataType::Uint32}, + {"head_dim_k", ProgramUniformVariableDataType::Uint32}, + {"head_dim_v", ProgramUniformVariableDataType::Uint32}, + {"scale", ProgramUniformVariableDataType::Float32}); + + private: + [[maybe_unused]] LinearAttentionUpdateRule update_rule_; bool has_decay_; bool has_beta_; bool has_initial_state_; diff --git a/onnxruntime/test/contrib_ops/linear_attention_op_test.cc b/onnxruntime/test/contrib_ops/linear_attention_op_test.cc index 2b128d8b3268a..e17031a98b300 100644 --- a/onnxruntime/test/contrib_ops/linear_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/linear_attention_op_test.cc @@ -112,6 +112,112 @@ void LinearAttentionRecurrentReference( } } +// Reference implementation for linear attention chunk parallel (full sequence) +// This is the sequential version that processes one step at a time. +void LinearAttentionChunkParallelReference( + const std::vector& query, + const std::vector& key, + const std::vector& value, + const std::vector* initial_state, + const std::vector* decay, + const std::vector* beta, + std::vector& output, + std::vector& final_state, + int batch_size, + int num_heads, + int seq_length, + int head_dim_k, + int head_dim_v, + const std::string& update_rule, + float scale) { + if (scale == 0.0f) { + scale = 1.0f / std::sqrt(static_cast(head_dim_k)); + } + + output.resize(batch_size * num_heads * seq_length * head_dim_v); + final_state.resize(batch_size * num_heads * head_dim_k * head_dim_v); + + int state_size = head_dim_k * head_dim_v; + + for (int b = 0; b < batch_size; ++b) { + for (int h = 0; h < num_heads; ++h) { + int bh = b * num_heads + h; + + // Initialize state + std::vector state(state_size, 0.0f); + if (initial_state != nullptr) { + int init_base = bh * state_size; + for (int i = 0; i < state_size; ++i) { + state[i] = (*initial_state)[init_base + i]; + } + } + + // Process each timestep sequentially + for (int t = 0; t < seq_length; ++t) { + int qk_base = (bh * seq_length + t) * head_dim_k; + int v_base = (bh * seq_length + t) * head_dim_v; + + // 1. Apply decay if gated or gated_delta + if (update_rule == "gated" || update_rule == "gated_delta") { + for (int ki = 0; ki < head_dim_k; ++ki) { + float g = (*decay)[qk_base + ki]; + float exp_g = std::exp(g); + for (int vi = 0; vi < head_dim_v; ++vi) { + state[ki * head_dim_v + vi] *= exp_g; + } + } + } + + // 2. Update state + if (update_rule == "linear" || update_rule == "gated") { + // S += k ⊗ v + for (int ki = 0; ki < head_dim_k; ++ki) { + float k_val = key[qk_base + ki]; + for (int vi = 0; vi < head_dim_v; ++vi) { + float v_val = value[v_base + vi]; + state[ki * head_dim_v + vi] += k_val * v_val; + } + } + } else if (update_rule == "delta" || update_rule == "gated_delta") { + // Compute retrieved = S^T @ k + std::vector retrieved(head_dim_v, 0.0f); + for (int vi = 0; vi < head_dim_v; ++vi) { + for (int ki = 0; ki < head_dim_k; ++ki) { + retrieved[vi] += state[ki * head_dim_v + vi] * key[qk_base + ki]; + } + } + + float beta_val = (*beta)[bh * seq_length + t]; + for (int ki = 0; ki < head_dim_k; ++ki) { + float k_val = key[qk_base + ki]; + for (int vi = 0; vi < head_dim_v; ++vi) { + float v_val = value[v_base + vi]; + float delta_val = beta_val * (v_val - retrieved[vi]); + state[ki * head_dim_v + vi] += k_val * delta_val; + } + } + } + + // 3. Compute output = scale * q^T @ S + int out_base = (bh * seq_length + t) * head_dim_v; + for (int vi = 0; vi < head_dim_v; ++vi) { + float out_val = 0.0f; + for (int ki = 0; ki < head_dim_k; ++ki) { + out_val += query[qk_base + ki] * state[ki * head_dim_v + vi]; + } + output[out_base + vi] = out_val * scale; + } + } + + // Copy final state + int final_base = bh * state_size; + for (int i = 0; i < state_size; ++i) { + final_state[final_base + i] = state[i]; + } + } + } +} + } // anonymous namespace static void RunLinearAttentionRecurrentTest( @@ -464,6 +570,16 @@ static void RunLinearAttentionChunkParallelTest( std::vector beta_shape = {batch_size, num_heads, seq_length, 1}; std::vector output_shape = {batch_size, num_heads, seq_length, head_dim_v}; + // Compute reference expected output + std::vector expected_output; + std::vector expected_state; + LinearAttentionChunkParallelReference( + query_data, key_data, value_data, + initial_state_data, decay_data, beta_data, + expected_output, expected_state, + batch_size, num_heads, seq_length, head_dim_k, head_dim_v, + update_rule, scale); + std::string op_type = "LinearAttentionChunkParallel"; std::vector> execution_providers; @@ -506,10 +622,8 @@ static void RunLinearAttentionChunkParallelTest( test.AddOptionalInputEdge(); } - // We just check that the output has the right shape and doesn't crash - // Full numerical verification is complex due to chunk parallel algorithm - test.AddOutput("output", output_shape, std::vector(batch_size * num_heads * seq_length * head_dim_v, 0.0f), false); - test.AddOutput("final_state", state_shape, std::vector(batch_size * num_heads * head_dim_k * head_dim_v, 0.0f), false); + test.AddOutput("output", output_shape, expected_output); + test.AddOutput("final_state", state_shape, expected_state); } else { test.AddInput("query", query_shape, ToFloat16(query_data)); test.AddInput("key", key_shape, ToFloat16(key_data)); @@ -533,10 +647,13 @@ static void RunLinearAttentionChunkParallelTest( test.AddOptionalInputEdge(); } - test.AddOutput("output", output_shape, std::vector(batch_size * num_heads * seq_length * head_dim_v), false); - test.AddOutput("final_state", state_shape, std::vector(batch_size * num_heads * head_dim_k * head_dim_v), false); + test.AddOutput("output", output_shape, ToFloat16(expected_output)); + test.AddOutput("final_state", state_shape, ToFloat16(expected_state)); } + test.SetOutputAbsErr("output", 0.01f); + test.SetOutputAbsErr("final_state", 0.01f); + std::vector> test_execution_providers; test_execution_providers.push_back(std::move(ep)); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &test_execution_providers); From b72a1508f0203f41090c905de45954c25c996296 Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Thu, 12 Mar 2026 16:41:40 -0700 Subject: [PATCH 03/27] merge LinearAttentionRecurrentProgram and LinearAttentionRecurrentProgram --- .../contrib_ops/webgpu/bert/linear_attention.cc | 5 +---- .../contrib_ops/webgpu/bert/linear_attention.h | 13 ++++--------- 2 files changed, 5 insertions(+), 13 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc index 0bd3b2842f9c0..591b75272227b 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc @@ -263,11 +263,8 @@ ONNX_OPERATOR_KERNEL_EX( LinearAttentionChunkParallel); LinearAttentionChunkParallel::LinearAttentionChunkParallel(const OpKernelInfo& info) - : WebGpuKernel(info) { - std::string update_rule_str = info.GetAttrOrDefault("update_rule", "gated_delta"); - update_rule_ = ParseUpdateRule(update_rule_str); + : LinearAttentionRecurrent(info) { chunk_size_ = info.GetAttrOrDefault("chunk_size", 64); - scale_ = info.GetAttrOrDefault("scale", 0.0f); } Status LinearAttentionChunkIntraProgram::GenerateShaderCode(ShaderHelper& shader) const { diff --git a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.h b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.h index 34e417d9a76e3..b807e729b6c0d 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.h @@ -50,14 +50,15 @@ class LinearAttentionRecurrentProgram final : public Program Date: Thu, 12 Mar 2026 17:58:25 -0700 Subject: [PATCH 04/27] webgpu support for rmsnorm --- .../webgpu/bert/linear_attention.cc | 344 ++++++++---------- .../core/providers/webgpu/nn/rms_norm.cc | 121 ++++++ .../core/providers/webgpu/nn/rms_norm.h | 26 ++ .../webgpu/webgpu_execution_provider.cc | 2 + 4 files changed, 306 insertions(+), 187 deletions(-) create mode 100644 onnxruntime/core/providers/webgpu/nn/rms_norm.cc create mode 100644 onnxruntime/core/providers/webgpu/nn/rms_norm.h diff --git a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc index 591b75272227b..973bce0f5dda3 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc @@ -44,6 +44,7 @@ LinearAttentionRecurrent::LinearAttentionRecurrent(const OpKernelInfo& info) std::string update_rule_str = info.GetAttrOrDefault("update_rule", "gated_delta"); update_rule_ = ParseUpdateRule(update_rule_str); scale_ = info.GetAttrOrDefault("scale", 0.0f); + chunk_size_ = info.GetAttrOrDefault("chunk_size", 64); } Status LinearAttentionRecurrentProgram::GenerateShaderCode(ShaderHelper& shader) const { @@ -187,22 +188,24 @@ Status LinearAttentionRecurrent::ComputeInternal(ComputeContext& context) const const auto* query = context.Input(0); const auto* key = context.Input(1); const auto* value = context.Input(2); - const auto* past_state = context.Input(3); - const auto* decay = context.Input(4); // Optional - const auto* beta = context.Input(5); // Optional + const auto* initial_state = context.Input(3); // past_state for recurrent, initial_state (optional) for chunk-parallel + const auto* decay = context.Input(4); // Optional + const auto* beta = context.Input(5); // Optional const auto& query_shape = query->Shape(); - ORT_ENFORCE(query_shape.NumDimensions() == 4, "Query must be 4D: (B, H, 1, d_k)"); + ORT_ENFORCE(query_shape.NumDimensions() == 4, "Query must be 4D: (B, H, L, d_k)"); const auto batch_size = static_cast(query_shape[0]); const auto num_heads = static_cast(query_shape[1]); + const auto seq_length = static_cast(query_shape[2]); const auto head_dim_k = static_cast(query_shape[3]); const auto head_dim_v = static_cast(value->Shape()[3]); - // Validate decay and beta based on update rule + bool has_initial_state = (initial_state != nullptr); bool has_decay = (decay != nullptr); bool has_beta = (beta != nullptr); + // Validate decay and beta based on update rule if (update_rule_ == LinearAttentionUpdateRule::Gated || update_rule_ == LinearAttentionUpdateRule::GatedDelta) { ORT_ENFORCE(has_decay, "Decay input is required for gated and gated_delta update rules"); } @@ -210,43 +213,165 @@ Status LinearAttentionRecurrent::ComputeInternal(ComputeContext& context) const ORT_ENFORCE(has_beta, "Beta input is required for delta and gated_delta update rules"); } - // Create output tensors - TensorShape output_shape({static_cast(batch_size), static_cast(num_heads), 1, static_cast(head_dim_v)}); + // seq_length == 1: single-step recurrent path + if (seq_length == 1) { + ORT_ENFORCE(has_initial_state, "past_state input is required for single-step recurrent mode"); + + TensorShape output_shape({static_cast(batch_size), static_cast(num_heads), 1, static_cast(head_dim_v)}); + auto* output = context.Output(0, output_shape); + auto* present_state = context.Output(1, initial_state->Shape()); + + LinearAttentionRecurrentProgram program{update_rule_, has_decay, has_beta}; + + program.AddInputs({{query, ProgramTensorMetadataDependency::TypeAndRank}, + {key, ProgramTensorMetadataDependency::TypeAndRank}, + {value, ProgramTensorMetadataDependency::TypeAndRank}, + {initial_state, ProgramTensorMetadataDependency::TypeAndRank}}); + + if (has_decay) { + program.AddInput({decay, ProgramTensorMetadataDependency::TypeAndRank}); + } + if (has_beta) { + program.AddInput({beta, ProgramTensorMetadataDependency::TypeAndRank}); + } + + program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank}, + {present_state, ProgramTensorMetadataDependency::TypeAndRank}}); + + const uint32_t workgroup_size_k = std::min(head_dim_k, 16u); + const uint32_t workgroup_size_v = std::min(head_dim_v, 16u); + + program.SetDispatchGroupSize(batch_size, num_heads, 1) + .SetWorkgroupSize(workgroup_size_k, workgroup_size_v, 1) + .AddUniformVariables({{batch_size}, + {num_heads}, + {head_dim_k}, + {head_dim_v}, + {scale_}}); + + return context.RunProgram(program); + } + + // seq_length > 1: chunk-parallel path + const uint32_t chunk_size = static_cast(chunk_size_); + const uint32_t num_chunks = (seq_length + chunk_size - 1) / chunk_size; + + TensorShape output_shape({static_cast(batch_size), static_cast(num_heads), + static_cast(seq_length), static_cast(head_dim_v)}); + TensorShape state_shape({static_cast(batch_size), static_cast(num_heads), + static_cast(head_dim_k), static_cast(head_dim_v)}); + auto* output = context.Output(0, output_shape); - auto* present_state = context.Output(1, past_state->Shape()); + auto* final_state = context.Output(1, state_shape); + + // For delta/gated_delta rules, use sequential computation. + // Chunk-parallel decomposition doesn't work because state updates depend on the + // running state through the S^T k term, making chunks non-independent. + if (update_rule_ == LinearAttentionUpdateRule::Delta || update_rule_ == LinearAttentionUpdateRule::GatedDelta) { + LinearAttentionFullSequentialProgram program{update_rule_, has_decay, has_beta, has_initial_state}; + + program.AddInputs({{query, ProgramTensorMetadataDependency::TypeAndRank}, + {key, ProgramTensorMetadataDependency::TypeAndRank}, + {value, ProgramTensorMetadataDependency::TypeAndRank}}); + + if (has_initial_state) { + program.AddInput({initial_state, ProgramTensorMetadataDependency::TypeAndRank}); + } + if (has_decay) { + program.AddInput({decay, ProgramTensorMetadataDependency::TypeAndRank}); + } + if (has_beta) { + program.AddInput({beta, ProgramTensorMetadataDependency::TypeAndRank}); + } - // Setup and run the program - LinearAttentionRecurrentProgram program{update_rule_, has_decay, has_beta}; + program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank}, + {final_state, ProgramTensorMetadataDependency::TypeAndRank}}); - program.AddInputs({{query, ProgramTensorMetadataDependency::TypeAndRank}, - {key, ProgramTensorMetadataDependency::TypeAndRank}, - {value, ProgramTensorMetadataDependency::TypeAndRank}, - {past_state, ProgramTensorMetadataDependency::TypeAndRank}}); + program.SetDispatchGroupSize(batch_size, num_heads, 1) + .SetWorkgroupSize(1, 1, 1) + .AddUniformVariables({{batch_size}, + {num_heads}, + {seq_length}, + {head_dim_k}, + {head_dim_v}, + {scale_}}); - if (has_decay) { - program.AddInput({decay, ProgramTensorMetadataDependency::TypeAndRank}); + return context.RunProgram(program); } - if (has_beta) { - program.AddInput({beta, ProgramTensorMetadataDependency::TypeAndRank}); + + // Linear/Gated rules: Use two-phase chunk-parallel approach + TensorShape chunk_states_shape({static_cast(batch_size), static_cast(num_heads), + static_cast(num_chunks), static_cast(head_dim_k), + static_cast(head_dim_v)}); + + Tensor intra_output_tensor = context.CreateGPUTensor(query->DataType(), output_shape); + Tensor chunk_states_tensor = context.CreateGPUTensor(query->DataType(), chunk_states_shape); + + // Step 1: Compute intra-chunk attention and per-chunk states + { + LinearAttentionChunkIntraProgram intra_program{update_rule_, has_decay, has_beta}; + + intra_program.AddInputs({{query, ProgramTensorMetadataDependency::TypeAndRank}, + {key, ProgramTensorMetadataDependency::TypeAndRank}, + {value, ProgramTensorMetadataDependency::TypeAndRank}}); + + if (has_decay) { + intra_program.AddInput({decay, ProgramTensorMetadataDependency::TypeAndRank}); + } + if (has_beta) { + intra_program.AddInput({beta, ProgramTensorMetadataDependency::TypeAndRank}); + } + + intra_program.AddOutputs({{&intra_output_tensor, ProgramTensorMetadataDependency::TypeAndRank}, + {&chunk_states_tensor, ProgramTensorMetadataDependency::TypeAndRank}}); + + intra_program.SetDispatchGroupSize(batch_size, num_heads, num_chunks) + .SetWorkgroupSize(64, 1, 1) + .AddUniformVariables({{batch_size}, + {num_heads}, + {seq_length}, + {head_dim_k}, + {head_dim_v}, + {chunk_size}, + {num_chunks}, + {scale_}}); + + ORT_RETURN_IF_ERROR(context.RunProgram(intra_program)); } - program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank}, - {present_state, ProgramTensorMetadataDependency::TypeAndRank}}); + // Step 2: Inter-chunk state propagation and final output computation + { + LinearAttentionChunkInterProgram inter_program{update_rule_, has_decay, has_beta, has_initial_state}; + + inter_program.AddInputs({{&intra_output_tensor, ProgramTensorMetadataDependency::TypeAndRank}, + {&chunk_states_tensor, ProgramTensorMetadataDependency::TypeAndRank}, + {query, ProgramTensorMetadataDependency::TypeAndRank}}); + + if (has_initial_state) { + inter_program.AddInput({initial_state, ProgramTensorMetadataDependency::TypeAndRank}); + } + if (has_decay) { + inter_program.AddInput({decay, ProgramTensorMetadataDependency::TypeAndRank}); + } + + inter_program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank}, + {final_state, ProgramTensorMetadataDependency::TypeAndRank}}); - // Dispatch: one workgroup per (batch, head), with threads for (k, v) elements - // Use a fixed workgroup size that can cover typical head dimensions - const uint32_t workgroup_size_k = std::min(head_dim_k, 16u); - const uint32_t workgroup_size_v = std::min(head_dim_v, 16u); + inter_program.SetDispatchGroupSize(batch_size, num_heads, 1) + .SetWorkgroupSize(256, 1, 1) + .AddUniformVariables({{batch_size}, + {num_heads}, + {seq_length}, + {head_dim_k}, + {head_dim_v}, + {chunk_size}, + {num_chunks}, + {scale_}}); - program.SetDispatchGroupSize(batch_size, num_heads, 1) - .SetWorkgroupSize(workgroup_size_k, workgroup_size_v, 1) - .AddUniformVariables({{batch_size}, - {num_heads}, - {head_dim_k}, - {head_dim_v}, - {scale_}}); + ORT_RETURN_IF_ERROR(context.RunProgram(inter_program)); + } - return context.RunProgram(program); + return Status::OK(); } // ============================================================================= @@ -264,7 +389,6 @@ ONNX_OPERATOR_KERNEL_EX( LinearAttentionChunkParallel::LinearAttentionChunkParallel(const OpKernelInfo& info) : LinearAttentionRecurrent(info) { - chunk_size_ = info.GetAttrOrDefault("chunk_size", 64); } Status LinearAttentionChunkIntraProgram::GenerateShaderCode(ShaderHelper& shader) const { @@ -698,160 +822,6 @@ Status LinearAttentionFullSequentialProgram::GenerateShaderCode(ShaderHelper& sh return Status::OK(); } -Status LinearAttentionChunkParallel::ComputeInternal(ComputeContext& context) const { - const auto* query = context.Input(0); - const auto* key = context.Input(1); - const auto* value = context.Input(2); - const auto* initial_state = context.Input(3); // Optional - const auto* decay = context.Input(4); // Optional - const auto* beta = context.Input(5); // Optional - - const auto& query_shape = query->Shape(); - ORT_ENFORCE(query_shape.NumDimensions() == 4, "Query must be 4D: (B, H, L, d_k)"); - - const auto batch_size = static_cast(query_shape[0]); - const auto num_heads = static_cast(query_shape[1]); - const auto seq_length = static_cast(query_shape[2]); - const auto head_dim_k = static_cast(query_shape[3]); - const auto head_dim_v = static_cast(value->Shape()[3]); - - bool has_initial_state = (initial_state != nullptr); - bool has_decay = (decay != nullptr); - bool has_beta = (beta != nullptr); - - // Validate inputs based on update rule - if (update_rule_ == LinearAttentionUpdateRule::Gated || update_rule_ == LinearAttentionUpdateRule::GatedDelta) { - ORT_ENFORCE(has_decay, "Decay input is required for gated and gated_delta update rules"); - } - if (update_rule_ == LinearAttentionUpdateRule::Delta || update_rule_ == LinearAttentionUpdateRule::GatedDelta) { - ORT_ENFORCE(has_beta, "Beta input is required for delta and gated_delta update rules"); - } - - const uint32_t chunk_size = static_cast(chunk_size_); - const uint32_t num_chunks = (seq_length + chunk_size - 1) / chunk_size; - - // Create output tensors - TensorShape output_shape({static_cast(batch_size), static_cast(num_heads), - static_cast(seq_length), static_cast(head_dim_v)}); - TensorShape state_shape({static_cast(batch_size), static_cast(num_heads), - static_cast(head_dim_k), static_cast(head_dim_v)}); - - auto* output = context.Output(0, output_shape); - auto* final_state = context.Output(1, state_shape); - - // For delta/gated_delta rules, use sequential computation. - // Chunk-parallel decomposition doesn't work because state updates depend on the - // running state through the S^T k term, making chunks non-independent. - if (update_rule_ == LinearAttentionUpdateRule::Delta || update_rule_ == LinearAttentionUpdateRule::GatedDelta) { - LinearAttentionFullSequentialProgram program{update_rule_, has_decay, has_beta, has_initial_state}; - - program.AddInputs({{query, ProgramTensorMetadataDependency::TypeAndRank}, - {key, ProgramTensorMetadataDependency::TypeAndRank}, - {value, ProgramTensorMetadataDependency::TypeAndRank}}); - - if (has_initial_state) { - program.AddInput({initial_state, ProgramTensorMetadataDependency::TypeAndRank}); - } - if (has_decay) { - program.AddInput({decay, ProgramTensorMetadataDependency::TypeAndRank}); - } - if (has_beta) { - program.AddInput({beta, ProgramTensorMetadataDependency::TypeAndRank}); - } - - program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank}, - {final_state, ProgramTensorMetadataDependency::TypeAndRank}}); - - program.SetDispatchGroupSize(batch_size, num_heads, 1) - .SetWorkgroupSize(1, 1, 1) - .AddUniformVariables({{batch_size}, - {num_heads}, - {seq_length}, - {head_dim_k}, - {head_dim_v}, - {scale_}}); - - return context.RunProgram(program); - } - - // Linear/Gated rules: Use two-phase chunk-parallel approach - // Allocate intermediate tensors for chunk computation - TensorShape chunk_states_shape({static_cast(batch_size), static_cast(num_heads), - static_cast(num_chunks), static_cast(head_dim_k), - static_cast(head_dim_v)}); - - // Allocate temporary tensors - need separate intra_output to avoid aliasing - Tensor intra_output_tensor = context.CreateGPUTensor(query->DataType(), output_shape); - Tensor chunk_states_tensor = context.CreateGPUTensor(query->DataType(), chunk_states_shape); - - // Step 1: Compute intra-chunk attention and per-chunk states - { - LinearAttentionChunkIntraProgram intra_program{update_rule_, has_decay, has_beta}; - - intra_program.AddInputs({{query, ProgramTensorMetadataDependency::TypeAndRank}, - {key, ProgramTensorMetadataDependency::TypeAndRank}, - {value, ProgramTensorMetadataDependency::TypeAndRank}}); - - if (has_decay) { - intra_program.AddInput({decay, ProgramTensorMetadataDependency::TypeAndRank}); - } - if (has_beta) { - intra_program.AddInput({beta, ProgramTensorMetadataDependency::TypeAndRank}); - } - - intra_program.AddOutputs({{&intra_output_tensor, ProgramTensorMetadataDependency::TypeAndRank}, - {&chunk_states_tensor, ProgramTensorMetadataDependency::TypeAndRank}}); - - intra_program.SetDispatchGroupSize(batch_size, num_heads, num_chunks) - .SetWorkgroupSize(64, 1, 1) - .AddUniformVariables({{batch_size}, - {num_heads}, - {seq_length}, - {head_dim_k}, - {head_dim_v}, - {chunk_size}, - {num_chunks}, - {scale_}}); - - ORT_RETURN_IF_ERROR(context.RunProgram(intra_program)); - } - - // Step 2: Inter-chunk state propagation and final output computation - { - LinearAttentionChunkInterProgram inter_program{update_rule_, has_decay, has_beta, has_initial_state}; - - // Use separate intra_output_tensor as input (read-only) and output (write-only) to avoid aliasing - inter_program.AddInputs({{&intra_output_tensor, ProgramTensorMetadataDependency::TypeAndRank}, // intra_output - {&chunk_states_tensor, ProgramTensorMetadataDependency::TypeAndRank}, // chunk_states - {query, ProgramTensorMetadataDependency::TypeAndRank}}); - - if (has_initial_state) { - inter_program.AddInput({initial_state, ProgramTensorMetadataDependency::TypeAndRank}); - } - if (has_decay) { - inter_program.AddInput({decay, ProgramTensorMetadataDependency::TypeAndRank}); - } - - inter_program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank}, - {final_state, ProgramTensorMetadataDependency::TypeAndRank}}); - - inter_program.SetDispatchGroupSize(batch_size, num_heads, 1) - .SetWorkgroupSize(256, 1, 1) - .AddUniformVariables({{batch_size}, - {num_heads}, - {seq_length}, - {head_dim_k}, - {head_dim_v}, - {chunk_size}, - {num_chunks}, - {scale_}}); - - ORT_RETURN_IF_ERROR(context.RunProgram(inter_program)); - } - - return Status::OK(); -} - } // namespace webgpu } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/nn/rms_norm.cc b/onnxruntime/core/providers/webgpu/nn/rms_norm.cc new file mode 100644 index 0000000000000..250b1153beb8b --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/rms_norm.cc @@ -0,0 +1,121 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/webgpu_utils.h" +#include "core/providers/webgpu/nn/rms_norm.h" +#include "core/providers/webgpu/nn/layer_norm.h" + +namespace onnxruntime { +namespace webgpu { + +static size_t NormalizeAxis(int64_t axis, size_t tensor_rank) { + int64_t rank = static_cast(tensor_rank); + if (axis < -rank && axis >= rank) { + ORT_THROW("invalid axis: ", axis); + } + return onnxruntime::narrow(axis < 0 ? axis + rank : axis); +} + +static TensorShape GetOverrideShape(const TensorShape& shape, int components) { + TensorShape override_shape{shape.Size() / components}; + return override_shape; +} + +Status RMSNorm::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { + const auto* x = context.Input(0); + const auto* scale = context.Input(1); + + const auto x_shape = x->Shape(); + + const size_t axis = NormalizeAxis(axis_, x_shape.NumDimensions()); + const uint32_t norm_count = onnxruntime::narrow(x_shape.SizeToDimension(axis)); + const int64_t norm_size = x_shape.SizeFromDimension(axis); + const int components = GetMaxComponents(norm_size); + const uint32_t norm_size_vectorized = onnxruntime::narrow((norm_size + components - 1) / components); + + const auto& scale_shape = scale->Shape(); + const auto scale_size = scale_shape.Size(); + if (scale_shape.NumDimensions() > x_shape.NumDimensions()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Scale and (optional) bias must match X.shape[axis:] or be NumPy-broadcastable to it." + " Scale/Bias rank cannot exceed Input rank."); + } + if (scale_size != norm_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Size of X.shape()[axis:] == ", norm_size, + ". Size of scale must match this. Got scale size of ", + scale_size); + } + + // RMSNormalization outputs: Y (index 0), InvStdDev (index 1, optional) + auto* y = context.Output(0, x_shape); + + TensorShapeVector inv_std_dev_dim; + for (size_t i = 0; i < x_shape.NumDimensions(); ++i) { + if (i < axis) { + inv_std_dev_dim.push_back(x_shape[i]); + } else { + inv_std_dev_dim.push_back(1); + } + } + TensorShape inv_std_dev_shape(inv_std_dev_dim); + auto* inv_std_dev = context.Output(1, inv_std_dev_shape); + + if (x_shape.Size() == 0) { + return Status::OK(); + } + + // Check if we should use split norm dimension optimization + const bool split_norm_dim = norm_size % 512 == 0 && norm_count == 1; + + // Reuse LayerNormProgram with simplified=true, has_bias=false, no mean output + LayerNormProgram program{/*has_bias=*/false, /*simplified=*/true, /*has_mean_output=*/false, + /*has_inv_std_dev_output=*/inv_std_dev != nullptr, split_norm_dim}; + + program.CacheHint(components, /*simplified=*/true, split_norm_dim) + .AddInputs({{x, ProgramTensorMetadataDependency::Type, GetOverrideShape(x->Shape(), components), components}}) + .AddInputs( + {{scale, ProgramTensorMetadataDependency::Type, GetOverrideShape(scale->Shape(), components), components}}) + .AddOutputs({{y, ProgramTensorMetadataDependency::None, GetOverrideShape(y->Shape(), components), components}}) + .AddUniformVariables({ + {static_cast(components)}, + }) + .AddUniformVariables({ + {static_cast(norm_count)}, + }) + .AddUniformVariables({ + {static_cast(norm_size)}, + }) + .AddUniformVariables({ + {static_cast(norm_size_vectorized)}, + }) + .AddUniformVariables({ + {static_cast(epsilon_)}, + }); + + if (split_norm_dim) { + const uint32_t workgroup_size_x = 128; + const uint32_t dispatch_size_x = onnxruntime::narrow(norm_size / (workgroup_size_x * components)); + program.SetDispatchGroupSize(dispatch_size_x, 1, 1) + .SetWorkgroupSize(workgroup_size_x); + } else { + program.SetDispatchGroupSize(norm_count); + } + + if (inv_std_dev != nullptr) { + program.AddOutputs({{inv_std_dev, ProgramTensorMetadataDependency::None}}); + } + + return context.RunProgram(program); +} + +ONNX_OPERATOR_KERNEL_EX(RMSNormalization, kOnnxDomain, 23, kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedFloatTypes()) + .TypeConstraint("V", WebGpuSupportedFloatTypes()), + RMSNorm); + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/nn/rms_norm.h b/onnxruntime/core/providers/webgpu/nn/rms_norm.h new file mode 100644 index 0000000000000..47da51f6df4a1 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/rms_norm.h @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace webgpu { + +class RMSNorm final : public WebGpuKernel { + public: + RMSNorm(const OpKernelInfo& info) : WebGpuKernel(info) { + info.GetAttrOrDefault("axis", &axis_, -1); + info.GetAttrOrDefault("epsilon", &epsilon_, 1e-05f); + } + + Status ComputeInternal(ComputeContext& context) const override; + + private: + int64_t axis_; + float epsilon_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index c6255e6f352d9..358669c042936 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -390,6 +390,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxD class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Tile); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 17, LayerNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 23, RMSNormalization); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 5, InstanceNormalization); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 1, 5, InstanceNormalization); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 21, InstanceNormalization); @@ -740,6 +741,7 @@ std::unique_ptr RegisterKernels(bool enable_graph_capture = fals BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, From 1baf195101a888f7b60a7a13ea48d19e9a35c5fb Mon Sep 17 00:00:00 2001 From: gs Date: Fri, 13 Mar 2026 08:45:06 -0700 Subject: [PATCH 05/27] add empty attention --- .../core/providers/webgpu/llm/attention.cc | 27 ++++++++++++++ .../core/providers/webgpu/llm/attention.h | 36 +++++++++++++++++++ .../webgpu/webgpu_execution_provider.cc | 7 ++++ 3 files changed, 70 insertions(+) create mode 100644 onnxruntime/core/providers/webgpu/llm/attention.cc create mode 100644 onnxruntime/core/providers/webgpu/llm/attention.h diff --git a/onnxruntime/core/providers/webgpu/llm/attention.cc b/onnxruntime/core/providers/webgpu/llm/attention.cc new file mode 100644 index 0000000000000..8722049517fa5 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/llm/attention.cc @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/llm/attention.h" + +namespace onnxruntime { +namespace webgpu { + +Status Attention::ComputeInternal(ComputeContext& /*context*/) const { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Attention operator is not yet implemented for WebGPU EP."); +} + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(Attention, kOnnxDomain, 23, 23, kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", WebGpuSupportedFloatTypes()) + .TypeConstraint("T2", WebGpuSupportedFloatTypes()), + Attention); + +ONNX_OPERATOR_KERNEL_EX(Attention, kOnnxDomain, 24, kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", WebGpuSupportedFloatTypes()) + .TypeConstraint("T2", WebGpuSupportedFloatTypes()), + Attention); + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/llm/attention.h b/onnxruntime/core/providers/webgpu/llm/attention.h new file mode 100644 index 0000000000000..d4d309f3234b4 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/llm/attention.h @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace webgpu { + +class Attention final : public WebGpuKernel { + public: + Attention(const OpKernelInfo& info) : WebGpuKernel(info) { + is_causal_ = info.GetAttrOrDefault("is_causal", 0); + q_num_heads_ = info.GetAttrOrDefault("q_num_heads", 0); + kv_num_heads_ = info.GetAttrOrDefault("kv_num_heads", 0); + scale_ = info.GetAttrOrDefault("scale", 0.0f); + softcap_ = info.GetAttrOrDefault("softcap", 0.0f); + qk_matmul_output_mode_ = info.GetAttrOrDefault("qk_matmul_output_mode", 0); + softmax_precision_ = info.GetAttrOrDefault("softmax_precision", 0); + } + + Status ComputeInternal(ComputeContext& context) const override; + + private: + int64_t is_causal_; + int64_t q_num_heads_; + int64_t kv_num_heads_; + float scale_; + float softcap_; + int64_t qk_matmul_output_mode_; + int64_t softmax_precision_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 358669c042936..e4be0e4d5c5b6 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -391,6 +391,10 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 17, LayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 23, RMSNormalization); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 23, 23, Attention); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 24, Attention); + class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 5, InstanceNormalization); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 1, 5, InstanceNormalization); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 21, InstanceNormalization); @@ -743,6 +747,9 @@ std::unique_ptr RegisterKernels(bool enable_graph_capture = fals BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, From c53ceaff9db4951a962395d0d282f2baf72e4193 Mon Sep 17 00:00:00 2001 From: gs Date: Fri, 13 Mar 2026 09:39:19 -0700 Subject: [PATCH 06/27] draft implementation of onnx Attention operator that maps to webgpu custom ops --- .../core/providers/webgpu/llm/attention.cc | 262 +++++++++++++++++- 1 file changed, 260 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/llm/attention.cc b/onnxruntime/core/providers/webgpu/llm/attention.cc index 8722049517fa5..e6be54c252215 100644 --- a/onnxruntime/core/providers/webgpu/llm/attention.cc +++ b/onnxruntime/core/providers/webgpu/llm/attention.cc @@ -1,14 +1,272 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include + #include "core/providers/webgpu/webgpu_supported_types.h" #include "core/providers/webgpu/llm/attention.h" +#include "contrib_ops/webgpu/bert/attention_common.h" +#include "contrib_ops/webgpu/bert/flash_attention.h" +#include "contrib_ops/cpu/bert/attention_parameters.h" + +/* +Remaining failures fall into known limitation categories: + Boolean masks (2) — not yet supported on WebGPU + SoftCap (2) — not yet wired through to the shader + GQA output (3) — output stride mismatch for GQA with different kv_num_heads + QK matmul output (5) — the output_qk output needs additional work + Present without past (2) — present key/value output without past input needs handling + is_causal (1) — causal masking interaction + +[ PASSED ] 24 tests. +[ FAILED ] 15 tests, listed below: +[ FAILED ] AttentionTest.Attention4DAttnMaskBoolAllFalse +[ FAILED ] AttentionTest.Attention4DAttnMaskBoolAllFalseDecodeWithPast +[ FAILED ] AttentionTest.Attention4DSoftCap +[ FAILED ] AttentionTest.Attention4DSoftCapFloat16 +[ FAILED ] AttentionTest.Attention4DAttnMaskBool +[ FAILED ] AttentionTest.Attention4DAttnIsCausal +[ FAILED ] AttentionTest.Attention3DGqaAttn +[ FAILED ] AttentionTest.Attention3DGqaSelfAttnCausal +[ FAILED ] AttentionTest.Attention4DGqaAttnMask +[ FAILED ] AttentionTest.Attention4DWithPastAndPresentQkMatmul +[ FAILED ] AttentionTest.Attention3DWithPastAndPresentQkMatmul +[ FAILED ] AttentionTest.Attention4DWithMask3DPastAndPresentQkMatmul +[ FAILED ] AttentionTest.Attention4DWithMask3DPastAndPresentQkMatmulCausal +[ FAILED ] AttentionTest.TestAttention4DWithPastAndPresentQkMatmulBias4DMaskCausal +[ FAILED ] AttentionTest.AttentionNoPastWithPresentOutput +*/ namespace onnxruntime { namespace webgpu { -Status Attention::ComputeInternal(ComputeContext& /*context*/) const { - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Attention operator is not yet implemented for WebGPU EP."); +Status Attention::ComputeInternal(ComputeContext& context) const { + const Tensor* Q = context.Input(0); + const Tensor* K = context.Input(1); + const Tensor* V = context.Input(2); + const Tensor* attn_mask = context.Input(3); // optional + const Tensor* past_key = context.Input(4); // optional + const Tensor* past_value = context.Input(5); // optional + // Input 6 is nonpad_kv_seqlen (opset 24 only) - not yet supported + + ORT_RETURN_IF(Q == nullptr || K == nullptr || V == nullptr, + "Q, K, and V inputs must not be null"); + ORT_RETURN_IF((past_key == nullptr) != (past_value == nullptr), + "past_key and past_value must be both present or both absent"); + + const auto& q_shape = Q->Shape(); + const auto& k_shape = K->Shape(); + const auto& v_shape = V->Shape(); + const int q_dims = static_cast(q_shape.NumDimensions()); + + ORT_RETURN_IF(q_dims != 3 && q_dims != 4, "Q must be a 3D or 4D tensor"); + ORT_RETURN_IF(q_dims != static_cast(k_shape.NumDimensions()), + "Q and K must have the same rank"); + ORT_RETURN_IF(q_dims != static_cast(v_shape.NumDimensions()), + "Q and V must have the same rank"); + + const bool is_4d = (q_dims == 4); + int batch_size, q_sequence_length, kv_sequence_length, head_size, v_head_size; + int q_num_heads_val, kv_num_heads_val; + + if (is_4d) { + batch_size = static_cast(q_shape[0]); + q_num_heads_val = static_cast(q_shape[1]); + q_sequence_length = static_cast(q_shape[2]); + head_size = static_cast(q_shape[3]); + kv_num_heads_val = static_cast(k_shape[1]); + kv_sequence_length = static_cast(k_shape[2]); + v_head_size = static_cast(v_shape[3]); + } else { + // 3D: (batch_size, sequence_length, hidden_size) + batch_size = static_cast(q_shape[0]); + q_sequence_length = static_cast(q_shape[1]); + q_num_heads_val = static_cast(q_num_heads_); + kv_num_heads_val = static_cast(kv_num_heads_); + ORT_RETURN_IF(q_num_heads_val <= 0 || kv_num_heads_val <= 0, + "q_num_heads and kv_num_heads attributes are required for 3D inputs"); + ORT_RETURN_IF(q_shape[2] % q_num_heads_val != 0, + "Q hidden size must be divisible by q_num_heads"); + ORT_RETURN_IF(v_shape[2] % kv_num_heads_val != 0, + "V hidden size must be divisible by kv_num_heads"); + head_size = static_cast(q_shape[2]) / q_num_heads_val; + kv_sequence_length = static_cast(k_shape[1]); + v_head_size = static_cast(v_shape[2]) / kv_num_heads_val; + } + + ORT_RETURN_IF(q_num_heads_val % kv_num_heads_val != 0, + "q_num_heads must be a multiple of kv_num_heads"); + + const int past_sequence_length = (past_key != nullptr) + ? static_cast(past_key->Shape()[2]) + : 0; + const int total_sequence_length = past_sequence_length + kv_sequence_length; + const float scale_val = (scale_ != 0.0f) + ? scale_ + : (1.0f / std::sqrt(static_cast(head_size))); + const bool is_gqa = (q_num_heads_val != kv_num_heads_val); + + // Build contrib::AttentionParameters to construct WebgpuAttentionParameters + contrib::AttentionParameters params = {}; + params.batch_size = batch_size; + params.sequence_length = q_sequence_length; + params.kv_sequence_length = kv_sequence_length; + params.past_sequence_length = past_sequence_length; + params.total_sequence_length = total_sequence_length; + params.hidden_size = q_num_heads_val * head_size; + params.head_size = head_size; + params.v_hidden_size = q_num_heads_val * v_head_size; + params.v_head_size = v_head_size; + params.num_heads = q_num_heads_val; + params.is_unidirectional = (is_causal_ == 1); + params.scale = scale_val; + params.mask_filter_value = -10000.0f; + params.qkv_format = contrib::Q_K_V_BNSH; + params.mask_type = contrib::MASK_NONE; + + contrib::webgpu::WebgpuAttentionParameters parameters(params); + + // For GQA (q_num_heads > kv_num_heads), set additional fields + if (is_gqa) { + parameters.is_gqa_ = true; + parameters.kv_num_heads_ = kv_num_heads_val; + parameters.kv_hidden_size_ = kv_num_heads_val * head_size; + parameters.v_hidden_size_ = kv_num_heads_val * v_head_size; + parameters.v_head_size_ = v_head_size; + parameters.n_reps = q_num_heads_val / kv_num_heads_val; + } + + if (softcap_ != 0.0f) { + parameters.softcap_ = softcap_; + } + + // Handle attention mask - reshape to 4D if needed for the shader. + // The shader expects 4D (batch, heads, q_seq, total_seq) and broadcasts + // via attn_bias_dim0/dim1 clamping. + const Tensor* attention_bias = nullptr; + Tensor reshaped_mask; + if (attn_mask != nullptr) { + ORT_RETURN_IF(attn_mask->IsDataType(), + "Boolean attention mask is not yet supported for WebGPU Attention"); + const auto mask_ndims = static_cast(attn_mask->Shape().NumDimensions()); + if (mask_ndims == 4) { + attention_bias = attn_mask; + } else if (mask_ndims == 3) { + // (A, q_seq, total_seq) → (1, A, q_seq, total_seq) per numpy broadcasting + TensorShape new_shape({1, attn_mask->Shape()[0], + attn_mask->Shape()[1], attn_mask->Shape()[2]}); + reshaped_mask = Tensor(attn_mask->DataType(), new_shape, + const_cast(attn_mask->DataRaw()), + attn_mask->Location()); + attention_bias = &reshaped_mask; + } else if (mask_ndims == 2) { + // (q_seq, total_seq) → (1, 1, q_seq, total_seq) + TensorShape new_shape({1, 1, attn_mask->Shape()[0], attn_mask->Shape()[1]}); + reshaped_mask = Tensor(attn_mask->DataType(), new_shape, + const_cast(attn_mask->DataRaw()), + attn_mask->Location()); + attention_bias = &reshaped_mask; + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "attn_mask must be 2D, 3D, or 4D tensor"); + } + } + + // Allocate output tensors. + // ApplyAttention and FlashAttention output in BSNH (BSD) memory layout. + // For 4D output, compute into a temporary BSD tensor then transpose to BNSH. + const int output_hidden = q_num_heads_val * v_head_size; + Tensor output_bsd_temp; + Tensor* output; + Tensor* compute_output; + + if (is_4d) { + TensorShapeVector y_shape({batch_size, q_num_heads_val, + q_sequence_length, v_head_size}); + output = context.Output(0, TensorShape(y_shape)); + // Temporary BSD tensor for the computation + TensorShapeVector bsd_shape({batch_size, q_sequence_length, output_hidden}); + output_bsd_temp = context.CreateGPUTensor(Q->DataType(), TensorShape(bsd_shape)); + compute_output = &output_bsd_temp; + } else { + TensorShapeVector y_shape({batch_size, q_sequence_length, output_hidden}); + output = context.Output(0, TensorShape(y_shape)); + compute_output = output; + } + + // Present key/value outputs (optional) + std::vector present_key_dims{ + batch_size, kv_num_heads_val, total_sequence_length, head_size}; + std::vector present_value_dims{ + batch_size, kv_num_heads_val, total_sequence_length, v_head_size}; + Tensor* present_key_output = context.Output(1, TensorShape(present_key_dims)); + Tensor* present_value_output = context.Output(2, TensorShape(present_value_dims)); + + // QK matmul output (optional, output index 3) + Tensor* output_qk = nullptr; + if (context.OutputCount() > 3) { + std::vector qk_dims{ + batch_size, q_num_heads_val, q_sequence_length, total_sequence_length}; + output_qk = context.Output(3, TensorShape(qk_dims)); + } + + // Prepare Q, K, V in BNSH format. + // 4D inputs are already BNSH; 3D inputs need BSD→BNSH conversion. + const Tensor* Q_bnsh = Q; + const Tensor* K_bnsh = K; + const Tensor* V_bnsh = V; + Tensor Q_converted, K_converted, V_converted; + + if (!is_4d) { + TensorShapeVector q_bnsh_dims({batch_size, q_num_heads_val, + q_sequence_length, head_size}); + Q_converted = context.CreateGPUTensor(Q->DataType(), TensorShape(q_bnsh_dims)); + ORT_RETURN_IF_ERROR(contrib::webgpu::TransferBSDToBNSH( + context, q_num_heads_val, q_sequence_length, + head_size, Q, nullptr, 0, &Q_converted)); + Q_bnsh = &Q_converted; + + TensorShapeVector k_bnsh_dims({batch_size, kv_num_heads_val, + kv_sequence_length, head_size}); + K_converted = context.CreateGPUTensor(K->DataType(), TensorShape(k_bnsh_dims)); + ORT_RETURN_IF_ERROR(contrib::webgpu::TransferBSDToBNSH( + context, kv_num_heads_val, kv_sequence_length, + head_size, K, nullptr, 0, &K_converted)); + K_bnsh = &K_converted; + + TensorShapeVector v_bnsh_dims({batch_size, kv_num_heads_val, + kv_sequence_length, v_head_size}); + V_converted = context.CreateGPUTensor(V->DataType(), TensorShape(v_bnsh_dims)); + ORT_RETURN_IF_ERROR(contrib::webgpu::TransferBSDToBNSH( + context, kv_num_heads_val, kv_sequence_length, + v_head_size, V, nullptr, 0, &V_converted)); + V_bnsh = &V_converted; + } + + // Try flash attention first (not available when output_qk is needed) + if (output_qk == nullptr && + contrib::webgpu::CanApplyFlashAttention(nullptr, parameters, context)) { + ORT_RETURN_IF_ERROR(contrib::webgpu::ApplyFlashAttention( + Q_bnsh, K_bnsh, V_bnsh, attention_bias, + compute_output, past_key, present_key_output, + past_value, present_value_output, parameters, context)); + } else { + // Fall back to tiled attention + ORT_RETURN_IF_ERROR(contrib::webgpu::ApplyAttention( + Q_bnsh, K_bnsh, V_bnsh, attention_bias, + past_key, past_value, + compute_output, present_key_output, present_value_output, + output_qk, parameters, context)); + } + + // For 4D output, transpose from BSNH (BSD) to BNSH + if (is_4d) { + ORT_RETURN_IF_ERROR(contrib::webgpu::TransferBSDToBNSH( + context, q_num_heads_val, q_sequence_length, + v_head_size, compute_output, nullptr, 0, output)); + } + + return Status::OK(); } ONNX_OPERATOR_VERSIONED_KERNEL_EX(Attention, kOnnxDomain, 23, 23, kWebGpuExecutionProvider, From c8de2f45c37a13a6ed2b08f4bb700a4edb076c01 Mon Sep 17 00:00:00 2001 From: gs Date: Fri, 13 Mar 2026 11:10:07 -0700 Subject: [PATCH 07/27] CausalConv1DWithState --- .../webgpu/bert/causal_conv1d_with_state.cc | 302 +++++++++ .../webgpu/bert/causal_conv1d_with_state.h | 66 ++ .../webgpu/webgpu_contrib_kernels.cc | 3 + .../core/graph/contrib_ops/bert_defs.cc | 84 +++ onnxruntime/core/graph/contrib_ops/ms_opset.h | 2 + .../core/providers/webgpu/llm/attention.cc | 8 + .../causal_conv1d_with_state_op_test.cc | 641 ++++++++++++++++++ 7 files changed, 1106 insertions(+) create mode 100644 onnxruntime/contrib_ops/webgpu/bert/causal_conv1d_with_state.cc create mode 100644 onnxruntime/contrib_ops/webgpu/bert/causal_conv1d_with_state.h create mode 100644 onnxruntime/test/contrib_ops/causal_conv1d_with_state_op_test.cc diff --git a/onnxruntime/contrib_ops/webgpu/bert/causal_conv1d_with_state.cc b/onnxruntime/contrib_ops/webgpu/bert/causal_conv1d_with_state.cc new file mode 100644 index 0000000000000..393e0ad5589ba --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/causal_conv1d_with_state.cc @@ -0,0 +1,302 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/webgpu/bert/causal_conv1d_with_state.h" + +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" + +using namespace onnxruntime::webgpu; + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +CausalConv1DActivation ParseCausalConv1DActivation(const std::string& activation_str) { + if (activation_str == "silu" || activation_str == "swish") { + return CausalConv1DActivation::Silu; + } else if (activation_str == "none" || activation_str.empty()) { + return CausalConv1DActivation::None; + } + ORT_THROW("Unknown activation for CausalConv1DWithState: ", activation_str); +} + +// ============================================================================= +// CausalConv1DWithState Implementation +// ============================================================================= + +ONNX_OPERATOR_KERNEL_EX( + CausalConv1DWithState, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedFloatTypes()), + CausalConv1DWithState); + +CausalConv1DWithState::CausalConv1DWithState(const OpKernelInfo& info) + : WebGpuKernel(info) { + std::string activation_str = info.GetAttrOrDefault("activation", "silu"); + activation_ = ParseCausalConv1DActivation(activation_str); +} + +Status CausalConv1DWithStateProgram::GenerateShaderCode(ShaderHelper& shader) const { + // Input tensors + const auto& input = shader.AddInput("input", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias); + const auto& weight = shader.AddInput("weight", ShaderUsage::UseUniform); + + // Optional inputs + const ShaderVariableHelper* bias_ptr = nullptr; + const ShaderVariableHelper* conv_state_ptr = nullptr; + if (has_bias_) { + bias_ptr = &shader.AddInput("bias", ShaderUsage::UseUniform); + } + if (has_conv_state_) { + conv_state_ptr = &shader.AddInput("conv_state", ShaderUsage::UseUniform); + } + + // Output tensors + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform); + const auto& present_state = shader.AddOutput("present_state", ShaderUsage::UseUniform); + + // Activation function implementation + if (activation_ == CausalConv1DActivation::Silu) { + shader.AdditionalImplementation() << R"SHADER( +fn silu(x: input_element_t) -> input_element_t { + return x / (1.0 + exp(-x)); +} +)SHADER"; + } + + // Flatten to 1D dispatch: each thread handles one (batch, channel, pos) triple. + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") + << R"SHADER( + let batch_size = uniforms.batch_size; + let channels = uniforms.channels; + let input_length = uniforms.input_length; + let kernel_size = uniforms.kernel_size; + let state_length = uniforms.state_length; // = kernel_size - 1 + + let pos = global_idx % input_length; + let bc_idx = global_idx / input_length; + let batch_idx = bc_idx / channels; + let channel_idx = bc_idx % channels; + + // Perform depthwise causal convolution for this (batch, channel, pos). + // The convolution window looks back kernel_size-1 positions. + // With conv_state providing the history before position 0, the + // "virtual" input is: [conv_state[0..state_length-1], input[0..L-1]] + // + // For output position pos: + // output[pos] = sum_{j=0}^{kernel_size-1} weight[j] * virtual_input[pos + j] + // where virtual_input is state_length positions of conv_state + // followed by input_length positions of input. + + var acc: input_element_t = 0.0; + + // Weight layout: (D, 1, K) -> channel_idx * kernel_size + j + let weight_base = channel_idx * kernel_size; + + for (var j: u32 = 0; j < kernel_size; j = j + 1) { + // virtual_pos is the position in the concatenated [conv_state, input] + let virtual_pos = pos + j; + + var val: input_element_t = 0.0; +)SHADER"; + + if (has_conv_state_) { + shader.MainFunctionBody() << R"SHADER( + if (virtual_pos < state_length) { + // Read from conv_state: (B, D, state_length) + let state_idx = (batch_idx * channels + channel_idx) * state_length + virtual_pos; + val = )SHADER" + << conv_state_ptr->GetByOffset("state_idx") << R"SHADER(; + } else { + // Read from input: (B, D, L) + let input_pos = virtual_pos - state_length; + let input_idx = (batch_idx * channels + channel_idx) * input_length + input_pos; + val = )SHADER" + << input.GetByOffset("input_idx") << R"SHADER(; + } +)SHADER"; + } else { + // No conv_state: pad with zeros for positions before the input + shader.MainFunctionBody() << R"SHADER( + if (virtual_pos >= state_length) { + let input_pos = virtual_pos - state_length; + let input_idx = (batch_idx * channels + channel_idx) * input_length + input_pos; + val = )SHADER" + << input.GetByOffset("input_idx") << R"SHADER(; + } +)SHADER"; + } + + shader.MainFunctionBody() << R"SHADER( + let w = )SHADER" + << weight.GetByOffset("weight_base + j") << R"SHADER(; + acc = acc + val * w; + } +)SHADER"; + + // Add bias if present + if (has_bias_) { + shader.MainFunctionBody() << " acc = acc + " << bias_ptr->GetByOffset("channel_idx") << ";\n"; + } + + // Apply activation + if (activation_ == CausalConv1DActivation::Silu) { + shader.MainFunctionBody() << " acc = silu(acc);\n"; + } + + // Write output: (B, D, L) + shader.MainFunctionBody() << R"SHADER( + let out_idx = (batch_idx * channels + channel_idx) * input_length + pos; + )SHADER" << output.SetByOffset("out_idx", "acc") + << "\n"; + + // Write present_state: the last (kernel_size - 1) elements from the + // virtual input [conv_state, input]. The virtual input has total length + // state_length + input_length. We want positions from + // (state_length + input_length - state_length) to (state_length + input_length - 1), + // i.e. the last state_length positions of the virtual input, which are the + // last state_length positions of input (when input_length >= state_length). + // + // We only write present_state once per (batch, channel), using the thread + // at pos == 0 to write all state_length values. + shader.MainFunctionBody() << R"SHADER( + if (pos == 0u) { + for (var s: u32 = 0; s < state_length; s = s + 1) { + var state_val: input_element_t = 0.0; + // total_len = state_length + input_length + // We want virtual_input[total_len - state_length + s] = virtual_input[input_length + s] + let vp = input_length + s; +)SHADER"; + + if (has_conv_state_) { + shader.MainFunctionBody() << R"SHADER( + if (vp < state_length) { + let si = (batch_idx * channels + channel_idx) * state_length + vp; + state_val = )SHADER" + << conv_state_ptr->GetByOffset("si") << R"SHADER(; + } else { + let ip = vp - state_length; + let ii = (batch_idx * channels + channel_idx) * input_length + ip; + state_val = )SHADER" + << input.GetByOffset("ii") << R"SHADER(; + } +)SHADER"; + } else { + shader.MainFunctionBody() << R"SHADER( + if (vp >= state_length) { + let ip = vp - state_length; + let ii = (batch_idx * channels + channel_idx) * input_length + ip; + state_val = )SHADER" + << input.GetByOffset("ii") << R"SHADER(; + } +)SHADER"; + } + + shader.MainFunctionBody() << R"SHADER( + let ps_idx = (batch_idx * channels + channel_idx) * state_length + s; + )SHADER" << present_state.SetByOffset("ps_idx", "state_val") + << R"SHADER( + } + } +)SHADER"; + + return Status::OK(); +} + +Status CausalConv1DWithState::ComputeInternal(ComputeContext& context) const { + const Tensor* input = context.Input(0); // (B, D, L) + const Tensor* weight = context.Input(1); // (D, 1, K) + const Tensor* bias = context.Input(2); // optional (D,) + const Tensor* conv_state = context.Input(3); // optional (B, D, K-1) + + ORT_RETURN_IF(input == nullptr, "Input tensor must not be null"); + ORT_RETURN_IF(weight == nullptr, "Weight tensor must not be null"); + + const auto& input_shape = input->Shape(); + const auto& weight_shape = weight->Shape(); + + ORT_RETURN_IF(input_shape.NumDimensions() != 3, + "Input must be 3D (batch_size, channels, length)"); + ORT_RETURN_IF(weight_shape.NumDimensions() != 3, + "Weight must be 3D (channels, 1, kernel_size)"); + + const int batch_size = static_cast(input_shape[0]); + const int channels = static_cast(input_shape[1]); + const int input_length = static_cast(input_shape[2]); + const int kernel_size = static_cast(weight_shape[2]); + const int state_length = kernel_size - 1; + + ORT_RETURN_IF(static_cast(weight_shape[0]) != channels, + "Weight first dim must match input channels"); + ORT_RETURN_IF(static_cast(weight_shape[1]) != 1, + "Weight second dim must be 1 for depthwise convolution"); + + if (bias != nullptr) { + ORT_RETURN_IF(bias->Shape().NumDimensions() != 1, + "Bias must be 1D"); + ORT_RETURN_IF(static_cast(bias->Shape()[0]) != channels, + "Bias size must match channels"); + } + + if (conv_state != nullptr) { + ORT_RETURN_IF(conv_state->Shape().NumDimensions() != 3, + "conv_state must be 3D (batch_size, channels, kernel_size - 1)"); + ORT_RETURN_IF(static_cast(conv_state->Shape()[0]) != batch_size, + "conv_state batch_size must match input"); + ORT_RETURN_IF(static_cast(conv_state->Shape()[1]) != channels, + "conv_state channels must match input"); + ORT_RETURN_IF(static_cast(conv_state->Shape()[2]) != state_length, + "conv_state last dim must be kernel_size - 1"); + } + + const bool has_bias = (bias != nullptr); + const bool has_conv_state = (conv_state != nullptr); + + // Allocate outputs + // Output 0: (B, D, L) + Tensor* output = context.Output(0, input_shape); + + // Output 1: present_state (B, D, K-1) + std::vector state_dims{batch_size, channels, state_length}; + Tensor* present_state = context.Output(1, TensorShape(state_dims)); + + if (input_length == 0) { + return Status::OK(); + } + + // Create and run the shader program + CausalConv1DWithStateProgram program{activation_, has_bias, has_conv_state, kernel_size}; + + uint32_t output_size = static_cast(batch_size * channels * input_length); + + program.AddInput({input, ProgramTensorMetadataDependency::Type}) + .AddInput({weight, ProgramTensorMetadataDependency::None}); + + if (has_bias) { + program.AddInput({bias, ProgramTensorMetadataDependency::None}); + } + if (has_conv_state) { + program.AddInput({conv_state, ProgramTensorMetadataDependency::None}); + } + + program.AddOutput({output, ProgramTensorMetadataDependency::None}) + .AddOutput({present_state, ProgramTensorMetadataDependency::None}) + .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariable({static_cast(batch_size)}) + .AddUniformVariable({static_cast(channels)}) + .AddUniformVariable({static_cast(input_length)}) + .AddUniformVariable({static_cast(kernel_size)}) + .AddUniformVariable({static_cast(state_length)}) + .AddUniformVariable({output_size}); + + return context.RunProgram(program); +} + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/causal_conv1d_with_state.h b/onnxruntime/contrib_ops/webgpu/bert/causal_conv1d_with_state.h new file mode 100644 index 0000000000000..ccbb22d9de7d4 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/causal_conv1d_with_state.h @@ -0,0 +1,66 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; +using onnxruntime::webgpu::ComputeContext; + +// Activation mode for CausalConv1DWithState +enum class CausalConv1DActivation { + None, + Silu, +}; + +CausalConv1DActivation ParseCausalConv1DActivation(const std::string& activation_str); + +// Program for CausalConv1DWithState +class CausalConv1DWithStateProgram final : public Program { + public: + CausalConv1DWithStateProgram(CausalConv1DActivation activation, bool has_bias, bool has_conv_state, + int kernel_size) + : Program{"CausalConv1DWithState"}, + activation_(activation), + has_bias_(has_bias), + has_conv_state_(has_conv_state), + kernel_size_(kernel_size) {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"batch_size", ProgramUniformVariableDataType::Uint32}, + {"channels", ProgramUniformVariableDataType::Uint32}, + {"input_length", ProgramUniformVariableDataType::Uint32}, + {"kernel_size", ProgramUniformVariableDataType::Uint32}, + {"state_length", ProgramUniformVariableDataType::Uint32}, + {"output_size", ProgramUniformVariableDataType::Uint32}); + + private: + CausalConv1DActivation activation_; + bool has_bias_; + bool has_conv_state_; + [[maybe_unused]] int kernel_size_; +}; + +// Kernel for CausalConv1DWithState +class CausalConv1DWithState final : public WebGpuKernel { + public: + CausalConv1DWithState(const OpKernelInfo& info); + Status ComputeInternal(ComputeContext& context) const override; + + private: + CausalConv1DActivation activation_; +}; + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc index a62cc2457b9ed..24a88320da980 100644 --- a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "contrib_ops/webgpu/webgpu_contrib_kernels.h" +#include "contrib_ops/webgpu/bert/causal_conv1d_with_state.h" #include "contrib_ops/webgpu/bert/group_query_attention.h" #include "contrib_ops/webgpu/bert/linear_attention.h" @@ -13,6 +14,7 @@ namespace webgpu { class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Attention); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasAdd); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, CausalConv1DWithState); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasGelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasSplitGelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FastGelu); @@ -46,6 +48,7 @@ Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry, bool enable BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 45fa48035885b..d0f7767452f49 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -2217,6 +2217,90 @@ ONNX_MS_OPERATOR_SET_SCHEMA( } })); +constexpr const char* CausalConv1DWithState_ver1_doc = R"DOC( +Depthwise causal 1D convolution with carry state for incremental decoding. + +Used by Gated DeltaNet (Qwen3.5) and Mamba (Jamba, FalconMamba) as a preprocessing step. +Replaces the 3-op pattern (Concat + Conv + Slice) with a single fused operation. + +The convolution is causal (looks only at current and past positions) and depthwise +(each channel is convolved independently with its own kernel). + +Input layout is channels-first: (batch_size, channels, length). +Weight layout: (channels, 1, kernel_size) for depthwise convolution. +Conv state carries the last (kernel_size - 1) input values for incremental decode. + +The optional activation attribute supports fused SiLU/Swish activation. +)DOC"; + +ONNX_MS_OPERATOR_SET_SCHEMA( + CausalConv1DWithState, 1, + OpSchema() + .SetDoc(CausalConv1DWithState_ver1_doc) + .Attr("activation", + "Fused activation function. One of: 'silu', 'swish', 'none'. " + "Default is 'silu'.", + AttributeProto::STRING, + std::string("silu")) + .Attr("group", + "group for convolution. Default is 1, which means normal convolution. When group equals to input channels, it becomes depthwise convolution.", + AttributeProto::INT, + static_cast(1)) + .Input(0, + "input", + "Input tensor with shape (batch_size, channels, length). Channels-first layout.", + "T") + .Input(1, + "weight", + "Depthwise convolution weights with shape (channels, 1, kernel_size).", + "T") + .Input(2, + "bias", + "Optional bias with shape (channels).", + "T", + OpSchema::Optional) + .Input(3, + "conv_state", + "Carry state from previous step with shape (batch_size, channels, kernel_size - 1). " + "If not provided, padding is zero.", + "T", + OpSchema::Optional) + .Output(0, + "output", + "Convolution output with shape (batch_size, channels, length).", + "T") + .Output(1, + "present_state", + "Updated carry state with shape (batch_size, channels, kernel_size - 1). " + "Contains the last (kernel_size - 1) values from the virtual input.", + "T") + .TypeConstraint("T", + {"tensor(float)", "tensor(float16)"}, + "Constrain input and output types to float tensors.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + propagateElemTypeFromInputToOutput(ctx, 0, 1); + + // Output 0: same shape as input (batch_size, channels, length) + propagateShapeFromInputToOutput(ctx, 0, 0); + + // Output 1: (batch_size, channels, kernel_size - 1) + if (hasInputShape(ctx, 0) && hasInputShape(ctx, 1)) { + auto& input_shape = getInputShape(ctx, 0); + auto& weight_shape = getInputShape(ctx, 1); + TensorShapeProto state_shape; + *state_shape.add_dim() = input_shape.dim(0); // batch_size + *state_shape.add_dim() = input_shape.dim(1); // channels + // kernel_size - 1 + if (weight_shape.dim(2).has_dim_value()) { + state_shape.add_dim()->set_dim_value(weight_shape.dim(2).dim_value() - 1); + } else { + state_shape.add_dim(); // unknown + } + updateOutputShape(ctx, 1, state_shape); + } + })); + constexpr const char* LinearAttentionRecurrent_ver1_doc = R"DOC( Linear Attention Recurrent operator for single-token decode step. diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h index c553bf4a3718d..2a8c8ee3521c6 100644 --- a/onnxruntime/core/graph/contrib_ops/ms_opset.h +++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h @@ -90,6 +90,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GroupQueryAttention); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, PagedAttention); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, LinearAttentionRecurrent); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, LinearAttentionChunkParallel); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, CausalConv1DWithState); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MurmurHash3); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, NGramRepeatBlock); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Pad); @@ -203,6 +204,7 @@ class OpSet_Microsoft_ver1 { fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); diff --git a/onnxruntime/core/providers/webgpu/llm/attention.cc b/onnxruntime/core/providers/webgpu/llm/attention.cc index e6be54c252215..f4b951b8d4d62 100644 --- a/onnxruntime/core/providers/webgpu/llm/attention.cc +++ b/onnxruntime/core/providers/webgpu/llm/attention.cc @@ -10,6 +10,14 @@ #include "contrib_ops/cpu/bert/attention_parameters.h" /* +Key design decisions: +1. Input parsing: Handles both 3D (B, S, hidden) and 4D (B, N, S, H) input formats per the ONNX spec +2. MHA vs GQA: Detects whether q_num_heads == kv_num_heads (MHA) or q_num_heads > kv_num_heads (GQA) and configures WebgpuAttentionParameters accordingly +3. Flash attention: Used when available (no output_qk needed, subgroups feature present, no bias) +4. 3D→BNSH conversion: For 3D inputs, uses TransferBSDToBNSH to convert to the BNSH format expected by the attention kernels +5. 4D output: Computes in BSD layout (as the shader outputs), then transposes back to BNSH for 4D output format +6. Attention mask: Reshapes 2D/3D masks to 4D for the shader's broadcasting logic; boolean masks return NOT_SUPPORTED + Remaining failures fall into known limitation categories: Boolean masks (2) — not yet supported on WebGPU SoftCap (2) — not yet wired through to the shader diff --git a/onnxruntime/test/contrib_ops/causal_conv1d_with_state_op_test.cc b/onnxruntime/test/contrib_ops/causal_conv1d_with_state_op_test.cc new file mode 100644 index 0000000000000..728389a4fb869 --- /dev/null +++ b/onnxruntime/test/contrib_ops/causal_conv1d_with_state_op_test.cc @@ -0,0 +1,641 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include "gtest/gtest.h" +#include "core/session/onnxruntime_cxx_api.h" +#include "test/common/tensor_op_test_utils.h" +#include "test/common/cuda_op_test_utils.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime { +namespace test { + +namespace { +enum class TensorType { + kFloat, + kFloat16 +}; + +// Reference implementation for CausalConv1DWithState +// Performs depthwise causal 1D convolution with optional state, bias, and activation. +// +// Input: (B, D, L) channels-first +// Weight: (D, 1, K) depthwise +// Bias: (D,) optional +// conv_state: (B, D, K-1) optional carry state +// +// Output: (B, D, L) convolution output (with optional activation) +// present_state: (B, D, K-1) updated carry state +void CausalConv1DWithStateReference( + const std::vector& input, + const std::vector& weight, + const std::vector* bias, + const std::vector* conv_state, + std::vector& output, + std::vector& present_state, + int batch_size, + int channels, + int input_length, + int kernel_size, + const std::string& activation) { + int state_length = kernel_size - 1; + int total_virtual_length = state_length + input_length; + + output.resize(batch_size * channels * input_length); + present_state.resize(batch_size * channels * state_length); + + for (int b = 0; b < batch_size; ++b) { + for (int d = 0; d < channels; ++d) { + int bd = b * channels + d; + + // Build virtual input: [conv_state, input] + std::vector virtual_input(total_virtual_length, 0.0f); + if (conv_state != nullptr) { + for (int s = 0; s < state_length; ++s) { + virtual_input[s] = (*conv_state)[bd * state_length + s]; + } + } + for (int l = 0; l < input_length; ++l) { + virtual_input[state_length + l] = input[bd * input_length + l]; + } + + // Compute depthwise convolution + for (int pos = 0; pos < input_length; ++pos) { + float acc = 0.0f; + for (int j = 0; j < kernel_size; ++j) { + float val = virtual_input[pos + j]; + float w = weight[d * kernel_size + j]; + acc += val * w; + } + // Add bias + if (bias != nullptr) { + acc += (*bias)[d]; + } + // Apply activation + if (activation == "silu" || activation == "swish") { + acc = acc / (1.0f + std::exp(-acc)); + } + output[bd * input_length + pos] = acc; + } + + // Compute present_state: last state_length values from virtual input + for (int s = 0; s < state_length; ++s) { + present_state[bd * state_length + s] = + virtual_input[input_length + s]; + } + } + } +} + +} // anonymous namespace + +static void RunCausalConv1DWithStateTest( + const std::vector& input_data, + const std::vector& weight_data, + const std::vector* bias_data, + const std::vector* conv_state_data, + const std::vector& expected_output, + const std::vector& expected_state, + int batch_size, + int channels, + int input_length, + int kernel_size, + const std::string& activation, + TensorType tensor_type) { + int state_length = kernel_size - 1; + + std::vector input_shape = {batch_size, channels, input_length}; + std::vector weight_shape = {channels, 1, kernel_size}; + std::vector bias_shape = {channels}; + std::vector state_shape = {batch_size, channels, state_length}; + std::vector output_shape = {batch_size, channels, input_length}; + + std::vector> execution_providers; + + bool enable_webgpu = nullptr != DefaultWebGpuExecutionProvider().get(); + if (enable_webgpu) { + execution_providers.push_back(DefaultWebGpuExecutionProvider()); + } + + if (execution_providers.empty()) { + // Skip if no providers available + return; + } + + for (auto& ep : execution_providers) { + OpTester test("CausalConv1DWithState", 1, onnxruntime::kMSDomain); + test.AddAttribute("activation", activation); + + if (tensor_type == TensorType::kFloat) { + test.AddInput("input", input_shape, input_data); + test.AddInput("weight", weight_shape, weight_data); + + if (bias_data != nullptr) { + test.AddInput("bias", bias_shape, *bias_data); + } else { + test.AddOptionalInputEdge(); + } + + if (conv_state_data != nullptr) { + test.AddInput("conv_state", state_shape, *conv_state_data); + } else { + test.AddOptionalInputEdge(); + } + + test.AddOutput("output", output_shape, expected_output); + test.AddOutput("present_state", state_shape, expected_state); + } else { + test.AddInput("input", input_shape, ToFloat16(input_data)); + test.AddInput("weight", weight_shape, ToFloat16(weight_data)); + + if (bias_data != nullptr) { + test.AddInput("bias", bias_shape, ToFloat16(*bias_data)); + } else { + test.AddOptionalInputEdge(); + } + + if (conv_state_data != nullptr) { + test.AddInput("conv_state", state_shape, ToFloat16(*conv_state_data)); + } else { + test.AddOptionalInputEdge(); + } + + test.AddOutput("output", output_shape, ToFloat16(expected_output)); + test.AddOutput("present_state", state_shape, ToFloat16(expected_state)); + } + + test.SetOutputAbsErr("output", 0.01f); + test.SetOutputAbsErr("present_state", 0.01f); + + std::vector> test_execution_providers; + test_execution_providers.push_back(std::move(ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &test_execution_providers); + } +} + +static void RunCausalConv1DWithStateTests( + const std::vector& input_data, + const std::vector& weight_data, + const std::vector* bias_data, + const std::vector* conv_state_data, + int batch_size, + int channels, + int input_length, + int kernel_size, + const std::string& activation = "silu") { + // Compute expected output using reference implementation + std::vector expected_output; + std::vector expected_state; + CausalConv1DWithStateReference( + input_data, weight_data, bias_data, conv_state_data, + expected_output, expected_state, + batch_size, channels, input_length, kernel_size, activation); + + // FP32 test + RunCausalConv1DWithStateTest( + input_data, weight_data, bias_data, conv_state_data, + expected_output, expected_state, + batch_size, channels, input_length, kernel_size, activation, + TensorType::kFloat); + + // FP16 test + RunCausalConv1DWithStateTest( + input_data, weight_data, bias_data, conv_state_data, + expected_output, expected_state, + batch_size, channels, input_length, kernel_size, activation, + TensorType::kFloat16); +} + +// ============================================================================= +// Basic tests - simple cases +// ============================================================================= + +TEST(CausalConv1DWithStateTest, BasicNoStateNoBias) { + // B=1, D=2, L=4, K=3, activation=none + int batch_size = 1, channels = 2, input_length = 4, kernel_size = 3; + + // Input: (1, 2, 4) + std::vector input_data = { + 1.0f, 2.0f, 3.0f, 4.0f, // channel 0 + 0.5f, 1.5f, 2.5f, 3.5f}; // channel 1 + + // Weight: (2, 1, 3) + std::vector weight_data = { + 0.1f, 0.2f, 0.3f, // channel 0 kernel + 0.4f, 0.5f, 0.6f}; // channel 1 kernel + + RunCausalConv1DWithStateTests( + input_data, weight_data, nullptr, nullptr, + batch_size, channels, input_length, kernel_size, "none"); +} + +TEST(CausalConv1DWithStateTest, BasicWithBias) { + // B=1, D=2, L=4, K=3, activation=none + int batch_size = 1, channels = 2, input_length = 4, kernel_size = 3; + + std::vector input_data = { + 1.0f, 2.0f, 3.0f, 4.0f, + 0.5f, 1.5f, 2.5f, 3.5f}; + std::vector weight_data = { + 0.1f, 0.2f, 0.3f, + 0.4f, 0.5f, 0.6f}; + std::vector bias_data = {0.1f, -0.2f}; + + RunCausalConv1DWithStateTests( + input_data, weight_data, &bias_data, nullptr, + batch_size, channels, input_length, kernel_size, "none"); +} + +TEST(CausalConv1DWithStateTest, BasicWithState) { + // B=1, D=2, L=3, K=3, activation=none + int batch_size = 1, channels = 2, input_length = 3, kernel_size = 3; + + std::vector input_data = { + 1.0f, 2.0f, 3.0f, + 0.5f, 1.5f, 2.5f}; + std::vector weight_data = { + 0.1f, 0.2f, 0.3f, + 0.4f, 0.5f, 0.6f}; + // State: (1, 2, 2) - kernel_size - 1 = 2 + std::vector conv_state_data = { + -1.0f, 0.5f, // channel 0 state + 0.3f, -0.7f}; // channel 1 state + + RunCausalConv1DWithStateTests( + input_data, weight_data, nullptr, &conv_state_data, + batch_size, channels, input_length, kernel_size, "none"); +} + +TEST(CausalConv1DWithStateTest, WithStateAndBias) { + // B=1, D=2, L=3, K=3, activation=none + int batch_size = 1, channels = 2, input_length = 3, kernel_size = 3; + + std::vector input_data = { + 1.0f, 2.0f, 3.0f, + 0.5f, 1.5f, 2.5f}; + std::vector weight_data = { + 0.1f, 0.2f, 0.3f, + 0.4f, 0.5f, 0.6f}; + std::vector bias_data = {0.1f, -0.2f}; + std::vector conv_state_data = { + -1.0f, 0.5f, + 0.3f, -0.7f}; + + RunCausalConv1DWithStateTests( + input_data, weight_data, &bias_data, &conv_state_data, + batch_size, channels, input_length, kernel_size, "none"); +} + +// ============================================================================= +// SiLU activation tests +// ============================================================================= + +TEST(CausalConv1DWithStateTest, SiluActivationNoState) { + int batch_size = 1, channels = 2, input_length = 4, kernel_size = 3; + + std::vector input_data = { + 1.0f, 2.0f, 3.0f, 4.0f, + 0.5f, 1.5f, 2.5f, 3.5f}; + std::vector weight_data = { + 0.1f, 0.2f, 0.3f, + 0.4f, 0.5f, 0.6f}; + + RunCausalConv1DWithStateTests( + input_data, weight_data, nullptr, nullptr, + batch_size, channels, input_length, kernel_size, "silu"); +} + +TEST(CausalConv1DWithStateTest, SiluActivationWithState) { + int batch_size = 1, channels = 2, input_length = 3, kernel_size = 3; + + std::vector input_data = { + 1.0f, 2.0f, 3.0f, + 0.5f, 1.5f, 2.5f}; + std::vector weight_data = { + 0.1f, 0.2f, 0.3f, + 0.4f, 0.5f, 0.6f}; + std::vector conv_state_data = { + -1.0f, 0.5f, + 0.3f, -0.7f}; + + RunCausalConv1DWithStateTests( + input_data, weight_data, nullptr, &conv_state_data, + batch_size, channels, input_length, kernel_size, "silu"); +} + +TEST(CausalConv1DWithStateTest, SiluActivationWithBiasAndState) { + int batch_size = 1, channels = 2, input_length = 4, kernel_size = 3; + + std::vector input_data = { + 1.0f, 2.0f, 3.0f, 4.0f, + 0.5f, 1.5f, 2.5f, 3.5f}; + std::vector weight_data = { + 0.1f, 0.2f, 0.3f, + 0.4f, 0.5f, 0.6f}; + std::vector bias_data = {0.1f, -0.2f}; + std::vector conv_state_data = { + -1.0f, 0.5f, + 0.3f, -0.7f}; + + RunCausalConv1DWithStateTests( + input_data, weight_data, &bias_data, &conv_state_data, + batch_size, channels, input_length, kernel_size, "silu"); +} + +// ============================================================================= +// Kernel size variations +// ============================================================================= + +TEST(CausalConv1DWithStateTest, KernelSize2) { + int batch_size = 1, channels = 2, input_length = 4, kernel_size = 2; + + std::vector input_data = { + 1.0f, 2.0f, 3.0f, 4.0f, + 0.5f, 1.5f, 2.5f, 3.5f}; + std::vector weight_data = { + 0.3f, 0.7f, + 0.4f, 0.6f}; + // State: (1, 2, 1) - kernel_size - 1 = 1 + std::vector conv_state_data = {0.5f, -0.3f}; + + RunCausalConv1DWithStateTests( + input_data, weight_data, nullptr, &conv_state_data, + batch_size, channels, input_length, kernel_size, "silu"); +} + +TEST(CausalConv1DWithStateTest, KernelSize4) { + int batch_size = 1, channels = 1, input_length = 5, kernel_size = 4; + + std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}; + std::vector weight_data = {0.1f, 0.2f, 0.3f, 0.4f}; + // State: (1, 1, 3) + std::vector conv_state_data = {-1.0f, 0.0f, 0.5f}; + + RunCausalConv1DWithStateTests( + input_data, weight_data, nullptr, &conv_state_data, + batch_size, channels, input_length, kernel_size, "none"); +} + +// ============================================================================= +// Batch size > 1 +// ============================================================================= + +TEST(CausalConv1DWithStateTest, MultiBatch) { + int batch_size = 2, channels = 2, input_length = 3, kernel_size = 3; + + // Input: (2, 2, 3) + std::vector input_data = { + // Batch 0 + 1.0f, 2.0f, 3.0f, // ch 0 + 0.5f, 1.5f, 2.5f, // ch 1 + // Batch 1 + -1.0f, 0.0f, 1.0f, // ch 0 + 0.2f, 0.4f, 0.6f}; // ch 1 + + std::vector weight_data = { + 0.1f, 0.2f, 0.3f, + 0.4f, 0.5f, 0.6f}; + + std::vector bias_data = {0.1f, -0.1f}; + + // State: (2, 2, 2) + std::vector conv_state_data = { + // Batch 0 + -0.5f, 0.5f, // ch 0 + 0.3f, -0.3f, // ch 1 + // Batch 1 + 0.1f, -0.1f, // ch 0 + 0.7f, 0.8f}; // ch 1 + + RunCausalConv1DWithStateTests( + input_data, weight_data, &bias_data, &conv_state_data, + batch_size, channels, input_length, kernel_size, "silu"); +} + +// ============================================================================= +// Single token decode (L=1) - the primary use case for incremental decoding +// ============================================================================= + +TEST(CausalConv1DWithStateTest, SingleTokenDecode) { + int batch_size = 1, channels = 4, input_length = 1, kernel_size = 4; + + // Input: (1, 4, 1) + std::vector input_data = {0.5f, -0.3f, 1.2f, 0.8f}; + + // Weight: (4, 1, 4) + std::vector weight_data = { + 0.1f, 0.2f, 0.3f, 0.4f, + 0.5f, 0.6f, 0.7f, 0.8f, + -0.1f, -0.2f, 0.1f, 0.2f, + 0.3f, 0.3f, 0.3f, 0.3f}; + + std::vector bias_data = {0.0f, 0.1f, -0.1f, 0.0f}; + + // State: (1, 4, 3) - carrying the last 3 values per channel + std::vector conv_state_data = { + 1.0f, 2.0f, 3.0f, // ch 0 + -1.0f, 0.0f, 1.0f, // ch 1 + 0.5f, 0.5f, 0.5f, // ch 2 + -0.2f, 0.4f, -0.6f}; // ch 3 + + RunCausalConv1DWithStateTests( + input_data, weight_data, &bias_data, &conv_state_data, + batch_size, channels, input_length, kernel_size, "silu"); +} + +TEST(CausalConv1DWithStateTest, SingleTokenDecodeMultiBatch) { + int batch_size = 2, channels = 2, input_length = 1, kernel_size = 3; + + // Input: (2, 2, 1) + std::vector input_data = { + 0.5f, // B0, ch 0 + -0.3f, // B0, ch 1 + 1.2f, // B1, ch 0 + 0.8f}; // B1, ch 1 + + std::vector weight_data = { + 0.1f, 0.2f, 0.3f, + 0.4f, 0.5f, 0.6f}; + + // State: (2, 2, 2) + std::vector conv_state_data = { + 1.0f, 2.0f, // B0, ch 0 + -1.0f, 0.0f, // B0, ch 1 + 0.5f, 0.5f, // B1, ch 0 + -0.2f, 0.4f}; // B1, ch 1 + + RunCausalConv1DWithStateTests( + input_data, weight_data, nullptr, &conv_state_data, + batch_size, channels, input_length, kernel_size, "silu"); +} + +// ============================================================================= +// State continuity test: verify that present_state from one call can be used +// as conv_state for the next call (simulating autoregressive decode) +// ============================================================================= + +TEST(CausalConv1DWithStateTest, StateContinuity) { + // Process a sequence of single tokens and verify state propagation + int batch_size = 1, channels = 1, kernel_size = 3; + int input_length = 1; + + std::vector weight_data = {0.2f, 0.3f, 0.5f}; + std::vector bias_data = {0.1f}; + + // Initial state: zeros + std::vector conv_state = {0.0f, 0.0f}; + + // First token + std::vector input1 = {1.0f}; + std::vector expected_output1; + std::vector expected_state1; + CausalConv1DWithStateReference(input1, weight_data, &bias_data, &conv_state, + expected_output1, expected_state1, + batch_size, channels, input_length, kernel_size, "none"); + + RunCausalConv1DWithStateTest(input1, weight_data, &bias_data, &conv_state, + expected_output1, expected_state1, + batch_size, channels, input_length, kernel_size, "none", + TensorType::kFloat); + + // Second token, using present_state from first as conv_state + std::vector input2 = {2.0f}; + std::vector expected_output2; + std::vector expected_state2; + CausalConv1DWithStateReference(input2, weight_data, &bias_data, &expected_state1, + expected_output2, expected_state2, + batch_size, channels, input_length, kernel_size, "none"); + + RunCausalConv1DWithStateTest(input2, weight_data, &bias_data, &expected_state1, + expected_output2, expected_state2, + batch_size, channels, input_length, kernel_size, "none", + TensorType::kFloat); + + // Third token + std::vector input3 = {3.0f}; + std::vector expected_output3; + std::vector expected_state3; + CausalConv1DWithStateReference(input3, weight_data, &bias_data, &expected_state2, + expected_output3, expected_state3, + batch_size, channels, input_length, kernel_size, "none"); + + RunCausalConv1DWithStateTest(input3, weight_data, &bias_data, &expected_state2, + expected_output3, expected_state3, + batch_size, channels, input_length, kernel_size, "none", + TensorType::kFloat); + + // The present_state after processing [1, 2, 3] should be [2, 3] + EXPECT_NEAR(expected_state3[0], 2.0f, 1e-5f); + EXPECT_NEAR(expected_state3[1], 3.0f, 1e-5f); +} + +// ============================================================================= +// Equivalence test: sequence processing should match token-by-token with state +// ============================================================================= + +TEST(CausalConv1DWithStateTest, SequenceVsTokenByToken) { + int batch_size = 1, channels = 2, kernel_size = 3; + + std::vector weight_data = { + 0.1f, 0.2f, 0.3f, + 0.4f, 0.5f, 0.6f}; + std::vector bias_data = {0.05f, -0.05f}; + + // Initial state: zeros + std::vector conv_state = {0.0f, 0.0f, 0.0f, 0.0f}; // (1, 2, 2) + + // Full sequence: length 4 + std::vector full_input = { + 1.0f, 2.0f, 3.0f, 4.0f, // ch 0 + 0.5f, 1.5f, 2.5f, 3.5f}; // ch 1 + + // Process full sequence at once + std::vector full_output; + std::vector full_final_state; + CausalConv1DWithStateReference(full_input, weight_data, &bias_data, &conv_state, + full_output, full_final_state, + batch_size, channels, 4, kernel_size, "none"); + + // Process token by token + std::vector current_state = conv_state; + std::vector token_outputs; + + for (int t = 0; t < 4; ++t) { + // Extract single token: (1, 2, 1) + std::vector token_input = { + full_input[0 * 4 + t], // ch 0 + full_input[1 * 4 + t]}; // ch 1 + + std::vector token_output; + std::vector next_state; + CausalConv1DWithStateReference(token_input, weight_data, &bias_data, ¤t_state, + token_output, next_state, + batch_size, channels, 1, kernel_size, "none"); + + // Collect outputs + for (int d = 0; d < channels; ++d) { + token_outputs.push_back(token_output[d]); + } + current_state = next_state; + } + + // Rearrange token_outputs from (T, D) to (D, T) layout for comparison + std::vector token_outputs_dlt(channels * 4); + for (int t = 0; t < 4; ++t) { + for (int d = 0; d < channels; ++d) { + token_outputs_dlt[d * 4 + t] = token_outputs[t * channels + d]; + } + } + + // Compare outputs + for (int i = 0; i < channels * 4; ++i) { + EXPECT_NEAR(full_output[i], token_outputs_dlt[i], 1e-5f) + << "Mismatch at index " << i; + } + + // Compare final states + for (int i = 0; i < channels * 2; ++i) { + EXPECT_NEAR(full_final_state[i], current_state[i], 1e-5f) + << "State mismatch at index " << i; + } +} + +// ============================================================================= +// Larger dimension test with realistic sizes +// ============================================================================= + +TEST(CausalConv1DWithStateTest, LargerDimensions) { + int batch_size = 2, channels = 8, input_length = 16, kernel_size = 4; + + // Generate test data with a simple pattern + std::vector input_data(batch_size * channels * input_length); + for (int i = 0; i < static_cast(input_data.size()); ++i) { + input_data[i] = std::sin(static_cast(i) * 0.1f); + } + + std::vector weight_data(channels * kernel_size); + for (int i = 0; i < static_cast(weight_data.size()); ++i) { + weight_data[i] = std::cos(static_cast(i) * 0.2f) * 0.5f; + } + + std::vector bias_data(channels); + for (int i = 0; i < channels; ++i) { + bias_data[i] = 0.01f * static_cast(i); + } + + int state_length = kernel_size - 1; + std::vector conv_state_data(batch_size * channels * state_length); + for (int i = 0; i < static_cast(conv_state_data.size()); ++i) { + conv_state_data[i] = std::sin(static_cast(i) * 0.3f) * 0.5f; + } + + RunCausalConv1DWithStateTests( + input_data, weight_data, &bias_data, &conv_state_data, + batch_size, channels, input_length, kernel_size, "silu"); +} + +} // namespace test +} // namespace onnxruntime From 0948d561a74aafefcc416562beaea57738de2361 Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Fri, 13 Mar 2026 18:30:54 -0700 Subject: [PATCH 08/27] add int64_t to concat/webgpu --- onnxruntime/core/providers/webgpu/tensor/concat.cc | 6 +++--- .../core/providers/webgpu/webgpu_supported_types.h | 14 ++++++++++++++ .../test/providers/cpu/tensor/concat_op_test.cc | 13 +++++++++++++ 3 files changed, 30 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/tensor/concat.cc b/onnxruntime/core/providers/webgpu/tensor/concat.cc index 75453b991a0cd..ed031654cf793 100644 --- a/onnxruntime/core/providers/webgpu/tensor/concat.cc +++ b/onnxruntime/core/providers/webgpu/tensor/concat.cc @@ -20,7 +20,7 @@ namespace webgpu { end, \ kWebGpuExecutionProvider, \ (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", WebGpuSupportedNumberTypes()), \ + .TypeConstraint("T", WebGpuSupportedNumberAndInt64Types()), \ Concat); #define WEBGPU_CONCAT_KERNEL(version) \ @@ -30,7 +30,7 @@ namespace webgpu { version, \ kWebGpuExecutionProvider, \ (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", WebGpuSupportedNumberTypes()), \ + .TypeConstraint("T", WebGpuSupportedNumberAndInt64Types()), \ Concat); WEBGPU_CONCAT_VERSIONED_KERNEL(1, 3) @@ -154,4 +154,4 @@ Status Concat::ComputeInternal(ComputeContext& context) const { } } // namespace webgpu -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_supported_types.h b/onnxruntime/core/providers/webgpu/webgpu_supported_types.h index 1efbda00ec869..455cf31928f26 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_supported_types.h +++ b/onnxruntime/core/providers/webgpu/webgpu_supported_types.h @@ -28,6 +28,15 @@ using SupportedNumberAndBoolTypes = uint32_t, bool>; +using SupportedNumberAndInt64Types = + TypeList< + float, + MLFloat16, + int32_t, + uint32_t, + int64_t, + uint64_t>; + inline const std::vector& WebGpuSupportedNumberTypes() { static const std::vector supportedDataTypes = BuildKernelDefConstraintsFromTypeList(); return supportedDataTypes; @@ -43,6 +52,11 @@ inline const std::vector& WebGpuSupportedNumberAndBoolTypes() { return supportedDataTypes; } +inline const std::vector& WebGpuSupportedNumberAndInt64Types() { + static const std::vector supportedDataTypes = BuildKernelDefConstraintsFromTypeList(); + return supportedDataTypes; +} + inline const std::vector& GetOpTypeConstraints(bool enable_int64 = false, bool enable_bool = false) { static std::vector base_types{ DataTypeImpl::GetTensorType(), diff --git a/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc b/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc index 5f08b6df6785d..775455e7e733e 100644 --- a/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/concat_op_test.cc @@ -40,6 +40,19 @@ TEST(ConcatOpTest, Concat1D_int32) { test.Run(); } +TEST(ConcatOpTest, Concat1D_int64) { + // webgpu ep will fail for 0x1122334455667788 + const int64_t val = 0x11223344; + OpTester test("Concat"); + test.AddAttribute("axis", int64_t{0}); + + test.AddInput("input1", {1}, {val}); + test.AddInput("input2", {2}, {2, 3}); + test.AddInput("input3", {4}, {4, 5, 6, 7}); + test.AddOutput("concat_result", {7}, {val, 2, 3, 4, 5, 6, 7}); + test.Run(); +} + TEST(ConcatOpTest, Concat1D_int32_negative_axis) { OpTester test("Concat"); test.AddAttribute("axis", int64_t{-1}); From 231c92c835b35b92735b2a1b6ff17f3c94b6c19f Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Sun, 15 Mar 2026 09:13:08 -0700 Subject: [PATCH 09/27] webgpu support for onnx rotarry embeddings --- .../providers/webgpu/llm/rotary_embedding.cc | 132 ++++++++++++++++++ .../providers/webgpu/llm/rotary_embedding.h | 23 +++ .../webgpu/webgpu_execution_provider.cc | 4 + 3 files changed, 159 insertions(+) create mode 100644 onnxruntime/core/providers/webgpu/llm/rotary_embedding.cc create mode 100644 onnxruntime/core/providers/webgpu/llm/rotary_embedding.h diff --git a/onnxruntime/core/providers/webgpu/llm/rotary_embedding.cc b/onnxruntime/core/providers/webgpu/llm/rotary_embedding.cc new file mode 100644 index 0000000000000..138b593819fb5 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/llm/rotary_embedding.cc @@ -0,0 +1,132 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/llm/rotary_embedding.h" +#include "contrib_ops/webgpu/bert/rotary_embedding.h" +#include "core/providers/webgpu/generator/range.h" + +namespace onnxruntime { +namespace webgpu { + +ONNX_OPERATOR_KERNEL_EX( + RotaryEmbedding, + kOnnxDomain, + 23, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedFloatTypes()) + .TypeConstraint("M", DataTypeImpl::GetTensorType()), + RotaryEmbedding); + +RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : WebGpuKernel(info) { + rotary_embedding_dim_ = static_cast(info.GetAttrOrDefault("rotary_embedding_dim", 0)); + num_heads_ = static_cast(info.GetAttrOrDefault("num_heads", 0)); + interleaved_ = (info.GetAttrOrDefault("interleaved", 0) == 1); +} + +Status RotaryEmbedding::ComputeInternal(ComputeContext& context) const { + // ONNX inputs: X(0), cos_cache(1), sin_cache(2), position_ids(3, optional) + const auto* input = context.Input(0); + const auto* cos_cache = context.Input(1); + const auto* sin_cache = context.Input(2); + const auto* position_ids = context.Input(3); // optional + + const auto input_shape = input->Shape(); + auto* output = context.Output(0, input_shape); + + const auto batch_size = onnxruntime::narrow(input_shape[0]); + const auto batch_stride = onnxruntime::narrow(input_shape.SizeFromDimension(1)); + const auto sequence_length = onnxruntime::narrow(input_shape[input_shape.NumDimensions() - 2]); + const auto hidden_size = batch_stride / sequence_length; + const auto half_rotary_embedding_dim = onnxruntime::narrow(cos_cache->Shape()[cos_cache->Shape().NumDimensions() - 1]); + const auto head_size = rotary_embedding_dim_ == 0 ? half_rotary_embedding_dim * 2 : hidden_size / num_heads_; + + const TensorShape global_shape({batch_size, + sequence_length, + hidden_size / head_size, + head_size - half_rotary_embedding_dim}); + + const auto rank = global_shape.NumDimensions(); + std::vector global_dims(rank); + std::vector global_strides(rank); + for (size_t j = 0; j < rank; ++j) { + global_dims[j] = onnxruntime::narrow(global_shape[j]); + global_strides[j] = onnxruntime::narrow(global_shape.SizeFromDimension(j + 1)); + } + + const auto output_size = onnxruntime::narrow(global_shape.Size()); + const auto input_output_strides = + input_shape.NumDimensions() == 3 + ? std::vector({batch_stride, hidden_size, head_size, 1}) + : (input_shape.NumDimensions() == 4 + ? std::vector({batch_stride, head_size, sequence_length * head_size, 1}) + : std::vector({})); + + // The contrib RotaryEmbeddingProgram expects inputs in order: + // input(0), position_ids(1), cos_cache(2), sin_cache(3) + // The ONNX op has: X(0), cos_cache(1), sin_cache(2), position_ids(3, optional) + + if (position_ids != nullptr) { + // position_ids provided: cos/sin cache is 2D (max_pos, D/2) + contrib::webgpu::RotaryEmbeddingProgram program{interleaved_}; + program + .CacheHint(interleaved_) + .AddInputs({{input, ProgramTensorMetadataDependency::TypeAndRank}, + {position_ids, ProgramTensorMetadataDependency::Rank}, + {cos_cache, ProgramTensorMetadataDependency::Rank}, + {sin_cache, ProgramTensorMetadataDependency::Rank}}) + .AddOutput({output, ProgramTensorMetadataDependency::None}) + .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({{1.0f}, + {gsl::make_span(global_dims)}, + {gsl::make_span(global_strides)}, + {gsl::make_span(input_output_strides)}}) + .AddIndices(TensorShape{1, 1}); + return context.RunProgram(program); + } + + // position_ids NOT provided: cos/sin cache is 3D (B, S, D/2) + // Reshape to 2D (B*S, D/2) and generate sequential position_ids. + const auto total_seq = batch_size * sequence_length; + const TensorShape cache_2d_shape({static_cast(total_seq), + static_cast(half_rotary_embedding_dim)}); + + // Generate position_ids [0, 1, ..., B*S-1] reshaped as (B, S) on GPU using RangeProgram + const TensorShape pos_ids_shape({static_cast(batch_size), + static_cast(sequence_length)}); + Tensor pos_ids_tensor = context.CreateGPUTensor(DataTypeImpl::GetType(), pos_ids_shape); + { + RangeProgram range_program{ONNX_NAMESPACE::TensorProto_DataType_INT64}; + int32_t start_i32 = 0; + int32_t delta_i32 = 1; + range_program + .AddOutput({&pos_ids_tensor, ProgramTensorMetadataDependency::Type}) + .SetDispatchGroupSize((total_seq + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({ + total_seq, + std::bit_cast(start_i32), + std::bit_cast(delta_i32), + }); + ORT_RETURN_IF_ERROR(context.RunProgram(range_program)); + } + + contrib::webgpu::RotaryEmbeddingProgram program{interleaved_}; + program + .CacheHint(interleaved_) + .AddInputs({{input, ProgramTensorMetadataDependency::TypeAndRank}, + {&pos_ids_tensor, ProgramTensorMetadataDependency::Rank}, + {cos_cache, ProgramTensorMetadataDependency::Rank, cache_2d_shape, 1}, + {sin_cache, ProgramTensorMetadataDependency::Rank, cache_2d_shape, 1}}) + .AddOutput({output, ProgramTensorMetadataDependency::None}) + .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({{1.0f}, + {gsl::make_span(global_dims)}, + {gsl::make_span(global_strides)}, + {gsl::make_span(input_output_strides)}}) + .AddIndices(TensorShape{1, 1}); + return context.RunProgram(program); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/llm/rotary_embedding.h b/onnxruntime/core/providers/webgpu/llm/rotary_embedding.h new file mode 100644 index 0000000000000..6a3f60e8b75e3 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/llm/rotary_embedding.h @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace webgpu { + +class RotaryEmbedding final : public WebGpuKernel { + public: + RotaryEmbedding(const OpKernelInfo& info); + Status ComputeInternal(ComputeContext& context) const override; + + private: + int num_heads_; + int rotary_embedding_dim_; + bool interleaved_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index e4be0e4d5c5b6..db4c78a65b388 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -395,6 +395,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 23, class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 23, 23, Attention); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 24, Attention); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 23, RotaryEmbedding); + class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 5, InstanceNormalization); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 1, 5, InstanceNormalization); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 21, InstanceNormalization); @@ -750,6 +752,8 @@ std::unique_ptr RegisterKernels(bool enable_graph_capture = fals BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, From e69b57923da465dd71cb8ce4823cc866fea41ff3 Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Sun, 15 Mar 2026 10:00:48 -0700 Subject: [PATCH 10/27] webgpu reshape to opset 25 --- .../core/providers/webgpu/tensor/reshape.cc | 26 ++++++++++++++++++- .../webgpu/webgpu_execution_provider.cc | 8 ++++-- 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/tensor/reshape.cc b/onnxruntime/core/providers/webgpu/tensor/reshape.cc index 9ede015a0c99c..26546d59220fa 100644 --- a/onnxruntime/core/providers/webgpu/tensor/reshape.cc +++ b/onnxruntime/core/providers/webgpu/tensor/reshape.cc @@ -11,7 +11,31 @@ namespace webgpu { ONNX_OPERATOR_KERNEL_EX( Reshape, kOnnxDomain, - 21, + 25, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("shape", DataTypeImpl::GetTensorType()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPU, 1), + Reshape); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Reshape, + kOnnxDomain, + 23, 24, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("shape", DataTypeImpl::GetTensorType()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPU, 1), + Reshape); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Reshape, + kOnnxDomain, + 21, 22, kWebGpuExecutionProvider, (*KernelDefBuilder::Create()) .TypeConstraint("T", WebGpuSupportedNumberTypes()) diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index db4c78a65b388..dcf4474d82d63 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -247,7 +247,9 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxD class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 13, Reshape); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, 18, Reshape); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, 20, Reshape); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, Reshape); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, 22, Reshape); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 23, 24, Reshape); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 25, Reshape); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 12, Identity); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 13, Identity); @@ -561,7 +563,9 @@ std::unique_ptr RegisterKernels(bool enable_graph_capture = fals BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, From 56fe4ac65d4d7200249d8b58d39d7586eaad0d61 Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Sun, 15 Mar 2026 22:07:17 -0700 Subject: [PATCH 11/27] keep int64_t for concat on cpu --- .../webgpu/bert/causal_conv1d_with_state.cc | 24 +++++------ .../core/graph/contrib_ops/bert_defs.cc | 4 +- .../core/providers/webgpu/tensor/concat.cc | 34 +++++++-------- .../causal_conv1d_with_state_op_test.cc | 40 +++++++++--------- .../contrib_ops/linear_attention_op_test.cc | 42 +++++++------------ 5 files changed, 66 insertions(+), 78 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/causal_conv1d_with_state.cc b/onnxruntime/contrib_ops/webgpu/bert/causal_conv1d_with_state.cc index 393e0ad5589ba..fe66579a3b42e 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/causal_conv1d_with_state.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/causal_conv1d_with_state.cc @@ -111,13 +111,13 @@ fn silu(x: input_element_t) -> input_element_t { // Read from conv_state: (B, D, state_length) let state_idx = (batch_idx * channels + channel_idx) * state_length + virtual_pos; val = )SHADER" - << conv_state_ptr->GetByOffset("state_idx") << R"SHADER(; + << conv_state_ptr->GetByOffset("state_idx") << R"SHADER(; } else { // Read from input: (B, D, L) let input_pos = virtual_pos - state_length; let input_idx = (batch_idx * channels + channel_idx) * input_length + input_pos; val = )SHADER" - << input.GetByOffset("input_idx") << R"SHADER(; + << input.GetByOffset("input_idx") << R"SHADER(; } )SHADER"; } else { @@ -127,14 +127,14 @@ fn silu(x: input_element_t) -> input_element_t { let input_pos = virtual_pos - state_length; let input_idx = (batch_idx * channels + channel_idx) * input_length + input_pos; val = )SHADER" - << input.GetByOffset("input_idx") << R"SHADER(; + << input.GetByOffset("input_idx") << R"SHADER(; } )SHADER"; } shader.MainFunctionBody() << R"SHADER( let w = )SHADER" - << weight.GetByOffset("weight_base + j") << R"SHADER(; + << weight.GetByOffset("weight_base + j") << R"SHADER(; acc = acc + val * w; } )SHADER"; @@ -153,7 +153,7 @@ fn silu(x: input_element_t) -> input_element_t { shader.MainFunctionBody() << R"SHADER( let out_idx = (batch_idx * channels + channel_idx) * input_length + pos; )SHADER" << output.SetByOffset("out_idx", "acc") - << "\n"; + << "\n"; // Write present_state: the last (kernel_size - 1) elements from the // virtual input [conv_state, input]. The virtual input has total length @@ -178,12 +178,12 @@ fn silu(x: input_element_t) -> input_element_t { if (vp < state_length) { let si = (batch_idx * channels + channel_idx) * state_length + vp; state_val = )SHADER" - << conv_state_ptr->GetByOffset("si") << R"SHADER(; + << conv_state_ptr->GetByOffset("si") << R"SHADER(; } else { let ip = vp - state_length; let ii = (batch_idx * channels + channel_idx) * input_length + ip; state_val = )SHADER" - << input.GetByOffset("ii") << R"SHADER(; + << input.GetByOffset("ii") << R"SHADER(; } )SHADER"; } else { @@ -192,7 +192,7 @@ fn silu(x: input_element_t) -> input_element_t { let ip = vp - state_length; let ii = (batch_idx * channels + channel_idx) * input_length + ip; state_val = )SHADER" - << input.GetByOffset("ii") << R"SHADER(; + << input.GetByOffset("ii") << R"SHADER(; } )SHADER"; } @@ -200,7 +200,7 @@ fn silu(x: input_element_t) -> input_element_t { shader.MainFunctionBody() << R"SHADER( let ps_idx = (batch_idx * channels + channel_idx) * state_length + s; )SHADER" << present_state.SetByOffset("ps_idx", "state_val") - << R"SHADER( + << R"SHADER( } } )SHADER"; @@ -209,9 +209,9 @@ fn silu(x: input_element_t) -> input_element_t { } Status CausalConv1DWithState::ComputeInternal(ComputeContext& context) const { - const Tensor* input = context.Input(0); // (B, D, L) - const Tensor* weight = context.Input(1); // (D, 1, K) - const Tensor* bias = context.Input(2); // optional (D,) + const Tensor* input = context.Input(0); // (B, D, L) + const Tensor* weight = context.Input(1); // (D, 1, K) + const Tensor* bias = context.Input(2); // optional (D,) const Tensor* conv_state = context.Input(3); // optional (B, D, K-1) ORT_RETURN_IF(input == nullptr, "Input tensor must not be null"); diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index d0f7767452f49..f0b68b1ae98d1 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -2289,8 +2289,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA( auto& input_shape = getInputShape(ctx, 0); auto& weight_shape = getInputShape(ctx, 1); TensorShapeProto state_shape; - *state_shape.add_dim() = input_shape.dim(0); // batch_size - *state_shape.add_dim() = input_shape.dim(1); // channels + *state_shape.add_dim() = input_shape.dim(0); // batch_size + *state_shape.add_dim() = input_shape.dim(1); // channels // kernel_size - 1 if (weight_shape.dim(2).has_dim_value()) { state_shape.add_dim()->set_dim_value(weight_shape.dim(2).dim_value() - 1); diff --git a/onnxruntime/core/providers/webgpu/tensor/concat.cc b/onnxruntime/core/providers/webgpu/tensor/concat.cc index ed031654cf793..55f4e2c5d0e5f 100644 --- a/onnxruntime/core/providers/webgpu/tensor/concat.cc +++ b/onnxruntime/core/providers/webgpu/tensor/concat.cc @@ -12,25 +12,25 @@ namespace onnxruntime { namespace webgpu { -#define WEBGPU_CONCAT_VERSIONED_KERNEL(start, end) \ - ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ - Concat, \ - kOnnxDomain, \ - start, \ - end, \ - kWebGpuExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", WebGpuSupportedNumberAndInt64Types()), \ +#define WEBGPU_CONCAT_VERSIONED_KERNEL(start, end) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + Concat, \ + kOnnxDomain, \ + start, \ + end, \ + kWebGpuExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", WebGpuSupportedNumberTypes()), \ Concat); -#define WEBGPU_CONCAT_KERNEL(version) \ - ONNX_OPERATOR_KERNEL_EX( \ - Concat, \ - kOnnxDomain, \ - version, \ - kWebGpuExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", WebGpuSupportedNumberAndInt64Types()), \ +#define WEBGPU_CONCAT_KERNEL(version) \ + ONNX_OPERATOR_KERNEL_EX( \ + Concat, \ + kOnnxDomain, \ + version, \ + kWebGpuExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", WebGpuSupportedNumberTypes()), \ Concat); WEBGPU_CONCAT_VERSIONED_KERNEL(1, 3) diff --git a/onnxruntime/test/contrib_ops/causal_conv1d_with_state_op_test.cc b/onnxruntime/test/contrib_ops/causal_conv1d_with_state_op_test.cc index 728389a4fb869..7f66707c32453 100644 --- a/onnxruntime/test/contrib_ops/causal_conv1d_with_state_op_test.cc +++ b/onnxruntime/test/contrib_ops/causal_conv1d_with_state_op_test.cc @@ -389,8 +389,8 @@ TEST(CausalConv1DWithStateTest, MultiBatch) { // Input: (2, 2, 3) std::vector input_data = { // Batch 0 - 1.0f, 2.0f, 3.0f, // ch 0 - 0.5f, 1.5f, 2.5f, // ch 1 + 1.0f, 2.0f, 3.0f, // ch 0 + 0.5f, 1.5f, 2.5f, // ch 1 // Batch 1 -1.0f, 0.0f, 1.0f, // ch 0 0.2f, 0.4f, 0.6f}; // ch 1 @@ -404,11 +404,11 @@ TEST(CausalConv1DWithStateTest, MultiBatch) { // State: (2, 2, 2) std::vector conv_state_data = { // Batch 0 - -0.5f, 0.5f, // ch 0 - 0.3f, -0.3f, // ch 1 + -0.5f, 0.5f, // ch 0 + 0.3f, -0.3f, // ch 1 // Batch 1 - 0.1f, -0.1f, // ch 0 - 0.7f, 0.8f}; // ch 1 + 0.1f, -0.1f, // ch 0 + 0.7f, 0.8f}; // ch 1 RunCausalConv1DWithStateTests( input_data, weight_data, &bias_data, &conv_state_data, @@ -436,10 +436,10 @@ TEST(CausalConv1DWithStateTest, SingleTokenDecode) { // State: (1, 4, 3) - carrying the last 3 values per channel std::vector conv_state_data = { - 1.0f, 2.0f, 3.0f, // ch 0 - -1.0f, 0.0f, 1.0f, // ch 1 - 0.5f, 0.5f, 0.5f, // ch 2 - -0.2f, 0.4f, -0.6f}; // ch 3 + 1.0f, 2.0f, 3.0f, // ch 0 + -1.0f, 0.0f, 1.0f, // ch 1 + 0.5f, 0.5f, 0.5f, // ch 2 + -0.2f, 0.4f, -0.6f}; // ch 3 RunCausalConv1DWithStateTests( input_data, weight_data, &bias_data, &conv_state_data, @@ -451,10 +451,10 @@ TEST(CausalConv1DWithStateTest, SingleTokenDecodeMultiBatch) { // Input: (2, 2, 1) std::vector input_data = { - 0.5f, // B0, ch 0 - -0.3f, // B0, ch 1 - 1.2f, // B1, ch 0 - 0.8f}; // B1, ch 1 + 0.5f, // B0, ch 0 + -0.3f, // B0, ch 1 + 1.2f, // B1, ch 0 + 0.8f}; // B1, ch 1 std::vector weight_data = { 0.1f, 0.2f, 0.3f, @@ -462,10 +462,10 @@ TEST(CausalConv1DWithStateTest, SingleTokenDecodeMultiBatch) { // State: (2, 2, 2) std::vector conv_state_data = { - 1.0f, 2.0f, // B0, ch 0 - -1.0f, 0.0f, // B0, ch 1 - 0.5f, 0.5f, // B1, ch 0 - -0.2f, 0.4f}; // B1, ch 1 + 1.0f, 2.0f, // B0, ch 0 + -1.0f, 0.0f, // B0, ch 1 + 0.5f, 0.5f, // B1, ch 0 + -0.2f, 0.4f}; // B1, ch 1 RunCausalConv1DWithStateTests( input_data, weight_data, nullptr, &conv_state_data, @@ -566,8 +566,8 @@ TEST(CausalConv1DWithStateTest, SequenceVsTokenByToken) { for (int t = 0; t < 4; ++t) { // Extract single token: (1, 2, 1) std::vector token_input = { - full_input[0 * 4 + t], // ch 0 - full_input[1 * 4 + t]}; // ch 1 + full_input[0 * 4 + t], // ch 0 + full_input[1 * 4 + t]}; // ch 1 std::vector token_output; std::vector next_state; diff --git a/onnxruntime/test/contrib_ops/linear_attention_op_test.cc b/onnxruntime/test/contrib_ops/linear_attention_op_test.cc index e17031a98b300..6b80a55c9ecec 100644 --- a/onnxruntime/test/contrib_ops/linear_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/linear_attention_op_test.cc @@ -366,21 +366,19 @@ TEST(ContribOpLinearAttentionTest, LinearAttentionRecurrent_Linear_Basic) { // Query: (1, 2, 1, 4) std::vector query_data = { - 0.5f, 0.3f, -0.2f, 0.1f, // head 0 - -0.4f, 0.6f, 0.2f, -0.3f // head 1 + 0.5f, 0.3f, -0.2f, 0.1f, // head 0 + -0.4f, 0.6f, 0.2f, -0.3f // head 1 }; // Key: (1, 2, 1, 4) std::vector key_data = { 0.1f, 0.2f, 0.3f, 0.4f, - 0.2f, -0.1f, 0.3f, 0.1f - }; + 0.2f, -0.1f, 0.3f, 0.1f}; // Value: (1, 2, 1, 4) std::vector value_data = { 0.4f, 0.3f, 0.2f, 0.1f, - -0.2f, 0.4f, 0.1f, 0.3f - }; + -0.2f, 0.4f, 0.1f, 0.3f}; // Past state: (1, 2, 4, 4) - initialized to small values std::vector past_state_data(batch_size * num_heads * head_dim_k * head_dim_v, 0.1f); @@ -400,26 +398,22 @@ TEST(ContribOpLinearAttentionTest, LinearAttentionRecurrent_Gated_Basic) { std::vector query_data = { 0.5f, 0.3f, -0.2f, 0.1f, - -0.4f, 0.6f, 0.2f, -0.3f - }; + -0.4f, 0.6f, 0.2f, -0.3f}; std::vector key_data = { 0.1f, 0.2f, 0.3f, 0.4f, - 0.2f, -0.1f, 0.3f, 0.1f - }; + 0.2f, -0.1f, 0.3f, 0.1f}; std::vector value_data = { 0.4f, 0.3f, 0.2f, 0.1f, - -0.2f, 0.4f, 0.1f, 0.3f - }; + -0.2f, 0.4f, 0.1f, 0.3f}; std::vector past_state_data(batch_size * num_heads * head_dim_k * head_dim_v, 0.1f); // Decay: (1, 2, 1, 4) - negative values for decay std::vector decay_data = { -0.1f, -0.1f, -0.1f, -0.1f, - -0.2f, -0.2f, -0.2f, -0.2f - }; + -0.2f, -0.2f, -0.2f, -0.2f}; RunLinearAttentionRecurrentTests( query_data, key_data, value_data, past_state_data, @@ -436,19 +430,17 @@ TEST(ContribOpLinearAttentionTest, LinearAttentionRecurrent_Delta_Basic) { std::vector query_data = { 0.5f, 0.3f, -0.2f, 0.1f, - -0.4f, 0.6f, 0.2f, -0.3f - }; + -0.4f, 0.6f, 0.2f, -0.3f}; // L2-normalized keys for delta rule std::vector key_data = { 0.1826f, 0.3651f, 0.5477f, 0.7303f, // normalized - 0.5345f, -0.2673f, 0.8018f, 0.2673f // normalized + 0.5345f, -0.2673f, 0.8018f, 0.2673f // normalized }; std::vector value_data = { 0.4f, 0.3f, 0.2f, 0.1f, - -0.2f, 0.4f, 0.1f, 0.3f - }; + -0.2f, 0.4f, 0.1f, 0.3f}; std::vector past_state_data(batch_size * num_heads * head_dim_k * head_dim_v, 0.1f); @@ -470,27 +462,23 @@ TEST(ContribOpLinearAttentionTest, LinearAttentionRecurrent_GatedDelta_Basic) { std::vector query_data = { 0.5f, 0.3f, -0.2f, 0.1f, - -0.4f, 0.6f, 0.2f, -0.3f - }; + -0.4f, 0.6f, 0.2f, -0.3f}; // L2-normalized keys std::vector key_data = { 0.1826f, 0.3651f, 0.5477f, 0.7303f, - 0.5345f, -0.2673f, 0.8018f, 0.2673f - }; + 0.5345f, -0.2673f, 0.8018f, 0.2673f}; std::vector value_data = { 0.4f, 0.3f, 0.2f, 0.1f, - -0.2f, 0.4f, 0.1f, 0.3f - }; + -0.2f, 0.4f, 0.1f, 0.3f}; std::vector past_state_data(batch_size * num_heads * head_dim_k * head_dim_v, 0.1f); // Decay: (1, 2, 1, 4) std::vector decay_data = { -0.1f, -0.1f, -0.1f, -0.1f, - -0.2f, -0.2f, -0.2f, -0.2f - }; + -0.2f, -0.2f, -0.2f, -0.2f}; // Beta: (1, 2, 1, 1) std::vector beta_data = {0.5f, 0.7f}; From 8e09ff37f88d1559789b8c80b27c345a894b95cc Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Mon, 16 Mar 2026 21:49:21 -0700 Subject: [PATCH 12/27] webgpu LinearAttention --- .../webgpu/bert/linear_attention.cc | 1031 ++++----------- .../webgpu/bert/linear_attention.h | 118 +- .../webgpu/webgpu_contrib_kernels.cc | 6 +- .../core/graph/contrib_ops/bert_defs.cc | 133 +- onnxruntime/core/graph/contrib_ops/ms_opset.h | 6 +- .../core/providers/webgpu/tensor/transpose.cc | 11 +- .../webgpu/webgpu_execution_provider.cc | 6 +- .../core/providers/webgpu/webgpu_kernel.h | 22 + .../contrib_ops/linear_attention_op_test.cc | 1108 ++++++----------- 9 files changed, 764 insertions(+), 1677 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc index 973bce0f5dda3..e94fba5041385 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc @@ -27,799 +27,300 @@ LinearAttentionUpdateRule ParseUpdateRule(const std::string& rule_str) { } // ============================================================================= -// LinearAttentionRecurrent Implementation +// LinearAttention Shader Implementation // ============================================================================= - -ONNX_OPERATOR_KERNEL_EX( - LinearAttentionRecurrent, - kMSDomain, - 1, - kWebGpuExecutionProvider, - (*KernelDefBuilder::Create()) - .TypeConstraint("T", WebGpuSupportedFloatTypes()), - LinearAttentionRecurrent); - -LinearAttentionRecurrent::LinearAttentionRecurrent(const OpKernelInfo& info) - : WebGpuKernel(info) { - std::string update_rule_str = info.GetAttrOrDefault("update_rule", "gated_delta"); - update_rule_ = ParseUpdateRule(update_rule_str); - scale_ = info.GetAttrOrDefault("scale", 0.0f); - chunk_size_ = info.GetAttrOrDefault("chunk_size", 64); -} - -Status LinearAttentionRecurrentProgram::GenerateShaderCode(ShaderHelper& shader) const { - // Input tensors - with proper accessor methods and element type alias for scaling - const auto& query = shader.AddInput("query", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias); - const auto& key = shader.AddInput("key", ShaderUsage::UseUniform); - const auto& value = shader.AddInput("value", ShaderUsage::UseUniform); - const auto& past_state = shader.AddInput("past_state", ShaderUsage::UseUniform); - - // Optional inputs based on update rule - const ShaderVariableHelper* decay_ptr = nullptr; - const ShaderVariableHelper* beta_ptr = nullptr; +// +// Design overview: +// - Each workgroup handles one (batch, head, dv_tile) combination +// - Workgroup size = head_dim_k (dk): one thread per state row +// - Each thread maintains TILE_V columns of its state row in private memory +// - Tokens are processed sequentially; matrix ops are parallelized across threads +// - Reductions across dk (for S^T @ k and S^T @ q) use shared memory +// + +Status LinearAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { + // Add inputs + shader.AddInput("query", ShaderUsage::UseUniform); + shader.AddInput("key", ShaderUsage::UseUniform); + shader.AddInput("value", ShaderUsage::UseUniform); + if (has_initial_state_) { + shader.AddInput("initial_state", ShaderUsage::UseUniform); + } if (has_decay_) { - decay_ptr = &shader.AddInput("decay", ShaderUsage::UseUniform); + shader.AddInput("decay", ShaderUsage::UseUniform); } if (has_beta_) { - beta_ptr = &shader.AddInput("beta", ShaderUsage::UseUniform); - } - - // Output tensors - const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform); - const auto& present_state = shader.AddOutput("present_state", ShaderUsage::UseUniform); - - // Each workgroup handles one (batch, head) pair - // Within the workgroup, we compute the state update and output - shader.MainFunctionBody() << R"SHADER( - let batch_idx = workgroup_id.x; - let head_idx = workgroup_id.y; - let local_k = local_id.x; - let local_v = local_id.y; - - // Bounds check - if (batch_idx >= uniforms.batch_size || head_idx >= uniforms.num_heads) { - return; - } - - let head_dim_k = uniforms.head_dim_k; - let head_dim_v = uniforms.head_dim_v; - // Cast scale factor to element type to match tensor data type - let scale_factor = query_element_t(select(1.0 / sqrt(f32(head_dim_k)), uniforms.scale, uniforms.scale != 0.0)); - - // Compute base offsets - let qkv_base = (batch_idx * uniforms.num_heads + head_idx) * head_dim_k; - let v_base = (batch_idx * uniforms.num_heads + head_idx) * head_dim_v; - let state_base = (batch_idx * uniforms.num_heads + head_idx) * head_dim_k * head_dim_v; - - // Process state update for this (k, v) element - if (local_k < head_dim_k && local_v < head_dim_v) { - let state_idx = state_base + local_k * head_dim_v + local_v; - - // Load current state value -)SHADER"; - - shader.MainFunctionBody() << " var state_val = " << past_state.GetByOffset("state_idx") << ";\n"; - - // Load k and v values - shader.MainFunctionBody() << " let k_val = " << key.GetByOffset("qkv_base + local_k") << ";\n"; - shader.MainFunctionBody() << " let v_val = " << value.GetByOffset("v_base + local_v") << ";\n"; - - // Apply decay if needed (gated or gated_delta) - if (update_rule_ == LinearAttentionUpdateRule::Gated || update_rule_ == LinearAttentionUpdateRule::GatedDelta) { - shader.MainFunctionBody() << " // Load decay and compute exp(decay) - decay is in log space\n"; - shader.MainFunctionBody() << " let decay_val = " << decay_ptr->GetByOffset("qkv_base + local_k") << ";\n"; - shader.MainFunctionBody() << " let exp_decay = exp(decay_val);\n"; - shader.MainFunctionBody() << " state_val = state_val * exp_decay;\n"; - } - - // Compute the update delta based on update rule - if (update_rule_ == LinearAttentionUpdateRule::Linear) { - shader.MainFunctionBody() << R"SHADER( - // Linear update: S += k ⊗ v - let update = k_val * v_val; - state_val = state_val + update; -)SHADER"; - } else if (update_rule_ == LinearAttentionUpdateRule::Gated) { - shader.MainFunctionBody() << R"SHADER( - // Gated update: S = exp(g) * S + k ⊗ v (decay already applied above) - let update = k_val * v_val; - state_val = state_val + update; -)SHADER"; - } else if (update_rule_ == LinearAttentionUpdateRule::Delta) { - // Delta update requires computing retrieved = S^T @ k - shader.MainFunctionBody() << " // Delta update: S += β * k ⊗ (v - S^T k)\n"; - shader.MainFunctionBody() << " var retrieved = " << past_state.GetByOffset("state_base + 0u * head_dim_v + local_v") - << " * " << key.GetByOffset("qkv_base + 0u") << ";\n"; - shader.MainFunctionBody() << " for (var k_i: u32 = 1u; k_i < head_dim_k; k_i = k_i + 1u) {\n"; - shader.MainFunctionBody() << " let s_idx = state_base + k_i * head_dim_v + local_v;\n"; - shader.MainFunctionBody() << " retrieved = retrieved + " << past_state.GetByOffset("s_idx") - << " * " << key.GetByOffset("qkv_base + k_i") << ";\n"; - shader.MainFunctionBody() << " }\n"; - shader.MainFunctionBody() << " let beta_val = " << beta_ptr->GetByOffset("(batch_idx * uniforms.num_heads + head_idx)") << ";\n"; - shader.MainFunctionBody() << " let delta = beta_val * (v_val - retrieved);\n"; - shader.MainFunctionBody() << " let update = k_val * delta;\n"; - shader.MainFunctionBody() << " state_val = state_val + update;\n"; - } else { // GatedDelta - // Gated Delta update - shader.MainFunctionBody() << " // Gated Delta update: S = exp(g) * S + β * k ⊗ (v - exp(g) * S^T k)\n"; - shader.MainFunctionBody() << " var retrieved = " << past_state.GetByOffset("state_base + 0u * head_dim_v + local_v") - << " * exp(" << decay_ptr->GetByOffset("qkv_base + 0u") << ")" - << " * " << key.GetByOffset("qkv_base + 0u") << ";\n"; - shader.MainFunctionBody() << " for (var k_i: u32 = 1u; k_i < head_dim_k; k_i = k_i + 1u) {\n"; - shader.MainFunctionBody() << " let s_idx = state_base + k_i * head_dim_v + local_v;\n"; - shader.MainFunctionBody() << " let decay_k = " << decay_ptr->GetByOffset("qkv_base + k_i") << ";\n"; - shader.MainFunctionBody() << " retrieved = retrieved + " << past_state.GetByOffset("s_idx") - << " * exp(decay_k) * " << key.GetByOffset("qkv_base + k_i") << ";\n"; - shader.MainFunctionBody() << " }\n"; - shader.MainFunctionBody() << " let beta_val = " << beta_ptr->GetByOffset("(batch_idx * uniforms.num_heads + head_idx)") << ";\n"; - shader.MainFunctionBody() << " let delta = beta_val * (v_val - retrieved);\n"; - shader.MainFunctionBody() << " let update = k_val * delta;\n"; - shader.MainFunctionBody() << " state_val = state_val + update;\n"; + shader.AddInput("beta", ShaderUsage::UseUniform); } - // Write updated state and compute output - shader.MainFunctionBody() << " // Write updated state\n"; - shader.MainFunctionBody() << " " << present_state.SetByOffset("state_idx", "state_val") << "\n"; - shader.MainFunctionBody() << " }\n"; - - shader.MainFunctionBody() << R"SHADER( - // Synchronize before computing output - workgroupBarrier(); - - // Compute output: o = scale * q^T @ S - // Each thread computes one element of the output - if (local_k == 0u && local_v < head_dim_v) { -)SHADER"; - - shader.MainFunctionBody() << " var out_val = " << query.GetByOffset("qkv_base + 0u") - << " * " << present_state.GetByOffset("state_base + 0u * head_dim_v + local_v") << ";\n"; - shader.MainFunctionBody() << " for (var k_i: u32 = 1u; k_i < head_dim_k; k_i = k_i + 1u) {\n"; - shader.MainFunctionBody() << " let q_val = " << query.GetByOffset("qkv_base + k_i") << ";\n"; - shader.MainFunctionBody() << " let s_idx = state_base + k_i * head_dim_v + local_v;\n"; - shader.MainFunctionBody() << " out_val = out_val + q_val * " << present_state.GetByOffset("s_idx") << ";\n"; - shader.MainFunctionBody() << " }\n"; - shader.MainFunctionBody() << " " << output.SetByOffset("v_base + local_v", "out_val * scale_factor") << "\n"; - shader.MainFunctionBody() << " }\n"; - - return Status::OK(); -} - -Status LinearAttentionRecurrent::ComputeInternal(ComputeContext& context) const { - const auto* query = context.Input(0); - const auto* key = context.Input(1); - const auto* value = context.Input(2); - const auto* initial_state = context.Input(3); // past_state for recurrent, initial_state (optional) for chunk-parallel - const auto* decay = context.Input(4); // Optional - const auto* beta = context.Input(5); // Optional - - const auto& query_shape = query->Shape(); - ORT_ENFORCE(query_shape.NumDimensions() == 4, "Query must be 4D: (B, H, L, d_k)"); - - const auto batch_size = static_cast(query_shape[0]); - const auto num_heads = static_cast(query_shape[1]); - const auto seq_length = static_cast(query_shape[2]); - const auto head_dim_k = static_cast(query_shape[3]); - const auto head_dim_v = static_cast(value->Shape()[3]); - - bool has_initial_state = (initial_state != nullptr); - bool has_decay = (decay != nullptr); - bool has_beta = (beta != nullptr); + // Add outputs + shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias); + shader.AddOutput("final_state", ShaderUsage::UseUniform); - // Validate decay and beta based on update rule + // Shared memory for parallel reduction across dk threads + // and for broadcasting delta values + // TILE_V is emitted as a compile-time constant (not overridable) because + // private address space arrays require fixed sizes in WGSL. + shader.AdditionalImplementation() + << "const TILE_V: u32 = " << tile_v_ << "u;\n" + << "var reduction_buf: array;\n" + << "var broadcast_buf: array;\n"; + + shader.MainFunctionBody() + // Identify which (batch, head, dv_tile) this workgroup handles + // workgroup_idx is already defined by the framework + << "let bh = workgroup_idx / uniforms.num_dv_tiles;\n" + << "let dv_tile_idx = workgroup_idx % uniforms.num_dv_tiles;\n" + << "let batch_idx = bh / uniforms.num_heads;\n" + << "let head_idx = bh % uniforms.num_heads;\n" + << "let dk_idx = local_idx; // thread index = row in state matrix\n" + << "let dv_start = dv_tile_idx * TILE_V;\n" + << "\n" + // Initialize state tile in private memory + << "var state: array;\n" + << "for (var j = 0u; j < TILE_V; j++) {\n" + << " state[j] = 0.0;\n" + << "}\n"; + + // Load initial state if provided + if (has_initial_state_) { + shader.MainFunctionBody() + << "// Load initial state: initial_state[batch, head, dk_idx, dv_start..dv_start+TILE_V]\n" + << "let state_base = ((batch_idx * uniforms.num_heads + head_idx) * uniforms.head_dim_k + dk_idx) * uniforms.head_dim_v + dv_start;\n" + << "for (var j = 0u; j < TILE_V; j++) {\n" + << " if (dv_start + j < uniforms.head_dim_v) {\n" + << " state[j] = f32(initial_state[state_base + j]);\n" + << " }\n" + << "}\n"; + } + + // Main token processing loop + shader.MainFunctionBody() + << "\n// Process each token sequentially\n" + << "for (var t = 0u; t < uniforms.seq_length; t++) {\n" + // Load k and q for this thread's dk row + << " let qkv_bh_offset = (batch_idx * uniforms.num_heads + head_idx) * uniforms.seq_length;\n" + << " let k_base = (qkv_bh_offset + t) * uniforms.head_dim_k + dk_idx;\n" + << " let k_val = f32(key[k_base]);\n" + << " let q_val = f32(query[k_base]);\n"; + + // Step 1: Apply decay (for gated and gated_delta modes) if (update_rule_ == LinearAttentionUpdateRule::Gated || update_rule_ == LinearAttentionUpdateRule::GatedDelta) { - ORT_ENFORCE(has_decay, "Decay input is required for gated and gated_delta update rules"); - } - if (update_rule_ == LinearAttentionUpdateRule::Delta || update_rule_ == LinearAttentionUpdateRule::GatedDelta) { - ORT_ENFORCE(has_beta, "Beta input is required for delta and gated_delta update rules"); - } - - // seq_length == 1: single-step recurrent path - if (seq_length == 1) { - ORT_ENFORCE(has_initial_state, "past_state input is required for single-step recurrent mode"); - - TensorShape output_shape({static_cast(batch_size), static_cast(num_heads), 1, static_cast(head_dim_v)}); - auto* output = context.Output(0, output_shape); - auto* present_state = context.Output(1, initial_state->Shape()); - - LinearAttentionRecurrentProgram program{update_rule_, has_decay, has_beta}; - - program.AddInputs({{query, ProgramTensorMetadataDependency::TypeAndRank}, - {key, ProgramTensorMetadataDependency::TypeAndRank}, - {value, ProgramTensorMetadataDependency::TypeAndRank}, - {initial_state, ProgramTensorMetadataDependency::TypeAndRank}}); - - if (has_decay) { - program.AddInput({decay, ProgramTensorMetadataDependency::TypeAndRank}); - } - if (has_beta) { - program.AddInput({beta, ProgramTensorMetadataDependency::TypeAndRank}); - } - - program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank}, - {present_state, ProgramTensorMetadataDependency::TypeAndRank}}); - - const uint32_t workgroup_size_k = std::min(head_dim_k, 16u); - const uint32_t workgroup_size_v = std::min(head_dim_v, 16u); - - program.SetDispatchGroupSize(batch_size, num_heads, 1) - .SetWorkgroupSize(workgroup_size_k, workgroup_size_v, 1) - .AddUniformVariables({{batch_size}, - {num_heads}, - {head_dim_k}, - {head_dim_v}, - {scale_}}); - - return context.RunProgram(program); + shader.MainFunctionBody() + << "\n // Apply exponential decay: S *= exp(decay)\n" + << " let decay_base = (qkv_bh_offset + t) * uniforms.head_dim_k + dk_idx;\n" + << " let exp_g = exp(f32(decay[decay_base]));\n" + << " for (var j = 0u; j < TILE_V; j++) {\n" + << " state[j] *= exp_g;\n" + << " }\n"; } - // seq_length > 1: chunk-parallel path - const uint32_t chunk_size = static_cast(chunk_size_); - const uint32_t num_chunks = (seq_length + chunk_size - 1) / chunk_size; - - TensorShape output_shape({static_cast(batch_size), static_cast(num_heads), - static_cast(seq_length), static_cast(head_dim_v)}); - TensorShape state_shape({static_cast(batch_size), static_cast(num_heads), - static_cast(head_dim_k), static_cast(head_dim_v)}); - - auto* output = context.Output(0, output_shape); - auto* final_state = context.Output(1, state_shape); - - // For delta/gated_delta rules, use sequential computation. - // Chunk-parallel decomposition doesn't work because state updates depend on the - // running state through the S^T k term, making chunks non-independent. + // Step 2: For delta/gated_delta rules, compute retrieved = S^T @ k (reduction across dk) if (update_rule_ == LinearAttentionUpdateRule::Delta || update_rule_ == LinearAttentionUpdateRule::GatedDelta) { - LinearAttentionFullSequentialProgram program{update_rule_, has_decay, has_beta, has_initial_state}; - - program.AddInputs({{query, ProgramTensorMetadataDependency::TypeAndRank}, - {key, ProgramTensorMetadataDependency::TypeAndRank}, - {value, ProgramTensorMetadataDependency::TypeAndRank}}); - - if (has_initial_state) { - program.AddInput({initial_state, ProgramTensorMetadataDependency::TypeAndRank}); - } - if (has_decay) { - program.AddInput({decay, ProgramTensorMetadataDependency::TypeAndRank}); - } - if (has_beta) { - program.AddInput({beta, ProgramTensorMetadataDependency::TypeAndRank}); - } - - program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank}, - {final_state, ProgramTensorMetadataDependency::TypeAndRank}}); - - program.SetDispatchGroupSize(batch_size, num_heads, 1) - .SetWorkgroupSize(1, 1, 1) - .AddUniformVariables({{batch_size}, - {num_heads}, - {seq_length}, - {head_dim_k}, - {head_dim_v}, - {scale_}}); - - return context.RunProgram(program); - } - - // Linear/Gated rules: Use two-phase chunk-parallel approach - TensorShape chunk_states_shape({static_cast(batch_size), static_cast(num_heads), - static_cast(num_chunks), static_cast(head_dim_k), - static_cast(head_dim_v)}); - - Tensor intra_output_tensor = context.CreateGPUTensor(query->DataType(), output_shape); - Tensor chunk_states_tensor = context.CreateGPUTensor(query->DataType(), chunk_states_shape); - - // Step 1: Compute intra-chunk attention and per-chunk states - { - LinearAttentionChunkIntraProgram intra_program{update_rule_, has_decay, has_beta}; - - intra_program.AddInputs({{query, ProgramTensorMetadataDependency::TypeAndRank}, - {key, ProgramTensorMetadataDependency::TypeAndRank}, - {value, ProgramTensorMetadataDependency::TypeAndRank}}); - - if (has_decay) { - intra_program.AddInput({decay, ProgramTensorMetadataDependency::TypeAndRank}); - } - if (has_beta) { - intra_program.AddInput({beta, ProgramTensorMetadataDependency::TypeAndRank}); - } - - intra_program.AddOutputs({{&intra_output_tensor, ProgramTensorMetadataDependency::TypeAndRank}, - {&chunk_states_tensor, ProgramTensorMetadataDependency::TypeAndRank}}); - - intra_program.SetDispatchGroupSize(batch_size, num_heads, num_chunks) - .SetWorkgroupSize(64, 1, 1) - .AddUniformVariables({{batch_size}, - {num_heads}, - {seq_length}, - {head_dim_k}, - {head_dim_v}, - {chunk_size}, - {num_chunks}, - {scale_}}); - - ORT_RETURN_IF_ERROR(context.RunProgram(intra_program)); - } + shader.MainFunctionBody() + << "\n // Compute retrieved = S^T @ k (parallel reduction over dk)\n" + << " for (var j = 0u; j < TILE_V; j++) {\n" + << " reduction_buf[j * workgroup_size_x + dk_idx] = state[j] * k_val;\n" + << " }\n" + << " workgroupBarrier();\n" + << " // Tree reduction\n" + << " for (var stride = workgroup_size_x >> 1u; stride > 0u; stride = stride >> 1u) {\n" + << " if (dk_idx < stride) {\n" + << " for (var j = 0u; j < TILE_V; j++) {\n" + << " reduction_buf[j * workgroup_size_x + dk_idx] += reduction_buf[j * workgroup_size_x + dk_idx + stride];\n" + << " }\n" + << " }\n" + << " workgroupBarrier();\n" + << " }\n" + // Thread 0 computes delta and broadcasts via shared memory + << " // Compute delta = beta * (v - retrieved) and broadcast\n" + << " let v_base = ((batch_idx * uniforms.num_heads + head_idx) * uniforms.seq_length + t) * uniforms.head_dim_v + dv_start;\n" + << " let beta_base = ((batch_idx * uniforms.num_heads + head_idx) * uniforms.seq_length + t);\n" + << " if (dk_idx == 0u) {\n" + << " let beta_val = f32(beta[beta_base]);\n" + << " for (var j = 0u; j < TILE_V; j++) {\n" + << " if (dv_start + j < uniforms.head_dim_v) {\n" + << " let retrieved_j = reduction_buf[j * workgroup_size_x];\n" + << " let v_val = f32(value[v_base + j]);\n" + << " broadcast_buf[j] = beta_val * (v_val - retrieved_j);\n" + << " }\n" + << " }\n" + << " }\n" + << " workgroupBarrier();\n" + // All threads update their state row using the broadcast delta + << " // Update state: S += k ⊗ delta\n" + << " for (var j = 0u; j < TILE_V; j++) {\n" + << " state[j] += k_val * broadcast_buf[j];\n" + << " }\n" + << " workgroupBarrier();\n"; + } else { + // For linear and gated modes: S += k ⊗ v (no delta rule) + shader.MainFunctionBody() + << "\n // Update state: S += k ⊗ v\n" + << " let v_base = ((batch_idx * uniforms.num_heads + head_idx) * uniforms.seq_length + t) * uniforms.head_dim_v + dv_start;\n" + << " for (var j = 0u; j < TILE_V; j++) {\n" + << " if (dv_start + j < uniforms.head_dim_v) {\n" + << " let v_val = f32(value[v_base + j]);\n" + << " state[j] += k_val * v_val;\n" + << " }\n" + << " }\n"; + } + + // Step 3: Compute output = scale * S^T @ q (reduction across dk) + shader.MainFunctionBody() + << "\n // Compute output = scale * S^T @ q (parallel reduction over dk)\n" + << " for (var j = 0u; j < TILE_V; j++) {\n" + << " reduction_buf[j * workgroup_size_x + dk_idx] = state[j] * q_val;\n" + << " }\n" + << " workgroupBarrier();\n" + << " for (var stride = workgroup_size_x >> 1u; stride > 0u; stride = stride >> 1u) {\n" + << " if (dk_idx < stride) {\n" + << " for (var j = 0u; j < TILE_V; j++) {\n" + << " reduction_buf[j * workgroup_size_x + dk_idx] += reduction_buf[j * workgroup_size_x + dk_idx + stride];\n" + << " }\n" + << " }\n" + << " workgroupBarrier();\n" + << " }\n" + // Thread 0 writes the output for this token and dv_tile + << " if (dk_idx == 0u) {\n" + << " let out_base = ((batch_idx * uniforms.num_heads + head_idx) * uniforms.seq_length + t) * uniforms.head_dim_v + dv_start;\n" + << " for (var j = 0u; j < TILE_V; j++) {\n" + << " if (dv_start + j < uniforms.head_dim_v) {\n" + << " output[out_base + j] = output_element_t(reduction_buf[j * workgroup_size_x] * uniforms.scale);\n" + << " }\n" + << " }\n" + << " }\n" + << " workgroupBarrier();\n" + << "}\n"; // end token loop - // Step 2: Inter-chunk state propagation and final output computation - { - LinearAttentionChunkInterProgram inter_program{update_rule_, has_decay, has_beta, has_initial_state}; - - inter_program.AddInputs({{&intra_output_tensor, ProgramTensorMetadataDependency::TypeAndRank}, - {&chunk_states_tensor, ProgramTensorMetadataDependency::TypeAndRank}, - {query, ProgramTensorMetadataDependency::TypeAndRank}}); - - if (has_initial_state) { - inter_program.AddInput({initial_state, ProgramTensorMetadataDependency::TypeAndRank}); - } - if (has_decay) { - inter_program.AddInput({decay, ProgramTensorMetadataDependency::TypeAndRank}); - } - - inter_program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank}, - {final_state, ProgramTensorMetadataDependency::TypeAndRank}}); - - inter_program.SetDispatchGroupSize(batch_size, num_heads, 1) - .SetWorkgroupSize(256, 1, 1) - .AddUniformVariables({{batch_size}, - {num_heads}, - {seq_length}, - {head_dim_k}, - {head_dim_v}, - {chunk_size}, - {num_chunks}, - {scale_}}); - - ORT_RETURN_IF_ERROR(context.RunProgram(inter_program)); - } + // Write final state + shader.MainFunctionBody() + << "\n// Write final state\n" + << "let final_state_base = ((batch_idx * uniforms.num_heads + head_idx) * uniforms.head_dim_k + dk_idx) * uniforms.head_dim_v + dv_start;\n" + << "for (var j = 0u; j < TILE_V; j++) {\n" + << " if (dv_start + j < uniforms.head_dim_v) {\n" + << " final_state[final_state_base + j] = output_element_t(state[j]);\n" + << " }\n" + << "}\n"; return Status::OK(); } // ============================================================================= -// LinearAttentionChunkParallel Implementation +// LinearAttention Kernel Registration and Computation // ============================================================================= ONNX_OPERATOR_KERNEL_EX( - LinearAttentionChunkParallel, + LinearAttention, kMSDomain, 1, kWebGpuExecutionProvider, (*KernelDefBuilder::Create()) .TypeConstraint("T", WebGpuSupportedFloatTypes()), - LinearAttentionChunkParallel); - -LinearAttentionChunkParallel::LinearAttentionChunkParallel(const OpKernelInfo& info) - : LinearAttentionRecurrent(info) { -} - -Status LinearAttentionChunkIntraProgram::GenerateShaderCode(ShaderHelper& shader) const { - // Inputs - referenced by name in WGSL shader - const auto& query = shader.AddInput("query", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias); - const auto& key = shader.AddInput("key", ShaderUsage::UseUniform); - const auto& value = shader.AddInput("value", ShaderUsage::UseUniform); - - std::string decay_name; - if (has_decay_) { - shader.AddInput("decay", ShaderUsage::UseUniform); - decay_name = "decay"; - } - if (has_beta_) { - shader.AddInput("beta", ShaderUsage::UseUniform); - } - - // Outputs - const auto& intra_output = shader.AddOutput("intra_output", ShaderUsage::UseUniform); - const auto& chunk_states = shader.AddOutput("chunk_states", ShaderUsage::UseUniform); - - // Compute intra-chunk causal attention - // For each position i in chunk, compute output using positions 0..i - shader.MainFunctionBody() << R"SHADER( - let batch_idx = workgroup_id.x; - let head_idx = workgroup_id.y; - let chunk_idx = workgroup_id.z; - - if (batch_idx >= uniforms.batch_size || head_idx >= uniforms.num_heads || chunk_idx >= uniforms.num_chunks) { - return; - } - - let head_dim_k = uniforms.head_dim_k; - let head_dim_v = uniforms.head_dim_v; - let chunk_size = uniforms.chunk_size; - let seq_len = uniforms.sequence_length; - let scale_factor = query_element_t(select(1.0 / sqrt(f32(head_dim_k)), uniforms.scale, uniforms.scale != 0.0)); - - // Chunk boundaries - let chunk_start = chunk_idx * chunk_size; - let chunk_end = min(chunk_start + chunk_size, seq_len); - let actual_chunk_size = chunk_end - chunk_start; - - // Base offsets - let bh_offset = batch_idx * uniforms.num_heads + head_idx; - - // Local thread handles one position in the chunk - let local_pos = local_id.x; - - if (local_pos < actual_chunk_size) { - let global_pos = chunk_start + local_pos; - - // Initialize local state for causal computation within chunk - // We need to accumulate state from positions 0..local_pos - let q_base = (bh_offset * seq_len + global_pos) * head_dim_k; - let out_base = (bh_offset * seq_len + global_pos) * head_dim_v; - - // Compute output for this position using causal mask within chunk - for (var v_i: u32 = 0u; v_i < head_dim_v; v_i = v_i + 1u) { - var out_val: query_element_t = query_element_t(0.0); - - // Accumulate contributions from positions 0 to local_pos (inclusive) - for (var src_pos: u32 = 0u; src_pos <= local_pos; src_pos = src_pos + 1u) { - let src_global = chunk_start + src_pos; - let k_base = (bh_offset * seq_len + src_global) * head_dim_k; - let v_base = (bh_offset * seq_len + src_global) * head_dim_v; - - // Compute q @ k^T for this position pair - var qk_dot: query_element_t = query_element_t(0.0); - for (var k_i: u32 = 0u; k_i < head_dim_k; k_i = k_i + 1u) { - qk_dot = qk_dot + )SHADER" << query.GetByOffset("q_base + k_i") << " * " << key.GetByOffset("k_base + k_i") << R"SHADER(; - } - - // For linear attention variants, we need to apply the appropriate weighting - let v_val = )SHADER" << value.GetByOffset("v_base + v_i") << R"SHADER(; -)SHADER"; - - // Apply decay-based weighting if needed - if (has_decay_) { - shader.MainFunctionBody() << R"SHADER( - // Compute cumulative decay from src_pos to local_pos - var cum_decay: query_element_t = query_element_t(0.0); - for (var d_pos: u32 = src_pos + 1u; d_pos <= local_pos; d_pos = d_pos + 1u) { - let d_global = chunk_start + d_pos; - // Average decay across k dimensions for simplicity - var avg_decay: query_element_t = query_element_t(0.0); - for (var k_i: u32 = 0u; k_i < head_dim_k; k_i = k_i + 1u) { - avg_decay = avg_decay + decay[(bh_offset * seq_len + d_global) * head_dim_k + k_i]; - } - cum_decay = cum_decay + avg_decay / query_element_t(head_dim_k); - } - let decay_weight = exp(cum_decay); - out_val = out_val + qk_dot * v_val * decay_weight; -)SHADER"; - } else { - shader.MainFunctionBody() << R"SHADER( - out_val = out_val + qk_dot * v_val; -)SHADER"; - } - - shader.MainFunctionBody() << R"SHADER( - } - - )SHADER" << intra_output.SetByOffset("out_base + v_i", "out_val * scale_factor") << R"SHADER(; - } - } + LinearAttention); - // Compute accumulated state at the end of this chunk - // Each thread contributes to building the chunk-end state - workgroupBarrier(); - - // Compute chunk-end state: accumulate k ⊗ v for all positions in chunk - let state_base = (bh_offset * uniforms.num_chunks + chunk_idx) * head_dim_k * head_dim_v; - - for (var k_i: u32 = local_id.x; k_i < head_dim_k; k_i = k_i + 64u) { - for (var v_i: u32 = 0u; v_i < head_dim_v; v_i = v_i + 1u) { - var state_val: query_element_t = query_element_t(0.0); - - for (var pos: u32 = 0u; pos < actual_chunk_size; pos = pos + 1u) { - let global_pos = chunk_start + pos; - let k_base = (bh_offset * seq_len + global_pos) * head_dim_k; - let v_base = (bh_offset * seq_len + global_pos) * head_dim_v; - - let k_val = )SHADER" << key.GetByOffset("k_base + k_i") << R"SHADER(; - let v_val = )SHADER" << value.GetByOffset("v_base + v_i") << R"SHADER(; -)SHADER"; - - if (has_decay_) { - shader.MainFunctionBody() << R"SHADER( - // Decay from this position to chunk end - var decay_to_end: query_element_t = query_element_t(0.0); - for (var d_pos: u32 = pos + 1u; d_pos < actual_chunk_size; d_pos = d_pos + 1u) { - let d_global = chunk_start + d_pos; - decay_to_end = decay_to_end + decay[(bh_offset * seq_len + d_global) * head_dim_k + k_i]; - } - state_val = state_val + k_val * v_val * exp(decay_to_end); -)SHADER"; - } else { - shader.MainFunctionBody() << R"SHADER( - state_val = state_val + k_val * v_val; -)SHADER"; - } - - shader.MainFunctionBody() << R"SHADER( - } - - let state_idx = state_base + k_i * head_dim_v + v_i; - )SHADER" << chunk_states.SetByOffset("state_idx", "state_val") << R"SHADER(; - } - } -)SHADER"; - - return Status::OK(); -} - -Status LinearAttentionChunkInterProgram::GenerateShaderCode(ShaderHelper& shader) const { - // Inputs - referenced by name in WGSL shader - shader.AddInput("intra_output", ShaderUsage::UseUniform); - shader.AddInput("chunk_states", ShaderUsage::UseUniform); - shader.AddInput("query", ShaderUsage::UseUniform); - - if (has_initial_state_) { - shader.AddInput("initial_state", ShaderUsage::UseUniform); - } - if (has_decay_) { - shader.AddInput("decay", ShaderUsage::UseUniform); - } - - // Outputs - referenced by name in WGSL shader - shader.AddOutput("output", ShaderUsage::UseUniform); - shader.AddOutput("final_state", ShaderUsage::UseUniform); - - // Propagate state between chunks and compute final output - shader.MainFunctionBody() << R"SHADER( - let batch_idx = workgroup_id.x; - let head_idx = workgroup_id.y; - - if (batch_idx >= uniforms.batch_size || head_idx >= uniforms.num_heads) { - return; - } - - let head_dim_k = uniforms.head_dim_k; - let head_dim_v = uniforms.head_dim_v; - let chunk_size = uniforms.chunk_size; - let num_chunks = uniforms.num_chunks; - let seq_len = uniforms.sequence_length; - let scale = select(1.0 / sqrt(f32(head_dim_k)), uniforms.scale, uniforms.scale != 0.0); - - let bh_offset = batch_idx * uniforms.num_heads + head_idx; - - // Process each sequence position - let pos = local_id.x; - if (pos < seq_len) { - let chunk_idx = pos / chunk_size; - let q_base = (bh_offset * seq_len + pos) * head_dim_k; - let out_base = (bh_offset * seq_len + pos) * head_dim_v; - - // Start with intra-chunk output - for (var v_i: u32 = 0u; v_i < head_dim_v; v_i = v_i + 1u) { - var out_val = intra_output[out_base + v_i]; - - // Add contribution from previous chunks' accumulated state - // This is q^T @ (sum of states from chunks 0 to chunk_idx-1) - for (var prev_chunk: u32 = 0u; prev_chunk < chunk_idx; prev_chunk = prev_chunk + 1u) { - let state_base = (bh_offset * num_chunks + prev_chunk) * head_dim_k * head_dim_v; - - for (var k_i: u32 = 0u; k_i < head_dim_k; k_i = k_i + 1u) { - let q_val = query[q_base + k_i]; - let state_val = chunk_states[state_base + k_i * head_dim_v + v_i]; -)SHADER"; - - if (has_decay_) { - shader.MainFunctionBody() << R"SHADER( - // Compute cumulative decay from end of prev_chunk to current position - var cum_decay: f32 = 0.0; - let prev_chunk_end = (prev_chunk + 1u) * chunk_size; - for (var d_pos: u32 = prev_chunk_end; d_pos <= pos; d_pos = d_pos + 1u) { - cum_decay = cum_decay + decay[(bh_offset * seq_len + d_pos) * head_dim_k + k_i]; - } - out_val = out_val + q_val * state_val * exp(cum_decay) * scale; -)SHADER"; - } else { - shader.MainFunctionBody() << R"SHADER( - out_val = out_val + q_val * state_val * scale; -)SHADER"; - } - - shader.MainFunctionBody() << R"SHADER( - } - } -)SHADER"; - - if (has_initial_state_) { - shader.MainFunctionBody() << R"SHADER( - // Add contribution from initial state - let init_state_base = bh_offset * head_dim_k * head_dim_v; - for (var k_i: u32 = 0u; k_i < head_dim_k; k_i = k_i + 1u) { - let q_val = query[q_base + k_i]; - let state_val = initial_state[init_state_base + k_i * head_dim_v + v_i]; -)SHADER"; - if (has_decay_) { - shader.MainFunctionBody() << R"SHADER( - // Decay from start to current position - var cum_decay: f32 = 0.0; - for (var d_pos: u32 = 0u; d_pos <= pos; d_pos = d_pos + 1u) { - cum_decay = cum_decay + decay[(bh_offset * seq_len + d_pos) * head_dim_k + k_i]; - } - out_val = out_val + q_val * state_val * exp(cum_decay) * scale; -)SHADER"; - } else { - shader.MainFunctionBody() << R"SHADER( - out_val = out_val + q_val * state_val * scale; -)SHADER"; - } - shader.MainFunctionBody() << R"SHADER( - } -)SHADER"; - } - - shader.MainFunctionBody() << R"SHADER( - output[out_base + v_i] = out_val; - } - } - - // Compute final state: sum all chunk states with appropriate decay - workgroupBarrier(); - - let final_state_base = bh_offset * head_dim_k * head_dim_v; - for (var idx: u32 = local_id.x; idx < head_dim_k * head_dim_v; idx = idx + 256u) { - let k_i = idx / head_dim_v; - let v_i = idx % head_dim_v; - - var state_val: f32 = 0.0; -)SHADER"; - - if (has_initial_state_) { - shader.MainFunctionBody() << R"SHADER( - // Start with initial state - let init_state_base = bh_offset * head_dim_k * head_dim_v; - state_val = initial_state[init_state_base + idx]; -)SHADER"; - if (has_decay_) { - shader.MainFunctionBody() << R"SHADER( - // Decay initial state through entire sequence - var total_decay: f32 = 0.0; - for (var d_pos: u32 = 0u; d_pos < seq_len; d_pos = d_pos + 1u) { - total_decay = total_decay + decay[(bh_offset * seq_len + d_pos) * head_dim_k + k_i]; - } - state_val = state_val * exp(total_decay); -)SHADER"; - } - } - - shader.MainFunctionBody() << R"SHADER( - // Accumulate all chunk states - for (var c: u32 = 0u; c < num_chunks; c = c + 1u) { - let chunk_state_base = (bh_offset * num_chunks + c) * head_dim_k * head_dim_v; - var chunk_val = chunk_states[chunk_state_base + idx]; -)SHADER"; - - if (has_decay_) { - shader.MainFunctionBody() << R"SHADER( - // Decay this chunk's state to end of sequence - let chunk_end = min((c + 1u) * chunk_size, seq_len); - var decay_to_end: f32 = 0.0; - for (var d_pos: u32 = chunk_end; d_pos < seq_len; d_pos = d_pos + 1u) { - decay_to_end = decay_to_end + decay[(bh_offset * seq_len + d_pos) * head_dim_k + k_i]; - } - chunk_val = chunk_val * exp(decay_to_end); -)SHADER"; - } - - shader.MainFunctionBody() << R"SHADER( - state_val = state_val + chunk_val; - } - - final_state[final_state_base + idx] = state_val; - } -)SHADER"; - - return Status::OK(); +LinearAttention::LinearAttention(const OpKernelInfo& info) + : WebGpuKernel(info) { + std::string update_rule_str = info.GetAttrOrDefault("update_rule", "gated_delta"); + update_rule_ = ParseUpdateRule(update_rule_str); + scale_ = info.GetAttrOrDefault("scale", 0.0f); + chunk_size_ = info.GetAttrOrDefault("chunk_size", 64); } -Status LinearAttentionFullSequentialProgram::GenerateShaderCode(ShaderHelper& shader) const { - // Full sequential computation for delta/gated_delta update rules. - // These rules have state updates that depend on the current state (S^T k term), - // making chunk-parallel decomposition incorrect. - shader.AddInput("query", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias); - shader.AddInput("key", ShaderUsage::UseUniform); - shader.AddInput("value", ShaderUsage::UseUniform); - - if (has_initial_state_) { - shader.AddInput("initial_state", ShaderUsage::UseUniform); - } - if (has_decay_) { - shader.AddInput("decay", ShaderUsage::UseUniform); - } - if (has_beta_) { - shader.AddInput("beta", ShaderUsage::UseUniform); - } - - shader.AddOutput("output", ShaderUsage::UseUniform); - shader.AddOutput("final_state", ShaderUsage::UseUniform); - - shader.MainFunctionBody() << R"SHADER( - let batch_idx = workgroup_id.x; - let head_idx = workgroup_id.y; - - if (batch_idx >= uniforms.batch_size || head_idx >= uniforms.num_heads) { - return; - } - - let dk = uniforms.head_dim_k; - let dv = uniforms.head_dim_v; - let seq_len = uniforms.sequence_length; - let scale_val = query_element_t(select(1.0 / sqrt(f32(dk)), uniforms.scale, uniforms.scale != 0.0)); - let bh = batch_idx * uniforms.num_heads + head_idx; - let state_size = dk * dv; - - // Initialize state array (supports up to 32x32 head dimensions) - var state: array; - for (var i = 0u; i < state_size; i = i + 1u) { - state[i] = query_element_t(0.0); - } -)SHADER"; - - if (has_initial_state_) { - shader.MainFunctionBody() << R"SHADER( - // Load initial state - let init_base = bh * state_size; - for (var i = 0u; i < state_size; i = i + 1u) { - state[i] = query_element_t(initial_state[init_base + i]); - } -)SHADER"; - } - - shader.MainFunctionBody() << R"SHADER( - // Process each timestep sequentially - for (var t = 0u; t < seq_len; t = t + 1u) { - let qk_base = (bh * seq_len + t) * dk; - let v_base = (bh * seq_len + t) * dv; -)SHADER"; - - if (has_decay_) { - shader.MainFunctionBody() << R"SHADER( - // Apply decay: state *= exp(decay) - for (var ki = 0u; ki < dk; ki = ki + 1u) { - let exp_g = query_element_t(exp(decay[qk_base + ki])); - for (var vi = 0u; vi < dv; vi = vi + 1u) { - state[ki * dv + vi] = state[ki * dv + vi] * exp_g; - } - } -)SHADER"; - } - - shader.MainFunctionBody() << R"SHADER( - // Delta update: S += beta * k \u2297 (v - S^T k) - let beta_val = query_element_t(beta[bh * seq_len + t]); - for (var vi = 0u; vi < dv; vi = vi + 1u) { - // Compute retrieved = S^T @ k for this v dimension - var retrieved = query_element_t(0.0); - for (var ki = 0u; ki < dk; ki = ki + 1u) { - retrieved = retrieved + state[ki * dv + vi] * query_element_t(key[qk_base + ki]); - } - let v_val = query_element_t(value[v_base + vi]); - let delta_val = beta_val * (v_val - retrieved); - - for (var ki = 0u; ki < dk; ki = ki + 1u) { - state[ki * dv + vi] = state[ki * dv + vi] + query_element_t(key[qk_base + ki]) * delta_val; - } - } - - // Compute output: o = scale * q^T @ state - let out_base = (bh * seq_len + t) * dv; - for (var vi = 0u; vi < dv; vi = vi + 1u) { - var out_val = query_element_t(0.0); - for (var ki = 0u; ki < dk; ki = ki + 1u) { - out_val = out_val + query_element_t(query[qk_base + ki]) * state[ki * dv + vi]; - } - output[out_base + vi] = out_val * scale_val; - } - } - - // Write final state - let final_base = bh * state_size; - for (var i = 0u; i < state_size; i = i + 1u) { - final_state[final_base + i] = state[i]; - } -)SHADER"; - - return Status::OK(); +Status LinearAttention::ComputeInternal(ComputeContext& context) const { + const Tensor* query = context.Input(0); + const Tensor* key = context.Input(1); + const Tensor* value = context.Input(2); + const Tensor* initial_state = context.Input(3); // optional + const Tensor* decay = context.Input(4); // optional + const Tensor* beta = context.Input(5); // optional + + // Validate inputs + const auto& q_shape = query->Shape(); + ORT_RETURN_IF(q_shape.NumDimensions() != 4, "query must be 4D (B, H, T, dk)"); + + const int batch_size = static_cast(q_shape[0]); + const int num_heads = static_cast(q_shape[1]); + const int seq_length = static_cast(q_shape[2]); + const int head_dim_k = static_cast(q_shape[3]); + const int head_dim_v = static_cast(value->Shape()[3]); + + // Validate update rule has required inputs + bool needs_decay = (update_rule_ == LinearAttentionUpdateRule::Gated || + update_rule_ == LinearAttentionUpdateRule::GatedDelta); + bool needs_beta = (update_rule_ == LinearAttentionUpdateRule::Delta || + update_rule_ == LinearAttentionUpdateRule::GatedDelta); + ORT_RETURN_IF(needs_decay && decay == nullptr, "decay input required for gated/gated_delta update rules"); + ORT_RETURN_IF(needs_beta && beta == nullptr, "beta input required for delta/gated_delta update rules"); + + // Compute scale + float scale = scale_; + if (scale == 0.0f) { + scale = 1.0f / std::sqrt(static_cast(head_dim_k)); + } + + // Allocate outputs + TensorShapeVector output_shape({batch_size, num_heads, seq_length, head_dim_v}); + Tensor* output = context.Output(0, output_shape); + + TensorShapeVector state_shape({batch_size, num_heads, head_dim_k, head_dim_v}); + Tensor* final_state = context.Output(1, state_shape); + + // Choose tile size: balance parallelism vs shared memory + // TILE_V * WORKGROUP_SIZE * 4 bytes must fit in shared memory (typically 16KB limit) + // E.g., TILE_V=4, WORKGROUP_SIZE=128: 4*128*4 = 2048 bytes + int tile_v = 4; + if (head_dim_v <= 4) { + tile_v = head_dim_v; + } + const int num_dv_tiles = (head_dim_v + tile_v - 1) / tile_v; + + // Workgroup size = head_dim_k (one thread per dk row) + // Ensure it's a power of 2 for tree reduction (round up) + uint32_t workgroup_size = 1; + while (workgroup_size < static_cast(head_dim_k)) { + workgroup_size *= 2; + } + // Cap at GPU limits + workgroup_size = std::min(workgroup_size, static_cast(256)); + + const uint32_t num_workgroups = batch_size * num_heads * num_dv_tiles; + + bool has_initial_state = initial_state != nullptr; + bool has_decay = decay != nullptr; + bool has_beta = beta != nullptr; + + LinearAttentionProgram program{update_rule_, has_initial_state, has_decay, has_beta, tile_v}; + + program.AddInputs({{query, ProgramTensorMetadataDependency::TypeAndRank}, + {key, ProgramTensorMetadataDependency::TypeAndRank}, + {value, ProgramTensorMetadataDependency::TypeAndRank}}); + if (has_initial_state) { + program.AddInput({initial_state, ProgramTensorMetadataDependency::TypeAndRank}); + } + if (has_decay) { + program.AddInput({decay, ProgramTensorMetadataDependency::TypeAndRank}); + } + if (has_beta) { + program.AddInput({beta, ProgramTensorMetadataDependency::TypeAndRank}); + } + + program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank}, + {final_state, ProgramTensorMetadataDependency::TypeAndRank}}); + + program.SetDispatchGroupSize(num_workgroups) + .SetWorkgroupSize(workgroup_size) + .CacheHint(std::to_string(static_cast(update_rule_)), + has_initial_state, has_decay, has_beta, tile_v) + .AddUniformVariables({{static_cast(batch_size)}, + {static_cast(num_heads)}, + {static_cast(seq_length)}, + {static_cast(head_dim_k)}, + {static_cast(head_dim_v)}, + {scale}, + {static_cast(num_dv_tiles)}}); + + return context.RunProgram(program); } } // namespace webgpu diff --git a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.h b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.h index b807e729b6c0d..a55a4d1801deb 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.h @@ -25,34 +25,43 @@ enum class LinearAttentionUpdateRule { LinearAttentionUpdateRule ParseUpdateRule(const std::string& rule_str); -// Program for LinearAttentionRecurrent (single-token decode) -class LinearAttentionRecurrentProgram final : public Program { +// WebGPU program for the fused linear attention kernel. +// Each workgroup processes one (batch, head, dv_tile) combination. +// Threads within a workgroup (one per dk row) cooperate on reductions. +class LinearAttentionProgram final : public Program { public: - LinearAttentionRecurrentProgram(LinearAttentionUpdateRule update_rule, bool has_decay, bool has_beta) - : Program{"LinearAttentionRecurrent"}, + LinearAttentionProgram(LinearAttentionUpdateRule update_rule, bool has_initial_state, + bool has_decay, bool has_beta, int tile_v) + : Program{"LinearAttention"}, update_rule_(update_rule), + has_initial_state_(has_initial_state), has_decay_(has_decay), - has_beta_(has_beta) {} + has_beta_(has_beta), + tile_v_(tile_v) {} Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( {"batch_size", ProgramUniformVariableDataType::Uint32}, {"num_heads", ProgramUniformVariableDataType::Uint32}, + {"seq_length", ProgramUniformVariableDataType::Uint32}, {"head_dim_k", ProgramUniformVariableDataType::Uint32}, {"head_dim_v", ProgramUniformVariableDataType::Uint32}, - {"scale", ProgramUniformVariableDataType::Float32}); + {"scale", ProgramUniformVariableDataType::Float32}, + {"num_dv_tiles", ProgramUniformVariableDataType::Uint32}); private: LinearAttentionUpdateRule update_rule_; + bool has_initial_state_; bool has_decay_; bool has_beta_; + int tile_v_; }; -// Kernel for LinearAttentionRecurrent -class LinearAttentionRecurrent : public WebGpuKernel { +// Kernel for LinearAttention +class LinearAttention : public WebGpuKernel { public: - LinearAttentionRecurrent(const OpKernelInfo& info); + LinearAttention(const OpKernelInfo& info); Status ComputeInternal(ComputeContext& context) const override; protected: @@ -61,97 +70,6 @@ class LinearAttentionRecurrent : public WebGpuKernel { int64_t chunk_size_; }; -// Program for intra-chunk attention computation -class LinearAttentionChunkIntraProgram final : public Program { - public: - LinearAttentionChunkIntraProgram(LinearAttentionUpdateRule update_rule, bool has_decay, bool has_beta) - : Program{"LinearAttentionChunkIntra"}, - update_rule_(update_rule), - has_decay_(has_decay), - has_beta_(has_beta) {} - - Status GenerateShaderCode(ShaderHelper& sh) const override; - - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( - {"batch_size", ProgramUniformVariableDataType::Uint32}, - {"num_heads", ProgramUniformVariableDataType::Uint32}, - {"sequence_length", ProgramUniformVariableDataType::Uint32}, - {"head_dim_k", ProgramUniformVariableDataType::Uint32}, - {"head_dim_v", ProgramUniformVariableDataType::Uint32}, - {"chunk_size", ProgramUniformVariableDataType::Uint32}, - {"num_chunks", ProgramUniformVariableDataType::Uint32}, - {"scale", ProgramUniformVariableDataType::Float32}); - - private: - [[maybe_unused]] LinearAttentionUpdateRule update_rule_; - bool has_decay_; - bool has_beta_; -}; - -// Program for inter-chunk state propagation -class LinearAttentionChunkInterProgram final : public Program { - public: - LinearAttentionChunkInterProgram(LinearAttentionUpdateRule update_rule, bool has_decay, bool has_beta, bool has_initial_state) - : Program{"LinearAttentionChunkInter"}, - update_rule_(update_rule), - has_decay_(has_decay), - has_beta_(has_beta), - has_initial_state_(has_initial_state) {} - - Status GenerateShaderCode(ShaderHelper& sh) const override; - - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( - {"batch_size", ProgramUniformVariableDataType::Uint32}, - {"num_heads", ProgramUniformVariableDataType::Uint32}, - {"sequence_length", ProgramUniformVariableDataType::Uint32}, - {"head_dim_k", ProgramUniformVariableDataType::Uint32}, - {"head_dim_v", ProgramUniformVariableDataType::Uint32}, - {"chunk_size", ProgramUniformVariableDataType::Uint32}, - {"num_chunks", ProgramUniformVariableDataType::Uint32}, - {"scale", ProgramUniformVariableDataType::Float32}); - - private: - [[maybe_unused]] LinearAttentionUpdateRule update_rule_; - bool has_decay_; - [[maybe_unused]] bool has_beta_; - bool has_initial_state_; -}; - -// Program for full sequential computation (used for delta/gated_delta update rules) -// Delta rules have non-linear state updates (S^T k term), so chunk-parallel decomposition -// doesn't produce correct results. This program processes the full sequence sequentially. -class LinearAttentionFullSequentialProgram final : public Program { - public: - LinearAttentionFullSequentialProgram(LinearAttentionUpdateRule update_rule, bool has_decay, bool has_beta, bool has_initial_state) - : Program{"LinearAttentionFullSequential"}, - update_rule_(update_rule), - has_decay_(has_decay), - has_beta_(has_beta), - has_initial_state_(has_initial_state) {} - - Status GenerateShaderCode(ShaderHelper& sh) const override; - - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( - {"batch_size", ProgramUniformVariableDataType::Uint32}, - {"num_heads", ProgramUniformVariableDataType::Uint32}, - {"sequence_length", ProgramUniformVariableDataType::Uint32}, - {"head_dim_k", ProgramUniformVariableDataType::Uint32}, - {"head_dim_v", ProgramUniformVariableDataType::Uint32}, - {"scale", ProgramUniformVariableDataType::Float32}); - - private: - [[maybe_unused]] LinearAttentionUpdateRule update_rule_; - bool has_decay_; - bool has_beta_; - bool has_initial_state_; -}; - -// Kernel for LinearAttentionChunkParallel -class LinearAttentionChunkParallel final : public LinearAttentionRecurrent { - public: - LinearAttentionChunkParallel(const OpKernelInfo& info); -}; - } // namespace webgpu } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc index 24a88320da980..67b71ed85d69a 100644 --- a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc @@ -22,8 +22,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Fu class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, GatherBlockQuantized); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Gelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, GroupQueryAttention); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, LinearAttentionRecurrent); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, LinearAttentionChunkParallel); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, LinearAttention); // LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 16, LayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MatMulNBits); @@ -55,8 +54,7 @@ Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry, bool enable BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index f0b68b1ae98d1..3477b5e445135 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -2301,12 +2301,12 @@ ONNX_MS_OPERATOR_SET_SCHEMA( } })); -constexpr const char* LinearAttentionRecurrent_ver1_doc = R"DOC( -Linear Attention Recurrent operator for single-token decode step. +constexpr const char* LinearAttention_ver1_doc = R"DOC( +Linear Attention operator (chunk-parallel). -This is the core operation for recurrent linear attention mechanisms used in modern -hybrid LLMs (Qwen3.5, Jamba, RWKV-6, etc.). It performs a fused state update and -output computation, keeping the full state matrix in fast memory. +Processes a sequence of tokens using linear attention with a recurrent state matrix. +When sequence_length=1, this is equivalent to a single recurrent decode step. +When sequence_length>1, this efficiently processes the full sequence (e.g., for prefill). The update_rule attribute selects the recurrence type: - "linear": S_t = S_{t-1} + k_t ⊗ v_t; o_t = q_t^T S_t / sqrt(d_k) @@ -2315,12 +2315,15 @@ The update_rule attribute selects the recurrence type: - "gated_delta": S_t = exp(g_t) * S_{t-1} + β_t * k_t ⊗ (v_t - exp(g_t) * S_{t-1}^T k_t); o_t = q_t^T S_t / sqrt(d_k) where g_t is the decay (in log-space), β_t is the update rate, and ⊗ denotes outer product. + +Semantics: Equivalent to running the recurrent update sequentially for each token, +but may be implemented using chunk-parallel algorithms for GPU efficiency. )DOC"; ONNX_MS_OPERATOR_SET_SCHEMA( - LinearAttentionRecurrent, 1, + LinearAttention, 1, OpSchema() - .SetDoc(LinearAttentionRecurrent_ver1_doc) + .SetDoc(LinearAttention_ver1_doc) .Attr("update_rule", "The update rule for the linear attention recurrence. " "One of: 'linear', 'gated', 'delta', 'gated_delta'. Default is 'gated_delta'.", @@ -2330,97 +2333,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "Output scaling factor. When 0.0 (default), uses 1/sqrt(d_k) where d_k is the key dimension.", AttributeProto::FLOAT, 0.0f) - .Input(0, - "query", - "Query vector with shape (batch_size, num_heads, 1, head_dim_k)", - "T") - .Input(1, - "key", - "Key vector with shape (batch_size, num_heads, 1, head_dim_k). " - "Should be L2-normalized for delta/gated_delta modes.", - "T") - .Input(2, - "value", - "Value vector with shape (batch_size, num_heads, 1, head_dim_v)", - "T") - .Input(3, - "past_state", - "Recurrent state from previous step with shape (batch_size, num_heads, head_dim_k, head_dim_v)", - "T") - .Input(4, - "decay", - "Exponential decay gate in log-space with shape broadcastable to (batch_size, num_heads, 1, head_dim_k). " - "Required for 'gated' and 'gated_delta' modes.", - "T", - OpSchema::Optional) - .Input(5, - "beta", - "Update rate (sigmoid output) with shape broadcastable to (batch_size, num_heads, 1, 1). " - "Required for 'delta' and 'gated_delta' modes.", - "T", - OpSchema::Optional) - .Output(0, - "output", - "Attention output with shape (batch_size, num_heads, 1, head_dim_v)", - "T") - .Output(1, - "present_state", - "Updated recurrent state with shape (batch_size, num_heads, head_dim_k, head_dim_v)", - "T") - .TypeConstraint("T", - {"tensor(float)", "tensor(float16)"}, - "Constrain input and output types to float tensors.") - .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { - propagateElemTypeFromInputToOutput(ctx, 0, 0); - propagateElemTypeFromInputToOutput(ctx, 0, 1); - - // Output 0: same shape as query (batch_size, num_heads, 1, head_dim_v) - // but last dim comes from value - if (hasInputShape(ctx, 0) && hasInputShape(ctx, 2)) { - auto& query_shape = getInputShape(ctx, 0); - auto& value_shape = getInputShape(ctx, 2); - TensorShapeProto output_shape; - *output_shape.add_dim() = query_shape.dim(0); - *output_shape.add_dim() = query_shape.dim(1); - *output_shape.add_dim() = query_shape.dim(2); - *output_shape.add_dim() = value_shape.dim(3); - updateOutputShape(ctx, 0, output_shape); - } - - // Output 1: same shape as past_state - if (hasInputShape(ctx, 3)) { - propagateShapeFromInputToOutput(ctx, 3, 1); - } - })); - -constexpr const char* LinearAttentionChunkParallel_ver1_doc = R"DOC( -Linear Attention Chunk-Parallel operator for efficient prefill. - -Processes a long input sequence by splitting it into chunks, computing intra-chunk -attention in parallel, and propagating state between chunks. This is semantically -equivalent to running LinearAttentionRecurrent sequentially for each token, but -implemented using a chunk-parallel algorithm for GPU efficiency. - -The update_rule attribute has the same semantics as LinearAttentionRecurrent. -)DOC"; - -ONNX_MS_OPERATOR_SET_SCHEMA( - LinearAttentionChunkParallel, 1, - OpSchema() - .SetDoc(LinearAttentionChunkParallel_ver1_doc) - .Attr("update_rule", - "The update rule for the linear attention recurrence. " - "One of: 'linear', 'gated', 'delta', 'gated_delta'. Default is 'gated_delta'.", - AttributeProto::STRING, - std::string("gated_delta")) .Attr("chunk_size", - "Chunk size for parallel computation. Default is 64.", + "Chunk size for parallel computation. Only a hint for the implementation.", AttributeProto::INT, static_cast(64)) - .Attr("scale", - "Output scaling factor. When 0.0 (default), uses 1/sqrt(d_k) where d_k is the key dimension.", - AttributeProto::FLOAT, - 0.0f) .Input(0, "query", "Query vectors with shape (batch_size, num_heads, sequence_length, head_dim_k)", @@ -2436,31 +2352,31 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "T") .Input(3, "initial_state", - "State from previous chunk/context with shape (batch_size, num_heads, head_dim_k, head_dim_v). " - "If not provided, initialized to zeros.", + "Initial recurrent state with shape (batch_size, num_heads, head_dim_k, head_dim_v). " + "If not provided, defaults to zeros.", "T", OpSchema::Optional) .Input(4, "decay", - "Per-token decay gates in log-space with shape broadcastable to " + "Exponential decay gate in log-space with shape broadcastable to " "(batch_size, num_heads, sequence_length, head_dim_k). " "Required for 'gated' and 'gated_delta' modes.", "T", OpSchema::Optional) .Input(5, "beta", - "Per-token update rates with shape broadcastable to " + "Update rate (sigmoid output) with shape broadcastable to " "(batch_size, num_heads, sequence_length, 1). " "Required for 'delta' and 'gated_delta' modes.", "T", OpSchema::Optional) .Output(0, "output", - "Attention output for all positions with shape (batch_size, num_heads, sequence_length, head_dim_v)", + "Attention output with shape (batch_size, num_heads, sequence_length, head_dim_v)", "T") .Output(1, "final_state", - "State after processing all tokens with shape (batch_size, num_heads, head_dim_k, head_dim_v)", + "Final recurrent state with shape (batch_size, num_heads, head_dim_k, head_dim_v)", "T") .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, @@ -2469,7 +2385,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( propagateElemTypeFromInputToOutput(ctx, 0, 0); propagateElemTypeFromInputToOutput(ctx, 0, 1); - // Output 0: (batch_size, num_heads, sequence_length, head_dim_v) + // Output 0: same shape as query but last dim from value if (hasInputShape(ctx, 0) && hasInputShape(ctx, 2)) { auto& query_shape = getInputShape(ctx, 0); auto& value_shape = getInputShape(ctx, 2); @@ -2481,18 +2397,21 @@ ONNX_MS_OPERATOR_SET_SCHEMA( updateOutputShape(ctx, 0, output_shape); } - // Output 1: (batch_size, num_heads, head_dim_k, head_dim_v) + // Output 1: final_state shape (B, H, dk, dv) if (hasInputShape(ctx, 0) && hasInputShape(ctx, 2)) { auto& query_shape = getInputShape(ctx, 0); auto& value_shape = getInputShape(ctx, 2); TensorShapeProto state_shape; - *state_shape.add_dim() = query_shape.dim(0); - *state_shape.add_dim() = query_shape.dim(1); - *state_shape.add_dim() = query_shape.dim(3); // head_dim_k - *state_shape.add_dim() = value_shape.dim(3); // head_dim_v + *state_shape.add_dim() = query_shape.dim(0); // batch + *state_shape.add_dim() = query_shape.dim(1); // heads + *state_shape.add_dim() = query_shape.dim(3); // dk + *state_shape.add_dim() = value_shape.dim(3); // dv updateOutputShape(ctx, 1, state_shape); + } else if (hasInputShape(ctx, 3)) { + propagateShapeFromInputToOutput(ctx, 3, 1); } })); + } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h index 2a8c8ee3521c6..b2b9bf8442692 100644 --- a/onnxruntime/core/graph/contrib_ops/ms_opset.h +++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h @@ -88,8 +88,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QMoE); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MultiHeadAttention); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GroupQueryAttention); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, PagedAttention); -class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, LinearAttentionRecurrent); -class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, LinearAttentionChunkParallel); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, LinearAttention); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, CausalConv1DWithState); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MurmurHash3); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, NGramRepeatBlock); @@ -202,8 +201,7 @@ class OpSet_Microsoft_ver1 { fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); - fn(GetOpSchema()); - fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.cc b/onnxruntime/core/providers/webgpu/tensor/transpose.cc index 230d172d7404e..ed83b20b8019f 100644 --- a/onnxruntime/core/providers/webgpu/tensor/transpose.cc +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.cc @@ -63,10 +63,19 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( .TypeConstraint("T", WebGpuSupportedNumberTypes()), Transpose); +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Transpose, + kOnnxDomain, + 23, 23, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()), + Transpose); + ONNX_OPERATOR_KERNEL_EX( Transpose, kOnnxDomain, - 23, + 24, kWebGpuExecutionProvider, (*KernelDefBuilder::Create()) .TypeConstraint("T", WebGpuSupportedNumberTypes()), diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index dcf4474d82d63..c6a78f26984e0 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -282,7 +282,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 16, class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 12, Transpose); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 20, Transpose); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, 22, Transpose); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 23, Transpose); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 23, 23, Transpose); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 24, Transpose); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, DepthToSpace); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, DepthToSpace); @@ -642,7 +643,8 @@ std::unique_ptr RegisterKernels(bool enable_graph_capture = fals BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/webgpu/webgpu_kernel.h b/onnxruntime/core/providers/webgpu/webgpu_kernel.h index 854b77ba4876b..42c31c7e1b82f 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_kernel.h +++ b/onnxruntime/core/providers/webgpu/webgpu_kernel.h @@ -6,6 +6,7 @@ #include "core/providers/webgpu/compute_context.h" #include "core/framework/op_kernel.h" +#include "core/providers/webgpu/numpy_io.h" namespace onnxruntime { @@ -23,6 +24,27 @@ class WebGpuKernel : public OpKernel { virtual Status ComputeInternal(ComputeContext& context) const = 0; + // call with + // NpyTensor(hidden_state, "/tmp/hidden_state.npy", context); + + template + void NpyTensor(const Tensor* t, std::string file, ComputeContext& context) const { + auto t_cpu = context.CreateCPUTensor(t->DataType(), t->Shape()); + ORT_THROW_IF_ERROR(Info().GetDataTransferManager().CopyTensor(*t, t_cpu)); + + std::vector dims; + auto dims1 = t_cpu.Shape().GetDims(); + for (uint64_t i=0; i(dims); + for (int64_t i = 0; i < t_cpu.Shape().Size(); i++) { + a.data[i] = static_cast(t_cpu.Data()[i]); + } + numpy_io::write_numpy_array(file, a); + } + + // Overrides OpKernel::PrePack to handle constant tensor pre-processing for WebGPU kernels. // This method creates a ComputeContextBase and delegates to PrePackInternal. // diff --git a/onnxruntime/test/contrib_ops/linear_attention_op_test.cc b/onnxruntime/test/contrib_ops/linear_attention_op_test.cc index 6b80a55c9ecec..af10d85330048 100644 --- a/onnxruntime/test/contrib_ops/linear_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/linear_attention_op_test.cc @@ -1,120 +1,26 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include #include #include + #include "gtest/gtest.h" -#include "core/session/onnxruntime_cxx_api.h" -#include "test/common/tensor_op_test_utils.h" -#include "test/common/cuda_op_test_utils.h" #include "test/providers/provider_test_utils.h" +#include "test/util/include/default_providers.h" + +using namespace onnxruntime::test; namespace onnxruntime { namespace test { namespace { -enum class TensorType { - kFloat, - kFloat16 -}; -// Reference implementation for linear attention recurrent update -void LinearAttentionRecurrentReference( - const std::vector& query, - const std::vector& key, - const std::vector& value, - const std::vector& past_state, - const std::vector* decay, - const std::vector* beta, - std::vector& output, - std::vector& present_state, - int batch_size, - int num_heads, - int head_dim_k, - int head_dim_v, +// Reference implementation of the linear attention recurrence. +// Processes all tokens sequentially and returns output + final_state. +void LinearAttentionReference( const std::string& update_rule, - float scale) { - if (scale == 0.0f) { - scale = 1.0f / std::sqrt(static_cast(head_dim_k)); - } - - // Copy past_state to present_state first - present_state = past_state; - - output.resize(batch_size * num_heads * head_dim_v); - - for (int b = 0; b < batch_size; ++b) { - for (int h = 0; h < num_heads; ++h) { - int bh = b * num_heads + h; - int state_base = bh * head_dim_k * head_dim_v; - int qkv_base = bh * head_dim_k; - int v_base = bh * head_dim_v; - - // Apply decay if gated or gated_delta - if (update_rule == "gated" || update_rule == "gated_delta") { - for (int k = 0; k < head_dim_k; ++k) { - float g = (*decay)[qkv_base + k]; - float exp_g = std::exp(g); - for (int v = 0; v < head_dim_v; ++v) { - present_state[state_base + k * head_dim_v + v] *= exp_g; - } - } - } - - // Compute update - if (update_rule == "linear" || update_rule == "gated") { - // S += k ⊗ v - for (int k = 0; k < head_dim_k; ++k) { - float k_val = key[qkv_base + k]; - for (int v = 0; v < head_dim_v; ++v) { - float v_val = value[v_base + v]; - present_state[state_base + k * head_dim_v + v] += k_val * v_val; - } - } - } else if (update_rule == "delta" || update_rule == "gated_delta") { - // Compute retrieved = S^T @ k - std::vector retrieved(head_dim_v, 0.0f); - for (int v = 0; v < head_dim_v; ++v) { - for (int k = 0; k < head_dim_k; ++k) { - float k_val = key[qkv_base + k]; - // For gated_delta, retrieval uses decayed state (already applied above) - // For delta, uses original past_state - float s_val = (update_rule == "gated_delta") - ? present_state[state_base + k * head_dim_v + v] - : past_state[state_base + k * head_dim_v + v]; - retrieved[v] += s_val * k_val; - } - } - - // Compute delta and update - float beta_val = (*beta)[bh]; - for (int k = 0; k < head_dim_k; ++k) { - float k_val = key[qkv_base + k]; - for (int v = 0; v < head_dim_v; ++v) { - float v_val = value[v_base + v]; - float delta = beta_val * (v_val - retrieved[v]); - present_state[state_base + k * head_dim_v + v] += k_val * delta; - } - } - } - - // Compute output = scale * q^T @ S - for (int v = 0; v < head_dim_v; ++v) { - float out_val = 0.0f; - for (int k = 0; k < head_dim_k; ++k) { - float q_val = query[qkv_base + k]; - out_val += q_val * present_state[state_base + k * head_dim_v + v]; - } - output[v_base + v] = out_val * scale; - } - } - } -} - -// Reference implementation for linear attention chunk parallel (full sequence) -// This is the sequential version that processes one step at a time. -void LinearAttentionChunkParallelReference( + int batch_size, int num_heads, int seq_length, int head_dim_k, int head_dim_v, + float scale, const std::vector& query, const std::vector& key, const std::vector& value, @@ -122,683 +28,497 @@ void LinearAttentionChunkParallelReference( const std::vector* decay, const std::vector* beta, std::vector& output, - std::vector& final_state, - int batch_size, - int num_heads, - int seq_length, - int head_dim_k, - int head_dim_v, - const std::string& update_rule, - float scale) { - if (scale == 0.0f) { - scale = 1.0f / std::sqrt(static_cast(head_dim_k)); + std::vector& final_state) { + // State: (B, H, dk, dv) + final_state.resize(batch_size * num_heads * head_dim_k * head_dim_v, 0.0f); + output.resize(batch_size * num_heads * seq_length * head_dim_v, 0.0f); + + // Initialize state from initial_state if provided + if (initial_state != nullptr) { + final_state = *initial_state; } - output.resize(batch_size * num_heads * seq_length * head_dim_v); - final_state.resize(batch_size * num_heads * head_dim_k * head_dim_v); - - int state_size = head_dim_k * head_dim_v; - - for (int b = 0; b < batch_size; ++b) { - for (int h = 0; h < num_heads; ++h) { - int bh = b * num_heads + h; - - // Initialize state - std::vector state(state_size, 0.0f); - if (initial_state != nullptr) { - int init_base = bh * state_size; - for (int i = 0; i < state_size; ++i) { - state[i] = (*initial_state)[init_base + i]; + for (int b = 0; b < batch_size; b++) { + for (int h = 0; h < num_heads; h++) { + // State for this (b, h): dk x dv + auto state_offset = [&](int k, int v) { + return ((b * num_heads + h) * head_dim_k + k) * head_dim_v + v; + }; + + for (int t = 0; t < seq_length; t++) { + auto qkv_offset = [&](int dim) { + return ((b * num_heads + h) * seq_length + t) * dim; + }; + + // Load q, k for this token + std::vector q_vec(head_dim_k), k_vec(head_dim_k), v_vec(head_dim_v); + for (int i = 0; i < head_dim_k; i++) { + q_vec[i] = query[qkv_offset(head_dim_k) + i]; + k_vec[i] = key[qkv_offset(head_dim_k) + i]; + } + for (int i = 0; i < head_dim_v; i++) { + v_vec[i] = value[qkv_offset(head_dim_v) + i]; } - } - - // Process each timestep sequentially - for (int t = 0; t < seq_length; ++t) { - int qk_base = (bh * seq_length + t) * head_dim_k; - int v_base = (bh * seq_length + t) * head_dim_v; - // 1. Apply decay if gated or gated_delta + // Step 1: Apply decay (gated, gated_delta) if (update_rule == "gated" || update_rule == "gated_delta") { - for (int ki = 0; ki < head_dim_k; ++ki) { - float g = (*decay)[qk_base + ki]; - float exp_g = std::exp(g); - for (int vi = 0; vi < head_dim_v; ++vi) { - state[ki * head_dim_v + vi] *= exp_g; + for (int k = 0; k < head_dim_k; k++) { + int decay_idx = ((b * num_heads + h) * seq_length + t) * head_dim_k + k; + float exp_g = std::exp((*decay)[decay_idx]); + for (int v_idx = 0; v_idx < head_dim_v; v_idx++) { + final_state[state_offset(k, v_idx)] *= exp_g; } } } - // 2. Update state - if (update_rule == "linear" || update_rule == "gated") { - // S += k ⊗ v - for (int ki = 0; ki < head_dim_k; ++ki) { - float k_val = key[qk_base + ki]; - for (int vi = 0; vi < head_dim_v; ++vi) { - float v_val = value[v_base + vi]; - state[ki * head_dim_v + vi] += k_val * v_val; - } - } - } else if (update_rule == "delta" || update_rule == "gated_delta") { - // Compute retrieved = S^T @ k + // Step 2: Compute state update + if (update_rule == "delta" || update_rule == "gated_delta") { + // retrieved = S^T @ k (for each v dimension) std::vector retrieved(head_dim_v, 0.0f); - for (int vi = 0; vi < head_dim_v; ++vi) { - for (int ki = 0; ki < head_dim_k; ++ki) { - retrieved[vi] += state[ki * head_dim_v + vi] * key[qk_base + ki]; + for (int v_idx = 0; v_idx < head_dim_v; v_idx++) { + for (int k = 0; k < head_dim_k; k++) { + retrieved[v_idx] += final_state[state_offset(k, v_idx)] * k_vec[k]; } } - float beta_val = (*beta)[bh * seq_length + t]; - for (int ki = 0; ki < head_dim_k; ++ki) { - float k_val = key[qk_base + ki]; - for (int vi = 0; vi < head_dim_v; ++vi) { - float v_val = value[v_base + vi]; - float delta_val = beta_val * (v_val - retrieved[vi]); - state[ki * head_dim_v + vi] += k_val * delta_val; + // delta = beta * (v - retrieved) + int beta_idx = (b * num_heads + h) * seq_length + t; + float beta_val = (*beta)[beta_idx]; + std::vector delta(head_dim_v); + for (int v_idx = 0; v_idx < head_dim_v; v_idx++) { + delta[v_idx] = beta_val * (v_vec[v_idx] - retrieved[v_idx]); + } + + // S += k ⊗ delta + for (int k = 0; k < head_dim_k; k++) { + for (int v_idx = 0; v_idx < head_dim_v; v_idx++) { + final_state[state_offset(k, v_idx)] += k_vec[k] * delta[v_idx]; + } + } + } else { + // linear, gated: S += k ⊗ v + for (int k = 0; k < head_dim_k; k++) { + for (int v_idx = 0; v_idx < head_dim_v; v_idx++) { + final_state[state_offset(k, v_idx)] += k_vec[k] * v_vec[v_idx]; } } } - // 3. Compute output = scale * q^T @ S - int out_base = (bh * seq_length + t) * head_dim_v; - for (int vi = 0; vi < head_dim_v; ++vi) { - float out_val = 0.0f; - for (int ki = 0; ki < head_dim_k; ++ki) { - out_val += query[qk_base + ki] * state[ki * head_dim_v + vi]; + // Step 3: Compute output = scale * S^T @ q + for (int v_idx = 0; v_idx < head_dim_v; v_idx++) { + float sum = 0.0f; + for (int k = 0; k < head_dim_k; k++) { + sum += final_state[state_offset(k, v_idx)] * q_vec[k]; } - output[out_base + vi] = out_val * scale; + int out_idx = ((b * num_heads + h) * seq_length + t) * head_dim_v + v_idx; + output[out_idx] = scale * sum; } } - - // Copy final state - int final_base = bh * state_size; - for (int i = 0; i < state_size; ++i) { - final_state[final_base + i] = state[i]; - } } } } -} // anonymous namespace - -static void RunLinearAttentionRecurrentTest( - const std::vector& query_data, - const std::vector& key_data, - const std::vector& value_data, - const std::vector& past_state_data, - const std::vector* decay_data, - const std::vector* beta_data, - const std::vector& expected_output, - const std::vector& expected_state, - int batch_size, - int num_heads, - int head_dim_k, - int head_dim_v, +void RunLinearAttentionTest( const std::string& update_rule, + int batch_size, int num_heads, int seq_length, int head_dim_k, int head_dim_v, float scale, - TensorType tensor_type) { - std::vector query_shape = {batch_size, num_heads, 1, head_dim_k}; - std::vector key_shape = {batch_size, num_heads, 1, head_dim_k}; - std::vector value_shape = {batch_size, num_heads, 1, head_dim_v}; - std::vector state_shape = {batch_size, num_heads, head_dim_k, head_dim_v}; - std::vector decay_shape = {batch_size, num_heads, 1, head_dim_k}; - std::vector beta_shape = {batch_size, num_heads, 1, 1}; - std::vector output_shape = {batch_size, num_heads, 1, head_dim_v}; - - std::string op_type = "LinearAttentionRecurrent"; - std::vector> execution_providers; - - bool enable_webgpu = nullptr != DefaultWebGpuExecutionProvider().get(); + const std::vector& query, + const std::vector& key, + const std::vector& value, + const std::vector* initial_state, + const std::vector* decay, + const std::vector* beta_data) { + // Compute reference output + std::vector expected_output, expected_state; + LinearAttentionReference(update_rule, batch_size, num_heads, seq_length, + head_dim_k, head_dim_v, scale, + query, key, value, initial_state, decay, beta_data, + expected_output, expected_state); + + bool enable_webgpu = (nullptr != DefaultWebGpuExecutionProvider().get()); + if (!enable_webgpu) { + return; + } - if (enable_webgpu) { - execution_providers.push_back(DefaultWebGpuExecutionProvider()); + OpTester tester("LinearAttention", 1, onnxruntime::kMSDomain); + tester.AddAttribute("update_rule", update_rule); + tester.AddAttribute("scale", scale); + + // Add required inputs + std::vector qk_dims = {batch_size, num_heads, seq_length, head_dim_k}; + std::vector v_dims = {batch_size, num_heads, seq_length, head_dim_v}; + tester.AddInput("query", qk_dims, query); + tester.AddInput("key", qk_dims, key); + tester.AddInput("value", v_dims, value); + + // Optional: initial_state + if (initial_state != nullptr) { + std::vector state_dims = {batch_size, num_heads, head_dim_k, head_dim_v}; + tester.AddInput("initial_state", state_dims, *initial_state); + } else { + tester.AddOptionalInputEdge(); } - if (execution_providers.empty()) { - // Skip if no providers available - return; + // Optional: decay + if (decay != nullptr) { + std::vector decay_dims = {batch_size, num_heads, seq_length, head_dim_k}; + tester.AddInput("decay", decay_dims, *decay); + } else { + tester.AddOptionalInputEdge(); } - for (auto& ep : execution_providers) { - OpTester test(op_type.c_str(), 1, onnxruntime::kMSDomain); - test.AddAttribute("update_rule", update_rule); - test.AddAttribute("scale", scale); - - if (tensor_type == TensorType::kFloat) { - test.AddInput("query", query_shape, query_data); - test.AddInput("key", key_shape, key_data); - test.AddInput("value", value_shape, value_data); - test.AddInput("past_state", state_shape, past_state_data); - - if (decay_data != nullptr) { - test.AddInput("decay", decay_shape, *decay_data); - } else { - test.AddOptionalInputEdge(); - } + // Optional: beta + if (beta_data != nullptr) { + std::vector beta_dims = {batch_size, num_heads, seq_length, 1}; + tester.AddInput("beta", beta_dims, *beta_data); + } else { + tester.AddOptionalInputEdge(); + } - if (beta_data != nullptr) { - test.AddInput("beta", beta_shape, *beta_data); - } else { - test.AddOptionalInputEdge(); - } + // Add outputs + std::vector out_dims = {batch_size, num_heads, seq_length, head_dim_v}; + std::vector state_dims = {batch_size, num_heads, head_dim_k, head_dim_v}; + tester.AddOutput("output", out_dims, expected_output, false, 0.005f, 0.005f); + tester.AddOutput("final_state", state_dims, expected_state, false, 0.005f, 0.005f); - test.AddOutput("output", output_shape, expected_output); - test.AddOutput("present_state", state_shape, expected_state); - } else { - test.AddInput("query", query_shape, ToFloat16(query_data)); - test.AddInput("key", key_shape, ToFloat16(key_data)); - test.AddInput("value", value_shape, ToFloat16(value_data)); - test.AddInput("past_state", state_shape, ToFloat16(past_state_data)); - - if (decay_data != nullptr) { - test.AddInput("decay", decay_shape, ToFloat16(*decay_data)); - } else { - test.AddOptionalInputEdge(); - } + std::vector> execution_providers; + execution_providers.push_back(DefaultWebGpuExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} - if (beta_data != nullptr) { - test.AddInput("beta", beta_shape, ToFloat16(*beta_data)); - } else { - test.AddOptionalInputEdge(); - } +} // namespace - test.AddOutput("output", output_shape, ToFloat16(expected_output)); - test.AddOutput("present_state", state_shape, ToFloat16(expected_state)); - } +// =========================================================================== +// Test: Linear update rule (simplest - no decay, no beta) +// =========================================================================== +TEST(ContribOpLinearAttentionTest, LinearRule_SingleToken) { + const int B = 1, H = 1, T = 1, dk = 4, dv = 4; + float scale = 1.0f / std::sqrt(static_cast(dk)); - test.SetOutputAbsErr("output", 0.01f); - test.SetOutputAbsErr("present_state", 0.01f); + std::vector query = {1.0f, 0.0f, 0.5f, -0.5f}; + std::vector key = {0.5f, 0.5f, 0.0f, 1.0f}; + std::vector value = {1.0f, 2.0f, 3.0f, 4.0f}; - std::vector> test_execution_providers; - test_execution_providers.push_back(std::move(ep)); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &test_execution_providers); - } + RunLinearAttentionTest("linear", B, H, T, dk, dv, scale, + query, key, value, + nullptr, nullptr, nullptr); } -static void RunLinearAttentionRecurrentTests( - const std::vector& query_data, - const std::vector& key_data, - const std::vector& value_data, - const std::vector& past_state_data, - const std::vector* decay_data, - const std::vector* beta_data, - int batch_size, - int num_heads, - int head_dim_k, - int head_dim_v, - const std::string& update_rule, - float scale = 0.0f) { - // Compute expected output using reference implementation - std::vector expected_output; - std::vector expected_state; - LinearAttentionRecurrentReference( - query_data, key_data, value_data, past_state_data, - decay_data, beta_data, - expected_output, expected_state, - batch_size, num_heads, head_dim_k, head_dim_v, - update_rule, scale); - - // FP32 test - RunLinearAttentionRecurrentTest( - query_data, key_data, value_data, past_state_data, - decay_data, beta_data, - expected_output, expected_state, - batch_size, num_heads, head_dim_k, head_dim_v, - update_rule, scale, TensorType::kFloat); - - // FP16 test - RunLinearAttentionRecurrentTest( - query_data, key_data, value_data, past_state_data, - decay_data, beta_data, - expected_output, expected_state, - batch_size, num_heads, head_dim_k, head_dim_v, - update_rule, scale, TensorType::kFloat16); +TEST(ContribOpLinearAttentionTest, LinearRule_MultiToken) { + const int B = 1, H = 1, T = 3, dk = 4, dv = 4; + float scale = 1.0f / std::sqrt(static_cast(dk)); + + std::vector query = { + 1.0f, 0.0f, 0.5f, -0.5f, + 0.5f, 1.0f, -0.5f, 0.0f, + 0.0f, -1.0f, 1.0f, 0.5f}; + std::vector key = { + 0.5f, 0.5f, 0.0f, 1.0f, + 1.0f, 0.0f, 1.0f, 0.5f, + -0.5f, 1.0f, 0.5f, 0.0f}; + std::vector value = { + 1.0f, 2.0f, 3.0f, 4.0f, + 2.0f, 1.0f, 0.0f, 3.0f, + 3.0f, 0.0f, 1.0f, 2.0f}; + + RunLinearAttentionTest("linear", B, H, T, dk, dv, scale, + query, key, value, + nullptr, nullptr, nullptr); } -// ============================================================================= -// LinearAttentionRecurrent Tests -// ============================================================================= - -TEST(ContribOpLinearAttentionTest, LinearAttentionRecurrent_Linear_Basic) { - int batch_size = 1; - int num_heads = 2; - int head_dim_k = 4; - int head_dim_v = 4; - - // Query: (1, 2, 1, 4) - std::vector query_data = { - 0.5f, 0.3f, -0.2f, 0.1f, // head 0 - -0.4f, 0.6f, 0.2f, -0.3f // head 1 - }; - - // Key: (1, 2, 1, 4) - std::vector key_data = { - 0.1f, 0.2f, 0.3f, 0.4f, - 0.2f, -0.1f, 0.3f, 0.1f}; - - // Value: (1, 2, 1, 4) - std::vector value_data = { - 0.4f, 0.3f, 0.2f, 0.1f, - -0.2f, 0.4f, 0.1f, 0.3f}; - - // Past state: (1, 2, 4, 4) - initialized to small values - std::vector past_state_data(batch_size * num_heads * head_dim_k * head_dim_v, 0.1f); - - RunLinearAttentionRecurrentTests( - query_data, key_data, value_data, past_state_data, - nullptr, nullptr, - batch_size, num_heads, head_dim_k, head_dim_v, - "linear"); +TEST(ContribOpLinearAttentionTest, LinearRule_WithInitialState) { + const int B = 1, H = 1, T = 2, dk = 4, dv = 4; + float scale = 1.0f / std::sqrt(static_cast(dk)); + + std::vector query = { + 1.0f, 0.0f, 0.5f, -0.5f, + 0.5f, 1.0f, -0.5f, 0.0f}; + std::vector key = { + 0.5f, 0.5f, 0.0f, 1.0f, + 1.0f, 0.0f, 1.0f, 0.5f}; + std::vector value = { + 1.0f, 2.0f, 3.0f, 4.0f, + 2.0f, 1.0f, 0.0f, 3.0f}; + + // Non-zero initial state + std::vector initial_state(dk * dv, 0.1f); + + RunLinearAttentionTest("linear", B, H, T, dk, dv, scale, + query, key, value, + &initial_state, nullptr, nullptr); } -TEST(ContribOpLinearAttentionTest, LinearAttentionRecurrent_Gated_Basic) { - int batch_size = 1; - int num_heads = 2; - int head_dim_k = 4; - int head_dim_v = 4; - - std::vector query_data = { - 0.5f, 0.3f, -0.2f, 0.1f, - -0.4f, 0.6f, 0.2f, -0.3f}; +// =========================================================================== +// Test: Gated update rule (decay, no beta) +// =========================================================================== +TEST(ContribOpLinearAttentionTest, GatedRule_SingleToken) { + const int B = 1, H = 1, T = 1, dk = 4, dv = 4; + float scale = 1.0f / std::sqrt(static_cast(dk)); - std::vector key_data = { - 0.1f, 0.2f, 0.3f, 0.4f, - 0.2f, -0.1f, 0.3f, 0.1f}; + std::vector query = {1.0f, 0.0f, 0.5f, -0.5f}; + std::vector key = {0.5f, 0.5f, 0.0f, 1.0f}; + std::vector value = {1.0f, 2.0f, 3.0f, 4.0f}; - std::vector value_data = { - 0.4f, 0.3f, 0.2f, 0.1f, - -0.2f, 0.4f, 0.1f, 0.3f}; + // Decay in log-space (small negative values for slight decay) + std::vector decay = {-0.1f, -0.2f, -0.05f, -0.15f}; - std::vector past_state_data(batch_size * num_heads * head_dim_k * head_dim_v, 0.1f); + // Initial state (needed to see decay effect) + std::vector initial_state(dk * dv, 1.0f); - // Decay: (1, 2, 1, 4) - negative values for decay - std::vector decay_data = { - -0.1f, -0.1f, -0.1f, -0.1f, - -0.2f, -0.2f, -0.2f, -0.2f}; - - RunLinearAttentionRecurrentTests( - query_data, key_data, value_data, past_state_data, - &decay_data, nullptr, - batch_size, num_heads, head_dim_k, head_dim_v, - "gated"); + RunLinearAttentionTest("gated", B, H, T, dk, dv, scale, + query, key, value, + &initial_state, &decay, nullptr); } -TEST(ContribOpLinearAttentionTest, LinearAttentionRecurrent_Delta_Basic) { - int batch_size = 1; - int num_heads = 2; - int head_dim_k = 4; - int head_dim_v = 4; - - std::vector query_data = { - 0.5f, 0.3f, -0.2f, 0.1f, - -0.4f, 0.6f, 0.2f, -0.3f}; - - // L2-normalized keys for delta rule - std::vector key_data = { - 0.1826f, 0.3651f, 0.5477f, 0.7303f, // normalized - 0.5345f, -0.2673f, 0.8018f, 0.2673f // normalized - }; +TEST(ContribOpLinearAttentionTest, GatedRule_MultiToken) { + const int B = 1, H = 1, T = 3, dk = 4, dv = 4; + float scale = 1.0f / std::sqrt(static_cast(dk)); + + std::vector query = { + 1.0f, 0.0f, 0.5f, -0.5f, + 0.5f, 1.0f, -0.5f, 0.0f, + 0.0f, -1.0f, 1.0f, 0.5f}; + std::vector key = { + 0.5f, 0.5f, 0.0f, 1.0f, + 1.0f, 0.0f, 1.0f, 0.5f, + -0.5f, 1.0f, 0.5f, 0.0f}; + std::vector value = { + 1.0f, 2.0f, 3.0f, 4.0f, + 2.0f, 1.0f, 0.0f, 3.0f, + 3.0f, 0.0f, 1.0f, 2.0f}; + std::vector decay = { + -0.1f, -0.2f, -0.05f, -0.15f, + -0.2f, -0.1f, -0.3f, -0.05f, + -0.05f, -0.15f, -0.1f, -0.2f}; + + std::vector initial_state(dk * dv, 0.5f); + + RunLinearAttentionTest("gated", B, H, T, dk, dv, scale, + query, key, value, + &initial_state, &decay, nullptr); +} - std::vector value_data = { - 0.4f, 0.3f, 0.2f, 0.1f, - -0.2f, 0.4f, 0.1f, 0.3f}; +// =========================================================================== +// Test: Delta update rule (no decay, uses beta) +// =========================================================================== +TEST(ContribOpLinearAttentionTest, DeltaRule_SingleToken) { + const int B = 1, H = 1, T = 1, dk = 4, dv = 4; + float scale = 1.0f / std::sqrt(static_cast(dk)); - std::vector past_state_data(batch_size * num_heads * head_dim_k * head_dim_v, 0.1f); + std::vector query = {1.0f, 0.0f, 0.5f, -0.5f}; + std::vector key = {0.5f, 0.5f, 0.0f, 1.0f}; + std::vector value = {1.0f, 2.0f, 3.0f, 4.0f}; + std::vector beta = {0.8f}; // shape: (1,1,1,1) - // Beta: (1, 2, 1, 1) - std::vector beta_data = {0.5f, 0.7f}; + std::vector initial_state(dk * dv, 0.5f); - RunLinearAttentionRecurrentTests( - query_data, key_data, value_data, past_state_data, - nullptr, &beta_data, - batch_size, num_heads, head_dim_k, head_dim_v, - "delta"); + RunLinearAttentionTest("delta", B, H, T, dk, dv, scale, + query, key, value, + &initial_state, nullptr, &beta); } -TEST(ContribOpLinearAttentionTest, LinearAttentionRecurrent_GatedDelta_Basic) { - int batch_size = 1; - int num_heads = 2; - int head_dim_k = 4; - int head_dim_v = 4; - - std::vector query_data = { - 0.5f, 0.3f, -0.2f, 0.1f, - -0.4f, 0.6f, 0.2f, -0.3f}; - - // L2-normalized keys - std::vector key_data = { - 0.1826f, 0.3651f, 0.5477f, 0.7303f, - 0.5345f, -0.2673f, 0.8018f, 0.2673f}; +TEST(ContribOpLinearAttentionTest, DeltaRule_MultiToken) { + const int B = 1, H = 1, T = 3, dk = 4, dv = 4; + float scale = 1.0f / std::sqrt(static_cast(dk)); + + std::vector query = { + 1.0f, 0.0f, 0.5f, -0.5f, + 0.5f, 1.0f, -0.5f, 0.0f, + 0.0f, -1.0f, 1.0f, 0.5f}; + std::vector key = { + 0.5f, 0.5f, 0.0f, 1.0f, + 1.0f, 0.0f, 1.0f, 0.5f, + -0.5f, 1.0f, 0.5f, 0.0f}; + std::vector value = { + 1.0f, 2.0f, 3.0f, 4.0f, + 2.0f, 1.0f, 0.0f, 3.0f, + 3.0f, 0.0f, 1.0f, 2.0f}; + std::vector beta = {0.8f, 0.6f, 0.9f}; // shape: (1,1,3,1) + + RunLinearAttentionTest("delta", B, H, T, dk, dv, scale, + query, key, value, + nullptr, nullptr, &beta); +} - std::vector value_data = { - 0.4f, 0.3f, 0.2f, 0.1f, - -0.2f, 0.4f, 0.1f, 0.3f}; +// =========================================================================== +// Test: GatedDelta update rule (full - decay + beta) +// =========================================================================== +TEST(ContribOpLinearAttentionTest, GatedDeltaRule_SingleToken) { + const int B = 1, H = 1, T = 1, dk = 4, dv = 4; + float scale = 1.0f / std::sqrt(static_cast(dk)); - std::vector past_state_data(batch_size * num_heads * head_dim_k * head_dim_v, 0.1f); + std::vector query = {1.0f, 0.0f, 0.5f, -0.5f}; + std::vector key = {0.5f, 0.5f, 0.0f, 1.0f}; + std::vector value = {1.0f, 2.0f, 3.0f, 4.0f}; + std::vector decay = {-0.1f, -0.2f, -0.05f, -0.15f}; + std::vector beta = {0.8f}; - // Decay: (1, 2, 1, 4) - std::vector decay_data = { - -0.1f, -0.1f, -0.1f, -0.1f, - -0.2f, -0.2f, -0.2f, -0.2f}; + std::vector initial_state(dk * dv, 1.0f); - // Beta: (1, 2, 1, 1) - std::vector beta_data = {0.5f, 0.7f}; + RunLinearAttentionTest("gated_delta", B, H, T, dk, dv, scale, + query, key, value, + &initial_state, &decay, &beta); +} - RunLinearAttentionRecurrentTests( - query_data, key_data, value_data, past_state_data, - &decay_data, &beta_data, - batch_size, num_heads, head_dim_k, head_dim_v, - "gated_delta"); +TEST(ContribOpLinearAttentionTest, GatedDeltaRule_MultiToken) { + const int B = 1, H = 1, T = 3, dk = 4, dv = 4; + float scale = 1.0f / std::sqrt(static_cast(dk)); + + std::vector query = { + 1.0f, 0.0f, 0.5f, -0.5f, + 0.5f, 1.0f, -0.5f, 0.0f, + 0.0f, -1.0f, 1.0f, 0.5f}; + std::vector key = { + 0.5f, 0.5f, 0.0f, 1.0f, + 1.0f, 0.0f, 1.0f, 0.5f, + -0.5f, 1.0f, 0.5f, 0.0f}; + std::vector value = { + 1.0f, 2.0f, 3.0f, 4.0f, + 2.0f, 1.0f, 0.0f, 3.0f, + 3.0f, 0.0f, 1.0f, 2.0f}; + std::vector decay = { + -0.1f, -0.2f, -0.05f, -0.15f, + -0.2f, -0.1f, -0.3f, -0.05f, + -0.05f, -0.15f, -0.1f, -0.2f}; + std::vector beta = {0.8f, 0.6f, 0.9f}; + + std::vector initial_state(dk * dv, 0.5f); + + RunLinearAttentionTest("gated_delta", B, H, T, dk, dv, scale, + query, key, value, + &initial_state, &decay, &beta); } -TEST(ContribOpLinearAttentionTest, LinearAttentionRecurrent_LargerBatch) { - int batch_size = 2; - int num_heads = 4; - int head_dim_k = 8; - int head_dim_v = 8; - - int qkv_size = batch_size * num_heads * head_dim_k; - int value_size = batch_size * num_heads * head_dim_v; - int state_size = batch_size * num_heads * head_dim_k * head_dim_v; - - // Generate random-ish data - std::vector query_data(qkv_size); - std::vector key_data(qkv_size); - std::vector value_data(value_size); - std::vector past_state_data(state_size); - std::vector decay_data(qkv_size); - std::vector beta_data(batch_size * num_heads); - - for (int i = 0; i < qkv_size; ++i) { - query_data[i] = 0.1f * (i % 10 - 5); - key_data[i] = 0.1f * ((i + 3) % 10 - 5); - decay_data[i] = -0.1f * ((i % 3) + 1); - } - for (int i = 0; i < value_size; ++i) { - value_data[i] = 0.1f * ((i + 7) % 10 - 5); +// =========================================================================== +// Test: Multi-batch, multi-head +// =========================================================================== +TEST(ContribOpLinearAttentionTest, LinearRule_MultiBatchMultiHead) { + const int B = 2, H = 2, T = 2, dk = 4, dv = 4; + float scale = 1.0f / std::sqrt(static_cast(dk)); + + // Total: B*H*T*dk = 2*2*2*4 = 32 values for q/k, B*H*T*dv = 32 for v + std::vector query(B * H * T * dk); + std::vector key(B * H * T * dk); + std::vector value(B * H * T * dv); + + // Fill with deterministic pattern + for (int i = 0; i < B * H * T * dk; i++) { + query[i] = std::sin(static_cast(i) * 0.3f); + key[i] = std::cos(static_cast(i) * 0.5f); } - for (int i = 0; i < state_size; ++i) { - past_state_data[i] = 0.05f * (i % 10 - 5); - } - for (int i = 0; i < batch_size * num_heads; ++i) { - beta_data[i] = 0.3f + 0.1f * (i % 5); + for (int i = 0; i < B * H * T * dv; i++) { + value[i] = std::sin(static_cast(i) * 0.7f + 1.0f); } - RunLinearAttentionRecurrentTests( - query_data, key_data, value_data, past_state_data, - &decay_data, &beta_data, - batch_size, num_heads, head_dim_k, head_dim_v, - "gated_delta"); + RunLinearAttentionTest("linear", B, H, T, dk, dv, scale, + query, key, value, + nullptr, nullptr, nullptr); } -// ============================================================================= -// LinearAttentionChunkParallel Tests -// ============================================================================= - -static void RunLinearAttentionChunkParallelTest( - const std::vector& query_data, - const std::vector& key_data, - const std::vector& value_data, - const std::vector* initial_state_data, - const std::vector* decay_data, - const std::vector* beta_data, - int batch_size, - int num_heads, - int seq_length, - int head_dim_k, - int head_dim_v, - const std::string& update_rule, - int64_t chunk_size, - float scale, - TensorType tensor_type) { - std::vector query_shape = {batch_size, num_heads, seq_length, head_dim_k}; - std::vector key_shape = {batch_size, num_heads, seq_length, head_dim_k}; - std::vector value_shape = {batch_size, num_heads, seq_length, head_dim_v}; - std::vector state_shape = {batch_size, num_heads, head_dim_k, head_dim_v}; - std::vector decay_shape = {batch_size, num_heads, seq_length, head_dim_k}; - std::vector beta_shape = {batch_size, num_heads, seq_length, 1}; - std::vector output_shape = {batch_size, num_heads, seq_length, head_dim_v}; - - // Compute reference expected output - std::vector expected_output; - std::vector expected_state; - LinearAttentionChunkParallelReference( - query_data, key_data, value_data, - initial_state_data, decay_data, beta_data, - expected_output, expected_state, - batch_size, num_heads, seq_length, head_dim_k, head_dim_v, - update_rule, scale); - - std::string op_type = "LinearAttentionChunkParallel"; - std::vector> execution_providers; +TEST(ContribOpLinearAttentionTest, GatedDeltaRule_MultiBatchMultiHead) { + const int B = 2, H = 2, T = 2, dk = 4, dv = 4; + float scale = 1.0f / std::sqrt(static_cast(dk)); - bool enable_webgpu = nullptr != DefaultWebGpuExecutionProvider().get(); + std::vector query(B * H * T * dk); + std::vector key(B * H * T * dk); + std::vector value(B * H * T * dv); + std::vector decay(B * H * T * dk); + std::vector beta(B * H * T); - if (enable_webgpu) { - execution_providers.push_back(DefaultWebGpuExecutionProvider()); + for (int i = 0; i < B * H * T * dk; i++) { + query[i] = std::sin(static_cast(i) * 0.3f); + key[i] = std::cos(static_cast(i) * 0.5f); + decay[i] = -0.1f - 0.1f * std::sin(static_cast(i) * 0.2f); } - - if (execution_providers.empty()) { - return; + for (int i = 0; i < B * H * T * dv; i++) { + value[i] = std::sin(static_cast(i) * 0.7f + 1.0f); } - - for (auto& ep : execution_providers) { - OpTester test(op_type.c_str(), 1, onnxruntime::kMSDomain); - test.AddAttribute("update_rule", update_rule); - test.AddAttribute("chunk_size", chunk_size); - test.AddAttribute("scale", scale); - - if (tensor_type == TensorType::kFloat) { - test.AddInput("query", query_shape, query_data); - test.AddInput("key", key_shape, key_data); - test.AddInput("value", value_shape, value_data); - - if (initial_state_data != nullptr) { - test.AddInput("initial_state", state_shape, *initial_state_data); - } else { - test.AddOptionalInputEdge(); - } - - if (decay_data != nullptr) { - test.AddInput("decay", decay_shape, *decay_data); - } else { - test.AddOptionalInputEdge(); - } - - if (beta_data != nullptr) { - test.AddInput("beta", beta_shape, *beta_data); - } else { - test.AddOptionalInputEdge(); - } - - test.AddOutput("output", output_shape, expected_output); - test.AddOutput("final_state", state_shape, expected_state); - } else { - test.AddInput("query", query_shape, ToFloat16(query_data)); - test.AddInput("key", key_shape, ToFloat16(key_data)); - test.AddInput("value", value_shape, ToFloat16(value_data)); - - if (initial_state_data != nullptr) { - test.AddInput("initial_state", state_shape, ToFloat16(*initial_state_data)); - } else { - test.AddOptionalInputEdge(); - } - - if (decay_data != nullptr) { - test.AddInput("decay", decay_shape, ToFloat16(*decay_data)); - } else { - test.AddOptionalInputEdge(); - } - - if (beta_data != nullptr) { - test.AddInput("beta", beta_shape, ToFloat16(*beta_data)); - } else { - test.AddOptionalInputEdge(); - } - - test.AddOutput("output", output_shape, ToFloat16(expected_output)); - test.AddOutput("final_state", state_shape, ToFloat16(expected_state)); - } - - test.SetOutputAbsErr("output", 0.01f); - test.SetOutputAbsErr("final_state", 0.01f); - - std::vector> test_execution_providers; - test_execution_providers.push_back(std::move(ep)); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &test_execution_providers); + for (int i = 0; i < B * H * T; i++) { + beta[i] = 0.5f + 0.3f * std::sin(static_cast(i)); } -} - -TEST(ContribOpLinearAttentionTest, LinearAttentionChunkParallel_Linear_Basic) { - int batch_size = 1; - int num_heads = 2; - int seq_length = 8; - int head_dim_k = 4; - int head_dim_v = 4; - int qkv_size = batch_size * num_heads * seq_length * head_dim_k; - int value_size = batch_size * num_heads * seq_length * head_dim_v; + std::vector initial_state(B * H * dk * dv, 0.1f); - std::vector query_data(qkv_size); - std::vector key_data(qkv_size); - std::vector value_data(value_size); - - for (int i = 0; i < qkv_size; ++i) { - query_data[i] = 0.1f * (i % 10 - 5); - key_data[i] = 0.1f * ((i + 3) % 10 - 5); - } - for (int i = 0; i < value_size; ++i) { - value_data[i] = 0.1f * ((i + 7) % 10 - 5); - } - - RunLinearAttentionChunkParallelTest( - query_data, key_data, value_data, - nullptr, nullptr, nullptr, - batch_size, num_heads, seq_length, head_dim_k, head_dim_v, - "linear", 4, 0.0f, TensorType::kFloat); + RunLinearAttentionTest("gated_delta", B, H, T, dk, dv, scale, + query, key, value, + &initial_state, &decay, &beta); } -TEST(ContribOpLinearAttentionTest, LinearAttentionChunkParallel_Gated_Basic) { - int batch_size = 1; - int num_heads = 2; - int seq_length = 8; - int head_dim_k = 4; - int head_dim_v = 4; - - int qkv_size = batch_size * num_heads * seq_length * head_dim_k; - int value_size = batch_size * num_heads * seq_length * head_dim_v; - int decay_size = batch_size * num_heads * seq_length * head_dim_k; - - std::vector query_data(qkv_size); - std::vector key_data(qkv_size); - std::vector value_data(value_size); - std::vector decay_data(decay_size); - - for (int i = 0; i < qkv_size; ++i) { - query_data[i] = 0.1f * (i % 10 - 5); - key_data[i] = 0.1f * ((i + 3) % 10 - 5); - } - for (int i = 0; i < value_size; ++i) { - value_data[i] = 0.1f * ((i + 7) % 10 - 5); - } - for (int i = 0; i < decay_size; ++i) { - decay_data[i] = -0.1f * ((i % 3) + 1); +// =========================================================================== +// Test: Default scale (should use 1/sqrt(dk)) +// =========================================================================== +TEST(ContribOpLinearAttentionTest, LinearRule_DefaultScale) { + const int B = 1, H = 1, T = 1, dk = 4, dv = 4; + + std::vector query = {1.0f, 0.0f, 0.5f, -0.5f}; + std::vector key = {0.5f, 0.5f, 0.0f, 1.0f}; + std::vector value = {1.0f, 2.0f, 3.0f, 4.0f}; + + // Compute with explicit scale for reference + float actual_scale = 1.0f / std::sqrt(static_cast(dk)); + std::vector expected_output, expected_state; + LinearAttentionReference("linear", B, H, T, dk, dv, actual_scale, + query, key, value, nullptr, nullptr, nullptr, + expected_output, expected_state); + + bool enable_webgpu = (nullptr != DefaultWebGpuExecutionProvider().get()); + if (!enable_webgpu) { + return; } - RunLinearAttentionChunkParallelTest( - query_data, key_data, value_data, - nullptr, &decay_data, nullptr, - batch_size, num_heads, seq_length, head_dim_k, head_dim_v, - "gated", 4, 0.0f, TensorType::kFloat); -} + OpTester tester("LinearAttention", 1, onnxruntime::kMSDomain); + tester.AddAttribute("update_rule", std::string("linear")); + // Don't set scale — use default (0.0 triggers 1/sqrt(dk)) -TEST(ContribOpLinearAttentionTest, LinearAttentionChunkParallel_GatedDelta_WithInitialState) { - int batch_size = 1; - int num_heads = 2; - int seq_length = 16; - int head_dim_k = 4; - int head_dim_v = 4; - - int qkv_size = batch_size * num_heads * seq_length * head_dim_k; - int value_size = batch_size * num_heads * seq_length * head_dim_v; - int state_size = batch_size * num_heads * head_dim_k * head_dim_v; - int decay_size = batch_size * num_heads * seq_length * head_dim_k; - int beta_size = batch_size * num_heads * seq_length; - - std::vector query_data(qkv_size); - std::vector key_data(qkv_size); - std::vector value_data(value_size); - std::vector initial_state_data(state_size); - std::vector decay_data(decay_size); - std::vector beta_data(beta_size); - - for (int i = 0; i < qkv_size; ++i) { - query_data[i] = 0.1f * (i % 10 - 5); - key_data[i] = 0.1f * ((i + 3) % 10 - 5); - } - for (int i = 0; i < value_size; ++i) { - value_data[i] = 0.1f * ((i + 7) % 10 - 5); - } - for (int i = 0; i < state_size; ++i) { - initial_state_data[i] = 0.05f; - } - for (int i = 0; i < decay_size; ++i) { - decay_data[i] = -0.1f * ((i % 3) + 1); - } - for (int i = 0; i < beta_size; ++i) { - beta_data[i] = 0.5f; - } + std::vector qk_dims = {B, H, T, dk}; + std::vector v_dims = {B, H, T, dv}; + tester.AddInput("query", qk_dims, query); + tester.AddInput("key", qk_dims, key); + tester.AddInput("value", v_dims, value); + tester.AddOptionalInputEdge(); // initial_state + tester.AddOptionalInputEdge(); // decay + tester.AddOptionalInputEdge(); // beta + + std::vector out_dims = {B, H, T, dv}; + std::vector state_dims = {B, H, dk, dv}; + tester.AddOutput("output", out_dims, expected_output, false, 0.005f, 0.005f); + tester.AddOutput("final_state", state_dims, expected_state, false, 0.005f, 0.005f); - RunLinearAttentionChunkParallelTest( - query_data, key_data, value_data, - &initial_state_data, &decay_data, &beta_data, - batch_size, num_heads, seq_length, head_dim_k, head_dim_v, - "gated_delta", 8, 0.0f, TensorType::kFloat); + std::vector> execution_providers; + execution_providers.push_back(DefaultWebGpuExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } -TEST(ContribOpLinearAttentionTest, LinearAttentionChunkParallel_LargerSequence) { - int batch_size = 2; - int num_heads = 4; - int seq_length = 64; - int head_dim_k = 8; - int head_dim_v = 8; - - int qkv_size = batch_size * num_heads * seq_length * head_dim_k; - int value_size = batch_size * num_heads * seq_length * head_dim_v; - int decay_size = batch_size * num_heads * seq_length * head_dim_k; - int beta_size = batch_size * num_heads * seq_length; - - std::vector query_data(qkv_size); - std::vector key_data(qkv_size); - std::vector value_data(value_size); - std::vector decay_data(decay_size); - std::vector beta_data(beta_size); - - for (int i = 0; i < qkv_size; ++i) { - query_data[i] = 0.05f * (i % 20 - 10); - key_data[i] = 0.05f * ((i + 7) % 20 - 10); - } - for (int i = 0; i < value_size; ++i) { - value_data[i] = 0.05f * ((i + 13) % 20 - 10); +// =========================================================================== +// Test: Longer sequence +// =========================================================================== +TEST(ContribOpLinearAttentionTest, GatedDeltaRule_LongerSequence) { + const int B = 1, H = 2, T = 16, dk = 8, dv = 8; + float scale = 1.0f / std::sqrt(static_cast(dk)); + + std::vector query(B * H * T * dk); + std::vector key(B * H * T * dk); + std::vector value(B * H * T * dv); + std::vector decay(B * H * T * dk); + std::vector beta(B * H * T); + + for (int i = 0; i < B * H * T * dk; i++) { + query[i] = 0.1f * std::sin(static_cast(i) * 0.13f); + key[i] = 0.1f * std::cos(static_cast(i) * 0.17f); + decay[i] = -0.05f - 0.05f * std::abs(std::sin(static_cast(i) * 0.07f)); } - for (int i = 0; i < decay_size; ++i) { - decay_data[i] = -0.05f * ((i % 5) + 1); + for (int i = 0; i < B * H * T * dv; i++) { + value[i] = 0.1f * std::sin(static_cast(i) * 0.23f + 0.5f); } - for (int i = 0; i < beta_size; ++i) { - beta_data[i] = 0.3f + 0.1f * (i % 5); + for (int i = 0; i < B * H * T; i++) { + beta[i] = 0.5f + 0.3f * std::sin(static_cast(i) * 0.31f); } - RunLinearAttentionChunkParallelTest( - query_data, key_data, value_data, - nullptr, &decay_data, &beta_data, - batch_size, num_heads, seq_length, head_dim_k, head_dim_v, - "gated_delta", 16, 0.0f, TensorType::kFloat); - - // Also test FP16 - RunLinearAttentionChunkParallelTest( - query_data, key_data, value_data, - nullptr, &decay_data, &beta_data, - batch_size, num_heads, seq_length, head_dim_k, head_dim_v, - "gated_delta", 16, 0.0f, TensorType::kFloat16); + std::vector initial_state(B * H * dk * dv, 0.01f); + + RunLinearAttentionTest("gated_delta", B, H, T, dk, dv, scale, + query, key, value, + &initial_state, &decay, &beta); } } // namespace test From 95400bee91279c5ffc2fa8d7a3dd865d61c0d4fa Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Wed, 18 Mar 2026 15:35:02 -0700 Subject: [PATCH 13/27] allow for decay [B,H,T] --- .../webgpu/bert/linear_attention.cc | 26 +++++- .../webgpu/bert/linear_attention.h | 4 +- .../contrib_ops/linear_attention_op_test.cc | 83 +++++++++++++++++-- 3 files changed, 103 insertions(+), 10 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc index e94fba5041385..b31def514f714 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc @@ -107,8 +107,17 @@ Status LinearAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { // Step 1: Apply decay (for gated and gated_delta modes) if (update_rule_ == LinearAttentionUpdateRule::Gated || update_rule_ == LinearAttentionUpdateRule::GatedDelta) { shader.MainFunctionBody() - << "\n // Apply exponential decay: S *= exp(decay)\n" - << " let decay_base = (qkv_bh_offset + t) * uniforms.head_dim_k + dk_idx;\n" + << "\n // Apply exponential decay: S *= exp(decay)\n"; + if (decay_broadcast_dk_) { + // Decay shape is (B, H, T) — same decay for all dk rows + shader.MainFunctionBody() + << " let decay_base = qkv_bh_offset + t;\n"; + } else { + // Decay shape is (B, H, T, dk) — per-dk decay + shader.MainFunctionBody() + << " let decay_base = (qkv_bh_offset + t) * uniforms.head_dim_k + dk_idx;\n"; + } + shader.MainFunctionBody() << " let exp_g = exp(f32(decay[decay_base]));\n" << " for (var j = 0u; j < TILE_V; j++) {\n" << " state[j] *= exp_g;\n" @@ -290,7 +299,16 @@ Status LinearAttention::ComputeInternal(ComputeContext& context) const { bool has_decay = decay != nullptr; bool has_beta = beta != nullptr; - LinearAttentionProgram program{update_rule_, has_initial_state, has_decay, has_beta, tile_v}; + // Detect whether decay is (B,H,T) or (B,H,T,dk) + bool decay_broadcast_dk = false; + if (has_decay) { + const auto& decay_shape = decay->Shape(); + if (decay_shape.NumDimensions() == 3) { + decay_broadcast_dk = true; + } + } + + LinearAttentionProgram program{update_rule_, has_initial_state, has_decay, has_beta, decay_broadcast_dk, tile_v}; program.AddInputs({{query, ProgramTensorMetadataDependency::TypeAndRank}, {key, ProgramTensorMetadataDependency::TypeAndRank}, @@ -311,7 +329,7 @@ Status LinearAttention::ComputeInternal(ComputeContext& context) const { program.SetDispatchGroupSize(num_workgroups) .SetWorkgroupSize(workgroup_size) .CacheHint(std::to_string(static_cast(update_rule_)), - has_initial_state, has_decay, has_beta, tile_v) + has_initial_state, has_decay, has_beta, decay_broadcast_dk, tile_v) .AddUniformVariables({{static_cast(batch_size)}, {static_cast(num_heads)}, {static_cast(seq_length)}, diff --git a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.h b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.h index a55a4d1801deb..bdd86e9d3d759 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.h @@ -31,12 +31,13 @@ LinearAttentionUpdateRule ParseUpdateRule(const std::string& rule_str); class LinearAttentionProgram final : public Program { public: LinearAttentionProgram(LinearAttentionUpdateRule update_rule, bool has_initial_state, - bool has_decay, bool has_beta, int tile_v) + bool has_decay, bool has_beta, bool decay_broadcast_dk, int tile_v) : Program{"LinearAttention"}, update_rule_(update_rule), has_initial_state_(has_initial_state), has_decay_(has_decay), has_beta_(has_beta), + decay_broadcast_dk_(decay_broadcast_dk), tile_v_(tile_v) {} Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -55,6 +56,7 @@ class LinearAttentionProgram final : public Program { bool has_initial_state_; bool has_decay_; bool has_beta_; + bool decay_broadcast_dk_; int tile_v_; }; diff --git a/onnxruntime/test/contrib_ops/linear_attention_op_test.cc b/onnxruntime/test/contrib_ops/linear_attention_op_test.cc index af10d85330048..74529dc5c6a04 100644 --- a/onnxruntime/test/contrib_ops/linear_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/linear_attention_op_test.cc @@ -27,6 +27,7 @@ void LinearAttentionReference( const std::vector* initial_state, const std::vector* decay, const std::vector* beta, + bool decay_broadcast_dk, std::vector& output, std::vector& final_state) { // State: (B, H, dk, dv) @@ -63,8 +64,14 @@ void LinearAttentionReference( // Step 1: Apply decay (gated, gated_delta) if (update_rule == "gated" || update_rule == "gated_delta") { for (int k = 0; k < head_dim_k; k++) { - int decay_idx = ((b * num_heads + h) * seq_length + t) * head_dim_k + k; - float exp_g = std::exp((*decay)[decay_idx]); + float exp_g; + if (decay_broadcast_dk) { + int decay_idx = (b * num_heads + h) * seq_length + t; + exp_g = std::exp((*decay)[decay_idx]); + } else { + int decay_idx = ((b * num_heads + h) * seq_length + t) * head_dim_k + k; + exp_g = std::exp((*decay)[decay_idx]); + } for (int v_idx = 0; v_idx < head_dim_v; v_idx++) { final_state[state_offset(k, v_idx)] *= exp_g; } @@ -127,12 +134,14 @@ void RunLinearAttentionTest( const std::vector& value, const std::vector* initial_state, const std::vector* decay, - const std::vector* beta_data) { + const std::vector* beta_data, + bool decay_broadcast_dk = false) { // Compute reference output std::vector expected_output, expected_state; LinearAttentionReference(update_rule, batch_size, num_heads, seq_length, head_dim_k, head_dim_v, scale, query, key, value, initial_state, decay, beta_data, + decay_broadcast_dk, expected_output, expected_state); bool enable_webgpu = (nullptr != DefaultWebGpuExecutionProvider().get()); @@ -161,8 +170,13 @@ void RunLinearAttentionTest( // Optional: decay if (decay != nullptr) { - std::vector decay_dims = {batch_size, num_heads, seq_length, head_dim_k}; - tester.AddInput("decay", decay_dims, *decay); + if (decay_broadcast_dk) { + std::vector decay_dims = {batch_size, num_heads, seq_length}; + tester.AddInput("decay", decay_dims, *decay); + } else { + std::vector decay_dims = {batch_size, num_heads, seq_length, head_dim_k}; + tester.AddInput("decay", decay_dims, *decay); + } } else { tester.AddOptionalInputEdge(); } @@ -389,6 +403,64 @@ TEST(ContribOpLinearAttentionTest, GatedDeltaRule_MultiToken) { &initial_state, &decay, &beta); } +// =========================================================================== +// Test: Gated rule with B,H,T decay (broadcast across dk) +// =========================================================================== +TEST(ContribOpLinearAttentionTest, GatedRule_BroadcastDecay) { + const int B = 1, H = 1, T = 3, dk = 4, dv = 4; + float scale = 1.0f / std::sqrt(static_cast(dk)); + + std::vector query = { + 1.0f, 0.0f, 0.5f, -0.5f, + 0.5f, 1.0f, -0.5f, 0.0f, + 0.0f, -1.0f, 1.0f, 0.5f}; + std::vector key = { + 0.5f, 0.5f, 0.0f, 1.0f, + 1.0f, 0.0f, 1.0f, 0.5f, + -0.5f, 1.0f, 0.5f, 0.0f}; + std::vector value = { + 1.0f, 2.0f, 3.0f, 4.0f, + 2.0f, 1.0f, 0.0f, 3.0f, + 3.0f, 0.0f, 1.0f, 2.0f}; + // Decay shape: (B, H, T) = (1, 1, 3) — one scalar per token + std::vector decay = {-0.1f, -0.2f, -0.05f}; + + std::vector initial_state(dk * dv, 0.5f); + + RunLinearAttentionTest("gated", B, H, T, dk, dv, scale, + query, key, value, + &initial_state, &decay, nullptr, + /*decay_broadcast_dk=*/true); +} + +TEST(ContribOpLinearAttentionTest, GatedDeltaRule_BroadcastDecay) { + const int B = 1, H = 1, T = 3, dk = 4, dv = 4; + float scale = 1.0f / std::sqrt(static_cast(dk)); + + std::vector query = { + 1.0f, 0.0f, 0.5f, -0.5f, + 0.5f, 1.0f, -0.5f, 0.0f, + 0.0f, -1.0f, 1.0f, 0.5f}; + std::vector key = { + 0.5f, 0.5f, 0.0f, 1.0f, + 1.0f, 0.0f, 1.0f, 0.5f, + -0.5f, 1.0f, 0.5f, 0.0f}; + std::vector value = { + 1.0f, 2.0f, 3.0f, 4.0f, + 2.0f, 1.0f, 0.0f, 3.0f, + 3.0f, 0.0f, 1.0f, 2.0f}; + // Decay shape: (B, H, T) = (1, 1, 3) — one scalar per token + std::vector decay = {-0.1f, -0.2f, -0.05f}; + std::vector beta = {0.8f, 0.6f, 0.9f}; + + std::vector initial_state(dk * dv, 0.5f); + + RunLinearAttentionTest("gated_delta", B, H, T, dk, dv, scale, + query, key, value, + &initial_state, &decay, &beta, + /*decay_broadcast_dk=*/true); +} + // =========================================================================== // Test: Multi-batch, multi-head // =========================================================================== @@ -459,6 +531,7 @@ TEST(ContribOpLinearAttentionTest, LinearRule_DefaultScale) { std::vector expected_output, expected_state; LinearAttentionReference("linear", B, H, T, dk, dv, actual_scale, query, key, value, nullptr, nullptr, nullptr, + false, expected_output, expected_state); bool enable_webgpu = (nullptr != DefaultWebGpuExecutionProvider().get()); From 6c0c736e642ad02e62d74ef0d82aa6366c5bb2ab Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Mon, 23 Mar 2026 22:35:20 -0700 Subject: [PATCH 14/27] guard for head-dim_k --- .../webgpu/bert/linear_attention.cc | 39 +- .../contrib_ops/linear_attention_op_test.cc | 339 +++++++++++++++++- 2 files changed, 354 insertions(+), 24 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc index b31def514f714..58435d6658d59 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc @@ -86,10 +86,12 @@ Status LinearAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { if (has_initial_state_) { shader.MainFunctionBody() << "// Load initial state: initial_state[batch, head, dk_idx, dv_start..dv_start+TILE_V]\n" - << "let state_base = ((batch_idx * uniforms.num_heads + head_idx) * uniforms.head_dim_k + dk_idx) * uniforms.head_dim_v + dv_start;\n" - << "for (var j = 0u; j < TILE_V; j++) {\n" - << " if (dv_start + j < uniforms.head_dim_v) {\n" - << " state[j] = f32(initial_state[state_base + j]);\n" + << "if (dk_idx < uniforms.head_dim_k) {\n" + << " let state_base = ((batch_idx * uniforms.num_heads + head_idx) * uniforms.head_dim_k + dk_idx) * uniforms.head_dim_v + dv_start;\n" + << " for (var j = 0u; j < TILE_V; j++) {\n" + << " if (dv_start + j < uniforms.head_dim_v) {\n" + << " state[j] = f32(initial_state[state_base + j]);\n" + << " }\n" << " }\n" << "}\n"; } @@ -100,9 +102,13 @@ Status LinearAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { << "for (var t = 0u; t < uniforms.seq_length; t++) {\n" // Load k and q for this thread's dk row << " let qkv_bh_offset = (batch_idx * uniforms.num_heads + head_idx) * uniforms.seq_length;\n" - << " let k_base = (qkv_bh_offset + t) * uniforms.head_dim_k + dk_idx;\n" - << " let k_val = f32(key[k_base]);\n" - << " let q_val = f32(query[k_base]);\n"; + << " var k_val: f32 = 0.0;\n" + << " var q_val: f32 = 0.0;\n" + << " if (dk_idx < uniforms.head_dim_k) {\n" + << " let k_base = (qkv_bh_offset + t) * uniforms.head_dim_k + dk_idx;\n" + << " k_val = f32(key[k_base]);\n" + << " q_val = f32(query[k_base]);\n" + << " }\n"; // Step 1: Apply decay (for gated and gated_delta modes) if (update_rule_ == LinearAttentionUpdateRule::Gated || update_rule_ == LinearAttentionUpdateRule::GatedDelta) { @@ -111,14 +117,17 @@ Status LinearAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { if (decay_broadcast_dk_) { // Decay shape is (B, H, T) — same decay for all dk rows shader.MainFunctionBody() - << " let decay_base = qkv_bh_offset + t;\n"; + << " let exp_g = exp(f32(decay[qkv_bh_offset + t]));\n"; } else { // Decay shape is (B, H, T, dk) — per-dk decay + // For padding threads (dk_idx >= head_dim_k), use 0.0 (exp(0)=1, no decay) shader.MainFunctionBody() - << " let decay_base = (qkv_bh_offset + t) * uniforms.head_dim_k + dk_idx;\n"; + << " var exp_g: f32 = 1.0;\n" + << " if (dk_idx < uniforms.head_dim_k) {\n" + << " exp_g = exp(f32(decay[(qkv_bh_offset + t) * uniforms.head_dim_k + dk_idx]));\n" + << " }\n"; } shader.MainFunctionBody() - << " let exp_g = exp(f32(decay[decay_base]));\n" << " for (var j = 0u; j < TILE_V; j++) {\n" << " state[j] *= exp_g;\n" << " }\n"; @@ -205,10 +214,12 @@ Status LinearAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { // Write final state shader.MainFunctionBody() << "\n// Write final state\n" - << "let final_state_base = ((batch_idx * uniforms.num_heads + head_idx) * uniforms.head_dim_k + dk_idx) * uniforms.head_dim_v + dv_start;\n" - << "for (var j = 0u; j < TILE_V; j++) {\n" - << " if (dv_start + j < uniforms.head_dim_v) {\n" - << " final_state[final_state_base + j] = output_element_t(state[j]);\n" + << "if (dk_idx < uniforms.head_dim_k) {\n" + << " let final_state_base = ((batch_idx * uniforms.num_heads + head_idx) * uniforms.head_dim_k + dk_idx) * uniforms.head_dim_v + dv_start;\n" + << " for (var j = 0u; j < TILE_V; j++) {\n" + << " if (dv_start + j < uniforms.head_dim_v) {\n" + << " final_state[final_state_base + j] = output_element_t(state[j]);\n" + << " }\n" << " }\n" << "}\n"; diff --git a/onnxruntime/test/contrib_ops/linear_attention_op_test.cc b/onnxruntime/test/contrib_ops/linear_attention_op_test.cc index 74529dc5c6a04..824922a0dc2a8 100644 --- a/onnxruntime/test/contrib_ops/linear_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/linear_attention_op_test.cc @@ -10,6 +10,238 @@ using namespace onnxruntime::test; +namespace linear_attention { + +enum class UpdateRule { + kLinear, // S = S + k^T v + kGated, // S = exp(g) * S + k^T v + kDelta, // S = S + k^T ((v - S^T k) * beta) + kGatedDelta, // S = exp(g) * S + k^T ((v - exp(g) * S^T k) * beta) +}; + +/// 4D tensor accessor with (d0, d1, d2, d3) layout (row-major). +template +struct Tensor4D { + T* data; + int d0, d1, d2, d3; + + Tensor4D(T* data, int d0, int d1, int d2, int d3) + : data(data), d0(d0), d1(d1), d2(d2), d3(d3) {} + + T& operator()(int i0, int i1, int i2, int i3) { + return data[((i0 * d1 + i1) * d2 + i2) * d3 + i3]; + } + const T& operator()(int i0, int i1, int i2, int i3) const { + return data[((i0 * d1 + i1) * d2 + i2) * d3 + i3]; + } + + int size() const { return d0 * d1 * d2 * d3; } +}; + +/// 3D tensor accessor with (d0, d1, d2) layout (row-major). +template +struct Tensor3D { + T* data; + int d0, d1, d2; + + Tensor3D(T* data, int d0, int d1, int d2) + : data(data), d0(d0), d1(d1), d2(d2) {} + + T& operator()(int i0, int i1, int i2) { + return data[(i0 * d1 + i1) * d2 + i2]; + } + const T& operator()(int i0, int i1, int i2) const { + return data[(i0 * d1 + i1) * d2 + i2]; + } + + int size() const { return d0 * d1 * d2; } +}; + +/// Compute one recurrence step for a single (batch, head) pair. +/// +/// state: (d_k, d_v) — carried across time steps (modified in place) +/// q_t: (d_k,) — query for this step +/// k_t: (d_k,) — key for this step +/// v_t: (d_v,) — value for this step +/// decay_t: scalar — log-space decay gate +/// beta_t: scalar — update rate +/// output_t: (d_v,) — output for this step (written) +template +inline void recurrence_step( + T* state, // (d_k, d_v), modified in place + const T* q_t, // (d_k,) + const T* k_t, // (d_k,) + const T* v_t, // (d_v,) + T decay_t, // scalar (log-space) + T beta_t, // scalar + T* output_t, // (d_v,), written + int d_k, + int d_v, + UpdateRule rule) { + const bool uses_decay = (rule == UpdateRule::kGated || rule == UpdateRule::kGatedDelta); + const bool uses_beta = (rule == UpdateRule::kDelta || rule == UpdateRule::kGatedDelta); + + // 1. State decay: state *= exp(decay) + T g_exp = T(1); + if (uses_decay) { + g_exp = std::exp(decay_t); + for (int i = 0; i < d_k * d_v; ++i) { + state[i] *= g_exp; + } + } + + // 2. Retrieval: retrieval[j] = sum_i k[i] * state[i, j] (k @ state) + // This computes k_row (1, d_k) @ state (d_k, d_v) -> (1, d_v) + std::vector retrieval(d_v, T(0)); + for (int i = 0; i < d_k; ++i) { + for (int j = 0; j < d_v; ++j) { + retrieval[j] += k_t[i] * state[i * d_v + j]; + } + } + + // 3. Compute delta + std::vector delta(d_v); + if (uses_beta) { + // delta = (v - retrieval) * beta + for (int j = 0; j < d_v; ++j) { + delta[j] = (v_t[j] - retrieval[j]) * beta_t; + } + } else { + // delta = v + std::copy(v_t, v_t + d_v, delta.data()); + } + + // 4. State update: state += k^T @ delta (outer product) + // state[i, j] += k[i] * delta[j] + for (int i = 0; i < d_k; ++i) { + for (int j = 0; j < d_v; ++j) { + state[i * d_v + j] += k_t[i] * delta[j]; + } + } + + // 5. Output: output[j] = sum_i q[i] * new_state[i, j] (q @ new_state) + std::fill(output_t, output_t + d_v, T(0)); + for (int i = 0; i < d_k; ++i) { + for (int j = 0; j < d_v; ++j) { + output_t[j] += q_t[i] * state[i * d_v + j]; + } + } +} + +/// Expand Q/K heads for GQA: repeat each head `ratio` times. +/// +/// src: (B, H_kv, T, d) +/// dst: (B, H, T, d) where H = H_kv * ratio +template +inline void expand_kv_heads( + const T* src, T* dst, + int B, int H_kv, int T_len, int d, int ratio) { + if (ratio == 1) { + std::memcpy(dst, src, B * H_kv * T_len * d * sizeof(T)); + return; + } + for (int b = 0; b < B; ++b) { + for (int h_kv = 0; h_kv < H_kv; ++h_kv) { + const T* src_head = src + ((b * H_kv + h_kv) * T_len) * d; + for (int r = 0; r < ratio; ++r) { + int h = h_kv * ratio + r; + T* dst_head = dst + ((b * (H_kv * ratio) + h) * T_len) * d; + std::memcpy(dst_head, src_head, T_len * d * sizeof(T)); + } + } + } +} + +/// Run the full LinearAttention operator. +/// +/// query: (B, H_kv, T, d_k) — pre-scaled by 1/sqrt(d_k) +/// key: (B, H_kv, T, d_k) — L2-normalized +/// value: (B, H, T, d_v) +/// past_state: (B, H, d_k, d_v) — recurrent state from previous chunk +/// decay: (B, H, T) — log-space decay gate +/// beta: (B, H, T) — sigmoid update rate +/// output: (B, H, T, d_v) — attention output [written] +/// present_state: (B, H, d_k, d_v) — updated state [written] +/// +/// H must be divisible by H_kv (GQA ratio = H / H_kv). +template +void linear_attention_forward( + const T* query, // (B, H_kv, T, d_k) + const T* key, // (B, H_kv, T, d_k) + const T* value, // (B, H, T, d_v) + const T* past_state, // (B, H, d_k, d_v) + const T* decay, // (B, H, T) + const T* beta, // (B, H, T) + T* output, // (B, H, T, d_v) + T* present_state, // (B, H, d_k, d_v) + int B, + int H_kv, + int H, + int T_len, + int d_k, + int d_v, + float scale, + UpdateRule rule = UpdateRule::kGatedDelta) { + assert(H % H_kv == 0 && "H must be divisible by H_kv for GQA"); + const int ratio = H / H_kv; + + // --- GQA: expand Q/K heads to match V head count --- + std::vector q_expanded(B * H * T_len * d_k); + std::vector k_expanded(B * H * T_len * d_k); + expand_kv_heads(query, q_expanded.data(), B, H_kv, T_len, d_k, ratio); + expand_kv_heads(key, k_expanded.data(), B, H_kv, T_len, d_k, ratio); + + // Accessors for expanded Q/K + Tensor4D Q(q_expanded.data(), B, H, T_len, d_k); + Tensor4D K(k_expanded.data(), B, H, T_len, d_k); + Tensor4D V(value, B, H, T_len, d_v); + Tensor3D D(decay, B, H, T_len); + Tensor4D O(output, B, H, T_len, d_v); + + // Copy past_state into present_state (we'll modify it in place) + const int state_size = B * H * d_k * d_v; + std::memcpy(present_state, past_state, state_size * sizeof(T)); + Tensor4D S(present_state, B, H, d_k, d_v); + + // --- Sequential recurrence over time --- + for (int b = 0; b < B; ++b) { + for (int h = 0; h < H; ++h) { + T* state_bh = &S(b, h, 0, 0); // (d_k, d_v) slice + + for (int t = 0; t < T_len; ++t) { + recurrence_step( + state_bh, + &Q(b, h, t, 0), + &K(b, h, t, 0), + &V(b, h, t, 0), + D(b, h, t), + beta ? beta[((b * H + h) * T_len) + t] : T(0), + &O(b, h, t, 0), + d_k, d_v, + rule); + } + } + } + // linear_attention_forward doesn't apply scale; apply it here. + int output_size = B * H * T_len * d_v; + for (int i = 0; i < output_size; ++i) { + output[i] *= scale; + } + +} + +/// Parse a string into an UpdateRule enum. +inline UpdateRule parse_update_rule(const std::string& s) { + if (s == "linear") return UpdateRule::kLinear; + if (s == "gated") return UpdateRule::kGated; + if (s == "delta") return UpdateRule::kDelta; + if (s == "gated_delta") return UpdateRule::kGatedDelta; + assert(false && "Unknown update_rule"); + return UpdateRule::kGatedDelta; +} + +} // namespace linear_attention + namespace onnxruntime { namespace test { @@ -27,10 +259,39 @@ void LinearAttentionReference( const std::vector* initial_state, const std::vector* decay, const std::vector* beta, - bool decay_broadcast_dk, std::vector& output, std::vector& final_state) { - // State: (B, H, dk, dv) + + int bht = batch_size * num_heads * seq_length; + bool decay_broadcast_dk = (decay != nullptr && static_cast(decay->size()) == bht); + + if (false && decay_broadcast_dk) { + output.resize(batch_size * num_heads * seq_length * head_dim_v, 0.0f); + final_state.resize(batch_size * num_heads * head_dim_k * head_dim_v, 0.0f); + if (initial_state != nullptr) { + final_state = *initial_state; + } + linear_attention::linear_attention_forward( + query.data(), // (B, H_kv, T, d_k) + key.data(), // (B, H_kv, T, d_k) + value.data(), // (B, H, T, d_v) + final_state.data(), // (B, H, d_k, d_v) + decay->data(), // (B, H, T) + beta ? beta->data() : nullptr, // (B, H, T) + output.data(), // (B, H, T, d_v) + final_state.data(), // (B, H, d_k, d_v) + batch_size, + num_heads, + num_heads, + seq_length, + head_dim_k, + head_dim_v, + scale, + linear_attention::parse_update_rule(update_rule)); + return; + } + + // State: (B, H, dk, dv) final_state.resize(batch_size * num_heads * head_dim_k * head_dim_v, 0.0f); output.resize(batch_size * num_heads * seq_length * head_dim_v, 0.0f); @@ -134,14 +395,12 @@ void RunLinearAttentionTest( const std::vector& value, const std::vector* initial_state, const std::vector* decay, - const std::vector* beta_data, - bool decay_broadcast_dk = false) { + const std::vector* beta_data) { // Compute reference output std::vector expected_output, expected_state; LinearAttentionReference(update_rule, batch_size, num_heads, seq_length, head_dim_k, head_dim_v, scale, query, key, value, initial_state, decay, beta_data, - decay_broadcast_dk, expected_output, expected_state); bool enable_webgpu = (nullptr != DefaultWebGpuExecutionProvider().get()); @@ -149,6 +408,8 @@ void RunLinearAttentionTest( return; } + int bht = batch_size * num_heads * seq_length; + bool decay_broadcast_dk = (decay != nullptr && static_cast(decay->size()) == bht); OpTester tester("LinearAttention", 1, onnxruntime::kMSDomain); tester.AddAttribute("update_rule", update_rule); tester.AddAttribute("scale", scale); @@ -429,8 +690,7 @@ TEST(ContribOpLinearAttentionTest, GatedRule_BroadcastDecay) { RunLinearAttentionTest("gated", B, H, T, dk, dv, scale, query, key, value, - &initial_state, &decay, nullptr, - /*decay_broadcast_dk=*/true); + &initial_state, &decay, nullptr); } TEST(ContribOpLinearAttentionTest, GatedDeltaRule_BroadcastDecay) { @@ -457,8 +717,7 @@ TEST(ContribOpLinearAttentionTest, GatedDeltaRule_BroadcastDecay) { RunLinearAttentionTest("gated_delta", B, H, T, dk, dv, scale, query, key, value, - &initial_state, &decay, &beta, - /*decay_broadcast_dk=*/true); + &initial_state, &decay, &beta); } // =========================================================================== @@ -531,7 +790,6 @@ TEST(ContribOpLinearAttentionTest, LinearRule_DefaultScale) { std::vector expected_output, expected_state; LinearAttentionReference("linear", B, H, T, dk, dv, actual_scale, query, key, value, nullptr, nullptr, nullptr, - false, expected_output, expected_state); bool enable_webgpu = (nullptr != DefaultWebGpuExecutionProvider().get()); @@ -594,5 +852,66 @@ TEST(ContribOpLinearAttentionTest, GatedDeltaRule_LongerSequence) { &initial_state, &decay, &beta); } +// Test with Qwen3.5-like dimensions: dk=128, dv=128, broadcast decay +TEST(ContribOpLinearAttentionTest, GatedDeltaRule_Qwen35Like) { + const int B = 1, H = 2, T = 8, dk = 128, dv = 128; + float scale = 1.0f / std::sqrt(static_cast(dk)); + + std::vector query(B * H * T * dk); + std::vector key(B * H * T * dk); + std::vector value(B * H * T * dv); + // Broadcast decay: (B, H, T) — one scalar per head per token, like Qwen3.5 + std::vector decay(B * H * T); + std::vector beta(B * H * T); + + for (int i = 0; i < B * H * T * dk; i++) { + query[i] = 0.05f * std::sin(static_cast(i) * 0.013f); + key[i] = 0.05f * std::cos(static_cast(i) * 0.017f); + } + for (int i = 0; i < B * H * T * dv; i++) { + value[i] = 0.05f * std::sin(static_cast(i) * 0.023f + 0.5f); + } + for (int i = 0; i < B * H * T; i++) { + decay[i] = -0.1f - 0.05f * std::abs(std::sin(static_cast(i) * 0.3f)); + beta[i] = 0.5f + 0.3f * std::sin(static_cast(i) * 0.31f); + } + + std::vector initial_state(B * H * dk * dv, 0.01f); + + RunLinearAttentionTest("gated_delta", B, H, T, dk, dv, scale, + query, key, value, + &initial_state, &decay, &beta); +} + +// Test with non-power-of-2 dk to trigger workgroup padding bug +TEST(ContribOpLinearAttentionTest, GatedDeltaRule_NonPowerOf2DK) { + const int B = 1, H = 1, T = 3, dk = 3, dv = 4; + float scale = 1.0f / std::sqrt(static_cast(dk)); + + std::vector query(B * H * T * dk); + std::vector key(B * H * T * dk); + std::vector value(B * H * T * dv); + std::vector decay(B * H * T); + std::vector beta(B * H * T); + + for (int i = 0; i < B * H * T * dk; i++) { + query[i] = 0.5f * std::sin(static_cast(i) * 0.3f); + key[i] = 0.5f * std::cos(static_cast(i) * 0.5f); + } + for (int i = 0; i < B * H * T * dv; i++) { + value[i] = 0.5f * std::sin(static_cast(i) * 0.7f + 1.0f); + } + for (int i = 0; i < B * H * T; i++) { + decay[i] = -0.1f - 0.05f * std::abs(std::sin(static_cast(i) * 0.3f)); + beta[i] = 0.5f + 0.3f * std::sin(static_cast(i) * 0.31f); + } + + std::vector initial_state(B * H * dk * dv, 0.5f); + + RunLinearAttentionTest("gated_delta", B, H, T, dk, dv, scale, + query, key, value, + &initial_state, &decay, &beta); +} + } // namespace test } // namespace onnxruntime From 28a3ee34e47c04a4c1ccfeddabc855dba4ef56c4 Mon Sep 17 00:00:00 2001 From: gs Date: Tue, 24 Mar 2026 12:11:59 -0700 Subject: [PATCH 15/27] qwen3.5 shows now correct results --- .../contrib_ops/webgpu/bert/linear_attention.cc | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc index 58435d6658d59..10986879c6d00 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc @@ -247,6 +247,20 @@ LinearAttention::LinearAttention(const OpKernelInfo& info) chunk_size_ = info.GetAttrOrDefault("chunk_size", 64); } + +/* + Inputs: + query: (B, H_kv, T, d_k) — query (may have fewer heads than value for GQA; pre-scaled by 1/sqrt(d_k)) + key: (B, H_kv, T, d_k) — key (L2-normalized) + value: (B, H, T, d_v) — value (H >= H_kv) + initial_state: (B, H, d_k, d_v) — recurrent state + decay: (B, H, T) — exponential decay gate (log-space) + beta: (B, H, T) — update rate (sigmoid output) + + Outputs: + output: (B, H, T, d_v) — attention output + present_state: (B, H, d_k, d_v) — updated recurrent state +*/ Status LinearAttention::ComputeInternal(ComputeContext& context) const { const Tensor* query = context.Input(0); const Tensor* key = context.Input(1); @@ -276,7 +290,7 @@ Status LinearAttention::ComputeInternal(ComputeContext& context) const { // Compute scale float scale = scale_; if (scale == 0.0f) { - scale = 1.0f / std::sqrt(static_cast(head_dim_k)); + scale = 1.0f; // TODO: should this come / std::sqrt(static_cast(head_dim_k)); } // Allocate outputs From 3f80587b99c7d26f94d8e68a69e1cb74d0b05c93 Mon Sep 17 00:00:00 2001 From: gs Date: Wed, 25 Mar 2026 12:14:50 -0700 Subject: [PATCH 16/27] remove chunk and group from signature --- .../contrib_ops/webgpu/bert/linear_attention.cc | 1 - .../contrib_ops/webgpu/bert/linear_attention.h | 1 - onnxruntime/core/graph/contrib_ops/bert_defs.cc | 14 +++----------- 3 files changed, 3 insertions(+), 13 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc index 10986879c6d00..ee7567ccad3f2 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc @@ -244,7 +244,6 @@ LinearAttention::LinearAttention(const OpKernelInfo& info) std::string update_rule_str = info.GetAttrOrDefault("update_rule", "gated_delta"); update_rule_ = ParseUpdateRule(update_rule_str); scale_ = info.GetAttrOrDefault("scale", 0.0f); - chunk_size_ = info.GetAttrOrDefault("chunk_size", 64); } diff --git a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.h b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.h index bdd86e9d3d759..5ecb736ff1c1c 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.h @@ -69,7 +69,6 @@ class LinearAttention : public WebGpuKernel { protected: LinearAttentionUpdateRule update_rule_; float scale_; - int64_t chunk_size_; }; } // namespace webgpu diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 3477b5e445135..fb9922a1b68d7 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -2242,10 +2242,6 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "Default is 'silu'.", AttributeProto::STRING, std::string("silu")) - .Attr("group", - "group for convolution. Default is 1, which means normal convolution. When group equals to input channels, it becomes depthwise convolution.", - AttributeProto::INT, - static_cast(1)) .Input(0, "input", "Input tensor with shape (batch_size, channels, length). Channels-first layout.", @@ -2302,7 +2298,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( })); constexpr const char* LinearAttention_ver1_doc = R"DOC( -Linear Attention operator (chunk-parallel). +Linear Attention operator. Processes a sequence of tokens using linear attention with a recurrent state matrix. When sequence_length=1, this is equivalent to a single recurrent decode step. @@ -2333,10 +2329,6 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "Output scaling factor. When 0.0 (default), uses 1/sqrt(d_k) where d_k is the key dimension.", AttributeProto::FLOAT, 0.0f) - .Attr("chunk_size", - "Chunk size for parallel computation. Only a hint for the implementation.", - AttributeProto::INT, - static_cast(64)) .Input(0, "query", "Query vectors with shape (batch_size, num_heads, sequence_length, head_dim_k)", @@ -2359,14 +2351,14 @@ ONNX_MS_OPERATOR_SET_SCHEMA( .Input(4, "decay", "Exponential decay gate in log-space with shape broadcastable to " - "(batch_size, num_heads, sequence_length, head_dim_k). " + "(batch_size, num_heads, sequence_length, [head_dim_k]). " "Required for 'gated' and 'gated_delta' modes.", "T", OpSchema::Optional) .Input(5, "beta", "Update rate (sigmoid output) with shape broadcastable to " - "(batch_size, num_heads, sequence_length, 1). " + "(batch_size, num_heads, sequence_length, [1]). " "Required for 'delta' and 'gated_delta' modes.", "T", OpSchema::Optional) From 4111b954821a24d8b42759199f39c845874c1be3 Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Sat, 28 Mar 2026 11:12:56 -0700 Subject: [PATCH 17/27] update to latest signature proposal --- .../webgpu/bert/causal_conv1d_with_state.cc | 20 +- .../webgpu/bert/causal_conv1d_with_state.h | 18 +- .../webgpu/bert/linear_attention.cc | 111 +++--- .../webgpu/bert/linear_attention.h | 2 + .../webgpu/webgpu_contrib_kernels.cc | 4 +- .../core/graph/contrib_ops/bert_defs.cc | 160 +++++--- onnxruntime/core/graph/contrib_ops/ms_opset.h | 4 +- .../causal_conv1d_with_state_op_test.cc | 94 ++--- .../contrib_ops/linear_attention_op_test.cc | 364 ++++-------------- 9 files changed, 318 insertions(+), 459 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/causal_conv1d_with_state.cc b/onnxruntime/contrib_ops/webgpu/bert/causal_conv1d_with_state.cc index fe66579a3b42e..a865d40b1c1ef 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/causal_conv1d_with_state.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/causal_conv1d_with_state.cc @@ -19,29 +19,29 @@ CausalConv1DActivation ParseCausalConv1DActivation(const std::string& activation } else if (activation_str == "none" || activation_str.empty()) { return CausalConv1DActivation::None; } - ORT_THROW("Unknown activation for CausalConv1DWithState: ", activation_str); + ORT_THROW("Unknown activation for CausalConvWithState: ", activation_str); } // ============================================================================= -// CausalConv1DWithState Implementation +// CausalConvWithState Implementation // ============================================================================= ONNX_OPERATOR_KERNEL_EX( - CausalConv1DWithState, + CausalConvWithState, kMSDomain, 1, kWebGpuExecutionProvider, (*KernelDefBuilder::Create()) .TypeConstraint("T", WebGpuSupportedFloatTypes()), - CausalConv1DWithState); + CausalConvWithState); -CausalConv1DWithState::CausalConv1DWithState(const OpKernelInfo& info) +CausalConvWithState::CausalConvWithState(const OpKernelInfo& info) : WebGpuKernel(info) { - std::string activation_str = info.GetAttrOrDefault("activation", "silu"); + std::string activation_str = info.GetAttrOrDefault("activation", "none"); activation_ = ParseCausalConv1DActivation(activation_str); } -Status CausalConv1DWithStateProgram::GenerateShaderCode(ShaderHelper& shader) const { +Status CausalConvWithStateProgram::GenerateShaderCode(ShaderHelper& shader) const { // Input tensors const auto& input = shader.AddInput("input", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias); const auto& weight = shader.AddInput("weight", ShaderUsage::UseUniform); @@ -208,11 +208,11 @@ fn silu(x: input_element_t) -> input_element_t { return Status::OK(); } -Status CausalConv1DWithState::ComputeInternal(ComputeContext& context) const { +Status CausalConvWithState::ComputeInternal(ComputeContext& context) const { const Tensor* input = context.Input(0); // (B, D, L) const Tensor* weight = context.Input(1); // (D, 1, K) const Tensor* bias = context.Input(2); // optional (D,) - const Tensor* conv_state = context.Input(3); // optional (B, D, K-1) + const Tensor* conv_state = context.Input(3); // optional (B, D, K-1) — past_state ORT_RETURN_IF(input == nullptr, "Input tensor must not be null"); ORT_RETURN_IF(weight == nullptr, "Weight tensor must not be null"); @@ -270,7 +270,7 @@ Status CausalConv1DWithState::ComputeInternal(ComputeContext& context) const { } // Create and run the shader program - CausalConv1DWithStateProgram program{activation_, has_bias, has_conv_state, kernel_size}; + CausalConvWithStateProgram program{activation_, has_bias, has_conv_state, kernel_size}; uint32_t output_size = static_cast(batch_size * channels * input_length); diff --git a/onnxruntime/contrib_ops/webgpu/bert/causal_conv1d_with_state.h b/onnxruntime/contrib_ops/webgpu/bert/causal_conv1d_with_state.h index ccbb22d9de7d4..820315f838e3d 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/causal_conv1d_with_state.h +++ b/onnxruntime/contrib_ops/webgpu/bert/causal_conv1d_with_state.h @@ -15,7 +15,7 @@ namespace webgpu { using namespace onnxruntime::webgpu; using onnxruntime::webgpu::ComputeContext; -// Activation mode for CausalConv1DWithState +// Activation mode for CausalConvWithState enum class CausalConv1DActivation { None, Silu, @@ -23,12 +23,12 @@ enum class CausalConv1DActivation { CausalConv1DActivation ParseCausalConv1DActivation(const std::string& activation_str); -// Program for CausalConv1DWithState -class CausalConv1DWithStateProgram final : public Program { +// Program for CausalConvWithState +class CausalConvWithStateProgram final : public Program { public: - CausalConv1DWithStateProgram(CausalConv1DActivation activation, bool has_bias, bool has_conv_state, - int kernel_size) - : Program{"CausalConv1DWithState"}, + CausalConvWithStateProgram(CausalConv1DActivation activation, bool has_bias, bool has_conv_state, + int kernel_size) + : Program{"CausalConvWithState"}, activation_(activation), has_bias_(has_bias), has_conv_state_(has_conv_state), @@ -51,10 +51,10 @@ class CausalConv1DWithStateProgram final : public Program;\n" << "for (var j = 0u; j < TILE_V; j++) {\n" @@ -100,14 +104,14 @@ Status LinearAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.MainFunctionBody() << "\n// Process each token sequentially\n" << "for (var t = 0u; t < uniforms.seq_length; t++) {\n" - // Load k and q for this thread's dk row - << " let qkv_bh_offset = (batch_idx * uniforms.num_heads + head_idx) * uniforms.seq_length;\n" + // 3D packed indexing: (B, T, H*D) — bt_offset indexes (batch, token) pair + << " let bt_offset = batch_idx * uniforms.seq_length + t;\n" << " var k_val: f32 = 0.0;\n" << " var q_val: f32 = 0.0;\n" << " if (dk_idx < uniforms.head_dim_k) {\n" - << " let k_base = (qkv_bh_offset + t) * uniforms.head_dim_k + dk_idx;\n" - << " k_val = f32(key[k_base]);\n" - << " q_val = f32(query[k_base]);\n" + << " let qk_idx = bt_offset * packed_dk + head_idx * uniforms.head_dim_k + dk_idx;\n" + << " k_val = f32(key[qk_idx]);\n" + << " q_val = f32(query[qk_idx]);\n" << " }\n"; // Step 1: Apply decay (for gated and gated_delta modes) @@ -115,16 +119,16 @@ Status LinearAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.MainFunctionBody() << "\n // Apply exponential decay: S *= exp(decay)\n"; if (decay_broadcast_dk_) { - // Decay shape is (B, H, T) — same decay for all dk rows + // Decay shape is (B, T, H_kv) — same decay for all dk rows shader.MainFunctionBody() - << " let exp_g = exp(f32(decay[qkv_bh_offset + t]));\n"; + << " let exp_g = exp(f32(decay[bt_offset * uniforms.num_heads + head_idx]));\n"; } else { - // Decay shape is (B, H, T, dk) — per-dk decay + // Decay shape is (B, T, H_kv * dk) — per-dk decay // For padding threads (dk_idx >= head_dim_k), use 0.0 (exp(0)=1, no decay) shader.MainFunctionBody() << " var exp_g: f32 = 1.0;\n" << " if (dk_idx < uniforms.head_dim_k) {\n" - << " exp_g = exp(f32(decay[(qkv_bh_offset + t) * uniforms.head_dim_k + dk_idx]));\n" + << " exp_g = exp(f32(decay[bt_offset * packed_dk + head_idx * uniforms.head_dim_k + dk_idx]));\n" << " }\n"; } shader.MainFunctionBody() @@ -152,8 +156,8 @@ Status LinearAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { << " }\n" // Thread 0 computes delta and broadcasts via shared memory << " // Compute delta = beta * (v - retrieved) and broadcast\n" - << " let v_base = ((batch_idx * uniforms.num_heads + head_idx) * uniforms.seq_length + t) * uniforms.head_dim_v + dv_start;\n" - << " let beta_base = ((batch_idx * uniforms.num_heads + head_idx) * uniforms.seq_length + t);\n" + << " let v_base = bt_offset * packed_dv + head_idx * uniforms.head_dim_v + dv_start;\n" + << " let beta_base = bt_offset * uniforms.num_heads + head_idx;\n" << " if (dk_idx == 0u) {\n" << " let beta_val = f32(beta[beta_base]);\n" << " for (var j = 0u; j < TILE_V; j++) {\n" @@ -175,7 +179,7 @@ Status LinearAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { // For linear and gated modes: S += k ⊗ v (no delta rule) shader.MainFunctionBody() << "\n // Update state: S += k ⊗ v\n" - << " let v_base = ((batch_idx * uniforms.num_heads + head_idx) * uniforms.seq_length + t) * uniforms.head_dim_v + dv_start;\n" + << " let v_base = bt_offset * packed_dv + head_idx * uniforms.head_dim_v + dv_start;\n" << " for (var j = 0u; j < TILE_V; j++) {\n" << " if (dv_start + j < uniforms.head_dim_v) {\n" << " let v_val = f32(value[v_base + j]);\n" @@ -201,7 +205,7 @@ Status LinearAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { << " }\n" // Thread 0 writes the output for this token and dv_tile << " if (dk_idx == 0u) {\n" - << " let out_base = ((batch_idx * uniforms.num_heads + head_idx) * uniforms.seq_length + t) * uniforms.head_dim_v + dv_start;\n" + << " let out_base = bt_offset * packed_dv + head_idx * uniforms.head_dim_v + dv_start;\n" << " for (var j = 0u; j < TILE_V; j++) {\n" << " if (dv_start + j < uniforms.head_dim_v) {\n" << " output[out_base + j] = output_element_t(reduction_buf[j * workgroup_size_x] * uniforms.scale);\n" @@ -211,14 +215,14 @@ Status LinearAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { << " workgroupBarrier();\n" << "}\n"; // end token loop - // Write final state + // Write final state (4D: B, H_kv, dk, dv) shader.MainFunctionBody() - << "\n// Write final state\n" + << "\n// Write present_state\n" << "if (dk_idx < uniforms.head_dim_k) {\n" - << " let final_state_base = ((batch_idx * uniforms.num_heads + head_idx) * uniforms.head_dim_k + dk_idx) * uniforms.head_dim_v + dv_start;\n" + << " let state_base = ((batch_idx * uniforms.num_heads + head_idx) * uniforms.head_dim_k + dk_idx) * uniforms.head_dim_v + dv_start;\n" << " for (var j = 0u; j < TILE_V; j++) {\n" << " if (dv_start + j < uniforms.head_dim_v) {\n" - << " final_state[final_state_base + j] = output_element_t(state[j]);\n" + << " present_state[state_base + j] = output_element_t(state[j]);\n" << " }\n" << " }\n" << "}\n"; @@ -244,39 +248,52 @@ LinearAttention::LinearAttention(const OpKernelInfo& info) std::string update_rule_str = info.GetAttrOrDefault("update_rule", "gated_delta"); update_rule_ = ParseUpdateRule(update_rule_str); scale_ = info.GetAttrOrDefault("scale", 0.0f); + q_num_heads_ = static_cast(info.GetAttr("q_num_heads")); + kv_num_heads_ = static_cast(info.GetAttr("kv_num_heads")); } /* - Inputs: - query: (B, H_kv, T, d_k) — query (may have fewer heads than value for GQA; pre-scaled by 1/sqrt(d_k)) - key: (B, H_kv, T, d_k) — key (L2-normalized) - value: (B, H, T, d_v) — value (H >= H_kv) - initial_state: (B, H, d_k, d_v) — recurrent state - decay: (B, H, T) — exponential decay gate (log-space) - beta: (B, H, T) — update rate (sigmoid output) + 3D packed inputs: + query: (B, T, H_q * d_k) — packed query + key: (B, T, H_kv * d_k) — packed key + value: (B, T, H_kv * d_v) — packed value + past_state: (B, H_kv, d_k, d_v) — recurrent state (4D) + decay: (B, T, H_kv * d_k) or (B, T, H_kv) — decay gate (3D) + beta: (B, T, H_kv) or (B, T, 1) — update rate (3D) Outputs: - output: (B, H, T, d_v) — attention output - present_state: (B, H, d_k, d_v) — updated recurrent state + output: (B, T, H_q * d_v) — packed attention output + present_state: (B, H_kv, d_k, d_v) — updated recurrent state (4D) */ Status LinearAttention::ComputeInternal(ComputeContext& context) const { const Tensor* query = context.Input(0); const Tensor* key = context.Input(1); const Tensor* value = context.Input(2); - const Tensor* initial_state = context.Input(3); // optional - const Tensor* decay = context.Input(4); // optional - const Tensor* beta = context.Input(5); // optional + const Tensor* past_state = context.Input(3); // optional + const Tensor* decay = context.Input(4); // optional + const Tensor* beta = context.Input(5); // optional - // Validate inputs + // Validate 3D packed inputs const auto& q_shape = query->Shape(); - ORT_RETURN_IF(q_shape.NumDimensions() != 4, "query must be 4D (B, H, T, dk)"); + ORT_RETURN_IF(q_shape.NumDimensions() != 3, "query must be 3D (B, T, H_q*d_k)"); const int batch_size = static_cast(q_shape[0]); - const int num_heads = static_cast(q_shape[1]); - const int seq_length = static_cast(q_shape[2]); - const int head_dim_k = static_cast(q_shape[3]); - const int head_dim_v = static_cast(value->Shape()[3]); + const int seq_length = static_cast(q_shape[1]); + const int q_packed_dim = static_cast(q_shape[2]); + const int num_heads = kv_num_heads_; + + ORT_RETURN_IF(q_num_heads_ != kv_num_heads_, + "GQA (q_num_heads != kv_num_heads) is not yet supported"); + + const int head_dim_k = q_packed_dim / q_num_heads_; + ORT_RETURN_IF(q_packed_dim != head_dim_k * q_num_heads_, + "query packed dim must be divisible by q_num_heads"); + + const int v_packed_dim = static_cast(value->Shape()[2]); + const int head_dim_v = v_packed_dim / kv_num_heads_; + ORT_RETURN_IF(v_packed_dim != head_dim_v * kv_num_heads_, + "value packed dim must be divisible by kv_num_heads"); // Validate update rule has required inputs bool needs_decay = (update_rule_ == LinearAttentionUpdateRule::Gated || @@ -286,18 +303,18 @@ Status LinearAttention::ComputeInternal(ComputeContext& context) const { ORT_RETURN_IF(needs_decay && decay == nullptr, "decay input required for gated/gated_delta update rules"); ORT_RETURN_IF(needs_beta && beta == nullptr, "beta input required for delta/gated_delta update rules"); - // Compute scale + // Compute scale: 0.0 means derive from d_k float scale = scale_; if (scale == 0.0f) { - scale = 1.0f; // TODO: should this come / std::sqrt(static_cast(head_dim_k)); + scale = 1.0f / std::sqrt(static_cast(head_dim_k)); } - // Allocate outputs - TensorShapeVector output_shape({batch_size, num_heads, seq_length, head_dim_v}); + // Allocate outputs — output is 3D packed, state is 4D + TensorShapeVector output_shape({batch_size, seq_length, q_num_heads_ * head_dim_v}); Tensor* output = context.Output(0, output_shape); TensorShapeVector state_shape({batch_size, num_heads, head_dim_k, head_dim_v}); - Tensor* final_state = context.Output(1, state_shape); + Tensor* present_state = context.Output(1, state_shape); // Choose tile size: balance parallelism vs shared memory // TILE_V * WORKGROUP_SIZE * 4 bytes must fit in shared memory (typically 16KB limit) @@ -319,15 +336,17 @@ Status LinearAttention::ComputeInternal(ComputeContext& context) const { const uint32_t num_workgroups = batch_size * num_heads * num_dv_tiles; - bool has_initial_state = initial_state != nullptr; + bool has_initial_state = past_state != nullptr; bool has_decay = decay != nullptr; bool has_beta = beta != nullptr; - // Detect whether decay is (B,H,T) or (B,H,T,dk) + // Detect whether decay is (B,T,H_kv) or (B,T,H_kv*dk) bool decay_broadcast_dk = false; if (has_decay) { const auto& decay_shape = decay->Shape(); - if (decay_shape.NumDimensions() == 3) { + // (B, T, H_kv) = 3D with last dim == num_heads + int decay_last_dim = static_cast(decay_shape[decay_shape.NumDimensions() - 1]); + if (decay_last_dim == num_heads) { decay_broadcast_dk = true; } } @@ -338,7 +357,7 @@ Status LinearAttention::ComputeInternal(ComputeContext& context) const { {key, ProgramTensorMetadataDependency::TypeAndRank}, {value, ProgramTensorMetadataDependency::TypeAndRank}}); if (has_initial_state) { - program.AddInput({initial_state, ProgramTensorMetadataDependency::TypeAndRank}); + program.AddInput({past_state, ProgramTensorMetadataDependency::TypeAndRank}); } if (has_decay) { program.AddInput({decay, ProgramTensorMetadataDependency::TypeAndRank}); @@ -348,7 +367,7 @@ Status LinearAttention::ComputeInternal(ComputeContext& context) const { } program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank}, - {final_state, ProgramTensorMetadataDependency::TypeAndRank}}); + {present_state, ProgramTensorMetadataDependency::TypeAndRank}}); program.SetDispatchGroupSize(num_workgroups) .SetWorkgroupSize(workgroup_size) diff --git a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.h b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.h index 5ecb736ff1c1c..2b02856440f39 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.h @@ -69,6 +69,8 @@ class LinearAttention : public WebGpuKernel { protected: LinearAttentionUpdateRule update_rule_; float scale_; + int q_num_heads_; + int kv_num_heads_; }; } // namespace webgpu diff --git a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc index 67b71ed85d69a..3f13b91adbefd 100644 --- a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc @@ -14,7 +14,7 @@ namespace webgpu { class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Attention); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasAdd); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, CausalConv1DWithState); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, CausalConvWithState); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasGelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasSplitGelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FastGelu); @@ -47,7 +47,7 @@ Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry, bool enable BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index fb9922a1b68d7..2feed9b30a9c6 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -2217,58 +2217,67 @@ ONNX_MS_OPERATOR_SET_SCHEMA( } })); -constexpr const char* CausalConv1DWithState_ver1_doc = R"DOC( -Depthwise causal 1D convolution with carry state for incremental decoding. +constexpr const char* CausalConvWithState_ver1_doc = R"DOC( +Stateful causal depthwise convolution, generalized to N spatial dimensions. Used by Gated DeltaNet (Qwen3.5) and Mamba (Jamba, FalconMamba) as a preprocessing step. Replaces the 3-op pattern (Concat + Conv + Slice) with a single fused operation. -The convolution is causal (looks only at current and past positions) and depthwise -(each channel is convolved independently with its own kernel). +The convolution is causal (looks only at current and past positions along the last +spatial dimension) and depthwise (each channel is convolved independently with its own kernel). -Input layout is channels-first: (batch_size, channels, length). -Weight layout: (channels, 1, kernel_size) for depthwise convolution. -Conv state carries the last (kernel_size - 1) input values for incremental decode. +Input layout is channels-first: (batch_size, channels, ...). +Weight layout: (channels, 1, k_1, ...) for depthwise convolution. +The carry state stores the last (k-1) positions along the causal axis for incremental decode. + +The ndim attribute generalizes the op to 1D, 2D, or 3D spatial dimensions. Causality is +enforced on the last spatial dimension only. The optional activation attribute supports fused SiLU/Swish activation. )DOC"; ONNX_MS_OPERATOR_SET_SCHEMA( - CausalConv1DWithState, 1, + CausalConvWithState, 1, OpSchema() - .SetDoc(CausalConv1DWithState_ver1_doc) + .SetDoc(CausalConvWithState_ver1_doc) .Attr("activation", "Fused activation function. One of: 'silu', 'swish', 'none'. " - "Default is 'silu'.", + "Default is 'none'.", AttributeProto::STRING, - std::string("silu")) + std::string("none")) + .Attr("ndim", + "Spatial dimensionality: 1, 2, or 3. Default is 1.", + AttributeProto::INT, + static_cast(1)) .Input(0, "input", - "Input tensor with shape (batch_size, channels, length). Channels-first layout.", + "Input tensor with shape (batch_size, channels, ...). Channels-first layout. " + "Spatial dims: 1D: (L,); 2D: (H, W); 3D: (D, H, W).", "T") .Input(1, "weight", - "Depthwise convolution weights with shape (channels, 1, kernel_size).", + "Depthwise convolution kernel with shape (channels, 1, k_1, ...). " + "Spatial kernel sizes: (k_1, ..., k_ndim).", "T") .Input(2, "bias", - "Optional bias with shape (channels).", + "Optional per-channel bias with shape (channels).", "T", OpSchema::Optional) .Input(3, - "conv_state", - "Carry state from previous step with shape (batch_size, channels, kernel_size - 1). " + "past_state", + "Carry state from previous step. For ndim=1: (batch_size, channels, k_1 - 1). " "If not provided, padding is zero.", "T", OpSchema::Optional) .Output(0, "output", - "Convolution output with shape (batch_size, channels, length).", + "Convolution output with same shape as input.", "T") .Output(1, "present_state", - "Updated carry state with shape (batch_size, channels, kernel_size - 1). " - "Contains the last (kernel_size - 1) values from the virtual input.", + "Updated carry state. For ndim=1: (batch_size, channels, k_1 - 1). " + "Contains the last (k-1) values from the virtual input along the causal axis.", "T") .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, @@ -2277,19 +2286,20 @@ ONNX_MS_OPERATOR_SET_SCHEMA( propagateElemTypeFromInputToOutput(ctx, 0, 0); propagateElemTypeFromInputToOutput(ctx, 0, 1); - // Output 0: same shape as input (batch_size, channels, length) + // Output 0: same shape as input (batch_size, channels, ...) propagateShapeFromInputToOutput(ctx, 0, 0); - // Output 1: (batch_size, channels, kernel_size - 1) + // Output 1: (batch_size, channels, kernel_size - 1) for ndim=1 if (hasInputShape(ctx, 0) && hasInputShape(ctx, 1)) { auto& input_shape = getInputShape(ctx, 0); auto& weight_shape = getInputShape(ctx, 1); TensorShapeProto state_shape; *state_shape.add_dim() = input_shape.dim(0); // batch_size *state_shape.add_dim() = input_shape.dim(1); // channels - // kernel_size - 1 - if (weight_shape.dim(2).has_dim_value()) { - state_shape.add_dim()->set_dim_value(weight_shape.dim(2).dim_value() - 1); + // kernel_size - 1 (last kernel dimension for ndim=1) + int last_kernel_dim = weight_shape.dim_size() - 1; + if (weight_shape.dim(last_kernel_dim).has_dim_value()) { + state_shape.add_dim()->set_dim_value(weight_shape.dim(last_kernel_dim).dim_value() - 1); } else { state_shape.add_dim(); // unknown } @@ -2298,17 +2308,16 @@ ONNX_MS_OPERATOR_SET_SCHEMA( })); constexpr const char* LinearAttention_ver1_doc = R"DOC( -Linear Attention operator. +Unified linear attention operator for autoregressive decoding (T=1) and prefill (T>1). -Processes a sequence of tokens using linear attention with a recurrent state matrix. -When sequence_length=1, this is equivalent to a single recurrent decode step. -When sequence_length>1, this efficiently processes the full sequence (e.g., for prefill). +All inputs use 3D packed format [B, T, H*D]; q_num_heads and kv_num_heads are always +required. The op internally unpacks to 4D for computation. The update_rule attribute selects the recurrence type: -- "linear": S_t = S_{t-1} + k_t ⊗ v_t; o_t = q_t^T S_t / sqrt(d_k) -- "gated": S_t = exp(g_t) * S_{t-1} + k_t ⊗ v_t; o_t = q_t^T S_t / sqrt(d_k) -- "delta": S_t = S_{t-1} + β_t * k_t ⊗ (v_t - S_{t-1}^T k_t); o_t = q_t^T S_t / sqrt(d_k) -- "gated_delta": S_t = exp(g_t) * S_{t-1} + β_t * k_t ⊗ (v_t - exp(g_t) * S_{t-1}^T k_t); o_t = q_t^T S_t / sqrt(d_k) +- "linear": S_t = S_{t-1} + k_t ⊗ v_t; o_t = scale * q_t^T S_t +- "gated": S_t = exp(g_t) * S_{t-1} + k_t ⊗ v_t; o_t = scale * q_t^T S_t +- "delta": S_t = S_{t-1} + β_t * k_t ⊗ (v_t - S_{t-1}^T k_t); o_t = scale * q_t^T S_t +- "gated_delta": S_t = exp(g_t) * S_{t-1} + β_t * k_t ⊗ (v_t - exp(g_t) * S_{t-1}^T k_t); o_t = scale * q_t^T S_t where g_t is the decay (in log-space), β_t is the update rate, and ⊗ denotes outer product. @@ -2326,49 +2335,63 @@ ONNX_MS_OPERATOR_SET_SCHEMA( AttributeProto::STRING, std::string("gated_delta")) .Attr("scale", - "Output scaling factor. When 0.0 (default), uses 1/sqrt(d_k) where d_k is the key dimension.", + "Output scaling factor. When 0.0 (default), derives d_k = query.shape[-1] / q_num_heads " + "and uses 1/sqrt(d_k). Set explicitly to override.", AttributeProto::FLOAT, 0.0f) + .Attr("q_num_heads", + "Number of query heads. Always required.", + AttributeProto::INT) + .Attr("kv_num_heads", + "Number of key/value heads. Always required.", + AttributeProto::INT) + .Attr("chunk_size", + "Chunk size for the chunk-parallel WY decomposition during prefill (T>1). " + "Tuning hint; does not affect output correctness.", + AttributeProto::INT, + static_cast(64)) .Input(0, "query", - "Query vectors with shape (batch_size, num_heads, sequence_length, head_dim_k)", + "Query vectors with 3D packed shape (B, T, H_q * d_k). " + "Heads are packed into the last dimension.", "T") .Input(1, "key", - "Key vectors with shape (batch_size, num_heads, sequence_length, head_dim_k). " + "Key vectors with 3D packed shape (B, T, H_kv * d_k). " "Should be L2-normalized for delta/gated_delta modes.", "T") .Input(2, "value", - "Value vectors with shape (batch_size, num_heads, sequence_length, head_dim_v)", + "Value vectors with 3D packed shape (B, T, H_kv * d_v).", "T") .Input(3, - "initial_state", - "Initial recurrent state with shape (batch_size, num_heads, head_dim_k, head_dim_v). " - "If not provided, defaults to zeros.", + "past_state", + "Recurrent state from previous step with shape (B, H_kv, d_k, d_v). " + "Always 4D. If not provided, defaults to zeros.", "T", OpSchema::Optional) .Input(4, "decay", - "Exponential decay gate in log-space with shape broadcastable to " - "(batch_size, num_heads, sequence_length, [head_dim_k]). " + "Exponential decay gate in log-space. 3D packed shape: " + "(B, T, H_kv * d_k) for per-key-dimension decay (GLA/RWKV-6), or " + "(B, T, H_kv) for per-head scalar decay (DeltaNet/RetNet). " "Required for 'gated' and 'gated_delta' modes.", "T", OpSchema::Optional) .Input(5, "beta", - "Update rate (sigmoid output) with shape broadcastable to " - "(batch_size, num_heads, sequence_length, [1]). " + "Update rate (sigmoid output). 3D packed shape: " + "(B, T, H_kv) or (B, T, 1). " "Required for 'delta' and 'gated_delta' modes.", "T", OpSchema::Optional) .Output(0, "output", - "Attention output with shape (batch_size, num_heads, sequence_length, head_dim_v)", + "Attention output with 3D packed shape (B, T, H_q * d_v).", "T") .Output(1, - "final_state", - "Final recurrent state with shape (batch_size, num_heads, head_dim_k, head_dim_v)", + "present_state", + "Updated recurrent state with shape (B, H_kv, d_k, d_v). Always 4D.", "T") .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, @@ -2377,27 +2400,48 @@ ONNX_MS_OPERATOR_SET_SCHEMA( propagateElemTypeFromInputToOutput(ctx, 0, 0); propagateElemTypeFromInputToOutput(ctx, 0, 1); - // Output 0: same shape as query but last dim from value - if (hasInputShape(ctx, 0) && hasInputShape(ctx, 2)) { + // Read required attributes + auto* q_num_heads_attr = ctx.getAttribute("q_num_heads"); + auto* kv_num_heads_attr = ctx.getAttribute("kv_num_heads"); + int64_t q_num_heads = (q_num_heads_attr && q_num_heads_attr->has_i()) ? q_num_heads_attr->i() : 0; + int64_t kv_num_heads = (kv_num_heads_attr && kv_num_heads_attr->has_i()) ? kv_num_heads_attr->i() : 0; + + // Output 0: (B, T, H_q * d_v) — 3D packed + if (hasInputShape(ctx, 0) && hasInputShape(ctx, 2) && q_num_heads > 0 && kv_num_heads > 0) { auto& query_shape = getInputShape(ctx, 0); auto& value_shape = getInputShape(ctx, 2); TensorShapeProto output_shape; - *output_shape.add_dim() = query_shape.dim(0); - *output_shape.add_dim() = query_shape.dim(1); - *output_shape.add_dim() = query_shape.dim(2); - *output_shape.add_dim() = value_shape.dim(3); + *output_shape.add_dim() = query_shape.dim(0); // B + *output_shape.add_dim() = query_shape.dim(1); // T + // H_q * d_v: d_v = value.dim(2) / kv_num_heads, then H_q * d_v + if (value_shape.dim(2).has_dim_value()) { + int64_t d_v = value_shape.dim(2).dim_value() / kv_num_heads; + output_shape.add_dim()->set_dim_value(q_num_heads * d_v); + } else { + output_shape.add_dim(); // unknown + } updateOutputShape(ctx, 0, output_shape); } - // Output 1: final_state shape (B, H, dk, dv) - if (hasInputShape(ctx, 0) && hasInputShape(ctx, 2)) { + // Output 1: present_state shape (B, H_kv, d_k, d_v) — 4D + if (hasInputShape(ctx, 0) && hasInputShape(ctx, 2) && q_num_heads > 0 && kv_num_heads > 0) { auto& query_shape = getInputShape(ctx, 0); auto& value_shape = getInputShape(ctx, 2); TensorShapeProto state_shape; - *state_shape.add_dim() = query_shape.dim(0); // batch - *state_shape.add_dim() = query_shape.dim(1); // heads - *state_shape.add_dim() = query_shape.dim(3); // dk - *state_shape.add_dim() = value_shape.dim(3); // dv + *state_shape.add_dim() = query_shape.dim(0); // B + state_shape.add_dim()->set_dim_value(kv_num_heads); // H_kv + // d_k = query.dim(2) / q_num_heads + if (query_shape.dim(2).has_dim_value()) { + state_shape.add_dim()->set_dim_value(query_shape.dim(2).dim_value() / q_num_heads); + } else { + state_shape.add_dim(); + } + // d_v = value.dim(2) / kv_num_heads + if (value_shape.dim(2).has_dim_value()) { + state_shape.add_dim()->set_dim_value(value_shape.dim(2).dim_value() / kv_num_heads); + } else { + state_shape.add_dim(); + } updateOutputShape(ctx, 1, state_shape); } else if (hasInputShape(ctx, 3)) { propagateShapeFromInputToOutput(ctx, 3, 1); diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h index b2b9bf8442692..59f97c222ceb2 100644 --- a/onnxruntime/core/graph/contrib_ops/ms_opset.h +++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h @@ -89,7 +89,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MultiHeadAttention); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GroupQueryAttention); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, PagedAttention); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, LinearAttention); -class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, CausalConv1DWithState); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, CausalConvWithState); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MurmurHash3); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, NGramRepeatBlock); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Pad); @@ -202,7 +202,7 @@ class OpSet_Microsoft_ver1 { fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); - fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); diff --git a/onnxruntime/test/contrib_ops/causal_conv1d_with_state_op_test.cc b/onnxruntime/test/contrib_ops/causal_conv1d_with_state_op_test.cc index 7f66707c32453..76e878d9fff80 100644 --- a/onnxruntime/test/contrib_ops/causal_conv1d_with_state_op_test.cc +++ b/onnxruntime/test/contrib_ops/causal_conv1d_with_state_op_test.cc @@ -19,17 +19,17 @@ enum class TensorType { kFloat16 }; -// Reference implementation for CausalConv1DWithState +// Reference implementation for CausalConvWithState // Performs depthwise causal 1D convolution with optional state, bias, and activation. // // Input: (B, D, L) channels-first // Weight: (D, 1, K) depthwise // Bias: (D,) optional -// conv_state: (B, D, K-1) optional carry state +// past_state: (B, D, K-1) optional carry state // // Output: (B, D, L) convolution output (with optional activation) // present_state: (B, D, K-1) updated carry state -void CausalConv1DWithStateReference( +void CausalConvWithStateReference( const std::vector& input, const std::vector& weight, const std::vector* bias, @@ -92,7 +92,7 @@ void CausalConv1DWithStateReference( } // anonymous namespace -static void RunCausalConv1DWithStateTest( +static void RunCausalConvWithStateTest( const std::vector& input_data, const std::vector& weight_data, const std::vector* bias_data, @@ -126,7 +126,7 @@ static void RunCausalConv1DWithStateTest( } for (auto& ep : execution_providers) { - OpTester test("CausalConv1DWithState", 1, onnxruntime::kMSDomain); + OpTester test("CausalConvWithState", 1, onnxruntime::kMSDomain); test.AddAttribute("activation", activation); if (tensor_type == TensorType::kFloat) { @@ -140,7 +140,7 @@ static void RunCausalConv1DWithStateTest( } if (conv_state_data != nullptr) { - test.AddInput("conv_state", state_shape, *conv_state_data); + test.AddInput("past_state", state_shape, *conv_state_data); } else { test.AddOptionalInputEdge(); } @@ -158,7 +158,7 @@ static void RunCausalConv1DWithStateTest( } if (conv_state_data != nullptr) { - test.AddInput("conv_state", state_shape, ToFloat16(*conv_state_data)); + test.AddInput("past_state", state_shape, ToFloat16(*conv_state_data)); } else { test.AddOptionalInputEdge(); } @@ -176,7 +176,7 @@ static void RunCausalConv1DWithStateTest( } } -static void RunCausalConv1DWithStateTests( +static void RunCausalConvWithStateTests( const std::vector& input_data, const std::vector& weight_data, const std::vector* bias_data, @@ -189,20 +189,20 @@ static void RunCausalConv1DWithStateTests( // Compute expected output using reference implementation std::vector expected_output; std::vector expected_state; - CausalConv1DWithStateReference( + CausalConvWithStateReference( input_data, weight_data, bias_data, conv_state_data, expected_output, expected_state, batch_size, channels, input_length, kernel_size, activation); // FP32 test - RunCausalConv1DWithStateTest( + RunCausalConvWithStateTest( input_data, weight_data, bias_data, conv_state_data, expected_output, expected_state, batch_size, channels, input_length, kernel_size, activation, TensorType::kFloat); // FP16 test - RunCausalConv1DWithStateTest( + RunCausalConvWithStateTest( input_data, weight_data, bias_data, conv_state_data, expected_output, expected_state, batch_size, channels, input_length, kernel_size, activation, @@ -213,7 +213,7 @@ static void RunCausalConv1DWithStateTests( // Basic tests - simple cases // ============================================================================= -TEST(CausalConv1DWithStateTest, BasicNoStateNoBias) { +TEST(CausalConvWithStateTest, BasicNoStateNoBias) { // B=1, D=2, L=4, K=3, activation=none int batch_size = 1, channels = 2, input_length = 4, kernel_size = 3; @@ -227,12 +227,12 @@ TEST(CausalConv1DWithStateTest, BasicNoStateNoBias) { 0.1f, 0.2f, 0.3f, // channel 0 kernel 0.4f, 0.5f, 0.6f}; // channel 1 kernel - RunCausalConv1DWithStateTests( + RunCausalConvWithStateTests( input_data, weight_data, nullptr, nullptr, batch_size, channels, input_length, kernel_size, "none"); } -TEST(CausalConv1DWithStateTest, BasicWithBias) { +TEST(CausalConvWithStateTest, BasicWithBias) { // B=1, D=2, L=4, K=3, activation=none int batch_size = 1, channels = 2, input_length = 4, kernel_size = 3; @@ -244,12 +244,12 @@ TEST(CausalConv1DWithStateTest, BasicWithBias) { 0.4f, 0.5f, 0.6f}; std::vector bias_data = {0.1f, -0.2f}; - RunCausalConv1DWithStateTests( + RunCausalConvWithStateTests( input_data, weight_data, &bias_data, nullptr, batch_size, channels, input_length, kernel_size, "none"); } -TEST(CausalConv1DWithStateTest, BasicWithState) { +TEST(CausalConvWithStateTest, BasicWithState) { // B=1, D=2, L=3, K=3, activation=none int batch_size = 1, channels = 2, input_length = 3, kernel_size = 3; @@ -264,12 +264,12 @@ TEST(CausalConv1DWithStateTest, BasicWithState) { -1.0f, 0.5f, // channel 0 state 0.3f, -0.7f}; // channel 1 state - RunCausalConv1DWithStateTests( + RunCausalConvWithStateTests( input_data, weight_data, nullptr, &conv_state_data, batch_size, channels, input_length, kernel_size, "none"); } -TEST(CausalConv1DWithStateTest, WithStateAndBias) { +TEST(CausalConvWithStateTest, WithStateAndBias) { // B=1, D=2, L=3, K=3, activation=none int batch_size = 1, channels = 2, input_length = 3, kernel_size = 3; @@ -284,7 +284,7 @@ TEST(CausalConv1DWithStateTest, WithStateAndBias) { -1.0f, 0.5f, 0.3f, -0.7f}; - RunCausalConv1DWithStateTests( + RunCausalConvWithStateTests( input_data, weight_data, &bias_data, &conv_state_data, batch_size, channels, input_length, kernel_size, "none"); } @@ -293,7 +293,7 @@ TEST(CausalConv1DWithStateTest, WithStateAndBias) { // SiLU activation tests // ============================================================================= -TEST(CausalConv1DWithStateTest, SiluActivationNoState) { +TEST(CausalConvWithStateTest, SiluActivationNoState) { int batch_size = 1, channels = 2, input_length = 4, kernel_size = 3; std::vector input_data = { @@ -303,12 +303,12 @@ TEST(CausalConv1DWithStateTest, SiluActivationNoState) { 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f}; - RunCausalConv1DWithStateTests( + RunCausalConvWithStateTests( input_data, weight_data, nullptr, nullptr, batch_size, channels, input_length, kernel_size, "silu"); } -TEST(CausalConv1DWithStateTest, SiluActivationWithState) { +TEST(CausalConvWithStateTest, SiluActivationWithState) { int batch_size = 1, channels = 2, input_length = 3, kernel_size = 3; std::vector input_data = { @@ -321,12 +321,12 @@ TEST(CausalConv1DWithStateTest, SiluActivationWithState) { -1.0f, 0.5f, 0.3f, -0.7f}; - RunCausalConv1DWithStateTests( + RunCausalConvWithStateTests( input_data, weight_data, nullptr, &conv_state_data, batch_size, channels, input_length, kernel_size, "silu"); } -TEST(CausalConv1DWithStateTest, SiluActivationWithBiasAndState) { +TEST(CausalConvWithStateTest, SiluActivationWithBiasAndState) { int batch_size = 1, channels = 2, input_length = 4, kernel_size = 3; std::vector input_data = { @@ -340,7 +340,7 @@ TEST(CausalConv1DWithStateTest, SiluActivationWithBiasAndState) { -1.0f, 0.5f, 0.3f, -0.7f}; - RunCausalConv1DWithStateTests( + RunCausalConvWithStateTests( input_data, weight_data, &bias_data, &conv_state_data, batch_size, channels, input_length, kernel_size, "silu"); } @@ -349,7 +349,7 @@ TEST(CausalConv1DWithStateTest, SiluActivationWithBiasAndState) { // Kernel size variations // ============================================================================= -TEST(CausalConv1DWithStateTest, KernelSize2) { +TEST(CausalConvWithStateTest, KernelSize2) { int batch_size = 1, channels = 2, input_length = 4, kernel_size = 2; std::vector input_data = { @@ -361,12 +361,12 @@ TEST(CausalConv1DWithStateTest, KernelSize2) { // State: (1, 2, 1) - kernel_size - 1 = 1 std::vector conv_state_data = {0.5f, -0.3f}; - RunCausalConv1DWithStateTests( + RunCausalConvWithStateTests( input_data, weight_data, nullptr, &conv_state_data, batch_size, channels, input_length, kernel_size, "silu"); } -TEST(CausalConv1DWithStateTest, KernelSize4) { +TEST(CausalConvWithStateTest, KernelSize4) { int batch_size = 1, channels = 1, input_length = 5, kernel_size = 4; std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}; @@ -374,7 +374,7 @@ TEST(CausalConv1DWithStateTest, KernelSize4) { // State: (1, 1, 3) std::vector conv_state_data = {-1.0f, 0.0f, 0.5f}; - RunCausalConv1DWithStateTests( + RunCausalConvWithStateTests( input_data, weight_data, nullptr, &conv_state_data, batch_size, channels, input_length, kernel_size, "none"); } @@ -383,7 +383,7 @@ TEST(CausalConv1DWithStateTest, KernelSize4) { // Batch size > 1 // ============================================================================= -TEST(CausalConv1DWithStateTest, MultiBatch) { +TEST(CausalConvWithStateTest, MultiBatch) { int batch_size = 2, channels = 2, input_length = 3, kernel_size = 3; // Input: (2, 2, 3) @@ -410,7 +410,7 @@ TEST(CausalConv1DWithStateTest, MultiBatch) { 0.1f, -0.1f, // ch 0 0.7f, 0.8f}; // ch 1 - RunCausalConv1DWithStateTests( + RunCausalConvWithStateTests( input_data, weight_data, &bias_data, &conv_state_data, batch_size, channels, input_length, kernel_size, "silu"); } @@ -419,7 +419,7 @@ TEST(CausalConv1DWithStateTest, MultiBatch) { // Single token decode (L=1) - the primary use case for incremental decoding // ============================================================================= -TEST(CausalConv1DWithStateTest, SingleTokenDecode) { +TEST(CausalConvWithStateTest, SingleTokenDecode) { int batch_size = 1, channels = 4, input_length = 1, kernel_size = 4; // Input: (1, 4, 1) @@ -441,12 +441,12 @@ TEST(CausalConv1DWithStateTest, SingleTokenDecode) { 0.5f, 0.5f, 0.5f, // ch 2 -0.2f, 0.4f, -0.6f}; // ch 3 - RunCausalConv1DWithStateTests( + RunCausalConvWithStateTests( input_data, weight_data, &bias_data, &conv_state_data, batch_size, channels, input_length, kernel_size, "silu"); } -TEST(CausalConv1DWithStateTest, SingleTokenDecodeMultiBatch) { +TEST(CausalConvWithStateTest, SingleTokenDecodeMultiBatch) { int batch_size = 2, channels = 2, input_length = 1, kernel_size = 3; // Input: (2, 2, 1) @@ -467,7 +467,7 @@ TEST(CausalConv1DWithStateTest, SingleTokenDecodeMultiBatch) { 0.5f, 0.5f, // B1, ch 0 -0.2f, 0.4f}; // B1, ch 1 - RunCausalConv1DWithStateTests( + RunCausalConvWithStateTests( input_data, weight_data, nullptr, &conv_state_data, batch_size, channels, input_length, kernel_size, "silu"); } @@ -477,7 +477,7 @@ TEST(CausalConv1DWithStateTest, SingleTokenDecodeMultiBatch) { // as conv_state for the next call (simulating autoregressive decode) // ============================================================================= -TEST(CausalConv1DWithStateTest, StateContinuity) { +TEST(CausalConvWithStateTest, StateContinuity) { // Process a sequence of single tokens and verify state propagation int batch_size = 1, channels = 1, kernel_size = 3; int input_length = 1; @@ -492,11 +492,11 @@ TEST(CausalConv1DWithStateTest, StateContinuity) { std::vector input1 = {1.0f}; std::vector expected_output1; std::vector expected_state1; - CausalConv1DWithStateReference(input1, weight_data, &bias_data, &conv_state, + CausalConvWithStateReference(input1, weight_data, &bias_data, &conv_state, expected_output1, expected_state1, batch_size, channels, input_length, kernel_size, "none"); - RunCausalConv1DWithStateTest(input1, weight_data, &bias_data, &conv_state, + RunCausalConvWithStateTest(input1, weight_data, &bias_data, &conv_state, expected_output1, expected_state1, batch_size, channels, input_length, kernel_size, "none", TensorType::kFloat); @@ -505,11 +505,11 @@ TEST(CausalConv1DWithStateTest, StateContinuity) { std::vector input2 = {2.0f}; std::vector expected_output2; std::vector expected_state2; - CausalConv1DWithStateReference(input2, weight_data, &bias_data, &expected_state1, + CausalConvWithStateReference(input2, weight_data, &bias_data, &expected_state1, expected_output2, expected_state2, batch_size, channels, input_length, kernel_size, "none"); - RunCausalConv1DWithStateTest(input2, weight_data, &bias_data, &expected_state1, + RunCausalConvWithStateTest(input2, weight_data, &bias_data, &expected_state1, expected_output2, expected_state2, batch_size, channels, input_length, kernel_size, "none", TensorType::kFloat); @@ -518,11 +518,11 @@ TEST(CausalConv1DWithStateTest, StateContinuity) { std::vector input3 = {3.0f}; std::vector expected_output3; std::vector expected_state3; - CausalConv1DWithStateReference(input3, weight_data, &bias_data, &expected_state2, + CausalConvWithStateReference(input3, weight_data, &bias_data, &expected_state2, expected_output3, expected_state3, batch_size, channels, input_length, kernel_size, "none"); - RunCausalConv1DWithStateTest(input3, weight_data, &bias_data, &expected_state2, + RunCausalConvWithStateTest(input3, weight_data, &bias_data, &expected_state2, expected_output3, expected_state3, batch_size, channels, input_length, kernel_size, "none", TensorType::kFloat); @@ -536,7 +536,7 @@ TEST(CausalConv1DWithStateTest, StateContinuity) { // Equivalence test: sequence processing should match token-by-token with state // ============================================================================= -TEST(CausalConv1DWithStateTest, SequenceVsTokenByToken) { +TEST(CausalConvWithStateTest, SequenceVsTokenByToken) { int batch_size = 1, channels = 2, kernel_size = 3; std::vector weight_data = { @@ -555,7 +555,7 @@ TEST(CausalConv1DWithStateTest, SequenceVsTokenByToken) { // Process full sequence at once std::vector full_output; std::vector full_final_state; - CausalConv1DWithStateReference(full_input, weight_data, &bias_data, &conv_state, + CausalConvWithStateReference(full_input, weight_data, &bias_data, &conv_state, full_output, full_final_state, batch_size, channels, 4, kernel_size, "none"); @@ -571,7 +571,7 @@ TEST(CausalConv1DWithStateTest, SequenceVsTokenByToken) { std::vector token_output; std::vector next_state; - CausalConv1DWithStateReference(token_input, weight_data, &bias_data, ¤t_state, + CausalConvWithStateReference(token_input, weight_data, &bias_data, ¤t_state, token_output, next_state, batch_size, channels, 1, kernel_size, "none"); @@ -607,7 +607,7 @@ TEST(CausalConv1DWithStateTest, SequenceVsTokenByToken) { // Larger dimension test with realistic sizes // ============================================================================= -TEST(CausalConv1DWithStateTest, LargerDimensions) { +TEST(CausalConvWithStateTest, LargerDimensions) { int batch_size = 2, channels = 8, input_length = 16, kernel_size = 4; // Generate test data with a simple pattern @@ -632,7 +632,7 @@ TEST(CausalConv1DWithStateTest, LargerDimensions) { conv_state_data[i] = std::sin(static_cast(i) * 0.3f) * 0.5f; } - RunCausalConv1DWithStateTests( + RunCausalConvWithStateTests( input_data, weight_data, &bias_data, &conv_state_data, batch_size, channels, input_length, kernel_size, "silu"); } diff --git a/onnxruntime/test/contrib_ops/linear_attention_op_test.cc b/onnxruntime/test/contrib_ops/linear_attention_op_test.cc index 824922a0dc2a8..24bb77d4aa007 100644 --- a/onnxruntime/test/contrib_ops/linear_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/linear_attention_op_test.cc @@ -10,237 +10,6 @@ using namespace onnxruntime::test; -namespace linear_attention { - -enum class UpdateRule { - kLinear, // S = S + k^T v - kGated, // S = exp(g) * S + k^T v - kDelta, // S = S + k^T ((v - S^T k) * beta) - kGatedDelta, // S = exp(g) * S + k^T ((v - exp(g) * S^T k) * beta) -}; - -/// 4D tensor accessor with (d0, d1, d2, d3) layout (row-major). -template -struct Tensor4D { - T* data; - int d0, d1, d2, d3; - - Tensor4D(T* data, int d0, int d1, int d2, int d3) - : data(data), d0(d0), d1(d1), d2(d2), d3(d3) {} - - T& operator()(int i0, int i1, int i2, int i3) { - return data[((i0 * d1 + i1) * d2 + i2) * d3 + i3]; - } - const T& operator()(int i0, int i1, int i2, int i3) const { - return data[((i0 * d1 + i1) * d2 + i2) * d3 + i3]; - } - - int size() const { return d0 * d1 * d2 * d3; } -}; - -/// 3D tensor accessor with (d0, d1, d2) layout (row-major). -template -struct Tensor3D { - T* data; - int d0, d1, d2; - - Tensor3D(T* data, int d0, int d1, int d2) - : data(data), d0(d0), d1(d1), d2(d2) {} - - T& operator()(int i0, int i1, int i2) { - return data[(i0 * d1 + i1) * d2 + i2]; - } - const T& operator()(int i0, int i1, int i2) const { - return data[(i0 * d1 + i1) * d2 + i2]; - } - - int size() const { return d0 * d1 * d2; } -}; - -/// Compute one recurrence step for a single (batch, head) pair. -/// -/// state: (d_k, d_v) — carried across time steps (modified in place) -/// q_t: (d_k,) — query for this step -/// k_t: (d_k,) — key for this step -/// v_t: (d_v,) — value for this step -/// decay_t: scalar — log-space decay gate -/// beta_t: scalar — update rate -/// output_t: (d_v,) — output for this step (written) -template -inline void recurrence_step( - T* state, // (d_k, d_v), modified in place - const T* q_t, // (d_k,) - const T* k_t, // (d_k,) - const T* v_t, // (d_v,) - T decay_t, // scalar (log-space) - T beta_t, // scalar - T* output_t, // (d_v,), written - int d_k, - int d_v, - UpdateRule rule) { - const bool uses_decay = (rule == UpdateRule::kGated || rule == UpdateRule::kGatedDelta); - const bool uses_beta = (rule == UpdateRule::kDelta || rule == UpdateRule::kGatedDelta); - - // 1. State decay: state *= exp(decay) - T g_exp = T(1); - if (uses_decay) { - g_exp = std::exp(decay_t); - for (int i = 0; i < d_k * d_v; ++i) { - state[i] *= g_exp; - } - } - - // 2. Retrieval: retrieval[j] = sum_i k[i] * state[i, j] (k @ state) - // This computes k_row (1, d_k) @ state (d_k, d_v) -> (1, d_v) - std::vector retrieval(d_v, T(0)); - for (int i = 0; i < d_k; ++i) { - for (int j = 0; j < d_v; ++j) { - retrieval[j] += k_t[i] * state[i * d_v + j]; - } - } - - // 3. Compute delta - std::vector delta(d_v); - if (uses_beta) { - // delta = (v - retrieval) * beta - for (int j = 0; j < d_v; ++j) { - delta[j] = (v_t[j] - retrieval[j]) * beta_t; - } - } else { - // delta = v - std::copy(v_t, v_t + d_v, delta.data()); - } - - // 4. State update: state += k^T @ delta (outer product) - // state[i, j] += k[i] * delta[j] - for (int i = 0; i < d_k; ++i) { - for (int j = 0; j < d_v; ++j) { - state[i * d_v + j] += k_t[i] * delta[j]; - } - } - - // 5. Output: output[j] = sum_i q[i] * new_state[i, j] (q @ new_state) - std::fill(output_t, output_t + d_v, T(0)); - for (int i = 0; i < d_k; ++i) { - for (int j = 0; j < d_v; ++j) { - output_t[j] += q_t[i] * state[i * d_v + j]; - } - } -} - -/// Expand Q/K heads for GQA: repeat each head `ratio` times. -/// -/// src: (B, H_kv, T, d) -/// dst: (B, H, T, d) where H = H_kv * ratio -template -inline void expand_kv_heads( - const T* src, T* dst, - int B, int H_kv, int T_len, int d, int ratio) { - if (ratio == 1) { - std::memcpy(dst, src, B * H_kv * T_len * d * sizeof(T)); - return; - } - for (int b = 0; b < B; ++b) { - for (int h_kv = 0; h_kv < H_kv; ++h_kv) { - const T* src_head = src + ((b * H_kv + h_kv) * T_len) * d; - for (int r = 0; r < ratio; ++r) { - int h = h_kv * ratio + r; - T* dst_head = dst + ((b * (H_kv * ratio) + h) * T_len) * d; - std::memcpy(dst_head, src_head, T_len * d * sizeof(T)); - } - } - } -} - -/// Run the full LinearAttention operator. -/// -/// query: (B, H_kv, T, d_k) — pre-scaled by 1/sqrt(d_k) -/// key: (B, H_kv, T, d_k) — L2-normalized -/// value: (B, H, T, d_v) -/// past_state: (B, H, d_k, d_v) — recurrent state from previous chunk -/// decay: (B, H, T) — log-space decay gate -/// beta: (B, H, T) — sigmoid update rate -/// output: (B, H, T, d_v) — attention output [written] -/// present_state: (B, H, d_k, d_v) — updated state [written] -/// -/// H must be divisible by H_kv (GQA ratio = H / H_kv). -template -void linear_attention_forward( - const T* query, // (B, H_kv, T, d_k) - const T* key, // (B, H_kv, T, d_k) - const T* value, // (B, H, T, d_v) - const T* past_state, // (B, H, d_k, d_v) - const T* decay, // (B, H, T) - const T* beta, // (B, H, T) - T* output, // (B, H, T, d_v) - T* present_state, // (B, H, d_k, d_v) - int B, - int H_kv, - int H, - int T_len, - int d_k, - int d_v, - float scale, - UpdateRule rule = UpdateRule::kGatedDelta) { - assert(H % H_kv == 0 && "H must be divisible by H_kv for GQA"); - const int ratio = H / H_kv; - - // --- GQA: expand Q/K heads to match V head count --- - std::vector q_expanded(B * H * T_len * d_k); - std::vector k_expanded(B * H * T_len * d_k); - expand_kv_heads(query, q_expanded.data(), B, H_kv, T_len, d_k, ratio); - expand_kv_heads(key, k_expanded.data(), B, H_kv, T_len, d_k, ratio); - - // Accessors for expanded Q/K - Tensor4D Q(q_expanded.data(), B, H, T_len, d_k); - Tensor4D K(k_expanded.data(), B, H, T_len, d_k); - Tensor4D V(value, B, H, T_len, d_v); - Tensor3D D(decay, B, H, T_len); - Tensor4D O(output, B, H, T_len, d_v); - - // Copy past_state into present_state (we'll modify it in place) - const int state_size = B * H * d_k * d_v; - std::memcpy(present_state, past_state, state_size * sizeof(T)); - Tensor4D S(present_state, B, H, d_k, d_v); - - // --- Sequential recurrence over time --- - for (int b = 0; b < B; ++b) { - for (int h = 0; h < H; ++h) { - T* state_bh = &S(b, h, 0, 0); // (d_k, d_v) slice - - for (int t = 0; t < T_len; ++t) { - recurrence_step( - state_bh, - &Q(b, h, t, 0), - &K(b, h, t, 0), - &V(b, h, t, 0), - D(b, h, t), - beta ? beta[((b * H + h) * T_len) + t] : T(0), - &O(b, h, t, 0), - d_k, d_v, - rule); - } - } - } - // linear_attention_forward doesn't apply scale; apply it here. - int output_size = B * H * T_len * d_v; - for (int i = 0; i < output_size; ++i) { - output[i] *= scale; - } - -} - -/// Parse a string into an UpdateRule enum. -inline UpdateRule parse_update_rule(const std::string& s) { - if (s == "linear") return UpdateRule::kLinear; - if (s == "gated") return UpdateRule::kGated; - if (s == "delta") return UpdateRule::kDelta; - if (s == "gated_delta") return UpdateRule::kGatedDelta; - assert(false && "Unknown update_rule"); - return UpdateRule::kGatedDelta; -} - -} // namespace linear_attention namespace onnxruntime { namespace test { @@ -265,32 +34,6 @@ void LinearAttentionReference( int bht = batch_size * num_heads * seq_length; bool decay_broadcast_dk = (decay != nullptr && static_cast(decay->size()) == bht); - if (false && decay_broadcast_dk) { - output.resize(batch_size * num_heads * seq_length * head_dim_v, 0.0f); - final_state.resize(batch_size * num_heads * head_dim_k * head_dim_v, 0.0f); - if (initial_state != nullptr) { - final_state = *initial_state; - } - linear_attention::linear_attention_forward( - query.data(), // (B, H_kv, T, d_k) - key.data(), // (B, H_kv, T, d_k) - value.data(), // (B, H, T, d_v) - final_state.data(), // (B, H, d_k, d_v) - decay->data(), // (B, H, T) - beta ? beta->data() : nullptr, // (B, H, T) - output.data(), // (B, H, T, d_v) - final_state.data(), // (B, H, d_k, d_v) - batch_size, - num_heads, - num_heads, - seq_length, - head_dim_k, - head_dim_v, - scale, - linear_attention::parse_update_rule(update_rule)); - return; - } - // State: (B, H, dk, dv) final_state.resize(batch_size * num_heads * head_dim_k * head_dim_v, 0.0f); output.resize(batch_size * num_heads * seq_length * head_dim_v, 0.0f); @@ -386,6 +129,40 @@ void LinearAttentionReference( } } +// Convert data from 4D (B,H,T,D) layout to 3D packed (B,T,H*D) layout +std::vector PackBHTD_to_BTHD(const std::vector& data_4d, + int B, int H, int T, int D) { + std::vector packed(B * T * H * D); + for (int b = 0; b < B; b++) { + for (int h = 0; h < H; h++) { + for (int t = 0; t < T; t++) { + for (int d = 0; d < D; d++) { + int src_idx = ((b * H + h) * T + t) * D + d; + int dst_idx = (b * T + t) * (H * D) + h * D + d; + packed[dst_idx] = data_4d[src_idx]; + } + } + } + } + return packed; +} + +// Convert decay/beta from (B,H,T) layout to (B,T,H) layout +std::vector TransposeBHT_to_BTH(const std::vector& data, + int B, int H, int T) { + std::vector transposed(B * T * H); + for (int b = 0; b < B; b++) { + for (int h = 0; h < H; h++) { + for (int t = 0; t < T; t++) { + int src_idx = (b * H + h) * T + t; + int dst_idx = (b * T + t) * H + h; + transposed[dst_idx] = data[src_idx]; + } + } + } + return transposed; +} + void RunLinearAttentionTest( const std::string& update_rule, int batch_size, int num_heads, int seq_length, int head_dim_k, int head_dim_v, @@ -396,12 +173,12 @@ void RunLinearAttentionTest( const std::vector* initial_state, const std::vector* decay, const std::vector* beta_data) { - // Compute reference output - std::vector expected_output, expected_state; + // Compute reference output (reference works in 4D layout) + std::vector expected_output_4d, expected_state; LinearAttentionReference(update_rule, batch_size, num_heads, seq_length, head_dim_k, head_dim_v, scale, query, key, value, initial_state, decay, beta_data, - expected_output, expected_state); + expected_output_4d, expected_state); bool enable_webgpu = (nullptr != DefaultWebGpuExecutionProvider().get()); if (!enable_webgpu) { @@ -410,51 +187,65 @@ void RunLinearAttentionTest( int bht = batch_size * num_heads * seq_length; bool decay_broadcast_dk = (decay != nullptr && static_cast(decay->size()) == bht); + + // Convert from 4D (B,H,T,D) to 3D packed (B,T,H*D) for OpTester + auto query_3d = PackBHTD_to_BTHD(query, batch_size, num_heads, seq_length, head_dim_k); + auto key_3d = PackBHTD_to_BTHD(key, batch_size, num_heads, seq_length, head_dim_k); + auto value_3d = PackBHTD_to_BTHD(value, batch_size, num_heads, seq_length, head_dim_v); + auto output_3d = PackBHTD_to_BTHD(expected_output_4d, batch_size, num_heads, seq_length, head_dim_v); + OpTester tester("LinearAttention", 1, onnxruntime::kMSDomain); tester.AddAttribute("update_rule", update_rule); tester.AddAttribute("scale", scale); + tester.AddAttribute("q_num_heads", static_cast(num_heads)); + tester.AddAttribute("kv_num_heads", static_cast(num_heads)); - // Add required inputs - std::vector qk_dims = {batch_size, num_heads, seq_length, head_dim_k}; - std::vector v_dims = {batch_size, num_heads, seq_length, head_dim_v}; - tester.AddInput("query", qk_dims, query); - tester.AddInput("key", qk_dims, key); - tester.AddInput("value", v_dims, value); + // Add required inputs — 3D packed (B, T, H*D) + std::vector qk_dims = {batch_size, seq_length, num_heads * head_dim_k}; + std::vector v_dims = {batch_size, seq_length, num_heads * head_dim_v}; + tester.AddInput("query", qk_dims, query_3d); + tester.AddInput("key", qk_dims, key_3d); + tester.AddInput("value", v_dims, value_3d); - // Optional: initial_state + // Optional: past_state (4D, same format as before) if (initial_state != nullptr) { std::vector state_dims = {batch_size, num_heads, head_dim_k, head_dim_v}; - tester.AddInput("initial_state", state_dims, *initial_state); + tester.AddInput("past_state", state_dims, *initial_state); } else { tester.AddOptionalInputEdge(); } - // Optional: decay + // Optional: decay — convert from (B,H,T[,dk]) to (B,T,H[*dk]) if (decay != nullptr) { if (decay_broadcast_dk) { - std::vector decay_dims = {batch_size, num_heads, seq_length}; - tester.AddInput("decay", decay_dims, *decay); + // (B,H,T) → (B,T,H) + auto decay_3d = TransposeBHT_to_BTH(*decay, batch_size, num_heads, seq_length); + std::vector decay_dims = {batch_size, seq_length, num_heads}; + tester.AddInput("decay", decay_dims, decay_3d); } else { - std::vector decay_dims = {batch_size, num_heads, seq_length, head_dim_k}; - tester.AddInput("decay", decay_dims, *decay); + // (B,H,T,dk) → (B,T,H*dk) + auto decay_3d = PackBHTD_to_BTHD(*decay, batch_size, num_heads, seq_length, head_dim_k); + std::vector decay_dims = {batch_size, seq_length, num_heads * head_dim_k}; + tester.AddInput("decay", decay_dims, decay_3d); } } else { tester.AddOptionalInputEdge(); } - // Optional: beta + // Optional: beta — convert from (B*H*T) flat to (B,T,H) if (beta_data != nullptr) { - std::vector beta_dims = {batch_size, num_heads, seq_length, 1}; - tester.AddInput("beta", beta_dims, *beta_data); + auto beta_3d = TransposeBHT_to_BTH(*beta_data, batch_size, num_heads, seq_length); + std::vector beta_dims = {batch_size, seq_length, num_heads}; + tester.AddInput("beta", beta_dims, beta_3d); } else { tester.AddOptionalInputEdge(); } - // Add outputs - std::vector out_dims = {batch_size, num_heads, seq_length, head_dim_v}; + // Add outputs — output is 3D packed, state is 4D + std::vector out_dims = {batch_size, seq_length, num_heads * head_dim_v}; std::vector state_dims = {batch_size, num_heads, head_dim_k, head_dim_v}; - tester.AddOutput("output", out_dims, expected_output, false, 0.005f, 0.005f); - tester.AddOutput("final_state", state_dims, expected_state, false, 0.005f, 0.005f); + tester.AddOutput("output", out_dims, output_3d, false, 0.005f, 0.005f); + tester.AddOutput("present_state", state_dims, expected_state, false, 0.005f, 0.005f); std::vector> execution_providers; execution_providers.push_back(DefaultWebGpuExecutionProvider()); @@ -799,21 +590,24 @@ TEST(ContribOpLinearAttentionTest, LinearRule_DefaultScale) { OpTester tester("LinearAttention", 1, onnxruntime::kMSDomain); tester.AddAttribute("update_rule", std::string("linear")); + tester.AddAttribute("q_num_heads", static_cast(H)); + tester.AddAttribute("kv_num_heads", static_cast(H)); // Don't set scale — use default (0.0 triggers 1/sqrt(dk)) - std::vector qk_dims = {B, H, T, dk}; - std::vector v_dims = {B, H, T, dv}; + // Convert to 3D packed for B=1, H=1 (flat data is identical) + std::vector qk_dims = {B, T, H * dk}; + std::vector v_dims = {B, T, H * dv}; tester.AddInput("query", qk_dims, query); tester.AddInput("key", qk_dims, key); tester.AddInput("value", v_dims, value); - tester.AddOptionalInputEdge(); // initial_state + tester.AddOptionalInputEdge(); // past_state tester.AddOptionalInputEdge(); // decay tester.AddOptionalInputEdge(); // beta - std::vector out_dims = {B, H, T, dv}; + std::vector out_dims = {B, T, H * dv}; std::vector state_dims = {B, H, dk, dv}; tester.AddOutput("output", out_dims, expected_output, false, 0.005f, 0.005f); - tester.AddOutput("final_state", state_dims, expected_state, false, 0.005f, 0.005f); + tester.AddOutput("present_state", state_dims, expected_state, false, 0.005f, 0.005f); std::vector> execution_providers; execution_providers.push_back(DefaultWebGpuExecutionProvider()); From f7711c4255a963e9b388e4077d95d0c9eb83160f Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Sat, 28 Mar 2026 17:21:11 -0700 Subject: [PATCH 18/27] opt: make use of vec4 --- ...ith_state.cc => causal_conv_with_state.cc} | 14 +- ..._with_state.h => causal_conv_with_state.h} | 10 +- .../webgpu/bert/linear_attention.cc | 304 +++++++++++------- .../webgpu/bert/linear_attention.h | 6 +- .../webgpu/webgpu_contrib_kernels.cc | 2 +- .../core/providers/webgpu/webgpu_kernel.h | 22 -- 6 files changed, 209 insertions(+), 149 deletions(-) rename onnxruntime/contrib_ops/webgpu/bert/{causal_conv1d_with_state.cc => causal_conv_with_state.cc} (96%) rename onnxruntime/contrib_ops/webgpu/bert/{causal_conv1d_with_state.h => causal_conv_with_state.h} (84%) diff --git a/onnxruntime/contrib_ops/webgpu/bert/causal_conv1d_with_state.cc b/onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.cc similarity index 96% rename from onnxruntime/contrib_ops/webgpu/bert/causal_conv1d_with_state.cc rename to onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.cc index a865d40b1c1ef..f7a356618b8d6 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/causal_conv1d_with_state.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "contrib_ops/webgpu/bert/causal_conv1d_with_state.h" +#include "contrib_ops/webgpu/bert/causal_conv_with_state.h" #include "core/providers/webgpu/shader_helper.h" #include "core/providers/webgpu/webgpu_supported_types.h" @@ -13,11 +13,11 @@ namespace onnxruntime { namespace contrib { namespace webgpu { -CausalConv1DActivation ParseCausalConv1DActivation(const std::string& activation_str) { +CausalConvActivation ParseCausalConvActivation(const std::string& activation_str) { if (activation_str == "silu" || activation_str == "swish") { - return CausalConv1DActivation::Silu; + return CausalConvActivation::Silu; } else if (activation_str == "none" || activation_str.empty()) { - return CausalConv1DActivation::None; + return CausalConvActivation::None; } ORT_THROW("Unknown activation for CausalConvWithState: ", activation_str); } @@ -38,7 +38,7 @@ ONNX_OPERATOR_KERNEL_EX( CausalConvWithState::CausalConvWithState(const OpKernelInfo& info) : WebGpuKernel(info) { std::string activation_str = info.GetAttrOrDefault("activation", "none"); - activation_ = ParseCausalConv1DActivation(activation_str); + activation_ = ParseCausalConvActivation(activation_str); } Status CausalConvWithStateProgram::GenerateShaderCode(ShaderHelper& shader) const { @@ -61,7 +61,7 @@ Status CausalConvWithStateProgram::GenerateShaderCode(ShaderHelper& shader) cons const auto& present_state = shader.AddOutput("present_state", ShaderUsage::UseUniform); // Activation function implementation - if (activation_ == CausalConv1DActivation::Silu) { + if (activation_ == CausalConvActivation::Silu) { shader.AdditionalImplementation() << R"SHADER( fn silu(x: input_element_t) -> input_element_t { return x / (1.0 + exp(-x)); @@ -145,7 +145,7 @@ fn silu(x: input_element_t) -> input_element_t { } // Apply activation - if (activation_ == CausalConv1DActivation::Silu) { + if (activation_ == CausalConvActivation::Silu) { shader.MainFunctionBody() << " acc = silu(acc);\n"; } diff --git a/onnxruntime/contrib_ops/webgpu/bert/causal_conv1d_with_state.h b/onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.h similarity index 84% rename from onnxruntime/contrib_ops/webgpu/bert/causal_conv1d_with_state.h rename to onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.h index 820315f838e3d..b23e6249efb42 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/causal_conv1d_with_state.h +++ b/onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.h @@ -16,17 +16,17 @@ using namespace onnxruntime::webgpu; using onnxruntime::webgpu::ComputeContext; // Activation mode for CausalConvWithState -enum class CausalConv1DActivation { +enum class CausalConvActivation { None, Silu, }; -CausalConv1DActivation ParseCausalConv1DActivation(const std::string& activation_str); +CausalConvActivation ParseCausalConvActivation(const std::string& activation_str); // Program for CausalConvWithState class CausalConvWithStateProgram final : public Program { public: - CausalConvWithStateProgram(CausalConv1DActivation activation, bool has_bias, bool has_conv_state, + CausalConvWithStateProgram(CausalConvActivation activation, bool has_bias, bool has_conv_state, int kernel_size) : Program{"CausalConvWithState"}, activation_(activation), @@ -45,7 +45,7 @@ class CausalConvWithStateProgram final : public Program reduction_buf: array;\n" - << "var broadcast_buf: array;\n"; + // and for broadcasting delta values. + // When use_vec4, each reduction_buf entry is a vec4 (4 dv values packed), + // eliminating the inner TILE_V loop and enabling native SIMD operations. + if (use_vec4) { + shader.AdditionalImplementation() + << "var reduction_buf: array, workgroup_size_x>;\n" + << "var broadcast_val: vec4;\n"; + } else { + // TILE_V is emitted as a compile-time constant (not overridable) because + // private address space arrays require fixed sizes in WGSL. + shader.AdditionalImplementation() + << "const TILE_V: u32 = " << tile_v_ << "u;\n" + << "var reduction_buf: array;\n" + << "var broadcast_buf: array;\n"; + } + // Identify which (batch, head, dv_tile) this workgroup handles shader.MainFunctionBody() - // Identify which (batch, head, dv_tile) this workgroup handles - // workgroup_idx is already defined by the framework << "let bh = workgroup_idx / uniforms.num_dv_tiles;\n" << "let dv_tile_idx = workgroup_idx % uniforms.num_dv_tiles;\n" << "let batch_idx = bh / uniforms.num_heads;\n" << "let head_idx = bh % uniforms.num_heads;\n" - << "let dk_idx = local_idx; // thread index = row in state matrix\n" - << "let dv_start = dv_tile_idx * TILE_V;\n" + << "let dk_idx = local_idx; // thread index = row in state matrix\n"; + if (!use_vec4) { + shader.MainFunctionBody() << "let dv_start = dv_tile_idx * TILE_V;\n"; + } + // Precompute packed strides for 3D packed inputs (B, T, H*D) + // When use_vec4, head_dim_v is already divided by 4 (vectorized), so + // packed_dv = num_heads * (head_dim_v/4) and dv_tile_idx indexes vec4 elements. + shader.MainFunctionBody() << "\n" - // Precompute packed strides for 3D packed inputs (B, T, H*D) << "let packed_dk = uniforms.num_heads * uniforms.head_dim_k;\n" << "let packed_dv = uniforms.num_heads * uniforms.head_dim_v;\n" - << "\n" - // Initialize state tile in private memory - << "var state: array;\n" - << "for (var j = 0u; j < TILE_V; j++) {\n" - << " state[j] = 0.0;\n" - << "}\n"; + << "\n"; - // Load initial state if provided - if (has_initial_state_) { + // Initialize state tile in private memory + if (use_vec4) { + shader.MainFunctionBody() << "var state = vec4(0.0);\n"; + } else { shader.MainFunctionBody() - << "// Load initial state: initial_state[batch, head, dk_idx, dv_start..dv_start+TILE_V]\n" - << "if (dk_idx < uniforms.head_dim_k) {\n" - << " let state_base = ((batch_idx * uniforms.num_heads + head_idx) * uniforms.head_dim_k + dk_idx) * uniforms.head_dim_v + dv_start;\n" - << " for (var j = 0u; j < TILE_V; j++) {\n" - << " if (dv_start + j < uniforms.head_dim_v) {\n" - << " state[j] = f32(initial_state[state_base + j]);\n" - << " }\n" - << " }\n" + << "var state: array;\n" + << "for (var j = 0u; j < TILE_V; j++) {\n" + << " state[j] = 0.0;\n" << "}\n"; } + // Load initial state if provided + if (has_initial_state_) { + shader.MainFunctionBody() << "if (dk_idx < uniforms.head_dim_k) {\n"; + if (use_vec4) { + shader.MainFunctionBody() + << " let state_offset = ((batch_idx * uniforms.num_heads + head_idx) * uniforms.head_dim_k + dk_idx) * uniforms.head_dim_v + dv_tile_idx;\n" + << " state = vec4(initial_state[state_offset]);\n"; + } else { + shader.MainFunctionBody() + << " let state_base = ((batch_idx * uniforms.num_heads + head_idx) * uniforms.head_dim_k + dk_idx) * uniforms.head_dim_v + dv_start;\n" + << " for (var j = 0u; j < TILE_V; j++) {\n" + << " if (dv_start + j < uniforms.head_dim_v) {\n" + << " state[j] = f32(initial_state[state_base + j]);\n" + << " }\n" + << " }\n"; + } + shader.MainFunctionBody() << "}\n"; + } + // Main token processing loop shader.MainFunctionBody() << "\n// Process each token sequentially\n" << "for (var t = 0u; t < uniforms.seq_length; t++) {\n" - // 3D packed indexing: (B, T, H*D) — bt_offset indexes (batch, token) pair << " let bt_offset = batch_idx * uniforms.seq_length + t;\n" << " var k_val: f32 = 0.0;\n" << " var q_val: f32 = 0.0;\n" @@ -119,33 +143,129 @@ Status LinearAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.MainFunctionBody() << "\n // Apply exponential decay: S *= exp(decay)\n"; if (decay_broadcast_dk_) { - // Decay shape is (B, T, H_kv) — same decay for all dk rows shader.MainFunctionBody() << " let exp_g = exp(f32(decay[bt_offset * uniforms.num_heads + head_idx]));\n"; } else { - // Decay shape is (B, T, H_kv * dk) — per-dk decay - // For padding threads (dk_idx >= head_dim_k), use 0.0 (exp(0)=1, no decay) shader.MainFunctionBody() << " var exp_g: f32 = 1.0;\n" << " if (dk_idx < uniforms.head_dim_k) {\n" << " exp_g = exp(f32(decay[bt_offset * packed_dk + head_idx * uniforms.head_dim_k + dk_idx]));\n" << " }\n"; } - shader.MainFunctionBody() - << " for (var j = 0u; j < TILE_V; j++) {\n" - << " state[j] *= exp_g;\n" - << " }\n"; + if (use_vec4) { + shader.MainFunctionBody() << " state *= exp_g;\n"; + } else { + shader.MainFunctionBody() + << " for (var j = 0u; j < TILE_V; j++) {\n" + << " state[j] *= exp_g;\n" + << " }\n"; + } } // Step 2: For delta/gated_delta rules, compute retrieved = S^T @ k (reduction across dk) if (update_rule_ == LinearAttentionUpdateRule::Delta || update_rule_ == LinearAttentionUpdateRule::GatedDelta) { + if (use_vec4) { + shader.MainFunctionBody() + << "\n // Compute retrieved = S^T @ k (parallel reduction over dk)\n" + << " reduction_buf[dk_idx] = state * k_val;\n" + << " workgroupBarrier();\n" + << " for (var stride = workgroup_size_x >> 1u; stride > 0u; stride = stride >> 1u) {\n" + << " if (dk_idx < stride) {\n" + << " reduction_buf[dk_idx] += reduction_buf[dk_idx + stride];\n" + << " }\n" + << " workgroupBarrier();\n" + << " }\n" + << " // Compute delta = beta * (v - retrieved) and broadcast\n" + << " let v_idx = bt_offset * packed_dv + head_idx * uniforms.head_dim_v + dv_tile_idx;\n" + << " let beta_base = bt_offset * uniforms.num_heads + head_idx;\n" + << " if (dk_idx == 0u) {\n" + << " let beta_val = f32(beta[beta_base]);\n" + << " broadcast_val = beta_val * (vec4(value[v_idx]) - reduction_buf[0]);\n" + << " }\n" + << " workgroupBarrier();\n" + << " state += k_val * broadcast_val;\n" + << " workgroupBarrier();\n"; + } else { + shader.MainFunctionBody() + << "\n // Compute retrieved = S^T @ k (parallel reduction over dk)\n" + << " for (var j = 0u; j < TILE_V; j++) {\n" + << " reduction_buf[j * workgroup_size_x + dk_idx] = state[j] * k_val;\n" + << " }\n" + << " workgroupBarrier();\n" + << " // Tree reduction\n" + << " for (var stride = workgroup_size_x >> 1u; stride > 0u; stride = stride >> 1u) {\n" + << " if (dk_idx < stride) {\n" + << " for (var j = 0u; j < TILE_V; j++) {\n" + << " reduction_buf[j * workgroup_size_x + dk_idx] += reduction_buf[j * workgroup_size_x + dk_idx + stride];\n" + << " }\n" + << " }\n" + << " workgroupBarrier();\n" + << " }\n" + << " // Compute delta = beta * (v - retrieved) and broadcast\n" + << " let v_base = bt_offset * packed_dv + head_idx * uniforms.head_dim_v + dv_start;\n" + << " let beta_base = bt_offset * uniforms.num_heads + head_idx;\n" + << " if (dk_idx == 0u) {\n" + << " let beta_val = f32(beta[beta_base]);\n" + << " for (var j = 0u; j < TILE_V; j++) {\n" + << " if (dv_start + j < uniforms.head_dim_v) {\n" + << " let retrieved_j = reduction_buf[j * workgroup_size_x];\n" + << " let v_val = f32(value[v_base + j]);\n" + << " broadcast_buf[j] = beta_val * (v_val - retrieved_j);\n" + << " }\n" + << " }\n" + << " }\n" + << " workgroupBarrier();\n" + << " // Update state: S += k ⊗ delta\n" + << " for (var j = 0u; j < TILE_V; j++) {\n" + << " state[j] += k_val * broadcast_buf[j];\n" + << " }\n" + << " workgroupBarrier();\n"; + } + } else { + // For linear and gated modes: S += k ⊗ v (no delta rule) + if (use_vec4) { + shader.MainFunctionBody() + << "\n // Update state: S += k ⊗ v\n" + << " let v_idx = bt_offset * packed_dv + head_idx * uniforms.head_dim_v + dv_tile_idx;\n" + << " state += k_val * vec4(value[v_idx]);\n"; + } else { + shader.MainFunctionBody() + << "\n // Update state: S += k ⊗ v\n" + << " let v_base = bt_offset * packed_dv + head_idx * uniforms.head_dim_v + dv_start;\n" + << " for (var j = 0u; j < TILE_V; j++) {\n" + << " if (dv_start + j < uniforms.head_dim_v) {\n" + << " let v_val = f32(value[v_base + j]);\n" + << " state[j] += k_val * v_val;\n" + << " }\n" + << " }\n"; + } + } + + // Step 3: Compute output = scale * S^T @ q (reduction across dk) + if (use_vec4) { + shader.MainFunctionBody() + << "\n // Compute output = scale * S^T @ q (parallel reduction over dk)\n" + << " reduction_buf[dk_idx] = state * q_val;\n" + << " workgroupBarrier();\n" + << " for (var stride = workgroup_size_x >> 1u; stride > 0u; stride = stride >> 1u) {\n" + << " if (dk_idx < stride) {\n" + << " reduction_buf[dk_idx] += reduction_buf[dk_idx + stride];\n" + << " }\n" + << " workgroupBarrier();\n" + << " }\n" + << " if (dk_idx == 0u) {\n" + << " let out_idx = bt_offset * packed_dv + head_idx * uniforms.head_dim_v + dv_tile_idx;\n" + << " output[out_idx] = output_value_t(reduction_buf[0] * uniforms.scale);\n" + << " }\n" + << " workgroupBarrier();\n" + << "}\n"; // end token loop + } else { shader.MainFunctionBody() - << "\n // Compute retrieved = S^T @ k (parallel reduction over dk)\n" + << "\n // Compute output = scale * S^T @ q (parallel reduction over dk)\n" << " for (var j = 0u; j < TILE_V; j++) {\n" - << " reduction_buf[j * workgroup_size_x + dk_idx] = state[j] * k_val;\n" + << " reduction_buf[j * workgroup_size_x + dk_idx] = state[j] * q_val;\n" << " }\n" << " workgroupBarrier();\n" - << " // Tree reduction\n" << " for (var stride = workgroup_size_x >> 1u; stride > 0u; stride = stride >> 1u) {\n" << " if (dk_idx < stride) {\n" << " for (var j = 0u; j < TILE_V; j++) {\n" @@ -154,78 +274,36 @@ Status LinearAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { << " }\n" << " workgroupBarrier();\n" << " }\n" - // Thread 0 computes delta and broadcasts via shared memory - << " // Compute delta = beta * (v - retrieved) and broadcast\n" - << " let v_base = bt_offset * packed_dv + head_idx * uniforms.head_dim_v + dv_start;\n" - << " let beta_base = bt_offset * uniforms.num_heads + head_idx;\n" << " if (dk_idx == 0u) {\n" - << " let beta_val = f32(beta[beta_base]);\n" + << " let out_base = bt_offset * packed_dv + head_idx * uniforms.head_dim_v + dv_start;\n" << " for (var j = 0u; j < TILE_V; j++) {\n" << " if (dv_start + j < uniforms.head_dim_v) {\n" - << " let retrieved_j = reduction_buf[j * workgroup_size_x];\n" - << " let v_val = f32(value[v_base + j]);\n" - << " broadcast_buf[j] = beta_val * (v_val - retrieved_j);\n" + << " output[out_base + j] = output_element_t(reduction_buf[j * workgroup_size_x] * uniforms.scale);\n" << " }\n" << " }\n" << " }\n" << " workgroupBarrier();\n" - // All threads update their state row using the broadcast delta - << " // Update state: S += k ⊗ delta\n" - << " for (var j = 0u; j < TILE_V; j++) {\n" - << " state[j] += k_val * broadcast_buf[j];\n" - << " }\n" - << " workgroupBarrier();\n"; + << "}\n"; // end token loop + } + + // Write final state (4D: B, H_kv, dk, dv) + shader.MainFunctionBody() + << "\n// Write present_state\n" + << "if (dk_idx < uniforms.head_dim_k) {\n"; + if (use_vec4) { + shader.MainFunctionBody() + << " let state_offset = ((batch_idx * uniforms.num_heads + head_idx) * uniforms.head_dim_k + dk_idx) * uniforms.head_dim_v + dv_tile_idx;\n" + << " present_state[state_offset] = output_value_t(state);\n"; } else { - // For linear and gated modes: S += k ⊗ v (no delta rule) shader.MainFunctionBody() - << "\n // Update state: S += k ⊗ v\n" - << " let v_base = bt_offset * packed_dv + head_idx * uniforms.head_dim_v + dv_start;\n" + << " let state_base = ((batch_idx * uniforms.num_heads + head_idx) * uniforms.head_dim_k + dk_idx) * uniforms.head_dim_v + dv_start;\n" << " for (var j = 0u; j < TILE_V; j++) {\n" << " if (dv_start + j < uniforms.head_dim_v) {\n" - << " let v_val = f32(value[v_base + j]);\n" - << " state[j] += k_val * v_val;\n" + << " present_state[state_base + j] = output_element_t(state[j]);\n" << " }\n" << " }\n"; } - - // Step 3: Compute output = scale * S^T @ q (reduction across dk) - shader.MainFunctionBody() - << "\n // Compute output = scale * S^T @ q (parallel reduction over dk)\n" - << " for (var j = 0u; j < TILE_V; j++) {\n" - << " reduction_buf[j * workgroup_size_x + dk_idx] = state[j] * q_val;\n" - << " }\n" - << " workgroupBarrier();\n" - << " for (var stride = workgroup_size_x >> 1u; stride > 0u; stride = stride >> 1u) {\n" - << " if (dk_idx < stride) {\n" - << " for (var j = 0u; j < TILE_V; j++) {\n" - << " reduction_buf[j * workgroup_size_x + dk_idx] += reduction_buf[j * workgroup_size_x + dk_idx + stride];\n" - << " }\n" - << " }\n" - << " workgroupBarrier();\n" - << " }\n" - // Thread 0 writes the output for this token and dv_tile - << " if (dk_idx == 0u) {\n" - << " let out_base = bt_offset * packed_dv + head_idx * uniforms.head_dim_v + dv_start;\n" - << " for (var j = 0u; j < TILE_V; j++) {\n" - << " if (dv_start + j < uniforms.head_dim_v) {\n" - << " output[out_base + j] = output_element_t(reduction_buf[j * workgroup_size_x] * uniforms.scale);\n" - << " }\n" - << " }\n" - << " }\n" - << " workgroupBarrier();\n" - << "}\n"; // end token loop - - // Write final state (4D: B, H_kv, dk, dv) - shader.MainFunctionBody() - << "\n// Write present_state\n" - << "if (dk_idx < uniforms.head_dim_k) {\n" - << " let state_base = ((batch_idx * uniforms.num_heads + head_idx) * uniforms.head_dim_k + dk_idx) * uniforms.head_dim_v + dv_start;\n" - << " for (var j = 0u; j < TILE_V; j++) {\n" - << " if (dv_start + j < uniforms.head_dim_v) {\n" - << " present_state[state_base + j] = output_element_t(state[j]);\n" - << " }\n" - << " }\n" - << "}\n"; + shader.MainFunctionBody() << "}\n"; return Status::OK(); } @@ -316,14 +394,16 @@ Status LinearAttention::ComputeInternal(ComputeContext& context) const { TensorShapeVector state_shape({batch_size, num_heads, head_dim_k, head_dim_v}); Tensor* present_state = context.Output(1, state_shape); - // Choose tile size: balance parallelism vs shared memory - // TILE_V * WORKGROUP_SIZE * 4 bytes must fit in shared memory (typically 16KB limit) - // E.g., TILE_V=4, WORKGROUP_SIZE=128: 4*128*4 = 2048 bytes - int tile_v = 4; - if (head_dim_v <= 4) { + // Vectorization: when head_dim_v is divisible by 4, use vec4 to pack 4 dv values + // per element. This replaces scalar TILE_V loops with native vec4 SIMD operations, + // reduces shared memory access overhead, and enables coalesced memory reads/writes. + const int components = (head_dim_v % 4 == 0 && head_dim_v >= 4) ? 4 : 1; + int tile_v = (components == 4) ? 1 : 4; + if (components == 1 && head_dim_v <= 4) { tile_v = head_dim_v; } - const int num_dv_tiles = (head_dim_v + tile_v - 1) / tile_v; + const int head_dim_v_vectorized = head_dim_v / components; + const int num_dv_tiles = (head_dim_v_vectorized + tile_v - 1) / tile_v; // Workgroup size = head_dim_k (one thread per dk row) // Ensure it's a power of 2 for tree reduction (round up) @@ -351,13 +431,13 @@ Status LinearAttention::ComputeInternal(ComputeContext& context) const { } } - LinearAttentionProgram program{update_rule_, has_initial_state, has_decay, has_beta, decay_broadcast_dk, tile_v}; + LinearAttentionProgram program{update_rule_, has_initial_state, has_decay, has_beta, decay_broadcast_dk, tile_v, components}; program.AddInputs({{query, ProgramTensorMetadataDependency::TypeAndRank}, {key, ProgramTensorMetadataDependency::TypeAndRank}, - {value, ProgramTensorMetadataDependency::TypeAndRank}}); + {value, ProgramTensorMetadataDependency::TypeAndRank, components}}); if (has_initial_state) { - program.AddInput({past_state, ProgramTensorMetadataDependency::TypeAndRank}); + program.AddInput({past_state, ProgramTensorMetadataDependency::TypeAndRank, components}); } if (has_decay) { program.AddInput({decay, ProgramTensorMetadataDependency::TypeAndRank}); @@ -366,18 +446,18 @@ Status LinearAttention::ComputeInternal(ComputeContext& context) const { program.AddInput({beta, ProgramTensorMetadataDependency::TypeAndRank}); } - program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank}, - {present_state, ProgramTensorMetadataDependency::TypeAndRank}}); + program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank, components}, + {present_state, ProgramTensorMetadataDependency::TypeAndRank, components}}); program.SetDispatchGroupSize(num_workgroups) .SetWorkgroupSize(workgroup_size) .CacheHint(std::to_string(static_cast(update_rule_)), - has_initial_state, has_decay, has_beta, decay_broadcast_dk, tile_v) + has_initial_state, has_decay, has_beta, decay_broadcast_dk, tile_v, components) .AddUniformVariables({{static_cast(batch_size)}, {static_cast(num_heads)}, {static_cast(seq_length)}, {static_cast(head_dim_k)}, - {static_cast(head_dim_v)}, + {static_cast(head_dim_v_vectorized)}, {scale}, {static_cast(num_dv_tiles)}}); diff --git a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.h b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.h index 2b02856440f39..b55e3a68cfe6b 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.h @@ -31,14 +31,15 @@ LinearAttentionUpdateRule ParseUpdateRule(const std::string& rule_str); class LinearAttentionProgram final : public Program { public: LinearAttentionProgram(LinearAttentionUpdateRule update_rule, bool has_initial_state, - bool has_decay, bool has_beta, bool decay_broadcast_dk, int tile_v) + bool has_decay, bool has_beta, bool decay_broadcast_dk, int tile_v, int components) : Program{"LinearAttention"}, update_rule_(update_rule), has_initial_state_(has_initial_state), has_decay_(has_decay), has_beta_(has_beta), decay_broadcast_dk_(decay_broadcast_dk), - tile_v_(tile_v) {} + tile_v_(tile_v), + components_(components) {} Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -58,6 +59,7 @@ class LinearAttentionProgram final : public Program { bool has_beta_; bool decay_broadcast_dk_; int tile_v_; + int components_; }; // Kernel for LinearAttention diff --git a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc index 3f13b91adbefd..2fe5f4d533b7d 100644 --- a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc @@ -2,7 +2,7 @@ // Licensed under the MIT License. #include "contrib_ops/webgpu/webgpu_contrib_kernels.h" -#include "contrib_ops/webgpu/bert/causal_conv1d_with_state.h" +#include "contrib_ops/webgpu/bert/causal_conv_with_state.h" #include "contrib_ops/webgpu/bert/group_query_attention.h" #include "contrib_ops/webgpu/bert/linear_attention.h" diff --git a/onnxruntime/core/providers/webgpu/webgpu_kernel.h b/onnxruntime/core/providers/webgpu/webgpu_kernel.h index 42c31c7e1b82f..854b77ba4876b 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_kernel.h +++ b/onnxruntime/core/providers/webgpu/webgpu_kernel.h @@ -6,7 +6,6 @@ #include "core/providers/webgpu/compute_context.h" #include "core/framework/op_kernel.h" -#include "core/providers/webgpu/numpy_io.h" namespace onnxruntime { @@ -24,27 +23,6 @@ class WebGpuKernel : public OpKernel { virtual Status ComputeInternal(ComputeContext& context) const = 0; - // call with - // NpyTensor(hidden_state, "/tmp/hidden_state.npy", context); - - template - void NpyTensor(const Tensor* t, std::string file, ComputeContext& context) const { - auto t_cpu = context.CreateCPUTensor(t->DataType(), t->Shape()); - ORT_THROW_IF_ERROR(Info().GetDataTransferManager().CopyTensor(*t, t_cpu)); - - std::vector dims; - auto dims1 = t_cpu.Shape().GetDims(); - for (uint64_t i=0; i(dims); - for (int64_t i = 0; i < t_cpu.Shape().Size(); i++) { - a.data[i] = static_cast(t_cpu.Data()[i]); - } - numpy_io::write_numpy_array(file, a); - } - - // Overrides OpKernel::PrePack to handle constant tensor pre-processing for WebGPU kernels. // This method creates a ComputeContextBase and delegates to PrePackInternal. // From 52cee104d9055bdb55723f201eaf00e36ce96a10 Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Sun, 29 Mar 2026 08:48:19 -0700 Subject: [PATCH 19/27] rename to causal_conv_with_state_op_test --- ...1d_with_state_op_test.cc => causal_conv_with_state_op_test.cc} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename onnxruntime/test/contrib_ops/{causal_conv1d_with_state_op_test.cc => causal_conv_with_state_op_test.cc} (100%) diff --git a/onnxruntime/test/contrib_ops/causal_conv1d_with_state_op_test.cc b/onnxruntime/test/contrib_ops/causal_conv_with_state_op_test.cc similarity index 100% rename from onnxruntime/test/contrib_ops/causal_conv1d_with_state_op_test.cc rename to onnxruntime/test/contrib_ops/causal_conv_with_state_op_test.cc From a1e9827934ccf22572b48f0ad5ad44e60044316e Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Sun, 29 Mar 2026 10:12:07 -0700 Subject: [PATCH 20/27] ut looks for the registered ops --- .../causal_conv_with_state_op_test.cc | 49 +++++++++++++------ .../contrib_ops/linear_attention_op_test.cc | 48 +++++++++++++----- 2 files changed, 69 insertions(+), 28 deletions(-) diff --git a/onnxruntime/test/contrib_ops/causal_conv_with_state_op_test.cc b/onnxruntime/test/contrib_ops/causal_conv_with_state_op_test.cc index 76e878d9fff80..7016e55f518db 100644 --- a/onnxruntime/test/contrib_ops/causal_conv_with_state_op_test.cc +++ b/onnxruntime/test/contrib_ops/causal_conv_with_state_op_test.cc @@ -5,10 +5,13 @@ #include #include #include "gtest/gtest.h" +#include "core/common/logging/logging.h" +#include "core/framework/kernel_registry.h" #include "core/session/onnxruntime_cxx_api.h" #include "test/common/tensor_op_test_utils.h" #include "test/common/cuda_op_test_utils.h" #include "test/providers/provider_test_utils.h" +#include "test/util/include/default_providers.h" namespace onnxruntime { namespace test { @@ -90,6 +93,26 @@ void CausalConvWithStateReference( } } +// Returns a WebGPU EP if it is available and has the CausalConvWithState kernel registered, +// or nullptr otherwise. +std::unique_ptr TryGetEpWithCausalConvWithState() { + auto ep = DefaultWebGpuExecutionProvider(); + if (!ep) { + ep = DefaultCpuExecutionProvider(); + } + + auto kernel_registry = ep->GetKernelRegistry(); + if (kernel_registry) { + const KernelCreateInfo* info = nullptr; + KernelRegistry::TypeConstraintMap type_constraints; + auto status = kernel_registry->TryFindKernel( + ep->Type(), "CausalConvWithState", kMSDomain, 1, + type_constraints, DefaultLoggingManager().DefaultLogger(), &info); + if (!status.IsOK()) return nullptr; + } + return ep; +} + } // anonymous namespace static void RunCausalConvWithStateTest( @@ -105,6 +128,12 @@ static void RunCausalConvWithStateTest( int kernel_size, const std::string& activation, TensorType tensor_type) { + auto ep = TryGetEpWithCausalConvWithState(); + if (!ep) { + GTEST_SKIP() << "CausalConvWithState kernel not registered"; + return; + } + int state_length = kernel_size - 1; std::vector input_shape = {batch_size, channels, input_length}; @@ -113,19 +142,7 @@ static void RunCausalConvWithStateTest( std::vector state_shape = {batch_size, channels, state_length}; std::vector output_shape = {batch_size, channels, input_length}; - std::vector> execution_providers; - - bool enable_webgpu = nullptr != DefaultWebGpuExecutionProvider().get(); - if (enable_webgpu) { - execution_providers.push_back(DefaultWebGpuExecutionProvider()); - } - - if (execution_providers.empty()) { - // Skip if no providers available - return; - } - - for (auto& ep : execution_providers) { + { OpTester test("CausalConvWithState", 1, onnxruntime::kMSDomain); test.AddAttribute("activation", activation); @@ -170,9 +187,9 @@ static void RunCausalConvWithStateTest( test.SetOutputAbsErr("output", 0.01f); test.SetOutputAbsErr("present_state", 0.01f); - std::vector> test_execution_providers; - test_execution_providers.push_back(std::move(ep)); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &test_execution_providers); + std::vector> execution_providers; + execution_providers.push_back(std::move(ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } } diff --git a/onnxruntime/test/contrib_ops/linear_attention_op_test.cc b/onnxruntime/test/contrib_ops/linear_attention_op_test.cc index 24bb77d4aa007..bdc293cc1fdcc 100644 --- a/onnxruntime/test/contrib_ops/linear_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/linear_attention_op_test.cc @@ -5,6 +5,8 @@ #include #include "gtest/gtest.h" +#include "core/common/logging/logging.h" +#include "core/framework/kernel_registry.h" #include "test/providers/provider_test_utils.h" #include "test/util/include/default_providers.h" @@ -163,6 +165,26 @@ std::vector TransposeBHT_to_BTH(const std::vector& data, return transposed; } +// Returns a WebGPU EP if it is available and has the LinearAttention kernel registered, +// or nullptr otherwise. +std::unique_ptr TryGetEpWithLinearAttention() { + auto ep = DefaultWebGpuExecutionProvider(); + if (!ep) { + ep = DefaultCpuExecutionProvider(); + } + + auto kernel_registry = ep->GetKernelRegistry(); + if (kernel_registry) { + const KernelCreateInfo* info = nullptr; + KernelRegistry::TypeConstraintMap type_constraints; + auto status = kernel_registry->TryFindKernel( + ep->Type(), "LinearAttention", kMSDomain, 1, + type_constraints, DefaultLoggingManager().DefaultLogger(), &info); + if (!status.IsOK()) return nullptr; + } + return ep; +} + void RunLinearAttentionTest( const std::string& update_rule, int batch_size, int num_heads, int seq_length, int head_dim_k, int head_dim_v, @@ -173,6 +195,12 @@ void RunLinearAttentionTest( const std::vector* initial_state, const std::vector* decay, const std::vector* beta_data) { + auto ep = TryGetEpWithLinearAttention(); + if (!ep) { + GTEST_SKIP() << "LinearAttention kernel not registered"; + return; + } + // Compute reference output (reference works in 4D layout) std::vector expected_output_4d, expected_state; LinearAttentionReference(update_rule, batch_size, num_heads, seq_length, @@ -180,11 +208,6 @@ void RunLinearAttentionTest( query, key, value, initial_state, decay, beta_data, expected_output_4d, expected_state); - bool enable_webgpu = (nullptr != DefaultWebGpuExecutionProvider().get()); - if (!enable_webgpu) { - return; - } - int bht = batch_size * num_heads * seq_length; bool decay_broadcast_dk = (decay != nullptr && static_cast(decay->size()) == bht); @@ -248,7 +271,7 @@ void RunLinearAttentionTest( tester.AddOutput("present_state", state_dims, expected_state, false, 0.005f, 0.005f); std::vector> execution_providers; - execution_providers.push_back(DefaultWebGpuExecutionProvider()); + execution_providers.push_back(std::move(ep)); tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } @@ -570,6 +593,12 @@ TEST(ContribOpLinearAttentionTest, GatedDeltaRule_MultiBatchMultiHead) { // Test: Default scale (should use 1/sqrt(dk)) // =========================================================================== TEST(ContribOpLinearAttentionTest, LinearRule_DefaultScale) { + auto ep = TryGetEpWithLinearAttention(); + if (!ep) { + GTEST_SKIP() << "LinearAttention kernel not registered on WebGPU EP (or EP not available)"; + return; + } + const int B = 1, H = 1, T = 1, dk = 4, dv = 4; std::vector query = {1.0f, 0.0f, 0.5f, -0.5f}; @@ -583,11 +612,6 @@ TEST(ContribOpLinearAttentionTest, LinearRule_DefaultScale) { query, key, value, nullptr, nullptr, nullptr, expected_output, expected_state); - bool enable_webgpu = (nullptr != DefaultWebGpuExecutionProvider().get()); - if (!enable_webgpu) { - return; - } - OpTester tester("LinearAttention", 1, onnxruntime::kMSDomain); tester.AddAttribute("update_rule", std::string("linear")); tester.AddAttribute("q_num_heads", static_cast(H)); @@ -610,7 +634,7 @@ TEST(ContribOpLinearAttentionTest, LinearRule_DefaultScale) { tester.AddOutput("present_state", state_dims, expected_state, false, 0.005f, 0.005f); std::vector> execution_providers; - execution_providers.push_back(DefaultWebGpuExecutionProvider()); + execution_providers.push_back(std::move(ep)); tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } From 5c80e30e76118490901a9ed6a31f90f01c101783 Mon Sep 17 00:00:00 2001 From: gs Date: Mon, 30 Mar 2026 10:05:31 -0700 Subject: [PATCH 21/27] lintrunner -a --- .../causal_conv_with_state_op_test.cc | 38 +++++++++---------- .../contrib_ops/linear_attention_op_test.cc | 12 +++--- 2 files changed, 24 insertions(+), 26 deletions(-) diff --git a/onnxruntime/test/contrib_ops/causal_conv_with_state_op_test.cc b/onnxruntime/test/contrib_ops/causal_conv_with_state_op_test.cc index 7016e55f518db..2a7837dd1ce73 100644 --- a/onnxruntime/test/contrib_ops/causal_conv_with_state_op_test.cc +++ b/onnxruntime/test/contrib_ops/causal_conv_with_state_op_test.cc @@ -510,39 +510,39 @@ TEST(CausalConvWithStateTest, StateContinuity) { std::vector expected_output1; std::vector expected_state1; CausalConvWithStateReference(input1, weight_data, &bias_data, &conv_state, - expected_output1, expected_state1, - batch_size, channels, input_length, kernel_size, "none"); + expected_output1, expected_state1, + batch_size, channels, input_length, kernel_size, "none"); RunCausalConvWithStateTest(input1, weight_data, &bias_data, &conv_state, - expected_output1, expected_state1, - batch_size, channels, input_length, kernel_size, "none", - TensorType::kFloat); + expected_output1, expected_state1, + batch_size, channels, input_length, kernel_size, "none", + TensorType::kFloat); // Second token, using present_state from first as conv_state std::vector input2 = {2.0f}; std::vector expected_output2; std::vector expected_state2; CausalConvWithStateReference(input2, weight_data, &bias_data, &expected_state1, - expected_output2, expected_state2, - batch_size, channels, input_length, kernel_size, "none"); + expected_output2, expected_state2, + batch_size, channels, input_length, kernel_size, "none"); RunCausalConvWithStateTest(input2, weight_data, &bias_data, &expected_state1, - expected_output2, expected_state2, - batch_size, channels, input_length, kernel_size, "none", - TensorType::kFloat); + expected_output2, expected_state2, + batch_size, channels, input_length, kernel_size, "none", + TensorType::kFloat); // Third token std::vector input3 = {3.0f}; std::vector expected_output3; std::vector expected_state3; CausalConvWithStateReference(input3, weight_data, &bias_data, &expected_state2, - expected_output3, expected_state3, - batch_size, channels, input_length, kernel_size, "none"); + expected_output3, expected_state3, + batch_size, channels, input_length, kernel_size, "none"); RunCausalConvWithStateTest(input3, weight_data, &bias_data, &expected_state2, - expected_output3, expected_state3, - batch_size, channels, input_length, kernel_size, "none", - TensorType::kFloat); + expected_output3, expected_state3, + batch_size, channels, input_length, kernel_size, "none", + TensorType::kFloat); // The present_state after processing [1, 2, 3] should be [2, 3] EXPECT_NEAR(expected_state3[0], 2.0f, 1e-5f); @@ -573,8 +573,8 @@ TEST(CausalConvWithStateTest, SequenceVsTokenByToken) { std::vector full_output; std::vector full_final_state; CausalConvWithStateReference(full_input, weight_data, &bias_data, &conv_state, - full_output, full_final_state, - batch_size, channels, 4, kernel_size, "none"); + full_output, full_final_state, + batch_size, channels, 4, kernel_size, "none"); // Process token by token std::vector current_state = conv_state; @@ -589,8 +589,8 @@ TEST(CausalConvWithStateTest, SequenceVsTokenByToken) { std::vector token_output; std::vector next_state; CausalConvWithStateReference(token_input, weight_data, &bias_data, ¤t_state, - token_output, next_state, - batch_size, channels, 1, kernel_size, "none"); + token_output, next_state, + batch_size, channels, 1, kernel_size, "none"); // Collect outputs for (int d = 0; d < channels; ++d) { diff --git a/onnxruntime/test/contrib_ops/linear_attention_op_test.cc b/onnxruntime/test/contrib_ops/linear_attention_op_test.cc index bdc293cc1fdcc..84ba7a2c361a6 100644 --- a/onnxruntime/test/contrib_ops/linear_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/linear_attention_op_test.cc @@ -12,7 +12,6 @@ using namespace onnxruntime::test; - namespace onnxruntime { namespace test { @@ -32,11 +31,10 @@ void LinearAttentionReference( const std::vector* beta, std::vector& output, std::vector& final_state) { + int bht = batch_size * num_heads * seq_length; + bool decay_broadcast_dk = (decay != nullptr && static_cast(decay->size()) == bht); - int bht = batch_size * num_heads * seq_length; - bool decay_broadcast_dk = (decay != nullptr && static_cast(decay->size()) == bht); - - // State: (B, H, dk, dv) + // State: (B, H, dk, dv) final_state.resize(batch_size * num_heads * head_dim_k * head_dim_v, 0.0f); output.resize(batch_size * num_heads * seq_length * head_dim_v, 0.0f); @@ -133,7 +131,7 @@ void LinearAttentionReference( // Convert data from 4D (B,H,T,D) layout to 3D packed (B,T,H*D) layout std::vector PackBHTD_to_BTHD(const std::vector& data_4d, - int B, int H, int T, int D) { + int B, int H, int T, int D) { std::vector packed(B * T * H * D); for (int b = 0; b < B; b++) { for (int h = 0; h < H; h++) { @@ -151,7 +149,7 @@ std::vector PackBHTD_to_BTHD(const std::vector& data_4d, // Convert decay/beta from (B,H,T) layout to (B,T,H) layout std::vector TransposeBHT_to_BTH(const std::vector& data, - int B, int H, int T) { + int B, int H, int T) { std::vector transposed(B * T * H); for (int b = 0; b < B; b++) { for (int h = 0; h < H; h++) { From 81251ff3014fb041add345f9409542dc52e191c7 Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Mon, 30 Mar 2026 22:02:08 -0700 Subject: [PATCH 22/27] move shader generation to .wgsl.template --- .../webgpu/bert/causal_conv_with_state.cc | 168 +---------- .../bert/causal_conv_with_state.wgsl.template | 118 ++++++++ .../webgpu/bert/linear_attention.cc | 278 ++---------------- .../bert/linear_attention.wgsl.template | 247 ++++++++++++++++ .../core/graph/contrib_ops/bert_defs.cc | 3 +- .../core/providers/webgpu/tensor/concat.cc | 30 +- 6 files changed, 419 insertions(+), 425 deletions(-) create mode 100644 onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.wgsl.template create mode 100644 onnxruntime/contrib_ops/webgpu/bert/linear_attention.wgsl.template diff --git a/onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.cc b/onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.cc index f7a356618b8d6..527f3c7d3027f 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.cc @@ -5,6 +5,7 @@ #include "core/providers/webgpu/shader_helper.h" #include "core/providers/webgpu/webgpu_supported_types.h" +// #include "core/providers/webgpu/wgsl_templates/wgsl_gen.h" #include "contrib_ops/webgpu/webgpu_contrib_kernels.h" using namespace onnxruntime::webgpu; @@ -42,170 +43,23 @@ CausalConvWithState::CausalConvWithState(const OpKernelInfo& info) } Status CausalConvWithStateProgram::GenerateShaderCode(ShaderHelper& shader) const { - // Input tensors - const auto& input = shader.AddInput("input", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias); - const auto& weight = shader.AddInput("weight", ShaderUsage::UseUniform); + shader.AddInput("input", ShaderUsage::UseElementTypeAlias); + shader.AddInput("weight", ShaderUsage::UseUniform); - // Optional inputs - const ShaderVariableHelper* bias_ptr = nullptr; - const ShaderVariableHelper* conv_state_ptr = nullptr; if (has_bias_) { - bias_ptr = &shader.AddInput("bias", ShaderUsage::UseUniform); + shader.AddInput("bias", ShaderUsage::UseUniform); } if (has_conv_state_) { - conv_state_ptr = &shader.AddInput("conv_state", ShaderUsage::UseUniform); + shader.AddInput("conv_state", ShaderUsage::UseUniform); } - // Output tensors - const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform); - const auto& present_state = shader.AddOutput("present_state", ShaderUsage::UseUniform); + shader.AddOutput("output", ShaderUsage::UseUniform); + shader.AddOutput("present_state", ShaderUsage::UseUniform); - // Activation function implementation - if (activation_ == CausalConvActivation::Silu) { - shader.AdditionalImplementation() << R"SHADER( -fn silu(x: input_element_t) -> input_element_t { - return x / (1.0 + exp(-x)); -} -)SHADER"; - } - - // Flatten to 1D dispatch: each thread handles one (batch, channel, pos) triple. - shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") - << R"SHADER( - let batch_size = uniforms.batch_size; - let channels = uniforms.channels; - let input_length = uniforms.input_length; - let kernel_size = uniforms.kernel_size; - let state_length = uniforms.state_length; // = kernel_size - 1 - - let pos = global_idx % input_length; - let bc_idx = global_idx / input_length; - let batch_idx = bc_idx / channels; - let channel_idx = bc_idx % channels; - - // Perform depthwise causal convolution for this (batch, channel, pos). - // The convolution window looks back kernel_size-1 positions. - // With conv_state providing the history before position 0, the - // "virtual" input is: [conv_state[0..state_length-1], input[0..L-1]] - // - // For output position pos: - // output[pos] = sum_{j=0}^{kernel_size-1} weight[j] * virtual_input[pos + j] - // where virtual_input is state_length positions of conv_state - // followed by input_length positions of input. - - var acc: input_element_t = 0.0; - - // Weight layout: (D, 1, K) -> channel_idx * kernel_size + j - let weight_base = channel_idx * kernel_size; - - for (var j: u32 = 0; j < kernel_size; j = j + 1) { - // virtual_pos is the position in the concatenated [conv_state, input] - let virtual_pos = pos + j; - - var val: input_element_t = 0.0; -)SHADER"; - - if (has_conv_state_) { - shader.MainFunctionBody() << R"SHADER( - if (virtual_pos < state_length) { - // Read from conv_state: (B, D, state_length) - let state_idx = (batch_idx * channels + channel_idx) * state_length + virtual_pos; - val = )SHADER" - << conv_state_ptr->GetByOffset("state_idx") << R"SHADER(; - } else { - // Read from input: (B, D, L) - let input_pos = virtual_pos - state_length; - let input_idx = (batch_idx * channels + channel_idx) * input_length + input_pos; - val = )SHADER" - << input.GetByOffset("input_idx") << R"SHADER(; - } -)SHADER"; - } else { - // No conv_state: pad with zeros for positions before the input - shader.MainFunctionBody() << R"SHADER( - if (virtual_pos >= state_length) { - let input_pos = virtual_pos - state_length; - let input_idx = (batch_idx * channels + channel_idx) * input_length + input_pos; - val = )SHADER" - << input.GetByOffset("input_idx") << R"SHADER(; - } -)SHADER"; - } - - shader.MainFunctionBody() << R"SHADER( - let w = )SHADER" - << weight.GetByOffset("weight_base + j") << R"SHADER(; - acc = acc + val * w; - } -)SHADER"; - - // Add bias if present - if (has_bias_) { - shader.MainFunctionBody() << " acc = acc + " << bias_ptr->GetByOffset("channel_idx") << ";\n"; - } - - // Apply activation - if (activation_ == CausalConvActivation::Silu) { - shader.MainFunctionBody() << " acc = silu(acc);\n"; - } - - // Write output: (B, D, L) - shader.MainFunctionBody() << R"SHADER( - let out_idx = (batch_idx * channels + channel_idx) * input_length + pos; - )SHADER" << output.SetByOffset("out_idx", "acc") - << "\n"; - - // Write present_state: the last (kernel_size - 1) elements from the - // virtual input [conv_state, input]. The virtual input has total length - // state_length + input_length. We want positions from - // (state_length + input_length - state_length) to (state_length + input_length - 1), - // i.e. the last state_length positions of the virtual input, which are the - // last state_length positions of input (when input_length >= state_length). - // - // We only write present_state once per (batch, channel), using the thread - // at pos == 0 to write all state_length values. - shader.MainFunctionBody() << R"SHADER( - if (pos == 0u) { - for (var s: u32 = 0; s < state_length; s = s + 1) { - var state_val: input_element_t = 0.0; - // total_len = state_length + input_length - // We want virtual_input[total_len - state_length + s] = virtual_input[input_length + s] - let vp = input_length + s; -)SHADER"; - - if (has_conv_state_) { - shader.MainFunctionBody() << R"SHADER( - if (vp < state_length) { - let si = (batch_idx * channels + channel_idx) * state_length + vp; - state_val = )SHADER" - << conv_state_ptr->GetByOffset("si") << R"SHADER(; - } else { - let ip = vp - state_length; - let ii = (batch_idx * channels + channel_idx) * input_length + ip; - state_val = )SHADER" - << input.GetByOffset("ii") << R"SHADER(; - } -)SHADER"; - } else { - shader.MainFunctionBody() << R"SHADER( - if (vp >= state_length) { - let ip = vp - state_length; - let ii = (batch_idx * channels + channel_idx) * input_length + ip; - state_val = )SHADER" - << input.GetByOffset("ii") << R"SHADER(; - } -)SHADER"; - } - - shader.MainFunctionBody() << R"SHADER( - let ps_idx = (batch_idx * channels + channel_idx) * state_length + s; - )SHADER" << present_state.SetByOffset("ps_idx", "state_val") - << R"SHADER( - } - } -)SHADER"; - - return Status::OK(); + return WGSL_TEMPLATE_APPLY(shader, "bert/causal_conv_with_state.wgsl.template", + WGSL_TEMPLATE_PARAMETER(has_bias, has_bias_), + WGSL_TEMPLATE_PARAMETER(has_conv_state, has_conv_state_), + WGSL_TEMPLATE_PARAMETER(use_silu, activation_ == CausalConvActivation::Silu)); } Status CausalConvWithState::ComputeInternal(ComputeContext& context) const { diff --git a/onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.wgsl.template new file mode 100644 index 0000000000000..e109f167d27b1 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.wgsl.template @@ -0,0 +1,118 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#param has_bias +#param has_conv_state +#param use_silu + +#use guardAgainstOutOfBoundsWorkgroupSizes + +#if use_silu +fn silu(x: input_element_t) -> input_element_t { + return x / (1.0 + exp(-x)); +} +#endif + +$MAIN { + guardAgainstOutOfBoundsWorkgroupSizes(uniforms.output_size); + + let batch_size = uniforms.batch_size; + let channels = uniforms.channels; + let input_length = uniforms.input_length; + let kernel_size = uniforms.kernel_size; + let state_length = uniforms.state_length; // = kernel_size - 1 + + let pos = global_idx % input_length; + let bc_idx = global_idx / input_length; + let batch_idx = bc_idx / channels; + let channel_idx = bc_idx % channels; + + // Perform depthwise causal convolution for this (batch, channel, pos). + // The convolution window looks back kernel_size-1 positions. + // With conv_state providing the history before position 0, the + // "virtual" input is: [conv_state[0..state_length-1], input[0..L-1]] + // + // For output position pos: + // output[pos] = sum_{j=0}^{kernel_size-1} weight[j] * virtual_input[pos + j] + // where virtual_input is state_length positions of conv_state + // followed by input_length positions of input. + + var acc: input_element_t = 0.0; + + // Weight layout: (D, 1, K) -> channel_idx * kernel_size + j + let weight_base = channel_idx * kernel_size; + + for (var j: u32 = 0; j < kernel_size; j = j + 1) { + // virtual_pos is the position in the concatenated [conv_state, input] + let virtual_pos = pos + j; + + var val: input_element_t = 0.0; + +#if has_conv_state + if (virtual_pos < state_length) { + // Read from conv_state: (B, D, state_length) + let state_idx = (batch_idx * channels + channel_idx) * state_length + virtual_pos; + val = conv_state[state_idx]; + } else { + // Read from input: (B, D, L) + let input_pos = virtual_pos - state_length; + let input_idx = (batch_idx * channels + channel_idx) * input_length + input_pos; + val = input[input_idx]; + } +#else + // No conv_state: pad with zeros for positions before the input + if (virtual_pos >= state_length) { + let input_pos = virtual_pos - state_length; + let input_idx = (batch_idx * channels + channel_idx) * input_length + input_pos; + val = input[input_idx]; + } +#endif + + let w = weight[weight_base + j]; + acc = acc + val * w; + } + +#if has_bias + acc = acc + bias[channel_idx]; +#endif + +#if use_silu + acc = silu(acc); +#endif + + // Write output: (B, D, L) + let out_idx = (batch_idx * channels + channel_idx) * input_length + pos; + output[out_idx] = acc; + + // Write present_state: the last (kernel_size - 1) elements from the + // virtual input [conv_state, input]. We only write present_state once + // per (batch, channel), using the thread at pos == 0. + if (pos == 0u) { + for (var s: u32 = 0; s < state_length; s = s + 1) { + var state_val: input_element_t = 0.0; + // total_len = state_length + input_length + // We want virtual_input[total_len - state_length + s] = virtual_input[input_length + s] + let vp = input_length + s; + +#if has_conv_state + if (vp < state_length) { + let si = (batch_idx * channels + channel_idx) * state_length + vp; + state_val = conv_state[si]; + } else { + let ip = vp - state_length; + let ii = (batch_idx * channels + channel_idx) * input_length + ip; + state_val = input[ii]; + } +#else + if (vp >= state_length) { + let ip = vp - state_length; + let ii = (batch_idx * channels + channel_idx) * input_length + ip; + state_val = input[ii]; + } +#endif + + let ps_idx = (batch_idx * channels + channel_idx) * state_length + s; + present_state[ps_idx] = state_val; + } + } +} // MAIN diff --git a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc index bc04894457fa9..39bf907f9b741 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc @@ -5,6 +5,7 @@ #include "core/providers/webgpu/shader_helper.h" #include "core/providers/webgpu/webgpu_supported_types.h" +// #include "core/providers/webgpu/wgsl_templates/wgsl_gen.h" #include "contrib_ops/webgpu/webgpu_contrib_kernels.h" using namespace onnxruntime::webgpu; @@ -41,6 +42,23 @@ LinearAttentionUpdateRule ParseUpdateRule(const std::string& rule_str) { Status LinearAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { const bool use_vec4 = (components_ == 4); + // Map update rule to integer for template conditionals + int update_rule_int = 0; + switch (update_rule_) { + case LinearAttentionUpdateRule::Linear: + update_rule_int = 0; + break; + case LinearAttentionUpdateRule::Gated: + update_rule_int = 1; + break; + case LinearAttentionUpdateRule::Delta: + update_rule_int = 2; + break; + case LinearAttentionUpdateRule::GatedDelta: + update_rule_int = 3; + break; + } + // Add inputs shader.AddInput("query", ShaderUsage::UseUniform); shader.AddInput("key", ShaderUsage::UseUniform); @@ -59,253 +77,12 @@ Status LinearAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); shader.AddOutput("present_state", ShaderUsage::UseUniform); - // Shared memory for parallel reduction across dk threads - // and for broadcasting delta values. - // When use_vec4, each reduction_buf entry is a vec4 (4 dv values packed), - // eliminating the inner TILE_V loop and enabling native SIMD operations. - if (use_vec4) { - shader.AdditionalImplementation() - << "var reduction_buf: array, workgroup_size_x>;\n" - << "var broadcast_val: vec4;\n"; - } else { - // TILE_V is emitted as a compile-time constant (not overridable) because - // private address space arrays require fixed sizes in WGSL. - shader.AdditionalImplementation() - << "const TILE_V: u32 = " << tile_v_ << "u;\n" - << "var reduction_buf: array;\n" - << "var broadcast_buf: array;\n"; - } - - // Identify which (batch, head, dv_tile) this workgroup handles - shader.MainFunctionBody() - << "let bh = workgroup_idx / uniforms.num_dv_tiles;\n" - << "let dv_tile_idx = workgroup_idx % uniforms.num_dv_tiles;\n" - << "let batch_idx = bh / uniforms.num_heads;\n" - << "let head_idx = bh % uniforms.num_heads;\n" - << "let dk_idx = local_idx; // thread index = row in state matrix\n"; - if (!use_vec4) { - shader.MainFunctionBody() << "let dv_start = dv_tile_idx * TILE_V;\n"; - } - // Precompute packed strides for 3D packed inputs (B, T, H*D) - // When use_vec4, head_dim_v is already divided by 4 (vectorized), so - // packed_dv = num_heads * (head_dim_v/4) and dv_tile_idx indexes vec4 elements. - shader.MainFunctionBody() - << "\n" - << "let packed_dk = uniforms.num_heads * uniforms.head_dim_k;\n" - << "let packed_dv = uniforms.num_heads * uniforms.head_dim_v;\n" - << "\n"; - - // Initialize state tile in private memory - if (use_vec4) { - shader.MainFunctionBody() << "var state = vec4(0.0);\n"; - } else { - shader.MainFunctionBody() - << "var state: array;\n" - << "for (var j = 0u; j < TILE_V; j++) {\n" - << " state[j] = 0.0;\n" - << "}\n"; - } - - // Load initial state if provided - if (has_initial_state_) { - shader.MainFunctionBody() << "if (dk_idx < uniforms.head_dim_k) {\n"; - if (use_vec4) { - shader.MainFunctionBody() - << " let state_offset = ((batch_idx * uniforms.num_heads + head_idx) * uniforms.head_dim_k + dk_idx) * uniforms.head_dim_v + dv_tile_idx;\n" - << " state = vec4(initial_state[state_offset]);\n"; - } else { - shader.MainFunctionBody() - << " let state_base = ((batch_idx * uniforms.num_heads + head_idx) * uniforms.head_dim_k + dk_idx) * uniforms.head_dim_v + dv_start;\n" - << " for (var j = 0u; j < TILE_V; j++) {\n" - << " if (dv_start + j < uniforms.head_dim_v) {\n" - << " state[j] = f32(initial_state[state_base + j]);\n" - << " }\n" - << " }\n"; - } - shader.MainFunctionBody() << "}\n"; - } - - // Main token processing loop - shader.MainFunctionBody() - << "\n// Process each token sequentially\n" - << "for (var t = 0u; t < uniforms.seq_length; t++) {\n" - << " let bt_offset = batch_idx * uniforms.seq_length + t;\n" - << " var k_val: f32 = 0.0;\n" - << " var q_val: f32 = 0.0;\n" - << " if (dk_idx < uniforms.head_dim_k) {\n" - << " let qk_idx = bt_offset * packed_dk + head_idx * uniforms.head_dim_k + dk_idx;\n" - << " k_val = f32(key[qk_idx]);\n" - << " q_val = f32(query[qk_idx]);\n" - << " }\n"; - - // Step 1: Apply decay (for gated and gated_delta modes) - if (update_rule_ == LinearAttentionUpdateRule::Gated || update_rule_ == LinearAttentionUpdateRule::GatedDelta) { - shader.MainFunctionBody() - << "\n // Apply exponential decay: S *= exp(decay)\n"; - if (decay_broadcast_dk_) { - shader.MainFunctionBody() - << " let exp_g = exp(f32(decay[bt_offset * uniforms.num_heads + head_idx]));\n"; - } else { - shader.MainFunctionBody() - << " var exp_g: f32 = 1.0;\n" - << " if (dk_idx < uniforms.head_dim_k) {\n" - << " exp_g = exp(f32(decay[bt_offset * packed_dk + head_idx * uniforms.head_dim_k + dk_idx]));\n" - << " }\n"; - } - if (use_vec4) { - shader.MainFunctionBody() << " state *= exp_g;\n"; - } else { - shader.MainFunctionBody() - << " for (var j = 0u; j < TILE_V; j++) {\n" - << " state[j] *= exp_g;\n" - << " }\n"; - } - } - - // Step 2: For delta/gated_delta rules, compute retrieved = S^T @ k (reduction across dk) - if (update_rule_ == LinearAttentionUpdateRule::Delta || update_rule_ == LinearAttentionUpdateRule::GatedDelta) { - if (use_vec4) { - shader.MainFunctionBody() - << "\n // Compute retrieved = S^T @ k (parallel reduction over dk)\n" - << " reduction_buf[dk_idx] = state * k_val;\n" - << " workgroupBarrier();\n" - << " for (var stride = workgroup_size_x >> 1u; stride > 0u; stride = stride >> 1u) {\n" - << " if (dk_idx < stride) {\n" - << " reduction_buf[dk_idx] += reduction_buf[dk_idx + stride];\n" - << " }\n" - << " workgroupBarrier();\n" - << " }\n" - << " // Compute delta = beta * (v - retrieved) and broadcast\n" - << " let v_idx = bt_offset * packed_dv + head_idx * uniforms.head_dim_v + dv_tile_idx;\n" - << " let beta_base = bt_offset * uniforms.num_heads + head_idx;\n" - << " if (dk_idx == 0u) {\n" - << " let beta_val = f32(beta[beta_base]);\n" - << " broadcast_val = beta_val * (vec4(value[v_idx]) - reduction_buf[0]);\n" - << " }\n" - << " workgroupBarrier();\n" - << " state += k_val * broadcast_val;\n" - << " workgroupBarrier();\n"; - } else { - shader.MainFunctionBody() - << "\n // Compute retrieved = S^T @ k (parallel reduction over dk)\n" - << " for (var j = 0u; j < TILE_V; j++) {\n" - << " reduction_buf[j * workgroup_size_x + dk_idx] = state[j] * k_val;\n" - << " }\n" - << " workgroupBarrier();\n" - << " // Tree reduction\n" - << " for (var stride = workgroup_size_x >> 1u; stride > 0u; stride = stride >> 1u) {\n" - << " if (dk_idx < stride) {\n" - << " for (var j = 0u; j < TILE_V; j++) {\n" - << " reduction_buf[j * workgroup_size_x + dk_idx] += reduction_buf[j * workgroup_size_x + dk_idx + stride];\n" - << " }\n" - << " }\n" - << " workgroupBarrier();\n" - << " }\n" - << " // Compute delta = beta * (v - retrieved) and broadcast\n" - << " let v_base = bt_offset * packed_dv + head_idx * uniforms.head_dim_v + dv_start;\n" - << " let beta_base = bt_offset * uniforms.num_heads + head_idx;\n" - << " if (dk_idx == 0u) {\n" - << " let beta_val = f32(beta[beta_base]);\n" - << " for (var j = 0u; j < TILE_V; j++) {\n" - << " if (dv_start + j < uniforms.head_dim_v) {\n" - << " let retrieved_j = reduction_buf[j * workgroup_size_x];\n" - << " let v_val = f32(value[v_base + j]);\n" - << " broadcast_buf[j] = beta_val * (v_val - retrieved_j);\n" - << " }\n" - << " }\n" - << " }\n" - << " workgroupBarrier();\n" - << " // Update state: S += k ⊗ delta\n" - << " for (var j = 0u; j < TILE_V; j++) {\n" - << " state[j] += k_val * broadcast_buf[j];\n" - << " }\n" - << " workgroupBarrier();\n"; - } - } else { - // For linear and gated modes: S += k ⊗ v (no delta rule) - if (use_vec4) { - shader.MainFunctionBody() - << "\n // Update state: S += k ⊗ v\n" - << " let v_idx = bt_offset * packed_dv + head_idx * uniforms.head_dim_v + dv_tile_idx;\n" - << " state += k_val * vec4(value[v_idx]);\n"; - } else { - shader.MainFunctionBody() - << "\n // Update state: S += k ⊗ v\n" - << " let v_base = bt_offset * packed_dv + head_idx * uniforms.head_dim_v + dv_start;\n" - << " for (var j = 0u; j < TILE_V; j++) {\n" - << " if (dv_start + j < uniforms.head_dim_v) {\n" - << " let v_val = f32(value[v_base + j]);\n" - << " state[j] += k_val * v_val;\n" - << " }\n" - << " }\n"; - } - } - - // Step 3: Compute output = scale * S^T @ q (reduction across dk) - if (use_vec4) { - shader.MainFunctionBody() - << "\n // Compute output = scale * S^T @ q (parallel reduction over dk)\n" - << " reduction_buf[dk_idx] = state * q_val;\n" - << " workgroupBarrier();\n" - << " for (var stride = workgroup_size_x >> 1u; stride > 0u; stride = stride >> 1u) {\n" - << " if (dk_idx < stride) {\n" - << " reduction_buf[dk_idx] += reduction_buf[dk_idx + stride];\n" - << " }\n" - << " workgroupBarrier();\n" - << " }\n" - << " if (dk_idx == 0u) {\n" - << " let out_idx = bt_offset * packed_dv + head_idx * uniforms.head_dim_v + dv_tile_idx;\n" - << " output[out_idx] = output_value_t(reduction_buf[0] * uniforms.scale);\n" - << " }\n" - << " workgroupBarrier();\n" - << "}\n"; // end token loop - } else { - shader.MainFunctionBody() - << "\n // Compute output = scale * S^T @ q (parallel reduction over dk)\n" - << " for (var j = 0u; j < TILE_V; j++) {\n" - << " reduction_buf[j * workgroup_size_x + dk_idx] = state[j] * q_val;\n" - << " }\n" - << " workgroupBarrier();\n" - << " for (var stride = workgroup_size_x >> 1u; stride > 0u; stride = stride >> 1u) {\n" - << " if (dk_idx < stride) {\n" - << " for (var j = 0u; j < TILE_V; j++) {\n" - << " reduction_buf[j * workgroup_size_x + dk_idx] += reduction_buf[j * workgroup_size_x + dk_idx + stride];\n" - << " }\n" - << " }\n" - << " workgroupBarrier();\n" - << " }\n" - << " if (dk_idx == 0u) {\n" - << " let out_base = bt_offset * packed_dv + head_idx * uniforms.head_dim_v + dv_start;\n" - << " for (var j = 0u; j < TILE_V; j++) {\n" - << " if (dv_start + j < uniforms.head_dim_v) {\n" - << " output[out_base + j] = output_element_t(reduction_buf[j * workgroup_size_x] * uniforms.scale);\n" - << " }\n" - << " }\n" - << " }\n" - << " workgroupBarrier();\n" - << "}\n"; // end token loop - } - - // Write final state (4D: B, H_kv, dk, dv) - shader.MainFunctionBody() - << "\n// Write present_state\n" - << "if (dk_idx < uniforms.head_dim_k) {\n"; - if (use_vec4) { - shader.MainFunctionBody() - << " let state_offset = ((batch_idx * uniforms.num_heads + head_idx) * uniforms.head_dim_k + dk_idx) * uniforms.head_dim_v + dv_tile_idx;\n" - << " present_state[state_offset] = output_value_t(state);\n"; - } else { - shader.MainFunctionBody() - << " let state_base = ((batch_idx * uniforms.num_heads + head_idx) * uniforms.head_dim_k + dk_idx) * uniforms.head_dim_v + dv_start;\n" - << " for (var j = 0u; j < TILE_V; j++) {\n" - << " if (dv_start + j < uniforms.head_dim_v) {\n" - << " present_state[state_base + j] = output_element_t(state[j]);\n" - << " }\n" - << " }\n"; - } - shader.MainFunctionBody() << "}\n"; - - return Status::OK(); + return WGSL_TEMPLATE_APPLY(shader, "bert/linear_attention.wgsl.template", + WGSL_TEMPLATE_PARAMETER(decay_broadcast_dk, decay_broadcast_dk_), + WGSL_TEMPLATE_PARAMETER(has_initial_state, has_initial_state_), + WGSL_TEMPLATE_PARAMETER(tile_v, tile_v_), + WGSL_TEMPLATE_PARAMETER(update_rule, update_rule_int), + WGSL_TEMPLATE_PARAMETER(use_vec4, use_vec4)); } // ============================================================================= @@ -330,7 +107,6 @@ LinearAttention::LinearAttention(const OpKernelInfo& info) kv_num_heads_ = static_cast(info.GetAttr("kv_num_heads")); } - /* 3D packed inputs: query: (B, T, H_q * d_k) — packed query @@ -348,9 +124,9 @@ Status LinearAttention::ComputeInternal(ComputeContext& context) const { const Tensor* query = context.Input(0); const Tensor* key = context.Input(1); const Tensor* value = context.Input(2); - const Tensor* past_state = context.Input(3); // optional - const Tensor* decay = context.Input(4); // optional - const Tensor* beta = context.Input(5); // optional + const Tensor* past_state = context.Input(3); // optional + const Tensor* decay = context.Input(4); // optional + const Tensor* beta = context.Input(5); // optional // Validate 3D packed inputs const auto& q_shape = query->Shape(); diff --git a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.wgsl.template new file mode 100644 index 0000000000000..338b70492e3a3 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.wgsl.template @@ -0,0 +1,247 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// +// LinearAttention shader +// +// Design overview: +// - Each workgroup handles one (batch, head, dv_tile) combination +// - Workgroup size = head_dim_k (dk): one thread per state row +// - Each thread maintains TILE_V columns of its state row in private memory +// - Tokens are processed sequentially; matrix ops are parallelized across threads +// - Reductions across dk (for S^T @ k and S^T @ q) use shared memory +// + +#param use_vec4 +#param has_initial_state +#param decay_broadcast_dk +#param tile_v + +// Update rule constants +#define UPDATE_LINEAR 0 +#define UPDATE_GATED 1 +#define UPDATE_DELTA 2 +#define UPDATE_GATED_DELTA 3 +#param update_rule + +// Shared memory for parallel reduction across dk threads +// and for broadcasting delta values. +#if use_vec4 +// When use_vec4, each reduction_buf entry is a vec4 (4 dv values packed), +// eliminating the inner TILE_V loop and enabling native SIMD operations. +var reduction_buf: array, workgroup_size_x>; +var broadcast_val: vec4; +#else +// TILE_V is emitted as a compile-time constant because +// private address space arrays require fixed sizes in WGSL. +const TILE_V: u32 = tile_v; +var reduction_buf: array; +var broadcast_buf: array; +#endif + +$MAIN { + // Identify which (batch, head, dv_tile) this workgroup handles + let bh = workgroup_idx / uniforms.num_dv_tiles; + let dv_tile_idx = workgroup_idx % uniforms.num_dv_tiles; + let batch_idx = bh / uniforms.num_heads; + let head_idx = bh % uniforms.num_heads; + let dk_idx = local_idx; // thread index = row in state matrix +#if !use_vec4 + let dv_start = dv_tile_idx * TILE_V; +#endif + + // Precompute packed strides for 3D packed inputs (B, T, H*D) + // When use_vec4, head_dim_v is already divided by 4 (vectorized), so + // packed_dv = num_heads * (head_dim_v/4) and dv_tile_idx indexes vec4 elements. + let packed_dk = uniforms.num_heads * uniforms.head_dim_k; + let packed_dv = uniforms.num_heads * uniforms.head_dim_v; + + // Initialize state tile in private memory +#if use_vec4 + var state = vec4(0.0); +#else + var state: array; + for (var j = 0u; j < TILE_V; j++) { + state[j] = 0.0; + } +#endif + + // Load initial state if provided +#if has_initial_state + if (dk_idx < uniforms.head_dim_k) { +#if use_vec4 + let state_offset = ((batch_idx * uniforms.num_heads + head_idx) * uniforms.head_dim_k + dk_idx) * uniforms.head_dim_v + dv_tile_idx; + state = vec4(initial_state[state_offset]); +#else + let state_base = ((batch_idx * uniforms.num_heads + head_idx) * uniforms.head_dim_k + dk_idx) * uniforms.head_dim_v + dv_start; + for (var j = 0u; j < TILE_V; j++) { + if (dv_start + j < uniforms.head_dim_v) { + state[j] = f32(initial_state[state_base + j]); + } + } +#endif + } +#endif + + // Process each token sequentially + for (var t = 0u; t < uniforms.seq_length; t++) { + let bt_offset = batch_idx * uniforms.seq_length + t; + var k_val: f32 = 0.0; + var q_val: f32 = 0.0; + if (dk_idx < uniforms.head_dim_k) { + let qk_idx = bt_offset * packed_dk + head_idx * uniforms.head_dim_k + dk_idx; + k_val = f32(key[qk_idx]); + q_val = f32(query[qk_idx]); + } + + // Step 1: Apply decay (for gated and gated_delta modes) +#if update_rule == UPDATE_GATED || update_rule == UPDATE_GATED_DELTA + // Apply exponential decay: S *= exp(decay) +#if decay_broadcast_dk + let exp_g = exp(f32(decay[bt_offset * uniforms.num_heads + head_idx])); +#else + var exp_g: f32 = 1.0; + if (dk_idx < uniforms.head_dim_k) { + exp_g = exp(f32(decay[bt_offset * packed_dk + head_idx * uniforms.head_dim_k + dk_idx])); + } +#endif +#if use_vec4 + state *= exp_g; +#else + for (var j = 0u; j < TILE_V; j++) { + state[j] *= exp_g; + } +#endif +#endif + + // Step 2: State update +#if update_rule == UPDATE_DELTA || update_rule == UPDATE_GATED_DELTA + // For delta/gated_delta rules, compute retrieved = S^T @ k (reduction across dk) +#if use_vec4 + // Compute retrieved = S^T @ k (parallel reduction over dk) + reduction_buf[dk_idx] = state * k_val; + workgroupBarrier(); + for (var stride = workgroup_size_x >> 1u; stride > 0u; stride = stride >> 1u) { + if (dk_idx < stride) { + reduction_buf[dk_idx] += reduction_buf[dk_idx + stride]; + } + workgroupBarrier(); + } + // Compute delta = beta * (v - retrieved) and broadcast + let v_idx = bt_offset * packed_dv + head_idx * uniforms.head_dim_v + dv_tile_idx; + let beta_base = bt_offset * uniforms.num_heads + head_idx; + if (dk_idx == 0u) { + let beta_val = f32(beta[beta_base]); + broadcast_val = beta_val * (vec4(value[v_idx]) - reduction_buf[0]); + } + workgroupBarrier(); + state += k_val * broadcast_val; + workgroupBarrier(); +#else + // Compute retrieved = S^T @ k (parallel reduction over dk) + for (var j = 0u; j < TILE_V; j++) { + reduction_buf[j * workgroup_size_x + dk_idx] = state[j] * k_val; + } + workgroupBarrier(); + // Tree reduction + for (var stride = workgroup_size_x >> 1u; stride > 0u; stride = stride >> 1u) { + if (dk_idx < stride) { + for (var j = 0u; j < TILE_V; j++) { + reduction_buf[j * workgroup_size_x + dk_idx] += reduction_buf[j * workgroup_size_x + dk_idx + stride]; + } + } + workgroupBarrier(); + } + // Compute delta = beta * (v - retrieved) and broadcast + let v_base = bt_offset * packed_dv + head_idx * uniforms.head_dim_v + dv_start; + let beta_base = bt_offset * uniforms.num_heads + head_idx; + if (dk_idx == 0u) { + let beta_val = f32(beta[beta_base]); + for (var j = 0u; j < TILE_V; j++) { + if (dv_start + j < uniforms.head_dim_v) { + let retrieved_j = reduction_buf[j * workgroup_size_x]; + let v_val = f32(value[v_base + j]); + broadcast_buf[j] = beta_val * (v_val - retrieved_j); + } + } + } + workgroupBarrier(); + // Update state: S += k ⊗ delta + for (var j = 0u; j < TILE_V; j++) { + state[j] += k_val * broadcast_buf[j]; + } + workgroupBarrier(); +#endif +#else + // For linear and gated modes: S += k ⊗ v (no delta rule) +#if use_vec4 + // Update state: S += k ⊗ v + let v_idx = bt_offset * packed_dv + head_idx * uniforms.head_dim_v + dv_tile_idx; + state += k_val * vec4(value[v_idx]); +#else + // Update state: S += k ⊗ v + let v_base = bt_offset * packed_dv + head_idx * uniforms.head_dim_v + dv_start; + for (var j = 0u; j < TILE_V; j++) { + if (dv_start + j < uniforms.head_dim_v) { + let v_val = f32(value[v_base + j]); + state[j] += k_val * v_val; + } + } +#endif +#endif + + // Step 3: Compute output = scale * S^T @ q (parallel reduction over dk) +#if use_vec4 + reduction_buf[dk_idx] = state * q_val; + workgroupBarrier(); + for (var stride = workgroup_size_x >> 1u; stride > 0u; stride = stride >> 1u) { + if (dk_idx < stride) { + reduction_buf[dk_idx] += reduction_buf[dk_idx + stride]; + } + workgroupBarrier(); + } + if (dk_idx == 0u) { + let out_idx = bt_offset * packed_dv + head_idx * uniforms.head_dim_v + dv_tile_idx; + output[out_idx] = output_value_t(reduction_buf[0] * uniforms.scale); + } + workgroupBarrier(); +#else + for (var j = 0u; j < TILE_V; j++) { + reduction_buf[j * workgroup_size_x + dk_idx] = state[j] * q_val; + } + workgroupBarrier(); + for (var stride = workgroup_size_x >> 1u; stride > 0u; stride = stride >> 1u) { + if (dk_idx < stride) { + for (var j = 0u; j < TILE_V; j++) { + reduction_buf[j * workgroup_size_x + dk_idx] += reduction_buf[j * workgroup_size_x + dk_idx + stride]; + } + } + workgroupBarrier(); + } + if (dk_idx == 0u) { + let out_base = bt_offset * packed_dv + head_idx * uniforms.head_dim_v + dv_start; + for (var j = 0u; j < TILE_V; j++) { + if (dv_start + j < uniforms.head_dim_v) { + output[out_base + j] = output_element_t(reduction_buf[j * workgroup_size_x] * uniforms.scale); + } + } + } + workgroupBarrier(); +#endif + } // end token loop + + // Write present_state + if (dk_idx < uniforms.head_dim_k) { +#if use_vec4 + let state_offset = ((batch_idx * uniforms.num_heads + head_idx) * uniforms.head_dim_k + dk_idx) * uniforms.head_dim_v + dv_tile_idx; + present_state[state_offset] = output_value_t(state); +#else + let state_base = ((batch_idx * uniforms.num_heads + head_idx) * uniforms.head_dim_k + dk_idx) * uniforms.head_dim_v + dv_start; + for (var j = 0u; j < TILE_V; j++) { + if (dv_start + j < uniforms.head_dim_v) { + present_state[state_base + j] = output_element_t(state[j]); + } + } +#endif + } +} // MAIN diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 2feed9b30a9c6..83c1e79e5cb6d 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -2428,7 +2428,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( auto& query_shape = getInputShape(ctx, 0); auto& value_shape = getInputShape(ctx, 2); TensorShapeProto state_shape; - *state_shape.add_dim() = query_shape.dim(0); // B + *state_shape.add_dim() = query_shape.dim(0); // B state_shape.add_dim()->set_dim_value(kv_num_heads); // H_kv // d_k = query.dim(2) / q_num_heads if (query_shape.dim(2).has_dim_value()) { @@ -2448,6 +2448,5 @@ ONNX_MS_OPERATOR_SET_SCHEMA( } })); - } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/concat.cc b/onnxruntime/core/providers/webgpu/tensor/concat.cc index 55f4e2c5d0e5f..d393a6ce1561b 100644 --- a/onnxruntime/core/providers/webgpu/tensor/concat.cc +++ b/onnxruntime/core/providers/webgpu/tensor/concat.cc @@ -12,24 +12,24 @@ namespace onnxruntime { namespace webgpu { -#define WEBGPU_CONCAT_VERSIONED_KERNEL(start, end) \ - ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ - Concat, \ - kOnnxDomain, \ - start, \ - end, \ - kWebGpuExecutionProvider, \ - (*KernelDefBuilder::Create()) \ +#define WEBGPU_CONCAT_VERSIONED_KERNEL(start, end) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + Concat, \ + kOnnxDomain, \ + start, \ + end, \ + kWebGpuExecutionProvider, \ + (*KernelDefBuilder::Create()) \ .TypeConstraint("T", WebGpuSupportedNumberTypes()), \ Concat); -#define WEBGPU_CONCAT_KERNEL(version) \ - ONNX_OPERATOR_KERNEL_EX( \ - Concat, \ - kOnnxDomain, \ - version, \ - kWebGpuExecutionProvider, \ - (*KernelDefBuilder::Create()) \ +#define WEBGPU_CONCAT_KERNEL(version) \ + ONNX_OPERATOR_KERNEL_EX( \ + Concat, \ + kOnnxDomain, \ + version, \ + kWebGpuExecutionProvider, \ + (*KernelDefBuilder::Create()) \ .TypeConstraint("T", WebGpuSupportedNumberTypes()), \ Concat); From 9ebcd61d8b4ce6f5678ddcd04fbd4d97b5239c50 Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Mon, 30 Mar 2026 23:04:47 -0700 Subject: [PATCH 23/27] optimize number of barriers --- .../webgpu/bert/linear_attention.cc | 3 +- .../bert/linear_attention.wgsl.template | 150 +++++++----------- .../contrib_ops/linear_attention_op_test.cc | 78 +++++++++ 3 files changed, 132 insertions(+), 99 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc index 39bf907f9b741..803fe3e088908 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc @@ -5,7 +5,6 @@ #include "core/providers/webgpu/shader_helper.h" #include "core/providers/webgpu/webgpu_supported_types.h" -// #include "core/providers/webgpu/wgsl_templates/wgsl_gen.h" #include "contrib_ops/webgpu/webgpu_contrib_kernels.h" using namespace onnxruntime::webgpu; @@ -179,7 +178,6 @@ Status LinearAttention::ComputeInternal(ComputeContext& context) const { tile_v = head_dim_v; } const int head_dim_v_vectorized = head_dim_v / components; - const int num_dv_tiles = (head_dim_v_vectorized + tile_v - 1) / tile_v; // Workgroup size = head_dim_k (one thread per dk row) // Ensure it's a power of 2 for tree reduction (round up) @@ -190,6 +188,7 @@ Status LinearAttention::ComputeInternal(ComputeContext& context) const { // Cap at GPU limits workgroup_size = std::min(workgroup_size, static_cast(256)); + const int num_dv_tiles = (head_dim_v_vectorized + tile_v - 1) / tile_v; const uint32_t num_workgroups = batch_size * num_heads * num_dv_tiles; bool has_initial_state = past_state != nullptr; diff --git a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.wgsl.template index 338b70492e3a3..6729a0a0abbfb 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.wgsl.template @@ -10,6 +10,9 @@ // - Each thread maintains TILE_V columns of its state row in private memory // - Tokens are processed sequentially; matrix ops are parallelized across threads // - Reductions across dk (for S^T @ k and S^T @ q) use shared memory +// - For delta/gated_delta: the S^T@k (retrieval) and S^T@q (output) reductions +// are fused into a single barrier tree, nearly halving barrier count per token. +// Key identity: output = scale * (S_old^T @ q + delta * (k^T @ q)) // #param use_vec4 @@ -24,19 +27,29 @@ #define UPDATE_GATED_DELTA 3 #param update_rule -// Shared memory for parallel reduction across dk threads -// and for broadcasting delta values. +// Type aliases: vtype is the element type for state and reductions. +// otype is the output storage type. #if use_vec4 -// When use_vec4, each reduction_buf entry is a vec4 (4 dv values packed), -// eliminating the inner TILE_V loop and enabling native SIMD operations. -var reduction_buf: array, workgroup_size_x>; -var broadcast_val: vec4; +alias vtype = vec4; +alias otype = output_value_t; #else -// TILE_V is emitted as a compile-time constant because -// private address space arrays require fixed sizes in WGSL. +alias vtype = f32; +alias otype = output_element_t; +#endif + const TILE_V: u32 = tile_v; -var reduction_buf: array; -var broadcast_buf: array; + +// Shared memory for parallel reduction across dk threads. +#if update_rule == UPDATE_DELTA || update_rule == UPDATE_GATED_DELTA +// Fused reduction: retrieved (S^T@k), pre_output (S^T@q), and kq_dot (k^T@q) +// are reduced in a single barrier tree. +var red_retrieved: array; +var red_preout: array; +var red_kq: array; +var broadcast_buf: array; +#else +// Output-only reduction for linear/gated. +var reduction_buf: array; #endif $MAIN { @@ -46,40 +59,27 @@ $MAIN { let batch_idx = bh / uniforms.num_heads; let head_idx = bh % uniforms.num_heads; let dk_idx = local_idx; // thread index = row in state matrix -#if !use_vec4 let dv_start = dv_tile_idx * TILE_V; -#endif // Precompute packed strides for 3D packed inputs (B, T, H*D) - // When use_vec4, head_dim_v is already divided by 4 (vectorized), so - // packed_dv = num_heads * (head_dim_v/4) and dv_tile_idx indexes vec4 elements. let packed_dk = uniforms.num_heads * uniforms.head_dim_k; let packed_dv = uniforms.num_heads * uniforms.head_dim_v; // Initialize state tile in private memory -#if use_vec4 - var state = vec4(0.0); -#else - var state: array; + var state: array; for (var j = 0u; j < TILE_V; j++) { - state[j] = 0.0; + state[j] = vtype(0.0); } -#endif // Load initial state if provided #if has_initial_state if (dk_idx < uniforms.head_dim_k) { -#if use_vec4 - let state_offset = ((batch_idx * uniforms.num_heads + head_idx) * uniforms.head_dim_k + dk_idx) * uniforms.head_dim_v + dv_tile_idx; - state = vec4(initial_state[state_offset]); -#else let state_base = ((batch_idx * uniforms.num_heads + head_idx) * uniforms.head_dim_k + dk_idx) * uniforms.head_dim_v + dv_start; for (var j = 0u; j < TILE_V; j++) { if (dv_start + j < uniforms.head_dim_v) { - state[j] = f32(initial_state[state_base + j]); + state[j] = vtype(initial_state[state_base + j]); } } -#endif } #endif @@ -105,107 +105,68 @@ $MAIN { exp_g = exp(f32(decay[bt_offset * packed_dk + head_idx * uniforms.head_dim_k + dk_idx])); } #endif -#if use_vec4 - state *= exp_g; -#else for (var j = 0u; j < TILE_V; j++) { state[j] *= exp_g; } -#endif #endif - // Step 2: State update #if update_rule == UPDATE_DELTA || update_rule == UPDATE_GATED_DELTA - // For delta/gated_delta rules, compute retrieved = S^T @ k (reduction across dk) -#if use_vec4 - // Compute retrieved = S^T @ k (parallel reduction over dk) - reduction_buf[dk_idx] = state * k_val; - workgroupBarrier(); - for (var stride = workgroup_size_x >> 1u; stride > 0u; stride = stride >> 1u) { - if (dk_idx < stride) { - reduction_buf[dk_idx] += reduction_buf[dk_idx + stride]; - } - workgroupBarrier(); - } - // Compute delta = beta * (v - retrieved) and broadcast - let v_idx = bt_offset * packed_dv + head_idx * uniforms.head_dim_v + dv_tile_idx; - let beta_base = bt_offset * uniforms.num_heads + head_idx; - if (dk_idx == 0u) { - let beta_val = f32(beta[beta_base]); - broadcast_val = beta_val * (vec4(value[v_idx]) - reduction_buf[0]); - } - workgroupBarrier(); - state += k_val * broadcast_val; - workgroupBarrier(); -#else - // Compute retrieved = S^T @ k (parallel reduction over dk) + // Fused reduction: compute retrieved = S^T@k, pre_output = S^T@q, + // and kq_dot = k^T@q in a single barrier tree. + // Then: output = scale * (pre_output + delta * kq_dot) for (var j = 0u; j < TILE_V; j++) { - reduction_buf[j * workgroup_size_x + dk_idx] = state[j] * k_val; + red_retrieved[j * workgroup_size_x + dk_idx] = state[j] * k_val; + red_preout[j * workgroup_size_x + dk_idx] = state[j] * q_val; } + red_kq[dk_idx] = k_val * q_val; workgroupBarrier(); - // Tree reduction + for (var stride = workgroup_size_x >> 1u; stride > 0u; stride = stride >> 1u) { if (dk_idx < stride) { for (var j = 0u; j < TILE_V; j++) { - reduction_buf[j * workgroup_size_x + dk_idx] += reduction_buf[j * workgroup_size_x + dk_idx + stride]; + red_retrieved[j * workgroup_size_x + dk_idx] += red_retrieved[j * workgroup_size_x + dk_idx + stride]; + red_preout[j * workgroup_size_x + dk_idx] += red_preout[j * workgroup_size_x + dk_idx + stride]; } + red_kq[dk_idx] += red_kq[dk_idx + stride]; } workgroupBarrier(); } - // Compute delta = beta * (v - retrieved) and broadcast + + // Thread 0: compute delta, broadcast it, and write fused output. let v_base = bt_offset * packed_dv + head_idx * uniforms.head_dim_v + dv_start; let beta_base = bt_offset * uniforms.num_heads + head_idx; if (dk_idx == 0u) { let beta_val = f32(beta[beta_base]); + let kq_dot = red_kq[0]; for (var j = 0u; j < TILE_V; j++) { if (dv_start + j < uniforms.head_dim_v) { - let retrieved_j = reduction_buf[j * workgroup_size_x]; - let v_val = f32(value[v_base + j]); - broadcast_buf[j] = beta_val * (v_val - retrieved_j); + let retrieved = red_retrieved[j * workgroup_size_x]; + let pre_out = red_preout[j * workgroup_size_x]; + let v_val = vtype(value[v_base + j]); + let delta_j = beta_val * (v_val - retrieved); + broadcast_buf[j] = delta_j; + output[v_base + j] = otype((pre_out + delta_j * kq_dot) * uniforms.scale); } } } workgroupBarrier(); - // Update state: S += k ⊗ delta + + // All threads: update state with delta for (var j = 0u; j < TILE_V; j++) { state[j] += k_val * broadcast_buf[j]; } workgroupBarrier(); -#endif -#else - // For linear and gated modes: S += k ⊗ v (no delta rule) -#if use_vec4 - // Update state: S += k ⊗ v - let v_idx = bt_offset * packed_dv + head_idx * uniforms.head_dim_v + dv_tile_idx; - state += k_val * vec4(value[v_idx]); + #else - // Update state: S += k ⊗ v + // Linear/gated: S += k ⊗ v, then output = scale * S^T @ q let v_base = bt_offset * packed_dv + head_idx * uniforms.head_dim_v + dv_start; for (var j = 0u; j < TILE_V; j++) { if (dv_start + j < uniforms.head_dim_v) { - let v_val = f32(value[v_base + j]); - state[j] += k_val * v_val; + state[j] += k_val * vtype(value[v_base + j]); } } -#endif -#endif - // Step 3: Compute output = scale * S^T @ q (parallel reduction over dk) -#if use_vec4 - reduction_buf[dk_idx] = state * q_val; - workgroupBarrier(); - for (var stride = workgroup_size_x >> 1u; stride > 0u; stride = stride >> 1u) { - if (dk_idx < stride) { - reduction_buf[dk_idx] += reduction_buf[dk_idx + stride]; - } - workgroupBarrier(); - } - if (dk_idx == 0u) { - let out_idx = bt_offset * packed_dv + head_idx * uniforms.head_dim_v + dv_tile_idx; - output[out_idx] = output_value_t(reduction_buf[0] * uniforms.scale); - } - workgroupBarrier(); -#else + // Output = scale * S^T @ q (parallel reduction over dk) for (var j = 0u; j < TILE_V; j++) { reduction_buf[j * workgroup_size_x + dk_idx] = state[j] * q_val; } @@ -222,7 +183,7 @@ $MAIN { let out_base = bt_offset * packed_dv + head_idx * uniforms.head_dim_v + dv_start; for (var j = 0u; j < TILE_V; j++) { if (dv_start + j < uniforms.head_dim_v) { - output[out_base + j] = output_element_t(reduction_buf[j * workgroup_size_x] * uniforms.scale); + output[out_base + j] = otype(reduction_buf[j * workgroup_size_x] * uniforms.scale); } } } @@ -232,16 +193,11 @@ $MAIN { // Write present_state if (dk_idx < uniforms.head_dim_k) { -#if use_vec4 - let state_offset = ((batch_idx * uniforms.num_heads + head_idx) * uniforms.head_dim_k + dk_idx) * uniforms.head_dim_v + dv_tile_idx; - present_state[state_offset] = output_value_t(state); -#else let state_base = ((batch_idx * uniforms.num_heads + head_idx) * uniforms.head_dim_k + dk_idx) * uniforms.head_dim_v + dv_start; for (var j = 0u; j < TILE_V; j++) { if (dv_start + j < uniforms.head_dim_v) { - present_state[state_base + j] = output_element_t(state[j]); + present_state[state_base + j] = otype(state[j]); } } -#endif } } // MAIN diff --git a/onnxruntime/test/contrib_ops/linear_attention_op_test.cc b/onnxruntime/test/contrib_ops/linear_attention_op_test.cc index 84ba7a2c361a6..a6c914afc0c3f 100644 --- a/onnxruntime/test/contrib_ops/linear_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/linear_attention_op_test.cc @@ -729,5 +729,83 @@ TEST(ContribOpLinearAttentionTest, GatedDeltaRule_NonPowerOf2DK) { &initial_state, &decay, &beta); } +// =========================================================================== +// Tests: Larger dimensions exercising multi-tile vec4 path (tile_v > 1) +// =========================================================================== +TEST(ContribOpLinearAttentionTest, LinearRule_LargerDims) { + const int B = 1, H = 2, T = 4, dk = 16, dv = 64; + float scale = 1.0f / std::sqrt(static_cast(dk)); + + std::vector query(B * H * T * dk); + std::vector key(B * H * T * dk); + std::vector value(B * H * T * dv); + + for (int i = 0; i < B * H * T * dk; i++) { + query[i] = 0.1f * std::sin(static_cast(i) * 0.13f); + key[i] = 0.1f * std::cos(static_cast(i) * 0.17f); + } + for (int i = 0; i < B * H * T * dv; i++) { + value[i] = 0.1f * std::sin(static_cast(i) * 0.23f + 0.5f); + } + + RunLinearAttentionTest("linear", B, H, T, dk, dv, scale, + query, key, value, + nullptr, nullptr, nullptr); +} + +TEST(ContribOpLinearAttentionTest, GatedRule_LargerDims) { + const int B = 1, H = 2, T = 4, dk = 32, dv = 64; + float scale = 1.0f / std::sqrt(static_cast(dk)); + + std::vector query(B * H * T * dk); + std::vector key(B * H * T * dk); + std::vector value(B * H * T * dv); + std::vector decay(B * H * T * dk); + + for (int i = 0; i < B * H * T * dk; i++) { + query[i] = 0.1f * std::sin(static_cast(i) * 0.13f); + key[i] = 0.1f * std::cos(static_cast(i) * 0.17f); + decay[i] = -0.05f - 0.05f * std::abs(std::sin(static_cast(i) * 0.07f)); + } + for (int i = 0; i < B * H * T * dv; i++) { + value[i] = 0.1f * std::sin(static_cast(i) * 0.23f + 0.5f); + } + + std::vector initial_state(B * H * dk * dv, 0.01f); + + RunLinearAttentionTest("gated", B, H, T, dk, dv, scale, + query, key, value, + &initial_state, &decay, nullptr); +} + +TEST(ContribOpLinearAttentionTest, GatedDeltaRule_LargerDims) { + const int B = 2, H = 2, T = 4, dk = 32, dv = 64; + float scale = 1.0f / std::sqrt(static_cast(dk)); + + std::vector query(B * H * T * dk); + std::vector key(B * H * T * dk); + std::vector value(B * H * T * dv); + std::vector decay(B * H * T * dk); + std::vector beta(B * H * T); + + for (int i = 0; i < B * H * T * dk; i++) { + query[i] = 0.1f * std::sin(static_cast(i) * 0.13f); + key[i] = 0.1f * std::cos(static_cast(i) * 0.17f); + decay[i] = -0.05f - 0.05f * std::abs(std::sin(static_cast(i) * 0.07f)); + } + for (int i = 0; i < B * H * T * dv; i++) { + value[i] = 0.1f * std::sin(static_cast(i) * 0.23f + 0.5f); + } + for (int i = 0; i < B * H * T; i++) { + beta[i] = 0.5f + 0.3f * std::sin(static_cast(i) * 0.31f); + } + + std::vector initial_state(B * H * dk * dv, 0.01f); + + RunLinearAttentionTest("gated_delta", B, H, T, dk, dv, scale, + query, key, value, + &initial_state, &decay, &beta); +} + } // namespace test } // namespace onnxruntime From 079a33bd38a59e728602cbd285127c70412016c3 Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Tue, 31 Mar 2026 08:32:43 -0700 Subject: [PATCH 24/27] sync with #27842 --- .../contrib_ops/webgpu/bert/causal_conv_with_state.cc | 1 - onnxruntime/core/graph/contrib_ops/bert_defs.cc | 6 +++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.cc b/onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.cc index 527f3c7d3027f..74ac412ece474 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.cc @@ -5,7 +5,6 @@ #include "core/providers/webgpu/shader_helper.h" #include "core/providers/webgpu/webgpu_supported_types.h" -// #include "core/providers/webgpu/wgsl_templates/wgsl_gen.h" #include "contrib_ops/webgpu/webgpu_contrib_kernels.h" using namespace onnxruntime::webgpu; diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 83c1e79e5cb6d..bfc85a54b0f4c 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -2280,7 +2280,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "Contains the last (k-1) values from the virtual input along the causal axis.", "T") .TypeConstraint("T", - {"tensor(float)", "tensor(float16)"}, + {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float tensors.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { propagateElemTypeFromInputToOutput(ctx, 0, 0); @@ -2333,7 +2333,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "The update rule for the linear attention recurrence. " "One of: 'linear', 'gated', 'delta', 'gated_delta'. Default is 'gated_delta'.", AttributeProto::STRING, - std::string("gated_delta")) + "gated_delta") .Attr("scale", "Output scaling factor. When 0.0 (default), derives d_k = query.shape[-1] / q_num_heads " "and uses 1/sqrt(d_k). Set explicitly to override.", @@ -2394,7 +2394,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "Updated recurrent state with shape (B, H_kv, d_k, d_v). Always 4D.", "T") .TypeConstraint("T", - {"tensor(float)", "tensor(float16)"}, + {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float tensors.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { propagateElemTypeFromInputToOutput(ctx, 0, 0); From 085b2747aed35b99da58b072f6af7f276f19a625 Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Tue, 31 Mar 2026 09:19:44 -0700 Subject: [PATCH 25/27] update_rule need to stay std::string --- onnxruntime/core/graph/contrib_ops/bert_defs.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index bfc85a54b0f4c..fcc3d6c8d438b 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -2333,7 +2333,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "The update rule for the linear attention recurrence. " "One of: 'linear', 'gated', 'delta', 'gated_delta'. Default is 'gated_delta'.", AttributeProto::STRING, - "gated_delta") + std::string("gated_delta")) .Attr("scale", "Output scaling factor. When 0.0 (default), derives d_k = query.shape[-1] / q_num_heads " "and uses 1/sqrt(d_k). Set explicitly to override.", From 6474e3e8adb80d6867b15227e983918bc8c182b6 Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Thu, 2 Apr 2026 00:02:52 -0700 Subject: [PATCH 26/27] add support for inverse GQA, needed for Qwen3.5-4/9B --- .../webgpu/bert/linear_attention.cc | 38 +- .../webgpu/bert/linear_attention.h | 6 +- .../bert/linear_attention.wgsl.template | 149 +++++-- .../core/graph/contrib_ops/bert_defs.cc | 21 +- .../contrib_ops/linear_attention_op_test.cc | 396 +++++++++++++++++- 5 files changed, 565 insertions(+), 45 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc index 803fe3e088908..338b89b4ce201 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc @@ -130,16 +130,15 @@ Status LinearAttention::ComputeInternal(ComputeContext& context) const { // Validate 3D packed inputs const auto& q_shape = query->Shape(); ORT_RETURN_IF(q_shape.NumDimensions() != 3, "query must be 3D (B, T, H_q*d_k)"); + const auto& k_shape = key->Shape(); const int batch_size = static_cast(q_shape[0]); const int seq_length = static_cast(q_shape[1]); const int q_packed_dim = static_cast(q_shape[2]); const int num_heads = kv_num_heads_; - - ORT_RETURN_IF(q_num_heads_ != kv_num_heads_, - "GQA (q_num_heads != kv_num_heads) is not yet supported"); - const int head_dim_k = q_packed_dim / q_num_heads_; + const int n_k_heads = static_cast(k_shape[2] / head_dim_k); + ORT_RETURN_IF(q_packed_dim != head_dim_k * q_num_heads_, "query packed dim must be divisible by q_num_heads"); @@ -148,6 +147,26 @@ Status LinearAttention::ComputeInternal(ComputeContext& context) const { ORT_RETURN_IF(v_packed_dim != head_dim_v * kv_num_heads_, "value packed dim must be divisible by kv_num_heads"); + // ==== GQA head mapping ==== + // Standard GQA: q_num_heads >= kv_num_heads, multiple Q heads per KV group. + // Inverse GQA: q_num_heads < kv_num_heads (e.g., Qwen3.5 9B: n_k=16, n_kv=32). + // Also n_k_heads may differ from both (K has its own head count). + int heads_per_group; // Q heads per KV group (0 if inverse GQA) + if (q_num_heads_ >= kv_num_heads_) { + ORT_RETURN_IF_NOT(q_num_heads_ % kv_num_heads_ == 0, + "q_num_heads must be divisible by kv_num_heads"); + heads_per_group = q_num_heads_ / kv_num_heads_; + } else { + ORT_RETURN_IF_NOT(kv_num_heads_ % q_num_heads_ == 0, + "kv_num_heads must be divisible by q_num_heads (inverse GQA)"); + heads_per_group = 0; // signals inverse GQA + } + + // K-to-KV head mapping: when n_k < kv_num_heads, multiple KV heads share one K head + ORT_RETURN_IF_NOT(kv_num_heads_ % n_k_heads == 0, + "kv_num_heads must be divisible by n_k_heads"); + int kv_per_k_head = kv_num_heads_ / n_k_heads; + // Validate update rule has required inputs bool needs_decay = (update_rule_ == LinearAttentionUpdateRule::Gated || update_rule_ == LinearAttentionUpdateRule::GatedDelta); @@ -163,7 +182,10 @@ Status LinearAttention::ComputeInternal(ComputeContext& context) const { } // Allocate outputs — output is 3D packed, state is 4D - TensorShapeVector output_shape({batch_size, seq_length, q_num_heads_ * head_dim_v}); + // Output uses kv_num_heads (matches schema inference: output_dim == V_dim). + // For inverse GQA (q < kv): each KV head writes its own output slot. + // For standard/MHA (q >= kv): q == kv with this schema, so equivalent. + TensorShapeVector output_shape({batch_size, seq_length, kv_num_heads_ * head_dim_v}); Tensor* output = context.Output(0, output_shape); TensorShapeVector state_shape({batch_size, num_heads, head_dim_k, head_dim_v}); @@ -234,7 +256,11 @@ Status LinearAttention::ComputeInternal(ComputeContext& context) const { {static_cast(head_dim_k)}, {static_cast(head_dim_v_vectorized)}, {scale}, - {static_cast(num_dv_tiles)}}); + {static_cast(num_dv_tiles)}, + {static_cast(heads_per_group)}, + {static_cast(kv_per_k_head)}, + {static_cast(q_num_heads_)}, + {static_cast(n_k_heads)}}); return context.RunProgram(program); } diff --git a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.h b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.h index b55e3a68cfe6b..91ac39ae40b61 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.h @@ -50,7 +50,11 @@ class LinearAttentionProgram final : public Program { {"head_dim_k", ProgramUniformVariableDataType::Uint32}, {"head_dim_v", ProgramUniformVariableDataType::Uint32}, {"scale", ProgramUniformVariableDataType::Float32}, - {"num_dv_tiles", ProgramUniformVariableDataType::Uint32}); + {"num_dv_tiles", ProgramUniformVariableDataType::Uint32}, + {"heads_per_group", ProgramUniformVariableDataType::Uint32}, + {"kv_per_k_head", ProgramUniformVariableDataType::Uint32}, + {"q_num_heads", ProgramUniformVariableDataType::Uint32}, + {"n_k_heads", ProgramUniformVariableDataType::Uint32}); private: LinearAttentionUpdateRule update_rule_; diff --git a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.wgsl.template index 6729a0a0abbfb..34b9ed51b35d3 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.wgsl.template @@ -62,7 +62,11 @@ $MAIN { let dv_start = dv_tile_idx * TILE_V; // Precompute packed strides for 3D packed inputs (B, T, H*D) - let packed_dk = uniforms.num_heads * uniforms.head_dim_k; + // Q: (B, T, q_num_heads * dk), K: (B, T, n_k_heads * dk), + // V/output: (B, T, num_heads * dv) [schema: output_dim == V_dim] + let packed_dk_q = uniforms.q_num_heads * uniforms.head_dim_k; + let packed_dk_k = uniforms.n_k_heads * uniforms.head_dim_k; + let packed_dk_kv = uniforms.num_heads * uniforms.head_dim_k; let packed_dv = uniforms.num_heads * uniforms.head_dim_v; // Initialize state tile in private memory @@ -87,11 +91,10 @@ $MAIN { for (var t = 0u; t < uniforms.seq_length; t++) { let bt_offset = batch_idx * uniforms.seq_length + t; var k_val: f32 = 0.0; - var q_val: f32 = 0.0; if (dk_idx < uniforms.head_dim_k) { - let qk_idx = bt_offset * packed_dk + head_idx * uniforms.head_dim_k + dk_idx; - k_val = f32(key[qk_idx]); - q_val = f32(query[qk_idx]); + let k_head_idx = head_idx / uniforms.kv_per_k_head; + let k_idx = bt_offset * packed_dk_k + k_head_idx * uniforms.head_dim_k + dk_idx; + k_val = f32(key[k_idx]); } // Step 1: Apply decay (for gated and gated_delta modes) @@ -102,7 +105,7 @@ $MAIN { #else var exp_g: f32 = 1.0; if (dk_idx < uniforms.head_dim_k) { - exp_g = exp(f32(decay[bt_offset * packed_dk + head_idx * uniforms.head_dim_k + dk_idx])); + exp_g = exp(f32(decay[bt_offset * packed_dk_kv + head_idx * uniforms.head_dim_k + dk_idx])); } #endif for (var j = 0u; j < TILE_V; j++) { @@ -111,14 +114,32 @@ $MAIN { #endif #if update_rule == UPDATE_DELTA || update_rule == UPDATE_GATED_DELTA - // Fused reduction: compute retrieved = S^T@k, pre_output = S^T@q, - // and kq_dot = k^T@q in a single barrier tree. - // Then: output = scale * (pre_output + delta * kq_dot) + // Determine Q head and output head for this KV head. + // Standard GQA/MHA (heads_per_group > 0): Q indexed by q_head = kv_head * hpg. + // Inverse GQA (heads_per_group == 0): multiple KV heads share one Q head; + // output indexed by KV head (each KV head has its own output slot). + var q_head_0: u32; + var out_head_0: u32; + if (uniforms.heads_per_group > 0u) { + q_head_0 = head_idx * uniforms.heads_per_group; + out_head_0 = q_head_0; + } else { + q_head_0 = head_idx * uniforms.q_num_heads / uniforms.num_heads; + out_head_0 = head_idx; + } + var q0_val: f32 = 0.0; + if (dk_idx < uniforms.head_dim_k) { + q0_val = f32(query[bt_offset * packed_dk_q + q_head_0 * uniforms.head_dim_k + dk_idx]); + } + + // Fused reduction: compute retrieved = S^T@k, pre_output = S^T@q_0, + // and kq_dot = k^T@q_0 in a single barrier tree. + // Then: output_0 = scale * (pre_output + delta * kq_dot) for (var j = 0u; j < TILE_V; j++) { red_retrieved[j * workgroup_size_x + dk_idx] = state[j] * k_val; - red_preout[j * workgroup_size_x + dk_idx] = state[j] * q_val; + red_preout[j * workgroup_size_x + dk_idx] = state[j] * q0_val; } - red_kq[dk_idx] = k_val * q_val; + red_kq[dk_idx] = k_val * q0_val; workgroupBarrier(); for (var stride = workgroup_size_x >> 1u; stride > 0u; stride = stride >> 1u) { @@ -132,7 +153,7 @@ $MAIN { workgroupBarrier(); } - // Thread 0: compute delta, broadcast it, and write fused output. + // Thread 0: compute delta, broadcast it, and write output for out_head_0. let v_base = bt_offset * packed_dv + head_idx * uniforms.head_dim_v + dv_start; let beta_base = bt_offset * uniforms.num_heads + head_idx; if (dk_idx == 0u) { @@ -145,20 +166,50 @@ $MAIN { let v_val = vtype(value[v_base + j]); let delta_j = beta_val * (v_val - retrieved); broadcast_buf[j] = delta_j; - output[v_base + j] = otype((pre_out + delta_j * kq_dot) * uniforms.scale); + output[bt_offset * packed_dv + out_head_0 * uniforms.head_dim_v + dv_start + j] = otype((pre_out + delta_j * kq_dot) * uniforms.scale); } } } workgroupBarrier(); - // All threads: update state with delta + // All threads: update state with delta (S_new = S_old + k * delta) for (var j = 0u; j < TILE_V; j++) { state[j] += k_val * broadcast_buf[j]; } workgroupBarrier(); + // Standard GQA: additional Q heads — output_g = scale * S_new^T @ q_g + // (For inverse GQA, heads_per_group == 0 so this loop is skipped.) + for (var qg = 1u; qg < uniforms.heads_per_group; qg++) { + let q_head_g = head_idx * uniforms.heads_per_group + qg; + var qg_val: f32 = 0.0; + if (dk_idx < uniforms.head_dim_k) { + qg_val = f32(query[bt_offset * packed_dk_q + q_head_g * uniforms.head_dim_k + dk_idx]); + } + for (var j = 0u; j < TILE_V; j++) { + red_preout[j * workgroup_size_x + dk_idx] = state[j] * qg_val; + } + workgroupBarrier(); + for (var stride = workgroup_size_x >> 1u; stride > 0u; stride = stride >> 1u) { + if (dk_idx < stride) { + for (var j = 0u; j < TILE_V; j++) { + red_preout[j * workgroup_size_x + dk_idx] += red_preout[j * workgroup_size_x + dk_idx + stride]; + } + } + workgroupBarrier(); + } + if (dk_idx == 0u) { + for (var j = 0u; j < TILE_V; j++) { + if (dv_start + j < uniforms.head_dim_v) { + output[bt_offset * packed_dv + q_head_g * uniforms.head_dim_v + dv_start + j] = otype(red_preout[j * workgroup_size_x] * uniforms.scale); + } + } + } + workgroupBarrier(); + } + #else - // Linear/gated: S += k ⊗ v, then output = scale * S^T @ q + // Linear/gated: S += k ⊗ v let v_base = bt_offset * packed_dv + head_idx * uniforms.head_dim_v + dv_start; for (var j = 0u; j < TILE_V; j++) { if (dv_start + j < uniforms.head_dim_v) { @@ -166,28 +217,66 @@ $MAIN { } } - // Output = scale * S^T @ q (parallel reduction over dk) - for (var j = 0u; j < TILE_V; j++) { - reduction_buf[j * workgroup_size_x + dk_idx] = state[j] * q_val; - } - workgroupBarrier(); - for (var stride = workgroup_size_x >> 1u; stride > 0u; stride = stride >> 1u) { - if (dk_idx < stride) { + // Output = scale * S^T @ q + if (uniforms.heads_per_group > 0u) { + // Standard GQA / MHA: one output per Q head in group + for (var qg = 0u; qg < uniforms.heads_per_group; qg++) { + let q_head_idx = head_idx * uniforms.heads_per_group + qg; + var q_val_g: f32 = 0.0; + if (dk_idx < uniforms.head_dim_k) { + q_val_g = f32(query[bt_offset * packed_dk_q + q_head_idx * uniforms.head_dim_k + dk_idx]); + } for (var j = 0u; j < TILE_V; j++) { - reduction_buf[j * workgroup_size_x + dk_idx] += reduction_buf[j * workgroup_size_x + dk_idx + stride]; + reduction_buf[j * workgroup_size_x + dk_idx] = state[j] * q_val_g; + } + workgroupBarrier(); + for (var stride = workgroup_size_x >> 1u; stride > 0u; stride = stride >> 1u) { + if (dk_idx < stride) { + for (var j = 0u; j < TILE_V; j++) { + reduction_buf[j * workgroup_size_x + dk_idx] += reduction_buf[j * workgroup_size_x + dk_idx + stride]; + } + } + workgroupBarrier(); } + if (dk_idx == 0u) { + let out_base = bt_offset * packed_dv + q_head_idx * uniforms.head_dim_v + dv_start; + for (var j = 0u; j < TILE_V; j++) { + if (dv_start + j < uniforms.head_dim_v) { + output[out_base + j] = otype(reduction_buf[j * workgroup_size_x] * uniforms.scale); + } + } + } + workgroupBarrier(); + } + } else { + // Inverse GQA: one output per KV head, using shared Q + let q_head_inv = head_idx * uniforms.q_num_heads / uniforms.num_heads; + var q_val_inv: f32 = 0.0; + if (dk_idx < uniforms.head_dim_k) { + q_val_inv = f32(query[bt_offset * packed_dk_q + q_head_inv * uniforms.head_dim_k + dk_idx]); } - workgroupBarrier(); - } - if (dk_idx == 0u) { - let out_base = bt_offset * packed_dv + head_idx * uniforms.head_dim_v + dv_start; for (var j = 0u; j < TILE_V; j++) { - if (dv_start + j < uniforms.head_dim_v) { - output[out_base + j] = otype(reduction_buf[j * workgroup_size_x] * uniforms.scale); + reduction_buf[j * workgroup_size_x + dk_idx] = state[j] * q_val_inv; + } + workgroupBarrier(); + for (var stride = workgroup_size_x >> 1u; stride > 0u; stride = stride >> 1u) { + if (dk_idx < stride) { + for (var j = 0u; j < TILE_V; j++) { + reduction_buf[j * workgroup_size_x + dk_idx] += reduction_buf[j * workgroup_size_x + dk_idx + stride]; + } } + workgroupBarrier(); } + if (dk_idx == 0u) { + let out_base = bt_offset * packed_dv + head_idx * uniforms.head_dim_v + dv_start; + for (var j = 0u; j < TILE_V; j++) { + if (dv_start + j < uniforms.head_dim_v) { + output[out_base + j] = otype(reduction_buf[j * workgroup_size_x] * uniforms.scale); + } + } + } + workgroupBarrier(); } - workgroupBarrier(); #endif } // end token loop diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index fcc3d6c8d438b..1209446c6a367 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -2289,14 +2289,22 @@ ONNX_MS_OPERATOR_SET_SCHEMA( // Output 0: same shape as input (batch_size, channels, ...) propagateShapeFromInputToOutput(ctx, 0, 0); - // Output 1: (batch_size, channels, kernel_size - 1) for ndim=1 + // Output 1: state shape is (batch_size, channels, [non-causal spatial dims...], k_last - 1) + // For ndim=1: (B, C, k_1-1) + // For ndim=2: (B, C, input_H, k_2-1) + // For ndim=3: (B, C, input_D, input_H, k_3-1) if (hasInputShape(ctx, 0) && hasInputShape(ctx, 1)) { auto& input_shape = getInputShape(ctx, 0); auto& weight_shape = getInputShape(ctx, 1); + int64_t ndim = getAttribute(ctx, "ndim", 1); TensorShapeProto state_shape; *state_shape.add_dim() = input_shape.dim(0); // batch_size *state_shape.add_dim() = input_shape.dim(1); // channels - // kernel_size - 1 (last kernel dimension for ndim=1) + // Copy non-causal spatial dims from input (dims 2 .. 2+ndim-2) + for (int64_t i = 0; i < ndim - 1; ++i) { + *state_shape.add_dim() = input_shape.dim(static_cast(2 + i)); + } + // Causal (last) spatial dim: kernel_size - 1 int last_kernel_dim = weight_shape.dim_size() - 1; if (weight_shape.dim(last_kernel_dim).has_dim_value()) { state_shape.add_dim()->set_dim_value(weight_shape.dim(last_kernel_dim).dim_value() - 1); @@ -2368,7 +2376,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "past_state", "Recurrent state from previous step with shape (B, H_kv, d_k, d_v). " "Always 4D. If not provided, defaults to zeros.", - "T", + "S", OpSchema::Optional) .Input(4, "decay", @@ -2392,10 +2400,13 @@ ONNX_MS_OPERATOR_SET_SCHEMA( .Output(1, "present_state", "Updated recurrent state with shape (B, H_kv, d_k, d_v). Always 4D.", - "T") + "S") .TypeConstraint("T", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float tensors.") + .TypeConstraint("S", + {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, + "Constrain state types to float tensors.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { propagateElemTypeFromInputToOutput(ctx, 0, 0); propagateElemTypeFromInputToOutput(ctx, 0, 1); @@ -2416,7 +2427,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( // H_q * d_v: d_v = value.dim(2) / kv_num_heads, then H_q * d_v if (value_shape.dim(2).has_dim_value()) { int64_t d_v = value_shape.dim(2).dim_value() / kv_num_heads; - output_shape.add_dim()->set_dim_value(q_num_heads * d_v); + output_shape.add_dim()->set_dim_value(kv_num_heads * d_v); } else { output_shape.add_dim(); // unknown } diff --git a/onnxruntime/test/contrib_ops/linear_attention_op_test.cc b/onnxruntime/test/contrib_ops/linear_attention_op_test.cc index a6c914afc0c3f..c6715d36f558b 100644 --- a/onnxruntime/test/contrib_ops/linear_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/linear_attention_op_test.cc @@ -129,6 +129,129 @@ void LinearAttentionReference( } } +// GQA-aware reference implementation. +// Q has q_num_heads heads, K has n_k_heads heads, V/state have kv_num_heads heads. +// Standard GQA: q_num_heads >= kv_num_heads, heads_per_group = q_num_heads / kv_num_heads. +// K-to-KV sharing: kv_per_k_head = kv_num_heads / n_k_heads. +void LinearAttentionGQAReference( + const std::string& update_rule, + int batch_size, int q_num_heads, int kv_num_heads, int n_k_heads, + int seq_length, int head_dim_k, int head_dim_v, + float scale, + const std::vector& query, // (B, q_num_heads, T, dk) + const std::vector& key, // (B, n_k_heads, T, dk) + const std::vector& value, // (B, kv_num_heads, T, dv) + const std::vector* initial_state, // (B, kv_num_heads, dk, dv) + const std::vector* decay, // (B, kv_num_heads, T[, dk]) + const std::vector* beta, // (B, kv_num_heads, T) + std::vector& output, // (B, kv_num_heads, T, dv) + std::vector& final_state) { // (B, kv_num_heads, dk, dv) + int bht_kv = batch_size * kv_num_heads * seq_length; + bool decay_broadcast_dk = (decay != nullptr && static_cast(decay->size()) == bht_kv); + int kv_per_k_head = kv_num_heads / n_k_heads; + bool inverse_gqa = q_num_heads < kv_num_heads; + int heads_per_group = inverse_gqa ? 0 : q_num_heads / kv_num_heads; + + final_state.resize(batch_size * kv_num_heads * head_dim_k * head_dim_v, 0.0f); + // Output always indexed by kv_num_heads (matches schema: output_dim == V_dim) + output.resize(batch_size * kv_num_heads * seq_length * head_dim_v, 0.0f); + + if (initial_state != nullptr) { + final_state = *initial_state; + } + + for (int b = 0; b < batch_size; b++) { + for (int kv_h = 0; kv_h < kv_num_heads; kv_h++) { + int k_head = kv_h / kv_per_k_head; + + auto state_offset = [&](int k, int v) { + return ((b * kv_num_heads + kv_h) * head_dim_k + k) * head_dim_v + v; + }; + + for (int t = 0; t < seq_length; t++) { + // Load k from the K-head that this KV-head maps to + std::vector k_vec(head_dim_k), v_vec(head_dim_v); + int k_base = ((b * n_k_heads + k_head) * seq_length + t) * head_dim_k; + for (int i = 0; i < head_dim_k; i++) k_vec[i] = key[k_base + i]; + int v_base = ((b * kv_num_heads + kv_h) * seq_length + t) * head_dim_v; + for (int i = 0; i < head_dim_v; i++) v_vec[i] = value[v_base + i]; + + // Step 1: Apply decay + if (update_rule == "gated" || update_rule == "gated_delta") { + for (int k = 0; k < head_dim_k; k++) { + float exp_g; + if (decay_broadcast_dk) { + exp_g = std::exp((*decay)[(b * kv_num_heads + kv_h) * seq_length + t]); + } else { + exp_g = std::exp((*decay)[((b * kv_num_heads + kv_h) * seq_length + t) * head_dim_k + k]); + } + for (int v_idx = 0; v_idx < head_dim_v; v_idx++) { + final_state[state_offset(k, v_idx)] *= exp_g; + } + } + } + + // Step 2: Update state + if (update_rule == "delta" || update_rule == "gated_delta") { + std::vector retrieved(head_dim_v, 0.0f); + for (int v_idx = 0; v_idx < head_dim_v; v_idx++) { + for (int k = 0; k < head_dim_k; k++) { + retrieved[v_idx] += final_state[state_offset(k, v_idx)] * k_vec[k]; + } + } + int beta_idx = (b * kv_num_heads + kv_h) * seq_length + t; + float beta_val = (*beta)[beta_idx]; + std::vector delta(head_dim_v); + for (int v_idx = 0; v_idx < head_dim_v; v_idx++) { + delta[v_idx] = beta_val * (v_vec[v_idx] - retrieved[v_idx]); + } + for (int k = 0; k < head_dim_k; k++) { + for (int v_idx = 0; v_idx < head_dim_v; v_idx++) { + final_state[state_offset(k, v_idx)] += k_vec[k] * delta[v_idx]; + } + } + } else { + for (int k = 0; k < head_dim_k; k++) { + for (int v_idx = 0; v_idx < head_dim_v; v_idx++) { + final_state[state_offset(k, v_idx)] += k_vec[k] * v_vec[v_idx]; + } + } + } + + // Step 3: Compute output + if (!inverse_gqa) { + // Standard GQA/MHA: one output per Q head + for (int g = 0; g < heads_per_group; g++) { + int q_h = kv_h * heads_per_group + g; + int q_base = ((b * q_num_heads + q_h) * seq_length + t) * head_dim_k; + for (int v_idx = 0; v_idx < head_dim_v; v_idx++) { + float sum = 0.0f; + for (int k = 0; k < head_dim_k; k++) { + sum += final_state[state_offset(k, v_idx)] * query[q_base + k]; + } + // For standard, output head == q head; since q==kv per schema, also == kv_h index + int out_idx = ((b * kv_num_heads + (kv_h * heads_per_group + g)) * seq_length + t) * head_dim_v + v_idx; + output[out_idx] = scale * sum; + } + } + } else { + // Inverse GQA: output indexed by kv_head, Q broadcast + int q_h = kv_h * q_num_heads / kv_num_heads; + int q_base = ((b * q_num_heads + q_h) * seq_length + t) * head_dim_k; + for (int v_idx = 0; v_idx < head_dim_v; v_idx++) { + float sum = 0.0f; + for (int k = 0; k < head_dim_k; k++) { + sum += final_state[state_offset(k, v_idx)] * query[q_base + k]; + } + int out_idx = ((b * kv_num_heads + kv_h) * seq_length + t) * head_dim_v + v_idx; + output[out_idx] = scale * sum; + } + } + } + } + } +} + // Convert data from 4D (B,H,T,D) layout to 3D packed (B,T,H*D) layout std::vector PackBHTD_to_BTHD(const std::vector& data_4d, int B, int H, int T, int D) { @@ -273,10 +396,87 @@ void RunLinearAttentionTest( tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } -} // namespace +// GQA-aware test harness. +// Q: (B, q_num_heads, T, dk), K: (B, n_k_heads, T, dk), V: (B, kv_num_heads, T, dv) +void RunLinearAttentionGQATest( + const std::string& update_rule, + int batch_size, int q_num_heads, int kv_num_heads, int n_k_heads, + int seq_length, int head_dim_k, int head_dim_v, + float scale, + const std::vector& query, + const std::vector& key, + const std::vector& value, + const std::vector* initial_state, + const std::vector* decay, + const std::vector* beta_data) { + auto ep = TryGetEpWithLinearAttention(); + if (!ep) { + GTEST_SKIP() << "LinearAttention kernel not registered"; + return; + } -// =========================================================================== -// Test: Linear update rule (simplest - no decay, no beta) + std::vector expected_output_4d, expected_state; + LinearAttentionGQAReference(update_rule, batch_size, q_num_heads, kv_num_heads, n_k_heads, + seq_length, head_dim_k, head_dim_v, scale, + query, key, value, initial_state, decay, beta_data, + expected_output_4d, expected_state); + + int bht_kv = batch_size * kv_num_heads * seq_length; + bool decay_broadcast_dk = (decay != nullptr && static_cast(decay->size()) == bht_kv); + + // Pack to 3D — each tensor uses its own head count + auto query_3d = PackBHTD_to_BTHD(query, batch_size, q_num_heads, seq_length, head_dim_k); + auto key_3d = PackBHTD_to_BTHD(key, batch_size, n_k_heads, seq_length, head_dim_k); + auto value_3d = PackBHTD_to_BTHD(value, batch_size, kv_num_heads, seq_length, head_dim_v); + // Output always indexed by kv_num_heads (matches schema) + auto output_3d = PackBHTD_to_BTHD(expected_output_4d, batch_size, kv_num_heads, seq_length, head_dim_v); + + OpTester tester("LinearAttention", 1, onnxruntime::kMSDomain); + tester.AddAttribute("update_rule", update_rule); + tester.AddAttribute("scale", scale); + tester.AddAttribute("q_num_heads", static_cast(q_num_heads)); + tester.AddAttribute("kv_num_heads", static_cast(kv_num_heads)); + + tester.AddInput("query", {batch_size, seq_length, q_num_heads * head_dim_k}, query_3d); + tester.AddInput("key", {batch_size, seq_length, n_k_heads * head_dim_k}, key_3d); + tester.AddInput("value", {batch_size, seq_length, kv_num_heads * head_dim_v}, value_3d); + + if (initial_state != nullptr) { + tester.AddInput("past_state", {batch_size, kv_num_heads, head_dim_k, head_dim_v}, *initial_state); + } else { + tester.AddOptionalInputEdge(); + } + + if (decay != nullptr) { + if (decay_broadcast_dk) { + auto decay_3d = TransposeBHT_to_BTH(*decay, batch_size, kv_num_heads, seq_length); + tester.AddInput("decay", {batch_size, seq_length, kv_num_heads}, decay_3d); + } else { + auto decay_3d = PackBHTD_to_BTHD(*decay, batch_size, kv_num_heads, seq_length, head_dim_k); + tester.AddInput("decay", {batch_size, seq_length, kv_num_heads * head_dim_k}, decay_3d); + } + } else { + tester.AddOptionalInputEdge(); + } + + if (beta_data != nullptr) { + auto beta_3d = TransposeBHT_to_BTH(*beta_data, batch_size, kv_num_heads, seq_length); + tester.AddInput("beta", {batch_size, seq_length, kv_num_heads}, beta_3d); + } else { + tester.AddOptionalInputEdge(); + } + + tester.AddOutput("output", {batch_size, seq_length, kv_num_heads * head_dim_v}, + output_3d, false, 0.005f, 0.005f); + tester.AddOutput("present_state", {batch_size, kv_num_heads, head_dim_k, head_dim_v}, + expected_state, false, 0.005f, 0.005f); + + std::vector> execution_providers; + execution_providers.push_back(std::move(ep)); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +} // namespace // =========================================================================== TEST(ContribOpLinearAttentionTest, LinearRule_SingleToken) { const int B = 1, H = 1, T = 1, dk = 4, dv = 4; @@ -807,5 +1007,195 @@ TEST(ContribOpLinearAttentionTest, GatedDeltaRule_LargerDims) { &initial_state, &decay, &beta); } +// =========================================================================== +// Tests: GQA (Grouped Query Attention) — q_num_heads != kv_num_heads +// =========================================================================== +// Tests: GQA — K has fewer heads than KV (n_k < kv_num_heads) +// Schema requires q_num_heads == kv_num_heads; K head count is derived from +// the key tensor shape. Multiple KV heads share one K head via kv_per_k_head. +// =========================================================================== + +// Small K-GQA: q=kv=4, n_k=2 → each K head serves 2 KV heads +TEST(ContribOpLinearAttentionTest, GatedDeltaRule_KGQA_Small) { + const int B = 1, q_H = 4, kv_H = 4, n_k = 2, T = 3, dk = 4, dv = 4; + float scale = 1.0f / std::sqrt(static_cast(dk)); + + std::vector query(B * q_H * T * dk); + std::vector key(B * n_k * T * dk); + std::vector value(B * kv_H * T * dv); + std::vector decay(B * kv_H * T); // broadcast + std::vector beta(B * kv_H * T); + + for (int i = 0; i < B * q_H * T * dk; i++) { + query[i] = 0.5f * std::sin(static_cast(i) * 0.13f); + } + for (int i = 0; i < B * n_k * T * dk; i++) { + key[i] = 0.5f * std::cos(static_cast(i) * 0.17f); + } + for (int i = 0; i < B * kv_H * T * dv; i++) { + value[i] = 0.5f * std::sin(static_cast(i) * 0.23f + 0.5f); + } + for (int i = 0; i < B * kv_H * T; i++) { + decay[i] = -0.1f - 0.05f * std::abs(std::sin(static_cast(i) * 0.3f)); + beta[i] = 0.5f + 0.3f * std::sin(static_cast(i) * 0.31f); + } + + std::vector initial_state(B * kv_H * dk * dv, 0.1f); + + RunLinearAttentionGQATest("gated_delta", B, q_H, kv_H, n_k, T, dk, dv, scale, + query, key, value, + &initial_state, &decay, &beta); +} + +// Linear rule with K-GQA: q=kv=4, n_k=2 +TEST(ContribOpLinearAttentionTest, LinearRule_KGQA) { + const int B = 1, q_H = 4, kv_H = 4, n_k = 2, T = 3, dk = 4, dv = 4; + float scale = 1.0f / std::sqrt(static_cast(dk)); + + std::vector query(B * q_H * T * dk); + std::vector key(B * n_k * T * dk); + std::vector value(B * kv_H * T * dv); + + for (int i = 0; i < B * q_H * T * dk; i++) { + query[i] = 0.5f * std::sin(static_cast(i) * 0.13f); + } + for (int i = 0; i < B * n_k * T * dk; i++) { + key[i] = 0.5f * std::cos(static_cast(i) * 0.17f); + } + for (int i = 0; i < B * kv_H * T * dv; i++) { + value[i] = 0.5f * std::sin(static_cast(i) * 0.23f + 0.5f); + } + + RunLinearAttentionGQATest("linear", B, q_H, kv_H, n_k, T, dk, dv, scale, + query, key, value, + nullptr, nullptr, nullptr); +} + +// Qwen3.5 9B-like: q=kv=32, n_k=16 (K has half the heads), +// dk=128, dv=128, broadcast decay +TEST(ContribOpLinearAttentionTest, GatedDeltaRule_Qwen35_KGQA) { + const int B = 1, q_H = 32, kv_H = 32, n_k = 16, T = 4, dk = 128, dv = 128; + float scale = 1.0f / std::sqrt(static_cast(dk)); + + std::vector query(B * q_H * T * dk); + std::vector key(B * n_k * T * dk); + std::vector value(B * kv_H * T * dv); + std::vector decay(B * kv_H * T); // broadcast + std::vector beta(B * kv_H * T); + + for (int i = 0; i < B * q_H * T * dk; i++) { + query[i] = 0.05f * std::sin(static_cast(i) * 0.013f); + } + for (int i = 0; i < B * n_k * T * dk; i++) { + key[i] = 0.05f * std::cos(static_cast(i) * 0.017f); + } + for (int i = 0; i < B * kv_H * T * dv; i++) { + value[i] = 0.05f * std::sin(static_cast(i) * 0.023f + 0.5f); + } + for (int i = 0; i < B * kv_H * T; i++) { + decay[i] = -0.1f - 0.05f * std::abs(std::sin(static_cast(i) * 0.3f)); + beta[i] = 0.5f + 0.3f * std::sin(static_cast(i) * 0.31f); + } + + std::vector initial_state(B * kv_H * dk * dv, 0.01f); + + RunLinearAttentionGQATest("gated_delta", B, q_H, kv_H, n_k, T, dk, dv, scale, + query, key, value, + &initial_state, &decay, &beta); +} + +// =========================================================================== +// Tests: Inverse GQA — q_num_heads < kv_num_heads +// Each KV head has its own output slot; Q is broadcast across KV groups. +// =========================================================================== + +// Small inverse GQA: q=2, kv=4 → each Q head shared by 2 KV heads +TEST(ContribOpLinearAttentionTest, GatedDeltaRule_InverseGQA_Small) { + const int B = 1, q_H = 2, kv_H = 4, n_k = 4, T = 3, dk = 4, dv = 4; + float scale = 1.0f / std::sqrt(static_cast(dk)); + + std::vector query(B * q_H * T * dk); + std::vector key(B * n_k * T * dk); + std::vector value(B * kv_H * T * dv); + std::vector decay(B * kv_H * T); // broadcast + std::vector beta(B * kv_H * T); + + for (int i = 0; i < B * q_H * T * dk; i++) { + query[i] = 0.5f * std::sin(static_cast(i) * 0.13f); + } + for (int i = 0; i < B * n_k * T * dk; i++) { + key[i] = 0.5f * std::cos(static_cast(i) * 0.17f); + } + for (int i = 0; i < B * kv_H * T * dv; i++) { + value[i] = 0.5f * std::sin(static_cast(i) * 0.23f + 0.5f); + } + for (int i = 0; i < B * kv_H * T; i++) { + decay[i] = -0.1f - 0.05f * std::abs(std::sin(static_cast(i) * 0.3f)); + beta[i] = 0.5f + 0.3f * std::sin(static_cast(i) * 0.31f); + } + + std::vector initial_state(B * kv_H * dk * dv, 0.1f); + + RunLinearAttentionGQATest("gated_delta", B, q_H, kv_H, n_k, T, dk, dv, scale, + query, key, value, + &initial_state, &decay, &beta); +} + +// Linear rule with inverse GQA: q=2, kv=4 +TEST(ContribOpLinearAttentionTest, LinearRule_InverseGQA) { + const int B = 1, q_H = 2, kv_H = 4, n_k = 4, T = 3, dk = 4, dv = 4; + float scale = 1.0f / std::sqrt(static_cast(dk)); + + std::vector query(B * q_H * T * dk); + std::vector key(B * n_k * T * dk); + std::vector value(B * kv_H * T * dv); + + for (int i = 0; i < B * q_H * T * dk; i++) { + query[i] = 0.5f * std::sin(static_cast(i) * 0.13f); + } + for (int i = 0; i < B * n_k * T * dk; i++) { + key[i] = 0.5f * std::cos(static_cast(i) * 0.17f); + } + for (int i = 0; i < B * kv_H * T * dv; i++) { + value[i] = 0.5f * std::sin(static_cast(i) * 0.23f + 0.5f); + } + + RunLinearAttentionGQATest("linear", B, q_H, kv_H, n_k, T, dk, dv, scale, + query, key, value, + nullptr, nullptr, nullptr); +} + +// Larger inverse GQA with K-head sharing: q=2, kv=8, n_k=4, dk=16, dv=64 +TEST(ContribOpLinearAttentionTest, GatedDeltaRule_InverseGQA_LargerDims) { + const int B = 1, q_H = 2, kv_H = 8, n_k = 4, T = 4, dk = 16, dv = 64; + float scale = 1.0f / std::sqrt(static_cast(dk)); + + std::vector query(B * q_H * T * dk); + std::vector key(B * n_k * T * dk); + std::vector value(B * kv_H * T * dv); + std::vector decay(B * kv_H * T); // broadcast + std::vector beta(B * kv_H * T); + + for (int i = 0; i < B * q_H * T * dk; i++) { + query[i] = 0.1f * std::sin(static_cast(i) * 0.013f); + } + for (int i = 0; i < B * n_k * T * dk; i++) { + key[i] = 0.1f * std::cos(static_cast(i) * 0.017f); + } + for (int i = 0; i < B * kv_H * T * dv; i++) { + value[i] = 0.1f * std::sin(static_cast(i) * 0.023f + 0.5f); + } + for (int i = 0; i < B * kv_H * T; i++) { + decay[i] = -0.1f - 0.05f * std::abs(std::sin(static_cast(i) * 0.3f)); + beta[i] = 0.5f + 0.3f * std::sin(static_cast(i) * 0.31f); + } + + std::vector initial_state(B * kv_H * dk * dv, 0.01f); + + RunLinearAttentionGQATest("gated_delta", B, q_H, kv_H, n_k, T, dk, dv, scale, + query, key, value, + &initial_state, &decay, &beta); +} + } // namespace test } // namespace onnxruntime From 15e4010fa823294310fd36cee3eed73b2e1a474b Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Sun, 5 Apr 2026 11:21:39 -0700 Subject: [PATCH 27/27] fix issue in Expand that shows with Qwen3.5 embeddings --- onnxruntime/core/providers/webgpu/tensor/expand.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/webgpu/tensor/expand.cc b/onnxruntime/core/providers/webgpu/tensor/expand.cc index 0dacd589cbba8..279f7dcdc2ade 100644 --- a/onnxruntime/core/providers/webgpu/tensor/expand.cc +++ b/onnxruntime/core/providers/webgpu/tensor/expand.cc @@ -21,7 +21,7 @@ Status ExpandProgram::GenerateShaderCode(ShaderHelper& shader) const { // The last dims of input shape and output shape are all divisible by 4. shader.MainFunctionBody() << " let output_indices = " << output_indices.OffsetToIndices("global_idx * 4") << ";\n" << " let input_offset = " << input_indices.BroadcastedIndicesToOffset("output_indices", output_indices) << ";\n" - << output.SetByOffset("global_idx", input.GetByOffset("input_offset")); + << output.SetByOffset("global_idx", input.GetByOffset("input_offset / 4")); } else if (output_last_dim_divisible_by_4_) { // The last dim of output shape is divisible by 4, and the last dim of input shape is 1. shader.MainFunctionBody() << " let output_indices = " << output_indices.OffsetToIndices("global_idx * 4") << ";\n"