diff --git a/docs/backends.md b/docs/backends.md index 0d82caf..5e43c33 100644 --- a/docs/backends.md +++ b/docs/backends.md @@ -9,7 +9,7 @@ | **mlx-lm** | In-process | (local) | Local Apple Silicon inference with logprobs | | **llama-cpp** | HTTP | `http://127.0.0.1:8080` | llama-server via `/completion` endpoint | | **vllm-mlx** | HTTP | `http://127.0.0.1:8000` | Continuous batching on Apple Silicon | -| **openai-compat** | HTTP | `http://127.0.0.1:11434/v1` | Any OpenAI-compatible server (vLLM, SGLang, Ollama) | +| **openai-compat** | HTTP | `http://127.0.0.1:11434` | Any OpenAI-compatible server (vLLM, SGLang, Ollama) | ## mlx-lm @@ -109,7 +109,7 @@ Generic backend for any server that implements the OpenAI API format. Works with | Model source | Default URL | |-------------|-------------| -| Ollama tags (e.g., `llama3.1:8b`) | `http://127.0.0.1:11434/v1` | +| Ollama tags (e.g., `llama3.1:8b`) | `http://127.0.0.1:11434` | | Custom server | Use `--base-url` | **Example with Ollama**: diff --git a/pyproject.toml b/pyproject.toml index b23ded4..e714900 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,9 +37,12 @@ dev = [ "pytest-cov>=7.0.0", "ruff>=0.15.5", ] -all = ["infer-check[mlx]", "infer-check[analysis]"] +all = ["infer-check[mlx]", "infer-check[analysis]", "infer-check[http]"] mlx = ["mlx>=0.31.0", "mlx-lm<0.31.0"] analysis = ["matplotlib>=3.10.8", "scikit-learn>=1.8.0"] +http = [ + "transformers>=5.3.0", +] [project.scripts] infer-check = "infer_check.cli:main" diff --git a/src/infer_check/backends/base.py b/src/infer_check/backends/base.py index 1d8a358..9283c80 100644 --- a/src/infer_check/backends/base.py +++ b/src/infer_check/backends/base.py @@ -41,7 +41,9 @@ class BackendConfig(BaseModel): backend_type: Literal["mlx-lm", "llama-cpp", "vllm-mlx", "openai-compat"] model_id: str quantization: str | None = None + hf_revision: str | None = None base_url: str | None = None + disable_thinking: bool = True extra: dict[str, Any] = Field(default_factory=dict) @@ -53,12 +55,19 @@ def get_backend(config: BackendConfig) -> BackendAdapter: return MLXBackend( model_id=config.model_id, quantization=config.quantization, + revision=config.hf_revision, + disable_thinking=config.disable_thinking, ) elif config.backend_type == "llama-cpp": from infer_check.backends.llama_cpp import LlamaCppBackend url = config.base_url or "http://127.0.0.1:8080" - return LlamaCppBackend(model_id=config.model_id, base_url=url) + return LlamaCppBackend( + model_id=config.model_id, + base_url=url, + revision=config.hf_revision, + disable_thinking=config.disable_thinking, + ) elif config.backend_type == "vllm-mlx": from infer_check.backends.vllm_mlx import VLLMMLXBackend @@ -66,20 +75,24 @@ def get_backend(config: BackendConfig) -> BackendAdapter: return VLLMMLXBackend( model_id=config.model_id, base_url=url, - chat=config.extra.get("chat", False), + chat=config.extra.get("chat", True), + revision=config.hf_revision, + disable_thinking=config.disable_thinking, ) elif config.backend_type == "openai-compat": from infer_check.backends.openai_compat import OpenAICompatBackend if not config.base_url: raise ValueError( - "openai-compat backend requires --base-url. Example: --base-url http://127.0.0.1:11434/v1 (Ollama)" + "openai-compat backend requires --base-url. Example: --base-url http://127.0.0.1:11434 (Ollama)" ) return OpenAICompatBackend( base_url=config.base_url, model_id=config.model_id, api_key=config.extra.get("api_key"), - chat=config.extra.get("chat", False), + chat=config.extra.get("chat", True), + revision=config.hf_revision, + disable_thinking=config.disable_thinking, ) else: supported = ", ".join(["mlx-lm", "llama-cpp", "vllm-mlx", "openai-compat"]) @@ -91,6 +104,8 @@ def get_backend_for_model( backend_type: str | None = None, base_url: str | None = None, quantization: str | None = None, + disable_thinking: bool = True, + chat: bool = True, ) -> BackendAdapter: """Resolve model string to a backend and instantiate it. @@ -106,6 +121,9 @@ def get_backend_for_model( model_id=resolved.model_id, base_url=base_url or resolved.base_url, quantization=quantization or resolved.label, + hf_revision=resolved.revision, + disable_thinking=disable_thinking, + extra={"chat": chat}, ) return get_backend(config) diff --git a/src/infer_check/backends/llama_cpp.py b/src/infer_check/backends/llama_cpp.py index cc8a338..f66b98f 100644 --- a/src/infer_check/backends/llama_cpp.py +++ b/src/infer_check/backends/llama_cpp.py @@ -9,6 +9,7 @@ import httpx from infer_check.types import InferenceResult, Prompt +from infer_check.utils import format_prompt __all__ = ["LlamaCppBackend"] @@ -19,9 +20,17 @@ class LlamaCppBackend: Communicates via the ``/completion`` endpoint. """ - def __init__(self, model_id: str, base_url: str = "http://127.0.0.1:8080") -> None: + def __init__( + self, + model_id: str, + base_url: str = "http://127.0.0.1:8080", + revision: str | None = None, + disable_thinking: bool = True, + ) -> None: self._model_id = model_id self._base_url = base_url.rstrip("/") + self._revision = revision + self._disable_thinking = disable_thinking self._client = httpx.AsyncClient(base_url=self._base_url, timeout=120.0) # ------------------------------------------------------------------ @@ -34,9 +43,15 @@ def name(self) -> str: async def generate(self, prompt: Prompt) -> InferenceResult: """Send a completion request and parse the response.""" + formatted = format_prompt( + prompt.text, + model_id=self._model_id, + revision=self._revision, + disable_thinking=self._disable_thinking, + ) payload = { "model": self._model_id, - "prompt": prompt.text, + "prompt": formatted, "n_predict": prompt.max_tokens, "temperature": prompt.metadata.get("temperature", 0.0) if prompt.metadata else 0.0, "n_probs": 10, # Request top 10 probabilities for KL divergence diff --git a/src/infer_check/backends/mlx_lm.py b/src/infer_check/backends/mlx_lm.py index 444d2ed..c455dc6 100644 --- a/src/infer_check/backends/mlx_lm.py +++ b/src/infer_check/backends/mlx_lm.py @@ -8,6 +8,7 @@ from typing import Any, cast from infer_check.types import InferenceResult, Prompt +from infer_check.utils import format_prompt __all__ = ["MLXBackend"] @@ -19,9 +20,17 @@ class MLXBackend: importing this module alone never triggers a heavy download. """ - def __init__(self, model_id: str, quantization: str | None = None) -> None: + def __init__( + self, + model_id: str, + quantization: str | None = None, + revision: str | None = None, + disable_thinking: bool = True, + ) -> None: self._model_id = model_id self._quantization = quantization + self._revision = revision + self._disable_thinking = disable_thinking self._model: Any = None self._tokenizer: Any = None @@ -112,7 +121,7 @@ def _ensure_loaded(self) -> None: repo_or_path = str(model_path) if model_path.exists() else self._model_id try: - res = load(repo_or_path) + res = load(repo_or_path, revision=self._revision) except Exception as exc: msg = str(exc) if "404" in msg or "Repository Not Found" in msg: @@ -132,17 +141,6 @@ def _ensure_loaded(self) -> None: self._model = res[0] self._tokenizer = res[1] - def _format_prompt(self, text: str) -> str: - """Apply chat template if the tokenizer has one (Instruct models). - - Raw prompts sent to Instruct models produce undefined behavior that - varies across quantization levels, making comparisons meaningless. - """ - if hasattr(self._tokenizer, "apply_chat_template") and self._tokenizer.chat_template is not None: - messages = [{"role": "user", "content": text}] - return str(self._tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)) - return text - def _generate_simple(self, prompt: Prompt) -> InferenceResult: """Generate using the high-level ``mlx_lm.generate`` API.""" from mlx_lm import generate as mlx_generate @@ -150,7 +148,12 @@ def _generate_simple(self, prompt: Prompt) -> InferenceResult: temp = prompt.metadata.get("temperature", 0.0) if prompt.metadata else 0.0 sampler = make_sampler(temp=temp) - formatted = self._format_prompt(prompt.text) + formatted = format_prompt( + prompt.text, + tokenizer=self._tokenizer, + revision=self._revision, + disable_thinking=self._disable_thinking, + ) start = time.perf_counter() text: str = mlx_generate( self._model, @@ -184,7 +187,12 @@ def _generate_with_logprobs(self, prompt: Prompt) -> InferenceResult: temp = prompt.metadata.get("temperature", 0.0) if prompt.metadata else 0.0 sampler = make_sampler(temp=temp) - formatted = self._format_prompt(prompt.text) + formatted = format_prompt( + prompt.text, + tokenizer=self._tokenizer, + revision=self._revision, + disable_thinking=self._disable_thinking, + ) input_ids = mx.array(self._tokenizer.encode(formatted)) # Configurable top-K to avoid memory explosion. Default to 10. diff --git a/src/infer_check/backends/openai_compat.py b/src/infer_check/backends/openai_compat.py index 7c26d2f..4142eec 100644 --- a/src/infer_check/backends/openai_compat.py +++ b/src/infer_check/backends/openai_compat.py @@ -14,6 +14,7 @@ import httpx from infer_check.types import InferenceResult, Prompt +from infer_check.utils import format_prompt, strip_thinking_tokens __all__ = ["OpenAICompatBackend"] @@ -42,11 +43,20 @@ def __init__( model_id: str, api_key: str | None = None, chat: bool = False, + revision: str | None = None, + disable_thinking: bool = True, ) -> None: self._base_url = base_url.rstrip("/") self._model_id = model_id self._api_key = api_key self._chat = chat + self._revision = revision + self._disable_thinking = disable_thinking + # Ollama listens on :11434 by default. Track it so later request + # handling can apply Ollama-specific chat behavior when thinking is + # disabled (for example, using request flags and stripping think + # tokens from responses) rather than relying on prompt rewriting. + self._is_ollama = ":11434" in self._base_url headers: dict[str, str] = {} if api_key: @@ -60,6 +70,9 @@ def __init__( # Logprobs support: assume yes until a server rejects it. self._chat_logprobs_supported: bool = True + # Thinking-disable keys are opportunistic: most servers accept them, + # some reject unknown params with 400/422. We drop them on first failure. + self._thinking_keys_supported: bool = True # ------------------------------------------------------------------ # BackendAdapter protocol @@ -115,25 +128,59 @@ async def _generate_chat(self, prompt: Prompt) -> InferenceResult: request fails with 400 or 422 (unsupported parameter), the backend automatically retries without logprobs and disables them for all subsequent requests. + + When ``disable_thinking`` is set and the server is Ollama, we route + through Ollama's native ``/api/chat`` endpoint instead — it's the only + one that reliably honours ``think: false`` across Ollama's model zoo + (gpt-oss, Qwen3, Gemma-thinking, …). Logprobs are requested via the + native ``logprobs`` and ``top_logprobs`` fields. """ + if self._disable_thinking and self._is_ollama: + return await self._generate_ollama_native(prompt) + + user_text = strip_thinking_tokens(prompt.text) if self._disable_thinking else prompt.text + messages: list[dict[str, str]] = [] + if self._disable_thinking: + # Empty system message overrides any server-side SYSTEM default + # (Ollama Modelfile SYSTEM, hosted-template system prompts, …) + # that might re-inject a thinking trigger token. + messages.append({"role": "system", "content": ""}) + messages.append({"role": "user", "content": user_text}) payload: dict[str, object] = { "model": self._model_id, - "messages": [{"role": "user", "content": prompt.text}], + "messages": messages, "max_tokens": prompt.max_tokens, "temperature": prompt.metadata.get("temperature", 0.0) if prompt.metadata else 0.0, } if self._chat_logprobs_supported: payload["logprobs"] = True payload["top_logprobs"] = 5 + if self._disable_thinking and self._thinking_keys_supported: + # Cross-backend hints for disabling reasoning/thinking mode: + # chat_template_kwargs.enable_thinking — vLLM / SGLang / Qwen3 family + # chat_template_kwargs.thinking — some DeepSeek / HunYuan templates + # think — Ollama native flag + # reasoning.enabled / reasoning_effort — OpenRouter / OpenAI-style + payload["chat_template_kwargs"] = {"enable_thinking": False, "thinking": False} + payload["think"] = False + payload["reasoning"] = {"enabled": False} + payload["reasoning_effort"] = "minimal" try: elapsed_s, data = await self._post_chat(payload) except _ServerHTTPError as exc: - # Retry without logprobs only on 400/422 (unsupported parameter). - if self._chat_logprobs_supported and exc.status_code in (400, 422): - self._chat_logprobs_supported = False - payload.pop("logprobs", None) - payload.pop("top_logprobs", None) + # Retry shedding unsupported params only on 400/422. + if exc.status_code in (400, 422) and (self._chat_logprobs_supported or self._thinking_keys_supported): + if self._chat_logprobs_supported: + self._chat_logprobs_supported = False + payload.pop("logprobs", None) + payload.pop("top_logprobs", None) + if self._thinking_keys_supported: + self._thinking_keys_supported = False + payload.pop("chat_template_kwargs", None) + payload.pop("think", None) + payload.pop("reasoning", None) + payload.pop("reasoning_effort", None) elapsed_s, data = await self._post_chat(payload) else: raise @@ -208,11 +255,129 @@ async def _generate_chat(self, prompt: Prompt) -> InferenceResult: tokens_per_second=tps, ) + async def _generate_ollama_native(self, prompt: Prompt) -> InferenceResult: + """POST to Ollama's native ``/api/chat`` with ``think: false``. + + Unlike ``/v1/chat/completions``, Ollama's native endpoint consistently + respects the ``think`` flag — vital for Gemma-thinking, gpt-oss, and + other variants whose Modelfile TEMPLATE hardcodes the thinking trigger. + Logprobs are requested via the ``logprobs`` and ``top_logprobs`` fields. + """ + user_text = strip_thinking_tokens(prompt.text) + payload: dict[str, Any] = { + "model": self._model_id, + "messages": [ + {"role": "system", "content": ""}, + {"role": "user", "content": user_text}, + ], + "stream": False, + "think": False, + "logprobs": True, + "top_logprobs": 5, + "options": { + "temperature": prompt.metadata.get("temperature", 0.0) if prompt.metadata else 0.0, + "num_predict": prompt.max_tokens, + }, + } + + start = time.perf_counter() + try: + response = await self._client.post("/api/chat", json=payload) + response.raise_for_status() + except httpx.ConnectError as exc: + raise RuntimeError( + f"Cannot connect to Ollama at {self._base_url}. Ensure `ollama serve` is running." + ) from exc + except httpx.TimeoutException as exc: + raise RuntimeError(f"Request to {self._base_url}/api/chat timed out after 120s.") from exc + except httpx.HTTPStatusError as exc: + raise RuntimeError(f"Ollama returned HTTP {exc.response.status_code}: {exc.response.text[:500]}") from exc + + elapsed_s = time.perf_counter() - start + + try: + data = response.json() + except Exception as exc: + raise RuntimeError(f"Ollama returned non-JSON response: {response.text[:200]}") from exc + + message = data.get("message") or {} + text: str = message.get("content", "") or "" + + # Parse logprobs (Ollama native format) ---------------------------- + # Ollama returns logprobs at the top level as a list of token entries, + # each with {token, logprob, top_logprobs: [{token, logprob}, ...]}. + tokens: list[str] = [] + logprobs_list: list[float] | None = None + distributions: list[list[float]] | None = None + distribution_metadata: list[dict[str, int | str]] | None = None + + lp_data = data.get("logprobs") + if lp_data and isinstance(lp_data, list) and len(lp_data) > 0: + tokens = [entry.get("token", "") for entry in lp_data] + logprobs_list = [] + for entry in lp_data: + raw = entry.get("logprob") + try: + fv = float(raw) if raw is not None else -9999.0 + except (TypeError, ValueError): + fv = -9999.0 + if math.isnan(fv): + fv = -9999.0 + logprobs_list.append(fv) + + distributions = [] + distribution_metadata = [] + for entry in lp_data: + top = entry.get("top_logprobs", []) + if not top: + distributions.append([]) + distribution_metadata.append({}) + continue + sorted_items = sorted(top, key=lambda x: x.get("token", "")) + cleaned: list[tuple[str, float]] = [] + for item in sorted_items: + try: + fv = float(item["logprob"]) if item.get("logprob") is not None else -9999.0 + except (TypeError, ValueError): + fv = -9999.0 + if math.isnan(fv): + fv = -9999.0 + cleaned.append((item.get("token", ""), fv)) + distributions.append([fv for _, fv in cleaned]) + meta: dict[str, int | str] = {} + for i, (tok, _) in enumerate(cleaned): + meta[f"id_{i}"] = tok + distribution_metadata.append(meta) + else: + tokens = text.split() + + completion_tokens = data.get("eval_count", len(tokens)) + tps = completion_tokens / elapsed_s if elapsed_s > 0 and completion_tokens else None + + return InferenceResult( + prompt_id=prompt.id, + backend_name=self.name, + model_id=self._model_id, + tokens=tokens, + logprobs=logprobs_list, + distributions=distributions, + distribution_metadata=distribution_metadata, + text=text, + latency_ms=elapsed_s * 1000, + tokens_per_second=tps, + ) + async def _generate_completions(self, prompt: Prompt) -> InferenceResult: """Use the legacy ``/v1/completions`` endpoint for raw logprobs.""" + formatted = format_prompt( + prompt.text, + model_id=self._model_id, + revision=self._revision, + disable_thinking=self._disable_thinking, + ) payload = { "model": self._model_id, - "prompt": prompt.text, + "prompt": formatted, "max_tokens": prompt.max_tokens, "temperature": prompt.metadata.get("temperature", 0.0) if prompt.metadata else 0.0, "logprobs": 5, diff --git a/src/infer_check/backends/vllm_mlx.py b/src/infer_check/backends/vllm_mlx.py index 0d044ff..6d95bbc 100644 --- a/src/infer_check/backends/vllm_mlx.py +++ b/src/infer_check/backends/vllm_mlx.py @@ -23,12 +23,16 @@ def __init__( base_url: str = "http://127.0.0.1:8000", api_key: str | None = None, chat: bool = False, + revision: str | None = None, + disable_thinking: bool = True, ) -> None: super().__init__( base_url=base_url, model_id=model_id, api_key=api_key, chat=chat, + revision=revision, + disable_thinking=disable_thinking, ) @property @@ -53,6 +57,7 @@ def from_model( cls, model_id: str, quantization: str | None = None, + revision: str | None = None, base_url: str = "http://127.0.0.1:8000", ) -> VLLMMLXBackend: """Create a backend for *model_id*. @@ -65,11 +70,13 @@ def from_model( python -m vllm.entrypoints.openai.api_server \\ --model \\ --quantization \\ + --revision \\ --port 8000 Args: model_id: HuggingFace model identifier. quantization: Optional quantization string (e.g. ``"4bit"``). + revision: Optional HuggingFace revision (e.g. ``"main"``). base_url: Server URL (default ``http://127.0.0.1:8000``). Returns: @@ -78,4 +85,5 @@ def from_model( return cls( model_id=model_id, base_url=base_url, + revision=revision, ) diff --git a/src/infer_check/cli.py b/src/infer_check/cli.py index b5bfc89..efdf0c7 100644 --- a/src/infer_check/cli.py +++ b/src/infer_check/cli.py @@ -33,6 +33,29 @@ def common_options(f: F) -> F: type=click.IntRange(min=1, clamp=True), help="Limit number of prompts to use.", ), + click.option( + "--disable-thinking/--enable-thinking", + "disable_thinking", + default=True, + show_default=True, + help=( + "Suppress reasoning/thinking mode on models that support it " + "(Qwen3, DeepSeek-R1, Ollama think, vLLM chat_template_kwargs, " + "OpenAI/OpenRouter reasoning). Models without a thinking mode " + "are unaffected. Defaults to disabled so outputs are directly " + "comparable across runs; pass --enable-thinking to restore it." + ), + ), + click.option( + "--chat/--no-chat", + default=True, + show_default=True, + help=( + "Use /v1/chat/completions for HTTP backends (applies chat " + "template server-side). Pass --no-chat to use raw " + "/v1/completions instead. Ignored for mlx-lm." + ), + ), ] for option in reversed(options): f = option(f) @@ -137,6 +160,8 @@ def sweep( base_url: str | None, max_tokens: int | None, num_prompts: int | None, + disable_thinking: bool, + chat: bool, ) -> None: """Run a quantization sweep: compare pre-quantized models against a baseline. @@ -192,6 +217,8 @@ def sweep( backend_type=backend, base_url=base_url, quantization=label, + disable_thinking=disable_thinking, + chat=chat, ) runner = TestRunner() @@ -335,6 +362,8 @@ def compare( report: bool, max_tokens: int | None, num_prompts: int | None, + disable_thinking: bool, + chat: bool, ) -> None: """Compare two quantizations of the same model. @@ -383,15 +412,19 @@ def compare( backend_type=resolved_a.backend, model_id=resolved_a.model_id, quantization=resolved_a.label, + hf_revision=resolved_a.revision, base_url=resolved_a.base_url, - extra={"chat": False}, + disable_thinking=disable_thinking, + extra={"chat": chat}, ) config_b = BackendConfig( backend_type=resolved_b.backend, model_id=resolved_b.model_id, quantization=resolved_b.label, + hf_revision=resolved_b.revision, base_url=resolved_b.base_url, - extra={"chat": False}, + disable_thinking=disable_thinking, + extra={"chat": chat}, ) backend_a = get_backend(config_a) backend_b = get_backend(config_b) @@ -573,12 +606,6 @@ def compare( default=None, help="Comma-separated base URLs for HTTP backends (positionally matched to --backends).", ) -@click.option( - "--chat/--no-chat", - default=True, - show_default=True, - help="Use /v1/chat/completions for HTTP backends (applies chat template server-side).", -) @common_options @click.pass_context def diff( @@ -589,31 +616,40 @@ def diff( output: Path, quant: str | None, base_urls: str | None, - chat: bool, max_tokens: int | None, num_prompts: int | None, + disable_thinking: bool, + chat: bool, ) -> None: """Compare outputs across different backends for the same model and prompts.""" from infer_check.backends.base import BackendConfig, get_backend + from infer_check.resolve import resolve_model from infer_check.runner import TestRunner prompt_list = _load_prompts(ctx, prompts, max_tokens, num_prompts) + # Resolve the model to handle @revision and ensure correct base URL + resolved = resolve_model(model) + backend_names = [b.strip() for b in backends.split(",") if b.strip()] url_list: list[str | None] = [u.strip() for u in base_urls.split(",")] if base_urls else [None] * len(backend_names) # Pad url_list if shorter than backend_names while len(url_list) < len(backend_names): url_list.append(None) - console.print(f"[bold cyan]diff[/bold cyan] model={model} backends={backend_names} quant={quant}") + console.print(f"[bold cyan]diff[/bold cyan] model={resolved.model_id} backends={backend_names} quant={quant}") + if resolved.revision: + console.print(f" revision: {resolved.revision}") backend_instances = [] for name, url in zip(backend_names, url_list, strict=True): config = BackendConfig( backend_type=name, # type: ignore[arg-type] - model_id=model, + model_id=resolved.model_id, quantization=quant, - base_url=url, + hf_revision=resolved.revision, + base_url=url or (resolved.base_url if name == resolved.backend else None), + disable_thinking=disable_thinking, extra={"chat": chat} if name in ("openai-compat", "vllm-mlx") else {}, ) backend_instances.append(get_backend(config)) @@ -633,7 +669,7 @@ def diff( # Persist results output.mkdir(parents=True, exist_ok=True) ts = int(datetime.now(UTC).timestamp()) - out_path = output / f"diff_{model.replace('/', '_')}_{ts}.json" + out_path = output / f"diff_{resolved.model_id.replace('/', '_')}_{ts}.json" out_path.write_text( json.dumps( [c.model_dump(mode="json") for c in comparisons], @@ -713,6 +749,8 @@ def stress( base_url: str | None, max_tokens: int | None, num_prompts: int | None, + disable_thinking: bool, + chat: bool, ) -> None: """Stress-test a backend with varying concurrency levels.""" from infer_check.backends.base import get_backend_for_model @@ -726,6 +764,8 @@ def stress( model_str=model, backend_type=backend, base_url=base_url, + disable_thinking=disable_thinking, + chat=chat, ) console.print( @@ -802,6 +842,8 @@ def determinism( base_url: str | None, max_tokens: int | None, num_prompts: int | None, + disable_thinking: bool, + chat: bool, ) -> None: """Test whether a backend produces identical outputs across repeated runs at temperature=0.""" from infer_check.backends.base import get_backend_for_model @@ -813,6 +855,8 @@ def determinism( model_str=model, backend_type=backend, base_url=base_url, + disable_thinking=disable_thinking, + chat=chat, ) console.print(f"[bold cyan]determinism[/bold cyan] model={model} backend={backend_instance.name} runs={runs}") diff --git a/src/infer_check/resolve.py b/src/infer_check/resolve.py index 9bff4c0..e6f6c59 100644 --- a/src/infer_check/resolve.py +++ b/src/infer_check/resolve.py @@ -34,7 +34,7 @@ # Default base URLs per backend (can be overridden via CLI). _DEFAULT_URLS: dict[BackendType, str] = { - "openai-compat": "http://127.0.0.1:11434/v1", # Ollama + "openai-compat": "http://127.0.0.1:11434", # Ollama (backend adds /v1/... paths) "llama-cpp": "http://127.0.0.1:8080", "vllm-mlx": "http://127.0.0.1:8000", } @@ -48,9 +48,13 @@ class ResolvedModel: model_id: str base_url: str | None label: str # short human-readable label for tables / reports + revision: str | None = None def __str__(self) -> str: - return f"{self.label} ({self.backend})" + res = f"{self.label} ({self.backend})" + if self.revision: + res += f" @ {self.revision}" + return res def _make_label(model_id: str) -> str: @@ -81,23 +85,9 @@ def resolve_model( Args: spec: Model identifier. Can be prefixed (``ollama:model``, ``mlx:repo/model``, ``gguf:/path/to/file.gguf``) or bare. + Optionally includes a revision after '@' (e.g. ``repo/model@main``). base_url: Override the default base URL for HTTP backends. label: Override the auto-derived label. - - Returns: - A ``ResolvedModel`` with backend, model_id, base_url, and label. - - Raises: - ValueError: If the spec is empty or cannot be resolved. - - Examples: - >>> r = resolve_model("ollama:llama3.1:8b-instruct-q4_K_M") - >>> r.backend - 'openai-compat' - - >>> r = resolve_model("mlx-community/Llama-3.1-8B-Instruct-4bit") - >>> r.backend - 'mlx-lm' """ spec = spec.strip() if not spec: @@ -108,14 +98,22 @@ def resolve_model( pattern = f"^{re.escape(prefix)}:" if re.match(pattern, spec, re.IGNORECASE): model_id = spec[len(prefix) + 1 :] + + # Revision is allowed for explicit prefixes + actual_revision = None + if "@" in model_id and not model_id.startswith("@"): + model_id, actual_revision = model_id.rsplit("@", 1) + return ResolvedModel( backend=backend, model_id=model_id, base_url=base_url or _DEFAULT_URLS.get(backend), label=label or _make_label(model_id), + revision=actual_revision, ) # ── 2. Local .gguf file path ───────────────────────────────────── + # If it's a local .gguf path, we don't treat @ as a revision delimiter. local_path = Path(spec) if local_path.suffix.lower() == ".gguf": if local_path.exists(): @@ -124,6 +122,7 @@ def resolve_model( model_id=str(local_path.resolve()), base_url=base_url or _DEFAULT_URLS["llama-cpp"], label=label or local_path.stem, + revision=None, ) # Even if it doesn't exist yet, honour the extension. return ResolvedModel( @@ -131,9 +130,15 @@ def resolve_model( model_id=spec, base_url=base_url or _DEFAULT_URLS["llama-cpp"], label=label or local_path.stem, + revision=None, ) # ── 3. HuggingFace repo heuristics ────────────────────────────── + # HF repos CAN have @revision. + actual_revision = None + if "@" in spec and not spec.startswith("@"): + spec, actual_revision = spec.rsplit("@", 1) + spec_lower = spec.lower() # MLX repos (mlx-community org or -mlx suffix). @@ -148,6 +153,7 @@ def resolve_model( model_id=spec, base_url=None, # mlx-lm loads locally, no URL label=label or _make_label(spec), + revision=actual_revision, ) # GGUF repos (typically served via Ollama or llama-cpp). @@ -159,6 +165,7 @@ def resolve_model( model_id=spec, base_url=base_url or _DEFAULT_URLS["llama-cpp"], label=label or _make_label(spec), + revision=actual_revision, ) # ── 4. Ollama-style tags (contain colon but no slash) ──────────── @@ -169,6 +176,7 @@ def resolve_model( model_id=spec, base_url=base_url or _DEFAULT_URLS["openai-compat"], label=label or _make_label(spec), + revision=actual_revision, ) # ── 5. Fallback — assume mlx-lm (Mac-first user base) ─────────── @@ -177,4 +185,5 @@ def resolve_model( model_id=spec, base_url=None, label=label or _make_label(spec), + revision=actual_revision, ) diff --git a/src/infer_check/utils.py b/src/infer_check/utils.py index 36da254..271f80b 100644 --- a/src/infer_check/utils.py +++ b/src/infer_check/utils.py @@ -2,7 +2,31 @@ from __future__ import annotations +import contextlib import re +from functools import lru_cache +from typing import TYPE_CHECKING, Any, cast + +if TYPE_CHECKING: + from transformers import SentencePieceBackend, TokenizersBackend + +# Tokens that trigger reasoning mode on specific model/runner combos. When +# ``disable_thinking`` is set we strip them from the prompt so that a stray +# token in user input can't re-enable thinking. ``<|think|>`` is Ollama's +# system-prompt trigger for gpt-oss-style models; ```` is the +# DeepSeek-R1 / Qwen reasoning wrapper. +_THINKING_TOKEN_PATTERNS = ( + re.compile(r"<\|think\|>", re.IGNORECASE), + re.compile(r".*?", re.IGNORECASE | re.DOTALL), + re.compile(r"", re.IGNORECASE), +) + + +def strip_thinking_tokens(text: str) -> str: + """Remove reasoning-trigger tokens from ``text``.""" + for pattern in _THINKING_TOKEN_PATTERNS: + text = pattern.sub("", text) + return text def sanitize_filename(label: str) -> str: @@ -43,3 +67,69 @@ def sanitize_filename(label: str) -> str: # Ensure we have something left return safe if safe else "model" + + +@lru_cache(maxsize=8) +def _get_tokenizer(model_id: str, revision: str | None = None) -> Any: + """Helper to load and cache HuggingFace tokenizers.""" + from transformers import AutoTokenizer + + # We use local_files_only=True to ensure that we don't hang on network + # calls if the model isn't actually a HF repo (or if we're offline). + # This matches the tightened is_hf_id heuristic in format_prompt. + return AutoTokenizer.from_pretrained( + model_id, + revision=revision, + local_files_only=True, + trust_remote_code=False, + ) + + +def format_prompt( + text: str, + tokenizer: TokenizersBackend | SentencePieceBackend | None = None, + model_id: str | None = None, + revision: str | None = None, + disable_thinking: bool = False, +) -> str: + """Apply chat template client-side. + + Uses an existing tokenizer if provided (mlx-lm path), + or loads one from HuggingFace by model_id (HTTP backend path). + + When ``disable_thinking`` is True, attempt to turn off reasoning/thinking + mode via the chat template. This works across model families that expose a + template flag (Qwen3 and derivatives use ``enable_thinking``; some DeepSeek + and HunYuan variants use ``thinking``). Templates that don't know the flag + ignore it; templates that reject unknown kwargs trigger a graceful fallback + to normal rendering, so non-thinking models keep working unchanged. + """ + if disable_thinking: + text = strip_thinking_tokens(text) + + if tokenizer is None and model_id: + # Only attempt to load from HF if it looks like a HF repo (owner/repo) + # or an absolute/relative path. Ollama tags (name:tag) or local GGUF + # files should be skipped as they'll fail or hang from_pretrained. + is_hf_id = "/" in model_id + if is_hf_id: + with contextlib.suppress(Exception): + tokenizer = _get_tokenizer(model_id, revision) + + if tokenizer and hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template: + messages = [{"role": "user", "content": text}] + if disable_thinking: + for kwargs in ({"enable_thinking": False}, {"thinking": False}): + try: + return str( + tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + **cast(dict[str, Any], kwargs), + ) + ) + except TypeError: + continue + return str(tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)) + return text diff --git a/tests/unit/test_llama_cpp_fallback.py b/tests/unit/test_llama_cpp_fallback.py index 25c670a..96729b0 100644 --- a/tests/unit/test_llama_cpp_fallback.py +++ b/tests/unit/test_llama_cpp_fallback.py @@ -21,7 +21,10 @@ async def test_llama_cpp_model_id_fallback() -> None: ) try: - with patch("httpx.AsyncClient.post", return_value=mock_response): + with ( + patch("infer_check.backends.llama_cpp.format_prompt", return_value="Hello"), + patch("httpx.AsyncClient.post", return_value=mock_response), + ): res = await backend.generate(prompt) # Verify it falls back to backend's model_id instead of "unknown" diff --git a/tests/unit/test_llama_cpp_payload.py b/tests/unit/test_llama_cpp_payload.py index 085c4fc..6c9667c 100644 --- a/tests/unit/test_llama_cpp_payload.py +++ b/tests/unit/test_llama_cpp_payload.py @@ -20,7 +20,10 @@ async def test_llama_cpp_includes_model_in_payload() -> None: ) try: - with patch("httpx.AsyncClient.post", return_value=mock_response) as mock_post: + with ( + patch("infer_check.backends.llama_cpp.format_prompt", return_value="Hello"), + patch("httpx.AsyncClient.post", return_value=mock_response) as mock_post, + ): res = await backend.generate(prompt) # Verify the call to post diff --git a/tests/unit/test_mlx_backend.py b/tests/unit/test_mlx_backend.py index 1fb3357..3fddfdf 100644 --- a/tests/unit/test_mlx_backend.py +++ b/tests/unit/test_mlx_backend.py @@ -6,6 +6,7 @@ import pytest from infer_check.backends.mlx_lm import MLXBackend +from infer_check.utils import format_prompt @pytest.fixture @@ -35,7 +36,7 @@ def test_mlx_chat_template(mock_mlx: tuple[MagicMock, MagicMock, MagicMock]) -> backend._model = mock_mlx[1] backend._tokenizer = mock_tokenizer - formatted = backend._format_prompt("hello") + formatted = format_prompt("hello", tokenizer=backend._tokenizer) assert formatted == "hello" mock_tokenizer.apply_chat_template.assert_called_once() diff --git a/tests/unit/test_openai_compat.py b/tests/unit/test_openai_compat.py index 2068f91..aea304f 100644 --- a/tests/unit/test_openai_compat.py +++ b/tests/unit/test_openai_compat.py @@ -50,9 +50,12 @@ def test_generate_completions_404() -> None: prompt = Prompt(id="p1", text="Hello", max_tokens=10) mock_response = httpx.Response(404, request=httpx.Request("POST", "http://127.0.0.1:8000/v1/completions")) - with patch( - "httpx.AsyncClient.post", - side_effect=httpx.HTTPStatusError("404 Not Found", request=mock_response.request, response=mock_response), + with ( + patch("infer_check.backends.openai_compat.format_prompt", return_value="Hello"), + patch( + "httpx.AsyncClient.post", + side_effect=httpx.HTTPStatusError("404 Not Found", request=mock_response.request, response=mock_response), + ), ): with pytest.raises(RuntimeError) as exc: asyncio.run(backend.generate(prompt)) @@ -164,6 +167,226 @@ def test_generate_chat_with_logprobs() -> None: assert res.distribution_metadata == [{"id_0": "earth", "id_1": "world"}] +def test_generate_chat_disable_thinking_payload() -> None: + """With disable_thinking, payload carries cross-backend reasoning-off hints.""" + backend = OpenAICompatBackend( + base_url="http://127.0.0.1:8000", + model_id="dummy", + chat=True, + disable_thinking=True, + ) + prompt = Prompt(id="p1", text="Hello", max_tokens=10) + + mock_response = httpx.Response( + 200, + json={"choices": [{"message": {"content": "ok"}}], "usage": {"completion_tokens": 1}}, + request=httpx.Request("POST", "http://127.0.0.1:8000/v1/chat/completions"), + ) + + with patch("httpx.AsyncClient.post", return_value=mock_response) as post: + asyncio.run(backend.generate(prompt)) + sent = post.call_args.kwargs["json"] + assert sent["chat_template_kwargs"] == {"enable_thinking": False, "thinking": False} + assert sent["think"] is False + assert sent["reasoning"] == {"enabled": False} + assert sent["reasoning_effort"] == "minimal" + + +def test_generate_chat_enable_thinking_omits_keys() -> None: + """When disable_thinking=False, none of the reasoning-off hints are sent.""" + backend = OpenAICompatBackend( + base_url="http://127.0.0.1:8000", + model_id="dummy", + chat=True, + disable_thinking=False, + ) + prompt = Prompt(id="p1", text="Hello", max_tokens=10) + mock_response = httpx.Response( + 200, + json={"choices": [{"message": {"content": "ok"}}], "usage": {"completion_tokens": 1}}, + request=httpx.Request("POST", "http://127.0.0.1:8000/v1/chat/completions"), + ) + with patch("httpx.AsyncClient.post", return_value=mock_response) as post: + asyncio.run(backend.generate(prompt)) + sent = post.call_args.kwargs["json"] + assert "chat_template_kwargs" not in sent + assert "think" not in sent + assert "reasoning" not in sent + assert "reasoning_effort" not in sent + + +def test_generate_chat_disable_thinking_strips_think_token_from_message() -> None: + """Ollama <|think|> trigger is stripped from user content when disabled.""" + backend = OpenAICompatBackend( + base_url="http://127.0.0.1:8000", + model_id="dummy", + chat=True, + disable_thinking=True, + ) + prompt = Prompt(id="p1", text="<|think|>what is 2+2?", max_tokens=10) + mock_response = httpx.Response( + 200, + json={"choices": [{"message": {"content": "4"}}], "usage": {"completion_tokens": 1}}, + request=httpx.Request("POST", "http://127.0.0.1:8000/v1/chat/completions"), + ) + with patch("httpx.AsyncClient.post", return_value=mock_response) as post: + asyncio.run(backend.generate(prompt)) + sent = post.call_args.kwargs["json"] + assert sent["messages"] == [ + {"role": "system", "content": ""}, + {"role": "user", "content": "what is 2+2?"}, + ] + + +def test_generate_chat_enable_thinking_preserves_think_token() -> None: + backend = OpenAICompatBackend( + base_url="http://127.0.0.1:8000", + model_id="dummy", + chat=True, + disable_thinking=False, + ) + prompt = Prompt(id="p1", text="<|think|>what is 2+2?", max_tokens=10) + mock_response = httpx.Response( + 200, + json={"choices": [{"message": {"content": "4"}}], "usage": {"completion_tokens": 1}}, + request=httpx.Request("POST", "http://127.0.0.1:8000/v1/chat/completions"), + ) + with patch("httpx.AsyncClient.post", return_value=mock_response) as post: + asyncio.run(backend.generate(prompt)) + sent = post.call_args.kwargs["json"] + assert sent["messages"] == [{"role": "user", "content": "<|think|>what is 2+2?"}] + + +def test_generate_chat_disable_thinking_prepends_empty_system() -> None: + """Empty system message overrides server-side SYSTEM defaults (e.g. Modelfile).""" + backend = OpenAICompatBackend( + base_url="http://127.0.0.1:8000", # not Ollama port → /v1/chat path + model_id="dummy", + chat=True, + disable_thinking=True, + ) + prompt = Prompt(id="p1", text="hi", max_tokens=5) + mock_response = httpx.Response( + 200, + json={"choices": [{"message": {"content": "hey"}}], "usage": {"completion_tokens": 1}}, + request=httpx.Request("POST", "http://127.0.0.1:8000/v1/chat/completions"), + ) + with patch("httpx.AsyncClient.post", return_value=mock_response) as post: + asyncio.run(backend.generate(prompt)) + sent = post.call_args.kwargs["json"] + assert sent["messages"] == [ + {"role": "system", "content": ""}, + {"role": "user", "content": "hi"}, + ] + + +def test_generate_chat_enable_thinking_no_system_message() -> None: + backend = OpenAICompatBackend( + base_url="http://127.0.0.1:8000", + model_id="dummy", + chat=True, + disable_thinking=False, + ) + prompt = Prompt(id="p1", text="hi", max_tokens=5) + mock_response = httpx.Response( + 200, + json={"choices": [{"message": {"content": "hey"}}], "usage": {"completion_tokens": 1}}, + request=httpx.Request("POST", "http://127.0.0.1:8000/v1/chat/completions"), + ) + with patch("httpx.AsyncClient.post", return_value=mock_response) as post: + asyncio.run(backend.generate(prompt)) + sent = post.call_args.kwargs["json"] + assert sent["messages"] == [{"role": "user", "content": "hi"}] + + +def test_ollama_native_route_on_disable_thinking() -> None: + """When talking to Ollama with thinking disabled, route to /api/chat.""" + backend = OpenAICompatBackend( + base_url="http://127.0.0.1:11434", + model_id="gemma3:4b", + chat=True, + disable_thinking=True, + ) + prompt = Prompt(id="p1", text="<|think|>hello", max_tokens=5) + + native_response = httpx.Response( + 200, + json={ + "message": {"content": "hi there"}, + "eval_count": 2, + }, + request=httpx.Request("POST", "http://127.0.0.1:11434/api/chat"), + ) + + with patch("httpx.AsyncClient.post", return_value=native_response) as post: + res = asyncio.run(backend.generate(prompt)) + assert post.call_args.args[0] == "/api/chat" + sent = post.call_args.kwargs["json"] + assert sent["think"] is False + assert sent["stream"] is False + assert sent["options"]["num_predict"] == 5 + assert sent["messages"] == [ + {"role": "system", "content": ""}, + {"role": "user", "content": "hello"}, + ] + assert res.text == "hi there" + assert res.logprobs is None + + +def test_ollama_port_with_thinking_enabled_uses_openai_route() -> None: + """Ollama detection only flips routing when disable_thinking is set.""" + backend = OpenAICompatBackend( + base_url="http://127.0.0.1:11434", + model_id="gemma3:4b", + chat=True, + disable_thinking=False, + ) + prompt = Prompt(id="p1", text="hi", max_tokens=5) + mock_response = httpx.Response( + 200, + json={"choices": [{"message": {"content": "hey"}}], "usage": {"completion_tokens": 1}}, + request=httpx.Request("POST", "http://127.0.0.1:11434/v1/chat/completions"), + ) + with patch("httpx.AsyncClient.post", return_value=mock_response) as post: + asyncio.run(backend.generate(prompt)) + assert post.call_args.args[0] == "/v1/chat/completions" + + +def test_generate_chat_disable_thinking_retry_on_400() -> None: + """Server rejects unknown params; backend retries once without them.""" + backend = OpenAICompatBackend( + base_url="http://127.0.0.1:8000", + model_id="dummy", + chat=True, + disable_thinking=True, + ) + prompt = Prompt(id="p1", text="Hello", max_tokens=10) + + bad_request = httpx.Request("POST", "http://127.0.0.1:8000/v1/chat/completions") + bad_response = httpx.Response(400, text="unknown field", request=bad_request) + good_response = httpx.Response( + 200, + json={"choices": [{"message": {"content": "ok"}}], "usage": {"completion_tokens": 1}}, + request=bad_request, + ) + + call_count = {"n": 0} + + def fake_post(*_args, **kwargs): # type: ignore[no-untyped-def] + call_count["n"] += 1 + if call_count["n"] == 1: + # First call raises via raise_for_status + return bad_response + return good_response + + with patch("httpx.AsyncClient.post", side_effect=fake_post): + res = asyncio.run(backend.generate(prompt)) + assert res.text == "ok" + assert call_count["n"] == 2 + # State should be sticky — future requests won't include the keys. + assert backend._thinking_keys_supported is False + + def test_generate_chat_with_logprobs_nan_and_missing() -> None: backend = OpenAICompatBackend(base_url="http://127.0.0.1:8000", model_id="dummy", chat=True) prompt = Prompt(id="p1", text="Hello", max_tokens=10) diff --git a/tests/unit/test_resolve.py b/tests/unit/test_resolve.py index 5d4c04c..a23d28a 100644 --- a/tests/unit/test_resolve.py +++ b/tests/unit/test_resolve.py @@ -14,7 +14,7 @@ def test_ollama_prefix(self) -> None: r = resolve_model("ollama:llama3.1:8b-instruct-q4_K_M") assert r.backend == "openai-compat" assert r.model_id == "llama3.1:8b-instruct-q4_K_M" - assert r.base_url == "http://127.0.0.1:11434/v1" + assert r.base_url == "http://127.0.0.1:11434" assert r.label == "llama3.1:8b-instruct-q4_K_M" def test_mlx_prefix(self) -> None: @@ -67,7 +67,7 @@ def test_mlx_keyword_heuristic(self) -> None: def test_ollama_style_tag(self) -> None: r = resolve_model("llama3.1:8b-instruct-q4_K_M") assert r.backend == "openai-compat" - assert r.base_url == "http://127.0.0.1:11434/v1" + assert r.base_url == "http://127.0.0.1:11434" def test_local_gguf_path(self, tmp_path: Path) -> None: gguf_file = tmp_path / "model-q4.gguf" diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 05b299c..497211c 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -1,4 +1,6 @@ -from infer_check.utils import sanitize_filename +from unittest.mock import MagicMock + +from infer_check.utils import format_prompt, sanitize_filename, strip_thinking_tokens def test_sanitize_filename_basic() -> None: @@ -44,3 +46,88 @@ def test_sanitize_filename_windows_reserved_names() -> None: def test_sanitize_filename_combination() -> None: assert sanitize_filename(" __model. ") == "model" assert sanitize_filename("my.model.name") == "my.model.name" # dots in middle are fine + + +def _make_fake_tokenizer(accepted_kwargs: set[str] | None) -> MagicMock: + """Build a tokenizer mock whose apply_chat_template accepts only selected kwargs. + + ``accepted_kwargs=None`` means the template rejects *every* extra kwarg + (simulating a non-thinking model's template). + """ + tok = MagicMock() + tok.chat_template = "stub" + + def apply(messages, **kwargs): # type: ignore[no-untyped-def] + for key in ("enable_thinking", "thinking"): + if key in kwargs: + if accepted_kwargs is None or key not in accepted_kwargs: + raise TypeError(f"got unexpected kwarg {key!r}") + return f"templated[{key}={kwargs[key]}]" + return "templated" + + tok.apply_chat_template = MagicMock(side_effect=apply) + return tok + + +def test_format_prompt_disable_thinking_qwen3_style() -> None: + tok = _make_fake_tokenizer(accepted_kwargs={"enable_thinking"}) + out = format_prompt("hi", tokenizer=tok, disable_thinking=True) + assert out == "templated[enable_thinking=False]" + + +def test_format_prompt_disable_thinking_deepseek_style() -> None: + tok = _make_fake_tokenizer(accepted_kwargs={"thinking"}) + out = format_prompt("hi", tokenizer=tok, disable_thinking=True) + assert out == "templated[thinking=False]" + + +def test_format_prompt_disable_thinking_non_thinking_model_falls_back() -> None: + # Template rejects both flags — we still render a normal prompt. + tok = _make_fake_tokenizer(accepted_kwargs=None) + out = format_prompt("hi", tokenizer=tok, disable_thinking=True) + assert out == "templated" + + +def test_format_prompt_default_does_not_pass_thinking_kwargs() -> None: + tok = _make_fake_tokenizer(accepted_kwargs={"enable_thinking"}) + out = format_prompt("hi", tokenizer=tok) + assert out == "templated" + + +def test_format_prompt_no_chat_template_returns_text() -> None: + tok = MagicMock() + tok.chat_template = None + assert format_prompt("raw text", tokenizer=tok, disable_thinking=True) == "raw text" + + +def test_strip_thinking_tokens_ollama_trigger() -> None: + # Ollama's gpt-oss uses <|think|> as a system-prompt trigger. + assert strip_thinking_tokens("<|think|>solve x") == "solve x" + assert strip_thinking_tokens("<|THINK|>hi") == "hi" + + +def test_strip_thinking_tokens_deepseek_wrapper() -> None: + assert strip_thinking_tokens("reasoning hereanswer") == "answer" + # Cross-line reasoning traces. + assert strip_thinking_tokens("before\nstep1\nstep2\nafter") == "beforeafter" + # Stray unbalanced tags are also removed. + assert strip_thinking_tokens("stray tag") == "stray tag" + + +def test_strip_thinking_tokens_leaves_normal_text_untouched() -> None: + assert strip_thinking_tokens("Hello, world!") == "Hello, world!" + + +def test_format_prompt_strips_thinking_tokens_when_disabled() -> None: + tok = _make_fake_tokenizer(accepted_kwargs={"enable_thinking"}) + format_prompt("<|think|>hi", tokenizer=tok, disable_thinking=True) + # The message forwarded to the template must not carry the trigger token. + args, _ = tok.apply_chat_template.call_args + assert args[0] == [{"role": "user", "content": "hi"}] + + +def test_format_prompt_keeps_thinking_tokens_when_enabled() -> None: + tok = _make_fake_tokenizer(accepted_kwargs={"enable_thinking"}) + format_prompt("<|think|>hi", tokenizer=tok, disable_thinking=False) + args, _ = tok.apply_chat_template.call_args + assert args[0] == [{"role": "user", "content": "<|think|>hi"}]