Skip to content

Commit ad0bf3f

Browse files
feat: add structured output retry mechanism and validation to LLMClient and LLMConfig
1 parent 3632a93 commit ad0bf3f

2 files changed

Lines changed: 134 additions & 20 deletions

File tree

app/llm_provider.py

Lines changed: 53 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@
1313
import re
1414
from dataclasses import dataclass
1515
from typing import Optional, Dict, Any, List
16-
from pydantic import BaseModel, Field
16+
from pydantic import BaseModel, Field, ValidationError
1717
import litellm
1818

1919
logger = logging.getLogger(__name__)
2020
DEFAULT_LLM_TIMEOUT_SECONDS = 60
2121
DEFAULT_LLM_RETRIES = 2
22+
DEFAULT_STRUCTURED_OUTPUT_RETRIES = 1
2223

2324
# Suppress noisy logging from litellm/openai unless error/warning
2425
litellm.set_verbose = False
@@ -120,13 +121,16 @@ class LLMConfig:
120121
send_site_info: bool = True
121122
timeout_seconds: int = DEFAULT_LLM_TIMEOUT_SECONDS
122123
num_retries: int = DEFAULT_LLM_RETRIES
124+
structured_output_retries: int = DEFAULT_STRUCTURED_OUTPUT_RETRIES
123125

124126
def __post_init__(self):
125127
"""Validate configuration after initialization."""
126128
if not self.model:
127129
raise ValueError("Model name is required")
128130
if self.num_retries < 0:
129131
raise ValueError("Number of retries cannot be negative")
132+
if self.structured_output_retries < 0:
133+
raise ValueError("Number of structured output retries cannot be negative")
130134

131135

132136
class LLMClient:
@@ -190,6 +194,31 @@ def _get_message_value(message: Any, key: str) -> str:
190194
value = getattr(message, key, None)
191195
return value if isinstance(value, str) else ""
192196

