Skip to content

Commit 6229298

Browse files
committed
make dataset work
1 parent 480fe2a commit 6229298

2 files changed

Lines changed: 267 additions & 10 deletions

File tree

eval_protocol/cli_commands/create_rft.py

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
create_dataset_from_jsonl,
2121
create_reinforcement_fine_tuning_job,
2222
)
23+
from ..fireworks_rft import detect_dataset_builder, materialize_dataset_via_builder
2324
from .upload import _discover_tests, _normalize_evaluator_id, _prompt_select
2425

2526

@@ -316,6 +317,9 @@ def create_rft_command(args) -> int:
316317
non_interactive: bool = bool(getattr(args, "yes", False))
317318
dry_run: bool = bool(getattr(args, "dry_run", False))
318319
force: bool = bool(getattr(args, "force", False))
320+
# Track the specifically chosen test (if any) to aid dataset inference later
321+
selected_test_file_path: Optional[str] = None
322+
selected_test_func_name: Optional[str] = None
319323

320324
api_key = get_fireworks_api_key()
321325
if not api_key:
@@ -354,6 +358,8 @@ def create_rft_command(args) -> int:
354358
func_name = chosen.qualname.split(".")[-1]
355359
source_file_name = os.path.splitext(os.path.basename(chosen.file_path))[0]
356360
evaluator_id = _normalize_evaluator_id(f"{source_file_name}-{func_name}")
361+
selected_test_file_path = chosen.file_path
362+
selected_test_func_name = func_name
357363
# Resolve evaluator resource name to fully-qualified format required by API
358364
evaluator_resource_name = f"accounts/{account_id}/evaluators/{evaluator_id}"
359365

@@ -404,6 +410,8 @@ def create_rft_command(args) -> int:
404410
except Exception:
405411
rel = abs_path
406412
selected_entry = f"{rel}::{func_name}"
413+
selected_test_file_path = tests[0].file_path
414+
selected_test_func_name = func_name
407415
else:
408416
# Try to match evaluator_id to a discovered test's normalized ID
409417
for t in tests:
@@ -417,6 +425,8 @@ def create_rft_command(args) -> int:
417425
except Exception:
418426
rel = abs_path
419427
selected_entry = f"{rel}::{func_name}"
428+
selected_test_file_path = t.file_path
429+
selected_test_func_name = func_name
420430
break
421431
# If still unresolved and multiple tests exist, fail fast to avoid uploading unintended evaluators
422432
if selected_entry is None and len(tests) > 1:
@@ -480,30 +490,48 @@ def create_rft_command(args) -> int:
480490
dataset_builder = getattr(args, "dataset_builder", None) # accepted but unused in simplified flow
481491

