|
56 | 56 | "metadata": {}, |
57 | 57 | "outputs": [], |
58 | 58 | "source": [ |
| 59 | + "from imitation.policies.serialize import load_policy\n", |
59 | 60 | "from stable_baselines3.common.policies import ActorCriticPolicy\n", |
60 | 61 | "from stable_baselines3 import PPO\n", |
61 | | - "from huggingface_sb3 import load_from_hub\n", |
62 | 62 | "from imitation.data import rollout\n", |
63 | 63 | "from stable_baselines3.common.vec_env import DummyVecEnv\n", |
64 | 64 | "from stable_baselines3.common.evaluation import evaluate_policy\n", |
|
70 | 70 | "\n", |
71 | 71 | "rng = np.random.default_rng(seed=SEED)\n", |
72 | 72 | "env_name = \"Pendulum-v1\"\n", |
73 | | - "expert = PPO.load(\n", |
74 | | - " load_from_hub(\"HumanCompatibleAI/ppo-Pendulum-v1\", \"ppo-Pendulum-v1.zip\")\n", |
75 | | - ").policy\n", |
76 | 73 | "rollout_env = DummyVecEnv(\n", |
77 | 74 | " [lambda: RolloutInfoWrapper(gym.make(env_name)) for _ in range(N_VEC)]\n", |
78 | 75 | ")\n", |
| 76 | + "expert = load_policy(\n", |
| 77 | + " \"ppo-huggingface\",\n", |
| 78 | + " organization=\"HumanCompatibleAI\",\n", |
| 79 | + " env_name=env_name,\n", |
| 80 | + " venv=rollout_env,\n", |
| 81 | + ")\n", |
79 | 82 | "rollouts = rollout.rollout(\n", |
80 | 83 | " expert,\n", |
81 | 84 | " rollout_env,\n", |
|
0 commit comments