Skip to content

Commit 3ac7a3e

Browse files
committed
Update unit and integration tests
1 parent 88a0fe5 commit 3ac7a3e

6 files changed

Lines changed: 1318 additions & 14 deletions

File tree

packages/nemo-evaluator/tests/integration/byob/test_byob_e2e.py

Lines changed: 205 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,33 @@
1919
- TruthfulQA scorer logic with mocked judge
2020
- Compiler integration (compile + install)
2121
- Runner E2E with mock server subprocess calls
22+
- Multiple-choice loglikelihood E2E with an in-process mock /v1/completions
2223
"""
2324

25+
import hashlib
2426
import importlib.util
2527
import json
2628
import subprocess
2729
import sys
30+
import threading
31+
from http.server import BaseHTTPRequestHandler, HTTPServer
2832
from pathlib import Path
2933
from unittest.mock import patch
3034

3135
import pytest
3236

37+
from nemo_evaluator.contrib.byob.aggregation import aggregate_scores
3338
from nemo_evaluator.contrib.byob.compiler import compile_benchmark, install_benchmark
34-
from nemo_evaluator.contrib.byob.decorators import ScorerInput
39+
from nemo_evaluator.contrib.byob.decorators import (
40+
BenchmarkDefinition,
41+
ScorerInput,
42+
)
43+
from nemo_evaluator.contrib.byob.eval_logic import run_eval_loop
44+
from nemo_evaluator.contrib.byob.runner import (
45+
_create_session_model_call_fn,
46+
create_session,
47+
)
48+
from nemo_evaluator.contrib.byob.scorers import multiple_choice_acc
3549

3650

3751
def _get_truthfulqa_benchmark_path():
@@ -299,3 +313,193 @@ def test_runner_cli_help(self):
299313
"Help output should mention --benchmark-module"
300314
)
301315
assert "--model-url" in result.stdout, "Help output should mention --model-url"
316+
317+
318+
# ---------------------------------------------------------------------------
319+
# Multiple-choice loglikelihood E2E
320+
#
321+
# Spins up an in-process OpenAI-compatible /v1/completions server that
322+
# returns deterministic ``echo+logprobs`` payloads, then runs
323+
# ``MultipleChoiceStrategy`` over a 4-question synthetic MMLU-style
324+
# dataset. Verifies the wire payload (``echo=true, logprobs=1, max_tokens=0``),
325+
# aggregated metrics (``acc``, ``acc_norm``, ``acc_greedy``), and
326+
# per-prediction diagnostics.
327+
# ---------------------------------------------------------------------------
328+
329+
330+
class _LogprobHandler(BaseHTTPRequestHandler):
331+
"""Deterministic mock that returns echo+logprobs payloads.
332+
333+
Tokenizes the incoming prompt by whitespace, assigns a fixed
334+
log-prob to each token (``-len(token) * 0.5``), then patches the
335+
log-prob of the token whose lowercase form matches ``GOLD_TOKEN``
336+
to ``-0.1`` so a known choice wins argmax.
337+
"""
338+
339+
GOLD_TOKEN = "b"
340+
requests_log: list = []
341+
342+
def do_POST(self):
343+
length = int(self.headers.get("Content-Length", 0))
344+
body = json.loads(self.rfile.read(length))
345+
type(self).requests_log.append({"path": self.path, "body": body})
346+
347+
if self.path != "/completions":
348+
self.send_response(404)
349+
self.end_headers()
350+
return
351+
352+
prompt = body.get("prompt", "")
353+
tokens, offsets = [], []
354+
i = 0
355+
n = len(prompt)
356+
while i < n:
357+
start = i
358+
while i < n and prompt[i].isspace() and i == start:
359+
i += 1
360+
while i < n and not prompt[i].isspace():
361+
i += 1
362+
tok = prompt[start:i]
363+
if tok:
364+
tokens.append(tok)
365+
offsets.append(start)
366+
367+
token_logprobs = [None]
368+
top_logprobs = [None]
369+
for tok in tokens[1:]:
370+
stripped = tok.strip().lower()
371+
if stripped == self.GOLD_TOKEN:
372+
lp = -0.1
373+
else:
374+
h = hashlib.md5(stripped.encode()).hexdigest()
375+
lp = -0.5 - (int(h[:4], 16) % 100) / 100.0
376+
token_logprobs.append(lp)
377+
top_logprobs.append({tok: lp})
378+
379+
resp = {
380+
"choices": [
381+
{
382+
"text": "",
383+
"logprobs": {
384+
"tokens": tokens,
385+
"token_logprobs": token_logprobs,
386+
"text_offset": offsets,
387+
"top_logprobs": top_logprobs,
388+
},
389+
}
390+
]
391+
}
392+
self.send_response(200)
393+
self.send_header("Content-Type", "application/json")
394+
self.end_headers()
395+
self.wfile.write(json.dumps(resp).encode())
396+
397+
def log_message(self, *_args, **_kwargs): # silence
398+
pass
399+
400+
401+
class _LogprobServer:
402+
def __init__(self):
403+
self.server = HTTPServer(("localhost", 0), _LogprobHandler)
404+
self.port = self.server.server_address[1]
405+
self.thread = threading.Thread(target=self.server.serve_forever, daemon=True)
406+
407+
def __enter__(self):
408+
_LogprobHandler.requests_log = []
409+
self.thread.start()
410+
return self
411+
412+
def __exit__(self, *_):
413+
self.server.shutdown()
414+
self.thread.join(timeout=5)
415+
416+
@property
417+
def url(self):
418+
return f"http://localhost:{self.port}"
419+
420+
421+
@pytest.fixture
422+
def loglikelihood_server():
423+
with _LogprobServer() as s:
424+
yield s
425+
426+
427+
def _make_logprob_args(server_url: str):
428+
"""Build a minimal Namespace shaped like argparse output."""
429+
import argparse
430+
431+
return argparse.Namespace(
432+
model_url=server_url,
433+
model_id="mock",
434+
api_key_name=None,
435+
temperature=0.0,
436+
max_tokens=0,
437+
top_p=None,
438+
request_timeout=None,
439+
timeout_per_sample=30.0,
440+
)
441+
442+
443+
class TestMultipleChoiceLogprobE2E:
444+
"""End-to-end loglikelihood evaluation against an in-process server."""
445+
446+
def test_multiple_choice_loglikelihood_e2e(self, loglikelihood_server):
447+
bench = BenchmarkDefinition(
448+
name="mock-mmlu",
449+
normalized_name="mock_mmlu",
450+
dataset="x",
451+
prompt="Q: {q} Answer:",
452+
scorer_fn=multiple_choice_acc,
453+
target_field="answer",
454+
endpoint_type="completions_logprob",
455+
# Leading space so each choice tokenizes as a single token in
456+
# the mock server's whitespace-based tokenizer.
457+
choices=[" A", " B", " C", " D"],
458+
)
459+
460+
# Target stored as a verbatim choice so multiple_choice_acc
461+
# resolves the gold index by string match.
462+
dataset = [
463+
{"q": "what?", "answer": " B"},
464+
{"q": "where?", "answer": " B"},
465+
{"q": "when?", "answer": " B"},
466+
{"q": "why?", "answer": " B"},
467+
]
468+
469+
args = _make_logprob_args(loglikelihood_server.url)
470+
session = create_session(max_retries=1, backoff_factor=0.0)
471+
model_call_fn = _create_session_model_call_fn(args, None, session)
472+
473+
scores, predictions = run_eval_loop(
474+
bench=bench,
475+
dataset=dataset,
476+
model_call_fn=model_call_fn,
477+
endpoint_type="completions_logprob",
478+
save_predictions=True,
479+
)
480+
481+
# Rigged: token "B" gets the highest logprob so all 4 samples score 1.
482+
assert len(scores) == 4
483+
assert all(s["acc"] == 1.0 for s in scores), scores
484+
assert all(s["acc_norm"] == 1.0 for s in scores), scores
485+
486+
# Each sample triggers exactly 4 server calls (one per choice).
487+
assert len(_LogprobHandler.requests_log) == 16
488+
payload = _LogprobHandler.requests_log[0]["body"]
489+
assert payload["max_tokens"] == 0
490+
assert payload["echo"] is True
491+
assert payload["logprobs"] == 1
492+
assert payload["temperature"] == 0.0
493+
494+
aggregated = aggregate_scores(scores, "mock_mmlu")
495+
out_scores = aggregated["tasks"]["mock_mmlu"]["metrics"]["pass@1"]["scores"]
496+
assert "acc" in out_scores
497+
assert "acc_norm" in out_scores
498+
assert "acc_greedy" in out_scores
499+
assert out_scores["acc"]["value"] == 1.0
500+
assert out_scores["acc_norm"]["value"] == 1.0
501+
502+
# All predictions captured per-choice diagnostic metadata.
503+
for pred in predictions:
504+
assert pred.metadata["_choices"] == [" A", " B", " C", " D"]
505+
assert len(pred.metadata["_choices_logprobs"]) == 4

0 commit comments

Comments
 (0)