482492
if not dataset_id:
483-
# Prefer explicit --dataset-jsonl, else attempt to extract from data loader or input_dataset of the single discovered test
493+
# Prefer explicit --dataset-jsonl, else attempt to extract from the selected test's data loader or input_dataset.
484494
if not dataset_jsonl:
485-
tests = _discover_tests(project_root)
486-
if len(tests) == 1:
487-
func_name = tests[0].qualname.split(".")[-1]
488-
# Try data_loaders first (existing behavior)
489-
dataset_jsonl = _extract_jsonl_from_dataloader(tests[0].file_path, func_name)
495+
# Use specifically selected test if available; else only infer when exactly one test exists
496+
test_file_for_infer = None
497+
func_for_infer = None
498+
if selected_test_file_path and selected_test_func_name:
499+
test_file_for_infer = selected_test_file_path
500+
func_for_infer = selected_test_func_name
501+
else:
502+
tests = _discover_tests(project_root)
503+
if len(tests) == 1:
504+
test_file_for_infer = tests[0].file_path
505+
func_for_infer = tests[0].qualname.split(".")[-1]
506+
if test_file_for_infer and func_for_infer:
507+
# Try data_loaders first
508+
dataset_jsonl = _extract_jsonl_from_dataloader(test_file_for_infer, func_for_infer)
490509
if dataset_jsonl:
491-
# Display relative path for readability
492510
try:
493511
rel = os.path.relpath(dataset_jsonl, project_root)
494512
except Exception:
495513
rel = dataset_jsonl
496514
print(f"✓ Using JSONL from data loader: {rel}")
497-
else:
515+
if not dataset_jsonl:
498516
# Fall back to input_dataset (dataset_path)
499-
dataset_jsonl = _extract_jsonl_from_input_dataset(tests[0].file_path, func_name)
517+
dataset_jsonl = _extract_jsonl_from_input_dataset(test_file_for_infer, func_for_infer)
500518
if dataset_jsonl:
501-
# Display relative path for readability
502519
try:
503520
rel = os.path.relpath(dataset_jsonl, project_root)
504521
except Exception:
505522
rel = dataset_jsonl
506523
print(f"✓ Using JSONL from input_dataset: {rel}")
524+
if not dataset_jsonl:
525+
# Last resort: attempt to detect and run a dataset builder in the test's directory
526+
metric_dir = os.path.dirname(test_file_for_infer)
527+
builder_spec = detect_dataset_builder(metric_dir)
528+
if builder_spec:
529+
try:
530+
tmp_jsonl, count = materialize_dataset_via_builder(builder_spec)
531+
dataset_jsonl = tmp_jsonl
532+
print(f"✓ Materialized {count} rows via dataset builder: {builder_spec}")
533+
except Exception as e:
534+
print(f"Warning: dataset builder failed: {e}")
507535
if not dataset_jsonl:
508536
print(
509537
"Error: Could not determine dataset. Provide --dataset-id or --dataset-jsonl, or ensure a JSONL-based data loader or input_dataset is used in your single discovered test."

tests/test_cli_create_rft_infer.py

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,3 +386,232 @@ def _raise(*a, **k):
386386

387387
rc = cr.create_rft_command(args)
388388
assert rc == 1
389+
390+
391+
def test_create_rft_fallback_to_dataset_builder(tmp_path, monkeypatch):
392+
# Setup project
393+
project = tmp_path / "proj"
394+
project.mkdir()
395+
monkeypatch.chdir(project)
396+
397+
# Single discovered test without data_loaders or input_dataset
398+
test_file = project / "metric" / "test_builder.py"
399+
test_file.parent.mkdir(parents=True, exist_ok=True)
400+
test_file.write_text("# builder case", encoding="utf-8")
401+
single_disc = SimpleNamespace(qualname="metric.test_builder", file_path=str(test_file))
402+
monkeypatch.setattr(cr, "_discover_tests", lambda cwd: [single_disc])
403+
404+
# Environment
405+
monkeypatch.setenv("FIREWORKS_API_KEY", "fw_dummy")
406+
monkeypatch.setenv("FIREWORKS_ACCOUNT_ID", "acct123")
407+
monkeypatch.setenv("FIREWORKS_API_BASE", "https://api.fireworks.ai")
408+
409+
# Stub selector, upload, and polling
410+
import eval_protocol.cli_commands.upload as upload_mod
411+
412+
monkeypatch.setattr(upload_mod, "_prompt_select", lambda tests, non_interactive=False: tests[:1])
413+
monkeypatch.setattr(upload_mod, "upload_command", lambda args: 0)
414+
monkeypatch.setattr(cr, "_poll_evaluator_status", lambda **kwargs: True)
415+
416+
# Dataset builder fallback
417+
out_jsonl = project / "metric" / "builder_out.jsonl"
418+
out_jsonl.write_text('{"row":1}\n{"row":2}\n', encoding="utf-8")
419+
420+
monkeypatch.setattr(cr, "detect_dataset_builder", lambda metric_dir: "builder.py::build_training_dataset")
421+
monkeypatch.setattr(cr, "materialize_dataset_via_builder", lambda spec: (str(out_jsonl), 2))
422+
423+
# Capture dataset creation args
424+
captured = {"dataset_id": None, "jsonl_path": None}
425+
426+
def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, display_name, jsonl_path):
427+
captured["dataset_id"] = dataset_id
428+
captured["jsonl_path"] = jsonl_path
429+
return dataset_id, {"name": f"accounts/{account_id}/datasets/{dataset_id}", "state": "UPLOADING"}
430+
431+
monkeypatch.setattr(cr, "create_dataset_from_jsonl", _fake_create_dataset_from_jsonl)
432+
monkeypatch.setattr(cr, "create_reinforcement_fine_tuning_job", lambda *a, **k: {"name": "jobs/123"})
433+
434+
# Run without dataset inputs so builder path is used
435+
import argparse
436+
437+
args = argparse.Namespace(
438+
evaluator_id=None,
439+
yes=True,
440+
dry_run=False,
441+
force=False,
442+
env_file=None,
443+
dataset_id=None,
444+
dataset_jsonl=None,
445+
dataset_display_name=None,
446+
dataset_builder=None,
447+
base_model=None,
448+
warm_start_from="accounts/acct123/models/ft-abc123",
449+
output_model=None,
450+
n=None,
451+
max_tokens=None,
452+
learning_rate=None,
453+
batch_size=None,
454+
epochs=None,
455+
lora_rank=None,
456+
max_context_length=None,
457+
chunk_size=None,
458+
eval_auto_carveout=None,
459+
)
460+
461+
rc = cr.create_rft_command(args)
462+
assert rc == 0
463+
# Evaluator id derived from test_builder -> "test-builder-test-builder"
464+
assert captured["dataset_id"] is not None
465+
assert captured["dataset_id"].startswith("test-builder-test-builder-dataset-")
466+
# Ensure we used the materialized JSONL
467+
assert captured["jsonl_path"] == str(out_jsonl)
468+
469+
470+
def test_create_rft_uses_dataloader_jsonl_when_available(tmp_path, monkeypatch):
471+
# Setup project
472+
project = tmp_path / "proj"
473+
project.mkdir()
474+
monkeypatch.chdir(project)
475+
476+
# Single discovered test
477+
test_file = project / "metric" / "test_loader.py"
478+
test_file.parent.mkdir(parents=True, exist_ok=True)
479+
test_file.write_text("# loader case", encoding="utf-8")
480+
single_disc = SimpleNamespace(qualname="metric.test_loader", file_path=str(test_file))
481+
monkeypatch.setattr(cr, "_discover_tests", lambda cwd: [single_disc])
482+
483+
# Environment
484+
monkeypatch.setenv("FIREWORKS_API_KEY", "fw_dummy")
485+
monkeypatch.setenv("FIREWORKS_ACCOUNT_ID", "acct123")
486+
monkeypatch.setenv("FIREWORKS_API_BASE", "https://api.fireworks.ai")
487+
488+
# Stub selector, upload, and polling
489+
import eval_protocol.cli_commands.upload as upload_mod
490+
491+
monkeypatch.setattr(upload_mod, "_prompt_select", lambda tests, non_interactive=False: tests[:1])
492+
monkeypatch.setattr(upload_mod, "upload_command", lambda args: 0)
493+
monkeypatch.setattr(cr, "_poll_evaluator_status", lambda **kwargs: True)
494+
495+
# Provide JSONL via dataloader extractor
496+
dl_jsonl = project / "metric" / "loader_out.jsonl"
497+
dl_jsonl.write_text('{"a":1}\n', encoding="utf-8")
498+
monkeypatch.setattr(cr, "_extract_jsonl_from_dataloader", lambda f, fn: str(dl_jsonl))
499+
monkeypatch.setattr(cr, "_extract_jsonl_from_input_dataset", lambda f, fn: None)
500+
monkeypatch.setattr(cr, "detect_dataset_builder", lambda metric_dir: None)
501+
502+
captured = {"dataset_id": None, "jsonl_path": None}
503+
504+
def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, display_name, jsonl_path):
505+
captured["dataset_id"] = dataset_id
506+
captured["jsonl_path"] = jsonl_path
507+
return dataset_id, {"name": f"accounts/{account_id}/datasets/{dataset_id}", "state": "UPLOADING"}
508+
509+
monkeypatch.setattr(cr, "create_dataset_from_jsonl", _fake_create_dataset_from_jsonl)
510+
monkeypatch.setattr(cr, "create_reinforcement_fine_tuning_job", lambda *a, **k: {"name": "jobs/123"})
511+
512+
import argparse
513+
514+
args = argparse.Namespace(
515+
evaluator_id=None,
516+
yes=True,
517+
dry_run=False,
518+
force=False,
519+
env_file=None,
520+
dataset_id=None,
521+
dataset_jsonl=None,
522+
dataset_display_name=None,
523+
dataset_builder=None,
524+
base_model=None,
525+
warm_start_from="accounts/acct123/models/ft-abc123",
526+
output_model=None,
527+
n=None,
528+
max_tokens=None,
529+
learning_rate=None,
530+
batch_size=None,
531+
epochs=None,
532+
lora_rank=None,
533+
max_context_length=None,
534+
chunk_size=None,
535+
eval_auto_carveout=None,
536+
)
537+
538+
rc = cr.create_rft_command(args)
539+
assert rc == 0
540+
assert captured["dataset_id"] is not None
541+
assert captured["dataset_id"].startswith("test-loader-test-loader-dataset-")
542+
assert captured["jsonl_path"] == str(dl_jsonl)
543+
544+
545+
def test_create_rft_uses_input_dataset_jsonl_when_available(tmp_path, monkeypatch):
546+
# Setup project
547+
project = tmp_path / "proj"
548+
project.mkdir()
549+
monkeypatch.chdir(project)
550+
551+
# Single discovered test
552+
test_file = project / "metric" / "test_input_ds.py"
553+
test_file.parent.mkdir(parents=True, exist_ok=True)
554+
test_file.write_text("# input_dataset case", encoding="utf-8")
555+
single_disc = SimpleNamespace(qualname="metric.test_input_ds", file_path=str(test_file))
556+
monkeypatch.setattr(cr, "_discover_tests", lambda cwd: [single_disc])
557+
558+
# Environment
559+
monkeypatch.setenv("FIREWORKS_API_KEY", "fw_dummy")
560+
monkeypatch.setenv("FIREWORKS_ACCOUNT_ID", "acct123")
561+
monkeypatch.setenv("FIREWORKS_API_BASE", "https://api.fireworks.ai")
562+
563+
# Stub selector, upload, and polling
564+
import eval_protocol.cli_commands.upload as upload_mod
565+
566+
monkeypatch.setattr(upload_mod, "_prompt_select", lambda tests, non_interactive=False: tests[:1])
567+
monkeypatch.setattr(upload_mod, "upload_command", lambda args: 0)
568+
monkeypatch.setattr(cr, "_poll_evaluator_status", lambda **kwargs: True)
569+
570+
# Provide JSONL via input_dataset extractor
571+
id_jsonl = project / "metric" / "input_ds_out.jsonl"
572+
id_jsonl.write_text('{"b":2}\n', encoding="utf-8")
573+
monkeypatch.setattr(cr, "_extract_jsonl_from_dataloader", lambda f, fn: None)
574+
monkeypatch.setattr(cr, "_extract_jsonl_from_input_dataset", lambda f, fn: str(id_jsonl))
575+
monkeypatch.setattr(cr, "detect_dataset_builder", lambda metric_dir: None)
576+
577+
captured = {"dataset_id": None, "jsonl_path": None}
578+
579+
def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, display_name, jsonl_path):
580+
captured["dataset_id"] = dataset_id
581+
captured["jsonl_path"] = jsonl_path
582+
return dataset_id, {"name": f"accounts/{account_id}/datasets/{dataset_id}", "state": "UPLOADING"}
583+
584+
monkeypatch.setattr(cr, "create_dataset_from_jsonl", _fake_create_dataset_from_jsonl)
585+
monkeypatch.setattr(cr, "create_reinforcement_fine_tuning_job", lambda *a, **k: {"name": "jobs/123"})
586+
587+
import argparse
588+
589+
args = argparse.Namespace(
590+
evaluator_id=None,
591+
yes=True,
592+
dry_run=False,
593+
force=False,
594+
env_file=None,
595+
dataset_id=None,
596+
dataset_jsonl=None,
597+
dataset_display_name=None,
598+
dataset_builder=None,
599+
base_model=None,
600+
warm_start_from="accounts/acct123/models/ft-abc123",
601+
output_model=None,
602+
n=None,
603+
max_tokens=None,
604+
learning_rate=None,
605+
batch_size=None,
606+
epochs=None,
607+
lora_rank=None,
608+
max_context_length=None,
609+
chunk_size=None,
610+
eval_auto_carveout=None,
611+
)
612+
613+
rc = cr.create_rft_command(args)
614+
assert rc == 0
615+
assert captured["dataset_id"] is not None
616+
assert captured["dataset_id"].startswith("test-input-ds-test-input-ds-dataset-")
617+
assert captured["jsonl_path"] == str(id_jsonl)

0 commit comments

Comments
 (0)