Skip to content

Commit 7f14a9d

Browse files
author
Github Executorch
committed
Summary:MV2 CortexM PassManager changes for Alif E8
Test Plan: python3 -m examples.arm.aot_arm_compiler -m mv2 --target=cortex-m --quantize --enable_qdq_fusion_pass --intermediates=./mv2_intermediates --output=./mv2_cortex_m.pte cat ./mv2_intermediates/delegation_info.txt Delegation info: Total delegated subgraphs: 0 Number of delegated nodes: 0 Number of non-delegated nodes: 72 Delegation table: ╒════╤═════════════════════════════════════════════╤═══════════════════════════════════╤═══════════════════════════════════════╕ │ │ op_type │ occurrences_in_delegated_graphs │ occurrences_in_non_delegated_graphs │ ╞════╪═════════════════════════════════════════════╪═══════════════════════════════════╪═══════════════════════════════════════╡ │ 0 │ aten_as_strided_copy_default │ 0 │ 1 │ ├────┼─────────────────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤ │ 1 │ aten_mean_dim │ 0 │ 1 │ ├────┼─────────────────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤ │ 2 │ aten_view_copy_default │ 0 │ 1 │ ├────┼─────────────────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤ │ 3 │ cortex_m_dequantize_per_tensor_default │ 0 │ 2 │ ├────┼─────────────────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤ │ 4 │ cortex_m_quantize_per_tensor_default │ 0 │ 2 │ ├────┼─────────────────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤ │ 5 │ cortex_m_quantized_add_default │ 0 │ 10 │ ├────┼─────────────────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤ │ 6 │ cortex_m_quantized_conv2d_default │ 0 │ 35 │ ├────┼─────────────────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤ │ 7 │ cortex_m_quantized_depthwise_conv2d_default │ 0 │ 17 │ ├────┼─────────────────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤ │ 8 │ cortex_m_quantized_linear_default │ 0 │ 1 │ ├────┼─────────────────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤ │ 9 │ dim_order_ops__clone_dim_order_default │ 0 │ 1 │ ├────┼─────────────────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤ │ 10 │ Total │ 0 │ 71 │ ╘════╧═════════════════════════════════════════════╧═══════════════════════════════════╧═══════════════════════════════════════╛ Reviewers: Subscribers: Tasks: Tags:
1 parent ec4c462 commit 7f14a9d

5 files changed

Lines changed: 376 additions & 18 deletions

File tree

backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,14 @@
3333
from torch.fx import GraphModule, Node
3434

3535

36+
# Passthrough ops that preserve quantization parameters from input to output.
37+
# These ops should be foldable even without explicit annotation metadata.
38+
PASSTHROUGH_OPS = {
39+
exir_ops.edge.aten.hardtanh.default,
40+
exir_ops.edge.aten.relu.default,
41+
exir_ops.edge.aten.clamp.default,
42+
}
43+
3644
def _get_special_dtype(qspec: QuantArgs) -> TosaSpecialDtype | None:
3745
if qspec.dtype == torch.int8:
3846
if qspec.qmax == 7 and qspec.qmin == -7:
@@ -248,6 +256,26 @@ def _handle_control_flow_node(self, node: Node, graph_module: GraphModule):
248256
submodule.graph.erase_node(node_to_remove)
249257
return
250258

259+
@staticmethod
260+
def _has_dq_input_and_q_output(node: Node) -> bool:
261+
"""
262+
Check if a node has dequantize input(s) and quantize output(s).
263+
This indicates the node is part of a quantized computation path.
264+
"""
265+
# Check if any input is from a dequantize op
266+
has_dq_input = any(
267+
isinstance(arg, Node) and arg.target in DQ_OPS
268+
for arg in node.args
269+
if isinstance(arg, Node)
270+
)
271+
272+
# Check if any output goes to a quantize op
273+
has_q_output = any(
274+
user.target in Q_OPS
275+
for user in node.users
276+
)
277+
return has_dq_input and has_q_output
278+
251279
@staticmethod
252280
def is_foldable(node: Node) -> bool:
253281
if node.op != "call_function":
@@ -263,6 +291,13 @@ def is_foldable(node: Node) -> bool:
263291
):
264292
return True
265293

