Skip to content

Commit 070d529

Browse files
authored
Remove hack required by a bug that is now fixed (#1120)
* remove workaround that was required by a server-side bug that is now fixed * remove print * noop; lint * use fixtures
1 parent 76f0797 commit 070d529

2 files changed

Lines changed: 13 additions & 30 deletions

File tree

src/modelgauge/suts/meta_llama_client.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,16 @@
55
from httpx import Timeout
66
from llama_api_client import LlamaAPIClient
77
from llama_api_client.types import CreateChatCompletionResponse, MessageTextContentItem, ModerationCreateResponse
8-
from pydantic import BaseModel
9-
from requests.adapters import HTTPAdapter, Retry # type:ignore
108

119
from modelgauge.prompt import TextPrompt
1210
from modelgauge.retry_decorator import retry
1311
from modelgauge.secret_values import InjectSecret, RequiredSecret, SecretDescription
14-
from modelgauge.sut import (
15-
PromptResponseSUT,
16-
SUTOptions,
17-
SUTResponse,
18-
REFUSAL_RESPONSE,
19-
)
12+
from modelgauge.sut import PromptResponseSUT, REFUSAL_RESPONSE, SUTOptions, SUTResponse
2013
from modelgauge.sut_capabilities import AcceptsTextPrompt
2114
from modelgauge.sut_decorator import modelgauge_sut
2215
from modelgauge.sut_registry import SUTS
16+
from pydantic import BaseModel
17+
from requests.adapters import HTTPAdapter, Retry # type:ignore
2318

2419
logger = logging.getLogger(__name__)
2520

@@ -105,11 +100,6 @@ def evaluate(self, request: MetaLlamaChatRequest) -> MetaLlamaModeratedResponse:
105100
messages: list = kwargs.get("messages") # type: ignore
106101
messages.append(chat_response.completion_message)
107102
moderation_response = self.client.moderations.create(messages=messages)
108-
for r in moderation_response.results:
109-
if r.flagged_categories is None:
110-
# make objects comply with Pydantic definitions due to bug;
111-
# see https://github.com/meta-llama/llama-api-python/issues/33 for more
112-
r.flagged_categories = []
113103
return MetaLlamaModeratedResponse(sut_response=chat_response, moderation_response=moderation_response)
114104

115105
def translate_response(self, request: MetaLlamaChatRequest, response: MetaLlamaModeratedResponse) -> SUTResponse:

tests/modelgauge_tests/test_meta_llama.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,12 @@
11
from unittest.mock import MagicMock
22

33
from llama_api_client.types import CreateChatCompletionResponse
4-
from requests import HTTPError # type:ignore
54

65
from modelgauge.prompt import TextPrompt
7-
from modelgauge.sut import SUTResponse, SUTOptions
8-
from modelgauge.suts.meta_llama_client import (
9-
MetaLlamaSUT,
10-
MetaLlamaChatRequest,
11-
InputMessage,
12-
MetaLlamaApiKey,
13-
)
6+
from modelgauge.sut import SUTOptions, SUTResponse
7+
from modelgauge.suts.meta_llama_client import InputMessage, MetaLlamaApiKey, MetaLlamaChatRequest, MetaLlamaSUT
8+
from pytest import fixture
9+
from requests import HTTPError # type:ignore
1410

1511
llama_chat_response_text = """
1612
{
@@ -43,12 +39,12 @@
4339
"""
4440

4541

46-
def make_sut():
42+
@fixture
43+
def sut():
4744
return MetaLlamaSUT("ignored", "a_model", MetaLlamaApiKey("whatever"))
4845

4946

50-
def test_translate_text_prompt():
51-
sut = make_sut()
47+
def test_translate_text_prompt(sut):
5248
sut_options = SUTOptions()
5349
result = sut.translate_text_prompt(TextPrompt(text="Why did the chicken cross the road?"), sut_options)
5450
assert result == MetaLlamaChatRequest(
@@ -58,8 +54,7 @@ def test_translate_text_prompt():
5854
)
5955

6056

61-
def test_translate_chat_response():
62-
sut = make_sut()
57+
def test_translate_chat_response(sut):
6358
request = MetaLlamaChatRequest(
6459
model="a_model",
6560
messages=[InputMessage(role="user", content="Why did the chicken cross the road?")],
@@ -71,18 +66,16 @@ def test_translate_chat_response():
7166
)
7267

7368

74-
def test_evaluate():
75-
sut = make_sut()
69+
def test_evaluate(sut):
7670
request = MetaLlamaChatRequest(
7771
model="a_model",
7872
messages=[InputMessage(role="user", content="Why did the chicken cross the road?")],
7973
max_completion_tokens=123,
8074
)
8175
sut.client = MagicMock()
82-
response = sut.evaluate(request)
76+
_ = sut.evaluate(request)
8377
assert sut.client.chat.completions.create.call_count == 1
8478
kwargs = sut.client.chat.completions.create.call_args.kwargs
85-
print(kwargs)
8679
assert kwargs["model"] == "a_model"
8780
assert kwargs["messages"][0]["role"] == "user"
8881
assert kwargs["messages"][0]["content"] == "Why did the chicken cross the road?"

0 commit comments

Comments
 (0)