Skip to content

Commit 6ecacb8

Browse files
committed
add test on flags
1 parent 3b23406 commit 6ecacb8

1 file changed

Lines changed: 136 additions & 0 deletions

File tree

tests/test_cli_create_rft_infer.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,142 @@ def _write_json(path: str, data: dict) -> None:
1515
json.dump(data, f)
1616

1717

18+
def test_create_rft_passes_all_flags_into_request_body(tmp_path, monkeypatch):
19+
# Isolate HOME and CWD
20+
monkeypatch.setenv("HOME", str(tmp_path / "home"))
21+
project = tmp_path / "proj"
22+
project.mkdir()
23+
monkeypatch.chdir(project)
24+
25+
# Environment required by command
26+
monkeypatch.setenv("FIREWORKS_API_KEY", "fw_dummy")
27+
monkeypatch.setenv("FIREWORKS_ACCOUNT_ID", "acct123")
28+
monkeypatch.setenv("FIREWORKS_API_BASE", "https://api.fireworks.ai")
29+
30+
# Provide dataset via --dataset-jsonl
31+
ds_path = project / "dataset.jsonl"
32+
ds_path.write_text('{"input":"x"}\n', encoding="utf-8")
33+
34+
# Skip upload: pretend evaluator exists and is ACTIVE
35+
class _Resp:
36+
ok = True
37+
38+
def json(self):
39+
return {"state": "ACTIVE"}
40+
41+
def raise_for_status(self):
42+
return None
43+
44+
monkeypatch.setattr(cr.requests, "get", lambda *a, **k: _Resp())
45+
46+
# Capture dataset creation inputs but let it succeed
47+
monkeypatch.setattr(
48+
cr,
49+
"create_dataset_from_jsonl",
50+
lambda account_id, api_key, api_base, dataset_id, display_name, jsonl_path: (
51+
dataset_id,
52+
{"name": f"accounts/{account_id}/datasets/{dataset_id}", "state": "UPLOADING"},
53+
),
54+
)
55+
56+
captured = {"body": None}
57+
58+
def _fake_create_job(account_id, api_key, api_base, body):
59+
captured["body"] = body
60+
return {"name": f"accounts/{account_id}/reinforcementFineTuningJobs/xyz"}
61+
62+
monkeypatch.setattr(cr, "create_reinforcement_fine_tuning_job", _fake_create_job)
63+
64+
import argparse
65+
66+
args = argparse.Namespace(
67+
# Evaluator and dataset
68+
evaluator_id="my-evaluator",
69+
dataset_id=None,
70+
dataset_jsonl=str(ds_path),
71+
dataset_display_name="My Dataset",
72+
dataset_builder=None,
73+
# Modes
74+
yes=True,
75+
dry_run=False,
76+
force=False,
77+
env_file=None,
78+
# Model selection (exactly one)
79+
base_model="accounts/fireworks/models/llama-v3p1-8b-instruct",
80+
warm_start_from=None,
81+
output_model="my-output-model",
82+
# Training config
83+
epochs=3,
84+
batch_size=65536,
85+
learning_rate=5e-5,
86+
lora_rank=32,
87+
max_context_length=131072,
88+
accelerator_count=4,
89+
region="us-east4",
90+
# Inference params
91+
temperature=0.9,
92+
top_p=0.95,
93+
top_k=50,
94+
max_tokens=4096,
95+
n=6,
96+
inference_extra_body='{"foo":"bar"}',
97+
# Rollout chunking and eval carveout
98+
chunk_size=250,
99+
eval_auto_carveout=False, # explicitly disabled via --no-eval-auto-carveout
100+
evaluation_dataset="accounts/acct123/datasets/eval-ds",
101+
# W&B
102+
wandb_enabled=True,
103+
wandb_project="proj",
104+
wandb_entity="ent",
105+
wandb_run_id="run123",
106+
wandb_api_key="key123",
107+
# Unused in body but accepted by parser
108+
rft_job_id=None,
109+
display_name=None,
110+
)
111+
112+
rc = cr.create_rft_command(args)
113+
assert rc == 0
114+
assert captured["body"] is not None
115+
body = captured["body"]
116+
117+
# Top-level fields
118+
assert body["dataset"].endswith("/datasets/" + body["dataset"].split("/")[-1])
119+
assert body["evaluator"].endswith("/evaluators/my-evaluator")
120+
assert body["chunkSize"] == 250
121+
assert body["evalAutoCarveout"] is False
122+
assert body["evaluationDataset"] == "accounts/acct123/datasets/eval-ds"
123+
124+
# Training config mapping
125+
tc = body["trainingConfig"]
126+
assert tc["baseModel"] == "accounts/fireworks/models/llama-v3p1-8b-instruct"
127+
assert tc["outputModel"] == "accounts/acct123/models/my-output-model"
128+
assert tc["epochs"] == 3
129+
assert tc["batchSize"] == 65536
130+
assert abs(tc["learningRate"] - 5e-5) < 1e-12
131+
assert tc["loraRank"] == 32
132+
assert tc["maxContextLength"] == 131072
133+
assert tc["acceleratorCount"] == 4
134+
assert tc["region"] == "us-east4"
135+
136+
# Inference params mapping
137+
ip = body["inferenceParameters"]
138+
assert abs(ip["temperature"] - 0.9) < 1e-12
139+
assert abs(ip["topP"] - 0.95) < 1e-12
140+
assert ip["topK"] == 50
141+
assert ip["maxTokens"] == 4096
142+
assert ip["n"] == 6
143+
assert ip["extraBody"] == '{"foo":"bar"}'
144+
145+
# W&B mapping
146+
wb = body["wandbConfig"]
147+
assert wb["enabled"] is True
148+
assert wb["project"] == "proj"
149+
assert wb["entity"] == "ent"
150+
assert wb["runId"] == "run123"
151+
assert wb["apiKey"] == "key123"
152+
153+
18154
def test_create_rft_picks_most_recent_evaluator_and_dataset_id_follows(tmp_path, monkeypatch):
19155
# Isolate HOME so expanduser paths remain inside tmp
20156
monkeypatch.setenv("HOME", str(tmp_path / "home"))

0 commit comments

Comments
 (0)