@@ -188,38 +188,42 @@ def pc():
188188 parallel_run_config = dict (
189189 sacred_ex_name = "train_preference_comparisons" ,
190190 run_name = "pc_tuning" ,
191- base_named_configs = ["logging.wandb_logging" ],
191+ base_named_configs = [],
192192 base_config_updates = {
193193 "environment" : {"num_vec" : 1 },
194- "demonstrations" : {"source" : "huggingface" },
195194 "total_timesteps" : 2e7 ,
196- "total_comparisons" : 5000 ,
197- "query_schedule" : "hyperbolic" ,
198- "gatherer_kwargs" : {"sample" : True },
195+ "total_comparisons" : 1000 ,
196+ "active_selection" : True ,
199197 },
200198 search_space = {
201- "named_configs" : [
202- ["reward.normalize_output_disable" ],
203- ],
199+ "named_configs" : ["reward.reward_ensemble" ],
204200 "config_updates" : {
205- "train " : {
206- "policy_kwargs" : {
207- "activation_fn " : tune .choice (
208- [
209- nn . ReLU ,
210- ] ,
211- ),
212- } ,
201+ "active_selection_oversampling " : tune . randint ( 1 , 11 ),
202+ "comparison_queue_size" : tune . randint ( 1 , 1001 ), # upper bound determined by total_comparisons=1000
203+ "exploration_frac " : tune .uniform ( 0.0 , 0.5 ),
204+ "fragment_length" : tune . randint ( 1 , 1001 ), # trajectories are 1000 steps long
205+ "gatherer_kwargs" : {
206+ "temperature" : tune . uniform ( 0.0 , 2.0 ) ,
207+ "discount_factor" : tune . uniform ( 0.95 , 1.0 ),
208+ "sample" : tune . choice ([ True , False ]) ,
213209 },
214- "num_iterations" : tune .choice ([25 , 50 ]),
215- "initial_comparison_frac" : tune .choice ([0.1 , 0.25 ]),
210+ "initial_comparison_frac" : tune .uniform (0.01 , 1.0 ),
211+ "num_iterations" : tune .randint (1 , 51 ),
212+ "preference_model_kwargs" : {
213+ "noise_prob" : tune .uniform (0.0 , 0.1 ),
214+ "discount_factor" : tune .uniform (0.95 , 1.0 ),
215+ },
216+ "query_schedule" : tune .choice (["hyperbolic" , "constant" , "inverse_quadratic" ]),
217+ "trajectory_generator_kwargs" : {
218+ "switch_prob" : tune .uniform (0.1 , 1 ),
219+ "random_prob" : tune .uniform (0.1 , 0.9 ),
220+ },
221+ "transition_oversampling" : tune .uniform (0.9 , 2.0 ),
216222 "reward_trainer_kwargs" : {
217- "epochs" : tune .choice ([ 1 , 3 , 6 ] ),
223+ "epochs" : tune .randint ( 1 , 11 ),
218224 },
219225 "rl" : {
220- "batch_size" : tune .choice ([512 , 2048 , 8192 ]),
221226 "rl_kwargs" : {
222- "learning_rate" : tune .loguniform (1e-5 , 1e-2 ),
223227 "ent_coef" : tune .loguniform (1e-7 , 1e-3 ),
224228 },
225229 },
0 commit comments