Skip to content

Commit ac0a201

Browse files
authored
Arm backend: Add "normalized op name" -> "edge op name mapping" (pytorch#16825)
To produce exir operator test coverage reports for the Arm backend, it requires mapping from normalized op name to full edge op name. Add this to collect_testname_resources.py. Signed-off-by: Martin Lindström <Martin.Lindstroem@arm.com>
1 parent 15ad846 commit ac0a201

1 file changed

Lines changed: 14 additions & 28 deletions

File tree

backends/arm/scripts/collect_testname_resources.py

Lines changed: 14 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -132,34 +132,19 @@ def _collect_arm_models(models_md: pathlib.Path) -> set[str]:
132132
return models
133133

134134

135-
def _collect_arm_ops() -> set[str]:
136-
"""
137-
Returns a mapping from names on the form to be used in unittests to edge op:
138-
1. Names are in lowercase.
139-
2. Overload is ignored if 'default', otherwise it's appended with an underscore.
140-
3. Overly verbose name are shortened by removing certain prefixes/suffixes.
141-
142-
Examples:
143-
abs.default -> abs
144-
split_copy.Tensor -> split_tensor
145-
"""
146-
ops: set[str] = set()
147-
for edge_name in _ALL_EDGE_OPS:
148-
op, overload = edge_name.split(".")
149-
150-
# Normalize names
151-
op = op.lower()
152-
op = op.removeprefix("_")
153-
op = op.removesuffix("_copy")
154-
op = op.removesuffix("_with_indices")
155-
overload = overload.lower()
156-
157-
if overload == "default":
158-
ops.add(op)
159-
else:
160-
ops.add(f"{op}_{overload}")
135+
def _normalize_op_name(edge_name: str) -> str:
136+
op, overload = edge_name.split(".")
137+
138+
op = op.lower()
139+
op = op.removeprefix("_")
140+
op = op.removesuffix("_copy")
141+
op = op.removesuffix("_with_indices")
161142

162-
return ops
143+
overload = overload.lower()
144+
if overload == "default":
145+
return op
146+
else:
147+
return f"{op}_{overload}"
163148

164149

165150
def _split_model_entry(entry: str) -> tuple[str, str | None, bool]:
@@ -190,7 +175,8 @@ def _camel_to_snake(name: str) -> str:
190175
return _CAMEL_BOUNDARY.sub("_", name).lower()
191176

192177

193-
OP_LIST = sorted(_collect_arm_ops())
178+
OP_NAME_MAP = {_normalize_op_name(edge_name): edge_name for edge_name in _ALL_EDGE_OPS}
179+
OP_LIST = sorted({_normalize_op_name(edge_name) for edge_name in _ALL_EDGE_OPS})
194180
PASS_LIST = sorted(
195181
_collect_arm_passes(pathlib.Path("backends/arm/_passes/__init__.py"))
196182
)

0 commit comments

Comments
 (0)