Skip to content

Commit 63317c6

Browse files
Copilotev-br
authored andcommitted
ENH: parse complex special cases from stubs of unary functions
"stub" docstrings include "special cases": > `if x_i is +infinity, sqrt(x_i) is +infinity` etc `test_special_cases` parses these statements from docstrings and makes them into tests cases. So far, parsing only worked for real value cases, and failed for complex-valued cases: > For complex floating-point operands, let a = real(x_i), b = imag(x_i), and > `If a is either +0 or -0 and b is +0, the result is +0 + 0j.` These stanzas simply generate "case for {func} is not machine-readable" `UserWarning`s. Quite a wall of them. Therefore, we update parsing and testing code to take these complex-valued cases into accout. For now, we only consider unary functions. The effect is: $ ARRAY_API_TESTS_MODULE=array_api_compat.torch pytest array_api_tests/test_special_cases.py::test_unary generates - "128 passed, 177 warnings in 0.78s" on master - "11 failed, 241 passed, 49 warnings in 1.82s" on this branch So that there are new failures (from new complex-valued cases) but we 128 less warnings.
1 parent b4038ce commit 63317c6

1 file changed

Lines changed: 281 additions & 8 deletions

File tree

array_api_tests/test_special_cases.py

Lines changed: 281 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import inspect
1616
import math
1717
import operator
18+
import os
1819
import re
1920
from dataclasses import dataclass, field
2021
from decimal import ROUND_HALF_EVEN, Decimal
@@ -99,7 +100,7 @@ def or_(i: float) -> bool:
99100

100101
def make_and(cond1: UnaryCheck, cond2: UnaryCheck) -> UnaryCheck:
101102
def and_(i: float) -> bool:
102-
return cond1(i) or cond2(i)
103+
return cond1(i) and cond2(i)
103104

104105
return and_
105106

@@ -492,6 +493,179 @@ def check_result(result: float) -> bool:
492493
return check_result, expr
493494

494495

