Skip to content

Commit f83550f

Browse files
committed
add support for AveragePool, GlobalAveragePool, and GlobalMaxPool operators for Generic target
1 parent cf7d2ca commit f83550f

27 files changed

Lines changed: 474 additions & 21 deletions

Deeploy/Targets/Generic/Bindings.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
from Deeploy.FutureExtension.CodeTransformationPasses.FutureCodeTransformation import FutureGeneration
1414
from Deeploy.Targets.Generic.Templates import AddTemplate, BatchNormalizationTemplate, ConcatTemplate, ConvTemplate, \
1515
ConvTransposeTemplate, DebugPrintTemplate, DequantTemplate, DummyTemplate, DWConvTemplate, FloatAddTemplate, \
16-
FloatCeilTemplate, FloatClipTemplate, FloatConvTemplate, FloatDivTemplate, FloatDWConvTemplate, FloatExpTemplate, \
17-
FloatFloorTemplate, FloatGELUTemplate, FloatGemmTemplate, FloatGroupNormTemplate, FloatHardSigmoidTemplate, \
16+
FloatAveragePoolTemplate, FloatCeilTemplate, FloatClipTemplate, FloatConvTemplate, FloatDivTemplate, \
17+
FloatDWConvTemplate, FloatExpTemplate, FloatFloorTemplate, FloatGELUTemplate, FloatGemmTemplate, \
18+
FloatGlobalAveragePoolTemplate, FloatGlobalMaxPoolTemplate, FloatGroupNormTemplate, FloatHardSigmoidTemplate, \
1819
FloatHardSwishTemplate, FloatInstanceNormTemplate, FloatLayernormTemplate, FloatMatMulTemplate, \
1920
FloatMaxPoolTemplate, FloatMulTemplate, FloatPadTemplate, FloatPowTemplate, FloatReduceMeanTemplate, \
2021
FloatReluTemplate, FloatSigmoidTemplate, FloatSoftmaxTemplate, FloatSqrtTemplate, FloatSubTemplate, \
@@ -399,3 +400,23 @@
399400
PointerClass(float32_t)], [PointerClass(float32_t)]), FloatGroupNormTemplate.referenceTemplate,
400401
BasicTransformer),
401402
]
403+
404+
BasicAveragePool1DBindings = [
405+
NodeBinding(DummyChecker([PointerClass(float32_t)], [PointerClass(float32_t)]),
406+
FloatAveragePoolTemplate.referenceTemplate1d, BasicTransformer)
407+
]
408+
409+
BasicAveragePool2DBindings = [
410+
NodeBinding(DummyChecker([PointerClass(float32_t)], [PointerClass(float32_t)]),
411+
FloatAveragePoolTemplate.referenceTemplate2d, BasicTransformer)
412+
]
413+
414+
BasicGlobalAveragePoolBindings = [
415+
NodeBinding(DummyChecker([PointerClass(float32_t)], [PointerClass(float32_t)]),
416+
FloatGlobalAveragePoolTemplate.referenceTemplate, BasicTransformer)
417+
]
418+
419+
BasicGlobalMaxPoolBindings = [
420+
NodeBinding(DummyChecker([PointerClass(float32_t)], [PointerClass(float32_t)]),
421+
FloatGlobalMaxPoolTemplate.referenceTemplate, BasicTransformer)
422+
]

Deeploy/Targets/Generic/Layers.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -752,3 +752,15 @@ class InstanceNormLayer(ONNXLayer):
752752

753753
class GroupNormLayer(ONNXLayer):
754754
pass
755+
756+
757+
class AveragePoolLayer(ONNXLayer):
758+
pass
759+
760+
761+
class GlobalAveragePoolLayer(ONNXLayer):
762+
pass
763+
764+
765+
class GlobalMaxPoolLayer(ONNXLayer):
766+
pass

