Skip to content

Commit b2e5e60

Browse files
committed
Added support for Gemm op -Supports Gemm operations
Signed-off-by: Vijeetkumar Benni <vijeetkumar.benni@intel.com
1 parent 3e3bbfc commit b2e5e60

1 file changed

Lines changed: 108 additions & 1 deletion

File tree

src/webnn/native/nnapi/GraphNnapi.cpp

Lines changed: 108 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -809,7 +809,114 @@ namespace webnn::native::nnapi {
809809
}
810810

811811
MaybeError Graph::AddGemm(const op::Gemm* gemm) {
812-
DAWN_TRY(CheckStatusCode(ANEURALNETWORKS_OP_FAILED, "nnapi gemm"));
812+
auto inputs = gemm->Inputs();
813+
const GemmOptions* options = gemm->GetOptions();
814+
815+
// inputs
816+
auto inputAOpIndex = mGraphNodeMap[gemm->Inputs()[0].Get()]; // A
817+
auto inputANodeInfo = mGraphOperandInfo[inputAOpIndex];
818+
auto inputBOpIndex = mGraphNodeMap[gemm->Inputs()[1].Get()]; // B
819+
auto inputBNodeInfo = mGraphOperandInfo[inputBOpIndex];
820+
821+
// output
822+
auto outputDims = gemm->PrimaryOutput()->Shape();
823+
824+
if (options->aTranspose) {
825+
NodeInfo transposedNodeA;
826+
memInt32Vec.emplace_back(new int(2));
827+
int32_t* permute = memInt32Vec.back().get();
828+
permute[0] = 1;
829+
permute[1] = 0;
830+
DAWN_TRY(AddTransposeImpl(inputANodeInfo, transposedNodeA, permute, 2));
831+
mGraphOperandInfo[transposedNodeA.opIndex] = transposedNodeA;
832+
inputANodeInfo = mGraphOperandInfo[transposedNodeA.opIndex];
833+
}
834+
if (options->bTranspose) {
835+
NodeInfo transposedNodeB;
836+
memInt32Vec.emplace_back(new int(2));
837+
int32_t* permute = memInt32Vec.back().get();
838+
permute[0] = 1;
839+
permute[1] = 0;
840+
DAWN_TRY(AddTransposeImpl(inputBNodeInfo, transposedNodeB, permute, 2));
841+
mGraphOperandInfo[transposedNodeB.opIndex] = transposedNodeB;
842+
inputBNodeInfo = mGraphOperandInfo[transposedNodeB.opIndex];
843+
}
844+
845+
// operation: gemm = alpha*A*B + beta*C
846+
// matMulNode = A*B
847+
NodeInfo matMulNode;
848+
matMulNode.type = inputANodeInfo.type;
849+
matMulNode.dimensions.resize(outputDims.size());
850+
DAWN_TRY(AddMatMulImpl(inputANodeInfo, inputBNodeInfo, matMulNode, outputDims,
851+
matMulNode.opIndex));
852+
853+
uint32_t outputOpIndex = 0;
854+
int32_t fuseCode = ANEURALNETWORKS_FUSED_NONE;
855+
uint32_t fuseCodeOpIndex = 0;
856+
DAWN_TRY(mNnapiMgr->CreateScalarOperand(ANEURALNETWORKS_INT32, &fuseCode, fuseCodeOpIndex));
857+
858+
if (options->alpha == 1)
859+
outputOpIndex = matMulNode.opIndex;
860+
else {
861+
float alpha = options->alpha;
862+
std::vector<float> alphaVec(1, alpha);
863+
NodeInfo alphaNode;
864+
alphaNode.type = ml::OperandType::Float32;
865+
alphaNode.dimensions = {1};
866+
DAWN_TRY(mNnapiMgr->CreateOperandAndSetMemory("alpha", &alphaNode, &alphaVec[0]));
867+
868+
// mulNode0 = alpha*matMulNode
869+
NodeInfo mulNode0;
870+
DAWN_TRY(CreateNode(mulNode0, matMulNode.type, gemm->PrimaryOutput()->Shape()));
871+
std::vector<uint32_t> inputList = {alphaNode.opIndex, matMulNode.opIndex,
872+
fuseCodeOpIndex};
873+
DAWN_TRY(mNnapiMgr->AddOperation(ANEURALNETWORKS_MUL, inputList.size(),
874+
inputList.data(), 1, &mulNode0.opIndex));
875+
mGraphOperandInfo[mulNode0.opIndex] = mulNode0;
876+
877+
outputOpIndex = mulNode0.opIndex;
878+
}
879+
880+
if (inputs.size() > 2) { // Check for C
881+
auto inputCOpIndex = mGraphNodeMap[gemm->Inputs()[2].Get()];
882+
auto inputCNodeInfo = mGraphOperandInfo[inputCOpIndex];
883+
884+
NodeInfo outputNode;
885+
DAWN_TRY(CreateNode(outputNode, inputANodeInfo.type, gemm->PrimaryOutput()->Shape()));
886+
887+
if (options->beta == 1) {
888+
// output = mulNode0 + NodeC
889+
std::vector<uint32_t> inputList4 = {outputOpIndex, inputCNodeInfo.opIndex,
890+
fuseCodeOpIndex};
891+
DAWN_TRY(mNnapiMgr->AddOperation(ANEURALNETWORKS_ADD, inputList4.size(),
892+
inputList4.data(), 1, &outputNode.opIndex));
893+
} else {
894+
// mulNode1 = beta*C
895+
float beta = options->beta;
896+
std::vector<float> betaVec(1, beta);
897+
NodeInfo betaNode;
898+
betaNode.type = ml::OperandType::Float32;
899+
betaNode.dimensions = {1};
900+
DAWN_TRY(mNnapiMgr->CreateOperandAndSetMemory("beta", &betaNode, &betaVec[0]));
901+
902+
NodeInfo mulNode1;
903+
DAWN_TRY(CreateNode(mulNode1, inputCNodeInfo.type, gemm->Inputs()[2]->Shape()));
904+
std::vector<uint32_t> inputList2 = {betaNode.opIndex, inputCNodeInfo.opIndex,
905+
fuseCodeOpIndex};
906+
DAWN_TRY(mNnapiMgr->AddOperation(ANEURALNETWORKS_MUL, inputList2.size(),
907+
inputList2.data(), 1, &mulNode1.opIndex));
908+
909+
// output = mulNode0 + mulNode1
910+
std::vector<uint32_t> inputList3 = {outputOpIndex, mulNode1.opIndex,
911+
fuseCodeOpIndex};
912+
DAWN_TRY(mNnapiMgr->AddOperation(ANEURALNETWORKS_ADD, inputList3.size(),
913+
inputList3.data(), 1, &outputNode.opIndex));
914+
}
915+
mGraphOperandInfo[outputNode.opIndex] = outputNode;
916+
mGraphNodeMap[gemm->PrimaryOutput()] = outputNode.opIndex;
917+
} else {
918+
mGraphNodeMap[gemm->PrimaryOutput()] = outputOpIndex;
919+
}
813920
return {};
814921
}
815922

0 commit comments

Comments
 (0)