diff --git a/src/llm/backends/openai.py b/src/llm/backends/openai.py index 1e01e78a2..f98e7e3a9 100644 --- a/src/llm/backends/openai.py +++ b/src/llm/backends/openai.py @@ -17,6 +17,16 @@ logger = logging.getLogger(__name__) +_STRUCTURED_FALLBACK_ERRORS = ( + BadRequestError, + json.JSONDecodeError, + ValidationError, + ValidationException, + AttributeError, + TypeError, + IndexError, +) + def _uses_max_completion_tokens(model: str) -> bool: """OpenAI reasoning models (gpt-5 family + o-series) require @@ -106,6 +116,21 @@ def extract_openai_cache_tokens(usage: Any) -> tuple[int, int]: return cache_creation, cache_read +def _first_openai_choice(response: Any) -> Any: + choices = getattr(response, "choices", None) + if not choices: + raise ValidationException("OpenAI response did not include any choices") + return choices[0] + + +def _first_openai_message(response: Any) -> tuple[Any, Any]: + choice = _first_openai_choice(response) + message = getattr(choice, "message", None) + if message is None: + raise ValidationException("OpenAI response choice did not include a message") + return choice, message + + class OpenAIBackend: """Provider backend wrapping AsyncOpenAI.""" @@ -151,7 +176,8 @@ async def complete( response = await self._client.chat.completions.parse(**params) except LengthFinishReasonError as exc: truncated = exc.completion - raw_content = truncated.choices[0].message.content or "" + _choice, message = _first_openai_message(truncated) + raw_content = getattr(message, "content", None) or "" content = repair_response_model_json( raw_content, response_format, @@ -161,22 +187,30 @@ async def complete( truncated, content_override=content, ) - except (BadRequestError, json.JSONDecodeError, ValidationError): - fallback_response = await self._create_structured_response( + except _STRUCTURED_FALLBACK_ERRORS: + fallback_response, content = await self._fallback_structured_response( params=params, response_format=response_format, + model=model, ) - content = self._parse_or_repair_structured_content( + return self._normalize_response( fallback_response, - response_format, - model, + content_override=content, + ) + try: + _choice, message = _first_openai_message(response) + parsed = getattr(message, "parsed", None) + raw_content = getattr(message, "content", None) or "" + except _STRUCTURED_FALLBACK_ERRORS: + fallback_response, content = await self._fallback_structured_response( + params=params, + response_format=response_format, + model=model, ) return self._normalize_response( fallback_response, content_override=content, ) - parsed = response.choices[0].message.parsed - raw_content = response.choices[0].message.content or "" if parsed is None and raw_content: content = repair_response_model_json( raw_content, @@ -185,13 +219,21 @@ async def complete( ) return self._normalize_response(response, content_override=content) if parsed is None: - refusal = getattr(response.choices[0].message, "refusal", None) + refusal = getattr(message, "refusal", None) if refusal: return self._normalize_response( response, content_override=refusal, ) - raise ValidationException("No parsed content in structured response") + fallback_response, content = await self._fallback_structured_response( + params=params, + response_format=response_format, + model=model, + ) + return self._normalize_response( + fallback_response, + content_override=content, + ) return self._normalize_response( response, content_override=validate_structured_output(parsed, response_format), @@ -327,10 +369,10 @@ def _normalize_response( *, content_override: Any | None = None, ) -> CompletionResult: - usage = response.usage - finish_reason = response.choices[0].finish_reason + choice, message = _first_openai_message(response) + usage = getattr(response, "usage", None) + finish_reason = getattr(choice, "finish_reason", None) tool_calls: list[ToolCallResult] = [] - message = response.choices[0].message if getattr(message, "tool_calls", None): for tool_call in message.tool_calls: tool_input: dict[str, Any] = {} @@ -359,9 +401,9 @@ def _normalize_response( return CompletionResult( content=content_override if content_override is not None - else (message.content or ""), - input_tokens=usage.prompt_tokens if usage else 0, - output_tokens=usage.completion_tokens if usage else 0, + else (getattr(message, "content", None) or ""), + input_tokens=(getattr(usage, "prompt_tokens", 0) or 0) if usage else 0, + output_tokens=(getattr(usage, "completion_tokens", 0) or 0) if usage else 0, cache_creation_input_tokens=cache_creation, cache_read_input_tokens=cache_read, finish_reason=finish_reason or "stop", @@ -387,16 +429,73 @@ async def _create_structured_response( } return await self._client.chat.completions.create(**structured_params) + async def _create_json_object_response( + self, + *, + params: dict[str, Any], + ) -> Any: + json_params = dict(params) + json_params["response_format"] = {"type": "json_object"} + return await self._client.chat.completions.create(**json_params) + + async def _create_plain_response( + self, + *, + params: dict[str, Any], + ) -> Any: + plain_params = dict(params) + plain_params.pop("response_format", None) + return await self._client.chat.completions.create(**plain_params) + + async def _fallback_structured_response( + self, + *, + params: dict[str, Any], + response_format: type[BaseModel], + model: str, + ) -> tuple[Any, BaseModel | str]: + last_error: Exception | None = None + for fallback_name, fallback_call in ( + ( + "json_schema", + lambda: self._create_structured_response( + params=params, + response_format=response_format, + ), + ), + ("json_object", lambda: self._create_json_object_response(params=params)), + ("plain", lambda: self._create_plain_response(params=params)), + ): + try: + response = await fallback_call() + content = self._parse_or_repair_structured_content( + response, + response_format, + model, + ) + return response, content + except _STRUCTURED_FALLBACK_ERRORS as exc: + last_error = exc + logger.warning( + "OpenAI structured output fallback %s failed: %s", + fallback_name, + exc.__class__.__name__, + ) + raise ValidationException( + "OpenAI structured response did not include usable content" + ) from last_error + @staticmethod def _parse_or_repair_structured_content( response: Any, response_format: type[BaseModel], model: str, ) -> BaseModel | str: - raw_content = response.choices[0].message.content or "" + _choice, message = _first_openai_message(response) + raw_content = getattr(message, "content", None) or "" if raw_content: return repair_response_model_json(raw_content, response_format, model) - refusal = getattr(response.choices[0].message, "refusal", None) + refusal = getattr(message, "refusal", None) if refusal: return refusal raise ValidationException( diff --git a/tests/llm/test_backends/test_openai.py b/tests/llm/test_backends/test_openai.py index 81838202e..c3c80651c 100644 --- a/tests/llm/test_backends/test_openai.py +++ b/tests/llm/test_backends/test_openai.py @@ -2,11 +2,16 @@ from unittest.mock import AsyncMock, Mock import pytest +from pydantic import BaseModel from src.exceptions import ValidationException from src.llm.backends.openai import OpenAIBackend +class StructuredAnswer(BaseModel): + answer: str + + @pytest.mark.asyncio async def test_openai_backend_uses_gpt5_params_and_extracts_reasoning() -> None: client = Mock() @@ -230,6 +235,216 @@ async def test_openai_backend_converts_anthropic_style_tools() -> None: assert call["tool_choice"] == "required" +@pytest.mark.asyncio +async def test_openai_backend_allows_missing_usage() -> None: + client = Mock() + client.chat.completions.create = AsyncMock( + return_value=SimpleNamespace( + choices=[ + SimpleNamespace( + finish_reason="stop", + message=SimpleNamespace(content="ok", tool_calls=[]), + ) + ], + ) + ) + + backend = OpenAIBackend(client) + result = await backend.complete( + model="gpt-4.1", + messages=[{"role": "user", "content": "Hello"}], + max_tokens=100, + ) + + assert result.content == "ok" + assert result.input_tokens == 0 + assert result.output_tokens == 0 + + +@pytest.mark.asyncio +async def test_openai_backend_allows_none_usage_tokens() -> None: + client = Mock() + client.chat.completions.create = AsyncMock( + return_value=SimpleNamespace( + choices=[ + SimpleNamespace( + finish_reason="stop", + message=SimpleNamespace(content="ok", tool_calls=[]), + ) + ], + usage=SimpleNamespace(prompt_tokens=None, completion_tokens=None), + ) + ) + + backend = OpenAIBackend(client) + result = await backend.complete( + model="gpt-4.1", + messages=[{"role": "user", "content": "Hello"}], + max_tokens=100, + ) + + assert result.content == "ok" + assert result.input_tokens == 0 + assert result.output_tokens == 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "response", + [ + SimpleNamespace(choices=[], usage=None), + SimpleNamespace(choices=None, usage=None), + SimpleNamespace( + choices=[SimpleNamespace(finish_reason="stop", message=None)], + usage=None, + ), + ], +) +async def test_openai_backend_rejects_malformed_responses(response: object) -> None: + client = Mock() + client.chat.completions.create = AsyncMock(return_value=response) + + backend = OpenAIBackend(client) + + with pytest.raises(ValidationException, match="OpenAI response"): + await backend.complete( + model="gpt-4.1", + messages=[{"role": "user", "content": "Hello"}], + max_tokens=100, + ) + + +@pytest.mark.asyncio +async def test_openai_structured_response_falls_back_to_json_object() -> None: + client = Mock() + client.chat.completions.parse = AsyncMock( + return_value=SimpleNamespace(choices=[], usage=None) + ) + client.chat.completions.create = AsyncMock( + side_effect=[ + SimpleNamespace(choices=[], usage=None), + SimpleNamespace( + choices=[ + SimpleNamespace( + finish_reason="stop", + message=SimpleNamespace( + content='{"answer": "json object worked"}', + tool_calls=[], + ), + ) + ], + usage=None, + ), + ] + ) + + backend = OpenAIBackend(client) + result = await backend.complete( + model="gpt-4.1", + messages=[{"role": "user", "content": "Give JSON"}], + max_tokens=100, + response_format=StructuredAnswer, + ) + + assert isinstance(result.content, StructuredAnswer) + assert result.content.answer == "json object worked" + assert client.chat.completions.create.await_count == 2 + first_call, second_call = client.chat.completions.create.await_args_list + assert first_call.kwargs["response_format"]["type"] == "json_schema" + assert second_call.kwargs["response_format"] == {"type": "json_object"} + + +@pytest.mark.asyncio +async def test_openai_structured_response_falls_back_when_parse_returns_no_content() -> None: + client = Mock() + client.chat.completions.parse = AsyncMock( + return_value=SimpleNamespace( + choices=[ + SimpleNamespace( + finish_reason="stop", + message=SimpleNamespace( + parsed=None, + content=None, + refusal=None, + tool_calls=[], + ), + ) + ], + usage=None, + ) + ) + client.chat.completions.create = AsyncMock( + return_value=SimpleNamespace( + choices=[ + SimpleNamespace( + finish_reason="stop", + message=SimpleNamespace( + content='{"answer": "fallback worked"}', + tool_calls=[], + ), + ) + ], + usage=None, + ) + ) + + backend = OpenAIBackend(client) + result = await backend.complete( + model="gpt-4.1", + messages=[{"role": "user", "content": "Give JSON"}], + max_tokens=100, + response_format=StructuredAnswer, + ) + + assert isinstance(result.content, StructuredAnswer) + assert result.content.answer == "fallback worked" + client.chat.completions.create.assert_awaited_once() + call = client.chat.completions.create.await_args + assert call.kwargs["response_format"]["type"] == "json_schema" + + +@pytest.mark.asyncio +async def test_openai_structured_response_falls_back_to_plain_create() -> None: + client = Mock() + client.chat.completions.parse = AsyncMock( + return_value=SimpleNamespace(choices=[], usage=None) + ) + client.chat.completions.create = AsyncMock( + side_effect=[ + SimpleNamespace(choices=[], usage=None), + SimpleNamespace(choices=[], usage=None), + SimpleNamespace( + choices=[ + SimpleNamespace( + finish_reason="stop", + message=SimpleNamespace( + content='{"answer": "plain worked"}', + tool_calls=[], + ), + ) + ], + usage=None, + ), + ] + ) + + backend = OpenAIBackend(client) + result = await backend.complete( + model="gpt-4.1", + messages=[{"role": "user", "content": "Give JSON"}], + max_tokens=100, + response_format=StructuredAnswer, + ) + + assert isinstance(result.content, StructuredAnswer) + assert result.content.answer == "plain worked" + assert client.chat.completions.create.await_count == 3 + first_call, second_call, third_call = client.chat.completions.create.await_args_list + assert first_call.kwargs["response_format"]["type"] == "json_schema" + assert second_call.kwargs["response_format"] == {"type": "json_object"} + assert "response_format" not in third_call.kwargs + + @pytest.mark.parametrize( "model", [