Skip to content
This repository was archived by the owner on May 27, 2026. It is now read-only.
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 17 additions & 11 deletions projects/pgai/pgai/semantic_catalog/vectorizer/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,24 @@
import os
from collections.abc import Sequence

from openai import AsyncClient
from openai import AsyncClient, AsyncAzureOpenAI

from pgai.semantic_catalog.vectorizer import OpenAIConfig
from pgai.semantic_catalog.vectorizer.models import EmbedRow


def create_openai_client(config: OpenAIConfig):
api_key: str | None = None
if config.api_key_name is not None:
api_key = os.getenv(config.api_key_name)
if config.model.startswith("azure:"):
return AsyncAzureOpenAI(
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
api_key=api_key), config.model.replace("azure:", "")
return AsyncClient(api_key=api_key, base_url=config.base_url), config.model


async def embed_batch(config: OpenAIConfig, batch: list[EmbedRow]) -> None:
"""Generate embeddings for a batch of content using OpenAI.

Expand All @@ -26,13 +38,10 @@ async def embed_batch(config: OpenAIConfig, batch: list[EmbedRow]) -> None:
Raises:
RuntimeError: If the number of embeddings returned doesn't match the batch size.
"""
api_key: str | None = None
if config.api_key_name is not None:
api_key = os.getenv(config.api_key_name)
client = AsyncClient(api_key=api_key, base_url=config.base_url) # TODO: cache this?
client, model = create_openai_client(config) # TODO: cache this?
response = await client.embeddings.create(
input=[x.content for x in batch],
model=config.model,
model=model,
dimensions=config.dimensions,
)
if len(response.data) != len(batch):
Expand All @@ -58,13 +67,10 @@ async def embed_query(config: OpenAIConfig, query: str) -> Sequence[float]:
Raises:
RuntimeError: If the number of embeddings returned is not exactly 1.
"""
api_key: str | None = None
if config.api_key_name is not None:
api_key = os.getenv(config.api_key_name)
client = AsyncClient(api_key=api_key, base_url=config.base_url) # TODO: cache this?
client, model = create_openai_client(config) # TODO: cache this?
response = await client.embeddings.create(
input=query,
model=config.model,
model=model,
dimensions=config.dimensions,
)
if len(response.data) != 1:
Expand Down
Loading