Skip to content

Commit 074270b

Browse files
committed
issue/497 - now requires test function definition
1 parent 558066f commit 074270b

4 files changed

Lines changed: 131 additions & 72 deletions

File tree

test/infinicore/framework/base.py

Lines changed: 39 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import infinicore
33

44
from abc import ABC, abstractmethod
5-
from typing import List, Dict, Any, Tuple, Union
5+
from typing import List, Dict, Any, Tuple, Union, Callable, Optional
66

77
from .datatypes import to_torch_dtype, to_infinicore_dtype
88
from .devices import InfiniDeviceNames, torch_device_map
@@ -66,12 +66,12 @@ class TestCase:
6666

6767
def __init__(self, inputs, output=None, **kwargs):
6868
"""
69-
简化构造函数
69+
Simplified constructor
7070
Args:
71-
inputs: List[TensorSpec] 或简单的形状元组
72-
output: TensorSpec 或形状元组
71+
inputs: List[TensorSpec] or simple shape tuples
72+
output: TensorSpec or shape tuple
7373
"""
74-
# 标准化 inputs
74+
# Normalize inputs
7575
self.inputs = []
7676
for inp in inputs:
7777
if isinstance(inp, (list, tuple)):
@@ -81,7 +81,7 @@ def __init__(self, inputs, output=None, **kwargs):
8181
else:
8282
self.inputs.append(inp)
8383

84-
# 标准化 output
84+
# Normalize output
8585
if isinstance(output, (list, tuple)):
8686
self.output = TensorSpec.from_tensor(output)
8787
else:
@@ -142,11 +142,11 @@ def __init__(self, test_cases, test_config):
142142
self.config = test_config
143143
self.failed_tests = [] # Track failures
144144

145-
def run_tests(self, devices, test_func):
145+
def run_tests(self, devices, test_func, test_type="Test"):
146146
"""Run tests and track failures"""
147147
for device in devices:
148148
print(f"\n{'='*60}")
149-
print(f"Testing on {InfiniDeviceNames[device]}")
149+
print(f"Testing {test_type} on {InfiniDeviceNames[device]}")
150150
print(f"{'='*60}")
151151

152152
# Filter unsupported data types
@@ -213,15 +213,17 @@ def get_tolerance_map(self):
213213
"""Return tolerance configuration"""
214214
pass
215215

216-
@abstractmethod
217-
def torch_operator(self, *inputs, **kwargs):
218-
"""PyTorch operator implementation"""
219-
pass
216+
def has_out_of_place_test(self):
217+
"""Check if out-of-place test functions are defined"""
218+
return hasattr(self, "torch_operator_out_of_place") and hasattr(
219+
self, "infinicore_operator_out_of_place"
220+
)
220221

221-
@abstractmethod
222-
def infinicore_operator(self, *inputs, **kwargs):
223-
"""Infinicore operator implementation"""
224-
pass
222+
def has_inplace_test(self):
223+
"""Check if in-place test functions are defined"""
224+
return hasattr(self, "torch_operator_inplace") and hasattr(
225+
self, "infinicore_operator_inplace"
226+
)
225227

226228
def create_strided_tensor(self, shape, strides, dtype, device_str):
227229
"""Create a non-contiguous tensor with specific strides"""
@@ -274,16 +276,19 @@ def prepare_inputs(self, test_case, device_str, dtype):
274276

275277
return inputs, test_case.kwargs
276278

277-
def run_test(self, device, test_case, dtype, config):
278-
"""Generic test execution flow with flexible inputs - output is always contiguous"""
279+
def run_out_of_place_test(self, device, test_case, dtype, config):
280+
"""Generic out-of-place test execution flow"""
281+
if not self.has_out_of_place_test():
282+
raise NotImplementedError("Out-of-place test functions not defined")
283+
279284
device_str = torch_device_map[device]
280285

281286
# Prepare inputs
282287
inputs, kwargs = self.prepare_inputs(test_case, device_str, dtype)
283288

284-
# PyTorch reference result - output is always contiguous for out-of-place
289+
# PyTorch reference result
285290
def torch_op():
286-
return self.torch_operator(*inputs, **kwargs)
291+
return self.torch_operator_out_of_place(*inputs, **kwargs)
287292

288293
torch_result = torch_op()
289294

@@ -310,36 +315,39 @@ def torch_op():
310315
else:
311316
infini_inputs.append(inp)
312317

313-
# Infinicore result - output is always contiguous for out-of-place
318+
# Infinicore result
314319
def infini_op():
315-
return self.infinicore_operator(*infini_inputs, **kwargs)
320+
return self.infinicore_operator_out_of_place(*infini_inputs, **kwargs)
316321

317322
infini_result = infini_op()
318323

319324
# Result comparison
320325
compare_fn = create_test_comparator(config, dtype)
321326
is_valid = compare_fn(infini_result, torch_result)
322-
assert is_valid, f"{self.operator_name} test failed"
327+
assert is_valid, f"{self.operator_name} out-of-place test failed"
323328

