|
1 | 1 | """Configuration for imitation.scripts.train_preference_comparisons.""" |
2 | 2 |
|
3 | 3 | import sacred |
| 4 | +import stable_baselines3 as sb3 |
4 | 5 |
|
5 | 6 | from imitation.algorithms import preference_comparisons |
| 7 | +from imitation.policies import base |
6 | 8 | from imitation.scripts.common import common, reward, rl, train |
7 | 9 |
|
8 | 10 | train_preference_comparisons_ex = sacred.Experiment( |
|
15 | 17 | ], |
16 | 18 | ) |
17 | 19 |
|
18 | | - |
19 | 20 | MUJOCO_SHARED_LOCALS = dict(rl=dict(rl_kwargs=dict(ent_coef=0.1))) |
20 | 21 | ANT_SHARED_LOCALS = dict( |
21 | 22 | total_timesteps=int(3e7), |
@@ -61,6 +62,26 @@ def train_defaults(): |
61 | 62 | query_schedule = "hyperbolic" |
62 | 63 |
|
63 | 64 |
|
| 65 | +@train_preference_comparisons_ex.named_config |
| 66 | +def pebble(): |
| 67 | + # fraction of total_timesteps for training before preference gathering |
| 68 | + unsupervised_agent_pretrain_frac = 0.05 |
| 69 | + pebble_nearest_neighbor_k = 5 |
| 70 | + |
| 71 | + rl = { |
| 72 | + "rl_cls": sb3.SAC, |
| 73 | + "batch_size": 256, # batch size for RL algorithm |
| 74 | + "rl_kwargs": {"batch_size": None}, # make sure to set batch size to None |
| 75 | + } |
| 76 | + train = { |
| 77 | + "policy_cls": base.SAC1024Policy, # noqa: F841 |
| 78 | + } |
| 79 | + common = {"env_name": "MountainCarContinuous-v0"} |
| 80 | + allow_variable_horizon = True |
| 81 | + |
| 82 | + locals() # quieten flake8 |
| 83 | + |
| 84 | + |
64 | 85 | @train_preference_comparisons_ex.named_config |
65 | 86 | def cartpole(): |
66 | 87 | common = dict(env_name="CartPole-v1") |
@@ -121,6 +142,7 @@ def fast(): |
121 | 142 | total_timesteps = 50 |
122 | 143 | total_comparisons = 5 |
123 | 144 | initial_comparison_frac = 0.2 |
| 145 | + unsupervised_agent_pretrain_frac = 0.2 |
124 | 146 | num_iterations = 1 |
125 | 147 | fragment_length = 2 |
126 | 148 | reward_trainer_kwargs = { |
|
0 commit comments