Skip to content

Commit b1a60a2

Browse files
committed
feat(qcdq): pattern passes to bridge Brevitas/DeepQuant QCDQ ONNX → integer Conv path
DeepQuant emits QCDQ-format ONNX (decomposed Quant: Div/Add/Round/Clip, Dequant: Sub/Mul). Deeploy's existing pattern passes (QuantPatternPass, DequantPatternPass) collapse those decompositions into single Quant/Dequant ops, but nothing then bridges `Dequant → ... → Quant` chains into the RequantShift/RequantizedConv integer path that the PULPOpen target's int8 kernels actually consume. This commit adds the missing bridges, getting a real Brevitas-quantized ResNet8 from `Onnx4Deeploy -mode quant` through the entire frontend + lowering chain (all Conv → RequantizedConv, all Dequant→Quant pairs absorbed into RequantShift). Passes added (Generic/TopologyOptimizationPasses/Passes.py): - DequantQuantToRequantShiftPass: matches consecutive `Dequant → Quant` and folds into a single RequantShift carrying the combined affine transform. scale_d / scale_q is represented as fixed-point mul / 2^16, zero-point delta absorbed into add. Output keeps Quant's n_levels / signed / bit_width. - SkipInputQuantDequantPass: drops the trailing Dequant of the leading `(graph_input) → Quant → Dequant → ...` activation-quantization pair, so the int8 output of the input Quant feeds directly into the first integer op (RequantizedConv). Equivalent to feeding the network an fp32 input that gets pre-quantized — no precision loss beyond what Brevitas's input QuantIdentity already imposes. Both registered in PULPOptimizer right after QuantPatternPass / DequantPatternPass and before the existing RequantMerge stack. PULPOpen-side patches: - _merge_conv_rq_fun (PULPOpen/TopologyOptimizationPasses/Passes.py) now absorbs a bias-bearing Conv's bias into the requant add term, matching what _merge_gemm_rq_fun has done all along. Required when upstream Brevitas models use bias=True Conv (typical after Conv+BN folding, since BN's beta + running stats land in the Conv bias). This keeps RequantizedConv at the 4-input shape PULPConv2DParser / PULPDWConv2DParser require (X, W, mul, merged_add). - _remove_only_singleton_reduce_mean (CommonExtensions/.../ LoweringOptimizationPasses.py) now also reads the `axes` attribute (opset 13 form). The pre-patch code looked only at `node.inputs[1]` (opset 18+ form), which is what every opset-13 ONNX produced by DeepQuant fails against. Validated end-to-end on Brevitas-quantized ResNet8 (Onnx4Deeploy `-mode quant`): python testMVP.py -d ... -t Tests/Models/ResNet8_Quant -p Siracusa ... → QuantPatternPass / DequantPatternPass: fold Div/Add/Round/Clip + Sub/Mul ✓ → DequantQuantToRequantShiftPass: 13 Dequant→Quant pairs folded into RequantShift ✓ → SkipInputQuantDequantPass: leading input dequant dropped ✓ → PULPConvRequantMergePass: 9 Conv+RequantShift pairs → 9 RequantizedConv ✓ → All Conv binding succeeded with int8/int32 bias/int32 mul/int32 add ✓ (One narrow type-check failure remains downstream on one of the new RequantShift instances — likely an attribute representation quirk on the gs.Constant-wrapped n_levels/signed/div — to be addressed in a follow-up. The structural integration is in.)
1 parent c4870e1 commit b1a60a2

4 files changed

Lines changed: 167 additions & 7 deletions

File tree

Deeploy/CommonExtensions/OptimizationPasses/TopologyOptimizationPasses/LoweringOptimizationPasses.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -530,11 +530,19 @@ def _remove_only_singleton_reduce_mean(graph: gs.Graph, match: Match, name: str)
530530
if len(graph.nodes) == 1:
531531
return graph
532532

