Skip to content

Commit e6cbe86

Browse files
committed
add new test to verify dataset id and fix code
1 parent 64efd1c commit e6cbe86

2 files changed

Lines changed: 194 additions & 6 deletions

File tree

eval_protocol/cli_commands/create_rft.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -446,9 +446,6 @@ def create_rft_command(args) -> int:
446446
if not evaluator_id:
447447
print("Error: Could not infer evaluator id. Provide --evaluator-id or run 'eval-protocol upload' first.")
448448
return 1
449-
# Persist last selected/used evaluator for next runs
450-
_save_last_evaluator(project_root, evaluator_id)
451-
452449
# Resolve evaluator resource name to fully-qualified format required by API
453450
evaluator_resource_name = f"accounts/{account_id}/evaluators/{evaluator_id}"
454451

@@ -515,6 +512,9 @@ def create_rft_command(args) -> int:
515512
print(f"📊 Please check the evaluator status at: {dashboard_url}")
516513
print(" Wait for it to become ACTIVE, then run 'eval-protocol create rft' again.")
517514
return 1
515+
else:
516+
# Only persist last-used evaluator after successful ensure + ACTIVE
517+
_save_last_evaluator(project_root, evaluator_id)
518518
else:
519519
print("Warning: Evaluator upload did not complete successfully; proceeding to RFT creation.")
520520
except Exception as e:

tests/test_cli_create_rft_infer.py

Lines changed: 191 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,17 +102,17 @@ def test_auto_select_falls_back_to_single_discovered_test(tmp_path, monkeypatch)
102102
project.mkdir()
103103

104104
# No traces; provide exactly one discovered test
105-
test_file = project / "metric" / "test_calendar.py"
105+
test_file = project / "metric" / "test_dummy.py"
106106
test_file.parent.mkdir(parents=True, exist_ok=True)
107107
test_file.write_text("# dummy", encoding="utf-8")
108108

109-
dummy = SimpleNamespace(qualname="calendar_agent.test_calendar_agent_evaluation", file_path=str(test_file))
109+
dummy = SimpleNamespace(qualname="dummy_module.test_dummy_evaluation", file_path=str(test_file))
110110
monkeypatch.setattr(cr, "_discover_tests", lambda cwd: [dummy])
111111

112112
eid = cr._auto_select_evaluator_id(str(project))
113113
assert eid is not None
114114
# Should incorporate function name suffix
115-
assert "test_calendar_agent_evaluation".split("_")[-1] in eid or "test-calendar-agent-evaluation" in eid
115+
assert "test_dummy_evaluation".split("_")[-1] in eid or "test-dummy-evaluation" in eid
116116

117117

