Skip to content

Commit 7dc5809

Browse files
committed
More tests
1 parent a4f061f commit 7dc5809

7 files changed

Lines changed: 485 additions & 27 deletions

File tree

src/maxtext/eval/configs/mlperf.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# MLPerf OpenOrca summarisation evaluation config.
1+
# MLPerf OpenOrca evaluation config.
22

33
benchmark: "mlperf_openorca"
44
max_tokens: 1024

src/maxtext/eval/datasets/registry.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,7 @@
1414

1515
"""Registry mapping benchmark names to BenchmarkDataset classes.
1616
17-
MMLU, GPQA, MATH, and GSM-8K are handled by lm_eval_runner (loglikelihood)
18-
or evalchemy_runner (generation) and no longer have custom dataset loaders here.
19-
Custom loaders are retained only for benchmarks not covered by those runners
20-
(currently: MLPerf OpenOrca, which requires ROUGE scoring).
17+
This can be used to define custom dataset loaders for benchmarks not covered by lm_eval and evalchemy.
2118
"""
2219

2320
from __future__ import annotations

src/maxtext/eval/runner/eval_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
--hf_path meta-llama/Llama-3.1-8B-Instruct
3030
3131
HuggingFace safetensors mode:
32-
Pass --hf_mode and point --hf_path at an existing HF model directory.
32+
Use --hf_mode and point --hf_path to an existing HF model directory.
3333
3434
python -m maxtext.eval.runner.eval_runner \
3535
--config src/maxtext/eval/configs/mlperf.yml \

