Skip to content

Commit bbdcb29

Browse files
authored
Add support for HuggingFace Datasets (#677)
* Make the imitation.types.safe/load functions store/retrieve from HuggingFace Datasets. * Add datasets dependency. * Add our own transform to numpy arrays to work around huggingface/datasets#5517 * Move serialization related code from imitation.data.types to imitation.data.serialize and encode infos using jsonpickle to support arbitrary infos structure. * Move numpy conversion logic into huggingface_datasets_conversion.py * Improve warnings and comments. * Fix convert_trajs.py script and its tests. * Add no cover pragma to convert_trajs main. * Add no cover pragma to the case of loading an unknown trajectory format. * Fix inconsistent imports. * Rename huggingface_datasets_conversion.py to huggingface_utils.py and fix the documentation of imitation.types.serialize.save. * Normalize imitation.data.serialize imports across the repo, move parse_path to the utils and load_rollouts_from_huggingface to data.serialize. * Remove now unneeded pytype error suppression. * Fix formatting in convert_trajs
1 parent 19e4b9b commit bbdcb29

29 files changed

Lines changed: 547 additions & 328 deletions

.gitattributes

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
11
*.ipynb linguist-vendored
2+
tests/testdata/pickle_format_rollout.pkl filter=lfs diff=lfs merge=lfs -text
3+
tests/testdata/npz_format_rollout.npz filter=lfs diff=lfs merge=lfs -text

docs/algorithms/density.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@ Detailed example notebook: :doc:`../tutorials/7_train_density`
1919
from stable_baselines3.common.policies import ActorCriticPolicy
2020

2121
from imitation.algorithms import density as db
22-
from imitation.data import types
22+
from imitation.data import serialize
2323
from imitation.util import util
2424

2525
rng = np.random.default_rng(0)
2626

2727
env = util.make_vec_env("Pendulum-v1", rng=rng, n_envs=2)
28-
rollouts = types.load("../tests/testdata/expert_models/pendulum_0/rollouts/final.npz")
28+
rollouts = serialize.load("../tests/testdata/expert_models/pendulum_0/rollouts/final.npz")
2929

3030
imitation_trainer = PPO(ActorCriticPolicy, env)
3131
density_trainer = db.DensityAlgorithm(

experiments/convert_traj.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import numpy as np
1010

11-
from imitation.data import rollout, types
11+
from imitation.data import rollout, serialize, types
1212

1313

1414
def convert_trajs_to_sb(trajs: Sequence[types.TrajectoryWithRew]) -> dict:
@@ -32,7 +32,7 @@ def main():
3232
dst_path = Path(args.dst_path)
3333

3434
assert src_path.is_file()
35-
src_trajs = types.load_with_rewards(src_path)
35+
src_trajs = serialize.load_with_rewards(src_path)
3636
dst_trajs = convert_trajs_to_sb(src_trajs)
3737
os.makedirs(dst_path.parent, exist_ok=True)
3838
with open(dst_path, "wb") as f:

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ def get_local_version(version: "ScmVersion", time_format="%Y%m%d") -> str:
208208
"sacred>=0.8.4",
209209
"tensorboard>=1.14",
210210
"huggingface_sb3>=2.2.1",
211+
"datasets>=2.8.0",
211212
],
212213
tests_require=TESTS_REQUIRE,
213214
extras_require={

src/imitation/algorithms/adversarial/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def __init__(
202202
self.venv = venv
203203
self.gen_algo = gen_algo
204204
self._reward_net = reward_net.to(gen_algo.device)
205-
self._log_dir = types.parse_path(log_dir)
205+
self._log_dir = util.parse_path(log_dir)
206206

207207
# Create graph for optimising/recording stats on discriminator
208208
self._disc_opt_cls = disc_opt_cls

src/imitation/algorithms/bc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -490,4 +490,4 @@ def save_policy(self, policy_path: types.AnyPath) -> None:
490490
Args:
491491
policy_path: path to save policy to.
492492
"""
493-
th.save(self.policy, types.parse_path(policy_path))
493+
th.save(self.policy, util.parse_path(policy_path))

src/imitation/algorithms/dagger.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@
2020
from torch.utils import data as th_data
2121

2222
from imitation.algorithms import base, bc
23-
from imitation.data import rollout, types
23+
from imitation.data import rollout, serialize, types
2424
from imitation.util import logger as imit_logger
25+
from imitation.util import util
2526

2627

2728
class BetaSchedule(abc.ABC):
@@ -118,7 +119,7 @@ def reconstruct_trainer(
118119
A deserialized `DAggerTrainer`.
119120
"""
120121
custom_logger = custom_logger or imit_logger.configure()
121-
scratch_dir = types.parse_path(scratch_dir)
122+
scratch_dir = util.parse_path(scratch_dir)
122123
checkpoint_path = scratch_dir / "checkpoint-latest.pt"
123124
trainer = th.load(checkpoint_path, map_location=utils.get_device(device))
124125
trainer.venv = venv
@@ -133,7 +134,7 @@ def _save_dagger_demo(
133134
rng: np.random.Generator,
134135
prefix: str = "",
135136
) -> None:
136-
save_dir = types.parse_path(save_dir)
137+
save_dir = util.parse_path(save_dir)
137138
assert isinstance(trajectory, types.Trajectory)
138139
actual_prefix = f"{prefix}-" if prefix else ""
139140
randbits = int.from_bytes(rng.bytes(16), "big")
@@ -143,7 +144,7 @@ def _save_dagger_demo(
143144
assert (
144145
not npz_path.exists()
145146
), "The following DAgger demonstration path already exists: {0}".format(npz_path)
146-
types.save(npz_path, [trajectory])
147+
serialize.save(npz_path, [trajectory])
147148
logging.info(f"Saved demo at '{npz_path}'")
148149

149150

@@ -353,7 +354,7 @@ def __init__(
353354
if beta_schedule is None:
354355
beta_schedule = LinearBetaSchedule(15)
355356
self.beta_schedule = beta_schedule
356-
self.scratch_dir = types.parse_path(scratch_dir)
357+
self.scratch_dir = util.parse_path(scratch_dir)
357358
self.venv = venv
358359
self.round_num = 0
359360
self._last_loaded_round = -1
@@ -399,7 +400,7 @@ def _load_all_demos(self) -> Tuple[types.Transitions, List[int]]:
399400
for round_num in range(self._last_loaded_round + 1, self.round_num + 1):
400401
round_dir = self._demo_dir_path_for_round(round_num)
401402
demo_paths = self._get_demo_paths(round_dir)
402-
self._all_demos.extend(types.load(p)[0] for p in demo_paths)
403+
self._all_demos.extend(serialize.load(p)[0] for p in demo_paths)
403404
num_demos_by_round.append(len(demo_paths))
404405
logging.info(f"Loaded {len(self._all_demos)} total")
405406
demo_transitions = rollout.flatten_trajectories(self._all_demos)
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
"""Helpers to convert between Trajectories and HuggingFace's datasets library."""
2+
import functools
3+
from typing import Any, Dict, Iterable, Sequence, cast
4+
5+
import datasets
6+
import jsonpickle
7+
import numpy as np
8+
9+
from imitation.data import types
10+
11+
12+
class TrajectoryDatasetSequence(Sequence[types.Trajectory]):
13+
"""A wrapper to present a HF dataset as a sequence of trajectories.
14+
15+
Converts the dataset to a sequence of trajectories on the fly.
16+
"""
17+
18+
def __init__(self, dataset: datasets.Dataset):
19+
"""Construct a TrajectoryDatasetSequence."""
20+
# TODO: this is just a temporary workaround for
21+
# https://github.com/huggingface/datasets/issues/5517
22+
# switch to .with_format("numpy") once it's fixed
23+
def numpy_transform(batch):
24+
return {key: np.asarray(val) for key, val in batch.items()}
25+
26+
self._dataset = dataset.with_transform(numpy_transform)
27+
self._trajectory_class = (
28+
types.TrajectoryWithRew if "rews" in dataset.features else types.Trajectory
29+
)
30+
31+
def __len__(self) -> int:
32+
return len(self._dataset)
33+
34+
def __getitem__(self, idx):
35+
36+
if isinstance(idx, slice):
37+
dataslice = self._dataset[idx]
38+
39+
# Extract the trajectory kwargs from the dataset slice
40+
trajectory_kwargs = [
41+
{key: dataslice[key][i] for key in dataslice}
42+
for i in range(len(dataslice["obs"]))
43+
]
44+
45+
# Ensure that the infos are decoded lazily using jsonpickle
46+
for kwargs in trajectory_kwargs:
47+
kwargs["infos"] = _LazyDecodedList(kwargs["infos"])
48+
49+
return [self._trajectory_class(**kwargs) for kwargs in trajectory_kwargs]
50+
else:
51+
# Extract the trajectory kwargs from the dataset
52+
kwargs = self._dataset[idx]
53+
54+
# Ensure that the infos are decoded lazily using jsonpickle
55+
kwargs["infos"] = _LazyDecodedList(kwargs["infos"])
56+
57+
return self._trajectory_class(**kwargs)
58+
59+
60+
class _LazyDecodedList(Sequence[Any]):
61+
"""A wrapper to lazily decode a list of jsonpickled strings.
62+
63+
Decoded results are cached to avoid decoding the same string multiple times.
64+
65+
This is used to decode the infos of a trajectory only when they are accessed.
66+
"""
67+
68+
def __init__(self, encoded_list: Sequence[str]):
69+
self._encoded_list = encoded_list
70+
71+
def __len__(self):
72+
return len(self._encoded_list)
73+
74+
# arbitrary cache size just to put a limit on memory usage
75+
@functools.lru_cache(maxsize=100000)
76+
def __getitem__(self, idx):
77+
if isinstance(idx, slice):
78+
return [jsonpickle.decode(info) for info in self._encoded_list[idx]]
79+
else:
80+
return jsonpickle.decode(self._encoded_list[idx])
81+
82+
83+
def make_dict_from_trajectory(trajectory: types.Trajectory):
84+
"""Convert a Trajectory to a dict.
85+
86+
The dict has the following fields:
87+
* obs: The observations. Shape: (num_timesteps, obs_dim). dtype: float.
88+
* acts: The actions. Shape: (num_timesteps, act_dim). dtype: float.
89+
* infos: The infos. Shape: (num_timesteps, ). dtype: (jsonpickled) str.
90+
* terminal: The terminal flags. Shape: (num_timesteps, ). dtype: bool.
91+
* rews: The rewards. Shape: (num_timesteps, ). dtype: float. if applicable.
92+
93+
Args:
94+
trajectory: The trajectory to convert.
95+
96+
Returns:
97+
A dict representing the trajectory.
98+
"""
99+
# Replace 'None' values for `infos`` with array of empty dicts
100+
infos = cast(
101+
Sequence[Dict[str, Any]],
102+
trajectory.infos if trajectory.infos is not None else [{}] * len(trajectory),
103+
)
104+
105+
# Encode infos as jsonpickled strings
106+
encoded_infos = [jsonpickle.encode(info) for info in infos]
107+
108+
trajectory_dict = dict(
109+
obs=trajectory.obs,
110+
acts=trajectory.acts,
111+
infos=encoded_infos,
112+
terminal=trajectory.terminal,
113+
)
114+
115+
# Add rewards if applicable
116+
if isinstance(trajectory, types.TrajectoryWithRew):
117+
trajectory_dict["rews"] = trajectory.rews
118+
119+
return trajectory_dict
120+
121+
122+
def trajectories_to_dict(
123+
trajectories: Sequence[types.Trajectory],
124+
) -> Dict[str, Sequence[Any]]:
125+
"""Convert a sequence of trajectories to a dict.
126+
127+
The dict has the following fields:
128+
129+
* obs: The observations. Shape: (num_trajectories, num_timesteps, obs_dim).
130+
* acts: The actions. Shape: (num_trajectories, num_timesteps, act_dim).
131+
* infos: The infos. Shape: (num_trajectories, num_timesteps) as jsonpickled str.
132+
* terminal: The terminal flags. Shape: (num_trajectories, num_timesteps, ).
133+
* rews: The rewards. Shape: (num_trajectories, num_timesteps) if applicable.
134+
135+
This dict can be used to construct a HuggingFace dataset.
136+
137+
Args:
138+
trajectories: The trajectories to save.
139+
140+
Raises:
141+
ValueError: If not all trajectories have the same type, i.e. some are
142+
`Trajectory` and others are `TrajectoryWithRew`.
143+
144+
Returns:
145+
A dict representing the trajectories.
146+
"""
147+
# Check that all trajectories have rewards or none have rewards
148+
has_reward = [isinstance(traj, types.TrajectoryWithRew) for traj in trajectories]
149+
all_trajectories_have_reward = all(has_reward)
150+
if not all_trajectories_have_reward and any(has_reward):
151+
raise ValueError("Some trajectories have rewards but not all")
152+
153+
# Convert to dict
154+
trajectory_dict: Dict[str, Sequence[Any]] = dict(
155+
obs=[traj.obs for traj in trajectories],
156+
acts=[traj.acts for traj in trajectories],
157+
# Replace 'None' values for `infos`` with array of empty dicts
158+
infos=[
159+
traj.infos if traj.infos is not None else [{}] * len(traj)
160+
for traj in trajectories
161+
],
162+
terminal=[traj.terminal for traj in trajectories],
163+
)
164+
165+
# Encode infos as jsonpickled strings
166+
trajectory_dict["infos"] = [
167+
[jsonpickle.encode(info) for info in traj_infos]
168+
for traj_infos in cast(Iterable[Iterable[Dict]], trajectory_dict["infos"])
169+
]
170+
171+
# Add rewards if applicable
172+
if all_trajectories_have_reward:
173+
trajectory_dict["rews"] = [
174+
cast(types.TrajectoryWithRew, traj).rews for traj in trajectories
175+
]
176+
return trajectory_dict

src/imitation/data/serialize.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
"""Serialization utilities for trajectories."""
2+
import logging
3+
import os
4+
import warnings
5+
from typing import Mapping, Sequence, cast
6+
7+
import datasets
8+
import huggingface_sb3 as hfsb3
9+
import numpy as np
10+
11+
from imitation.data import huggingface_utils
12+
from imitation.data.types import AnyPath, Trajectory, TrajectoryWithRew
13+
from imitation.util import util
14+
15+
16+
def save(path: AnyPath, trajectories: Sequence[Trajectory]) -> None:
17+
"""Save a sequence of Trajectories to disk using HuggingFace's datasets library.
18+
19+
Args:
20+
path: Trajectories are saved to this path.
21+
trajectories: The trajectories to save.
22+
"""
23+
p = util.parse_path(path)
24+
d = datasets.Dataset.from_dict(huggingface_utils.trajectories_to_dict(trajectories))
25+
d.save_to_disk(p)
26+
logging.info(f"Dumped demonstrations to {p}.")
27+
28+
29+
def load(path: AnyPath) -> Sequence[Trajectory]:
30+
"""Loads a sequence of trajectories saved by `save()` from `path`."""
31+
# Interestingly, np.load will just silently load a normal pickle file when you
32+
# set `allow_pickle=True`. So this call should succeed for both the new compressed
33+
# .npz format and the old pickle based format. To tell the difference we need to
34+
# look at the type of the resulting object. If it's the new compressed format,
35+
# it should be a Mapping that we need to decode, whereas if it's the old format
36+
# it's just the sequence of trajectories, and we can return it directly.
37+
38+
if os.path.isdir(path): # huggingface datasets format
39+
dataset = datasets.load_from_disk(str(path))
40+
if not isinstance(dataset, datasets.Dataset):
41+
raise ValueError(
42+
f"Expected to load a `datasets.Dataset` but got {type(dataset)}",
43+
)
44+
45+
return huggingface_utils.TrajectoryDatasetSequence(dataset)
46+
47+
data = np.load(path, allow_pickle=True) # works for both .npz and .pkl
48+
49+
if isinstance(data, Sequence): # pickle format
50+
warnings.warn("Loading old pickle version of Trajectories", DeprecationWarning)
51+
return data
52+
if isinstance(data, Mapping): # .npz format
53+
warnings.warn("Loading old npz version of Trajectories", DeprecationWarning)
54+
num_trajs = len(data["indices"])
55+
fields = [
56+
# Account for the extra obs in each trajectory
57+
np.split(data["obs"], data["indices"] + np.arange(num_trajs) + 1),
58+
np.split(data["acts"], data["indices"]),
59+
np.split(data["infos"], data["indices"]),
60+
data["terminal"],
61+
]
62+
if "rews" in data:
63+
fields = [
64+
*fields,
65+
np.split(data["rews"], data["indices"]),
66+
]
67+
return [TrajectoryWithRew(*args) for args in zip(*fields)]
68+
else:
69+
return [Trajectory(*args) for args in zip(*fields)]
70+
else: # pragma: no cover
71+
raise ValueError(
72+
f"Expected either an .npz file or a pickled sequence of trajectories; "
73+
f"got a pickled object of type {type(data).__name__}",
74+
)
75+
76+
77+
def load_with_rewards(path: AnyPath) -> Sequence[TrajectoryWithRew]:
78+
"""Loads a sequence of trajectories with rewards from a file."""
79+
data = load(path)
80+
81+
mismatched_types = [
82+
type(traj) for traj in data if not isinstance(traj, TrajectoryWithRew)
83+
]
84+
if mismatched_types:
85+
raise ValueError(
86+
f"Expected all trajectories to be of type `TrajectoryWithRew`, "
87+
f"but found {mismatched_types[0].__name__}",
88+
)
89+
90+
return cast(Sequence[TrajectoryWithRew], data)
91+
92+
93+
def load_rollouts_from_huggingface(
94+
algo_name: str,
95+
env_name: str,
96+
organization: str = "HumanCompatibleAI",
97+
) -> str:
98+
model_name = hfsb3.ModelName(algo_name, hfsb3.EnvironmentName(env_name))
99+
repo_id = hfsb3.ModelRepoId(organization, model_name)
100+
filename = hfsb3.load_from_hub(repo_id, "rollouts.npz")
101+
return filename

0 commit comments

Comments
 (0)