496+
def parse_complex_value(value_str: str) -> complex:
497+
"""
498+
Parses a complex value string to return a complex number, e.g.
499+
500+
>>> parse_complex_value('+0 + 0j')
501+
0j
502+
>>> parse_complex_value('NaN + NaN j')
503+
(nan+nanj)
504+
>>> parse_complex_value('0 + NaN j')
505+
nanj
506+
>>> parse_complex_value('+0 + πj/2')
507+
1.5707963267948966j
508+
>>> parse_complex_value('+infinity + 3πj/4')
509+
(inf+2.356194490192345j)
510+
511+
Handles formats: "A + Bj", "A + B j", "A + πj/N", "A + Nπj/M"
512+
"""
513+
m = r_complex_value.match(value_str)
514+
if m is None:
515+
raise ParseError(value_str)
516+
517+
# Parse real part with its sign
518+
real_sign = m.group(1) if m.group(1) else "+"
519+
real_val_str = m.group(2)
520+
real_val = parse_value(real_sign + real_val_str)
521+
522+
# Parse imaginary part with its sign
523+
imag_sign = m.group(3)
524+
# Group 4 is πj form (e.g., "πj/2"), group 5 is plain form (e.g., "NaN")
525+
if m.group(4): # πj form
526+
imag_val_str_raw = m.group(4)
527+
# Remove 'j' to get coefficient: "πj/2" -> "π/2"
528+
imag_val_str = imag_val_str_raw.replace('j', '')
529+
else: # plain form
530+
imag_val_str_raw = m.group(5)
531+
# Strip trailing 'j' if present: "0j" -> "0"
532+
imag_val_str = imag_val_str_raw[:-1] if imag_val_str_raw.endswith('j') else imag_val_str_raw
533+
534+
imag_val = parse_value(imag_sign + imag_val_str)
535+
536+
return complex(real_val, imag_val)
537+
538+
539+
def make_strict_eq_complex(v: complex) -> Callable[[complex], bool]:
540+
"""
541+
Creates a checker for complex values that respects sign of zero and NaN.
542+
"""
543+
real_check = make_strict_eq(v.real)
544+
imag_check = make_strict_eq(v.imag)
545+
546+
def strict_eq_complex(z: complex) -> bool:
547+
return real_check(z.real) and imag_check(z.imag)
548+
549+
return strict_eq_complex
550+
551+
552+
def parse_complex_cond(
553+
a_cond_str: str, b_cond_str: str
554+
) -> Tuple[Callable[[complex], bool], str, FromDtypeFunc]:
555+
"""
556+
Parses complex condition strings for real (a) and imaginary (b) parts.
557+
558+
Returns:
559+
- cond: Function that checks if a complex number meets the condition
560+
- expr: String expression for the condition
561+
- from_dtype: Strategy generator for complex numbers meeting the condition
562+
"""
563+
# Parse conditions for real and imaginary parts separately
564+
a_cond, a_expr_template, a_from_dtype = parse_cond(a_cond_str)
565+
b_cond, b_expr_template, b_from_dtype = parse_cond(b_cond_str)
566+
567+
# Create compound condition
568+
def complex_cond(z: complex) -> bool:
569+
return a_cond(z.real) and b_cond(z.imag)
570+
571+
# Create expression
572+
a_expr = a_expr_template.replace("{}", "real(x_i)")
573+
b_expr = b_expr_template.replace("{}", "imag(x_i)")
574+
expr = f"{a_expr} and {b_expr}"
575+
576+
# Create strategy that generates complex numbers
577+
def complex_from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[complex]:
578+
assert len(kw) == 0 # sanity check
579+
# For complex dtype, we need to get the corresponding float dtype
580+
# complex64 -> float32, complex128 -> float64
581+
if hasattr(dtype, 'name'):
582+
if 'complex64' in str(dtype):
583+
float_dtype = xp.float32
584+
elif 'complex128' in str(dtype):
585+
float_dtype = xp.float64
586+
else:
587+
# Fallback to float64
588+
float_dtype = xp.float64
589+
else:
590+
float_dtype = xp.float64
591+
592+
real_strat = a_from_dtype(float_dtype)
593+
imag_strat = b_from_dtype(float_dtype)
594+
return st.builds(complex, real_strat, imag_strat)
595+
596+
return complex_cond, expr, complex_from_dtype
597+
598+
599+
def _check_component_with_tolerance(actual: float, expected: float, allow_any_sign: bool) -> bool:
600+
"""
601+
Helper to check if actual matches expected, with optional sign flexibility and tolerance.
602+
"""
603+
if allow_any_sign and not math.isnan(expected):
604+
return abs(actual) == abs(expected) or math.isclose(abs(actual), abs(expected), abs_tol=0.01)
605+
elif not math.isnan(expected):
606+
check_fn = make_strict_eq(expected) if expected == 0 or math.isinf(expected) else make_rough_eq(expected)
607+
return check_fn(actual)
608+
else:
609+
return math.isnan(actual)
610+
611+
612+
def parse_complex_result(result_str: str) -> Tuple[Callable[[complex], bool], str]:
613+
"""
614+
Parses a complex result string to return a checker and expression.
615+
616+
Handles cases like:
617+
- "``+0 + 0j``" - exact complex value
618+
- "``0 + NaN j`` (sign of the real component is unspecified)"
619+
- "``+0 + πj/2``" - with π expressions (uses approximate equality)
620+
"""
621+
# Check for unspecified sign notes
622+
unspecified_real_sign = "sign of the real component is unspecified" in result_str
623+
unspecified_imag_sign = "sign of the imaginary component is unspecified" in result_str
624+
625+
# Extract the complex value from backticks - need to handle spaces in complex values
626+
# Pattern: ``...`` where ... can contain spaces (for complex values like "0 + NaN j")
627+
m = re.search(r"``([^`]+)``", result_str)
628+
if m:
629+
value_str = m.group(1)
630+
# Check if the value contains π expressions (for approximate comparison)
631+
has_pi = 'π' in value_str
632+
633+
try:
634+
expected = parse_complex_value(value_str)
635+
except ParseError:
636+
raise ParseError(result_str)
637+
638+
# Create checker based on whether signs are unspecified and whether π is involved
639+
if has_pi:
640+
# Use approximate equality for both real and imaginary parts if they involve π
641+
def check_result(z: complex) -> bool:
642+
real_match = _check_component_with_tolerance(z.real, expected.real, unspecified_real_sign)
643+
imag_match = _check_component_with_tolerance(z.imag, expected.imag, unspecified_imag_sign)
644+
return real_match and imag_match
645+
elif unspecified_real_sign and not math.isnan(expected.real):
646+
# Allow any sign for real part
647+
def check_result(z: complex) -> bool:
648+
imag_check = make_strict_eq(expected.imag)
649+
return abs(z.real) == abs(expected.real) and imag_check(z.imag)
650+
elif unspecified_imag_sign and not math.isnan(expected.imag):
651+
# Allow any sign for imaginary part
652+
def check_result(z: complex) -> bool:
653+
real_check = make_strict_eq(expected.real)
654+
return real_check(z.real) and abs(z.imag) == abs(expected.imag)
655+
elif unspecified_real_sign and unspecified_imag_sign:
656+
# Allow any sign for both parts
657+
def check_result(z: complex) -> bool:
658+
return abs(z.real) == abs(expected.real) and abs(z.imag) == abs(expected.imag)
659+
else:
660+
# Exact match including signs
661+
check_result = make_strict_eq_complex(expected)
662+
663+
expr = value_str
664+
return check_result, expr
665+
else:
666+
raise ParseError(result_str)
667+
668+
495669
class Case(Protocol):
496670
cond_expr: str
497671
result_expr: str
@@ -535,6 +709,7 @@ class UnaryCase(Case):
535709
cond: UnaryCheck
536710
check_result: UnaryResultCheck
537711
raw_case: Optional[str] = field(default=None)
712+
is_complex: bool = field(default=False)
538713

