Skip to content

Commit 2a2a4b4

Browse files
committed
fix(hooks): atomic read-modify-write in SessionStats.flush() (#1493)
Replace separate _locked_read() + _locked_write() in flush() with a single _locked_modify() helper that holds LOCK_EX for the entire read-modify-write window. This prevents concurrent PostToolUse processes from clobbering each other's updates. Add TestConcurrentFlush regression test (8 procs × 100 calls = 800 expected) that fails without the fix.
1 parent 64168b1 commit 2a2a4b4

2 files changed

Lines changed: 120 additions & 16 deletions

File tree

packages/claude-code-plugin/hooks/lib/stats.py

Lines changed: 71 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -96,24 +96,38 @@ def record_tool_call(self, tool_name: str, success: bool = True) -> None:
9696
self.flush()
9797

9898
def flush(self) -> None:
99-
"""Flush accumulated in-memory stats to disk."""
99+
"""Flush accumulated in-memory stats to disk.
100+
101+
Uses _locked_modify to perform atomic read-modify-write inside a
102+
single LOCK_EX window, preventing lost updates from concurrent
103+
processes (#1493).
104+
"""
100105
if self._pending_count == 0:
101106
return
102-
data = self._locked_read()
103-
data["tool_count"] = data.get("tool_count", 0) + self._mem_tool_count
104-
data["error_count"] = data.get("error_count", 0) + self._mem_error_count
105-
tool_names = data.get("tool_names", {})
106-
for name, count in self._mem_tool_names.items():
107-
tool_names[name] = tool_names.get(name, 0) + count
108-
data["tool_names"] = tool_names
109-
# Merge hook timings
110-
hook_timings = data.get("hook_timings", {})
111-
for name, times in self._mem_hook_timings.items():
112-
if name not in hook_timings:
113-
hook_timings[name] = []
114-
hook_timings[name].extend(times)
115-
data["hook_timings"] = hook_timings
116-
self._locked_write(data)
107+
108+
# Capture deltas before entering critical section
109+
delta_tool_count = self._mem_tool_count
110+
delta_error_count = self._mem_error_count
111+
delta_tool_names = dict(self._mem_tool_names)
112+
delta_hook_timings = {k: list(v) for k, v in self._mem_hook_timings.items()}
113+
114+
def apply_deltas(data: Dict[str, Any]) -> Dict[str, Any]:
115+
data["tool_count"] = data.get("tool_count", 0) + delta_tool_count
116+
data["error_count"] = data.get("error_count", 0) + delta_error_count
117+
tool_names = data.get("tool_names", {})
118+
for name, count in delta_tool_names.items():
119+
tool_names[name] = tool_names.get(name, 0) + count
120+
data["tool_names"] = tool_names
121+
hook_timings = data.get("hook_timings", {})
122+
for name, times in delta_hook_timings.items():
123+
if name not in hook_timings:
124+
hook_timings[name] = []
125+
hook_timings[name].extend(times)
126+
data["hook_timings"] = hook_timings
127+
return data
128+
129+
self._locked_modify(apply_deltas)
130+
117131
# Reset in-memory accumulators
118132
self._mem_tool_count = 0
119133
self._mem_error_count = 0
@@ -239,6 +253,47 @@ def cleanup_stale(data_dir: str, max_age_hours: int = 24) -> None:
239253
except OSError:
240254
pass
241255

256+
def _locked_modify(self, mutator: Any) -> None:
257+
"""Atomic read-modify-write inside a single LOCK_EX window (#1493).
258+
259+
Opens the stats file with exclusive lock, reads current data,
260+
applies *mutator(data) -> data*, then writes back — all without
261+
releasing the lock. This prevents the lost-update race where
262+
concurrent processes each read the same baseline.
263+
264+
Args:
265+
mutator: Callable (Dict -> Dict) that transforms the data
266+
dict in place or returns the updated dict.
267+
268+
Note: When HAS_FCNTL is False (non-Unix platforms), locking is
269+
skipped entirely. Concurrent flushes on such platforms may lose
270+
updates — this is a known limitation documented here for
271+
visibility.
272+
"""
273+
seed: Dict[str, Any] = {
274+
"session_id": self.session_id,
275+
"started_at": time.time(),
276+
"tool_count": 0,
277+
"error_count": 0,
278+
"tool_names": {},
279+
"hook_timings": {},
280+
}
281+
try:
282+
fd = os.open(self.stats_file, os.O_RDWR | os.O_CREAT)
283+
with os.fdopen(fd, "r+", encoding="utf-8") as f:
284+
if HAS_FCNTL:
285+
fcntl.flock(f.fileno(), fcntl.LOCK_EX)
286+
raw = f.read()
287+
data = json.loads(raw) if raw else dict(seed)
288+
data = mutator(data)
289+
f.seek(0)
290+
f.truncate()
291+
json.dump(data, f)
292+
except (json.JSONDecodeError, OSError):
293+
# File corrupted or missing — write seed with deltas applied
294+
data = mutator(dict(seed))
295+
self._locked_write(data)
296+
242297
def _locked_read(self) -> Dict[str, Any]:
243298
"""Read stats file with file locking."""
244299
try:

packages/claude-code-plugin/tests/test_stats.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,55 @@ def test_format_summary_no_timing_when_empty(self, stats):
250250
assert "⏱" not in summary
251251

252252

253+
class TestConcurrentFlush:
254+
"""Regression test for race condition in flush() (#1493).
255+
256+
Multiple processes calling record_tool_call() + flush() against the
257+
same session/data_dir must not lose updates.
258+
"""
259+
260+
@staticmethod
261+
def _worker(data_dir: str, session_id: str, n: int) -> None:
262+
"""Worker that records n tool calls and flushes each one."""
263+
import sys as _sys
264+
_lib = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "hooks", "lib")
265+
if _lib not in _sys.path:
266+
_sys.path.insert(0, _lib)
267+
from stats import SessionStats as _SS
268+
s = _SS(session_id=session_id, data_dir=data_dir, flush_interval=10)
269+
for _ in range(n):
270+
s.record_tool_call("Bash")
271+
s.flush()
272+
273+
def test_concurrent_flush_no_lost_updates(self, data_dir):
274+
"""8 processes x 100 calls = 800 total. Final disk count must be 800."""
275+
import multiprocessing as mp
276+
277+
session_id = "race-test"
278+
num_workers = 8
279+
calls_per_worker = 100
280+
expected = num_workers * calls_per_worker
281+
282+
# Seed the stats file
283+
SessionStats(session_id=session_id, data_dir=data_dir)
284+
285+
procs = [
286+
mp.Process(target=self._worker, args=(data_dir, session_id, calls_per_worker))
287+
for _ in range(num_workers)
288+
]
289+
for p in procs:
290+
p.start()
291+
for p in procs:
292+
p.join()
293+
294+
s = SessionStats(session_id=session_id, data_dir=data_dir)
295+
on_disk = s._locked_read()
296+
assert on_disk["tool_count"] == expected, (
297+
f"Expected {expected}, got {on_disk['tool_count']} — lost updates detected"
298+
)
299+
assert on_disk["tool_names"]["Bash"] == expected
300+
301+
253302
class TestCleanup:
254303
def test_cleanup_stale_removes_old_files(self, data_dir):
255304
# Create a stale file

0 commit comments

Comments
 (0)