Skip to content

Commit 0dcff21

Browse files
Copilotev-br
andcommitted
Fix regex and parsing for complex π expressions
- Updated r_complex_value regex to handle πj/N and NπJ/M patterns - Modified parse_complex_value to extract imaginary coefficient from both πj and plain formats - Updated parse_complex_result to use approximate equality for π-based values - Test failures reduced from 23 to 9 (14 failures fixed) Co-authored-by: ev-br <2133832+ev-br@users.noreply.github.com>
1 parent c87a8bb commit 0dcff21

1 file changed

Lines changed: 49 additions & 10 deletions

File tree

array_api_tests/test_special_cases.py

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -502,10 +502,13 @@ def parse_complex_value(value_str: str) -> complex:
502502
(nan+nanj)
503503
>>> parse_complex_value('0 + NaN j')
504504
nanj
505+
>>> parse_complex_value('+0 + πj/2')
506+
1.5707963267948966j
507+
>>> parse_complex_value('+infinity + 3πj/4')
508+
(inf+2.356194490192345j)
505509
506-
Handles both "0j" and "0 j" formats with optional spaces.
510+
Handles formats: "A + Bj", "A + B j", "A + πj/N", "A + NπJ/M"
507511
"""
508-
# Handle the format like "+0 + 0j" or "NaN + NaN j"
509512
m = r_complex_value.match(value_str)
510513
if m is None:
511514
raise ParseError(value_str)
@@ -517,7 +520,16 @@ def parse_complex_value(value_str: str) -> complex:
517520

518521
# Parse imaginary part with its sign
519522
imag_sign = m.group(3)
520-
imag_val_str = m.group(4)
523+
# Group 4 is πj form (e.g., "πj/2"), group 5 is plain form (e.g., "NaN")
524+
if m.group(4): # πj form
525+
imag_val_str_raw = m.group(4)
526+
# Remove 'j' to get coefficient: "πj/2" -> "π/2"
527+
imag_val_str = imag_val_str_raw.replace('j', '')
528+
else: # plain form
529+
imag_val_str_raw = m.group(5)
530+
# Strip trailing 'j' if present: "0j" -> "0"
531+
imag_val_str = imag_val_str_raw.rstrip('j') if imag_val_str_raw.endswith('j') else imag_val_str_raw
532+
521533
imag_val = parse_value(imag_sign + imag_val_str)
522534

523535
return complex(real_val, imag_val)
@@ -589,10 +601,10 @@ def parse_complex_result(result_str: str) -> Tuple[Callable[[complex], bool], st
589601
590602
Handles cases like:
591603
- "``+0 + 0j``" - exact complex value
592-
- "``0 + NaN j`` (sign of the real component is unspecified)" - allow any sign for real
593-
- "``NaN + NaN j``" - both parts NaN
604+
- "``0 + NaN j`` (sign of the real component is unspecified)"
605+
- "``+0 + πj/2``" - with π expressions (uses approximate equality)
594606
"""
595-
# Check for unspecified sign note
607+
# Check for unspecified sign notes
596608
unspecified_real_sign = "sign of the real component is unspecified" in result_str
597609
unspecified_imag_sign = "sign of the imaginary component is unspecified" in result_str
598610

@@ -601,13 +613,39 @@ def parse_complex_result(result_str: str) -> Tuple[Callable[[complex], bool], st
601613
m = re.search(r"``([^`]+)``", result_str)
602614
if m:
603615
value_str = m.group(1)
616+
# Check if the value contains π expressions (for approximate comparison)
617+
has_pi = 'π' in value_str
618+
604619
try:
605620
expected = parse_complex_value(value_str)
606621
except ParseError:
607622
raise ParseError(result_str)
608623

609-
# Create checker based on whether signs are unspecified
610-
if unspecified_real_sign and not math.isnan(expected.real):
624+
# Create checker based on whether signs are unspecified and whether π is involved
625+
if has_pi:
626+
# Use approximate equality for both real and imaginary parts if they involve π
627+
def check_result(z: complex) -> bool:
628+
real_match = True
629+
imag_match = True
630+
631+
if unspecified_real_sign and not math.isnan(expected.real):
632+
real_match = abs(z.real) == abs(expected.real) or math.isclose(abs(z.real), abs(expected.real), abs_tol=0.01)
633+
elif not math.isnan(expected.real):
634+
real_check = make_strict_eq(expected.real) if expected.real == 0 or math.isinf(expected.real) else make_rough_eq(expected.real)
635+
real_match = real_check(z.real)
636+
else:
637+
real_match = math.isnan(z.real)
638+
639+
if unspecified_imag_sign and not math.isnan(expected.imag):
640+
imag_match = abs(z.imag) == abs(expected.imag) or math.isclose(abs(z.imag), abs(expected.imag), abs_tol=0.01)
641+
elif not math.isnan(expected.imag):
642+
imag_check = make_strict_eq(expected.imag) if expected.imag == 0 or math.isinf(expected.imag) else make_rough_eq(expected.imag)
643+
imag_match = imag_check(z.imag)
644+
else:
645+
imag_match = math.isnan(z.imag)
646+
647+
return real_match and imag_match
648+
elif unspecified_real_sign and not math.isnan(expected.real):
611649
# Allow any sign for real part
612650
def check_result(z: complex) -> bool:
613651
imag_check = make_strict_eq(expected.imag)
@@ -693,9 +731,10 @@ class UnaryCase(Case):
693731
r"For complex floating-point operands, let ``a = real\(x_i\)``, ``b = imag\(x_i\)``"
694732
)
695733
r_complex_case = re.compile(r"If ``a`` is (.+) and ``b`` is (.+), the result is (.+)")
696-
# Matches complex values like "+0 + 0j", "NaN + NaN j", "infinity + NaN j"
734+
# Matches complex values like "+0 + 0j", "NaN + NaN j", "infinity + NaN j", "πj/2", "3πj/4"
735+
# Two formats: 1) πj/N expressions where j is part of the coefficient, 2) plain values followed by j
697736
r_complex_value = re.compile(
698-
r"([+-]?)([^\s]+)\s*([+-])\s*([^\s]+)\s*j"
737+
r"([+-]?)([^\s]+)\s*([+-])\s*(?:(\d*πj(?:/\d)?)|([^\s]+))\s*j?"
699738
)
700739

701740

0 commit comments

Comments
 (0)