Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
| **mlx-lm** | In-process | (local) | Local Apple Silicon inference with logprobs |
| **llama-cpp** | HTTP | `http://127.0.0.1:8080` | llama-server via `/completion` endpoint |
| **vllm-mlx** | HTTP | `http://127.0.0.1:8000` | Continuous batching on Apple Silicon |
| **openai-compat** | HTTP | `http://127.0.0.1:11434/v1` | Any OpenAI-compatible server (vLLM, SGLang, Ollama) |
| **openai-compat** | HTTP | `http://127.0.0.1:11434` | Any OpenAI-compatible server (vLLM, SGLang, Ollama) |

## mlx-lm

Expand Down Expand Up @@ -109,7 +109,7 @@ Generic backend for any server that implements the OpenAI API format. Works with

| Model source | Default URL |
|-------------|-------------|
| Ollama tags (e.g., `llama3.1:8b`) | `http://127.0.0.1:11434/v1` |
| Ollama tags (e.g., `llama3.1:8b`) | `http://127.0.0.1:11434` |
| Custom server | Use `--base-url` |

**Example with Ollama**:
Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,12 @@ dev = [
"pytest-cov>=7.0.0",
"ruff>=0.15.5",
]
all = ["infer-check[mlx]", "infer-check[analysis]"]
all = ["infer-check[mlx]", "infer-check[analysis]", "infer-check[http]"]
mlx = ["mlx>=0.31.0", "mlx-lm<0.31.0"]
analysis = ["matplotlib>=3.10.8", "scikit-learn>=1.8.0"]
http = [
"transformers>=5.3.0",
]

[project.scripts]
infer-check = "infer_check.cli:main"
Expand Down
26 changes: 22 additions & 4 deletions src/infer_check/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ class BackendConfig(BaseModel):
backend_type: Literal["mlx-lm", "llama-cpp", "vllm-mlx", "openai-compat"]
model_id: str
quantization: str | None = None
hf_revision: str | None = None
base_url: str | None = None
disable_thinking: bool = True
extra: dict[str, Any] = Field(default_factory=dict)


Expand All @@ -53,33 +55,44 @@ def get_backend(config: BackendConfig) -> BackendAdapter:
return MLXBackend(
model_id=config.model_id,
quantization=config.quantization,
revision=config.hf_revision,
disable_thinking=config.disable_thinking,
)
elif config.backend_type == "llama-cpp":
from infer_check.backends.llama_cpp import LlamaCppBackend

url = config.base_url or "http://127.0.0.1:8080"
return LlamaCppBackend(model_id=config.model_id, base_url=url)
return LlamaCppBackend(
model_id=config.model_id,
base_url=url,
revision=config.hf_revision,
disable_thinking=config.disable_thinking,
)
elif config.backend_type == "vllm-mlx":
from infer_check.backends.vllm_mlx import VLLMMLXBackend

url = config.base_url or "http://127.0.0.1:8000"
return VLLMMLXBackend(
model_id=config.model_id,
base_url=url,
chat=config.extra.get("chat", False),
chat=config.extra.get("chat", True),
Comment thread
NullPointerDepressiveDisorder marked this conversation as resolved.
Comment thread
NullPointerDepressiveDisorder marked this conversation as resolved.
revision=config.hf_revision,
disable_thinking=config.disable_thinking,
)
elif config.backend_type == "openai-compat":
from infer_check.backends.openai_compat import OpenAICompatBackend

if not config.base_url:
raise ValueError(
"openai-compat backend requires --base-url. Example: --base-url http://127.0.0.1:11434/v1 (Ollama)"
"openai-compat backend requires --base-url. Example: --base-url http://127.0.0.1:11434 (Ollama)"
)
return OpenAICompatBackend(
base_url=config.base_url,
model_id=config.model_id,
api_key=config.extra.get("api_key"),
chat=config.extra.get("chat", False),
chat=config.extra.get("chat", True),
Comment thread
NullPointerDepressiveDisorder marked this conversation as resolved.
revision=config.hf_revision,
disable_thinking=config.disable_thinking,
)
else:
supported = ", ".join(["mlx-lm", "llama-cpp", "vllm-mlx", "openai-compat"])
Expand All @@ -91,6 +104,8 @@ def get_backend_for_model(
backend_type: str | None = None,
base_url: str | None = None,
quantization: str | None = None,
disable_thinking: bool = True,
chat: bool = True,
) -> BackendAdapter:
"""Resolve model string to a backend and instantiate it.

Expand All @@ -106,6 +121,9 @@ def get_backend_for_model(
model_id=resolved.model_id,
base_url=base_url or resolved.base_url,
quantization=quantization or resolved.label,
hf_revision=resolved.revision,
disable_thinking=disable_thinking,
extra={"chat": chat},
)

return get_backend(config)
19 changes: 17 additions & 2 deletions src/infer_check/backends/llama_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import httpx

from infer_check.types import InferenceResult, Prompt
from infer_check.utils import format_prompt

__all__ = ["LlamaCppBackend"]

Expand All @@ -19,9 +20,17 @@ class LlamaCppBackend:
Communicates via the ``/completion`` endpoint.
"""

