Skip to content

Commit 1313dab

Browse files
committed
Fixed computed output shape of matmul op
1 parent 756228c commit 1313dab

1 file changed

Lines changed: 4 additions & 3 deletions

File tree

src/webnn/native/ops/Binary.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,12 @@ namespace webnn::native::op {
5454
}
5555
outputShape = {1};
5656
}
57-
if (rankA == 2 && rankB == 1) {
58-
if (inputShapeA[1] != inputShapeB[0]) {
57+
if (rankA >= 2 && rankB == 1) {
58+
if (inputShapeA[rankA - 1] != inputShapeB[0]) {
5959
return DAWN_VALIDATION_ERROR("The input shapes are incompatible.");
6060
}
61-
outputShape = {inputShapeA[0], 1};
61+
outputShape = std::move(inputShapeA);
62+
outputShape[rankA - 1] = 1;
6263
}
6364
if (rankA == 1 && rankB == 2) {
6465
if (inputShapeA[0] != inputShapeB[0]) {

0 commit comments

Comments
 (0)