Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Drop python 3.8 and add python 3.12 support #2041

Merged
merged 4 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"]
python-version: ["3.9", "3.10", "3.11", "3.12"]
include:
# Default version
- gymnasium-version: "1.0.0"
Expand Down Expand Up @@ -48,7 +48,8 @@ jobs:
- name: Install specific version of gym
run: |
uv pip install --system gymnasium==${{ matrix.gymnasium-version }}
# Only run for python 3.10, downgrade gym to 0.29.1
uv pip install --system "numpy<2"
# Only run for python 3.10, downgrade gym to 0.29.1, numpy<2
if: matrix.gymnasium-version != '1.0.0'
- name: Lint with ruff
run: |
Expand All @@ -62,8 +63,6 @@ jobs:
- name: Type check
run: |
make type
# Do not run for python 3.8 (mypy internal error)
if: matrix.python-version != '3.8'
- name: Test with pytest
run: |
make pytest
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,10 @@ It provides a minimal number of features compared to SB3 but can be much faster

## Installation

**Note:** Stable-Baselines3 supports PyTorch >= 1.13
**Note:** Stable-Baselines3 supports PyTorch >= 2.3

### Prerequisites
Stable Baselines3 requires Python 3.8+.
Stable Baselines3 requires Python 3.9+.

#### Windows

Expand Down
2 changes: 1 addition & 1 deletion docs/conda_env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ dependencies:
- cloudpickle
- opencv-python-headless
- pandas
- numpy>=1.20,<2.0
- numpy>=1.20,<3.0
- matplotlib
- sphinx>=5,<9
- sphinx_rtd_theme>=1.3.0
Expand Down
3 changes: 1 addition & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import datetime
import os
import sys
from typing import Dict

# We CANNOT enable 'sphinxcontrib.spelling' because ReadTheDocs.org does not support
# PyEnchant.
Expand Down Expand Up @@ -151,7 +150,7 @@ def setup(app):

# -- Options for LaTeX output ------------------------------------------------

