Skip to content

Commit a3b13f3

Browse files
authored
Cortex-M: Thread target CPU/ISA through the AOT pass manager (#19470)
### Summary Introduce a CortexMCompileConfig dataclass (cpu + isa) that carries Cortex-M target information from the --target=cortex-m<variant>+int8 CLI string into CortexMPassManager. The full standard Cortex-M lineup is registered (M0, M0+, M3, M4, M7, M23, M33, M35P, M52, M55, M85), each with a sensible default ISA; the optional-DSP M33/M35P and optional-MVE M52/M55/M85 cases can be expressed via the isa= kwarg. No pass reads the config yet, so this change is purely plumbing, but it positions both the upcoming AOT scratch-buffer sizing work (#16580) and the CMSIS-NN scalar (#17646) / DSP (#17644) backend support to plug in without re-plumbing the call site. Actually building for the new variants still requires CortexMTester gains an optional config kwarg, and the Pico 2 MLP example now constructs CortexMPassManager with cpu='cortex-m33' to match the RP2350 hardware it targets. Authored with Claude. cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell
1 parent 8e8e957 commit a3b13f3

10 files changed

Lines changed: 366 additions & 35 deletions

File tree

.ci/scripts/test_cortex_m_e2e.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ MODEL=$1
1717
script_dir=$(realpath "$(dirname "${BASH_SOURCE[0]}")")
1818
et_root_dir=$(realpath "${script_dir}/../..")
1919

20-
# Quantization is the default for the cortex-m55+int8 target; run.sh's
20+
# Quantization is the default for the cortex-m55 target; run.sh's
2121
# arg parser only recognizes --no_quantize, so we omit any explicit flag.
2222
bash "${et_root_dir}/examples/arm/run.sh" \
2323
--model_name="${MODEL}" \
24-
--target=cortex-m55+int8 \
24+
--target=cortex-m55 \
2525
--bundleio

backends/arm/scripts/aot_arm_compiler.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
ReplaceQuantNodesPass,
4040
)
4141
from executorch.backends.cortex_m.quantizer.quantizer import CortexMQuantizer
42+
from executorch.backends.cortex_m.target_config import CortexMTargetConfig
4243
from executorch.devtools import BundledProgram, generate_etrecord
4344
from executorch.devtools.backend_debug import get_delegation_info
4445
from executorch.devtools.bundled_program.config import MethodTestCase, MethodTestSuite
@@ -465,7 +466,16 @@ def forward(self, x):
465466
"TOSA-1.0+INT",
466467
"TOSA-1.0+FP",
467468
"TOSA-1.0+INT+int16",
468-
"cortex-m55+int8",
469+
"cortex-m0",
470+
"cortex-m0plus",
471+
"cortex-m3",
472+
"cortex-m4",
473+
"cortex-m7",
474+
"cortex-m23",
475+
"cortex-m33",
476+
"cortex-m35p",
477+
"cortex-m55",
478+
"cortex-m85",
469479
]
470480

471481

@@ -566,7 +576,7 @@ def _get_args():
566576
required=False,
567577
default="ethos-u55-128",
568578
choices=TARGETS,
569-
help=f"Target backend. For delegated models: Ethos-U/VGF/TOSA variants. For non-delegated: cortex-m55+int8 (CMSIS-NN portable kernels). Valid targets: {TARGETS}",
579+
help=f"Target backend. For delegated models: Ethos-U/VGF/TOSA variants. For non-delegated: cortex-m<variant> (CMSIS-NN portable kernels). Valid targets: {TARGETS}",
570580
)
571581
# TODO: Remove --evaluate and --evaluate_config completely after a suitable time.
572582
# They are deprecated and no longer functional in this script.
@@ -860,9 +870,13 @@ def _to_edge_cortex_m(
860870
model: GraphModule,
861871
example_inputs: Tuple[torch.Tensor],
862872
calibration_samples: Optional[List[Tuple[torch.Tensor, ...]]],
873+
target_config: CortexMTargetConfig,
863874
):
864875
"""Cortex-M/CMSIS-NN compilation path with no delegation."""
865-
logging.info("Using Cortex-M/CMSIS-NN compilation path (no delegation)")
876+
logging.info(
877+
f"Using Cortex-M/CMSIS-NN compilation path for cpu={target_config.cpu.name} "
878+
f"backend={target_config.backend.name}"
879+
)
866880

