Skip to content

Commit ecdbe88

Browse files
committed
add support for Floor operator for Generic target
1 parent 3b58bf9 commit ecdbe88

12 files changed

Lines changed: 115 additions & 19 deletions

File tree

Deeploy/Targets/Generic/Bindings.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,14 @@
1313
from Deeploy.FutureExtension.CodeTransformationPasses.FutureCodeTransformation import FutureGeneration
1414
from Deeploy.Targets.Generic.Templates import AddTemplate, BatchNormalizationTemplate, ConcatTemplate, ConvTemplate, \
1515
ConvTransposeTemplate, DebugPrintTemplate, DequantTemplate, DummyTemplate, DWConvTemplate, FloatAddTemplate, \
16-
FloatCeilTemplate, FloatClipTemplate, FloatConvTemplate, FloatDivTemplate, FloatDWConvTemplate, FloatGELUTemplate, \
17-
FloatGemmTemplate, FloatLayernormTemplate, FloatMatMulTemplate, FloatMaxPoolTemplate, FloatMulTemplate, \
18-
FloatPadTemplate, FloatPowTemplate, FloatReduceMeanTemplate, FloatReluTemplate, FloatSoftmaxTemplate, \
19-
FloatSqrtTemplate, GatherTemplate, GemmTemplate, IntegerDivTemplate, ITAMaxTemplate, ITAPartialMaxTemplate, \
20-
MatMulTemplate, MaxPoolTemplate, MulTemplate, PadTemplate, QuantTemplate, ReduceMeanTemplate, ReduceSumTemplate, \
21-
RequantShiftTemplate, ReshapeTemplate, RQIntegerDivTemplate, RQSiGELUTemplate, SliceTemplate, TransposeTemplate, \
22-
iGELUTemplate, iLayernormTemplate, iRMSNormTemplate, iSoftmaxTemplate
16+
FloatCeilTemplate, FloatClipTemplate, FloatConvTemplate, FloatDivTemplate, FloatDWConvTemplate, \
17+
FloatFloorTemplate, FloatGELUTemplate, FloatGemmTemplate, FloatLayernormTemplate, FloatMatMulTemplate, \
18+
FloatMaxPoolTemplate, FloatMulTemplate, FloatPadTemplate, FloatPowTemplate, FloatReduceMeanTemplate, \
19+
FloatReluTemplate, FloatSoftmaxTemplate, FloatSqrtTemplate, GatherTemplate, GemmTemplate, IntegerDivTemplate, \
20+
ITAMaxTemplate, ITAPartialMaxTemplate, MatMulTemplate, MaxPoolTemplate, MulTemplate, PadTemplate, QuantTemplate, \
21+
ReduceMeanTemplate, ReduceSumTemplate, RequantShiftTemplate, ReshapeTemplate, RQIntegerDivTemplate, \
22+
RQSiGELUTemplate, SliceTemplate, TransposeTemplate, iGELUTemplate, iLayernormTemplate, iRMSNormTemplate, \
23+
iSoftmaxTemplate
2324
from Deeploy.Targets.Generic.TypeCheckers import AddChecker, BatchNormChecker, ConcatChecker, ConvChecker, \
2425
DebugPrintChecker, DequantChecker, DivChecker, DummyChecker, GatherChecker, GELUChecker, GEMMChecker, \
2526
LayerNormChecker, MatMulChecker, MaxPoolChecker, MulChecker, PadChecker, QuantChecker, ReduceMeanChecker, \
@@ -333,6 +334,11 @@
333334
BasicTransformer),
334335
]
335336