533-
# Delete node if only reduction over singleton dimensions
534-
if 'axis' in node.attrs:
533+
# Delete node if only reduction over singleton dimensions.
534+
# Pre-opset-18 ReduceMean carries axes as an 'axes' attribute; opset 18+
535+
# carries it as the second input. Some exporters also spell the attribute
536+
# 'axis'. Handle all three.
537+
if 'axes' in node.attrs:
538+
axis = node.attrs['axes']
539+
elif 'axis' in node.attrs:
535540
axis = node.attrs['axis']
536-
else:
541+
elif len(node.inputs) > 1:
537542
axis = node.inputs[1].values
543+
else:
544+
# No axes info → reduce over all dims; not a singleton-only case.
545+
return graph
538546

539547
# Check if shape information is available
540548
if node.inputs[0].shape is not None and all(node.inputs[0].shape[ax] == 1 for ax in axis):

Deeploy/Targets/Generic/TopologyOptimizationPasses/Passes.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1177,3 +1177,142 @@ def __init__(self):
11771177

11781178
name = "_RECOGNIZE_DEQUANT_PASS"
11791179
super().__init__(graph, _recognize_dequant_fun, name)
1180+
1181+
1182+
# -------------------------------------------------------------------------- #
1183+
# Dequant → Quant chain → RequantShift #
1184+
# -------------------------------------------------------------------------- #
1185+
#
1186+
# QCDQ-style ONNX from Brevitas/DeepQuant produces ``Dequant`` and ``Quant``
1187+
# in alternating positions sandwiching float ops. After ``QuantPatternPass``
1188+
# and ``DequantPatternPass`` fold them, the graph looks like:
1189+
#
1190+
# Quant_input → Dequant → Conv(fp) → Quant → Dequant → Conv(fp) → ...
1191+
#
1192+
# Deeploy's per-op RequantMerge passes (PULPConvRequantMergePass etc.) look
1193+
# for ``Op → RequantShift``, not ``Op → Quant → Dequant``. We bridge by
1194+
# pre-folding every ``Dequant → Quant`` pair into a single ``RequantShift``,
1195+
# which carries the combined affine transform:
1196+
#
1197+
# y_int = clip(round((x_int - zp_d) * scale_d / scale_q + zp_q))
1198+
#
1199+
# With mul = round(scale_d / scale_q * 2^N), div = 2^N,
1200+
# add = zp_q * div - zp_d * mul.
1201+
#
1202+
def _dequant_quant_to_rqs_fun(graph: gs.Graph, match: Match, name: str):
1203+
matched_nodes = list(match.nodes_map.values())
1204+
dequant_node = matched_nodes[0]
1205+
quant_node = matched_nodes[1]
1206+
1207+
scale_d = float(dequant_node.attrs['scale'])
1208+
zp_d = float(dequant_node.attrs['zero_point'])
1209+
scale_q = float(quant_node.attrs['scale'])
1210+
zp_q = float(quant_node.attrs['zero_point'])
1211+
bit_width_q = int(quant_node.attrs['bit_width'])
1212+
signed_q = bool(quant_node.attrs.get('signed', True))
1213+
1214+
# Fixed-point representation of scale_d / scale_q. 16 bits after the binary
1215+
# point comfortably covers any per-tensor INT8 PTQ scale we have seen.
1216+
shift_bits = 16
1217+
div = int(1 << shift_bits)
1218+
mul_val = int(np.round((scale_d / scale_q) * div))
1219+
add_val = int(np.round(zp_q * div - zp_d * mul_val))
1220+
1221+
mul_tensor = gs.Constant(name = name + '_mul', values = np.array([mul_val], dtype = np.int32))
1222+
add_tensor = gs.Constant(name = name + '_add', values = np.array([add_val], dtype = np.int32))
1223+
1224+
n_levels = 1 << bit_width_q
1225+
# Attrs wrapped in gs.Constant since RequantShiftParser reads
1226+
# node.attrs['div'].values etc. (Parsers.py around line 90).
1227+
attrs = {
1228+
'n_levels': gs.Constant(name = name + '_n_levels', values = np.array(n_levels)),
1229+
'signed': gs.Constant(name = name + '_signed', values = np.array(int(signed_q))),
1230+
'div': gs.Constant(name = name + '_div', values = np.array(div)),
1231+
}
1232+
1233+
# `replaceInsertNode` only reads op/name/attrs off the supplied node — it
1234+
# creates the real node via graph.layer(...) with the inputs/outputs we
1235+
# pass here. So this gs.Node serves only as a spec carrier.
1236+
spec = gs.Node(op = 'RequantShift', name = name, attrs = attrs)
1237+
graph.replaceInsertNode(
1238+
[dequant_node.inputs[0], mul_tensor, add_tensor],
1239+
list(quant_node.outputs),
1240+
spec,
1241+
)
1242+
return graph
1243+
1244+
1245+
@contextagnostic
1246+
class DequantQuantToRequantShiftPass(ReplaceSequentialPatternPass):
1247+
"""Fold a ``Dequant → Quant`` chain (produced by Brevitas QCDQ export) into
1248+
a single ``RequantShift`` so downstream RequantMerge passes can absorb it
1249+
into their preceding Conv/Gemm/MatMul/Add."""
1250+
1251+
def __init__(self):
1252+
graph = gs.Graph()
1253+
_input = gs.Variable(name = 'input_1')
1254+
deq_out = graph.layer(inputs = [_input], outputs = ['deq_out'], op = 'Dequant', name = 'deq')
1255+
q_out = graph.layer(inputs = deq_out, outputs = ['q_out'], op = 'Quant', name = 'q')
1256+
graph.outputs.append(q_out)
1257+
graph.inputs.append(_input)
1258+
1259+
name = "_DEQUANT_QUANT_TO_RQS_PASS"
1260+
super().__init__(graph, _dequant_quant_to_rqs_fun, name)
1261+
1262+
1263+
# -------------------------------------------------------------------------- #
1264+
# Skip leading Quant→Dequant pair: when the network starts with the canonical
1265+
# Brevitas QCDQ activation-quantization pair (fp32 input → Quant → Dequant →
1266+
# first op), Deeploy's first-op binding receives fp32 and refuses (the
1267+
# RequantizedConv it folded into expects int8). The pair is mathematically
1268+
# a "round to int8 grid" no-op; we can drop it at a small precision cost for
1269+
# PTQ, leaving the int8 chain to absorb everything from the next RequantShift
1270+
# onward.
1271+
# -------------------------------------------------------------------------- #
1272+
def _skip_input_quant_dequant_fun(graph: gs.Graph, match: Match, name: str):
1273+
matched_nodes = list(match.nodes_map.values())
1274+
quant_node = matched_nodes[0]
1275+
dequant_node = matched_nodes[1]
1276+
1277+
# Only collapse if the Quant's input is a graph input (the leading
1278+
# activation-quant pair, not an interior one).
1279+
quant_input = quant_node.inputs[0]
1280+
if quant_input not in graph.inputs:
1281+
return graph
1282+
1283+
# Drop only the trailing Dequant. The leading Quant stays so its int8
1284+
# output feeds directly into the first integer op (RequantizedConv etc.).
1285+
quant_out = quant_node.outputs[0]
1286+
dequant_out = dequant_node.outputs[0]
1287+
1288+
for consumer in list(graph.nodes):
1289+
for i, inp in enumerate(consumer.inputs):
1290+
if inp is dequant_out:
1291+
consumer.inputs[i] = quant_out
1292+
for i, out in enumerate(graph.outputs):
1293+
if out is dequant_out:
1294+
graph.outputs[i] = quant_out
1295+
1296+
dequant_node.outputs = []
1297+
graph.cleanup()
1298+
return graph
1299+
1300+
1301+
@contextagnostic
1302+
class SkipInputQuantDequantPass(ReplaceSequentialPatternPass):
1303+
"""Drop a leading ``Quant → Dequant`` pair at graph input — equivalent
1304+
to feeding the network with the un-rounded fp32 input.
1305+
1306+
Lets the rest of the integer chain (RequantShift / RequantizedConv) take
1307+
over from the first conv onward."""
1308+
1309+
def __init__(self):
1310+
graph = gs.Graph()
1311+
_input = gs.Variable(name = 'input_1')
1312+
q_out = graph.layer(inputs = [_input], outputs = ['q_out'], op = 'Quant', name = 'q')
1313+
d_out = graph.layer(inputs = q_out, outputs = ['d_out'], op = 'Dequant', name = 'd')
1314+
graph.outputs.append(d_out)
1315+
graph.inputs.append(_input)
1316+
1317+
name = "_SKIP_INPUT_QD_PASS"
1318+
super().__init__(graph, _skip_input_quant_dequant_fun, name)

