Skip to content

Commit c7d1bb3

Browse files
KennethEnevoldsenCopilot
andauthored
fix: ensure that get_model_metas raise an error if model name is incorrect (#4560)
* fix: ensure that get_model_metas raise an error if model name is incorrect Currently it will just silently drop the model. Discovered when trying to get "mteb/baseline-bm25". Added test and fixes the issue. This should in practice also speed up the function quite a bit. Co-authored-by: Copilot <copilot@github.com> * format * fix typecheck --------- Co-authored-by: Copilot <copilot@github.com>
1 parent bfb8fe7 commit c7d1bb3

2 files changed

Lines changed: 15 additions & 1 deletion

File tree

mteb/models/get_model_meta.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,11 @@ def get_model_metas( # noqa: PLR0913, PLR0917
5858
model_types_set = set(model_types) if model_types is not None else None
5959
modalities_set = set(modalities) if modalities is not None else None
6060

61-
for model_meta in MODEL_REGISTRY.values():
61+
model_metas: Iterable[ModelMeta] = MODEL_REGISTRY.values()
62+
if model_names is not None:
63+
model_metas = [get_model_meta(name) for name in model_names]
64+
65+
for model_meta in model_metas:
6266
if (model_names is not None) and (model_meta.name not in model_names):
6367
continue
6468
if languages is not None:

tests/test_models/test_model_meta.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,16 @@ def test_fill_missing_parameter():
316316
assert meta_with_compute.memory_usage_mb is not None
317317

318318

319+
def test_raise_on_invalid_model_name():
320+
"""Test that an error is raised for invalid model names."""
321+
with pytest.raises(KeyError):
322+
mteb.get_model_metas(
323+
["mteb/baseline-bm25"]
324+
) # invalid (but plausible user input)
325+
mdls = mteb.get_model_metas(["mteb/baseline-bm25s"]) # valid
326+
assert len(mdls) == 1
327+
328+
319329
@pytest.mark.parametrize(
320330
"model_meta",
321331
[m for m in mteb.get_model_metas(open_weights=True) if "text" in m.modalities],

0 commit comments

Comments
 (0)