Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 117 additions & 18 deletions src/llm/backends/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@

logger = logging.getLogger(__name__)

_STRUCTURED_FALLBACK_ERRORS = (
BadRequestError,
json.JSONDecodeError,
ValidationError,
ValidationException,
AttributeError,
TypeError,
IndexError,
)


def _uses_max_completion_tokens(model: str) -> bool:
"""OpenAI reasoning models (gpt-5 family + o-series) require
Expand Down Expand Up @@ -106,6 +116,21 @@ def extract_openai_cache_tokens(usage: Any) -> tuple[int, int]:
return cache_creation, cache_read


def _first_openai_choice(response: Any) -> Any:
choices = getattr(response, "choices", None)
if not choices:
raise ValidationException("OpenAI response did not include any choices")
return choices[0]


def _first_openai_message(response: Any) -> tuple[Any, Any]:
choice = _first_openai_choice(response)
message = getattr(choice, "message", None)
if message is None:
raise ValidationException("OpenAI response choice did not include a message")
return choice, message


class OpenAIBackend:
"""Provider backend wrapping AsyncOpenAI."""

Expand Down Expand Up @@ -151,7 +176,8 @@ async def complete(
response = await self._client.chat.completions.parse(**params)
except LengthFinishReasonError as exc:
truncated = exc.completion
raw_content = truncated.choices[0].message.content or ""
_choice, message = _first_openai_message(truncated)
raw_content = getattr(message, "content", None) or ""
content = repair_response_model_json(
raw_content,
response_format,
Expand All @@ -161,22 +187,30 @@ async def complete(
truncated,
content_override=content,
)
except (BadRequestError, json.JSONDecodeError, ValidationError):
fallback_response = await self._create_structured_response(
except _STRUCTURED_FALLBACK_ERRORS:
fallback_response, content = await self._fallback_structured_response(
params=params,
response_format=response_format,
model=model,
)
content = self._parse_or_repair_structured_content(
return self._normalize_response(
fallback_response,
response_format,
model,
content_override=content,
)
try:
_choice, message = _first_openai_message(response)
parsed = getattr(message, "parsed", None)
raw_content = getattr(message, "content", None) or ""
except _STRUCTURED_FALLBACK_ERRORS:
fallback_response, content = await self._fallback_structured_response(
params=params,
response_format=response_format,
model=model,
)
return self._normalize_response(
fallback_response,
content_override=content,
)
parsed = response.choices[0].message.parsed
raw_content = response.choices[0].message.content or ""
if parsed is None and raw_content:
content = repair_response_model_json(
raw_content,
Expand All @@ -185,13 +219,21 @@ async def complete(
)
return self._normalize_response(response, content_override=content)
if parsed is None:
refusal = getattr(response.choices[0].message, "refusal", None)
refusal = getattr(message, "refusal", None)
if refusal:
return self._normalize_response(
response,
content_override=refusal,
)
raise ValidationException("No parsed content in structured response")
fallback_response, content = await self._fallback_structured_response(
params=params,
response_format=response_format,
model=model,
)
return self._normalize_response(
fallback_response,
content_override=content,
)
return self._normalize_response(
response,
content_override=validate_structured_output(parsed, response_format),
Expand Down Expand Up @@ -327,10 +369,10 @@ def _normalize_response(
*,
content_override: Any | None = None,
) -> CompletionResult:
usage = response.usage
finish_reason = response.choices[0].finish_reason
choice, message = _first_openai_message(response)
usage = getattr(response, "usage", None)
finish_reason = getattr(choice, "finish_reason", None)
tool_calls: list[ToolCallResult] = []
message = response.choices[0].message
if getattr(message, "tool_calls", None):
for tool_call in message.tool_calls:
tool_input: dict[str, Any] = {}
Expand Down Expand Up @@ -359,9 +401,9 @@ def _normalize_response(
return CompletionResult(
content=content_override
if content_override is not None
else (message.content or ""),
input_tokens=usage.prompt_tokens if usage else 0,
output_tokens=usage.completion_tokens if usage else 0,
else (getattr(message, "content", None) or ""),
input_tokens=(getattr(usage, "prompt_tokens", 0) or 0) if usage else 0,
output_tokens=(getattr(usage, "completion_tokens", 0) or 0) if usage else 0,
cache_creation_input_tokens=cache_creation,
cache_read_input_tokens=cache_read,
finish_reason=finish_reason or "stop",
Expand All @@ -387,16 +429,73 @@ async def _create_structured_response(
}
return await self._client.chat.completions.create(**structured_params)

async def _create_json_object_response(
self,
*,
params: dict[str, Any],
) -> Any:
json_params = dict(params)
json_params["response_format"] = {"type": "json_object"}
return await self._client.chat.completions.create(**json_params)

async def _create_plain_response(
self,
*,
params: dict[str, Any],
) -> Any:
plain_params = dict(params)
plain_params.pop("response_format", None)
return await self._client.chat.completions.create(**plain_params)

async def _fallback_structured_response(
self,
*,
params: dict[str, Any],
response_format: type[BaseModel],
model: str,
) -> tuple[Any, BaseModel | str]:
last_error: Exception | None = None
for fallback_name, fallback_call in (
(
"json_schema",
lambda: self._create_structured_response(
params=params,
response_format=response_format,
),
),
("json_object", lambda: self._create_json_object_response(params=params)),
("plain", lambda: self._create_plain_response(params=params)),
):
try:
response = await fallback_call()
content = self._parse_or_repair_structured_content(
response,
response_format,
model,
)
return response, content
except _STRUCTURED_FALLBACK_ERRORS as exc:
last_error = exc
logger.warning(
"OpenAI structured output fallback %s failed: %s",
fallback_name,
exc.__class__.__name__,
)
raise ValidationException(
"OpenAI structured response did not include usable content"
) from last_error

@staticmethod
def _parse_or_repair_structured_content(
response: Any,
response_format: type[BaseModel],
model: str,
) -> BaseModel | str:
raw_content = response.choices[0].message.content or ""
_choice, message = _first_openai_message(response)
raw_content = getattr(message, "content", None) or ""
if raw_content:
return repair_response_model_json(raw_content, response_format, model)
refusal = getattr(response.choices[0].message, "refusal", None)
refusal = getattr(message, "refusal", None)
if refusal:
return refusal
raise ValidationException(
Expand Down
Loading