Skip to content

Commit c0fcd55

Browse files
Copilotev-br
andcommitted
Address code review feedback: improve regex and refactor comparison logic
- Updated regex to match multi-digit denominators (\d+) - Fixed inconsistent j stripping logic in parse_complex_value - Added helper function for component comparison to eliminate code duplication - Fixed capitalization in docstring examples Co-authored-by: ev-br <2133832+ev-br@users.noreply.github.com>
1 parent 0dcff21 commit c0fcd55

1 file changed

Lines changed: 18 additions & 22 deletions

File tree

array_api_tests/test_special_cases.py

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,7 @@ def parse_complex_value(value_str: str) -> complex:
507507
>>> parse_complex_value('+infinity + 3πj/4')
508508
(inf+2.356194490192345j)
509509
510-
Handles formats: "A + Bj", "A + B j", "A + πj/N", "A + NπJ/M"
510+
Handles formats: "A + Bj", "A + B j", "A + πj/N", "A + Nπj/M"
511511
"""
512512
m = r_complex_value.match(value_str)
513513
if m is None:
@@ -528,7 +528,7 @@ def parse_complex_value(value_str: str) -> complex:
528528
else: # plain form
529529
imag_val_str_raw = m.group(5)
530530
# Strip trailing 'j' if present: "0j" -> "0"
531-
imag_val_str = imag_val_str_raw.rstrip('j') if imag_val_str_raw.endswith('j') else imag_val_str_raw
531+
imag_val_str = imag_val_str_raw[:-1] if imag_val_str_raw.endswith('j') else imag_val_str_raw
532532

533533
imag_val = parse_value(imag_sign + imag_val_str)
534534

@@ -595,6 +595,19 @@ def complex_from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[complex]:
595595
return complex_cond, expr, complex_from_dtype
596596

597597

598+
def _check_component_with_tolerance(actual: float, expected: float, allow_any_sign: bool) -> bool:
599+
"""
600+
Helper to check if actual matches expected, with optional sign flexibility and tolerance.
601+
"""
602+
if allow_any_sign and not math.isnan(expected):
603+
return abs(actual) == abs(expected) or math.isclose(abs(actual), abs(expected), abs_tol=0.01)
604+
elif not math.isnan(expected):
605+
check_fn = make_strict_eq(expected) if expected == 0 or math.isinf(expected) else make_rough_eq(expected)
606+
return check_fn(actual)
607+
else:
608+
return math.isnan(actual)
609+
610+
598611
def parse_complex_result(result_str: str) -> Tuple[Callable[[complex], bool], str]:
599612
"""
600613
Parses a complex result string to return a checker and expression.
@@ -625,25 +638,8 @@ def parse_complex_result(result_str: str) -> Tuple[Callable[[complex], bool], st
625638
if has_pi:
626639
# Use approximate equality for both real and imaginary parts if they involve π
627640
def check_result(z: complex) -> bool:
628-
real_match = True
629-
imag_match = True
630-
631-
if unspecified_real_sign and not math.isnan(expected.real):
632-
real_match = abs(z.real) == abs(expected.real) or math.isclose(abs(z.real), abs(expected.real), abs_tol=0.01)
633-
elif not math.isnan(expected.real):
634-
real_check = make_strict_eq(expected.real) if expected.real == 0 or math.isinf(expected.real) else make_rough_eq(expected.real)
635-
real_match = real_check(z.real)
636-
else:
637-
real_match = math.isnan(z.real)
638-
639-
if unspecified_imag_sign and not math.isnan(expected.imag):
640-
imag_match = abs(z.imag) == abs(expected.imag) or math.isclose(abs(z.imag), abs(expected.imag), abs_tol=0.01)
641-
elif not math.isnan(expected.imag):
642-
imag_check = make_strict_eq(expected.imag) if expected.imag == 0 or math.isinf(expected.imag) else make_rough_eq(expected.imag)
643-
imag_match = imag_check(z.imag)
644-
else:
645-
imag_match = math.isnan(z.imag)
646-
641+
real_match = _check_component_with_tolerance(z.real, expected.real, unspecified_real_sign)
642+
imag_match = _check_component_with_tolerance(z.imag, expected.imag, unspecified_imag_sign)
647643
return real_match and imag_match
648644
elif unspecified_real_sign and not math.isnan(expected.real):
649645
# Allow any sign for real part
@@ -734,7 +730,7 @@ class UnaryCase(Case):
734730
# Matches complex values like "+0 + 0j", "NaN + NaN j", "infinity + NaN j", "πj/2", "3πj/4"
735731
# Two formats: 1) πj/N expressions where j is part of the coefficient, 2) plain values followed by j
736732
r_complex_value = re.compile(
737-
r"([+-]?)([^\s]+)\s*([+-])\s*(?:(\d*πj(?:/\d)?)|([^\s]+))\s*j?"
733+
r"([+-]?)([^\s]+)\s*([+-])\s*(?:(\d*πj(?:/\d+)?)|([^\s]+))\s*j?"
738734
)
739735

740736

0 commit comments

Comments
 (0)