Skip to content

Commit d7219c7

Browse files
Martin LindströmMartin Lindström
authored andcommitted
Arm backend: Add test for evaluate_model.py
Add two tests that exercise evaluate_model.py with different command line arguments. Signed-off-by: Martin Lindström <Martin.Lindstroem@arm.com> Change-Id: I47304ea270518703dc4c826c4c6672c7aca95228
1 parent ae85eb9 commit d7219c7

1 file changed

Lines changed: 71 additions & 0 deletions

File tree

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import json
7+
import sys
8+
from pathlib import Path
9+
10+
from executorch.backends.arm.scripts import evaluate_model
11+
12+
13+
def _run_evaluate_model(*args: str) -> None:
14+
previous_argv = sys.argv
15+
try:
16+
sys.argv = ["evaluate_model.py", *args]
17+
evaluate_model.main()
18+
finally:
19+
sys.argv = previous_argv
20+
21+
22+
def test_evaluate_model_tosa_INT(tmp_path: Path) -> None:
23+
intermediates = tmp_path / "test_evaluate_model_tosa_INT_intermediates"
24+
output = tmp_path / "test_evaluate_model_tosa_INT_metrics.json"
25+
26+
_run_evaluate_model(
27+
"--model_name",
28+
"add",
29+
"--target",
30+
"TOSA-1.0+INT",
31+
"--quant_mode",
32+
"int8",
33+
"--no_delegate",
34+
"--evaluators",
35+
"numerical",
36+
"--intermediates",
37+
str(intermediates),
38+
"--output",
39+
str(output),
40+
)
41+
42+
assert output.exists(), f"Metrics file not created at {output}"
43+
data = json.loads(output.read_text())
44+
assert data["name"] == "add"
45+
assert "metrics" in data
46+
assert "mean_absolute_error" in data["metrics"]
47+
48+
49+
def test_evaluate_model_tosa_FP(tmp_path: Path) -> None:
50+
intermediates = tmp_path / "test_evaluate_model_tosa_FP_intermediates"
51+
output = tmp_path / "test_evaluate_model_tosa_FP_metrics.json"
52+
53+
_run_evaluate_model(
54+
"--model_name",
55+
"add",
56+
"--target",
57+
"TOSA-1.0+FP",
58+
"--evaluators",
59+
"numerical",
60+
"--intermediates",
61+
str(intermediates),
62+
"--output",
63+
str(output),
64+
)
65+
66+
assert output.exists(), f"Metrics file not created at {output}"
67+
data = json.loads(output.read_text())
68+
assert data["name"] == "add"
69+
assert "metrics" in data
70+
assert "mean_absolute_error" in data["metrics"]
71+
assert "compression_ratio" in data["metrics"]

0 commit comments

Comments
 (0)