Skip to content

Commit 268dd80

Browse files
committed
issue/497 - add a script to run all op tests
1 parent 932983b commit 268dd80

5 files changed

Lines changed: 287 additions & 8 deletions

File tree

test/infinicore/ops.py

Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
import os
2+
import sys
3+
import subprocess
4+
import argparse
5+
from pathlib import Path
6+
7+
8+
def find_ops_directory(start_dir=None):
9+
"""
10+
Find the ops directory by searching from start_dir upwards.
11+
"""
12+
if start_dir is None:
13+
start_dir = Path(__file__).parent
14+
15+
# Check if we're already in the ops directory
16+
if (start_dir / "rms_norm.py").exists() and (start_dir / "add.py").exists():
17+
return start_dir
18+
19+
# Check if there's an ops subdirectory
20+
ops_dir = start_dir / "ops"
21+
if ops_dir.exists() and (ops_dir / "rms_norm.py").exists():
22+
return ops_dir
23+
24+
# Check parent directory
25+
parent_ops_dir = start_dir.parent / "ops"
26+
if parent_ops_dir.exists() and (parent_ops_dir / "rms_norm.py").exists():
27+
return parent_ops_dir
28+
29+
# If we can't find it, return the most likely candidate
30+
if ops_dir.exists():
31+
return ops_dir
32+
elif parent_ops_dir.exists():
33+
return parent_ops_dir
34+
else:
35+
return start_dir
36+
37+
38+
def run_all_op_tests(ops_dir=None, verbose=False, specific_ops=None, extra_args=None):
39+
"""
40+
Run all operator test scripts in the ops directory.
41+
42+
Args:
43+
ops_dir (str, optional): Path to the ops directory. If None, uses the current directory.
44+
verbose (bool): Whether to print detailed output.
45+
specific_ops (list, optional): List of specific operator names to test (e.g., ['add', 'matmul']).
46+
extra_args (list, optional): Extra command line arguments to pass to test scripts.
47+
48+
Returns:
49+
dict: Results dictionary with test names as keys and (success, return_code, output) as values.
50+
"""
51+
if ops_dir is None:
52+
ops_dir = find_ops_directory()
53+
else:
54+
ops_dir = Path(ops_dir)
55+
56+
if not ops_dir.exists():
57+
print(f"Error: Ops directory '{ops_dir}' does not exist.")
58+
return {}
59+
60+
print(f"Looking for test files in: {ops_dir}")
61+
62+
# Find all Python test files (looking for actual operator test files)
63+
test_files = list(ops_dir.glob("*.py"))
64+
65+
# Filter out this script itself and non-operator test files
66+
current_script = Path(__file__).name
67+
test_files = [f for f in test_files if f.name != current_script]
68+
69+
# Further filter to include only files that look like operator tests
70+
# (they typically import infinicore and BaseOperatorTest)
71+
operator_test_files = []
72+
for test_file in test_files:
73+
try:
74+
with open(test_file, "r", encoding="utf-8") as f:
75+
content = f.read()
76+
if "infinicore" in content and "BaseOperatorTest" in content:
77+
operator_test_files.append(test_file)
78+
elif verbose:
79+
print(f" Skipping {test_file.name}: not an operator test file")
80+
except Exception as e:
81+
if verbose:
82+
print(f" Could not read {test_file.name}: {e}")
83+
continue
84+
85+
if specific_ops:
86+
# Filter for specific operators (case insensitive)
87+
filtered_files = []
88+
for test_file in operator_test_files:
89+
test_name = test_file.stem.lower()
90+
if any(op.lower() in test_name for op in specific_ops):
91+
filtered_files.append(test_file)
92+
elif verbose:
93+
print(f" Filtered out {test_file.name}: not in specific_ops list")
94+
operator_test_files = filtered_files
95+
96+
if not operator_test_files:
97+
print(f"No operator test files found in {ops_dir}")
98+
print(f"Available Python files: {[f.name for f in test_files]}")
99+
print(f"Current directory: {Path.cwd()}")
100+
return {}
101+
102+
print(f"Found {len(operator_test_files)} operator test files:")
103+
for test_file in operator_test_files:
104+
print(f" - {test_file.name}")
105+
106+
results = {}
107+
108+
for test_file in operator_test_files:
109+
test_name = test_file.stem
110+
111+
try:
112+
# Run the test script
113+
cmd = [sys.executable, str(test_file)]
114+
115+
# Add extra arguments if provided
116+
if extra_args:
117+
cmd.extend(extra_args)
118+
119+
if verbose:
120+
print(f"Command: {' '.join(cmd)}")
121+
print(f"Working directory: {ops_dir}")
122+
123+
# Always capture output to display it
124+
result = subprocess.run(cmd, cwd=ops_dir, capture_output=True, text=True)
125+
126+
success = result.returncode == 0
127+
results[test_name] = (
128+
success,
129+
result.returncode,
130+
result.stdout,
131+
result.stderr,
132+
)
133+
134+
# Print the output from the test script
135+
if result.stdout:
136+
print(result.stdout)
137+
138+
if result.stderr:
139+
print("STDERR:")
140+
print(result.stderr)
141+
142+
if success:
143+
print(f"✅ {test_name}: PASSED (return code: {result.returncode})")
144+
else:
145+
print(f"❌ {test_name}: FAILED (return code: {result.returncode})")
146+
147+
except Exception as e:
148+
print(f"❌ {test_name}: ERROR - {str(e)}")
149+
results[test_name] = (False, -1, "", str(e))
150+
151+
return results
152+
153+
154+
def print_summary(results):
155+
"""Print a summary of test results."""
156+
print(f"\n{'='*80}")
157+
print("TEST SUMMARY")
158+
print(f"{'='*80}")
159+
160+
if not results:
161+
print("No tests were run.")
162+
return
163+
164+
passed = sum(1 for success, _, _, _ in results.values() if success)
165+
total = len(results)
166+
167+
print(f"Total tests: {total}")
168+
print(f"Passed: {passed}")
169+
print(f"Failed: {total - passed}")
170+
171+
if total > 0:
172+
print(f"Success rate: {passed/total*100:.1f}%")
173+
174+
if passed == total:
175+
print("\n🎉 All tests passed!")
176+
else:
177+
print("\nFailed tests:")
178+
for test_name, (success, returncode, stdout, stderr) in results.items():
179+
if not success:
180+
print(f" - {test_name} (return code: {returncode})")
181+
# Print brief error info for failed tests
182+
if stderr:
183+
error_lines = stderr.strip().split("\n")
184+
if error_lines:
185+
print(f" Error: {error_lines[0]}")
186+
187+
188+
def main():
189+
"""Main entry point with command line argument parsing."""
190+
parser = argparse.ArgumentParser(
191+
description="Run all operator tests in the ops directory", add_help=False
192+
)
193+
194+
# Our script's specific arguments
195+
parser.add_argument(
196+
"--ops-dir", type=str, help="Path to the ops directory (default: auto-detect)"
197+
)
198+
parser.add_argument(
199+
"-v",
200+
"--verbose",
201+
action="store_true",
202+
help="Print detailed command information for each test",
203+
)
204+
parser.add_argument(
205+
"--ops", nargs="+", help="Run specific operators only (e.g., --ops add matmul)"
206+
)
207+
parser.add_argument(
208+
"--list",
209+
action="store_true",
210+
help="List all available test files without running them",
211+
)
212+
parser.add_argument(
213+
"-h", "--help", action="store_true", help="Show this help message and exit"
214+
)
215+
216+
# Parse known args first, leave the rest for the test scripts
217+
args, unknown_args = parser.parse_known_args()
218+
219+
if args.help:
220+
parser.print_help()
221+
print("\nExtra arguments that will be passed to test scripts:")
222+
print(" --nvidia, --cpu, --bench, --debug, etc.")
223+
return
224+
225+
# Auto-detect ops directory if not provided
226+
if args.ops_dir is None:
227+
ops_dir = find_ops_directory()
228+
else:
229+
ops_dir = Path(args.ops_dir)
230+
231+
if args.list:
232+
# Just list available test files
233+
test_files = list(ops_dir.glob("*.py"))
234+
current_script = Path(__file__).name
235+
test_files = [f for f in test_files if f.name != current_script]
236+
237+
operator_test_files = []
238+
for test_file in test_files:
239+
try:
240+
with open(test_file, "r", encoding="utf-8") as f:
241+
content = f.read()
242+
if "infinicore" in content and "BaseOperatorTest" in content:
243+
operator_test_files.append(test_file)
244+
except:
245+
continue
246+
247+
if operator_test_files:
248+
print(f"Available operator test files in {ops_dir}:")
249+
for test_file in operator_test_files:
250+
print(f" - {test_file.name}")
251+
else:
252+
print(f"No operator test files found in {ops_dir}")
253+
print(f"Available Python files: {[f.name for f in test_files]}")
254+
return
255+
256+
# Show what extra arguments will be passed
257+
if unknown_args:
258+
print(f"Passing extra arguments to test scripts: {unknown_args}")
259+
260+
# Run all tests
261+
results = run_all_op_tests(
262+
ops_dir=ops_dir,
263+
verbose=args.verbose,
264+
specific_ops=args.ops,
265+
extra_args=unknown_args,
266+
)
267+
268+
print_summary(results)
269+
270+
# Exit with appropriate code
271+
if results and all(success for success, _, _, _ in results.values()):
272+
sys.exit(0)
273+
else:
274+
sys.exit(1)
275+
276+
277+
if __name__ == "__main__":
278+
main()

