Skip to content

Commit 900ab8e

Browse files
Copilotev-br
andcommitted
Add special case recording and emission during test run
Co-authored-by: ev-br <2133832+ev-br@users.noreply.github.com>
1 parent e48216b commit 900ab8e

1 file changed

Lines changed: 31 additions & 4 deletions

File tree

array_api_tests/test_special_cases.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -822,7 +822,7 @@ def check_result(in_value, out_value):
822822
return check_result
823823

824824

825-
def parse_unary_case_block(case_block: str, func_name: str) -> List[UnaryCase]:
825+
def parse_unary_case_block(case_block: str, func_name: str, record_list: Optional[List[str]] = None) -> List[UnaryCase]:
826826
"""
827827
Parses a Sphinx-formatted docstring of a unary function to return a list of
828828
codified unary cases, e.g.
@@ -875,6 +875,11 @@ def parse_unary_case_block(case_block: str, func_name: str) -> List[UnaryCase]:
875875
for case_m in r_case.finditer(case_block):
876876
case_str = case_m.group(1)
877877

878+
# Record this special case if a record list is provided
879+
if record_list is not None:
880+
record_list.append(f"{func_name}: {case_str}.")
881+
882+
878883
# Try to parse complex cases if we're in the complex section
879884
if in_complex_section and (m := r_complex_case.search(case_str)):
880885
try:
@@ -1334,7 +1339,7 @@ def cond(i1: float, i2: float) -> bool:
13341339
r_redundant_case = re.compile("result.+determined by the rule already stated above")
13351340

13361341

1337-
def parse_binary_case_block(case_block: str, func_name: str) -> List[BinaryCase]:
1342+
def parse_binary_case_block(case_block: str, func_name: str, record_list: Optional[List[str]] = None) -> List[BinaryCase]:
13381343
"""
13391344
Parses a Sphinx-formatted docstring of a binary function to return a list of
13401345
codified binary cases, e.g.
@@ -1376,6 +1381,11 @@ def parse_binary_case_block(case_block: str, func_name: str) -> List[BinaryCase]
13761381
cases = []
13771382
for case_m in r_case.finditer(case_block):
13781383
case_str = case_m.group(1)
1384+
1385+
# Record this special case if a record list is provided
1386+
if record_list is not None:
1387+
record_list.append(f"{func_name}: {case_str}.")
1388+
13791389
if r_redundant_case.search(case_str):
13801390
continue
13811391
if r_binary_case.match(case_str):
@@ -1393,6 +1403,7 @@ def parse_binary_case_block(case_block: str, func_name: str) -> List[BinaryCase]
13931403
unary_params = []
13941404
binary_params = []
13951405
iop_params = []
1406+
special_case_records = [] # List of "func_name: case_str" for all special cases
13961407
func_to_op: Dict[str, str] = {v: k for k, v in dh.op_to_func.items()}
13971408
for stub in category_to_funcs["elementwise"]:
13981409
func_name = stub.__name__
@@ -1417,7 +1428,7 @@ def parse_binary_case_block(case_block: str, func_name: str) -> List[BinaryCase]
14171428
warn(f"{func=} has no parameters")
14181429
continue
14191430
if param_names[0] == "x":
1420-
if cases := parse_unary_case_block(case_block, func_name):
1431+
if cases := parse_unary_case_block(case_block, func_name, special_case_records):
14211432
name_to_func = {func_name: func}
14221433
if func_name in func_to_op.keys():
14231434
op_name = func_to_op[func_name]
@@ -1435,7 +1446,7 @@ def parse_binary_case_block(case_block: str, func_name: str) -> List[BinaryCase]
14351446
warn(f"{func=} has one parameter '{param_names[0]}' which is not named 'x'")
14361447
continue
14371448
if param_names[0] == "x1" and param_names[1] == "x2":
1438-
if cases := parse_binary_case_block(case_block, func_name):
1449+
if cases := parse_binary_case_block(case_block, func_name, special_case_records):
14391450
name_to_func = {func_name: func}
14401451
if func_name in func_to_op.keys():
14411452
op_name = func_to_op[func_name]
@@ -1480,6 +1491,22 @@ def parse_binary_case_block(case_block: str, func_name: str) -> List[BinaryCase]
14801491
assert len(iop_params) != 0
14811492

14821493

1494+
def emit_special_case_records():
1495+
"""Emit all special case records for debugging/tracking purposes."""
1496+
print("\n" + "="*80)
1497+
print("SPECIAL CASE RECORDS")
1498+
print("="*80)
1499+
for record in special_case_records:
1500+
print(record)
1501+
print("="*80)
1502+
print(f"Total special cases: {len(special_case_records)}")
1503+
print("="*80 + "\n")
1504+
1505+
1506+
# Emit special case records at module load time
1507+
emit_special_case_records()
1508+
1509+
14831510
@pytest.mark.parametrize("func_name, func, case", unary_params)
14841511
def test_unary(func_name, func, case):
14851512
with catch_warnings():

0 commit comments

Comments
 (0)