Skip to content

Commit b57cde9

Browse files
committed
Fixed computed output shape of matmul op
1 parent 756228c commit b57cde9

1 file changed

Lines changed: 8 additions & 6 deletions

File tree

src/webnn/native/ops/Binary.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,17 +54,19 @@ namespace webnn::native::op {
5454
}
5555
outputShape = {1};
5656
}
57-
if (rankA == 2 && rankB == 1) {
58-
if (inputShapeA[1] != inputShapeB[0]) {
57+
if (rankA >= 2 && rankB == 1) {
58+
if (inputShapeA[rankA - 1] != inputShapeB[0]) {
5959
return DAWN_VALIDATION_ERROR("The input shapes are incompatible.");
6060
}
61-
outputShape = {inputShapeA[0], 1};
61+
outputShape = std::move(inputShapeA);
62+
outputShape[rankA - 1] = 1;
6263
}
63-
if (rankA == 1 && rankB == 2) {
64-
if (inputShapeA[0] != inputShapeB[0]) {
64+
if (rankA == 1 && rankB >= 2) {
65+
if (inputShapeA[0] != inputShapeB[rankB - 2]) {
6566
return DAWN_VALIDATION_ERROR("The input shapes are incompatible.");
6667
}
67-
outputShape = {1, inputShapeB[1]};
68+
outputShape = std::move(inputShapeB);
69+
outputShape[rankB - 2] = 1;
6870
}
6971
if (rankA >= 2 && rankB >= 2) {
7072
if (inputShapeA[rankA - 1] != inputShapeB[rankB - 2]) {

0 commit comments

Comments
 (0)