Skip to content

Commit d73d0a2

Browse files
Aayush KatariaAayush Kataria
authored andcommitted
Code refactoring, async code fixes and error handling
1 parent 6a68b59 commit d73d0a2

12 files changed

Lines changed: 494 additions & 267 deletions

File tree

.coverage

-52 KB
Binary file not shown.

.github/workflows/ci.yml

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
name: CI
2+
3+
on:
4+
push:
5+
branches: [main]
6+
pull_request:
7+
branches: [main]
8+
9+
permissions:
10+
contents: read
11+
12+
jobs:
13+
lint:
14+
runs-on: ubuntu-latest
15+
strategy:
16+
matrix:
17+
python-version: ["3.11", "3.12", "3.13"]
18+
steps:
19+
- uses: actions/checkout@v4
20+
- uses: actions/setup-python@v5
21+
with:
22+
python-version: ${{ matrix.python-version }}
23+
- name: Install dependencies
24+
run: pip install ruff
25+
- name: Ruff check
26+
run: ruff check agent_memory_toolkit/ tests/
27+
- name: Ruff format check
28+
run: ruff format --check agent_memory_toolkit/ tests/
29+
30+
test:
31+
runs-on: ubuntu-latest
32+
strategy:
33+
matrix:
34+
python-version: ["3.11", "3.12", "3.13"]
35+
steps:
36+
- uses: actions/checkout@v4
37+
- uses: actions/setup-python@v5
38+
with:
39+
python-version: ${{ matrix.python-version }}
40+
- name: Install package with dev dependencies
41+
run: pip install -e ".[dev]"
42+
- name: Run unit tests with coverage
43+
run: pytest tests/unit/ --cov=agent_memory_toolkit --cov-report=xml --cov-report=term-missing -v
44+
- name: Upload coverage
45+
if: always()
46+
uses: actions/upload-artifact@v4
47+
with:
48+
name: coverage-report-${{ matrix.python-version }}
49+
path: coverage.xml

