|
| 1 | +// Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | +// Licensed under the MIT License. |
| 3 | + |
| 4 | +#include "core/providers/webgpu/webgpu_supported_types.h" |
| 5 | +#include "core/providers/webgpu/llm/rotary_embedding.h" |
| 6 | +#include "contrib_ops/webgpu/bert/rotary_embedding.h" |
| 7 | +#include "core/providers/webgpu/generator/range.h" |
| 8 | + |
| 9 | +namespace onnxruntime { |
| 10 | +namespace webgpu { |
| 11 | + |
| 12 | +ONNX_OPERATOR_KERNEL_EX( |
| 13 | + RotaryEmbedding, |
| 14 | + kOnnxDomain, |
| 15 | + 23, |
| 16 | + kWebGpuExecutionProvider, |
| 17 | + (*KernelDefBuilder::Create()) |
| 18 | + .TypeConstraint("T", WebGpuSupportedFloatTypes()) |
| 19 | + .TypeConstraint("M", DataTypeImpl::GetTensorType<int64_t>()), |
| 20 | + RotaryEmbedding); |
| 21 | + |
| 22 | +RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : WebGpuKernel(info) { |
| 23 | + rotary_embedding_dim_ = static_cast<int>(info.GetAttrOrDefault<int64_t>("rotary_embedding_dim", 0)); |
| 24 | + num_heads_ = static_cast<int>(info.GetAttrOrDefault<int64_t>("num_heads", 0)); |
| 25 | + interleaved_ = (info.GetAttrOrDefault<int64_t>("interleaved", 0) == 1); |
| 26 | +} |
| 27 | + |
| 28 | +Status RotaryEmbedding::ComputeInternal(ComputeContext& context) const { |
| 29 | + // ONNX inputs: X(0), cos_cache(1), sin_cache(2), position_ids(3, optional) |
| 30 | + const auto* input = context.Input<Tensor>(0); |
| 31 | + const auto* cos_cache = context.Input<Tensor>(1); |
| 32 | + const auto* sin_cache = context.Input<Tensor>(2); |
| 33 | + const auto* position_ids = context.Input<Tensor>(3); // optional |
| 34 | + |
| 35 | + const auto input_shape = input->Shape(); |
| 36 | + auto* output = context.Output(0, input_shape); |
| 37 | + |
| 38 | + const auto batch_size = onnxruntime::narrow<uint32_t>(input_shape[0]); |
| 39 | + const auto batch_stride = onnxruntime::narrow<uint32_t>(input_shape.SizeFromDimension(1)); |
| 40 | + const auto sequence_length = onnxruntime::narrow<uint32_t>(input_shape[input_shape.NumDimensions() - 2]); |
| 41 | + const auto hidden_size = batch_stride / sequence_length; |
| 42 | + const auto half_rotary_embedding_dim = onnxruntime::narrow<uint32_t>(cos_cache->Shape()[cos_cache->Shape().NumDimensions() - 1]); |
| 43 | + |
| 44 | + // Compute head_size: when rotary_embedding_dim is not set, head_size = rotary_dim (= 2 * half). |
| 45 | + // When rotary_embedding_dim is set, derive head_size from the 4D input shape or num_heads attribute. |
| 46 | + uint32_t head_size; |
| 47 | + if (rotary_embedding_dim_ == 0) { |
| 48 | + head_size = half_rotary_embedding_dim * 2; |
| 49 | + } else if (input_shape.NumDimensions() == 4) { |
| 50 | + // 4D input: [batch, num_heads, seq, head_size] |
| 51 | + head_size = onnxruntime::narrow<uint32_t>(input_shape[3]); |
| 52 | + } else { |
| 53 | + ORT_ENFORCE(num_heads_ > 0, |
| 54 | + "Attribute 'num_heads' must be provided when 'rotary_embedding_dim' is specified " |
| 55 | + "and input is not rank-4 (batch, num_heads, sequence, head)."); |
| 56 | + head_size = hidden_size / num_heads_; |
| 57 | + } |
| 58 | + |
| 59 | + const TensorShape global_shape({batch_size, |
| 60 | + sequence_length, |
| 61 | + hidden_size / head_size, |
| 62 | + head_size - half_rotary_embedding_dim}); |
| 63 | + |
| 64 | + const auto rank = global_shape.NumDimensions(); |
| 65 | + std::vector<uint32_t> global_dims(rank); |
| 66 | + std::vector<uint32_t> global_strides(rank); |
| 67 | + for (size_t j = 0; j < rank; ++j) { |
| 68 | + global_dims[j] = onnxruntime::narrow<uint32_t>(global_shape[j]); |
| 69 | + global_strides[j] = onnxruntime::narrow<uint32_t>(global_shape.SizeFromDimension(j + 1)); |
| 70 | + } |
| 71 | + |
| 72 | + const auto output_size = onnxruntime::narrow<const uint32_t>(global_shape.Size()); |
| 73 | + const auto input_output_strides = |
| 74 | + input_shape.NumDimensions() == 3 |
| 75 | + ? std::vector<uint32_t>({batch_stride, hidden_size, head_size, 1}) |
| 76 | + : (input_shape.NumDimensions() == 4 |
| 77 | + ? std::vector<uint32_t>({batch_stride, head_size, sequence_length * head_size, 1}) |
| 78 | + : std::vector<uint32_t>({})); |
| 79 | + |
| 80 | + // The contrib RotaryEmbeddingProgram expects inputs in order: |
| 81 | + // input(0), position_ids(1), cos_cache(2), sin_cache(3) |
| 82 | + // The ONNX op has: X(0), cos_cache(1), sin_cache(2), position_ids(3, optional) |
| 83 | + |
| 84 | + if (position_ids != nullptr) { |
| 85 | + // position_ids provided: cos/sin cache is 2D (max_pos, D/2) |
| 86 | + contrib::webgpu::RotaryEmbeddingProgram program{interleaved_}; |
| 87 | + program |
| 88 | + .CacheHint(interleaved_) |
| 89 | + .AddInputs({{input, ProgramTensorMetadataDependency::TypeAndRank}, |
| 90 | + {position_ids, ProgramTensorMetadataDependency::Rank}, |
| 91 | + {cos_cache, ProgramTensorMetadataDependency::Rank}, |
| 92 | + {sin_cache, ProgramTensorMetadataDependency::Rank}}) |
| 93 | + .AddOutput({output, ProgramTensorMetadataDependency::None}) |
| 94 | + .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) |
| 95 | + .AddUniformVariables({{1.0f}, |
| 96 | + {gsl::make_span(global_dims)}, |
| 97 | + {gsl::make_span(global_strides)}, |
| 98 | + {gsl::make_span(input_output_strides)}}) |
| 99 | + .AddIndices(TensorShape{1, 1}); |
| 100 | + return context.RunProgram(program); |
| 101 | + } |
| 102 | + |
| 103 | + // position_ids NOT provided: cos/sin cache is 3D (B, S, D/2) |
| 104 | + // Reshape to 2D (B*S, D/2) and generate sequential position_ids. |
| 105 | + const auto total_seq = batch_size * sequence_length; |
| 106 | + const TensorShape cache_2d_shape({static_cast<int64_t>(total_seq), |
| 107 | + static_cast<int64_t>(half_rotary_embedding_dim)}); |
| 108 | + |
| 109 | + // Generate position_ids [0, 1, ..., B*S-1] reshaped as (B, S) on GPU using RangeProgram |
| 110 | + const TensorShape pos_ids_shape({static_cast<int64_t>(batch_size), |
| 111 | + static_cast<int64_t>(sequence_length)}); |
| 112 | + Tensor pos_ids_tensor = context.CreateGPUTensor(DataTypeImpl::GetType<int64_t>(), pos_ids_shape); |
| 113 | + { |
| 114 | + RangeProgram range_program{ONNX_NAMESPACE::TensorProto_DataType_INT64}; |
| 115 | + int32_t start_i32 = 0; |
| 116 | + int32_t delta_i32 = 1; |
| 117 | + range_program |
| 118 | + .AddOutput({&pos_ids_tensor, ProgramTensorMetadataDependency::Type}) |
| 119 | + .SetDispatchGroupSize((total_seq + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) |
| 120 | + .AddUniformVariables({ |
| 121 | + total_seq, |
| 122 | + std::bit_cast<uint32_t>(start_i32), |
| 123 | + std::bit_cast<uint32_t>(delta_i32), |
| 124 | + }); |
| 125 | + ORT_RETURN_IF_ERROR(context.RunProgram(range_program)); |
| 126 | + } |
| 127 | + |
| 128 | + contrib::webgpu::RotaryEmbeddingProgram program{interleaved_}; |
| 129 | + program |
| 130 | + .CacheHint(interleaved_) |
| 131 | + .AddInputs({{input, ProgramTensorMetadataDependency::TypeAndRank}, |
| 132 | + {&pos_ids_tensor, ProgramTensorMetadataDependency::Rank}, |
| 133 | + {cos_cache, ProgramTensorMetadataDependency::Rank, cache_2d_shape, 1}, |
| 134 | + {sin_cache, ProgramTensorMetadataDependency::Rank, cache_2d_shape, 1}}) |
| 135 | + .AddOutput({output, ProgramTensorMetadataDependency::None}) |
| 136 | + .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) |
| 137 | + .AddUniformVariables({{1.0f}, |
| 138 | + {gsl::make_span(global_dims)}, |
| 139 | + {gsl::make_span(global_strides)}, |
| 140 | + {gsl::make_span(input_output_strides)}}) |
| 141 | + .AddIndices(TensorShape{1, 1}); |
| 142 | + return context.RunProgram(program); |
| 143 | +} |
| 144 | + |
| 145 | +} // namespace webgpu |
| 146 | +} // namespace onnxruntime |
0 commit comments