539714

540715
r_unary_case = re.compile("If ``x_i`` is (.+), the result is (.+)")
@@ -549,6 +724,16 @@ class UnaryCase(Case):
549724
"If ``x_i`` is ``NaN`` and the sign bit of ``x_i`` is ``(.+)``, "
550725
"the result is ``(.+)``"
551726
)
727+
# Regex patterns for complex special cases
728+
r_complex_marker = re.compile(
729+
r"For complex floating-point operands, let ``a = real\(x_i\)``, ``b = imag\(x_i\)``"
730+
)
731+
r_complex_case = re.compile(r"If ``a`` is (.+) and ``b`` is (.+), the result is (.+)")
732+
# Matches complex values like "+0 + 0j", "NaN + NaN j", "infinity + NaN j", "πj/2", "3πj/4"
733+
# Two formats: 1) πj/N expressions where j is part of the coefficient, 2) plain values followed by j
734+
r_complex_value = re.compile(
735+
r"([+-]?)([^\s]+)\s*([+-])\s*(?:(\d*πj(?:/\d+)?)|([^\s]+))\s*j?"
736+
)
552737

553738

554739
def integers_from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]:
@@ -630,7 +815,15 @@ def check_result(i: float, result: float) -> bool:
630815
return check_result
631816

632817

633-
def parse_unary_case_block(case_block: str, func_name: str) -> List[UnaryCase]:
818+
def make_complex_unary_check_result(check_fn: Callable[[complex], bool]) -> UnaryResultCheck:
819+
"""Wraps a complex check function for use in UnaryCase."""
820+
def check_result(in_value, out_value):
821+
# in_value is complex, out_value is complex
822+
return check_fn(out_value)
823+
return check_result
824+
825+
826+
def parse_unary_case_block(case_block: str, func_name: str, record_list: Optional[List[str]] = None) -> List[UnaryCase]:
634827
"""
635828
Parses a Sphinx-formatted docstring of a unary function to return a list of
636829
codified unary cases, e.g.
@@ -677,8 +870,52 @@ def parse_unary_case_block(case_block: str, func_name: str) -> List[UnaryCase]:
677870
678871
"""
679872
cases = []
873+
# Check if the case block contains complex cases by looking for the marker
874+
in_complex_section = r_complex_marker.search(case_block) is not None
875+
680876
for case_m in r_case.finditer(case_block):
681877
case_str = case_m.group(1)
878+
879+
# Record this special case if a record list is provided
880+
if record_list is not None:
881+
record_list.append(f"{func_name}: {case_str}.")
882+
883+
884+
# Try to parse complex cases if we're in the complex section
885+
if in_complex_section and (m := r_complex_case.search(case_str)):
886+
try:
887+
a_cond_str = m.group(1)
888+
b_cond_str = m.group(2)
889+
result_str = m.group(3)
890+
891+
# Skip cases with complex expressions like "cis(b)"
892+
if "cis" in result_str or "*" in result_str:
893+
warn(f"case for {func_name} not machine-readable: '{case_str}'")
894+
continue
895+
896+
# Parse the complex condition and result
897+
complex_cond, cond_expr, complex_from_dtype = parse_complex_cond(
898+
a_cond_str, b_cond_str
899+
)
900+
_check_result, result_expr = parse_complex_result(result_str)
901+
902+
check_result = make_complex_unary_check_result(_check_result)
903+
904+
case = UnaryCase(
905+
cond_expr=cond_expr,
906+
cond=complex_cond,
907+
cond_from_dtype=complex_from_dtype,
908+
result_expr=result_expr,
909+
check_result=check_result,
910+
raw_case=case_str,
911+
is_complex=True,
912+
)
913+
cases.append(case)
914+
except ParseError as e:
915+
warn(f"case for {func_name} not machine-readable: '{e.value}'")
916+
continue
917+
918+
# Parse regular (real-valued) cases
682919
if r_already_int_case.search(case_str):
683920
cases.append(already_int_case)
684921
elif r_even_round_halves_case.search(case_str):
@@ -1103,7 +1340,7 @@ def cond(i1: float, i2: float) -> bool:
11031340
r_redundant_case = re.compile("result.+determined by the rule already stated above")
11041341

11051342

