Skip to content

Commit 2459817

Browse files
committed
Fix and expand hyperparameter search space for PC.
1 parent d74e903 commit 2459817

1 file changed

Lines changed: 25 additions & 21 deletions

File tree

src/imitation/scripts/config/tuning.py

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -188,38 +188,42 @@ def pc():
188188
parallel_run_config = dict(
189189
sacred_ex_name="train_preference_comparisons",
190190
run_name="pc_tuning",
191-
base_named_configs=["logging.wandb_logging"],
191+
base_named_configs=[],
192192
base_config_updates={
193193
"environment": {"num_vec": 1},
194-
"demonstrations": {"source": "huggingface"},
195194
"total_timesteps": 2e7,
196-
"total_comparisons": 5000,
197-
"query_schedule": "hyperbolic",
198-
"gatherer_kwargs": {"sample": True},
195+
"total_comparisons": 1000,
196+
"active_selection": True,
199197
},
200198
search_space={
201-
"named_configs": [
202-
["reward.normalize_output_disable"],
203-
],
199+
"named_configs": ["reward.reward_ensemble"],
204200
"config_updates": {
205-
"train": {
206-
"policy_kwargs": {
207-
"activation_fn": tune.choice(
208-
[
209-
nn.ReLU,
210-
],
211-
),
212-
},
201+
"active_selection_oversampling": tune.randint(1, 11),
202+
"comparison_queue_size": tune.randint(1, 1001), # upper bound determined by total_comparisons=1000
203+
"exploration_frac": tune.uniform(0.0, 0.5),
204+
"fragment_length": tune.randint(1, 1001), # trajectories are 1000 steps long
205+
"gatherer_kwargs": {
206+
"temperature": tune.uniform(0.0, 2.0),
207+
"discount_factor": tune.uniform(0.95, 1.0),
208+
"sample": tune.choice([True, False]),
213209
},
214-
"num_iterations": tune.choice([25, 50]),
215-
"initial_comparison_frac": tune.choice([0.1, 0.25]),
210+
"initial_comparison_frac": tune.uniform(0.01, 1.0),
211+
"num_iterations": tune.randint(1, 51),
212+
"preference_model_kwargs": {
213+
"noise_prob": tune.uniform(0.0, 0.1),
214+
"discount_factor": tune.uniform(0.95, 1.0),
215+
},
216+
"query_schedule": tune.choice(["hyperbolic", "constant", "inverse_quadratic"]),
217+
"trajectory_generator_kwargs": {
218+
"switch_prob": tune.uniform(0.1, 1),
219+
"random_prob": tune.uniform(0.1, 0.9),
220+
},
221+
"transition_oversampling": tune.uniform(0.9, 2.0),
216222
"reward_trainer_kwargs": {
217-
"epochs": tune.choice([1, 3, 6]),
223+
"epochs": tune.randint(1, 11),
218224
},
219225
"rl": {
220-
"batch_size": tune.choice([512, 2048, 8192]),
221226
"rl_kwargs": {
222-
"learning_rate": tune.loguniform(1e-5, 1e-2),
223227
"ent_coef": tune.loguniform(1e-7, 1e-3),
224228
},
225229
},

0 commit comments

Comments
 (0)