Deeploy/Targets/Generic/Parsers.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3000,3 +3000,120 @@ def parseNode(self, node: gs.Node) -> bool:
30003000
return False
30013001
self.operatorRepresentation['num_groups'] = node.attrs['num_groups']
30023002
return True
3003+
3004+
3005+
class AveragePoolParser(NodeParser):
3006+
3007+
def parseNode(self, node: gs.Node) -> bool:
3008+
3009+
if not all([
3010+
node.op == 'AveragePool',
3011+
len(node.inputs) == 1,
3012+
len(node.outputs) == 1,
3013+
'kernel_shape' in node.attrs,
3014+
]):
3015+
return False
3016+
3017+
kernel_shape = node.attrs['kernel_shape']
3018+
spatial_ndim = len(kernel_shape)
3019+
3020+
auto_pad = node.attrs.get('auto_pad', 'NOTSET')
3021+
ceil_mode = node.attrs.get('ceil_mode', 0)
3022+
count_include_pad = node.attrs.get('count_include_pad ', 0)
3023+
dilations = node.attrs.get('dilations', (1,) * spatial_ndim)
3024+
strides = node.attrs.get('strides', (1,) * spatial_ndim)
3025+
pads = node.attrs.get('pads', (0,) * (2 * spatial_ndim))
3026+
3027+
if not all([
3028+
auto_pad == 'NOTSET', # TODO: implement other values
3029+
ceil_mode == 0, # TODO: implement other values
3030+
count_include_pad == 0, # TODO: implement other values
3031+
all([d == 1 for d in dilations]), # TODO: implement other values
3032+
len(dilations) == spatial_ndim,
3033+
len(strides) == spatial_ndim,
3034+
len(pads) == 2 * spatial_ndim,
3035+
all([s > 0 for s in strides]),
3036+
]):
3037+
return False
3038+
3039+
self.operatorRepresentation['kernel_shape'] = kernel_shape
3040+
self.operatorRepresentation['auto_pad'] = auto_pad
3041+
self.operatorRepresentation['ceil_mode'] = ceil_mode
3042+
self.operatorRepresentation['count_include_pad'] = count_include_pad
3043+
self.operatorRepresentation['dilations'] = dilations
3044+
self.operatorRepresentation['strides'] = strides
3045+
self.operatorRepresentation['pads'] = pads
3046+
3047+
return True
3048+
3049+
def parseNodeCtxt(self,
3050+
ctxt: NetworkContext,
3051+
node: gs.Node,
3052+
channels_first: bool = True) -> Tuple[NetworkContext, bool]:
3053+
3054+
data_in = ctxt.lookup(node.inputs[0].name)
3055+
data_out = ctxt.lookup(node.outputs[0].name)
3056+
self.operatorRepresentation['data_in'] = data_in.name
3057+
self.operatorRepresentation['data_out'] = data_out.name
3058+
3059+
self.operatorRepresentation['batch_size'] = data_in.shape[0]
3060+
self.operatorRepresentation['num_channels'] = data_in.shape[1]
3061+
3062+
spatial_shape = data_in.shape[2:]
3063+
if len(self.operatorRepresentation['kernel_shape']) != len(spatial_shape):
3064+
return ctxt, False
3065+
3066+
if len(spatial_shape) == 1:
3067+
self.operatorRepresentation['length'] = spatial_shape[0]
3068+
elif len(spatial_shape) == 2:
3069+
self.operatorRepresentation['height'] = spatial_shape[0]
3070+
self.operatorRepresentation['width'] = spatial_shape[1]
3071+
else:
3072+
return ctxt, False
3073+
3074+
return ctxt, True
3075+
3076+
3077+
class AveragePool1DParser(AveragePoolParser):
3078+
3079+
def parseNode(self, node: gs.Node) -> bool:
3080+
return super().parseNode(node) and len(node.attrs['kernel_shape']) == 1
3081+
3082+
3083+
class AveragePool2DParser(AveragePoolParser):
3084+
3085+
def parseNode(self, node: gs.Node) -> bool:
3086+
return super().parseNode(node) and len(node.attrs['kernel_shape']) == 2
3087+
3088+
3089+
class GlobalPoolParser(NodeParser):
3090+
3091+
def parseNode(self, node: gs.Node) -> bool:
3092+
return len(node.inputs) == 1 and len(node.outputs) == 1
3093+
3094+
def parseNodeCtxt(self,
3095+
ctxt: NetworkContext,
3096+
node: gs.Node,
3097+
channels_first: bool = True) -> Tuple[NetworkContext, bool]:
3098+
3099+
data_in = ctxt.lookup(node.inputs[0].name)
3100+
data_out = ctxt.lookup(node.outputs[0].name)
3101+
self.operatorRepresentation['data_in'] = data_in.name
3102+
self.operatorRepresentation['data_out'] = data_out.name
3103+
self.operatorRepresentation['batch_size'] = data_in.shape[0]
3104+
self.operatorRepresentation['num_channels'] = data_in.shape[1]
3105+
self.operatorRepresentation['spatial_size'] = np.prod(data_in.shape[2:])
3106+
3107+
return ctxt, True
3108+
3109+
3110+
class GlobalAveragePoolParser(GlobalPoolParser):
3111+
3112+
def parseNode(self, node: gs.Node) -> bool:
3113+
return super().parseNode(node) and node.op == 'GlobalAveragePool'
3114+
3115+
3116+
class GlobalMaxPoolParser(GlobalPoolParser):
3117+
3118+
def parseNode(self, node: gs.Node) -> bool:
3119+
return super().parseNode(node) and node.op == 'GlobalMaxPool'