867881
def _to_channels_last(x):
868882
if isinstance(x, torch.Tensor):
@@ -915,7 +929,9 @@ def _to_channels_last(x):
915929
),
916930
)
917931

918-
pass_manager = CortexMPassManager(edge.exported_program())
932+
pass_manager = CortexMPassManager(
933+
edge.exported_program(), target_config=target_config
934+
)
919935
edge._edge_programs["forward"] = pass_manager.transform()
920936

921937
return model_quant, edge
@@ -1007,11 +1023,12 @@ def main() -> None: # noqa: C901
10071023
else:
10081024
quant_mode = None
10091025

1010-
if args.target == "cortex-m55+int8":
1026+
if args.target.startswith("cortex-m"):
10111027
# Cortex-M path: CMSIS-NN portable kernels, no delegation
1028+
target_config = CortexMTargetConfig.from_target_string(args.target)
10121029
if args.delegate:
10131030
logging.warning(
1014-
"--delegate is ignored for target 'cortex-m55+int8' "
1031+
f"--delegate is ignored for target {args.target!r} "
10151032
"(this target does not use delegated ops)."
10161033
)
10171034
args.delegate = False
@@ -1021,6 +1038,7 @@ def main() -> None: # noqa: C901
10211038
model,
10221039
example_inputs,
10231040
calibration_samples,
1041+
target_config,
10241042
)
10251043
elif args.delegate:
10261044
# As we can target multiple output encodings, one must

backends/cortex_m/passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def _ensure_cortex_m_dependencies() -> None:
3636
from .activation_fusion_pass import ActivationFusionPass # noqa
3737
from .clamp_hardswish_pass import ClampHardswishPass # noqa
3838
from .convert_to_cortex_m_pass import ConvertToCortexMPass # noqa
39+
from .cortex_m_pass import CortexMPass # noqa
3940
from .decompose_hardswish_pass import DecomposeHardswishPass # noqa
4041
from .decompose_mean_pass import DecomposeMeanPass # noqa
4142
from .quantized_clamp_activation_pass import QuantizedClampActivationPass # noqa
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from executorch.backends.cortex_m.target_config import CortexMTargetConfig
8+
from executorch.exir.pass_base import ExportPass
9+
from torch.export import ExportedProgram
10+
11+
12+
class CortexMPass(ExportPass):
13+
"""Base class for passes that need the Cortex-M target config.
14+
15+
Passes that subclass this declare `exported_program` and `target_config`
16+
in their `__init__`; `CortexMPassManager.transform()` injects both
17+
automatically when running the pass list.
18+
"""
19+
20+
def __init__(
21+
self,
22+
exported_program: ExportedProgram,
23+
target_config: CortexMTargetConfig,
24+
) -> None:
25+
super().__init__()
26+
self._exported_program = exported_program
27+
self._target_config = target_config
28+
29+
@property
30+
def exported_program(self) -> ExportedProgram:
31+
return self._exported_program
32+
33+
@property
34+
def target_config(self) -> CortexMTargetConfig:
35+
return self._target_config

backends/cortex_m/passes/cortex_m_pass_manager.py

Lines changed: 45 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55

66

77
import inspect
8-
from typing import Callable, cast, Optional, Type
8+
from typing import Any, Optional, Type
99

1010
from executorch.backends.arm._passes import (
1111
FoldAndAnnotateQParamsPass,
1212
ScalarsToAttributePass,
1313
)
14+
from executorch.backends.cortex_m.target_config import CortexM, CortexMTargetConfig
1415
from executorch.backends.transforms.remove_getitem_op import RemoveGetItemPass
1516
from executorch.backends.transforms.replace_scalar_with_tensor import (
1617
ReplaceScalarWithTensorArgPass,
@@ -19,9 +20,6 @@
1920
from executorch.exir.pass_manager import PassManager
2021
from executorch.exir.program._program import _transform, lift_constant_tensor_pass
2122
from torch.export import ExportedProgram
22-
from torch.fx.passes.infra.pass_base import PassResult
23-
24-
from torch.nn import Module
2523

2624
from .activation_fusion_pass import ActivationFusionPass
2725
from .clamp_hardswish_pass import ClampHardswishPass
@@ -57,14 +55,33 @@ class CortexMPassManager(PassManager):
5755
]
5856

