Skip to content

Commit bf8cd7d

Browse files
Enhance backend configurations with revision support and chat options (#16)
* fix: update Ollama backend default URL to remove /v1 suffix and set chat mode default to True * feat: add revision support and unified chat template formatting for all backends - Add `hf_revision`/`revision` parameter to model config, resolve, and backend constructors - Parse and propagate revision from model spec (e.g. repo/model@main) - Implement `format_prompt` utility to apply chat templates using tokenizer or HuggingFace model - Use `format_prompt` in mlx-lm, llama-cpp, and openai-compat backends for consistent prompt formatting - Add `transformers` as a required dependency * feat: add disable_thinking option to suppress reasoning mode across backends - Introduce `disable_thinking` flag to all backend configs and CLI, defaulting to True for consistent output comparison. - Implement cross-backend support for disabling reasoning/thinking mode (Qwen3, DeepSeek-R1, Ollama, vLLM, OpenAI/OpenRouter). - Strip reasoning-trigger tokens from prompts when disabled. - Route to Ollama's native /api/chat endpoint with `think: false` when appropriate. - Add tests for prompt formatting, payload construction, and retry logic when disabling thinking. * feat: add disable_thinking option to suppress reasoning mode across backends - Introduce `disable_thinking` flag to all backend configs and CLI, defaulting to True for consistent output comparison. - Implement cross-backend support for disabling reasoning/thinking mode (Qwen3, DeepSeek-R1, Ollama, vLLM, OpenAI/OpenRouter). - Strip reasoning-trigger tokens from prompts when disabled. - Route to Ollama's native /api/chat endpoint with `think: false` when appropriate. - Add tests for prompt formatting, payload construction, and retry logic when disabling thinking. * refactor: move transformers dependency to optional http extra and cache tokenizer loading * feat: add revision support for vllm-mlx backend and propagate resolved revision in diff command * feat: propagate hf_revision to diff command and clarify Ollama chat handling comments * feat: add --chat option to CLI and propagate chat mode to backend configuration * fix: refine model spec revision parsing and tighten HF repo detection for tokenizer loading * fix: pass revision to model loader in mlx backend
1 parent 3289264 commit bf8cd7d

16 files changed

Lines changed: 747 additions & 70 deletions

docs/backends.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
| **mlx-lm** | In-process | (local) | Local Apple Silicon inference with logprobs |
1010
| **llama-cpp** | HTTP | `http://127.0.0.1:8080` | llama-server via `/completion` endpoint |
1111
| **vllm-mlx** | HTTP | `http://127.0.0.1:8000` | Continuous batching on Apple Silicon |
12-
| **openai-compat** | HTTP | `http://127.0.0.1:11434/v1` | Any OpenAI-compatible server (vLLM, SGLang, Ollama) |
12+
| **openai-compat** | HTTP | `http://127.0.0.1:11434` | Any OpenAI-compatible server (vLLM, SGLang, Ollama) |
1313

1414
## mlx-lm
1515

@@ -109,7 +109,7 @@ Generic backend for any server that implements the OpenAI API format. Works with
109109

110110
| Model source | Default URL |
111111
|-------------|-------------|
112-
| Ollama tags (e.g., `llama3.1:8b`) | `http://127.0.0.1:11434/v1` |
112+
| Ollama tags (e.g., `llama3.1:8b`) | `http://127.0.0.1:11434` |
113113
| Custom server | Use `--base-url` |
114114

115115
**Example with Ollama**:

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,12 @@ dev = [
3737
"pytest-cov>=7.0.0",
3838
"ruff>=0.15.5",
3939
]
40-
all = ["infer-check[mlx]", "infer-check[analysis]"]
40+
all = ["infer-check[mlx]", "infer-check[analysis]", "infer-check[http]"]
4141
mlx = ["mlx>=0.31.0", "mlx-lm<0.31.0"]
4242
analysis = ["matplotlib>=3.10.8", "scikit-learn>=1.8.0"]
43+
http = [
44+
"transformers>=5.3.0",
45+
]
4346

4447
[project.scripts]
4548
infer-check = "infer_check.cli:main"

src/infer_check/backends/base.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ class BackendConfig(BaseModel):
4141
backend_type: Literal["mlx-lm", "llama-cpp", "vllm-mlx", "openai-compat"]
4242
model_id: str
4343
quantization: str | None = None
44+
hf_revision: str | None = None
4445
base_url: str | None = None
46+
disable_thinking: bool = True
4547
extra: dict[str, Any] = Field(default_factory=dict)
4648

4749

@@ -53,33 +55,44 @@ def get_backend(config: BackendConfig) -> BackendAdapter:
5355
return MLXBackend(
5456
model_id=config.model_id,
5557
quantization=config.quantization,
58+
revision=config.hf_revision,
59+
disable_thinking=config.disable_thinking,
5660
)
5761
elif config.backend_type == "llama-cpp":
5862
from infer_check.backends.llama_cpp import LlamaCppBackend
5963

6064
url = config.base_url or "http://127.0.0.1:8080"
61-
return LlamaCppBackend(model_id=config.model_id, base_url=url)
65+
return LlamaCppBackend(
66+
model_id=config.model_id,
67+
base_url=url,
68+
revision=config.hf_revision,
69+
disable_thinking=config.disable_thinking,
70+
)
6271
elif config.backend_type == "vllm-mlx":
6372
from infer_check.backends.vllm_mlx import VLLMMLXBackend
6473

6574
url = config.base_url or "http://127.0.0.1:8000"
6675
return VLLMMLXBackend(
6776
model_id=config.model_id,
6877
base_url=url,
69-
chat=config.extra.get("chat", False),
78+
chat=config.extra.get("chat", True),
79+
revision=config.hf_revision,
80+
disable_thinking=config.disable_thinking,
7081
)
7182
elif config.backend_type == "openai-compat":
7283
from infer_check.backends.openai_compat import OpenAICompatBackend
7384

7485
if not config.base_url:
7586
raise ValueError(
76-
"openai-compat backend requires --base-url. Example: --base-url http://127.0.0.1:11434/v1 (Ollama)"
87+
"openai-compat backend requires --base-url. Example: --base-url http://127.0.0.1:11434 (Ollama)"
7788
)
7889
return OpenAICompatBackend(
7990
base_url=config.base_url,
8091
model_id=config.model_id,
8192
api_key=config.extra.get("api_key"),
82-
chat=config.extra.get("chat", False),
93+
chat=config.extra.get("chat", True),
94+
revision=config.hf_revision,
95+
disable_thinking=config.disable_thinking,
8396
)
8497
else:
8598
supported = ", ".join(["mlx-lm", "llama-cpp", "vllm-mlx", "openai-compat"])
@@ -91,6 +104,8 @@ def get_backend_for_model(
91104
backend_type: str | None = None,
92105
base_url: str | None = None,
93106
quantization: str | None = None,
107+
disable_thinking: bool = True,
108+
chat: bool = True,
94109
) -> BackendAdapter:
95110
"""Resolve model string to a backend and instantiate it.
96111
@@ -106,6 +121,9 @@ def get_backend_for_model(
106121
model_id=resolved.model_id,
107122
base_url=base_url or resolved.base_url,
108123
quantization=quantization or resolved.label,
124+
hf_revision=resolved.revision,
125+
disable_thinking=disable_thinking,
126+
extra={"chat": chat},
109127
)
110128

111129
return get_backend(config)

src/infer_check/backends/llama_cpp.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import httpx
1010

1111
from infer_check.types import InferenceResult, Prompt
12+
from infer_check.utils import format_prompt
1213

1314
__all__ = ["LlamaCppBackend"]
1415

@@ -19,9 +20,17 @@ class LlamaCppBackend:
1920
Communicates via the ``/completion`` endpoint.
2021
"""
2122

22-
def __init__(self, model_id: str, base_url: str = "http://127.0.0.1:8080") -> None:
23+
def __init__(
24+
self,
25+
model_id: str,
26+
base_url: str = "http://127.0.0.1:8080",
27+
revision: str | None = None,
28+
disable_thinking: bool = True,
29+
) -> None:
2330
self._model_id = model_id
2431
self._base_url = base_url.rstrip("/")
32+
self._revision = revision
33+
self._disable_thinking = disable_thinking
2534
self._client = httpx.AsyncClient(base_url=self._base_url, timeout=120.0)
2635

2736
# ------------------------------------------------------------------
@@ -34,9 +43,15 @@ def name(self) -> str:
3443

3544
async def generate(self, prompt: Prompt) -> InferenceResult:
3645
"""Send a completion request and parse the response."""
46+
formatted = format_prompt(
47+
prompt.text,
48+
model_id=self._model_id,
49+
revision=self._revision,
50+
disable_thinking=self._disable_thinking,
51+
)
3752
payload = {
3853
"model": self._model_id,
39-
"prompt": prompt.text,
54+
"prompt": formatted,
4055
"n_predict": prompt.max_tokens,
4156
"temperature": prompt.metadata.get("temperature", 0.0) if prompt.metadata else 0.0,
4257
"n_probs": 10, # Request top 10 probabilities for KL divergence

src/infer_check/backends/mlx_lm.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Any, cast
99

1010
from infer_check.types import InferenceResult, Prompt
11+
from infer_check.utils import format_prompt
1112

1213
__all__ = ["MLXBackend"]
1314

@@ -19,9 +20,17 @@ class MLXBackend:
1920
importing this module alone never triggers a heavy download.
2021
"""
2122

22-
def __init__(self, model_id: str, quantization: str | None = None) -> None:
23+
def __init__(
24+
self,
25+
model_id: str,
26+
quantization: str | None = None,
27+
revision: str | None = None,
28+
disable_thinking: bool = True,
29+
) -> None:
2330
self._model_id = model_id
2431
self._quantization = quantization
32+
self._revision = revision
33+
self._disable_thinking = disable_thinking
2534
self._model: Any = None
2635
self._tokenizer: Any = None
2736

@@ -112,7 +121,7 @@ def _ensure_loaded(self) -> None:
112121

113122
repo_or_path = str(model_path) if model_path.exists() else self._model_id
114123
try:
115-
res = load(repo_or_path)
124+
res = load(repo_or_path, revision=self._revision)
116125
except Exception as exc:
117126
msg = str(exc)
118127
if "404" in msg or "Repository Not Found" in msg:
@@ -132,25 +141,19 @@ def _ensure_loaded(self) -> None:
132141
self._model = res[0]
133142
self._tokenizer = res[1]
134143

135-
def _format_prompt(self, text: str) -> str:
136-
"""Apply chat template if the tokenizer has one (Instruct models).
137-
138-
Raw prompts sent to Instruct models produce undefined behavior that
139-
varies across quantization levels, making comparisons meaningless.
140-
"""
141-
if hasattr(self._tokenizer, "apply_chat_template") and self._tokenizer.chat_template is not None:
142-
messages = [{"role": "user", "content": text}]
143-
return str(self._tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True))
144-
return text
145-
146144
def _generate_simple(self, prompt: Prompt) -> InferenceResult:
147145
"""Generate using the high-level ``mlx_lm.generate`` API."""
148146
from mlx_lm import generate as mlx_generate
149147
from mlx_lm.sample_utils import make_sampler
150148

151149
temp = prompt.metadata.get("temperature", 0.0) if prompt.metadata else 0.0
152150
sampler = make_sampler(temp=temp)
153-
formatted = self._format_prompt(prompt.text)
151+
formatted = format_prompt(
152+
prompt.text,
153+
tokenizer=self._tokenizer,
154+
revision=self._revision,
155+
disable_thinking=self._disable_thinking,
156+
)
154157
start = time.perf_counter()
155158
text: str = mlx_generate(
156159
self._model,
@@ -184,7 +187,12 @@ def _generate_with_logprobs(self, prompt: Prompt) -> InferenceResult:
184187

185188
temp = prompt.metadata.get("temperature", 0.0) if prompt.metadata else 0.0
186189
sampler = make_sampler(temp=temp)
187-
formatted = self._format_prompt(prompt.text)
190+
formatted = format_prompt(
191+
prompt.text,
192+
tokenizer=self._tokenizer,
193+
revision=self._revision,
194+
disable_thinking=self._disable_thinking,
195+
)
188196
input_ids = mx.array(self._tokenizer.encode(formatted))
189197

190198
# Configurable top-K to avoid memory explosion. Default to 10.

0 commit comments

Comments
 (0)