Skip to content

Commit 99be90d

Browse files
committed
issue/497 - add a script to run all op tests
1 parent 932983b commit 99be90d

5 files changed

Lines changed: 269 additions & 8 deletions

File tree

test/infinicore/ops.py

Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
import os
2+
import sys
3+
import subprocess
4+
import argparse
5+
from pathlib import Path
6+
7+
8+
def find_ops_directory(start_dir=None):
9+
"""
10+
Find the ops directory by searching from start_dir upwards.
11+
"""
12+
if start_dir is None:
13+
start_dir = Path(__file__).parent
14+
15+
ops_dir = start_dir / "ops"
16+
if ops_dir.exists() and (ops_dir / "rms_norm.py").exists():
17+
return ops_dir
18+
19+
20+
def run_all_op_tests(ops_dir=None, verbose=False, specific_ops=None, extra_args=None):
21+
"""
22+
Run all operator test scripts in the ops directory.
23+
24+
Args:
25+
ops_dir (str, optional): Path to the ops directory. If None, uses the current directory.
26+
verbose (bool): Whether to print detailed output.
27+
specific_ops (list, optional): List of specific operator names to test (e.g., ['add', 'matmul']).
28+
extra_args (list, optional): Extra command line arguments to pass to test scripts.
29+
30+
Returns:
31+
dict: Results dictionary with test names as keys and (success, return_code, output) as values.
32+
"""
33+
if ops_dir is None:
34+
ops_dir = find_ops_directory()
35+
else:
36+
ops_dir = Path(ops_dir)
37+
38+
if not ops_dir.exists():
39+
print(f"Error: Ops directory '{ops_dir}' does not exist.")
40+
return {}
41+
42+
print(f"Looking for test files in: {ops_dir}")
43+
44+
# Find all Python test files (looking for actual operator test files)
45+
test_files = list(ops_dir.glob("*.py"))
46+
47+
# Filter out this script itself and non-operator test files
48+
current_script = Path(__file__).name
49+
test_files = [f for f in test_files if f.name != current_script]
50+
51+
# Further filter to include only files that look like operator tests
52+
# (they typically import infinicore and BaseOperatorTest)
53+
operator_test_files = []
54+
for test_file in test_files:
55+
try:
56+
with open(test_file, "r", encoding="utf-8") as f:
57+
content = f.read()
58+
if "infinicore" in content and "BaseOperatorTest" in content:
59+
operator_test_files.append(test_file)
60+
elif verbose:
61+
print(f" Skipping {test_file.name}: not an operator test file")
62+
except Exception as e:
63+
if verbose:
64+
print(f" Could not read {test_file.name}: {e}")
65+
continue
66+
67+
if specific_ops:
68+
# Filter for specific operators (case insensitive)
69+
filtered_files = []
70+
for test_file in operator_test_files:
71+
test_name = test_file.stem.lower()
72+
if any(op.lower() in test_name for op in specific_ops):
73+
filtered_files.append(test_file)
74+
elif verbose:
75+
print(f" Filtered out {test_file.name}: not in specific_ops list")
76+
operator_test_files = filtered_files
77+
78+
if not operator_test_files:
79+
print(f"No operator test files found in {ops_dir}")
80+
print(f"Available Python files: {[f.name for f in test_files]}")
81+
print(f"Current directory: {Path.cwd()}")
82+
return {}
83+
84+
print(f"Found {len(operator_test_files)} operator test files:")
85+
for test_file in operator_test_files:
86+
print(f" - {test_file.name}")
87+
88+
results = {}
89+
90+
for test_file in operator_test_files:
91+
test_name = test_file.stem
92+
93+
try:
94+
# Run the test script
95+
cmd = [sys.executable, str(test_file)]
96+
97+
# Add extra arguments if provided
98+
if extra_args:
99+
cmd.extend(extra_args)
100+
101+
if verbose:
102+
print(f"Command: {' '.join(cmd)}")
103+
print(f"Working directory: {ops_dir}")
104+
105+
# Always capture output to display it
106+
result = subprocess.run(cmd, cwd=ops_dir, capture_output=True, text=True)
107+
108+
success = result.returncode == 0
109+
results[test_name] = (
110+
success,
111+
result.returncode,
112+
result.stdout,
113+
result.stderr,
114+
)
115+
116+
# Print the output from the test script
117+
if result.stdout:
118+
print(result.stdout)
119+
120+
if result.stderr:
121+
print("STDERR:")
122+
print(result.stderr)
123+
124+
if success:
125+
print(f"✅ {test_name}: PASSED (return code: {result.returncode})")
126+
else:
127+
print(f"❌ {test_name}: FAILED (return code: {result.returncode})")
128+
129+
except Exception as e:
130+
print(f"❌ {test_name}: ERROR - {str(e)}")
131+
results[test_name] = (False, -1, "", str(e))
132+
133+
return results
134+
135+
136+
def print_summary(results):
137+
"""Print a summary of test results."""
138+
print(f"\n{'='*80}")
139+
print("TEST SUMMARY")
140+
print(f"{'='*80}")
141+
142+
if not results:
143+
print("No tests were run.")
144+
return
145+
146+
passed = sum(1 for success, _, _, _ in results.values() if success)
147+
total = len(results)
148+
149+
print(f"Total tests: {total}")
150+
print(f"Passed: {passed}")
151+
print(f"Failed: {total - passed}")
152+
153+
if total > 0:
154+
print(f"Success rate: {passed/total*100:.1f}%")
155+
156+
if passed == total:
157+
print("\n🎉 All tests passed!")
158+
else:
159+
print("\nFailed tests:")
160+
for test_name, (success, returncode, stdout, stderr) in results.items():
161+
if not success:
162+
print(f" - {test_name} (return code: {returncode})")
163+
# Print brief error info for failed tests
164+
if stderr:
165+
error_lines = stderr.strip().split("\n")
166+
if error_lines:
167+
print(f" Error: {error_lines[0]}")
168+
169+
170+
def main():
171+
"""Main entry point with command line argument parsing."""
172+
parser = argparse.ArgumentParser(
173+
description="Run all operator tests in the ops directory", add_help=False
174+
)
175+
176+
# Our script's specific arguments
177+
parser.add_argument(
178+
"--ops-dir", type=str, help="Path to the ops directory (default: auto-detect)"
179+
)
180+
parser.add_argument(
181+
"-v",
182+
"--verbose",
183+
action="store_true",
184+
help="Print detailed command information for each test",
185+
)
186+
parser.add_argument(
187+
"--ops", nargs="+", help="Run specific operators only (e.g., --ops add matmul)"
188+
)
189+
parser.add_argument(
190+
"--list",
191+
action="store_true",
192+
help="List all available test files without running them",
193+
)
194+
parser.add_argument(
195+
"-h", "--help", action="store_true", help="Show this help message and exit"
196+
)
197+
198+
# Parse known args first, leave the rest for the test scripts
199+
args, unknown_args = parser.parse_known_args()
200+
201+
if args.help:
202+
parser.print_help()
203+
print("\nExtra arguments that will be passed to test scripts:")
204+
print(" --nvidia, --cpu, --bench, --debug, etc.")
205+
return
206+
207+
# Auto-detect ops directory if not provided
208+
if args.ops_dir is None:
209+
ops_dir = find_ops_directory()
210+
else:
211+
ops_dir = Path(args.ops_dir)
212+
213+
if args.list:
214+
# Just list available test files
215+
test_files = list(ops_dir.glob("*.py"))
216+
current_script = Path(__file__).name
217+
test_files = [f for f in test_files if f.name != current_script]
218+
219+
operator_test_files = []
220+
for test_file in test_files:
221+
try:
222+
with open(test_file, "r", encoding="utf-8") as f:
223+
content = f.read()
224+
if "infinicore" in content and "BaseOperatorTest" in content:
225+
operator_test_files.append(test_file)
226+
except:
227+
continue
228+
229+
if operator_test_files:
230+
print(f"Available operator test files in {ops_dir}:")
231+
for test_file in operator_test_files:
232+
print(f" - {test_file.name}")
233+
else:
234+
print(f"No operator test files found in {ops_dir}")
235+
print(f"Available Python files: {[f.name for f in test_files]}")
236+
return
237+
238+
# Show what extra arguments will be passed
239+
if unknown_args:
240+
print(f"Passing extra arguments to test scripts: {unknown_args}")
241+
242+
# Run all tests
243+
results = run_all_op_tests(
244+
ops_dir=ops_dir,
245+
verbose=args.verbose,
246+
specific_ops=args.ops,
247+
extra_args=unknown_args,
248+
)
249+
250+
print_summary(results)
251+
252+
# Exit with appropriate code
253+
if results and all(success for success, _, _, _ in results.values()):
254+
sys.exit(0)
255+
else:
256+
sys.exit(1)
257+
258+
259+
if __name__ == "__main__":
260+
main()

