@@ -809,7 +809,114 @@ namespace webnn::native::nnapi {
809809 }
810810
811811 MaybeError Graph::AddGemm (const op::Gemm* gemm) {
812- DAWN_TRY (CheckStatusCode (ANEURALNETWORKS_OP_FAILED, " nnapi gemm" ));
812+ auto inputs = gemm->Inputs ();
813+ const GemmOptions* options = gemm->GetOptions ();
814+
815+ // inputs
816+ auto inputAOpIndex = mGraphNodeMap [gemm->Inputs ()[0 ].Get ()]; // A
817+ auto inputANodeInfo = mGraphOperandInfo [inputAOpIndex];
818+ auto inputBOpIndex = mGraphNodeMap [gemm->Inputs ()[1 ].Get ()]; // B
819+ auto inputBNodeInfo = mGraphOperandInfo [inputBOpIndex];
820+
821+ // output
822+ auto outputDims = gemm->PrimaryOutput ()->Shape ();
823+
824+ if (options->aTranspose ) {
825+ NodeInfo transposedNodeA;
826+ memInt32Vec.emplace_back (new int (2 ));
827+ int32_t * permute = memInt32Vec.back ().get ();
828+ permute[0 ] = 1 ;
829+ permute[1 ] = 0 ;
830+ DAWN_TRY (AddTransposeImpl (inputANodeInfo, transposedNodeA, permute, 2 ));
831+ mGraphOperandInfo [transposedNodeA.opIndex ] = transposedNodeA;
832+ inputANodeInfo = mGraphOperandInfo [transposedNodeA.opIndex ];
833+ }
834+ if (options->bTranspose ) {
835+ NodeInfo transposedNodeB;
836+ memInt32Vec.emplace_back (new int (2 ));
837+ int32_t * permute = memInt32Vec.back ().get ();
838+ permute[0 ] = 1 ;
839+ permute[1 ] = 0 ;
840+ DAWN_TRY (AddTransposeImpl (inputBNodeInfo, transposedNodeB, permute, 2 ));
841+ mGraphOperandInfo [transposedNodeB.opIndex ] = transposedNodeB;
842+ inputBNodeInfo = mGraphOperandInfo [transposedNodeB.opIndex ];
843+ }
844+
845+ // operation: gemm = alpha*A*B + beta*C
846+ // matMulNode = A*B
847+ NodeInfo matMulNode;
848+ matMulNode.type = inputANodeInfo.type ;
849+ matMulNode.dimensions .resize (outputDims.size ());
850+ DAWN_TRY (AddMatMulImpl (inputANodeInfo, inputBNodeInfo, matMulNode, outputDims,
851+ matMulNode.opIndex ));
852+
853+ uint32_t outputOpIndex = 0 ;
854+ int32_t fuseCode = ANEURALNETWORKS_FUSED_NONE;
855+ uint32_t fuseCodeOpIndex = 0 ;
856+ DAWN_TRY (mNnapiMgr ->CreateScalarOperand (ANEURALNETWORKS_INT32, &fuseCode, fuseCodeOpIndex));
857+
858+ if (options->alpha == 1 )
859+ outputOpIndex = matMulNode.opIndex ;
860+ else {
861+ float alpha = options->alpha ;
862+ std::vector<float > alphaVec (1 , alpha);
863+ NodeInfo alphaNode;
864+ alphaNode.type = ml::OperandType::Float32;
865+ alphaNode.dimensions = {1 };
866+ DAWN_TRY (mNnapiMgr ->CreateOperandAndSetMemory (" alpha" , &alphaNode, &alphaVec[0 ]));
867+
868+ // mulNode0 = alpha*matMulNode
869+ NodeInfo mulNode0;
870+ DAWN_TRY (CreateNode (mulNode0, matMulNode.type , gemm->PrimaryOutput ()->Shape ()));
871+ std::vector<uint32_t > inputList = {alphaNode.opIndex , matMulNode.opIndex ,
872+ fuseCodeOpIndex};
873+ DAWN_TRY (mNnapiMgr ->AddOperation (ANEURALNETWORKS_MUL, inputList.size (),
874+ inputList.data (), 1 , &mulNode0.opIndex ));
875+ mGraphOperandInfo [mulNode0.opIndex ] = mulNode0;
876+
877+ outputOpIndex = mulNode0.opIndex ;
878+ }
879+
880+ if (inputs.size () > 2 ) { // Check for C
881+ auto inputCOpIndex = mGraphNodeMap [gemm->Inputs ()[2 ].Get ()];
882+ auto inputCNodeInfo = mGraphOperandInfo [inputCOpIndex];
883+
884+ NodeInfo outputNode;
885+ DAWN_TRY (CreateNode (outputNode, inputANodeInfo.type , gemm->PrimaryOutput ()->Shape ()));
886+
887+ if (options->beta == 1 ) {
888+ // output = mulNode0 + NodeC
889+ std::vector<uint32_t > inputList4 = {outputOpIndex, inputCNodeInfo.opIndex ,
890+ fuseCodeOpIndex};
891+ DAWN_TRY (mNnapiMgr ->AddOperation (ANEURALNETWORKS_ADD, inputList4.size (),
892+ inputList4.data (), 1 , &outputNode.opIndex ));
893+ } else {
894+ // mulNode1 = beta*C
895+ float beta = options->beta ;
896+ std::vector<float > betaVec (1 , beta);
897+ NodeInfo betaNode;
898+ betaNode.type = ml::OperandType::Float32;
899+ betaNode.dimensions = {1 };
900+ DAWN_TRY (mNnapiMgr ->CreateOperandAndSetMemory (" beta" , &betaNode, &betaVec[0 ]));
901+
902+ NodeInfo mulNode1;
903+ DAWN_TRY (CreateNode (mulNode1, inputCNodeInfo.type , gemm->Inputs ()[2 ]->Shape ()));
904+ std::vector<uint32_t > inputList2 = {betaNode.opIndex , inputCNodeInfo.opIndex ,
905+ fuseCodeOpIndex};
906+ DAWN_TRY (mNnapiMgr ->AddOperation (ANEURALNETWORKS_MUL, inputList2.size (),
907+ inputList2.data (), 1 , &mulNode1.opIndex ));
908+
909+ // output = mulNode0 + mulNode1
910+ std::vector<uint32_t > inputList3 = {outputOpIndex, mulNode1.opIndex ,
911+ fuseCodeOpIndex};
912+ DAWN_TRY (mNnapiMgr ->AddOperation (ANEURALNETWORKS_ADD, inputList3.size (),
913+ inputList3.data (), 1 , &outputNode.opIndex ));
914+ }
915+ mGraphOperandInfo [outputNode.opIndex ] = outputNode;
916+ mGraphNodeMap [gemm->PrimaryOutput ()] = outputNode.opIndex ;
917+ } else {
918+ mGraphNodeMap [gemm->PrimaryOutput ()] = outputOpIndex;
919+ }
813920 return {};
814921 }
815922
0 commit comments