agent_memory_toolkit/_utils.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
"""Shared utilities for the Agent Memory Toolkit.
2+
3+
Houses helpers used by both the sync and async clients to avoid
4+
duplication and hidden cross-module coupling.
5+
"""
6+
7+
from __future__ import annotations
8+
9+
import os
10+
import uuid
11+
from datetime import datetime, timezone
12+
from typing import Any, Optional
13+
14+
from ._query_builder import _QueryBuilder
15+
from .exceptions import ConfigurationError, MemoryNotFoundError, ValidationError
16+
17+
# ---------------------------------------------------------------------------
18+
# Validation constants
19+
# ---------------------------------------------------------------------------
20+
21+
VALID_ROLES = {"agent", "user", "tool", "system"}
22+
VALID_TYPES = {"turn", "summary", "fact", "user_summary"}
23+
24+
25+
# ---------------------------------------------------------------------------
26+
# Memory factory
27+
# ---------------------------------------------------------------------------
28+
29+
30+
def _make_memory(
31+
user_id: str,
32+
role: str,
33+
content: str,
34+
memory_type: str = "turn",
35+
agent_id: Optional[str] = None,
36+
metadata: Optional[dict[str, Any]] = None,
37+
memory_id: Optional[str] = None,
38+
thread_id: Optional[str] = None,
39+
) -> dict[str, Any]:
40+
"""Create a validated memory dict."""
41+
if role not in VALID_ROLES:
42+
raise ValidationError(f"role must be one of {VALID_ROLES}, got '{role}'")
43+
if memory_type not in VALID_TYPES:
44+
raise ValidationError(f"type must be one of {VALID_TYPES}, got '{memory_type}'")
45+
46+
memory: dict[str, Any] = {
47+
"id": memory_id or str(uuid.uuid4()),
48+
"user_id": user_id,
49+
"thread_id": thread_id or str(uuid.uuid4()),
50+
"role": role,
51+
"type": memory_type,
52+
"content": content,
53+
"metadata": metadata or {},
54+
"created_at": datetime.now(timezone.utc).isoformat(),
55+
}
56+
57+
if agent_id is not None:
58+
memory["agent_id"] = agent_id
59+
60+
return memory
61+
62+
63+
def _resolve_embedding_dimensions(val: Optional[int]) -> Optional[int]:
64+
"""Resolve embedding dimensions from explicit value or ``EMBEDDING_DIMENSIONS`` env var."""
65+
if val is not None:
66+
return val
67+
raw = os.environ.get("EMBEDDING_DIMENSIONS", "0") or "0"
68+
parsed = int(raw)
69+
return parsed if parsed else None
70+
71+
72+
# ---------------------------------------------------------------------------
73+
# Connection / query helpers (shared by sync & async Cosmos clients)
74+
# ---------------------------------------------------------------------------
75+
76+
77+
def _validate_connection(
78+
endpoint: str | None,
79+
credential: Any,
80+
database: str,
81+
container: str,
82+
) -> None:
83+
"""Raise :class:`ConfigurationError` if any required field is missing."""
84+
if not endpoint:
85+
raise ConfigurationError(parameter="endpoint")
86+
if not credential:
87+
raise ConfigurationError(parameter="credential")
88+
if not database:
89+
raise ConfigurationError(parameter="database")
90+
if not container:
91+
raise ConfigurationError(parameter="container")
92+
93+
94+
def _build_memory_query_builder(
95+
*,
96+
memory_id: Optional[str] = None,
97+
user_id: Optional[str] = None,
98+
thread_id: Optional[str] = None,
99+
role: Optional[str] = None,
100+
memory_type: Optional[str] = None,
101+
) -> _QueryBuilder:
102+
"""Return a :class:`_QueryBuilder` pre-loaded with the standard filters."""
103+
qb = _QueryBuilder()
104+
qb.add_filter("c.id", "@memory_id", memory_id)
105+
qb.add_filter("c.user_id", "@user_id", user_id)
106+
qb.add_filter("c.thread_id", "@thread_id", thread_id)
107+
qb.add_filter("c.role", "@role", role)
108+
qb.add_filter("c.type", "@memory_type", memory_type)
109+
return qb
110+
111+
112+
def _container_policies(
113+
*,
114+
embedding_dimensions: int,
115+
embedding_data_type: str,
116+
distance_function: str,
117+
full_text_language: str,
118+
) -> tuple[dict, dict, dict]:
119+
"""Build the vector, indexing, and full-text policies for container creation."""
120+
vector_embedding_policy = {
121+
"vectorEmbeddings": [
122+
{
123+
"path": "/embedding",
124+
"dataType": embedding_data_type,
125+
"distanceFunction": distance_function,
126+
"dimensions": embedding_dimensions,
127+
}
128+
]
129+
}
130+
131+
indexing_policy = {
132+
"includedPaths": [{"path": "/*"}],
133+
"excludedPaths": [{"path": "/embedding/*"}],
134+
"vectorIndexes": [{"path": "/embedding", "type": "quantizedFlat"}],
135+
"fullTextIndexes": [{"path": "/content"}],
136+
}
137+
138+
full_text_policy = {
139+
"defaultLanguage": full_text_language,
140+
"fullTextPaths": [{"path": "/content", "language": full_text_language}],
141+
}
142+
143+
return vector_embedding_policy, indexing_policy, full_text_policy
144+
145+
146+
def _validate_hybrid_search(
147+
hybrid_search: bool,
148+
search_terms: Optional[str],
149+
) -> None:
150+
"""Raise :class:`ValidationError` if hybrid search is requested without search terms."""
151+
if hybrid_search and not search_terms:
152+
raise ValidationError(
153+
"search_terms is required when hybrid_search is True"
154+
)

agent_memory_toolkit/aio/cosmos_memory_client.py

Lines changed: 31 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,13 @@
1515
from typing import Any, Optional
1616

