File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -54,17 +54,19 @@ namespace webnn::native::op {
5454 }
5555 outputShape = {1 };
5656 }
57- if (rankA = = 2 && rankB == 1 ) {
58- if (inputShapeA[1 ] != inputShapeB[0 ]) {
57+ if (rankA > = 2 && rankB == 1 ) {
58+ if (inputShapeA[rankA - 1 ] != inputShapeB[0 ]) {
5959 return DAWN_VALIDATION_ERROR (" The input shapes are incompatible." );
6060 }
61- outputShape = {inputShapeA[0 ], 1 };
61+ outputShape = std::move (inputShapeA);
62+ outputShape[rankA - 1 ] = 1 ;
6263 }
63- if (rankA == 1 && rankB = = 2 ) {
64- if (inputShapeA[0 ] != inputShapeB[0 ]) {
64+ if (rankA == 1 && rankB > = 2 ) {
65+ if (inputShapeA[0 ] != inputShapeB[rankB - 2 ]) {
6566 return DAWN_VALIDATION_ERROR (" The input shapes are incompatible." );
6667 }
67- outputShape = {1 , inputShapeB[1 ]};
68+ outputShape = std::move (inputShapeB);
69+ outputShape[rankB - 2 ] = 1 ;
6870 }
6971 if (rankA >= 2 && rankB >= 2 ) {
7072 if (inputShapeA[rankA - 1 ] != inputShapeB[rankB - 2 ]) {
You can’t perform that action at this time.
0 commit comments