Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ These algorithms will make it easier for the research community and industry to

## Main Features

**The performance of each algorithm was tested** (see *Results* section in their respective page),
you can take a look at the issues [#48](https://github.com/DLR-RM/stable-baselines3/issues/48) and [#49](https://github.com/DLR-RM/stable-baselines3/issues/49) for more details.
**The performance of each algorithm was tested** (see *Results* section in their respective page).
You can take a look at the issues [#48](https://github.com/DLR-RM/stable-baselines3/issues/48) and [#49](https://github.com/DLR-RM/stable-baselines3/issues/49) for more details.

We also provide detailed logs and reports on the [OpenRL Benchmark](https://wandb.ai/openrlbenchmark/sb3) platform.

Expand All @@ -43,7 +43,7 @@ We also provide detailed logs and reports on the [OpenRL Benchmark](https://wand

### Planned features

Since most of the features from the [original roadmap](https://github.com/DLR-RM/stable-baselines3/issues/1) have been implemented, there are no major changes planned for SB3, it is now *stable*.
Since most of the features from the [original roadmap](https://github.com/DLR-RM/stable-baselines3/issues/1) have been implemented, there are no major changes planned for SB3; it is now *stable*.
If you want to contribute, you can search in the issues for the ones where [help is welcomed](https://github.com/DLR-RM/stable-baselines3/labels/help%20wanted) and the other [proposed enhancements](https://github.com/DLR-RM/stable-baselines3/labels/enhancement).

While SB3 development is now focused on bug fixes and maintenance (doc update, user experience, ...), there is more active development going on in the associated repositories:
Expand Down Expand Up @@ -116,7 +116,7 @@ Install the Stable Baselines3 package:
pip install 'stable-baselines3[extra]'
```

This includes optional dependencies like Tensorboard, OpenCV or `ale-py` to train on atari games. If you do not need those, you can use:
This includes optional dependencies like Tensorboard, OpenCV, `ale-py` to train on atari games, as well as `pandas` and `matplotlib` for plotting and analyzing results. If you do not need those, you can use:
```sh
pip install stable-baselines3
```
Expand Down Expand Up @@ -163,7 +163,7 @@ model = PPO("MlpPolicy", "CartPole-v1").learn(10_000)
Please read the [documentation](https://stable-baselines3.readthedocs.io/) for more examples.


## Try it online with Colab Notebooks !
## Try it online with Colab Notebooks!

All the following examples can be executed online using Google Colab notebooks:

Expand Down Expand Up @@ -201,7 +201,7 @@ All the following examples can be executed online using Google Colab notebooks:

Actions `gymnasium.spaces`:
* `Box`: A N-dimensional box that contains every point in the action space.
* `Discrete`: A list of possible actions, where each timestep only one of the actions can be used.
* `Discrete`: A list of possible actions, where only one action can be used per timestep.
* `MultiDiscrete`: A list of possible actions, where each timestep only one action of each discrete set can be used.
* `MultiBinary`: A list of possible actions, where each timestep any of the actions can be used in any combination.

Expand Down Expand Up @@ -272,12 +272,12 @@ Please post your question on the [RL Discord](https://discord.com/invite/xhfNqQv

## How To Contribute

To any interested in making the baselines better, there is still some documentation that needs to be done.
For anyone interested in making the baselines better, there is still some documentation that needs to be done.
If you want to contribute, please read [**CONTRIBUTING.md**](./CONTRIBUTING.md) guide first.

## Acknowledgments

The initial work to develop Stable Baselines3 was partially funded by the project *Reduced Complexity Models* from the *Helmholtz-Gemeinschaft Deutscher Forschungszentren*, and by the EU's Horizon 2020 Research and Innovation Programme under grant number 951992 ([VeriDream](https://www.veridream.eu/)).
The initial work to develop Stable Baselines3 was partially funded by the project *Reduced Complexity Models* from the *Helmholtz-Gemeinschaft Deutscher Forschungszentren*, and by the EU Horizon 2020 Research and Innovation Programme under grant number 951992 ([VeriDream](https://www.veridream.eu/)).

The original version, Stable Baselines, was created in the [robotics lab U2IS](http://u2is.ensta-paristech.fr/index.php?lang=en) ([INRIA Flowers](https://flowers.inria.fr/) team) at [ENSTA ParisTech](http://www.ensta-paristech.fr/en).

Expand Down
11 changes: 11 additions & 0 deletions docs/guide/plotting.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,17 @@
Stable Baselines3 provides utilities for plotting training results, allowing you to monitor and visualize your agent's learning progress.
The plotting functionality is provided by the `results_plotter` module, which can load monitor files created during training and generate various plots.

:::{note}
Plotting requires `pandas` and `matplotlib`. Install them with:
```bash
pip install pandas matplotlib
```
Or install the extra dependencies:
```bash
pip install 'stable-baselines3[extra]'
```
:::

:::{note}
We recommend using the
[RL Baselines3 Zoo plotting scripts](https://rl-baselines3-zoo.readthedocs.io/en/master/guide/plot.html)
Expand Down
8 changes: 5 additions & 3 deletions docs/misc/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

# Changelog

## Release 2.9.0a1 (WIP)
## Release 2.9.0a2 (WIP)

### Breaking Changes:
- Relax Gymnasium version range (from `"gymnasium>=0.29.1,<1.3.0"` to `"gymnasium>=0.29.1,<2.0"`)
- Relaxed Gymnasium version range (from `"gymnasium>=0.29.1,<1.3.0"` to `"gymnasium>=0.29.1,<2.0"`)
- `pandas` and `matplotlib` are no longer core dependencies; they are now optional and only required for loading results and plotting (moved to `stable-baselines3[extra]`).
- Moved `read_json` and `read_csv` helper functions to test files
Comment thread
araffin marked this conversation as resolved.

### New Features:

Expand All @@ -21,7 +23,7 @@

### Others:

- Optimize tests (faster to run)
- Optimized tests (faster to run)

### Documentation:

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,5 @@ exclude_lines = [
"pragma: no cover",
"raise NotImplementedError()",
"if typing.TYPE_CHECKING:",
"if TYPE_CHECKING:",
]
7 changes: 3 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,6 @@
"torch>=2.3,<3.0",
# For saving models
"cloudpickle",
# For reading logs
"pandas",
# Plotting learning curves
"matplotlib",
],
extras_require={
"tests": [
Expand Down Expand Up @@ -128,6 +124,9 @@
# For atari games,
"ale-py>=0.9.0",
"pillow",
# For plotting and loading results
"pandas",
"matplotlib",
],
},
description="Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.",
Expand Down
38 changes: 5 additions & 33 deletions stable_baselines3/common/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
from collections import defaultdict
from collections.abc import Mapping, Sequence
from io import TextIOBase
from typing import Any, TextIO
from typing import TYPE_CHECKING, Any, TextIO

import matplotlib.figure
import numpy as np
import pandas
import torch as th

if TYPE_CHECKING:
import matplotlib.figure

try:
from torch.utils.tensorboard import SummaryWriter
from torch.utils.tensorboard.summary import hparams
Expand Down Expand Up @@ -53,7 +54,7 @@ class Figure:
:param close: if true, close the figure after logging it
"""

def __init__(self, figure: matplotlib.figure.Figure, close: bool):
def __init__(self, figure: "matplotlib.figure.Figure", close: bool):
self.figure = figure
self.close = close

Expand Down Expand Up @@ -665,32 +666,3 @@ def configure(folder: str | None = None, format_strings: list[str] | None = None
if len(format_strings) > 0 and format_strings != ["stdout"]:
logger.log(f"Logging to {folder}")
return logger


# ================================================================
# Readers
# ================================================================


def read_json(filename: str) -> pandas.DataFrame:
"""
read a json file using pandas

:param filename: the file path to read
:return: the data in the json
"""
data = []
with open(filename) as file_handler:
for line in file_handler:
data.append(json.loads(line))
return pandas.DataFrame(data)


def read_csv(filename: str) -> pandas.DataFrame:
"""
read a csv file using pandas

:param filename: the file path to read
:return: the data in the csv
"""
return pandas.read_csv(filename, index_col=None, comment="#")
16 changes: 13 additions & 3 deletions stable_baselines3/common/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
import os
import time
from glob import glob
from typing import Any, SupportsFloat
from typing import TYPE_CHECKING, Any, SupportsFloat

import gymnasium as gym
import pandas
from gymnasium.core import ActType, ObsType

if TYPE_CHECKING:
import pandas


class Monitor(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
"""
Expand Down Expand Up @@ -227,13 +229,21 @@ def get_monitor_files(path: str) -> list[str]:
return glob(os.path.join(path, "*" + Monitor.EXT))


def load_results(path: str) -> pandas.DataFrame:
def load_results(path: str) -> "pandas.DataFrame":
"""
Load all Monitor logs from a given directory path matching ``*monitor.csv``

:param path: the directory path containing the log file(s)
:return: the logged data
"""
try:
import pandas
except ImportError as e:
raise ImportError(
"pandas is required for loading results. "
"Install it with `pip install pandas` or install the extra dependencies with "
"`pip install 'stable-baselines3[extra]'`."
) from e
monitor_files = get_monitor_files(path)
if len(monitor_files) == 0:
raise LoadMonitorResultsError(f"No monitor files of the form *{Monitor.EXT} found in {path}")
Expand Down
19 changes: 17 additions & 2 deletions stable_baselines3/common/results_plotter.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,26 @@
from collections.abc import Callable

import numpy as np
import pandas as pd

try:
import pandas as pd
except ImportError as e:
raise ImportError(
"pandas is required for plotting functionality. "
"Install it with `pip install pandas` or install the extra dependencies with "
"`pip install 'stable-baselines3[extra]'`."
) from e

# import matplotlib
# matplotlib.use('TkAgg') # Can change to 'Agg' for non-interactive mode
from matplotlib import pyplot as plt
try:
from matplotlib import pyplot as plt
except ImportError as e:
raise ImportError(
"matplotlib is required for plotting functionality. "
"Install it with `pip install matplotlib` or install the extra dependencies with "
"`pip install 'stable-baselines3[extra]'`."
) from e

from stable_baselines3.common.monitor import load_results

Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.9.0a1
2.9.0a2
84 changes: 82 additions & 2 deletions tests/test_logger.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import importlib.util
import json
import os
import sys
import time
Expand All @@ -8,6 +9,7 @@

import gymnasium as gym
import numpy as np
import pandas
import pytest
import torch as th
from gymnasium import spaces
Expand All @@ -30,11 +32,34 @@
Video,
configure,
make_output_format,
read_csv,
read_json,
)
from stable_baselines3.common.monitor import Monitor


def read_csv(filename: str):
"""
read a csv file using pandas

:param filename: the file path to read
:return: the data in the csv
"""
return pandas.read_csv(filename, index_col=None, comment="#")


def read_json(filename: str):
"""
read a json file using pandas

:param filename: the file path to read
:return: the data in the json
"""
data = []
with open(filename) as file_handler:
for line in file_handler:
data.append(json.loads(line))
return pandas.DataFrame(data)


KEY_VALUES = {
"test": 1,
"b": -3.14,
Expand Down Expand Up @@ -634,3 +659,58 @@ def test_rollout_success_rate_onpolicy_algo(tmp_path):
assert logger.name_to_value["rollout/success_rate"] == 0.5
model.learn(total_timesteps=steps_per_log * ep_steps, log_interval=1)
assert logger.name_to_value["rollout/success_rate"] == 0.8


def test_pandas_import_error(tmp_path):
"""Test that a clear ImportError is raised when pandas is not available."""
# Mock the import to simulate pandas not being installed
with mock.patch.dict("sys.modules", {"pandas": None}):
# First, remove the modules from cache if they exist
if "stable_baselines3.common.results_plotter" in sys.modules:
del sys.modules["stable_baselines3.common.results_plotter"]
if "stable_baselines3.common.monitor" in sys.modules:
del sys.modules["stable_baselines3.common.monitor"]

# Test results_plotter raises ImportError at import time
with pytest.raises(ImportError, match="pandas is required for plotting"):
import stable_baselines3.common.results_plotter # noqa: F401

# Test load_results raises ImportError at call time
# monitor module can still be imported (pandas import is lazy)
from stable_baselines3.common.monitor import load_results

with pytest.raises(ImportError, match="pandas is required for loading results"):
load_results(str(tmp_path))


def test_matplotlib_import_error():
"""Test that a clear ImportError is raised when matplotlib is not available."""
# Mock the import to simulate matplotlib not being installed
with mock.patch.dict("sys.modules", {"matplotlib": None, "matplotlib.pyplot": None}):
# First, remove the module from cache if it exists
if "stable_baselines3.common.results_plotter" in sys.modules:
del sys.modules["stable_baselines3.common.results_plotter"]

# Test results_plotter raises ImportError at import time
with pytest.raises(ImportError, match="matplotlib is required for plotting"):
import stable_baselines3.common.results_plotter # noqa: F401


def test_sb3_import_without_optional_deps():
"""Test that SB3 core can be imported without matplotlib and pandas."""
# Mock the imports to simulate optional dependencies not being installed
with mock.patch.dict("sys.modules", {"pandas": None, "matplotlib": None, "matplotlib.pyplot": None}):
# First, remove the modules from cache if they exist
modules_to_remove = [key for key in sys.modules.keys() if key.startswith("stable_baselines3")]
for module in modules_to_remove:
del sys.modules[module]

# Core SB3 should still be importable
from stable_baselines3 import A2C, DQN, PPO # noqa: F401

# Monitor should be importable (pandas import is lazy in load_results)
from stable_baselines3.common.monitor import Monitor # noqa: F401

# But plotting module should fail
with pytest.raises(ImportError):
import stable_baselines3.common.results_plotter # noqa: F401
Loading