test/infinicore/ops/add.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def parse_add_test_case(data):
7474
}
7575

7676

77-
class AddTest(BaseOperatorTest):
77+
class OpTest(BaseOperatorTest):
7878
"""Add test with simplified test case parsing"""
7979

8080
def __init__(self):
@@ -98,7 +98,7 @@ def infinicore_operator(self, a, b, out=None, **kwargs):
9898

9999
def main():
100100
"""Main entry point"""
101-
runner = GenericTestRunner(AddTest)
101+
runner = GenericTestRunner(OpTest)
102102
runner.run_and_exit()
103103

104104

test/infinicore/ops/attention_temp.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
This is for framework validation
33
"""
4+
45
import sys
56
import os
67

@@ -229,7 +230,7 @@ def parse_attention_test_case(data):
229230
}
230231

231232

232-
class AttentionTest(BaseOperatorTest):
233+
class OpTest(BaseOperatorTest):
233234
"""Attention test with simplified test case parsing"""
234235

235236
def __init__(self):
@@ -259,7 +260,7 @@ def infinicore_operator(self, q, k, v, k_cache, v_cache, pos, out=None, **kwargs
259260

260261
def main():
261262
"""Main entry point"""
262-
runner = GenericTestRunner(AttentionTest)
263+
runner = GenericTestRunner(OpTest)
263264
runner.run_and_exit()
264265

265266

test/infinicore/ops/matmul.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def parse_matmul_test_case(data):
9090
}
9191

9292

93-
class MatmulTest(BaseOperatorTest):
93+
class OpTest(BaseOperatorTest):
9494
"""Matmul test with simplified test case parsing"""
9595

9696
def __init__(self):
@@ -114,7 +114,7 @@ def infinicore_operator(self, a, b, out=None, **kwargs):
114114

115115
def main():
116116
"""Main entry point"""
117-
runner = GenericTestRunner(MatmulTest)
117+
runner = GenericTestRunner(OpTest)
118118
runner.run_and_exit()
119119

120120

test/infinicore/ops/rms_norm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def parse_rms_norm_test_case(data):
8888
_EPSILON = 1e-5
8989

9090

91-
class RMSNormTest(BaseOperatorTest):
91+
class OpTest(BaseOperatorTest):
9292
"""RMSNorm test with simplified test case parsing"""
9393

9494
def __init__(self):
@@ -124,7 +124,7 @@ def infinicore_operator(self, x, weight, out=None, **kwargs):
124124

125125
def main():
126126
"""Main entry point"""
127-
runner = GenericTestRunner(RMSNormTest)
127+
runner = GenericTestRunner(OpTest)
128128
runner.run_and_exit()
129129

130130

0 commit comments

Comments
 (0)