Skip to content

Commit ad2e565

Browse files
authored
[webgpu] Fix opset-12 softmax nhwc issue (microsoft#24227)
### Description <!-- Describe your changes. --> ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
1 parent 22787ae commit ad2e565

3 files changed

Lines changed: 9 additions & 5 deletions

File tree

onnxruntime/core/providers/webgpu/math/softmax.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,9 @@ Status Softmax::ComputeInternal(ComputeContext& context) const {
156156

157157
// normalize axis
158158
size_t axis = static_cast<size_t>(HandleNegativeAxis(axis_, input_rank));
159-
bool is_transpose_required = axis < input_rank - 1;
159+
// The `axis` attribute of the opset lower than version 13 describes the axis of the inputs when coerced to 2D,
160+
// the 0th axis most likely describes the batch_size, so transpose is not required on old opset versions.
161+
bool is_transpose_required = axis < input_rank - 1 && opset_ >= 13;
160162

161163
TensorShape transposed_input_shape;
162164
Tensor transposed_input_tensor;
@@ -179,7 +181,9 @@ Status Softmax::ComputeInternal(ComputeContext& context) const {
179181
intermediate_output = context.CreateGPUTensor(output_tensor->DataType(), transposed_input_shape);
180182
}
181183

182-
const int64_t cols = is_transpose_required ? transposed_input_shape[input_rank - 1] : input_shape[input_rank - 1];
184+
// The `axis` attribute of the opset lower than version 13 separates input tensor's dimensions into two parts,
185+
// one part is treated as batch size, and the other part is performed by Softmax.
186+
const int64_t cols = is_transpose_required ? transposed_input_shape[input_rank - 1] : (opset_ >= 13 ? input_shape[input_rank - 1] : input_shape.SizeFromDimension(axis));
183187
const int64_t rows = input_shape.Size() / cols;
184188
const int64_t components = GetMaxComponents(cols);
185189
const auto packed_cols = cols / components;

onnxruntime/core/providers/webgpu/math/softmax.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ namespace webgpu {
1414
class Softmax final : public WebGpuKernel {
1515
public:
1616
Softmax(const OpKernelInfo& info) : WebGpuKernel{info} {
17-
int opset_ = info.node().SinceVersion();
17+
opset_ = info.node().SinceVersion();
1818
int64_t axis;
1919
Status status = info.GetAttr<int64_t>("axis", &axis);
2020

@@ -33,6 +33,7 @@ class Softmax final : public WebGpuKernel {
3333

3434
private:
3535
int64_t axis_;
36+
int opset_;
3637
};
3738

3839
class SoftmaxProgram final : public Program<SoftmaxProgram> {

onnxruntime/test/providers/cpu/math/softmax_test.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -422,8 +422,7 @@ TEST(SoftmaxOperator, GH15949_regression_test) {
422422
{0.00032932f, 0.01798029f, 0.9816904f});
423423

424424
// disable TRT as it does not support axis=0 as used by the model
425-
// TODO: Fix the Softmax operator of WebGPU EP.
426-
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kWebGpuExecutionProvider});
425+
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
427426
}
428427

429428
} // namespace test

0 commit comments

Comments
 (0)