Skip to content

Commit e9add4f

Browse files
committed
Refine replay UX and similar failure search
1 parent 687086b commit e9add4f

55 files changed

Lines changed: 4268 additions & 439 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

api/entity_routes.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class EntityListResponse(BaseModel):
4040
class EntitySummaryResponse(BaseModel):
4141
"""Response schema for entity summary statistics."""
4242

43+
agent_name_count: int
4344
tool_name_count: int
4445
error_type_count: int
4546
model_count: int
@@ -156,6 +157,7 @@ async def get_entity_summary(
156157
summary = await repo.get_entity_summary()
157158

158159
return EntitySummaryResponse(
160+
agent_name_count=summary.get(EntityType.AGENT_NAME, 0),
159161
tool_name_count=summary.get(EntityType.TOOL_NAME, 0),
160162
error_type_count=summary.get(EntityType.ERROR_TYPE, 0),
161163
model_count=summary.get(EntityType.MODEL, 0),

api/replay_routes.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@
2626
router = APIRouter(tags=["replay"])
2727

2828

29+
def _split_csv_param(value: str | None) -> set[str]:
30+
return {item.strip() for item in (value or "").split(",") if item.strip()}
31+
32+
2933
@router.get("/api/sessions/{session_id}/replay", response_model=ReplayResponse)
3034
async def replay_session(
3135
session_id: str,
@@ -46,6 +50,8 @@ async def replay_session(
4650
# Extract the default value when the raw Query object is passed through.
4751
if hasattr(collapse_threshold, "default"):
4852
collapse_threshold = float(collapse_threshold.default)
53+
if hasattr(stop_at_breakpoint, "default"):
54+
stop_at_breakpoint = bool(stop_at_breakpoint.default)
4955

5056
# Record analytics event (fire-and-forget)
5157
record_event("replay_started", session_id=session_id, properties={"mode": mode})
@@ -72,10 +78,10 @@ async def replay_session(
7278
checkpoints,
7379
mode=mode,
7480
focus_event_id=focus_event_id,
75-
breakpoint_event_types={item for item in (breakpoint_event_types or "").split(",") if item},
76-
breakpoint_tool_names={item for item in (breakpoint_tool_names or "").split(",") if item},
81+
breakpoint_event_types=_split_csv_param(breakpoint_event_types),
82+
breakpoint_tool_names=_split_csv_param(breakpoint_tool_names),
7783
breakpoint_confidence_below=breakpoint_confidence_below,
78-
breakpoint_safety_outcomes={item for item in (breakpoint_safety_outcomes or "").split(",") if item},
84+
breakpoint_safety_outcomes=_split_csv_param(breakpoint_safety_outcomes),
7985
)
8086

8187
# Handle segment collapsing for highlights mode
@@ -98,8 +104,10 @@ async def replay_session(
98104
stopped_at_breakpoint = True
99105
# Build O(1) event_id -> index map for efficient breakpoint lookup
100106
event_id_to_index = {event.get("id"): i for i, event in enumerate(replay_data["events"])}
101-
first_breakpoint_id = replay_data["breakpoints"][0].get("id")
102-
stopped_at_index = event_id_to_index.get(first_breakpoint_id)
107+
for breakpoint_event in replay_data["breakpoints"]:
108+
stopped_at_index = event_id_to_index.get(breakpoint_event.get("id"))
109+
if stopped_at_index is not None:
110+
break
103111

104112
return ReplayResponse(
105113
session_id=session_id,

api/schemas.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,11 @@ class TraceEventSchema(BaseModel):
7070
outcome: SafetyOutcome | None = None
7171
risk_level: RiskLevel | None = None
7272
rationale: str | None = None
73+
attempted_fix: str | None = None
74+
validation_result: str | None = None
75+
repair_outcome: str | None = None
76+
repair_sequence_id: str | None = None
77+
repair_diff: str | None = None
7378
blocked_action: str | None = None
7479
reason: str | None = None
7580
safe_alternative: str | None = None

api/services.py

Lines changed: 109 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import logging
88
from typing import Any
99

10+
from sqlalchemy import String, cast, or_, select
1011
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
1112

1213
from agent_debugger_sdk.core.events import Checkpoint, EventType, Session, SessionStatus, TraceEvent
@@ -17,9 +18,12 @@
1718
from collector.intelligence.facade import TraceIntelligence
1819
from redaction.pipeline import RedactionPipeline
1920
from storage import TraceRepository
21+
from storage.converters import orm_to_event, orm_to_session
22+
from storage.models import EventModel, SessionModel
2023

2124
logger = logging.getLogger(__name__)
2225
SESSION_ANALYSIS_CAP = 100
26+
FAILURE_SIMILARITY_THRESHOLD = 0.5
2327

2428

2529
def normalize_session(
@@ -400,69 +404,97 @@ async def find_similar_failures(
400404
# Get the failure event
401405
failure_event = await repo.get_event(failure_event_id)
402406
if not failure_event:
403-
return []
407+
raise NotFoundError(f"Failure event {failure_event_id} not found")
408+
if failure_event.session_id != session_id:
409+
raise NotFoundError(
410+
f"Failure event {failure_event_id} was not found in session {session_id}"
411+
)
404412

405413
# Determine failure characteristics
406-
error_text = failure_event.error or failure_event.error_message or failure_event.name or ""
407-
error_type = failure_event.error_type or ""
408-
409-
# Get all sessions with failures
410-
all_sessions = await repo.list_sessions(limit=500, offset=0, sort_by="started_at")
411-
412-
similar_failures: list[dict[str, Any]] = []
413-
414-
for session in all_sessions:
415-
# Skip the current session
416-
if session.id == session_id:
417-
continue
418-
419-
# Skip sessions without errors
420-
if session.errors == 0:
421-
continue
422-
423-
# Get events from this session to find matching failures
424-
try:
425-
session_events = await repo.list_events(session.id, limit=1000)
426-
except Exception:
414+
error_text = _event_error_text(failure_event)
415+
error_type = _event_error_type(failure_event)
416+
candidate_failures = await _load_candidate_failure_events(repo, failure_event, session_id)
417+
418+
best_match_by_session: dict[str, dict[str, Any]] = {}
419+
420+
for event, session in candidate_failures:
421+
similarity = _calculate_failure_similarity(
422+
failure_event,
423+
event,
424+
error_text,
425+
error_type,
426+
)
427+
if similarity < FAILURE_SIMILARITY_THRESHOLD:
427428
continue
428429

429-
# Find failure events in this session
430-
for event in session_events:
431-
if not _is_failure_event(event):
432-
continue
433-
434-
# Calculate similarity score
435-
similarity = _calculate_failure_similarity(
436-
failure_event,
437-
event,
438-
error_text,
439-
error_type,
440-
)
441-
442-
# Only include reasonably similar failures
443-
if similarity >= 0.3:
444-
# Derive failure mode and root cause
445-
failure_mode = _derive_failure_mode(event)
446-
root_cause = _derive_root_cause(event)
447-
448-
similar_failures.append({
449-
"session_id": session.id,
450-
"agent_name": session.agent_name,
451-
"framework": session.framework,
452-
"started_at": session.started_at,
453-
"failure_type": str(event.event_type),
454-
"failure_mode": failure_mode,
455-
"root_cause": root_cause,
456-
"similarity": similarity,
457-
"fix_note": session.fix_note,
458-
})
459-
break # Only add one failure per session
430+
failure_summary = {
431+
"session_id": session.id,
432+
"agent_name": session.agent_name,
433+
"framework": session.framework,
434+
"started_at": session.started_at,
435+
"failure_type": str(event.event_type),
436+
"failure_mode": _derive_failure_mode(event),
437+
"root_cause": _derive_root_cause(event),
438+
"similarity": similarity,
439+
"fix_note": session.fix_note,
440+
}
441+
existing = best_match_by_session.get(session.id)
442+
if existing is None or failure_summary["similarity"] > existing["similarity"]:
443+
best_match_by_session[session.id] = failure_summary
460444

461445
# Sort by similarity and limit
446+
similar_failures = list(best_match_by_session.values())
462447
similar_failures.sort(key=lambda x: x["similarity"], reverse=True)
463448
return similar_failures[:limit]
464449

465450

451+
async def _load_candidate_failure_events(
452+
repo: TraceRepository,
453+
failure_event: TraceEvent,
454+
session_id: str,
455+
) -> list[tuple[TraceEvent, Session]]:
456+
"""Load tenant-scoped failure candidates without per-session N+1 queries."""
457+
failure_event_types = [
458+
str(EventType.ERROR),
459+
str(EventType.REFUSAL),
460+
str(EventType.POLICY_VIOLATION),
461+
str(EventType.BEHAVIOR_ALERT),
462+
str(EventType.TOOL_RESULT),
463+
str(EventType.SAFETY_CHECK),
464+
]
465+
466+
source_clues = [EventModel.event_type == str(failure_event.event_type)]
467+
source_error_type = _event_error_type(failure_event)
468+
if source_error_type:
469+
source_clues.append(cast(EventModel.data, String).ilike(f"%{source_error_type}%"))
470+
source_tool_name = getattr(failure_event, "tool_name", None)
471+
if source_tool_name:
472+
source_clues.append(cast(EventModel.data, String).ilike(f"%{source_tool_name}%"))
473+
474+
stmt = (
475+
select(EventModel, SessionModel)
476+
.join(SessionModel, EventModel.session_id == SessionModel.id)
477+
.where(
478+
SessionModel.tenant_id == repo.tenant_id,
479+
EventModel.tenant_id == repo.tenant_id,
480+
SessionModel.id != session_id,
481+
SessionModel.errors > 0,
482+
EventModel.event_type.in_(failure_event_types),
483+
or_(*source_clues),
484+
)
485+
.order_by(SessionModel.started_at.desc(), EventModel.timestamp.desc())
486+
)
487+
result = await repo.session.execute(stmt)
488+
489+
candidates: list[tuple[TraceEvent, Session]] = []
490+
for db_event, db_session in result.all():
491+
event = orm_to_event(db_event)
492+
if not _is_failure_event(event):
493+
continue
494+
candidates.append((event, orm_to_session(db_session)))
495+
return candidates
496+
497+
466498
def _is_failure_event(event: TraceEvent) -> bool:
467499
"""Check if an event represents a failure."""
468500
return (
@@ -475,6 +507,28 @@ def _is_failure_event(event: TraceEvent) -> bool:
475507
)
476508

477509

510+
def _event_error_text(event: TraceEvent) -> str:
511+
"""Return the most useful error-like text available on an event."""
512+
return (
513+
getattr(event, "error", None)
514+
or getattr(event, "error_message", None)
515+
or getattr(event, "reason", None)
516+
or event.name
517+
or ""
518+
)
519+
520+
521+
def _event_error_type(event: TraceEvent) -> str:
522+
"""Return the most useful error-like type available on an event."""
523+
return (
524+
getattr(event, "error_type", None)
525+
or getattr(event, "violation_type", None)
526+
or getattr(event, "alert_type", None)
527+
or ""
528+
)
529+
530+
531+
478532
def _calculate_failure_similarity(
479533
source_event: TraceEvent,
480534
candidate_event: TraceEvent,
@@ -492,13 +546,13 @@ def _calculate_failure_similarity(
492546
score += 0.4
493547

494548
# Error type match
495-
candidate_error_type = candidate_event.error_type or ""
549+
candidate_error_type = _event_error_type(candidate_event)
496550
if source_error_type and candidate_error_type:
497551
if source_error_type.lower() == candidate_error_type.lower():
498552
score += 0.3
499553

500554
# Error text similarity (simple keyword overlap)
501-
candidate_error_text = candidate_event.error or candidate_event.error_message or candidate_event.name or ""
555+
candidate_error_text = _event_error_text(candidate_event)
502556
if source_error_text and candidate_error_text:
503557
source_words = set(source_error_text.lower().split())
504558
candidate_words = set(candidate_error_text.lower().split())

collector/replay.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,9 +243,10 @@ def build_replay(
243243
breakpoint_tool_names = breakpoint_tool_names or set()
244244
breakpoint_safety_outcomes = breakpoint_safety_outcomes or set()
245245

246+
breakpoint_source_events = replay_events if mode == "focus" else replay_window_events
246247
breakpoints = [
247248
event.to_dict()
248-
for event in replay_window_events
249+
for event in breakpoint_source_events
249250
if matches_breakpoint(
250251
event,
251252
event_types=breakpoint_event_types,

0 commit comments

Comments
 (0)