From ceeab7c7760e29cd917edacea5941abf7c6f0855 Mon Sep 17 00:00:00 2001 From: Denis Steckelmacher Date: Fri, 2 Jul 2021 16:26:59 +0200 Subject: [PATCH 1/2] Add Bootstrapped Dual Policy Iteration (RL algo+unit tests+doc) --- README.md | 1 + docs/guide/algos.rst | 3 +- docs/guide/examples.rst | 13 ++ docs/index.rst | 1 + docs/modules/bdpi.rst | 136 +++++++++++++++ sb3_contrib/__init__.py | 1 + sb3_contrib/bdpi/__init__.py | 2 + sb3_contrib/bdpi/bdpi.py | 281 ++++++++++++++++++++++++++++++ sb3_contrib/bdpi/policies.py | 320 +++++++++++++++++++++++++++++++++++ setup.cfg | 1 + tests/test_cnn.py | 4 +- tests/test_dict_env.py | 21 ++- tests/test_run.py | 17 +- tests/test_save_load.py | 12 +- 14 files changed, 792 insertions(+), 21 deletions(-) create mode 100644 docs/modules/bdpi.rst create mode 100644 sb3_contrib/bdpi/__init__.py create mode 100644 sb3_contrib/bdpi/bdpi.py create mode 100644 sb3_contrib/bdpi/policies.py diff --git a/README.md b/README.md index 5d7e7746..5593d4a4 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,7 @@ See documentation for the full list of included features. **RL Algorithms**: - [Truncated Quantile Critics (TQC)](https://arxiv.org/abs/2005.04269) - [Quantile Regression DQN (QR-DQN)](https://arxiv.org/abs/1710.10044) +- [Bootstrapped Dual Policy Iteration](https://arxiv.org/abs/1903.04193) **Gym Wrappers**: - [Time Feature Wrapper](https://arxiv.org/abs/1712.00378) diff --git a/docs/guide/algos.rst b/docs/guide/algos.rst index f1b0ed55..6dd4a4e5 100644 --- a/docs/guide/algos.rst +++ b/docs/guide/algos.rst @@ -10,11 +10,12 @@ Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Pr ============ =========== ============ ================= =============== ================ TQC ✔️ ❌ ❌ ❌ ❌ QR-DQN ️❌ ️✔️ ❌ ❌ ❌ +BDPI ❌ ✔️ ❌ ❌ ✔️ ============ =========== ============ ================= =============== ================ .. note:: - Non-array spaces such as ``Dict`` or ``Tuple`` are not currently supported by any algorithm. + Non-array spaces such as ``Dict`` or ``Tuple`` are only supported by BDPI, using ``MultiInputPolicy`` instead of ``MlpPolicy`` as the policy. Actions ``gym.spaces``: diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index 12b7d718..d77e5003 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -30,6 +30,19 @@ Train a Quantile Regression DQN (QR-DQN) agent on the CartPole environment. model.learn(total_timesteps=10000, log_interval=4) model.save("qrdqn_cartpole") +BDPI +---- + +Train a Bootstrapped Dual Policy Iteration (BDPI) agent on the LunarLander environment + +.. code-block:: python + + from sb3_contrib import BDPI + + policy_kwargs = dict(n_critics=8) + model = BDPI("MlpPolicy", "LunarLander-v2", policy_kwargs=policy_kwargs, verbose=1) + model.learn(total_timesteps=50000, log_interval=4) + model.save("bdpi_lunarlander") .. PyBullet: Normalizing input features .. ------------------------------------ diff --git a/docs/index.rst b/docs/index.rst index f86c47ee..0b7aec85 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -33,6 +33,7 @@ RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and d modules/tqc modules/qrdqn + modules/bdpi .. toctree:: :maxdepth: 1 diff --git a/docs/modules/bdpi.rst b/docs/modules/bdpi.rst new file mode 100644 index 00000000..941c99ab --- /dev/null +++ b/docs/modules/bdpi.rst @@ -0,0 +1,136 @@ +.. _bdpi: + +.. automodule:: sb3_contrib.bdpi + +BDPI +==== + +`Bootstrapped Dual Policy Iteration `_ is an actor-critic algorithm for +discrete action spaces. The distinctive components of BDPI are as follows: + +- Like Bootstrapped DQN, it uses several critics, with each critic having a Qa and Qb network + (like Clipped DQN). +- The BDPI critics, inspired from the DQN literature, are therefore off-policy. They don't know + about the actor, and do not use any form of off-policy corrections to evaluate the actor. They + instead directly approximate Q*, the optimal value function. +- The actor is trained with an equation inspired from Conservative Policy Iteration, instead of + Policy Gradient (as used by A2C, PPO, SAC, DDPG, etc). This use of Conservative Policy Iteration + is what allows the BDPI actor to be compatible with off-policy critics. + +As a result, BDPI can be configured to be highly sample-efficient, at the cost of compute efficiency. +The off-policy critics can learn aggressively (many samples, many gradient steps), as they don't have +to remain close to the actor. The actor then learns from a mixture of high-quality critics, leading +to good exploration even in challenging environments (see the Table environment described in the paper +linked above). + +Can I use? +---------- + +- Recurrent policies: ❌ +- Multi processing: ✔️ +- Gym spaces: + + +============= ====== =========== +Space Action Observation +============= ====== =========== +Discrete ✔️ ✔️ +Box ❌ ✔️ +MultiDiscrete ❌ ✔️ +MultiBinary ❌ ✔️ +Dict ❌ ✔️ +============= ====== =========== + +Example +------- + +Train a BDPI agent on ``LunarLander-v2``, with hyper-parameters tuned by Optuna in rl-baselines3-zoo: + +.. code-block:: python + + import gym + + from sb3_contrib import BDPI + + model = BDPI( + "MlpPolicy", + 'LunarLander-v2', + actor_lr=0.01, # How fast the actor pursues the greedy policy of the critics + critic_lr=0.234, # Q-Learning learning rate of the critics + batch_size=256, # 256 experiences sampled from the buffer every time-step, for every critic + buffer_size=100000, + gradient_steps=64, # Actor and critics fit for 64 gradient steps per time-step + learning_rate=0.00026, # Adam optimizer learning rate + policy_kwargs=dict(net_arch=[64, 64], n_critics=8), # 8 critics + verbose=1, + tensorboard_log='./tb_log' + ) + + model.learn(total_timesteps=50000) + model.save("bdpi_lunarlander") + + del model # remove to demonstrate saving and loading + + model = BDPI.load("bdpi_lunarlander") + + obs = env.reset() + while True: + action, _states = model.predict(obs) + obs, rewards, dones, info = env.step(action) + env.render() + + +Results +------- + +LunarLander +^^^^^^^^^^^ + +Results for BDPI are available in `this Github issue `_. + +How to replicate the results? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Clone the `rl-zoo repo `_: + +.. code-block:: bash + + git clone https://github.com/DLR-RM/rl-baselines3-zoo + cd rl-baselines3-zoo/ + + +Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above, and ``$N`` with +the number of CPU cores in your machine): + +.. code-block:: bash + + python train.py --algo bdpi --env $ENV_ID --eval-episodes 10 --eval-freq 10000 -params threads:$N + + +Plot the results (here for LunarLander only): + +.. code-block:: bash + + python scripts/all_plots.py -a bdpi -e LunarLander -f logs/ -o logs/bdpi_results + python scripts/plot_from_file.py -i logs/bdpi_results.pkl -latex -l BDPI + + +Parameters +---------- + +.. autoclass:: BDPI + :members: + :inherited-members: + + +BDPI Policies +------------- + +.. autoclass:: MlpPolicy + :members: + +.. autoclass:: CnnPolicy + :members: + +.. autoclass:: MultiInputPolicy + :members: diff --git a/sb3_contrib/__init__.py b/sb3_contrib/__init__.py index 8f253e12..a3c83525 100644 --- a/sb3_contrib/__init__.py +++ b/sb3_contrib/__init__.py @@ -1,5 +1,6 @@ import os +from sb3_contrib.bdpi import BDPI from sb3_contrib.qrdqn import QRDQN from sb3_contrib.tqc import TQC diff --git a/sb3_contrib/bdpi/__init__.py b/sb3_contrib/bdpi/__init__.py new file mode 100644 index 00000000..7dc9b2ef --- /dev/null +++ b/sb3_contrib/bdpi/__init__.py @@ -0,0 +1,2 @@ +from sb3_contrib.bdpi.bdpi import BDPI +from sb3_contrib.bdpi.policies import CnnPolicy, MlpPolicy, MultiInputPolicy diff --git a/sb3_contrib/bdpi/bdpi.py b/sb3_contrib/bdpi/bdpi.py new file mode 100644 index 00000000..6ee4a160 --- /dev/null +++ b/sb3_contrib/bdpi/bdpi.py @@ -0,0 +1,281 @@ +import random +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +import gym +import torch as th +import torch.multiprocessing as mp +from stable_baselines3.common.buffers import ReplayBuffer +from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm +from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule, TensorDict + +from sb3_contrib.bdpi.policies import BDPIPolicy + +# Because BDPI uses many critics per agent, and each critic has 2 Q-Networks, sharing them with file descriptors +# exhausts the maximum number of open file descriptors on many Linux distributions. The file_system sharing method +# creates many small files in /dev/shm, that are then shared by file-name. This avoids reaching the maximum number +# of open file descriptors. +mp.set_sharing_strategy("file_system") + + +def train_model( + model: th.nn.Module, + inp: Union[th.Tensor, TensorDict], + outp: th.Tensor, + gradient_steps: int, +) -> float: + """Train a PyTorch module on inputs and outputs, minimizing the MSE loss for gradient_steps steps. + + :param model: PyTorch module to be trained. It must have a ".optimizer" attribute with an instance of Optimizer in it. + :param inp: Input tensor (or dictionary of tensors if model is a MultiInput model) + :param outp: Expected outputs tensor + :param gradient_steps: Number of gradient steps to execute when minimizing the MSE. + :return: MSE loss (with the 'sum' reduction) after the last gradient step, as a float. + """ + mse_loss = th.nn.MSELoss(reduction="sum") + + for i in range(gradient_steps): + predicted = model(inp) + loss = mse_loss(predicted, outp) + + model.optimizer.zero_grad() + loss.backward() + model.optimizer.step() + + return float(loss.item()) + + +class BDPI(OffPolicyAlgorithm): + """ + Bootstrapped Dual Policy Iteration + + Sample-efficient discrete-action RL algorithm, built on one actor trained + to imitate the greedy policy of several Q-Learning critics. + + Paper: https://arxiv.org/abs/1903.04193 + + :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) + :param env: The environment to learn from (if registered in Gym, can be str) + :param learning_rate: learning rate for adam optimizer, + the same learning rate will be used for all networks (Q-Values and Actor) + it can be a function of the current progress remaining (from 1 to 0) + :param actor_lr: Conservative Policy Iteration learning rate for the actor (used in a formula, not for Adam gradient steps) + :param critic_lr: Q-Learning "alpha" learning rate for the critics + :param buffer_size: size of the replay buffer + :param learning_starts: how many steps of the model to collect transitions for before learning starts + :param batch_size: Minibatch size for each gradient update + :param gamma: the discount factor + :param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit + like ``(5, "step")`` or ``(2, "episode")``. + :param gradient_steps: How many gradient steps to do after each rollout (see ``train_freq``) + :param threads: Number of threads to use to train the actor and critics in parallel + :param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``). + If ``None``, it will be automatically selected. + :param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation. + :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer + at a cost of more complexity. + See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 + :param create_eval_env: Whether to create a second environment that will be + used for evaluating the agent periodically. (Only available when passing string for the environment) + :param policy_kwargs: additional arguments to be passed to the policy on creation + :param verbose: the verbosity level: 0 no output, 1 info, 2 debug + :param seed: Seed for the pseudo random generators + :param device: Device (cpu, cuda, ...) on which the code should be run. + Setting it to auto, the code will be run on the GPU if possible. + :param _init_setup_model: Whether or not to build the network at the creation of the instance. + """ + + def __init__( + self, + policy: Union[str, Type[BDPIPolicy]], + env: Union[GymEnv, str], + learning_rate: Union[float, Schedule] = 1e-4, + actor_lr: float = 0.05, + critic_lr: float = 0.2, + buffer_size: int = 1000000, # 1e6 + learning_starts: int = 256, + batch_size: int = 256, + gamma: float = 0.99, + train_freq: Union[int, Tuple[int, str]] = 1, + gradient_steps: int = 20, + threads: int = 1, + replay_buffer_class: Optional[ReplayBuffer] = None, + replay_buffer_kwargs: Optional[Dict[str, Any]] = None, + optimize_memory_usage: bool = False, + tensorboard_log: Optional[str] = None, + create_eval_env: bool = False, + policy_kwargs: Dict[str, Any] = None, + verbose: int = 0, + seed: Optional[int] = None, + device: Union[th.device, str] = "auto", + _init_setup_model: bool = True, + ): + super(BDPI, self).__init__( + policy, + env, + BDPIPolicy, + learning_rate, + buffer_size, + learning_starts, + batch_size, + 0.0, + gamma, + train_freq, + gradient_steps, + None, + replay_buffer_class=replay_buffer_class, + replay_buffer_kwargs=replay_buffer_kwargs, + policy_kwargs=policy_kwargs, + tensorboard_log=tensorboard_log, + verbose=verbose, + device=device, + create_eval_env=create_eval_env, + seed=seed, + use_sde=False, + sde_sample_freq=1, + use_sde_at_warmup=False, + optimize_memory_usage=optimize_memory_usage, + supported_action_spaces=(gym.spaces.Discrete), + sde_support=False, + ) + + self.actor_lr = actor_lr + self.critic_lr = critic_lr + self.threads = threads + self.pool = mp.get_context("spawn").Pool(threads) + + if _init_setup_model: + self._setup_model() + + def _setup_model(self) -> None: + """Create the BDPI actor and critics, and make their memory shared across processes.""" + super(BDPI, self)._setup_model() + + self.actor = self.policy.actor + self.criticsA = self.policy.criticsA + self.criticsB = self.policy.criticsB + + self.actor.share_memory() + + for cA, cB in zip(self.criticsA, self.criticsB): + cA.share_memory() + cB.share_memory() + + def _excluded_save_params(self) -> List[str]: + """Process pools cannot be pickled, so exclude "self.pool" from the saved parameters of BDPI.""" + return super()._excluded_save_params() + ["pool"] + + def train(self, gradient_steps: int, batch_size: int = 64) -> None: + """BDPI Training procedure. + + This method is called every time-step (if train_freq=1, as in the original paper). + Every time this method is called, the following steps are performed: + + - Every critic, in random order, gets updated with the Clipped DQN equation on its own batch of experiences + - Every critic, just after being updated, computes its greedy policy and updates the actor towards it + - After every critic has been updated, their QA and QB networks are swapped. + + This method implements some basic multi-processing: + + - Every critic and the actor are PyTorch modules with share_memory() called on them + - A process pool is used to perform the neural network training operations (gradient descent steps) + + This approach only has a minimal impact on code, but does not scale very well: + + - On the plus side, the actor is trained concurrently by several workers, ala HOGWILD + - However, the predictions (getting Q(next state), computing updated Q-Values and the greedy policy) + all happen sequentially in the main process. With self.threads>8, the bottleneck therefore becomes + the main process, that has to perform all the updates and predictions. The worker processes only + fit neural networks. + """ + + # Update optimizers learning rate + optimizers = [self.actor.optimizer] + [c.optimizer for c in self.criticsA] + [c.optimizer for c in self.criticsB] + self._update_learning_rate(optimizers) + + # Update every critic (and the actor after each critic), in a random order + critic_losses = [] + actor_losses = [] + critics = list(zip(self.criticsA, self.criticsB)) + + random.shuffle(critics) + + for criticA, criticB in critics: + # Sample replay buffer + with th.no_grad(): + replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) + + # Update the critic (code taken from DQN) + with th.no_grad(): + qvA = criticA(replay_data.next_observations) + qvB = criticB(replay_data.next_observations) + qv = th.min(qvA, qvB) + + QN = th.arange(replay_data.rewards.shape[0]) + next_q_values = qv[QN, qvA.argmax(1)].reshape(-1, 1) + + # 1-step TD target + target_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values + + # Make real supervised learning target Q-Values (even for non-taken actions) + target_q_values = criticA(replay_data.observations) + actions = replay_data.actions.long().flatten() + target_q_values[QN, actions] += self.critic_lr * (target_values.flatten() - target_q_values[QN, actions]) + + critic_losses.append( + self.pool.apply_async(train_model, (criticA, replay_data.observations, target_q_values, gradient_steps)) + ) + + self.logger.record("train/avg_q", float(target_q_values.mean())) + + # Update the actor + with th.no_grad(): + greedy_actions = target_q_values.argmax(1) + + train_probas = th.zeros_like(target_q_values) + train_probas[QN, greedy_actions] = 1.0 + + # Normalize the direction to be pursued + train_probas /= 1e-6 + train_probas.sum(1)[:, None] + actor_probas = self.actor(replay_data.observations) + + # Imitation learning (or distillation, or reward-penalty Pursuit, all these are the same thing) + alr = self.actor_lr + train_probas = (1.0 - alr) * actor_probas + alr * train_probas + train_probas /= train_probas.sum(-1, keepdim=True) + + actor_losses.append( + self.pool.apply_async(train_model, (self.actor, replay_data.observations, train_probas, gradient_steps)) + ) + + # Log losses + for aloss, closs in zip(actor_losses, critic_losses): + self.logger.record("train/critic_loss", closs.get()) + self.logger.record("train/actor_loss", aloss.get()) + + # Swap QA and QB + self.criticsA, self.criticsB = self.criticsB, self.criticsA + + def learn( + self, + total_timesteps: int, + callback: MaybeCallback = None, + log_interval: int = 4, + eval_env: Optional[GymEnv] = None, + eval_freq: int = -1, + n_eval_episodes: int = 5, + tb_log_name: str = "BDPI", + eval_log_path: Optional[str] = None, + reset_num_timesteps: bool = True, + ) -> OffPolicyAlgorithm: + + return super(BDPI, self).learn( + total_timesteps=total_timesteps, + callback=callback, + log_interval=log_interval, + eval_env=eval_env, + eval_freq=eval_freq, + n_eval_episodes=n_eval_episodes, + tb_log_name=tb_log_name, + eval_log_path=eval_log_path, + reset_num_timesteps=reset_num_timesteps, + ) diff --git a/sb3_contrib/bdpi/policies.py b/sb3_contrib/bdpi/policies.py new file mode 100644 index 00000000..e24f43b0 --- /dev/null +++ b/sb3_contrib/bdpi/policies.py @@ -0,0 +1,320 @@ +from typing import Any, Dict, List, Optional, Type, Union + +import gym +import torch as th +from stable_baselines3.common.policies import BasePolicy, register_policy +from stable_baselines3.common.torch_layers import ( + BaseFeaturesExtractor, + CombinedExtractor, + FlattenExtractor, + NatureCNN, + create_mlp, +) +from stable_baselines3.common.type_aliases import Schedule +from stable_baselines3.dqn.policies import QNetwork +from torch import nn + + +class Actor(BasePolicy): + """ + Actor network (policy) for BDPI. + + :param observation_space: Obervation space + :param action_space: Action space + :param net_arch: Network architecture + :param features_extractor: Network to extract features + (a CNN when using images, a nn.Flatten() layer otherwise) + :param features_dim: Number of features + :param activation_fn: Activation function + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + """ + + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + net_arch: List[int], + features_extractor: nn.Module, + features_dim: int, + activation_fn: Type[nn.Module] = nn.ReLU, + normalize_images: bool = True, + ): + super(Actor, self).__init__( + observation_space, action_space, features_extractor=features_extractor, normalize_images=normalize_images + ) + + # Save arguments to re-create object at loading + self.net_arch = net_arch + self.features_dim = features_dim + self.activation_fn = activation_fn + + # One neural network in the policy, that maps a state to a discrete distribution over actions + actor_net = create_mlp(features_dim, self.action_space.n, net_arch, activation_fn) + actor_net.append(nn.Softmax(1)) + + self.actor = nn.Sequential(*actor_net) + + def _get_constructor_parameters(self) -> Dict[str, Any]: + data = super()._get_constructor_parameters() + + data.update( + dict( + net_arch=self.net_arch, + features_dim=self.features_dim, + activation_fn=self.activation_fn, + features_extractor=self.features_extractor, + ) + ) + return data + + def forward(self, obs: th.Tensor) -> th.Tensor: + """Output of the neural network (here: probabilities)""" + return self.actor(self.extract_features(obs)) + + def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: + """Action to be executed by the actor""" + probas = self.forward(observation) + + if deterministic: + # Take the argmax + return th.max(probas, 1)[1].flatten() + else: + # Sample according to probabilities + return th.multinomial(probas, num_samples=1).flatten() + + +class BDPIPolicy(BasePolicy): + """ + Policy class for BDPI, with one actor and several critics (the critics are DQN QNetworks). + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param features_extractor_class: Features extractor to use. + :param features_extractor_kwargs: Keyword arguments to pass to the features extractor. + :param normalize_images: Whether to normalize images or not, dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, excluding the learning rate, to pass to the optimizer + :param n_critics: Number of critic networks to create. + :param share_features_extractor: Whether to share or not the features extractor + between the actor and the critics (this saves computation time) + """ + + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, + activation_fn: Type[nn.Module] = nn.ReLU, + features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + n_critics: int = 16, + share_features_extractor: bool = True, + ): + super(BDPIPolicy, self).__init__( + observation_space, + action_space, + features_extractor_class, + features_extractor_kwargs, + optimizer_class=optimizer_class, + optimizer_kwargs=optimizer_kwargs, + ) + + if net_arch is None: + if features_extractor_class == NatureCNN: + net_arch = [] + else: + net_arch = [256, 256] + + self.net_arch = net_arch + self.activation_fn = activation_fn + self.net_args = { + "observation_space": self.observation_space, + "action_space": self.action_space, + "net_arch": net_arch, + "activation_fn": self.activation_fn, + "normalize_images": normalize_images, + } + + self.actor = None + self.criticsA = nn.ModuleList() + self.criticsB = nn.ModuleList() + self.share_features_extractor = share_features_extractor + self.n_critics = n_critics + + self._build(lr_schedule) + + def _build(self, lr_schedule: Schedule) -> None: + # Make the actor + self.actor = self.make_actor() + self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) + + # Make the critics + for i in range(self.n_critics * 2): + if self.share_features_extractor: + critic = self.make_critic(features_extractor=self.actor.features_extractor) + else: + # Create a separate features extractor for the critic + # this requires more memory and computation + critic = self.make_critic(features_extractor=None) + + critic.optimizer = self.optimizer_class(critic.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) + + if i < self.n_critics: + self.criticsA.append(critic) + else: + self.criticsB.append(critic) + + def _get_constructor_parameters(self) -> Dict[str, Any]: + data = super()._get_constructor_parameters() + + data.update( + dict( + net_arch=self.net_arch, + activation_fn=self.net_args["activation_fn"], + n_critics=self.n_critics, + lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone + optimizer_class=self.optimizer_class, + optimizer_kwargs=self.optimizer_kwargs, + features_extractor_class=self.features_extractor_class, + features_extractor_kwargs=self.features_extractor_kwargs, + ) + ) + return data + + def make_actor(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> Actor: + actor_kwargs = self._update_features_extractor(self.net_args, features_extractor) + return Actor(**actor_kwargs).to(self.device) + + def make_critic(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> QNetwork: + critic_kwargs = self._update_features_extractor(self.net_args, features_extractor) + return QNetwork(**critic_kwargs).to(self.device) + + def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor: + return self._predict(obs, deterministic=deterministic) + + def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: + return self.actor._predict(observation, deterministic) + + +MlpPolicy = BDPIPolicy + + +class CnnPolicy(BDPIPolicy): + """ + Policy class for BDPI, with one actor and several critics (the critics are DQN QNetworks). CNN version + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param features_extractor_class: Features extractor to use. + :param features_extractor_kwargs: Keyword arguments + to pass to the features extractor. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + :param n_critics: Number of critic networks to create. + :param share_features_extractor: Whether to share or not the features extractor + between the actor and the critics (this saves computation time) + """ + + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[List[int]] = None, + activation_fn: Type[nn.Module] = nn.ReLU, + features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + n_critics: int = 16, + share_features_extractor: bool = True, + ): + super(CnnPolicy, self).__init__( + observation_space, + action_space, + lr_schedule, + net_arch, + activation_fn, + features_extractor_class, + features_extractor_kwargs, + normalize_images, + optimizer_class, + optimizer_kwargs, + n_critics, + share_features_extractor, + ) + + +class MultiInputPolicy(BDPIPolicy): + """ + Policy class for BDPI, with one actor and several critics (the critics are DQN QNetworks). Multi-input version + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param features_extractor_class: Features extractor to use. + :param features_extractor_kwargs: Keyword arguments + to pass to the features extractor. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + :param n_critics: Number of critic networks to create. + :param share_features_extractor: Whether to share or not the features extractor + between the actor and the critics (this saves computation time) + """ + + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, + activation_fn: Type[nn.Module] = nn.ReLU, + features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + n_critics: int = 16, + share_features_extractor: bool = True, + ): + super(MultiInputPolicy, self).__init__( + observation_space, + action_space, + lr_schedule, + net_arch, + activation_fn, + features_extractor_class, + features_extractor_kwargs, + normalize_images, + optimizer_class, + optimizer_kwargs, + n_critics, + share_features_extractor, + ) + + +register_policy("MlpPolicy", MlpPolicy) +register_policy("CnnPolicy", CnnPolicy) +register_policy("MultiInputPolicy", MultiInputPolicy) diff --git a/setup.cfg b/setup.cfg index 3ddbd033..72faf2a0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -24,6 +24,7 @@ per-file-ignores = ./sb3_contrib/__init__.py:F401 ./sb3_contrib/qrdqn/__init__.py:F401 ./sb3_contrib/tqc/__init__.py:F401 + ./sb3_contrib/bdpi/__init__.py:F401 ./sb3_contrib/common/wrappers/__init__.py:F401 exclude = # No need to traverse our git directory diff --git a/tests/test_cnn.py b/tests/test_cnn.py index 6c277856..be93565d 100644 --- a/tests/test_cnn.py +++ b/tests/test_cnn.py @@ -8,10 +8,10 @@ from stable_baselines3.common.utils import zip_strict from stable_baselines3.common.vec_env import VecTransposeImage, is_vecenv_wrapped -from sb3_contrib import QRDQN, TQC +from sb3_contrib import BDPI, QRDQN, TQC -@pytest.mark.parametrize("model_class", [TQC, QRDQN]) +@pytest.mark.parametrize("model_class", [TQC, QRDQN, BDPI]) def test_cnn(tmp_path, model_class): SAVE_NAME = "cnn_model.zip" # Fake grayscale with frameskip diff --git a/tests/test_dict_env.py b/tests/test_dict_env.py index fda27259..349f3976 100644 --- a/tests/test_dict_env.py +++ b/tests/test_dict_env.py @@ -6,7 +6,7 @@ from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack, VecNormalize -from sb3_contrib import QRDQN, TQC +from sb3_contrib import BDPI, QRDQN, TQC class DummyDictEnv(gym.Env): @@ -78,13 +78,13 @@ def render(self, mode="human"): pass -@pytest.mark.parametrize("model_class", [QRDQN, TQC]) +@pytest.mark.parametrize("model_class", [QRDQN, TQC, BDPI]) def test_consistency(model_class): """ Make sure that dict obs with vector only vs using flatten obs is equivalent. This ensures notable that the network architectures are the same. """ - use_discrete_actions = model_class == QRDQN + use_discrete_actions = model_class in [QRDQN, BDPI] dict_env = DummyDictEnv(use_discrete_actions=use_discrete_actions, vec_only=True) dict_env = gym.wrappers.TimeLimit(dict_env, 100) env = gym.wrappers.FlattenObservation(dict_env) @@ -124,7 +124,7 @@ def test_consistency(model_class): assert np.allclose(action_1, action_2) -@pytest.mark.parametrize("model_class", [QRDQN, TQC]) +@pytest.mark.parametrize("model_class", [QRDQN, TQC, BDPI]) @pytest.mark.parametrize("channel_last", [False, True]) def test_dict_spaces(model_class, channel_last): """ @@ -138,9 +138,8 @@ def test_dict_spaces(model_class, channel_last): kwargs = {} n_steps = 256 - if model_class in {}: + if model_class in [BDPI]: kwargs = dict( - n_steps=128, policy_kwargs=dict( net_arch=[32], features_extractor_kwargs=dict(cnn_output_dim=32), @@ -169,7 +168,7 @@ def test_dict_spaces(model_class, channel_last): evaluate_policy(model, env, n_eval_episodes=5, warn=False) -@pytest.mark.parametrize("model_class", [QRDQN, TQC]) +@pytest.mark.parametrize("model_class", [QRDQN, TQC, BDPI]) @pytest.mark.parametrize("channel_last", [False, True]) def test_dict_vec_framestack(model_class, channel_last): """ @@ -187,9 +186,8 @@ def test_dict_vec_framestack(model_class, channel_last): kwargs = {} n_steps = 256 - if model_class in {}: + if model_class in [BDPI]: kwargs = dict( - n_steps=128, policy_kwargs=dict( net_arch=[32], features_extractor_kwargs=dict(cnn_output_dim=32), @@ -218,13 +216,14 @@ def test_dict_vec_framestack(model_class, channel_last): evaluate_policy(model, env, n_eval_episodes=5, warn=False) -@pytest.mark.parametrize("model_class", [QRDQN, TQC]) +@pytest.mark.parametrize("model_class", [QRDQN, TQC, BDPI]) def test_vec_normalize(model_class): """ Additional tests to check observation space support for GoalEnv and VecNormalize using MultiInputPolicy. """ - env = DummyVecEnv([lambda: BitFlippingEnv(n_bits=4, continuous=not (model_class == QRDQN))]) + use_discrete_actions = model_class not in [TQC] + env = DummyVecEnv([lambda: BitFlippingEnv(n_bits=4, continuous=not use_discrete_actions)]) env = VecNormalize(env) kwargs = {} diff --git a/tests/test_run.py b/tests/test_run.py index 195d0114..86268b64 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -1,6 +1,6 @@ import pytest -from sb3_contrib import QRDQN, TQC +from sb3_contrib import BDPI, QRDQN, TQC @pytest.mark.parametrize("ent_coef", ["auto", 0.01, "auto_0.01"]) @@ -56,3 +56,18 @@ def test_qrdqn(): create_eval_env=True, ) model.learn(total_timesteps=500, eval_freq=250) + + +@pytest.mark.parametrize("n_critics", [1, 3]) +def test_bdpi(n_critics): + model = BDPI( + "MlpPolicy", + "CartPole-v1", + policy_kwargs=dict(n_critics=n_critics, net_arch=[64, 64]), + learning_starts=0, + buffer_size=500, + learning_rate=3e-4, + verbose=1, + create_eval_env=True, + ) + model.learn(total_timesteps=500, eval_freq=250) diff --git a/tests/test_save_load.py b/tests/test_save_load.py index d2ee3a21..05edf20f 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -12,16 +12,16 @@ from stable_baselines3.common.utils import get_device from stable_baselines3.common.vec_env import DummyVecEnv -from sb3_contrib import QRDQN, TQC +from sb3_contrib import BDPI, QRDQN, TQC -MODEL_LIST = [TQC, QRDQN] +MODEL_LIST = [TQC, QRDQN, BDPI] def select_env(model_class: BaseAlgorithm) -> gym.Env: """ - Selects an environment with the correct action space as QRDQN only supports discrete action space + Selects an environment with the correct action space as QRDQN and BDPI only support discrete action space """ - if model_class == QRDQN: + if model_class in [QRDQN, BDPI]: return IdentityEnv(10) else: return IdentityEnvBox(10) @@ -228,7 +228,7 @@ def test_exclude_include_saved_params(tmp_path, model_class): os.remove(tmp_path / "test_save.zip") -@pytest.mark.parametrize("model_class", [TQC, QRDQN]) +@pytest.mark.parametrize("model_class", [TQC, QRDQN, BDPI]) def test_save_load_replay_buffer(tmp_path, model_class): path = pathlib.Path(tmp_path / "logs/replay_buffer.pkl") path.parent.mkdir(exist_ok=True, parents=True) # to not raise a warning @@ -277,7 +277,7 @@ def test_save_load_policy(tmp_path, model_class, policy_str): learning_starts=100, policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)), ) - env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=model_class == QRDQN) + env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=(model_class in [QRDQN, BDPI])) # Reduce number of quantiles for faster tests if model_class in [TQC, QRDQN]: From 554ae3f068a0378e5bee38046607fe8aa0b77e4c Mon Sep 17 00:00:00 2001 From: Denis Steckelmacher Date: Fri, 2 Jul 2021 20:08:38 +0200 Subject: [PATCH 2/2] Reduce experience buffer size for BDPI in unit tests Large experience buffer sizes lead to warnings about memory allocations, and on the Github CI, to memory allocation failures. So, small experience buffers are important. --- tests/test_cnn.py | 12 +++++++++++- tests/test_dict_env.py | 3 +++ tests/test_save_load.py | 4 ++-- 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/tests/test_cnn.py b/tests/test_cnn.py index be93565d..58f40c13 100644 --- a/tests/test_cnn.py +++ b/tests/test_cnn.py @@ -24,7 +24,15 @@ def test_cnn(tmp_path, model_class): discrete=model_class not in {TQC}, ) kwargs = {} - if model_class in {TQC, QRDQN}: + + if model_class in [BDPI]: + kwargs = dict( + buffer_size=250, + policy_kwargs=dict( + features_extractor_kwargs=dict(features_dim=32), + ), + ) + elif model_class in [TQC, QRDQN]: # Avoid memory error when using replay buffer # Reduce the size of the features and the number of quantiles kwargs = dict( @@ -34,6 +42,7 @@ def test_cnn(tmp_path, model_class): features_extractor_kwargs=dict(features_dim=32), ), ) + model = model_class("CnnPolicy", env, **kwargs).learn(250) obs = env.reset() @@ -94,6 +103,7 @@ def test_feature_extractor_target_net(model_class, share_features_extractor): learning_starts=100, policy_kwargs=dict(n_quantiles=25, features_extractor_kwargs=dict(features_dim=32)), ) + if model_class != QRDQN: kwargs["policy_kwargs"]["share_features_extractor"] = share_features_extractor diff --git a/tests/test_dict_env.py b/tests/test_dict_env.py index 349f3976..116667f2 100644 --- a/tests/test_dict_env.py +++ b/tests/test_dict_env.py @@ -140,6 +140,7 @@ def test_dict_spaces(model_class, channel_last): if model_class in [BDPI]: kwargs = dict( + buffer_size=250, policy_kwargs=dict( net_arch=[32], features_extractor_kwargs=dict(cnn_output_dim=32), @@ -188,6 +189,7 @@ def test_dict_vec_framestack(model_class, channel_last): if model_class in [BDPI]: kwargs = dict( + buffer_size=250, policy_kwargs=dict( net_arch=[32], features_extractor_kwargs=dict(cnn_output_dim=32), @@ -232,6 +234,7 @@ def test_vec_normalize(model_class): if model_class in {}: kwargs = dict( n_steps=128, + buffer_size=250, policy_kwargs=dict( net_arch=[32], ), diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 05edf20f..4eb6c9a4 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -269,7 +269,7 @@ def test_save_load_policy(tmp_path, model_class, policy_str): if policy_str == "MlpPolicy": env = select_env(model_class) else: - if model_class in [TQC, QRDQN]: + if model_class in [TQC, QRDQN, BDPI]: # Avoid memory error when using replay buffer # Reduce the size of the features kwargs = dict( @@ -368,7 +368,7 @@ def test_save_load_q_net(tmp_path, model_class, policy_str): if policy_str == "MlpPolicy": env = select_env(model_class) else: - if model_class in [QRDQN]: + if model_class in [QRDQN, BDPI]: # Avoid memory error when using replay buffer # Reduce the size of the features kwargs = dict(