Skip to content

Commit b09c33d

Browse files
committed
examples: add session monitoring example
1 parent 566e41f commit b09c33d

File tree

1 file changed

+360
-0
lines changed

1 file changed

+360
-0
lines changed

examples/session_monitor.py

Lines changed: 360 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,360 @@
1+
#!/usr/bin/env python
2+
"""session_monitor.py — behavioral consistency monitoring for long SDK sessions.
3+
4+
This example stays on the public SDK surface:
5+
6+
- `HookMatcher`-based `PreToolUse` / `PostToolUse` callbacks
7+
- `ClaudeSDKClient.query()` + `receive_response()` for turns
8+
- `ClaudeSDKClient.get_context_usage()` for context-window telemetry
9+
10+
Together, those are enough to build a lightweight monitor for long-running
11+
sessions where context compaction or summarization may silently change the
12+
agent's behavior.
13+
14+
Because a short fresh session will not reliably trigger compaction on demand,
15+
the default runnable demo below uses a simulated token-usage boundary while the
16+
live integration helpers keep the exact public SDK wiring you would use in a
17+
real session.
18+
19+
Reference: https://github.com/anthropics/claude-agent-sdk-python/issues/772
20+
"""
21+
22+
import asyncio
23+
import json
24+
import math
25+
import os
26+
import re
27+
import time
28+
from collections import Counter
29+
from dataclasses import dataclass, field
30+
from pathlib import Path
31+
from typing import Any, Optional
32+
33+
from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient, HookMatcher
34+
from claude_agent_sdk.types import (
35+
AssistantMessage,
36+
HookContext,
37+
HookJSONOutput,
38+
PostToolUseHookInput,
39+
PreToolUseHookInput,
40+
ResultMessage,
41+
TextBlock,
42+
)
43+
44+
45+
@dataclass
46+
class BehavioralSnapshot:
47+
"""What the agent looks like at one point in the session."""
48+
49+
turn: int
50+
tokens: int
51+
timestamp: float
52+
tool_counts: Counter = field(default_factory=Counter)
53+
vocabulary: set[str] = field(default_factory=set)
54+
55+
56+
class SessionMonitor:
57+
"""Track vocabulary and tool-use drift across a Claude SDK session."""
58+
59+
def __init__(
60+
self,
61+
compaction_drop_ratio: float = 0.20,
62+
drift_threshold: float = 0.30,
63+
log_path: Optional[Path] = None,
64+
) -> None:
65+
self.compaction_drop_ratio = compaction_drop_ratio
66+
self.drift_threshold = drift_threshold
67+
self.log_path = log_path
68+
69+
self._baseline: Optional[BehavioralSnapshot] = None
70+
self._current: Optional[BehavioralSnapshot] = None
71+
self._turn = 0
72+
self._compaction_events: list[dict[str, Any]] = []
73+
self._drift_scores: list[float] = []
74+
self._pending_tool_counts: Counter = Counter()
75+
self._pending_vocabulary: set[str] = set()
76+
77+
async def on_pre_tool_use(
78+
self,
79+
input_data: PreToolUseHookInput,
80+
tool_use_id: Optional[str],
81+
context: HookContext,
82+
) -> HookJSONOutput:
83+
"""Record each tool call before execution."""
84+
85+
del tool_use_id, context
86+
self._pending_tool_counts[input_data["tool_name"]] += 1
87+
return {}
88+
89+
async def on_post_tool_use(
90+
self,
91+
input_data: PostToolUseHookInput,
92+
tool_use_id: Optional[str],
93+
context: HookContext,
94+
) -> HookJSONOutput:
95+
"""Capture vocabulary emitted by tool results."""
96+
97+
del tool_use_id, context
98+
tool_response = str(input_data.get("tool_response", ""))
99+
words = set(re.findall(r"\b[a-zA-Z_]\w{3,}\b", tool_response.lower()))
100+
self._pending_vocabulary.update(words)
101+
return {}
102+
103+
def record_turn(self, message_text: str, total_tokens: int) -> Optional[dict[str, Any]]:
104+
"""Record a completed turn and return any detected event."""
105+
106+
self._turn += 1
107+
words = set(re.findall(r"\b[a-zA-Z_]\w{3,}\b", message_text.lower()))
108+
prev_tokens = self._current.tokens if self._current else 0
109+
110+
self._current = BehavioralSnapshot(
111+
turn=self._turn,
112+
tokens=total_tokens,
113+
timestamp=time.time(),
114+
tool_counts=Counter(self._pending_tool_counts),
115+
vocabulary=words | self._pending_vocabulary,
116+
)
117+
118+
self._pending_tool_counts.clear()
119+
self._pending_vocabulary.clear()
120+
121+
if self._baseline is None and total_tokens > 0:
122+
self._baseline = BehavioralSnapshot(
123+
turn=self._turn,
124+
tokens=total_tokens,
125+
timestamp=self._current.timestamp,
126+
tool_counts=Counter(self._current.tool_counts),
127+
vocabulary=set(self._current.vocabulary),
128+
)
129+
return None
130+
131+
if self._baseline is None:
132+
return None
133+
134+
compaction_detected = False
135+
if prev_tokens > 0 and total_tokens < prev_tokens * (1 - self.compaction_drop_ratio):
136+
compaction_detected = True
137+
event = {
138+
"event": "compaction_suspected",
139+
"turn": self._turn,
140+
"tokens_before": prev_tokens,
141+
"tokens_after": total_tokens,
142+
"drop_ratio": round(1.0 - total_tokens / prev_tokens, 3),
143+
"timestamp": self._current.timestamp,
144+
}
145+
self._compaction_events.append(event)
146+
self._log(event)
147+
148+
ccs = self._compute_ccs()
149+
self._drift_scores.append(ccs)
150+
151+
if ccs < (1.0 - self.drift_threshold) or compaction_detected:
152+
event = {
153+
"event": "post_compaction_drift" if compaction_detected else "behavioral_drift",
154+
"turn": self._turn,
155+
"ccs": round(ccs, 3),
156+
"compaction_at_this_turn": compaction_detected,
157+
"ghost_terms": self._ghost_terms(),
158+
"tool_shift": self._tool_shift_summary(),
159+
}
160+
self._log(event)
161+
return event
162+
163+
return None
164+
165+
def _compute_ccs(self) -> float:
166+
"""Context Consistency Score: 1.0 means no behavioral change."""
167+
168+
return 0.6 * self._vocab_overlap() + 0.4 * self._tool_consistency()
169+
170+
def _vocab_overlap(self) -> float:
171+
if not self._baseline or not self._baseline.vocabulary or not self._current:
172+
return 1.0
173+
if not self._current.vocabulary:
174+
return 1.0
175+
intersection = self._baseline.vocabulary & self._current.vocabulary
176+
union = self._baseline.vocabulary | self._current.vocabulary
177+
return len(intersection) / len(union) if union else 1.0
178+
179+
def _ghost_terms(self) -> list[str]:
180+
if not self._baseline or not self._current:
181+
return []
182+
return sorted(self._baseline.vocabulary - self._current.vocabulary)[:20]
183+
184+
def _tool_consistency(self) -> float:
185+
if not self._baseline or not self._current:
186+
return 1.0
187+
if not self._baseline.tool_counts or not self._current.tool_counts:
188+
return 1.0
189+
190+
all_tools = set(self._baseline.tool_counts) | set(self._current.tool_counts)
191+
baseline_total = sum(self._baseline.tool_counts.values()) or 1
192+
current_total = sum(self._current.tool_counts.values()) or 1
193+
baseline_distribution = {
194+
tool: self._baseline.tool_counts.get(tool, 0) / baseline_total
195+
for tool in all_tools
196+
}
197+
current_distribution = {
198+
tool: self._current.tool_counts.get(tool, 0) / current_total
199+
for tool in all_tools
200+
}
201+
midpoint = {
202+
tool: 0.5 * (baseline_distribution[tool] + current_distribution[tool])
203+
for tool in all_tools
204+
}
205+
206+
def kl_divergence(lhs: dict[str, float], rhs: dict[str, float]) -> float:
207+
return sum(
208+
lhs[tool] * math.log(lhs[tool] / rhs[tool] + 1e-10)
209+
for tool in all_tools
210+
if lhs[tool] > 0
211+
)
212+
213+
jsd = 0.5 * kl_divergence(baseline_distribution, midpoint) + 0.5 * kl_divergence(
214+
current_distribution, midpoint
215+
)
216+
return max(0.0, 1.0 - jsd)
217+
218+
def _tool_shift_summary(self) -> dict[str, dict[str, int]]:
219+
if not self._baseline or not self._current:
220+
return {}
221+
all_tools = set(self._baseline.tool_counts) | set(self._current.tool_counts)
222+
return {
223+
tool: {
224+
"baseline": self._baseline.tool_counts.get(tool, 0),
225+
"current": self._current.tool_counts.get(tool, 0),
226+
}
227+
for tool in all_tools
228+
}
229+
230+
def summary(self) -> dict[str, Any]:
231+
return {
232+
"turns": self._turn,
233+
"compaction_events": len(self._compaction_events),
234+
"avg_ccs": round(sum(self._drift_scores) / len(self._drift_scores), 3)
235+
if self._drift_scores
236+
else None,
237+
"min_ccs": round(min(self._drift_scores), 3) if self._drift_scores else None,
238+
"compaction_detail": self._compaction_events,
239+
}
240+
241+
def _log(self, event: dict[str, Any]) -> None:
242+
if self.log_path:
243+
with self.log_path.open("a", encoding="utf-8") as handle:
244+
handle.write(json.dumps(event) + "\n")
245+
else:
246+
print(f"[session_monitor] {json.dumps(event)}")
247+
248+
249+
async def run_monitored_turn(
250+
client: ClaudeSDKClient,
251+
monitor: SessionMonitor,
252+
prompt: str,
253+
) -> Optional[dict[str, Any]]:
254+
"""Run one SDK turn, then score it using public message + usage APIs."""
255+
256+
await client.query(prompt)
257+
258+
text_parts: list[str] = []
259+
async for message in client.receive_response():
260+
if isinstance(message, AssistantMessage):
261+
for block in message.content:
262+
if isinstance(block, TextBlock):
263+
text_parts.append(block.text)
264+
elif isinstance(message, ResultMessage) and message.is_error:
265+
raise RuntimeError(message.result or "Claude SDK turn failed")
266+
267+
usage = await client.get_context_usage()
268+
total_tokens = int(usage.get("totalTokens", 0))
269+
return monitor.record_turn(" ".join(text_parts), total_tokens)
270+
271+
272+
def run_simulated_boundary_demo() -> None:
273+
"""Run a deterministic boundary demo using the same scoring logic."""
274+
monitor = SessionMonitor(
275+
compaction_drop_ratio=0.20,
276+
drift_threshold=0.30,
277+
log_path=None,
278+
)
279+
280+
synthetic_turns = [
281+
(
282+
"Use jwt validation with bcrypt hashes, redis-backed sessions, and "
283+
"foreign_key-safe migrations for the auth schema.",
284+
1200,
285+
),
286+
(
287+
"Keep jwt auth, bcrypt password storage, and redis session checks "
288+
"intact while you add the profile endpoint.",
289+
1480,
290+
),
291+
(
292+
"Add PATCH /profile rate limiting with concise validation and 429 "
293+
"responses. Focus on the endpoint only.",
294+
860,
295+
),
296+
]
297+
298+
print("=== Deterministic session boundary demo ===")
299+
print("This uses simulated token snapshots so the monitor always shows a boundary event.\n")
300+
for text, total_tokens in synthetic_turns:
301+
event = monitor.record_turn(text, total_tokens)
302+
if event:
303+
print(json.dumps(event, indent=2))
304+
305+
print("\n=== Session summary ===")
306+
print(json.dumps(monitor.summary(), indent=2))
307+
308+
309+
async def run_live_demo() -> None:
310+
"""Optional live SDK demo using the same monitor."""
311+
monitor = SessionMonitor(
312+
compaction_drop_ratio=0.20,
313+
drift_threshold=0.30,
314+
log_path=None,
315+
)
316+
317+
options = ClaudeAgentOptions(
318+
allowed_tools=["Bash"],
319+
hooks={
320+
"PreToolUse": [
321+
HookMatcher(matcher="Bash", hooks=[monitor.on_pre_tool_use]),
322+
],
323+
"PostToolUse": [
324+
HookMatcher(matcher="Bash", hooks=[monitor.on_post_tool_use]),
325+
],
326+
},
327+
)
328+
329+
prompts = [
330+
"Use Bash to print 'jwt bcrypt redis', then explain how those terms fit together in a web auth stack.",
331+
"Use Bash to print 'id,name\\n1,Ada', then explain how pandas would load this CSV.",
332+
"Use Bash to print '[0 1 2]', then explain numpy arrays in one short paragraph.",
333+
]
334+
335+
async with ClaudeSDKClient(options=options) as client:
336+
for prompt in prompts:
337+
event = await run_monitored_turn(client, monitor, prompt)
338+
if event:
339+
print(f"\n[session_monitor] Behavioral event: {json.dumps(event, indent=2)}")
340+
341+
print("\n=== Session summary ===")
342+
print(json.dumps(monitor.summary(), indent=2))
343+
print()
344+
print("Note: native OnCompaction / OnContextThreshold hooks would still be better.")
345+
print("This sample shows the closest monitor you can build today with public hooks")
346+
print("plus get_context_usage() as the compaction-boundary heuristic.")
347+
348+
349+
async def main() -> None:
350+
run_simulated_boundary_demo()
351+
352+
if os.getenv("CLAUDE_SESSION_MONITOR_LIVE") == "1":
353+
print("\n=== Live SDK session demo ===")
354+
await run_live_demo()
355+
else:
356+
print("\nSet CLAUDE_SESSION_MONITOR_LIVE=1 to also run the live SDK session demo.")
357+
358+
359+
if __name__ == "__main__":
360+
asyncio.run(main())

0 commit comments

Comments
 (0)