197+
def _completion_content(self, api_params: Dict[str, Any]) -> str:
198+
"""Run LiteLLM completion and return the model text content."""
199+
response = litellm.completion(**api_params)
200+
message = response.choices[0].message
201+
content = self._get_message_value(message, "content").strip()
202+
reasoning_content = self._get_message_value(
203+
message, "reasoning_content"
204+
).strip()
205+
206+
if not content and reasoning_content:
207+
logger.info(
208+
"Content is empty but reasoning_content is present. "
209+
"Falling back to reasoning_content for structured output parsing."
210+
)
211+
content = reasoning_content
212+
213+
return content
214+
215+
def _parse_structured_response(
216+
self, content: str, response_model: type[BaseModel]
217+
) -> BaseModel:
218+
"""Normalize and validate structured model output."""
219+
content = self._coerce_structured_payload(content, response_model)
220+
return response_model.model_validate_json(content)
221+
193222
def chat_completion(
194223
self,
195224
messages: list,
@@ -200,6 +229,12 @@ def chat_completion(
200229
"""
201230
Send a chat completion request to the LLM API using LiteLLM.
202231
"""
232+
structured_output_retries = kwargs.pop(
233+
"structured_output_retries", self.config.structured_output_retries
234+
)
235+
if structured_output_retries < 0:
236+
raise ValueError("Number of structured output retries cannot be negative")
237+
203238
# Build payload parameters
204239
api_params = {
205240
"model": self.config.model,
@@ -236,25 +271,23 @@ def chat_completion(
236271
)
237272

238273
try:
239-
response = litellm.completion(**api_params)
240-
message = response.choices[0].message
241-
content = self._get_message_value(message, "content").strip()
242-
reasoning_content = self._get_message_value(
243-
message, "reasoning_content"
244-
).strip()
245-
246-
if not content and reasoning_content:
247-
logger.info(
248-
"Content is empty but reasoning_content is present. "
249-
"Falling back to reasoning_content for structured output parsing."
250-
)
251-
content = reasoning_content
252-
253-
if response_model:
254-
# Natively parse and validate the JSON string into the Pydantic model
255-
content = self._coerce_structured_payload(content, response_model)
256-
return response_model.model_validate_json(content)
257-
return content
274+
if not response_model:
275+
return self._completion_content(api_params)
276+
277+
validation_attempts = structured_output_retries + 1
278+
for attempt in range(1, validation_attempts + 1):
279+
content = self._completion_content(api_params)
280+
try:
281+
return self._parse_structured_response(content, response_model)
282+
except ValidationError as e:
283+
if attempt >= validation_attempts:
284+
raise
285+
logger.warning(
286+
"LLM returned invalid structured output on attempt "
287+
f"{attempt}/{validation_attempts}; retrying. Error: {e}"
288+
)
289+
290+
raise RuntimeError("Structured output retry loop exited unexpectedly")
258291

259292
except Exception as e:
260293
logger.error(f"Error during LLM API call: {e}")

app/tests/test_translation.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -697,6 +697,87 @@ def test_llm_config_rejects_negative_retries(self):
697697
num_retries=-1,
698698
)
699699

700+
def test_llm_client_retries_invalid_structured_output(self):
701+
"""Malformed model JSON should trigger a fresh structured-output attempt."""
702+
703+
bad_response = SimpleNamespace(
704+
choices=[
705+
SimpleNamespace(
706+
message=SimpleNamespace(
707+
content='{\n "translations":',
708+
reasoning_content=None,
709+
)
710+
)
711+
]
712+
)
713+
good_response = SimpleNamespace(
714+
choices=[
715+
SimpleNamespace(
716+
message=SimpleNamespace(
717+
content='{"translations": [{"key": "hello", "translation": "Hola"}]}',
718+
reasoning_content=None,
719+
)
720+
)
721+
]
722+
)
723+
llm_config = LLMConfig(provider="openrouter", model="openrouter/owl-alpha")
724+
725+
with patch(
726+
"llm_provider.litellm.completion",
727+
side_effect=[bad_response, good_response],
728+
) as mock_completion:
729+
result = LLMClient(llm_config).chat_completion(
730+
messages=[],
731+
response_model=StringBatchTranslation,
732+
temperature=0,
733+
)
734+
735+
self.assertEqual(mock_completion.call_count, 2)
736+
self.assertEqual(
737+
[(item.key, item.translation) for item in result.translations],
738+
[("hello", "Hola")],
739+
)
740+
741+
def test_llm_client_allows_structured_output_retry_override(self):
742+
"""Callers can disable app-level structured output retries per request."""
743+
744+
bad_response = SimpleNamespace(
745+
choices=[
746+
SimpleNamespace(
747+
message=SimpleNamespace(
748+
content='{\n "translations":',
749+
reasoning_content=None,
750+
)
751+
)
752+
]
753+
)
754+
llm_config = LLMConfig(provider="openrouter", model="openrouter/owl-alpha")
755+
756+
with patch(
757+
"llm_provider.litellm.completion", return_value=bad_response
758+
) as mock_completion:
759+
with self.assertRaisesRegex(ValueError, "Invalid JSON"):
760+
LLMClient(llm_config).chat_completion(
761+
messages=[],
762+
response_model=StringBatchTranslation,
763+
temperature=0,
764+
structured_output_retries=0,
765+
)
766+
767+
self.assertEqual(mock_completion.call_count, 1)
768+
769+
def test_llm_config_rejects_negative_structured_output_retries(self):
770+
"""Structured output retry count must not be negative."""
771+
772+
with self.assertRaisesRegex(
773+
ValueError, "Number of structured output retries cannot be negative"
774+
):
775+
LLMConfig(
776+
provider="openrouter",
777+
model="openrouter/owl-alpha",
778+
structured_output_retries=-1,
779+
)
780+
700781
def test_llm_client_accepts_dict_style_message(self):
701782
"""LiteLLM responses can expose message data with dict-style access."""
702783

0 commit comments

Comments
 (0)