@@ -22,14 +22,13 @@ ONNX_OPERATOR_KERNEL_EX(
2222
2323Status RotaryEmbeddingProgram::GenerateShaderCode (ShaderHelper& shader) const {
2424 const auto & input = shader.AddInput (" input" , ShaderUsage::UseUniform);
25+ const auto & cos_cache = shader.AddInput (" cos_cache" , ShaderUsage::UseUniform);
26+ const auto & sin_cache = shader.AddInput (" sin_cache" , ShaderUsage::UseUniform);
27+ const auto & output = shader.AddOutput (" output" , ShaderUsage::UseUniform);
2528 const auto interleaved_str = interleaved_ ? " true" : " false" ;
26-
2729 if (use_position_offset_) {
2830 // Position offset path: inputs are [input, cos_cache, sin_cache].
2931 // Compute position_id = position_offset + sequence_index (no position_ids tensor).
30- const auto & cos_cache = shader.AddInput (" cos_cache" , ShaderUsage::UseUniform);
31- const auto & sin_cache = shader.AddInput (" sin_cache" , ShaderUsage::UseUniform);
32- const auto & output = shader.AddOutput (" output" , ShaderUsage::UseUniform);
3332 shader.MainFunctionBody () << " let half_rotary_emb_dim = uniforms.cos_cache_shape[1];\n "
3433 " let bsnh = global_idx / uniforms.global_stride % uniforms.global_shape;\n "
3534 " let size = uniforms.global_shape[0] * uniforms.global_stride[0];\n "
@@ -55,43 +54,40 @@ Status RotaryEmbeddingProgram::GenerateShaderCode(ShaderHelper& shader) const {
5554 } else {
5655 // Original path: inputs are [input, position_ids, cos_cache, sin_cache].
5756 const auto & position_ids = shader.AddInput (" position_ids" , ShaderUsage::UseUniform);
58- const auto & cos_cache = shader.AddInput (" cos_cache" , ShaderUsage::UseUniform);
59- const auto & sin_cache = shader.AddInput (" sin_cache" , ShaderUsage::UseUniform);
60- const auto & output = shader.AddOutput (" output" , ShaderUsage::UseUniform);
6157 // TODO: remove output_indices.
6258 const auto & output_indices = shader.AddIndices (" output_indices" , ShaderUsage::None);
6359 shader.MainFunctionBody () << " let half_rotary_emb_dim = uniforms.cos_cache_shape[1];\n "
6460 " let bsnh = global_idx / uniforms.global_stride % uniforms.global_shape;\n "
6561 " let size = uniforms.global_shape[0] * uniforms.global_stride[0];\n "
6662 " if (global_idx >= size) { return; }\n "
6763 " if (bsnh[3] < half_rotary_emb_dim) {\n "
68- << " let position_ids_idx = " << position_ids.BroadcastedIndicesToOffset (" bsnh.xy" , output_indices) << " ;\n "
69- << " let raw_pos = " << position_ids.GetByOffset (" position_ids_idx" ) << " ;\n "
70- << " let i = dot(bsnh, uniforms.input_output_stride) + select(0, bsnh[3], " << interleaved_str << " );\n "
71- << " let j = i + select(half_rotary_emb_dim, 1, " << interleaved_str << " );\n "
72- " let max_position = uniforms.cos_cache_shape[0];\n "
73- // Bounds check: raw_pos < 0 catches negative position_ids (i32 from truncated int64).
74- // After u32 conversion + offset, check >= max_position catches too-large values.
75- // On OOB, pass through input unchanged (same as CUDA kernel behavior).
76- " if (raw_pos < 0) {\n "
77- << " " << output.SetByOffset (" i" , input.GetByOffset (" i" )) << " \n "
78- << " " << output.SetByOffset (" j" , input.GetByOffset (" j" )) << " \n "
79- " } else {\n "
80- " let position_id = u32(raw_pos) + select(0, bsnh[1], position_ids_idx == 0);\n "
81- " if (position_id >= max_position) {\n "
82- << " " << output.SetByOffset (" i" , input.GetByOffset (" i" )) << " \n "
83- << " " << output.SetByOffset (" j" , input.GetByOffset (" j" )) << " \n "
84- " } else {\n "
85- << " let re = " << input.GetByOffset (" i" ) << " * " << cos_cache.GetByIndices (" vec2<u32>(position_id, bsnh[3])" ) << " - " << input.GetByOffset (" j" ) << " * " << sin_cache.GetByIndices (" vec2<u32>(position_id, bsnh[3])" ) << " ;\n "
86- << " " << output.SetByOffset (" i" , " re" ) << " \n "
87- << " let im = " << input.GetByOffset (" i" ) << " * " << sin_cache.GetByIndices (" vec2<u32>(position_id, bsnh[3])" ) << " + " << input.GetByOffset (" j" ) << " * " << cos_cache.GetByIndices (" vec2<u32>(position_id, bsnh[3])" ) << " ;\n "
88- << " " << output.SetByOffset (" j" , " im" ) << " \n "
89- " }\n "
90- " }\n "
91- << " } else { \n "
92- " let k = dot(bsnh, uniforms.input_output_stride) + half_rotary_emb_dim;\n "
93- << " " << output.SetByOffset (" k" , input.GetByOffset (" k" )) << " \n "
94- << " }" ;
64+ << " let position_ids_idx = " << position_ids.BroadcastedIndicesToOffset (" bsnh.xy" , output_indices) << " ;\n "
65+ << " let raw_pos = " << position_ids.GetByOffset (" position_ids_idx" ) << " ;\n "
66+ << " let i = dot(bsnh, uniforms.input_output_stride) + select(0, bsnh[3], " << interleaved_str << " );\n "
67+ << " let j = i + select(half_rotary_emb_dim, 1, " << interleaved_str << " );\n "
68+ " let max_position = uniforms.cos_cache_shape[0];\n "
69+ // Bounds check: raw_pos < 0 catches negative position_ids (i32 from truncated int64).
70+ // After u32 conversion + offset, check >= max_position catches too-large values.
71+ // On OOB, pass through input unchanged (same as CUDA kernel behavior).
72+ " if (raw_pos < 0) {\n "
73+ << " " << output.SetByOffset (" i" , input.GetByOffset (" i" )) << " \n "
74+ << " " << output.SetByOffset (" j" , input.GetByOffset (" j" )) << " \n "
75+ " } else {\n "
76+ " let position_id = u32(raw_pos) + select(0, bsnh[1], position_ids_idx == 0);\n "
77+ " if (position_id >= max_position) {\n "
78+ << " " << output.SetByOffset (" i" , input.GetByOffset (" i" )) << " \n "
79+ << " " << output.SetByOffset (" j" , input.GetByOffset (" j" )) << " \n "
80+ " } else {\n "
81+ << " let re = " << input.GetByOffset (" i" ) << " * " << cos_cache.GetByIndices (" vec2<u32>(position_id, bsnh[3])" ) << " - " << input.GetByOffset (" j" ) << " * " << sin_cache.GetByIndices (" vec2<u32>(position_id, bsnh[3])" ) << " ;\n "
82+ << " " << output.SetByOffset (" i" , " re" ) << " \n "
83+ << " let im = " << input.GetByOffset (" i" ) << " * " << sin_cache.GetByIndices (" vec2<u32>(position_id, bsnh[3])" ) << " + " << input.GetByOffset (" j" ) << " * " << cos_cache.GetByIndices (" vec2<u32>(position_id, bsnh[3])" ) << " ;\n "
84+ << " " << output.SetByOffset (" j" , " im" ) << " \n "
85+ " }\n "
86+ " }\n "
87+ << " } else { \n "
88+ " let k = dot(bsnh, uniforms.input_output_stride) + half_rotary_emb_dim;\n "
89+ << " " << output.SetByOffset (" k" , input.GetByOffset (" k" )) << " \n "
90+ << " }" ;
9591 }
9692
9793 return Status::OK ();
0 commit comments