|
15 | 15 | from typing import Any, Optional |
16 | 16 |
|
17 | 17 | from agent_memory_toolkit._query_builder import _QueryBuilder |
| 18 | +from agent_memory_toolkit._utils import ( |
| 19 | + _build_memory_query_builder, |
| 20 | + _container_policies, |
| 21 | + _validate_connection, |
| 22 | + _validate_hybrid_search, |
| 23 | +) |
18 | 24 | from agent_memory_toolkit.exceptions import ( |
19 | | - ConfigurationError, |
20 | 25 | CosmosNotConnectedError, |
21 | 26 | CosmosOperationError, |
22 | 27 | MemoryNotFoundError, |
|
26 | 31 | logger = logging.getLogger(__name__) |
27 | 32 |
|
28 | 33 |
|
29 | | -# --------------------------------------------------------------------------- |
30 | | -# Helpers |
31 | | -# --------------------------------------------------------------------------- |
32 | | - |
33 | | - |
34 | | -def _build_memory_query_builder( |
35 | | - *, |
36 | | - memory_id: Optional[str] = None, |
37 | | - user_id: Optional[str] = None, |
38 | | - thread_id: Optional[str] = None, |
39 | | - role: Optional[str] = None, |
40 | | - memory_type: Optional[str] = None, |
41 | | -) -> _QueryBuilder: |
42 | | - """Return a :class:`_QueryBuilder` pre-loaded with standard filters.""" |
43 | | - qb = _QueryBuilder() |
44 | | - qb.add_filter("c.id", "@memory_id", memory_id) |
45 | | - qb.add_filter("c.user_id", "@user_id", user_id) |
46 | | - qb.add_filter("c.thread_id", "@thread_id", thread_id) |
47 | | - qb.add_filter("c.role", "@role", role) |
48 | | - qb.add_filter("c.type", "@memory_type", memory_type) |
49 | | - return qb |
50 | | - |
51 | | - |
52 | 34 | # --------------------------------------------------------------------------- |
53 | 35 | # Async client |
54 | 36 | # --------------------------------------------------------------------------- |
@@ -102,19 +84,21 @@ async def connect(self) -> None: |
102 | 84 | CosmosOperationError |
103 | 85 | If the connection fails. |
104 | 86 | """ |
105 | | - if not self._endpoint: |
106 | | - raise ConfigurationError(parameter="endpoint") |
107 | | - if not self._credential: |
108 | | - raise ConfigurationError(parameter="credential") |
| 87 | + _validate_connection( |
| 88 | + self._endpoint, self._credential, self._database, self._container |
| 89 | + ) |
109 | 90 |
|
110 | 91 | try: |
111 | 92 | from azure.cosmos.aio import CosmosClient |
112 | 93 |
|
113 | | - self._cosmos_client = CosmosClient( |
| 94 | + client = CosmosClient( |
114 | 95 | self._endpoint, credential=self._credential |
115 | 96 | ) |
116 | | - db = self._cosmos_client.get_database_client(self._database) |
117 | | - self._container_client = db.get_container_client(self._container) |
| 97 | + db = client.get_database_client(self._database) |
| 98 | + container = db.get_container_client(self._container) |
| 99 | + |
| 100 | + self._cosmos_client = client |
| 101 | + self._container_client = container |
118 | 102 | except Exception as exc: |
119 | 103 | raise CosmosOperationError( |
120 | 104 | f"Failed to connect to Cosmos DB (async): {exc}" |
@@ -143,62 +127,44 @@ async def create_store( |
143 | 127 | * Full-text index on ``/content`` |
144 | 128 | * Autoscale throughput (max RU) |
145 | 129 | """ |
146 | | - if not self._endpoint: |
147 | | - raise ConfigurationError(parameter="endpoint") |
148 | | - if not self._credential: |
149 | | - raise ConfigurationError(parameter="credential") |
| 130 | + _validate_connection( |
| 131 | + self._endpoint, self._credential, self._database, self._container |
| 132 | + ) |
150 | 133 |
|
151 | 134 | try: |
152 | 135 | from azure.cosmos import PartitionKey, ThroughputProperties |
153 | 136 | from azure.cosmos.aio import CosmosClient |
154 | 137 |
|
155 | | - self._cosmos_client = CosmosClient( |
| 138 | + client = CosmosClient( |
156 | 139 | self._endpoint, credential=self._credential |
157 | 140 | ) |
158 | 141 |
|
159 | | - db = await self._cosmos_client.create_database_if_not_exists( |
| 142 | + db = await client.create_database_if_not_exists( |
160 | 143 | id=self._database |
161 | 144 | ) |
162 | 145 |
|
163 | 146 | partition_key = PartitionKey( |
164 | 147 | path=["/user_id", "/thread_id"], kind="MultiHash" |
165 | 148 | ) |
166 | 149 |
|
167 | | - vector_embedding_policy = { |
168 | | - "vectorEmbeddings": [ |
169 | | - { |
170 | | - "path": "/embedding", |
171 | | - "dataType": embedding_data_type, |
172 | | - "distanceFunction": distance_function, |
173 | | - "dimensions": embedding_dimensions, |
174 | | - } |
175 | | - ] |
176 | | - } |
177 | | - |
178 | | - indexing_policy = { |
179 | | - "includedPaths": [{"path": "/*"}], |
180 | | - "excludedPaths": [{"path": "/embedding/*"}], |
181 | | - "vectorIndexes": [{"path": "/embedding", "type": "quantizedFlat"}], |
182 | | - "fullTextIndexes": [{"path": "/content"}], |
183 | | - } |
184 | | - |
185 | | - full_text_policy = { |
186 | | - "defaultLanguage": full_text_language, |
187 | | - "fullTextPaths": [ |
188 | | - {"path": "/content", "language": full_text_language} |
189 | | - ], |
190 | | - } |
| 150 | + vec_policy, idx_policy, ft_policy = _container_policies( |
| 151 | + embedding_dimensions=embedding_dimensions, |
| 152 | + embedding_data_type=embedding_data_type, |
| 153 | + distance_function=distance_function, |
| 154 | + full_text_language=full_text_language, |
| 155 | + ) |
191 | 156 |
|
192 | 157 | container = await db.create_container_if_not_exists( |
193 | 158 | id=self._container, |
194 | 159 | partition_key=partition_key, |
195 | | - indexing_policy=indexing_policy, |
196 | | - vector_embedding_policy=vector_embedding_policy, |
197 | | - full_text_policy=full_text_policy, |
| 160 | + indexing_policy=idx_policy, |
| 161 | + vector_embedding_policy=vec_policy, |
| 162 | + full_text_policy=ft_policy, |
198 | 163 | offer_throughput=ThroughputProperties( |
199 | 164 | auto_scale_max_throughput=autoscale_max_ru, |
200 | 165 | ), |
201 | 166 | ) |
| 167 | + self._cosmos_client = client |
202 | 168 | self._container_client = container |
203 | 169 | except Exception as exc: |
204 | 170 | raise CosmosOperationError( |
@@ -478,6 +444,7 @@ async def vector_search( |
478 | 444 | Required when *hybrid_search* is ``True``. |
479 | 445 | """ |
480 | 446 | self._require_connected() |
| 447 | + _validate_hybrid_search(hybrid_search, search_terms) |
481 | 448 |
|
482 | 449 | qb = _build_memory_query_builder( |
483 | 450 | user_id=user_id, role=role, memory_type=memory_type, thread_id=thread_id |
|
0 commit comments