diff --git a/projects/pgai/pgai/semantic_catalog/vectorizer/openai.py b/projects/pgai/pgai/semantic_catalog/vectorizer/openai.py index f61869e8f..40cf85193 100644 --- a/projects/pgai/pgai/semantic_catalog/vectorizer/openai.py +++ b/projects/pgai/pgai/semantic_catalog/vectorizer/openai.py @@ -7,12 +7,24 @@ import os from collections.abc import Sequence -from openai import AsyncClient +from openai import AsyncClient, AsyncAzureOpenAI from pgai.semantic_catalog.vectorizer import OpenAIConfig from pgai.semantic_catalog.vectorizer.models import EmbedRow +def create_openai_client(config: OpenAIConfig): + api_key: str | None = None + if config.api_key_name is not None: + api_key = os.getenv(config.api_key_name) + if config.model.startswith("azure:"): + return AsyncAzureOpenAI( + azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), + api_version=os.getenv("AZURE_OPENAI_API_VERSION"), + api_key=api_key), config.model.replace("azure:", "") + return AsyncClient(api_key=api_key, base_url=config.base_url), config.model + + async def embed_batch(config: OpenAIConfig, batch: list[EmbedRow]) -> None: """Generate embeddings for a batch of content using OpenAI. @@ -26,13 +38,10 @@ async def embed_batch(config: OpenAIConfig, batch: list[EmbedRow]) -> None: Raises: RuntimeError: If the number of embeddings returned doesn't match the batch size. """ - api_key: str | None = None - if config.api_key_name is not None: - api_key = os.getenv(config.api_key_name) - client = AsyncClient(api_key=api_key, base_url=config.base_url) # TODO: cache this? + client, model = create_openai_client(config) # TODO: cache this? response = await client.embeddings.create( input=[x.content for x in batch], - model=config.model, + model=model, dimensions=config.dimensions, ) if len(response.data) != len(batch): @@ -58,13 +67,10 @@ async def embed_query(config: OpenAIConfig, query: str) -> Sequence[float]: Raises: RuntimeError: If the number of embeddings returned is not exactly 1. """ - api_key: str | None = None - if config.api_key_name is not None: - api_key = os.getenv(config.api_key_name) - client = AsyncClient(api_key=api_key, base_url=config.base_url) # TODO: cache this? + client, model = create_openai_client(config) # TODO: cache this? response = await client.embeddings.create( input=query, - model=config.model, + model=model, dimensions=config.dimensions, ) if len(response.data) != 1: