Skip to content

Commit f099c33

Browse files
authored
Reduce training time and improve expert loading code in the tutorials (#810)
* Use normal env name when loading policies in the tutorial notebooks. * Use load_policy from the utils to load a policy in the density tutorial notebook. * Reduce the number of train steps in the FAST mode of the AIRL tutorial notebook.
1 parent 7b8b4bf commit f099c33

5 files changed

Lines changed: 13 additions & 11 deletions

File tree

docs/tutorials/1_train_bc.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
"expert = load_policy(\n",
4848
" \"ppo-huggingface\",\n",
4949
" organization=\"HumanCompatibleAI\",\n",
50-
" env_name=\"seals-CartPole-v0\",\n",
50+
" env_name=\"seals/CartPole-v0\",\n",
5151
" venv=env,\n",
5252
")"
5353
]

docs/tutorials/2_train_dagger.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
"expert = load_policy(\n",
4040
" \"ppo-huggingface\",\n",
4141
" organization=\"HumanCompatibleAI\",\n",
42-
" env_name=\"seals-CartPole-v0\",\n",
42+
" env_name=\"seals/CartPole-v0\",\n",
4343
" venv=env,\n",
4444
")"
4545
]

docs/tutorials/3_train_gail.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
"expert = load_policy(\n",
4545
" \"ppo-huggingface\",\n",
4646
" organization=\"HumanCompatibleAI\",\n",
47-
" env_name=\"seals:seals/CartPole-v0\",\n",
47+
" env_name=\"seals/CartPole-v0\",\n",
4848
" venv=env,\n",
4949
")"
5050
]

docs/tutorials/4_train_airl.ipynb

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
"metadata": {},
2424
"outputs": [],
2525
"source": [
26-
"import seals # noqa: F401 # needed to load \"seals/\" environments\n",
2726
"import numpy as np\n",
2827
"from imitation.policies.serialize import load_policy\n",
2928
"from imitation.util.util import make_vec_env\n",
@@ -34,12 +33,12 @@
3433
"FAST = True\n",
3534
"\n",
3635
"if FAST:\n",
37-
" N_RL_TRAIN_STEPS = 300_000\n",
36+
" N_RL_TRAIN_STEPS = 100_000\n",
3837
"else:\n",
3938
" N_RL_TRAIN_STEPS = 2_000_000\n",
4039
"\n",
4140
"venv = make_vec_env(\n",
42-
" \"seals/CartPole-v0\",\n",
41+
" \"seals:seals/CartPole-v0\",\n",
4342
" rng=np.random.default_rng(SEED),\n",
4443
" n_envs=8,\n",
4544
" post_wrappers=[\n",
@@ -49,7 +48,7 @@
4948
"expert = load_policy(\n",
5049
" \"ppo-huggingface\",\n",
5150
" organization=\"HumanCompatibleAI\",\n",
52-
" env_name=\"seals-CartPole-v0\",\n",
51+
" env_name=\"seals/CartPole-v0\",\n",
5352
" venv=venv,\n",
5453
")"
5554
]

docs/tutorials/7_train_density.ipynb

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,9 @@
5656
"metadata": {},
5757
"outputs": [],
5858
"source": [
59+
"from imitation.policies.serialize import load_policy\n",
5960
"from stable_baselines3.common.policies import ActorCriticPolicy\n",
6061
"from stable_baselines3 import PPO\n",
61-
"from huggingface_sb3 import load_from_hub\n",
6262
"from imitation.data import rollout\n",
6363
"from stable_baselines3.common.vec_env import DummyVecEnv\n",
6464
"from stable_baselines3.common.evaluation import evaluate_policy\n",
@@ -70,12 +70,15 @@
7070
"\n",
7171
"rng = np.random.default_rng(seed=SEED)\n",
7272
"env_name = \"Pendulum-v1\"\n",
73-
"expert = PPO.load(\n",
74-
" load_from_hub(\"HumanCompatibleAI/ppo-Pendulum-v1\", \"ppo-Pendulum-v1.zip\")\n",
75-
").policy\n",
7673
"rollout_env = DummyVecEnv(\n",
7774
" [lambda: RolloutInfoWrapper(gym.make(env_name)) for _ in range(N_VEC)]\n",
7875
")\n",
76+
"expert = load_policy(\n",
77+
" \"ppo-huggingface\",\n",
78+
" organization=\"HumanCompatibleAI\",\n",
79+
" env_name=env_name,\n",
80+
" venv=rollout_env,\n",
81+
")\n",
7982
"rollouts = rollout.rollout(\n",
8083
" expert,\n",
8184
" rollout_env,\n",

0 commit comments

Comments
 (0)