Deeploy/Targets/Generic/Platform.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,32 +6,33 @@
66
RemoveEmptyConvBiasPass, RemoveOnlySingletonReduceMeanPass
77
from Deeploy.DeeployTypes import ConstantBuffer, DeploymentEngine, DeploymentPlatform, NodeMapper, NodeTemplate, \
88
StructBuffer, TopologyOptimizer, TransientBuffer, VariableBuffer
9-
from Deeploy.Targets.Generic.Bindings import BasicAddBindings, BasicBatchNormBindings, BasicCeilBindings, \
10-
BasicClipBindings, BasicConcatBindings, BasicConv1DBindings, BasicConv2DBindings, BasicConvTransposeBindings, \
11-
BasicDebugPrintBindings, BasicDequantBindings, BasicDivBindings, BasicDWConv1DBinding, BasicDWConv2DBindings, \
12-
BasicExpBindings, BasicFloorBindings, BasicGatherBindings, BasicGELUBindings, BasicGEMMBindings, \
9+
from Deeploy.Targets.Generic.Bindings import BasicAddBindings, BasicAveragePool1DBindings, BasicAveragePool2DBindings, \
10+
BasicBatchNormBindings, BasicCeilBindings, BasicClipBindings, BasicConcatBindings, BasicConv1DBindings, \
11+
BasicConv2DBindings, BasicConvTransposeBindings, BasicDebugPrintBindings, BasicDequantBindings, BasicDivBindings, \
12+
BasicDWConv1DBinding, BasicDWConv2DBindings, BasicExpBindings, BasicFloorBindings, BasicGatherBindings, \
13+
BasicGELUBindings, BasicGEMMBindings, BasicGlobalAveragePoolBindings, BasicGlobalMaxPoolBindings, \
1314
BasicGroupNormBindings, BasicHardSigmoidBindings, BasicHardSwishBindings, BasicInstanceNormBindings, \
1415
BasicITAPartialSoftmaxBinding, BasicITASoftmaxBinding, BasicLayerNormBindings, BasicMatMulBindings, \
1516
BasicMaxPool1DBindings, BasicMaxPool2DBindings, BasicMulBindings, BasicPad1DBindings, BasicPad2DBindings, \
1617
BasicPowBindings, BasicQuantBindings, BasicReduceMeanBindings, BasicReduceSumBindings, BasicReluBinding, \
1718
BasicReshapeBindings, BasicRQIntegerDivBinding, BasicRQSBindings, BasicRQSGELUBinding, BasicSigmoidBindings, \
1819
BasicSliceBindings, BasicSoftmaxBindings, BasicSqrtBindings, BasicSubBindings, BasicSwishBindings, \
1920
BasicTransposeBindings, DummyBinding
20-
from Deeploy.Targets.Generic.Layers import AddLayer, BatchNormalizationLayer, CeilLayer, ClipLayer, ConcatLayer, \
21-
ConvLayer, ConvTransposeLayer, DebugPrintLayer, DequantLayer, DivLayer, ExpLayer, FloorLayer, GatherLayer, \
22-
GELULayer, GEMMLayer, GroupNormLayer, InstanceNormLayer, ITAMaxLayer, LayerNormLayer, MatMulLayer, MaxPoolLayer, \
23-
MulLayer, PadLayer, PowLayer, QuantLayer, ReduceMeanLayer, ReduceSumLayer, ReluLayer, RequantShiftLayer, \
24-
ReshapeLayer, RQIntegerDivLayer, RQSiGELULayer, SigmoidLayer, SliceLayer, SoftmaxLayer, SqrtLayer, SubLayer, \
25-
SwishLayer, TransposeLayer
26-
from Deeploy.Targets.Generic.Parsers import AddParser, BatchNormParser, CeilParser, ClipParser, ConcatParser, \
27-
ConvTranspose1DParser, DebugParser, DequantParser, DivParser, DummyParser, ExpParser, FlattenParser, FloorParser, \
28-
GatherParser, GELUParser, GenericConv1DParser, GenericConv2DParser, GenericDWConv1DParser, GenericDWConv2DParser, \
29-
GenericGEMMParser, GenericMaxPool2DParser, GroupNormParser, HardSigmoidParser, HardSwishParser, \
30-
InstanceNormParser, IntegerDivParser, ITAMaxParser, ITAPartialMaxParser, LayerNormParser, MatMulParser, \
31-
MaxPool1DParser, MulParser, Pad1DParser, Pad2DParser, PowParser, QuantParser, ReduceMeanParser, ReduceSumParser, \
32-
ReluParser, RequantShiftParser, ReshapeParser, RQIntegerDivParser, RQSiGELUParser, SigmoidParser, SliceParser, \
33-
SoftmaxParser, SqrtParser, SubParser, SwishParser, TransposeParser, UnsqueezeParser, iLayerNormParser, \
34-
iSoftmaxParser
21+
from Deeploy.Targets.Generic.Layers import AddLayer, AveragePoolLayer, BatchNormalizationLayer, CeilLayer, ClipLayer, \
22+
ConcatLayer, ConvLayer, ConvTransposeLayer, DebugPrintLayer, DequantLayer, DivLayer, ExpLayer, FloorLayer, \
23+
GatherLayer, GELULayer, GEMMLayer, GlobalAveragePoolLayer, GlobalMaxPoolLayer, GroupNormLayer, InstanceNormLayer, \
24+
ITAMaxLayer, LayerNormLayer, MatMulLayer, MaxPoolLayer, MulLayer, PadLayer, PowLayer, QuantLayer, ReduceMeanLayer, \
25+
ReduceSumLayer, ReluLayer, RequantShiftLayer, ReshapeLayer, RQIntegerDivLayer, RQSiGELULayer, SigmoidLayer, \
26+
SliceLayer, SoftmaxLayer, SqrtLayer, SubLayer, SwishLayer, TransposeLayer
27+
from Deeploy.Targets.Generic.Parsers import AddParser, AveragePool1DParser, AveragePool2DParser, BatchNormParser, \
28+
CeilParser, ClipParser, ConcatParser, ConvTranspose1DParser, DebugParser, DequantParser, DivParser, DummyParser, \
29+
ExpParser, FlattenParser, FloorParser, GatherParser, GELUParser, GenericConv1DParser, GenericConv2DParser, \
30+
GenericDWConv1DParser, GenericDWConv2DParser, GenericGEMMParser, GenericMaxPool2DParser, GlobalAveragePoolParser, \
31+
GlobalMaxPoolParser, GroupNormParser, HardSigmoidParser, HardSwishParser, InstanceNormParser, IntegerDivParser, \
32+
ITAMaxParser, ITAPartialMaxParser, LayerNormParser, MatMulParser, MaxPool1DParser, MulParser, Pad1DParser, \
33+
Pad2DParser, PowParser, QuantParser, ReduceMeanParser, ReduceSumParser, ReluParser, RequantShiftParser, \
34+
ReshapeParser, RQIntegerDivParser, RQSiGELUParser, SigmoidParser, SliceParser, SoftmaxParser, SqrtParser, \
35+
SubParser, SwishParser, TransposeParser, UnsqueezeParser, iLayerNormParser, iSoftmaxParser
3536
from Deeploy.Targets.Generic.Templates import AllocateTemplate, FreeTemplate
3637
from Deeploy.Targets.Generic.TopologyOptimizationPasses.Passes import DequantPatternPass, ExtractPaddingFromConvPass, \
3738
ExtractPaddingFromPoolPass, MatMulAddMergePass, MergeConstAddAndRequantPass, QuantPatternPass, \
@@ -89,6 +90,10 @@
8990
HardSwishMapper = NodeMapper(HardSwishParser(), BasicHardSwishBindings)
9091
InstanceNormMapper = NodeMapper(InstanceNormParser(), BasicInstanceNormBindings)
9192
GroupNormMapper = NodeMapper(GroupNormParser(), BasicGroupNormBindings)
93+
AveragePool1DMapper = NodeMapper(AveragePool1DParser(), BasicAveragePool1DBindings)
94+
AveragePool2DMapper = NodeMapper(AveragePool2DParser(), BasicAveragePool2DBindings)
95+
GlobalAveragePoolMapper = NodeMapper(GlobalAveragePoolParser(), BasicGlobalAveragePoolBindings)
96+
GlobalMaxPoolMapper = NodeMapper(GlobalMaxPoolParser(), BasicGlobalMaxPoolBindings)
9297

