Skip to content

Commit cee95a9

Browse files
authored
support dataset adapter (#395)
* support dataset adapter * remove multi select * no nul lvalues
1 parent 8a5d7da commit cee95a9

4 files changed

Lines changed: 243 additions & 65 deletions

File tree

eval_protocol/cli_commands/create_rft.py

Lines changed: 96 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@
55
import os
66
import sys
77
import time
8-
from typing import Any, Dict, Optional
8+
from typing import Any, Callable, Dict, Optional
99
import inspect
1010
import requests
11+
import tempfile
1112
from pydantic import ValidationError
1213

1314
from ..auth import get_fireworks_api_base, get_fireworks_api_key
14-
from ..common_utils import get_user_agent
15+
from ..common_utils import get_user_agent, load_jsonl
1516
from ..fireworks_rft import (
16-
build_default_output_model,
1717
create_dataset_from_jsonl,
1818
detect_dataset_builder,
1919
materialize_dataset_via_builder,
@@ -31,12 +31,88 @@
3131
_normalize_evaluator_id,
3232
_print_links,
3333
_resolve_selected_test,
34+
load_module_from_file_path,
3435
)
3536
from .local_test import run_evaluator_test
3637

3738
from fireworks import Fireworks
3839

3940

41+
def _extract_dataset_adapter(
42+
test_file_path: str, test_func_name: str
43+
) -> Optional[Callable[[list[dict[str, Any]]], Any]]:
44+
"""Extract dataset_adapter from an @evaluation_test wrapper via __ep_params__."""
45+
try:
46+
module = load_module_from_file_path(test_file_path)
47+
wrapper = getattr(module, test_func_name, None)
48+
if wrapper is None:
49+
return None
50+
ep_params = getattr(wrapper, "__ep_params__", None)
51+
if ep_params is None:
52+
return None
53+
adapter = getattr(ep_params, "dataset_adapter", None)
54+
if callable(adapter):
55+
return adapter
56+
return None
57+
except Exception:
58+
return None
59+
60+
61+
def _maybe_transform_dataset_jsonl_via_adapter(
62+
project_root: str,
63+
dataset_jsonl: str,
64+
test_file_path: Optional[str],
65+
test_func_name: Optional[str],
66+
) -> str:
67+
"""Transform dataset_jsonl via the test's dataset_adapter (when available).
68+
69+
For RFT dataset uploads, we want the uploaded dataset to match what evaluation-time
70+
would run on. If the selected evaluation test provides a dataset_adapter, that
71+
adapter is treated as the source of truth for constructing EvaluationRows.
72+
"""
73+
if not dataset_jsonl:
74+
return dataset_jsonl
75+
76+
if not test_file_path or not test_func_name:
77+
return dataset_jsonl
78+
79+
adapter = _extract_dataset_adapter(test_file_path, test_func_name)
80+
if not adapter:
81+
return dataset_jsonl
82+
83+
raw_rows: list[dict[str, Any]] = load_jsonl(dataset_jsonl) # type: ignore[assignment]
84+
adapted = adapter(raw_rows)
85+
if not isinstance(adapted, list):
86+
raise ValueError("dataset_adapter must return a list of EvaluationRow (or dicts parseable as EvaluationRow).")
87+
88+
eval_rows: list[EvaluationRow] = []
89+
for item in adapted:
90+
if isinstance(item, EvaluationRow):
91+
eval_rows.append(item)
92+
else:
93+
eval_rows.append(EvaluationRow.model_validate(item))
94+
95+
output_dir = os.path.join(project_root, ".ep_tmp")
96+
os.makedirs(output_dir, exist_ok=True)
97+
with tempfile.NamedTemporaryFile(
98+
mode="w",
99+
encoding="utf-8",
100+
delete=False,
101+
suffix=".jsonl",
102+
prefix="ep_rft_dataset_",
103+
dir=output_dir,
104+
) as f:
105+
for row in eval_rows:
106+
f.write(json.dumps(row.model_dump(mode="json", exclude_none=True), ensure_ascii=False) + "\n")
107+
out_path = os.path.abspath(f.name)
108+
try:
109+
rel = os.path.relpath(out_path, project_root)
110+
except Exception:
111+
rel = out_path
112+
print(f"✓ Transformed dataset via dataset_adapter into EvaluationRow JSONL: {rel} ({len(eval_rows)} rows)")
113+
return out_path
114+
115+
40116
def _extract_jsonl_from_dataloader(test_file_path: str, test_func_name: str) -> Optional[str]:
41117
"""Import the test module and extract a JSONL path from data_loaders param if present.
42118
@@ -45,18 +121,10 @@ def _extract_jsonl_from_dataloader(test_file_path: str, test_func_name: str) ->
45121
relative to the directory of the test file.
46122
"""
47123
try:
48-
import importlib.util
49-
from pathlib import Path
50-
51-
spec = importlib.util.spec_from_file_location(Path(test_file_path).stem, test_file_path)
52-
if not spec or not spec.loader:
53-
return None
54-
module = importlib.util.module_from_spec(spec)
55-
sys.modules[spec.name] = module
56-
spec.loader.exec_module(module) # type: ignore[attr-defined]
57-
if not hasattr(module, test_func_name):
124+
module = load_module_from_file_path(test_file_path)
125+
wrapper = getattr(module, test_func_name, None)
126+
if wrapper is None:
58127
return None
59-
wrapper = getattr(module, test_func_name)
60128
marks = getattr(wrapper, "pytestmark", [])
61129
for m in marks:
62130
if getattr(m, "name", "") == "parametrize":
@@ -105,18 +173,10 @@ def _extract_jsonl_from_input_dataset(test_file_path: str, test_func_name: str)
105173
of the test file.
106174
"""
107175
try:
108-
import importlib.util
109-
from pathlib import Path
110-
111-
spec = importlib.util.spec_from_file_location(Path(test_file_path).stem, test_file_path)
112-
if not spec or not spec.loader:
113-
return None
114-
module = importlib.util.module_from_spec(spec)
115-
sys.modules[spec.name] = module
116-
spec.loader.exec_module(module) # type: ignore[attr-defined]
117-
if not hasattr(module, test_func_name):
176+
module = load_module_from_file_path(test_file_path)
177+
wrapper = getattr(module, test_func_name, None)
178+
if wrapper is None:
118179
return None
119-
wrapper = getattr(module, test_func_name)
120180
marks = getattr(wrapper, "pytestmark", [])
121181
for m in marks:
122182
if getattr(m, "name", "") == "parametrize":
@@ -320,27 +380,15 @@ def _resolve_evaluator(
320380
selected_tests = _discover_and_select_tests(project_root, non_interactive=non_interactive)
321381
if not selected_tests:
322382
return None, None, None, None
383+
323384
if len(selected_tests) != 1:
324385
if non_interactive and len(selected_tests) > 1:
325386
print("Error: Multiple evaluation tests found in --yes (non-interactive) mode.")
326387
print(" Please pass --evaluator or --entry to disambiguate.")
327-
try:
328-
# Offer candidate evaluator ids for convenience
329-
tests = _discover_tests(project_root)
330-
if tests:
331-
print(" Candidate evaluator ids:")
332-
for t in tests:
333-
func = t.qualname.split(".")[-1]
334-
stem = os.path.splitext(os.path.basename(t.file_path))[0]
335-
cand = _normalize_evaluator_id(f"{stem}-{func}")
336-
print(f" - {cand}")
337-
except Exception:
338-
pass
339388
else:
340389
print("Error: Please select exactly one evaluation test for 'create rft'.")
341390
return None, None, None, None
342391

343-
# Derive evaluator_id from user's single selection
344392
chosen = selected_tests[0]
345393
func_name = chosen.qualname.split(".")[-1]
346394
source_file_name = os.path.splitext(os.path.basename(chosen.file_path))[0]
@@ -719,6 +767,16 @@ def create_rft_command(args) -> int:
719767
if dataset_jsonl is None and not dataset_id:
720768
return 1
721769

770+
# 2.5) If the selected evaluation test provides a dataset_adapter, always use it to
771+
# construct the EvaluationRow dataset that we upload for RFT.
772+
if dataset_jsonl is not None:
773+
dataset_jsonl = _maybe_transform_dataset_jsonl_via_adapter(
774+
project_root=project_root,
775+
dataset_jsonl=dataset_jsonl,
776+
test_file_path=selected_test_file_path,
777+
test_func_name=selected_test_func_name,
778+
)
779+
722780
# 3) Optional local validation
723781
if not skip_validation:
724782
# Dataset validation (JSONL must be EvaluationRow-compatible when present)

eval_protocol/cli_commands/upload.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import argparse
22
from eval_protocol.cli_commands.utils import DiscoveredTest
3-
import importlib.util
43
import os
54
import re
65
import sys
@@ -18,6 +17,7 @@
1817
_discover_tests,
1918
_ensure_account_id,
2019
_get_questionary_style,
20+
load_module_from_file_path,
2121
_normalize_evaluator_id,
2222
_prompt_select,
2323
)
@@ -120,13 +120,8 @@ def _resolve_entry_to_qual_and_source(entry: str, cwd: str) -> tuple[str, str]:
120120
source_file_path = os.path.join(cwd, dotted_as_path)
121121

122122
# Load the module from the file path
123-
spec = importlib.util.spec_from_file_location(Path(source_file_path).stem, source_file_path)
124-
if not spec or not spec.loader:
125-
raise ValueError(f"Unable to load module from path: {source_file_path}")
126-
module = importlib.util.module_from_spec(spec)
127-
sys.modules[spec.name] = module
128-
spec.loader.exec_module(module) # type: ignore[attr-defined]
129-
module_name = spec.name
123+
module = load_module_from_file_path(source_file_path)
124+
module_name = getattr(module, "__name__", Path(source_file_path).stem)
130125

131126
if not hasattr(module, func):
132127
raise ValueError(f"Function '{func}' not found in module '{module_name}'")

eval_protocol/cli_commands/utils.py

Lines changed: 52 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from types import ModuleType
2+
3+
14
import os
25
import ast
36
import sys
@@ -6,23 +9,46 @@
69
import argparse
710
import typing
811
import types
12+
import importlib.util
913
from dataclasses import dataclass
1014
from pathlib import Path
11-
from typing import Any, List, Optional, is_typeddict
15+
from typing import Any, List, Optional
1216
import typing_extensions
1317
import inspect
1418
from collections.abc import Callable
1519
import pytest
1620

1721
from ..auth import (
18-
get_fireworks_account_id,
1922
get_fireworks_api_base,
2023
get_fireworks_api_key,
2124
verify_api_key_and_get_account_id,
2225
)
2326
from ..fireworks_rft import _map_api_host_to_app_host
2427

2528

29+
def load_module_from_file_path(source_file_path: str) -> ModuleType:
30+
"""Load a Python module from an absolute/relative filesystem path.
31+
32+
This mirrors the CLI behavior used by `upload.py` and `create_rft.py`:
33+
- module name is derived from the file stem (e.g. /a/b/foo.py -> foo)
34+
- the module is inserted into sys.modules under that name before exec
35+
"""
36+
abs_path = os.path.abspath(source_file_path)
37+
if not os.path.isfile(abs_path):
38+
raise ValueError(f"File not found: {abs_path}")
39+
if not abs_path.endswith(".py"):
40+
raise ValueError(f"Expected a .py file path, got: {abs_path}")
41+
42+
module_name = Path(abs_path).stem
43+
spec = importlib.util.spec_from_file_location(module_name, abs_path)
44+
if not spec or not spec.loader:
45+
raise ValueError(f"Unable to load module from path: {abs_path}")
46+
module = importlib.util.module_from_spec(spec)
47+
sys.modules[spec.name] = module
48+
spec.loader.exec_module(module) # type: ignore[attr-defined]
49+
return module
50+
51+
2652
def _get_questionary_style():
2753
"""Get the shared questionary style for CLI prompts - minimal and clean."""
2854
try:
@@ -252,7 +278,7 @@ def _format_test_choice(test: DiscoveredTest, idx: int) -> str:
252278

253279

254280
def _prompt_select_interactive(tests: list[DiscoveredTest]) -> list[DiscoveredTest]:
255-
"""Interactive selection with arrow keys using questionary."""
281+
"""Interactive single selection with arrow keys using questionary (Enter selects highlighted)."""
256282
try:
257283
import questionary
258284

@@ -263,35 +289,32 @@ def _prompt_select_interactive(tests: list[DiscoveredTest]) -> list[DiscoveredTe
263289
print(f"\nFound 1 test: {_format_test_choice(tests[0], 1)}")
264290
confirm = questionary.confirm("Select this test?", default=True, style=custom_style).ask()
265291
if confirm:
266-
return tests
292+
return [tests[0]]
267293
else:
268294
return []
269295

270-
# Build checkbox choices
296+
# Build single-select choices
271297
choices = []
272298
for idx, t in enumerate(tests, 1):
273299
choice_text = _format_test_choice(t, idx)
274-
choices.append(questionary.Choice(title=choice_text, value=idx - 1, checked=False))
300+
choices.append(questionary.Choice(title=choice_text, value=idx - 1))
275301

276302
print()
277-
selected_indices = questionary.checkbox(
278-
"Select evaluation tests to upload:",
303+
selected_index = questionary.select(
304+
"Select an evaluation test:",
279305
choices=choices,
280306
style=custom_style,
281307
pointer=">",
282-
instruction="(↑↓ move, space select, enter confirm)",
308+
instruction="(↑↓ move, enter confirm)",
283309
).ask()
284310

285-
if selected_indices is None: # Ctrl+C
311+
if selected_index is None: # Ctrl+C / Esc
286312
print("\nUpload cancelled.")
287313
return []
288314

289-
if not selected_indices:
290-
return []
291-
292-
selected_tests = [tests[i] for i in selected_indices]
293-
print(f"\n✓ Selected {len(selected_tests)} test(s)")
294-
return selected_tests
315+
chosen = tests[int(selected_index)]
316+
print("\n✓ Selected 1 test")
317+
return [chosen]
295318

296319
except ImportError:
297320
# Fallback to simpler implementation
@@ -346,9 +369,10 @@ def _prompt_select_fallback(tests: list[DiscoveredTest]) -> list[DiscoveredTest]
346369

347370

348371
def _prompt_select(tests: list[DiscoveredTest], non_interactive: bool) -> list[DiscoveredTest]:
349-
"""Prompt user to select tests to upload."""
372+
"""Prompt user to select exactly one test."""
350373
if non_interactive:
351-
return tests
374+
# In non-interactive mode, only proceed if unambiguous.
375+
return [tests[0]] if len(tests) == 1 else []
352376

353377
return _prompt_select_interactive(tests)
354378

@@ -375,7 +399,16 @@ def _discover_and_select_tests(project_root: str, non_interactive: bool) -> Opti
375399
return None
376400

377401
if not selected_tests:
378-
print("No tests selected.")
402+
if non_interactive and len(tests) > 1:
403+
print("Error: Multiple evaluation tests found in --yes (non-interactive) mode.")
404+
print(" Please pass --evaluator or --entry to disambiguate.")
405+
else:
406+
print("No test selected.")
407+
return None
408+
409+
# Enforce single-select at the helper level.
410+
if len(selected_tests) != 1:
411+
print("Error: Please select exactly one evaluation test.")
379412
return None
380413

381414
return selected_tests

0 commit comments

Comments
 (0)