Skip to content

Commit a2b556e

Browse files
committed
Cortex-M backend: thread target CPU/ISA through the AOT pass manager
Introduce a CortexMCompileConfig dataclass (cpu + isa) that carries Cortex-M target information from the --target=cortex-m<variant>+int8 CLI string into CortexMPassManager. The full standard Cortex-M lineup is registered (M0, M0+, M3, M4, M7, M23, M33, M35P, M52, M55, M85), each with a sensible default ISA; the optional-DSP M33/M35P and optional-MVE M52/M55/M85 cases can be expressed via the isa= kwarg. No pass reads the config yet, so this change is purely plumbing — but it positions both the upcoming AOT scratch-buffer sizing work (#16580) and the M0+ (#17646) / M33 (#17644) backend support to plug in without re-plumbing the call site. Actually building for the new variants still requires Phase 2's MPS2 platform glue. CortexMTester gains an optional config kwarg, and the Pico 2 MLP example now constructs CortexMPassManager with cpu='cortex-m33' to match the RP2350 hardware it targets. Authored with Claude.
1 parent 02f9866 commit a2b556e

6 files changed

Lines changed: 255 additions & 13 deletions

File tree

backends/arm/scripts/aot_arm_compiler.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from executorch.backends.arm.util._factory import create_partitioner, create_quantizer
3434

3535
from executorch.backends.arm.vgf import VgfCompileSpec
36+
from executorch.backends.cortex_m.compile_config import CortexMCompileConfig
3637
from executorch.backends.cortex_m.passes.cortex_m_pass_manager import CortexMPassManager
3738

3839
from executorch.backends.cortex_m.passes.replace_quant_nodes_pass import (
@@ -465,7 +466,17 @@ def forward(self, x):
465466
"TOSA-1.0+INT",
466467
"TOSA-1.0+FP",
467468
"TOSA-1.0+INT+int16",
469+
"cortex-m0+int8",
470+
"cortex-m0plus+int8",
471+
"cortex-m3+int8",
472+
"cortex-m4+int8",
473+
"cortex-m7+int8",
474+
"cortex-m23+int8",
475+
"cortex-m33+int8",
476+
"cortex-m35p+int8",
477+
"cortex-m52+int8",
468478
"cortex-m55+int8",
479+
"cortex-m85+int8",
469480
]
470481

471482

@@ -566,7 +577,7 @@ def _get_args():
566577
required=False,
567578
default="ethos-u55-128",
568579
choices=TARGETS,
569-
help=f"Target backend. For delegated models: Ethos-U/VGF/TOSA variants. For non-delegated: cortex-m55+int8 (CMSIS-NN portable kernels). Valid targets: {TARGETS}",
580+
help=f"Target backend. For delegated models: Ethos-U/VGF/TOSA variants. For non-delegated: cortex-m<variant>+int8 (CMSIS-NN portable kernels). Valid targets: {TARGETS}",
570581
)
571582
# TODO: Remove --evaluate and --evaluate_config completely after a suitable time.
572583
# They are deprecated and no longer functional in this script.
@@ -860,9 +871,14 @@ def _to_edge_cortex_m(
860871
model: GraphModule,
861872
example_inputs: Tuple[torch.Tensor],
862873
calibration_samples: Optional[List[Tuple[torch.Tensor, ...]]],
874+
config: CortexMCompileConfig,
863875
):
864876
"""Cortex-M/CMSIS-NN compilation path with no delegation."""
865-
logging.info("Using Cortex-M/CMSIS-NN compilation path (no delegation)")
877+
logging.info(
878+
"Using Cortex-M/CMSIS-NN compilation path for cpu=%s isa=%s",
879+
config.cpu,
880+
config.isa,
881+
)
866882

867883
def _to_channels_last(x):
868884
if isinstance(x, torch.Tensor):
@@ -915,7 +931,7 @@ def _to_channels_last(x):
915931
),
916932
)
917933

918-
pass_manager = CortexMPassManager(edge.exported_program())
934+
pass_manager = CortexMPassManager(edge.exported_program(), config=config)
919935
edge._edge_programs["forward"] = pass_manager.transform()
920936

921937
return model_quant, edge
@@ -1007,12 +1023,14 @@ def main() -> None: # noqa: C901
10071023
else:
10081024
quant_mode = None
10091025