9398
# Dummy nodes are intended for development purposes only!
9499
# They should always generate compiler errors to not accidentally end up in production code
@@ -146,6 +151,9 @@
146151
'HardSwish': SwishLayer([HardSwishMapper]),
147152
'InstanceNormalization': InstanceNormLayer([InstanceNormMapper]),
148153
'GroupNormalization': GroupNormLayer([GroupNormMapper]),
154+
'AveragePool': AveragePoolLayer([AveragePool1DMapper, AveragePool2DMapper]),
155+
'GlobalAveragePool': GlobalAveragePoolLayer([GlobalAveragePoolMapper]),
156+
'GlobalMaxPool': GlobalMaxPoolLayer([GlobalMaxPoolMapper]),
149157
# # For example, you can use the DummpyMapper, in case you want to test
150158
# # deployment or optimizations with GlobalAveragePool nodes but did not yet
151159
# # implement the corresponding kernel
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# SPDX-FileCopyrightText: 2023 ETH Zurich and University of Bologna
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from Deeploy.DeeployTypes import NetworkContext, NodeTemplate, OperatorRepresentation
6+
7+
8+
class _AveragePoolTemplate(NodeTemplate):
9+
10+
def alignToContext(self, ctxt: NetworkContext,
11+
operatorRepresentation: OperatorRepresentation) -> tuple[NetworkContext, dict, list[str]]:
12+
13+
data_in = ctxt.lookup(operatorRepresentation['data_in'])
14+
operatorRepresentation['type_width'] = data_in._type.referencedType.typeWidth
15+
return ctxt, operatorRepresentation, []
16+
17+
18+
referenceTemplate1d = _AveragePoolTemplate("""
19+
// Average Pool 1D (Name: ${nodeName}, Op: ${nodeOp})
20+
AveragePool1d_fp${type_width}_fp${type_width}(
21+
${data_in}, ${data_out}, ${batch_size}, ${num_channels}, ${length}, ${kernel_shape[0]},
22+
${strides[0]}, ${pads[0]}, ${pads[1]});
23+
""")
24+
25+
referenceTemplate2d = _AveragePoolTemplate("""
26+
// Average Pool 2D (Name: ${nodeName}, Op: ${nodeOp})
27+
AveragePool2d_fp${type_width}_fp${type_width}(
28+
${data_in}, ${data_out}, ${batch_size}, ${num_channels}, ${height}, ${width},
29+
${kernel_shape[0]}, ${kernel_shape[1]}, ${strides[0]}, ${strides[1]},
30+
${pads[0]}, ${pads[1]}, ${pads[2]}, ${pads[3]});
31+
""")
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# SPDX-FileCopyrightText: 2023 ETH Zurich and University of Bologna
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from Deeploy.DeeployTypes import NetworkContext, NodeTemplate, OperatorRepresentation
6+
7+
8+
class _GlobalAveragePoolTemplate(NodeTemplate):
9+
10+
def alignToContext(self, ctxt: NetworkContext,
11+
operatorRepresentation: OperatorRepresentation) -> tuple[NetworkContext, dict, list[str]]:
12+
13+
data_in = ctxt.lookup(operatorRepresentation['data_in'])
14+
operatorRepresentation['type_width'] = data_in._type.referencedType.typeWidth
15+
return ctxt, operatorRepresentation, []
16+
17+
18+
referenceTemplate = _GlobalAveragePoolTemplate("""
19+
// Global Average Pool 1D (Name: ${nodeName}, Op: ${nodeOp})
20+
GlobalAveragePool_fp${type_width}_fp${type_width}(
21+
${data_in}, ${data_out}, ${batch_size}, ${num_channels}, ${spatial_size});
22+
""")
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# SPDX-FileCopyrightText: 2023 ETH Zurich and University of Bologna
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from Deeploy.DeeployTypes import NetworkContext, NodeTemplate, OperatorRepresentation
6+
7+
8+
class _GlobalMaxPoolTemplate(NodeTemplate):
9+
10+
def alignToContext(self, ctxt: NetworkContext,
11+
operatorRepresentation: OperatorRepresentation) -> tuple[NetworkContext, dict, list[str]]:
12+
13+
data_in = ctxt.lookup(operatorRepresentation['data_in'])
14+
operatorRepresentation['type_width'] = data_in._type.referencedType.typeWidth
15+
return ctxt, operatorRepresentation, []
16+
17+
18+
referenceTemplate = _GlobalMaxPoolTemplate("""
19+
// Global Max Pool 1D (Name: ${nodeName}, Op: ${nodeOp})
20+
GlobalMaxPool_fp${type_width}_fp${type_width}(
21+
${data_in}, ${data_out}, ${batch_size}, ${num_channels}, ${spatial_size});
22+
""")
776 Bytes
Binary file not shown.
181 Bytes
Binary file not shown.
746 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)