diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d34a93c9a..a0077a367 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -20,7 +20,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] + python-version: ["3.9", "3.10", "3.11", "3.12"] include: # Default version - gymnasium-version: "1.0.0" @@ -48,7 +48,8 @@ jobs: - name: Install specific version of gym run: | uv pip install --system gymnasium==${{ matrix.gymnasium-version }} - # Only run for python 3.10, downgrade gym to 0.29.1 + uv pip install --system "numpy<2" + # Only run for python 3.10, downgrade gym to 0.29.1, numpy<2 if: matrix.gymnasium-version != '1.0.0' - name: Lint with ruff run: | @@ -62,8 +63,6 @@ jobs: - name: Type check run: | make type - # Do not run for python 3.8 (mypy internal error) - if: matrix.python-version != '3.8' - name: Test with pytest run: | make pytest diff --git a/README.md b/README.md index 5d25781d9..9ae78b239 100644 --- a/README.md +++ b/README.md @@ -100,10 +100,10 @@ It provides a minimal number of features compared to SB3 but can be much faster ## Installation -**Note:** Stable-Baselines3 supports PyTorch >= 1.13 +**Note:** Stable-Baselines3 supports PyTorch >= 2.3 ### Prerequisites -Stable Baselines3 requires Python 3.8+. +Stable Baselines3 requires Python 3.9+. #### Windows diff --git a/docs/conda_env.yml b/docs/conda_env.yml index c9b1392d2..ee491017b 100644 --- a/docs/conda_env.yml +++ b/docs/conda_env.yml @@ -12,7 +12,7 @@ dependencies: - cloudpickle - opencv-python-headless - pandas - - numpy>=1.20,<2.0 + - numpy>=1.20,<3.0 - matplotlib - sphinx>=5,<9 - sphinx_rtd_theme>=1.3.0 diff --git a/docs/conf.py b/docs/conf.py index bd6365701..7e0555e57 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -14,7 +14,6 @@ import datetime import os import sys -from typing import Dict # We CANNOT enable 'sphinxcontrib.spelling' because ReadTheDocs.org does not support # PyEnchant. @@ -151,7 +150,7 @@ def setup(app): # -- Options for LaTeX output ------------------------------------------------ -latex_elements: Dict[str, str] = { +latex_elements: dict[str, str] = { # The paper size ('letterpaper' or 'a4paper'). # # 'papersize': 'letterpaper', diff --git a/docs/guide/install.rst b/docs/guide/install.rst index 587234b00..4bdd0a007 100644 --- a/docs/guide/install.rst +++ b/docs/guide/install.rst @@ -7,7 +7,7 @@ Installation Prerequisites ------------- -Stable-Baselines3 requires python 3.8+ and PyTorch >= 1.13 +Stable-Baselines3 requires python 3.9+ and PyTorch >= 2.3 Windows ~~~~~~~ diff --git a/docs/index.rst b/docs/index.rst index d74120c41..6b6018b42 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -20,6 +20,8 @@ RL Baselines3 Zoo provides a collection of pre-trained agents, scripts for train SB3 Contrib (experimental RL code, latest algorithms): https://github.com/Stable-Baselines-Team/stable-baselines3-contrib +SBX (SB3 + Jax): https://github.com/araffin/sbx + Main Features -------------- diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 43ba33e48..e9f3416b1 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,6 +3,41 @@ Changelog ========== +Release 2.5.0a0 (WIP) +-------------------------- + +Breaking Changes: +^^^^^^^^^^^^^^^^^ +- Increased minimum required version of PyTorch to 2.3.0 +- Removed support for Python 3.8 + +New Features: +^^^^^^^^^^^^^ +- Added support for NumPy v2.0: ``VecNormalize`` now cast normalized rewards to float32, updated bit flipping env to avoid overflow issues too +- Added official support for Python 3.12 + +Bug Fixes: +^^^^^^^^^^ + +`SB3-Contrib`_ +^^^^^^^^^^^^^^ + +`RL Zoo`_ +^^^^^^^^^ + +`SBX`_ (SB3 + Jax) +^^^^^^^^^^^^^^^^^^ + +Deprecations: +^^^^^^^^^^^^^ + +Others: +^^^^^^^ + +Documentation: +^^^^^^^^^^^^^^ + + Release 2.4.0 (2024-11-18) -------------------------- diff --git a/pyproject.toml b/pyproject.toml index 1fd1a1890..89af5a67f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,8 +1,8 @@ [tool.ruff] # Same as Black. line-length = 127 -# Assume Python 3.8 -target-version = "py38" +# Assume Python 3.9 +target-version = "py39" [tool.ruff.lint] # See https://beta.ruff.rs/docs/rules/ diff --git a/setup.py b/setup.py index 15c02cb57..fa24fc8a3 100644 --- a/setup.py +++ b/setup.py @@ -77,8 +77,8 @@ package_data={"stable_baselines3": ["py.typed", "version.txt"]}, install_requires=[ "gymnasium>=0.29.1,<1.1.0", - "numpy>=1.20,<2.0", # PyTorch not compatible https://github.com/pytorch/pytorch/issues/107302 - "torch>=1.13", + "numpy>=1.20,<3.0", + "torch>=2.3,<3.0", # For saving models "cloudpickle", # For reading logs @@ -135,7 +135,7 @@ long_description=long_description, long_description_content_type="text/markdown", version=__version__, - python_requires=">=3.8", + python_requires=">=3.9", # PyPI package information. project_urls={ "Code": "https://github.com/DLR-RM/stable-baselines3", @@ -147,10 +147,10 @@ }, classifiers=[ "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", ], ) diff --git a/stable_baselines3/a2c/a2c.py b/stable_baselines3/a2c/a2c.py index 718571f0c..a125aaef6 100644 --- a/stable_baselines3/a2c/a2c.py +++ b/stable_baselines3/a2c/a2c.py @@ -1,4 +1,4 @@ -from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union +from typing import Any, ClassVar, Optional, TypeVar, Union import torch as th from gymnasium import spaces @@ -57,7 +57,7 @@ class A2C(OnPolicyAlgorithm): :param _init_setup_model: Whether or not to build the network at the creation of the instance """ - policy_aliases: ClassVar[Dict[str, Type[BasePolicy]]] = { + policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = { "MlpPolicy": ActorCriticPolicy, "CnnPolicy": ActorCriticCnnPolicy, "MultiInputPolicy": MultiInputActorCriticPolicy, @@ -65,7 +65,7 @@ class A2C(OnPolicyAlgorithm): def __init__( self, - policy: Union[str, Type[ActorCriticPolicy]], + policy: Union[str, type[ActorCriticPolicy]], env: Union[GymEnv, str], learning_rate: Union[float, Schedule] = 7e-4, n_steps: int = 5, @@ -78,12 +78,12 @@ def __init__( use_rms_prop: bool = True, use_sde: bool = False, sde_sample_freq: int = -1, - rollout_buffer_class: Optional[Type[RolloutBuffer]] = None, - rollout_buffer_kwargs: Optional[Dict[str, Any]] = None, + rollout_buffer_class: Optional[type[RolloutBuffer]] = None, + rollout_buffer_kwargs: Optional[dict[str, Any]] = None, normalize_advantage: bool = False, stats_window_size: int = 100, tensorboard_log: Optional[str] = None, - policy_kwargs: Optional[Dict[str, Any]] = None, + policy_kwargs: Optional[dict[str, Any]] = None, verbose: int = 0, seed: Optional[int] = None, device: Union[th.device, str] = "auto", diff --git a/stable_baselines3/common/atari_wrappers.py b/stable_baselines3/common/atari_wrappers.py index bbdba9a3d..83a64a5c7 100644 --- a/stable_baselines3/common/atari_wrappers.py +++ b/stable_baselines3/common/atari_wrappers.py @@ -1,4 +1,4 @@ -from typing import Dict, SupportsFloat +from typing import SupportsFloat import gymnasium as gym import numpy as np @@ -64,7 +64,7 @@ def reset(self, **kwargs) -> AtariResetReturn: noops = self.unwrapped.np_random.integers(1, self.noop_max + 1) assert noops > 0 obs = np.zeros(0) - info: Dict = {} + info: dict = {} for _ in range(noops): obs, _, terminated, truncated, info = self.env.step(self.noop_action) if terminated or truncated: diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index e43955f94..412f9dda2 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -6,7 +6,8 @@ import warnings from abc import ABC, abstractmethod from collections import deque -from typing import Any, ClassVar, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union +from collections.abc import Iterable +from typing import Any, ClassVar, Optional, TypeVar, Union import gymnasium as gym import numpy as np @@ -94,7 +95,7 @@ class BaseAlgorithm(ABC): """ # Policy aliases (see _get_policy_from_name()) - policy_aliases: ClassVar[Dict[str, Type[BasePolicy]]] = {} + policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = {} policy: BasePolicy observation_space: spaces.Space action_space: spaces.Space @@ -104,10 +105,10 @@ class BaseAlgorithm(ABC): def __init__( self, - policy: Union[str, Type[BasePolicy]], + policy: Union[str, type[BasePolicy]], env: Union[GymEnv, str, None], learning_rate: Union[float, Schedule], - policy_kwargs: Optional[Dict[str, Any]] = None, + policy_kwargs: Optional[dict[str, Any]] = None, stats_window_size: int = 100, tensorboard_log: Optional[str] = None, verbose: int = 0, @@ -117,7 +118,7 @@ def __init__( seed: Optional[int] = None, use_sde: bool = False, sde_sample_freq: int = -1, - supported_action_spaces: Optional[Tuple[Type[spaces.Space], ...]] = None, + supported_action_spaces: Optional[tuple[type[spaces.Space], ...]] = None, ) -> None: if isinstance(policy, str): self.policy_class = self._get_policy_from_name(policy) @@ -141,10 +142,10 @@ def __init__( self.start_time = 0.0 self.learning_rate = learning_rate self.tensorboard_log = tensorboard_log - self._last_obs = None # type: Optional[Union[np.ndarray, Dict[str, np.ndarray]]] + self._last_obs = None # type: Optional[Union[np.ndarray, dict[str, np.ndarray]]] self._last_episode_starts = None # type: Optional[np.ndarray] # When using VecNormalize: - self._last_original_obs = None # type: Optional[Union[np.ndarray, Dict[str, np.ndarray]]] + self._last_original_obs = None # type: Optional[Union[np.ndarray, dict[str, np.ndarray]]] self._episode_num = 0 # Used for gSDE only self.use_sde = use_sde @@ -283,7 +284,7 @@ def _update_current_progress_remaining(self, num_timesteps: int, total_timesteps """ self._current_progress_remaining = 1.0 - float(num_timesteps) / float(total_timesteps) - def _update_learning_rate(self, optimizers: Union[List[th.optim.Optimizer], th.optim.Optimizer]) -> None: + def _update_learning_rate(self, optimizers: Union[list[th.optim.Optimizer], th.optim.Optimizer]) -> None: """ Update the optimizers learning rate using the current learning rate schedule and the current progress remaining (from 1 to 0). @@ -299,7 +300,7 @@ def _update_learning_rate(self, optimizers: Union[List[th.optim.Optimizer], th.o for optimizer in optimizers: update_learning_rate(optimizer, self.lr_schedule(self._current_progress_remaining)) - def _excluded_save_params(self) -> List[str]: + def _excluded_save_params(self) -> list[str]: """ Returns the names of the parameters that should be excluded from being saved by pickling. E.g. replay buffers are skipped by default @@ -320,7 +321,7 @@ def _excluded_save_params(self) -> List[str]: "_custom_logger", ] - def _get_policy_from_name(self, policy_name: str) -> Type[BasePolicy]: + def _get_policy_from_name(self, policy_name: str) -> type[BasePolicy]: """ Get a policy class from its name representation. @@ -337,7 +338,7 @@ def _get_policy_from_name(self, policy_name: str) -> Type[BasePolicy]: else: raise ValueError(f"Policy {policy_name} unknown") - def _get_torch_save_params(self) -> Tuple[List[str], List[str]]: + def _get_torch_save_params(self) -> tuple[list[str], list[str]]: """ Get the name of the torch variables that will be saved with PyTorch ``th.save``, ``th.load`` and ``state_dicts`` instead of the default @@ -387,7 +388,7 @@ def _setup_learn( reset_num_timesteps: bool = True, tb_log_name: str = "run", progress_bar: bool = False, - ) -> Tuple[int, BaseCallback]: + ) -> tuple[int, BaseCallback]: """ Initialize different variables needed for training. @@ -435,7 +436,7 @@ def _setup_learn( return total_timesteps, callback - def _update_info_buffer(self, infos: List[Dict[str, Any]], dones: Optional[np.ndarray] = None) -> None: + def _update_info_buffer(self, infos: list[dict[str, Any]], dones: Optional[np.ndarray] = None) -> None: """ Retrieve reward, episode length, episode success and update the buffer if using Monitor wrapper or a GoalEnv. @@ -535,11 +536,11 @@ def learn( def predict( self, - observation: Union[np.ndarray, Dict[str, np.ndarray]], - state: Optional[Tuple[np.ndarray, ...]] = None, + observation: Union[np.ndarray, dict[str, np.ndarray]], + state: Optional[tuple[np.ndarray, ...]] = None, episode_start: Optional[np.ndarray] = None, deterministic: bool = False, - ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: + ) -> tuple[np.ndarray, Optional[tuple[np.ndarray, ...]]]: """ Get the policy action from an observation (and optional hidden state). Includes sugar-coating to handle different observations (e.g. normalizing images). @@ -640,11 +641,11 @@ def set_parameters( @classmethod def load( # noqa: C901 - cls: Type[SelfBaseAlgorithm], + cls: type[SelfBaseAlgorithm], path: Union[str, pathlib.Path, io.BufferedIOBase], env: Optional[GymEnv] = None, device: Union[th.device, str] = "auto", - custom_objects: Optional[Dict[str, Any]] = None, + custom_objects: Optional[dict[str, Any]] = None, print_system_info: bool = False, force_reset: bool = True, **kwargs, @@ -800,7 +801,7 @@ def load( # noqa: C901 model.policy.reset_noise() # type: ignore[operator] return model - def get_parameters(self) -> Dict[str, Dict]: + def get_parameters(self) -> dict[str, dict]: """ Return the parameters of the agent. This includes parameters from different networks, e.g. critics (value functions) and policies (pi functions). diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index b2fc5a710..004adae90 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -1,6 +1,7 @@ import warnings from abc import ABC, abstractmethod -from typing import Any, Dict, Generator, List, Optional, Tuple, Union +from collections.abc import Generator +from typing import Any, Optional, Union import numpy as np import torch as th @@ -36,7 +37,7 @@ class BaseBuffer(ABC): """ observation_space: spaces.Space - obs_shape: Tuple[int, ...] + obs_shape: tuple[int, ...] def __init__( self, @@ -140,9 +141,9 @@ def to_torch(self, array: np.ndarray, copy: bool = True) -> th.Tensor: @staticmethod def _normalize_obs( - obs: Union[np.ndarray, Dict[str, np.ndarray]], + obs: Union[np.ndarray, dict[str, np.ndarray]], env: Optional[VecNormalize] = None, - ) -> Union[np.ndarray, Dict[str, np.ndarray]]: + ) -> Union[np.ndarray, dict[str, np.ndarray]]: if env is not None: return env.normalize_obs(obs) return obs @@ -250,7 +251,7 @@ def add( action: np.ndarray, reward: np.ndarray, done: np.ndarray, - infos: List[Dict[str, Any]], + infos: list[dict[str, Any]], ) -> None: # Reshape needed when using multiple envs with discrete observations # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1) @@ -538,9 +539,9 @@ class DictReplayBuffer(ReplayBuffer): """ observation_space: spaces.Dict - obs_shape: Dict[str, Tuple[int, ...]] # type: ignore[assignment] - observations: Dict[str, np.ndarray] # type: ignore[assignment] - next_observations: Dict[str, np.ndarray] # type: ignore[assignment] + obs_shape: dict[str, tuple[int, ...]] # type: ignore[assignment] + observations: dict[str, np.ndarray] # type: ignore[assignment] + next_observations: dict[str, np.ndarray] # type: ignore[assignment] def __init__( self, @@ -609,12 +610,12 @@ def __init__( def add( # type: ignore[override] self, - obs: Dict[str, np.ndarray], - next_obs: Dict[str, np.ndarray], + obs: dict[str, np.ndarray], + next_obs: dict[str, np.ndarray], action: np.ndarray, reward: np.ndarray, done: np.ndarray, - infos: List[Dict[str, Any]], + infos: list[dict[str, Any]], ) -> None: # Copy to avoid modification by reference for key in self.observations.keys(): @@ -718,8 +719,8 @@ class DictRolloutBuffer(RolloutBuffer): """ observation_space: spaces.Dict - obs_shape: Dict[str, Tuple[int, ...]] # type: ignore[assignment] - observations: Dict[str, np.ndarray] # type: ignore[assignment] + obs_shape: dict[str, tuple[int, ...]] # type: ignore[assignment] + observations: dict[str, np.ndarray] # type: ignore[assignment] def __init__( self, @@ -757,7 +758,7 @@ def reset(self) -> None: def add( # type: ignore[override] self, - obs: Dict[str, np.ndarray], + obs: dict[str, np.ndarray], action: np.ndarray, reward: np.ndarray, episode_start: np.ndarray, diff --git a/stable_baselines3/common/callbacks.py b/stable_baselines3/common/callbacks.py index c7841866b..31c3a24a7 100644 --- a/stable_baselines3/common/callbacks.py +++ b/stable_baselines3/common/callbacks.py @@ -1,7 +1,7 @@ import os import warnings from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, Union import gymnasium as gym import numpy as np @@ -45,8 +45,8 @@ def __init__(self, verbose: int = 0): # n_envs * n times env.step() was called self.num_timesteps = 0 # type: int self.verbose = verbose - self.locals: Dict[str, Any] = {} - self.globals: Dict[str, Any] = {} + self.locals: dict[str, Any] = {} + self.globals: dict[str, Any] = {} # Sometimes, for event callback, it is useful # to have access to the parent object self.parent = None # type: Optional[BaseCallback] @@ -75,7 +75,7 @@ def init_callback(self, model: "base_class.BaseAlgorithm") -> None: def _init_callback(self) -> None: pass - def on_training_start(self, locals_: Dict[str, Any], globals_: Dict[str, Any]) -> None: + def on_training_start(self, locals_: dict[str, Any], globals_: dict[str, Any]) -> None: # Those are reference and will be updated automatically self.locals = locals_ self.globals = globals_ @@ -125,7 +125,7 @@ def on_rollout_end(self) -> None: def _on_rollout_end(self) -> None: pass - def update_locals(self, locals_: Dict[str, Any]) -> None: + def update_locals(self, locals_: dict[str, Any]) -> None: """ Update the references to the local variables. @@ -134,7 +134,7 @@ def update_locals(self, locals_: Dict[str, Any]) -> None: self.locals.update(locals_) self.update_child_locals(locals_) - def update_child_locals(self, locals_: Dict[str, Any]) -> None: + def update_child_locals(self, locals_: dict[str, Any]) -> None: """ Update the references to the local variables on sub callbacks. @@ -177,7 +177,7 @@ def _on_event(self) -> bool: def _on_step(self) -> bool: return True - def update_child_locals(self, locals_: Dict[str, Any]) -> None: + def update_child_locals(self, locals_: dict[str, Any]) -> None: """ Update the references to the local variables. @@ -195,7 +195,7 @@ class CallbackList(BaseCallback): sequentially. """ - def __init__(self, callbacks: List[BaseCallback]): + def __init__(self, callbacks: list[BaseCallback]): super().__init__() assert isinstance(callbacks, list) self.callbacks = callbacks @@ -231,7 +231,7 @@ def _on_training_end(self) -> None: for callback in self.callbacks: callback.on_training_end() - def update_child_locals(self, locals_: Dict[str, Any]) -> None: + def update_child_locals(self, locals_: dict[str, Any]) -> None: """ Update the references to the local variables. @@ -328,7 +328,7 @@ class ConvertCallback(BaseCallback): :param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages """ - def __init__(self, callback: Optional[Callable[[Dict[str, Any], Dict[str, Any]], bool]], verbose: int = 0): + def __init__(self, callback: Optional[Callable[[dict[str, Any], dict[str, Any]], bool]], verbose: int = 0): super().__init__(verbose) self.callback = callback @@ -405,12 +405,12 @@ def __init__( if log_path is not None: log_path = os.path.join(log_path, "evaluations") self.log_path = log_path - self.evaluations_results: List[List[float]] = [] - self.evaluations_timesteps: List[int] = [] - self.evaluations_length: List[List[int]] = [] + self.evaluations_results: list[list[float]] = [] + self.evaluations_timesteps: list[int] = [] + self.evaluations_length: list[list[int]] = [] # For computing success rate - self._is_success_buffer: List[bool] = [] - self.evaluations_successes: List[List[bool]] = [] + self._is_success_buffer: list[bool] = [] + self.evaluations_successes: list[list[bool]] = [] def _init_callback(self) -> None: # Does not work in some corner cases, where the wrapper is not the same @@ -427,7 +427,7 @@ def _init_callback(self) -> None: if self.callback_on_new_best is not None: self.callback_on_new_best.init_callback(self.model) - def _log_success_callback(self, locals_: Dict[str, Any], globals_: Dict[str, Any]) -> None: + def _log_success_callback(self, locals_: dict[str, Any], globals_: dict[str, Any]) -> None: """ Callback passed to the ``evaluate_policy`` function in order to log the success rate (when applicable), @@ -530,7 +530,7 @@ def _on_step(self) -> bool: return continue_training - def update_child_locals(self, locals_: Dict[str, Any]) -> None: + def update_child_locals(self, locals_: dict[str, Any]) -> None: """ Update the references to the local variables. diff --git a/stable_baselines3/common/distributions.py b/stable_baselines3/common/distributions.py index 132a35348..380898a50 100644 --- a/stable_baselines3/common/distributions.py +++ b/stable_baselines3/common/distributions.py @@ -1,7 +1,7 @@ """Probability distributions.""" from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union +from typing import Any, Optional, TypeVar, Union import numpy as np import torch as th @@ -30,7 +30,7 @@ def __init__(self): self.distribution = None @abstractmethod - def proba_distribution_net(self, *args, **kwargs) -> Union[nn.Module, Tuple[nn.Module, nn.Parameter]]: + def proba_distribution_net(self, *args, **kwargs) -> Union[nn.Module, tuple[nn.Module, nn.Parameter]]: """Create the layers and parameters that represent the distribution. Subclasses must define this, but the arguments and return type vary between @@ -98,7 +98,7 @@ def actions_from_params(self, *args, **kwargs) -> th.Tensor: """ @abstractmethod - def log_prob_from_params(self, *args, **kwargs) -> Tuple[th.Tensor, th.Tensor]: + def log_prob_from_params(self, *args, **kwargs) -> tuple[th.Tensor, th.Tensor]: """ Returns samples and the associated log probabilities from the probability distribution given its parameters. @@ -135,7 +135,7 @@ def __init__(self, action_dim: int): self.mean_actions = None self.log_std = None - def proba_distribution_net(self, latent_dim: int, log_std_init: float = 0.0) -> Tuple[nn.Module, nn.Parameter]: + def proba_distribution_net(self, latent_dim: int, log_std_init: float = 0.0) -> tuple[nn.Module, nn.Parameter]: """ Create the layers and parameter that represent the distribution: one output will be the mean of the Gaussian, the other parameter will be the @@ -190,7 +190,7 @@ def actions_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, deter self.proba_distribution(mean_actions, log_std) return self.get_actions(deterministic=deterministic) - def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: + def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> tuple[th.Tensor, th.Tensor]: """ Compute the log probability of taking an action given the distribution parameters. @@ -254,7 +254,7 @@ def mode(self) -> th.Tensor: # Squash the output return th.tanh(self.gaussian_actions) - def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: + def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> tuple[th.Tensor, th.Tensor]: action = self.actions_from_params(mean_actions, log_std) log_prob = self.log_prob(action, self.gaussian_actions) return action, log_prob @@ -305,7 +305,7 @@ def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = Fa self.proba_distribution(action_logits) return self.get_actions(deterministic=deterministic) - def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: + def log_prob_from_params(self, action_logits: th.Tensor) -> tuple[th.Tensor, th.Tensor]: actions = self.actions_from_params(action_logits) log_prob = self.log_prob(actions) return actions, log_prob @@ -318,7 +318,7 @@ class MultiCategoricalDistribution(Distribution): :param action_dims: List of sizes of discrete action spaces """ - def __init__(self, action_dims: List[int]): + def __init__(self, action_dims: list[int]): super().__init__() self.action_dims = action_dims @@ -362,7 +362,7 @@ def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = Fa self.proba_distribution(action_logits) return self.get_actions(deterministic=deterministic) - def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: + def log_prob_from_params(self, action_logits: th.Tensor) -> tuple[th.Tensor, th.Tensor]: actions = self.actions_from_params(action_logits) log_prob = self.log_prob(actions) return actions, log_prob @@ -412,7 +412,7 @@ def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = Fa self.proba_distribution(action_logits) return self.get_actions(deterministic=deterministic) - def log_prob_from_params(self, action_logits: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: + def log_prob_from_params(self, action_logits: th.Tensor) -> tuple[th.Tensor, th.Tensor]: actions = self.actions_from_params(action_logits) log_prob = self.log_prob(actions) return actions, log_prob @@ -513,7 +513,7 @@ def sample_weights(self, log_std: th.Tensor, batch_size: int = 1) -> None: def proba_distribution_net( self, latent_dim: int, log_std_init: float = -2.0, latent_sde_dim: Optional[int] = None - ) -> Tuple[nn.Module, nn.Parameter]: + ) -> tuple[nn.Module, nn.Parameter]: """ Create the layers and parameter that represent the distribution: one output will be the deterministic action, the other parameter will be the @@ -611,7 +611,7 @@ def actions_from_params( def log_prob_from_params( self, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor - ) -> Tuple[th.Tensor, th.Tensor]: + ) -> tuple[th.Tensor, th.Tensor]: actions = self.actions_from_params(mean_actions, log_std, latent_sde) log_prob = self.log_prob(actions) return actions, log_prob @@ -661,7 +661,7 @@ def log_prob_correction(self, x: th.Tensor) -> th.Tensor: def make_proba_distribution( - action_space: spaces.Space, use_sde: bool = False, dist_kwargs: Optional[Dict[str, Any]] = None + action_space: spaces.Space, use_sde: bool = False, dist_kwargs: Optional[dict[str, Any]] = None ) -> Distribution: """ Return an instance of Distribution for the correct type of action space diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index e47dd123a..0310bcfe7 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -1,5 +1,5 @@ import warnings -from typing import Any, Dict, Union +from typing import Any, Union import gymnasium as gym import numpy as np @@ -172,10 +172,10 @@ def _check_goal_env_obs(obs: dict, observation_space: spaces.Dict, method_name: def _check_goal_env_compute_reward( - obs: Dict[str, Union[np.ndarray, int]], + obs: dict[str, Union[np.ndarray, int]], env: gym.Env, reward: float, - info: Dict[str, Any], + info: dict[str, Any], ) -> None: """ Check that reward is computed with `compute_reward` diff --git a/stable_baselines3/common/env_util.py b/stable_baselines3/common/env_util.py index 0132c32f8..bbe281f27 100644 --- a/stable_baselines3/common/env_util.py +++ b/stable_baselines3/common/env_util.py @@ -1,5 +1,5 @@ import os -from typing import Any, Callable, Dict, Optional, Type, Union +from typing import Any, Callable, Optional, Union import gymnasium as gym @@ -9,7 +9,7 @@ from stable_baselines3.common.vec_env.patch_gym import _patch_env -def unwrap_wrapper(env: gym.Env, wrapper_class: Type[gym.Wrapper]) -> Optional[gym.Wrapper]: +def unwrap_wrapper(env: gym.Env, wrapper_class: type[gym.Wrapper]) -> Optional[gym.Wrapper]: """ Retrieve a ``VecEnvWrapper`` object by recursively searching. @@ -25,7 +25,7 @@ def unwrap_wrapper(env: gym.Env, wrapper_class: Type[gym.Wrapper]) -> Optional[g return None -def is_wrapped(env: gym.Env, wrapper_class: Type[gym.Wrapper]) -> bool: +def is_wrapped(env: gym.Env, wrapper_class: type[gym.Wrapper]) -> bool: """ Check if a given environment has been wrapped with a given wrapper. @@ -43,11 +43,11 @@ def make_vec_env( start_index: int = 0, monitor_dir: Optional[str] = None, wrapper_class: Optional[Callable[[gym.Env], gym.Env]] = None, - env_kwargs: Optional[Dict[str, Any]] = None, - vec_env_cls: Optional[Type[Union[DummyVecEnv, SubprocVecEnv]]] = None, - vec_env_kwargs: Optional[Dict[str, Any]] = None, - monitor_kwargs: Optional[Dict[str, Any]] = None, - wrapper_kwargs: Optional[Dict[str, Any]] = None, + env_kwargs: Optional[dict[str, Any]] = None, + vec_env_cls: Optional[type[Union[DummyVecEnv, SubprocVecEnv]]] = None, + vec_env_kwargs: Optional[dict[str, Any]] = None, + monitor_kwargs: Optional[dict[str, Any]] = None, + wrapper_kwargs: Optional[dict[str, Any]] = None, ) -> VecEnv: """ Create a wrapped, monitored ``VecEnv``. @@ -134,11 +134,11 @@ def make_atari_env( seed: Optional[int] = None, start_index: int = 0, monitor_dir: Optional[str] = None, - wrapper_kwargs: Optional[Dict[str, Any]] = None, - env_kwargs: Optional[Dict[str, Any]] = None, - vec_env_cls: Optional[Union[Type[DummyVecEnv], Type[SubprocVecEnv]]] = None, - vec_env_kwargs: Optional[Dict[str, Any]] = None, - monitor_kwargs: Optional[Dict[str, Any]] = None, + wrapper_kwargs: Optional[dict[str, Any]] = None, + env_kwargs: Optional[dict[str, Any]] = None, + vec_env_cls: Optional[Union[type[DummyVecEnv], type[SubprocVecEnv]]] = None, + vec_env_kwargs: Optional[dict[str, Any]] = None, + monitor_kwargs: Optional[dict[str, Any]] = None, ) -> VecEnv: """ Create a wrapped, monitored VecEnv for Atari. diff --git a/stable_baselines3/common/envs/bit_flipping_env.py b/stable_baselines3/common/envs/bit_flipping_env.py index 3ea0c7bb0..4d99313ea 100644 --- a/stable_baselines3/common/envs/bit_flipping_env.py +++ b/stable_baselines3/common/envs/bit_flipping_env.py @@ -1,5 +1,5 @@ from collections import OrderedDict -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Optional, Union import numpy as np from gymnasium import Env, spaces @@ -75,14 +75,17 @@ def convert_if_needed(self, state: np.ndarray) -> Union[int, np.ndarray]: :param state: :return: """ + if self.discrete_obs_space: + # Convert from int8 to int32 for NumPy 2.0 + state = state.astype(np.int32) # The internal state is the binary representation of the # observed one return int(sum(state[i] * 2**i for i in range(len(state)))) if self.image_obs_space: size = np.prod(self.image_shape) - image = np.concatenate((state * 255, np.zeros(size - len(state), dtype=np.uint8))) + image = np.concatenate((state.astype(np.uint8) * 255, np.zeros(size - len(state), dtype=np.uint8))) return image.reshape(self.image_shape).astype(np.uint8) return state @@ -163,7 +166,7 @@ def _make_observation_space(self, discrete_obs_space: bool, image_obs_space: boo } ) - def _get_obs(self) -> Dict[str, Union[int, np.ndarray]]: + def _get_obs(self) -> dict[str, Union[int, np.ndarray]]: """ Helper to create the observation. @@ -178,8 +181,8 @@ def _get_obs(self) -> Dict[str, Union[int, np.ndarray]]: ) def reset( - self, *, seed: Optional[int] = None, options: Optional[Dict] = None - ) -> Tuple[Dict[str, Union[int, np.ndarray]], Dict]: + self, *, seed: Optional[int] = None, options: Optional[dict] = None + ) -> tuple[dict[str, Union[int, np.ndarray]], dict]: if seed is not None: self._obs_space.seed(seed) self.current_step = 0 @@ -207,7 +210,7 @@ def step(self, action: Union[np.ndarray, int]) -> GymStepReturn: return obs, reward, terminated, truncated, info def compute_reward( - self, achieved_goal: Union[int, np.ndarray], desired_goal: Union[int, np.ndarray], _info: Optional[Dict[str, Any]] + self, achieved_goal: Union[int, np.ndarray], desired_goal: Union[int, np.ndarray], _info: Optional[dict[str, Any]] ) -> np.float32: # As we are using a vectorized version, we need to keep track of the `batch_size` if isinstance(achieved_goal, int): diff --git a/stable_baselines3/common/envs/identity_env.py b/stable_baselines3/common/envs/identity_env.py index 99a664999..0c5610446 100644 --- a/stable_baselines3/common/envs/identity_env.py +++ b/stable_baselines3/common/envs/identity_env.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Generic, Optional, Tuple, TypeVar, Union +from typing import Any, Generic, Optional, TypeVar, Union import gymnasium as gym import numpy as np @@ -34,7 +34,7 @@ def __init__(self, dim: Optional[int] = None, space: Optional[spaces.Space] = No self.num_resets = -1 # Becomes 0 after __init__ exits. self.reset() - def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[T, Dict]: + def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None) -> tuple[T, dict]: if seed is not None: super().reset(seed=seed) self.current_step = 0 @@ -42,7 +42,7 @@ def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None) - self._choose_next_state() return self.state, {} - def step(self, action: T) -> Tuple[T, float, bool, bool, Dict[str, Any]]: + def step(self, action: T) -> tuple[T, float, bool, bool, dict[str, Any]]: reward = self._get_reward(action) self._choose_next_state() self.current_step += 1 @@ -74,7 +74,7 @@ def __init__(self, low: float = -1.0, high: float = 1.0, eps: float = 0.05, ep_l super().__init__(ep_length=ep_length, space=space) self.eps = eps - def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, bool, Dict[str, Any]]: + def step(self, action: np.ndarray) -> tuple[np.ndarray, float, bool, bool, dict[str, Any]]: reward = self._get_reward(action) self._choose_next_state() self.current_step += 1 @@ -142,7 +142,7 @@ def __init__( self.ep_length = 10 self.current_step = 0 - def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[np.ndarray, Dict]: + def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None) -> tuple[np.ndarray, dict]: if seed is not None: super().reset(seed=seed) self.current_step = 0 diff --git a/stable_baselines3/common/envs/multi_input_envs.py b/stable_baselines3/common/envs/multi_input_envs.py index f34d13b7c..8749f82f3 100644 --- a/stable_baselines3/common/envs/multi_input_envs.py +++ b/stable_baselines3/common/envs/multi_input_envs.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Tuple, Union +from typing import Optional, Union import gymnasium as gym import numpy as np @@ -73,7 +73,7 @@ def __init__( self.init_possible_transitions() self.num_col = num_col - self.state_mapping: List[Dict[str, np.ndarray]] = [] + self.state_mapping: list[dict[str, np.ndarray]] = [] self.init_state_mapping(num_col, num_row) self.max_state = len(self.state_mapping) - 1 @@ -94,7 +94,7 @@ def init_state_mapping(self, num_col: int, num_row: int) -> None: for j in range(num_row): self.state_mapping.append({"vec": col_vecs[i], "img": row_imgs[j].reshape(self.img_size)}) - def get_state_mapping(self) -> Dict[str, np.ndarray]: + def get_state_mapping(self) -> dict[str, np.ndarray]: """ Uses the state to get the observation mapping. @@ -166,7 +166,7 @@ def render(self, mode: str = "human") -> None: """ print(self.log) - def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[Dict[str, np.ndarray], Dict]: + def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None) -> tuple[dict[str, np.ndarray], dict]: """ Resets the environment state and step count and returns reset observation. diff --git a/stable_baselines3/common/evaluation.py b/stable_baselines3/common/evaluation.py index c9253a899..e66448b51 100644 --- a/stable_baselines3/common/evaluation.py +++ b/stable_baselines3/common/evaluation.py @@ -1,5 +1,5 @@ import warnings -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Optional, Union import gymnasium as gym import numpy as np @@ -14,11 +14,11 @@ def evaluate_policy( n_eval_episodes: int = 10, deterministic: bool = True, render: bool = False, - callback: Optional[Callable[[Dict[str, Any], Dict[str, Any]], None]] = None, + callback: Optional[Callable[[dict[str, Any], dict[str, Any]], None]] = None, reward_threshold: Optional[float] = None, return_episode_rewards: bool = False, warn: bool = True, -) -> Union[Tuple[float, float], Tuple[List[float], List[int]]]: +) -> Union[tuple[float, float], tuple[list[float], list[int]]]: """ Runs policy for ``n_eval_episodes`` episodes and returns average reward. If a vector env is passed in, this divides the episodes to evaluate onto the diff --git a/stable_baselines3/common/logger.py b/stable_baselines3/common/logger.py index 8ceda71ed..8d707cba5 100644 --- a/stable_baselines3/common/logger.py +++ b/stable_baselines3/common/logger.py @@ -5,8 +5,9 @@ import tempfile import warnings from collections import defaultdict +from collections.abc import Mapping, Sequence from io import TextIOBase -from typing import Any, Dict, List, Mapping, Optional, Sequence, TextIO, Tuple, Union +from typing import Any, Optional, TextIO, Union import matplotlib.figure import numpy as np @@ -114,7 +115,7 @@ class KVWriter: Key Value writer """ - def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], step: int = 0) -> None: + def write(self, key_values: dict[str, Any], key_excluded: dict[str, tuple[str, ...]], step: int = 0) -> None: """ Write a dictionary to file @@ -136,7 +137,7 @@ class SeqWriter: sequence writer """ - def write_sequence(self, sequence: List[str]) -> None: + def write_sequence(self, sequence: list[str]) -> None: """ write_sequence an array to file @@ -172,7 +173,7 @@ def __init__(self, filename_or_file: Union[str, TextIO], max_length: int = 36): else: raise ValueError(f"Expected file or str, got {filename_or_file}") - def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], step: int = 0) -> None: + def write(self, key_values: dict[str, Any], key_excluded: dict[str, tuple[str, ...]], step: int = 0) -> None: # Create strings for printing key2str = {} tag = "" @@ -244,7 +245,7 @@ def _truncate(self, string: str) -> str: string = string[: self.max_length - 3] + "..." return string - def write_sequence(self, sequence: List[str]) -> None: + def write_sequence(self, sequence: list[str]) -> None: for i, elem in enumerate(sequence): self.file.write(elem) if i < len(sequence) - 1: # add space unless this is the last one @@ -260,7 +261,7 @@ def close(self) -> None: self.file.close() -def filter_excluded_keys(key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], _format: str) -> Dict[str, Any]: +def filter_excluded_keys(key_values: dict[str, Any], key_excluded: dict[str, tuple[str, ...]], _format: str) -> dict[str, Any]: """ Filters the keys specified by ``key_exclude`` for the specified format @@ -286,7 +287,7 @@ class JSONOutputFormat(KVWriter): def __init__(self, filename: str): self.file = open(filename, "w") - def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], step: int = 0) -> None: + def write(self, key_values: dict[str, Any], key_excluded: dict[str, tuple[str, ...]], step: int = 0) -> None: def cast_to_json_serializable(value: Any): if isinstance(value, Video): raise FormatUnsupportedError(["json"], "video") @@ -328,12 +329,12 @@ class CSVOutputFormat(KVWriter): """ def __init__(self, filename: str): - self.file = open(filename, "w+t") - self.keys: List[str] = [] + self.file = open(filename, "w+") + self.keys: list[str] = [] self.separator = "," self.quotechar = '"' - def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], step: int = 0) -> None: + def write(self, key_values: dict[str, Any], key_excluded: dict[str, tuple[str, ...]], step: int = 0) -> None: # Add our current row to the history key_values = filter_excluded_keys(key_values, key_excluded, "csv") extra_keys = key_values.keys() - self.keys @@ -399,7 +400,7 @@ def __init__(self, folder: str): self.writer = SummaryWriter(log_dir=folder) self._is_closed = False - def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Tuple[str, ...]], step: int = 0) -> None: + def write(self, key_values: dict[str, Any], key_excluded: dict[str, tuple[str, ...]], step: int = 0) -> None: assert not self._is_closed, "The SummaryWriter was closed, please re-create one." for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items())): if excluded is not None and "tensorboard" in excluded: @@ -481,16 +482,16 @@ class Logger: :param output_formats: the list of output formats """ - def __init__(self, folder: Optional[str], output_formats: List[KVWriter]): - self.name_to_value: Dict[str, float] = defaultdict(float) # values this iteration - self.name_to_count: Dict[str, int] = defaultdict(int) - self.name_to_excluded: Dict[str, Tuple[str, ...]] = {} + def __init__(self, folder: Optional[str], output_formats: list[KVWriter]): + self.name_to_value: dict[str, float] = defaultdict(float) # values this iteration + self.name_to_count: dict[str, int] = defaultdict(int) + self.name_to_excluded: dict[str, tuple[str, ...]] = {} self.level = INFO self.dir = folder self.output_formats = output_formats @staticmethod - def to_tuple(string_or_tuple: Optional[Union[str, Tuple[str, ...]]]) -> Tuple[str, ...]: + def to_tuple(string_or_tuple: Optional[Union[str, tuple[str, ...]]]) -> tuple[str, ...]: """ Helper function to convert str to tuple of str. """ @@ -500,7 +501,7 @@ def to_tuple(string_or_tuple: Optional[Union[str, Tuple[str, ...]]]) -> Tuple[st return string_or_tuple return (string_or_tuple,) - def record(self, key: str, value: Any, exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None: + def record(self, key: str, value: Any, exclude: Optional[Union[str, tuple[str, ...]]] = None) -> None: """ Log a value of some diagnostic Call this once for each diagnostic quantity, each iteration @@ -513,7 +514,7 @@ def record(self, key: str, value: Any, exclude: Optional[Union[str, Tuple[str, . self.name_to_value[key] = value self.name_to_excluded[key] = self.to_tuple(exclude) - def record_mean(self, key: str, value: Optional[float], exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None: + def record_mean(self, key: str, value: Optional[float], exclude: Optional[Union[str, tuple[str, ...]]] = None) -> None: """ The same as record(), but if called many times, values averaged. @@ -624,7 +625,7 @@ def close(self) -> None: # Misc # ---------------------------------------- - def _do_log(self, args: Tuple[Any, ...]) -> None: + def _do_log(self, args: tuple[Any, ...]) -> None: """ log to the requested format outputs @@ -635,7 +636,7 @@ def _do_log(self, args: Tuple[Any, ...]) -> None: _format.write_sequence(list(map(str, args))) -def configure(folder: Optional[str] = None, format_strings: Optional[List[str]] = None) -> Logger: +def configure(folder: Optional[str] = None, format_strings: Optional[list[str]] = None) -> Logger: """ Configure the current logger. diff --git a/stable_baselines3/common/monitor.py b/stable_baselines3/common/monitor.py index fb8ce33c6..80dfd4668 100644 --- a/stable_baselines3/common/monitor.py +++ b/stable_baselines3/common/monitor.py @@ -5,7 +5,7 @@ import os import time from glob import glob -from typing import Any, Dict, List, Optional, SupportsFloat, Tuple, Union +from typing import Any, Optional, SupportsFloat, Union import gymnasium as gym import pandas @@ -33,8 +33,8 @@ def __init__( env: gym.Env, filename: Optional[str] = None, allow_early_resets: bool = True, - reset_keywords: Tuple[str, ...] = (), - info_keywords: Tuple[str, ...] = (), + reset_keywords: tuple[str, ...] = (), + info_keywords: tuple[str, ...] = (), override_existing: bool = True, ): super().__init__(env=env) @@ -52,16 +52,16 @@ def __init__( self.reset_keywords = reset_keywords self.info_keywords = info_keywords self.allow_early_resets = allow_early_resets - self.rewards: List[float] = [] + self.rewards: list[float] = [] self.needs_reset = True - self.episode_returns: List[float] = [] - self.episode_lengths: List[int] = [] - self.episode_times: List[float] = [] + self.episode_returns: list[float] = [] + self.episode_lengths: list[int] = [] + self.episode_times: list[float] = [] self.total_steps = 0 # extra info about the current episode, that was passed in during reset() - self.current_reset_info: Dict[str, Any] = {} + self.current_reset_info: dict[str, Any] = {} - def reset(self, **kwargs) -> Tuple[ObsType, Dict[str, Any]]: + def reset(self, **kwargs) -> tuple[ObsType, dict[str, Any]]: """ Calls the Gym environment reset. Can only be called if the environment is over, or if allow_early_resets is True @@ -82,7 +82,7 @@ def reset(self, **kwargs) -> Tuple[ObsType, Dict[str, Any]]: self.current_reset_info[key] = value return self.env.reset(**kwargs) - def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: + def step(self, action: ActType) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]: """ Step the environment with the given action @@ -126,7 +126,7 @@ def get_total_steps(self) -> int: """ return self.total_steps - def get_episode_rewards(self) -> List[float]: + def get_episode_rewards(self) -> list[float]: """ Returns the rewards of all the episodes @@ -134,7 +134,7 @@ def get_episode_rewards(self) -> List[float]: """ return self.episode_returns - def get_episode_lengths(self) -> List[int]: + def get_episode_lengths(self) -> list[int]: """ Returns the number of timesteps of all the episodes @@ -142,7 +142,7 @@ def get_episode_lengths(self) -> List[int]: """ return self.episode_lengths - def get_episode_times(self) -> List[float]: + def get_episode_times(self) -> list[float]: """ Returns the runtime in seconds of all the episodes @@ -175,8 +175,8 @@ class ResultsWriter: def __init__( self, filename: str = "", - header: Optional[Dict[str, Union[float, str]]] = None, - extra_keys: Tuple[str, ...] = (), + header: Optional[dict[str, Union[float, str]]] = None, + extra_keys: tuple[str, ...] = (), override_existing: bool = True, ): if header is None: @@ -200,7 +200,7 @@ def __init__( self.file_handler.flush() - def write_row(self, epinfo: Dict[str, float]) -> None: + def write_row(self, epinfo: dict[str, float]) -> None: """ Write row of monitor data to csv log file. @@ -217,7 +217,7 @@ def close(self) -> None: self.file_handler.close() -def get_monitor_files(path: str) -> List[str]: +def get_monitor_files(path: str) -> list[str]: """ get all the monitor files in the given path diff --git a/stable_baselines3/common/noise.py b/stable_baselines3/common/noise.py index 01670e6e4..991cd23ea 100644 --- a/stable_baselines3/common/noise.py +++ b/stable_baselines3/common/noise.py @@ -1,6 +1,7 @@ import copy from abc import ABC, abstractmethod -from typing import Iterable, List, Optional +from collections.abc import Iterable +from typing import Optional import numpy as np from numpy.typing import DTypeLike @@ -153,11 +154,11 @@ def base_noise(self, base_noise: ActionNoise) -> None: self._base_noise = base_noise @property - def noises(self) -> List[ActionNoise]: + def noises(self) -> list[ActionNoise]: return self._noises @noises.setter - def noises(self, noises: List[ActionNoise]) -> None: + def noises(self, noises: list[ActionNoise]) -> None: noises = list(noises) # raises TypeError if not iterable assert len(noises) == self.n_envs, f"Expected a list of {self.n_envs} ActionNoises, found {len(noises)}." diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index c460d0236..6a043e7ac 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -4,7 +4,7 @@ import time import warnings from copy import deepcopy -from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union +from typing import Any, Optional, TypeVar, Union import numpy as np import torch as th @@ -79,7 +79,7 @@ class OffPolicyAlgorithm(BaseAlgorithm): def __init__( self, - policy: Union[str, Type[BasePolicy]], + policy: Union[str, type[BasePolicy]], env: Union[GymEnv, str], learning_rate: Union[float, Schedule], buffer_size: int = 1_000_000, # 1e6 @@ -87,13 +87,13 @@ def __init__( batch_size: int = 256, tau: float = 0.005, gamma: float = 0.99, - train_freq: Union[int, Tuple[int, str]] = (1, "step"), + train_freq: Union[int, tuple[int, str]] = (1, "step"), gradient_steps: int = 1, action_noise: Optional[ActionNoise] = None, - replay_buffer_class: Optional[Type[ReplayBuffer]] = None, - replay_buffer_kwargs: Optional[Dict[str, Any]] = None, + replay_buffer_class: Optional[type[ReplayBuffer]] = None, + replay_buffer_kwargs: Optional[dict[str, Any]] = None, optimize_memory_usage: bool = False, - policy_kwargs: Optional[Dict[str, Any]] = None, + policy_kwargs: Optional[dict[str, Any]] = None, stats_window_size: int = 100, tensorboard_log: Optional[str] = None, verbose: int = 0, @@ -105,7 +105,7 @@ def __init__( sde_sample_freq: int = -1, use_sde_at_warmup: bool = False, sde_support: bool = True, - supported_action_spaces: Optional[Tuple[Type[spaces.Space], ...]] = None, + supported_action_spaces: Optional[tuple[type[spaces.Space], ...]] = None, ): super().__init__( policy=policy, @@ -256,7 +256,7 @@ def _setup_learn( reset_num_timesteps: bool = True, tb_log_name: str = "run", progress_bar: bool = False, - ) -> Tuple[int, BaseCallback]: + ) -> tuple[int, BaseCallback]: """ cf `BaseAlgorithm`. """ @@ -362,7 +362,7 @@ def _sample_action( learning_starts: int, action_noise: Optional[ActionNoise] = None, n_envs: int = 1, - ) -> Tuple[np.ndarray, np.ndarray]: + ) -> tuple[np.ndarray, np.ndarray]: """ Sample an action according to the exploration policy. This is either done by sampling the probability distribution of the policy, @@ -442,10 +442,10 @@ def _store_transition( self, replay_buffer: ReplayBuffer, buffer_action: np.ndarray, - new_obs: Union[np.ndarray, Dict[str, np.ndarray]], + new_obs: Union[np.ndarray, dict[str, np.ndarray]], reward: np.ndarray, dones: np.ndarray, - infos: List[Dict[str, Any]], + infos: list[dict[str, Any]], ) -> None: """ Store transition in the replay buffer. diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index dc885242e..ac4c0970c 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -1,7 +1,7 @@ import sys import time import warnings -from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union +from typing import Any, Optional, TypeVar, Union import numpy as np import torch as th @@ -60,7 +60,7 @@ class OnPolicyAlgorithm(BaseAlgorithm): def __init__( self, - policy: Union[str, Type[ActorCriticPolicy]], + policy: Union[str, type[ActorCriticPolicy]], env: Union[GymEnv, str], learning_rate: Union[float, Schedule], n_steps: int, @@ -71,17 +71,17 @@ def __init__( max_grad_norm: float, use_sde: bool, sde_sample_freq: int, - rollout_buffer_class: Optional[Type[RolloutBuffer]] = None, - rollout_buffer_kwargs: Optional[Dict[str, Any]] = None, + rollout_buffer_class: Optional[type[RolloutBuffer]] = None, + rollout_buffer_kwargs: Optional[dict[str, Any]] = None, stats_window_size: int = 100, tensorboard_log: Optional[str] = None, monitor_wrapper: bool = True, - policy_kwargs: Optional[Dict[str, Any]] = None, + policy_kwargs: Optional[dict[str, Any]] = None, verbose: int = 0, seed: Optional[int] = None, device: Union[th.device, str] = "auto", _init_setup_model: bool = True, - supported_action_spaces: Optional[Tuple[Type[spaces.Space], ...]] = None, + supported_action_spaces: Optional[tuple[type[spaces.Space], ...]] = None, ): super().__init__( policy=policy, @@ -339,7 +339,7 @@ def learn( return self - def _get_torch_save_params(self) -> Tuple[List[str], List[str]]: + def _get_torch_save_params(self) -> tuple[list[str], list[str]]: state_dicts = ["policy", "policy.optimizer"] return state_dicts, [] diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index f9c4285dc..e20256f0c 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -5,7 +5,7 @@ import warnings from abc import ABC, abstractmethod from functools import partial -from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union +from typing import Any, Optional, TypeVar, Union import numpy as np import torch as th @@ -64,12 +64,12 @@ def __init__( self, observation_space: spaces.Space, action_space: spaces.Space, - features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, - features_extractor_kwargs: Optional[Dict[str, Any]] = None, + features_extractor_class: type[BaseFeaturesExtractor] = FlattenExtractor, + features_extractor_kwargs: Optional[dict[str, Any]] = None, features_extractor: Optional[BaseFeaturesExtractor] = None, normalize_images: bool = True, - optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, + optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[dict[str, Any]] = None, ): super().__init__() @@ -95,9 +95,9 @@ def __init__( def _update_features_extractor( self, - net_kwargs: Dict[str, Any], + net_kwargs: dict[str, Any], features_extractor: Optional[BaseFeaturesExtractor] = None, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Update the network keyword arguments and create a new features extractor object if needed. If a ``features_extractor`` object is passed, then it will be shared. @@ -130,7 +130,7 @@ def extract_features(self, obs: PyTorchObs, features_extractor: BaseFeaturesExtr preprocessed_obs = preprocess_obs(obs, self.observation_space, normalize_images=self.normalize_images) return features_extractor(preprocessed_obs) - def _get_constructor_parameters(self) -> Dict[str, Any]: + def _get_constructor_parameters(self) -> dict[str, Any]: """ Get data that need to be saved in order to re-create the model when loading it from disk. @@ -164,7 +164,7 @@ def save(self, path: str) -> None: th.save({"state_dict": self.state_dict(), "data": self._get_constructor_parameters()}, path) @classmethod - def load(cls: Type[SelfBaseModel], path: str, device: Union[th.device, str] = "auto") -> SelfBaseModel: + def load(cls: type[SelfBaseModel], path: str, device: Union[th.device, str] = "auto") -> SelfBaseModel: """ Load model from path. @@ -210,7 +210,7 @@ def set_training_mode(self, mode: bool) -> None: """ self.train(mode) - def is_vectorized_observation(self, observation: Union[np.ndarray, Dict[str, np.ndarray]]) -> bool: + def is_vectorized_observation(self, observation: Union[np.ndarray, dict[str, np.ndarray]]) -> bool: """ Check whether or not the observation is vectorized, apply transposition to image (so that they are channel-first) if needed. @@ -233,7 +233,7 @@ def is_vectorized_observation(self, observation: Union[np.ndarray, Dict[str, np. ) return vectorized_env - def obs_to_tensor(self, observation: Union[np.ndarray, Dict[str, np.ndarray]]) -> Tuple[PyTorchObs, bool]: + def obs_to_tensor(self, observation: Union[np.ndarray, dict[str, np.ndarray]]) -> tuple[PyTorchObs, bool]: """ Convert an input observation to a PyTorch tensor that can be fed to a model. Includes sugar-coating to handle different observations (e.g. normalizing images). @@ -330,11 +330,11 @@ def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.T def predict( self, - observation: Union[np.ndarray, Dict[str, np.ndarray]], - state: Optional[Tuple[np.ndarray, ...]] = None, + observation: Union[np.ndarray, dict[str, np.ndarray]], + state: Optional[tuple[np.ndarray, ...]] = None, episode_start: Optional[np.ndarray] = None, deterministic: bool = False, - ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: + ) -> tuple[np.ndarray, Optional[tuple[np.ndarray, ...]]]: """ Get the policy action from an observation (and optional hidden state). Includes sugar-coating to handle different observations (e.g. normalizing images). @@ -450,20 +450,20 @@ def __init__( observation_space: spaces.Space, action_space: spaces.Space, lr_schedule: Schedule, - net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, - activation_fn: Type[nn.Module] = nn.Tanh, + net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None, + activation_fn: type[nn.Module] = nn.Tanh, ortho_init: bool = True, use_sde: bool = False, log_std_init: float = 0.0, full_std: bool = True, use_expln: bool = False, squash_output: bool = False, - features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, - features_extractor_kwargs: Optional[Dict[str, Any]] = None, + features_extractor_class: type[BaseFeaturesExtractor] = FlattenExtractor, + features_extractor_kwargs: Optional[dict[str, Any]] = None, share_features_extractor: bool = True, normalize_images: bool = True, - optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, + optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[dict[str, Any]] = None, ): if optimizer_kwargs is None: optimizer_kwargs = {} @@ -534,7 +534,7 @@ def __init__( self._build(lr_schedule) - def _get_constructor_parameters(self) -> Dict[str, Any]: + def _get_constructor_parameters(self) -> dict[str, Any]: data = super()._get_constructor_parameters() default_none_kwargs = self.dist_kwargs or collections.defaultdict(lambda: None) # type: ignore[arg-type, return-value] @@ -633,7 +633,7 @@ def _build(self, lr_schedule: Schedule) -> None: # Setup optimizer with initial learning rate self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) # type: ignore[call-arg] - def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: + def forward(self, obs: th.Tensor, deterministic: bool = False) -> tuple[th.Tensor, th.Tensor, th.Tensor]: """ Forward pass in all the networks (actor and critic) @@ -659,7 +659,7 @@ def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tenso def extract_features( # type: ignore[override] self, obs: PyTorchObs, features_extractor: Optional[BaseFeaturesExtractor] = None - ) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]: + ) -> Union[th.Tensor, tuple[th.Tensor, th.Tensor]]: """ Preprocess the observation if needed and extract features. @@ -716,7 +716,7 @@ def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.T """ return self.get_distribution(observation).get_actions(deterministic=deterministic) - def evaluate_actions(self, obs: PyTorchObs, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Optional[th.Tensor]]: + def evaluate_actions(self, obs: PyTorchObs, actions: th.Tensor) -> tuple[th.Tensor, th.Tensor, Optional[th.Tensor]]: """ Evaluate actions according to the current policy, given the observations. @@ -800,20 +800,20 @@ def __init__( observation_space: spaces.Space, action_space: spaces.Space, lr_schedule: Schedule, - net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, - activation_fn: Type[nn.Module] = nn.Tanh, + net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None, + activation_fn: type[nn.Module] = nn.Tanh, ortho_init: bool = True, use_sde: bool = False, log_std_init: float = 0.0, full_std: bool = True, use_expln: bool = False, squash_output: bool = False, - features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN, - features_extractor_kwargs: Optional[Dict[str, Any]] = None, + features_extractor_class: type[BaseFeaturesExtractor] = NatureCNN, + features_extractor_kwargs: Optional[dict[str, Any]] = None, share_features_extractor: bool = True, normalize_images: bool = True, - optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, + optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[dict[str, Any]] = None, ): super().__init__( observation_space, @@ -873,20 +873,20 @@ def __init__( observation_space: spaces.Dict, action_space: spaces.Space, lr_schedule: Schedule, - net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, - activation_fn: Type[nn.Module] = nn.Tanh, + net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None, + activation_fn: type[nn.Module] = nn.Tanh, ortho_init: bool = True, use_sde: bool = False, log_std_init: float = 0.0, full_std: bool = True, use_expln: bool = False, squash_output: bool = False, - features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor, - features_extractor_kwargs: Optional[Dict[str, Any]] = None, + features_extractor_class: type[BaseFeaturesExtractor] = CombinedExtractor, + features_extractor_kwargs: Optional[dict[str, Any]] = None, share_features_extractor: bool = True, normalize_images: bool = True, - optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, + optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[dict[str, Any]] = None, ): super().__init__( observation_space, @@ -942,10 +942,10 @@ def __init__( self, observation_space: spaces.Space, action_space: spaces.Box, - net_arch: List[int], + net_arch: list[int], features_extractor: BaseFeaturesExtractor, features_dim: int, - activation_fn: Type[nn.Module] = nn.ReLU, + activation_fn: type[nn.Module] = nn.ReLU, normalize_images: bool = True, n_critics: int = 2, share_features_extractor: bool = True, @@ -961,14 +961,14 @@ def __init__( self.share_features_extractor = share_features_extractor self.n_critics = n_critics - self.q_networks: List[nn.Module] = [] + self.q_networks: list[nn.Module] = [] for idx in range(n_critics): q_net_list = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn) q_net = nn.Sequential(*q_net_list) self.add_module(f"qf{idx}", q_net) self.q_networks.append(q_net) - def forward(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, ...]: + def forward(self, obs: th.Tensor, actions: th.Tensor) -> tuple[th.Tensor, ...]: # Learn the features extractor using the policy loss only # when the features_extractor is shared with the actor with th.set_grad_enabled(not self.share_features_extractor): diff --git a/stable_baselines3/common/preprocessing.py b/stable_baselines3/common/preprocessing.py index d0bfbcd1e..a35f8b76f 100644 --- a/stable_baselines3/common/preprocessing.py +++ b/stable_baselines3/common/preprocessing.py @@ -1,5 +1,5 @@ import warnings -from typing import Dict, Tuple, Union +from typing import Union import numpy as np import torch as th @@ -90,10 +90,10 @@ def maybe_transpose(observation: np.ndarray, observation_space: spaces.Space) -> def preprocess_obs( - obs: Union[th.Tensor, Dict[str, th.Tensor]], + obs: Union[th.Tensor, dict[str, th.Tensor]], observation_space: spaces.Space, normalize_images: bool = True, -) -> Union[th.Tensor, Dict[str, th.Tensor]]: +) -> Union[th.Tensor, dict[str, th.Tensor]]: """ Preprocess observation to be to a neural network. For images, it normalizes the values by dividing them by 255 (to have values in [0, 1]) @@ -107,7 +107,7 @@ def preprocess_obs( """ if isinstance(observation_space, spaces.Dict): # Do not modify by reference the original observation - assert isinstance(obs, Dict), f"Expected dict, got {type(obs)}" + assert isinstance(obs, dict), f"Expected dict, got {type(obs)}" preprocessed_obs = {} for key, _obs in obs.items(): preprocessed_obs[key] = preprocess_obs(_obs, observation_space[key], normalize_images=normalize_images) @@ -142,7 +142,7 @@ def preprocess_obs( def get_obs_shape( observation_space: spaces.Space, -) -> Union[Tuple[int, ...], Dict[str, Tuple[int, ...]]]: +) -> Union[tuple[int, ...], dict[str, tuple[int, ...]]]: """ Get the shape of the observation (useful for the buffers). diff --git a/stable_baselines3/common/results_plotter.py b/stable_baselines3/common/results_plotter.py index f4c1a7a05..f09a54e58 100644 --- a/stable_baselines3/common/results_plotter.py +++ b/stable_baselines3/common/results_plotter.py @@ -1,4 +1,4 @@ -from typing import Callable, List, Optional, Tuple +from typing import Callable, Optional import numpy as np import pandas as pd @@ -29,7 +29,7 @@ def rolling_window(array: np.ndarray, window: int) -> np.ndarray: return np.lib.stride_tricks.as_strided(array, shape=shape, strides=strides) -def window_func(var_1: np.ndarray, var_2: np.ndarray, window: int, func: Callable) -> Tuple[np.ndarray, np.ndarray]: +def window_func(var_1: np.ndarray, var_2: np.ndarray, window: int, func: Callable) -> tuple[np.ndarray, np.ndarray]: """ Apply a function to the rolling window of 2 arrays @@ -44,7 +44,7 @@ def window_func(var_1: np.ndarray, var_2: np.ndarray, window: int, func: Callabl return var_1[window - 1 :], function_on_var2 -def ts2xy(data_frame: pd.DataFrame, x_axis: str) -> Tuple[np.ndarray, np.ndarray]: +def ts2xy(data_frame: pd.DataFrame, x_axis: str) -> tuple[np.ndarray, np.ndarray]: """ Decompose a data frame variable to x and ys @@ -69,7 +69,7 @@ def ts2xy(data_frame: pd.DataFrame, x_axis: str) -> Tuple[np.ndarray, np.ndarray def plot_curves( - xy_list: List[Tuple[np.ndarray, np.ndarray]], x_axis: str, title: str, figsize: Tuple[int, int] = (8, 2) + xy_list: list[tuple[np.ndarray, np.ndarray]], x_axis: str, title: str, figsize: tuple[int, int] = (8, 2) ) -> None: """ plot the curves @@ -99,7 +99,7 @@ def plot_curves( def plot_results( - dirs: List[str], num_timesteps: Optional[int], x_axis: str, task_name: str, figsize: Tuple[int, int] = (8, 2) + dirs: list[str], num_timesteps: Optional[int], x_axis: str, task_name: str, figsize: tuple[int, int] = (8, 2) ) -> None: """ Plot the results using csv files from ``Monitor`` wrapper. diff --git a/stable_baselines3/common/running_mean_std.py b/stable_baselines3/common/running_mean_std.py index ac3538c50..c8f03b212 100644 --- a/stable_baselines3/common/running_mean_std.py +++ b/stable_baselines3/common/running_mean_std.py @@ -1,10 +1,8 @@ -from typing import Tuple - import numpy as np class RunningMeanStd: - def __init__(self, epsilon: float = 1e-4, shape: Tuple[int, ...] = ()): + def __init__(self, epsilon: float = 1e-4, shape: tuple[int, ...] = ()): """ Calculates the running mean and std of a data stream https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm diff --git a/stable_baselines3/common/save_util.py b/stable_baselines3/common/save_util.py index a85c9c2ec..8b545f898 100644 --- a/stable_baselines3/common/save_util.py +++ b/stable_baselines3/common/save_util.py @@ -12,7 +12,7 @@ import pickle import warnings import zipfile -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Optional, Union import cloudpickle import torch as th @@ -73,7 +73,7 @@ def is_json_serializable(item: Any) -> bool: return json_serializable -def data_to_json(data: Dict[str, Any]) -> str: +def data_to_json(data: dict[str, Any]) -> str: """ Turn data (class parameters) into a JSON string for storing @@ -128,7 +128,7 @@ def data_to_json(data: Dict[str, Any]) -> str: return json_string -def json_to_data(json_string: str, custom_objects: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: +def json_to_data(json_string: str, custom_objects: Optional[dict[str, Any]] = None) -> dict[str, Any]: """ Turn JSON serialization of class-parameters back into dictionary. @@ -293,9 +293,9 @@ def open_path_pathlib(path: pathlib.Path, mode: str, verbose: int = 0, suffix: O def save_to_zip_file( save_path: Union[str, pathlib.Path, io.BufferedIOBase], - data: Optional[Dict[str, Any]] = None, - params: Optional[Dict[str, Any]] = None, - pytorch_variables: Optional[Dict[str, Any]] = None, + data: Optional[dict[str, Any]] = None, + params: Optional[dict[str, Any]] = None, + pytorch_variables: Optional[dict[str, Any]] = None, verbose: int = 0, ) -> None: """ @@ -376,11 +376,11 @@ def load_from_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], verbose: in def load_from_zip_file( load_path: Union[str, pathlib.Path, io.BufferedIOBase], load_data: bool = True, - custom_objects: Optional[Dict[str, Any]] = None, + custom_objects: Optional[dict[str, Any]] = None, device: Union[th.device, str] = "auto", verbose: int = 0, print_system_info: bool = False, -) -> Tuple[Optional[Dict[str, Any]], TensorDict, Optional[TensorDict]]: +) -> tuple[Optional[dict[str, Any]], TensorDict, Optional[TensorDict]]: """ Load model data from a .zip archive diff --git a/stable_baselines3/common/sb2_compat/rmsprop_tf_like.py b/stable_baselines3/common/sb2_compat/rmsprop_tf_like.py index 25f0a6f96..036958460 100644 --- a/stable_baselines3/common/sb2_compat/rmsprop_tf_like.py +++ b/stable_baselines3/common/sb2_compat/rmsprop_tf_like.py @@ -1,4 +1,5 @@ -from typing import Any, Callable, Dict, Iterable, Optional +from collections.abc import Iterable +from typing import Any, Callable, Optional import torch from torch.optim import Optimizer @@ -67,7 +68,7 @@ def __init__( defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay) super().__init__(params, defaults) - def __setstate__(self, state: Dict[str, Any]) -> None: + def __setstate__(self, state: dict[str, Any]) -> None: super().__setstate__(state) for group in self.param_groups: group.setdefault("momentum", 0) diff --git a/stable_baselines3/common/torch_layers.py b/stable_baselines3/common/torch_layers.py index 234b91551..6c6aa2ddd 100644 --- a/stable_baselines3/common/torch_layers.py +++ b/stable_baselines3/common/torch_layers.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Tuple, Type, Union +from typing import Optional, Union import gymnasium as gym import torch as th @@ -110,13 +110,13 @@ def forward(self, observations: th.Tensor) -> th.Tensor: def create_mlp( input_dim: int, output_dim: int, - net_arch: List[int], - activation_fn: Type[nn.Module] = nn.ReLU, + net_arch: list[int], + activation_fn: type[nn.Module] = nn.ReLU, squash_output: bool = False, with_bias: bool = True, - pre_linear_modules: Optional[List[Type[nn.Module]]] = None, - post_linear_modules: Optional[List[Type[nn.Module]]] = None, -) -> List[nn.Module]: + pre_linear_modules: Optional[list[type[nn.Module]]] = None, + post_linear_modules: Optional[list[type[nn.Module]]] = None, +) -> list[nn.Module]: """ Create a multi layer perceptron (MLP), which is a collection of fully-connected layers each followed by an activation function. @@ -211,14 +211,14 @@ class MlpExtractor(nn.Module): def __init__( self, feature_dim: int, - net_arch: Union[List[int], Dict[str, List[int]]], - activation_fn: Type[nn.Module], + net_arch: Union[list[int], dict[str, list[int]]], + activation_fn: type[nn.Module], device: Union[th.device, str] = "auto", ) -> None: super().__init__() device = get_device(device) - policy_net: List[nn.Module] = [] - value_net: List[nn.Module] = [] + policy_net: list[nn.Module] = [] + value_net: list[nn.Module] = [] last_layer_dim_pi = feature_dim last_layer_dim_vf = feature_dim @@ -249,7 +249,7 @@ def __init__( self.policy_net = nn.Sequential(*policy_net).to(device) self.value_net = nn.Sequential(*value_net).to(device) - def forward(self, features: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: + def forward(self, features: th.Tensor) -> tuple[th.Tensor, th.Tensor]: """ :return: latent_policy, latent_value of the specified network. If all layers are shared, then ``latent_policy == latent_value`` @@ -288,7 +288,7 @@ def __init__( # TODO we do not know features-dim here before going over all the items, so put something there. This is dirty! super().__init__(observation_space, features_dim=1) - extractors: Dict[str, nn.Module] = {} + extractors: dict[str, nn.Module] = {} total_concat_size = 0 for key, subspace in observation_space.spaces.items(): @@ -313,7 +313,7 @@ def forward(self, observations: TensorDict) -> th.Tensor: return th.cat(encoded_tensor_list, dim=1) -def get_actor_critic_arch(net_arch: Union[List[int], Dict[str, List[int]]]) -> Tuple[List[int], List[int]]: +def get_actor_critic_arch(net_arch: Union[list[int], dict[str, list[int]]]) -> tuple[list[int], list[int]]: """ Get the actor and critic network architectures for off-policy actor-critic algorithms (SAC, TD3, DDPG). diff --git a/stable_baselines3/common/type_aliases.py b/stable_baselines3/common/type_aliases.py index 042c66f9c..b7c578ac0 100644 --- a/stable_baselines3/common/type_aliases.py +++ b/stable_baselines3/common/type_aliases.py @@ -1,7 +1,7 @@ """Common aliases for type hints""" from enum import Enum -from typing import TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, Optional, Protocol, SupportsFloat, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, NamedTuple, Optional, Protocol, SupportsFloat, Union import gymnasium as gym import numpy as np @@ -13,14 +13,14 @@ from stable_baselines3.common.vec_env import VecEnv GymEnv = Union[gym.Env, "VecEnv"] -GymObs = Union[Tuple, Dict[str, Any], np.ndarray, int] -GymResetReturn = Tuple[GymObs, Dict] -AtariResetReturn = Tuple[np.ndarray, Dict[str, Any]] -GymStepReturn = Tuple[GymObs, float, bool, bool, Dict] -AtariStepReturn = Tuple[np.ndarray, SupportsFloat, bool, bool, Dict[str, Any]] -TensorDict = Dict[str, th.Tensor] -OptimizerStateDict = Dict[str, Any] -MaybeCallback = Union[None, Callable, List["BaseCallback"], "BaseCallback"] +GymObs = Union[tuple, dict[str, Any], np.ndarray, int] +GymResetReturn = tuple[GymObs, dict] +AtariResetReturn = tuple[np.ndarray, dict[str, Any]] +GymStepReturn = tuple[GymObs, float, bool, bool, dict] +AtariStepReturn = tuple[np.ndarray, SupportsFloat, bool, bool, dict[str, Any]] +TensorDict = dict[str, th.Tensor] +OptimizerStateDict = dict[str, Any] +MaybeCallback = Union[None, Callable, list["BaseCallback"], "BaseCallback"] PyTorchObs = Union[th.Tensor, TensorDict] # A schedule takes the remaining progress as input @@ -81,11 +81,11 @@ class TrainFreq(NamedTuple): class PolicyPredictor(Protocol): def predict( self, - observation: Union[np.ndarray, Dict[str, np.ndarray]], - state: Optional[Tuple[np.ndarray, ...]] = None, + observation: Union[np.ndarray, dict[str, np.ndarray]], + state: Optional[tuple[np.ndarray, ...]] = None, episode_start: Optional[np.ndarray] = None, deterministic: bool = False, - ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: + ) -> tuple[np.ndarray, Optional[tuple[np.ndarray, ...]]]: """ Get the policy action from an observation (and optional hidden state). Includes sugar-coating to handle different observations (e.g. normalizing images). diff --git a/stable_baselines3/common/utils.py b/stable_baselines3/common/utils.py index 4e9fbc2db..7caef0501 100644 --- a/stable_baselines3/common/utils.py +++ b/stable_baselines3/common/utils.py @@ -4,8 +4,9 @@ import random import re from collections import deque +from collections.abc import Iterable from itertools import zip_longest -from typing import Dict, Iterable, List, Optional, Tuple, Union +from typing import Optional, Union import cloudpickle import gymnasium as gym @@ -415,7 +416,7 @@ def safe_mean(arr: Union[np.ndarray, list, deque]) -> float: return np.nan if len(arr) == 0 else float(np.mean(arr)) # type: ignore[arg-type] -def get_parameters_by_name(model: th.nn.Module, included_names: Iterable[str]) -> List[th.Tensor]: +def get_parameters_by_name(model: th.nn.Module, included_names: Iterable[str]) -> list[th.Tensor]: """ Extract parameters from the state dict of ``model`` if the name contains one of the strings in ``included_names``. @@ -473,7 +474,7 @@ def polyak_update( th.add(target_param.data, param.data, alpha=tau, out=target_param.data) -def obs_as_tensor(obs: Union[np.ndarray, Dict[str, np.ndarray]], device: th.device) -> Union[th.Tensor, TensorDict]: +def obs_as_tensor(obs: Union[np.ndarray, dict[str, np.ndarray]], device: th.device) -> Union[th.Tensor, TensorDict]: """ Moves the observation to the given device. @@ -517,7 +518,7 @@ def should_collect_more_steps( ) -def get_system_info(print_info: bool = True) -> Tuple[Dict[str, str], str]: +def get_system_info(print_info: bool = True) -> tuple[dict[str, str], str]: """ Retrieve system and python env info for the current system. diff --git a/stable_baselines3/common/vec_env/__init__.py b/stable_baselines3/common/vec_env/__init__.py index 5f73d3978..9a60c07dc 100644 --- a/stable_baselines3/common/vec_env/__init__.py +++ b/stable_baselines3/common/vec_env/__init__.py @@ -1,5 +1,5 @@ from copy import deepcopy -from typing import Optional, Type, TypeVar +from typing import Optional, TypeVar from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper, VecEnv, VecEnvWrapper from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv @@ -16,7 +16,7 @@ VecEnvWrapperT = TypeVar("VecEnvWrapperT", bound=VecEnvWrapper) -def unwrap_vec_wrapper(env: VecEnv, vec_wrapper_class: Type[VecEnvWrapperT]) -> Optional[VecEnvWrapperT]: +def unwrap_vec_wrapper(env: VecEnv, vec_wrapper_class: type[VecEnvWrapperT]) -> Optional[VecEnvWrapperT]: """ Retrieve a ``VecEnvWrapper`` object by recursively searching. @@ -42,7 +42,7 @@ def unwrap_vec_normalize(env: VecEnv) -> Optional[VecNormalize]: return unwrap_vec_wrapper(env, VecNormalize) -def is_vecenv_wrapped(env: VecEnv, vec_wrapper_class: Type[VecEnvWrapper]) -> bool: +def is_vecenv_wrapped(env: VecEnv, vec_wrapper_class: type[VecEnvWrapper]) -> bool: """ Check if an environment is already wrapped in a given ``VecEnvWrapper``. diff --git a/stable_baselines3/common/vec_env/base_vec_env.py b/stable_baselines3/common/vec_env/base_vec_env.py index 8e0c8cc69..b85c1cf88 100644 --- a/stable_baselines3/common/vec_env/base_vec_env.py +++ b/stable_baselines3/common/vec_env/base_vec_env.py @@ -1,8 +1,9 @@ import inspect import warnings from abc import ABC, abstractmethod +from collections.abc import Iterable, Sequence from copy import deepcopy -from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Type, Union +from typing import Any, Optional, Union import cloudpickle import gymnasium as gym @@ -14,10 +15,10 @@ VecEnvIndices = Union[None, int, Iterable[int]] # VecEnvObs is what is returned by the reset() method # it contains the observation for each env -VecEnvObs = Union[np.ndarray, Dict[str, np.ndarray], Tuple[np.ndarray, ...]] +VecEnvObs = Union[np.ndarray, dict[str, np.ndarray], tuple[np.ndarray, ...]] # VecEnvStepReturn is what is returned by the step() method # it contains the observation, reward, done, info for each env -VecEnvStepReturn = Tuple[VecEnvObs, np.ndarray, np.ndarray, List[Dict]] +VecEnvStepReturn = tuple[VecEnvObs, np.ndarray, np.ndarray, list[dict]] def tile_images(images_nhwc: Sequence[np.ndarray]) -> np.ndarray: # pragma: no cover @@ -65,11 +66,11 @@ def __init__( self.observation_space = observation_space self.action_space = action_space # store info returned by the reset method - self.reset_infos: List[Dict[str, Any]] = [{} for _ in range(num_envs)] + self.reset_infos: list[dict[str, Any]] = [{} for _ in range(num_envs)] # seeds to be used in the next call to env.reset() - self._seeds: List[Optional[int]] = [None for _ in range(num_envs)] + self._seeds: list[Optional[int]] = [None for _ in range(num_envs)] # options to be used in the next call to env.reset() - self._options: List[Dict[str, Any]] = [{} for _ in range(num_envs)] + self._options: list[dict[str, Any]] = [{} for _ in range(num_envs)] try: render_modes = self.get_attr("render_mode") @@ -147,7 +148,7 @@ def close(self) -> None: raise NotImplementedError() @abstractmethod - def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]: + def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> list[Any]: """ Return attribute from vectorized environment. @@ -170,7 +171,7 @@ def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> raise NotImplementedError() @abstractmethod - def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]: + def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> list[Any]: """ Call instance methods of vectorized environments. @@ -183,7 +184,7 @@ def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = No raise NotImplementedError() @abstractmethod - def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]: + def env_is_wrapped(self, wrapper_class: type[gym.Wrapper], indices: VecEnvIndices = None) -> list[bool]: """ Check if environments are wrapped with a given wrapper. @@ -292,7 +293,7 @@ def seed(self, seed: Optional[int] = None) -> Sequence[Union[None, int]]: self._seeds = [seed + idx for idx in range(self.num_envs)] return self._seeds - def set_options(self, options: Optional[Union[List[Dict], Dict]] = None) -> None: + def set_options(self, options: Optional[Union[list[dict], dict]] = None) -> None: """ Set environment options for all environments. If a dict is passed instead of a list, the same options will be used for all environments. @@ -379,7 +380,7 @@ def step_wait(self) -> VecEnvStepReturn: def seed(self, seed: Optional[int] = None) -> Sequence[Union[None, int]]: return self.venv.seed(seed) - def set_options(self, options: Optional[Union[List[Dict], Dict]] = None) -> None: + def set_options(self, options: Optional[Union[list[dict], dict]] = None) -> None: return self.venv.set_options(options) def close(self) -> None: @@ -391,16 +392,16 @@ def render(self, mode: Optional[str] = None) -> Optional[np.ndarray]: def get_images(self) -> Sequence[Optional[np.ndarray]]: return self.venv.get_images() - def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]: + def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> list[Any]: return self.venv.get_attr(attr_name, indices) def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None: return self.venv.set_attr(attr_name, value, indices) - def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]: + def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> list[Any]: return self.venv.env_method(method_name, *method_args, indices=indices, **method_kwargs) - def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]: + def env_is_wrapped(self, wrapper_class: type[gym.Wrapper], indices: VecEnvIndices = None) -> list[bool]: return self.venv.env_is_wrapped(wrapper_class, indices=indices) def __getattr__(self, name: str) -> Any: @@ -419,7 +420,7 @@ def __getattr__(self, name: str) -> Any: return self.getattr_recursive(name) - def _get_all_attributes(self) -> Dict[str, Any]: + def _get_all_attributes(self) -> dict[str, Any]: """Get all (inherited) instance and class attributes :return: all_attributes diff --git a/stable_baselines3/common/vec_env/dummy_vec_env.py b/stable_baselines3/common/vec_env/dummy_vec_env.py index 5625e2453..4069356d2 100644 --- a/stable_baselines3/common/vec_env/dummy_vec_env.py +++ b/stable_baselines3/common/vec_env/dummy_vec_env.py @@ -1,7 +1,8 @@ import warnings from collections import OrderedDict +from collections.abc import Sequence from copy import deepcopy -from typing import Any, Callable, Dict, List, Optional, Sequence, Type +from typing import Any, Callable, Optional import gymnasium as gym import numpy as np @@ -26,7 +27,7 @@ class DummyVecEnv(VecEnv): actions: np.ndarray - def __init__(self, env_fns: List[Callable[[], gym.Env]]): + def __init__(self, env_fns: list[Callable[[], gym.Env]]): self.envs = [_patch_env(fn()) for fn in env_fns] if len(set([id(env.unwrapped) for env in self.envs])) != len(self.envs): raise ValueError( @@ -46,7 +47,7 @@ def __init__(self, env_fns: List[Callable[[], gym.Env]]): self.buf_obs = OrderedDict([(k, np.zeros((self.num_envs, *tuple(shapes[k])), dtype=dtypes[k])) for k in self.keys]) self.buf_dones = np.zeros((self.num_envs,), dtype=bool) self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32) - self.buf_infos: List[Dict[str, Any]] = [{} for _ in range(self.num_envs)] + self.buf_infos: list[dict[str, Any]] = [{} for _ in range(self.num_envs)] self.metadata = env.metadata def step_async(self, actions: np.ndarray) -> None: @@ -112,7 +113,7 @@ def _save_obs(self, env_idx: int, obs: VecEnvObs) -> None: def _obs_from_buf(self) -> VecEnvObs: return dict_to_obs(self.observation_space, deepcopy(self.buf_obs)) - def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]: + def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> list[Any]: """Return attribute from vectorized environment (see base class).""" target_envs = self._get_target_envs(indices) return [env_i.get_wrapper_attr(attr_name) for env_i in target_envs] @@ -123,12 +124,12 @@ def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> for env_i in target_envs: setattr(env_i, attr_name, value) - def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]: + def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> list[Any]: """Call instance methods of vectorized environments.""" target_envs = self._get_target_envs(indices) return [env_i.get_wrapper_attr(method_name)(*method_args, **method_kwargs) for env_i in target_envs] - def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]: + def env_is_wrapped(self, wrapper_class: type[gym.Wrapper], indices: VecEnvIndices = None) -> list[bool]: """Check if worker environments are wrapped with a given wrapper""" target_envs = self._get_target_envs(indices) # Import here to avoid a circular import @@ -136,6 +137,6 @@ def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndice return [env_util.is_wrapped(env_i, wrapper_class) for env_i in target_envs] - def _get_target_envs(self, indices: VecEnvIndices) -> List[gym.Env]: + def _get_target_envs(self, indices: VecEnvIndices) -> list[gym.Env]: indices = self._get_indices(indices) return [self.envs[i] for i in indices] diff --git a/stable_baselines3/common/vec_env/stacked_observations.py b/stable_baselines3/common/vec_env/stacked_observations.py index b6a759f30..d1b3ad298 100644 --- a/stable_baselines3/common/vec_env/stacked_observations.py +++ b/stable_baselines3/common/vec_env/stacked_observations.py @@ -1,12 +1,13 @@ import warnings -from typing import Any, Dict, Generic, List, Mapping, Optional, Tuple, TypeVar, Union +from collections.abc import Mapping +from typing import Any, Generic, Optional, TypeVar, Union import numpy as np from gymnasium import spaces from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first -TObs = TypeVar("TObs", np.ndarray, Dict[str, np.ndarray]) +TObs = TypeVar("TObs", np.ndarray, dict[str, np.ndarray]) class StackedObservations(Generic[TObs]): @@ -66,7 +67,7 @@ def __init__( @staticmethod def compute_stacking( n_stack: int, observation_space: spaces.Box, channels_order: Optional[str] = None - ) -> Tuple[bool, int, Tuple[int, ...], int]: + ) -> tuple[bool, int, tuple[int, ...], int]: """ Calculates the parameters in order to stack observations @@ -119,8 +120,8 @@ def update( self, observations: TObs, dones: np.ndarray, - infos: List[Dict[str, Any]], - ) -> Tuple[TObs, List[Dict[str, Any]]]: + infos: list[dict[str, Any]], + ) -> tuple[TObs, list[dict[str, Any]]]: """ Add the observations to the stack and use the dones to update the infos. diff --git a/stable_baselines3/common/vec_env/subproc_vec_env.py b/stable_baselines3/common/vec_env/subproc_vec_env.py index a606a7cb9..225eadd79 100644 --- a/stable_baselines3/common/vec_env/subproc_vec_env.py +++ b/stable_baselines3/common/vec_env/subproc_vec_env.py @@ -1,6 +1,7 @@ import multiprocessing as mp import warnings -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union +from collections.abc import Sequence +from typing import Any, Callable, Optional, Union import gymnasium as gym import numpy as np @@ -26,7 +27,7 @@ def _worker( parent_remote.close() env = _patch_env(env_fn_wrapper.var()) - reset_info: Optional[Dict[str, Any]] = {} + reset_info: Optional[dict[str, Any]] = {} while True: try: cmd, data = remote.recv() @@ -91,7 +92,7 @@ class SubprocVecEnv(VecEnv): Defaults to 'forkserver' on available platforms, and 'spawn' otherwise. """ - def __init__(self, env_fns: List[Callable[[], gym.Env]], start_method: Optional[str] = None): + def __init__(self, env_fns: list[Callable[[], gym.Env]], start_method: Optional[str] = None): self.waiting = False self.closed = False n_envs = len(env_fns) @@ -164,7 +165,7 @@ def get_images(self) -> Sequence[Optional[np.ndarray]]: outputs = [pipe.recv() for pipe in self.remotes] return outputs - def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]: + def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> list[Any]: """Return attribute from vectorized environment (see base class).""" target_remotes = self._get_target_remotes(indices) for remote in target_remotes: @@ -179,21 +180,21 @@ def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> for remote in target_remotes: remote.recv() - def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]: + def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> list[Any]: """Call instance methods of vectorized environments.""" target_remotes = self._get_target_remotes(indices) for remote in target_remotes: remote.send(("env_method", (method_name, method_args, method_kwargs))) return [remote.recv() for remote in target_remotes] - def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]: + def env_is_wrapped(self, wrapper_class: type[gym.Wrapper], indices: VecEnvIndices = None) -> list[bool]: """Check if worker environments are wrapped with a given wrapper""" target_remotes = self._get_target_remotes(indices) for remote in target_remotes: remote.send(("is_wrapped", wrapper_class)) return [remote.recv() for remote in target_remotes] - def _get_target_remotes(self, indices: VecEnvIndices) -> List[Any]: + def _get_target_remotes(self, indices: VecEnvIndices) -> list[Any]: """ Get the connection object needed to communicate with the wanted envs that are in subprocesses. @@ -205,7 +206,7 @@ def _get_target_remotes(self, indices: VecEnvIndices) -> List[Any]: return [self.remotes[i] for i in indices] -def _stack_obs(obs_list: Union[List[VecEnvObs], Tuple[VecEnvObs]], space: spaces.Space) -> VecEnvObs: +def _stack_obs(obs_list: Union[list[VecEnvObs], tuple[VecEnvObs]], space: spaces.Space) -> VecEnvObs: """ Stack observations (convert from a list of single env obs to a stack of obs), depending on the observation space. diff --git a/stable_baselines3/common/vec_env/util.py b/stable_baselines3/common/vec_env/util.py index 6ea04f6ab..c1babd87b 100644 --- a/stable_baselines3/common/vec_env/util.py +++ b/stable_baselines3/common/vec_env/util.py @@ -2,7 +2,7 @@ Helpers for dealing with vectorized environments. """ -from typing import Any, Dict, List, Tuple +from typing import Any import numpy as np from gymnasium import spaces @@ -11,7 +11,7 @@ from stable_baselines3.common.vec_env.base_vec_env import VecEnvObs -def dict_to_obs(obs_space: spaces.Space, obs_dict: Dict[Any, np.ndarray]) -> VecEnvObs: +def dict_to_obs(obs_space: spaces.Space, obs_dict: dict[Any, np.ndarray]) -> VecEnvObs: """ Convert an internal representation raw_obs into the appropriate type specified by space. @@ -32,7 +32,7 @@ def dict_to_obs(obs_space: spaces.Space, obs_dict: Dict[Any, np.ndarray]) -> Vec return obs_dict[None] -def obs_space_info(obs_space: spaces.Space) -> Tuple[List[str], Dict[Any, Tuple[int, ...]], Dict[Any, np.dtype]]: +def obs_space_info(obs_space: spaces.Space) -> tuple[list[str], dict[Any, tuple[int, ...]], dict[Any, np.dtype]]: """ Get dict-structured information about a gym.Space. diff --git a/stable_baselines3/common/vec_env/vec_check_nan.py b/stable_baselines3/common/vec_env/vec_check_nan.py index 170f36ec8..1d775aad5 100644 --- a/stable_baselines3/common/vec_env/vec_check_nan.py +++ b/stable_baselines3/common/vec_env/vec_check_nan.py @@ -1,5 +1,4 @@ import warnings -from typing import List, Tuple import numpy as np from gymnasium import spaces @@ -48,7 +47,7 @@ def reset(self) -> VecEnvObs: self._observations = observations return observations - def check_array_value(self, name: str, value: np.ndarray) -> List[Tuple[str, str]]: + def check_array_value(self, name: str, value: np.ndarray) -> list[tuple[str, str]]: """ Check for inf and NaN for a single numpy array. diff --git a/stable_baselines3/common/vec_env/vec_frame_stack.py b/stable_baselines3/common/vec_env/vec_frame_stack.py index daa2b365c..2142bcb9e 100644 --- a/stable_baselines3/common/vec_env/vec_frame_stack.py +++ b/stable_baselines3/common/vec_env/vec_frame_stack.py @@ -1,4 +1,5 @@ -from typing import Any, Dict, List, Mapping, Optional, Tuple, Union +from collections.abc import Mapping +from typing import Any, Optional, Union import numpy as np from gymnasium import spaces @@ -29,17 +30,17 @@ def __init__(self, venv: VecEnv, n_stack: int, channels_order: Optional[Union[st def step_wait( self, - ) -> Tuple[ - Union[np.ndarray, Dict[str, np.ndarray]], + ) -> tuple[ + Union[np.ndarray, dict[str, np.ndarray]], np.ndarray, np.ndarray, - List[Dict[str, Any]], + list[dict[str, Any]], ]: observations, rewards, dones, infos = self.venv.step_wait() observations, infos = self.stacked_obs.update(observations, dones, infos) # type: ignore[arg-type] return observations, rewards, dones, infos - def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]: + def reset(self) -> Union[np.ndarray, dict[str, np.ndarray]]: """ Reset all environments """ diff --git a/stable_baselines3/common/vec_env/vec_monitor.py b/stable_baselines3/common/vec_env/vec_monitor.py index 0d7f18a5e..4aa9325f6 100644 --- a/stable_baselines3/common/vec_env/vec_monitor.py +++ b/stable_baselines3/common/vec_env/vec_monitor.py @@ -1,6 +1,6 @@ import time import warnings -from typing import Optional, Tuple +from typing import Optional import numpy as np @@ -27,7 +27,7 @@ def __init__( self, venv: VecEnv, filename: Optional[str] = None, - info_keywords: Tuple[str, ...] = (), + info_keywords: tuple[str, ...] = (), ): # Avoid circular import from stable_baselines3.common.monitor import Monitor, ResultsWriter diff --git a/stable_baselines3/common/vec_env/vec_normalize.py b/stable_baselines3/common/vec_env/vec_normalize.py index ab1d8403a..5f0ee1c25 100644 --- a/stable_baselines3/common/vec_env/vec_normalize.py +++ b/stable_baselines3/common/vec_env/vec_normalize.py @@ -1,7 +1,7 @@ import inspect import pickle from copy import deepcopy -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union import numpy as np from gymnasium import spaces @@ -29,8 +29,8 @@ class VecNormalize(VecEnvWrapper): If not specified, all keys will be normalized. """ - obs_spaces: Dict[str, spaces.Space] - old_obs: Union[np.ndarray, Dict[str, np.ndarray]] + obs_spaces: dict[str, spaces.Space] + old_obs: Union[np.ndarray, dict[str, np.ndarray]] def __init__( self, @@ -42,7 +42,7 @@ def __init__( clip_reward: float = 10.0, gamma: float = 0.99, epsilon: float = 1e-8, - norm_obs_keys: Optional[List[str]] = None, + norm_obs_keys: Optional[list[str]] = None, ): VecEnvWrapper.__init__(self, venv) @@ -125,7 +125,7 @@ def _sanity_checks(self) -> None: f"not {self.observation_space}" ) - def __getstate__(self) -> Dict[str, Any]: + def __getstate__(self) -> dict[str, Any]: """ Gets state for pickling. @@ -138,7 +138,7 @@ def __getstate__(self) -> Dict[str, Any]: del state["returns"] return state - def __setstate__(self, state: Dict[str, Any]) -> None: + def __setstate__(self, state: dict[str, Any]) -> None: """ Restores pickled state. @@ -229,7 +229,7 @@ def _unnormalize_obs(self, obs: np.ndarray, obs_rms: RunningMeanStd) -> np.ndarr """ return (obs * np.sqrt(obs_rms.var + self.epsilon)) + obs_rms.mean - def normalize_obs(self, obs: Union[np.ndarray, Dict[str, np.ndarray]]) -> Union[np.ndarray, Dict[str, np.ndarray]]: + def normalize_obs(self, obs: Union[np.ndarray, dict[str, np.ndarray]]) -> Union[np.ndarray, dict[str, np.ndarray]]: """ Normalize observations using this VecNormalize's observations statistics. Calling this method does not update statistics. @@ -254,9 +254,11 @@ def normalize_reward(self, reward: np.ndarray) -> np.ndarray: """ if self.norm_reward: reward = np.clip(reward / np.sqrt(self.ret_rms.var + self.epsilon), -self.clip_reward, self.clip_reward) - return reward + # Note: we cast to float32 as it correspond to Python default float type + # This cast is needed because `RunningMeanStd` keeps stats in float64 + return reward.astype(np.float32) - def unnormalize_obs(self, obs: Union[np.ndarray, Dict[str, np.ndarray]]) -> Union[np.ndarray, Dict[str, np.ndarray]]: + def unnormalize_obs(self, obs: Union[np.ndarray, dict[str, np.ndarray]]) -> Union[np.ndarray, dict[str, np.ndarray]]: # Avoid modifying by reference the original object obs_ = deepcopy(obs) if self.norm_obs: @@ -274,7 +276,7 @@ def unnormalize_reward(self, reward: np.ndarray) -> np.ndarray: return reward * np.sqrt(self.ret_rms.var + self.epsilon) return reward - def get_original_obs(self) -> Union[np.ndarray, Dict[str, np.ndarray]]: + def get_original_obs(self) -> Union[np.ndarray, dict[str, np.ndarray]]: """ Returns an unnormalized version of the observations from the most recent step or reset. @@ -287,7 +289,7 @@ def get_original_reward(self) -> np.ndarray: """ return self.old_reward.copy() - def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]: + def reset(self) -> Union[np.ndarray, dict[str, np.ndarray]]: """ Reset all environments :return: first observation of the episode diff --git a/stable_baselines3/common/vec_env/vec_transpose.py b/stable_baselines3/common/vec_env/vec_transpose.py index 487bd8c07..3fade64d1 100644 --- a/stable_baselines3/common/vec_env/vec_transpose.py +++ b/stable_baselines3/common/vec_env/vec_transpose.py @@ -1,5 +1,5 @@ from copy import deepcopy -from typing import Dict, Union +from typing import Union import numpy as np from gymnasium import spaces @@ -73,7 +73,7 @@ def transpose_image(image: np.ndarray) -> np.ndarray: return np.transpose(image, (2, 0, 1)) return np.transpose(image, (0, 3, 1, 2)) - def transpose_observations(self, observations: Union[np.ndarray, Dict]) -> Union[np.ndarray, Dict]: + def transpose_observations(self, observations: Union[np.ndarray, dict]) -> Union[np.ndarray, dict]: """ Transpose (if needed) and return new observations. @@ -106,7 +106,7 @@ def step_wait(self) -> VecEnvStepReturn: assert isinstance(observations, (np.ndarray, dict)) return self.transpose_observations(observations), rewards, dones, infos - def reset(self) -> Union[np.ndarray, Dict]: + def reset(self) -> Union[np.ndarray, dict]: """ Reset all environments """ diff --git a/stable_baselines3/common/vec_env/vec_video_recorder.py b/stable_baselines3/common/vec_env/vec_video_recorder.py index e586f94ab..add3846b6 100644 --- a/stable_baselines3/common/vec_env/vec_video_recorder.py +++ b/stable_baselines3/common/vec_env/vec_video_recorder.py @@ -1,6 +1,6 @@ import os import os.path -from typing import Callable, List +from typing import Callable import numpy as np from gymnasium import error, logger @@ -109,7 +109,7 @@ def _capture_frame(self) -> None: assert self.recording, "Cannot capture a frame, recording wasn't started." frame = self.env.render() - if isinstance(frame, List): + if isinstance(frame, list): frame = frame[-1] if isinstance(frame, np.ndarray): diff --git a/stable_baselines3/ddpg/ddpg.py b/stable_baselines3/ddpg/ddpg.py index 2fe2fdfc4..d94fa1812 100644 --- a/stable_baselines3/ddpg/ddpg.py +++ b/stable_baselines3/ddpg/ddpg.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union +from typing import Any, Optional, TypeVar, Union import torch as th @@ -55,7 +55,7 @@ class DDPG(TD3): def __init__( self, - policy: Union[str, Type[TD3Policy]], + policy: Union[str, type[TD3Policy]], env: Union[GymEnv, str], learning_rate: Union[float, Schedule] = 1e-3, buffer_size: int = 1_000_000, # 1e6 @@ -63,14 +63,14 @@ def __init__( batch_size: int = 256, tau: float = 0.005, gamma: float = 0.99, - train_freq: Union[int, Tuple[int, str]] = 1, + train_freq: Union[int, tuple[int, str]] = 1, gradient_steps: int = 1, action_noise: Optional[ActionNoise] = None, - replay_buffer_class: Optional[Type[ReplayBuffer]] = None, - replay_buffer_kwargs: Optional[Dict[str, Any]] = None, + replay_buffer_class: Optional[type[ReplayBuffer]] = None, + replay_buffer_kwargs: Optional[dict[str, Any]] = None, optimize_memory_usage: bool = False, tensorboard_log: Optional[str] = None, - policy_kwargs: Optional[Dict[str, Any]] = None, + policy_kwargs: Optional[dict[str, Any]] = None, verbose: int = 0, seed: Optional[int] = None, device: Union[th.device, str] = "auto", diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index 894ed9f04..a3f200e59 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -1,5 +1,5 @@ import warnings -from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TypeVar, Union +from typing import Any, ClassVar, Optional, TypeVar, Union import numpy as np import torch as th @@ -62,7 +62,7 @@ class DQN(OffPolicyAlgorithm): :param _init_setup_model: Whether or not to build the network at the creation of the instance """ - policy_aliases: ClassVar[Dict[str, Type[BasePolicy]]] = { + policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = { "MlpPolicy": MlpPolicy, "CnnPolicy": CnnPolicy, "MultiInputPolicy": MultiInputPolicy, @@ -75,7 +75,7 @@ class DQN(OffPolicyAlgorithm): def __init__( self, - policy: Union[str, Type[DQNPolicy]], + policy: Union[str, type[DQNPolicy]], env: Union[GymEnv, str], learning_rate: Union[float, Schedule] = 1e-4, buffer_size: int = 1_000_000, # 1e6 @@ -83,10 +83,10 @@ def __init__( batch_size: int = 32, tau: float = 1.0, gamma: float = 0.99, - train_freq: Union[int, Tuple[int, str]] = 4, + train_freq: Union[int, tuple[int, str]] = 4, gradient_steps: int = 1, - replay_buffer_class: Optional[Type[ReplayBuffer]] = None, - replay_buffer_kwargs: Optional[Dict[str, Any]] = None, + replay_buffer_class: Optional[type[ReplayBuffer]] = None, + replay_buffer_kwargs: Optional[dict[str, Any]] = None, optimize_memory_usage: bool = False, target_update_interval: int = 10000, exploration_fraction: float = 0.1, @@ -95,7 +95,7 @@ def __init__( max_grad_norm: float = 10, stats_window_size: int = 100, tensorboard_log: Optional[str] = None, - policy_kwargs: Optional[Dict[str, Any]] = None, + policy_kwargs: Optional[dict[str, Any]] = None, verbose: int = 0, seed: Optional[int] = None, device: Union[th.device, str] = "auto", @@ -227,11 +227,11 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None: def predict( self, - observation: Union[np.ndarray, Dict[str, np.ndarray]], - state: Optional[Tuple[np.ndarray, ...]] = None, + observation: Union[np.ndarray, dict[str, np.ndarray]], + state: Optional[tuple[np.ndarray, ...]] = None, episode_start: Optional[np.ndarray] = None, deterministic: bool = False, - ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: + ) -> tuple[np.ndarray, Optional[tuple[np.ndarray, ...]]]: """ Overrides the base_class predict function to include epsilon-greedy exploration. @@ -273,10 +273,10 @@ def learn( progress_bar=progress_bar, ) - def _excluded_save_params(self) -> List[str]: + def _excluded_save_params(self) -> list[str]: return [*super()._excluded_save_params(), "q_net", "q_net_target"] - def _get_torch_save_params(self) -> Tuple[List[str], List[str]]: + def _get_torch_save_params(self) -> tuple[list[str], list[str]]: state_dicts = ["policy", "policy.optimizer"] return state_dicts, [] diff --git a/stable_baselines3/dqn/policies.py b/stable_baselines3/dqn/policies.py index bfefc8137..95f05d8ca 100644 --- a/stable_baselines3/dqn/policies.py +++ b/stable_baselines3/dqn/policies.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Type +from typing import Any, Optional import torch as th from gymnasium import spaces @@ -35,8 +35,8 @@ def __init__( action_space: spaces.Discrete, features_extractor: BaseFeaturesExtractor, features_dim: int, - net_arch: Optional[List[int]] = None, - activation_fn: Type[nn.Module] = nn.ReLU, + net_arch: Optional[list[int]] = None, + activation_fn: type[nn.Module] = nn.ReLU, normalize_images: bool = True, ) -> None: super().__init__( @@ -71,7 +71,7 @@ def _predict(self, observation: PyTorchObs, deterministic: bool = True) -> th.Te action = q_values.argmax(dim=1).reshape(-1) return action - def _get_constructor_parameters(self) -> Dict[str, Any]: + def _get_constructor_parameters(self) -> dict[str, Any]: data = super()._get_constructor_parameters() data.update( @@ -113,13 +113,13 @@ def __init__( observation_space: spaces.Space, action_space: spaces.Discrete, lr_schedule: Schedule, - net_arch: Optional[List[int]] = None, - activation_fn: Type[nn.Module] = nn.ReLU, - features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, - features_extractor_kwargs: Optional[Dict[str, Any]] = None, + net_arch: Optional[list[int]] = None, + activation_fn: type[nn.Module] = nn.ReLU, + features_extractor_class: type[BaseFeaturesExtractor] = FlattenExtractor, + features_extractor_kwargs: Optional[dict[str, Any]] = None, normalize_images: bool = True, - optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, + optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[dict[str, Any]] = None, ) -> None: super().__init__( observation_space, @@ -183,7 +183,7 @@ def forward(self, obs: PyTorchObs, deterministic: bool = True) -> th.Tensor: def _predict(self, obs: PyTorchObs, deterministic: bool = True) -> th.Tensor: return self.q_net._predict(obs, deterministic=deterministic) - def _get_constructor_parameters(self) -> Dict[str, Any]: + def _get_constructor_parameters(self) -> dict[str, Any]: data = super()._get_constructor_parameters() data.update( @@ -237,13 +237,13 @@ def __init__( observation_space: spaces.Space, action_space: spaces.Discrete, lr_schedule: Schedule, - net_arch: Optional[List[int]] = None, - activation_fn: Type[nn.Module] = nn.ReLU, - features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN, - features_extractor_kwargs: Optional[Dict[str, Any]] = None, + net_arch: Optional[list[int]] = None, + activation_fn: type[nn.Module] = nn.ReLU, + features_extractor_class: type[BaseFeaturesExtractor] = NatureCNN, + features_extractor_kwargs: Optional[dict[str, Any]] = None, normalize_images: bool = True, - optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, + optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[dict[str, Any]] = None, ) -> None: super().__init__( observation_space, @@ -282,13 +282,13 @@ def __init__( observation_space: spaces.Dict, action_space: spaces.Discrete, lr_schedule: Schedule, - net_arch: Optional[List[int]] = None, - activation_fn: Type[nn.Module] = nn.ReLU, - features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor, - features_extractor_kwargs: Optional[Dict[str, Any]] = None, + net_arch: Optional[list[int]] = None, + activation_fn: type[nn.Module] = nn.ReLU, + features_extractor_class: type[BaseFeaturesExtractor] = CombinedExtractor, + features_extractor_kwargs: Optional[dict[str, Any]] = None, normalize_images: bool = True, - optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, + optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[dict[str, Any]] = None, ) -> None: super().__init__( observation_space, diff --git a/stable_baselines3/her/her_replay_buffer.py b/stable_baselines3/her/her_replay_buffer.py index 20214e72c..956aabc92 100644 --- a/stable_baselines3/her/her_replay_buffer.py +++ b/stable_baselines3/her/her_replay_buffer.py @@ -1,6 +1,6 @@ import copy import warnings -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union import numpy as np import torch as th @@ -98,7 +98,7 @@ def __init__( self.ep_length = np.zeros((self.buffer_size, self.n_envs), dtype=np.int64) self._current_ep_start = np.zeros(self.n_envs, dtype=np.int64) - def __getstate__(self) -> Dict[str, Any]: + def __getstate__(self) -> dict[str, Any]: """ Gets state for pickling. @@ -109,7 +109,7 @@ def __getstate__(self) -> Dict[str, Any]: del state["env"] return state - def __setstate__(self, state: Dict[str, Any]) -> None: + def __setstate__(self, state: dict[str, Any]) -> None: """ Restores pickled state. @@ -134,12 +134,12 @@ def set_env(self, env: VecEnv) -> None: def add( # type: ignore[override] self, - obs: Dict[str, np.ndarray], - next_obs: Dict[str, np.ndarray], + obs: dict[str, np.ndarray], + next_obs: dict[str, np.ndarray], action: np.ndarray, reward: np.ndarray, done: np.ndarray, - infos: List[Dict[str, Any]], + infos: list[dict[str, Any]], ) -> None: # When the buffer is full, we rewrite on old episodes. When we start to # rewrite on an old episodes, we want the whole old episode to be deleted diff --git a/stable_baselines3/ppo/ppo.py b/stable_baselines3/ppo/ppo.py index 52ee2eb64..03cbc2464 100644 --- a/stable_baselines3/ppo/ppo.py +++ b/stable_baselines3/ppo/ppo.py @@ -1,5 +1,5 @@ import warnings -from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union +from typing import Any, ClassVar, Optional, TypeVar, Union import numpy as np import torch as th @@ -71,7 +71,7 @@ class PPO(OnPolicyAlgorithm): :param _init_setup_model: Whether or not to build the network at the creation of the instance """ - policy_aliases: ClassVar[Dict[str, Type[BasePolicy]]] = { + policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = { "MlpPolicy": ActorCriticPolicy, "CnnPolicy": ActorCriticCnnPolicy, "MultiInputPolicy": MultiInputActorCriticPolicy, @@ -79,7 +79,7 @@ class PPO(OnPolicyAlgorithm): def __init__( self, - policy: Union[str, Type[ActorCriticPolicy]], + policy: Union[str, type[ActorCriticPolicy]], env: Union[GymEnv, str], learning_rate: Union[float, Schedule] = 3e-4, n_steps: int = 2048, @@ -95,12 +95,12 @@ def __init__( max_grad_norm: float = 0.5, use_sde: bool = False, sde_sample_freq: int = -1, - rollout_buffer_class: Optional[Type[RolloutBuffer]] = None, - rollout_buffer_kwargs: Optional[Dict[str, Any]] = None, + rollout_buffer_class: Optional[type[RolloutBuffer]] = None, + rollout_buffer_kwargs: Optional[dict[str, Any]] = None, target_kl: Optional[float] = None, stats_window_size: int = 100, tensorboard_log: Optional[str] = None, - policy_kwargs: Optional[Dict[str, Any]] = None, + policy_kwargs: Optional[dict[str, Any]] = None, verbose: int = 0, seed: Optional[int] = None, device: Union[th.device, str] = "auto", diff --git a/stable_baselines3/sac/policies.py b/stable_baselines3/sac/policies.py index 6185e2992..330467727 100644 --- a/stable_baselines3/sac/policies.py +++ b/stable_baselines3/sac/policies.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Optional, Union import torch as th from gymnasium import spaces @@ -51,10 +51,10 @@ def __init__( self, observation_space: spaces.Space, action_space: spaces.Box, - net_arch: List[int], + net_arch: list[int], features_extractor: nn.Module, features_dim: int, - activation_fn: Type[nn.Module] = nn.ReLU, + activation_fn: type[nn.Module] = nn.ReLU, use_sde: bool = False, log_std_init: float = -3, full_std: bool = True, @@ -102,7 +102,7 @@ def __init__( self.mu = nn.Linear(last_layer_dim, action_dim) self.log_std = nn.Linear(last_layer_dim, action_dim) # type: ignore[assignment] - def _get_constructor_parameters(self) -> Dict[str, Any]: + def _get_constructor_parameters(self) -> dict[str, Any]: data = super()._get_constructor_parameters() data.update( @@ -144,7 +144,7 @@ def reset_noise(self, batch_size: int = 1) -> None: assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg self.action_dist.sample_weights(self.log_std, batch_size=batch_size) - def get_action_dist_params(self, obs: PyTorchObs) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]: + def get_action_dist_params(self, obs: PyTorchObs) -> tuple[th.Tensor, th.Tensor, dict[str, th.Tensor]]: """ Get the parameters for the action distribution. @@ -169,7 +169,7 @@ def forward(self, obs: PyTorchObs, deterministic: bool = False) -> th.Tensor: # Note: the action is squashed return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, **kwargs) - def action_log_prob(self, obs: PyTorchObs) -> Tuple[th.Tensor, th.Tensor]: + def action_log_prob(self, obs: PyTorchObs) -> tuple[th.Tensor, th.Tensor]: mean_actions, log_std, kwargs = self.get_action_dist_params(obs) # return action and associated log prob return self.action_dist.log_prob_from_params(mean_actions, log_std, **kwargs) @@ -216,17 +216,17 @@ def __init__( observation_space: spaces.Space, action_space: spaces.Box, lr_schedule: Schedule, - net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, - activation_fn: Type[nn.Module] = nn.ReLU, + net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None, + activation_fn: type[nn.Module] = nn.ReLU, use_sde: bool = False, log_std_init: float = -3, use_expln: bool = False, clip_mean: float = 2.0, - features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, - features_extractor_kwargs: Optional[Dict[str, Any]] = None, + features_extractor_class: type[BaseFeaturesExtractor] = FlattenExtractor, + features_extractor_kwargs: Optional[dict[str, Any]] = None, normalize_images: bool = True, - optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, + optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[dict[str, Any]] = None, n_critics: int = 2, share_features_extractor: bool = False, ): @@ -309,7 +309,7 @@ def _build(self, lr_schedule: Schedule) -> None: # Target networks should always be in eval mode self.critic_target.set_training_mode(False) - def _get_constructor_parameters(self) -> Dict[str, Any]: + def _get_constructor_parameters(self) -> dict[str, Any]: data = super()._get_constructor_parameters() data.update( @@ -400,17 +400,17 @@ def __init__( observation_space: spaces.Space, action_space: spaces.Box, lr_schedule: Schedule, - net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, - activation_fn: Type[nn.Module] = nn.ReLU, + net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None, + activation_fn: type[nn.Module] = nn.ReLU, use_sde: bool = False, log_std_init: float = -3, use_expln: bool = False, clip_mean: float = 2.0, - features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN, - features_extractor_kwargs: Optional[Dict[str, Any]] = None, + features_extractor_class: type[BaseFeaturesExtractor] = NatureCNN, + features_extractor_kwargs: Optional[dict[str, Any]] = None, normalize_images: bool = True, - optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, + optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[dict[str, Any]] = None, n_critics: int = 2, share_features_extractor: bool = False, ): @@ -466,17 +466,17 @@ def __init__( observation_space: spaces.Space, action_space: spaces.Box, lr_schedule: Schedule, - net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, - activation_fn: Type[nn.Module] = nn.ReLU, + net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None, + activation_fn: type[nn.Module] = nn.ReLU, use_sde: bool = False, log_std_init: float = -3, use_expln: bool = False, clip_mean: float = 2.0, - features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor, - features_extractor_kwargs: Optional[Dict[str, Any]] = None, + features_extractor_class: type[BaseFeaturesExtractor] = CombinedExtractor, + features_extractor_kwargs: Optional[dict[str, Any]] = None, normalize_images: bool = True, - optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, + optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[dict[str, Any]] = None, n_critics: int = 2, share_features_extractor: bool = False, ): diff --git a/stable_baselines3/sac/sac.py b/stable_baselines3/sac/sac.py index bf0fa5028..8cb2ae53d 100644 --- a/stable_baselines3/sac/sac.py +++ b/stable_baselines3/sac/sac.py @@ -1,4 +1,4 @@ -from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TypeVar, Union +from typing import Any, ClassVar, Optional, TypeVar, Union import numpy as np import torch as th @@ -77,7 +77,7 @@ class SAC(OffPolicyAlgorithm): :param _init_setup_model: Whether or not to build the network at the creation of the instance """ - policy_aliases: ClassVar[Dict[str, Type[BasePolicy]]] = { + policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = { "MlpPolicy": MlpPolicy, "CnnPolicy": CnnPolicy, "MultiInputPolicy": MultiInputPolicy, @@ -89,7 +89,7 @@ class SAC(OffPolicyAlgorithm): def __init__( self, - policy: Union[str, Type[SACPolicy]], + policy: Union[str, type[SACPolicy]], env: Union[GymEnv, str], learning_rate: Union[float, Schedule] = 3e-4, buffer_size: int = 1_000_000, # 1e6 @@ -97,11 +97,11 @@ def __init__( batch_size: int = 256, tau: float = 0.005, gamma: float = 0.99, - train_freq: Union[int, Tuple[int, str]] = 1, + train_freq: Union[int, tuple[int, str]] = 1, gradient_steps: int = 1, action_noise: Optional[ActionNoise] = None, - replay_buffer_class: Optional[Type[ReplayBuffer]] = None, - replay_buffer_kwargs: Optional[Dict[str, Any]] = None, + replay_buffer_class: Optional[type[ReplayBuffer]] = None, + replay_buffer_kwargs: Optional[dict[str, Any]] = None, optimize_memory_usage: bool = False, ent_coef: Union[str, float] = "auto", target_update_interval: int = 1, @@ -111,7 +111,7 @@ def __init__( use_sde_at_warmup: bool = False, stats_window_size: int = 100, tensorboard_log: Optional[str] = None, - policy_kwargs: Optional[Dict[str, Any]] = None, + policy_kwargs: Optional[dict[str, Any]] = None, verbose: int = 0, seed: Optional[int] = None, device: Union[th.device, str] = "auto", @@ -313,10 +313,10 @@ def learn( progress_bar=progress_bar, ) - def _excluded_save_params(self) -> List[str]: + def _excluded_save_params(self) -> list[str]: return super()._excluded_save_params() + ["actor", "critic", "critic_target"] # noqa: RUF005 - def _get_torch_save_params(self) -> Tuple[List[str], List[str]]: + def _get_torch_save_params(self) -> tuple[list[str], list[str]]: state_dicts = ["policy", "actor.optimizer", "critic.optimizer"] if self.ent_coef_optimizer is not None: saved_pytorch_variables = ["log_ent_coef"] diff --git a/stable_baselines3/td3/policies.py b/stable_baselines3/td3/policies.py index a15be0396..aa7ea8069 100644 --- a/stable_baselines3/td3/policies.py +++ b/stable_baselines3/td3/policies.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any, Optional, Union import torch as th from gymnasium import spaces @@ -36,10 +36,10 @@ def __init__( self, observation_space: spaces.Space, action_space: spaces.Box, - net_arch: List[int], + net_arch: list[int], features_extractor: nn.Module, features_dim: int, - activation_fn: Type[nn.Module] = nn.ReLU, + activation_fn: type[nn.Module] = nn.ReLU, normalize_images: bool = True, ): super().__init__( @@ -59,7 +59,7 @@ def __init__( # Deterministic action self.mu = nn.Sequential(*actor_net) - def _get_constructor_parameters(self) -> Dict[str, Any]: + def _get_constructor_parameters(self) -> dict[str, Any]: data = super()._get_constructor_parameters() data.update( @@ -116,13 +116,13 @@ def __init__( observation_space: spaces.Space, action_space: spaces.Box, lr_schedule: Schedule, - net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, - activation_fn: Type[nn.Module] = nn.ReLU, - features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, - features_extractor_kwargs: Optional[Dict[str, Any]] = None, + net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None, + activation_fn: type[nn.Module] = nn.ReLU, + features_extractor_class: type[BaseFeaturesExtractor] = FlattenExtractor, + features_extractor_kwargs: Optional[dict[str, Any]] = None, normalize_images: bool = True, - optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, + optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[dict[str, Any]] = None, n_critics: int = 2, share_features_extractor: bool = False, ): @@ -207,7 +207,7 @@ def _build(self, lr_schedule: Schedule) -> None: self.actor_target.set_training_mode(False) self.critic_target.set_training_mode(False) - def _get_constructor_parameters(self) -> Dict[str, Any]: + def _get_constructor_parameters(self) -> dict[str, Any]: data = super()._get_constructor_parameters() data.update( @@ -285,13 +285,13 @@ def __init__( observation_space: spaces.Space, action_space: spaces.Box, lr_schedule: Schedule, - net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, - activation_fn: Type[nn.Module] = nn.ReLU, - features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN, - features_extractor_kwargs: Optional[Dict[str, Any]] = None, + net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None, + activation_fn: type[nn.Module] = nn.ReLU, + features_extractor_class: type[BaseFeaturesExtractor] = NatureCNN, + features_extractor_kwargs: Optional[dict[str, Any]] = None, normalize_images: bool = True, - optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, + optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[dict[str, Any]] = None, n_critics: int = 2, share_features_extractor: bool = False, ): @@ -339,13 +339,13 @@ def __init__( observation_space: spaces.Dict, action_space: spaces.Box, lr_schedule: Schedule, - net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, - activation_fn: Type[nn.Module] = nn.ReLU, - features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor, - features_extractor_kwargs: Optional[Dict[str, Any]] = None, + net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None, + activation_fn: type[nn.Module] = nn.ReLU, + features_extractor_class: type[BaseFeaturesExtractor] = CombinedExtractor, + features_extractor_kwargs: Optional[dict[str, Any]] = None, normalize_images: bool = True, - optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, + optimizer_class: type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[dict[str, Any]] = None, n_critics: int = 2, share_features_extractor: bool = False, ): diff --git a/stable_baselines3/td3/td3.py b/stable_baselines3/td3/td3.py index a61d954bc..affb9c9f8 100644 --- a/stable_baselines3/td3/td3.py +++ b/stable_baselines3/td3/td3.py @@ -1,4 +1,4 @@ -from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TypeVar, Union +from typing import Any, ClassVar, Optional, TypeVar, Union import numpy as np import torch as th @@ -65,7 +65,7 @@ class TD3(OffPolicyAlgorithm): :param _init_setup_model: Whether or not to build the network at the creation of the instance """ - policy_aliases: ClassVar[Dict[str, Type[BasePolicy]]] = { + policy_aliases: ClassVar[dict[str, type[BasePolicy]]] = { "MlpPolicy": MlpPolicy, "CnnPolicy": CnnPolicy, "MultiInputPolicy": MultiInputPolicy, @@ -78,7 +78,7 @@ class TD3(OffPolicyAlgorithm): def __init__( self, - policy: Union[str, Type[TD3Policy]], + policy: Union[str, type[TD3Policy]], env: Union[GymEnv, str], learning_rate: Union[float, Schedule] = 1e-3, buffer_size: int = 1_000_000, # 1e6 @@ -86,18 +86,18 @@ def __init__( batch_size: int = 256, tau: float = 0.005, gamma: float = 0.99, - train_freq: Union[int, Tuple[int, str]] = 1, + train_freq: Union[int, tuple[int, str]] = 1, gradient_steps: int = 1, action_noise: Optional[ActionNoise] = None, - replay_buffer_class: Optional[Type[ReplayBuffer]] = None, - replay_buffer_kwargs: Optional[Dict[str, Any]] = None, + replay_buffer_class: Optional[type[ReplayBuffer]] = None, + replay_buffer_kwargs: Optional[dict[str, Any]] = None, optimize_memory_usage: bool = False, policy_delay: int = 2, target_policy_noise: float = 0.2, target_noise_clip: float = 0.5, stats_window_size: int = 100, tensorboard_log: Optional[str] = None, - policy_kwargs: Optional[Dict[str, Any]] = None, + policy_kwargs: Optional[dict[str, Any]] = None, verbose: int = 0, seed: Optional[int] = None, device: Union[th.device, str] = "auto", @@ -228,9 +228,9 @@ def learn( progress_bar=progress_bar, ) - def _excluded_save_params(self) -> List[str]: + def _excluded_save_params(self) -> list[str]: return super()._excluded_save_params() + ["actor", "critic", "actor_target", "critic_target"] # noqa: RUF005 - def _get_torch_save_params(self) -> Tuple[List[str], List[str]]: + def _get_torch_save_params(self) -> tuple[list[str], list[str]]: state_dicts = ["policy", "actor.optimizer", "critic.optimizer"] return state_dicts, [] diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 197c4d5c2..b8feefb94 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.4.0 +2.5.0a0 diff --git a/tests/test_dict_env.py b/tests/test_dict_env.py index 8049c6887..305f56fc2 100644 --- a/tests/test_dict_env.py +++ b/tests/test_dict_env.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional +from typing import Optional import gymnasium as gym import numpy as np @@ -72,7 +72,7 @@ def step(self, action): terminated = truncated = False return self.observation_space.sample(), reward, terminated, truncated, {} - def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): + def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): if seed is not None: self.observation_space.seed(seed) return self.observation_space.sample(), {} diff --git a/tests/test_distributions.py b/tests/test_distributions.py index 48eae12d0..0f8ae6253 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -1,5 +1,4 @@ from copy import deepcopy -from typing import Tuple import gymnasium as gym import numpy as np @@ -55,7 +54,7 @@ def test_squashed_gaussian(model_class): @pytest.fixture() -def dummy_model_distribution_obs_and_actions() -> Tuple[A2C, np.ndarray, np.ndarray]: +def dummy_model_distribution_obs_and_actions() -> tuple[A2C, np.ndarray, np.ndarray]: """ Fixture creating a Pendulum-v1 gym env, an A2C model and sampling 10 random observations and actions from the env :return: A2C model, random observations, random actions diff --git a/tests/test_env_checker.py b/tests/test_env_checker.py index 62dd6ffa6..c1598cad4 100644 --- a/tests/test_env_checker.py +++ b/tests/test_env_checker.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Tuple +from typing import Any, Optional import gymnasium as gym import numpy as np @@ -135,7 +135,7 @@ def test_check_env_detailed_error(obs_tuple, method): class TestEnv(gym.Env): action_space = spaces.Box(low=-1.0, high=1.0, shape=(3,), dtype=np.float32) - def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): + def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): return wrong_obs if method == "reset" else good_obs, {} def step(self, action): @@ -162,7 +162,7 @@ def __init__(self, steps_before_termination: int = 1): self._steps_called = 0 self._terminated = False - def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[int, Dict]: + def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None) -> tuple[int, dict]: super().reset(seed=seed) self._steps_called = 0 @@ -170,7 +170,7 @@ def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None) - return 0, {} - def step(self, action: np.ndarray) -> Tuple[int, float, bool, bool, Dict[str, Any]]: + def step(self, action: np.ndarray) -> tuple[int, float, bool, bool, dict[str, Any]]: self._steps_called += 1 assert not self._terminated diff --git a/tests/test_gae.py b/tests/test_gae.py index bb674cffa..6e32d0f87 100644 --- a/tests/test_gae.py +++ b/tests/test_gae.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional +from typing import Optional import gymnasium as gym import numpy as np @@ -23,7 +23,7 @@ def __init__(self, max_steps=8): def seed(self, seed): self.observation_space.seed(seed) - def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): + def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): if seed is not None: self.observation_space.seed(seed) self.n_steps = 0 @@ -53,7 +53,7 @@ def __init__(self, n_states=4): self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) self.current_state = 0 - def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): + def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): if seed is not None: super().reset(seed=seed) diff --git a/tests/test_logger.py b/tests/test_logger.py index 02d36b306..039e3f4ac 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -2,8 +2,8 @@ import os import sys import time +from collections.abc import Sequence from io import TextIOBase -from typing import Sequence from unittest import mock import gymnasium as gym diff --git a/tests/test_spaces.py b/tests/test_spaces.py index e006c1f96..cd38e1ecd 100644 --- a/tests/test_spaces.py +++ b/tests/test_spaces.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Dict, Optional +from typing import Optional import gymnasium as gym import numpy as np @@ -24,7 +24,7 @@ class DummyEnv(gym.Env): def step(self, action): return self.observation_space.sample(), 0.0, False, False, {} - def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): + def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): if seed is not None: super().reset(seed=seed) return self.observation_space.sample(), {} diff --git a/tests/test_tensorboard.py b/tests/test_tensorboard.py index eee0ec0aa..b81586367 100644 --- a/tests/test_tensorboard.py +++ b/tests/test_tensorboard.py @@ -1,5 +1,5 @@ import os -from typing import Dict, Union +from typing import Union import pytest @@ -24,7 +24,7 @@ class HParamCallback(BaseCallback): """ def _on_training_start(self) -> None: - hparam_dict: Dict[str, Union[str, float]] = { + hparam_dict: dict[str, Union[str, float]] = { "algorithm": self.model.__class__.__name__, # Ignore type checking for gamma, see https://github.com/DLR-RM/stable-baselines3/pull/1194/files#r1035006458 "gamma": self.model.gamma, # type: ignore[attr-defined] @@ -33,7 +33,7 @@ def _on_training_start(self) -> None: hparam_dict["learning rate"] = self.model.learning_rate # define the metrics that will appear in the `HPARAMS` Tensorboard tab by referencing their tag # Tensorbaord will find & display metrics from the `SCALARS` tab - metric_dict: Dict[str, float] = { + metric_dict: dict[str, float] = { "rollout/ep_len_mean": 0, } self.logger.record( diff --git a/tests/test_vec_envs.py b/tests/test_vec_envs.py index 3aa52762d..93185fa3c 100644 --- a/tests/test_vec_envs.py +++ b/tests/test_vec_envs.py @@ -4,7 +4,7 @@ import multiprocessing import os import warnings -from typing import Dict, Optional +from typing import Optional import gymnasium as gym import numpy as np @@ -30,9 +30,9 @@ def __init__(self, space, render_mode: str = "rgb_array"): self.current_step = 0 self.ep_length = 4 self.render_mode = render_mode - self.current_options: Optional[Dict] = None + self.current_options: Optional[dict] = None - def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): + def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): if seed is not None: self.seed(seed) self.current_step = 0 @@ -193,7 +193,7 @@ def __init__(self, max_steps): self.max_steps = max_steps self.current_step = 0 - def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): + def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): self.current_step = 0 return np.array([self.current_step], dtype="int"), {} diff --git a/tests/test_vec_normalize.py b/tests/test_vec_normalize.py index b7d71b748..3db5fcd47 100644 --- a/tests/test_vec_normalize.py +++ b/tests/test_vec_normalize.py @@ -1,5 +1,5 @@ import operator -from typing import Any, Dict, Optional +from typing import Any, Optional import gymnasium as gym import numpy as np @@ -22,7 +22,7 @@ class DummyRewardEnv(gym.Env): - metadata: Dict[str, Any] = {} + metadata: dict[str, Any] = {} def __init__(self, return_reward_idx=0): self.action_space = spaces.Discrete(2) @@ -39,7 +39,7 @@ def step(self, action): truncated = self.t == len(self.returned_rewards) return np.array([returned_value]), returned_value, terminated, truncated, {} - def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): + def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): if seed is not None: super().reset(seed=seed) self.t = 0 @@ -62,7 +62,7 @@ def __init__(self): ) self.action_space = spaces.Box(low=-1, high=1, shape=(3,), dtype=np.float32) - def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): + def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): if seed is not None: super().reset(seed=seed) return self.observation_space.sample(), {} @@ -94,7 +94,7 @@ def __init__(self): ) self.action_space = spaces.Box(low=-1, high=1, shape=(3,), dtype=np.float32) - def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): + def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): if seed is not None: super().reset(seed=seed) return self.observation_space.sample(), {}