1010-
if args.target == "cortex-m55+int8":
1026+
if args.target.startswith("cortex-m"):
10111027
# Cortex-M path: CMSIS-NN portable kernels, no delegation
1028+
cortex_m_config = CortexMCompileConfig.from_target_string(args.target)
10121029
if args.delegate:
10131030
logging.warning(
1014-
"--delegate is ignored for target 'cortex-m55+int8' "
1015-
"(this target does not use delegated ops)."
1031+
"--delegate is ignored for target %r "
1032+
"(this target does not use delegated ops).",
1033+
args.target,
10161034
)
10171035
args.delegate = False
10181036
model_quant, edge = _to_edge_cortex_m(
@@ -1021,6 +1039,7 @@ def main() -> None: # noqa: C901
10211039
model,
10221040
example_inputs,
10231041
calibration_samples,
1042+
cortex_m_config,
10241043
)
10251044
elif args.delegate:
10261045
# As we can target multiple output encodings, one must
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from __future__ import annotations
8+
9+
from dataclasses import dataclass
10+
from typing import Literal
11+
12+
Cpu = Literal[
13+
"cortex-m0",
14+
"cortex-m0plus",
15+
"cortex-m3",
16+
"cortex-m4",
17+
"cortex-m7",
18+
"cortex-m23",
19+
"cortex-m33",
20+
"cortex-m35p",
21+
"cortex-m52",
22+
"cortex-m55",
23+
"cortex-m85",
24+
]
25+
Isa = Literal["scalar", "dsp", "mve"]
26+
27+
# Default ISA per CPU follows the most common configuration each core is
28+
# shipped with. M33/M35P optionally lack DSP, and M52/M55/M85 optionally
29+
# lack MVE; callers can pass `isa=` explicitly to override.
30+
_CPU_DEFAULT_ISA: dict[str, str] = {
31+
"cortex-m0": "scalar",
32+
"cortex-m0plus": "scalar",
33+
"cortex-m3": "scalar",
34+
"cortex-m4": "dsp",
35+
"cortex-m7": "dsp",
36+
"cortex-m23": "scalar",
37+
"cortex-m33": "dsp",
38+
"cortex-m35p": "dsp",
39+
"cortex-m52": "mve",
40+
"cortex-m55": "mve",
41+
"cortex-m85": "mve",
42+
}
43+
44+
_SUPPORTED_FEATURES: frozenset[str] = frozenset({"int8"})
45+
46+
47+
@dataclass(frozen=True)
48+
class CortexMCompileConfig:
49+
"""AOT compile configuration for the Cortex-M backend.
50+
51+
`cpu` and `isa` are consumed by passes that need to differ by target — most
52+
notably any future AOT scratch-buffer sizing — and threaded through the
53+
build system as the `-mcpu=` value.
54+
55+
The current default matches pre-config behavior (M55 + MVE) so callers that
56+
don't opt in see no change.
57+
"""
58+
59+
cpu: Cpu = "cortex-m55"
60+
isa: Isa | None = None
61+
62+
def __post_init__(self) -> None:
63+
if self.cpu not in _CPU_DEFAULT_ISA:
64+
raise ValueError(
65+
f"Unsupported Cortex-M CPU: {self.cpu!r}. "
66+
f"Supported: {sorted(_CPU_DEFAULT_ISA)}"
67+
)
68+
if self.isa is None:
69+
# frozen dataclass: use object.__setattr__ to fill default ISA.
70+
object.__setattr__(self, "isa", _CPU_DEFAULT_ISA[self.cpu])
71+
72+
@classmethod
73+
def from_target_string(cls, target: str) -> CortexMCompileConfig:
74+
"""Parse `cortex-m<variant>+int8` strings used by `aot_arm_compiler.py`.
75+
76+
Today only `+int8` is supported. The suffix is required so the target
77+
string remains explicit about the data type contract.
78+
"""
79+
cpu, sep, features = target.partition("+")
80+
if not sep:
81+
raise ValueError(
82+
f"Cortex-M target string must include a feature suffix "
83+
f"(e.g. '+int8'), got: {target!r}"
84+
)
85+
feature_set = set(features.split("+"))
86+
unknown = feature_set - _SUPPORTED_FEATURES
87+
if unknown or "int8" not in feature_set:
88+
raise ValueError(
89+
f"Cortex-M target string must be '<cpu>+int8' "
90+
f"(supported features: {sorted(_SUPPORTED_FEATURES)}), "
91+
f"got: {target!r}"
92+
)
93+
if cpu not in _CPU_DEFAULT_ISA:
94+
raise ValueError(
95+
f"Unsupported Cortex-M CPU in target string: {cpu!r}. "
96+
f"Supported: {sorted(_CPU_DEFAULT_ISA)}"
97+
)
98+
return cls(cpu=cpu) # type: ignore[arg-type]

