Skip to content

Commit 8483963

Browse files
committed
Fix QLiteCNN, add support for Tiled Quant/Dequant nodes
1 parent 12f311a commit 8483963

11 files changed

Lines changed: 133 additions & 13 deletions

File tree

Deeploy/Targets/GAP9/Bindings.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
from Deeploy.Targets.GAP9.DMA.L3Dma import gap9L3DmaHack
2222
from Deeploy.Targets.GAP9.DMA.MchanDma import GAP9MchanDma
2323
# Import templates from PULPOpen and Generic
24-
from Deeploy.Targets.Generic.Templates import AddTemplate, ConcatTemplate, DequantTemplate, FloatReduceMeanTemplate, \
25-
FloatReduceSumTemplate, GatherTemplate, QuantTemplate, RQSiGELUTemplate, SliceTemplate, iHardswishTemplate, DebugPrintTemplate
24+
from Deeploy.Targets.Generic.Templates import AddTemplate, ConcatTemplate, FloatReduceMeanTemplate, \
25+
FloatReduceSumTemplate, GatherTemplate, RQSiGELUTemplate, SliceTemplate, iHardswishTemplate, DebugPrintTemplate
2626
from Deeploy.Targets.Generic.TypeCheckers import AddChecker, ConcatChecker, ConvChecker, DequantChecker, \
2727
GatherChecker, GELUChecker, GEMMChecker, HardswishChecker, LayerNormChecker, MatMulChecker, MulChecker, \
2828
QuantChecker, ReduceMeanChecker, ReluChecker, ReshapeChecker, RQAddChecker, RQHardswishChecker, SGDChecker, \
@@ -40,7 +40,7 @@
4040
FloatMulTemplate, FloatReluTemplate, FloatSoftmaxTemplate, GEMMTemplate, MatrixVectorTemplate, MaxPoolTemplate, \
4141
MulTemplate, ReduceMeanTemplate, RequantShiftTemplate, ReshapeTemplate, RQAddTemplate, RQSiHardswishTemplate, \
4242
SGDTemplate, SoftmaxCrossEntropyLossTemplate, TallGEMMTemplate, TransposeTemplate, UniformRequantShiftTemplate, \
43-
iRMSNormTemplate, iSoftmaxTemplate, FloatInPlaceAccumulatorV2Template, \
43+
iRMSNormTemplate, iSoftmaxTemplate, FloatInPlaceAccumulatorV2Template, QuantTemplate, DequantTemplate, \
4444
FloatPerturbEggrollTemplate, FloatPerturbUniformTemplate, FloatPerturbNormalTemplate, \
4545
FloatPerturbRademacherTemplate, FloatPerturbTriangleTemplate
4646
from Deeploy.Targets.PULPOpen.TypeCheckers import PULPConvChecker, PULPLinearChecker, PULPMaxPoolChecker, \

