Skip to content

Commit 2201d67

Browse files
committed
feat: handle inf and -inf values returned from AnyBody
1 parent 7faa692 commit 2201d67

2 files changed

Lines changed: 62 additions & 8 deletions

File tree

anypytools/tools.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -838,9 +838,17 @@ def _recursive_replace(iterable: Iterable, old: Any, new: Any):
838838
elif elem == old:
839839
iterable[i] = new
840840

841+
def _replace_nan_and_inf(iterable: Iterable):
842+
for i, elem in enumerate(iterable):
843+
if isinstance(elem, list):
844+
_replace_nan_and_inf(elem)
845+
elif elem in ("nan", "-nan", "inf", "-inf"):
846+
iterable[i] = float(elem)
847+
841848

842849
TRIPEL_QUOTE_WRAP = re.compile(r'([^\[\]",\s]+)')
843850

851+
QUOTE_INF_NAN = re.compile(r"([-]?(nan|inf))(?=[,\]])")
844852

845853
def _parse_data(val):
846854
"""Convert a str AnyBody data repr into Numpy array."""
@@ -852,15 +860,11 @@ def _parse_data(val):
852860
out = literal_eval(val)
853861
except (SyntaxError, ValueError):
854862
try:
855-
if "nan," in val or "nan]" in val:
856-
# handle the case where AnyBody has output 'nan' values
857-
val2 = val.replace("-nan", "nan")
858-
val2 = val2.replace("nan,", ' "nan",')
859-
val2 = val2.replace("nan]", ' "nan"]')
860-
out = literal_eval(val2)
861-
_recursive_replace(out, "nan", float("nan"))
862-
else:
863+
val2, n_replacements = QUOTE_INF_NAN.subn(r'"\1"', val )
864+
if not n_replacements:
863865
raise SyntaxError
866+
out = literal_eval(val2)
867+
_replace_nan_and_inf(out)
864868
except (SyntaxError, ValueError):
865869
val, _ = TRIPEL_QUOTE_WRAP.subn(r"'''\1'''", val)
866870
if val == "":

tests/test_tools.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import numpy as np
1212

1313
from anypytools.tools import (
14+
_parse_data,
1415
array2anyscript,
1516
get_anybodycon_path,
1617
define2str,
@@ -146,6 +147,55 @@ def test_get_anybodycon_path():
146147
assert os.path.exists(abc)
147148

148149

150+
@pytest.mark.parametrize(
151+
"str_val, expected, expected_dtype",
152+
[
153+
(
154+
"""{{{0.9959166, nan, -nan}, {-0.08322453, 0.952727, -0.292207}, {0.0349831, 0.2958367, 0.9545977}},{{0.996183, -0.0689204, 0.05356734}, {0.07839193, 0.9763008, -0.2017211}, {-0.03839513, 0.2051504, 0.9779771}}}""",
155+
np.array([[[0.9959166, float("nan"), float("nan") ], [-0.08322453, 0.952727, -0.292207], [0.0349831, 0.2958367, 0.9545977]],[[0.996183, -0.0689204, 0.05356734], [0.07839193, 0.9763008, -0.2017211], [-0.03839513, 0.2051504, 0.9779771]]]),
156+
np.float64,
157+
),
158+
(
159+
"""{{inf, 4.0}, {3.0, -inf}}""",
160+
np.array([[float("inf"), 4.0],[3.0, float("-inf")]]),
161+
np.float64,
162+
),
163+
(
164+
"""{1.0, inf, 3.0}""",
165+
np.array([1.0, float("inf"), 3.0]),
166+
np.float64,
167+
),
168+
]
169+
)
170+
def test_parse_anybodydata_arrays(str_val, expected, expected_dtype):
171+
data_np = _parse_data(str_val)
172+
assert data_np.dtype == expected_dtype
173+
assert np.isclose(data_np, expected, equal_nan=True).all()
174+
175+
176+
@pytest.mark.parametrize(
177+
"str_val, expected, expected_type",
178+
[
179+
("1.0", 1.0, float),
180+
('"some_str"', "some_str", str),
181+
("1", 1, int),
182+
]
183+
)
184+
def test_parse_anybodydata_scalars(str_val, expected, expected_type):
185+
out = _parse_data(str_val)
186+
assert isinstance(out, expected_type)
187+
assert out == expected
188+
189+
190+
191+
192+
193+
194+
195+
196+
197+
198+
149199
if __name__ == "__main__":
150200
os.chdir(Path(__file__).parent)
151201
pytest.main([str("test_tools.py::test_AnyPyProcessOutputList_to_dataframe")])

0 commit comments

Comments
 (0)