|
1 | 1 | """AsyncAgentMemory: Async version of AgentMemory using azure.cosmos.aio and AsyncAzureOpenAI.""" |
2 | 2 |
|
3 | | -import uuid |
| 3 | +import uuid, os |
4 | 4 | from datetime import datetime, timezone |
5 | 5 | from typing import Any, Optional |
6 | 6 |
|
@@ -28,6 +28,7 @@ def __init__( |
28 | 28 | ai_foundry_credential: Optional["TokenCredential"] = None, |
29 | 29 | ai_foundry_api_key: Optional[str] = None, |
30 | 30 | embedding_model: str = "text-embedding-3-large", |
| 31 | + embedding_dimensions: Optional[int] = None, |
31 | 32 | adf_endpoint: Optional[str] = None, |
32 | 33 | adf_key: Optional[str] = None, |
33 | 34 | use_default_credential: bool = True, |
@@ -57,6 +58,9 @@ def __init__( |
57 | 58 | self.ai_foundry_credential = ai_foundry_credential |
58 | 59 | self.ai_foundry_api_key = ai_foundry_api_key |
59 | 60 | self.embedding_model = embedding_model |
| 61 | + self.embedding_dimensions = embedding_dimensions or int( |
| 62 | + _os.environ.get("EMBEDDING_DIMENSIONS", "0") or "0" |
| 63 | + ) or None |
60 | 64 | self._embeddings_client = None |
61 | 65 |
|
62 | 66 | self.adf_endpoint = adf_endpoint |
@@ -292,9 +296,13 @@ async def _get_embedding(self, text: str) -> list[float]: |
292 | 296 | azure_ad_token_provider=token_provider, |
293 | 297 | ) |
294 | 298 |
|
295 | | - response = await self._embeddings_client.embeddings.create( |
296 | | - input=[text], model=self.embedding_model, |
297 | | - ) |
| 299 | + kwargs: dict[str, Any] = { |
| 300 | + "input": [text], |
| 301 | + "model": self.embedding_model, |
| 302 | + } |
| 303 | + if self.embedding_dimensions: |
| 304 | + kwargs["dimensions"] = self.embedding_dimensions |
| 305 | + response = await self._embeddings_client.embeddings.create(**kwargs) |
298 | 306 | return response.data[0].embedding |
299 | 307 |
|
300 | 308 | # ------------------------------------------------------------------ |
|
0 commit comments