Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions gigaevo/programs/stages/optimization/optuna/desubstitution.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _coerce_param_value(value: Any) -> Any:
if isinstance(value, str):
try:
parsed = ast.literal_eval(value)
except (ValueError, SyntaxError):
except (ValueError, SyntaxError, MemoryError, RecursionError, TypeError):
return value
# Coerce whole-number floats to int so range()/indexing works
if isinstance(parsed, float) and parsed == int(parsed):
Expand Down Expand Up @@ -229,16 +229,16 @@ def _clean_eval_in_source(code: str) -> str:
# Quoted string: eval('math.sqrt') — only strip if dotted name
try:
string_val = ast.literal_eval(inner)
except (ValueError, SyntaxError):
except (ValueError, SyntaxError, MemoryError, RecursionError, TypeError):
continue
if isinstance(string_val, str) and _DOTTED_NAME_RE.match(string_val):
replacement = string_val
else:
# Unquoted: eval([2, 3]), eval((1, 2)), eval(42), etc.
# Unquoted: eval([2, 3]), eval((1, 2)), eval(42), eval(-5), etc.
try:
ast.literal_eval(inner)
replacement = inner
except (ValueError, SyntaxError):
except (ValueError, SyntaxError, MemoryError, RecursionError, TypeError):
pass

if replacement is not None:
Expand Down Expand Up @@ -284,10 +284,15 @@ def visit_Call(self, node: ast.Call) -> ast.AST:
return ast.copy_location(result, node)
return node

# eval(<literal>) — already a valid AST node, just strip eval()
if isinstance(arg, (ast.Constant, ast.List, ast.Tuple, ast.Set, ast.Dict)):
return ast.copy_location(arg, node)
return node
# eval(<literal>) — use ast.literal_eval as the predicate so this
# path matches `_clean_eval_in_source` exactly (it also handles
# ``UnaryOp(USub|UAdd, Constant)`` like ``eval(-5)``, which the
# previous isinstance(...) check silently missed).
try:
ast.literal_eval(arg)
except (ValueError, SyntaxError, MemoryError, RecursionError, TypeError):
return node
return ast.copy_location(arg, node)


# ---------------------------------------------------------------------------
Expand Down
52 changes: 52 additions & 0 deletions tests/stages/test_desubstitution_edge_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@

from __future__ import annotations

import ast
import textwrap
from unittest.mock import patch

from gigaevo.programs.stages.optimization.optuna.desubstitution import (
_build_line_offsets,
_clean_eval_in_source,
_coerce_param_value,
_EvalCleaner,
_find_matching_close_paren,
coerce_params,
desubstitute_params,
Expand Down Expand Up @@ -87,6 +90,55 @@ def test_invalid_string_stays(self) -> None:
"""Strings that cause SyntaxError in literal_eval stay as-is."""
assert _coerce_param_value("[invalid") == "[invalid"

def test_literal_eval_memory_error_falls_back(self) -> None:
"""ast.literal_eval can raise MemoryError on huge literals; the
broader except must catch it and return the original string."""
target = "gigaevo.programs.stages.optimization.optuna.desubstitution.ast.literal_eval"
with patch(target, side_effect=MemoryError):
assert _coerce_param_value("[1, 2, 3]") == "[1, 2, 3]"

def test_literal_eval_recursion_error_falls_back(self) -> None:
"""Deeply-nested literals can hit RecursionError; same fallback."""
target = "gigaevo.programs.stages.optimization.optuna.desubstitution.ast.literal_eval"
with patch(target, side_effect=RecursionError):
assert _coerce_param_value("[1, 2, 3]") == "[1, 2, 3]"


class TestEvalCleanerASTPath:
"""Direct tests for _EvalCleaner; AST-level eval() stripping."""

@staticmethod
def _strip(code: str) -> str:
tree = ast.parse(code)
new = _EvalCleaner().visit(tree)
ast.fix_missing_locations(new)
return ast.unparse(new)

def test_negative_numeric_literal_is_stripped(self) -> None:
"""eval(-5) — UnaryOp(USub, Constant) — used to survive because the
isinstance check only handled bare Constant/List/Tuple/Set/Dict.
Now matches the source-level path."""
assert self._strip("x = eval(-5)") == "x = -5"
assert self._strip("x = eval(-3.14)") == "x = -3.14"

def test_non_literal_unaryop_stays(self) -> None:
"""eval(-some_var) is NOT a literal (operand is Name, not Constant)
and must stay wrapped."""
assert self._strip("x = eval(-some_var)") == "x = eval(-some_var)"

def test_eval_of_call_stays(self) -> None:
"""eval(foo()) is not a literal; must stay wrapped."""
assert self._strip("x = eval(foo())") == "x = eval(foo())"

def test_ast_path_matches_source_path_on_negative_literal(self) -> None:
"""Parity: AST-path and source-path now produce equivalent output
for the previously-divergent eval(-N) case."""
code = "x = eval(-5)"
ast_out = self._strip(code)
src_out = _clean_eval_in_source(code)
# Normalize both via ast round-trip for fair textual comparison.
assert ast.unparse(ast.parse(ast_out)) == ast.unparse(ast.parse(src_out))


class TestCoerceParams:
def test_dict_coercion(self) -> None:
Expand Down