294+
# Passthrough ops (hardtanh, relu, clamp) that have dq inputs and q outputs
295+
# should be foldable even without explicit annotation. These ops preserve
296+
# quantization parameters and are common in quantized models like MobileNetV2.
297+
if node.target in PASSTHROUGH_OPS:
298+
if FoldAndAnnotateQParamsPass._has_dq_input_and_q_output(node):
299+
return True
300+
266301
# We should not fold q-dq nodes into non-quantized nodes.
267302
if not (
268303
ArmAnnotationInfo.CUSTOM_META_KEY in node.meta.get("custom", {})
@@ -335,6 +370,35 @@ def call(self, graph_module: GraphModule) -> PassResult: # noqa: C901
335370
):
336371
self._handle_control_flow_node(n, graph_module)
337372

373+
# Second pass: Propagate qparams through passthrough ops.
374+
# For ops like hardtanh that share qparams with their input, we need to:
375+
# 1. Copy output_qparams from the passthrough op to its input node
376+
# 2. Set input_qparams on the passthrough op
377+
for n in graph_module.graph.nodes:
378+
n = cast(Node, n)
379+
if n.target not in PASSTHROUGH_OPS:
380+
continue
381+
382+
# Check if this passthrough op has output_qparams but missing input_qparams
383+
has_output = "output_qparams" in n.meta and len(n.meta.get("output_qparams", {})) > 0
384+
has_input = "input_qparams" in n.meta and len(n.meta.get("input_qparams", {})) > 0
385+
386+
if not has_output or has_input:
387+
continue
388+
389+
# Get the input node
390+
if len(n.args) == 0 or not isinstance(n.args[0], Node):
391+
continue
392+
393+
input_node = n.args[0]
394+
395+
# Propagate: For passthrough ops, output qparams equal input qparams
396+
if "output_qparams" not in input_node.meta:
397+
input_node.meta["output_qparams"] = n.meta["output_qparams"]
398+
399+
# Set input_qparams from output_qparams (same for passthrough ops)
400+
n.meta["input_qparams"] = {0: n.meta["output_qparams"][0]}
401+
338402
# retrace the graph to update the fake tensor types
339403
graph_module = super().call(graph_module).graph_module
340404

backends/cortex_m/passes/convert_to_cortex_m_pass.py

Lines changed: 162 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,164 @@ def _get_batch_size_from_conv(self, conv_node: torch.fx.Node):
6969
pass
7070
return None
7171

