Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions mellea/backends/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,10 @@ class LocalHFBackend(FormatterBackend, AdapterMixin):
Mellea `ModelOption` sentinel keys.
from_mellea_model_opts_map (dict): Mapping from Mellea sentinel keys to
HF-specific option names.

Raises:
OSError: If the model cannot be loaded from HuggingFace Hub (bad ID,
missing access, or local filesystem/cache error).
"""

_cached_blocks: dict[str, DynamicCache] = dict()
Expand Down Expand Up @@ -327,12 +331,21 @@ def __init__(
else "cpu"
)
# Get the model and tokenizer.
self._model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(
self._model_id, device_map=str(self._device)
)
self._tokenizer: PreTrainedTokenizerBase = (
AutoTokenizer.from_pretrained(self._model_id)
)
try:
self._model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(
self._model_id, device_map=str(self._device)
)
self._tokenizer: PreTrainedTokenizerBase = (
AutoTokenizer.from_pretrained(self._model_id)
)
except OSError as e:
Comment thread
AngeloDanducci marked this conversation as resolved.
raise OSError(
f"Model '{self._model_id}' could not be loaded from HuggingFace Hub. "
"Check that the model ID is correct and that you have access to it. "
"If the model is gated, set the HF_TOKEN environment variable. "
"To browse available models, visit https://huggingface.co/models "
f"(Original error: {e})"
) from e
Comment thread
AngeloDanducci marked this conversation as resolved.
case _:
self._tokenizer, self._model, self._device = custom_config

Expand Down
2 changes: 1 addition & 1 deletion mellea/backends/model_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ class ModelIdentifier:
)

META_LLAMA_GUARD3_1B = ModelIdentifier(
ollama_name="llama-guard3:1b", hf_model_name="unsloth/Llama-Guard-3-1B"
ollama_name="llama-guard3:1b", hf_model_name="meta-llama/Llama-Guard-3-1B"
)

META_LLAMA_3_2_1B = ModelIdentifier(
Expand Down
26 changes: 19 additions & 7 deletions mellea/backends/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ class OllamaModelBackend(FormatterBackend):
to Mellea ``ModelOption`` sentinel keys.
from_mellea_model_opts_map (dict): Mapping from Mellea ``ModelOption``
sentinel keys to Ollama-specific option names.

Raises:
ValueError: If ``model_id`` is a ``ModelIdentifier`` with no ``ollama_name`` set.
ConnectionError: If the Ollama server is not running at ``base_url``.
OSError: If the model cannot be pulled from the Ollama library.
"""

def __init__(
Expand All @@ -98,13 +103,16 @@ def __init__(
),
model_options=model_options,
)
# Resolve to a concrete ollama model name; assertion fires if no ollama_name.
# Resolve to a concrete ollama model name; raises ValueError if no ollama_name is set.
ollama_model_id = (
model_id.ollama_name if isinstance(model_id, ModelIdentifier) else model_id
)
assert ollama_model_id is not None, (
"model_id is None: the ModelIdentifier has no ollama_name configured, or this model is not available in ollama."
)
if ollama_model_id is None or ollama_model_id == "":
Comment thread
AngeloDanducci marked this conversation as resolved.
raise ValueError(
Comment thread
AngeloDanducci marked this conversation as resolved.
"Cannot create OllamaModelBackend: the ModelIdentifier has no ollama_name set. "
"Check mellea/backends/model_ids.py and ensure the constant you are using "
"has an ollama_name value, or pass the Ollama model tag as a plain string."
)
self._model_id: str = ollama_model_id
self._provider: str = "ollama"

Expand All @@ -125,11 +133,15 @@ def __init__(
if not self._check_ollama_server():
err = f"could not create OllamaModelBackend: ollama server not running at {base_url}"
MelleaLogger.get_logger().error(err)
raise Exception(err)
raise ConnectionError(err)
if not self._pull_ollama_model():
err = f"could not create OllamaModelBackend: {self._model_id} could not be pulled from ollama library"
err = (
f"Model '{self._model_id}' could not be pulled from the Ollama library. "
f"Check that the model name is correct (run 'ollama list' to see locally "
f"available models, or 'ollama pull {self._model_id}' to fetch it manually)."
)
MelleaLogger.get_logger().error(err)
raise Exception(err)
raise OSError(err)

# A mapping of common options for this backend mapped to their Mellea ModelOptions equivalent.
# These are usually values that must be extracted before hand or that are common among backend providers.
Expand Down
109 changes: 109 additions & 0 deletions test/backends/test_model_ids.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""Integration tests verifying that model IDs in model_ids.py resolve to real models.

HuggingFace tests check the Hub API (no model download required).
Ollama tests require a running Ollama server and are skipped otherwise.
"""

import inspect

import pytest

import mellea.backends.model_ids as model_ids
from mellea.backends.model_ids import ModelIdentifier

# Collect all ModelIdentifier constants defined at module level.
_ALL_IDS: list[tuple[str, ModelIdentifier]] = [
(name, obj)
for name, obj in inspect.getmembers(model_ids)
if isinstance(obj, ModelIdentifier)
]

_HF_IDS = [(name, obj.hf_model_name) for name, obj in _ALL_IDS if obj.hf_model_name]
_OLLAMA_IDS = [
(name, obj.ollama_name)
for name, obj in _ALL_IDS
if obj.ollama_name # excludes None and ""
]


@pytest.mark.integration
Comment thread
AngeloDanducci marked this conversation as resolved.
@pytest.mark.slow
@pytest.mark.parametrize("const_name,hf_name", _HF_IDS, ids=[n for n, _ in _HF_IDS])
def test_hf_model_names_exist(const_name: str, hf_name: str) -> None:
"""Every hf_model_name in model_ids.py must resolve to a real HuggingFace repo."""
pytest.importorskip("huggingface_hub", reason="huggingface_hub not installed")
from huggingface_hub import model_info
from huggingface_hub.errors import GatedRepoError, RepositoryNotFoundError

try:
model_info(hf_name, token=False)
except GatedRepoError:
# Gated repos exist but require auth — that's fine.
pass
except RepositoryNotFoundError:
# When token=False the Hub maps gated repos to RepositoryNotFoundError to
# avoid leaking their existence. Skip rather than fail so anonymous CI
# doesn't false-fail on gated models; a human with HF_TOKEN can verify.
pytest.skip(
f"{const_name}.hf_model_name={hf_name!r} not found on HuggingFace Hub "
"(may be gated — re-run with HF_TOKEN to confirm it exists)."
)
Comment thread
AngeloDanducci marked this conversation as resolved.


@pytest.mark.integration
@pytest.mark.slow
@pytest.mark.ollama
@pytest.mark.parametrize(
"const_name,ollama_name", _OLLAMA_IDS, ids=[n for n, _ in _OLLAMA_IDS]
)
def test_ollama_model_names_exist(const_name: str, ollama_name: str) -> None:
"""Every ollama_name in model_ids.py must exist in the Ollama library.

The test first checks whether the model is already present locally via
show(). If it is not, it queries the Ollama registry manifest endpoint
directly — no pull is initiated, so no data is downloaded.

The test is skipped (not failed) when the Ollama server is unreachable or
returns an unexpected error, since those conditions reflect environment
problems rather than a bad model ID.
"""
import urllib.error
import urllib.request

import ollama as ollama_sdk

client = ollama_sdk.Client()
try:
client.show(ollama_name)
# Model is present locally — name is valid.
return
except ollama_sdk.ResponseError as e:
if "not found" not in str(e).lower() and "404" not in str(e):
pytest.skip(
f"Ollama server returned an unexpected error for show({ollama_name!r}): {e}"
)
# Model not cached locally; fall through to registry check.
except ConnectionError as e:
pytest.skip(f"Ollama server unreachable: {e}")

# Check the Ollama registry manifest without downloading anything.
# Strip any tag so we query the base model name.
base_name = ollama_name.split(":")[0]
tag = ollama_name.split(":")[1] if ":" in ollama_name else "latest"
registry_url = f"https://registry.ollama.ai/v2/{base_name}/manifests/{tag}"
try:
req = urllib.request.Request(registry_url, method="HEAD")
urllib.request.urlopen(req, timeout=10)
except urllib.error.HTTPError as e:
if e.code in (404, 401):
pytest.fail(
f"{const_name}.ollama_name={ollama_name!r} does not exist in the "
"Ollama library. Update or remove the model ID in "
"mellea/backends/model_ids.py."
)
else:
pytest.skip(
f"Ollama registry returned unexpected HTTP {e.code} for {ollama_name!r}"
)
except OSError as e:
pytest.skip(f"Could not reach Ollama registry: {e}")
Loading