Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 14 additions & 83 deletions tests/test_feature_2_failure_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,63 +8,18 @@

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any
from unittest.mock import MagicMock

import pytest

from agent_debugger_sdk.core.events import EventType, TraceEvent

# Note: Feature 2 (failure_memory) not yet implemented
pytestmark = pytest.mark.skip(reason="Feature 2 (failure_memory) not yet implemented")


# =============================================================================
# Custom Exceptions
# =============================================================================


class EmbeddingGenerationError(Exception):
"""Raised when embedding generation fails."""

pass


# =============================================================================
# Mock Data Classes for FailureMemory
# =============================================================================


@dataclass
class FailureSignature:
"""Signature extracted from a failure event for embedding."""

error_type: str
error_message: str
tool_name: str | None = None
session_id: str | None = None
additional_context: dict[str, Any] = field(default_factory=dict)

def to_text(self) -> str:
"""Convert signature to text for embedding."""
parts = [f"Error: {self.error_type}", f"Message: {self.error_message}"]
if self.tool_name:
parts.append(f"Tool: {self.tool_name}")
return " | ".join(parts)


@dataclass
class SimilarFailureMatch:
"""A match from the failure memory search."""

failure_id: str
similarity_score: float
signature: FailureSignature
fix_applied: str | None = None
occurrence_count: int = 1
session_id: str | None = None

from collector.failure_memory import (
EmbeddingGenerationError,
FailureMemory,
FailureSignature,
SimilarFailureMatch,
)

# =============================================================================
# Fixtures
Expand Down Expand Up @@ -168,7 +123,7 @@ def mock_embedding_model():
def mock_vector_db():
"""Mock vector database for failure memory storage."""
db = MagicMock()
# Default: empty results
# Default: empty results (flat-list format)
db.query.return_value = {"ids": [], "distances": [], "metadatas": []}
db.add.return_value = None
return db
Expand All @@ -185,8 +140,6 @@ class TestFailureMemoryHappyPath:
def test_remember_failure_stores_embedding(self, make_error_event, mock_embedding_model, mock_vector_db):
"""Storing a failure should generate an embedding and call vector_db.add."""
# Arrange
from collector.failure_memory import FailureMemory