latex_elements: Dict[str, str] = {
latex_elements: dict[str, str] = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why? (Out of curiosity)

Copy link
Member Author

@araffin araffin Nov 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's python 3.9 syntax, from typing import Dict is deprecated, you can use dict[] directly now (and with python 3.10, there is the | syntax for union.

I used ruff check --select UP --fix --unsafe-fixes to automatically upgrade everything.

There is some failure for python 3.12, I will have a look

Note: with python 3.8, dict[] would raise an error

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to know, thanks!

# The paper size ('letterpaper' or 'a4paper').
#
# 'papersize': 'letterpaper',
Expand Down
2 changes: 1 addition & 1 deletion docs/guide/install.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Installation
Prerequisites
-------------

Stable-Baselines3 requires python 3.8+ and PyTorch >= 1.13
Stable-Baselines3 requires python 3.9+ and PyTorch >= 2.3

Windows
~~~~~~~
Expand Down
2 changes: 2 additions & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ RL Baselines3 Zoo provides a collection of pre-trained agents, scripts for train

SB3 Contrib (experimental RL code, latest algorithms): https://github.com/Stable-Baselines-Team/stable-baselines3-contrib

SBX (SB3 + Jax): https://github.com/araffin/sbx


Main Features
--------------
Expand Down
35 changes: 35 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,41 @@
Changelog
==========

Release 2.5.0a0 (WIP)
--------------------------

Breaking Changes:
^^^^^^^^^^^^^^^^^
- Increased minimum required version of PyTorch to 2.3.0
- Removed support for Python 3.8

New Features:
^^^^^^^^^^^^^
- Added support for NumPy v2.0: ``VecNormalize`` now cast normalized rewards to float32, updated bit flipping env to avoid overflow issues too
- Added official support for Python 3.12

Bug Fixes:
^^^^^^^^^^

`SB3-Contrib`_
^^^^^^^^^^^^^^

`RL Zoo`_
^^^^^^^^^

`SBX`_ (SB3 + Jax)
^^^^^^^^^^^^^^^^^^

Deprecations:
^^^^^^^^^^^^^

Others:
^^^^^^^

Documentation:
^^^^^^^^^^^^^^


Release 2.4.0 (2024-11-18)
--------------------------

Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
[tool.ruff]
# Same as Black.
line-length = 127
# Assume Python 3.8
target-version = "py38"
# Assume Python 3.9
target-version = "py39"

[tool.ruff.lint]
# See https://beta.ruff.rs/docs/rules/
Expand Down
8 changes: 4 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@
package_data={"stable_baselines3": ["py.typed", "version.txt"]},
install_requires=[
"gymnasium>=0.29.1,<1.1.0",
"numpy>=1.20,<2.0", # PyTorch not compatible https://github.com/pytorch/pytorch/issues/107302
"torch>=1.13",
"numpy>=1.20,<3.0",
"torch>=2.3,<3.0",
# For saving models
"cloudpickle",
# For reading logs
Expand Down Expand Up @@ -135,7 +135,7 @@
long_description=long_description,
long_description_content_type="text/markdown",
version=__version__,
python_requires=">=3.8",
python_requires=">=3.9",
# PyPI package information.
project_urls={
"Code": "https://github.com/DLR-RM/stable-baselines3",
Expand All @@ -147,10 +147,10 @@
},
classifiers=[
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
],
)

Expand Down
12 changes: 6 additions & 6 deletions stable_baselines3/a2c/a2c.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union
from typing import Any, ClassVar, Optional, TypeVar, Union

import torch as th
from gymnasium import spaces
Expand Down Expand Up @@ -57,15 +57,15 @@ class A2C(OnPolicyAlgorithm):
:param _init_setup_model: Whether or not to build the network at the creation of the instance
"""

policy_aliases: ClassVar[Dict[str, Type[BasePolicy]]] = {
policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = {
"MlpPolicy": ActorCriticPolicy,
"CnnPolicy": ActorCriticCnnPolicy,
"MultiInputPolicy": MultiInputActorCriticPolicy,
}

def __init__(
self,
policy: Union[str, Type[ActorCriticPolicy]],
policy: Union[str, type[ActorCriticPolicy]],
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule] = 7e-4,
n_steps: int = 5,
Expand All @@ -78,12 +78,12 @@ def __init__(
use_rms_prop: bool = True,
use_sde: bool = False,
sde_sample_freq: int = -1,
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
rollout_buffer_class: Optional[type[RolloutBuffer]] = None,
rollout_buffer_kwargs: Optional[dict[str, Any]] = None,
normalize_advantage: bool = False,
stats_window_size: int = 100,
tensorboard_log: Optional[str] = None,
policy_kwargs: Optional[Dict[str, Any]] = None,
policy_kwargs: Optional[dict[str, Any]] = None,
verbose: int = 0,
seed: Optional[int] = None,
device: Union[th.device, str] = "auto",
Expand Down
4 changes: 2 additions & 2 deletions stable_baselines3/common/atari_wrappers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, SupportsFloat
from typing import SupportsFloat

import gymnasium as gym
import numpy as np
Expand Down Expand Up @@ -64,7 +64,7 @@ def reset(self, **kwargs) -> AtariResetReturn:
noops = self.unwrapped.np_random.integers(1, self.noop_max + 1)
assert noops > 0
obs = np.zeros(0)
info: Dict = {}
info: dict = {}
for _ in range(noops):
obs, _, terminated, truncated, info = self.env.step(self.noop_action)
if terminated or truncated:
Expand Down
39 changes: 20 additions & 19 deletions stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import warnings
from abc import ABC, abstractmethod
from collections import deque
from typing import Any, ClassVar, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union
from collections.abc import Iterable
from typing import Any, ClassVar, Optional, TypeVar, Union

import gymnasium as gym
import numpy as np
Expand Down Expand Up @@ -94,7 +95,7 @@ class BaseAlgorithm(ABC):
"""

# Policy aliases (see _get_policy_from_name())
policy_aliases: ClassVar[Dict[str, Type[BasePolicy]]] = {}
policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = {}
policy: BasePolicy
observation_space: spaces.Space
action_space: spaces.Space
Expand All @@ -104,10 +105,10 @@ class BaseAlgorithm(ABC):

def __init__(
self,
policy: Union[str, Type[BasePolicy]],
policy: Union[str, type[BasePolicy]],
env: Union[GymEnv, str, None],
learning_rate: Union[float, Schedule],
policy_kwargs: Optional[Dict[str, Any]] = None,
policy_kwargs: Optional[dict[str, Any]] = None,
stats_window_size: int = 100,
tensorboard_log: Optional[str] = None,
verbose: int = 0,
Expand All @@ -117,7 +118,7 @@ def __init__(
seed: Optional[int] = None,
use_sde: bool = False,
sde_sample_freq: int = -1,
supported_action_spaces: Optional[Tuple[Type[spaces.Space], ...]] = None,
supported_action_spaces: Optional[tuple[type[spaces.Space], ...]] = None,
) -> None:
if isinstance(policy, str):
self.policy_class = self._get_policy_from_name(policy)
Expand All @@ -141,10 +142,10 @@ def __init__(
self.start_time = 0.0
self.learning_rate = learning_rate
self.tensorboard_log = tensorboard_log
self._last_obs = None # type: Optional[Union[np.ndarray, Dict[str, np.ndarray]]]
self._last_obs = None # type: Optional[Union[np.ndarray, dict[str, np.ndarray]]]
self._last_episode_starts = None # type: Optional[np.ndarray]
# When using VecNormalize:
self._last_original_obs = None # type: Optional[Union[np.ndarray, Dict[str, np.ndarray]]]
self._last_original_obs = None # type: Optional[Union[np.ndarray, dict[str, np.ndarray]]]
self._episode_num = 0
# Used for gSDE only
self.use_sde = use_sde
Expand Down Expand Up @@ -283,7 +284,7 @@ def _update_current_progress_remaining(self, num_timesteps: int, total_timesteps
"""
self._current_progress_remaining = 1.0 - float(num_timesteps) / float(total_timesteps)

def _update_learning_rate(self, optimizers: Union[List[th.optim.Optimizer], th.optim.Optimizer]) -> None:
def _update_learning_rate(self, optimizers: Union[list[th.optim.Optimizer], th.optim.Optimizer]) -> None:
"""
Update the optimizers learning rate using the current learning rate schedule
and the current progress remaining (from 1 to 0).
Expand All @@ -299,7 +300,7 @@ def _update_learning_rate(self, optimizers: Union[List[th.optim.Optimizer], th.o
for optimizer in optimizers:
update_learning_rate(optimizer, self.lr_schedule(self._current_progress_remaining))

def _excluded_save_params(self) -> List[str]:
def _excluded_save_params(self) -> list[str]:
"""
Returns the names of the parameters that should be excluded from being
saved by pickling. E.g. replay buffers are skipped by default
Expand All @@ -320,7 +321,7 @@ def _excluded_save_params(self) -> List[str]:
"_custom_logger",
]

def _get_policy_from_name(self, policy_name: str) -> Type[BasePolicy]:
def _get_policy_from_name(self, policy_name: str) -> type[BasePolicy]:
"""
Get a policy class from its name representation.

Expand All @@ -337,7 +338,7 @@ def _get_policy_from_name(self, policy_name: str) -> Type[BasePolicy]:
else:
raise ValueError(f"Policy {policy_name} unknown")

def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
def _get_torch_save_params(self) -> tuple[list[str], list[str]]:
"""
Get the name of the torch variables that will be saved with
PyTorch ``th.save``, ``th.load`` and ``state_dicts`` instead of the default
Expand Down Expand Up @@ -387,7 +388,7 @@ def _setup_learn(
reset_num_timesteps: bool = True,
tb_log_name: str = "run",
progress_bar: bool = False,
) -> Tuple[int, BaseCallback]:
) -> tuple[int, BaseCallback]:
"""
Initialize different variables needed for training.

Expand Down Expand Up @@ -435,7 +436,7 @@ def _setup_learn(

return total_timesteps, callback

def _update_info_buffer(self, infos: List[Dict[str, Any]], dones: Optional[np.ndarray] = None) -> None:
def _update_info_buffer(self, infos: list[dict[str, Any]], dones: Optional[np.ndarray] = None) -> None:
"""
Retrieve reward, episode length, episode success and update the buffer
if using Monitor wrapper or a GoalEnv.
Expand Down Expand Up @@ -535,11 +536,11 @@ def learn(

def predict(
self,
observation: Union[np.ndarray, Dict[str, np.ndarray]],
state: Optional[Tuple[np.ndarray, ...]] = None,
observation: Union[np.ndarray, dict[str, np.ndarray]],
state: Optional[tuple[np.ndarray, ...]] = None,
episode_start: Optional[np.ndarray] = None,
deterministic: bool = False,
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
) -> tuple[np.ndarray, Optional[tuple[np.ndarray, ...]]]:
"""
Get the policy action from an observation (and optional hidden state).
Includes sugar-coating to handle different observations (e.g. normalizing images).
Expand Down Expand Up @@ -640,11 +641,11 @@ def set_parameters(

@classmethod
def load( # noqa: C901
cls: Type[SelfBaseAlgorithm],
cls: type[SelfBaseAlgorithm],
path: Union[str, pathlib.Path, io.BufferedIOBase],
env: Optional[GymEnv] = None,
device: Union[th.device, str] = "auto",
custom_objects: Optional[Dict[str, Any]] = None,
custom_objects: Optional[dict[str, Any]] = None,
print_system_info: bool = False,
force_reset: bool = True,
**kwargs,
Expand Down Expand Up @@ -800,7 +801,7 @@ def load( # noqa: C901
model.policy.reset_noise() # type: ignore[operator]
return model

def get_parameters(self) -> Dict[str, Dict]:
def get_parameters(self) -> dict[str, dict]:
"""
Return the parameters of the agent. This includes parameters from different networks, e.g.
critics (value functions) and policies (pi functions).
Expand Down
Loading
Loading