Skip to content

Commit 47d5c26

Browse files
authored
Merge pull request #62 from martipath/new_skills
Support Azure AD token auth in ada002 embedding skill
2 parents a6661eb + 54b8961 commit 47d5c26

2 files changed

Lines changed: 29 additions & 7 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ authors = [
1010
requires-python = ">=3.11"
1111
dependencies = [
1212
"azure-ai-formrecognizer>=3.3.3",
13+
"azure-identity>=1.15.0",
1314
"azure-search-documents>=11.5.2",
1415
"azure-storage-blob>=12.24.1",
1516
"cerberus>=1.3.7",

src/docs2vecs/subcommands/indexer/skills/ada002_embedding_skill.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import List, Optional
22

3+
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
34
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
45

56
from docs2vecs.subcommands.indexer.config import Config
@@ -15,12 +16,32 @@ def az_ada002_embeddings(self, content: str, chunk_id=None):
1516
self.logger.debug(
1617
f"Requesting embedding for chunk_id={chunk_id}, content_length={len(content)} chars"
1718
)
18-
embed_model = AzureOpenAIEmbedding(
19-
deployment_name=self._config["deployment_name"],
20-
api_key=self._config["api_key"],
21-
azure_endpoint=self._config["endpoint"],
22-
api_version=self._config["api_version"],
23-
)
19+
20+
api_key = self._config.get("api_key")
21+
if api_key:
22+
self.logger.debug("Using API key authentication")
23+
embed_model = AzureOpenAIEmbedding(
24+
deployment_name=self._config["deployment_name"],
25+
api_key=api_key,
26+
azure_endpoint=self._config["endpoint"],
27+
api_version=self._config["api_version"],
28+
)
29+
else:
30+
self.logger.debug(
31+
"No api_key provided, using Azure AD token authentication (DefaultAzureCredential)"
32+
)
33+
credential = DefaultAzureCredential()
34+
token_provider = get_bearer_token_provider(
35+
credential, "https://cognitiveservices.azure.com/.default"
36+
)
37+
embed_model = AzureOpenAIEmbedding(
38+
deployment_name=self._config["deployment_name"],
39+
azure_ad_token_provider=token_provider,
40+
azure_endpoint=self._config["endpoint"],
41+
api_version=self._config["api_version"],
42+
use_azure_ad=True,
43+
)
44+
2445
embedding = embed_model.get_query_embedding(content)
2546
self.logger.debug(
2647
f"Successfully received embedding for chunk_id={chunk_id}, embedding_dim={len(embedding) if embedding else 0}"
@@ -47,4 +68,4 @@ def run(self, input: Optional[List[Document]] = None) -> Optional[List[Document]
4768
chunk.content, chunk_id=chunk.chunk_id
4869
)
4970

50-
return input
71+
return input

0 commit comments

Comments
 (0)