backends/cortex_m/passes/cortex_m_pass_manager.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
FoldAndAnnotateQParamsPass,
1212
ScalarsToAttributePass,
1313
)
14+
from executorch.backends.cortex_m.compile_config import CortexMCompileConfig
1415
from executorch.backends.transforms.remove_getitem_op import RemoveGetItemPass
1516
from executorch.backends.transforms.replace_scalar_with_tensor import (
1617
ReplaceScalarWithTensorArgPass,
@@ -57,14 +58,18 @@ class CortexMPassManager(PassManager):
5758
]
5859

5960
def __init__(
60-
self, exported_program, passes: Optional[list[PassClass]] = None
61+
self,
62+
exported_program,
63+
passes: Optional[list[PassClass]] = None,
64+
config: Optional[CortexMCompileConfig] = None,
6165
) -> None:
6266
super().__init__(passes=[])
6367
self.exported_program = exported_program
6468
# PassManager.passes is typed as callables; this manager stores pass classes which are initialized at transform time with the exported_program.
6569
self.passes: list[PassClass] = ( # type: ignore[assignment]
6670
passes if passes is not None else self.pass_list # type: ignore[assignment]
6771
)
72+
self.config: CortexMCompileConfig = config or CortexMCompileConfig()
6873

6974
def transform_for_annotation(self, model):
7075
passes = self.pass_list_transform_for_annotation
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from importlib.util import find_spec
8+
9+
import pytest
10+
11+
from executorch.backends.cortex_m.compile_config import CortexMCompileConfig
12+
13+
_HAS_CMSIS_NN = find_spec("cmsis_nn") is not None
14+
15+
16+
class TestCortexMCompileConfig:
17+
def test_default_is_m55_mve(self):
18+
config = CortexMCompileConfig()
19+
assert config.cpu == "cortex-m55"
20+
assert config.isa == "mve"
21+
22+
@pytest.mark.parametrize(
23+
"target_string,expected_cpu,expected_isa",
24+
[
25+
("cortex-m0+int8", "cortex-m0", "scalar"),
26+
("cortex-m0plus+int8", "cortex-m0plus", "scalar"),
27+
("cortex-m3+int8", "cortex-m3", "scalar"),
28+
("cortex-m4+int8", "cortex-m4", "dsp"),
29+
("cortex-m7+int8", "cortex-m7", "dsp"),
30+
("cortex-m23+int8", "cortex-m23", "scalar"),
31+
("cortex-m33+int8", "cortex-m33", "dsp"),
32+
("cortex-m35p+int8", "cortex-m35p", "dsp"),
33+
("cortex-m52+int8", "cortex-m52", "mve"),
34+
("cortex-m55+int8", "cortex-m55", "mve"),
35+
("cortex-m85+int8", "cortex-m85", "mve"),
36+
],
37+
)
38+
def test_from_target_string(self, target_string, expected_cpu, expected_isa):
39+
config = CortexMCompileConfig.from_target_string(target_string)
40+
assert config.cpu == expected_cpu
41+
assert config.isa == expected_isa
42+
43+
def test_from_target_string_rejects_unknown_cpu(self):
44+
with pytest.raises(ValueError, match="cortex-m999"):
45+
CortexMCompileConfig.from_target_string("cortex-m999+int8")
46+
47+
@pytest.mark.parametrize(
48+
"target_string",
49+
[
50+
"cortex-m55", # missing feature suffix
51+
"cortex-m55+int8+int16", # unsupported extra feature
52+
"cortex-m55+", # trailing plus
53+
"cortex-m55+fp16", # unknown feature
54+
],
55+
)
56+
def test_from_target_string_rejects_invalid_features(self, target_string):
57+
with pytest.raises(ValueError):
58+
CortexMCompileConfig.from_target_string(target_string)
59+
60+
def test_default_matches_m55_target_string(self):
61+
# Regression guard: pre-Phase-1 behavior was M55+MVE; the default
62+
# constructor must remain equivalent to parsing the existing target.
63+
assert CortexMCompileConfig() == CortexMCompileConfig.from_target_string(
64+
"cortex-m55+int8"
65+
)
66+
67+
def test_is_hashable_and_frozen(self):
68+
from dataclasses import FrozenInstanceError
69+
70+
config = CortexMCompileConfig(cpu="cortex-m33")
71+
assert hash(config) == hash(CortexMCompileConfig(cpu="cortex-m33"))
72+
assert {config, CortexMCompileConfig(cpu="cortex-m33")} == {config}
73+
with pytest.raises(FrozenInstanceError):
74+
config.cpu = "cortex-m55" # type: ignore[misc]
75+
76+
def test_explicit_isa_override(self):
77+
config = CortexMCompileConfig(cpu="cortex-m33", isa="scalar")
78+
assert config.cpu == "cortex-m33"
79+
assert config.isa == "scalar"
80+
81+
82+
@pytest.mark.skipif(
83+
not _HAS_CMSIS_NN, reason="cortex_m passes require cmsis_nn"
84+
)
85+
class TestPassManagerConfigWiring:
86+
def test_default_config_is_m55(self):
87+
from executorch.backends.cortex_m.passes.cortex_m_pass_manager import (
88+
CortexMPassManager,
89+
)
90+
91+
pm = CortexMPassManager(exported_program=None)
92+
assert pm.config.cpu == "cortex-m55"
93+
assert pm.config.isa == "mve"
94+
95+
def test_explicit_config_threaded(self):
96+
from executorch.backends.cortex_m.passes.cortex_m_pass_manager import (
97+
CortexMPassManager,
98+
)
99+
100+
config = CortexMCompileConfig(cpu="cortex-m33")
101+
pm = CortexMPassManager(exported_program=None, config=config)
102+
assert pm.config.cpu == "cortex-m33"
103+
assert pm.config.isa == "dsp"

