Skip to content

Commit 7904d52

Browse files
fix(litellm): omit temperature for reasoning models
Share the OpenAI-compatible reasoning model detector between the native OpenAI-compatible provider and LiteLLM-backed providers. Skip forwarding explicit temperature values through LiteLLM for those models, including Azure GPT-5 deployments, o-series models, and DeepSeek reasoning routes. Limit the LiteLLM Router check to deployments reachable from the default entrypoint so unrelated model groups do not affect call parameters.
1 parent b0038e9 commit 7904d52

6 files changed

Lines changed: 165 additions & 10 deletions

File tree

hindsight-api-slim/hindsight_api/engine/providers/litellm_llm.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from hindsight_api.config import DEFAULT_LLM_TIMEOUT, ENV_LLM_TIMEOUT
2525
from hindsight_api.engine.llm_interface import LLMInterface, OutputTooLongError
2626
from hindsight_api.engine.llm_trace import LLMResponseUsage, stash_response_usage
27+
from hindsight_api.engine.providers.model_capabilities import supports_openai_compatible_reasoning
2728
from hindsight_api.engine.response_models import LLMToolCall, LLMToolCallResult, TokenUsage
2829
from hindsight_api.metrics import get_metrics_collector
2930
from hindsight_api.worker.stage import set_stage
@@ -144,12 +145,18 @@ def _build_common_kwargs(
144145
for key, value in self._extra_body.items():
145146
kwargs.setdefault(key, value)
146147

148+
if self._should_omit_temperature():
149+
kwargs.pop("temperature", None)
150+
147151
# Bedrock service tier: flex (50% cheaper), priority, or reserved
148152
if self.model.startswith("bedrock/") and self.bedrock_service_tier is not None:
149153
kwargs["service_tier"] = self.bedrock_service_tier
150154

151155
return kwargs
152156

157+
def _should_omit_temperature(self) -> bool:
158+
return supports_openai_compatible_reasoning(self.model)
159+
153160
# ── per-model output-tokens cap (shared with Router subclass) ────────────
154161
# Hindsight's defaults (e.g. retain_max_completion_tokens=64000) target
155162
# high-capacity models. When a configured deployment supports fewer

hindsight-api-slim/hindsight_api/engine/providers/litellm_router_llm.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from typing import Any
3838

3939
from hindsight_api.engine.providers.litellm_llm import LiteLLMLLM
40+
from hindsight_api.engine.providers.model_capabilities import supports_openai_compatible_reasoning
4041

4142
logger = logging.getLogger(__name__)
4243

@@ -94,6 +95,7 @@ def __init__(
9495
# deployment Router picks. Uses LiteLLM's own per-model registry; unknown
9596
# models contribute no cap. See LiteLLMLLM._cap_max_completion_tokens.
9697
self._router_output_cap = self._compute_router_output_cap(config)
98+
self._router_omits_temperature = self._config_has_temperature_rejecting_model(config)
9799

98100
logger.info("LiteLLM Router initialized; entrypoint model_name=%r", _ENTRYPOINT_MODEL_NAME)
99101

@@ -130,6 +132,48 @@ def _resolve_completion_model(self, response: Any) -> str:
130132
def _get_model_output_cap(self) -> int | None:
131133
return self._router_output_cap
132134

135+
def _should_omit_temperature(self) -> bool:
136+
return bool(getattr(self, "_router_omits_temperature", False))
137+
138+
def _config_has_temperature_rejecting_model(self, config: dict[str, Any]) -> bool:
139+
reachable_model_names = self._reachable_model_names(config)
140+
for deployment in (config.get("model_list") or []) if isinstance(config, dict) else []:
141+
if not isinstance(deployment, dict):
142+
continue
143+
model_name = deployment.get("model_name")
144+
if model_name not in reachable_model_names:
145+
continue
146+
params = deployment.get("litellm_params") or {}
147+
model_str = params.get("model") if isinstance(params, dict) else None
148+
if model_str and supports_openai_compatible_reasoning(model_str):
149+
return True
150+
return False
151+
152+
def _reachable_model_names(self, config: dict[str, Any]) -> set[str]:
153+
reachable = {_ENTRYPOINT_MODEL_NAME}
154+
pending = [_ENTRYPOINT_MODEL_NAME]
155+
fallback_specs = []
156+
for key in ("fallbacks", "context_window_fallbacks"):
157+
value = config.get(key) if isinstance(config, dict) else None
158+
if isinstance(value, list):
159+
fallback_specs.extend(value)
160+
161+
while pending:
162+
source = pending.pop()
163+
for spec in fallback_specs:
164+
if not isinstance(spec, dict):
165+
continue
166+
targets = spec.get(source)
167+
if isinstance(targets, str):
168+
targets = [targets]
169+
if not isinstance(targets, list):
170+
continue
171+
for target in targets:
172+
if isinstance(target, str) and target not in reachable:
173+
reachable.add(target)
174+
pending.append(target)
175+
return reachable
176+
133177
def _build_common_kwargs(
134178
self,
135179
messages: list[dict[str, Any]],
@@ -144,7 +188,7 @@ def _build_common_kwargs(
144188
}
145189
if max_completion_tokens is not None:
146190
kwargs["max_completion_tokens"] = self._cap_max_completion_tokens(max_completion_tokens)
147-
if temperature is not None:
191+
if temperature is not None and not self._should_omit_temperature():
148192
kwargs["temperature"] = temperature
149193
return kwargs
150194

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
"""Shared provider model capability helpers."""
2+
3+
4+
def supports_openai_compatible_reasoning(model: str) -> bool:
5+
"""Return True for OpenAI-compatible reasoning model names."""
6+
model_lower = (model or "").lower()
7+
if "deepseek" in model_lower:
8+
# DeepSeek v4-flash is the non-thinking route. Treating every
9+
# DeepSeek model as reasoning injects unsupported reasoning params.
10+
return any(x in model_lower for x in ["v4-pro", "reasoner", "r1", "thinking"])
11+
return any(x in model_lower for x in ["gpt-5", "o1", "o3"])

hindsight-api-slim/hindsight_api/engine/providers/openai_compatible_llm.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from hindsight_api.engine.bank_attribution import apply_bank_attribution
3939
from hindsight_api.engine.llm_interface import LLMInterface, OutputTooLongError, ProviderRateLimitResetError
4040
from hindsight_api.engine.llm_trace import LLMResponseUsage, stash_response_usage
41+
from hindsight_api.engine.providers.model_capabilities import supports_openai_compatible_reasoning
4142
from hindsight_api.engine.response_models import LLMToolCall, LLMToolCallResult, TokenUsage
4243
from hindsight_api.metrics import get_metrics_collector
4344
from hindsight_api.worker.stage import set_stage
@@ -594,13 +595,7 @@ async def verify_connection(self) -> None:
594595

595596
def _supports_reasoning_model(self) -> bool:
596597
"""Check if the current model is a reasoning model (o1, o3, GPT-5, DeepSeek)."""
597-
model_lower = self.model.lower()
598-
if "deepseek" in model_lower:
599-
# DeepSeek v4-flash is the non-thinking route. Treating every
600-
# DeepSeek model as a reasoning model injects reasoning_effort,
601-
# which conflicts with thinking-disabled flash calls.
602-
return any(x in model_lower for x in ["v4-pro", "reasoner", "r1", "thinking"])
603-
return any(x in model_lower for x in ["gpt-5", "o1", "o3"])
598+
return supports_openai_compatible_reasoning(self.model)
604599

605600
def _get_max_reasoning_tokens(self) -> int | None:
606601
"""Get max reasoning tokens for reasoning models."""

hindsight-api-slim/tests/test_llm_extra_body.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,15 +348,15 @@ class StructuredAnswer(BaseModel):
348348
# ─── LiteLLM ──────────────────────────────────────────────────────────────────
349349

350350

351-
def _make_litellm_provider(extra_body=None):
351+
def _make_litellm_provider(extra_body=None, model="gpt-4o"):
352352
pytest.importorskip("litellm")
353353
from hindsight_api.engine.providers.litellm_llm import LiteLLMLLM
354354

355355
return LiteLLMLLM(
356356
provider="litellm",
357357
api_key="fake-key",
358358
base_url="",
359-
model="gpt-4o",
359+
model=model,
360360
extra_body=extra_body,
361361
)
362362

@@ -373,6 +373,18 @@ def _fake_litellm_response():
373373
return resp
374374

375375

376+
async def _call_litellm_provider(provider, temperature=0.1):
377+
provider._acompletion = AsyncMock(return_value=_fake_litellm_response())
378+
with patch("hindsight_api.engine.providers.litellm_llm.get_metrics_collector"):
379+
await provider.call(
380+
messages=[{"role": "user", "content": "hi"}],
381+
temperature=temperature,
382+
scope="test",
383+
max_retries=0,
384+
)
385+
return provider._acompletion.call_args.kwargs
386+
387+
376388
def test_litellm_stores_extra_body():
377389
provider = _make_litellm_provider(extra_body=EXTRA_BODY)
378390
assert provider._extra_body == EXTRA_BODY
@@ -404,6 +416,33 @@ async def test_litellm_explicit_param_wins_over_extra_body():
404416
assert provider._acompletion.call_args.kwargs.get("temperature") == 0.9
405417

406418

419+
@pytest.mark.asyncio
420+
@pytest.mark.parametrize(
421+
("model", "expected_temperature"),
422+
[
423+
("azure/gpt-5.5", None),
424+
("openai/o3-mini", None),
425+
("deepseek/deepseek-reasoner", None),
426+
("deepseek/deepseek-v4-flash", 0.1),
427+
],
428+
)
429+
async def test_litellm_temperature_for_reasoning_models(model, expected_temperature):
430+
kwargs = await _call_litellm_provider(_make_litellm_provider(model=model))
431+
432+
if expected_temperature is None:
433+
assert "temperature" not in kwargs
434+
else:
435+
assert kwargs["temperature"] == expected_temperature
436+
437+
438+
@pytest.mark.asyncio
439+
async def test_litellm_gpt5_omits_extra_body_temperature():
440+
provider = _make_litellm_provider(extra_body=EXTRA_BODY, model="azure/gpt-5.5")
441+
kwargs = await _call_litellm_provider(provider, temperature=None)
442+
assert "temperature" not in kwargs
443+
assert kwargs.get("top_p") == 0.9
444+
445+
407446
def test_litellm_router_forwards_extra_body():
408447
"""The Router subclass forwards extra_body through to the shared LiteLLM base."""
409448
pytest.importorskip("litellm")

hindsight-api-slim/tests/test_llm_router_provider.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,22 @@ def mock_router_response() -> MagicMock:
7070
return response
7171

7272

73+
def _router_config(
74+
default_model: str,
75+
*,
76+
extra_model: str | None = None,
77+
extra_name: str = "fallback",
78+
fallbacks: list[dict[str, list[str]]] | None = None,
79+
) -> dict[str, Any]:
80+
model_list = [{"model_name": "default", "litellm_params": {"model": default_model, "api_key": "sk"}}]
81+
if extra_model:
82+
model_list.append({"model_name": extra_name, "litellm_params": {"model": extra_model, "api_key": "sk"}})
83+
config: dict[str, Any] = {"model_list": model_list}
84+
if fallbacks:
85+
config["fallbacks"] = fallbacks
86+
return config
87+
88+
7389
# --- config parsing ----------------------------------------------------------
7490

7591

@@ -202,6 +218,20 @@ def _make_router_provider(config: dict[str, Any], mock_router: Any) -> LiteLLMRo
202218

203219

204220
class TestRouterCall:
221+
async def _call_with_temperature(self, config, mock_router_response):
222+
mock_router = MagicMock()
223+
mock_router.acompletion = AsyncMock(return_value=mock_router_response)
224+
provider = _make_router_provider(config, mock_router)
225+
provider._router_omits_temperature = provider._config_has_temperature_rejecting_model(config)
226+
227+
await provider.call(
228+
messages=[{"role": "user", "content": "hi"}],
229+
temperature=0.1,
230+
max_retries=0,
231+
)
232+
233+
return mock_router.acompletion.await_args.kwargs
234+
205235
@pytest.mark.asyncio
206236
async def test_plain_text_call_targets_default_entrypoint(self, two_step_config, mock_router_response):
207237
mock_router = MagicMock()
@@ -314,6 +344,35 @@ async def test_no_cap_when_litellm_registry_has_no_data(self, two_step_config, m
314344
kwargs = mock_router.acompletion.await_args.kwargs
315345
assert kwargs["max_completion_tokens"] == 64000
316346

347+
@pytest.mark.asyncio
348+
@pytest.mark.parametrize(
349+
("config", "expected_temperature"),
350+
[
351+
(_router_config("azure/gpt-5.5"), None),
352+
(_router_config("openai/gpt-4o-mini", extra_model="azure/gpt-5.5", extra_name="unused"), 0.1),
353+
(
354+
_router_config(
355+
"openai/gpt-4o-mini",
356+
extra_model="azure/gpt-5.5",
357+
fallbacks=[{"default": ["fallback"]}],
358+
),
359+
None,
360+
),
361+
],
362+
)
363+
async def test_router_temperature_for_reachable_reasoning_models(
364+
self,
365+
config,
366+
expected_temperature,
367+
mock_router_response,
368+
):
369+
kwargs = await self._call_with_temperature(config, mock_router_response)
370+
assert kwargs["model"] == "default"
371+
if expected_temperature is None:
372+
assert "temperature" not in kwargs
373+
else:
374+
assert kwargs["temperature"] == expected_temperature
375+
317376
@pytest.mark.asyncio
318377
async def test_call_with_tools(self, two_step_config):
319378
response = MagicMock()

0 commit comments

Comments
 (0)