diff --git a/gigaevo/programs/stages/optimization/optuna/desubstitution.py b/gigaevo/programs/stages/optimization/optuna/desubstitution.py index ae42457f..36dbff18 100644 --- a/gigaevo/programs/stages/optimization/optuna/desubstitution.py +++ b/gigaevo/programs/stages/optimization/optuna/desubstitution.py @@ -41,7 +41,7 @@ def _coerce_param_value(value: Any) -> Any: if isinstance(value, str): try: parsed = ast.literal_eval(value) - except (ValueError, SyntaxError): + except (ValueError, SyntaxError, MemoryError, RecursionError, TypeError): return value # Coerce whole-number floats to int so range()/indexing works if isinstance(parsed, float) and parsed == int(parsed): @@ -229,16 +229,16 @@ def _clean_eval_in_source(code: str) -> str: # Quoted string: eval('math.sqrt') — only strip if dotted name try: string_val = ast.literal_eval(inner) - except (ValueError, SyntaxError): + except (ValueError, SyntaxError, MemoryError, RecursionError, TypeError): continue if isinstance(string_val, str) and _DOTTED_NAME_RE.match(string_val): replacement = string_val else: - # Unquoted: eval([2, 3]), eval((1, 2)), eval(42), etc. + # Unquoted: eval([2, 3]), eval((1, 2)), eval(42), eval(-5), etc. try: ast.literal_eval(inner) replacement = inner - except (ValueError, SyntaxError): + except (ValueError, SyntaxError, MemoryError, RecursionError, TypeError): pass if replacement is not None: @@ -284,10 +284,15 @@ def visit_Call(self, node: ast.Call) -> ast.AST: return ast.copy_location(result, node) return node - # eval() — already a valid AST node, just strip eval() - if isinstance(arg, (ast.Constant, ast.List, ast.Tuple, ast.Set, ast.Dict)): - return ast.copy_location(arg, node) - return node + # eval() — use ast.literal_eval as the predicate so this + # path matches `_clean_eval_in_source` exactly (it also handles + # ``UnaryOp(USub|UAdd, Constant)`` like ``eval(-5)``, which the + # previous isinstance(...) check silently missed). + try: + ast.literal_eval(arg) + except (ValueError, SyntaxError, MemoryError, RecursionError, TypeError): + return node + return ast.copy_location(arg, node) # --------------------------------------------------------------------------- diff --git a/tests/stages/test_desubstitution_edge_cases.py b/tests/stages/test_desubstitution_edge_cases.py index b5aa8cc0..5555aa24 100644 --- a/tests/stages/test_desubstitution_edge_cases.py +++ b/tests/stages/test_desubstitution_edge_cases.py @@ -8,12 +8,15 @@ from __future__ import annotations +import ast import textwrap +from unittest.mock import patch from gigaevo.programs.stages.optimization.optuna.desubstitution import ( _build_line_offsets, _clean_eval_in_source, _coerce_param_value, + _EvalCleaner, _find_matching_close_paren, coerce_params, desubstitute_params, @@ -87,6 +90,55 @@ def test_invalid_string_stays(self) -> None: """Strings that cause SyntaxError in literal_eval stay as-is.""" assert _coerce_param_value("[invalid") == "[invalid" + def test_literal_eval_memory_error_falls_back(self) -> None: + """ast.literal_eval can raise MemoryError on huge literals; the + broader except must catch it and return the original string.""" + target = "gigaevo.programs.stages.optimization.optuna.desubstitution.ast.literal_eval" + with patch(target, side_effect=MemoryError): + assert _coerce_param_value("[1, 2, 3]") == "[1, 2, 3]" + + def test_literal_eval_recursion_error_falls_back(self) -> None: + """Deeply-nested literals can hit RecursionError; same fallback.""" + target = "gigaevo.programs.stages.optimization.optuna.desubstitution.ast.literal_eval" + with patch(target, side_effect=RecursionError): + assert _coerce_param_value("[1, 2, 3]") == "[1, 2, 3]" + + +class TestEvalCleanerASTPath: + """Direct tests for _EvalCleaner; AST-level eval() stripping.""" + + @staticmethod + def _strip(code: str) -> str: + tree = ast.parse(code) + new = _EvalCleaner().visit(tree) + ast.fix_missing_locations(new) + return ast.unparse(new) + + def test_negative_numeric_literal_is_stripped(self) -> None: + """eval(-5) — UnaryOp(USub, Constant) — used to survive because the + isinstance check only handled bare Constant/List/Tuple/Set/Dict. + Now matches the source-level path.""" + assert self._strip("x = eval(-5)") == "x = -5" + assert self._strip("x = eval(-3.14)") == "x = -3.14" + + def test_non_literal_unaryop_stays(self) -> None: + """eval(-some_var) is NOT a literal (operand is Name, not Constant) + and must stay wrapped.""" + assert self._strip("x = eval(-some_var)") == "x = eval(-some_var)" + + def test_eval_of_call_stays(self) -> None: + """eval(foo()) is not a literal; must stay wrapped.""" + assert self._strip("x = eval(foo())") == "x = eval(foo())" + + def test_ast_path_matches_source_path_on_negative_literal(self) -> None: + """Parity: AST-path and source-path now produce equivalent output + for the previously-divergent eval(-N) case.""" + code = "x = eval(-5)" + ast_out = self._strip(code) + src_out = _clean_eval_in_source(code) + # Normalize both via ast round-trip for fair textual comparison. + assert ast.unparse(ast.parse(ast_out)) == ast.unparse(ast.parse(src_out)) + class TestCoerceParams: def test_dict_coercion(self) -> None: