Skip to content

Commit aca4c07

Browse files
authored
Update gymnasium dependency and render_mode in gym.make (#806)
* Update gymnasium version and render mode in eval policy * Fix error
1 parent e6d8886 commit aca4c07

2 files changed

Lines changed: 7 additions & 3 deletions

File tree

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def get_local_version(version: "ScmVersion", time_format="%Y%m%d") -> str:
188188
# encode only known incompatibilities here. This prevents nasty dependency issues
189189
# for our users.
190190
install_requires=[
191-
"gymnasium[classic-control]~=0.28.1",
191+
"gymnasium[classic-control]~=0.29",
192192
"matplotlib",
193193
"numpy>=1.15",
194194
"torch>=1.4.0",
@@ -220,7 +220,7 @@ def get_local_version(version: "ScmVersion", time_format="%Y%m%d") -> str:
220220
"docs": DOCS_REQUIRE,
221221
"parallel": PARALLEL_REQUIRE,
222222
"mujoco": [
223-
"gymnasium[classic-control,mujoco]~=0.28.1",
223+
"gymnasium[classic-control,mujoco]~=0.29",
224224
],
225225
"atari": ATARI_REQUIRE,
226226
},

src/imitation/scripts/eval_policy.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,11 @@ def eval_policy(
9595
log_dir = logging_ingredient.make_log_dir()
9696
sample_until = rollout.make_sample_until(eval_n_timesteps, eval_n_episodes)
9797
post_wrappers = [video_wrapper_factory(log_dir, **video_kwargs)] if videos else None
98-
with environment.make_venv(post_wrappers=post_wrappers) as venv:
98+
render_mode = "rgb_array" if videos else None
99+
with environment.make_venv(
100+
post_wrappers=post_wrappers,
101+
env_make_kwargs=dict(render_mode=render_mode),
102+
) as venv:
99103
if render:
100104
venv = InteractiveRender(venv, render_fps)
101105

0 commit comments

Comments
 (0)