324329
# Performance testing
325330
if config.bench:
326331
profile_operation(
327-
f"PyTorch {self.operator_name}",
332+
f"PyTorch {self.operator_name} Out-of-place",
328333
torch_op,
329334
device_str,
330335
config.num_prerun,
331336
config.num_iterations,
332337
)
333338
profile_operation(
334-
f"Infinicore {self.operator_name}",
339+
f"Infinicore {self.operator_name} Out-of-place",
335340
infini_op,
336341
device_str,
337342
config.num_prerun,
338343
config.num_iterations,
339344
)
340345

341346
def run_inplace_test(self, device, test_case, dtype, config):
342-
"""Generic in-place operation test execution flow - supports strided output"""
347+
"""Generic in-place operation test execution flow"""
348+
if not self.has_inplace_test():
349+
raise NotImplementedError("In-place test functions not defined")
350+
343351
device_str = torch_device_map[device]
344352

345353
# Prepare inputs and output
@@ -370,7 +378,7 @@ def run_inplace_test(self, device, test_case, dtype, config):
370378
torch_preallocated.zero_()
371379

372380
def torch_op_inplace():
373-
self.torch_operator(*inputs, out=torch_preallocated, **kwargs)
381+
self.torch_operator_inplace(*inputs, out=torch_preallocated, **kwargs)
374382

375383
torch_op_inplace()
376384

@@ -392,19 +400,7 @@ def torch_op_inplace():
392400
else:
393401
infini_inputs.append(inp)
394402

395-
# # Create infinicore output tensor
396-
# if test_case.output.is_contiguous or test_case.output.strides is None:
397-
# infini_output = infinicore.empty(
398-
# output_shape, dtype=dtype, device=infinicore.device(device_str, 0)
399-
# )
400-
# else:
401-
# infini_output = infinicore.strided_empty(
402-
# output_shape,
403-
# test_case.output.strides,
404-
# dtype=dtype,
405-
# device=infinicore.device(device_str, 0),
406-
# )
407-
403+
# Create infinicore output tensor
408404
torch_dummy = torch.zeros(output_shape, dtype=output_dtype, device=device_str)
409405
if test_case.output.is_contiguous or test_case.output.strides is None:
410406
infini_output = create_infinicore_tensor(torch_dummy, device_str)
@@ -413,7 +409,9 @@ def torch_op_inplace():
413409
infini_output = create_strided_infinicore_tensor(torch_dummy, device_str)
414410

415411
def infini_op_inplace():
416-
self.infinicore_operator(*infini_inputs, out=infini_output, **kwargs)
412+
self.infinicore_operator_inplace(
413+
*infini_inputs, out=infini_output, **kwargs
414+
)
417415

418416
infini_op_inplace()
419417

test/infinicore/framework/runner.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,29 @@ def run(self):
3434
print(f"Starting {self.operator_test.operator_name} tests...")
3535
all_passed = True
3636

37-
# Run out-of-place tests
38-
print(f"\n--- Testing Out-of-place {self.operator_test.operator_name} ---")
39-
out_of_place_passed = runner.run_tests(devices, self.operator_test.run_test)
40-
all_passed = all_passed and out_of_place_passed
41-
42-
# Run in-place tests
43-
print(f"\n--- Testing In-place {self.operator_test.operator_name} ---")
44-
in_place_passed = runner.run_tests(devices, self.operator_test.run_inplace_test)
45-
all_passed = all_passed and in_place_passed
37+
# Run out-of-place tests if defined
38+
if self.operator_test.has_out_of_place_test():
39+
print(f"\n--- Testing Out-of-place {self.operator_test.operator_name} ---")
40+
out_of_place_passed = runner.run_tests(
41+
devices, self.operator_test.run_out_of_place_test, "Out-of-place"
42+
)
43+
all_passed = all_passed and out_of_place_passed
44+
else:
45+
print(
46+
f"\n--- Skipping Out-of-place {self.operator_test.operator_name} (not defined) ---"
47+
)
48+
49+
# Run in-place tests if defined
50+
if self.operator_test.has_inplace_test():
51+
print(f"\n--- Testing In-place {self.operator_test.operator_name} ---")
52+
in_place_passed = runner.run_tests(
53+
devices, self.operator_test.run_inplace_test, "In-place"
54+
)
55+
all_passed = all_passed and in_place_passed
56+
else:
57+
print(
58+
f"\n--- Skipping In-place {self.operator_test.operator_name} (not defined) ---"
59+
)
4660

4761
# Print summary
4862
summary_passed = runner.print_summary()

test/infinicore/ops/add.py

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,17 @@
33

44
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
55

6+
import torch
67
import infinicore
78
from framework import create_test_cases
8-
from framework.templates import BinaryOperatorTest
9+
from framework.base import BaseOperatorTest
910
from framework.runner import GenericTestRunner
1011

1112
# ==============================================================================
1213
# Operator-specific configuration
1314
# ==============================================================================
1415

