Skip to content

Commit 92e68f8

Browse files
feat: add max_tokens configuration to LLMConfig and LLMClient with validation
1 parent ad0bf3f commit 92e68f8

3 files changed

Lines changed: 78 additions & 1 deletion

File tree

app/llm_provider.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
logger = logging.getLogger(__name__)
2020
DEFAULT_LLM_TIMEOUT_SECONDS = 60
21+
DEFAULT_LLM_MAX_TOKENS = 4096
2122
DEFAULT_LLM_RETRIES = 2
2223
DEFAULT_STRUCTURED_OUTPUT_RETRIES = 1
2324

@@ -120,13 +121,16 @@ class LLMConfig:
120121
site_name: Optional[str] = None
121122
send_site_info: bool = True
122123
timeout_seconds: int = DEFAULT_LLM_TIMEOUT_SECONDS
124+
max_tokens: int = DEFAULT_LLM_MAX_TOKENS
123125
num_retries: int = DEFAULT_LLM_RETRIES
124126
structured_output_retries: int = DEFAULT_STRUCTURED_OUTPUT_RETRIES
125127

126128
def __post_init__(self):
127129
"""Validate configuration after initialization."""
128130
if not self.model:
129131
raise ValueError("Model name is required")
132+
if self.max_tokens <= 0:
133+
raise ValueError("Max tokens must be greater than zero")
130134
if self.num_retries < 0:
131135
raise ValueError("Number of retries cannot be negative")
132136
if self.structured_output_retries < 0:
@@ -241,7 +245,7 @@ def chat_completion(
241245
"custom_llm_provider": self.config.provider,
242246
"messages": messages,
243247
"temperature": temperature,
244-
"max_tokens": kwargs.pop("max_tokens", 4096),
248+
"max_tokens": kwargs.pop("max_tokens", self.config.max_tokens),
245249
"timeout": kwargs.pop("timeout", self.config.timeout_seconds),
246250
"num_retries": kwargs.pop("num_retries", self.config.num_retries),
247251
**kwargs,

app/tests/test_translation.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -656,8 +656,70 @@ def test_llm_client_accepts_fenced_raw_string_batch_json(self):
656656
[("hello", "Hola"), ("goodbye", "Adiós")],
657657
)
658658
self.assertEqual(mock_completion.call_args.kwargs["timeout"], 60)
659+
self.assertEqual(mock_completion.call_args.kwargs["max_tokens"], 2048)
659660
self.assertEqual(mock_completion.call_args.kwargs["num_retries"], 2)
660661

662+
def test_llm_client_uses_configured_max_tokens(self):
663+
"""Configured max tokens should be used when no per-call override is provided."""
664+
665+
response = SimpleNamespace(
666+
choices=[
667+
SimpleNamespace(
668+
message=SimpleNamespace(
669+
content='{"translations": [{"key": "hello", "translation": "Hola"}]}',
670+
reasoning_content=None,
671+
)
672+
)
673+
]
674+
)
675+
llm_config = LLMConfig(
676+
provider="openrouter",
677+
model="openrouter/owl-alpha",
678+
max_tokens=1234,
679+
)
680+
681+
with patch(
682+
"llm_provider.litellm.completion", return_value=response
683+
) as mock_completion:
684+
LLMClient(llm_config).chat_completion(
685+
messages=[],
686+
response_model=StringBatchTranslation,
687+
temperature=0,
688+
)
689+
690+
self.assertEqual(mock_completion.call_args.kwargs["max_tokens"], 1234)
691+
692+
def test_llm_client_allows_max_tokens_override(self):
693+
"""Callers can override max tokens per request."""
694+
695+
response = SimpleNamespace(
696+
choices=[
697+
SimpleNamespace(
698+
message=SimpleNamespace(
699+
content='{"translations": [{"key": "hello", "translation": "Hola"}]}',
700+
reasoning_content=None,
701+
)
702+
)
703+
]
704+
)
705+
llm_config = LLMConfig(
706+
provider="openrouter",
707+
model="openrouter/owl-alpha",
708+
max_tokens=1234,
709+
)
710+
711+
with patch(
712+
"llm_provider.litellm.completion", return_value=response
713+
) as mock_completion:
714+
LLMClient(llm_config).chat_completion(
715+
messages=[],
716+
response_model=StringBatchTranslation,
717+
temperature=0,
718+
max_tokens=512,
719+
)
720+
721+
self.assertEqual(mock_completion.call_args.kwargs["max_tokens"], 512)
722+
661723
def test_llm_client_allows_retry_override(self):
662724
"""Callers can override the default LiteLLM retry count per request."""
663725

@@ -697,6 +759,16 @@ def test_llm_config_rejects_negative_retries(self):
697759
num_retries=-1,
698760
)
699761

762+
def test_llm_config_rejects_non_positive_max_tokens(self):
763+
"""Max tokens must be positive because providers reject invalid caps."""
764+
765+
with self.assertRaisesRegex(ValueError, "Max tokens must be greater than zero"):
766+
LLMConfig(
767+
provider="openrouter",
768+
model="openrouter/owl-alpha",
769+
max_tokens=0,
770+
)
771+
700772
def test_llm_client_retries_invalid_structured_output(self):
701773
"""Malformed model JSON should trigger a fresh structured-output attempt."""
702774

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ babel>=2.17.0
55
ruff>=0.8.0
66
pydantic>=2.0.0
77
litellm>=1.60.0
8+
tenacity>=8.0.0

0 commit comments

Comments
 (0)