Skip to content

Commit 193574d

Browse files
rascaniclaude
andauthored
Cortex-M backend: dispatch quantized_linear AOT layout on target ISA (#19676)
### Summary CMSIS-NN's `arm_fully_connected_s8` has three runtime paths, gated by compile-time `ARM_MATH_MVEI` / `ARM_MATH_DSP`. They split the bias and input_offset×sum(weight) offset term between two inputs, in incompatible conventions: * MVE: reads `ctx.buf` as a precomputed kernel_sum that must already include `input_offset × sum(weight)` and the bias contribution. The `bias` argument is `(void)bias;` — ignored. * DSP / scalar: read the `bias` argument directly and fold the input_offset contribution at runtime. `ctx.buf` (kernel_sum) is `(void)kernel_sum;` — ignored. `ConvertToCortexMPass._get_linear_replacement` previously emitted only the MVE shape (kernel_sum populated, bias=None). On any non-MVE build the DSP/scalar path started the int32 accumulator at 0 instead of at `bias + input_offset × sum(weight)`, dropping both the bias and the offset contribution. The accumulator wound up much smaller than intended, requantization collapsed it to the output zero point, and every classifier with a deep, narrow tail produced essentially uniform near-zero outputs on non-MVE Cortex-M builds. Use the target-ISA plumbing added by the CortexMTargetConfig PR (#19470) to dispatch the right input shape at AOT time: on MVE targets emit kernel_sum with bias folded in (bias=None); on DSP and scalar targets emit the raw int32 bias directly (kernel_sum=None). The CMSIS-NN runtime then matches exactly what it expects. Update `quantized_linear_impl` in `operators.py` to mirror the same contract: dispatch off whichever of kernel_sum / bias is non-None. Threading happens automatically via `CortexMPassManager`'s signature injection of `target_config` into the pass's `__init__`. ### Test Plan Add `backends/cortex_m/test/misc/test_quantized_linear_small_magnitude.py` as a regression. A tiny `nn.Linear(512, 10)` on uniform[0, 0.002] input is the minimal reproducer for the small-magnitude regime where the missing offset terms dominate. The dialect test parametrizes over MVE/DSP/scalar target configs; the implementation test runs against whatever path the runner build matches. The DSP & Scalar tests will need #19520 for CI testing. Authored with Claude. --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent b771dab commit 193574d

5 files changed

Lines changed: 231 additions & 49 deletions

File tree

backends/cortex_m/ops/operators.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -467,8 +467,8 @@ def quantized_linear_meta(
467467
def quantized_linear_impl(
468468
input: torch.Tensor,
469469
weights: torch.Tensor,
470-
bias: torch.Tensor,
471-
kernel_sum: torch.Tensor,
470+
bias: torch.Tensor | None,
471+
kernel_sum: torch.Tensor | None,
472472
input_offset: int,
473473
filter_offset: int,
474474
output_offset: int,
@@ -481,10 +481,11 @@ def quantized_linear_impl(
481481
Functional variant - creates output tensor and calls out variant
482482
"""
483483

484-
# Leaving both implementations for debugging purposes.
485-
compute_using_kernel_sum = True
486-
487-
if compute_using_kernel_sum:
484+
# Mirror CMSIS-NN's arm_fully_connected_s8 contract: the MVE path reads
485+
# kernel_sum (ctx.buf) and ignores bias; the DSP and scalar paths read
486+
# bias and ignore kernel_sum. The AOT pass populates exactly one of them
487+
# based on the target ISA, so dispatch off which one is present.
488+
if kernel_sum is not None:
488489
weights_int32 = weights.to(torch.int32)
489490

490491
input_int32 = input.to(torch.int32)

backends/cortex_m/passes/aten_to_cortex_m_pass.py

Lines changed: 45 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from typing import cast
99

10+
import cmsis_nn # type: ignore[import-not-found, import-untyped]
1011
import executorch.backends.cortex_m.ops.operators # noqa
1112
import executorch.exir as exir
1213
import torch
@@ -146,7 +147,7 @@ def _has_qparams(node: Node) -> bool:
146147
@AtenToCortexMPass.register_dialect_substitution(exir_ops.edge.aten.tanh.default)
147148
@AtenToCortexMPass.register_dialect_substitution(exir_ops.edge.aten.silu.default)
148149
def _get_activation_replacement(
149-
node: Node, exported_program: ExportedProgram
150+
node: Node, dialect_pass: AtenToDialectPass
150151
) -> DialectNodeSpec | None:
151152
"""Lower a standalone quantized sigmoid / tanh / silu to a single
152153
cortex_m.quantized_activation call backed by an AoT-built 256-entry
@@ -156,6 +157,7 @@ def _get_activation_replacement(
156157
if not _has_qparams(node):
157158
return None
158159

160+
exported_program = dialect_pass.exported_program
159161
input_qparams = node.meta["input_qparams"][0]
160162
output_qparams = node.meta["output_qparams"][0]
161163
lut_tensor = build_activation_lut(
@@ -187,7 +189,7 @@ def _get_activation_replacement(
187189

188190
@AtenToCortexMPass.register_dialect_substitution(exir_ops.edge.aten.linear.default)
189191
def _get_linear_replacement(
190-
node: Node, exported_program: ExportedProgram
192+
node: Node, dialect_pass: AtenToDialectPass
191193
) -> DialectNodeSpec | None:
192194
"""
193195
Let
@@ -209,6 +211,10 @@ def _get_linear_replacement(
209211
if not _has_qparams(node):
210212
return None
211213

214+
assert isinstance(dialect_pass, AtenToCortexMPass)
215+
exported_program = dialect_pass.exported_program
216+
target_config = dialect_pass.target_config
217+
212218
input_scale = node.meta["input_qparams"][0].scale
213219
input_zp = node.meta["input_qparams"][0].zp
214220
weight_scale = node.meta["input_qparams"][1].scale
@@ -218,37 +224,52 @@ def _get_linear_replacement(
218224
output_min = node.meta["output_qparams"][0].qmin
219225
output_max = node.meta["output_qparams"][0].qmax
220226

227+
if weight_zp != 0:
228+
raise NotImplementedError(
229+
f"cortex_m::quantized_linear assumes symmetric weight "
230+
f"quantization (weight_zp == 0); got weight_zp={weight_zp}"
231+
)
232+
221233
quantized_multiplier, quantized_shift = quantize_multiplier_aot(
222234
(input_scale * weight_scale) / output_scale
223235
)
224236

225-
# TODO: Add support for configuring the backend to support other extensions.
226-
# Kernel sum is only used in the CMSIS-NN implementation for the MVE extension,
227-
# so this should be optional.
237+
# CMSIS-NN's MVE `arm_fully_connected_s8` path reads a precomputed
238+
# kernel_sum (input_offset×sum(weight) + bias) from ctx.buf and
239+
# ignores the bias argument. The DSP and scalar paths do the opposite
240+
# — they read the bias argument at runtime and ignore ctx.buf
241+
# (see arm_nn_vec_mat_mult_t_s8.c). Pick the right input format here
242+
# based on the target ISA so the runtime gets exactly what it expects.
228243
linear_args = node.args
229244
weights = cast(Node, linear_args[1])
230245
weights_tensor = get_param_tensor(exported_program, weights)
231246
bias_node = cast(Node | None, linear_args[2]) if len(linear_args) > 2 else None
232247
bias_tensor = (
233248
get_param_tensor(exported_program, bias_node) if bias_node is not None else None
234249
)
235-
kernel_sum_tensor = _compute_kernel_sum(
236-
weights_tensor, bias_tensor, -input_zp, -weight_zp
237-
)
238-
with node.graph.inserting_after(weights):
239-
kernel_sum = create_constant_placeholder(
240-
exported_program,
241-
node.graph,
242-
node.name + "_kernel_sum",
243-
InputKind.PARAMETER,
244-
kernel_sum_tensor,
250+
251+
if target_config.backend == cmsis_nn.Backend.MVE:
252+
kernel_sum_tensor = _compute_kernel_sum(
253+
weights_tensor, bias_tensor, -input_zp, -weight_zp
245254
)
255+
with node.graph.inserting_after(weights):
256+
kernel_sum_arg = create_constant_placeholder(
257+
exported_program,
258+
node.graph,
259+
node.name + "_kernel_sum",
260+
InputKind.PARAMETER,
261+
kernel_sum_tensor,
262+
)
263+
bias_arg = None
264+
else:
265+
kernel_sum_arg = None
266+
bias_arg = bias_node
246267

247268
args = (
248269
linear_args[0],
249270
weights,
250-
None,
251-
kernel_sum,
271+
bias_arg,
272+
kernel_sum_arg,
252273
-input_zp,
253274
-weight_zp,
254275
output_zp,
@@ -263,11 +284,12 @@ def _get_linear_replacement(
263284

264285
@AtenToCortexMPass.register_dialect_substitution(exir_ops.edge.aten.convolution.default)
265286
def _get_convolution_replacement(
266-
node: Node, exported_program: ExportedProgram
287+
node: Node, dialect_pass: AtenToDialectPass
267288
) -> DialectNodeSpec | None:
268289
if not _has_qparams(node):
269290
return None
270291

292+
exported_program = dialect_pass.exported_program
271293
conv_args = node.args
272294
(
273295
x,
@@ -292,7 +314,7 @@ def _get_convolution_replacement(
292314
)
293315

294316
if transposed:
295-
return _get_transpose_conv2d_replacement(node, exported_program)
317+
return _get_transpose_conv2d_replacement(node, dialect_pass)
296318

297319
input_scale = node.meta["input_qparams"][0].scale
298320
input_zero_point = node.meta["input_qparams"][0].zp
@@ -437,14 +459,15 @@ def _get_convolution_replacement(
437459

438460

439461
def _get_transpose_conv2d_replacement(
440-
node: Node, exported_program: ExportedProgram
462+
node: Node, dialect_pass: AtenToDialectPass
441463
) -> DialectNodeSpec | None:
442464
"""
443465
Transform aten.convolution with transposed=True to cortex_m.quantized_transpose_conv2d.
444466
"""
445467
if not _has_qparams(node):
446468
return None
447469

470+
exported_program = dialect_pass.exported_program
448471
conv_t_args = node.args
449472
(
450473
x,
@@ -562,11 +585,12 @@ def _get_transpose_conv2d_replacement(
562585

563586
@AtenToCortexMPass.register_dialect_substitution(exir_ops.edge.aten.bmm.default)
564587
def _get_bmm_replacement(
565-
node: Node, exported_program: ExportedProgram
588+
node: Node, dialect_pass: AtenToDialectPass
566589
) -> DialectNodeSpec | None:
567590
if not _has_qparams(node):
568591
return None
569592

593+
exported_program = dialect_pass.exported_program
570594
lhs_scale = node.meta["input_qparams"][0].scale
571595
lhs_zp = node.meta["input_qparams"][0].zp
572596
rhs_scale = node.meta["input_qparams"][1].scale

backends/cortex_m/test/ops/test_linear.py

Lines changed: 161 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,21 @@
1-
# Copyright 2025 Arm Limited and/or its affiliates.
1+
# Copyright 2025-2026 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

66

7+
from dataclasses import dataclass
8+
79
import torch
810
from executorch.backends.arm.test.common import parametrize
11+
from executorch.backends.cortex_m.target_config import CortexM, CortexMTargetConfig
912
from executorch.backends.cortex_m.test.tester import (
1013
CortexMTester,
1114
McuTestCase,
1215
ramp_tensor,
1316
)
17+
from executorch.backends.test.harness.stages import StageType
18+
from executorch.exir.dialects._ops import ops as exir_ops
1419

1520

1621
class CortexMLinear(torch.nn.Module):
@@ -128,3 +133,158 @@ def test_dialect_linear(test_case):
128133
def test_implementation_linear(test_case):
129134
tester = CortexMTester(test_case.model, test_case.example_inputs)
130135
tester.test_implementation(qtol=1)
136+
137+
138+
# ---------------------------------------------------------------------------
139+
# Regression: cortex_m::quantized_linear must pick the right CMSIS-NN input
140+
# convention based on the target ISA. `arm_fully_connected_s8` reads
141+
# kernel_sum (ctx.buf) on MVE/Helium and reads the bias argument on DSP/scalar
142+
# paths; the two are mutually exclusive. Previously the pass unconditionally
143+
# emitted the MVE shape, which silently dropped the bias and input-offset
144+
# terms on every non-MVE build. The regression only showed up when those
145+
# terms dominated the int32 accumulator -- i.e., on small-magnitude inputs.
146+
#
147+
# Coverage strategy: a single ISA-parametrized dialect test verifies the
148+
# numeric output against the float reference (catches the dropped-bias bug
149+
# directly), checks ops_after_transforms to confirm the linear lowered, and
150+
# asserts the post-pass node has the value in the slot the configured ISA
151+
# expects -- the structural guard against a regression that emits zero-valued
152+
# kernel_sum on a no-bias DSP path (numerically inert, but wrong shape).
153+
# An additional implementation test drives the default M55 MVE build path
154+
# through the simulator.
155+
# ---------------------------------------------------------------------------
156+
157+
158+
class _SmallMagnitudeLinear(torch.nn.Module):
159+
ops_before_transforms = {
160+
"executorch_exir_dialects_edge__ops_aten_linear_default": 1,
161+
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2,
162+
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 4,
163+
}
164+
ops_after_transforms = {
165+
"executorch_exir_dialects_edge__ops_cortex_m_quantized_linear_default": 1,
166+
"executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1,
167+
"executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1,
168+
}
169+
170+
def __init__(self, bias: bool = True):
171+
super().__init__()
172+
self.fc = torch.nn.Linear(512, 10, bias=bias)
173+
174+
def forward(self, x):
175+
return self.fc(x)
176+
177+
178+
class _SmallMagnitudeLinearNoBias(_SmallMagnitudeLinear):
179+
ops_before_transforms = {
180+
"executorch_exir_dialects_edge__ops_aten_linear_default": 1,
181+
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2,
182+
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3,
183+
}
184+
185+
def __init__(self):
186+
super().__init__(bias=False)
187+
188+
189+
def _small_magnitude_input():
190+
return torch.rand(1, 512) * 0.002
191+
192+
193+
_small_magnitude_calibration = [(_small_magnitude_input(),) for _ in range(8)]
194+
195+
196+
@dataclass(frozen=True)
197+
class _SmallMagnitudeVariant:
198+
case: McuTestCase
199+
target_config: CortexMTargetConfig
200+
uses_kernel_sum: bool
201+
has_bias: bool
202+
203+
204+
def _small_magnitude_variant(
205+
model_cls, cpu: CortexM, *, uses_kernel_sum: bool, has_bias: bool
206+
) -> _SmallMagnitudeVariant:
207+
return _SmallMagnitudeVariant(
208+
case=McuTestCase(
209+
model=model_cls().eval(),
210+
example_inputs=lambda: (_small_magnitude_input(),),
211+
),
212+
target_config=CortexMTargetConfig(cpu=cpu),
213+
uses_kernel_sum=uses_kernel_sum,
214+
has_bias=has_bias,
215+
)
216+
217+
218+
# bias=True covers the regression directly (the bug dropped the bias term);
219+
# bias=False covers the symmetric case where only the input-offset term is
220+
# missing on the non-MVE paths.
221+
small_magnitude_variants = {
222+
"mve_bias": _small_magnitude_variant(
223+
_SmallMagnitudeLinear, CortexM.M55, uses_kernel_sum=True, has_bias=True
224+
),
225+
"dsp_bias": _small_magnitude_variant(
226+
_SmallMagnitudeLinear, CortexM.M4, uses_kernel_sum=False, has_bias=True
227+
),
228+
"scalar_bias": _small_magnitude_variant(
229+
_SmallMagnitudeLinear, CortexM.M0PLUS, uses_kernel_sum=False, has_bias=True
230+
),
231+
"mve_nobias": _small_magnitude_variant(
232+
_SmallMagnitudeLinearNoBias, CortexM.M55, uses_kernel_sum=True, has_bias=False
233+
),
234+
"dsp_nobias": _small_magnitude_variant(
235+
_SmallMagnitudeLinearNoBias, CortexM.M4, uses_kernel_sum=False, has_bias=False
236+
),
237+
"scalar_nobias": _small_magnitude_variant(
238+
_SmallMagnitudeLinearNoBias,
239+
CortexM.M0PLUS,
240+
uses_kernel_sum=False,
241+
has_bias=False,
242+
),
243+
}
244+
245+
246+
@parametrize("variant", small_magnitude_variants)
247+
def test_dialect_linear_small_magnitude(variant: _SmallMagnitudeVariant):
248+
tester = CortexMTester(
249+
variant.case.model,
250+
variant.case.get_example_inputs(),
251+
target_config=variant.target_config,
252+
)
253+
tester.test_dialect(
254+
ops_before_transforms=variant.case.model.ops_before_transforms,
255+
ops_after_transforms=variant.case.model.ops_after_transforms,
256+
qtol=1,
257+
calibration_samples=_small_magnitude_calibration,
258+
)
259+
260+
# Structural guard: numeric divergence catches the original dropped-bias
261+
# bug, but a future regression that emits zero-valued kernel_sum on a
262+
# no-bias DSP/scalar path would be numerically inert. Assert the slot the
263+
# configured ISA actually consumes is populated and the unused one is None.
264+
module = tester.get_artifact(StageType.RUN_PASSES).exported_program().module()
265+
linear_target = exir_ops.edge.cortex_m.quantized_linear.default
266+
[linear_node] = [
267+
n
268+
for n in module.graph.nodes
269+
if n.op == "call_function" and n.target == linear_target
270+
]
271+
bias_arg, kernel_sum_arg = linear_node.args[2], linear_node.args[3]
272+
if variant.uses_kernel_sum:
273+
assert kernel_sum_arg is not None
274+
assert bias_arg is None
275+
else:
276+
assert kernel_sum_arg is None
277+
if variant.has_bias:
278+
assert bias_arg is not None
279+
else:
280+
assert bias_arg is None
281+
282+
283+
def test_implementation_linear_small_magnitude():
284+
"""Exercise the MVE kernel_sum codepath via the default M55 simulator build."""
285+
case = McuTestCase(
286+
model=_SmallMagnitudeLinear().eval(),
287+
example_inputs=lambda: (_small_magnitude_input(),),
288+
)
289+
tester = CortexMTester(case.model, case.get_example_inputs())
290+
tester.test_implementation(qtol=1, calibration_samples=_small_magnitude_calibration)

0 commit comments

Comments
 (0)