tests/unit/eval/test_build_app.py

Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Unit tests for maxtext.eval.runner.server_manager._build_app."""
16+
17+
from __future__ import annotations
18+
19+
import unittest
20+
from types import SimpleNamespace
21+
from unittest.mock import MagicMock, patch
22+
23+
24+
def _make_mock_output(generated_text="hello", prompt_token_ids=(1, 2, 3), generated_token_ids=(4, 5)):
25+
"""Build a SimpleNamespace mimicking a vLLM RequestOutput object."""
26+
return SimpleNamespace(
27+
prompt_token_ids=list(prompt_token_ids),
28+
prompt_logprobs=None,
29+
outputs=[
30+
SimpleNamespace(
31+
text=generated_text,
32+
token_ids=list(generated_token_ids),
33+
logprobs=None,
34+
finish_reason="stop",
35+
)
36+
],
37+
)
38+
39+
40+
def _make_mock_llm(generated_text="hello", prompt_token_ids=(1, 2, 3), generated_token_ids=(4, 5)):
41+
"""Return a mock vLLM LLM object whose generate() returns a single RequestOutput.
42+
43+
The tokenizer returned by ``get_tokenizer()`` decodes each token ID to the
44+
string ``f"tok{tok_id}"``.
45+
"""
46+
mock_output = _make_mock_output(
47+
generated_text=generated_text,
48+
prompt_token_ids=prompt_token_ids,
49+
generated_token_ids=generated_token_ids,
50+
)
51+
52+
mock_tokenizer = MagicMock()
53+
mock_tokenizer.decode.side_effect = lambda ids: "".join(f"tok{i}" for i in ids)
54+
mock_tokenizer.apply_chat_template.return_value = "rendered_prompt"
55+
56+
mock_llm = MagicMock()
57+
mock_llm.generate.return_value = [mock_output]
58+
mock_llm.get_tokenizer.return_value = mock_tokenizer
59+
return mock_llm
60+
61+
62+
class TestBuildApp(unittest.TestCase):
63+
"""Tests for the FastAPI app returned by _build_app(llm)."""
64+
65+
def setUp(self):
66+
"""Patch SamplingParams at the module level used by server_manager."""
67+
self.mock_llm = _make_mock_llm()
68+
self.mock_sampling_params_cls = MagicMock(return_value=MagicMock())
69+
70+
# Patch at the import location used inside _build_app.
71+
self._sp_patcher = patch(
72+
"vllm.sampling_params.SamplingParams",
73+
self.mock_sampling_params_cls,
74+
)
75+
self._vllm_patcher = patch.dict(
76+
"sys.modules",
77+
{
78+
"vllm": MagicMock(),
79+
"vllm.sampling_params": MagicMock(SamplingParams=self.mock_sampling_params_cls),
80+
},
81+
)
82+
self._vllm_patcher.start()
83+
self._sp_patcher.start()
84+
85+
from maxtext.eval.runner.server_manager import _build_app
86+
from starlette.testclient import TestClient
87+
88+
self.app = _build_app(self.mock_llm)
89+
self.client = TestClient(self.app)
90+
91+
def tearDown(self):
92+
self._sp_patcher.stop()
93+
self._vllm_patcher.stop()
94+
95+
def test_health_endpoint(self):
96+
resp = self.client.get("/health")
97+
self.assertEqual(resp.status_code, 200)
98+
self.assertEqual(resp.json(), {"status": "ok"})
99+
100+
def test_completions_basic(self):
101+
resp = self.client.post(
102+
"/v1/completions",
103+
json={"model": "m", "prompt": "hi", "max_tokens": 10},
104+
)
105+
self.assertEqual(resp.status_code, 200)
106+
data = resp.json()
107+
self.assertIn("choices", data)
108+
self.assertEqual(len(data["choices"]), 1)
109+
self.assertEqual(data["choices"][0]["text"], "hello")
110+
111+
def test_completions_list_prompt(self):
112+
mock_llm = _make_mock_llm(generated_text="world")
113+
mock_llm.generate.return_value = [
114+
_make_mock_output(generated_text="alpha"),
115+
_make_mock_output(generated_text="beta"),
116+
]
117+
mock_llm.get_tokenizer.return_value = self.mock_llm.get_tokenizer()
118+
119+
from maxtext.eval.runner.server_manager import _build_app
120+
from starlette.testclient import TestClient
121+
122+
app = _build_app(mock_llm)
123+
client = TestClient(app)
124+
125+
resp = client.post(
126+
"/v1/completions",
127+
json={"model": "m", "prompt": ["first", "second"], "max_tokens": 5},
128+
)
129+
self.assertEqual(resp.status_code, 200)
130+
data = resp.json()
131+
self.assertEqual(len(data["choices"]), 2)
132+
self.assertEqual(data["choices"][0]["text"], "alpha")
133+
self.assertEqual(data["choices"][1]["text"], "beta")
134+
135+
def test_completions_no_logprobs(self):
136+
resp = self.client.post(
137+
"/v1/completions",
138+
json={"model": "m", "prompt": "test", "max_tokens": 5},
139+
)
140+
data = resp.json()
141+
self.assertIsNone(data["choices"][0]["logprobs"])
142+
143+
def test_completions_with_logprobs_echo_false(self):
144+
mock_output = _make_mock_output(
145+
generated_text="hi",
146+
prompt_token_ids=[1, 2],
147+
generated_token_ids=[4, 5],
148+
)
149+
mock_output.outputs[0].logprobs = [
150+
{4: SimpleNamespace(logprob=-0.5)},
151+
{5: SimpleNamespace(logprob=-1.2)},
152+
]
153+
self.mock_llm.generate.return_value = [mock_output]
154+
155+
resp = self.client.post(
156+
"/v1/completions",
157+
json={"model": "m", "prompt": "ab", "max_tokens": 5, "logprobs": 1},
158+
)
159+
self.assertEqual(resp.status_code, 200)
160+
data = resp.json()
161+
lp = data["choices"][0]["logprobs"]
162+
self.assertIsNotNone(lp)
163+
self.assertEqual(len(lp["tokens"]), 2)
164+
self.assertAlmostEqual(lp["token_logprobs"][0], -0.5, places=4)
165+
self.assertAlmostEqual(lp["token_logprobs"][1], -1.2, places=4)
166+
167+
def test_completions_with_logprobs_echo_true(self):
168+
mock_output = _make_mock_output(
169+
generated_text=" world",
170+
prompt_token_ids=[1, 2, 3],
171+
generated_token_ids=[4, 5],
172+
)
173+
mock_output.prompt_logprobs = [
174+
None,
175+
{2: SimpleNamespace(logprob=-0.3)},
176+
{3: SimpleNamespace(logprob=-0.7)},
177+
]
178+
mock_output.outputs[0].logprobs = [
179+
{4: SimpleNamespace(logprob=-0.9)},
180+
{5: SimpleNamespace(logprob=-1.1)},
181+
]
182+
self.mock_llm.generate.return_value = [mock_output]
183+
184+
resp = self.client.post(
185+
"/v1/completions",
186+
json={
187+
"model": "m",
188+
"prompt": "tok1tok2tok3",
189+
"max_tokens": 5,
190+
"logprobs": 1,
191+
"echo": True,
192+
},
193+
)
194+
self.assertEqual(resp.status_code, 200)
195+
data = resp.json()
196+
lp = data["choices"][0]["logprobs"]
197+
self.assertIsNotNone(lp)
198+
# echo=True → prompt tokens (3) + generated tokens (2) = 5 total.
199+
self.assertEqual(len(lp["tokens"]), 5)
200+
201+
def test_chat_completions_basic(self):
202+
resp = self.client.post(
203+
"/v1/chat/completions",
204+
json={
205+
"model": "m",
206+
"messages": [{"role": "user", "content": "hello"}],
207+
"max_tokens": 20,
208+
},
209+
)
210+
self.assertEqual(resp.status_code, 200)
211+
data = resp.json()
212+
self.assertIn("choices", data)
213+
self.assertEqual(data["choices"][0]["message"]["role"], "assistant")
214+
self.assertEqual(data["choices"][0]["message"]["content"], "hello")
215+
216+
def test_chat_completions_applies_template(self):
217+
resp = self.client.post(
218+
"/v1/chat/completions",
219+
json={
220+
"model": "m",
221+
"messages": [{"role": "user", "content": "ping"}],
222+
"max_tokens": 10,
223+
},
224+
)
225+
self.assertEqual(resp.status_code, 200)
226+
tokenizer = self.mock_llm.get_tokenizer()
227+
tokenizer.apply_chat_template.assert_called()
228+
call_args = tokenizer.apply_chat_template.call_args
229+
# The messages list should have been forwarded to apply_chat_template.
230+
passed_messages = call_args[0][0] if call_args[0] else call_args[1].get("conversation")
231+
self.assertIsNotNone(passed_messages)
232+
233+
234+
if __name__ == "__main__":
235+
unittest.main()

0 commit comments

Comments
 (0)