|
| 1 | +from llama_api_client import LlamaAPIClient |
| 2 | + |
| 3 | +from modelgauge.dynamic_sut_factory import DynamicSUTFactory, ModelNotSupportedError |
| 4 | +from modelgauge.secret_values import InjectSecret, RawSecrets |
| 5 | +from modelgauge.sut import SUT |
| 6 | +from modelgauge.sut_definition import SUTDefinition |
| 7 | +from modelgauge.suts.meta_llama_client import MetaLlamaApiKey, MetaLlamaSUT |
| 8 | + |
| 9 | + |
| 10 | +class LlamaSUTFactory(DynamicSUTFactory): |
| 11 | + def __init__(self, raw_secrets: RawSecrets): |
| 12 | + super().__init__(raw_secrets) |
| 13 | + self._client = None |
| 14 | + |
| 15 | + @property |
| 16 | + def client(self): |
| 17 | + if self._client is None: |
| 18 | + api_key = self.injected_secrets()[0].value |
| 19 | + self._client = LlamaAPIClient(api_key=api_key) |
| 20 | + return self._client |
| 21 | + |
| 22 | + def get_secrets(self) -> list[InjectSecret]: |
| 23 | + api_key = InjectSecret(MetaLlamaApiKey) |
| 24 | + return [api_key] |
| 25 | + |
| 26 | + def _get_model_name(self, model) -> str | None: |
| 27 | + """Llama API model names are case sensitive.""" |
| 28 | + models = self.client.models.list() |
| 29 | + for m in models: |
| 30 | + if m.id.lower() == model.lower(): |
| 31 | + return m.id |
| 32 | + return None |
| 33 | + |
| 34 | + def make_sut(self, sut_definition: SUTDefinition) -> SUT: |
| 35 | + model_name = sut_definition.to_dynamic_sut_metadata().external_model_name() |
| 36 | + model_name = self._get_model_name(model_name) |
| 37 | + |
| 38 | + if model_name is None: |
| 39 | + raise ModelNotSupportedError(f"Model {model_name} not found or not available via Llama API.") |
| 40 | + return MetaLlamaSUT(sut_definition.dynamic_uid, model_name, *self.injected_secrets()) |
0 commit comments