Skip to content

Commit 1f547b6

Browse files
fix(litellm): omit temperature for reasoning models
Share the OpenAI-compatible reasoning model detector between the native OpenAI-compatible provider and LiteLLM-backed providers. Skip forwarding explicit temperature values through LiteLLM for those models, including Azure GPT-5 deployments, o-series models, and DeepSeek reasoning routes. Limit the LiteLLM Router check to deployments reachable from the default entrypoint so unrelated model groups do not affect call parameters.
1 parent ef2e8ab commit 1f547b6

6 files changed

Lines changed: 168 additions & 12 deletions

File tree

hindsight-api-slim/hindsight_api/engine/providers/litellm_llm.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from hindsight_api.config import DEFAULT_LLM_TIMEOUT, ENV_LLM_TIMEOUT
2525
from hindsight_api.engine.llm_interface import LLMInterface, OutputTooLongError
2626
from hindsight_api.engine.llm_trace import LLMResponseUsage, stash_response_usage
27+
from hindsight_api.engine.providers.model_capabilities import supports_openai_compatible_reasoning
2728
from hindsight_api.engine.response_models import LLMToolCall, LLMToolCallResult, TokenUsage
2829
from hindsight_api.metrics import get_metrics_collector
2930
from hindsight_api.worker.stage import set_stage
@@ -159,12 +160,18 @@ def _build_common_kwargs(
159160
if self._default_headers:
160161
kwargs.setdefault("extra_headers", dict(self._default_headers))
161162

163+
if self._should_omit_temperature():
164+
kwargs.pop("temperature", None)
165+
162166
# Bedrock service tier: flex (50% cheaper), priority, or reserved
163167
if self.model.startswith("bedrock/") and self.bedrock_service_tier is not None:
164168
kwargs["service_tier"] = self.bedrock_service_tier
165169

166170
return kwargs
167171

172+
def _should_omit_temperature(self) -> bool:
173+
return supports_openai_compatible_reasoning(self.model)
174+
168175
# ── per-model output-tokens cap (shared with Router subclass) ────────────
169176
# Hindsight's defaults (e.g. retain_max_completion_tokens=64000) target
170177
# high-capacity models. When a configured deployment supports fewer

hindsight-api-slim/hindsight_api/engine/providers/litellm_router_llm.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from typing import Any
3838

3939
from hindsight_api.engine.providers.litellm_llm import LiteLLMLLM
40+
from hindsight_api.engine.providers.model_capabilities import supports_openai_compatible_reasoning
4041

4142
logger = logging.getLogger(__name__)
4243

@@ -94,6 +95,7 @@ def __init__(
9495
# deployment Router picks. Uses LiteLLM's own per-model registry; unknown
9596
# models contribute no cap. See LiteLLMLLM._cap_max_completion_tokens.
9697
self._router_output_cap = self._compute_router_output_cap(config)
98+
self._router_omits_temperature = self._config_has_temperature_rejecting_model(config)
9799

98100
logger.info("LiteLLM Router initialized; entrypoint model_name=%r", _ENTRYPOINT_MODEL_NAME)
99101

@@ -130,6 +132,48 @@ def _resolve_completion_model(self, response: Any) -> str:
130132
def _get_model_output_cap(self) -> int | None:
131133
return self._router_output_cap
132134

135+
def _should_omit_temperature(self) -> bool:
136+
return bool(getattr(self, "_router_omits_temperature", False))
137+
138+
def _config_has_temperature_rejecting_model(self, config: dict[str, Any]) -> bool:
139+
reachable_model_names = self._reachable_model_names(config)
140+
for deployment in (config.get("model_list") or []) if isinstance(config, dict) else []:
141+
if not isinstance(deployment, dict):
142+
continue
143+
model_name = deployment.get("model_name")
144+
if model_name not in reachable_model_names:
145+
continue
146+
params = deployment.get("litellm_params") or {}
147+
model_str = params.get("model") if isinstance(params, dict) else None
148+
if model_str and supports_openai_compatible_reasoning(model_str):
149+
return True
150+
return False
151+
152+
def _reachable_model_names(self, config: dict[str, Any]) -> set[str]:
153+
reachable = {_ENTRYPOINT_MODEL_NAME}
154+
pending = [_ENTRYPOINT_MODEL_NAME]
155+
fallback_specs = []
156+
for key in ("fallbacks", "context_window_fallbacks"):
157+
value = config.get(key) if isinstance(config, dict) else None
158+
if isinstance(value, list):
159+
fallback_specs.extend(value)
160+
161+
while pending:
162+
source = pending.pop()
163+
for spec in fallback_specs:
164+
if not isinstance(spec, dict):
165+
continue
166+
targets = spec.get(source)
167+
if isinstance(targets, str):
168+
targets = [targets]
169+
if not isinstance(targets, list):
170+
continue
171+
for target in targets:
172+
if isinstance(target, str) and target not in reachable:
173+
reachable.add(target)
174+
pending.append(target)
175+
return reachable
176+
133177
def _build_common_kwargs(
134178
self,
135179
messages: list[dict[str, Any]],
@@ -144,7 +188,7 @@ def _build_common_kwargs(
144188
}
145189
if max_completion_tokens is not None:
146190
kwargs["max_completion_tokens"] = self._cap_max_completion_tokens(max_completion_tokens)
147-
if temperature is not None:
191+
if temperature is not None and not self._should_omit_temperature():
148192
kwargs["temperature"] = temperature
149193

150194
# Forward operator-configured default headers as ``extra_headers`` so they
@@ -154,8 +198,9 @@ def _build_common_kwargs(
154198
# concern, so we inject them here too — mirroring the base provider.
155199
# ``setdefault`` keeps any explicit per-call ``extra_headers`` authoritative;
156200
# a per-call copy prevents LiteLLM/downstream from mutating the stored dict.
157-
if self._default_headers:
158-
kwargs.setdefault("extra_headers", dict(self._default_headers))
201+
default_headers = getattr(self, "_default_headers", {})
202+
if default_headers:
203+
kwargs.setdefault("extra_headers", dict(default_headers))
159204
return kwargs
160205

161206
async def verify_connection(self) -> None:
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
"""Shared provider model capability helpers."""
2+
3+
4+
def supports_openai_compatible_reasoning(model: str) -> bool:
5+
"""Return True for OpenAI-compatible reasoning model names."""
6+
model_lower = (model or "").lower()
7+
if "deepseek" in model_lower:
8+
# DeepSeek v4-flash is the non-thinking route. Treating every
9+
# DeepSeek model as reasoning injects unsupported reasoning params.
10+
return any(x in model_lower for x in ["v4-pro", "reasoner", "r1", "thinking"])
11+
return any(x in model_lower for x in ["gpt-5", "o1", "o3"])

hindsight-api-slim/hindsight_api/engine/providers/openai_compatible_llm.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from hindsight_api.engine.bank_attribution import apply_bank_attribution
3939
from hindsight_api.engine.llm_interface import LLMInterface, OutputTooLongError, ProviderRateLimitResetError
4040
from hindsight_api.engine.llm_trace import LLMResponseUsage, stash_response_usage
41+
from hindsight_api.engine.providers.model_capabilities import supports_openai_compatible_reasoning
4142
from hindsight_api.engine.response_models import LLMToolCall, LLMToolCallResult, TokenUsage
4243
from hindsight_api.metrics import get_metrics_collector
4344
from hindsight_api.worker.stage import set_stage
@@ -594,13 +595,7 @@ async def verify_connection(self) -> None:
594595

595596
def _supports_reasoning_model(self) -> bool:
596597
"""Check if the current model is a reasoning model (o1, o3, GPT-5, DeepSeek)."""
597-
model_lower = self.model.lower()
598-
if "deepseek" in model_lower:
599-
# DeepSeek v4-flash is the non-thinking route. Treating every
600-
# DeepSeek model as a reasoning model injects reasoning_effort,
601-
# which conflicts with thinking-disabled flash calls.
602-
return any(x in model_lower for x in ["v4-pro", "reasoner", "r1", "thinking"])
603-
return any(x in model_lower for x in ["gpt-5", "o1", "o3"])
598+
return supports_openai_compatible_reasoning(self.model)
604599

605600
def _get_max_reasoning_tokens(self) -> int | None:
606601
"""Get max reasoning tokens for reasoning models."""

hindsight-api-slim/tests/test_llm_extra_body.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -349,15 +349,15 @@ class StructuredAnswer(BaseModel):
349349
# ─── LiteLLM ──────────────────────────────────────────────────────────────────
350350

351351

352-
def _make_litellm_provider(extra_body=None, default_headers=None):
352+
def _make_litellm_provider(extra_body=None, default_headers=None, model="gpt-4o"):
353353
pytest.importorskip("litellm")
354354
from hindsight_api.engine.providers.litellm_llm import LiteLLMLLM
355355

356356
return LiteLLMLLM(
357357
provider="litellm",
358358
api_key="fake-key",
359359
base_url="",
360-
model="gpt-4o",
360+
model=model,
361361
extra_body=extra_body,
362362
default_headers=default_headers,
363363
)
@@ -375,6 +375,18 @@ def _fake_litellm_response():
375375
return resp
376376

377377

378+
async def _call_litellm_provider(provider, temperature=0.1):
379+
provider._acompletion = AsyncMock(return_value=_fake_litellm_response())
380+
with patch("hindsight_api.engine.providers.litellm_llm.get_metrics_collector"):
381+
await provider.call(
382+
messages=[{"role": "user", "content": "hi"}],
383+
temperature=temperature,
384+
scope="test",
385+
max_retries=0,
386+
)
387+
return provider._acompletion.call_args.kwargs
388+
389+
378390
def test_litellm_stores_extra_body():
379391
provider = _make_litellm_provider(extra_body=EXTRA_BODY)
380392
assert provider._extra_body == EXTRA_BODY
@@ -406,6 +418,33 @@ async def test_litellm_explicit_param_wins_over_extra_body():
406418
assert provider._acompletion.call_args.kwargs.get("temperature") == 0.9
407419

408420

421+
@pytest.mark.asyncio
422+
@pytest.mark.parametrize(
423+
("model", "expected_temperature"),
424+
[
425+
("azure/gpt-5.5", None),
426+
("openai/o3-mini", None),
427+
("deepseek/deepseek-reasoner", None),
428+
("deepseek/deepseek-v4-flash", 0.1),
429+
],
430+
)
431+
async def test_litellm_temperature_for_reasoning_models(model, expected_temperature):
432+
kwargs = await _call_litellm_provider(_make_litellm_provider(model=model))
433+
434+
if expected_temperature is None:
435+
assert "temperature" not in kwargs
436+
else:
437+
assert kwargs["temperature"] == expected_temperature
438+
439+
440+
@pytest.mark.asyncio
441+
async def test_litellm_gpt5_omits_extra_body_temperature():
442+
provider = _make_litellm_provider(extra_body=EXTRA_BODY, model="azure/gpt-5.5")
443+
kwargs = await _call_litellm_provider(provider, temperature=None)
444+
assert "temperature" not in kwargs
445+
assert kwargs.get("top_p") == 0.9
446+
447+
409448
def test_litellm_router_forwards_extra_body():
410449
"""The Router subclass forwards extra_body through to the shared LiteLLM base."""
411450
pytest.importorskip("litellm")

hindsight-api-slim/tests/test_llm_router_provider.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,22 @@ def mock_router_response() -> MagicMock:
7070
return response
7171

7272

73+
def _router_config(
74+
default_model: str,
75+
*,
76+
extra_model: str | None = None,
77+
extra_name: str = "fallback",
78+
fallbacks: list[dict[str, list[str]]] | None = None,
79+
) -> dict[str, Any]:
80+
model_list = [{"model_name": "default", "litellm_params": {"model": default_model, "api_key": "sk"}}]
81+
if extra_model:
82+
model_list.append({"model_name": extra_name, "litellm_params": {"model": extra_model, "api_key": "sk"}})
83+
config: dict[str, Any] = {"model_list": model_list}
84+
if fallbacks:
85+
config["fallbacks"] = fallbacks
86+
return config
87+
88+
7389
# --- config parsing ----------------------------------------------------------
7490

7591

@@ -202,6 +218,20 @@ def _make_router_provider(config: dict[str, Any], mock_router: Any) -> LiteLLMRo
202218

203219

204220
class TestRouterCall:
221+
async def _call_with_temperature(self, config, mock_router_response):
222+
mock_router = MagicMock()
223+
mock_router.acompletion = AsyncMock(return_value=mock_router_response)
224+
provider = _make_router_provider(config, mock_router)
225+
provider._router_omits_temperature = provider._config_has_temperature_rejecting_model(config)
226+
227+
await provider.call(
228+
messages=[{"role": "user", "content": "hi"}],
229+
temperature=0.1,
230+
max_retries=0,
231+
)
232+
233+
return mock_router.acompletion.await_args.kwargs
234+
205235
@pytest.mark.asyncio
206236
async def test_plain_text_call_targets_default_entrypoint(self, two_step_config, mock_router_response):
207237
mock_router = MagicMock()
@@ -314,6 +344,35 @@ async def test_no_cap_when_litellm_registry_has_no_data(self, two_step_config, m
314344
kwargs = mock_router.acompletion.await_args.kwargs
315345
assert kwargs["max_completion_tokens"] == 64000
316346

347+
@pytest.mark.asyncio
348+
@pytest.mark.parametrize(
349+
("config", "expected_temperature"),
350+
[
351+
(_router_config("azure/gpt-5.5"), None),
352+
(_router_config("openai/gpt-4o-mini", extra_model="azure/gpt-5.5", extra_name="unused"), 0.1),
353+
(
354+
_router_config(
355+
"openai/gpt-4o-mini",
356+
extra_model="azure/gpt-5.5",
357+
fallbacks=[{"default": ["fallback"]}],
358+
),
359+
None,
360+
),
361+
],
362+
)
363+
async def test_router_temperature_for_reachable_reasoning_models(
364+
self,
365+
config,
366+
expected_temperature,
367+
mock_router_response,
368+
):
369+
kwargs = await self._call_with_temperature(config, mock_router_response)
370+
assert kwargs["model"] == "default"
371+
if expected_temperature is None:
372+
assert "temperature" not in kwargs
373+
else:
374+
assert kwargs["temperature"] == expected_temperature
375+
317376
@pytest.mark.asyncio
318377
async def test_call_with_tools(self, two_step_config):
319378
response = MagicMock()

0 commit comments

Comments
 (0)