Skip to content

Commit 5c4f121

Browse files
Fix pickle cache rewrite races (#364)
* Release v4.3.0 * Fix pickle cleanup rewrite race * Fix shared pickle cache locking on Windows * Keep pickle lockfiles out of cache dirs * Remove version bump from pickle bugfix PR * Harden pickle temp-file cleanup and lock fallback * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 98f137b commit 5c4f121

File tree

2 files changed

+236
-4
lines changed

2 files changed

+236
-4
lines changed

src/cachier/cores/pickle.py

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
# Licensed under the MIT license:
77
# http://www.opensource.org/licenses/MIT-license
88
# Copyright (c) 2016, Shay Palachy <shaypal5@gmail.com>
9+
import hashlib
910
import logging
1011
import os
1112
import pickle # for local caching
13+
import tempfile
1214
import time
1315
from contextlib import suppress
1416
from datetime import datetime, timedelta
@@ -28,6 +30,8 @@
2830
class _PickleCore(_BaseCore):
2931
"""The pickle core class for cachier."""
3032

33+
_SHARED_LOCK_SUFFIX = ".lock"
34+
3135
class CacheChangeHandler(PatternMatchingEventHandler):
3236
"""Handles cache-file modification events."""
3337

@@ -71,6 +75,10 @@ def on_modified(self, event) -> None:
7175
"""A Watchdog Event Handler method.""" # noqa: D401
7276
self._check_calculation()
7377

78+
def on_moved(self, event) -> None:
79+
"""A Watchdog Event Handler method.""" # noqa: D401
80+
self._check_calculation()
81+
7482
def __init__(
7583
self,
7684
hash_func: Optional[HashFunc],
@@ -97,6 +105,21 @@ def cache_fpath(self) -> str:
97105
os.makedirs(self.cache_dir, exist_ok=True)
98106
return os.path.abspath(os.path.join(os.path.realpath(self.cache_dir), self.cache_fname))
99107

108+
@property
109+
def _shared_lock_fpath(self) -> str:
110+
cache_hash = hashlib.sha256(self.cache_fpath.encode("utf-8")).hexdigest()
111+
candidate_dirs = (
112+
os.path.join(tempfile.gettempdir(), "cachier-locks"),
113+
os.path.join(os.path.dirname(self.cache_fpath), ".cachier-locks"),
114+
)
115+
for lock_dir in candidate_dirs:
116+
try:
117+
os.makedirs(lock_dir, exist_ok=True)
118+
return os.path.join(lock_dir, f"{cache_hash}{self._SHARED_LOCK_SUFFIX}")
119+
except OSError:
120+
continue
121+
return os.path.join(os.path.dirname(self.cache_fpath), f".{cache_hash}{self._SHARED_LOCK_SUFFIX}")
122+
100123
@staticmethod
101124
def _convert_legacy_cache_entry(
102125
entry: Union[dict, CacheEntry],
@@ -113,8 +136,8 @@ def _convert_legacy_cache_entry(
113136

114137
def _load_cache_dict(self) -> Dict[str, CacheEntry]:
115138
try:
116-
with portalocker.Lock(self.cache_fpath, mode="rb") as cf:
117-
cache = pickle.load(cast(IO[bytes], cf))
139+
with portalocker.Lock(self._shared_lock_fpath, mode="a+b"), open(self.cache_fpath, "rb") as cache_file:
140+
cache = pickle.load(cast(IO[bytes], cache_file))
118141
self._cache_used_fpath = str(self.cache_fpath)
119142
except (FileNotFoundError, EOFError):
120143
cache = {}
@@ -181,9 +204,29 @@ def _save_cache(
181204
fpath += f"_{separate_file_key}"
182205
elif hash_str is not None:
183206
fpath += f"_{hash_str}"
207+
parent_dir = os.path.dirname(fpath)
184208
with self.lock:
185-
with portalocker.Lock(fpath, mode="wb") as cf:
186-
pickle.dump(cache, cast(IO[bytes], cf), protocol=4)
209+
if isinstance(cache, CacheEntry):
210+
with portalocker.Lock(fpath, mode="wb") as cache_file:
211+
pickle.dump(cache, cast(IO[bytes], cache_file), protocol=4)
212+
else:
213+
with portalocker.Lock(self._shared_lock_fpath, mode="a+b"):
214+
temp_path = ""
215+
try:
216+
with tempfile.NamedTemporaryFile(
217+
mode="wb",
218+
dir=parent_dir,
219+
delete=False,
220+
) as temp_file:
221+
temp_path = temp_file.name
222+
pickle.dump(cache, cast(IO[bytes], temp_file), protocol=4)
223+
temp_file.flush()
224+
os.fsync(temp_file.fileno())
225+
os.replace(temp_path, fpath)
226+
finally:
227+
if temp_path:
228+
with suppress(FileNotFoundError):
229+
os.remove(temp_path)
187230
# the same as check for separate_file, but changed for typing
188231
if isinstance(cache, dict):
189232
self._cache_dict = cache
@@ -256,6 +299,7 @@ async def amark_entry_being_calculated(self, key: str) -> None:
256299
def mark_entry_not_calculated(self, key: str) -> None:
257300
if self.separate_files:
258301
self._mark_entry_not_calculated_separate_files(key)
302+
return # pragma: no cover
259303
with self.lock:
260304
cache = self.get_cache_dict()
261305
# that's ok, we don't need an entry in that case

tests/pickle_tests/test_pickle_core.py

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -709,6 +709,130 @@ def mock_func():
709709
core._save_cache({"key": "value"}, separate_file_key="test_key")
710710

711711

712+
@pytest.mark.pickle
713+
def test_save_cache_removes_temp_file_when_fsync_fails(tmp_path):
714+
"""Test _save_cache removes the temp file when fsync fails."""
715+
core = _PickleCore(
716+
hash_func=None,
717+
cache_dir=tmp_path,
718+
pickle_reload=False,
719+
wait_for_calc_timeout=10,
720+
separate_files=False,
721+
)
722+
723+
def mock_func():
724+
pass
725+
726+
core.set_func(mock_func)
727+
728+
with (
729+
patch("cachier.cores.pickle.os.fsync", side_effect=OSError("fsync failed")),
730+
pytest.raises(OSError, match="fsync failed"),
731+
):
732+
core._save_cache({"key": CacheEntry(value="value", time=datetime.now(), stale=False, _processing=False)})
733+
734+
assert list(tmp_path.iterdir()) == []
735+
736+
737+
@pytest.mark.pickle
738+
def test_save_cache_propagates_tempfile_creation_failure_without_cleanup_error(tmp_path):
739+
"""Test _save_cache handles temp-file creation failures before temp_path exists."""
740+
core = _PickleCore(
741+
hash_func=None,
742+
cache_dir=tmp_path,
743+
pickle_reload=False,
744+
wait_for_calc_timeout=10,
745+
separate_files=False,
746+
)
747+
748+
def mock_func():
749+
pass
750+
751+
core.set_func(mock_func)
752+
753+
with (
754+
patch("cachier.cores.pickle.tempfile.NamedTemporaryFile", side_effect=OSError("tempfile failed")),
755+
patch("cachier.cores.pickle.os.replace") as mock_replace,
756+
patch("cachier.cores.pickle.os.remove") as mock_remove,
757+
pytest.raises(OSError, match="tempfile failed"),
758+
):
759+
core._save_cache({"key": CacheEntry(value="value", time=datetime.now(), stale=False, _processing=False)})
760+
761+
mock_replace.assert_not_called()
762+
mock_remove.assert_not_called()
763+
assert list(tmp_path.iterdir()) == []
764+
765+
766+
@pytest.mark.pickle
767+
def test_shared_lock_fpath_falls_back_to_cache_dir_when_temp_dir_unwritable(tmp_path):
768+
"""Test _shared_lock_fpath falls back when the system temp dir is not writable."""
769+
core = _PickleCore(
770+
hash_func=None,
771+
cache_dir=tmp_path,
772+
pickle_reload=False,
773+
wait_for_calc_timeout=10,
774+
separate_files=False,
775+
)
776+
777+
def mock_func():
778+
pass
779+
780+
core.set_func(mock_func)
781+
782+
temp_lock_dir = os.path.join("/non-writable-temp", "cachier-locks")
783+
fallback_lock_dir = os.path.join(core.cache_dir, ".cachier-locks")
784+
785+
def mock_makedirs(path, exist_ok=False):
786+
if path in (core.cache_dir, fallback_lock_dir):
787+
return None
788+
if path == temp_lock_dir:
789+
raise PermissionError("temp dir not writable")
790+
raise AssertionError(f"Unexpected os.makedirs path: {path}")
791+
792+
with (
793+
patch("cachier.cores.pickle.tempfile.gettempdir", return_value="/non-writable-temp"),
794+
patch("cachier.cores.pickle.os.makedirs", side_effect=mock_makedirs),
795+
):
796+
assert core._shared_lock_fpath == os.path.join(
797+
fallback_lock_dir,
798+
f"{hashlib.sha256(core.cache_fpath.encode('utf-8')).hexdigest()}{core._SHARED_LOCK_SUFFIX}",
799+
)
800+
801+
802+
@pytest.mark.pickle
803+
def test_shared_lock_fpath_uses_cache_dir_file_when_lock_dirs_unwritable(tmp_path):
804+
"""Test _shared_lock_fpath falls back to a lockfile in the cache dir."""
805+
core = _PickleCore(
806+
hash_func=None,
807+
cache_dir=tmp_path,
808+
pickle_reload=False,
809+
wait_for_calc_timeout=10,
810+
separate_files=False,
811+
)
812+
813+
def mock_func():
814+
pass
815+
816+
core.set_func(mock_func)
817+
818+
temp_lock_dir = os.path.join("/non-writable-temp", "cachier-locks")
819+
fallback_lock_dir = os.path.join(core.cache_dir, ".cachier-locks")
820+
cache_hash = hashlib.sha256(core.cache_fpath.encode("utf-8")).hexdigest()
821+
822+
def mock_makedirs(path, exist_ok=False):
823+
if path == core.cache_dir:
824+
return None
825+
if path in (temp_lock_dir, fallback_lock_dir):
826+
raise PermissionError("lock dir not writable")
827+
raise AssertionError(f"Unexpected os.makedirs path: {path}")
828+
829+
with (
830+
patch("cachier.cores.pickle.tempfile.gettempdir", return_value="/non-writable-temp"),
831+
patch("cachier.cores.pickle.os.makedirs", side_effect=mock_makedirs),
832+
):
833+
assert core._shared_lock_fpath == os.path.join(core.cache_dir, f".{cache_hash}{core._SHARED_LOCK_SUFFIX}")
834+
835+
712836
@pytest.mark.pickle
713837
def test_set_entry_should_not_store(tmp_path):
714838
"""Test set_entry when value should not be stored."""
@@ -1053,6 +1177,70 @@ def mock_get_cache_dict():
10531177
assert result == "result"
10541178

10551179

1180+
@pytest.mark.pickle
1181+
def test_save_cache_keeps_existing_file_readable_during_write(tmp_path, monkeypatch):
1182+
"""Test that cache rewrites do not expose a truncated file to plain readers."""
1183+
core = _PickleCore(
1184+
hash_func=None,
1185+
cache_dir=tmp_path,
1186+
pickle_reload=False,
1187+
wait_for_calc_timeout=10,
1188+
separate_files=False,
1189+
)
1190+
1191+
def mock_func():
1192+
pass
1193+
1194+
core.set_func(mock_func)
1195+
1196+
initial_cache = {
1197+
"key1": CacheEntry(
1198+
value="result-1",
1199+
time=datetime.now(),
1200+
stale=False,
1201+
_processing=False,
1202+
)
1203+
}
1204+
updated_cache = {
1205+
**initial_cache,
1206+
"key2": CacheEntry(
1207+
value="result-2",
1208+
time=datetime.now(),
1209+
stale=False,
1210+
_processing=False,
1211+
),
1212+
}
1213+
core._save_cache(initial_cache)
1214+
1215+
dump_started = threading.Event()
1216+
allow_dump = threading.Event()
1217+
real_pickle_dump = pickle.dump
1218+
1219+
def blocking_dump(obj, fh, protocol):
1220+
if obj is updated_cache:
1221+
dump_started.set()
1222+
assert allow_dump.wait(timeout=5)
1223+
return real_pickle_dump(obj, fh, protocol=protocol)
1224+
1225+
monkeypatch.setattr("cachier.cores.pickle.pickle.dump", blocking_dump)
1226+
1227+
writer = threading.Thread(target=core._save_cache, args=(updated_cache,), daemon=True)
1228+
writer.start()
1229+
1230+
assert dump_started.wait(timeout=5)
1231+
with open(core.cache_fpath, "rb") as cache_file:
1232+
visible_cache = pickle.load(cache_file)
1233+
assert visible_cache == initial_cache
1234+
1235+
allow_dump.set()
1236+
writer.join(timeout=5)
1237+
assert not writer.is_alive()
1238+
1239+
with open(core.cache_fpath, "rb") as cache_file:
1240+
visible_cache = pickle.load(cache_file)
1241+
assert visible_cache == updated_cache
1242+
1243+
10561244
@pytest.mark.pickle
10571245
def test_wait_with_polling_calls_timeout_check_when_processing(tmp_path):
10581246
"""Test _wait_with_polling checks timeout while entry is processing."""

0 commit comments

Comments
 (0)