Skip to content

Commit ff8050a

Browse files
authored
Fix together dynamic factory (#1162)
* different api call * Fix unit test
1 parent d4d5081 commit ff8050a

2 files changed

Lines changed: 16 additions & 12 deletions

File tree

src/modelgauge/suts/together_sut_factory.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,18 +28,21 @@ def get_secrets() -> list[InjectSecret]:
2828
return [api_key]
2929

3030
def _find(self, sut_metadata: DynamicSUTMetadata):
31-
found = None
31+
model = None
3232
try:
33-
model_list = self.client.Models.list()
34-
found = [
35-
model["id"] for model in model_list if model["id"].lower() == sut_metadata.external_model_name().lower()
36-
][0]
33+
model = sut_metadata.external_model_name().lower()
34+
self.client.chat.completions.create(
35+
model=model,
36+
messages=[
37+
{"role": "user", "content": "Anybody home?"},
38+
],
39+
)
3740
except Exception as e:
3841
raise ModelNotSupportedError(
3942
f"Model {sut_metadata.external_model_name()} not found or not available on together: {e}"
4043
)
4144

42-
return found
45+
return model
4346

4447
def make_sut(self, sut_metadata: DynamicSUTMetadata) -> TogetherChatSUT:
4548
model_name = self._find(sut_metadata)

tests/modelgauge_tests/sut_tests/test_together_sut_factory.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from unittest.mock import patch
1+
from unittest.mock import patch, MagicMock
22

33
import pytest
44

@@ -8,7 +8,6 @@
88
from modelgauge.suts.together_client import TogetherChatSUT
99
from modelgauge.suts.together_sut_factory import TogetherSUTFactory
1010
from modelgauge_tests.utilities import expensive_tests
11-
from together import Together
1211

1312

1413
@pytest.fixture
@@ -34,16 +33,18 @@ def test_make_sut_bad_model(factory):
3433

3534

3635
def test_find(factory):
37-
with patch.object(Together, "Models", create=True) as mock_models:
38-
mock_models.list.return_value = [{"id": "google/gemma"}]
36+
mock_together = MagicMock()
37+
mock_together.return_value.chat.completions.create.return_value = {} # The method doesn't use the return value.
38+
with patch("modelgauge.suts.together_sut_factory.Together", mock_together):
3939
sut_metadata = DynamicSUTMetadata(model="gemma", maker="google", driver="together")
4040
assert factory._find(sut_metadata) == sut_metadata.external_model_name()
4141

4242

4343
def test_find_bad_model(factory):
4444
sut_metadata = DynamicSUTMetadata(model="any", maker="any", driver="together")
45-
with patch.object(Together, "Models", create=True) as mock_models:
46-
mock_models.list.return_value = None
45+
mock_together = MagicMock()
46+
mock_together.return_value.chat.completions.create.side_effect = Exception("Model not available")
47+
with patch("modelgauge.suts.together_sut_factory.Together", mock_together):
4748
with pytest.raises(ModelNotSupportedError):
4849
_ = factory._find(sut_metadata)
4950

0 commit comments

Comments
 (0)