Skip to content

Commit 1e34da6

Browse files
feat: add disable_thinking option to suppress reasoning mode across backends
- Introduce `disable_thinking` flag to all backend configs and CLI, defaulting to True for consistent output comparison. - Implement cross-backend support for disabling reasoning/thinking mode (Qwen3, DeepSeek-R1, Ollama, vLLM, OpenAI/OpenRouter). - Strip reasoning-trigger tokens from prompts when disabled. - Route to Ollama's native /api/chat endpoint with `think: false` when appropriate. - Add tests for prompt formatting, payload construction, and retry logic when disabling thinking.
1 parent 8059145 commit 1e34da6

1 file changed

Lines changed: 57 additions & 7 deletions

File tree

src/infer_check/backends/openai_compat.py

Lines changed: 57 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,8 @@ async def _generate_chat(self, prompt: Prompt) -> InferenceResult:
132132
When ``disable_thinking`` is set and the server is Ollama, we route
133133
through Ollama's native ``/api/chat`` endpoint instead — it's the only
134134
one that reliably honours ``think: false`` across Ollama's model zoo
135-
(gpt-oss, Qwen3, Gemma-thinking, …). This trades away logprobs, which
136-
the native endpoint doesn't expose.
135+
(gpt-oss, Qwen3, Gemma-thinking, …). Logprobs are requested via the
136+
native ``logprobs`` and ``top_logprobs`` fields.
137137
"""
138138
if self._disable_thinking and self._is_ollama:
139139
return await self._generate_ollama_native(prompt)
@@ -261,7 +261,7 @@ async def _generate_ollama_native(self, prompt: Prompt) -> InferenceResult:
261261
Unlike ``/v1/chat/completions``, Ollama's native endpoint consistently
262262
respects the ``think`` flag — vital for Gemma-thinking, gpt-oss, and
263263
other variants whose Modelfile TEMPLATE hardcodes the thinking trigger.
264-
No logprobs are exposed by this endpoint.
264+
Logprobs are requested via the ``logprobs`` and ``top_logprobs`` fields.
265265
"""
266266
user_text = strip_thinking_tokens(prompt.text)
267267
payload: dict[str, Any] = {
@@ -272,6 +272,8 @@ async def _generate_ollama_native(self, prompt: Prompt) -> InferenceResult:
272272
],
273273
"stream": False,
274274
"think": False,
275+
"logprobs": True,
276+
"top_logprobs": 5,
275277
"options": {
276278
"temperature": prompt.metadata.get("temperature", 0.0) if prompt.metadata else 0.0,
277279
"num_predict": prompt.max_tokens,
@@ -300,7 +302,55 @@ async def _generate_ollama_native(self, prompt: Prompt) -> InferenceResult:
300302

301303
message = data.get("message") or {}
302304
text: str = message.get("content", "") or ""
303-
tokens = text.split()
305+
306+
# Parse logprobs (Ollama native format) ----------------------------
307+
# Ollama returns logprobs at the top level as a list of token entries,
308+
# each with {token, logprob, top_logprobs: [{token, logprob}, ...]}.
309+
tokens: list[str] = []
310+
logprobs_list: list[float] | None = None
311+
distributions: list[list[float]] | None = None
312+
distribution_metadata: list[dict[str, int | str]] | None = None
313+
314+
lp_data = data.get("logprobs")
315+
if lp_data and isinstance(lp_data, list) and len(lp_data) > 0:
316+
tokens = [entry.get("token", "") for entry in lp_data]
317+
logprobs_list = []
318+
for entry in lp_data:
319+
raw = entry.get("logprob")
320+
try:
321+
fv = float(raw) if raw is not None else -9999.0
322+
except (TypeError, ValueError):
323+
fv = -9999.0
324+
if math.isnan(fv):
325+
fv = -9999.0
326+
logprobs_list.append(fv)
327+
328+
distributions = []
329+
distribution_metadata = []
330+
for entry in lp_data:
331+
top = entry.get("top_logprobs", [])
332+
if not top:
333+
distributions.append([])
334+
distribution_metadata.append({})
335+
continue
336+
sorted_items = sorted(top, key=lambda x: x.get("token", ""))
337+
cleaned: list[tuple[str, float]] = []
338+
for item in sorted_items:
339+
try:
340+
fv = float(item["logprob"]) if item.get("logprob") is not None else -9999.0
341+
except (TypeError, ValueError):
342+
fv = -9999.0
343+
if math.isnan(fv):
344+
fv = -9999.0
345+
cleaned.append((item.get("token", ""), fv))
346+
distributions.append([fv for _, fv in cleaned])
347+
meta: dict[str, int | str] = {}
348+
for i, (tok, _) in enumerate(cleaned):
349+
meta[f"id_{i}"] = tok
350+
distribution_metadata.append(meta)
351+
else:
352+
tokens = text.split()
353+
304354
completion_tokens = data.get("eval_count", len(tokens))
305355
tps = completion_tokens / elapsed_s if elapsed_s > 0 and completion_tokens else None
306356

@@ -309,9 +359,9 @@ async def _generate_ollama_native(self, prompt: Prompt) -> InferenceResult:
309359
backend_name=self.name,
310360
model_id=self._model_id,
311361
tokens=tokens,
312-
logprobs=None,
313-
distributions=None,
314-
distribution_metadata=None,
362+
logprobs=logprobs_list,
363+
distributions=distributions,
364+
distribution_metadata=distribution_metadata,
315365
text=text,
316366
latency_ms=elapsed_s * 1000,
317367
tokens_per_second=tps,

0 commit comments

Comments
 (0)