337+
BasicFloorBindings = [
338+
NodeBinding(DummyChecker([PointerClass(float32_t)], [PointerClass(float32_t)]),
339+
FloatFloorTemplate.referenceTemplate, BasicTransformer),
340+
]
341+
336342
BasicClipBindings = [
337343
NodeBinding(
338344
DummyChecker(

Deeploy/Targets/Generic/Layers.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -717,6 +717,12 @@ def __init__(self, maps: List[NodeMapper]):
717717
super().__init__(maps)
718718

719719

720+
class FloorLayer(ONNXLayer):
721+
722+
def __init__(self, maps: List[NodeMapper]):
723+
super().__init__(maps)
724+
725+
720726
class ClipLayer(ONNXLayer):
721727

722728
def __init__(self, maps: List[NodeMapper]):

Deeploy/Targets/Generic/Parsers.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2910,6 +2910,28 @@ def parseNodeCtxt(self,
29102910
return ctxt, True
29112911

29122912

2913+
class FloorParser(NodeParser):
2914+
2915+
def __init__(self):
2916+
super().__init__()
2917+
2918+
def parseNode(self, node: gs.Node) -> bool:
2919+
return node.op == 'Floor' and len(node.inputs) == 1 and len(node.outputs) == 1
2920+
2921+
def parseNodeCtxt(self,
2922+
ctxt: NetworkContext,
2923+
node: gs.Node,
2924+
channels_first: bool = True) -> Tuple[NetworkContext, bool]:
2925+
2926+
data_in = ctxt.lookup(node.inputs[0].name)
2927+
data_out = ctxt.lookup(node.outputs[0].name)
2928+
2929+
self.operatorRepresentation['data_in'] = data_in.name
2930+
self.operatorRepresentation['data_out'] = data_out.name
2931+
self.operatorRepresentation['size'] = int(np.prod(data_in.shape))
2932+
return ctxt, True
2933+
2934+
29132935
class ClipParser(NodeParser):
29142936

29152937
def __init__(self):

Deeploy/Targets/Generic/Platform.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,20 @@
99
from Deeploy.Targets.Generic.Bindings import BasicAddBindings, BasicBatchNormBindings, BasicCeilBindings, \
1010
BasicClipBindings, BasicConcatBindings, BasicConv1DBindings, BasicConv2DBindings, BasicConvTransposeBindings, \
1111
BasicDebugPrintBindings, BasicDequantBindings, BasicDivBindings, BasicDWConv1DBinding, BasicDWConv2DBindings, \
12-
BasicGatherBindings, BasicGELUBindings, BasicGEMMBindings, BasicITAPartialSoftmaxBinding, BasicITASoftmaxBinding, \
13-
BasicLayerNormBindings, BasicMatMulBindings, BasicMaxPool1DBindings, BasicMaxPool2DBindings, BasicMulBindings, \
14-
BasicPad1DBindings, BasicPad2DBindings, BasicPowBindings, BasicQuantBindings, BasicReduceMeanBindings, \
15-
BasicReduceSumBindings, BasicReluBinding, BasicReshapeBindings, BasicRQIntegerDivBinding, BasicRQSBindings, \
16-
BasicRQSGELUBinding, BasicSliceBindings, BasicSoftmaxBindings, BasicSqrtBindings, BasicTransposeBindings, \
17-
DummyBinding
12+
BasicFloorBindings, BasicGatherBindings, BasicGELUBindings, BasicGEMMBindings, BasicITAPartialSoftmaxBinding, \
13+
BasicITASoftmaxBinding, BasicLayerNormBindings, BasicMatMulBindings, BasicMaxPool1DBindings, \
14+
BasicMaxPool2DBindings, BasicMulBindings, BasicPad1DBindings, BasicPad2DBindings, BasicPowBindings, \
15+
BasicQuantBindings, BasicReduceMeanBindings, BasicReduceSumBindings, BasicReluBinding, BasicReshapeBindings, \
16+
BasicRQIntegerDivBinding, BasicRQSBindings, BasicRQSGELUBinding, BasicSliceBindings, BasicSoftmaxBindings, \
17+
BasicSqrtBindings, BasicTransposeBindings, DummyBinding
1818
from Deeploy.Targets.Generic.Layers import AddLayer, BatchNormalizationLayer, CeilLayer, ClipLayer, ConcatLayer, \
19-
ConvLayer, ConvTransposeLayer, DebugPrintLayer, DequantLayer, DivLayer, GatherLayer, GELULayer, GEMMLayer, \
20-
ITAMaxLayer, LayerNormLayer, MatMulLayer, MaxPoolLayer, MulLayer, PadLayer, PowLayer, QuantLayer, ReduceMeanLayer, \
21-
ReduceSumLayer, ReluLayer, RequantShiftLayer, ReshapeLayer, RQIntegerDivLayer, RQSiGELULayer, SliceLayer, \
22-
SoftmaxLayer, SqrtLayer, TransposeLayer
19+
ConvLayer, ConvTransposeLayer, DebugPrintLayer, DequantLayer, DivLayer, FloorLayer, GatherLayer, GELULayer, \
20+
GEMMLayer, ITAMaxLayer, LayerNormLayer, MatMulLayer, MaxPoolLayer, MulLayer, PadLayer, PowLayer, QuantLayer, \
21+
ReduceMeanLayer, ReduceSumLayer, ReluLayer, RequantShiftLayer, ReshapeLayer, RQIntegerDivLayer, RQSiGELULayer, \
22+
SliceLayer, SoftmaxLayer, SqrtLayer, TransposeLayer
2323
from Deeploy.Targets.Generic.Parsers import AddParser, BatchNormParser, CeilParser, ClipParser, ConcatParser, \
24-
ConvTranspose1DParser, DebugParser, DequantParser, DivParser, DummyParser, FlattenParser, GatherParser, \
25-
GELUParser, GenericConv1DParser, GenericConv2DParser, GenericDWConv1DParser, GenericDWConv2DParser, \
24+
ConvTranspose1DParser, DebugParser, DequantParser, DivParser, DummyParser, FlattenParser, FloorParser, \
25+
GatherParser, GELUParser, GenericConv1DParser, GenericConv2DParser, GenericDWConv1DParser, GenericDWConv2DParser, \
2626
GenericGEMMParser, GenericMaxPool2DParser, IntegerDivParser, ITAMaxParser, ITAPartialMaxParser, LayerNormParser, \
2727
MatMulParser, MaxPool1DParser, MulParser, Pad1DParser, Pad2DParser, PowParser, QuantParser, ReduceMeanParser, \
2828
ReduceSumParser, ReluParser, RequantShiftParser, ReshapeParser, RQIntegerDivParser, RQSiGELUParser, SliceParser, \
@@ -74,6 +74,7 @@
7474
ConvTransposeMapper = NodeMapper(ConvTranspose1DParser(), BasicConvTransposeBindings)
7575
SliceMapper = NodeMapper(SliceParser(), BasicSliceBindings)
7676
CeilMapper = NodeMapper(CeilParser(), BasicCeilBindings)
77+
FloorMapper = NodeMapper(FloorParser(), BasicFloorBindings)
7778
ClipMapper = NodeMapper(ClipParser(), BasicClipBindings)
7879

7980
# Dummy nodes are intended for development purposes only!
@@ -122,6 +123,7 @@
122123
'BatchNormalization': BatchNormalizationLayer([BatchNormalizationMapper]),
123124
'ConvTranspose': ConvTransposeLayer([ConvTransposeMapper]),
124125
'Ceil': CeilLayer([CeilMapper]),
126+
'Floor': FloorLayer([FloorMapper]),
125127
'Clip': ClipLayer([ClipMapper]),
126128
# # For example, you can use the DummpyMapper, in case you want to test
127129
# # deployment or optimizations with GlobalAveragePool nodes but did not yet
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# SPDX-FileCopyrightText: 2021 ETH Zurich and University of Bologna
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
import numpy as np
5+
6+
from Deeploy.DeeployTypes import NetworkContext, NodeTemplate, OperatorRepresentation
7+
8+
9+
class _FloorTemplate(NodeTemplate):
10+
11+
def alignToContext(self, ctxt: NetworkContext,
12+
operatorRepresentation: OperatorRepresentation) -> tuple[NetworkContext, dict, list[str]]:
13+
14+
data_in = ctxt.lookup(operatorRepresentation['data_in'])
15+
operatorRepresentation['size'] = int(np.prod(data_in.shape))
16+
operatorRepresentation['type_width'] = data_in._type.referencedType.typeWidth
17+
return ctxt, operatorRepresentation, []
18+
19+
20+
referenceTemplate = _FloorTemplate("""
21+
// Floor (Name: ${nodeName}, Op: ${nodeOp})
22+
Floor_fp${type_width}_fp${type_width}(${data_in}, ${data_out}, ${size});
23+
""")
776 Bytes
Binary file not shown.
122 Bytes
Binary file not shown.
778 Bytes
Binary file not shown.

DeeployTest/test_generic_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
"Kernels/FP32/Conv/Regular_2D_NoBias",
1818
"Kernels/FP32/Conv/Regular_2D_ZeroValuedBias",
1919
"Kernels/FP32/Div",
20+
"Kernels/FP32/Floor",
2021
"Kernels/FP32/GEMM/Regular",
2122
"Kernels/FP32/MatMul",
2223
"Kernels/FP32/MaxPool/Regular_1D",

TargetLibraries/Generic/inc/DeeployBasicMath.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
#include "kernel/Convolution.h"
4040
#include "kernel/DWConvolution.h"
4141
#include "kernel/Div.h"
42+
#include "kernel/Floor.h"
4243
#include "kernel/GELU.h"
4344
#include "kernel/Gemm.h"
4445
#include "kernel/Hardswish.h"

0 commit comments

Comments
 (0)