Skip to content

Commit b2ba0ba

Browse files
Copilotshaypal5
andauthored
Add recursion depth limit to prevent stack overflow in hash function (#342)
* Initial plan * Add stack overflow protection to _update_hash_for_value with max depth check Co-authored-by: shaypal5 <917954+shaypal5@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: shaypal5 <917954+shaypal5@users.noreply.github.com>
1 parent c33c6be commit b2ba0ba

File tree

2 files changed

+170
-5
lines changed

2 files changed

+170
-5
lines changed

src/cachier/config.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,9 @@ def _hash_numpy_array(hasher: "hashlib._Hash", value: Any) -> None:
5656
hasher.update(value.tobytes(order="C"))
5757

5858

59-
def _update_hash_for_value(hasher: "hashlib._Hash", value: Any) -> None:
59+
def _update_hash_for_value(
60+
hasher: "hashlib._Hash", value: Any, depth: int = 0, max_depth: int = 100
61+
) -> None:
6062
"""Update hasher with a stable representation of a Python value.
6163
6264
Parameters
@@ -65,29 +67,45 @@ def _update_hash_for_value(hasher: "hashlib._Hash", value: Any) -> None:
6567
The hasher to update.
6668
value : Any
6769
Value to encode.
70+
depth : int, optional
71+
Current recursion depth (internal use only).
72+
max_depth : int, optional
73+
Maximum allowed recursion depth to prevent stack overflow.
74+
75+
Raises
76+
------
77+
RecursionError
78+
If the recursion depth exceeds max_depth.
6879
6980
"""
81+
if depth > max_depth:
82+
raise RecursionError(
83+
f"Maximum recursion depth ({max_depth}) exceeded while hashing nested "
84+
f"data structure. Consider flattening your data or using a custom "
85+
f"hash_func parameter."
86+
)
87+
7088
if _is_numpy_array(value):
7189
_hash_numpy_array(hasher, value)
7290
return
7391

7492
if isinstance(value, tuple):
7593
hasher.update(b"tuple")
7694
for item in value:
77-
_update_hash_for_value(hasher, item)
95+
_update_hash_for_value(hasher, item, depth + 1, max_depth)
7896
return
7997

8098
if isinstance(value, list):
8199
hasher.update(b"list")
82100
for item in value:
83-
_update_hash_for_value(hasher, item)
101+
_update_hash_for_value(hasher, item, depth + 1, max_depth)
84102
return
85103

86104
if isinstance(value, dict):
87105
hasher.update(b"dict")
88106
for dict_key in sorted(value):
89-
_update_hash_for_value(hasher, dict_key)
90-
_update_hash_for_value(hasher, value[dict_key])
107+
_update_hash_for_value(hasher, dict_key, depth + 1, max_depth)
108+
_update_hash_for_value(hasher, value[dict_key], depth + 1, max_depth)
91109
return
92110

93111
if isinstance(value, (set, frozenset)):

tests/test_recursion_depth.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
"""Tests for recursion depth protection in hash function."""
2+
3+
from datetime import timedelta
4+
5+
import pytest
6+
7+
from cachier import cachier
8+
9+
10+
@pytest.mark.parametrize(
11+
"backend",
12+
[
13+
pytest.param("memory", marks=pytest.mark.memory),
14+
pytest.param("pickle", marks=pytest.mark.pickle),
15+
],
16+
)
17+
def test_moderately_nested_structures_work(backend, tmp_path):
18+
"""Verify that moderately nested structures (< 100 levels) work fine."""
19+
call_count = 0
20+
21+
decorator_kwargs = {"backend": backend, "stale_after": timedelta(seconds=120)}
22+
if backend == "pickle":
23+
decorator_kwargs["cache_dir"] = tmp_path
24+
25+
@cachier(**decorator_kwargs)
26+
def process_nested(data):
27+
nonlocal call_count
28+
call_count += 1
29+
return "processed"
30+
31+
# Create a nested structure with 50 levels (well below the 100 limit)
32+
nested_list = []
33+
current = nested_list
34+
for _ in range(50):
35+
inner = []
36+
current.append(inner)
37+
current = inner
38+
current.append("leaf")
39+
40+
# Should work without issues
41+
result1 = process_nested(nested_list)
42+
assert result1 == "processed"
43+
assert call_count == 1
44+
45+
# Second call should hit cache
46+
result2 = process_nested(nested_list)
47+
assert result2 == "processed"
48+
assert call_count == 1
49+
50+
process_nested.clear_cache()
51+
52+
53+
@pytest.mark.parametrize(
54+
"backend",
55+
[
56+
pytest.param("memory", marks=pytest.mark.memory),
57+
pytest.param("pickle", marks=pytest.mark.pickle),
58+
],
59+
)
60+
def test_deeply_nested_structures_raise_error(backend, tmp_path):
61+
"""Verify that deeply nested structures (> 100 levels) raise RecursionError."""
62+
decorator_kwargs = {"backend": backend, "stale_after": timedelta(seconds=120)}
63+
if backend == "pickle":
64+
decorator_kwargs["cache_dir"] = tmp_path
65+
66+
@cachier(**decorator_kwargs)
67+
def process_nested(data):
68+
return "processed"
69+
70+
# Create a nested structure with 150 levels (exceeds the 100 limit)
71+
nested_list = []
72+
current = nested_list
73+
for _ in range(150):
74+
inner = []
75+
current.append(inner)
76+
current = inner
77+
current.append("leaf")
78+
79+
# Should raise RecursionError with a clear message
80+
with pytest.raises(
81+
RecursionError,
82+
match=r"Maximum recursion depth \(100\) exceeded while hashing nested",
83+
):
84+
process_nested(nested_list)
85+
86+
87+
@pytest.mark.parametrize(
88+
"backend",
89+
[
90+
pytest.param("memory", marks=pytest.mark.memory),
91+
pytest.param("pickle", marks=pytest.mark.pickle),
92+
],
93+
)
94+
def test_nested_dicts_respect_depth_limit(backend, tmp_path):
95+
"""Verify that nested dictionaries also respect the depth limit."""
96+
decorator_kwargs = {"backend": backend, "stale_after": timedelta(seconds=120)}
97+
if backend == "pickle":
98+
decorator_kwargs["cache_dir"] = tmp_path
99+
100+
@cachier(**decorator_kwargs)
101+
def process_dict(data):
102+
return "processed"
103+
104+
# Create nested dictionaries beyond the limit
105+
nested_dict = {}
106+
current = nested_dict
107+
for i in range(150):
108+
current[f"level_{i}"] = {}
109+
current = current[f"level_{i}"]
110+
current["leaf"] = "value"
111+
112+
# Should raise RecursionError
113+
with pytest.raises(
114+
RecursionError,
115+
match=r"Maximum recursion depth \(100\) exceeded while hashing nested",
116+
):
117+
process_dict(nested_dict)
118+
119+
120+
@pytest.mark.parametrize(
121+
"backend",
122+
[
123+
pytest.param("memory", marks=pytest.mark.memory),
124+
pytest.param("pickle", marks=pytest.mark.pickle),
125+
],
126+
)
127+
def test_nested_tuples_respect_depth_limit(backend, tmp_path):
128+
"""Verify that nested tuples also respect the depth limit."""
129+
decorator_kwargs = {"backend": backend, "stale_after": timedelta(seconds=120)}
130+
if backend == "pickle":
131+
decorator_kwargs["cache_dir"] = tmp_path
132+
133+
@cachier(**decorator_kwargs)
134+
def process_tuple(data):
135+
return "processed"
136+
137+
# Create nested tuples beyond the limit
138+
nested_tuple = ("leaf",)
139+
for _ in range(150):
140+
nested_tuple = (nested_tuple,)
141+
142+
# Should raise RecursionError
143+
with pytest.raises(
144+
RecursionError,
145+
match=r"Maximum recursion depth \(100\) exceeded while hashing nested",
146+
):
147+
process_tuple(nested_tuple)

0 commit comments

Comments
 (0)