72+
def _get_addmm_replacement(self, node):
73+
"""
74+
Handle aten.addmm (decomposed linear):
75+
addmm(bias, input, weight.T) = input @ weight.T + bias
76+
77+
In the graph, weight is already transposed via cortex_m.transpose or aten.t
78+
so we need to trace back to find the original weight placeholder.
79+
80+
input_qparams indices for addmm:
81+
[0] = bias (int32)
82+
[1] = input activation (int8)
83+
[2] = weight (int8) - often missing because weight goes through transpose
84+
"""
85+
# addmm args: (bias, input, weight_transposed)
86+
bias_node = node.args[0]
87+
input_node = node.args[1]
88+
weights_node = node.args[2] # This is the transposed weight
89+
90+
# Get input qparams - use index 1 for input activation (index 0 is bias!)
91+
input_scale = node.meta["input_qparams"][1].scale
92+
input_zp = node.meta["input_qparams"][1].zp
93+
94+
# Get output qparams
95+
output_scale = node.meta["output_qparams"][0].scale
96+
output_zp = node.meta["output_qparams"][0].zp
97+
output_min = node.meta["output_qparams"][0].qmin
98+
output_max = node.meta["output_qparams"][0].qmax
99+
100+
# Trace back through graph to find original weight placeholder and its qparams
101+
current_node = weights_node
102+
max_depth = 10
103+
found_transpose = False
104+
original_weight_node = None
105+
weight_qparams = None
106+
107+
# Check if weights_node (transpose) has qparams in its metadata
108+
if "input_qparams" in weights_node.meta:
109+
if 0 in weights_node.meta["input_qparams"]:
110+
weight_qparams = weights_node.meta["input_qparams"][0]
111+
if "output_qparams" in weights_node.meta:
112+
if weight_qparams is None and 0 in weights_node.meta["output_qparams"]:
113+
weight_qparams = weights_node.meta["output_qparams"][0]
114+
115+
# Trace back to find original weight placeholder
116+
for depth in range(max_depth):
117+
# Check for qparams in current node
118+
if weight_qparams is None and "output_qparams" in current_node.meta:
119+
oq = current_node.meta.get("output_qparams", {})
120+
if 0 in oq:
121+
weight_qparams = oq[0]
122+
123+
if current_node.op == "placeholder":
124+
original_weight_node = current_node
125+
if "val" in original_weight_node.meta:
126+
val = original_weight_node.meta["val"]
127+
# Check placeholder for output_qparams
128+
if weight_qparams is None and "output_qparams" in original_weight_node.meta:
129+
oq = original_weight_node.meta.get("output_qparams", {})
130+
if 0 in oq:
131+
weight_qparams = oq[0]
132+
break
133+
elif current_node.op == "call_function":
134+
target_name = str(current_node.target)
135+
if ".t." in target_name or "transpose" in target_name.lower():
136+
found_transpose = True
137+
if len(current_node.args) > 0:
138+
current_node = current_node.args[0]
139+
else:
140+
break
141+
else:
142+
break
143+
144+
if original_weight_node is None:
145+
raise RuntimeError(f"Could not find original weight placeholder for addmm node {node.name}")
146+
147+
# Get the weight tensor from the original placeholder
148+
weights_tensor = get_param_tensor(self.exported_program, original_weight_node)
149+
150+
# If transpose found, original weights are [out_feat, in_feat]
151+
# CMSIS-NN expects [out_feat, in_feat], so use original directly
152+
if found_transpose:
153+
final_weights = weights_tensor.contiguous()
154+
else:
155+
final_weights = weights_tensor.T.contiguous()
156+
157+
# Extract weight scale and zero_point
158+
if weight_qparams is not None:
159+
weight_scale = weight_qparams.scale
160+
weight_zp = weight_qparams.zp
161+
elif 2 in node.meta.get("input_qparams", {}):
162+
# Fallback: check if weight qparams are at index 2
163+
weight_scale = node.meta["input_qparams"][2].scale
164+
weight_zp = node.meta["input_qparams"][2].zp
165+
else:
166+
# Derive weight scale from bias scale!
167+
# For quantized linear: bias_scale = input_scale * weight_scale
168+
# Therefore: weight_scale = bias_scale / input_scale
169+
if 0 in node.meta.get("input_qparams", {}):
170+
bias_scale = node.meta["input_qparams"][0].scale
171+
weight_scale = bias_scale / input_scale
172+
weight_zp = 0 # Symmetric quantization
173+
else:
174+
# Last resort: derive from weight tensor (symmetric quantization assumed)
175+
weight_min = final_weights.min().item()
176+
weight_max = final_weights.max().item()
177+
weight_absmax = max(abs(weight_min), abs(weight_max))
178+
weight_scale = weight_absmax / 127.0 if weight_absmax > 0 else 1.0
179+
weight_zp = 0
180+
181+
# Calculate quantization multiplier and shift
182+
quantized_multiplier, quantized_shift = quantize_multiplier_aot(
183+
(input_scale * weight_scale) / output_scale
184+
)
185+
186+
# Compute kernel_sum WITHOUT bias (pass None)
187+
# Pass bias separately to C++ operator
188+
kernel_sum_tensor = self._compute_kernel_sum(
189+
final_weights, None, -input_zp, -weight_zp
190+
)
191+
192+
# Create placeholders
193+
with node.graph.inserting_after(original_weight_node):
194+
weights_placeholder = create_constant_placeholder(
195+
self.exported_program,
196+
node.graph,
197+
node.name + "_weights_correct",
198+
InputKind.PARAMETER,
199+
final_weights,
200+
)
201+
202+
kernel_sum = create_constant_placeholder(
203+
self.exported_program,
204+
node.graph,
205+
node.name + "_kernel_sum",
206+
InputKind.PARAMETER,
207+
kernel_sum_tensor,
208+
)
209+
210+
# CMSIS-NN shift convention: use the shift as-is (not negated)
211+
args = (
212+
input_node,
213+
weights_placeholder,
214+
bias_node, # Pass original bias (kernel_sum doesn't include it)
215+
kernel_sum,
216+
-input_zp,
217+
-weight_zp,
218+
output_zp,
219+
[quantized_multiplier],
220+
[quantized_shift], # Use shift as-is
221+
output_max,
222+
output_min,
223+
)
224+
225+
return exir_ops.edge.cortex_m.quantized_linear.default, args
226+
72227
def _get_linear_replacement(self, node):
73228
"""
74-
Let
229+
Let
75230
- yi be the output activations (y1, ... yn)
76231
- xj be the input activations (x1, ... xm)
77232
- wij be the weights (w11, ... wnm)
@@ -175,6 +330,7 @@ def _get_convolution_replacement(self, node) -> int:
175330

176331
weight_tensor = get_param_tensor(self.exported_program, weight)
177332

333+
178334
# Detect depthwise convolution:
179335
# Depthwise means groups == in_channels, out_channels == K * in_channels
180336
# Weight shape is [out_ch, in_ch_per_group, H, W]
@@ -386,6 +542,11 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
386542
match node.target:
387543
case exir_ops.edge.aten.linear.default:
388544
op, args = self._get_linear_replacement(node)
545+
case exir_ops.edge.aten.addmm.default:
546+
result = self._get_addmm_replacement(node)
547+
if result is None:
548+
continue
549+
op, args = result
389550
case exir_ops.edge.aten.convolution.default:
390551
# Check if it's transposed convolution (arg index 6)
391552
transposed = node.args[6] if len(node.args) > 6 else False

backends/cortex_m/passes/cortex_m_pass_manager.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
from executorch.backends.transforms.replace_scalar_with_tensor import (
1414
ReplaceScalarWithTensorArgPass,
1515
)
16+
from executorch.backends.arm._passes.decompose_adaptive_avg_pool2d_pass import (
17+
DecomposeAdaptiveAvgPool2dPass,
18+
)
1619
from executorch.exir.pass_base import ExportPass
1720
from executorch.exir.pass_manager import PassManager
1821
from executorch.exir.program._program import _transform
@@ -33,6 +36,7 @@ class CortexMPassManager(PassManager):
3336
ReplaceScalarWithTensorArgPass,
3437
ReplaceQuantNodesPass,
3538
ActivationFusionPass,
39+
DecomposeAdaptiveAvgPool2dPass,
3640
DecomposeHardswishPass,
3741
QuantizedOpFusionPass,
3842
ConvertToCortexMPass,
@@ -44,12 +48,22 @@ class CortexMPassManager(PassManager):
4448
ClampHardswishPass,
4549
]
4650

47-
def __init__(self, exported_program, passes=None):
51+
def __init__(self, exported_program, passes=None, skip_passes=None):
52+
"""
53+
Initialize CortexMPassManager.
54+
55+
Args:
56+
exported_program: The ExportedProgram to transform.
57+
passes: Optional custom pass list. Uses default pass_list if None.
58+
skip_passes: Optional list of pass classes to skip.
59+
"""
4860
self.exported_program = exported_program
4961
if passes is not None:
5062
self.passes = passes
5163
else:
52-
self.passes = self.pass_list
64+
self.passes = list(self.pass_list)
65+
if skip_passes:
66+
self.passes = [p for p in self.passes if p not in skip_passes]
5367

5468
def transform_for_annotation(self, model):
5569
passes = self.pass_list_transform_for_annotation

backends/cortex_m/quantizer/quantizer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,11 @@ class SharedQspecQuantizer(Quantizer):
448448
torch.ops.aten._unsafe_view.default,
449449
torch.ops.aten.unflatten.int,
450450
torch.ops.aten.flatten.using_ints,
451+
# Additional passthrough ops for MobileNetV2 and similar architectures
452+
torch.ops.aten.hardtanh.default,
453+
torch.ops.aten.hardtanh_.default,
454+
torch.ops.aten.max_pool2d.default,
455+
torch.ops.aten.dropout.default,
451456
]
452457

453458
def __init__(self, targets: Optional[List[OpOverload]] = None) -> None:

0 commit comments

Comments
 (0)