Skip to content

Commit 6bfface

Browse files
committed
add support for Exp operator for Generic target
1 parent 52c1e73 commit 6bfface

11 files changed

Lines changed: 121 additions & 106 deletions

File tree

Deeploy/Targets/Generic/Bindings.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from Deeploy.FutureExtension.CodeTransformationPasses.FutureCodeTransformation import FutureGeneration
1414
from Deeploy.Targets.Generic.Templates import AddTemplate, BatchNormalizationTemplate, ConcatTemplate, ConvTemplate, \
1515
ConvTransposeTemplate, DebugPrintTemplate, DequantTemplate, DummyTemplate, DWConvTemplate, FloatAddTemplate, \
16-
FloatCeilTemplate, FloatClipTemplate, FloatConvTemplate, FloatDivTemplate, FloatDWConvTemplate, \
16+
FloatCeilTemplate, FloatClipTemplate, FloatConvTemplate, FloatDivTemplate, FloatDWConvTemplate, FloatExpTemplate, \
1717
FloatFloorTemplate, FloatGELUTemplate, FloatGemmTemplate, FloatLayernormTemplate, FloatMatMulTemplate, \
1818
FloatMaxPoolTemplate, FloatMulTemplate, FloatPadTemplate, FloatPowTemplate, FloatReduceMeanTemplate, \
1919
FloatReluTemplate, FloatSoftmaxTemplate, FloatSqrtTemplate, FloatSubTemplate, GatherTemplate, GemmTemplate, \
@@ -357,3 +357,8 @@
357357
PointerClass(float32_t)], [PointerClass(float32_t)]), FloatClipTemplate.referenceTemplate,
358358
BasicTransformer),
359359
]
360+
361+
BasicExpBindings = [
362+
NodeBinding(DummyChecker([PointerClass(float32_t)], [PointerClass(float32_t)]), FloatExpTemplate.referenceTemplate,
363+
BasicTransformer),
364+
]

Deeploy/Targets/Generic/Layers.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -715,18 +715,16 @@ def computeOps(self):
715715

716716

717717
class CeilLayer(ONNXLayer):
718-
719-
def __init__(self, maps: List[NodeMapper]):
720-
super().__init__(maps)
718+
pass
721719

722720

723721
class FloorLayer(ONNXLayer):
724-
725-
def __init__(self, maps: List[NodeMapper]):
726-
super().__init__(maps)
722+
pass
727723

728724

729725
class ClipLayer(ONNXLayer):
726+
pass
730727

731-
def __init__(self, maps: List[NodeMapper]):
732-
super().__init__(maps)
728+
729+
class ExpLayer(ONNXLayer):
730+
pass

Deeploy/Targets/Generic/Parsers.py

Lines changed: 37 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,23 @@
1111
from Deeploy.DeeployTypes import ConstantBuffer, NetworkContext, NodeParser, VariableBuffer
1212

1313

14+
class UnaryElementWiseParser(NodeParser):
15+
16+
def parseNode(self, node: gs.Node) -> bool:
17+
return len(node.inputs) == 1 and len(node.outputs) == 1
18+
19+
def parseNodeCtxt(self,
20+
ctxt: NetworkContext,
21+
node: gs.Node,
22+
channels_first: bool = True) -> Tuple[NetworkContext, bool]:
23+
data_in = ctxt.lookup(node.inputs[0].name)
24+
data_out = ctxt.lookup(node.outputs[0].name)
25+
self.operatorRepresentation['data_in'] = data_in.name
26+
self.operatorRepresentation['data_out'] = data_out.name
27+
self.operatorRepresentation['size'] = int(np.prod(data_in.shape))
28+
return ctxt, True
29+
30+
1431
class ConcatParser(NodeParser):
1532

