Skip to content

Commit 080c865

Browse files
authored
basic llama factory (#1494)
1 parent a80d274 commit 080c865

3 files changed

Lines changed: 71 additions & 0 deletions

File tree

src/modelgauge/sut_factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from modelgauge.suts.google_sut_factory import GoogleSUTFactory
1111
from modelgauge.suts.huggingface_sut_factory import HuggingFaceSUTFactory
1212
from modelgauge.suts.indirect_sut import IndirectSUTFactory
13+
from modelgauge.suts.meta_llama_factory import LlamaSUTFactory
1314
from modelgauge.suts.mistral_sut_factory import MistralSUTFactory
1415
from modelgauge.suts.modelship_sut import ModelShipSUTFactory
1516
from modelgauge.suts.openai_sut_factory import OpenAICompatibleSUTFactory
@@ -36,6 +37,7 @@ class SUTType(Enum):
3637
"hf": HuggingFaceSUTFactory,
3738
"hfrelay": HuggingFaceSUTFactory,
3839
"indirect": IndirectSUTFactory,
40+
"llama": LlamaSUTFactory,
3941
"openai": OpenAICompatibleSUTFactory,
4042
"mistral": MistralSUTFactory,
4143
"modelship": ModelShipSUTFactory,
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from llama_api_client import LlamaAPIClient
2+
3+
from modelgauge.dynamic_sut_factory import DynamicSUTFactory, ModelNotSupportedError
4+
from modelgauge.secret_values import InjectSecret, RawSecrets
5+
from modelgauge.sut import SUT
6+
from modelgauge.sut_definition import SUTDefinition
7+
from modelgauge.suts.meta_llama_client import MetaLlamaApiKey, MetaLlamaSUT
8+
9+
10+
class LlamaSUTFactory(DynamicSUTFactory):
11+
def __init__(self, raw_secrets: RawSecrets):
12+
super().__init__(raw_secrets)
13+
self._client = None
14+
15+
@property
16+
def client(self):
17+
if self._client is None:
18+
api_key = self.injected_secrets()[0].value
19+
self._client = LlamaAPIClient(api_key=api_key)
20+
return self._client
21+
22+
def get_secrets(self) -> list[InjectSecret]:
23+
api_key = InjectSecret(MetaLlamaApiKey)
24+
return [api_key]
25+
26+
def _get_model_name(self, model) -> str | None:
27+
"""Llama API model names are case sensitive."""
28+
models = self.client.models.list()
29+
for m in models:
30+
if m.id.lower() == model.lower():
31+
return m.id
32+
return None
33+
34+
def make_sut(self, sut_definition: SUTDefinition) -> SUT:
35+
model_name = sut_definition.to_dynamic_sut_metadata().external_model_name()
36+
model_name = self._get_model_name(model_name)
37+
38+
if model_name is None:
39+
raise ModelNotSupportedError(f"Model {model_name} not found or not available via Llama API.")
40+
return MetaLlamaSUT(sut_definition.dynamic_uid, model_name, *self.injected_secrets())
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import pytest
2+
from unittest.mock import patch
3+
4+
from modelgauge.dynamic_sut_factory import ModelNotSupportedError
5+
from modelgauge.sut_definition import SUTDefinition
6+
from modelgauge.suts.meta_llama_client import MetaLlamaSUT
7+
from modelgauge.suts.meta_llama_factory import LlamaSUTFactory
8+
9+
10+
@pytest.fixture
11+
def factory():
12+
return LlamaSUTFactory({"meta_llama": {"api_key": "value"}})
13+
14+
15+
def test_make_sut(factory):
16+
with patch("modelgauge.suts.meta_llama_factory.LlamaSUTFactory._get_model_name", return_value="Foo/Bar"):
17+
sut_definition = SUTDefinition(model="bar", maker="foo", driver="llama")
18+
sut = factory.make_sut(sut_definition)
19+
20+
assert isinstance(sut, MetaLlamaSUT)
21+
assert sut.uid == "foo/bar:llama"
22+
assert sut.model == "Foo/Bar"
23+
24+
25+
def test_make_sut_bad_model(factory):
26+
sut_definition = SUTDefinition(model="bogus", maker="fake", driver="llama")
27+
with patch("modelgauge.suts.meta_llama_factory.LlamaSUTFactory._get_model_name", return_value=None):
28+
with pytest.raises(ModelNotSupportedError):
29+
factory.make_sut(sut_definition)

0 commit comments

Comments
 (0)