Skip to content

Commit 6484891

Browse files
committed
Improve restore flow and seed data validation
1 parent 5a0e681 commit 6484891

6 files changed

Lines changed: 239 additions & 163 deletions

File tree

agent_debugger_sdk/core/context/session_manager.py

Lines changed: 76 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import uuid
66
from collections.abc import Awaitable, Callable
77
from datetime import datetime, timezone
8-
from typing import TYPE_CHECKING
8+
from typing import TYPE_CHECKING, Any
99

1010
import httpx
1111

@@ -21,6 +21,77 @@ class _CheckpointRestoreError(Exception):
2121
pass
2222

2323

24+
def _resolve_restore_server_url(server_url: str | None) -> str:
25+
"""Resolve the checkpoint restore server URL."""
26+
if server_url is not None:
27+
return server_url
28+
29+
from agent_debugger_sdk.config import get_config
30+
31+
config = get_config()
32+
return config.endpoint or "http://localhost:8000"
33+
34+
35+
async def _fetch_checkpoint_payload(
36+
client: httpx.AsyncClient,
37+
checkpoint_id: str,
38+
server_url: str,
39+
) -> dict[str, Any]:
40+
"""Fetch and decode checkpoint payload data from the server."""
41+
try:
42+
response = await client.get(f"{server_url}/api/checkpoints/{checkpoint_id}")
43+
response.raise_for_status()
44+
payload = response.json()
45+
except httpx.HTTPStatusError as e:
46+
raise _CheckpointRestoreError(
47+
f"Failed to restore checkpoint {checkpoint_id!r} from {server_url}: "
48+
f"{e.response.status_code} {e.response.reason_phrase}"
49+
) from e
50+
except httpx.RequestError as e:
51+
raise _CheckpointRestoreError(
52+
f"Network error while restoring checkpoint {checkpoint_id!r} from {server_url}: {e}"
53+
) from e
54+
except Exception as e:
55+
raise _CheckpointRestoreError(
56+
f"Unexpected error while restoring checkpoint {checkpoint_id!r} from {server_url}: {e}"
57+
) from e
58+
59+
if not isinstance(payload, dict):
60+
raise _CheckpointRestoreError(
61+
f"Unexpected checkpoint payload type for {checkpoint_id!r} from {server_url}: "
62+
f"{type(payload).__name__}"
63+
)
64+
65+
return payload
66+
67+
68+
def _build_restored_session(
69+
checkpoint_id: str,
70+
checkpoint_data: dict[str, Any],
71+
*,
72+
session_id: str | None = None,
73+
label: str = "",
74+
) -> tuple[Session, BaseCheckpointState | None]:
75+
"""Build a restored Session and validated checkpoint state from payload data."""
76+
from agent_debugger_sdk.checkpoints import validate_checkpoint_state
77+
78+
state_dict = checkpoint_data.get("state", {})
79+
original_session_id = checkpoint_data.get("session_id", "")
80+
81+
session = Session(
82+
id=session_id or str(uuid.uuid4()),
83+
agent_name=label or f"restored from {checkpoint_id[:8]}",
84+
framework=state_dict.get("framework", "custom"),
85+
config={
86+
"restored_from_checkpoint": checkpoint_id,
87+
"original_session_id": original_session_id,
88+
},
89+
)
90+
91+
restored_state = validate_checkpoint_state(state_dict)
92+
return session, restored_state
93+
94+
2495
class SessionManager:
2596
"""Manage session lifecycle for TraceContext.
2697
@@ -79,48 +150,9 @@ async def restore_from_checkpoint(
79150
Example:
80151
>>> session, state = await SessionManager.restore_from_checkpoint("ckpt_123")
81152
"""
82-
from agent_debugger_sdk.checkpoints import validate_checkpoint_state
83-
from agent_debugger_sdk.config import get_config
153+
resolved_server_url = _resolve_restore_server_url(server_url)
84154

