Skip to content

Commit 08240cb

Browse files
committed
reshape the code structure of GenerateShaderCode
1 parent 9f3383c commit 08240cb

1 file changed

Lines changed: 30 additions & 34 deletions

File tree

onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc

Lines changed: 30 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,13 @@ ONNX_OPERATOR_KERNEL_EX(
2222

2323
Status RotaryEmbeddingProgram::GenerateShaderCode(ShaderHelper& shader) const {
2424
const auto& input = shader.AddInput("input", ShaderUsage::UseUniform);
25+
const auto& cos_cache = shader.AddInput("cos_cache", ShaderUsage::UseUniform);
26+
const auto& sin_cache = shader.AddInput("sin_cache", ShaderUsage::UseUniform);
27+
const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform);
2528
const auto interleaved_str = interleaved_ ? "true" : "false";
26-
2729
if (use_position_offset_) {
2830
// Position offset path: inputs are [input, cos_cache, sin_cache].
2931
// Compute position_id = position_offset + sequence_index (no position_ids tensor).
30-
const auto& cos_cache = shader.AddInput("cos_cache", ShaderUsage::UseUniform);
31-
const auto& sin_cache = shader.AddInput("sin_cache", ShaderUsage::UseUniform);
32-
const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform);
3332
shader.MainFunctionBody() << " let half_rotary_emb_dim = uniforms.cos_cache_shape[1];\n"
3433
" let bsnh = global_idx / uniforms.global_stride % uniforms.global_shape;\n"
3534
" let size = uniforms.global_shape[0] * uniforms.global_stride[0];\n"
@@ -55,43 +54,40 @@ Status RotaryEmbeddingProgram::GenerateShaderCode(ShaderHelper& shader) const {
5554
} else {
5655
// Original path: inputs are [input, position_ids, cos_cache, sin_cache].
5756
const auto& position_ids = shader.AddInput("position_ids", ShaderUsage::UseUniform);
58-
const auto& cos_cache = shader.AddInput("cos_cache", ShaderUsage::UseUniform);
59-
const auto& sin_cache = shader.AddInput("sin_cache", ShaderUsage::UseUniform);
60-
const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform);
6157
// TODO: remove output_indices.
6258
const auto& output_indices = shader.AddIndices("output_indices", ShaderUsage::None);
6359
shader.MainFunctionBody() << " let half_rotary_emb_dim = uniforms.cos_cache_shape[1];\n"
6460
" let bsnh = global_idx / uniforms.global_stride % uniforms.global_shape;\n"
6561
" let size = uniforms.global_shape[0] * uniforms.global_stride[0];\n"
6662
" if (global_idx >= size) { return; }\n"
6763
" if (bsnh[3] < half_rotary_emb_dim) {\n"
68-
<< " let position_ids_idx = " << position_ids.BroadcastedIndicesToOffset("bsnh.xy", output_indices) << ";\n"
69-
<< " let raw_pos = " << position_ids.GetByOffset("position_ids_idx") << ";\n"
70-
<< " let i = dot(bsnh, uniforms.input_output_stride) + select(0, bsnh[3], " << interleaved_str << ");\n"
71-
<< " let j = i + select(half_rotary_emb_dim, 1, " << interleaved_str << ");\n"
72-
" let max_position = uniforms.cos_cache_shape[0];\n"
73-
// Bounds check: raw_pos < 0 catches negative position_ids (i32 from truncated int64).
74-
// After u32 conversion + offset, check >= max_position catches too-large values.
75-
// On OOB, pass through input unchanged (same as CUDA kernel behavior).
76-
" if (raw_pos < 0) {\n"
77-
<< " " << output.SetByOffset("i", input.GetByOffset("i")) << "\n"
78-
<< " " << output.SetByOffset("j", input.GetByOffset("j")) << "\n"
79-
" } else {\n"
80-
" let position_id = u32(raw_pos) + select(0, bsnh[1], position_ids_idx == 0);\n"
81-
" if (position_id >= max_position) {\n"
82-
<< " " << output.SetByOffset("i", input.GetByOffset("i")) << "\n"
83-
<< " " << output.SetByOffset("j", input.GetByOffset("j")) << "\n"
84-
" } else {\n"
85-
<< " let re = " << input.GetByOffset("i") << " * " << cos_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") << " - " << input.GetByOffset("j") << " * " << sin_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") << ";\n"
86-
<< " " << output.SetByOffset("i", "re") << "\n"
87-
<< " let im = " << input.GetByOffset("i") << " * " << sin_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") << " + " << input.GetByOffset("j") << " * " << cos_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") << ";\n"
88-
<< " " << output.SetByOffset("j", "im") << "\n"
89-
" }\n"
90-
" }\n"
91-
<< " } else { \n"
92-
" let k = dot(bsnh, uniforms.input_output_stride) + half_rotary_emb_dim;\n"
93-
<< " " << output.SetByOffset("k", input.GetByOffset("k")) << "\n"
94-
<< " }";
64+
<< " let position_ids_idx = " << position_ids.BroadcastedIndicesToOffset("bsnh.xy", output_indices) << ";\n"
65+
<< " let raw_pos = " << position_ids.GetByOffset("position_ids_idx") << ";\n"
66+
<< " let i = dot(bsnh, uniforms.input_output_stride) + select(0, bsnh[3], " << interleaved_str << ");\n"
67+
<< " let j = i + select(half_rotary_emb_dim, 1, " << interleaved_str << ");\n"
68+
" let max_position = uniforms.cos_cache_shape[0];\n"
69+
// Bounds check: raw_pos < 0 catches negative position_ids (i32 from truncated int64).
70+
// After u32 conversion + offset, check >= max_position catches too-large values.
71+
// On OOB, pass through input unchanged (same as CUDA kernel behavior).
72+
" if (raw_pos < 0) {\n"
73+
<< " " << output.SetByOffset("i", input.GetByOffset("i")) << "\n"
74+
<< " " << output.SetByOffset("j", input.GetByOffset("j")) << "\n"
75+
" } else {\n"
76+
" let position_id = u32(raw_pos) + select(0, bsnh[1], position_ids_idx == 0);\n"
77+
" if (position_id >= max_position) {\n"
78+
<< " " << output.SetByOffset("i", input.GetByOffset("i")) << "\n"
79+
<< " " << output.SetByOffset("j", input.GetByOffset("j")) << "\n"
80+
" } else {\n"
81+
<< " let re = " << input.GetByOffset("i") << " * " << cos_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") << " - " << input.GetByOffset("j") << " * " << sin_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") << ";\n"
82+
<< " " << output.SetByOffset("i", "re") << "\n"
83+
<< " let im = " << input.GetByOffset("i") << " * " << sin_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") << " + " << input.GetByOffset("j") << " * " << cos_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") << ";\n"
84+
<< " " << output.SetByOffset("j", "im") << "\n"
85+
" }\n"
86+
" }\n"
87+
<< " } else { \n"
88+
" let k = dot(bsnh, uniforms.input_output_stride) + half_rotary_emb_dim;\n"
89+
<< " " << output.SetByOffset("k", input.GetByOffset("k")) << "\n"
90+
<< " }";
9591
}
9692

9793
return Status::OK();

0 commit comments

Comments
 (0)