1515import inspect
1616import math
1717import operator
18+ import os
1819import re
1920from dataclasses import dataclass , field
2021from decimal import ROUND_HALF_EVEN , Decimal
@@ -99,7 +100,7 @@ def or_(i: float) -> bool:
99100
100101def make_and (cond1 : UnaryCheck , cond2 : UnaryCheck ) -> UnaryCheck :
101102 def and_ (i : float ) -> bool :
102- return cond1 (i ) or cond2 (i )
103+ return cond1 (i ) and cond2 (i )
103104
104105 return and_
105106
@@ -492,6 +493,179 @@ def check_result(result: float) -> bool:
492493 return check_result , expr
493494
494495
496+ def parse_complex_value (value_str : str ) -> complex :
497+ """
498+ Parses a complex value string to return a complex number, e.g.
499+
500+ >>> parse_complex_value('+0 + 0j')
501+ 0j
502+ >>> parse_complex_value('NaN + NaN j')
503+ (nan+nanj)
504+ >>> parse_complex_value('0 + NaN j')
505+ nanj
506+ >>> parse_complex_value('+0 + πj/2')
507+ 1.5707963267948966j
508+ >>> parse_complex_value('+infinity + 3πj/4')
509+ (inf+2.356194490192345j)
510+
511+ Handles formats: "A + Bj", "A + B j", "A + πj/N", "A + Nπj/M"
512+ """
513+ m = r_complex_value .match (value_str )
514+ if m is None :
515+ raise ParseError (value_str )
516+
517+ # Parse real part with its sign
518+ real_sign = m .group (1 ) if m .group (1 ) else "+"
519+ real_val_str = m .group (2 )
520+ real_val = parse_value (real_sign + real_val_str )
521+
522+ # Parse imaginary part with its sign
523+ imag_sign = m .group (3 )
524+ # Group 4 is πj form (e.g., "πj/2"), group 5 is plain form (e.g., "NaN")
525+ if m .group (4 ): # πj form
526+ imag_val_str_raw = m .group (4 )
527+ # Remove 'j' to get coefficient: "πj/2" -> "π/2"
528+ imag_val_str = imag_val_str_raw .replace ('j' , '' )
529+ else : # plain form
530+ imag_val_str_raw = m .group (5 )
531+ # Strip trailing 'j' if present: "0j" -> "0"
532+ imag_val_str = imag_val_str_raw [:- 1 ] if imag_val_str_raw .endswith ('j' ) else imag_val_str_raw
533+
534+ imag_val = parse_value (imag_sign + imag_val_str )
535+
536+ return complex (real_val , imag_val )
537+
538+
539+ def make_strict_eq_complex (v : complex ) -> Callable [[complex ], bool ]:
540+ """
541+ Creates a checker for complex values that respects sign of zero and NaN.
542+ """
543+ real_check = make_strict_eq (v .real )
544+ imag_check = make_strict_eq (v .imag )
545+
546+ def strict_eq_complex (z : complex ) -> bool :
547+ return real_check (z .real ) and imag_check (z .imag )
548+
549+ return strict_eq_complex
550+
551+
552+ def parse_complex_cond (
553+ a_cond_str : str , b_cond_str : str
554+ ) -> Tuple [Callable [[complex ], bool ], str , FromDtypeFunc ]:
555+ """
556+ Parses complex condition strings for real (a) and imaginary (b) parts.
557+
558+ Returns:
559+ - cond: Function that checks if a complex number meets the condition
560+ - expr: String expression for the condition
561+ - from_dtype: Strategy generator for complex numbers meeting the condition
562+ """
563+ # Parse conditions for real and imaginary parts separately
564+ a_cond , a_expr_template , a_from_dtype = parse_cond (a_cond_str )
565+ b_cond , b_expr_template , b_from_dtype = parse_cond (b_cond_str )
566+
567+ # Create compound condition
568+ def complex_cond (z : complex ) -> bool :
569+ return a_cond (z .real ) and b_cond (z .imag )
570+
571+ # Create expression
572+ a_expr = a_expr_template .replace ("{}" , "real(x_i)" )
573+ b_expr = b_expr_template .replace ("{}" , "imag(x_i)" )
574+ expr = f"{ a_expr } and { b_expr } "
575+
576+ # Create strategy that generates complex numbers
577+ def complex_from_dtype (dtype : DataType , ** kw ) -> st .SearchStrategy [complex ]:
578+ assert len (kw ) == 0 # sanity check
579+ # For complex dtype, we need to get the corresponding float dtype
580+ # complex64 -> float32, complex128 -> float64
581+ if hasattr (dtype , 'name' ):
582+ if 'complex64' in str (dtype ):
583+ float_dtype = xp .float32
584+ elif 'complex128' in str (dtype ):
585+ float_dtype = xp .float64
586+ else :
587+ # Fallback to float64
588+ float_dtype = xp .float64
589+ else :
590+ float_dtype = xp .float64
591+
592+ real_strat = a_from_dtype (float_dtype )
593+ imag_strat = b_from_dtype (float_dtype )
594+ return st .builds (complex , real_strat , imag_strat )
595+
596+ return complex_cond , expr , complex_from_dtype
597+
598+
599+ def _check_component_with_tolerance (actual : float , expected : float , allow_any_sign : bool ) -> bool :
600+ """
601+ Helper to check if actual matches expected, with optional sign flexibility and tolerance.
602+ """
603+ if allow_any_sign and not math .isnan (expected ):
604+ return abs (actual ) == abs (expected ) or math .isclose (abs (actual ), abs (expected ), abs_tol = 0.01 )
605+ elif not math .isnan (expected ):
606+ check_fn = make_strict_eq (expected ) if expected == 0 or math .isinf (expected ) else make_rough_eq (expected )
607+ return check_fn (actual )
608+ else :
609+ return math .isnan (actual )
610+
611+
612+ def parse_complex_result (result_str : str ) -> Tuple [Callable [[complex ], bool ], str ]:
613+ """
614+ Parses a complex result string to return a checker and expression.
615+
616+ Handles cases like:
617+ - "``+0 + 0j``" - exact complex value
618+ - "``0 + NaN j`` (sign of the real component is unspecified)"
619+ - "``+0 + πj/2``" - with π expressions (uses approximate equality)
620+ """
621+ # Check for unspecified sign notes
622+ unspecified_real_sign = "sign of the real component is unspecified" in result_str
623+ unspecified_imag_sign = "sign of the imaginary component is unspecified" in result_str
624+
625+ # Extract the complex value from backticks - need to handle spaces in complex values
626+ # Pattern: ``...`` where ... can contain spaces (for complex values like "0 + NaN j")
627+ m = re .search (r"``([^`]+)``" , result_str )
628+ if m :
629+ value_str = m .group (1 )
630+ # Check if the value contains π expressions (for approximate comparison)
631+ has_pi = 'π' in value_str
632+
633+ try :
634+ expected = parse_complex_value (value_str )
635+ except ParseError :
636+ raise ParseError (result_str )
637+
638+ # Create checker based on whether signs are unspecified and whether π is involved
639+ if has_pi :
640+ # Use approximate equality for both real and imaginary parts if they involve π
641+ def check_result (z : complex ) -> bool :
642+ real_match = _check_component_with_tolerance (z .real , expected .real , unspecified_real_sign )
643+ imag_match = _check_component_with_tolerance (z .imag , expected .imag , unspecified_imag_sign )
644+ return real_match and imag_match
645+ elif unspecified_real_sign and not math .isnan (expected .real ):
646+ # Allow any sign for real part
647+ def check_result (z : complex ) -> bool :
648+ imag_check = make_strict_eq (expected .imag )
649+ return abs (z .real ) == abs (expected .real ) and imag_check (z .imag )
650+ elif unspecified_imag_sign and not math .isnan (expected .imag ):
651+ # Allow any sign for imaginary part
652+ def check_result (z : complex ) -> bool :
653+ real_check = make_strict_eq (expected .real )
654+ return real_check (z .real ) and abs (z .imag ) == abs (expected .imag )
655+ elif unspecified_real_sign and unspecified_imag_sign :
656+ # Allow any sign for both parts
657+ def check_result (z : complex ) -> bool :
658+ return abs (z .real ) == abs (expected .real ) and abs (z .imag ) == abs (expected .imag )
659+ else :
660+ # Exact match including signs
661+ check_result = make_strict_eq_complex (expected )
662+
663+ expr = value_str
664+ return check_result , expr
665+ else :
666+ raise ParseError (result_str )
667+
668+
495669class Case (Protocol ):
496670 cond_expr : str
497671 result_expr : str
@@ -535,6 +709,7 @@ class UnaryCase(Case):
535709 cond : UnaryCheck
536710 check_result : UnaryResultCheck
537711 raw_case : Optional [str ] = field (default = None )
712+ is_complex : bool = field (default = False )
538713
539714
540715r_unary_case = re .compile ("If ``x_i`` is (.+), the result is (.+)" )
@@ -549,6 +724,16 @@ class UnaryCase(Case):
549724 "If ``x_i`` is ``NaN`` and the sign bit of ``x_i`` is ``(.+)``, "
550725 "the result is ``(.+)``"
551726)
727+ # Regex patterns for complex special cases
728+ r_complex_marker = re .compile (
729+ r"For complex floating-point operands, let ``a = real\(x_i\)``, ``b = imag\(x_i\)``"
730+ )
731+ r_complex_case = re .compile (r"If ``a`` is (.+) and ``b`` is (.+), the result is (.+)" )
732+ # Matches complex values like "+0 + 0j", "NaN + NaN j", "infinity + NaN j", "πj/2", "3πj/4"
733+ # Two formats: 1) πj/N expressions where j is part of the coefficient, 2) plain values followed by j
734+ r_complex_value = re .compile (
735+ r"([+-]?)([^\s]+)\s*([+-])\s*(?:(\d*πj(?:/\d+)?)|([^\s]+))\s*j?"
736+ )
552737
553738
554739def integers_from_dtype (dtype : DataType , ** kw ) -> st .SearchStrategy [float ]:
@@ -630,7 +815,15 @@ def check_result(i: float, result: float) -> bool:
630815 return check_result
631816
632817
633- def parse_unary_case_block (case_block : str , func_name : str ) -> List [UnaryCase ]:
818+ def make_complex_unary_check_result (check_fn : Callable [[complex ], bool ]) -> UnaryResultCheck :
819+ """Wraps a complex check function for use in UnaryCase."""
820+ def check_result (in_value , out_value ):
821+ # in_value is complex, out_value is complex
822+ return check_fn (out_value )
823+ return check_result
824+
825+
826+ def parse_unary_case_block (case_block : str , func_name : str , record_list : Optional [List [str ]] = None ) -> List [UnaryCase ]:
634827 """
635828 Parses a Sphinx-formatted docstring of a unary function to return a list of
636829 codified unary cases, e.g.
@@ -677,8 +870,52 @@ def parse_unary_case_block(case_block: str, func_name: str) -> List[UnaryCase]:
677870
678871 """
679872 cases = []
873+ # Check if the case block contains complex cases by looking for the marker
874+ in_complex_section = r_complex_marker .search (case_block ) is not None
875+
680876 for case_m in r_case .finditer (case_block ):
681877 case_str = case_m .group (1 )
878+
879+ # Record this special case if a record list is provided
880+ if record_list is not None :
881+ record_list .append (f"{ func_name } : { case_str } ." )
882+
883+
884+ # Try to parse complex cases if we're in the complex section
885+ if in_complex_section and (m := r_complex_case .search (case_str )):
886+ try :
887+ a_cond_str = m .group (1 )
888+ b_cond_str = m .group (2 )
889+ result_str = m .group (3 )
890+
891+ # Skip cases with complex expressions like "cis(b)"
892+ if "cis" in result_str or "*" in result_str :
893+ warn (f"case for { func_name } not machine-readable: '{ case_str } '" )
894+ continue
895+
896+ # Parse the complex condition and result
897+ complex_cond , cond_expr , complex_from_dtype = parse_complex_cond (
898+ a_cond_str , b_cond_str
899+ )
900+ _check_result , result_expr = parse_complex_result (result_str )
901+
902+ check_result = make_complex_unary_check_result (_check_result )
903+
904+ case = UnaryCase (
905+ cond_expr = cond_expr ,
906+ cond = complex_cond ,
907+ cond_from_dtype = complex_from_dtype ,
908+ result_expr = result_expr ,
909+ check_result = check_result ,
910+ raw_case = case_str ,
911+ is_complex = True ,
912+ )
913+ cases .append (case )
914+ except ParseError as e :
915+ warn (f"case for { func_name } not machine-readable: '{ e .value } '" )
916+ continue
917+
918+ # Parse regular (real-valued) cases
682919 if r_already_int_case .search (case_str ):
683920 cases .append (already_int_case )
684921 elif r_even_round_halves_case .search (case_str ):
@@ -1103,7 +1340,7 @@ def cond(i1: float, i2: float) -> bool:
11031340r_redundant_case = re .compile ("result.+determined by the rule already stated above" )
11041341
11051342
1106- def parse_binary_case_block (case_block : str , func_name : str ) -> List [BinaryCase ]:
1343+ def parse_binary_case_block (case_block : str , func_name : str , record_list : Optional [ List [ str ]] = None ) -> List [BinaryCase ]:
11071344 """
11081345 Parses a Sphinx-formatted docstring of a binary function to return a list of
11091346 codified binary cases, e.g.
@@ -1145,6 +1382,11 @@ def parse_binary_case_block(case_block: str, func_name: str) -> List[BinaryCase]
11451382 cases = []
11461383 for case_m in r_case .finditer (case_block ):
11471384 case_str = case_m .group (1 )
1385+
1386+ # Record this special case if a record list is provided
1387+ if record_list is not None :
1388+ record_list .append (f"{ func_name } : { case_str } ." )
1389+
11481390 if r_redundant_case .search (case_str ):
11491391 continue
11501392 if r_binary_case .match (case_str ):
@@ -1162,6 +1404,7 @@ def parse_binary_case_block(case_block: str, func_name: str) -> List[BinaryCase]
11621404unary_params = []
11631405binary_params = []
11641406iop_params = []
1407+ special_case_records = [] # List of "func_name: case_str" for all special cases
11651408func_to_op : Dict [str , str ] = {v : k for k , v in dh .op_to_func .items ()}
11661409for stub in category_to_funcs ["elementwise" ]:
11671410 func_name = stub .__name__
@@ -1186,7 +1429,7 @@ def parse_binary_case_block(case_block: str, func_name: str) -> List[BinaryCase]
11861429 warn (f"{ func = } has no parameters" )
11871430 continue
11881431 if param_names [0 ] == "x" :
1189- if cases := parse_unary_case_block (case_block , func_name ):
1432+ if cases := parse_unary_case_block (case_block , func_name , special_case_records ):
11901433 name_to_func = {func_name : func }
11911434 if func_name in func_to_op .keys ():
11921435 op_name = func_to_op [func_name ]
@@ -1204,7 +1447,7 @@ def parse_binary_case_block(case_block: str, func_name: str) -> List[BinaryCase]
12041447 warn (f"{ func = } has one parameter '{ param_names [0 ]} ' which is not named 'x'" )
12051448 continue
12061449 if param_names [0 ] == "x1" and param_names [1 ] == "x2" :
1207- if cases := parse_binary_case_block (case_block , func_name ):
1450+ if cases := parse_binary_case_block (case_block , func_name , special_case_records ):
12081451 name_to_func = {func_name : func }
12091452 if func_name in func_to_op .keys ():
12101453 op_name = func_to_op [func_name ]
@@ -1249,6 +1492,22 @@ def parse_binary_case_block(case_block: str, func_name: str) -> List[BinaryCase]
12491492assert len (iop_params ) != 0
12501493
12511494
1495+ @pytest .fixture (scope = "session" , autouse = True )
1496+ def emit_special_case_records ():
1497+ """Emit all special case records at the start of test session."""
1498+ # This runs once at the beginning of the test session
1499+ if os .environ .get ('ARRAY_API_TESTS_SPECIAL_CASES_VERBOSE' ) == '1' :
1500+ print ("\n " + "=" * 80 )
1501+ print ("SPECIAL CASE RECORDS" )
1502+ print ("=" * 80 )
1503+ for record in special_case_records :
1504+ print (record )
1505+ print ("=" * 80 )
1506+ print (f"Total special cases: { len (special_case_records )} " )
1507+ print ("=" * 80 + "\n " )
1508+ yield # Tests run after this point
1509+
1510+
12521511@pytest .mark .parametrize ("func_name, func, case" , unary_params )
12531512def test_unary (func_name , func , case ):
12541513 with catch_warnings ():
@@ -1257,10 +1516,24 @@ def test_unary(func_name, func, case):
12571516 # drawing multiple examples like a normal test, or just hard-coding a
12581517 # single example test case without using hypothesis.
12591518 filterwarnings ('ignore' , category = NonInteractiveExampleWarning )
1260- in_value = case .cond_from_dtype (xp .float64 ).example ()
1261- x = xp .asarray (in_value , dtype = xp .float64 )
1519+
1520+ # Use the is_complex flag to determine the appropriate dtype
1521+ if case .is_complex :
1522+ dtype = xp .complex128
1523+ in_value = case .cond_from_dtype (dtype ).example ()
1524+ else :
1525+ dtype = xp .float64
1526+ in_value = case .cond_from_dtype (dtype ).example ()
1527+
1528+ # Create array and compute result based on dtype
1529+ x = xp .asarray (in_value , dtype = dtype )
12621530 out = func (x )
1263- out_value = float (out )
1531+
1532+ if case .is_complex :
1533+ out_value = complex (out )
1534+ else :
1535+ out_value = float (out )
1536+
12641537 assert case .check_result (in_value , out_value ), (
12651538 f"out={ out_value } , but should be { case .result_expr } [{ func_name } ()]\n "
12661539 )
0 commit comments