diff --git a/docs/misc/changelog.md b/docs/misc/changelog.md index f417da89f..d58c42090 100644 --- a/docs/misc/changelog.md +++ b/docs/misc/changelog.md @@ -12,6 +12,7 @@ ### New Features: ### Bug Fixes: +- Fixed deprecated error Taxi-v3 from gymnasium v1.3.0 in tests ### [SB3-Contrib] diff --git a/pyproject.toml b/pyproject.toml index b26d86ec9..ae67ce959 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,17 +39,18 @@ exclude = """(?x)( env = ["PYTHONHASHSEED=0"] filterwarnings = [ + # multiprocessing and fork method, use of fork may lead to deadlocks + "ignore::DeprecationWarning:multiprocessing.popen_fork", # A2C/PPO on GPU "ignore:You are trying to run (PPO|A2C) on the GPU", # Tensorboard warnings "ignore::DeprecationWarning:tensorboard", # Gymnasium warnings "ignore::UserWarning:gymnasium", + # Taxi-v3 and CliffWalking-v0 + "ignore::DeprecationWarning:gymnasium.envs.registration", # tqdm warning about rich being experimental "ignore:rich is experimental", - # Pygame warnings about pkg_resources - "ignore:pkg_resources is deprecated", - "ignore:Deprecated call to `pkg_resources", ] markers = [ "expensive: marks tests as expensive (deselect with '-m \"not expensive\"')", diff --git a/stable_baselines3/common/logger.py b/stable_baselines3/common/logger.py index c2797be30..88fcf5614 100644 --- a/stable_baselines3/common/logger.py +++ b/stable_baselines3/common/logger.py @@ -428,6 +428,7 @@ def write(self, key_values: dict[str, Any], key_excluded: dict[str, tuple[str, . self.writer.add_image(key, value.image, step, dataformats=value.dataformats) if isinstance(value, HParam): + assert self.writer.file_writer is not None # we don't use `self.writer.add_hparams` to have control over the log_dir experiment, session_start_info, session_end_info = hparams(value.hparam_dict, metric_dict=value.metric_dict) self.writer.file_writer.add_summary(experiment) diff --git a/tests/test_cnn.py b/tests/test_cnn.py index 6ddf8a44f..e10e4093a 100644 --- a/tests/test_cnn.py +++ b/tests/test_cnn.py @@ -133,12 +133,19 @@ def test_features_extractor_target_net(model_class, share_features_extractor): if model_class == DQN and share_features_extractor: pytest.skip() - env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1, discrete=model_class not in {SAC, TD3}) + env = FakeImageEnv(screen_height=36, screen_width=36, n_channels=1, discrete=model_class not in {SAC, TD3}) # Avoid memory error when using replay buffer # Reduce the size of the features - kwargs = dict(buffer_size=250, learning_starts=100, policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32))) + kwargs = dict( + buffer_size=100, + learning_starts=50, + policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=8)), + learning_rate=1e-3, + ) if model_class != DQN: kwargs["policy_kwargs"]["share_features_extractor"] = share_features_extractor + else: + kwargs["target_update_interval"] = 10 # No delay for TD3 (changes when the actor and polyak update take place) if model_class == TD3: @@ -172,7 +179,7 @@ def test_features_extractor_target_net(model_class, share_features_extractor): if model_class == TD3: params_should_match(model.actor.parameters(), model.actor_target.parameters()) - model.learn(200) + model.learn(100) # Critic and target should differ params_should_differ(model.critic.parameters(), model.critic_target.parameters()) diff --git a/tests/test_spaces.py b/tests/test_spaces.py index d6a178668..a3000af25 100644 --- a/tests/test_spaces.py +++ b/tests/test_spaces.py @@ -123,7 +123,8 @@ def test_sde_multi_dim(): @pytest.mark.parametrize("model_class", [A2C, PPO, DQN]) @pytest.mark.parametrize("env", ["Taxi-v3"]) def test_discrete_obs_space(model_class, env): - env = make_vec_env(env, n_envs=2, seed=0) + + env = _make_env_safe(env, n_envs=2, seed=0) kwargs = {} if model_class == DQN: kwargs = dict(buffer_size=1000, learning_starts=100) @@ -173,3 +174,17 @@ def test_multidim_binary_not_supported(): env = DummyEnv(BOX_SPACE_FLOAT32, spaces.MultiBinary([2, 3])) with pytest.raises(AssertionError, match=r"Multi-dimensional MultiBinary\(.*\) action space is not supported"): A2C("MlpPolicy", env) + + +def _make_env_safe(env_id, **kwargs): + try: + return make_vec_env(env_id, **kwargs) + except gym.error.DeprecatedEnv: + # map deprecated to new (better extendability for this) + fallback_map = { + # as of gymnasium 1.3.0 Taxi-v3 is deprecated + "Taxi-v3": "Taxi-v4", + } + if env_id in fallback_map: + return make_vec_env(fallback_map[env_id], **kwargs) + raise