Skip to content

Commit 22922df

Browse files
Martin LindströmMartin Lindström
authored andcommitted
Arm backend: Add test for evaluate_model.py
Add two tests that exercise evaluate_model.py with different command line arguments. Signed-off-by: Martin Lindström <Martin.Lindstroem@arm.com> Change-Id: I47304ea270518703dc4c826c4c6672c7aca95228
1 parent 44186e8 commit 22922df

2 files changed

Lines changed: 73 additions & 0 deletions

File tree

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
import json
8+
import sys
9+
from pathlib import Path
10+
11+
from executorch.backends.arm.scripts import evaluate_model
12+
13+
14+
def _run_evaluate_model(*args: str) -> None:
15+
previous_argv = sys.argv
16+
try:
17+
sys.argv = ["evaluate_model.py", *args]
18+
evaluate_model.main()
19+
finally:
20+
sys.argv = previous_argv
21+
22+
23+
def test_evaluate_model_tosa_INT(tmp_path: Path) -> None:
24+
intermediates = tmp_path / "test_evaluate_model_tosa_INT_intermediates"
25+
output = tmp_path / "test_evaluate_model_tosa_INT_metrics.json"
26+
27+
_run_evaluate_model(
28+
"--model_name",
29+
"add",
30+
"--target",
31+
"TOSA-1.0+INT",
32+
"--quant_mode",
33+
"int8",
34+
"--no_delegate",
35+
"--evaluators",
36+
"numerical",
37+
"--intermediates",
38+
str(intermediates),
39+
"--output",
40+
str(output),
41+
)
42+
43+
assert output.exists(), f"Metrics file not created at {output}"
44+
data = json.loads(output.read_text())
45+
assert data["name"] == "add"
46+
assert "metrics" in data
47+
assert "mean_absolute_error" in data["metrics"]
48+
49+
50+
def test_evaluate_model_tosa_FP(tmp_path: Path) -> None:
51+
intermediates = tmp_path / "test_evaluate_model_tosa_FP_intermediates"
52+
output = tmp_path / "test_evaluate_model_tosa_FP_metrics.json"
53+
54+
_run_evaluate_model(
55+
"--model_name",
56+
"add",
57+
"--target",
58+
"TOSA-1.0+FP",
59+
"--evaluators",
60+
"numerical",
61+
"--intermediates",
62+
str(intermediates),
63+
"--output",
64+
str(output),
65+
)
66+
67+
assert output.exists(), f"Metrics file not created at {output}"
68+
data = json.loads(output.read_text())
69+
assert data["name"] == "add"
70+
assert "metrics" in data
71+
assert "mean_absolute_error" in data["metrics"]
72+
assert "compression_ratio" in data["metrics"]

backends/arm/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def define_arm_tests():
4444
# Misc tests
4545
test_files += [
4646
"misc/test_compile_spec.py",
47+
"misc/test_evaluate_model.py",
4748
"misc/test_pass_pipeline_config.py",
4849
"misc/test_tosa_spec.py",
4950
"misc/test_bn_relu_folding_qat.py",

0 commit comments

Comments
 (0)