diff --git a/lmms_eval/models/__init__.py b/lmms_eval/models/__init__.py index a7c1f92e5..25d71a3b4 100644 --- a/lmms_eval/models/__init__.py +++ b/lmms_eval/models/__init__.py @@ -114,6 +114,7 @@ AVAILABLE_CHAT_TEMPLATE_MODELS = { "bagel_lmms_engine": "BagelLmmsEngine", + "ollama": "Ollama", "fastvideo": "FastVideo", "internvl_hf": "InternVLHf", "llava_hf": "LlavaHf", diff --git a/lmms_eval/models/chat/ollama.py b/lmms_eval/models/chat/ollama.py new file mode 100644 index 000000000..ae8ef86d3 --- /dev/null +++ b/lmms_eval/models/chat/ollama.py @@ -0,0 +1,64 @@ +"""Ollama chat backend. + +Ollama exposes an OpenAI-compatible API at http://localhost:11434/v1, so this +backend inherits the OpenAI chat implementation for generation. + +Example usage:: + + python -m lmms_eval \\ + --model ollama \\ + --model_args model_version=llava \\ + --tasks mme --limit 8 +""" + +from __future__ import annotations + +from typing import Any, List, Optional, Tuple + +from lmms_eval.api.instance import Instance +from lmms_eval.api.registry import register_model +from lmms_eval.models.chat.openai import OpenAICompatible as OpenAICompatibleChatBase + +_OLLAMA_DEFAULT_BASE_URL = "http://localhost:11434/v1" +_OLLAMA_NO_KEY = "ollama" + + +@register_model("ollama") +class Ollama(OpenAICompatibleChatBase): + """Ollama local inference backend (OpenAI-compatible /v1 API).""" + + is_simple = False + + def __init__( + self, + model_version: str = "llava", + model: Optional[str] = None, + host: str = _OLLAMA_DEFAULT_BASE_URL, + base_url: Optional[str] = None, + api_key: str = _OLLAMA_NO_KEY, + num_concurrent: int = 4, + **kwargs: Any, + ) -> None: + resolved_base_url = base_url or host + # Derive the Ollama native API root (without /v1) for loglikelihood calls. + self._ollama_api_base = resolved_base_url.rstrip("/").removesuffix("/v1") + super().__init__( + model_version=model_version, + model=model, + base_url=resolved_base_url, + api_key=api_key, + num_concurrent=num_concurrent, + azure_openai=False, + **kwargs, + ) + + def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: + """Ollama does not expose prompt-token log-likelihoods. + + Ollama's native ``POST /api/generate`` can return ``logprobs`` for + generated tokens, but lmms-eval's loglikelihood API needs the + likelihood of a provided continuation under a fixed context. Returning + a fabricated score would make multiple-choice likelihood tasks look + valid while producing misleading metrics. + """ + raise NotImplementedError("Ollama loglikelihood is not supported; use generate_until tasks instead.") diff --git a/test/models/test_ollama.py b/test/models/test_ollama.py new file mode 100644 index 000000000..c20d39ea1 --- /dev/null +++ b/test/models/test_ollama.py @@ -0,0 +1,90 @@ +"""Unit tests for the Ollama backend.""" + +from __future__ import annotations + +import sys +import types +import unittest +from types import SimpleNamespace +from unittest import mock + + +def _ensure_decord_stub() -> None: + """Register a fake decord module so optional_import resolves without the package.""" + if "decord" not in sys.modules: + mod = types.ModuleType("decord") + mod.VideoReader = mock.MagicMock() + mod.cpu = mock.MagicMock() + sys.modules["decord"] = mod + + +def _fake_accelerator() -> mock.MagicMock: + acc = mock.MagicMock() + acc.num_processes = 1 + acc.local_process_index = 0 + acc.device = "cpu" + return acc + + +def _make_ollama(model_version: str = "llava", **kwargs): + _ensure_decord_stub() + from lmms_eval.models.chat.ollama import Ollama + + with mock.patch("lmms_eval.models.simple.openai.OpenAI"), mock.patch("lmms_eval.models.simple.openai.Accelerator", return_value=_fake_accelerator()): + return Ollama(model_version=model_version, **kwargs) + + +def _make_instance(context: str, continuation: str) -> SimpleNamespace: + return SimpleNamespace(args=(context, continuation), rank=0) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestOllamaRegistration(unittest.TestCase): + def test_registered_as_chat_model(self) -> None: + from lmms_eval.models import MODEL_REGISTRY_V2 + + manifest = MODEL_REGISTRY_V2.get_manifest("ollama") + self.assertEqual(manifest.model_id, "ollama") + self.assertEqual(manifest.chat_class_path, "lmms_eval.models.chat.ollama.Ollama") + self.assertIsNone(manifest.simple_class_path) + + def test_is_simple_false(self) -> None: + _ensure_decord_stub() + from lmms_eval.models.chat.ollama import Ollama + + self.assertFalse(Ollama.is_simple) + + +class TestOllamaInit(unittest.TestCase): + def test_default_base_url(self) -> None: + m = _make_ollama() + self.assertEqual(m._ollama_api_base, "http://localhost:11434") + + def test_custom_host_strips_v1(self) -> None: + m = _make_ollama(host="http://myserver:11434/v1") + self.assertEqual(m._ollama_api_base, "http://myserver:11434") + + def test_model_version_stored(self) -> None: + m = _make_ollama(model_version="mistral") + self.assertEqual(m.model_version, "mistral") + + def test_num_concurrent_default(self) -> None: + m = _make_ollama() + self.assertEqual(m.num_concurrent, 4) + + +class TestOllamaLoglikelihood(unittest.TestCase): + def test_loglikelihood_is_explicitly_unsupported(self) -> None: + model = _make_ollama() + instance = _make_instance("The sky is", " blue") + + with self.assertRaisesRegex(NotImplementedError, "generate_until"): + model.loglikelihood([instance]) + + +if __name__ == "__main__": + unittest.main()