5957
def __init__(
60-
self, exported_program, passes: Optional[list[PassClass]] = None
58+
self,
59+
exported_program: ExportedProgram | None,
60+
passes: Optional[list[PassClass]] = None,
61+
target_config: Optional[CortexMTargetConfig] = None,
6162
) -> None:
63+
"""Initialize the Cortex-M pass manager.
64+
65+
Args:
66+
exported_program: The exported program to transform. Required
67+
before calling ``transform()``; may be ``None`` for callers
68+
that only use ``transform_for_annotation()``.
69+
passes: Optional override of the pass list. Defaults to
70+
``CortexMPassManager.pass_list``.
71+
target_config: Compilation target for passes that need it.
72+
Defaults to ``CortexMTargetConfig(cpu=CortexM.M55)``, which
73+
resolves through cmsis_nn to the MVE backend — matching the
74+
pre-config historical behaviour.
75+
"""
6276
super().__init__(passes=[])
6377
self.exported_program = exported_program
6478
# PassManager.passes is typed as callables; this manager stores pass classes which are initialized at transform time with the exported_program.
6579
self.passes: list[PassClass] = ( # type: ignore[assignment]
6680
passes if passes is not None else self.pass_list # type: ignore[assignment]
6781
)
82+
self.target_config: CortexMTargetConfig = target_config or CortexMTargetConfig(
83+
cpu=CortexM.M55
84+
)
6885

6986
def transform_for_annotation(self, model):
7087
passes = self.pass_list_transform_for_annotation
@@ -73,18 +90,31 @@ def transform_for_annotation(self, model):
7390
return model
7491

7592
def transform(self) -> ExportedProgram:
76-
ep = self.exported_program
93+
exported_program = self.exported_program
94+
if not isinstance(exported_program, ExportedProgram):
95+
raise ValueError(
96+
f"{type(self).__name__}.transform() needs a real ExportedProgram, "
97+
f"got {exported_program!r}"
98+
)
99+
77100
for pass_cls in self.passes:
101+
if not isinstance(pass_cls, type):
102+
raise ValueError(
103+
f"{type(self).__name__} expects pass classes, not instances; "
104+
f"got {pass_cls!r}"
105+
)
106+
78107
signature = inspect.signature(pass_cls)
108+
kwargs: dict[str, Any] = {}
79109
if "exported_program" in signature.parameters:
80-
ep_pass_ctor = cast(Callable[[ExportedProgram], ExportPass], pass_cls)
81-
transform_pass = ep_pass_ctor(ep)
82-
else:
83-
transform_pass = pass_cls()
84-
pass_callable = cast(Callable[[Module], PassResult], transform_pass)
85-
ep = _transform(ep, pass_callable)
110+
kwargs["exported_program"] = exported_program
111+
if "target_config" in signature.parameters:
112+
kwargs["target_config"] = self.target_config
113+
114+
transform_pass = pass_cls(**kwargs)
115+
exported_program = _transform(exported_program, transform_pass)
86116

87117
# All constant tensors should be lifted to buffers at this point, re-run
88-
# lift_constant_tensor_pass in case new ones have been introduced by the passes above.
89-
ep = lift_constant_tensor_pass(ep)
90-
return ep
118+
# lift_constant_tensor_pass in case new ones have been introduced.
119+
exported_program = lift_constant_tensor_pass(exported_program)
120+
return exported_program