85-
if server_url is None:
86-
config = get_config()
87-
server_url = config.endpoint or "http://localhost:8000"
88-
89-
# Fetch checkpoint data using a temporary client (avoids connection leaks)
90-
# All processing happens inside the context manager to ensure checkpoint_data is in scope
91155
async with httpx.AsyncClient() as client:
92-
try:
93-
response = await client.get(f"{server_url}/api/checkpoints/{checkpoint_id}")
94-
response.raise_for_status()
95-
checkpoint_data = response.json()
96-
except httpx.HTTPStatusError as e:
97-
raise _CheckpointRestoreError(
98-
f"Failed to restore checkpoint {checkpoint_id!r} from {server_url}: "
99-
f"{e.response.status_code} {e.response.reason_phrase}"
100-
) from e
101-
except httpx.RequestError as e:
102-
raise _CheckpointRestoreError(
103-
f"Network error while restoring checkpoint {checkpoint_id!r} from {server_url}: {e}"
104-
) from e
105-
except Exception as e:
106-
# Catch any other unexpected errors and wrap them
107-
raise _CheckpointRestoreError(
108-
f"Unexpected error while restoring checkpoint {checkpoint_id!r} from {server_url}: {e}"
109-
) from e
110-
111-
# Process the checkpoint data inside the context manager where checkpoint_data is in scope
112-
state_dict = checkpoint_data.get("state", {})
113-
original_session_id = checkpoint_data.get("session_id", "")
114-
115-
session = Session(
116-
id=session_id or str(uuid.uuid4()),
117-
agent_name=label or f"restored from {checkpoint_id[:8]}",
118-
framework=state_dict.get("framework", "custom"),
119-
config={
120-
"restored_from_checkpoint": checkpoint_id,
121-
"original_session_id": original_session_id,
122-
},
123-
)
124-
125-
restored_state = validate_checkpoint_state(state_dict)
126-
return session, restored_state
156+
checkpoint_data = await _fetch_checkpoint_payload(client, checkpoint_id, resolved_server_url)
157+
158+
return _build_restored_session(checkpoint_id, checkpoint_data, session_id=session_id, label=label)

api/app_context.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,29 +28,35 @@ def init_app_context() -> None:
2828
_redaction_pipeline = RedactionPipeline.from_config()
2929

3030

31+
def _ensure_initialized() -> None:
32+
"""Initialize app context lazily when first accessed."""
33+
if engine is None or async_session_maker is None or trace_intelligence is None or _redaction_pipeline is None:
34+
init_app_context()
35+
36+
3137
def require_engine() -> AsyncEngine:
3238
"""Return the initialized database engine."""
33-
if engine is None:
34-
raise RuntimeError("API app context has not been initialized")
39+
_ensure_initialized()
40+
assert engine is not None
3541
return engine
3642

3743

3844
def require_session_maker() -> async_sessionmaker[AsyncSession]:
3945
"""Return the initialized async session maker."""
40-
if async_session_maker is None:
41-
raise RuntimeError("API app context has not been initialized")
46+
_ensure_initialized()
47+
assert async_session_maker is not None
4248
return async_session_maker
4349

4450

4551
def require_trace_intelligence() -> TraceIntelligence:
4652
"""Return the initialized trace intelligence service."""
47-
if trace_intelligence is None:
48-
raise RuntimeError("API app context has not been initialized")
53+
_ensure_initialized()
54+
assert trace_intelligence is not None
4955
return trace_intelligence
5056

5157

5258
def _get_redaction_pipeline() -> RedactionPipeline:
5359
"""Return the configured redaction pipeline."""
54-
if _redaction_pipeline is None:
55-
raise RuntimeError("API app context has not been initialized")
60+
_ensure_initialized()
61+
assert _redaction_pipeline is not None
5662
return _redaction_pipeline

