@@ -15,6 +15,142 @@ def _write_json(path: str, data: dict) -> None:
1515 json .dump (data , f )
1616
1717
18+ def test_create_rft_passes_all_flags_into_request_body (tmp_path , monkeypatch ):
19+ # Isolate HOME and CWD
20+ monkeypatch .setenv ("HOME" , str (tmp_path / "home" ))
21+ project = tmp_path / "proj"
22+ project .mkdir ()
23+ monkeypatch .chdir (project )
24+
25+ # Environment required by command
26+ monkeypatch .setenv ("FIREWORKS_API_KEY" , "fw_dummy" )
27+ monkeypatch .setenv ("FIREWORKS_ACCOUNT_ID" , "acct123" )
28+ monkeypatch .setenv ("FIREWORKS_API_BASE" , "https://api.fireworks.ai" )
29+
30+ # Provide dataset via --dataset-jsonl
31+ ds_path = project / "dataset.jsonl"
32+ ds_path .write_text ('{"input":"x"}\n ' , encoding = "utf-8" )
33+
34+ # Skip upload: pretend evaluator exists and is ACTIVE
35+ class _Resp :
36+ ok = True
37+
38+ def json (self ):
39+ return {"state" : "ACTIVE" }
40+
41+ def raise_for_status (self ):
42+ return None
43+
44+ monkeypatch .setattr (cr .requests , "get" , lambda * a , ** k : _Resp ())
45+
46+ # Capture dataset creation inputs but let it succeed
47+ monkeypatch .setattr (
48+ cr ,
49+ "create_dataset_from_jsonl" ,
50+ lambda account_id , api_key , api_base , dataset_id , display_name , jsonl_path : (
51+ dataset_id ,
52+ {"name" : f"accounts/{ account_id } /datasets/{ dataset_id } " , "state" : "UPLOADING" },
53+ ),
54+ )
55+
56+ captured = {"body" : None }
57+
58+ def _fake_create_job (account_id , api_key , api_base , body ):
59+ captured ["body" ] = body
60+ return {"name" : f"accounts/{ account_id } /reinforcementFineTuningJobs/xyz" }
61+
62+ monkeypatch .setattr (cr , "create_reinforcement_fine_tuning_job" , _fake_create_job )
63+
64+ import argparse
65+
66+ args = argparse .Namespace (
67+ # Evaluator and dataset
68+ evaluator_id = "my-evaluator" ,
69+ dataset_id = None ,
70+ dataset_jsonl = str (ds_path ),
71+ dataset_display_name = "My Dataset" ,
72+ dataset_builder = None ,
73+ # Modes
74+ yes = True ,
75+ dry_run = False ,
76+ force = False ,
77+ env_file = None ,
78+ # Model selection (exactly one)
79+ base_model = "accounts/fireworks/models/llama-v3p1-8b-instruct" ,
80+ warm_start_from = None ,
81+ output_model = "my-output-model" ,
82+ # Training config
83+ epochs = 3 ,
84+ batch_size = 65536 ,
85+ learning_rate = 5e-5 ,
86+ lora_rank = 32 ,
87+ max_context_length = 131072 ,
88+ accelerator_count = 4 ,
89+ region = "us-east4" ,
90+ # Inference params
91+ temperature = 0.9 ,
92+ top_p = 0.95 ,
93+ top_k = 50 ,
94+ max_tokens = 4096 ,
95+ n = 6 ,
96+ inference_extra_body = '{"foo":"bar"}' ,
97+ # Rollout chunking and eval carveout
98+ chunk_size = 250 ,
99+ eval_auto_carveout = False , # explicitly disabled via --no-eval-auto-carveout
100+ evaluation_dataset = "accounts/acct123/datasets/eval-ds" ,
101+ # W&B
102+ wandb_enabled = True ,
103+ wandb_project = "proj" ,
104+ wandb_entity = "ent" ,
105+ wandb_run_id = "run123" ,
106+ wandb_api_key = "key123" ,
107+ # Unused in body but accepted by parser
108+ rft_job_id = None ,
109+ display_name = None ,
110+ )
111+
112+ rc = cr .create_rft_command (args )
113+ assert rc == 0
114+ assert captured ["body" ] is not None
115+ body = captured ["body" ]
116+
117+ # Top-level fields
118+ assert body ["dataset" ].endswith ("/datasets/" + body ["dataset" ].split ("/" )[- 1 ])
119+ assert body ["evaluator" ].endswith ("/evaluators/my-evaluator" )
120+ assert body ["chunkSize" ] == 250
121+ assert body ["evalAutoCarveout" ] is False
122+ assert body ["evaluationDataset" ] == "accounts/acct123/datasets/eval-ds"
123+
124+ # Training config mapping
125+ tc = body ["trainingConfig" ]
126+ assert tc ["baseModel" ] == "accounts/fireworks/models/llama-v3p1-8b-instruct"
127+ assert tc ["outputModel" ] == "accounts/acct123/models/my-output-model"
128+ assert tc ["epochs" ] == 3
129+ assert tc ["batchSize" ] == 65536
130+ assert abs (tc ["learningRate" ] - 5e-5 ) < 1e-12
131+ assert tc ["loraRank" ] == 32
132+ assert tc ["maxContextLength" ] == 131072
133+ assert tc ["acceleratorCount" ] == 4
134+ assert tc ["region" ] == "us-east4"
135+
136+ # Inference params mapping
137+ ip = body ["inferenceParameters" ]
138+ assert abs (ip ["temperature" ] - 0.9 ) < 1e-12
139+ assert abs (ip ["topP" ] - 0.95 ) < 1e-12
140+ assert ip ["topK" ] == 50
141+ assert ip ["maxTokens" ] == 4096
142+ assert ip ["n" ] == 6
143+ assert ip ["extraBody" ] == '{"foo":"bar"}'
144+
145+ # W&B mapping
146+ wb = body ["wandbConfig" ]
147+ assert wb ["enabled" ] is True
148+ assert wb ["project" ] == "proj"
149+ assert wb ["entity" ] == "ent"
150+ assert wb ["runId" ] == "run123"
151+ assert wb ["apiKey" ] == "key123"
152+
153+
18154def test_create_rft_picks_most_recent_evaluator_and_dataset_id_follows (tmp_path , monkeypatch ):
19155 # Isolate HOME so expanduser paths remain inside tmp
20156 monkeypatch .setenv ("HOME" , str (tmp_path / "home" ))
0 commit comments