Deeploy/Targets/GAP9/Platform.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
GAP9RQSTallGEMMTilingReadyBindings, GAP9RQSTilingReadyBindings, GAP9SGDTilingReadyBindings, \
2323
GAP9SoftmaxCrossEntropyGradTilingReadyBindings, GAP9SoftmaxCrossEntropyTilingReadyBindings, \
2424
GAP9SoftmaxGradTilingReadyBindings, GAP9SoftmaxTilingReadyBindings, GAP9TransposeTilingReadyBindings, \
25-
GAP9UniformRQSTilingReadyBindings, GAP9InPlaceAccumulatorV2TilingReadyBindings, GAP9PerturbNormalTilingReadyBindings, GAP9PerturbUniformTilingReadyBindings, \
25+
GAP9UniformRQSTilingReadyBindings, GAP9InPlaceAccumulatorV2TilingReadyBindings, GAP9PerturbNormalTilingReadyBindings, \
26+
GAP9PerturbUniformTilingReadyBindings, GAP9QuantTilingReadyBindings, GAP9DequantTilingReadyBindings, \
2627
GAP9PerturbEggrollTilingReadyBindings, GAP9PerturbRademacherTilingReadyBindings, GAP9PerturbTriangleTilingReadyBindings
2728
from Deeploy.Targets.Generic.Bindings import BasicGEMMBindings, BasicPad1DBindings, BasicPad2DBindings, \
2829
BasicRQIntegerDivBinding
@@ -41,12 +42,13 @@
4142
InPlaceAccumulatorV2Parser, GELUGradParser, ReluGradParser, DebugParser, \
4243
PerturbEggrollParser, PerturbNormalParser, PerturbRademacherParser, PerturbTriangleParser, PerturbUniformParser
4344
from Deeploy.Targets.Generic.Templates import AllocateTemplate as BasicAllocateTemplate
44-
from Deeploy.Targets.GAP9.Bindings import GAP9SoftmaxCrossEntropyLossDualOutputBindings, GAP9LayernormGradBinding, GAP9FloatGELUGradBinding, GAP9ReluGradBinding, GAP9BasicDebugPrintBindings
45+
from Deeploy.Targets.GAP9.Bindings import GAP9SoftmaxCrossEntropyLossDualOutputBindings, GAP9LayernormGradBinding, \
46+
GAP9FloatGELUGradBinding, GAP9ReluGradBinding, GAP9BasicDebugPrintBindings
4547
from Deeploy.Targets.PULPOpen.Bindings import BasicDequantBindings, BasicQuantBindings, PULPDMASliceBindings, \
4648
PULPDWConv1DBinding, PULPReduceMeanBindings, PULPRQSConv1DBindings, PULPSliceBindings
4749
from Deeploy.Targets.PULPOpen.Layers import PULPRQSConvLayer, PULPRQSGEMMLayer
4850
from Deeploy.Targets.PULPOpen.Parsers import PULPConv1DParser, PULPConv2DParser, PULPDWConv1DParser, \
49-
PULPDWConv2DParser, PULPFPConv2DParser, PULPFPDWConv2DParser, PULPGEMMParser, PULPMatrixVecParser, \
51+
PULPDWConv2DParser, PULPFPConv2DParser, PULPFPDWConv2DParser, PULPGEMMParser, PULPIntConv2DParser, PULPMatrixVecParser, \
5052
PULPTallGEMMParser
5153

