Skip to content

Commit fb3d2bf

Browse files
committed
feat: Add a new eval helper function
1 parent a975d69 commit fb3d2bf

2 files changed

Lines changed: 301 additions & 0 deletions

File tree

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#!/usr/bin/env python3
2+
3+
import ast
4+
5+
6+
def saferEval(obj_str, max_len=2048):
7+
"""This function adds an extra length check around literal_eval.
8+
On python3.11 and above (which has a recursion guard), this should
9+
be safe enough for use on general authenticated user input.
10+
11+
Note: This doesn't handle all of the cases of eval, such as
12+
datetime as those are technically executing code to
13+
instantiate the non-base objects.
14+
"""
15+
# Ensure input is a string
16+
obj_str = str(obj_str)
17+
if len(obj_str) > max_len:
18+
raise ValueError(f"Object string is too long (>{max_len} bytes)")
19+
try:
20+
return ast.literal_eval(obj_str)
21+
except (ValueError, TypeError, SyntaxError):
22+
# This covers all of the cases where the string is wrong (unclosed brackets...)
23+
# or contains disallowed items like function calls or non-expression.
24+
raise ValueError("Syntax error processing object expression")
25+
except (MemoryError, RecursionError):
26+
# This is encountered if the object is nested too deeply and other structures
27+
# that are probably malicious.
28+
raise ValueError("Object expression too large")
29+
except Exception:
30+
# There are no other possible exceptions at the time of writing,
31+
# this is to catch any added in future python versions.
32+
raise ValueError("Unknown error processing object expression")
Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
"""Test saferEval function from DIRAC.Core.Utilities.saferEval"""
2+
3+
4+
import time
5+
import unittest
6+
7+
from DIRAC.Core.Utilities.SaferEval import saferEval
8+
9+
10+
class Test_saferEval(unittest.TestCase):
11+
# --- Normal Python literal inputs ---
12+
13+
def test_none(self):
14+
self.assertEqual(saferEval("None"), None)
15+
16+
def test_bool_true(self):
17+
self.assertEqual(saferEval("True"), True)
18+
19+
def test_bool_false(self):
20+
self.assertEqual(saferEval("False"), False)
21+
22+
def test_int_zero(self):
23+
self.assertEqual(saferEval("0"), 0)
24+
25+
def test_int_positive(self):
26+
self.assertEqual(saferEval("42"), 42)
27+
28+
def test_int_negative(self):
29+
self.assertEqual(saferEval("-17"), -17)
30+
31+
def test_int_hex(self):
32+
self.assertEqual(saferEval("0xFF"), 0xFF)
33+
34+
def test_int_octal(self):
35+
self.assertEqual(saferEval("0o77"), 0o77)
36+
37+
def test_int_binary(self):
38+
self.assertEqual(saferEval("0b1010"), 0b1010)
39+
40+
def test_float(self):
41+
self.assertEqual(saferEval("3.14"), 3.14)
42+
43+
def test_float_scientific(self):
44+
self.assertEqual(saferEval("1e10"), 1e10)
45+
46+
def test_complex(self):
47+
self.assertEqual(saferEval("1j"), 1j)
48+
49+
def test_str_double_quoted(self):
50+
self.assertEqual(saferEval('"hello"'), "hello")
51+
52+
def test_str_single_quoted(self):
53+
self.assertEqual(saferEval("'hello'"), "hello")
54+
55+
def test_str_escaped(self):
56+
self.assertEqual(saferEval("'a\\nb'"), "a\nb")
57+
58+
def test_str_raw(self):
59+
self.assertEqual(saferEval(r"r'\n'"), r"\n")
60+
61+
def test_str_unicode(self):
62+
self.assertEqual(saferEval('"hello 🌍"'), "hello 🌍")
63+
64+
def test_bytes(self):
65+
self.assertEqual(saferEval("b'bytes'"), b"bytes")
66+
67+
def test_bytes_escape(self):
68+
self.assertEqual(saferEval("b'\\xff'"), b"\xff")
69+
70+
def test_empty_list(self):
71+
self.assertEqual(saferEval("[]"), [])
72+
73+
def test_list_mixed(self):
74+
self.assertEqual(saferEval("[1, 'two', True, None]"), [1, "two", True, None])
75+
76+
def test_empty_tuple(self):
77+
self.assertEqual(saferEval("()"), ())
78+
79+
def test_singleton_tuple(self):
80+
self.assertEqual(saferEval("(1,)"), (1,))
81+
82+
def test_tuple(self):
83+
self.assertEqual(saferEval("(1, 2, 3)"), (1, 2, 3))
84+
85+
def test_empty_dict(self):
86+
self.assertEqual(saferEval("{}"), {})
87+
88+
def test_dict(self):
89+
self.assertEqual(saferEval("{'a': 1, 'b': 2}"), {"a": 1, "b": 2})
90+
91+
def test_set(self):
92+
self.assertEqual(saferEval("{1, 2, 3}"), {1, 2, 3})
93+
94+
def test_nested_list(self):
95+
self.assertEqual(saferEval("[[1, 2], [3, 4]]"), [[1, 2], [3, 4]])
96+
97+
def test_nested_dict(self):
98+
result = saferEval("{'a': {'b': {'c': [1, 2]}}}")
99+
self.assertEqual(result, {"a": {"b": {"c": [1, 2]}}})
100+
101+
# --- Invalid inputs (should raise ValueError) ---
102+
103+
def test_code_execution(self):
104+
self.assertRaises(ValueError, saferEval, "__import__('os').system('id')")
105+
106+
def test_code_execution_open(self):
107+
self.assertRaises(ValueError, saferEval, "open('/etc/passwd').read()")
108+
109+
def test_function_call(self):
110+
self.assertRaises(ValueError, saferEval, "list()")
111+
112+
def test_variable_reference(self):
113+
self.assertRaises(ValueError, saferEval, "foo")
114+
115+
def test_class_instantiation(self):
116+
self.assertRaises(ValueError, saferEval, "datetime.datetime.now()")
117+
118+
def test_lambda(self):
119+
self.assertRaises(ValueError, saferEval, "lambda x: x")
120+
121+
def test_dict_comprehension(self):
122+
self.assertRaises(ValueError, saferEval, "{k: v for k, v in []}")
123+
124+
def test_generator_expression(self):
125+
self.assertRaises(ValueError, saferEval, "(x for x in [])")
126+
127+
def test_name_comparison(self):
128+
self.assertRaises(ValueError, saferEval, "x == y")
129+
130+
def test_binary_ops(self):
131+
self.assertRaises(ValueError, saferEval, "1 + 2")
132+
133+
def test_attribute_access(self):
134+
self.assertRaises(ValueError, saferEval, "().__class__")
135+
136+
def test_subscript(self):
137+
self.assertRaises(ValueError, saferEval, "x[0]")
138+
139+
def test_slice(self):
140+
self.assertRaises(ValueError, saferEval, "[1,2][1:]")
141+
142+
def test_starred(self):
143+
self.assertRaises(ValueError, saferEval, "*1")
144+
145+
# --- Security edge cases ---
146+
147+
def test_builtin_name(self):
148+
self.assertRaises(ValueError, saferEval, "builtins.open")
149+
150+
def test_class_constructor(self):
151+
self.assertRaises(ValueError, saferEval, "object()")
152+
153+
def test_custom_class(self):
154+
self.assertRaises(ValueError, saferEval, "MyList()")
155+
156+
def test_fstring(self):
157+
self.assertRaises(ValueError, saferEval, "f'{1+2}'")
158+
159+
def test_decorator(self):
160+
self.assertRaises(ValueError, saferEval, "@decorator")
161+
162+
def test_assert_statement(self):
163+
self.assertRaises(ValueError, saferEval, "assert True")
164+
165+
def test_return_statement(self):
166+
self.assertRaises(ValueError, saferEval, "return 42")
167+
168+
def test_augmented_assignment(self):
169+
self.assertRaises(ValueError, saferEval, "x += 1")
170+
171+
def test_with_statement(self):
172+
self.assertRaises(ValueError, saferEval, "with open('x') as f: pass")
173+
174+
def test_for_loop(self):
175+
self.assertRaises(ValueError, saferEval, "for x in []: pass")
176+
177+
def test_try_except(self):
178+
self.assertRaises(ValueError, saferEval, "try: pass\nexcept: pass")
179+
180+
def test_import(self):
181+
self.assertRaises(ValueError, saferEval, "import os")
182+
183+
def test_from_import(self):
184+
self.assertRaises(ValueError, saferEval, "from os import path")
185+
186+
def test_del_statement(self):
187+
self.assertRaises(ValueError, saferEval, "del x")
188+
189+
def test_raise_statement(self):
190+
self.assertRaises(ValueError, saferEval, "raise ValueError('x')")
191+
192+
def test_yield(self):
193+
self.assertRaises(ValueError, saferEval, "yield 1")
194+
195+
def test_await(self):
196+
self.assertRaises(ValueError, saferEval, "await something")
197+
198+
def test_named_expression(self):
199+
self.assertRaises(ValueError, saferEval, "(x := 1)")
200+
201+
def test_positional_only_lambda(self):
202+
self.assertRaises(ValueError, saferEval, "(lambda x, /: x)(1)")
203+
204+
def test_power_arithmetic(self):
205+
# 10**200 is an expression, not a literal
206+
self.assertRaises(ValueError, saferEval, "10**200")
207+
208+
# --- Recursion limit ---
209+
210+
def test_deep_nesting(self):
211+
depth = 2000
212+
s = "[" * depth + "1" + "]" * depth
213+
try:
214+
result = saferEval(s)
215+
self.assertEqual(result, 1)
216+
except (ValueError, RecursionError):
217+
# Either ast.literal_eval's recursion guard (3.11+) or Python's
218+
# built-in recursion limit catches it — both are safe.
219+
pass
220+
221+
def test_deep_nesting_within_max_len(self):
222+
depth = 500
223+
s = "[" * depth + "1" + "]" * depth
224+
try:
225+
saferEval(s)
226+
except (ValueError, RecursionError):
227+
pass
228+
229+
# --- Max length ---
230+
231+
def test_max_len_boundary(self):
232+
self.assertEqual(saferEval("42"), 42)
233+
234+
def test_max_len_exceeded(self):
235+
s = "1" * 2049
236+
self.assertRaises(ValueError, saferEval, s, 2048)
237+
238+
def test_max_len_custom(self):
239+
self.assertRaises(ValueError, saferEval, "[1, 2, 3]", 5)
240+
241+
def test_max_len_under_custom(self):
242+
self.assertEqual(saferEval("[1, 2, 3]", 10), [1, 2, 3])
243+
244+
# --- Memory explosion prevention ---
245+
246+
def test_large_string_literal(self):
247+
s = "'" + "a" * 3000 + "'"
248+
self.assertRaises(ValueError, saferEval, s, 2048)
249+
250+
def test_large_list(self):
251+
self.assertRaises(ValueError, saferEval, str([1] * 3000), 2048)
252+
253+
# --- Performance ---
254+
255+
def test_normal_dict_performance(self):
256+
parts = ['"key": 0'] + ['"k{0}": {0}'.format(i) for i in range(50)]
257+
s = "{" + ", ".join(parts) + "}"
258+
start = time.time()
259+
saferEval(s, 2048)
260+
self.assertLess(time.time() - start, 0.1)
261+
262+
def test_normal_list_performance(self):
263+
s = "[" + ",".join(str(i) for i in range(50)) + "]"
264+
start = time.time()
265+
saferEval(s, 2048)
266+
self.assertLess(time.time() - start, 0.1)
267+
268+
if __name__ == "__main__":
269+
unittest.main()

0 commit comments

Comments
 (0)