15-
# Test cases in flexible format:
16-
# - Single shape tuple: (13, 4) → automatically expands to ((13, 4), None, None, None)
17-
# - Nested single shape: ((13, 4),) → automatically expands to ((13, 4), None, None, None)
18-
# - Full format: ((13, 4), None, None, None) or ((13, 4), (10, 1), (10, 1), (10, 1))
16+
# Test cases in flexible format
1917
_TEST_CASES_DATA = [
2018
((13, 4)),
2119
((13, 4), (10, 1), (10, 1), (10, 1)),
@@ -32,16 +30,14 @@
3230
]
3331

3432
# Parameter mapping configuration for add operator
35-
# Format: (shape, a_stride, b_stride, c_stride)
36-
# Call signature: add(input, other)
3733
_ADD_PARAMETER_MAPPING = (
38-
"add", # operator_name
39-
"add(input, other)", # call_signature
40-
[ # input_configs
41-
{"shape": 0, "stride": 1}, # input: shape from index 0, stride from index 1
42-
{"shape": 0, "stride": 2}, # other: shape from index 0, stride from index 2
34+
"add",
35+
"add(input, other)",
36+
[
37+
{"shape": 0, "stride": 1},
38+
{"shape": 0, "stride": 2},
4339
],
44-
{"shape": 0, "stride": 3}, # output: shape from index 0, stride from index 3
40+
{"shape": 0, "stride": 3},
4541
)
4642

4743
# Parse test cases using add parameter mapping
@@ -58,15 +54,40 @@
5854
}
5955

6056
# ==============================================================================
61-
# Operator test class
57+
# Operator test class with specific test functions
6258
# ==============================================================================
6359

6460

65-
class AddTest(BinaryOperatorTest):
66-
"""Add test"""
61+
class AddTest(BaseOperatorTest):
62+
"""Add test with operator-specific test functions"""
6763

6864
def __init__(self):
69-
super().__init__("add", _TEST_CASES, _TENSOR_DTYPES, _TOLERANCE_MAP)
65+
super().__init__("add")
66+
67+
def get_test_cases(self):
68+
return _TEST_CASES
69+
70+
def get_tensor_dtypes(self):
71+
return _TENSOR_DTYPES
72+
73+
def get_tolerance_map(self):
74+
return _TOLERANCE_MAP
75+
76+
def torch_operator_inplace(self, a, b, out=None, **kwargs):
77+
"""PyTorch in-place add operation"""
78+
torch.add(a, b, out=out)
79+
80+
def infinicore_operator_inplace(self, a, b, out=None, **kwargs):
81+
"""Infinicore in-place add operation"""
82+
infinicore.add(a, b, out=out)
83+
84+
def torch_operator_out_of_place(self, a, b, **kwargs):
85+
"""PyTorch out-of-place add operation"""
86+
return torch.add(a, b)
87+
88+
def infinicore_operator_out_of_place(self, a, b, **kwargs):
89+
"""Infinicore out-of-place add operation"""
90+
return infinicore.add(a, b)
7091

7192

7293
# ==============================================================================

test/infinicore/ops/matmul.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33

44
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
55

6+
import torch
67
import infinicore
78
from framework import create_test_cases
8-
from framework.templates import BinaryOperatorTest
9+
from framework.base import BaseOperatorTest
910
from framework.runner import GenericTestRunner
1011

1112
# ==============================================================================
@@ -50,15 +51,40 @@
5051
}
5152

5253
# ==============================================================================
53-
# Operator test class
54+
# Operator test class with specific test functions
5455
# ==============================================================================
5556

5657

57-
class MatmulTest(BinaryOperatorTest):
58-
"""Matmul test"""
58+
class MatmulTest(BaseOperatorTest):
59+
"""Matmul test with operator-specific test functions"""
5960

6061
def __init__(self):
61-
super().__init__("matmul", _TEST_CASES, _TENSOR_DTYPES, _TOLERANCE_MAP)
62+
super().__init__("matmul")
63+
64+
def get_test_cases(self):
65+
return _TEST_CASES
66+
67+
def get_tensor_dtypes(self):
68+
return _TENSOR_DTYPES
69+
70+
def get_tolerance_map(self):
71+
return _TOLERANCE_MAP
72+
73+
def torch_operator_inplace(self, a, b, out=None, **kwargs):
74+
"""PyTorch in-place matmul operation"""
75+
torch.matmul(a, b, out=out)
76+
77+
def infinicore_operator_inplace(self, a, b, out=None, **kwargs):
78+
"""Infinicore in-place matmul operation"""
79+
infinicore.matmul(a, b, out=out)
80+
81+
def torch_operator_out_of_place(self, a, b, **kwargs):
82+
"""PyTorch out-of-place matmul operation"""
83+
return torch.matmul(a, b)
84+
85+
def infinicore_operator_out_of_place(self, a, b, **kwargs):
86+
"""Infinicore out-of-place matmul operation"""
87+
return infinicore.matmul(a, b)
6288

6389

6490
# ==============================================================================

0 commit comments

Comments
 (0)