|
1 | 1 | """Unit tests for pure functions in memory.py.""" |
2 | 2 |
|
| 3 | +import hashlib |
| 4 | +from unittest.mock import MagicMock, patch |
| 5 | + |
3 | 6 | import pytest |
4 | 7 |
|
5 | | -from memory import _validate_repo |
| 8 | +from memory import ( |
| 9 | + _SCHEMA_VERSION, |
| 10 | + MEMORY_SOURCE_TYPES, |
| 11 | + _validate_repo, |
| 12 | + write_repo_learnings, |
| 13 | + write_task_episode, |
| 14 | +) |
| 15 | +from sanitization import sanitize_external_content |
6 | 16 |
|
7 | 17 |
|
8 | 18 | class TestValidateRepo: |
@@ -34,3 +44,83 @@ def test_invalid_spaces(self): |
34 | 44 | def test_invalid_empty(self): |
35 | 45 | with pytest.raises(ValueError, match="does not match"): |
36 | 46 | _validate_repo("") |
| 47 | + |
| 48 | + |
| 49 | +class TestSchemaVersion: |
| 50 | + def test_schema_version_is_3(self): |
| 51 | + assert _SCHEMA_VERSION == "3" |
| 52 | + |
| 53 | + |
| 54 | +class TestMemorySourceTypes: |
| 55 | + def test_contains_expected_values(self): |
| 56 | + assert {"agent_episode", "agent_learning", "orchestrator_fallback"} == MEMORY_SOURCE_TYPES |
| 57 | + |
| 58 | + def test_is_frozen(self): |
| 59 | + assert isinstance(MEMORY_SOURCE_TYPES, frozenset) |
| 60 | + |
| 61 | + |
| 62 | +class TestWriteTaskEpisode: |
| 63 | + @patch("memory._get_client") |
| 64 | + def test_includes_source_type_in_metadata(self, mock_get_client): |
| 65 | + mock_client = MagicMock() |
| 66 | + mock_get_client.return_value = mock_client |
| 67 | + |
| 68 | + write_task_episode("mem-1", "owner/repo", "task-1", "COMPLETED") |
| 69 | + |
| 70 | + call_kwargs = mock_client.create_event.call_args[1] |
| 71 | + metadata = call_kwargs["metadata"] |
| 72 | + assert metadata["source_type"] == {"stringValue": "agent_episode"} |
| 73 | + assert metadata["source_type"]["stringValue"] in MEMORY_SOURCE_TYPES |
| 74 | + assert metadata["schema_version"] == {"stringValue": "3"} |
| 75 | + |
| 76 | + @patch("memory._get_client") |
| 77 | + def test_content_sha256_matches_sanitized_content(self, mock_get_client): |
| 78 | + mock_client = MagicMock() |
| 79 | + mock_get_client.return_value = mock_client |
| 80 | + |
| 81 | + write_task_episode("mem-1", "owner/repo", "task-1", "COMPLETED") |
| 82 | + |
| 83 | + call_kwargs = mock_client.create_event.call_args[1] |
| 84 | + metadata = call_kwargs["metadata"] |
| 85 | + assert "content_sha256" in metadata |
| 86 | + hash_value = metadata["content_sha256"]["stringValue"] |
| 87 | + assert len(hash_value) == 64 |
| 88 | + |
| 89 | + # Verify hash matches the sanitized content that was actually stored |
| 90 | + content = call_kwargs["payload"][0]["conversational"]["content"]["text"] |
| 91 | + sanitized = sanitize_external_content(content) |
| 92 | + expected = hashlib.sha256(sanitized.encode("utf-8")).hexdigest() |
| 93 | + assert hash_value == expected |
| 94 | + |
| 95 | + |
| 96 | +class TestWriteRepoLearnings: |
| 97 | + @patch("memory._get_client") |
| 98 | + def test_includes_source_type_in_metadata(self, mock_get_client): |
| 99 | + mock_client = MagicMock() |
| 100 | + mock_get_client.return_value = mock_client |
| 101 | + |
| 102 | + write_repo_learnings("mem-1", "owner/repo", "task-1", "Use Jest for tests") |
| 103 | + |
| 104 | + call_kwargs = mock_client.create_event.call_args[1] |
| 105 | + metadata = call_kwargs["metadata"] |
| 106 | + assert metadata["source_type"] == {"stringValue": "agent_learning"} |
| 107 | + assert metadata["source_type"]["stringValue"] in MEMORY_SOURCE_TYPES |
| 108 | + assert metadata["schema_version"] == {"stringValue": "3"} |
| 109 | + |
| 110 | + @patch("memory._get_client") |
| 111 | + def test_content_sha256_matches_sanitized_content(self, mock_get_client): |
| 112 | + mock_client = MagicMock() |
| 113 | + mock_get_client.return_value = mock_client |
| 114 | + |
| 115 | + write_repo_learnings("mem-1", "owner/repo", "task-1", "Use Jest for tests") |
| 116 | + |
| 117 | + call_kwargs = mock_client.create_event.call_args[1] |
| 118 | + metadata = call_kwargs["metadata"] |
| 119 | + assert "content_sha256" in metadata |
| 120 | + hash_value = metadata["content_sha256"]["stringValue"] |
| 121 | + assert len(hash_value) == 64 |
| 122 | + |
| 123 | + content = call_kwargs["payload"][0]["conversational"]["content"]["text"] |
| 124 | + sanitized = sanitize_external_content(content) |
| 125 | + expected = hashlib.sha256(sanitized.encode("utf-8")).hexdigest() |
| 126 | + assert hash_value == expected |
0 commit comments