scripts/seed_demo_sessions.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,18 @@
2727
# Session enrichment data: realistic values for demo sessions
2828
# Note: failure_count is computed in API layer (services.py) as errors count
2929
# behavior_alert_count is computed in API layer from AnomalyAlertModel records
30+
def validate_session_enrichment(session_id: str, enrichment: dict[str, object]) -> None:
31+
"""Validate curated enrichment metrics for demo seed sessions."""
32+
total_tokens = enrichment.get("total_tokens")
33+
total_cost_usd = enrichment.get("total_cost_usd")
34+
35+
if not isinstance(total_tokens, int) or total_tokens <= 0:
36+
raise ValueError(f"Seed enrichment for {session_id} must define positive total_tokens")
37+
38+
if not isinstance(total_cost_usd, (int, float)) or float(total_cost_usd) <= 0:
39+
raise ValueError(f"Seed enrichment for {session_id} must define positive total_cost_usd")
40+
41+
3042
SESSION_ENRICHMENT = {
3143
"seed-prompt-injection": {
3244
"total_tokens": 856,
@@ -95,19 +107,41 @@
95107
}
96108

97109

110+
def validate_session_metrics(total_tokens: int, total_cost_usd: float, *, context: str) -> None:
111+
"""Validate curated session metrics before persisting demo seed data."""
112+
if total_tokens < 0:
113+
raise ValueError(f"{context}: total_tokens must be non-negative, got {total_tokens}")
114+
if total_cost_usd < 0:
115+
raise ValueError(f"{context}: total_cost_usd must be non-negative, got {total_cost_usd}")
116+
117+
has_tokens = total_tokens > 0
118+
has_cost = total_cost_usd > 0
119+
if has_tokens != has_cost:
120+
raise ValueError(
121+
f"{context}: total_tokens and total_cost_usd must either both be zero or both be positive "
122+
f"(got total_tokens={total_tokens}, total_cost_usd={total_cost_usd})"
123+
)
124+
125+
98126
async def enrich_session(session_id: str, session_maker: async_sessionmaker[AsyncSession]) -> None:
99127
"""Enrich a session with realistic data fields and behavior alerts."""
100128
enrichment = SESSION_ENRICHMENT.get(session_id, {})
101129
if not enrichment:
102130
return
103131

132+
validate_session_enrichment(session_id, enrichment)
133+
134+
total_tokens = enrichment.get("total_tokens", 0)
135+
total_cost_usd = enrichment.get("total_cost_usd", 0.0)
136+
validate_session_metrics(total_tokens, total_cost_usd, context=f"seed enrichment for {session_id}")
137+
104138
async with session_maker() as db_session:
105139
repo = TraceRepository(db_session)
106140

107141
# Update session fields
108142
update_data = {
109-
"total_tokens": enrichment.get("total_tokens", 0),
110-
"total_cost_usd": enrichment.get("total_cost_usd", 0.0),
143+
"total_tokens": total_tokens,
144+
"total_cost_usd": total_cost_usd,
111145
"errors": enrichment.get("errors", 0),
112146
}
113147

tests/conftest.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,17 @@ def setup_test_db():
5656
from storage.engine import create_db_engine
5757

5858
async def _setup():
59-
# Reset app context globals so each xdist worker uses its own engine
60-
app_context.engine = None
61-
app_context.async_session_maker = None
59+
# Build schema with a short-lived engine bound to this temporary loop.
60+
# The app context is initialized lazily inside each test loop.
6261
engine = create_db_engine()
6362
async with engine.begin() as conn:
6463
await conn.run_sync(Base.metadata.create_all)
65-
app_context.init_app_context()
64+
await engine.dispose()
65+
66+
app_context.engine = None
67+
app_context.async_session_maker = None
68+
app_context.trace_intelligence = None
69+
app_context._redaction_pipeline = None
6670

6771
asyncio.run(_setup())
6872
yield

0 commit comments

Comments
 (0)