def __init__(self, model_id: str, base_url: str = "http://127.0.0.1:8080") -> None:
def __init__(
self,
model_id: str,
base_url: str = "http://127.0.0.1:8080",
revision: str | None = None,
disable_thinking: bool = True,
) -> None:
self._model_id = model_id
self._base_url = base_url.rstrip("/")
self._revision = revision
self._disable_thinking = disable_thinking
self._client = httpx.AsyncClient(base_url=self._base_url, timeout=120.0)

# ------------------------------------------------------------------
Expand All @@ -34,9 +43,15 @@ def name(self) -> str:

async def generate(self, prompt: Prompt) -> InferenceResult:
"""Send a completion request and parse the response."""
formatted = format_prompt(
prompt.text,
model_id=self._model_id,
revision=self._revision,
disable_thinking=self._disable_thinking,
)
Comment thread
NullPointerDepressiveDisorder marked this conversation as resolved.
payload = {
"model": self._model_id,
"prompt": prompt.text,
"prompt": formatted,
"n_predict": prompt.max_tokens,
"temperature": prompt.metadata.get("temperature", 0.0) if prompt.metadata else 0.0,
"n_probs": 10, # Request top 10 probabilities for KL divergence
Expand Down
38 changes: 23 additions & 15 deletions src/infer_check/backends/mlx_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Any, cast

from infer_check.types import InferenceResult, Prompt
from infer_check.utils import format_prompt

__all__ = ["MLXBackend"]

Expand All @@ -19,9 +20,17 @@ class MLXBackend:
importing this module alone never triggers a heavy download.
"""

def __init__(self, model_id: str, quantization: str | None = None) -> None:
def __init__(
self,
model_id: str,
quantization: str | None = None,
revision: str | None = None,
disable_thinking: bool = True,
) -> None:
self._model_id = model_id
self._quantization = quantization
self._revision = revision
Comment thread
NullPointerDepressiveDisorder marked this conversation as resolved.
self._disable_thinking = disable_thinking
self._model: Any = None
self._tokenizer: Any = None

Expand Down Expand Up @@ -112,7 +121,7 @@ def _ensure_loaded(self) -> None:

repo_or_path = str(model_path) if model_path.exists() else self._model_id
try:
res = load(repo_or_path)
res = load(repo_or_path, revision=self._revision)
except Exception as exc:
msg = str(exc)
if "404" in msg or "Repository Not Found" in msg:
Expand All @@ -132,25 +141,19 @@ def _ensure_loaded(self) -> None:
self._model = res[0]
self._tokenizer = res[1]

def _format_prompt(self, text: str) -> str:
"""Apply chat template if the tokenizer has one (Instruct models).

Raw prompts sent to Instruct models produce undefined behavior that
varies across quantization levels, making comparisons meaningless.
"""
if hasattr(self._tokenizer, "apply_chat_template") and self._tokenizer.chat_template is not None:
messages = [{"role": "user", "content": text}]
return str(self._tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True))
return text

def _generate_simple(self, prompt: Prompt) -> InferenceResult:
"""Generate using the high-level ``mlx_lm.generate`` API."""
from mlx_lm import generate as mlx_generate
from mlx_lm.sample_utils import make_sampler

temp = prompt.metadata.get("temperature", 0.0) if prompt.metadata else 0.0
sampler = make_sampler(temp=temp)
formatted = self._format_prompt(prompt.text)
formatted = format_prompt(
prompt.text,
tokenizer=self._tokenizer,
revision=self._revision,
disable_thinking=self._disable_thinking,
)
start = time.perf_counter()
text: str = mlx_generate(
self._model,
Expand Down Expand Up @@ -184,7 +187,12 @@ def _generate_with_logprobs(self, prompt: Prompt) -> InferenceResult:

temp = prompt.metadata.get("temperature", 0.0) if prompt.metadata else 0.0
sampler = make_sampler(temp=temp)
formatted = self._format_prompt(prompt.text)
formatted = format_prompt(
prompt.text,
tokenizer=self._tokenizer,
revision=self._revision,
disable_thinking=self._disable_thinking,
)
input_ids = mx.array(self._tokenizer.encode(formatted))

# Configurable top-K to avoid memory explosion. Default to 10.
Expand Down
Loading
Loading