backends/cortex_m/test/tester.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66

77
from collections.abc import Callable
88
from dataclasses import dataclass
9-
from typing import Any
9+
from functools import partial
10+
from typing import Any, Optional
1011

1112
import torch
1213
from executorch.backends.arm.test.common import get_u55_compile_spec
1314
from executorch.backends.arm.test.tester.arm_tester import Serialize
15+
from executorch.backends.cortex_m.compile_config import CortexMCompileConfig
1416
from executorch.backends.cortex_m.passes.cortex_m_pass_manager import CortexMPassManager
1517
from executorch.backends.cortex_m.quantizer.quantizer import CortexMQuantizer
1618
from executorch.backends.test.harness import Tester as TesterBase
@@ -48,9 +50,12 @@ def __init__(self):
4850

4951

5052
class CortexMRunPasses(RunPasses):
51-
def __init__(self):
53+
def __init__(self, config: Optional[CortexMCompileConfig] = None):
54+
config = config or CortexMCompileConfig()
55+
# The base RunPasses constructs the pass manager as `cls(ep, pass_list)`.
56+
# Pre-bind the config so it flows through that 2-arg call.
5257
super().__init__(
53-
CortexMPassManager,
58+
partial(CortexMPassManager, config=config), # type: ignore[arg-type]
5459
CortexMPassManager.pass_list,
5560
)
5661

@@ -73,12 +78,20 @@ def __init__(self):
7378

7479

7580
class CortexMTester(TesterBase):
76-
def __init__(self, module, example_inputs):
81+
def __init__(
82+
self,
83+
module,
84+
example_inputs,
85+
config: Optional[CortexMCompileConfig] = None,
86+
):
7787
if callable(example_inputs):
7888
resolved_example_inputs = example_inputs()
7989
else:
8090
resolved_example_inputs = example_inputs
81-
super().__init__(module, resolved_example_inputs, cortex_m_stage_classes)
91+
config = config or CortexMCompileConfig()
92+
stage_classes = dict(cortex_m_stage_classes)
93+
stage_classes[StageType.RUN_PASSES] = lambda: CortexMRunPasses(config=config)
94+
super().__init__(module, resolved_example_inputs, stage_classes)
8295

8396
def test_dialect(
8497
self,

examples/raspberry_pi/pico2/export_mlp_mnist_cmsis.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
import torch
2626

27+
from executorch.backends.cortex_m.compile_config import CortexMCompileConfig
2728
from executorch.backends.cortex_m.passes.cortex_m_pass_manager import CortexMPassManager
2829
from executorch.backends.cortex_m.quantizer.quantizer import CortexMQuantizer
2930
from executorch.exir import EdgeCompileConfig, ExecutorchBackendConfig, to_edge
@@ -94,7 +95,10 @@ def export_to_pte(quantized_model, example_input, output_path: str):
9495
logger.info("Edge program created")
9596

9697
logger.info("Applying Cortex-M optimization passes...")
97-
pass_manager = CortexMPassManager(edge_program.exported_program())
98+
pass_manager = CortexMPassManager(
99+
edge_program.exported_program(),
100+
config=CortexMCompileConfig(cpu="cortex-m33"),
101+
)
98102
transformed_ep = pass_manager.transform()
99103

100104
edge_program = to_edge(transformed_ep, compile_config=edge_config)

0 commit comments

Comments
 (0)