Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lmms_eval/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@

AVAILABLE_CHAT_TEMPLATE_MODELS = {
"bagel_lmms_engine": "BagelLmmsEngine",
"ollama": "Ollama",
"fastvideo": "FastVideo",
"internvl_hf": "InternVLHf",
"llava_hf": "LlavaHf",
Expand Down
64 changes: 64 additions & 0 deletions lmms_eval/models/chat/ollama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""Ollama chat backend.

Ollama exposes an OpenAI-compatible API at http://localhost:11434/v1, so this
backend inherits the OpenAI chat implementation for generation.

Example usage::

python -m lmms_eval \\
--model ollama \\
--model_args model_version=llava \\
--tasks mme --limit 8
"""

from __future__ import annotations

from typing import Any, List, Optional, Tuple

from lmms_eval.api.instance import Instance
from lmms_eval.api.registry import register_model
from lmms_eval.models.chat.openai import OpenAICompatible as OpenAICompatibleChatBase

_OLLAMA_DEFAULT_BASE_URL = "http://localhost:11434/v1"
_OLLAMA_NO_KEY = "ollama"


@register_model("ollama")
class Ollama(OpenAICompatibleChatBase):
"""Ollama local inference backend (OpenAI-compatible /v1 API)."""

is_simple = False

def __init__(
self,
model_version: str = "llava",
model: Optional[str] = None,
host: str = _OLLAMA_DEFAULT_BASE_URL,
base_url: Optional[str] = None,
api_key: str = _OLLAMA_NO_KEY,
num_concurrent: int = 4,
**kwargs: Any,
) -> None:
resolved_base_url = base_url or host
# Derive the Ollama native API root (without /v1) for loglikelihood calls.
self._ollama_api_base = resolved_base_url.rstrip("/").removesuffix("/v1")
super().__init__(
model_version=model_version,
model=model,
base_url=resolved_base_url,
api_key=api_key,
num_concurrent=num_concurrent,
azure_openai=False,
**kwargs,
)

def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
"""Ollama does not expose prompt-token log-likelihoods.

Ollama's native ``POST /api/generate`` can return ``logprobs`` for
generated tokens, but lmms-eval's loglikelihood API needs the
likelihood of a provided continuation under a fixed context. Returning
a fabricated score would make multiple-choice likelihood tasks look
valid while producing misleading metrics.
"""
raise NotImplementedError("Ollama loglikelihood is not supported; use generate_until tasks instead.")
90 changes: 90 additions & 0 deletions test/models/test_ollama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""Unit tests for the Ollama backend."""

from __future__ import annotations

import sys
import types
import unittest
from types import SimpleNamespace
from unittest import mock


def _ensure_decord_stub() -> None:
"""Register a fake decord module so optional_import resolves without the package."""
if "decord" not in sys.modules:
mod = types.ModuleType("decord")
mod.VideoReader = mock.MagicMock()
mod.cpu = mock.MagicMock()
sys.modules["decord"] = mod


def _fake_accelerator() -> mock.MagicMock:
acc = mock.MagicMock()
acc.num_processes = 1
acc.local_process_index = 0
acc.device = "cpu"
return acc


def _make_ollama(model_version: str = "llava", **kwargs):
_ensure_decord_stub()
from lmms_eval.models.chat.ollama import Ollama

with mock.patch("lmms_eval.models.simple.openai.OpenAI"), mock.patch("lmms_eval.models.simple.openai.Accelerator", return_value=_fake_accelerator()):
return Ollama(model_version=model_version, **kwargs)


def _make_instance(context: str, continuation: str) -> SimpleNamespace:
return SimpleNamespace(args=(context, continuation), rank=0)


# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------


class TestOllamaRegistration(unittest.TestCase):
def test_registered_as_chat_model(self) -> None:
from lmms_eval.models import MODEL_REGISTRY_V2

manifest = MODEL_REGISTRY_V2.get_manifest("ollama")
self.assertEqual(manifest.model_id, "ollama")
self.assertEqual(manifest.chat_class_path, "lmms_eval.models.chat.ollama.Ollama")
self.assertIsNone(manifest.simple_class_path)

def test_is_simple_false(self) -> None:
_ensure_decord_stub()
from lmms_eval.models.chat.ollama import Ollama

self.assertFalse(Ollama.is_simple)


class TestOllamaInit(unittest.TestCase):
def test_default_base_url(self) -> None:
m = _make_ollama()
self.assertEqual(m._ollama_api_base, "http://localhost:11434")

def test_custom_host_strips_v1(self) -> None:
m = _make_ollama(host="http://myserver:11434/v1")
self.assertEqual(m._ollama_api_base, "http://myserver:11434")

def test_model_version_stored(self) -> None:
m = _make_ollama(model_version="mistral")
self.assertEqual(m.model_version, "mistral")

def test_num_concurrent_default(self) -> None:
m = _make_ollama()
self.assertEqual(m.num_concurrent, 4)


class TestOllamaLoglikelihood(unittest.TestCase):
def test_loglikelihood_is_explicitly_unsupported(self) -> None:
model = _make_ollama()
instance = _make_instance("The sky is", " blue")

with self.assertRaisesRegex(NotImplementedError, "generate_until"):
model.loglikelihood([instance])


if __name__ == "__main__":
unittest.main()
Loading