118118
def test_auto_select_returns_none_when_no_candidates(tmp_path, monkeypatch):
@@ -124,3 +124,191 @@ def test_auto_select_returns_none_when_no_candidates(tmp_path, monkeypatch):
124124
monkeypatch.setattr(cr, "_discover_tests", lambda cwd: [])
125125
eid = cr._auto_select_evaluator_id(str(project))
126126
assert eid is None
127+
128+
129+
def test_create_rft_picks_most_recent_evaluator_and_dataset_id_follows(tmp_path, monkeypatch):
130+
# Isolate HOME so expanduser paths remain inside tmp
131+
monkeypatch.setenv("HOME", str(tmp_path / "home"))
132+
133+
# Create a fake project and chdir into it (create_rft uses os.getcwd())
134+
project = tmp_path / "proj"
135+
project.mkdir()
136+
monkeypatch.chdir(project)
137+
138+
# Prepare two evaluator traces with different mtimes
139+
traces_dir = project / ".eval_protocol" / "evaluators"
140+
traces_dir.mkdir(parents=True, exist_ok=True)
141+
older = traces_dir / "example-eval-1.json"
142+
newer = traces_dir / "example-eval-2.json"
143+
older.write_text("{}", encoding="utf-8")
144+
newer.write_text("{}", encoding="utf-8")
145+
t0 = time.time() - 200
146+
os.utime(str(older), (t0, t0))
147+
t1 = time.time()
148+
os.utime(str(newer), (t1, t1))
149+
150+
# Create a dummy dataset jsonl file
151+
ds_path = project / "evaluator" / "dummy_dataset.jsonl"
152+
ds_path.parent.mkdir(parents=True, exist_ok=True)
153+
ds_path.write_text('{"input":"x"}\n', encoding="utf-8")
154+
155+
# Env required by create_rft_command
156+
monkeypatch.setenv("FIREWORKS_API_KEY", "fw_dummy")
157+
monkeypatch.setenv("FIREWORKS_ACCOUNT_ID", "acct123")
158+
monkeypatch.setenv("FIREWORKS_API_BASE", "https://api.fireworks.ai")
159+
160+
# Stub out networked/subcommands used by create_rft
161+
# Patch upload command in its own module (create_rft imports it at call time)
162+
import eval_protocol.cli_commands.upload as upload_mod
163+
164+
monkeypatch.setattr(upload_mod, "upload_command", lambda args: 0)
165+
monkeypatch.setattr(cr, "_poll_evaluator_status", lambda **kwargs: True)
166+
167+
captured = {"dataset_id": None}
168+
169+
def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, display_name, jsonl_path):
170+
captured["dataset_id"] = dataset_id
171+
return dataset_id, {"name": f"accounts/{account_id}/datasets/{dataset_id}", "state": "UPLOADING"}
172+
173+
monkeypatch.setattr(cr, "create_dataset_from_jsonl", _fake_create_dataset_from_jsonl)
174+
monkeypatch.setattr(cr, "create_reinforcement_fine_tuning_job", lambda *a, **k: {"name": "jobs/123"})
175+
176+
# Build args: non_interactive (yes=True), no explicit evaluator_id, valid warm_start_from
177+
args = type("Args", (), {})()
178+
setattr(args, "evaluator_id", None)
179+
setattr(args, "yes", True)
180+
setattr(args, "dry_run", False)
181+
setattr(args, "force", False)
182+
setattr(args, "env_file", None)
183+
setattr(args, "dataset_id", None)
184+
setattr(args, "dataset_jsonl", str(ds_path))
185+
setattr(args, "dataset_display_name", None)
186+
setattr(args, "dataset_builder", None)
187+
setattr(args, "base_model", None)
188+
setattr(args, "warm_start_from", "accounts/acct123/models/ft-abc123")
189+
setattr(args, "output_model", None)
190+
setattr(args, "n", None)
191+
setattr(args, "max_tokens", None)
192+
setattr(args, "learning_rate", None)
193+
setattr(args, "batch_size", None)
194+
setattr(args, "epochs", None)
195+
setattr(args, "lora_rank", None)
196+
setattr(args, "max_context_length", None)
197+
setattr(args, "chunk_size", None)
198+
setattr(args, "eval_auto_carveout", None)
199+
200+
rc = cr.create_rft_command(args)
201+
assert rc == 0
202+
203+
# Assert dataset id followed the most recent evaluator id ("example-eval-2")
204+
assert captured["dataset_id"] is not None
205+
assert captured["dataset_id"].startswith("example-eval-2-dataset-")
206+
207+
208+
def test_create_rft_passes_matching_evaluator_id_and_entry_with_multiple_tests(tmp_path, monkeypatch):
209+
# Ensure expanduser paths stay under tmp
210+
monkeypatch.setenv("HOME", str(tmp_path / "home"))
211+
212+
# Project structure and CWD
213+
project = tmp_path / "proj"
214+
project.mkdir()
215+
monkeypatch.chdir(project)
216+
217+
# Two evaluator traces: make the target evaluator the most recent
218+
traces_dir = project / ".eval_protocol" / "evaluators"
219+
traces_dir.mkdir(parents=True, exist_ok=True)
220+
svg_id = "example-svg-evaluation"
221+
# Use an evaluator id that matches normalization logic for mapping to foo_eval.py::test_bar_evaluation
222+
target_id = cr._normalize_evaluator_id("foo_eval-test_bar_evaluation")
223+
older = traces_dir / f"{svg_id}.json"
224+
newer = traces_dir / f"{target_id}.json"
225+
older.write_text("{}", encoding="utf-8")
226+
newer.write_text("{}", encoding="utf-8")
227+
t0 = time.time() - 200
228+
os.utime(str(older), (t0, t0))
229+
t1 = time.time()
230+
os.utime(str(newer), (t1, t1))
231+
232+
# Create dummy test files for discovery
233+
eval_dir = project / "evaluator"
234+
eval_dir.mkdir(parents=True, exist_ok=True)
235+
cal_file = eval_dir / "foo_eval.py"
236+
svg_file = eval_dir / "bar_eval.py"
237+
cal_file.write_text("# foo", encoding="utf-8")
238+
svg_file.write_text("# bar", encoding="utf-8")
239+
240+
# Fake discovered tests: foo and bar
241+
cal_disc = SimpleNamespace(qualname="foo_eval.test_bar_evaluation", file_path=str(cal_file))
242+
svg_disc = SimpleNamespace(qualname="bar_eval.test_baz_evaluation", file_path=str(svg_file))
243+
monkeypatch.setattr(cr, "_discover_tests", lambda cwd: [cal_disc, svg_disc])
244+
245+
# Env for CLI
246+
monkeypatch.setenv("FIREWORKS_API_KEY", "fw_dummy")
247+
monkeypatch.setenv("FIREWORKS_ACCOUNT_ID", "acct123")
248+
monkeypatch.setenv("FIREWORKS_API_BASE", "https://api.fireworks.ai")
249+
250+
# Capture what upload receives (id and entry)
251+
captured = {"id": None, "entry": None, "dataset_id": None}
252+
253+
# Monkeypatch the upload command from the upload module (the function imports it inside)
254+
import eval_protocol.cli_commands.upload as upload_mod
255+
256+
def _fake_upload(ns):
257+
captured["id"] = getattr(ns, "id", None)
258+
captured["entry"] = getattr(ns, "entry", None)
259+
return 0
260+
261+
monkeypatch.setattr(upload_mod, "upload_command", _fake_upload)
262+
263+
# Avoid network and capture dataset id
264+
monkeypatch.setattr(cr, "_poll_evaluator_status", lambda **kwargs: True)
265+
266+
def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, display_name, jsonl_path):
267+
captured["dataset_id"] = dataset_id
268+
return dataset_id, {"name": f"accounts/{account_id}/datasets/{dataset_id}", "state": "UPLOADING"}
269+
270+
monkeypatch.setattr(cr, "create_dataset_from_jsonl", _fake_create_dataset_from_jsonl)
271+
monkeypatch.setattr(cr, "create_reinforcement_fine_tuning_job", lambda *a, **k: {"name": "jobs/123"})
272+
273+
# Provide a dataset jsonl so flow proceeds
274+
ds_path = eval_dir / "dummy_dataset.jsonl"
275+
ds_path.write_text('{"input":"x"}\n', encoding="utf-8")
276+
277+
# Build args: non-interactive, no explicit evaluator id
278+
import argparse
279+
280+
args = argparse.Namespace(
281+
evaluator_id=None,
282+
yes=True,
283+
dry_run=False,
284+
force=False,
285+
env_file=None,
286+
dataset_id=None,
287+
dataset_jsonl=str(ds_path),
288+
dataset_display_name=None,
289+
dataset_builder=None,
290+
base_model=None,
291+
warm_start_from="accounts/acct123/models/ft-abc123",
292+
output_model=None,
293+
n=None,
294+
max_tokens=None,
295+
learning_rate=None,
296+
batch_size=None,
297+
epochs=None,
298+
lora_rank=None,
299+
max_context_length=None,
300+
chunk_size=None,
301+
eval_auto_carveout=None,
302+
)
303+
304+
rc = cr.create_rft_command(args)
305+
assert rc == 0
306+
307+
# Assert evaluator_id passed to upload matches the most recent trace (target)
308+
assert captured["id"] == target_id
309+
# Assert entry points to the foo test (should map when id matches normalization)
310+
assert captured["entry"] is not None and captured["entry"].endswith("foo_eval.py::test_bar_evaluation")
311+
# Assert dataset id is derived from the same evaluator id (trimmed base + '-dataset-<timestamp>')
312+
assert captured["dataset_id"] is not None
313+
expected_prefix = cr._build_trimmed_dataset_id(target_id).split("-dataset-")[0] + "-dataset-"
314+
assert captured["dataset_id"].startswith(expected_prefix)

0 commit comments

Comments
 (0)