From b3876c5ba1fb81f4eb06226aa7137de9d57d7265 Mon Sep 17 00:00:00 2001 From: Barbara Korycki Date: Fri, 6 Mar 2026 16:10:39 -0800 Subject: [PATCH] basic llama factory --- src/modelgauge/sut_factory.py | 2 + src/modelgauge/suts/meta_llama_factory.py | 40 +++++++++++++++++++ .../sut_tests/test_meta_llama_factory.py | 29 ++++++++++++++ 3 files changed, 71 insertions(+) create mode 100644 src/modelgauge/suts/meta_llama_factory.py create mode 100644 tests/modelgauge_tests/sut_tests/test_meta_llama_factory.py diff --git a/src/modelgauge/sut_factory.py b/src/modelgauge/sut_factory.py index 2581cab0..99a7c03c 100644 --- a/src/modelgauge/sut_factory.py +++ b/src/modelgauge/sut_factory.py @@ -10,6 +10,7 @@ from modelgauge.suts.google_sut_factory import GoogleSUTFactory from modelgauge.suts.huggingface_sut_factory import HuggingFaceSUTFactory from modelgauge.suts.indirect_sut import IndirectSUTFactory +from modelgauge.suts.meta_llama_factory import LlamaSUTFactory from modelgauge.suts.mistral_sut_factory import MistralSUTFactory from modelgauge.suts.modelship_sut import ModelShipSUTFactory from modelgauge.suts.openai_sut_factory import OpenAICompatibleSUTFactory @@ -36,6 +37,7 @@ class SUTType(Enum): "hf": HuggingFaceSUTFactory, "hfrelay": HuggingFaceSUTFactory, "indirect": IndirectSUTFactory, + "llama": LlamaSUTFactory, "openai": OpenAICompatibleSUTFactory, "mistral": MistralSUTFactory, "modelship": ModelShipSUTFactory, diff --git a/src/modelgauge/suts/meta_llama_factory.py b/src/modelgauge/suts/meta_llama_factory.py new file mode 100644 index 00000000..4bf6558b --- /dev/null +++ b/src/modelgauge/suts/meta_llama_factory.py @@ -0,0 +1,40 @@ +from llama_api_client import LlamaAPIClient + +from modelgauge.dynamic_sut_factory import DynamicSUTFactory, ModelNotSupportedError +from modelgauge.secret_values import InjectSecret, RawSecrets +from modelgauge.sut import SUT +from modelgauge.sut_definition import SUTDefinition +from modelgauge.suts.meta_llama_client import MetaLlamaApiKey, MetaLlamaSUT + + +class LlamaSUTFactory(DynamicSUTFactory): + def __init__(self, raw_secrets: RawSecrets): + super().__init__(raw_secrets) + self._client = None + + @property + def client(self): + if self._client is None: + api_key = self.injected_secrets()[0].value + self._client = LlamaAPIClient(api_key=api_key) + return self._client + + def get_secrets(self) -> list[InjectSecret]: + api_key = InjectSecret(MetaLlamaApiKey) + return [api_key] + + def _get_model_name(self, model) -> str | None: + """Llama API model names are case sensitive.""" + models = self.client.models.list() + for m in models: + if m.id.lower() == model.lower(): + return m.id + return None + + def make_sut(self, sut_definition: SUTDefinition) -> SUT: + model_name = sut_definition.to_dynamic_sut_metadata().external_model_name() + model_name = self._get_model_name(model_name) + + if model_name is None: + raise ModelNotSupportedError(f"Model {model_name} not found or not available via Llama API.") + return MetaLlamaSUT(sut_definition.dynamic_uid, model_name, *self.injected_secrets()) diff --git a/tests/modelgauge_tests/sut_tests/test_meta_llama_factory.py b/tests/modelgauge_tests/sut_tests/test_meta_llama_factory.py new file mode 100644 index 00000000..2b63b472 --- /dev/null +++ b/tests/modelgauge_tests/sut_tests/test_meta_llama_factory.py @@ -0,0 +1,29 @@ +import pytest +from unittest.mock import patch + +from modelgauge.dynamic_sut_factory import ModelNotSupportedError +from modelgauge.sut_definition import SUTDefinition +from modelgauge.suts.meta_llama_client import MetaLlamaSUT +from modelgauge.suts.meta_llama_factory import LlamaSUTFactory + + +@pytest.fixture +def factory(): + return LlamaSUTFactory({"meta_llama": {"api_key": "value"}}) + + +def test_make_sut(factory): + with patch("modelgauge.suts.meta_llama_factory.LlamaSUTFactory._get_model_name", return_value="Foo/Bar"): + sut_definition = SUTDefinition(model="bar", maker="foo", driver="llama") + sut = factory.make_sut(sut_definition) + + assert isinstance(sut, MetaLlamaSUT) + assert sut.uid == "foo/bar:llama" + assert sut.model == "Foo/Bar" + + +def test_make_sut_bad_model(factory): + sut_definition = SUTDefinition(model="bogus", maker="fake", driver="llama") + with patch("modelgauge.suts.meta_llama_factory.LlamaSUTFactory._get_model_name", return_value=None): + with pytest.raises(ModelNotSupportedError): + factory.make_sut(sut_definition)