Skip to content

Commit 25b5f9d

Browse files
committed
examples: add session_monitor.py — behavioral consistency tracking via existing hooks
1 parent 566e41f commit 25b5f9d

File tree

1 file changed

+339
-0
lines changed

1 file changed

+339
-0
lines changed

examples/session_monitor.py

Lines changed: 339 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,339 @@
1+
#!/usr/bin/env python
2+
"""session_monitor.py — behavioral consistency monitoring using existing SDK hooks.
3+
4+
Demonstrates how to use the claude-agent-sdk-python hooks surface (PostToolUse,
5+
PreToolUse, SessionStart) to build a lightweight behavioral fingerprint that
6+
detects drift across long sessions.
7+
8+
Works with the current SDK surface today. The patterns here also motivate the
9+
OnCompaction + OnContextThreshold hooks proposed in Issue #772, which would allow
10+
earlier interception rather than inferring boundaries from token count changes.
11+
12+
Usage:
13+
python examples/session_monitor.py
14+
15+
What it shows:
16+
- Tracking tool call distribution across turns via PostToolUse
17+
- Detecting token-count drops between turns (heuristic compaction boundary)
18+
- Capturing a pre-session vocabulary baseline via SessionStart
19+
- Computing behavioral drift score: did the agent's output profile change?
20+
- Logging compaction-boundary events for downstream analysis
21+
22+
Context:
23+
Long-running agents hit context limits, triggering compaction/summarization.
24+
After compaction, the agent may lose task-specific vocabulary, shift its tool
25+
call mix, or change its response style — behavioral drift that is invisible to
26+
the user and often undetected by the agent itself. This example shows how to
27+
measure it using the hooks the SDK already has.
28+
29+
Reference: https://github.com/anthropics/claude-agent-sdk-python/issues/772
30+
"""
31+
32+
import asyncio
33+
import json
34+
import math
35+
import re
36+
import time
37+
from collections import Counter, defaultdict
38+
from dataclasses import dataclass, field
39+
from pathlib import Path
40+
from typing import Any, Optional
41+
42+
from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient
43+
from claude_agent_sdk.types import (
44+
AssistantMessage,
45+
HookContext,
46+
HookInput,
47+
HookJSONOutput,
48+
Message,
49+
ResultMessage,
50+
TextBlock,
51+
ToolUseBlock,
52+
)
53+
54+
55+
# ---------------------------------------------------------------------------
56+
# Behavioral snapshot — what the agent looks like at one point in time
57+
# ---------------------------------------------------------------------------
58+
59+
@dataclass
60+
class BehavioralSnapshot:
61+
turn: int
62+
tokens: int
63+
timestamp: float
64+
tool_counts: Counter = field(default_factory=Counter)
65+
output_tokens: list[int] = field(default_factory=list)
66+
vocabulary: set[str] = field(default_factory=set)
67+
68+
69+
# ---------------------------------------------------------------------------
70+
# Session monitor — accumulates snapshots, detects compaction, scores drift
71+
# ---------------------------------------------------------------------------
72+
73+
class SessionMonitor:
74+
"""
75+
Monitors behavioral consistency across a claude-agent-sdk session.
76+
77+
Connects to the SDK via hook callbacks. Each hook updates an internal
78+
snapshot. After each turn, compute_drift() compares the current snapshot
79+
to the baseline and flags anomalies.
80+
"""
81+
82+
def __init__(
83+
self,
84+
compaction_drop_ratio: float = 0.20, # token count drops > 20% → suspect compaction
85+
drift_threshold: float = 0.30, # CCS below 0.70 → drift alert
86+
log_path: Optional[Path] = None,
87+
):
88+
self.compaction_drop_ratio = compaction_drop_ratio
89+
self.drift_threshold = drift_threshold
90+
self.log_path = log_path
91+
92+
self._baseline: Optional[BehavioralSnapshot] = None
93+
self._current: Optional[BehavioralSnapshot] = None
94+
self._turn = 0
95+
self._compaction_events: list[dict] = []
96+
self._drift_scores: list[float] = []
97+
self._pending_tool_counts: Counter = Counter()
98+
self._pending_vocabulary: set[str] = set()
99+
100+
# -----------------------------------------------------------------------
101+
# Hook callbacks — wire these into ClaudeAgentOptions.hooks
102+
# -----------------------------------------------------------------------
103+
104+
async def on_session_start(
105+
self, input_data: HookInput, tool_use_id: Optional[str], context: HookContext
106+
) -> HookJSONOutput:
107+
"""Capture the session's initial state as a baseline."""
108+
self._baseline = BehavioralSnapshot(turn=0, tokens=0, timestamp=time.time())
109+
self._current = BehavioralSnapshot(turn=0, tokens=0, timestamp=time.time())
110+
return {}
111+
112+
async def on_pre_tool_use(
113+
self, input_data: HookInput, tool_use_id: Optional[str], context: HookContext
114+
) -> HookJSONOutput:
115+
"""Record each tool call before it executes."""
116+
tool_name = input_data.get("tool_name", "unknown")
117+
self._pending_tool_counts[tool_name] += 1
118+
return {}
119+
120+
async def on_post_tool_use(
121+
self, input_data: HookInput, tool_use_id: Optional[str], context: HookContext
122+
) -> HookJSONOutput:
123+
"""Record tool output vocabulary — useful for detecting forgotten context."""
124+
tool_response = str(input_data.get("tool_response", ""))
125+
words = set(re.findall(r"\b[a-zA-Z_]\w{3,}\b", tool_response.lower()))
126+
self._pending_vocabulary.update(words)
127+
return {}
128+
129+
# -----------------------------------------------------------------------
130+
# Call this after each agent turn with the turn's AssistantMessage tokens
131+
# -----------------------------------------------------------------------
132+
133+
def record_turn(self, message_text: str, total_tokens: int) -> Optional[dict]:
134+
"""
135+
Record a completed turn and check for compaction boundary + behavioral drift.
136+
137+
Returns a dict describing any detected event (compaction or drift), or None.
138+
"""
139+
self._turn += 1
140+
words = set(re.findall(r"\b[a-zA-Z_]\w{3,}\b", message_text.lower()))
141+
142+
prev_tokens = self._current.tokens if self._current else 0
143+
144+
# Update current snapshot
145+
self._current = BehavioralSnapshot(
146+
turn=self._turn,
147+
tokens=total_tokens,
148+
timestamp=time.time(),
149+
tool_counts=Counter(self._pending_tool_counts),
150+
vocabulary=words | self._pending_vocabulary,
151+
)
152+
153+
# Reset accumulators
154+
self._pending_tool_counts.clear()
155+
self._pending_vocabulary.clear()
156+
157+
# Detect compaction boundary: token count drops significantly
158+
compaction_detected = False
159+
if prev_tokens > 0 and total_tokens < prev_tokens * (1 - self.compaction_drop_ratio):
160+
compaction_detected = True
161+
event = {
162+
"event": "compaction_suspected",
163+
"turn": self._turn,
164+
"tokens_before": prev_tokens,
165+
"tokens_after": total_tokens,
166+
"drop_ratio": round(1.0 - total_tokens / prev_tokens, 3),
167+
"timestamp": self._current.timestamp,
168+
}
169+
self._compaction_events.append(event)
170+
self._log(event)
171+
172+
# Re-baseline after compaction
173+
if self._baseline and self._baseline.tokens == 0:
174+
self._baseline = self._current
175+
176+
# Seed baseline from first real turn
177+
if self._baseline and self._baseline.tokens == 0 and total_tokens > 0:
178+
self._baseline = BehavioralSnapshot(
179+
turn=self._turn,
180+
tokens=total_tokens,
181+
timestamp=self._current.timestamp,
182+
tool_counts=Counter(self._current.tool_counts),
183+
vocabulary=set(self._current.vocabulary),
184+
)
185+
return None # Nothing to compare yet
186+
187+
if not self._baseline or self._baseline.tokens == 0:
188+
return None
189+
190+
# Compute behavioral drift score (Context Consistency Score)
191+
ccs = self._compute_ccs()
192+
self._drift_scores.append(ccs)
193+
194+
result = None
195+
if ccs < (1.0 - self.drift_threshold) or compaction_detected:
196+
result = {
197+
"event": "behavioral_drift" if not compaction_detected else "post_compaction_drift",
198+
"turn": self._turn,
199+
"ccs": round(ccs, 3),
200+
"compaction_at_this_turn": compaction_detected,
201+
"ghost_terms": list(self._ghost_terms()),
202+
"tool_shift": self._tool_shift_summary(),
203+
}
204+
self._log(result)
205+
206+
return result
207+
208+
# -----------------------------------------------------------------------
209+
# Scoring helpers
210+
# -----------------------------------------------------------------------
211+
212+
def _compute_ccs(self) -> float:
213+
"""
214+
Context Consistency Score: [0, 1] where 1.0 = no behavioral change.
215+
216+
Combines:
217+
- Vocabulary overlap: Jaccard similarity vs baseline
218+
- Tool distribution shift: Jensen-Shannon divergence (inverted)
219+
"""
220+
vocab_score = self._vocab_overlap()
221+
tool_score = self._tool_consistency()
222+
return 0.6 * vocab_score + 0.4 * tool_score
223+
224+
def _vocab_overlap(self) -> float:
225+
if not self._baseline.vocabulary or not self._current.vocabulary:
226+
return 1.0
227+
intersection = self._baseline.vocabulary & self._current.vocabulary
228+
union = self._baseline.vocabulary | self._current.vocabulary
229+
return len(intersection) / len(union) if union else 1.0
230+
231+
def _ghost_terms(self) -> list[str]:
232+
"""Terms present at baseline but absent from recent output — 'forgotten' vocabulary."""
233+
if not self._baseline or not self._current:
234+
return []
235+
return sorted(self._baseline.vocabulary - self._current.vocabulary)[:20]
236+
237+
def _tool_consistency(self) -> float:
238+
"""Jensen-Shannon divergence inverted: 1.0 = identical tool distribution."""
239+
if not self._baseline.tool_counts or not self._current.tool_counts:
240+
return 1.0
241+
all_tools = set(self._baseline.tool_counts) | set(self._current.tool_counts)
242+
base_total = sum(self._baseline.tool_counts.values()) or 1
243+
curr_total = sum(self._current.tool_counts.values()) or 1
244+
p = {t: self._baseline.tool_counts.get(t, 0) / base_total for t in all_tools}
245+
q = {t: self._current.tool_counts.get(t, 0) / curr_total for t in all_tools}
246+
m = {t: 0.5 * (p[t] + q[t]) for t in all_tools}
247+
248+
def kl(a, b):
249+
return sum(a[t] * math.log(a[t] / b[t] + 1e-10) for t in all_tools if a[t] > 0)
250+
251+
jsd = 0.5 * kl(p, m) + 0.5 * kl(q, m)
252+
return max(0.0, 1.0 - jsd)
253+
254+
def _tool_shift_summary(self) -> dict:
255+
if not self._baseline or not self._current:
256+
return {}
257+
all_tools = set(self._baseline.tool_counts) | set(self._current.tool_counts)
258+
return {
259+
t: {
260+
"baseline": self._baseline.tool_counts.get(t, 0),
261+
"current": self._current.tool_counts.get(t, 0),
262+
}
263+
for t in all_tools
264+
}
265+
266+
def summary(self) -> dict:
267+
return {
268+
"turns": self._turn,
269+
"compaction_events": len(self._compaction_events),
270+
"avg_ccs": round(sum(self._drift_scores) / len(self._drift_scores), 3)
271+
if self._drift_scores else None,
272+
"min_ccs": round(min(self._drift_scores), 3) if self._drift_scores else None,
273+
"compaction_detail": self._compaction_events,
274+
}
275+
276+
def _log(self, event: dict) -> None:
277+
if self.log_path:
278+
with open(self.log_path, "a") as f:
279+
f.write(json.dumps(event) + "\n")
280+
else:
281+
print(f"[session_monitor] {json.dumps(event)}")
282+
283+
284+
# ---------------------------------------------------------------------------
285+
# Demo: run a short session and monitor behavioral consistency
286+
# ---------------------------------------------------------------------------
287+
288+
async def main():
289+
monitor = SessionMonitor(
290+
compaction_drop_ratio=0.20,
291+
drift_threshold=0.30,
292+
log_path=None, # set to Path("session_monitor.jsonl") to persist
293+
)
294+
295+
options = ClaudeAgentOptions(
296+
hooks={
297+
"SessionStart": [monitor.on_session_start],
298+
"PreToolUse": [monitor.on_pre_tool_use],
299+
"PostToolUse": [monitor.on_post_tool_use],
300+
}
301+
)
302+
303+
async with ClaudeSDKClient(options=options) as client:
304+
# Example: run a short multi-turn session
305+
prompts = [
306+
"What Python libraries are good for data analysis?",
307+
"How do I read a CSV with pandas?",
308+
"Now forget everything about pandas. Tell me about numpy arrays.",
309+
]
310+
311+
total_tokens = 0
312+
313+
async for message in client.process_query(
314+
"\n\n".join(prompts),
315+
options=options,
316+
):
317+
text = ""
318+
if isinstance(message, AssistantMessage):
319+
for block in message.content:
320+
if isinstance(block, TextBlock):
321+
text += block.text
322+
elif isinstance(message, ResultMessage):
323+
# ResultMessage carries cumulative token usage
324+
total_tokens = getattr(message, "usage", {}).get("output_tokens", total_tokens)
325+
326+
if text:
327+
event = monitor.record_turn(text, total_tokens)
328+
if event:
329+
print(f"\n⚠ Behavioral event: {json.dumps(event, indent=2)}")
330+
331+
print("\n=== Session summary ===")
332+
print(json.dumps(monitor.summary(), indent=2))
333+
print()
334+
print("Note: OnCompaction + OnContextThreshold hooks (Issue #772) would allow")
335+
print("exact compaction-boundary capture instead of the token-drop heuristic above.")
336+
337+
338+
if __name__ == "__main__":
339+
asyncio.run(main())

0 commit comments

Comments
 (0)