1106-
def parse_binary_case_block(case_block: str, func_name: str) -> List[BinaryCase]:
1343+
def parse_binary_case_block(case_block: str, func_name: str, record_list: Optional[List[str]] = None) -> List[BinaryCase]:
11071344
"""
11081345
Parses a Sphinx-formatted docstring of a binary function to return a list of
11091346
codified binary cases, e.g.
@@ -1145,6 +1382,11 @@ def parse_binary_case_block(case_block: str, func_name: str) -> List[BinaryCase]
11451382
cases = []
11461383
for case_m in r_case.finditer(case_block):
11471384
case_str = case_m.group(1)
1385+
1386+
# Record this special case if a record list is provided
1387+
if record_list is not None:
1388+
record_list.append(f"{func_name}: {case_str}.")
1389+
11481390
if r_redundant_case.search(case_str):
11491391
continue
11501392
if r_binary_case.match(case_str):
@@ -1162,6 +1404,7 @@ def parse_binary_case_block(case_block: str, func_name: str) -> List[BinaryCase]
11621404
unary_params = []
11631405
binary_params = []
11641406
iop_params = []
1407+
special_case_records = [] # List of "func_name: case_str" for all special cases
11651408
func_to_op: Dict[str, str] = {v: k for k, v in dh.op_to_func.items()}
11661409
for stub in category_to_funcs["elementwise"]:
11671410
func_name = stub.__name__
@@ -1186,7 +1429,7 @@ def parse_binary_case_block(case_block: str, func_name: str) -> List[BinaryCase]
11861429
warn(f"{func=} has no parameters")
11871430
continue
11881431
if param_names[0] == "x":
1189-
if cases := parse_unary_case_block(case_block, func_name):
1432+
if cases := parse_unary_case_block(case_block, func_name, special_case_records):
11901433
name_to_func = {func_name: func}
11911434
if func_name in func_to_op.keys():
11921435
op_name = func_to_op[func_name]
@@ -1204,7 +1447,7 @@ def parse_binary_case_block(case_block: str, func_name: str) -> List[BinaryCase]
12041447
warn(f"{func=} has one parameter '{param_names[0]}' which is not named 'x'")
12051448
continue
12061449
if param_names[0] == "x1" and param_names[1] == "x2":
1207-
if cases := parse_binary_case_block(case_block, func_name):
1450+
if cases := parse_binary_case_block(case_block, func_name, special_case_records):
12081451
name_to_func = {func_name: func}
12091452
if func_name in func_to_op.keys():
12101453
op_name = func_to_op[func_name]
@@ -1249,6 +1492,22 @@ def parse_binary_case_block(case_block: str, func_name: str) -> List[BinaryCase]
12491492
assert len(iop_params) != 0
12501493

12511494

1495+
@pytest.fixture(scope="session", autouse=True)
1496+
def emit_special_case_records():
1497+
"""Emit all special case records at the start of test session."""
1498+
# This runs once at the beginning of the test session
1499+
if os.environ.get('ARRAY_API_TESTS_SPECIAL_CASES_VERBOSE') == '1':
1500+
print("\n" + "="*80)
1501+
print("SPECIAL CASE RECORDS")
1502+
print("="*80)
1503+
for record in special_case_records:
1504+
print(record)
1505+
print("="*80)
1506+
print(f"Total special cases: {len(special_case_records)}")
1507+
print("="*80 + "\n")
1508+
yield # Tests run after this point
1509+
1510+
12521511
@pytest.mark.parametrize("func_name, func, case", unary_params)
12531512
def test_unary(func_name, func, case):
12541513
with catch_warnings():
@@ -1257,10 +1516,24 @@ def test_unary(func_name, func, case):
12571516
# drawing multiple examples like a normal test, or just hard-coding a
12581517
# single example test case without using hypothesis.
12591518
filterwarnings('ignore', category=NonInteractiveExampleWarning)
1260-
in_value = case.cond_from_dtype(xp.float64).example()
1261-
x = xp.asarray(in_value, dtype=xp.float64)
1519+
1520+
# Use the is_complex flag to determine the appropriate dtype
1521+
if case.is_complex:
1522+
dtype = xp.complex128
1523+
in_value = case.cond_from_dtype(dtype).example()
1524+
else:
1525+
dtype = xp.float64
1526+
in_value = case.cond_from_dtype(dtype).example()
1527+
1528+
# Create array and compute result based on dtype
1529+
x = xp.asarray(in_value, dtype=dtype)
12621530
out = func(x)
1263-
out_value = float(out)
1531+
1532+
if case.is_complex:
1533+
out_value = complex(out)
1534+
else:
1535+
out_value = float(out)
1536+
12641537
assert case.check_result(in_value, out_value), (
12651538
f"out={out_value}, but should be {case.result_expr} [{func_name}()]\n"
12661539
)

0 commit comments

Comments
 (0)