Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/misc/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
### New Features:

### Bug Fixes:
- Fixed deprecated error Taxi-v3 from gymnasium v1.3.0 in tests

### [SB3-Contrib]

Expand Down
7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,18 @@ exclude = """(?x)(
env = ["PYTHONHASHSEED=0"]

filterwarnings = [
# multiprocessing and fork method, use of fork may lead to deadlocks
"ignore::DeprecationWarning:multiprocessing.popen_fork",
# A2C/PPO on GPU
"ignore:You are trying to run (PPO|A2C) on the GPU",
# Tensorboard warnings
"ignore::DeprecationWarning:tensorboard",
# Gymnasium warnings
"ignore::UserWarning:gymnasium",
# Taxi-v3 and CliffWalking-v0
"ignore::DeprecationWarning:gymnasium.envs.registration",
# tqdm warning about rich being experimental
"ignore:rich is experimental",
# Pygame warnings about pkg_resources
"ignore:pkg_resources is deprecated",
"ignore:Deprecated call to `pkg_resources",
]
markers = [
"expensive: marks tests as expensive (deselect with '-m \"not expensive\"')",
Expand Down
1 change: 1 addition & 0 deletions stable_baselines3/common/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,7 @@ def write(self, key_values: dict[str, Any], key_excluded: dict[str, tuple[str, .
self.writer.add_image(key, value.image, step, dataformats=value.dataformats)

if isinstance(value, HParam):
assert self.writer.file_writer is not None
# we don't use `self.writer.add_hparams` to have control over the log_dir
experiment, session_start_info, session_end_info = hparams(value.hparam_dict, metric_dict=value.metric_dict)
self.writer.file_writer.add_summary(experiment)
Expand Down
13 changes: 10 additions & 3 deletions tests/test_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,19 @@ def test_features_extractor_target_net(model_class, share_features_extractor):
if model_class == DQN and share_features_extractor:
pytest.skip()

env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1, discrete=model_class not in {SAC, TD3})
env = FakeImageEnv(screen_height=36, screen_width=36, n_channels=1, discrete=model_class not in {SAC, TD3})
# Avoid memory error when using replay buffer
# Reduce the size of the features
kwargs = dict(buffer_size=250, learning_starts=100, policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)))
kwargs = dict(
buffer_size=100,
learning_starts=50,
policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=8)),
learning_rate=1e-3,
)
if model_class != DQN:
kwargs["policy_kwargs"]["share_features_extractor"] = share_features_extractor
else:
kwargs["target_update_interval"] = 10

# No delay for TD3 (changes when the actor and polyak update take place)
if model_class == TD3:
Expand Down Expand Up @@ -172,7 +179,7 @@ def test_features_extractor_target_net(model_class, share_features_extractor):
if model_class == TD3:
params_should_match(model.actor.parameters(), model.actor_target.parameters())

model.learn(200)
model.learn(100)

# Critic and target should differ
params_should_differ(model.critic.parameters(), model.critic_target.parameters())
Expand Down
17 changes: 16 additions & 1 deletion tests/test_spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ def test_sde_multi_dim():
@pytest.mark.parametrize("model_class", [A2C, PPO, DQN])
@pytest.mark.parametrize("env", ["Taxi-v3"])
def test_discrete_obs_space(model_class, env):
env = make_vec_env(env, n_envs=2, seed=0)

env = _make_env_safe(env, n_envs=2, seed=0)
kwargs = {}
if model_class == DQN:
kwargs = dict(buffer_size=1000, learning_starts=100)
Expand Down Expand Up @@ -173,3 +174,17 @@ def test_multidim_binary_not_supported():
env = DummyEnv(BOX_SPACE_FLOAT32, spaces.MultiBinary([2, 3]))
with pytest.raises(AssertionError, match=r"Multi-dimensional MultiBinary\(.*\) action space is not supported"):
A2C("MlpPolicy", env)


def _make_env_safe(env_id, **kwargs):
try:
return make_vec_env(env_id, **kwargs)
except gym.error.DeprecatedEnv:
# map deprecated to new (better extendability for this)
fallback_map = {
# as of gymnasium 1.3.0 Taxi-v3 is deprecated
"Taxi-v3": "Taxi-v4",
}
if env_id in fallback_map:
return make_vec_env(fallback_map[env_id], **kwargs)
raise