memory = FailureMemory(embedding_model=mock_embedding_model, vector_db=mock_vector_db)
error_event = make_error_event(
error_type="TimeoutError",
Expand All @@ -204,17 +157,15 @@ def test_remember_failure_stores_embedding(self, make_error_event, mock_embeddin

mock_vector_db.add.assert_called_once()
add_args = mock_vector_db.add.call_args
assert add_args[1]["metadata"]["error_type"] == "TimeoutError"
assert add_args[1]["metadata"]["fix_applied"] == "Added retry logic"
assert add_args[1]["metadatas"][0]["error_type"] == "TimeoutError"
assert add_args[1]["metadatas"][0]["fix_applied"] == "Added retry logic"

def test_search_similar_returns_matches(self, make_error_event, mock_embedding_model, mock_vector_db):
"""Searching for similar failures should return a ranked list with scores."""
# Arrange
from collector.failure_memory import FailureMemory

memory = FailureMemory(embedding_model=mock_embedding_model, vector_db=mock_vector_db)

# Mock the query to return some matches
# Mock the query to return some matches (flat-list format)
mock_vector_db.query.return_value = {
"ids": ["fail-1", "fail-2"],
"distances": [0.15, 0.35], # Lower distance = higher similarity
Expand Down Expand Up @@ -254,8 +205,6 @@ def test_search_similar_returns_matches(self, make_error_event, mock_embedding_m
def test_search_includes_fix_information(self, make_error_event, mock_embedding_model, mock_vector_db):
"""Search results should include the fix that was previously applied."""
# Arrange
from collector.failure_memory import FailureMemory

memory = FailureMemory(embedding_model=mock_embedding_model, vector_db=mock_vector_db)

mock_vector_db.query.return_value = {
Expand Down Expand Up @@ -287,8 +236,6 @@ def test_search_includes_fix_information(self, make_error_event, mock_embedding_
def test_failure_signature_extracts_key_fields(self, make_error_event):
"""The failure signature should extract error type and message."""
# Arrange
from collector.failure_memory import FailureMemory

error_event = make_error_event(
error_type="KeyError",
error_message="Missing required key 'user_id'",
Expand All @@ -299,6 +246,7 @@ def test_failure_signature_extracts_key_fields(self, make_error_event):
signature = FailureMemory.extract_signature(error_event)

# Assert
assert isinstance(signature, FailureSignature)
assert signature.error_type == "KeyError"
assert "Missing required key" in signature.error_message
assert signature.tool_name == "get_user"
Expand All @@ -321,8 +269,6 @@ class TestFailureMemoryEdgeCases:
def test_empty_memory_returns_empty_list(self, make_error_event, mock_embedding_model, mock_vector_db):
"""An empty vector DB should return an empty list, not an error."""
# Arrange
from collector.failure_memory import FailureMemory

memory = FailureMemory(embedding_model=mock_embedding_model, vector_db=mock_vector_db)

# Mock empty response
Expand All @@ -344,8 +290,6 @@ def test_empty_memory_returns_empty_list(self, make_error_event, mock_embedding_
def test_low_similarity_excluded(self, make_error_event, mock_embedding_model, mock_vector_db):
"""Results below the similarity threshold should be excluded."""
# Arrange
from collector.failure_memory import FailureMemory

memory = FailureMemory(embedding_model=mock_embedding_model, vector_db=mock_vector_db)

# Mock results with varying distances
Expand All @@ -372,11 +316,9 @@ def test_low_similarity_excluded(self, make_error_event, mock_embedding_model, m
def test_duplicate_failures_update_existing(self, make_error_event, mock_embedding_model, mock_vector_db):
"""Storing the same failure again should update the occurrence count."""
# Arrange
from collector.failure_memory import FailureMemory

memory = FailureMemory(embedding_model=mock_embedding_model, vector_db=mock_vector_db)

# Mock that the failure already exists
# Mock that the failure already exists (flat-list format, single result)
mock_vector_db.query.return_value = {
"ids": ["existing-fail-1"],
"distances": [0.01], # Very similar
Expand All @@ -399,10 +341,9 @@ def test_duplicate_failures_update_existing(self, make_error_event, mock_embeddi
memory.remember_failure(error_event)

# Assert - should update existing, not add new
# The implementation should call update with incremented count
mock_vector_db.update.assert_called_once()
update_args = mock_vector_db.update.call_args
assert update_args[1]["metadata"]["occurrence_count"] == 3
assert update_args[1]["metadatas"][0]["occurrence_count"] == 3

def test_session_without_error_skipped(
self,
Expand All @@ -413,8 +354,6 @@ def test_session_without_error_skipped(
):
"""A session without an error event should not be stored in memory."""
# Arrange
from collector.failure_memory import FailureMemory

memory = FailureMemory(embedding_model=mock_embedding_model, vector_db=mock_vector_db)

decision = make_decision_event()
Expand All @@ -439,8 +378,6 @@ class TestFailureMemoryErrorHandling:
def test_embedding_failure_returns_graceful_error(self, make_error_event, mock_embedding_model, mock_vector_db):
"""If embedding generation fails, EmbeddingGenerationError should be raised."""
# Arrange
from collector.failure_memory import FailureMemory

mock_embedding_model.encode.side_effect = RuntimeError("Model not loaded")
memory = FailureMemory(embedding_model=mock_embedding_model, vector_db=mock_vector_db)

Expand All @@ -455,8 +392,6 @@ def test_embedding_failure_returns_graceful_error(self, make_error_event, mock_e
def test_vector_db_unavailable_returns_empty(self, make_error_event, mock_embedding_model, mock_vector_db):
"""If vector DB connection fails, search should return empty list."""
# Arrange
from collector.failure_memory import FailureMemory

mock_vector_db.query.side_effect = ConnectionError("Vector DB unavailable")
memory = FailureMemory(embedding_model=mock_embedding_model, vector_db=mock_vector_db)

Expand All @@ -471,8 +406,6 @@ def test_vector_db_unavailable_returns_empty(self, make_error_event, mock_embedd
def test_malformed_metadata_handled(self, make_error_event, mock_embedding_model, mock_vector_db):
"""Malformed or None metadata in results should not crash the search."""
# Arrange
from collector.failure_memory import FailureMemory

memory = FailureMemory(embedding_model=mock_embedding_model, vector_db=mock_vector_db)

# Mock results with malformed metadata
Expand All @@ -481,7 +414,7 @@ def test_malformed_metadata_handled(self, make_error_event, mock_embedding_model
"distances": [0.1, 0.2],
"metadatas": [
None, # Malformed: None instead of dict
{"error_type": "Error", "error_message": "Msg"}, # Missing fields
{"error_type": "Error", "error_message": "Msg"}, # Missing optional fields
],
}

Expand Down Expand Up @@ -515,8 +448,6 @@ def test_link_to_why_button_analysis(
):
"""Failure memory should be queryable from Why button results."""
# Arrange
from collector.failure_memory import FailureMemory

memory = FailureMemory(embedding_model=mock_embedding_model, vector_db=mock_vector_db)

# Set up a previous similar failure
Expand Down
Loading