1717
from agent_memory_toolkit._query_builder import _QueryBuilder
18+
from agent_memory_toolkit._utils import (
19+
_build_memory_query_builder,
20+
_container_policies,
21+
_validate_connection,
22+
_validate_hybrid_search,
23+
)
1824
from agent_memory_toolkit.exceptions import (
19-
ConfigurationError,
2025
CosmosNotConnectedError,
2126
CosmosOperationError,
2227
MemoryNotFoundError,
@@ -26,29 +31,6 @@
2631
logger = logging.getLogger(__name__)
2732

2833

29-
# ---------------------------------------------------------------------------
30-
# Helpers
31-
# ---------------------------------------------------------------------------
32-
33-
34-
def _build_memory_query_builder(
35-
*,
36-
memory_id: Optional[str] = None,
37-
user_id: Optional[str] = None,
38-
thread_id: Optional[str] = None,
39-
role: Optional[str] = None,
40-
memory_type: Optional[str] = None,
41-
) -> _QueryBuilder:
42-
"""Return a :class:`_QueryBuilder` pre-loaded with standard filters."""
43-
qb = _QueryBuilder()
44-
qb.add_filter("c.id", "@memory_id", memory_id)
45-
qb.add_filter("c.user_id", "@user_id", user_id)
46-
qb.add_filter("c.thread_id", "@thread_id", thread_id)
47-
qb.add_filter("c.role", "@role", role)
48-
qb.add_filter("c.type", "@memory_type", memory_type)
49-
return qb
50-
51-
5234
# ---------------------------------------------------------------------------
5335
# Async client
5436
# ---------------------------------------------------------------------------
@@ -102,19 +84,21 @@ async def connect(self) -> None:
10284
CosmosOperationError
10385
If the connection fails.
10486
"""
105-
if not self._endpoint:
106-
raise ConfigurationError(parameter="endpoint")
107-
if not self._credential:
108-
raise ConfigurationError(parameter="credential")
87+
_validate_connection(
88+
self._endpoint, self._credential, self._database, self._container
89+
)
10990

11091
try:
11192
from azure.cosmos.aio import CosmosClient
11293

113-
self._cosmos_client = CosmosClient(
94+
client = CosmosClient(
11495
self._endpoint, credential=self._credential
11596
)
116-
db = self._cosmos_client.get_database_client(self._database)
117-
self._container_client = db.get_container_client(self._container)
97+
db = client.get_database_client(self._database)
98+
container = db.get_container_client(self._container)
99+
100+
self._cosmos_client = client
101+
self._container_client = container
118102
except Exception as exc:
119103
raise CosmosOperationError(
120104
f"Failed to connect to Cosmos DB (async): {exc}"
@@ -143,62 +127,44 @@ async def create_store(
143127
* Full-text index on ``/content``
144128
* Autoscale throughput (max RU)
145129
"""
146-
if not self._endpoint:
147-
raise ConfigurationError(parameter="endpoint")
148-
if not self._credential:
149-
raise ConfigurationError(parameter="credential")
130+
_validate_connection(
131+
self._endpoint, self._credential, self._database, self._container
132+
)
150133

151134
try:
152135
from azure.cosmos import PartitionKey, ThroughputProperties
153136
from azure.cosmos.aio import CosmosClient
154137

155-
self._cosmos_client = CosmosClient(
138+
client = CosmosClient(
156139
self._endpoint, credential=self._credential
157140
)
158141

159-
db = await self._cosmos_client.create_database_if_not_exists(
142+
db = await client.create_database_if_not_exists(
160143
id=self._database
161144
)
162145

163146
partition_key = PartitionKey(
164147
path=["/user_id", "/thread_id"], kind="MultiHash"
165148
)
166149

167-
vector_embedding_policy = {
168-
"vectorEmbeddings": [
169-
{
170-
"path": "/embedding",
171-
"dataType": embedding_data_type,
172-
"distanceFunction": distance_function,
173-
"dimensions": embedding_dimensions,
174-
}
175-
]
176-
}
177-
178-
indexing_policy = {
179-
"includedPaths": [{"path": "/*"}],
180-
"excludedPaths": [{"path": "/embedding/*"}],
181-
"vectorIndexes": [{"path": "/embedding", "type": "quantizedFlat"}],
182-
"fullTextIndexes": [{"path": "/content"}],
183-
}
184-
185-
full_text_policy = {
186-
"defaultLanguage": full_text_language,
187-
"fullTextPaths": [
188-
{"path": "/content", "language": full_text_language}
189-
],
190-
}
150+
vec_policy, idx_policy, ft_policy = _container_policies(
151+
embedding_dimensions=embedding_dimensions,
152+
embedding_data_type=embedding_data_type,
153+
distance_function=distance_function,
154+
full_text_language=full_text_language,
155+
)
191156

192157
container = await db.create_container_if_not_exists(
193158
id=self._container,
194159
partition_key=partition_key,
195-
indexing_policy=indexing_policy,
196-
vector_embedding_policy=vector_embedding_policy,
197-
full_text_policy=full_text_policy,
160+
indexing_policy=idx_policy,
161+
vector_embedding_policy=vec_policy,
162+
full_text_policy=ft_policy,
198163
offer_throughput=ThroughputProperties(
199164
auto_scale_max_throughput=autoscale_max_ru,
200165
),
201166
)
167+
self._cosmos_client = client
202168
self._container_client = container
203169
except Exception as exc:
204170
raise CosmosOperationError(
@@ -478,6 +444,7 @@ async def vector_search(
478444
Required when *hybrid_search* is ``True``.
479445
"""
480446
self._require_connected()
447+
_validate_hybrid_search(hybrid_search, search_terms)
481448

482449
qb = _build_memory_query_builder(
483450
user_id=user_id, role=role, memory_type=memory_type, thread_id=thread_id

0 commit comments

Comments
 (0)