@@ -822,7 +822,7 @@ def check_result(in_value, out_value):
822822 return check_result
823823
824824
825- def parse_unary_case_block (case_block : str , func_name : str ) -> List [UnaryCase ]:
825+ def parse_unary_case_block (case_block : str , func_name : str , record_list : Optional [ List [ str ]] = None ) -> List [UnaryCase ]:
826826 """
827827 Parses a Sphinx-formatted docstring of a unary function to return a list of
828828 codified unary cases, e.g.
@@ -875,6 +875,11 @@ def parse_unary_case_block(case_block: str, func_name: str) -> List[UnaryCase]:
875875 for case_m in r_case .finditer (case_block ):
876876 case_str = case_m .group (1 )
877877
878+ # Record this special case if a record list is provided
879+ if record_list is not None :
880+ record_list .append (f"{ func_name } : { case_str } ." )
881+
882+
878883 # Try to parse complex cases if we're in the complex section
879884 if in_complex_section and (m := r_complex_case .search (case_str )):
880885 try :
@@ -1334,7 +1339,7 @@ def cond(i1: float, i2: float) -> bool:
13341339r_redundant_case = re .compile ("result.+determined by the rule already stated above" )
13351340
13361341
1337- def parse_binary_case_block (case_block : str , func_name : str ) -> List [BinaryCase ]:
1342+ def parse_binary_case_block (case_block : str , func_name : str , record_list : Optional [ List [ str ]] = None ) -> List [BinaryCase ]:
13381343 """
13391344 Parses a Sphinx-formatted docstring of a binary function to return a list of
13401345 codified binary cases, e.g.
@@ -1376,6 +1381,11 @@ def parse_binary_case_block(case_block: str, func_name: str) -> List[BinaryCase]
13761381 cases = []
13771382 for case_m in r_case .finditer (case_block ):
13781383 case_str = case_m .group (1 )
1384+
1385+ # Record this special case if a record list is provided
1386+ if record_list is not None :
1387+ record_list .append (f"{ func_name } : { case_str } ." )
1388+
13791389 if r_redundant_case .search (case_str ):
13801390 continue
13811391 if r_binary_case .match (case_str ):
@@ -1393,6 +1403,7 @@ def parse_binary_case_block(case_block: str, func_name: str) -> List[BinaryCase]
13931403unary_params = []
13941404binary_params = []
13951405iop_params = []
1406+ special_case_records = [] # List of "func_name: case_str" for all special cases
13961407func_to_op : Dict [str , str ] = {v : k for k , v in dh .op_to_func .items ()}
13971408for stub in category_to_funcs ["elementwise" ]:
13981409 func_name = stub .__name__
@@ -1417,7 +1428,7 @@ def parse_binary_case_block(case_block: str, func_name: str) -> List[BinaryCase]
14171428 warn (f"{ func = } has no parameters" )
14181429 continue
14191430 if param_names [0 ] == "x" :
1420- if cases := parse_unary_case_block (case_block , func_name ):
1431+ if cases := parse_unary_case_block (case_block , func_name , special_case_records ):
14211432 name_to_func = {func_name : func }
14221433 if func_name in func_to_op .keys ():
14231434 op_name = func_to_op [func_name ]
@@ -1435,7 +1446,7 @@ def parse_binary_case_block(case_block: str, func_name: str) -> List[BinaryCase]
14351446 warn (f"{ func = } has one parameter '{ param_names [0 ]} ' which is not named 'x'" )
14361447 continue
14371448 if param_names [0 ] == "x1" and param_names [1 ] == "x2" :
1438- if cases := parse_binary_case_block (case_block , func_name ):
1449+ if cases := parse_binary_case_block (case_block , func_name , special_case_records ):
14391450 name_to_func = {func_name : func }
14401451 if func_name in func_to_op .keys ():
14411452 op_name = func_to_op [func_name ]
@@ -1480,6 +1491,22 @@ def parse_binary_case_block(case_block: str, func_name: str) -> List[BinaryCase]
14801491assert len (iop_params ) != 0
14811492
14821493
1494+ def emit_special_case_records ():
1495+ """Emit all special case records for debugging/tracking purposes."""
1496+ print ("\n " + "=" * 80 )
1497+ print ("SPECIAL CASE RECORDS" )
1498+ print ("=" * 80 )
1499+ for record in special_case_records :
1500+ print (record )
1501+ print ("=" * 80 )
1502+ print (f"Total special cases: { len (special_case_records )} " )
1503+ print ("=" * 80 + "\n " )
1504+
1505+
1506+ # Emit special case records at module load time
1507+ emit_special_case_records ()
1508+
1509+
14831510@pytest .mark .parametrize ("func_name, func, case" , unary_params )
14841511def test_unary (func_name , func , case ):
14851512 with catch_warnings ():
0 commit comments