-
Notifications
You must be signed in to change notification settings - Fork 971
Expand file tree
/
Copy pathsession_monitor.py
More file actions
360 lines (300 loc) · 12.4 KB
/
session_monitor.py
File metadata and controls
360 lines (300 loc) · 12.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
#!/usr/bin/env python
"""session_monitor.py — behavioral consistency monitoring for long SDK sessions.
This example stays on the public SDK surface:
- `HookMatcher`-based `PreToolUse` / `PostToolUse` callbacks
- `ClaudeSDKClient.query()` + `receive_response()` for turns
- `ClaudeSDKClient.get_context_usage()` for context-window telemetry
Together, those are enough to build a lightweight monitor for long-running
sessions where context compaction or summarization may silently change the
agent's behavior.
Because a short fresh session will not reliably trigger compaction on demand,
the default runnable demo below uses a simulated token-usage boundary while the
live integration helpers keep the exact public SDK wiring you would use in a
real session.
Reference: https://github.com/anthropics/claude-agent-sdk-python/issues/772
"""
import asyncio
import json
import math
import os
import re
import time
from collections import Counter
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Optional
from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient, HookMatcher
from claude_agent_sdk.types import (
AssistantMessage,
HookContext,
HookJSONOutput,
PostToolUseHookInput,
PreToolUseHookInput,
ResultMessage,
TextBlock,
)
@dataclass
class BehavioralSnapshot:
"""What the agent looks like at one point in the session."""
turn: int
tokens: int
timestamp: float
tool_counts: Counter = field(default_factory=Counter)
vocabulary: set[str] = field(default_factory=set)
class SessionMonitor:
"""Track vocabulary and tool-use drift across a Claude SDK session."""
def __init__(
self,
compaction_drop_ratio: float = 0.20,
drift_threshold: float = 0.30,
log_path: Optional[Path] = None,
) -> None:
self.compaction_drop_ratio = compaction_drop_ratio
self.drift_threshold = drift_threshold
self.log_path = log_path
self._baseline: Optional[BehavioralSnapshot] = None
self._current: Optional[BehavioralSnapshot] = None
self._turn = 0
self._compaction_events: list[dict[str, Any]] = []
self._drift_scores: list[float] = []
self._pending_tool_counts: Counter = Counter()
self._pending_vocabulary: set[str] = set()
async def on_pre_tool_use(
self,
input_data: PreToolUseHookInput,
tool_use_id: Optional[str],
context: HookContext,
) -> HookJSONOutput:
"""Record each tool call before execution."""
del tool_use_id, context
self._pending_tool_counts[input_data["tool_name"]] += 1
return {}
async def on_post_tool_use(
self,
input_data: PostToolUseHookInput,
tool_use_id: Optional[str],
context: HookContext,
) -> HookJSONOutput:
"""Capture vocabulary emitted by tool results."""
del tool_use_id, context
tool_response = str(input_data.get("tool_response", ""))
words = set(re.findall(r"\b[a-zA-Z_]\w{3,}\b", tool_response.lower()))
self._pending_vocabulary.update(words)
return {}
def record_turn(self, message_text: str, total_tokens: int) -> Optional[dict[str, Any]]:
"""Record a completed turn and return any detected event."""
self._turn += 1
words = set(re.findall(r"\b[a-zA-Z_]\w{3,}\b", message_text.lower()))
prev_tokens = self._current.tokens if self._current else 0
self._current = BehavioralSnapshot(
turn=self._turn,
tokens=total_tokens,
timestamp=time.time(),
tool_counts=Counter(self._pending_tool_counts),
vocabulary=words | self._pending_vocabulary,
)
self._pending_tool_counts.clear()
self._pending_vocabulary.clear()
if self._baseline is None and total_tokens > 0:
self._baseline = BehavioralSnapshot(
turn=self._turn,
tokens=total_tokens,
timestamp=self._current.timestamp,
tool_counts=Counter(self._current.tool_counts),
vocabulary=set(self._current.vocabulary),
)
return None
if self._baseline is None:
return None
compaction_detected = False
if prev_tokens > 0 and total_tokens < prev_tokens * (1 - self.compaction_drop_ratio):
compaction_detected = True
event = {
"event": "compaction_suspected",
"turn": self._turn,
"tokens_before": prev_tokens,
"tokens_after": total_tokens,
"drop_ratio": round(1.0 - total_tokens / prev_tokens, 3),
"timestamp": self._current.timestamp,
}
self._compaction_events.append(event)
self._log(event)
ccs = self._compute_ccs()
self._drift_scores.append(ccs)
if ccs < (1.0 - self.drift_threshold) or compaction_detected:
event = {
"event": "post_compaction_drift" if compaction_detected else "behavioral_drift",
"turn": self._turn,
"ccs": round(ccs, 3),
"compaction_at_this_turn": compaction_detected,
"ghost_terms": self._ghost_terms(),
"tool_shift": self._tool_shift_summary(),
}
self._log(event)
return event
return None
def _compute_ccs(self) -> float:
"""Context Consistency Score: 1.0 means no behavioral change."""
return 0.6 * self._vocab_overlap() + 0.4 * self._tool_consistency()
def _vocab_overlap(self) -> float:
if not self._baseline or not self._baseline.vocabulary or not self._current:
return 1.0
if not self._current.vocabulary:
return 1.0
intersection = self._baseline.vocabulary & self._current.vocabulary
union = self._baseline.vocabulary | self._current.vocabulary
return len(intersection) / len(union) if union else 1.0
def _ghost_terms(self) -> list[str]:
if not self._baseline or not self._current:
return []
return sorted(self._baseline.vocabulary - self._current.vocabulary)[:20]
def _tool_consistency(self) -> float:
if not self._baseline or not self._current:
return 1.0
if not self._baseline.tool_counts or not self._current.tool_counts:
return 1.0
all_tools = set(self._baseline.tool_counts) | set(self._current.tool_counts)
baseline_total = sum(self._baseline.tool_counts.values()) or 1
current_total = sum(self._current.tool_counts.values()) or 1
baseline_distribution = {
tool: self._baseline.tool_counts.get(tool, 0) / baseline_total
for tool in all_tools
}
current_distribution = {
tool: self._current.tool_counts.get(tool, 0) / current_total
for tool in all_tools
}
midpoint = {
tool: 0.5 * (baseline_distribution[tool] + current_distribution[tool])
for tool in all_tools
}
def kl_divergence(lhs: dict[str, float], rhs: dict[str, float]) -> float:
return sum(
lhs[tool] * math.log(lhs[tool] / rhs[tool] + 1e-10)
for tool in all_tools
if lhs[tool] > 0
)
jsd = 0.5 * kl_divergence(baseline_distribution, midpoint) + 0.5 * kl_divergence(
current_distribution, midpoint
)
return max(0.0, 1.0 - jsd)
def _tool_shift_summary(self) -> dict[str, dict[str, int]]:
if not self._baseline or not self._current:
return {}
all_tools = set(self._baseline.tool_counts) | set(self._current.tool_counts)
return {
tool: {
"baseline": self._baseline.tool_counts.get(tool, 0),
"current": self._current.tool_counts.get(tool, 0),
}
for tool in all_tools
}
def summary(self) -> dict[str, Any]:
return {
"turns": self._turn,
"compaction_events": len(self._compaction_events),
"avg_ccs": round(sum(self._drift_scores) / len(self._drift_scores), 3)
if self._drift_scores
else None,
"min_ccs": round(min(self._drift_scores), 3) if self._drift_scores else None,
"compaction_detail": self._compaction_events,
}
def _log(self, event: dict[str, Any]) -> None:
if self.log_path:
with self.log_path.open("a", encoding="utf-8") as handle:
handle.write(json.dumps(event) + "\n")
else:
print(f"[session_monitor] {json.dumps(event)}")
async def run_monitored_turn(
client: ClaudeSDKClient,
monitor: SessionMonitor,
prompt: str,
) -> Optional[dict[str, Any]]:
"""Run one SDK turn, then score it using public message + usage APIs."""
await client.query(prompt)
text_parts: list[str] = []
async for message in client.receive_response():
if isinstance(message, AssistantMessage):
for block in message.content:
if isinstance(block, TextBlock):
text_parts.append(block.text)
elif isinstance(message, ResultMessage) and message.is_error:
raise RuntimeError(message.result or "Claude SDK turn failed")
usage = await client.get_context_usage()
total_tokens = int(usage.get("totalTokens", 0))
return monitor.record_turn(" ".join(text_parts), total_tokens)
def run_simulated_boundary_demo() -> None:
"""Run a deterministic boundary demo using the same scoring logic."""
monitor = SessionMonitor(
compaction_drop_ratio=0.20,
drift_threshold=0.30,
log_path=None,
)
synthetic_turns = [
(
"Use jwt validation with bcrypt hashes, redis-backed sessions, and "
"foreign_key-safe migrations for the auth schema.",
1200,
),
(
"Keep jwt auth, bcrypt password storage, and redis session checks "
"intact while you add the profile endpoint.",
1480,
),
(
"Add PATCH /profile rate limiting with concise validation and 429 "
"responses. Focus on the endpoint only.",
860,
),
]
print("=== Deterministic session boundary demo ===")
print("This uses simulated token snapshots so the monitor always shows a boundary event.\n")
for text, total_tokens in synthetic_turns:
event = monitor.record_turn(text, total_tokens)
if event:
print(json.dumps(event, indent=2))
print("\n=== Session summary ===")
print(json.dumps(monitor.summary(), indent=2))
async def run_live_demo() -> None:
"""Optional live SDK demo using the same monitor."""
monitor = SessionMonitor(
compaction_drop_ratio=0.20,
drift_threshold=0.30,
log_path=None,
)
options = ClaudeAgentOptions(
allowed_tools=["Bash"],
hooks={
"PreToolUse": [
HookMatcher(matcher="Bash", hooks=[monitor.on_pre_tool_use]),
],
"PostToolUse": [
HookMatcher(matcher="Bash", hooks=[monitor.on_post_tool_use]),
],
},
)
prompts = [
"Use Bash to print 'jwt bcrypt redis', then explain how those terms fit together in a web auth stack.",
"Use Bash to print 'id,name\\n1,Ada', then explain how pandas would load this CSV.",
"Use Bash to print '[0 1 2]', then explain numpy arrays in one short paragraph.",
]
async with ClaudeSDKClient(options=options) as client:
for prompt in prompts:
event = await run_monitored_turn(client, monitor, prompt)
if event:
print(f"\n[session_monitor] Behavioral event: {json.dumps(event, indent=2)}")
print("\n=== Session summary ===")
print(json.dumps(monitor.summary(), indent=2))
print()
print("Note: native OnCompaction / OnContextThreshold hooks would still be better.")
print("This sample shows the closest monitor you can build today with public hooks")
print("plus get_context_usage() as the compaction-boundary heuristic.")
async def main() -> None:
run_simulated_boundary_demo()
if os.getenv("CLAUDE_SESSION_MONITOR_LIVE") == "1":
print("\n=== Live SDK session demo ===")
await run_live_demo()
else:
print("\nSet CLAUDE_SESSION_MONITOR_LIVE=1 to also run the live SDK session demo.")
if __name__ == "__main__":
asyncio.run(main())