backends/cortex_m/target_config.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from __future__ import annotations
8+
9+
from dataclasses import dataclass
10+
from enum import auto, Enum
11+
from typing import Optional
12+
13+
import cmsis_nn # type: ignore[import-not-found, import-untyped]
14+
15+
16+
class CortexM(Enum):
17+
"""Cortex-M CPU variant. Names mirror cmsis_nn.CortexM so the cmsis_nn
18+
enum can be looked up by name."""
19+
20+
M0 = auto()
21+
M0PLUS = auto()
22+
M3 = auto()
23+
M4 = auto()
24+
M7 = auto()
25+
M23 = auto()
26+
M33 = auto()
27+
M35P = auto()
28+
M55 = auto()
29+
M85 = auto()
30+
31+
32+
# Per-CPU set of cmsis_nn backends the core can execute. SCALAR is
33+
# universal; DSP requires the Armv7E-M or Armv8-M-Mainline DSP option;
34+
# MVE requires Armv8.1-M Mainline with the MVE extension. The supersession
35+
# (SCALAR < DSP < MVE) reflects that an MVE-capable core also runs DSP
36+
# and scalar code, which is what makes "M55 without MVE" → DSP override
37+
# legitimate.
38+
_SUPPORTED_BACKENDS: dict[CortexM, frozenset[cmsis_nn.Backend]] = {
39+
CortexM.M0: frozenset({cmsis_nn.Backend.SCALAR}),
40+
CortexM.M0PLUS: frozenset({cmsis_nn.Backend.SCALAR}),
41+
CortexM.M3: frozenset({cmsis_nn.Backend.SCALAR}),
42+
CortexM.M23: frozenset({cmsis_nn.Backend.SCALAR}),
43+
CortexM.M4: frozenset({cmsis_nn.Backend.SCALAR, cmsis_nn.Backend.DSP}),
44+
CortexM.M7: frozenset({cmsis_nn.Backend.SCALAR, cmsis_nn.Backend.DSP}),
45+
CortexM.M33: frozenset({cmsis_nn.Backend.SCALAR, cmsis_nn.Backend.DSP}),
46+
CortexM.M35P: frozenset({cmsis_nn.Backend.SCALAR, cmsis_nn.Backend.DSP}),
47+
CortexM.M55: frozenset(
48+
{cmsis_nn.Backend.SCALAR, cmsis_nn.Backend.DSP, cmsis_nn.Backend.MVE}
49+
),
50+
CortexM.M85: frozenset(
51+
{cmsis_nn.Backend.SCALAR, cmsis_nn.Backend.DSP, cmsis_nn.Backend.MVE}
52+
),
53+
}
54+
55+
56+
@dataclass(frozen=True)
57+
class CortexMTargetConfig:
58+
"""AOT compile target configuration for the Cortex-M backend.
59+
60+
`cpu` selects the CPU variant. `isa` optionally overrides the cmsis_nn
61+
backend that would normally be derived from `cpu` — useful for cores
62+
with optional ISA extensions (M55 without MVE, M33 without DSP, etc.).
63+
Overrides are validated against the CPU's architectural capability set
64+
on construction; e.g. forcing MVE on an M0 raises ValueError.
65+
"""
66+
67+
cpu: CortexM
68+
isa: Optional[cmsis_nn.Backend] = None
69+
70+
def __post_init__(self) -> None:
71+
if self.isa is None:
72+
return
73+
supported = _SUPPORTED_BACKENDS.get(self.cpu)
74+
if supported is None or self.isa not in supported:
75+
allowed = sorted(b.name for b in supported) if supported else []
76+
raise ValueError(
77+
f"Backend {self.isa.name} is not supported on "
78+
f"{self.cpu.name}; supported: {allowed}"
79+
)
80+
81+
@property
82+
def backend(self) -> cmsis_nn.Backend:
83+
if self.isa is not None:
84+
return self.isa
85+
try:
86+
cmsis_member = getattr(cmsis_nn.CortexM, self.cpu.name)
87+
except AttributeError as e:
88+
raise ValueError(
89+
f"cmsis_nn does not yet support {self.cpu.name}; pass an "
90+
f"explicit `isa=` override or wait for upstream support."
91+
) from e
92+
return cmsis_nn.resolve_backend(cmsis_member)
93+
94+
@classmethod
95+
def from_target_string(cls, target: str) -> CortexMTargetConfig:
96+
"""Parse a `cortex-m<variant>` target string."""
97+
if not target.startswith("cortex-m"):
98+
raise ValueError(
99+
f"Cortex-M target string must start with 'cortex-m', "
100+
f"got: {target!r}"
101+
)
102+
enum_name = "M" + target[len("cortex-m") :].upper()
103+
try:
104+
cpu = CortexM[enum_name]
105+
except KeyError as e:
106+
raise ValueError(
107+
f"Unsupported Cortex-M target string: {target!r}. "
108+
f"Supported: {sorted('cortex-m' + m.name[1:].lower() for m in CortexM)}"
109+
) from e
110+
return cls(cpu=cpu)

0 commit comments

Comments
 (0)