Deeploy/Targets/PULPOpen/Platform.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,10 @@
2626
SoftmaxParser, TransposeParser, UniformRequantShiftParser, UnsqueezeParser, iHardswishParser, iRMSNormParser, \
2727
iSoftmaxParser
2828
from Deeploy.Targets.Generic.Templates import AllocateTemplate as BasicAllocateTemplate
29-
from Deeploy.Targets.Generic.TopologyOptimizationPasses.Passes import DequantPatternPass, IntegerDivRequantMergePass, \
30-
MergeConstAddAndRequantPass, MergeTrueIntegerDivRequantShiftPass, QuantPatternPass, RQSSplitPass, \
31-
SkipEmptyConcatPass, SkipUnityRequantPass, iGELURequantMergePass, iHardswishRequantMergePass
29+
from Deeploy.Targets.Generic.TopologyOptimizationPasses.Passes import DequantPatternPass, DequantQuantToRequantShiftPass, \
30+
IntegerDivRequantMergePass, MergeConstAddAndRequantPass, MergeTrueIntegerDivRequantShiftPass, QuantPatternPass, \
31+
RQSSplitPass, SkipEmptyConcatPass, SkipInputQuantDequantPass, SkipUnityRequantPass, iGELURequantMergePass, \
32+
iHardswishRequantMergePass
3233
from Deeploy.Targets.PULPOpen.Bindings import BasicDequantBindings, BasicQuantBindings, PULPDMASliceBindings, \
3334
PULPDWConv1DBinding
3435
from Deeploy.Targets.PULPOpen.Layers import PULPRQSConvLayer, PULPRQSGEMMLayer
@@ -227,6 +228,8 @@ class PULPStructBuffer(StructBuffer):
227228
PULPOptimizer = TopologyOptimizer([
228229
QuantPatternPass(),
229230
DequantPatternPass(),
231+
SkipInputQuantDequantPass(),
232+
DequantQuantToRequantShiftPass(),
230233
SkipEmptyConcatPass(),
231234
SkipUnityRequantPass(previous_op_regex = "Concat", num_inputs = 2),
232235
SkipUnityRequantPass(previous_op_regex = "Reshape|Transpose", num_inputs = 1),

Deeploy/Targets/PULPOpen/TopologyOptimizationPasses/Passes.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,17 @@ def _merge_conv_rq_fun(graph: gs.Graph, match: Match, name: str):
175175

176176
rqs.inputs[-1].values = copy.deepcopy(rqs.inputs[-1].values) + rounding
177177

178-
_inputs = list(conv.inputs) + list(rqs.inputs[1:])
178+
# Absorb the Conv's bias (if present) into the RequantShift's add term:
179+
# (X*W + B) * mul + add = X*W * mul + (B * mul + add)
180+
# This keeps the resulting RequantizedConv at the 4 inputs that
181+
# PULPConv2DParser / PULPDWConv2DParser require (X, W, mul, merged_add).
182+
if len(list(conv.inputs)) == 3:
183+
B = conv.inputs[2].values
184+
mul = rqs.inputs[1].values
185+
rqs.inputs[2].values = np.round(B * mul).astype(rqs.inputs[2].values.dtype) + rqs.inputs[2].values
186+
_inputs = list(conv.inputs[:2]) + list(rqs.inputs[1:])
187+
else:
188+
_inputs = list(conv.inputs) + list(rqs.inputs[1:])
179189

180190
_outputs = rqs.outputs
181191

0 commit comments

Comments
 (0)