Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/modelgauge/sut_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from modelgauge.suts.google_sut_factory import GoogleSUTFactory
from modelgauge.suts.huggingface_sut_factory import HuggingFaceSUTFactory
from modelgauge.suts.indirect_sut import IndirectSUTFactory
from modelgauge.suts.meta_llama_factory import LlamaSUTFactory
from modelgauge.suts.mistral_sut_factory import MistralSUTFactory
from modelgauge.suts.modelship_sut import ModelShipSUTFactory
from modelgauge.suts.openai_sut_factory import OpenAICompatibleSUTFactory
Expand All @@ -36,6 +37,7 @@ class SUTType(Enum):
"hf": HuggingFaceSUTFactory,
"hfrelay": HuggingFaceSUTFactory,
"indirect": IndirectSUTFactory,
"llama": LlamaSUTFactory,
"openai": OpenAICompatibleSUTFactory,
"mistral": MistralSUTFactory,
"modelship": ModelShipSUTFactory,
Expand Down
40 changes: 40 additions & 0 deletions src/modelgauge/suts/meta_llama_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from llama_api_client import LlamaAPIClient

from modelgauge.dynamic_sut_factory import DynamicSUTFactory, ModelNotSupportedError
from modelgauge.secret_values import InjectSecret, RawSecrets
from modelgauge.sut import SUT
from modelgauge.sut_definition import SUTDefinition
from modelgauge.suts.meta_llama_client import MetaLlamaApiKey, MetaLlamaSUT


class LlamaSUTFactory(DynamicSUTFactory):
def __init__(self, raw_secrets: RawSecrets):
super().__init__(raw_secrets)
self._client = None

@property
def client(self):
if self._client is None:
api_key = self.injected_secrets()[0].value
self._client = LlamaAPIClient(api_key=api_key)
return self._client

def get_secrets(self) -> list[InjectSecret]:
api_key = InjectSecret(MetaLlamaApiKey)
return [api_key]

def _get_model_name(self, model) -> str | None:
"""Llama API model names are case sensitive."""
models = self.client.models.list()
for m in models:
if m.id.lower() == model.lower():
return m.id
return None

def make_sut(self, sut_definition: SUTDefinition) -> SUT:
model_name = sut_definition.to_dynamic_sut_metadata().external_model_name()
model_name = self._get_model_name(model_name)

if model_name is None:
raise ModelNotSupportedError(f"Model {model_name} not found or not available via Llama API.")
return MetaLlamaSUT(sut_definition.dynamic_uid, model_name, *self.injected_secrets())
29 changes: 29 additions & 0 deletions tests/modelgauge_tests/sut_tests/test_meta_llama_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import pytest
from unittest.mock import patch

from modelgauge.dynamic_sut_factory import ModelNotSupportedError
from modelgauge.sut_definition import SUTDefinition
from modelgauge.suts.meta_llama_client import MetaLlamaSUT
from modelgauge.suts.meta_llama_factory import LlamaSUTFactory


@pytest.fixture
def factory():
return LlamaSUTFactory({"meta_llama": {"api_key": "value"}})


def test_make_sut(factory):
with patch("modelgauge.suts.meta_llama_factory.LlamaSUTFactory._get_model_name", return_value="Foo/Bar"):
sut_definition = SUTDefinition(model="bar", maker="foo", driver="llama")
sut = factory.make_sut(sut_definition)

assert isinstance(sut, MetaLlamaSUT)
assert sut.uid == "foo/bar:llama"
assert sut.model == "Foo/Bar"


def test_make_sut_bad_model(factory):
sut_definition = SUTDefinition(model="bogus", maker="fake", driver="llama")
with patch("modelgauge.suts.meta_llama_factory.LlamaSUTFactory._get_model_name", return_value=None):
with pytest.raises(ModelNotSupportedError):
factory.make_sut(sut_definition)
Loading