Skip to content

Commit 97201e1

Browse files
committed
[tmva][sofie] Implement binary operator creation in C++
Re-implement the binary operator creation in the Keras parser code path in C++, so the Python side Keras parser implementation doesn't have to deal with SOFIE implementation details like the `EBasicBinaryOperator` or which binary operator types are supported. This also fixes a logic error where a binary operator with any name was falling back to multiplication.
1 parent 9d097fb commit 97201e1

3 files changed

Lines changed: 25 additions & 20 deletions

File tree

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,5 @@
11
def MakeKerasBinary(layer):
22
from ROOT.TMVA.Experimental import SOFIE
33

4-
input = layer["layerInput"]
5-
output = layer["layerOutput"]
6-
fLayerType = layer["layerType"]
7-
fLayerDType = layer["layerDType"]
8-
fX1 = input[0]
9-
fX2 = input[1]
10-
fY = output[0]
11-
op = None
12-
if SOFIE.ConvertStringToType(fLayerDType) == SOFIE.ETensorType.FLOAT:
13-
if fLayerType == "Add":
14-
op = SOFIE.ROperator_BasicBinary(float, SOFIE.EBasicBinaryOperator.Add)(fX1, fX2, fY)
15-
elif fLayerType == "Subtract":
16-
op = SOFIE.ROperator_BasicBinary(float, SOFIE.EBasicBinaryOperator.Sub)(fX1, fX2, fY)
17-
else:
18-
op = SOFIE.ROperator_BasicBinary(float, SOFIE.EBasicBinaryOperator.Mul)(fX1, fX2, fY)
19-
else:
20-
raise RuntimeError(
21-
"TMVA::SOFIE - Unsupported - Operator BasicBinary does not yet support input type " + fLayerDType
22-
)
23-
return op
4+
inpt = layer["layerInput"]
5+
return SOFIE.createBasicBinary(layer["layerDType"], layer["layerType"], inpt[0], inpt[1], layer["layerOutput"][0])

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_tmva/_sofie/_parser/_keras/parser.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,11 @@ def move_operator(op):
9999
"""
100100
import ROOT
101101

102+
# If the object is already held by a smart pointer, just move it.
103+
smartptr = op.__smartptr__()
104+
if smartptr:
105+
return type(smartptr)(ROOT.std.move(smartptr))
106+
102107
ROOT.SetOwnership(op, False)
103108
return ROOT.std.unique_ptr[type(op)](op)
104109

tmva/sofie/inc/TMVA/ROperator_BasicBinary.hxx

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,24 @@ public:
455455
}
456456
};
457457

458+
inline std::unique_ptr<ROperator> createBasicBinary(std::string layerDType, std::string layerType, std::string nameA,
459+
std::string nameB, std::string nameY)
460+
{
461+
if (ConvertStringToType(layerDType) != ETensorType::FLOAT) {
462+
throw std::runtime_error(
463+
("TMVA::SOFIE - Unsupported - Operator BasicBinary does not yet support input type " + layerDType).c_str());
464+
}
465+
if (layerType == "Add")
466+
return std::make_unique<ROperator_BasicBinary<float, EBasicBinaryOperator::Add>>(nameA, nameB, nameY);
467+
if (layerType == "Subtract")
468+
return std::make_unique<ROperator_BasicBinary<float, EBasicBinaryOperator::Sub>>(nameA, nameB, nameY);
469+
if (layerType == "Multiply")
470+
return std::make_unique<ROperator_BasicBinary<float, EBasicBinaryOperator::Mul>>(nameA, nameB, nameY);
471+
472+
throw std::runtime_error(
473+
("TMVA::SOFIE - Unsupported - Operator BasicBinary does not yet support layer type " + layerType).c_str());
474+
}
475+
458476
} // namespace SOFIE
459477
} // namespace Experimental
460478
} // namespace TMVA

0 commit comments

Comments
 (0)