diff --git a/README.md b/README.md index 5d4498831..ccb570610 100644 --- a/README.md +++ b/README.md @@ -19,8 +19,8 @@ These algorithms will make it easier for the research community and industry to ## Main Features -**The performance of each algorithm was tested** (see *Results* section in their respective page), -you can take a look at the issues [#48](https://github.com/DLR-RM/stable-baselines3/issues/48) and [#49](https://github.com/DLR-RM/stable-baselines3/issues/49) for more details. +**The performance of each algorithm was tested** (see *Results* section in their respective page). +You can take a look at the issues [#48](https://github.com/DLR-RM/stable-baselines3/issues/48) and [#49](https://github.com/DLR-RM/stable-baselines3/issues/49) for more details. We also provide detailed logs and reports on the [OpenRL Benchmark](https://wandb.ai/openrlbenchmark/sb3) platform. @@ -43,7 +43,7 @@ We also provide detailed logs and reports on the [OpenRL Benchmark](https://wand ### Planned features -Since most of the features from the [original roadmap](https://github.com/DLR-RM/stable-baselines3/issues/1) have been implemented, there are no major changes planned for SB3, it is now *stable*. +Since most of the features from the [original roadmap](https://github.com/DLR-RM/stable-baselines3/issues/1) have been implemented, there are no major changes planned for SB3; it is now *stable*. If you want to contribute, you can search in the issues for the ones where [help is welcomed](https://github.com/DLR-RM/stable-baselines3/labels/help%20wanted) and the other [proposed enhancements](https://github.com/DLR-RM/stable-baselines3/labels/enhancement). While SB3 development is now focused on bug fixes and maintenance (doc update, user experience, ...), there is more active development going on in the associated repositories: @@ -116,7 +116,7 @@ Install the Stable Baselines3 package: pip install 'stable-baselines3[extra]' ``` -This includes optional dependencies like Tensorboard, OpenCV or `ale-py` to train on atari games. If you do not need those, you can use: +This includes optional dependencies like Tensorboard, OpenCV, `ale-py` to train on atari games, as well as `pandas` and `matplotlib` for plotting and analyzing results. If you do not need those, you can use: ```sh pip install stable-baselines3 ``` @@ -163,7 +163,7 @@ model = PPO("MlpPolicy", "CartPole-v1").learn(10_000) Please read the [documentation](https://stable-baselines3.readthedocs.io/) for more examples. -## Try it online with Colab Notebooks ! +## Try it online with Colab Notebooks! All the following examples can be executed online using Google Colab notebooks: @@ -201,7 +201,7 @@ All the following examples can be executed online using Google Colab notebooks: Actions `gymnasium.spaces`: * `Box`: A N-dimensional box that contains every point in the action space. - * `Discrete`: A list of possible actions, where each timestep only one of the actions can be used. + * `Discrete`: A list of possible actions, where only one action can be used per timestep. * `MultiDiscrete`: A list of possible actions, where each timestep only one action of each discrete set can be used. * `MultiBinary`: A list of possible actions, where each timestep any of the actions can be used in any combination. @@ -272,12 +272,12 @@ Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv ## How To Contribute -To any interested in making the baselines better, there is still some documentation that needs to be done. +For anyone interested in making the baselines better, there is still some documentation that needs to be done. If you want to contribute, please read [**CONTRIBUTING.md**](./CONTRIBUTING.md) guide first. ## Acknowledgments -The initial work to develop Stable Baselines3 was partially funded by the project *Reduced Complexity Models* from the *Helmholtz-Gemeinschaft Deutscher Forschungszentren*, and by the EU's Horizon 2020 Research and Innovation Programme under grant number 951992 ([VeriDream](https://www.veridream.eu/)). +The initial work to develop Stable Baselines3 was partially funded by the project *Reduced Complexity Models* from the *Helmholtz-Gemeinschaft Deutscher Forschungszentren*, and by the EU Horizon 2020 Research and Innovation Programme under grant number 951992 ([VeriDream](https://www.veridream.eu/)). The original version, Stable Baselines, was created in the [robotics lab U2IS](http://u2is.ensta-paristech.fr/index.php?lang=en) ([INRIA Flowers](https://flowers.inria.fr/) team) at [ENSTA ParisTech](http://www.ensta-paristech.fr/en). diff --git a/docs/guide/plotting.md b/docs/guide/plotting.md index 7a6b0dc31..af0ef0930 100644 --- a/docs/guide/plotting.md +++ b/docs/guide/plotting.md @@ -5,6 +5,17 @@ Stable Baselines3 provides utilities for plotting training results, allowing you to monitor and visualize your agent's learning progress. The plotting functionality is provided by the `results_plotter` module, which can load monitor files created during training and generate various plots. +:::{note} +Plotting requires `pandas` and `matplotlib`. Install them with: +```bash +pip install pandas matplotlib +``` +Or install the extra dependencies: +```bash +pip install 'stable-baselines3[extra]' +``` +::: + :::{note} We recommend using the [RL Baselines3 Zoo plotting scripts](https://rl-baselines3-zoo.readthedocs.io/en/master/guide/plot.html) diff --git a/docs/misc/changelog.md b/docs/misc/changelog.md index fb6662807..f417da89f 100644 --- a/docs/misc/changelog.md +++ b/docs/misc/changelog.md @@ -2,10 +2,12 @@ # Changelog -## Release 2.9.0a1 (WIP) +## Release 2.9.0a2 (WIP) ### Breaking Changes: -- Relax Gymnasium version range (from `"gymnasium>=0.29.1,<1.3.0"` to `"gymnasium>=0.29.1,<2.0"`) +- Relaxed Gymnasium version range (from `"gymnasium>=0.29.1,<1.3.0"` to `"gymnasium>=0.29.1,<2.0"`) +- `pandas` and `matplotlib` are no longer core dependencies; they are now optional and only required for loading results and plotting (moved to `stable-baselines3[extra]`). +- Moved `read_json` and `read_csv` helper functions to test files ### New Features: @@ -21,7 +23,7 @@ ### Others: -- Optimize tests (faster to run) +- Optimized tests (faster to run) ### Documentation: diff --git a/pyproject.toml b/pyproject.toml index 056ccb152..b26d86ec9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,4 +72,5 @@ exclude_lines = [ "pragma: no cover", "raise NotImplementedError()", "if typing.TYPE_CHECKING:", + "if TYPE_CHECKING:", ] diff --git a/setup.py b/setup.py index 57eb07dde..077e1adfd 100644 --- a/setup.py +++ b/setup.py @@ -84,10 +84,6 @@ "torch>=2.3,<3.0", # For saving models "cloudpickle", - # For reading logs - "pandas", - # Plotting learning curves - "matplotlib", ], extras_require={ "tests": [ @@ -128,6 +124,9 @@ # For atari games, "ale-py>=0.9.0", "pillow", + # For plotting and loading results + "pandas", + "matplotlib", ], }, description="Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.", diff --git a/stable_baselines3/common/logger.py b/stable_baselines3/common/logger.py index c15a65df8..c2797be30 100644 --- a/stable_baselines3/common/logger.py +++ b/stable_baselines3/common/logger.py @@ -7,13 +7,14 @@ from collections import defaultdict from collections.abc import Mapping, Sequence from io import TextIOBase -from typing import Any, TextIO +from typing import TYPE_CHECKING, Any, TextIO -import matplotlib.figure import numpy as np -import pandas import torch as th +if TYPE_CHECKING: + import matplotlib.figure + try: from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard.summary import hparams @@ -53,7 +54,7 @@ class Figure: :param close: if true, close the figure after logging it """ - def __init__(self, figure: matplotlib.figure.Figure, close: bool): + def __init__(self, figure: "matplotlib.figure.Figure", close: bool): self.figure = figure self.close = close @@ -665,32 +666,3 @@ def configure(folder: str | None = None, format_strings: list[str] | None = None if len(format_strings) > 0 and format_strings != ["stdout"]: logger.log(f"Logging to {folder}") return logger - - -# ================================================================ -# Readers -# ================================================================ - - -def read_json(filename: str) -> pandas.DataFrame: - """ - read a json file using pandas - - :param filename: the file path to read - :return: the data in the json - """ - data = [] - with open(filename) as file_handler: - for line in file_handler: - data.append(json.loads(line)) - return pandas.DataFrame(data) - - -def read_csv(filename: str) -> pandas.DataFrame: - """ - read a csv file using pandas - - :param filename: the file path to read - :return: the data in the csv - """ - return pandas.read_csv(filename, index_col=None, comment="#") diff --git a/stable_baselines3/common/monitor.py b/stable_baselines3/common/monitor.py index a482135dd..5eba694a4 100644 --- a/stable_baselines3/common/monitor.py +++ b/stable_baselines3/common/monitor.py @@ -5,12 +5,14 @@ import os import time from glob import glob -from typing import Any, SupportsFloat +from typing import TYPE_CHECKING, Any, SupportsFloat import gymnasium as gym -import pandas from gymnasium.core import ActType, ObsType +if TYPE_CHECKING: + import pandas + class Monitor(gym.Wrapper[ObsType, ActType, ObsType, ActType]): """ @@ -227,13 +229,21 @@ def get_monitor_files(path: str) -> list[str]: return glob(os.path.join(path, "*" + Monitor.EXT)) -def load_results(path: str) -> pandas.DataFrame: +def load_results(path: str) -> "pandas.DataFrame": """ Load all Monitor logs from a given directory path matching ``*monitor.csv`` :param path: the directory path containing the log file(s) :return: the logged data """ + try: + import pandas + except ImportError as e: + raise ImportError( + "pandas is required for loading results. " + "Install it with `pip install pandas` or install the extra dependencies with " + "`pip install 'stable-baselines3[extra]'`." + ) from e monitor_files = get_monitor_files(path) if len(monitor_files) == 0: raise LoadMonitorResultsError(f"No monitor files of the form *{Monitor.EXT} found in {path}") diff --git a/stable_baselines3/common/results_plotter.py b/stable_baselines3/common/results_plotter.py index 58b489b42..c3f62e1de 100644 --- a/stable_baselines3/common/results_plotter.py +++ b/stable_baselines3/common/results_plotter.py @@ -1,11 +1,26 @@ from collections.abc import Callable import numpy as np -import pandas as pd + +try: + import pandas as pd +except ImportError as e: + raise ImportError( + "pandas is required for plotting functionality. " + "Install it with `pip install pandas` or install the extra dependencies with " + "`pip install 'stable-baselines3[extra]'`." + ) from e # import matplotlib # matplotlib.use('TkAgg') # Can change to 'Agg' for non-interactive mode -from matplotlib import pyplot as plt +try: + from matplotlib import pyplot as plt +except ImportError as e: + raise ImportError( + "matplotlib is required for plotting functionality. " + "Install it with `pip install matplotlib` or install the extra dependencies with " + "`pip install 'stable-baselines3[extra]'`." + ) from e from stable_baselines3.common.monitor import load_results diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 665a82640..84d4b1e63 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.9.0a1 +2.9.0a2 diff --git a/tests/test_logger.py b/tests/test_logger.py index beed40625..8e38cf556 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -1,4 +1,5 @@ import importlib.util +import json import os import sys import time @@ -8,6 +9,7 @@ import gymnasium as gym import numpy as np +import pandas import pytest import torch as th from gymnasium import spaces @@ -30,11 +32,34 @@ Video, configure, make_output_format, - read_csv, - read_json, ) from stable_baselines3.common.monitor import Monitor + +def read_csv(filename: str): + """ + read a csv file using pandas + + :param filename: the file path to read + :return: the data in the csv + """ + return pandas.read_csv(filename, index_col=None, comment="#") + + +def read_json(filename: str): + """ + read a json file using pandas + + :param filename: the file path to read + :return: the data in the json + """ + data = [] + with open(filename) as file_handler: + for line in file_handler: + data.append(json.loads(line)) + return pandas.DataFrame(data) + + KEY_VALUES = { "test": 1, "b": -3.14, @@ -634,3 +659,58 @@ def test_rollout_success_rate_onpolicy_algo(tmp_path): assert logger.name_to_value["rollout/success_rate"] == 0.5 model.learn(total_timesteps=steps_per_log * ep_steps, log_interval=1) assert logger.name_to_value["rollout/success_rate"] == 0.8 + + +def test_pandas_import_error(tmp_path): + """Test that a clear ImportError is raised when pandas is not available.""" + # Mock the import to simulate pandas not being installed + with mock.patch.dict("sys.modules", {"pandas": None}): + # First, remove the modules from cache if they exist + if "stable_baselines3.common.results_plotter" in sys.modules: + del sys.modules["stable_baselines3.common.results_plotter"] + if "stable_baselines3.common.monitor" in sys.modules: + del sys.modules["stable_baselines3.common.monitor"] + + # Test results_plotter raises ImportError at import time + with pytest.raises(ImportError, match="pandas is required for plotting"): + import stable_baselines3.common.results_plotter # noqa: F401 + + # Test load_results raises ImportError at call time + # monitor module can still be imported (pandas import is lazy) + from stable_baselines3.common.monitor import load_results + + with pytest.raises(ImportError, match="pandas is required for loading results"): + load_results(str(tmp_path)) + + +def test_matplotlib_import_error(): + """Test that a clear ImportError is raised when matplotlib is not available.""" + # Mock the import to simulate matplotlib not being installed + with mock.patch.dict("sys.modules", {"matplotlib": None, "matplotlib.pyplot": None}): + # First, remove the module from cache if it exists + if "stable_baselines3.common.results_plotter" in sys.modules: + del sys.modules["stable_baselines3.common.results_plotter"] + + # Test results_plotter raises ImportError at import time + with pytest.raises(ImportError, match="matplotlib is required for plotting"): + import stable_baselines3.common.results_plotter # noqa: F401 + + +def test_sb3_import_without_optional_deps(): + """Test that SB3 core can be imported without matplotlib and pandas.""" + # Mock the imports to simulate optional dependencies not being installed + with mock.patch.dict("sys.modules", {"pandas": None, "matplotlib": None, "matplotlib.pyplot": None}): + # First, remove the modules from cache if they exist + modules_to_remove = [key for key in sys.modules.keys() if key.startswith("stable_baselines3")] + for module in modules_to_remove: + del sys.modules[module] + + # Core SB3 should still be importable + from stable_baselines3 import A2C, DQN, PPO # noqa: F401 + + # Monitor should be importable (pandas import is lazy in load_results) + from stable_baselines3.common.monitor import Monitor # noqa: F401 + + # But plotting module should fail + with pytest.raises(ImportError): + import stable_baselines3.common.results_plotter # noqa: F401 diff --git a/tests/test_monitor.py b/tests/test_monitor.py index 4c25821a5..5e6520bc4 100644 --- a/tests/test_monitor.py +++ b/tests/test_monitor.py @@ -142,3 +142,33 @@ def test_monitor_load_results(tmp_path): os.remove(monitor_file1) os.remove(monitor_file2) + + +def test_monitor_error_cases(): + """ + Test error cases in Monitor wrapper + """ + env = gym.make("CartPole-v1") + env.reset(seed=0) + + with pytest.raises(RuntimeError, match="Tried to reset an environment before done"): + monitor_env = Monitor(env, allow_early_resets=False) + monitor_env.reset() + monitor_env.step(monitor_env.action_space.sample()) + monitor_env.reset() + + env2 = gym.make("CartPole-v1") + env2.reset(seed=0) + with pytest.raises(ValueError, match="Expected you to pass keyword argument test_key into reset"): + monitor_env2 = Monitor(env2, reset_keywords=("test_key",)) + monitor_env2.reset() + + # Note: cannot use test_key because CartPole doesn't accept this keyword + monitor_env2 = Monitor(env2, reset_keywords=("options",)) + monitor_env2.reset(options={"ok": 1}) + + env3 = gym.make("CartPole-v1") + env3.reset(seed=0) + with pytest.raises(RuntimeError, match="Tried to step environment that needs reset"): + monitor_env3 = Monitor(env3) + monitor_env3.step(monitor_env3.action_space.sample())