Skip to content

Commit 4f54cc0

Browse files
committed
fixed issues with embedding generation on client and ADF side
1 parent c4f3e56 commit 4f54cc0

4 files changed

Lines changed: 31 additions & 14 deletions

File tree

.env.template

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ COSMOS_DB_AUTOSCALE_MAX_RU=1000
77
AI_FOUNDRY_ENDPOINT=https://<your-account>.openai.azure.com/
88
AI_FOUNDRY_API_KEY=
99
EMBEDDING_MODEL=text-embedding-3-large
10-
EMBEDDING_DIMENSIONS=1535
10+
EMBEDDING_DIMENSIONS=1536
1111
EMBEDDING_DATA_TYPE=float32
1212
EMBEDDING_DISTANCE_FUNCTION=cosine
1313
FULL_TEXT_LANGUAGE=en-US

agent_memory_toolkit/async_memory.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""AsyncAgentMemory: Async version of AgentMemory using azure.cosmos.aio and AsyncAzureOpenAI."""
22

3-
import uuid
3+
import uuid, os
44
from datetime import datetime, timezone
55
from typing import Any, Optional
66

@@ -28,6 +28,7 @@ def __init__(
2828
ai_foundry_credential: Optional["TokenCredential"] = None,
2929
ai_foundry_api_key: Optional[str] = None,
3030
embedding_model: str = "text-embedding-3-large",
31+
embedding_dimensions: Optional[int] = None,
3132
adf_endpoint: Optional[str] = None,
3233
adf_key: Optional[str] = None,
3334
use_default_credential: bool = True,
@@ -57,6 +58,9 @@ def __init__(
5758
self.ai_foundry_credential = ai_foundry_credential
5859
self.ai_foundry_api_key = ai_foundry_api_key
5960
self.embedding_model = embedding_model
61+
self.embedding_dimensions = embedding_dimensions or int(
62+
_os.environ.get("EMBEDDING_DIMENSIONS", "0") or "0"
63+
) or None
6064
self._embeddings_client = None
6165

6266
self.adf_endpoint = adf_endpoint
@@ -292,9 +296,13 @@ async def _get_embedding(self, text: str) -> list[float]:
292296
azure_ad_token_provider=token_provider,
293297
)
294298

295-
response = await self._embeddings_client.embeddings.create(
296-
input=[text], model=self.embedding_model,
297-
)
299+
kwargs: dict[str, Any] = {
300+
"input": [text],
301+
"model": self.embedding_model,
302+
}
303+
if self.embedding_dimensions:
304+
kwargs["dimensions"] = self.embedding_dimensions
305+
response = await self._embeddings_client.embeddings.create(**kwargs)
298306
return response.data[0].embedding
299307

300308
# ------------------------------------------------------------------

agent_memory_toolkit/memory.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""AgentMemory: A class for managing agent memories locally and (eventually) in Cosmos DB."""
22

3-
import uuid
3+
import uuid, os
44
from datetime import datetime, timezone
55
from typing import Any, Optional
66

@@ -97,6 +97,7 @@ def __init__(
9797
ai_foundry_credential: Optional["TokenCredential"] = None,
9898
ai_foundry_api_key: Optional[str] = None,
9999
embedding_model: str = "text-embedding-3-large",
100+
embedding_dimensions: Optional[int] = None,
100101
adf_endpoint: Optional[str] = None,
101102
adf_key: Optional[str] = None,
102103
use_default_credential: bool = True,
@@ -131,6 +132,9 @@ def __init__(
131132
self.ai_foundry_credential = ai_foundry_credential
132133
self.ai_foundry_api_key = ai_foundry_api_key
133134
self.embedding_model = embedding_model
135+
self.embedding_dimensions = embedding_dimensions or int(
136+
os.environ.get("EMBEDDING_DIMENSIONS", "0") or "0"
137+
) or None
134138
self._embeddings_client = None
135139

136140
# Azure Durable Functions configuration
@@ -457,10 +461,13 @@ def _get_embedding(self, text: str) -> list[float]:
457461
azure_ad_token_provider=token_provider,
458462
)
459463

460-
response = self._embeddings_client.embeddings.create(
461-
input=[text],
462-
model=self.embedding_model,
463-
)
464+
kwargs: dict[str, Any] = {
465+
"input": [text],
466+
"model": self.embedding_model,
467+
}
468+
if self.embedding_dimensions:
469+
kwargs["dimensions"] = self.embedding_dimensions
470+
response = self._embeddings_client.embeddings.create(**kwargs)
464471
return response.data[0].embedding
465472

466473
# ------------------------------------------------------------------

azure_functions/activities.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,11 @@ def _get_embeddings_client():
5151

5252
endpoint = os.environ["AI_FOUNDRY_ENDPOINT"]
5353
api_key = os.environ.get("AI_FOUNDRY_API_KEY")
54+
api_version = os.environ.get("AI_FOUNDRY_API_VERSION", "2024-12-01-preview")
5455

5556
if api_key:
5657
_embeddings_client = AzureOpenAI(
57-
api_version="2024-12-01-preview",
58+
api_version=api_version,
5859
azure_endpoint=endpoint,
5960
api_key=api_key,
6061
)
@@ -66,7 +67,7 @@ def _get_embeddings_client():
6667
"https://cognitiveservices.azure.com/.default",
6768
)
6869
_embeddings_client = AzureOpenAI(
69-
api_version="2024-12-01-preview",
70+
api_version=api_version,
7071
azure_endpoint=endpoint,
7172
azure_ad_token_provider=token_provider,
7273
)
@@ -84,10 +85,11 @@ def _get_chat_client():
8485

8586
endpoint = os.environ["AI_FOUNDRY_ENDPOINT"]
8687
api_key = os.environ.get("AI_FOUNDRY_API_KEY")
88+
api_version = os.environ.get("AI_FOUNDRY_API_VERSION", "2024-12-01-preview")
8789

8890
if api_key:
8991
_chat_client = AzureOpenAI(
90-
api_version="2024-12-01-preview",
92+
api_version=api_version,
9193
azure_endpoint=endpoint,
9294
api_key=api_key,
9395
)
@@ -99,7 +101,7 @@ def _get_chat_client():
99101
"https://cognitiveservices.azure.com/.default",
100102
)
101103
_chat_client = AzureOpenAI(
102-
api_version="2024-12-01-preview",
104+
api_version=api_version,
103105
azure_endpoint=endpoint,
104106
azure_ad_token_provider=token_provider,
105107
)

0 commit comments

Comments
 (0)