1633
def __init__(self):
@@ -1095,29 +1112,10 @@ def parseNodeCtxt(self,
10951112
return ctxt, True
10961113

10971114

1098-
class ReluParser(NodeParser):
1099-
1100-
def __init__(self):
1101-
super().__init__()
1102-
1103-
def parseNode(self, node: gs.Node) -> (bool):
1104-
1105-
ret = all([len(node.inputs) == 1, len(node.outputs) == 1])
1106-
1107-
return ret
1108-
1109-
def parseNodeCtxt(self,
1110-
ctxt: NetworkContext,
1111-
node: gs.Node,
1112-
channels_first: bool = True) -> Tuple[NetworkContext, bool]:
1113-
1114-
data_in = ctxt.lookup(node.inputs[0].name)
1115-
data_out = ctxt.lookup(node.outputs[0].name)
1116-
self.operatorRepresentation['data_in'] = data_in.name
1117-
self.operatorRepresentation['data_out'] = data_out.name
1118-
self.operatorRepresentation['size'] = np.prod(data_in.shape)
1115+
class ReluParser(UnaryElementWiseParser):
11191116

1120-
return ctxt, True
1117+
def parseNode(self, node: gs.Node) -> bool:
1118+
return super().parseNode(node) and node.op == 'Relu'
11211119

11221120

11231121
class ReshapeParser(NodeParser):
@@ -2868,79 +2866,28 @@ def parseNodeCtxt(self,
28682866
return ctxt, False
28692867

28702868

2871-
class SqrtParser(NodeParser):
2872-
2873-
def __init__(self):
2874-
super().__init__()
2869+
class SqrtParser(UnaryElementWiseParser):
28752870

28762871
def parseNode(self, node: gs.Node) -> bool:
2877-
return node.op == 'Sqrt' and len(node.inputs) == 1 and len(node.outputs) == 1
2878-
2879-
def parseNodeCtxt(self,
2880-
ctxt: NetworkContext,
2881-
node: gs.Node,
2882-
channels_first: bool = True) -> Tuple[NetworkContext, bool]:
2872+
return super().parseNode(node) and node.op == 'Sqrt'
28832873

2884-
data_in = ctxt.lookup(node.inputs[0].name)
2885-
data_out = ctxt.lookup(node.outputs[0].name)
28862874

2887-
self.operatorRepresentation['data_in'] = data_in.name
2888-
self.operatorRepresentation['data_out'] = data_out.name
2889-
self.operatorRepresentation['size'] = int(np.prod(data_in.shape))
2890-
2891-
return ctxt, True
2892-
2893-
2894-
class CeilParser(NodeParser):
2895-
2896-
def __init__(self):
2897-
super().__init__()
2875+
class CeilParser(UnaryElementWiseParser):
28982876

28992877
def parseNode(self, node: gs.Node) -> bool:
2900-
return node.op == 'Ceil' and len(node.inputs) == 1 and len(node.outputs) == 1
2878+
return super().parseNode(node) and node.op == 'Ceil'
29012879

2902-
def parseNodeCtxt(self,
2903-
ctxt: NetworkContext,
2904-
node: gs.Node,
2905-
channels_first: bool = True) -> Tuple[NetworkContext, bool]:
29062880

2907-
data_in = ctxt.lookup(node.inputs[0].name)
2908-
data_out = ctxt.lookup(node.outputs[0].name)
2909-
2910-
self.operatorRepresentation['data_in'] = data_in.name
2911-
self.operatorRepresentation['data_out'] = data_out.name
2912-
self.operatorRepresentation['size'] = int(np.prod(data_in.shape))
2913-
return ctxt, True
2914-
2915-
2916-
class FloorParser(NodeParser):
2917-
2918-
def __init__(self):
2919-
super().__init__()
2881+
class FloorParser(UnaryElementWiseParser):
29202882

29212883
def parseNode(self, node: gs.Node) -> bool:
2922-
return node.op == 'Floor' and len(node.inputs) == 1 and len(node.outputs) == 1
2884+
return super().parseNode(node) and node.op == 'Floor'
29232885

2924-
def parseNodeCtxt(self,
2925-
ctxt: NetworkContext,
2926-
node: gs.Node,
2927-
channels_first: bool = True) -> Tuple[NetworkContext, bool]:
2928-
2929-
data_in = ctxt.lookup(node.inputs[0].name)
2930-
data_out = ctxt.lookup(node.outputs[0].name)
2931-
2932-
self.operatorRepresentation['data_in'] = data_in.name
2933-
self.operatorRepresentation['data_out'] = data_out.name
2934-
self.operatorRepresentation['size'] = int(np.prod(data_in.shape))
2935-
return ctxt, True
2936-
2937-
2938-
class ClipParser(NodeParser):
29392886

2940-
def __init__(self):
2941-
super().__init__()
2887+
class ClipParser(UnaryElementWiseParser):
29422888

29432889
def parseNode(self, node: gs.Node) -> bool:
2890+
# Clip allows 1–3 inputs (optional min/max constants), so we can't use super()
29442891
if node.op != 'Clip' \
29452892
or len(node.outputs) != 1 \
29462893
or (not (1 <= len(node.inputs) <= 3)):
@@ -2952,11 +2899,9 @@ def parseNodeCtxt(self,
29522899
node: gs.Node,
29532900
channels_first: bool = True) -> Tuple[NetworkContext, bool]:
29542901

2955-
data_in = ctxt.lookup(node.inputs[0].name)
2956-
data_out = ctxt.lookup(node.outputs[0].name)
2957-
self.operatorRepresentation['data_in'] = data_in.name
2958-
self.operatorRepresentation['data_out'] = data_out.name
2959-
self.operatorRepresentation['size'] = int(np.prod(data_in.shape))
2902+
ctxt, ok = super().parseNodeCtxt(ctxt, node, channels_first)
2903+
if not ok:
2904+
return ctxt, False
29602905

29612906
# min_val and max_val only handled as constants
29622907
# Defaults: full float32 range
@@ -2969,3 +2914,9 @@ def parseNodeCtxt(self,
29692914
self.operatorRepresentation['max_val'] = float(node.inputs[2].values.item())
29702915

29712916
return ctxt, True
2917+
2918+
2919+
class ExpParser(UnaryElementWiseParser):
2920+
2921+
def parseNode(self, node: gs.Node) -> bool:
2922+
return super().parseNode(node) and node.op == 'Exp'

Deeploy/Targets/Generic/Platform.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,19 @@
99
from Deeploy.Targets.Generic.Bindings import BasicAddBindings, BasicBatchNormBindings, BasicCeilBindings, \
1010
BasicClipBindings, BasicConcatBindings, BasicConv1DBindings, BasicConv2DBindings, BasicConvTransposeBindings, \
1111
BasicDebugPrintBindings, BasicDequantBindings, BasicDivBindings, BasicDWConv1DBinding, BasicDWConv2DBindings, \
12-
BasicFloorBindings, BasicGatherBindings, BasicGELUBindings, BasicGEMMBindings, BasicITAPartialSoftmaxBinding, \
13-
BasicITASoftmaxBinding, BasicLayerNormBindings, BasicMatMulBindings, BasicMaxPool1DBindings, \
14-
BasicMaxPool2DBindings, BasicMulBindings, BasicPad1DBindings, BasicPad2DBindings, BasicPowBindings, \
15-
BasicQuantBindings, BasicReduceMeanBindings, BasicReduceSumBindings, BasicReluBinding, BasicReshapeBindings, \
16-
BasicRQIntegerDivBinding, BasicRQSBindings, BasicRQSGELUBinding, BasicSliceBindings, BasicSoftmaxBindings, \
17-
BasicSqrtBindings, BasicSubBindings, BasicTransposeBindings, DummyBinding
12+
BasicExpBindings, BasicFloorBindings, BasicGatherBindings, BasicGELUBindings, BasicGEMMBindings, \
13+
BasicITAPartialSoftmaxBinding, BasicITASoftmaxBinding, BasicLayerNormBindings, BasicMatMulBindings, \
14+
BasicMaxPool1DBindings, BasicMaxPool2DBindings, BasicMulBindings, BasicPad1DBindings, BasicPad2DBindings, \
15+
BasicPowBindings, BasicQuantBindings, BasicReduceMeanBindings, BasicReduceSumBindings, BasicReluBinding, \
16+
BasicReshapeBindings, BasicRQIntegerDivBinding, BasicRQSBindings, BasicRQSGELUBinding, BasicSliceBindings, \
17+
BasicSoftmaxBindings, BasicSqrtBindings, BasicSubBindings, BasicTransposeBindings, DummyBinding
1818
from Deeploy.Targets.Generic.Layers import AddLayer, BatchNormalizationLayer, CeilLayer, ClipLayer, ConcatLayer, \
19-
ConvLayer, ConvTransposeLayer, DebugPrintLayer, DequantLayer, DivLayer, FloorLayer, GatherLayer, GELULayer, \
20-
GEMMLayer, ITAMaxLayer, LayerNormLayer, MatMulLayer, MaxPoolLayer, MulLayer, PadLayer, PowLayer, QuantLayer, \
21-
ReduceMeanLayer, ReduceSumLayer, ReluLayer, RequantShiftLayer, ReshapeLayer, RQIntegerDivLayer, RQSiGELULayer, \
22-
SliceLayer, SoftmaxLayer, SqrtLayer, SubLayer, TransposeLayer
19+
ConvLayer, ConvTransposeLayer, DebugPrintLayer, DequantLayer, DivLayer, ExpLayer, FloorLayer, GatherLayer, \
20+
GELULayer, GEMMLayer, ITAMaxLayer, LayerNormLayer, MatMulLayer, MaxPoolLayer, MulLayer, PadLayer, PowLayer, \
21+
QuantLayer, ReduceMeanLayer, ReduceSumLayer, ReluLayer, RequantShiftLayer, ReshapeLayer, RQIntegerDivLayer, \
22+
RQSiGELULayer, SliceLayer, SoftmaxLayer, SqrtLayer, SubLayer, TransposeLayer
2323
from Deeploy.Targets.Generic.Parsers import AddParser, BatchNormParser, CeilParser, ClipParser, ConcatParser, \
24-
ConvTranspose1DParser, DebugParser, DequantParser, DivParser, DummyParser, FlattenParser, FloorParser, \
24+
ConvTranspose1DParser, DebugParser, DequantParser, DivParser, DummyParser, ExpParser, FlattenParser, FloorParser, \
2525
GatherParser, GELUParser, GenericConv1DParser, GenericConv2DParser, GenericDWConv1DParser, GenericDWConv2DParser, \
2626
GenericGEMMParser, GenericMaxPool2DParser, IntegerDivParser, ITAMaxParser, ITAPartialMaxParser, LayerNormParser, \
2727
MatMulParser, MaxPool1DParser, MulParser, Pad1DParser, Pad2DParser, PowParser, QuantParser, ReduceMeanParser, \
@@ -77,6 +77,7 @@
7777
CeilMapper = NodeMapper(CeilParser(), BasicCeilBindings)
7878
FloorMapper = NodeMapper(FloorParser(), BasicFloorBindings)
7979
ClipMapper = NodeMapper(ClipParser(), BasicClipBindings)
80+
ExpMapper = NodeMapper(ExpParser(), BasicExpBindings)
8081

8182
# Dummy nodes are intended for development purposes only!
8283
# They should always generate compiler errors to not accidentally end up in production code
@@ -127,6 +128,7 @@
127128
'Ceil': CeilLayer([CeilMapper]),
128129
'Floor': FloorLayer([FloorMapper]),
129130
'Clip': ClipLayer([ClipMapper]),
131+
'Exp': ExpLayer([ExpMapper]),
130132
# # For example, you can use the DummpyMapper, in case you want to test
131133
# # deployment or optimizations with GlobalAveragePool nodes but did not yet
132134
# # implement the corresponding kernel
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# SPDX-FileCopyrightText: 2021 ETH Zurich and University of Bologna
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
import numpy as np
5+
6+
from Deeploy.DeeployTypes import NetworkContext, NodeTemplate, OperatorRepresentation
7+
8+
9+
class _ExpTemplate(NodeTemplate):
10+
11+
def alignToContext(self, ctxt: NetworkContext,
12+
operatorRepresentation: OperatorRepresentation) -> tuple[NetworkContext, dict, list[str]]:
13+
14+
data_in = ctxt.lookup(operatorRepresentation['data_in'])
15+
operatorRepresentation['size'] = int(np.prod(data_in.shape))
16+
operatorRepresentation['type_width'] = data_in._type.referencedType.typeWidth
17+
return ctxt, operatorRepresentation, []
18+
19+
20+
referenceTemplate = _ExpTemplate("""
21+
// Exp (Name: ${nodeName}, Op: ${nodeOp})
22+
Exp_fp${type_width}_fp${type_width}(${data_in}, ${data_out}, ${size});
23+
""")
776 Bytes
Binary file not shown.
120 Bytes
Binary file not shown.
778 Bytes
Binary file not shown.

TargetLibraries/Generic/inc/DeeployBasicMath.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
#include "kernel/Convolution.h"
4040
#include "kernel/DWConvolution.h"
4141
#include "kernel/Div.h"
42+
#include "kernel/Exp.h"
4243
#include "kernel/Floor.h"
4344
#include "kernel/GELU.h"
4445
#include "kernel/Gemm.h"
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
/*
2+
* SPDX-FileCopyrightText: 2020 ETH Zurich and University of Bologna
3+
*
4+
* SPDX-License-Identifier: Apache-2.0
5+
*/
6+
7+
#ifndef __DEEPLOY_BASIC_MATH_EXP_KERNEL_HEADER_
8+
#define __DEEPLOY_BASIC_MATH_EXP_KERNEL_HEADER_
9+
10+
#include "DeeployBasicMath.h"
11+
12+
/*
13+
* element wise exponential
14+
*/
15+
16+
/******************************************************************************/
17+
/* Exp */
18+
/******************************************************************************/
19+
void Exp_fp32_fp32(float32_t *data_in, float32_t *data_out, int32_t size);
20+
21+
#endif //__DEEPLOY_BASIC_MATH_EXP_KERNEL_HEADER_

0 commit comments

Comments
 (0)