Skip to content

Commit 40da841

Browse files
Fix Taxi-v3 deprecation in tests (gymnasium 1.3.0 compatibility) (#2247)
* Fix Taxi-v3 deprecation in tests (gymnasium 1.3.0 compatibility) * Update changelog with bug fixes Removed it from deprecations * Fix mypy error with newer pytorch * Update warnings filters * Update warning filter * Optimize test_features_extractor_target_net --------- Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
1 parent 08d984c commit 40da841

5 files changed

Lines changed: 32 additions & 7 deletions

File tree

docs/misc/changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
### New Features:
1313

1414
### Bug Fixes:
15+
- Fixed deprecated error Taxi-v3 from gymnasium v1.3.0 in tests
1516

1617
### [SB3-Contrib]
1718

pyproject.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,17 +39,18 @@ exclude = """(?x)(
3939
env = ["PYTHONHASHSEED=0"]
4040

4141
filterwarnings = [
42+
# multiprocessing and fork method, use of fork may lead to deadlocks
43+
"ignore::DeprecationWarning:multiprocessing.popen_fork",
4244
# A2C/PPO on GPU
4345
"ignore:You are trying to run (PPO|A2C) on the GPU",
4446
# Tensorboard warnings
4547
"ignore::DeprecationWarning:tensorboard",
4648
# Gymnasium warnings
4749
"ignore::UserWarning:gymnasium",
50+
# Taxi-v3 and CliffWalking-v0
51+
"ignore::DeprecationWarning:gymnasium.envs.registration",
4852
# tqdm warning about rich being experimental
4953
"ignore:rich is experimental",
50-
# Pygame warnings about pkg_resources
51-
"ignore:pkg_resources is deprecated",
52-
"ignore:Deprecated call to `pkg_resources",
5354
]
5455
markers = [
5556
"expensive: marks tests as expensive (deselect with '-m \"not expensive\"')",

stable_baselines3/common/logger.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,7 @@ def write(self, key_values: dict[str, Any], key_excluded: dict[str, tuple[str, .
428428
self.writer.add_image(key, value.image, step, dataformats=value.dataformats)
429429

430430
if isinstance(value, HParam):
431+
assert self.writer.file_writer is not None
431432
# we don't use `self.writer.add_hparams` to have control over the log_dir
432433
experiment, session_start_info, session_end_info = hparams(value.hparam_dict, metric_dict=value.metric_dict)
433434
self.writer.file_writer.add_summary(experiment)

tests/test_cnn.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,12 +133,19 @@ def test_features_extractor_target_net(model_class, share_features_extractor):
133133
if model_class == DQN and share_features_extractor:
134134
pytest.skip()
135135

136-
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1, discrete=model_class not in {SAC, TD3})
136+
env = FakeImageEnv(screen_height=36, screen_width=36, n_channels=1, discrete=model_class not in {SAC, TD3})
137137
# Avoid memory error when using replay buffer
138138
# Reduce the size of the features
139-
kwargs = dict(buffer_size=250, learning_starts=100, policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)))
139+
kwargs = dict(
140+
buffer_size=100,
141+
learning_starts=50,
142+
policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=8)),
143+
learning_rate=1e-3,
144+
)
140145
if model_class != DQN:
141146
kwargs["policy_kwargs"]["share_features_extractor"] = share_features_extractor
147+
else:
148+
kwargs["target_update_interval"] = 10
142149

143150
# No delay for TD3 (changes when the actor and polyak update take place)
144151
if model_class == TD3:
@@ -172,7 +179,7 @@ def test_features_extractor_target_net(model_class, share_features_extractor):
172179
if model_class == TD3:
173180
params_should_match(model.actor.parameters(), model.actor_target.parameters())
174181

175-
model.learn(200)
182+
model.learn(100)
176183

177184
# Critic and target should differ
178185
params_should_differ(model.critic.parameters(), model.critic_target.parameters())

tests/test_spaces.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,8 @@ def test_sde_multi_dim():
123123
@pytest.mark.parametrize("model_class", [A2C, PPO, DQN])
124124
@pytest.mark.parametrize("env", ["Taxi-v3"])
125125
def test_discrete_obs_space(model_class, env):
126-
env = make_vec_env(env, n_envs=2, seed=0)
126+
127+
env = _make_env_safe(env, n_envs=2, seed=0)
127128
kwargs = {}
128129
if model_class == DQN:
129130
kwargs = dict(buffer_size=1000, learning_starts=100)
@@ -173,3 +174,17 @@ def test_multidim_binary_not_supported():
173174
env = DummyEnv(BOX_SPACE_FLOAT32, spaces.MultiBinary([2, 3]))
174175
with pytest.raises(AssertionError, match=r"Multi-dimensional MultiBinary\(.*\) action space is not supported"):
175176
A2C("MlpPolicy", env)
177+
178+
179+
def _make_env_safe(env_id, **kwargs):
180+
try:
181+
return make_vec_env(env_id, **kwargs)
182+
except gym.error.DeprecatedEnv:
183+
# map deprecated to new (better extendability for this)
184+
fallback_map = {
185+
# as of gymnasium 1.3.0 Taxi-v3 is deprecated
186+
"Taxi-v3": "Taxi-v4",
187+
}
188+
if env_id in fallback_map:
189+
return make_vec_env(fallback_map[env_id], **kwargs)
190+
raise

0 commit comments

Comments
 (0)