5254
# Create GAP9-specific NodeMappers
@@ -96,8 +98,8 @@
9698
GAP9_SoftmaxCrossEntropyLossGradMapper = NodeMapper(SoftmaxCrossEntropyLossGradParser(),
9799
GAP9SoftmaxCrossEntropyGradTilingReadyBindings)
98100
GAP9_SGDMapper = NodeMapper(SGDParser(), GAP9SGDTilingReadyBindings)
99-
GAP9_QuantMapper = NodeMapper(QuantParser(), BasicQuantBindings)
100-
GAP9_DequantMapper = NodeMapper(DequantParser(), BasicDequantBindings)
101+
GAP9_QuantMapper = NodeMapper(QuantParser(), GAP9QuantTilingReadyBindings)
102+
GAP9_DequantMapper = NodeMapper(DequantParser(), GAP9DequantTilingReadyBindings)
101103
GAP9_GEMMDequantMapper = NodeMapper(PULPGEMMParser(), BasicGEMMBindings)
102104
GAP9InPlaceAccumulatorV2Mapper = NodeMapper(InPlaceAccumulatorV2Parser(), GAP9InPlaceAccumulatorV2TilingReadyBindings)
103105
GAP9SoftmaxCrossEntropyLossDualOutputMapper = NodeMapper(SoftmaxCrossEntropyLossParser(),

Deeploy/Targets/GAP9/Tiler.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
GAP9RQSiHardswishBindings, GAP9RQSMatrixVecBindings, GAP9RQSTallGEMMBindings, GAP9SGDBindings, \
1919
GAP9SoftmaxBindings, GAP9SoftmaxCrossEntropyLossBindings, GAP9SoftmaxCrossEntropyLossGradBindings, \
2020
GAP9SoftmaxGradBindings, GAP9TransposeBindings, GAP9UniformRQSBindings, GAP9InPlaceAccumulatorV2Bindings, GAP9PerturbNormalBindings, \
21-
GAP9PerturbUniformBindings, GAP9PerturbEggrollBindings, GAP9PerturbRademacherBindings, GAP9PerturbTriangleBindings
21+
GAP9PerturbUniformBindings, GAP9PerturbEggrollBindings, GAP9PerturbRademacherBindings, GAP9PerturbTriangleBindings, \
22+
GAP9QuantBindings, GAP9DequantBindings
2223
from Deeploy.Targets.Generic.TileConstraints.AddTileConstraint import AddTileConstraint
2324
from Deeploy.Targets.Generic.TileConstraints.ConcatTileConstraint import ConcatTileConstraint
2425
from Deeploy.Targets.Generic.TileConstraints.iHardswishTileConstraint import iHardswishTileConstraint
@@ -93,6 +94,12 @@
9394
GAP9MaxPool2DTilingReadyBindings = TilingReadyNodeBindings(nodeBindings = GAP9MaxPool2DBindings,
9495
tileConstraint = MaxPoolCTileConstraint())
9596

97+
GAP9QuantTilingReadyBindings = TilingReadyNodeBindings(nodeBindings = GAP9QuantBindings,
98+
tileConstraint = UnaryTileConstraint())
99+
100+
GAP9DequantTilingReadyBindings = TilingReadyNodeBindings(nodeBindings = GAP9DequantBindings,
101+
tileConstraint = UnaryTileConstraint())
102+
96103
GAP9RQSTilingReadyBindings = TilingReadyNodeBindings(nodeBindings = GAP9RQSBindings,
97104
tileConstraint = RequantShiftTileConstraint())
98105

Deeploy/Targets/PULPOpen/Parsers.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,38 @@ def parseNodeCtxt(self,
6161

6262
return ctxt, False
6363

64+
class PULPIntConv2DParser(Conv2DParser):
65+
66+
def __init__(self, noBiasHoisting = True):
67+
super().__init__(noBiasHoisting)
68+
69+
def parseNode(self, node: gs.Node) -> (bool):
70+
71+
wellFormed = super().parseNode(node)
72+
if wellFormed:
73+
ret = all([
74+
# Current PULP kernel only supports grouping of 1
75+
self.operatorRepresentation['group'] == 1,
76+
77+
# Make sure padding is symmetric (left==right, top==bottom)
78+
# but top/bottom can differ from left/right
79+
self.operatorRepresentation['pads'][0] == self.operatorRepresentation['pads'][2], # top == bottom
80+
self.operatorRepresentation['pads'][1] == self.operatorRepresentation['pads'][3], # left == right
81+
82+
# Check number of inputs
83+
# 2 inputs if no bias, 3 if layer has bias
84+
len(node.inputs) in [2, 3],
85+
])
86+
87+
# Extract additional attributes
88+
self.operatorRepresentation['padding_y_top'] = int(self.operatorRepresentation['pads'][0])
89+
self.operatorRepresentation['padding_x_left'] = int(self.operatorRepresentation['pads'][1])
90+
self.operatorRepresentation['padding_y_bottom'] = int(self.operatorRepresentation['pads'][2])
91+
self.operatorRepresentation['padding_x_right'] = int(self.operatorRepresentation['pads'][3])
92+
93+
return ret
94+
return False
95+
6496

6597
class PULPFPConv2DParser(Conv2DParser):
6698

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from Deeploy.DeeployTypes import NodeTemplate
6+
7+
8+
class PULPDequantTemplate(NodeTemplate):
9+
10+
def __init__(self, templateStr):
11+
super().__init__(templateStr)
12+
13+
14+
referenceTemplate = PULPDequantTemplate("""
15+
// Dequantization (Name: ${nodeName}, Op: ${nodeOp})
16+
uint8_t ${nodeName}_core_id = (uint8_t) pi_core_id();
17+
uint8_t ${nodeName}_log2Core = (uint8_t) log2(NUM_CORES);
18+
uint32_t ${nodeName}_chunk = (${size} >> ${nodeName}_log2Core) + ((${size} & (NUM_CORES-1))!=0);
19+
uint32_t ${nodeName}_chunk_start = (uint32_t) MIN(${nodeName}_chunk*${nodeName}_core_id, (uint32_t) ${size});
20+
uint32_t ${nodeName}_chunk_stop = (uint32_t) MIN(${nodeName}_chunk_start + ${nodeName}_chunk, (uint32_t) ${size});
21+
22+
23+
for (uint32_t i=${nodeName}_chunk_start; i<${nodeName}_chunk_stop; i++) {
24+
int32_t quantized = (int32_t)${data_in}[i];
25+
float32_t shifted_val = quantized - ${zero_point};
26+
float32_t dequantized = shifted_val * ${scale};
27+
28+
${data_out}[i] = (${data_out_type.referencedType.typeName})dequantized;
29+
}
30+
""")
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from Deeploy.DeeployTypes import NodeTemplate
6+
7+
8+
class PULPQuantTemplate(NodeTemplate):
9+
10+
def __init__(self, templateStr):
11+
super().__init__(templateStr)
12+
13+
14+
referenceTemplate = PULPQuantTemplate("""
15+
// Quantization (Name: ${nodeName}, Op: ${nodeOp})
16+
uint8_t ${nodeName}_core_id = (uint8_t) pi_core_id();
17+
uint8_t ${nodeName}_log2Core = (uint8_t) log2(NUM_CORES);
18+
uint32_t ${nodeName}_chunk = (${size} >> ${nodeName}_log2Core) + ((${size} & (NUM_CORES-1))!=0);
19+
uint32_t ${nodeName}_chunk_start = (uint32_t) MIN(${nodeName}_chunk*${nodeName}_core_id, (uint32_t) ${size});
20+
uint32_t ${nodeName}_chunk_stop = (uint32_t) MIN(${nodeName}_chunk_start + ${nodeName}_chunk, (uint32_t) ${size});
21+
22+
for (uint32_t i=${nodeName}_chunk_start; i<${nodeName}_chunk_stop; i++) {
23+
// quantization formula
24+
float32_t input_val = ${data_in}[i];
25+
float32_t scaled_val = input_val * ${scale}; // Multiply instead of divide
26+
float32_t shifted_val = scaled_val + ${zero_point};
27+
28+
// Round to nearest integer
29+
int32_t quantized = (int32_t)(shifted_val + 0.5f * (shifted_val >= 0 ? 1 : -1));
30+
31+
// Clamp the value
32+
if (quantized < ${min_val}) quantized = ${min_val};
33+
if (quantized > ${max_val}) quantized = ${max_val};
34+
35+
// Assign directly with explicit cast
36+
${data_out}[i] = (${data_out_type.referencedType.typeName})quantized;
37+
38+
}
39+
""")

Deeploy/Targets/PULPOpen/TopologyOptimizationPasses/Passes.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -168,12 +168,17 @@ def _merge_conv_rq_fun(graph: gs.Graph, match: Match, name: str):
168168
conv = matched_nodes[0]
169169
rqs = matched_nodes[1]
170170

171-
totalShift = int(np.log2(rqs.attrs['div'].values))
171+
div_val = rqs.attrs['div']
172+
if hasattr(div_val, 'values'):
173+
div_val = div_val.values.item()
172174

175+
totalShift = int(np.log2(div_val))
176+
print(f"total shift: {totalShift}")
173177
# Artifically add half the shift division value to implement rounding
174178
rounding = 2**(totalShift - 1) if totalShift > 0 else 0
175179

176-
rqs.inputs[-1].values = copy.deepcopy(rqs.inputs[-1].values) + rounding
180+
# JANSNO: this can't be right-commented.
181+
#rqs.inputs[-1].values = copy.deepcopy(rqs.inputs[-1].values) + rounding
177182

178183
_inputs = list(conv.inputs) + list(rqs.inputs[1:])
179184

@@ -205,9 +210,14 @@ def _merge_gemm_rq_fun(graph: gs.Graph, match: Match, name: str):
205210
gemm = matched_nodes[0]
206211
rqs = matched_nodes[1]
207212

208-
totalShift = int(np.log2(rqs.attrs['div'].values))
213+
div_val = rqs.attrs['div']
214+
if hasattr(div_val, 'values'):
215+
div_val = div_val.values.item()
209216

210-
rqs.inputs[-1].values = copy.deepcopy(rqs.inputs[-1].values) + 2**(totalShift - 1)
217+
totalShift = int(np.log2(div_val))
218+
219+
# JANSNO: rounding here is not valid
220+
#rqs.inputs[-1].values = copy.deepcopy(rqs.inputs[-1].values) + 2**(totalShift - 1)
211221

212222
# GEMM has add
213223
if len(list(gemm.inputs)) == 3:
0 Bytes
Binary file not shown.
-4.34 KB
Binary file not shown.
-16.4 KB
Binary file not shown.

0 commit comments

Comments
 (0)