test/infinicore/ops/add.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def parse_add_test_case(data):
7474
}
7575

7676

77-
class AddTest(BaseOperatorTest):
77+
class OpTest(BaseOperatorTest):
7878
"""Add test with simplified test case parsing"""
7979

8080
def __init__(self):
@@ -98,7 +98,7 @@ def infinicore_operator(self, a, b, out=None, **kwargs):
9898

9999
def main():
100100
"""Main entry point"""
101-
runner = GenericTestRunner(AddTest)
101+
runner = GenericTestRunner(OpTest)
102102
runner.run_and_exit()
103103

104104

test/infinicore/ops/attention_temp.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
This is for framework validation
33
"""
4+
45
import sys
56
import os
67

@@ -229,7 +230,7 @@ def parse_attention_test_case(data):
229230
}
230231

231232

232-
class AttentionTest(BaseOperatorTest):
233+
class OpTest(BaseOperatorTest):
233234
"""Attention test with simplified test case parsing"""
234235

235236
def __init__(self):
@@ -259,7 +260,7 @@ def infinicore_operator(self, q, k, v, k_cache, v_cache, pos, out=None, **kwargs
259260

260261
def main():
261262
"""Main entry point"""
262-
runner = GenericTestRunner(AttentionTest)
263+
runner = GenericTestRunner(OpTest)
263264
runner.run_and_exit()
264265

265266

test/infinicore/ops/matmul.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def parse_matmul_test_case(data):
9090
}
9191

9292

93-
class MatmulTest(BaseOperatorTest):
93+
class OpTest(BaseOperatorTest):
9494
"""Matmul test with simplified test case parsing"""
9595

9696
def __init__(self):
@@ -114,7 +114,7 @@ def infinicore_operator(self, a, b, out=None, **kwargs):
114114

115115
def main():
116116
"""Main entry point"""
117-
runner = GenericTestRunner(MatmulTest)
117+
runner = GenericTestRunner(OpTest)
118118
runner.run_and_exit()
119119

120120

test/infinicore/ops/rms_norm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def parse_rms_norm_test_case(data):
8888
_EPSILON = 1e-5
8989

9090

91-
class RMSNormTest(BaseOperatorTest):
91+
class OpTest(BaseOperatorTest):
9292
"""RMSNorm test with simplified test case parsing"""
9393

9494
def __init__(self):
@@ -124,7 +124,7 @@ def infinicore_operator(self, x, weight, out=None, **kwargs):
124124

125125
def main():
126126
"""Main entry point"""
127-
runner = GenericTestRunner(RMSNormTest)
127+
runner = GenericTestRunner(OpTest)
128128
runner.run_and_exit()
129129

130130

0 commit comments

Comments
 (0)