Skip to content

Commit 681cb72

Browse files
Add support to download rollouts from huggingface (#629)
* Add support to download rollouts from huggingface * Add check for rollout_path * Fix type error * Add test cases * Incorporate reviewer's suggestions * Change pkl to npz * Add input validation for rollout_type * Fix error * Update src/imitation/scripts/common/demonstrations.py Co-authored-by: Adam Gleave <adam@gleave.me> Co-authored-by: Adam Gleave <adam@gleave.me>
1 parent f7fecad commit 681cb72

3 files changed

Lines changed: 110 additions & 6 deletions

File tree

src/imitation/policies/serialize.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,3 +230,14 @@ def _on_step(self) -> bool:
230230
output_dir = self.policy_dir / f"{self.num_timesteps:012d}"
231231
save_stable_model(output_dir, self.model)
232232
return True
233+
234+
235+
def load_rollouts_from_huggingface(
236+
algo_name: str,
237+
env_name: str,
238+
organization: str = "HumanCompatibleAI",
239+
) -> str:
240+
model_name = hfsb3.ModelName(algo_name, hfsb3.EnvironmentName(env_name))
241+
repo_id = hfsb3.ModelRepoId(organization, model_name)
242+
filename = hfsb3.load_from_hub(repo_id, "rollouts.npz")
243+
return filename

src/imitation/scripts/ingredients/demonstrations.py

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
"""Ingredient for scripts learning from demonstrations."""
22

33
import logging
4-
from typing import Optional, Sequence
4+
import pathlib
5+
import warnings
6+
from typing import Optional, Sequence, Union
57

68
import numpy as np
79
import sacred
810

911
from imitation.data import rollout, types
12+
from imitation.policies import serialize
1013
from imitation.scripts.ingredients import environment, expert
1114
from imitation.scripts.ingredients import logging as logging_ingredient
1215

@@ -23,7 +26,9 @@
2326

2427
@demonstrations_ingredient.config
2528
def config():
26-
# path to file containing rollouts. If None, they are sampled from the expert.
29+
rollout_type = "local"
30+
# path to file containing rollouts. If rollout_path is None
31+
# and rollout_type is local, they are sampled from the expert.
2732
rollout_path = None
2833
n_expert_demos = None # Num demos used or sampled. None loads every demo possible.
2934
locals() # quieten flake8
@@ -38,10 +43,39 @@ def fast():
3843

3944
@demonstrations_ingredient.capture
4045
def get_expert_trajectories(
46+
rollout_type: str,
4147
rollout_path: str,
4248
) -> Sequence[types.Trajectory]:
49+
"""Loads expert demonstrations.
50+
51+
Args:
52+
rollout_type: Can be either `local` to load rollouts from the disk or to
53+
generate them locally or of the format `{algo}-huggingface` to load
54+
from the huggingface hub of expert trained using `{algo}`.
55+
rollout_path: A path containing a pickled sequence of `types.Trajectory`.
56+
57+
Returns:
58+
The expert trajectories.
59+
60+
Raises:
61+
ValueError: if `rollout_type` is not "local" or of the form {algo}-huggingface.
62+
"""
63+
if rollout_type.endswith("-huggingface"):
64+
if rollout_path is not None:
65+
warnings.warn(
66+
"Ignoring `rollout_path` since `rollout_type` is set to download the "
67+
"rollouts from the huggingface-hub. If you want to load the rollouts "
68+
'from disk, set `rollout_type`="local" and the path in `rollout_path`.',
69+
RuntimeWarning,
70+
)
71+
rollout_path = _download_expert_rollouts(rollout_type)
72+
elif rollout_type != "local":
73+
raise ValueError(
74+
"`rollout_type` can either be `local` or of the form `{algo}-huggingface`.",
75+
)
76+
4377
if rollout_path is not None:
44-
return load_expert_trajs()
78+
return load_local_expert_trajs(rollout_path)
4579
else:
4680
return generate_expert_trajs()
4781

@@ -77,11 +111,11 @@ def generate_expert_trajs(
77111

78112

79113
@demonstrations_ingredient.capture
80-
def load_expert_trajs(
81-
rollout_path: str,
114+
def load_local_expert_trajs(
115+
rollout_path: Union[str, pathlib.Path],
82116
n_expert_demos: Optional[int],
83117
) -> Sequence[types.Trajectory]:
84-
"""Loads expert demonstrations.
118+
"""Loads expert demonstrations from a local path.
85119
86120
Args:
87121
rollout_path: A path containing a pickled sequence of `types.Trajectory`.
@@ -105,3 +139,14 @@ def load_expert_trajs(
105139
expert_trajs = expert_trajs[:n_expert_demos]
106140
logger.info(f"Truncated to {n_expert_demos} expert trajectories")
107141
return expert_trajs
142+
143+
144+
@demonstrations_ingredient.capture(prefix="expert")
145+
def _download_expert_rollouts(rollout_type, loader_kwargs):
146+
assert rollout_type.endswith("-huggingface")
147+
algo_name = rollout_type.split("-")[0]
148+
return serialize.load_rollouts_from_huggingface(
149+
algo_name,
150+
env_name=loader_kwargs["env_name"],
151+
organization=loader_kwargs["organization"],
152+
)

tests/scripts/test_scripts.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,54 @@ def test_train_bc_main_with_none_demonstrations_raises_value_error(tmpdir):
313313
)
314314

315315

316+
def test_train_bc_main_with_demonstrations_from_huggingface(tmpdir):
317+
train_imitation.train_imitation_ex.run(
318+
command_name="bc",
319+
named_configs=["seals_cartpole"] + ALGO_FAST_CONFIGS["imitation"],
320+
config_updates=dict(
321+
logging=dict(log_root=tmpdir),
322+
demonstrations=dict(rollout_type="ppo-huggingface"),
323+
),
324+
)
325+
326+
327+
def test_train_bc_main_with_demonstrations_raises_error_on_wrong_huggingface_format(
328+
tmpdir,
329+
):
330+
with pytest.raises(
331+
ValueError,
332+
match="`rollout_type` can either be `local` or of the form .*-huggingface.S*",
333+
):
334+
train_imitation.train_imitation_ex.run(
335+
command_name="bc",
336+
named_configs=["seals_cartpole"] + ALGO_FAST_CONFIGS["imitation"],
337+
config_updates=dict(
338+
logging=dict(log_root=tmpdir),
339+
demonstrations=dict(rollout_type="huggingface-ppo"),
340+
),
341+
)
342+
343+
344+
def test_train_bc_main_with_demonstrations_warns_setting_rollout_type(
345+
tmpdir,
346+
):
347+
with pytest.warns(
348+
RuntimeWarning,
349+
match="Ignoring `rollout_path` .*",
350+
):
351+
train_imitation.train_imitation_ex.run(
352+
command_name="bc",
353+
named_configs=["seals_cartpole"] + ALGO_FAST_CONFIGS["imitation"],
354+
config_updates=dict(
355+
logging=dict(log_root=tmpdir),
356+
demonstrations=dict(
357+
rollout_type="ppo-huggingface",
358+
rollout_path="path",
359+
),
360+
),
361+
)
362+
363+
316364
@pytest.fixture(
317365
params=[
318366
"expert_from_path",

0 commit comments

Comments
 (0)