From e1709f0a1ab515b0ae91e2dce021565bde568ac2 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Thu, 1 Aug 2024 17:34:41 +0200 Subject: [PATCH 01/33] TrainingStats: fix for zero-len sequences, fixed an optional type Fixup --- tianshou/data/stats.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tianshou/data/stats.py b/tianshou/data/stats.py index b77318602..4685f5730 100644 --- a/tianshou/data/stats.py +++ b/tianshou/data/stats.py @@ -22,6 +22,8 @@ class SequenceSummaryStats(DataclassPPrintMixin): @classmethod def from_sequence(cls, sequence: Sequence[float | int] | np.ndarray) -> "SequenceSummaryStats": + if len(sequence) == 0: + return cls(mean=0.0, std=0.0, max=0.0, min=0.0) return cls( mean=float(np.mean(sequence)), std=float(np.std(sequence)), From e41decacac43e9c8f32e1deb4c497b7586e51510 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Thu, 1 Aug 2024 18:02:45 +0200 Subject: [PATCH 02/33] Rliable eval: multiple extensions 1. Support for evaluating training runs 2. Improved handling of figures and axes 3. Allow passing max_env_step 4. Use min len of all experiments (bugfix, previously it would crash if experiments had different lengths) --- tianshou/evaluation/rliable_evaluation_hl.py | 189 ++++++++++++++++--- 1 file changed, 158 insertions(+), 31 deletions(-) diff --git a/tianshou/evaluation/rliable_evaluation_hl.py b/tianshou/evaluation/rliable_evaluation_hl.py index 884176bd1..953fbf01d 100644 --- a/tianshou/evaluation/rliable_evaluation_hl.py +++ b/tianshou/evaluation/rliable_evaluation_hl.py @@ -3,7 +3,8 @@ """ import os -from dataclasses import asdict, dataclass, fields +from dataclasses import dataclass, fields +from typing import Literal import matplotlib.pyplot as plt import numpy as np @@ -12,7 +13,7 @@ from rliable import plot_utils from tianshou.highlevel.experiment import Experiment -from tianshou.utils import logging +from tianshou.utils import TensorboardLogger, logging from tianshou.utils.logger.base import DataScope log = logging.getLogger(__name__) @@ -61,17 +62,31 @@ class RLiableExperimentResult: test_episode_returns_RE: np.ndarray """The test episodes for each run of the experiment where each row corresponds to one run.""" + train_episode_returns_RE: np.ndarray + """The training episodes for each run of the experiment where each row corresponds to one run.""" + env_steps_E: np.ndarray """The number of environment steps at which the test episodes were evaluated.""" + env_steps_train_E: np.ndarray + """The number of environment steps at which the training episodes were evaluated.""" + @classmethod - def load_from_disk(cls, exp_dir: str) -> "RLiableExperimentResult": + def load_from_disk( + cls, + exp_dir: str, + max_env_step: int | None = None, + ) -> "RLiableExperimentResult": """Load the experiment result from disk. :param exp_dir: The directory from where the experiment results are restored. + :param max_env_step: The maximum number of environment steps to consider. If None, all data is considered. + Note: if the experiments have different numbers of steps, the minimum number is used. """ test_episode_returns = [] + train_episode_returns = [] env_step_at_test = None + env_step_at_train = None # TODO: env_step_at_test should not be defined in a loop and overwritten at each iteration # just for retrieving them. We might need a cleaner directory structure. @@ -79,43 +94,101 @@ def load_from_disk(cls, exp_dir: str) -> "RLiableExperimentResult": if entry.name.startswith(".") or not entry.is_dir(): continue - exp = Experiment.from_directory(entry.path) - logger = exp.logger_factory.create_logger( - entry.path, - entry.name, - None, - asdict(exp.config), - ) - data = logger.restore_logged_data(entry.path) + try: + # TODO: fix + logger_factory = Experiment.from_directory(entry.path).logger_factory + logger_cls = type(logger_factory.create_logger(entry.path, entry.name, None)) + # Usually this means from low-level API + except FileNotFoundError: + log.info( + f"Could not find persisted experiment in {entry.path}, using default logger.", + ) + logger_cls = TensorboardLogger + + data = logger_cls.restore_logged_data(entry.path) + # TODO: align low-level and high-level dir structure. This is a hack! + if not data: + dirs = [ + d for d in os.listdir(entry.path) if os.path.isdir(os.path.join(entry.path, d)) + ] + if len(dirs) != 1: + raise ValueError( + f"Could not restore data from {entry.path}, " + f"expected either events or exactly one subdirectory, ", + ) + data = logger_cls.restore_logged_data(os.path.join(entry.path, dirs[0])) + if not data: + raise ValueError(f"Could not restore data from {entry.path}.") if DataScope.TEST.value not in data or not data[DataScope.TEST.value]: continue restored_test_data = data[DataScope.TEST.value] - if not isinstance(restored_test_data, dict): - raise RuntimeError( - f"Expected entry with key {DataScope.TEST.value} data to be a dictionary, " - f"but got {restored_test_data=}.", - ) + restored_train_data = data[DataScope.TRAIN.value] + for restored_data, scope in zip( + [restored_test_data, restored_train_data], + [DataScope.TEST, DataScope.TRAIN], + strict=True, + ): + if not isinstance(restored_data, dict): + raise RuntimeError( + f"Expected entry with key {scope.value} data to be a dictionary, " + f"but got {restored_data=}.", + ) test_data = LoggedCollectStats.from_data_dict(restored_test_data) + train_data = LoggedCollectStats.from_data_dict(restored_train_data) - if test_data.returns_stat is None: - continue - test_episode_returns.append(test_data.returns_stat.mean) - env_step_at_test = test_data.env_step + if test_data.returns_stat is not None: + test_episode_returns.append(test_data.returns_stat.mean) + env_step_at_test = test_data.env_step + + if train_data.returns_stat is not None: + train_episode_returns.append(train_data.returns_stat.mean) + env_step_at_train = train_data.env_step + test_data_found = True + train_data_found = True if not test_episode_returns or env_step_at_test is None: - raise ValueError(f"No experiment data found in {exp_dir}.") + log.warning(f"No test experiment data found in {exp_dir}.") + test_data_found = False + if not train_episode_returns or env_step_at_train is None: + log.warning(f"No train experiment data found in {exp_dir}.") + train_data_found = False + + if not test_data_found and not train_data_found: + raise RuntimeError(f"No test or train data found in {exp_dir}.") + + min_train_len = min([len(arr) for arr in train_episode_returns]) + if max_env_step is not None: + min_train_len = min(min_train_len, max_env_step) + min_test_len = min([len(arr) for arr in test_episode_returns]) + if max_env_step is not None: + min_test_len = min(min_test_len, max_env_step) + + env_step_at_test = env_step_at_test[:min_test_len] + env_step_at_train = env_step_at_train[:min_train_len] + if max_env_step: + # find the index at which the maximum env step is reached with searchsorted + min_test_len = np.searchsorted(env_step_at_test, max_env_step) + min_train_len = np.searchsorted(env_step_at_train, max_env_step) + env_step_at_test = env_step_at_test[:min_test_len] + env_step_at_train = env_step_at_train[:min_train_len] + + test_episode_returns = np.array([arr[:min_test_len] for arr in test_episode_returns]) + train_episode_returns = np.array([arr[:min_train_len] for arr in train_episode_returns]) return cls( - test_episode_returns_RE=np.array(test_episode_returns), - env_steps_E=np.array(env_step_at_test), + test_episode_returns_RE=test_episode_returns, + env_steps_E=env_step_at_test, exp_dir=exp_dir, + train_episode_returns_RE=train_episode_returns, + env_steps_train_E=env_step_at_train, ) def _get_rliable_data( self, algo_name: str | None = None, score_thresholds: np.ndarray | None = None, + scope: DataScope | Literal["train", "test"] = DataScope.TEST, ) -> tuple[dict, np.ndarray, np.ndarray]: """Return the data in the format expected by the rliable library. @@ -126,19 +199,27 @@ def _get_rliable_data( :return: A tuple score_dict, env_steps, and score_thresholds. """ + if isinstance(scope, DataScope): + scope = scope.value + if scope == DataScope.TEST.value: + env_steps, returns = self.env_steps_E, self.test_episode_returns_RE + elif scope == DataScope.TRAIN.value: + env_steps, returns = self.env_steps_train_E, self.train_episode_returns_RE + else: + raise ValueError(f"Invalid scope {scope}, should be either 'TEST' or 'TRAIN'.") if score_thresholds is None: score_thresholds = np.linspace( - np.min(self.test_episode_returns_RE), - np.max(self.test_episode_returns_RE), + np.min(returns), + np.max(returns), 101, ) if algo_name is None: algo_name = os.path.basename(self.exp_dir) - score_dict = {algo_name: self.test_episode_returns_RE} + score_dict = {algo_name: returns} - return score_dict, self.env_steps_E, score_thresholds + return score_dict, env_steps, score_thresholds def eval_results( self, @@ -146,6 +227,10 @@ def eval_results( score_thresholds: np.ndarray | None = None, save_plots: bool = False, show_plots: bool = True, + scope: DataScope | Literal["train", "test"] = DataScope.TEST, + ax_iqm: plt.Axes | None = None, + ax_profile: plt.Axes | None = None, + algo2color: dict[str, str] | None = None, ) -> tuple[plt.Figure, plt.Axes, plt.Figure, plt.Axes]: """Evaluate the results of an experiment and create a sample efficiency curve and a performance profile. @@ -155,19 +240,30 @@ def eval_results( from the minimum and maximum test episode returns. :param save_plots: If True, the figures are saved to the experiment directory. :param show_plots: If True, the figures are shown. - - :return: The created figures and axes. + :param scope: The scope of the evaluation, either 'TEST' or 'TRAIN'. + :param ax_iqm: The axis to plot the IQM sample efficiency curve on. If None, a new figure is created. + :param ax_profile: The axis to plot the performance profile on. If None, a new figure is created. + :param algo2color: A dictionary mapping algorithm names to colors. Useful for plotting + the evaluations of multiple algorithms in the same figure, e.g., by first creating an ax_iqm and ax_profile + with one evaluation and then passing them into the other evaluation. Same as the `colors` + kwarg in the rliable plotting utils. + + :return: The created figures and axes in the order: fig_iqm, ax_iqm, fig_profile, ax_profile. """ score_dict, env_steps, score_thresholds = self._get_rliable_data( algo_name, score_thresholds, + scope, ) iqm = lambda scores: sst.trim_mean(scores, proportiontocut=0.25, axis=0) iqm_scores, iqm_cis = rly.get_interval_estimates(score_dict, iqm) # Plot IQM sample efficiency curve - fig_iqm, ax_iqm = plt.subplots(ncols=1, figsize=(7, 5), constrained_layout=True) + if ax_iqm is None: + fig_iqm, ax_iqm = plt.subplots(ncols=1, figsize=(7, 5), constrained_layout=True) + else: + fig_iqm = ax_iqm.get_figure() plot_utils.plot_sample_efficiency_curve( env_steps, iqm_scores, @@ -176,6 +272,7 @@ def eval_results( xlabel="env step", ylabel="IQM episode return", ax=ax_iqm, + colors=algo2color, ) if show_plots: plt.show(block=False) @@ -197,7 +294,10 @@ def eval_results( ) # Plot score distributions - fig_profile, ax_profile = plt.subplots(ncols=1, figsize=(7, 5), constrained_layout=True) + if ax_profile is None: + fig_profile, ax_profile = plt.subplots(ncols=1, figsize=(7, 5), constrained_layout=True) + else: + fig_profile = ax_profile.get_figure() plot_utils.plot_performance_profiles( score_distributions, score_thresholds, @@ -216,3 +316,30 @@ def eval_results( plt.show(block=False) return fig_iqm, ax_iqm, fig_profile, ax_profile + + +def load_and_eval_experiments( + log_dir: str, + show_plots: bool = True, + save_plots: bool = True, + scope: DataScope | Literal["train", "test", "both"] = DataScope.TEST, + max_env_step: int | None = None, +) -> RLiableExperimentResult: + """Evaluate the experiments in the given log directory using the rliable API and return the loaded results object. + + If neither `show_plots` nor `save_plots` is set to `True`, this is equivalent to just loading the results from disk. + + :param log_dir: The directory containing the experiment results. + :param show_plots: whether to display plots. + :param save_plots: whether to save plots to the `log_dir`. + :param scope: The scope of the evaluation, either 'TEST' or 'TRAIN'. + :param max_env_step: The maximum number of environment steps to consider. If None, all data is considered. + Note: if the experiments have different numbers of steps, the minimum number is used. + """ + rliable_result = RLiableExperimentResult.load_from_disk(log_dir, max_env_step=max_env_step) + if scope == "both": + for scope in [DataScope.TEST, DataScope.TRAIN]: + rliable_result.eval_results(show_plots=True, save_plots=True, scope=scope) + else: + rliable_result.eval_results(show_plots=show_plots, save_plots=save_plots, scope=scope) + return rliable_result From 9ceb041fccdf18604d23a046aada0d184df3c263 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Thu, 1 Aug 2024 18:17:55 +0200 Subject: [PATCH 03/33] Added WandbLoggerFactory, made config_dict optional --- tianshou/highlevel/logger.py | 41 ++++++++++++++++++++++++++++++++++-- 1 file changed, 39 insertions(+), 2 deletions(-) diff --git a/tianshou/highlevel/logger.py b/tianshou/highlevel/logger.py index 2223d81f7..13da2aae7 100644 --- a/tianshou/highlevel/logger.py +++ b/tianshou/highlevel/logger.py @@ -22,7 +22,7 @@ def create_logger( """Creates the logger. :param log_dir: path to the directory in which log data is to be stored - :param experiment_name: the name of the job, which may contain `os.path.sep` + :param experiment_name: the name of the job, which may contain `os.path.delimiter` :param run_id: a unique name, which, depending on the logging framework, may be used to identify the logger :param config_dict: a dictionary with data that is to be logged :return: the logger @@ -45,7 +45,7 @@ def create_logger( log_dir: str, experiment_name: str, run_id: str | None, - config_dict: dict, + config_dict: dict | None = None, ) -> TLogger: if self.logger_type in ["wandb", "tensorboard"]: writer = SummaryWriter(log_dir) @@ -74,3 +74,40 @@ def create_logger( return TensorboardLogger(writer) case _: raise ValueError(f"Unknown logger type '{self.logger_type}'") + + +class WandbLoggerFactory(LoggerFactory): + def __init__( + self, + wandb_project: str, + group: str | None = None, + job_type: str | None = None, + save_interval: int = 1, + ): + self.wandb_project = wandb_project + self.group = group + self.job_type = job_type + self.save_interval = save_interval + + def create_logger( + self, + log_dir: str, + experiment_name: str, + run_id: str | None, + config_dict: dict | None, + ) -> TLogger: + logger = WandbLogger( + save_interval=self.save_interval, + name=experiment_name.replace(os.path.sep, "__"), + run_id=run_id, + config=config_dict, + project=self.wandb_project, + # entity= + group=self.group, + job_type=self.job_type, + log_dir=log_dir, + ) + + writer = SummaryWriter(log_dir) + logger.load(writer) + return logger From c492765a173d2e74352ec4e14e6948783d6eaa92 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Thu, 1 Aug 2024 19:04:14 +0200 Subject: [PATCH 04/33] Logging: made restore_logged_data static. Eval: better use of DataScope Minor improvements in typing --- tianshou/evaluation/rliable_evaluation_hl.py | 36 +++++++++++--------- tianshou/highlevel/logger.py | 8 ++--- tianshou/utils/logger/base.py | 17 ++++----- tianshou/utils/logger/tensorboard.py | 2 +- tianshou/utils/logger/wandb.py | 18 +++++----- 5 files changed, 42 insertions(+), 39 deletions(-) diff --git a/tianshou/evaluation/rliable_evaluation_hl.py b/tianshou/evaluation/rliable_evaluation_hl.py index 953fbf01d..28b6b25cc 100644 --- a/tianshou/evaluation/rliable_evaluation_hl.py +++ b/tianshou/evaluation/rliable_evaluation_hl.py @@ -4,7 +4,6 @@ import os from dataclasses import dataclass, fields -from typing import Literal import matplotlib.pyplot as plt import numpy as np @@ -120,10 +119,14 @@ def load_from_disk( if not data: raise ValueError(f"Could not restore data from {entry.path}.") - if DataScope.TEST.value not in data or not data[DataScope.TEST.value]: + if DataScope.TEST not in data or not data[DataScope.TEST]: continue - restored_test_data = data[DataScope.TEST.value] - restored_train_data = data[DataScope.TRAIN.value] + restored_test_data = data[DataScope.TEST] + restored_train_data = data[DataScope.TRAIN] + + assert isinstance(restored_test_data, dict) + assert isinstance(restored_train_data, dict) + for restored_data, scope in zip( [restored_test_data, restored_train_data], [DataScope.TEST, DataScope.TRAIN], @@ -131,7 +134,7 @@ def load_from_disk( ): if not isinstance(restored_data, dict): raise RuntimeError( - f"Expected entry with key {scope.value} data to be a dictionary, " + f"Expected entry with key {scope} data to be a dictionary, " f"but got {restored_data=}.", ) test_data = LoggedCollectStats.from_data_dict(restored_test_data) @@ -164,12 +167,15 @@ def load_from_disk( if max_env_step is not None: min_test_len = min(min_test_len, max_env_step) + assert env_step_at_test is not None + assert env_step_at_train is not None + env_step_at_test = env_step_at_test[:min_test_len] env_step_at_train = env_step_at_train[:min_train_len] if max_env_step: # find the index at which the maximum env step is reached with searchsorted - min_test_len = np.searchsorted(env_step_at_test, max_env_step) - min_train_len = np.searchsorted(env_step_at_train, max_env_step) + min_test_len = int(np.searchsorted(env_step_at_test, max_env_step)) + min_train_len = int(np.searchsorted(env_step_at_train, max_env_step)) env_step_at_test = env_step_at_test[:min_test_len] env_step_at_train = env_step_at_train[:min_train_len] @@ -188,7 +194,7 @@ def _get_rliable_data( self, algo_name: str | None = None, score_thresholds: np.ndarray | None = None, - scope: DataScope | Literal["train", "test"] = DataScope.TEST, + scope: DataScope = DataScope.TEST, ) -> tuple[dict, np.ndarray, np.ndarray]: """Return the data in the format expected by the rliable library. @@ -199,11 +205,9 @@ def _get_rliable_data( :return: A tuple score_dict, env_steps, and score_thresholds. """ - if isinstance(scope, DataScope): - scope = scope.value - if scope == DataScope.TEST.value: + if scope == DataScope.TEST: env_steps, returns = self.env_steps_E, self.test_episode_returns_RE - elif scope == DataScope.TRAIN.value: + elif scope == DataScope.TRAIN: env_steps, returns = self.env_steps_train_E, self.train_episode_returns_RE else: raise ValueError(f"Invalid scope {scope}, should be either 'TEST' or 'TRAIN'.") @@ -227,7 +231,7 @@ def eval_results( score_thresholds: np.ndarray | None = None, save_plots: bool = False, show_plots: bool = True, - scope: DataScope | Literal["train", "test"] = DataScope.TEST, + scope: DataScope = DataScope.TEST, ax_iqm: plt.Axes | None = None, ax_profile: plt.Axes | None = None, algo2color: dict[str, str] | None = None, @@ -263,7 +267,7 @@ def eval_results( if ax_iqm is None: fig_iqm, ax_iqm = plt.subplots(ncols=1, figsize=(7, 5), constrained_layout=True) else: - fig_iqm = ax_iqm.get_figure() + fig_iqm = ax_iqm.get_figure() # type: ignore plot_utils.plot_sample_efficiency_curve( env_steps, iqm_scores, @@ -297,7 +301,7 @@ def eval_results( if ax_profile is None: fig_profile, ax_profile = plt.subplots(ncols=1, figsize=(7, 5), constrained_layout=True) else: - fig_profile = ax_profile.get_figure() + fig_profile = ax_profile.get_figure() # type: ignore plot_utils.plot_performance_profiles( score_distributions, score_thresholds, @@ -322,7 +326,7 @@ def load_and_eval_experiments( log_dir: str, show_plots: bool = True, save_plots: bool = True, - scope: DataScope | Literal["train", "test", "both"] = DataScope.TEST, + scope: DataScope = DataScope.TEST, max_env_step: int | None = None, ) -> RLiableExperimentResult: """Evaluate the experiments in the given log directory using the rliable API and return the loaded results object. diff --git a/tianshou/highlevel/logger.py b/tianshou/highlevel/logger.py index 13da2aae7..f3405db66 100644 --- a/tianshou/highlevel/logger.py +++ b/tianshou/highlevel/logger.py @@ -17,7 +17,7 @@ def create_logger( log_dir: str, experiment_name: str, run_id: str | None, - config_dict: dict, + config_dict: dict | None = None, ) -> TLogger: """Creates the logger. @@ -94,7 +94,7 @@ def create_logger( log_dir: str, experiment_name: str, run_id: str | None, - config_dict: dict | None, + config_dict: dict | None = None, ) -> TLogger: logger = WandbLogger( save_interval=self.save_interval, @@ -102,10 +102,6 @@ def create_logger( run_id=run_id, config=config_dict, project=self.wandb_project, - # entity= - group=self.group, - job_type=self.job_type, - log_dir=log_dir, ) writer = SummaryWriter(log_dir) diff --git a/tianshou/utils/logger/base.py b/tianshou/utils/logger/base.py index c1bd73795..cbe22aca7 100644 --- a/tianshou/utils/logger/base.py +++ b/tianshou/utils/logger/base.py @@ -1,7 +1,7 @@ import typing from abc import ABC, abstractmethod from collections.abc import Callable -from enum import Enum +from enum import StrEnum from numbers import Number import numpy as np @@ -13,7 +13,7 @@ TRestoredData = dict[str, np.ndarray | dict[str, "TRestoredData"]] -class DataScope(Enum): +class DataScope(StrEnum): TRAIN = "train" TEST = "test" UPDATE = "update" @@ -76,7 +76,7 @@ def log_train_data(self, log_data: dict, step: int) -> None: # TODO: move interval check to calling method if step - self.last_log_train_step >= self.train_interval: log_data = self.prepare_dict_for_logging(log_data) - self.write(f"{DataScope.TRAIN.value}/env_step", step, log_data) + self.write(f"{DataScope.TRAIN}/env_step", step, log_data) self.last_log_train_step = step def log_test_data(self, log_data: dict, step: int) -> None: @@ -88,7 +88,7 @@ def log_test_data(self, log_data: dict, step: int) -> None: # TODO: move interval check to calling method (stupid because log_test_data is only called from function in utils.py, not from BaseTrainer) if step - self.last_log_test_step >= self.test_interval: log_data = self.prepare_dict_for_logging(log_data) - self.write(f"{DataScope.TEST.value}/env_step", step, log_data) + self.write(f"{DataScope.TEST}/env_step", step, log_data) self.last_log_test_step = step def log_update_data(self, log_data: dict, step: int) -> None: @@ -100,7 +100,7 @@ def log_update_data(self, log_data: dict, step: int) -> None: # TODO: move interval check to calling method if step - self.last_log_update_step >= self.update_interval: log_data = self.prepare_dict_for_logging(log_data) - self.write(f"{DataScope.UPDATE.value}/gradient_step", step, log_data) + self.write(f"{DataScope.UPDATE}/gradient_step", step, log_data) self.last_log_update_step = step def log_info_data(self, log_data: dict, step: int) -> None: @@ -113,7 +113,7 @@ def log_info_data(self, log_data: dict, step: int) -> None: step - self.last_log_info_step >= self.info_interval ): # TODO: move interval check to calling method log_data = self.prepare_dict_for_logging(log_data) - self.write(f"{DataScope.INFO.value}/epoch", step, log_data) + self.write(f"{DataScope.INFO}/epoch", step, log_data) self.last_log_info_step = step @abstractmethod @@ -143,9 +143,9 @@ def restore_data(self) -> tuple[int, int, int]: :return: epoch, env_step, gradient_step. """ + @staticmethod @abstractmethod def restore_logged_data( - self, log_path: str, ) -> TRestoredData: """Load the logged data from disk for post-processing. @@ -181,5 +181,6 @@ def save_data( def restore_data(self) -> tuple[int, int, int]: return 0, 0, 0 - def restore_logged_data(self, log_path: str) -> dict: + @staticmethod + def restore_logged_data(log_path: str) -> dict: return {} diff --git a/tianshou/utils/logger/tensorboard.py b/tianshou/utils/logger/tensorboard.py index d824d862d..17b4ec278 100644 --- a/tianshou/utils/logger/tensorboard.py +++ b/tianshou/utils/logger/tensorboard.py @@ -136,8 +136,8 @@ def restore_data(self) -> tuple[int, int, int]: return epoch, env_step, gradient_step + @staticmethod def restore_logged_data( - self, log_path: str, ) -> TRestoredData: """Restores the logged data from the tensorboard log directory. diff --git a/tianshou/utils/logger/wandb.py b/tianshou/utils/logger/wandb.py index 74d844fa9..63eccff66 100644 --- a/tianshou/utils/logger/wandb.py +++ b/tianshou/utils/logger/wandb.py @@ -1,5 +1,6 @@ import argparse import contextlib +import logging import os from collections.abc import Callable @@ -11,6 +12,8 @@ with contextlib.suppress(ImportError): import wandb +log = logging.getLogger(__name__) + class WandbLogger(BaseLogger): """Weights and Biases logger that sends data to https://wandb.ai/. @@ -167,11 +170,10 @@ def restore_data(self) -> tuple[int, int, int]: env_step = 0 return epoch, env_step, gradient_step - def restore_logged_data(self, log_path: str) -> TRestoredData: - if self.tensorboard_logger is None: - raise NotImplementedError( - "Restoring logged data directly from W&B is not yet implemented." - "Try instantiating the internal TensorboardLogger by calling something" - "like `logger.load(SummaryWriter(log_path))`", - ) - return self.tensorboard_logger.restore_logged_data(log_path) + @staticmethod + def restore_logged_data(log_path: str) -> TRestoredData: + log.warning( + "Logging data directly from W&B is not yet implemented, will use the " + "TensorboardLogger to restore it from disc instead.", + ) + return TensorboardLogger.restore_logged_data(log_path) From 547b62619f408a876a5c8bebe325f8df0a0ef666 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Thu, 1 Aug 2024 19:28:15 +0200 Subject: [PATCH 05/33] Spelling [ci skip] --- docs/spelling_wordlist.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index c30b9f2cb..3a855c63a 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -271,3 +271,5 @@ v_s_ obs obs_next dtype +iqm +kwarg From f18f4a4eab77bfdfee178f23a9358973c9cbf250 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Mon, 5 Aug 2024 12:08:10 +0200 Subject: [PATCH 06/33] Minor typing and docstrings [ci skip] --- tianshou/evaluation/rliable_evaluation_hl.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tianshou/evaluation/rliable_evaluation_hl.py b/tianshou/evaluation/rliable_evaluation_hl.py index 28b6b25cc..b3324f70e 100644 --- a/tianshou/evaluation/rliable_evaluation_hl.py +++ b/tianshou/evaluation/rliable_evaluation_hl.py @@ -4,6 +4,7 @@ import os from dataclasses import dataclass, fields +from typing import Literal import matplotlib.pyplot as plt import numpy as np @@ -85,7 +86,11 @@ def load_from_disk( test_episode_returns = [] train_episode_returns = [] env_step_at_test = None + """The number of steps of the test run, + will try extracting it either from the loaded stats or from loaded arrays.""" env_step_at_train = None + """The number of steps of the training run, + will try extracting it from the loaded stats or from loaded arrays.""" # TODO: env_step_at_test should not be defined in a loop and overwritten at each iteration # just for retrieving them. We might need a cleaner directory structure. @@ -326,7 +331,7 @@ def load_and_eval_experiments( log_dir: str, show_plots: bool = True, save_plots: bool = True, - scope: DataScope = DataScope.TEST, + scope: DataScope | Literal["both"] = DataScope.TEST, max_env_step: int | None = None, ) -> RLiableExperimentResult: """Evaluate the experiments in the given log directory using the rliable API and return the loaded results object. @@ -336,7 +341,7 @@ def load_and_eval_experiments( :param log_dir: The directory containing the experiment results. :param show_plots: whether to display plots. :param save_plots: whether to save plots to the `log_dir`. - :param scope: The scope of the evaluation, either 'TEST' or 'TRAIN'. + :param scope: The scope of the evaluation, either 'test', 'train' or 'both'. :param max_env_step: The maximum number of environment steps to consider. If None, all data is considered. Note: if the experiments have different numbers of steps, the minimum number is used. """ From 434606dd9161a24f1cbdf19a992eb61f1b747954 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Wed, 7 Aug 2024 11:43:14 +0200 Subject: [PATCH 07/33] New util: create_uniform_action_dist --- test/base/test_utils.py | 109 +++++++++++++++++++++++++++++++++- tianshou/utils/torch_utils.py | 39 +++++++++++- 2 files changed, 146 insertions(+), 2 deletions(-) diff --git a/test/base/test_utils.py b/test/base/test_utils.py index f8e5938cb..23fe0b337 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -1,12 +1,15 @@ import numpy as np +import pytest import torch +import torch.distributions as dist +from gymnasium import spaces from torch import nn +from utils.torch_utils import create_uniform_action_dist, torch_train_mode from tianshou.exploration import GaussianNoise, OUNoise from tianshou.utils import MovAvg, MultipleLRSchedulers, RunningMeanStd from tianshou.utils.net.common import MLP, Net from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic -from tianshou.utils.torch_utils import torch_train_mode def test_noise() -> None: @@ -148,3 +151,107 @@ def test_in_train_mode() -> None: with torch_train_mode(module): assert module.training assert not module.training + + +class TestCreateActionDistribution: + @classmethod + def setup_class(cls): + # Set random seeds for reproducibility + torch.manual_seed(0) + np.random.seed(0) + + @pytest.mark.parametrize( + "action_space, batch_size", + [ + (spaces.Box(low=-1.0, high=1.0, shape=(3,)), 1), + (spaces.Box(low=-1.0, high=1.0, shape=(3,)), 5), + (spaces.Discrete(5), 1), + (spaces.Discrete(5), 5), + ], + ) + def test_distribution_properties( + self, action_space: spaces.Box | spaces.Discrete, batch_size: int, + ) -> None: + distribution = create_uniform_action_dist(action_space, batch_size) + + # Correct distribution type + if isinstance(action_space, spaces.Box): + assert isinstance(distribution, dist.Uniform) + elif isinstance(action_space, spaces.Discrete): + assert isinstance(distribution, dist.Categorical) + + # Samples are within correct range + samples = distribution.sample() + if isinstance(action_space, spaces.Box): + low = torch.tensor(action_space.low, dtype=torch.float32) + high = torch.tensor(action_space.high, dtype=torch.float32) + assert torch.all(samples >= low) + assert torch.all(samples <= high) + elif isinstance(action_space, spaces.Discrete): + assert torch.all(samples >= 0) + assert torch.all(samples < action_space.n) + + @pytest.mark.parametrize( + "action_space, batch_size", + [ + (spaces.Box(low=-1.0, high=1.0, shape=(3,)), 1), + (spaces.Box(low=-1.0, high=1.0, shape=(3,)), 5), + (spaces.Discrete(5), 1), + (spaces.Discrete(5), 5), + ], + ) + def test_distribution_uniformity( + self, action_space: spaces.Box | spaces.Discrete, batch_size: int, + ) -> None: + distribution = create_uniform_action_dist(action_space, batch_size) + + # Test 7: Uniform distribution (statistical test) + large_sample = distribution.sample((10000,)) + if isinstance(action_space, spaces.Box): + # For Box, check if mean is close to 0 and std is close to 1/sqrt(3) + assert torch.allclose(large_sample.mean(), torch.tensor(0.0), atol=0.1) + assert torch.allclose(large_sample.std(), torch.tensor(1 / 3**0.5), atol=0.1) + elif isinstance(action_space, spaces.Discrete): + # For Discrete, check if all actions are roughly equally likely + counts = torch.bincount(large_sample.flatten(), minlength=action_space.n).float() + expected_count = 10000 * batch_size / action_space.n + assert torch.allclose(counts, torch.tensor(expected_count).float(), rtol=0.1) + + def test_unsupported_space(self) -> None: + # Test 6: Raises ValueError for unsupported space + with pytest.raises(ValueError): + create_uniform_action_dist(spaces.MultiBinary(5)) + + @pytest.mark.parametrize( + "space, batch_size, expected_shape, distribution_type", + [ + (spaces.Box(low=-1.0, high=1.0, shape=(3,)), 1, (1, 3), dist.Uniform), + (spaces.Box(low=-1.0, high=1.0, shape=(3,)), 5, (5, 3), dist.Uniform), + (spaces.Box(low=-1.0, high=1.0, shape=(3,)), 10, (10, 3), dist.Uniform), + (spaces.Discrete(5), 1, (1,), dist.Categorical), + (spaces.Discrete(5), 5, (5,), dist.Categorical), + (spaces.Discrete(5), 10, (10,), dist.Categorical), + ], + ) + def test_batch_sizes( + self, + space: spaces.Box | spaces.Discrete, + batch_size: int, + expected_shape: tuple[int, ...], + distribution_type: type[dist.Distribution], + ) -> None: + distribution = create_uniform_action_dist(space, batch_size) + + # Check distribution type + assert isinstance(distribution, distribution_type) + + # Check sample shape + samples = distribution.sample() + assert samples.shape == expected_shape + + # Check internal distribution shapes + if isinstance(space, spaces.Box): + assert distribution.low.shape == expected_shape + assert distribution.high.shape == expected_shape + elif isinstance(space, spaces.Discrete): + assert distribution.probs.shape == (batch_size, space.n) diff --git a/tianshou/utils/torch_utils.py b/tianshou/utils/torch_utils.py index 430d174e7..44d5f7668 100644 --- a/tianshou/utils/torch_utils.py +++ b/tianshou/utils/torch_utils.py @@ -1,7 +1,10 @@ from collections.abc import Iterator from contextlib import contextmanager -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, overload +import torch +import torch.distributions as dist +from gymnasium import spaces from torch import nn if TYPE_CHECKING: @@ -37,3 +40,37 @@ def policy_within_training_step(policy: "BasePolicy", enabled: bool = True) -> I yield finally: policy.is_within_training_step = original_mode + + +@overload +def create_uniform_action_dist(action_space: spaces.Box, batch_size: int = 1) -> dist.Uniform: + ... + + +@overload +def create_uniform_action_dist( + action_space: spaces.Discrete, batch_size: int = 1, +) -> dist.Categorical: + ... + + +def create_uniform_action_dist( + action_space: spaces.Box | spaces.Discrete, + batch_size: int = 1, +) -> dist.Uniform | dist.Categorical: + """Create a Distribution such that sampling from it is equivalent to sampling a batch with `action_space.sample()`. + + :param action_space: The action space of the environment. + :param batch_size: The number of environments or batch size for sampling. + :return: A PyTorch distribution for sampling actions. + """ + if isinstance(action_space, spaces.Box): + low = torch.FloatTensor(action_space.low).unsqueeze(0).repeat(batch_size, 1) + high = torch.FloatTensor(action_space.high).unsqueeze(0).repeat(batch_size, 1) + return dist.Uniform(low, high) + + elif isinstance(action_space, spaces.Discrete): + return dist.Categorical(torch.ones(batch_size, action_space.n)) + + else: + raise ValueError(f"Unsupported action space type: {type(action_space)}") From 0a3fc25a8dfa83684dfb9cf70046835c45dbb1f0 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Wed, 7 Aug 2024 11:45:32 +0200 Subject: [PATCH 08/33] HL, config: option to collect n_episodes, post-init validation, docstrings Note: the new config option will be used in follow-up commits --- tianshou/highlevel/config.py | 56 +++++++++++++++++++++++++++++------- 1 file changed, 46 insertions(+), 10 deletions(-) diff --git a/tianshou/highlevel/config.py b/tianshou/highlevel/config.py index 951f2f3af..ec702e373 100644 --- a/tianshou/highlevel/config.py +++ b/tianshou/highlevel/config.py @@ -1,8 +1,11 @@ import multiprocessing from dataclasses import dataclass +from tianshou.utils import logging from tianshou.utils.string import ToStringMixin +log = logging.getLogger(__name__) + @dataclass class SamplingConfig(ToStringMixin): @@ -67,7 +70,7 @@ class SamplingConfig(ToStringMixin): """the total size of the sample/replay buffer, in which environment steps (transitions) are stored""" - step_per_collect: int = 2048 + step_per_collect: int | None = 2048 """ the number of environment steps/transitions to collect in each collection step before the network update within each training step. @@ -81,11 +84,18 @@ class SamplingConfig(ToStringMixin): collected during training. """ + episode_per_collect: int | None = None + """ + the number of episodes to collect in each collection step before the network update within + each training step. If this is set, the number of environment steps collected in each + collection step is the sum of the lengths of the episodes collected. + """ + repeat_per_collect: int | None = 1 """ controls, within one gradient update step of an on-policy algorithm, the number of times an actual gradient update is applied using the full collected dataset, i.e. if the parameter is - 5, then the collected data shall be used five times to update the policy within the same + `n`, then the collected data shall be used five times to update the policy within the same training step. The parameter is ignored and may be set to None for off-policy and offline algorithms. @@ -116,14 +126,22 @@ class SamplingConfig(ToStringMixin): """ replay_buffer_ignore_obs_next: bool = False + """whether to ignore the `obs_next` field in the collected samples when storing them in the + buffer and instead use the one-in-the-future of `obs` as the next observation. + This can be useful for very large observations, like for Atari, in order to save RAM. + + However, setting this to True **may introduce an error** at the last steps of episodes! Should + only be used in exceptional cases and only when you know what you are doing. + Currently only used in Atari examples and may be removed in the future! + """ replay_buffer_save_only_last_obs: bool = False - """if True, for the case where the environment outputs stacked frames (e.g. because it - is using a `FrameStack` wrapper), save only the most recent frame so as not to duplicate - observations in buffer memory. Specifically, if the environment outputs observations `obs` with - shape (N, ...), only obs[-1] of shape (...) will be stored. - Frame stacking with a fixed number of frames can then be recreated at the buffer level by setting - :attr:`replay_buffer_stack_num`. + """if True, only the most recent frame is saved when appending to experiences rather than the + full stacked frames. This avoids duplicating observations in buffer memory. Set to False to + save stacked frames in full. + + Note: typically used together with `replay_buffer_stack_num`, see documentation there. + Currently only used in Atari examples and may be removed in the future! """ replay_buffer_stack_num: int = 1 @@ -132,8 +150,8 @@ class SamplingConfig(ToStringMixin): to the agent for each time step. Setting this to a value greater than 1 can help agents learn temporal aspects (e.g. velocities of moving objects for which only positions are observed). - If the environment already stacks frames (e.g. using a `FrameStack` wrapper), this should either not - be used or should be used in conjunction with :attr:`replay_buffer_save_only_last_obs`. + Note: it is recommended to do this stacking on the environment level by using something like + gymnasium's `FrameStack` instead. Currently only used in Atari examples and may be removed in the future! """ @property @@ -143,3 +161,21 @@ def test_seed(self) -> int: def __post_init__(self) -> None: if self.num_train_envs == -1: self.num_train_envs = multiprocessing.cpu_count() + + if self.num_test_episodes == 0 and self.num_test_envs != 0: + log.warning( + f"Number of test episodes is set to 0, " + f"but number of test environments is ({self.num_test_envs}). " + f"This can cause unnecessary memory usage.", + ) + + if self.num_test_episodes != 0 and self.num_test_episodes % self.num_test_envs != 0: + log.warning( + f"Number of test episodes ({self.num_test_episodes} " + f"is not divisible by the number of test environments ({self.num_test_envs}). " + f"This can cause unnecessary memory usage, it is recommended to adjust this.", + ) + + assert ( + sum([self.step_per_collect is not None, self.episode_per_collect is not None]) == 1 + ), ("Only one of `step_per_collect` and `episode_per_collect` can be set.",) From a95a8b166db9eff0d050954c465ed1d0ba41b88a Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Wed, 7 Aug 2024 23:26:59 +0200 Subject: [PATCH 09/33] n-step-return: better variable names, more docstrings and comments --- tianshou/policy/base.py | 167 ++++++++++++++++++++++++++++++---------- 1 file changed, 128 insertions(+), 39 deletions(-) diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index d886180a5..1db0fce77 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -660,26 +660,60 @@ def compute_nstep_return( if len(indices) != len(batch): raise ValueError(f"Batch size {len(batch)} and indices size {len(indices)} mismatch.") - rew = buffer.rew - bsz = len(indices) - indices = [indices] - for _ in range(n_step - 1): - indices.append(buffer.next(indices[-1])) - indices = np.stack(indices) - # terminal indicates buffer indexes nstep after 'indices', - # and are truncated at the end of each episode - terminal = indices[-1] + # naming convention + # I = number of indices + # B = size of the replay buffer + # N = n_step + # A = the output dimension of target_q_fn for a single index. Presumably + # this is the number of actions in the discrete case, or something like that. + # 1 = 1 extra dimension + # TODO: it's very weird that this is not always one! + # We set the n-step-return for a single index to be the same shape as the target_q_fn. + # I don't understand how a non-scalar value would make sense there, but such cases are covered by tests + + # support in following naming convention + I = len(indices) + N = n_step + + _indices_to_stack = [indices] + for _ in range(N - 1): + next_indices = buffer.next(_indices_to_stack[-1]) + _indices_to_stack.append(next_indices) + stacked_indices_NI = np.stack(_indices_to_stack) + """The stacked indices represent a 2d array of shape `IxN` of the type + [ + [i_1, i_2,...], + [i_(next(1)), i_(next(2)), ...], + [i_(next(next(1)), ... + ... + ] + where `next` is the subsequent transition in the buffer. + """ + indices_after_n_steps_I = stacked_indices_NI[-1] + """Indicates indexes of transitions in buffer that occur N steps after the user provided 'indices'; + they are truncated at the end of each episode""" + with torch.no_grad(): - target_q_torch = target_q_fn(buffer, terminal) # (bsz, ?) - target_q = to_numpy(target_q_torch.reshape(bsz, -1)) - target_q = target_q * BasePolicy.value_mask(buffer, terminal).reshape(-1, 1) - end_flag = buffer.done.copy() - end_flag[buffer.unfinished_index()] = True - target_q = _nstep_return(rew, end_flag, target_q, indices, gamma, n_step) - - batch.returns = to_torch_as(target_q, target_q_torch) - if hasattr(batch, "weight"): # prio buffer update - batch.weight = to_torch_as(batch.weight, target_q_torch) + target_q_torch_IA = target_q_fn(buffer, indices_after_n_steps_I) + target_q_IA = to_numpy(target_q_torch_IA.reshape(I, -1)) + """Represents the Q-values (one for each action) of the transition after N steps.""" + + target_q_IA = target_q_IA * BasePolicy.value_mask(buffer, indices_after_n_steps_I).reshape( + -1, 1, + ) + end_flag_B = buffer.done.copy() + end_flag_B[buffer.unfinished_index()] = True + n_step_return_IA = _nstep_return( + buffer.rew, end_flag_B, target_q_IA, stacked_indices_NI, gamma, n_step, + ) + """The n-step return plus the last Q-values, see method's docstring""" + + batch.returns = to_torch_as(n_step_return_IA, target_q_torch_IA) + + # TODO: this is simply casting to a certain type. Why is this necessary, and why is it happening here? + if hasattr(batch, "weight"): + batch.weight = to_torch_as(batch.weight, target_q_torch_IA) + return cast(BatchWithReturnsProtocol, batch) @staticmethod @@ -743,28 +777,83 @@ def _gae_return( return returns +@njit +def episode_mc_return_to_go(rewards: np.ndarray, gamma: float = 0.99) -> np.ndarray: + """Calculates discounted monte-carlo returns to go from rewards of a single episode. + + :param rewards: rewards of a single episode. Assumed to be a 1-dim array from reset till the end of the episode. + :param gamma: discount factor + :return: a numpy array of shape (len(rewards), ). + """ + len_episode = len(rewards) + ret2go = np.zeros(len_episode) + ret2go[-1] = rewards[-1] + + for j in range(len_episode - 2, -1, -1): + ret2go[j] = rewards[j] + gamma * ret2go[j + 1] + return ret2go + + @njit def _nstep_return( - rew: np.ndarray, - end_flag: np.ndarray, - target_q: np.ndarray, - indices: np.ndarray, + rew_B: np.ndarray, + end_flag_B: np.ndarray, + target_q_IA: np.ndarray, + stacked_indices_NI: np.ndarray, gamma: float, n_step: int, ) -> np.ndarray: - gamma_buffer = np.ones(n_step + 1) - for i in range(1, n_step + 1): - gamma_buffer[i] = gamma_buffer[i - 1] * gamma - target_shape = target_q.shape - bsz = target_shape[0] - # change target_q to 2d array - target_q = target_q.reshape(bsz, -1) - returns = np.zeros(target_q.shape) - gammas = np.full(indices[0].shape, n_step) - for n in range(n_step - 1, -1, -1): - now = indices[n] - gammas[end_flag[now] > 0] = n + 1 - returns[end_flag[now] > 0] = 0.0 - returns = rew[now].reshape(bsz, 1) + gamma * returns - target_q = target_q * gamma_buffer[gammas].reshape(bsz, 1) + returns - return target_q.reshape(target_shape) + """Computes n-step returns starting at the transitions at the selected indices in the buffer. + Importantly, this is not a pure MC n-step return but it also uses the Q-values of the + obs-action pair after the n-step transition to compute the return. + + Thus, it computes `n_step_return + gamma^(n) * Q(s_{t+n}, a_{t+n})` where + `n_step_return = r_t + gamma * r_{t+1} + ... + gamma^(n-1) * r_{t+n-1}`. + See the docstring of `compute_nstep_return` for more details. + + The target_q_B should be the array of `Q(s_{t+n}, a_{t+n})` corresponding to + the batch of rewards that started at t=0. + + Notation: + I = number of indices + B = size of the replay buffer + N = n_step + A = the output dimension of target_q_fn for a single index. Presumably, + this is the number of actions in the discrete case, or something like that. + See comments in the method `compute_nstep_return` for more details. + 1 = 1 extra dimension + + :param rew_B: rewards of the entire replay buffer + :param end_flag_B: end flags (where done=True) of the entire replay buffer + :param target_q_IA: Q-values of the transitions after n steps. Passed as a 2d array of shape (I, A) + :param stacked_indices_NI: indices of the transitions in the buffer of the structure + [ + [i_1, i_2,...], + [i_(next(1)), i_(next(2)), ...], + [i_(next(next(1)), ... + ... + ] + where `next` is the subsequent transition in the buffer. + """ + N = n_step + I, A = target_q_IA.shape + gamma_buffer_N = np.ones(N + 1) + for i in range(1, N + 1): + gamma_buffer_N[i] = gamma_buffer_N[i - 1] * gamma + target_q_IA = target_q_IA.reshape(I, -1) + """Make sure tarqet_q_I has an empty extra dimension, usually already passed with the + right shape, hence the input param name""" + n_step_mc_returns_IA = np.zeros(target_q_IA.shape) + """Will hold the n_step MC return part of the final n_step + Q-value return. + """ + gammas_IN = np.full(I, N) + for n in range(N - 1, -1, -1): + now = stacked_indices_NI[n] + gammas_IN[end_flag_B[now] > 0] = n + 1 + n_step_mc_returns_IA[end_flag_B[now] > 0] = 0.0 + n_step_mc_returns_IA = rew_B[now].reshape(I, 1) + gamma * n_step_mc_returns_IA + + n_step_return_with_Q_IA = ( + target_q_IA * gamma_buffer_N[gammas_IN].reshape(I, 1) + n_step_mc_returns_IA + ) + return n_step_return_with_Q_IA.reshape((I, A)) From 29bc77adad0bc0658740308f5f7a42a81979a879 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Thu, 8 Aug 2024 16:39:21 +0200 Subject: [PATCH 10/33] Fixup returns --- tianshou/policy/base.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 1db0fce77..9ecf252d1 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -698,13 +698,16 @@ def compute_nstep_return( target_q_IA = to_numpy(target_q_torch_IA.reshape(I, -1)) """Represents the Q-values (one for each action) of the transition after N steps.""" - target_q_IA = target_q_IA * BasePolicy.value_mask(buffer, indices_after_n_steps_I).reshape( - -1, 1, - ) + target_q_IA *= BasePolicy.value_mask(buffer, indices_after_n_steps_I).reshape(-1, 1) end_flag_B = buffer.done.copy() end_flag_B[buffer.unfinished_index()] = True n_step_return_IA = _nstep_return( - buffer.rew, end_flag_B, target_q_IA, stacked_indices_NI, gamma, n_step, + buffer.rew, + end_flag_B, + target_q_IA, + stacked_indices_NI, + gamma, + n_step, ) """The n-step return plus the last Q-values, see method's docstring""" From 6abfefc57b8cfb33a1d5e6123ede383acc4eb131 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Thu, 8 Aug 2024 17:17:46 +0200 Subject: [PATCH 11/33] Test: minor comment --- test/base/env.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/base/env.py b/test/base/env.py index 02f76ad2d..ec81554b0 100644 --- a/test/base/env.py +++ b/test/base/env.py @@ -212,8 +212,11 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: kwargs.get("dict_state", 0) + kwargs.get("recurse_state", 0) == 0 ), "dict_state / recurse_state not supported" super().__init__(*args, **kwargs) - obs, _ = super().reset(options={"state": 0}) + super().reset(options={"state": 0}) + + # will result in obs=1, I guess, so the goal is to reach the max size by moving right obs, _, _, _, _ = super().step(1) + self._goal = obs * self.size super_obsv = self.observation_space self.observation_space = gym.spaces.Dict( From 8dac5b158e67f35201a33611749b7d4bd23cda20 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Thu, 8 Aug 2024 17:34:08 +0200 Subject: [PATCH 12/33] Major commit to `ReplayBuffer` and related classes Extensions: - new property `subbuffer_edges` in normal and vectorized buffer - forwarded set_array_at_key, hasnull, isnull and dropnull from Batch - added last_index creation to init of `BufferManager` Breaking: - Better input validation and checks for malformed buffer Non-functional: Many renamings, comments, docstrings and TODOs --- test/base/test_buffer.py | 150 +++++++++++++---------- tianshou/data/buffer/base.py | 204 +++++++++++++++++++++++--------- tianshou/data/buffer/her.py | 2 +- tianshou/data/buffer/manager.py | 59 +++++++-- tianshou/data/buffer/prio.py | 5 +- 5 files changed, 285 insertions(+), 135 deletions(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 0996f2436..781b7e044 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -35,17 +35,14 @@ def test_replaybuffer(size: int = 10, bufsize: int = 20) -> None: for i, act in enumerate(action_list): obs_next, rew, terminated, truncated, info = env.step(act) buf.add( - cast( - RolloutBatchProtocol, - Batch( - obs=obs, - act=[act], - rew=rew, - terminated=terminated, - truncated=truncated, - obs_next=obs_next, - info=info, - ), + Batch( + obs=obs, + act=[act], + rew=rew, + terminated=terminated, + truncated=truncated, + obs_next=obs_next, + info=info, ), ) obs = obs_next @@ -62,36 +59,33 @@ def test_replaybuffer(size: int = 10, bufsize: int = 20) -> None: assert (data.terminated <= 1).all() assert (data.truncated >= 0).all() assert (data.truncated <= 1).all() - replay_buffer = ReplayBuffer(size=10) + b = ReplayBuffer(size=10) # neg bsz should return empty index - assert replay_buffer.sample_indices(-1).tolist() == [] - ptr, ep_rew, ep_len, ep_idx = replay_buffer.add( - cast( - RolloutBatchProtocol, - Batch( - obs=1, - act=1, - rew=1, - terminated=1, - truncated=0, - obs_next="str", - info={"a": 3, "b": {"c": 5.0}}, - ), + assert b.sample_indices(-1).tolist() == [] + ptr, ep_rew, ep_len, ep_idx = b.add( + Batch( + obs=1, + act=1, + rew=1, + terminated=1, + truncated=0, + obs_next="str", + info={"a": 3, "b": {"c": 5.0}}, ), ) - assert replay_buffer.obs[0] == 1 - assert replay_buffer.done[0] - assert replay_buffer.terminated[0] - assert not replay_buffer.truncated[0] - assert replay_buffer.obs_next[0] == "str" - assert np.all(replay_buffer.obs[1:] == 0) - assert np.all(replay_buffer.obs_next[1:] == np.array(None)) - assert replay_buffer.info.a[0] == 3 - assert replay_buffer.info.a.dtype == int - assert np.all(replay_buffer.info.a[1:] == 0) - assert replay_buffer.info.b.c[0] == 5.0 - assert replay_buffer.info.b.c.dtype == float - assert np.all(replay_buffer.info.b.c[1:] == 0.0) + assert b.obs[0] == 1 + assert b.done[0] + assert b.terminated[0] + assert not b.truncated[0] + assert b.obs_next[0] == "str" + assert np.all(b.obs[1:] == 0) + assert np.all(b.obs_next[1:] == np.array(None)) + assert b.info.a[0] == 3 + assert b.info.a.dtype == int + assert np.all(b.info.a[1:] == 0) + assert b.info.b.c[0] == 5.0 + assert b.info.b.c.dtype == float + assert np.all(b.info.b.c[1:] == 0.0) assert ptr.shape == (1,) assert ptr[0] == 0 assert ep_rew.shape == (1,) @@ -113,20 +107,19 @@ def test_replaybuffer(size: int = 10, bufsize: int = 20) -> None: info={"a": 4, "d": {"e": -np.inf}}, ), ) - replay_buffer.add(batch) + b.add(batch) info_keys = ["a", "b", "d"] - assert set(replay_buffer.info.keys()) == set(info_keys) - assert replay_buffer.info.a[1] == 4 - assert replay_buffer.info.b.c[1] == 0 - assert replay_buffer.info.d.e[1] == -np.inf + assert set(b.info.keys()) == set(info_keys) + assert b.info.a[1] == 4 + assert b.info.b.c[1] == 0 + assert b.info.d.e[1] == -np.inf # test batch-style adding method, where len(batch) == 1 batch.done = [1] - batch.terminated = [0] # type: ignore[assignment] - batch.truncated = [1] # type: ignore[assignment] - assert isinstance(batch.info, Batch) + batch.terminated = [0] + batch.truncated = [1] batch.info.e = np.zeros([1, 4]) - batch = Batch.stack([batch]) - ptr, ep_rew, ep_len, ep_idx = replay_buffer.add(batch, buffer_ids=[0]) + batch: RolloutBatchProtocol = Batch.stack([batch]) + ptr, ep_rew, ep_len, ep_idx = b.add(batch, buffer_ids=[0]) assert ptr.shape == (1,) assert ptr[0] == 2 assert ep_rew.shape == (1,) @@ -135,17 +128,17 @@ def test_replaybuffer(size: int = 10, bufsize: int = 20) -> None: assert ep_len[0] == 2 assert ep_idx.shape == (1,) assert ep_idx[0] == 1 - assert set(replay_buffer.info.keys()) == {*info_keys, "e"} - assert replay_buffer.info.e.shape == (replay_buffer.maxsize, 1, 4) + assert set(b.info.keys()) == {*info_keys, "e"} + assert b.info.e.shape == (b.maxsize, 1, 4) with pytest.raises(IndexError): - replay_buffer[22] + b[22] # test prev / next - assert np.all(replay_buffer.prev(np.array([0, 1, 2])) == [0, 1, 1]) - assert np.all(replay_buffer.next(np.array([0, 1, 2])) == [0, 2, 2]) + assert np.all(b.prev(np.array([0, 1, 2])) == [0, 1, 1]) + assert np.all(b.next(np.array([0, 1, 2])) == [0, 2, 2]) batch.done = [0] - replay_buffer.add(batch, buffer_ids=[0]) - assert np.all(replay_buffer.prev(np.array([0, 1, 2, 3])) == [0, 1, 1, 3]) - assert np.all(replay_buffer.next(np.array([0, 1, 2, 3])) == [0, 2, 2, 3]) + b.add(batch, buffer_ids=[0]) + assert np.all(b.prev(np.array([0, 1, 2, 3])) == [0, 1, 1, 3]) + assert np.all(b.next(np.array([0, 1, 2, 3])) == [0, 2, 2, 3]) def test_ignore_obs_next(size: int = 10) -> None: @@ -308,7 +301,7 @@ def test_stack(size: int = 5, bufsize: int = 9, stack_num: int = 4, cached_num: buf[bufsize * 2] -def test_priortized_replaybuffer(size: int = 32, bufsize: int = 15) -> None: +def test_prioritized_replaybuffer(size: int = 32, bufsize: int = 15) -> None: env = MoveToRightEnv(size) buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5) buf2 = PrioritizedVectorReplayBuffer(bufsize, buffer_num=3, alpha=0.5, beta=0.5) @@ -329,7 +322,7 @@ def test_priortized_replaybuffer(size: int = 32, bufsize: int = 15) -> None: policy=np.random.randn() - 0.5, ), ) - batch_stack = Batch.stack([batch, batch, batch]) + batch_stack: RolloutBatchProtocol = Batch.stack([batch, batch, batch]) buf.add(Batch.stack([batch]), buffer_ids=[0]) buf2.add(batch_stack, buffer_ids=[0, 1, 2]) obs = obs_next @@ -467,10 +460,9 @@ def compute_reward_fn(ag: np.ndarray, g: np.ndarray) -> np.ndarray: # Test handling cycled indices env_size = size bufsize = 15 - env = MyGoalEnv(env_size, array_state=False) + env = MyGoalEnv(size=env_size, array_state=False) buf = HERReplayBuffer(bufsize, compute_reward_fn=compute_reward_fn, horizon=30, future_k=8) - buf._index = 5 # shifted start index buf.future_p = 1 for ep_len in [5, 10]: obs, _ = env.reset() @@ -494,7 +486,7 @@ def compute_reward_fn(ag: np.ndarray, g: np.ndarray) -> np.ndarray: batch_sample, indices = buf.sample(0) assert np.all(buf.obs.desired_goal[:5] == buf.obs.desired_goal[0]) assert np.all(buf.obs.desired_goal[5:10] == buf.obs.desired_goal[5]) - assert np.all(buf.obs.desired_goal[10:] == buf.obs.desired_goal[0]) # (same ep) + assert np.all(buf.obs.desired_goal[5:] == buf.obs.desired_goal[14]) # (same ep) assert np.all(buf.obs.desired_goal[0] != buf.obs.desired_goal[5]) # (diff ep) # Another test case for cycled indices @@ -726,7 +718,7 @@ def test_hdf5() -> None: assert _buffers[k].maxsize == buffers[k].maxsize assert np.all(_buffers[k]._indices == buffers[k]._indices) for k in ["array", "prioritized"]: - assert _buffers[k]._index == buffers[k]._index + assert _buffers[k]._insertion_idx == buffers[k]._insertion_idx assert isinstance(buffers[k].get(0, "info"), Batch) assert isinstance(_buffers[k].get(0, "info"), Batch) for k in ["array"]: @@ -1032,6 +1024,8 @@ def test_cachedbuffer() -> None: # used in test_collector buf = CachedReplayBuffer(ReplayBuffer(0, sample_avail=True), 4, 5) data = np.zeros(4) + # TODO: this doesn't make any sense - why a matrix reward?! + # See error message in ReplayBuffer._update_state_pre_add rew = np.ones([4, 4]) buf.add( cast( @@ -1477,3 +1471,35 @@ def test_custom_key() -> None: ): assert len(batch.__dict__[key].get_keys()) == 0 assert len(sampled_batch.__dict__[key].get_keys()) == 0 + + +def test_buffer_dropnull() -> None: + size = 10 + buf = ReplayBuffer(size, ignore_obs_next=True) + for i in range(4): + buf.add( + cast( + RolloutBatchProtocol, + Batch( + obs={ + "mask1": i + 1, + "mask2": i + 4, + "mask": i, + }, + act={"act_id": i, "position_id": i + 3}, + rew=i, + terminated=i % 3 == 0, + truncated=False, + info={"if": i}, + ), + ), + ) + + assert len(buf[:3]) == 3 + + buf.set_array_at_key(np.array([1, 2, 3], float), "newkey", [0, 1, 2]) + assert np.array_equal(buf.newkey[:3], np.array([1, 2, 3], float)) + assert buf.hasnull() + buf.dropnull() + assert len(buf[:3]) == 3 + assert not buf.hasnull() diff --git a/tianshou/data/buffer/base.py b/tianshou/data/buffer/base.py index 2699d92d7..e03a5d602 100644 --- a/tianshou/data/buffer/base.py +++ b/tianshou/data/buffer/base.py @@ -4,13 +4,22 @@ import numpy as np from tianshou.data import Batch -from tianshou.data.batch import alloc_by_keys_diff, create_value +from tianshou.data.batch import ( + IndexType, + alloc_by_keys_diff, + create_value, + log, +) from tianshou.data.types import RolloutBatchProtocol from tianshou.data.utils.converter import from_hdf5, to_hdf5 TBuffer = TypeVar("TBuffer", bound="ReplayBuffer") +class MalformedBufferError(RuntimeError): + pass + + class ReplayBuffer: """:class:`~tianshou.data.ReplayBuffer` stores data generated from interaction between the policy and environment. @@ -23,11 +32,11 @@ class ReplayBuffer: :param size: the maximum size of replay buffer. :param stack_num: the frame-stack sampling argument, should be greater than or equal to 1. Default to 1 (no stacking). - :param ignore_obs_next: whether to not store obs_next. Default to False. + :param ignore_obs_next: whether to not store obs_next. :param save_only_last_obs: only save the last obs/obs_next when it has a shape - of (timestep, ...) because of temporal stacking. Default to False. - :param sample_avail: the parameter indicating sampling only available index - when using frame-stack sampling method. Default to False. + of (timestep, ...) because of temporal stacking. + :param sample_avail: whether to sample only available indices + when using the frame-stack sampling method. """ _reserved_keys = ( @@ -61,6 +70,7 @@ def __init__( sample_avail: bool = False, **kwargs: Any, # otherwise PrioritizedVectorReplayBuffer will cause TypeError ) -> None: + # TODO: why do we need this? Just for readout? self.options: dict[str, Any] = { "stack_num": stack_num, "ignore_obs_next": ignore_obs_next, @@ -72,38 +82,44 @@ def __init__( assert stack_num > 0, "stack_num should be greater than 0" self.stack_num = stack_num self._indices = np.arange(size) + # TODO: remove double negation and different name self._save_obs_next = not ignore_obs_next self._save_only_last_obs = save_only_last_obs self._sample_avail = sample_avail self._meta = cast(RolloutBatchProtocol, Batch()) - self._ep_rew: float | np.ndarray - self.reset() + + # Keep in sync with reset! + self.last_index = np.array([0]) + self._insertion_idx = self._size = 0 + self._ep_return, self._ep_len, self._ep_start_idx = 0.0, 0, 0 + + @property + def subbuffer_edges(self) -> np.ndarray: + """Edges of contained buffers, mostly needed as part of the VectorReplayBuffer interface. + + For the standard ReplayBuffer it is always [0, maxsize]. Transitions can be added + to the buffer indefinitely, and one episode can "go over the edge". Having the edges + available is useful for fishing out whole episodes from the buffer and for input validation. + """ + return np.array([0, self.maxsize], dtype=int) def __len__(self) -> int: - """Return len(self).""" return self._size def __repr__(self) -> str: - """Return str(self).""" - return self.__class__.__name__ + self._meta.__repr__()[5:] + wrapped_batch_repr = self._meta.__repr__()[len(self._meta.__class__.__name__) :] + return self.__class__.__name__ + wrapped_batch_repr def __getattr__(self, key: str) -> Any: - """Return self.key.""" try: return self._meta[key] except KeyError as exception: raise AttributeError from exception def __setstate__(self, state: dict[str, Any]) -> None: - """Unpickling interface. - - We need it because pickling buffer does not work out-of-the-box - ("buffer.__getattr__" is customized). - """ self.__dict__.update(state) def __setattr__(self, key: str, value: Any) -> None: - """Set self.key = value.""" assert key not in self._reserved_keys, f"key '{key}' is reserved and cannot be assigned" super().__setattr__(key, value) @@ -154,37 +170,39 @@ def from_data( def reset(self, keep_statistics: bool = False) -> None: """Clear all the data in replay buffer and episode statistics.""" + # Keep in sync with init! self.last_index = np.array([0]) - self._index = self._size = 0 + self._insertion_idx = self._size = self._ep_start_idx = 0 if not keep_statistics: - self._ep_rew, self._ep_len, self._ep_idx = 0.0, 0, 0 + self._ep_return, self._ep_len = 0.0, 0 + # TODO: is this method really necessary? It's kinda dangerous, can accidentally + # remove all references to collected data def set_batch(self, batch: RolloutBatchProtocol) -> None: """Manually choose the batch you want the ReplayBuffer to manage.""" - assert len(batch) == self.maxsize and set(batch.keys()).issubset( + assert len(batch) == self.maxsize and set(batch.get_keys()).issubset( self._reserved_keys, ), "Input batch doesn't meet ReplayBuffer's data form requirement." self._meta = batch def unfinished_index(self) -> np.ndarray: """Return the index of unfinished episode.""" - last = (self._index - 1) % self._size if self._size else 0 + last = (self._insertion_idx - 1) % self._size if self._size else 0 return np.array([last] if not self.done[last] and self._size else [], int) def prev(self, index: int | np.ndarray) -> np.ndarray: - """Return the index of preceding step within the same episode if it exists. - If it does not exist (because it is the first index within the episode), - the index remains unmodified. + """Return the index of previous transition. + + The index won't be modified if it is the beginning of an episode. """ - index = (index - 1) % self._size # compute preceding index with wrap-around - # end_flag will be 1 if the previous index is the last step of an episode or - # if it is the very last index of the buffer (wrap-around case), and 0 otherwise + index = (index - 1) % self._size end_flag = self.done[index] | (index == self.last_index[0]) return (index + end_flag) % self._size def next(self, index: int | np.ndarray) -> np.ndarray: - """Return the index of next step if there is a next step within the episode. - If there isn't a next step, the index remains unmodified. + """Return the index of next transition. + + The index won't be modified if it is the end of an episode. """ end_flag = self.done[index] | (index == self.last_index[0]) return (index + (1 - end_flag)) % self._size @@ -203,9 +221,9 @@ def update(self, buffer: "ReplayBuffer") -> np.ndarray: return np.array([], int) to_indices = [] for _ in range(len(from_indices)): - to_indices.append(self._index) - self.last_index[0] = self._index - self._index = (self._index + 1) % self.maxsize + to_indices.append(self._insertion_idx) + self.last_index[0] = self._insertion_idx + self._insertion_idx = (self._insertion_idx + 1) % self.maxsize self._size = min(self._size + 1, self.maxsize) to_indices = np.array(to_indices) if len(self._meta.get_keys()) == 0: @@ -213,28 +231,62 @@ def update(self, buffer: "ReplayBuffer") -> np.ndarray: self._meta[to_indices] = buffer._meta[from_indices] return to_indices - def _add_index( + def _update_state_pre_add( self, rew: float | np.ndarray, done: bool, - ) -> tuple[int, float | np.ndarray, int, int]: - """Maintain the buffer's state after adding one data batch. - - Return (index_to_be_modified, episode_reward, episode_length, - episode_start_index). + ) -> tuple[int, float, int, int]: + """Update the buffer's state before adding one data batch. + + Updates the `_size` and `_insertion_idx`, adds the reward and len + internally maintained `_ep_len` and `_ep_return`. If `done` is `True`, + will reset `_ep_len` and `_ep_return` to zero, and set `_ep_start_idx` to + `_insertion_idx` + + Returns a tuple with: + 0. the index at which to insert the next transition, + 1. the episode len (if done=True, otherwise 0) + 2. the episode return (if done=True, otherwise 0) + 3. the episode start index. """ - self.last_index[0] = ptr = self._index + self.last_index[0] = cur_insertion_idx = self._insertion_idx self._size = min(self._size + 1, self.maxsize) - self._index = (self._index + 1) % self.maxsize + self._insertion_idx = (self._insertion_idx + 1) % self.maxsize - self._ep_rew += rew + self._ep_return += rew # type: ignore self._ep_len += 1 + if self._ep_start_idx > len(self): + raise MalformedBufferError( + f"Encountered a starting index {self._ep_start_idx} that is outside " + f"the currently available samples {len(self)=}. " + f"The buffer is malformed. This might be caused by a bug or by manual modifications of the buffer " + f"by users.", + ) + + # return 0 for unfinished episodes + if done: + ep_return = self._ep_return + ep_len = self._ep_len + else: + if isinstance(self._ep_return, np.ndarray): # type: ignore[unreachable] + # TODO: fix this! + log.error( # type: ignore[unreachable] + f"ep_return should be a scalar but is a numpy array: {self._ep_return.shape=}. " + "This doesn't make sense for a ReplayBuffer, but currently tests of CachedReplayBuffer require" + "this behavior for some reason. Should be fixed ASAP! " + "Returning an array of zeros instead of a scalar zero.", + ) + ep_return = np.zeros_like(self._ep_return) # type: ignore + ep_len = 0 + + result = cur_insertion_idx, ep_return, ep_len, self._ep_start_idx + if done: - result = ptr, self._ep_rew, self._ep_len, self._ep_idx - self._ep_rew, self._ep_len, self._ep_idx = 0.0, 0, self._index - return result - return ptr, self._ep_rew * 0.0, 0, self._ep_idx + # prepare for next episode collection + # set return and len to zero, set start idx to next insertion idx + self._ep_return, self._ep_len, self._ep_start_idx = 0.0, 0, self._insertion_idx + return result def add( self, @@ -275,9 +327,11 @@ def add( rew, done = batch.rew[0], batch.done[0] else: rew, done = batch.rew, batch.done - ptr, ep_rew, ep_len, ep_idx = (np.array([x]) for x in self._add_index(rew, done)) + insertion_idx, ep_return, ep_len, ep_start_idx = ( + np.array([x]) for x in self._update_state_pre_add(rew, done) + ) try: - self._meta[ptr] = batch + self._meta[insertion_idx] = batch except ValueError: stack = not stacked_batch batch.rew = batch.rew.astype(float) @@ -288,8 +342,8 @@ def add( self._meta = create_value(batch, self.maxsize, stack) # type: ignore else: # dynamic key pops up in batch alloc_by_keys_diff(self._meta, batch, self.maxsize, stack) - self._meta[ptr] = batch - return ptr, ep_rew, ep_len, ep_idx + self._meta[insertion_idx] = batch + return insertion_idx, ep_return, ep_len, ep_start_idx def sample_indices(self, batch_size: int | None) -> np.ndarray: """Get a random sample of index with size = batch_size. @@ -308,7 +362,9 @@ def sample_indices(self, batch_size: int | None) -> np.ndarray: return np.random.choice(self._size, batch_size) # TODO: is this behavior really desired? if batch_size == 0: # construct current available indices - return np.concatenate([np.arange(self._index, self._size), np.arange(self._index)]) + return np.concatenate( + [np.arange(self._insertion_idx, self._size), np.arange(self._insertion_idx)], + ) return np.array([], int) # TODO: raise error on negative batch_size instead? if batch_size < 0: @@ -318,7 +374,7 @@ def sample_indices(self, batch_size: int | None) -> np.ndarray: # It is also not clear whether this is really necessary - frame stacking usually is handled # by environment wrappers (e.g. FrameStack) and not by the replay buffer. all_indices = prev_indices = np.concatenate( - [np.arange(self._index, self._size), np.arange(self._index)], + [np.arange(self._insertion_idx, self._size), np.arange(self._insertion_idx)], ) for _ in range(self.stack_num - 2): prev_indices = self.prev(prev_indices) @@ -342,6 +398,10 @@ def get( index: int | list[int] | np.ndarray, key: str, default_value: Any = None, + # TODO 1: this is only here because of atari, it should never be needed (can be solved with index) + # and should be removed + # TODO 2: does something entirely different from getitem + # TODO 3: key should not be required stack_num: int | None = None, ) -> Batch | np.ndarray: """Return the stacked result. @@ -350,7 +410,7 @@ def get( stacked result as ``[obs[t-3], obs[t-2], obs[t-1], obs[t]]``. :param index: the index for getting stacked data. - :param str key: the key to get, should be one of the reserved_keys. + :param key: the key to get, should be one of the reserved_keys. :param default_value: if the given key's data is not found and default_value is set, return this default_value. :param stack_num: Default to self.stack_num. @@ -377,16 +437,19 @@ def get( return np.stack(stack, axis=indices.ndim) except IndexError as exception: - if not (isinstance(val, Batch) and len(val.get_keys()) == 0): + if not (isinstance(val, Batch) and len(val.keys()) == 0): raise exception # val != Batch() return Batch() - def __getitem__(self, index: slice | int | list[int] | np.ndarray) -> RolloutBatchProtocol: + def __getitem__(self, index: IndexType) -> RolloutBatchProtocol: """Return a data batch: self[index]. If stack_num is larger than 1, return the stacked obs and obs_next with shape (batch, len, ...). """ + # TODO: this is a seriously problematic hack leading to + # buffer[slice] != buffer[np.arange(slice.start, slice.stop)] + # Fix asap, high priority!!! if isinstance(index, slice): # change slice to np array # buffer[:] will get all available data indices = ( @@ -402,7 +465,9 @@ def __getitem__(self, index: slice | int | list[int] | np.ndarray) -> RolloutBat if self._save_obs_next: obs_next = self.get(indices, "obs_next", Batch()) else: - obs_next = self.get(self.next(indices), "obs", Batch()) + obs_next_indices = self.next(indices) + obs_next = self.get(obs_next_indices, "obs", Batch()) + # TODO: don't do this batch_dict = { "obs": obs, "act": self.act[indices], @@ -415,7 +480,30 @@ def __getitem__(self, index: slice | int | list[int] | np.ndarray) -> RolloutBat # TODO: what's the use of this key? "policy": self.get(indices, "policy", Batch()), } - for key in self._meta.__dict__: - if key not in self._input_keys: - batch_dict[key] = self._meta[key][indices] + # TODO: don't do this, reduce complexity. Why such a big difference between what is returned + # and sub-batches of self._meta? + missing_keys = set(self._meta.get_keys()) - set(self._input_keys) + for key in missing_keys: + batch_dict[key] = self._meta[key][indices] return cast(RolloutBatchProtocol, Batch(batch_dict)) + + def set_array_at_key( + self, + seq: np.ndarray, + key: str, + index: IndexType | None = None, + default_value: float | None = None, + ) -> None: + self._meta.set_array_at_key(seq, key, index, default_value) + + def hasnull(self) -> bool: + return self[:].hasnull() + + def isnull(self) -> RolloutBatchProtocol: + return self[:].isnull() + + def dropnull(self) -> None: + # TODO: may fail, needs more testing with VectorBuffers + self._meta = self._meta.dropnull() + self._size = len(self._meta) + self._insertion_idx = len(self._meta) diff --git a/tianshou/data/buffer/her.py b/tianshou/data/buffer/her.py index 1ae1e8f23..087f8d0b0 100644 --- a/tianshou/data/buffer/her.py +++ b/tianshou/data/buffer/her.py @@ -110,7 +110,7 @@ def rewrite_transitions(self, indices: np.ndarray) -> None: return # Sort indices keeping chronological order - indices[indices < self._index] += self.maxsize + indices[indices < self._insertion_idx] += self.maxsize indices = np.sort(indices) indices[indices >= self.maxsize] -= self.maxsize diff --git a/tianshou/data/buffer/manager.py b/tianshou/data/buffer/manager.py index 38db7d120..e8176aa8c 100644 --- a/tianshou/data/buffer/manager.py +++ b/tianshou/data/buffer/manager.py @@ -1,8 +1,9 @@ from collections.abc import Sequence -from typing import Union +from typing import Union, cast import numpy as np from numba import njit +from overrides import override from tianshou.data import Batch, HERReplayBuffer, PrioritizedReplayBuffer, ReplayBuffer from tianshou.data.batch import alloc_by_keys_diff, create_value @@ -25,22 +26,37 @@ class ReplayBufferManager(ReplayBuffer): def __init__(self, buffer_list: list[ReplayBuffer] | list[HERReplayBuffer]) -> None: self.buffer_num = len(buffer_list) self.buffers = np.array(buffer_list, dtype=object) + last_index: list[int] = [] offset, size = [], 0 buffer_type = type(self.buffers[0]) kwargs = self.buffers[0].options for buf in self.buffers: + buf = cast(ReplayBuffer, buf) assert len(buf._meta.get_keys()) == 0 assert isinstance(buf, buffer_type) assert buf.options == kwargs offset.append(size) + if len(buf.last_index) != 1: + raise ValueError( + f"{self.__class__.__name__} only supports buffers with a single index " + f"(non-vector buffers), but got {last_index=}. " + f"Did you try to use a {self.__class__.__name__} within a {self.__class__.__name__}?", + ) + last_index.append(size + buf.last_index[0]) size += buf.maxsize + super().__init__(size=size, **kwargs) self._offset = np.array(offset) self._extend_offset = np.array([*offset, size]) self._lengths = np.zeros_like(offset) - super().__init__(size=size, **kwargs) + self.last_index = np.array(last_index) self._compile() self._meta: RolloutBatchProtocol + @property + @override + def subbuffer_edges(self) -> np.ndarray: + return self._extend_offset + def _compile(self) -> None: lens = last = index = np.array([0]) offset = np.array([0, 1]) @@ -52,6 +68,7 @@ def __len__(self) -> int: return int(self._lengths.sum()) def reset(self, keep_statistics: bool = False) -> None: + # keep in sync with init! self.last_index = self._offset.copy() self._lengths = np.zeros_like(self._offset) for buf in self.buffers: @@ -141,21 +158,27 @@ def add( # get index if buffer_ids is None: buffer_ids = np.arange(self.buffer_num) - ptrs, ep_lens, ep_rews, ep_idxs = [], [], [], [] + insertion_indxS, ep_lens, ep_returns, ep_idxs = [], [], [], [] for batch_idx, buffer_id in enumerate(buffer_ids): - ptr, ep_rew, ep_len, ep_idx = self.buffers[buffer_id]._add_index( + # TODO: don't access private method! + insertion_index, ep_return, ep_len, ep_start_idx = self.buffers[ + buffer_id + ]._update_state_pre_add( batch.rew[batch_idx], batch.done[batch_idx], ) - ptrs.append(ptr + self._offset[buffer_id]) + offset_insertion_idx = insertion_index + self._offset[buffer_id] + offset_ep_start_idx = ep_start_idx + self._offset[buffer_id] + insertion_indxS.append(offset_insertion_idx) ep_lens.append(ep_len) - ep_rews.append(ep_rew) - ep_idxs.append(ep_idx + self._offset[buffer_id]) - self.last_index[buffer_id] = ptr + self._offset[buffer_id] + ep_returns.append(ep_return) + ep_idxs.append(offset_ep_start_idx) + self.last_index[buffer_id] = insertion_index + self._offset[buffer_id] self._lengths[buffer_id] = len(self.buffers[buffer_id]) - ptrs = np.array(ptrs) + insertion_indxS = np.array(insertion_indxS) try: - self._meta[ptrs] = batch + self._meta[insertion_indxS] = batch + # TODO: don't do this! except ValueError: batch.rew = batch.rew.astype(float) batch.done = batch.done.astype(bool) @@ -166,8 +189,8 @@ def add( else: # dynamic key pops up in batch alloc_by_keys_diff(self._meta, batch, self.maxsize, False) self._set_batch_for_children() - self._meta[ptrs] = batch - return ptrs, np.array(ep_rews), np.array(ep_lens), np.array(ep_idxs) + self._meta[insertion_indxS] = batch + return insertion_indxS, np.array(ep_returns), np.array(ep_lens), np.array(ep_idxs) def sample_indices(self, batch_size: int | None) -> np.ndarray: # TODO: simplify this code @@ -206,6 +229,8 @@ def sample_indices(self, batch_size: int | None) -> np.ndarray: ) +# TODO: unintuitively, the order of inheritance has to stay this way for tests to pass +# As also described in the todo below, this is a bad design and should be refactored class PrioritizedReplayBufferManager(PrioritizedReplayBuffer, ReplayBufferManager): """PrioritizedReplayBufferManager contains a list of PrioritizedReplayBuffer with exactly the same configuration. @@ -221,11 +246,21 @@ class PrioritizedReplayBufferManager(PrioritizedReplayBuffer, ReplayBufferManage def __init__(self, buffer_list: Sequence[PrioritizedReplayBuffer]) -> None: ReplayBufferManager.__init__(self, buffer_list) # type: ignore + # last_index = copy(self.last_index) kwargs = buffer_list[0].options + last_index_from_buffer_manager = self.last_index + for buf in buffer_list: del buf.weight PrioritizedReplayBuffer.__init__(self, self.maxsize, **kwargs) + # TODO: the line below is needed since we now set the last_index of the manager in init + # (previously it was only set in reset), and it clashes with multiple inheritance + # Initializing the ReplayBufferManager after the PrioritizedReplayBuffer would be a better solution, + # but it currently leads to infinite recursion. This kind of multiple inheritance with overlapping + # interfaces is evil and we should get rid of it + self.last_index = last_index_from_buffer_manager + class HERReplayBufferManager(ReplayBufferManager): """HERReplayBufferManager contains a list of HERReplayBuffer with exactly the same configuration. diff --git a/tianshou/data/buffer/prio.py b/tianshou/data/buffer/prio.py index 406e39afd..3936bd641 100644 --- a/tianshou/data/buffer/prio.py +++ b/tianshou/data/buffer/prio.py @@ -4,6 +4,7 @@ import torch from tianshou.data import ReplayBuffer, SegmentTree, to_numpy +from tianshou.data.batch import IndexType from tianshou.data.types import PrioBatchProtocol, RolloutBatchProtocol @@ -87,7 +88,7 @@ def update_weight(self, index: np.ndarray, new_weight: np.ndarray | torch.Tensor self._max_prio = max(self._max_prio, weight.max()) self._min_prio = min(self._min_prio, weight.min()) - def __getitem__(self, index: slice | int | list[int] | np.ndarray) -> PrioBatchProtocol: + def __getitem__(self, index: IndexType) -> PrioBatchProtocol: if isinstance(index, slice): # change slice to np array # buffer[:] will get all available data indices = ( @@ -96,7 +97,7 @@ def __getitem__(self, index: slice | int | list[int] | np.ndarray) -> PrioBatchP else self._indices[: len(self)][index] ) else: - indices = index # type: ignore + indices = index batch = super().__getitem__(indices) weight = self.get_weight(indices) # ref: https://github.com/Kaixhin/Rainbow/blob/master/memory.py L154 From d9842e8b24ef84cedf57492fe35c510422ba5528 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Thu, 8 Aug 2024 17:34:20 +0200 Subject: [PATCH 13/33] Minor formatting, typing --- test/base/test_utils.py | 21 +++++++++++++++------ tianshou/highlevel/params/lr_scheduler.py | 1 + tianshou/utils/torch_utils.py | 5 +++-- 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/test/base/test_utils.py b/test/base/test_utils.py index 23fe0b337..5d85556fe 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -1,15 +1,17 @@ +from typing import cast + import numpy as np import pytest import torch import torch.distributions as dist from gymnasium import spaces from torch import nn -from utils.torch_utils import create_uniform_action_dist, torch_train_mode from tianshou.exploration import GaussianNoise, OUNoise from tianshou.utils import MovAvg, MultipleLRSchedulers, RunningMeanStd from tianshou.utils.net.common import MLP, Net from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic +from tianshou.utils.torch_utils import create_uniform_action_dist, torch_train_mode def test_noise() -> None: @@ -155,7 +157,7 @@ def test_in_train_mode() -> None: class TestCreateActionDistribution: @classmethod - def setup_class(cls): + def setup_class(cls) -> None: # Set random seeds for reproducibility torch.manual_seed(0) np.random.seed(0) @@ -170,7 +172,9 @@ def setup_class(cls): ], ) def test_distribution_properties( - self, action_space: spaces.Box | spaces.Discrete, batch_size: int, + self, + action_space: spaces.Box | spaces.Discrete, + batch_size: int, ) -> None: distribution = create_uniform_action_dist(action_space, batch_size) @@ -201,7 +205,9 @@ def test_distribution_properties( ], ) def test_distribution_uniformity( - self, action_space: spaces.Box | spaces.Discrete, batch_size: int, + self, + action_space: spaces.Box | spaces.Discrete, + batch_size: int, ) -> None: distribution = create_uniform_action_dist(action_space, batch_size) @@ -213,8 +219,9 @@ def test_distribution_uniformity( assert torch.allclose(large_sample.std(), torch.tensor(1 / 3**0.5), atol=0.1) elif isinstance(action_space, spaces.Discrete): # For Discrete, check if all actions are roughly equally likely - counts = torch.bincount(large_sample.flatten(), minlength=action_space.n).float() - expected_count = 10000 * batch_size / action_space.n + n_actions = cast(int, action_space.n) + counts = torch.bincount(large_sample.flatten(), minlength=n_actions).float() + expected_count = 10000 * batch_size / n_actions assert torch.allclose(counts, torch.tensor(expected_count).float(), rtol=0.1) def test_unsupported_space(self) -> None: @@ -251,7 +258,9 @@ def test_batch_sizes( # Check internal distribution shapes if isinstance(space, spaces.Box): + distribution = cast(dist.Uniform, distribution) assert distribution.low.shape == expected_shape assert distribution.high.shape == expected_shape elif isinstance(space, spaces.Discrete): + distribution = cast(dist.Categorical, distribution) assert distribution.probs.shape == (batch_size, space.n) diff --git a/tianshou/highlevel/params/lr_scheduler.py b/tianshou/highlevel/params/lr_scheduler.py index 0b0cf359a..883afe61c 100644 --- a/tianshou/highlevel/params/lr_scheduler.py +++ b/tianshou/highlevel/params/lr_scheduler.py @@ -25,6 +25,7 @@ def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler: class _LRLambda: def __init__(self, sampling_config: SamplingConfig): + assert sampling_config.step_per_collect is not None self.max_update_num = ( np.ceil(sampling_config.step_per_epoch / sampling_config.step_per_collect) * sampling_config.num_epochs diff --git a/tianshou/utils/torch_utils.py b/tianshou/utils/torch_utils.py index 44d5f7668..1ffb9fcd8 100644 --- a/tianshou/utils/torch_utils.py +++ b/tianshou/utils/torch_utils.py @@ -49,7 +49,8 @@ def create_uniform_action_dist(action_space: spaces.Box, batch_size: int = 1) -> @overload def create_uniform_action_dist( - action_space: spaces.Discrete, batch_size: int = 1, + action_space: spaces.Discrete, + batch_size: int = 1, ) -> dist.Categorical: ... @@ -70,7 +71,7 @@ def create_uniform_action_dist( return dist.Uniform(low, high) elif isinstance(action_space, spaces.Discrete): - return dist.Categorical(torch.ones(batch_size, action_space.n)) + return dist.Categorical(torch.ones(batch_size, int(action_space.n))) else: raise ValueError(f"Unsupported action space type: {type(action_space)}") From 3193d617ff12faedb4a239ef3c155d1c37516cc4 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Thu, 8 Aug 2024 17:38:32 +0200 Subject: [PATCH 14/33] PPO, Changed implementation detail! Important! Previously the advantages were normalized multiple times if `repeat` is set to more than 1 Also minor improvement and extension of PPOTrainingStats --- tianshou/policy/modelfree/ppo.py | 57 +++++++++++++++++++++----------- 1 file changed, 38 insertions(+), 19 deletions(-) diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index 1933c7d54..a4694b57b 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -1,5 +1,6 @@ +from collections.abc import Sequence from dataclasses import dataclass -from typing import Any, Generic, Literal, TypeVar +from typing import Any, Generic, Literal, Self, TypeVar import gymnasium as gym import numpy as np @@ -23,6 +24,25 @@ class PPOTrainingStats(TrainingStats): clip_loss: SequenceSummaryStats vf_loss: SequenceSummaryStats ent_loss: SequenceSummaryStats + gradient_steps: int = 0 + + @classmethod + def from_sequences( + cls, + *, + losses: Sequence[float], + clip_losses: Sequence[float], + vf_losses: Sequence[float], + ent_losses: Sequence[float], + gradient_steps: int = 0, + ) -> Self: + return cls( + loss=SequenceSummaryStats.from_sequence(losses), + clip_loss=SequenceSummaryStats.from_sequence(clip_losses), + vf_loss=SequenceSummaryStats.from_sequence(vf_losses), + ent_loss=SequenceSummaryStats.from_sequence(ent_losses), + gradient_steps=gradient_steps, + ) TPPOTrainingStats = TypeVar("TPPOTrainingStats", bound=PPOTrainingStats) @@ -155,24 +175,27 @@ def learn( # type: ignore **kwargs: Any, ) -> TPPOTrainingStats: losses, clip_losses, vf_losses, ent_losses = [], [], [], [] + gradient_steps = 0 split_batch_size = batch_size or -1 for step in range(repeat): if self.recompute_adv and step > 0: batch = self._compute_returns(batch, self._buffer, self._indices) for minibatch in batch.split(split_batch_size, merge_last=True): + gradient_steps += 1 # calculate loss for actor + advantages = minibatch.adv dist = self(minibatch).dist if self.norm_adv: - mean, std = minibatch.adv.mean(), minibatch.adv.std() - minibatch.adv = (minibatch.adv - mean) / (std + self._eps) # per-batch norm - ratio = (dist.log_prob(minibatch.act) - minibatch.logp_old).exp().float() - ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1) - surr1 = ratio * minibatch.adv - surr2 = ratio.clamp(1.0 - self.eps_clip, 1.0 + self.eps_clip) * minibatch.adv + mean, std = advantages.mean(), advantages.std() + advantages = (advantages - mean) / (std + self._eps) # per-batch norm + ratios = (dist.log_prob(minibatch.act) - minibatch.logp_old).exp().float() + ratios = ratios.reshape(ratios.size(0), -1).transpose(0, 1) + surr1 = ratios * advantages + surr2 = ratios.clamp(1.0 - self.eps_clip, 1.0 + self.eps_clip) * advantages if self.dual_clip: clip1 = torch.min(surr1, surr2) - clip2 = torch.max(clip1, self.dual_clip * minibatch.adv) - clip_loss = -torch.where(minibatch.adv < 0, clip2, clip1).mean() + clip2 = torch.max(clip1, self.dual_clip * advantages) + clip_loss = -torch.where(advantages < 0, clip2, clip1).mean() else: clip_loss = -torch.min(surr1, surr2).mean() # calculate loss for critic @@ -203,14 +226,10 @@ def learn( # type: ignore ent_losses.append(ent_loss.item()) losses.append(loss.item()) - losses_summary = SequenceSummaryStats.from_sequence(losses) - clip_losses_summary = SequenceSummaryStats.from_sequence(clip_losses) - vf_losses_summary = SequenceSummaryStats.from_sequence(vf_losses) - ent_losses_summary = SequenceSummaryStats.from_sequence(ent_losses) - - return PPOTrainingStats( # type: ignore[return-value] - loss=losses_summary, - clip_loss=clip_losses_summary, - vf_loss=vf_losses_summary, - ent_loss=ent_losses_summary, + return PPOTrainingStats.from_sequences( # type: ignore[return-value] + losses=losses, + clip_losses=clip_losses, + vf_losses=vf_losses, + ent_losses=ent_losses, + gradient_steps=gradient_steps, ) From dfac7ad3d1a36d0d5ffdcc9838859a7a32720701 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Thu, 8 Aug 2024 17:54:06 +0200 Subject: [PATCH 15/33] SAC: minor refactoring (extracted correct_log_prob_gaussian_tanh) --- tianshou/policy/modelfree/sac.py | 33 ++++++++++++++++++++++---------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index a5a05c0fd..2348890d3 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -21,6 +21,25 @@ from tianshou.utils.optim import clone_optimizer +def correct_log_prob_gaussian_tanh( + log_prob: torch.Tensor, + squashed_action: torch.Tensor, + eps: float = np.finfo(np.float32).eps.item(), +) -> torch.Tensor: + """Apply correction for Tanh squashing when computing logprob from Gaussian. + + You can check out the original SAC paper (arXiv 1801.01290): Eq 21. + in appendix C to get some understanding of this equation. + :param log_prob: log probability of the action + :param squashed_action: tanh-squashed action + :param eps: epsilon for numerical stability + """ + return log_prob - torch.log((1 - squashed_action.pow(2)) + eps).sum( + -1, + keepdim=True, + ) + + @dataclass(kw_only=True) class SACTrainingStats(TrainingStats): actor_loss: float @@ -63,6 +82,8 @@ class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]): # t :param action_bound_method: method to bound action to range [-1, 1], can be either "clip" (for simply clipping the action) or empty string for no bounding. Only used if the action_space is continuous. + This parameter is ignored in SAC, which used tanh squashing after sampling + unbounded from the gaussian policy (as in (arXiv 1801.01290): Eq 21.). :param observation_space: Env's observation space. :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in optimizer in each policy.update() @@ -90,8 +111,6 @@ def __init__( exploration_noise: BaseNoise | Literal["default"] | None = None, deterministic_eval: bool = True, action_scaling: bool = True, - # TODO: some papers claim that tanh is crucial for SAC, yet DDPG will raise an - # error if tanh is used. Should be investigated. action_bound_method: Literal["clip"] | None = "clip", observation_space: gym.Space | None = None, lr_scheduler: TLearningRateScheduler | None = None, @@ -117,7 +136,6 @@ def __init__( self.critic2_old.eval() self.critic2_optim = critic2_optim self.deterministic_eval = deterministic_eval - self.__eps = np.finfo(np.float32).eps.item() self.alpha: float | torch.Tensor self._is_auto_alpha = not isinstance(alpha, float) @@ -180,14 +198,9 @@ def forward( # type: ignore else: act_B = dist.rsample() log_prob = dist.log_prob(act_B).unsqueeze(-1) - # apply correction for Tanh squashing when computing logprob from Gaussian - # You can check out the original SAC paper (arXiv 1801.01290): Eq 21. - # in appendix C to get some understanding of this equation. + squashed_action = torch.tanh(act_B) - log_prob = log_prob - torch.log((1 - squashed_action.pow(2)) + self.__eps).sum( - -1, - keepdim=True, - ) + log_prob = correct_log_prob_gaussian_tanh(log_prob, squashed_action) result = Batch( logits=(loc_B, scale_B), act=squashed_action, From e9d1b6852c87ae2632c1ff20c11d9f07bee28a67 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Thu, 8 Aug 2024 17:58:42 +0200 Subject: [PATCH 16/33] Trainer: multiple small enhancements 1. improved logging 2. extended resetting possibilities 3. collect stats for n_episodes 4. raise error on NaNs in buffer Added some comments and TODOs --- tianshou/trainer/base.py | 77 ++++++++++++++++++++++++++++++++-------- 1 file changed, 63 insertions(+), 14 deletions(-) diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index 242f2b028..76805c0b2 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -4,9 +4,12 @@ from collections import defaultdict, deque from collections.abc import Callable from dataclasses import asdict +from functools import partial import numpy as np import tqdm +from data.buffer.base import MalformedBufferError +from utils.torch_utils import policy_within_training_step from tianshou.data import ( AsyncCollector, @@ -22,13 +25,10 @@ from tianshou.trainer.utils import gather_info, test_episode from tianshou.utils import ( BaseLogger, - DummyTqdm, LazyLogger, MovAvg, - tqdm_config, ) from tianshou.utils.logging import set_numerical_fields_to_precision -from tianshou.utils.torch_utils import policy_within_training_step log = logging.getLogger(__name__) @@ -168,11 +168,12 @@ def __init__( save_checkpoint_fn: Callable[[int, int, int], str] | None = None, resume_from_log: bool = False, reward_metric: Callable[[np.ndarray], np.ndarray] | None = None, - logger: BaseLogger = LazyLogger(), + logger: BaseLogger | None = None, verbose: bool = True, show_progress: bool = True, test_in_train: bool = True, ): + logger = logger or LazyLogger() self.policy = policy if buffer is not None: @@ -192,6 +193,7 @@ def __init__( # of the trainers. I believe it would be better to remove self._gradient_step = 0 self.env_step = 0 + self.env_episode = 0 self.policy_update_time = 0.0 self.max_epoch = max_epoch self.step_per_epoch = step_per_epoch @@ -227,6 +229,16 @@ def __init__( self.stop_fn_flag = False self.iter_num = 0 + @property + def _pbar(self) -> type[tqdm.tqdm]: + """Use as context manager or iterator, i.e., `with self._pbar(...) as t:` or `for _ in self._pbar(...):`.""" + return partial( + tqdm.tqdm, + dynamic_ncols=True, + ascii=True, + disable=not self.show_progress, + ) # type: ignore[return-value] + def _reset_collectors(self, reset_buffer: bool = False) -> None: if self.train_collector is not None: self.train_collector.reset(reset_buffer=reset_buffer) @@ -298,10 +310,8 @@ def __next__(self) -> EpochStats: if self.stop_fn_flag: raise StopIteration - progress = tqdm.tqdm if self.show_progress else DummyTqdm - # perform n step_per_epoch - with progress(total=self.step_per_epoch, desc=f"Epoch #{self.epoch}", **tqdm_config) as t: + with self._pbar(total=self.step_per_epoch, desc=f"Epoch #{self.epoch}", position=1) as t: train_stat: CollectStatsBase while t.n < t.total and not self.stop_fn_flag: train_stat, update_stat, self.stop_fn_flag = self.training_step() @@ -309,14 +319,23 @@ def __next__(self) -> EpochStats: if isinstance(train_stat, CollectStats): pbar_data_dict = { "env_step": str(self.env_step), + "env_episode": str(self.env_episode), "rew": f"{self.last_rew:.2f}", "len": str(int(self.last_len)), "n/ep": str(train_stat.n_collected_episodes), "n/st": str(train_stat.n_collected_steps), } t.update(train_stat.n_collected_steps) + if self.stop_fn_flag: + t.set_postfix(**pbar_data_dict) else: + # TODO: there is no iteration happening here, it's the offline case + # Code should be restructured! pbar_data_dict = {} + assert self.buffer, "No train_collector or buffer specified" + train_stat = CollectStatsBase( + n_collected_steps=len(self.buffer), + ) t.update() pbar_data_dict = set_numerical_fields_to_precision(pbar_data_dict) @@ -449,7 +468,20 @@ def _collect_training_data(self) -> CollectStats: n_episode=self.episode_per_collect, ) + if self.train_collector.buffer.hasnull(): + from tianshou.data.collector import EpisodeRolloutHook + from tianshou.env import DummyVectorEnv + + raise MalformedBufferError( + f"Encountered NaNs in buffer after {self.env_step} steps." + f"Such errors are usually caused by either a bug in the environment or by " + f"problematic implementations {EpisodeRolloutHook.__class__.__name__}. " + f"For debugging such issues it is recommended to run the training in a single process, " + f"e.g., by using {DummyVectorEnv.__class__.__name__}.", + ) + self.env_step += collect_stats.n_collected_steps + self.env_episode += collect_stats.n_collected_episodes if collect_stats.n_collected_episodes > 0: assert collect_stats.returns_stat is not None # for mypy @@ -462,7 +494,6 @@ def _collect_training_data(self) -> CollectStats: collect_stats.returns_stat = SequenceSummaryStats.from_sequence(rew) self.logger.log_train_data(asdict(collect_stats), self.env_step) - return collect_stats # TODO (maybe): separate out side effect, simplify name? @@ -547,14 +578,18 @@ def policy_update_fn( stats of the whole dataset """ - def run(self, reset_prior_to_run: bool = True) -> InfoStats: + def run(self, reset_prior_to_run: bool = True, reset_buffer: bool = False) -> InfoStats: """Consume iterator. See itertools - recipes. Use functions that consume iterators at C speed (feed the entire iterator into a zero-length deque). + :param reset_prior_to_run: whether to reset collectors prior to run + :param reset_buffer: only has effect if `reset_prior_to_run` is True. + Then it will also reset the buffer. This is usually not necessary, use + with caution. """ if reset_prior_to_run: - self.reset() + self.reset(reset_buffer=reset_buffer) try: self.is_run = True deque(self, maxlen=0) # feed the entire iterator into a zero-length deque @@ -635,10 +670,14 @@ def policy_update_fn( f"n_gradient_steps is 0, n_collected_steps={n_collected_steps}, " f"update_per_step={self.update_per_step}", ) - for _ in range(n_gradient_steps): - update_stat = self._sample_and_update(self.train_collector.buffer) - # logging + for _ in self._pbar( + range(n_gradient_steps), + desc="Offpolicy gradient update", + position=0, + leave=False, + ): + update_stat = self._sample_and_update(self.train_collector.buffer) self.policy_update_time += update_stat.train_time # TODO: only the last update_stat is returned, should be improved return update_stat @@ -661,6 +700,11 @@ def policy_update_fn( ) -> TrainingStats: """Perform one on-policy update by passing the entire buffer to the policy's update method.""" assert self.train_collector is not None + # TODO: add logging like in off-policy. Iteration over minibatches currently happens in the learn implementation of + # on-policy algos like PG or PPO + log.info( + f"Performing on-policy update on buffer of length {len(self.train_collector.buffer)}", + ) training_stat = self.policy.update( sample_size=0, buffer=self.train_collector.buffer, @@ -682,10 +726,15 @@ def policy_update_fn( elif self.batch_size > 0: self._gradient_step += int((len(self.train_collector.buffer) - 0.1) // self.batch_size) - # Note: this is the main difference to the off-policy trainer! + # Note 1: this is the main difference to the off-policy trainer! # The second difference is that batches of data are sampled without replacement # during training, whereas in off-policy or offline training, the batches are # sampled with replacement (and potentially custom prioritization). + # Note 2: in the policy-update we modify the buffer, which is not very clean. + # currently the modification will erase previous samples but keep things like + # _ep_rew and _ep_len. This means that such quantities can no longer be computed + # from samples still contained in the buffer, which is also not clean + # TODO: improve this situation self.train_collector.reset_buffer(keep_statistics=True) # The step is the number of mini-batches used for the update, so essentially From da3a2998d388165d3d054a65af46263989a62c66 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Thu, 8 Aug 2024 18:21:49 +0200 Subject: [PATCH 17/33] Collector: major extension and some refactoring 1. Added support for Step and Episode hooks. The latter are particularly tricky. 2. Refactoring: collect results of action computation in batch instead of in tuple 3. Enhanced input validation 4. Better variable names Also a bunch of comments and todos --- test/base/test_collector.py | 14 +- test/base/test_policy.py | 8 + tianshou/data/collector.py | 440 +++++++++++++++++++++++++++++++----- tianshou/utils/array.py | 15 ++ 4 files changed, 413 insertions(+), 64 deletions(-) create mode 100644 tianshou/utils/array.py diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 95b604905..d03a54df7 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -1,6 +1,6 @@ from collections.abc import Callable, Sequence from test.base.env import MoveToRightEnv, NXEnv -from typing import Any, cast +from typing import Any import gymnasium as gym import numpy as np @@ -17,11 +17,7 @@ VectorReplayBuffer, ) from tianshou.data.batch import BatchProtocol -from tianshou.data.types import ( - ActStateBatchProtocol, - ObsBatchProtocol, - RolloutBatchProtocol, -) +from tianshou.data.types import ObsBatchProtocol, RolloutBatchProtocol from tianshou.env import DummyVectorEnv, SubprocVectorEnv from tianshou.policy import BasePolicy, TrainingStats @@ -58,7 +54,7 @@ def forward( batch: ObsBatchProtocol, state: dict | BatchProtocol | np.ndarray | None = None, **kwargs: Any, - ) -> ActStateBatchProtocol: + ) -> Batch: if self.need_state: if state is None: state = np.zeros((len(batch.obs), 2)) @@ -73,9 +69,9 @@ def forward( action_shape = len(batch.obs["index"]) else: action_shape = len(batch.obs) - return cast(ActStateBatchProtocol, Batch(act=np.ones(action_shape), state=state)) + return Batch(act=np.ones(action_shape), state=state) action_shape = self.action_shape if self.action_shape else len(batch.obs) - return cast(ActStateBatchProtocol, Batch(act=np.ones(action_shape), state=state)) + return Batch(act=np.ones(action_shape), state=state) def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TrainingStats: raise NotImplementedError diff --git a/test/base/test_policy.py b/test/base/test_policy.py index 4d26905c3..b222fd3f8 100644 --- a/test/base/test_policy.py +++ b/test/base/test_policy.py @@ -5,6 +5,7 @@ from torch.distributions import Categorical, Distribution, Independent, Normal from tianshou.policy import BasePolicy, PPOPolicy +from tianshou.policy.base import episode_mc_return_to_go from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.continuous import ActorProb, Critic from tianshou.utils.net.discrete import Actor @@ -16,6 +17,13 @@ def _to_hashable(x: np.ndarray | int) -> int | tuple[list]: return x if isinstance(x, int) else tuple(x.tolist()) +def test_calculate_discounted_returns() -> None: + assert np.all( + episode_mc_return_to_go([1, 1, 1], 0.9) == np.array([0.9**2 + 0.9 + 1, 0.9 + 1, 1]), + ) + assert episode_mc_return_to_go([1, 2, 3], 0.5)[0] == 1 + 0.5 * (2 + 0.5 * 3) + + @pytest.fixture(params=["continuous", "discrete"]) def policy(request: pytest.FixtureRequest) -> PPOPolicy: action_type = request.param diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 6773a6383..a813d8bbc 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -2,14 +2,16 @@ import time import warnings from abc import ABC, abstractmethod +from collections.abc import Sequence from copy import copy from dataclasses import dataclass -from typing import Any, Self, TypeVar, cast +from typing import Any, Optional, Protocol, Self, TypeVar, cast import gymnasium as gym import numpy as np import torch from overrides import override +from torch.distributions import Distribution from tianshou.data import ( Batch, @@ -20,17 +22,40 @@ VectorReplayBuffer, to_numpy, ) +from tianshou.data.buffer.base import MalformedBufferError from tianshou.data.types import ( + ActBatchProtocol, + DistBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol, ) from tianshou.env import BaseVectorEnv, DummyVectorEnv from tianshou.policy import BasePolicy +from tianshou.policy.base import episode_mc_return_to_go +from tianshou.utils.array import bisect_left, bisect_right from tianshou.utils.print import DataclassPPrintMixin from tianshou.utils.torch_utils import torch_train_mode log = logging.getLogger(__name__) +_TArrLike = TypeVar("_TArrLike", bound="np.ndarray | torch.Tensor | Batch | None") + + +def _get_start_stop_tuples_around_edges( + edges: Sequence[int], + start: int, + stop: int, +) -> tuple[tuple[int, int], tuple[int, int]]: + """We assume that stop is smaller than start and that `edges` is a sorted array of integers. + + Then it will return the two tuples containing (start, stop) where we go from start to the next edge, + and from the previous edge to stop. + :return: (start, upper_edge), (lower_edge, stop) + """ + upper_edge = int(bisect_right(edges, start)) + lower_edge = int(bisect_left(edges, stop)) + return (start, upper_edge), (lower_edge, stop) + @dataclass(kw_only=True) class CollectStatsBase(DataclassPPrintMixin): @@ -58,6 +83,10 @@ class CollectStats(CollectStatsBase): """The collected episode lengths.""" lens_stat: SequenceSummaryStats | None # can be None if no episode ends during the collect step """Stats of the collected episode lengths.""" + std_array: np.ndarray | None = None + """The standard deviations of the predicted distributions.""" + std_array_stat: SequenceSummaryStats | None = None + """Stats of the standard deviations of the predicted distributions.""" @classmethod def with_autogenerated_stats( @@ -84,7 +113,18 @@ def with_autogenerated_stats( ) -_TArrLike = TypeVar("_TArrLike", bound="np.ndarray | torch.Tensor | Batch | None") +class CollectActionBatchProtocol(Protocol): + """A protocol for results of computing actions within a single collect step. + + All fields all have length R (the dist is a Distribution of batch size R), + where R is the number of ready envs. + """ + + act: np.ndarray | torch.Tensor + act_normalized: np.ndarray | torch.Tensor + policy_entry: Batch + dist: Distribution | None + hidden_state: np.ndarray | torch.Tensor | Batch | None def _nullable_slice(obj: _TArrLike, indices: np.ndarray) -> _TArrLike: @@ -156,7 +196,7 @@ def __init__( if buffer is None: buffer = VectorReplayBuffer(len(env), len(env)) - self.buffer: ReplayBuffer = buffer + self.buffer: ReplayBuffer | ReplayBufferManager = buffer self.policy = policy self.env = cast(BaseVectorEnv, env) self.exploration_noise = exploration_noise @@ -167,6 +207,27 @@ def __init__( self._validate_buffer() + @property + def _subbuffer_edges(self) -> np.ndarray: + return self.buffer.subbuffer_edges + + def _get_start_stop_tuples_for_edge_crossing_interval( + self, + start: int, + stop: int, + ) -> tuple[tuple[int, int], tuple[int, int]]: + """:return: (start, upper_edge), (lower_edge, stop)""" + log.debug( + "Received an edge-crossing episode: {start=}, {stop=}, {self._subbuffer_edges=}", + ) + if stop >= start: + raise ValueError( + f"Expected stop < start, but got {start=}, {stop=}. " + f"For stop larger than start this should never be used, and stop=start should never occur.", + ) + subbuffer_edges = cast(Sequence[int], self._subbuffer_edges) + return _get_start_stop_tuples_around_edges(subbuffer_edges, start, stop) + def _validate_buffer(self) -> None: buf = self.buffer # TODO: a bit weird but true - all VectorReplayBuffers inherit from ReplayBufferManager. @@ -237,10 +298,7 @@ def reset_env( self, gym_reset_kwargs: dict[str, Any] | None = None, ) -> tuple[np.ndarray, np.ndarray]: - """Reset the environments and the initial obs, info, and hidden state of the collector. - - :return: The initial observation and info from the (vectorized) environment. - """ + """Reset the environments and the initial obs, info, and hidden state of the collector.""" gym_reset_kwargs = gym_reset_kwargs or {} obs_NO, info_N = self.env.reset(**gym_reset_kwargs) # TODO: hack, wrap envpool envs such that they don't return a dict @@ -262,7 +320,6 @@ def _collect( ) -> CollectStats: pass - @torch.no_grad() def collect( self, n_step: int | None = None, @@ -300,7 +357,7 @@ def collect( if reset_before_collect: self.reset(reset_buffer=False, gym_reset_kwargs=gym_reset_kwargs) - with torch_train_mode(self.policy, False): + with torch_train_mode(self.policy, enabled=False): return self._collect( n_step=n_step, n_episode=n_episode, @@ -352,6 +409,9 @@ def __init__( env: gym.Env | BaseVectorEnv, buffer: ReplayBuffer | None = None, exploration_noise: bool = False, + on_episode_done_hook: Optional["EpisodeRolloutHookProtocol"] = None, + on_step_hook: Optional["StepHookProtocol"] = None, + raise_on_nan_in_buffer: bool = True, ) -> None: """:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. :param env: a ``gym.Env`` environment or an instance of the @@ -363,6 +423,26 @@ def __init__( with the corresponding policy's exploration noise. If so, "policy. exploration_noise(act, batch)" will be called automatically to add the exploration noise into action. Default to False. + :param on_episode_done_hook: if passed, will be executed when an episode is done. + The input to the hook will be a `RolloutBatch` that contains the entire episode (and nothing else). + The dict returned by the hook will be used to add new entries to the buffer + for the episode that just ended. The hook should return arrays with floats + which should be of the same length as the input rollout batch. + If you have multiple hooks, you can use the `CombinedRolloutHook` class to combine them. + A typical example of a hook is the `EpisodeRolloutHookMCReturn` which adds the Monte Carlo return + as a field to the buffer. + + Care must be taken when using such hook, as for unfinished episodes one can easily end + up with NaNs in the buffer. It is recommended to use the hooks only with the `n_episode` option + in `collect`, or to strip the buffer of NaNs after the collection. + :param on_step_hook: if passed, will be executed after each step of the collection but before the + rollout batch is resulting added to the buffer. The inputs to the hook will be + the action distributions computed from the previous observations (following the + :class:`CollectActionBatchProtocol`) using the policy, and the resulting + rollout batch (following the :class:`RolloutBatchProtocol`). + :param raise_on_nan_in_buffer: whether to raise a Runtime if NaNs are found in the buffer after + a collection step. Especially useful when using episode-level hooks. Consider setting to False if + the NaN-check becomes a bottleneck. """ super().__init__(policy, env, buffer, exploration_noise=exploration_noise) self._pre_collect_obs_RO: np.ndarray | None = None @@ -370,19 +450,48 @@ def __init__( self._pre_collect_hidden_state_RH: np.ndarray | torch.Tensor | Batch | None = None self._is_closed = False + self.on_episode_done_hook = on_episode_done_hook + self.on_step_hook = on_step_hook self.collect_step, self.collect_episode, self.collect_time = 0, 0, 0.0 - @override def close(self) -> None: super().close() self._pre_collect_obs_RO = None self._pre_collect_info_R = None - @override + def run_on_episode_done( + self, + episode_batch: RolloutBatchProtocol, + ) -> dict[str, np.ndarray] | None: + """Executes the `on_episode_done_hook` that was passed on init. + + One of the main uses of this public method is to allow users to override it in custom + subclasses of the `Collector`. This way, they can override the init to no longer accept + the `on_episode_done` provider. + """ + if self.on_episode_done_hook is not None: + return self.on_episode_done_hook(episode_batch) + return None + + def run_on_step_hook( + self, + action_batch: CollectActionBatchProtocol, + rollout_batch: RolloutBatchProtocol, + ) -> None: + """Executes the instance's `on_step_hook`. + + One of the main uses of this public method is to allow users to override it in custom + subclasses of the `Collector`. This way, they can override the init to no longer accept + the `on_step_hook` provider. + """ + if self.on_step_hook is not None: + self.on_step_hook(action_batch, rollout_batch) + def reset_env( self, gym_reset_kwargs: dict[str, Any] | None = None, ) -> tuple[np.ndarray, np.ndarray]: + """Reset the environments and the initial obs, info, and hidden state of the collector.""" obs_NO, info_N = super().reset_env(gym_reset_kwargs=gym_reset_kwargs) # We assume that R = N when reset is called. # TODO: there is currently no mechanism that ensures this and it's a public method! @@ -398,7 +507,7 @@ def _compute_action_policy_hidden( last_obs_RO: np.ndarray, last_info_R: np.ndarray, last_hidden_state_RH: np.ndarray | torch.Tensor | Batch | None = None, - ) -> tuple[np.ndarray, np.ndarray, Batch, np.ndarray | torch.Tensor | Batch | None]: + ) -> CollectActionBatchProtocol: """Returns the action, the normalized action, a "policy" entry, and the hidden state.""" if random: try: @@ -411,12 +520,14 @@ def _compute_action_policy_hidden( act_RA = self.policy.map_action_inverse(np.array(act_normalized_RA)) policy_R = Batch() hidden_state_RH = None + # TODO: instead use a (uniform) Distribution instance that corresponds to sampling from action_space + action_dist_R = None else: info_batch = _HACKY_create_info_batch(last_info_R) obs_batch_R = cast(ObsBatchProtocol, Batch(obs=last_obs_RO, info=info_batch)) - act_batch_RA = self.policy( + act_batch_RA: ActBatchProtocol | DistBatchProtocol = self.policy( obs_batch_R, last_hidden_state_RH, ) @@ -440,10 +551,22 @@ def _compute_action_policy_hidden( policy_R.hidden_state = ( hidden_state_RH # save state into buffer through policy attr ) - return act_RA, act_normalized_RA, policy_R, hidden_state_RH + # can't use act_batch_RA.dist directly as act_batch_RA might not have that attribute + action_dist_R = act_batch_RA.get("dist") + + return cast( + CollectActionBatchProtocol, + Batch( + act=act_RA, + act_normalized=act_normalized_RA, + policy_entry=policy_R, + dist=action_dist_R, + hidden_state=hidden_state_RH, + ), + ) # TODO: reduce complexity, remove the noqa - def _collect( + def _collect( # noqa: C901 self, n_step: int | None = None, n_episode: int | None = None, @@ -461,9 +584,15 @@ def _collect( if n_step is not None: ready_env_ids_R = np.arange(self.env_num) elif n_episode is not None: + if self.env_num > n_episode: + log.warning( + f"Number of episodes ({n_episode}) is smaller than the number of environments " + f"({self.env_num}). This means that {self.env_num - n_episode} " + f"environments (or, equivalently, parallel workers) will not be used!", + ) ready_env_ids_R = np.arange(min(self.env_num, n_episode)) else: - raise ValueError("Either n_step or n_episode should be set.") + raise RuntimeError("Input validation failed, this is a bug and shouldn't have happened") start_time = time.time() if self._pre_collect_obs_RO is None or self._pre_collect_info_R is None: @@ -498,13 +627,8 @@ def _collect( # ) # restore the state: if the last state is None, it won't store - # get the next action - ( - act_RA, - act_normalized_RA, - policy_R, - hidden_state_RH, - ) = self._compute_action_policy_hidden( + # get the next action and related stats from the previous observation + collect_action_computation_batch_R = self._compute_action_policy_hidden( random=random, ready_env_ids_R=ready_env_ids_R, last_obs_RO=last_obs_RO, @@ -513,7 +637,7 @@ def _collect( ) obs_next_RO, rew_R, terminated_R, truncated_R, info_R = self.env.step( - act_normalized_RA, + collect_action_computation_batch_R.act_normalized, ready_env_ids_R, ) if isinstance(info_R, dict): # type: ignore[unreachable] @@ -521,12 +645,12 @@ def _collect( info_R = _dict_of_arr_to_arr_of_dicts(info_R) # type: ignore[unreachable] done_R = np.logical_or(terminated_R, truncated_R) - current_iteration_batch = cast( + current_iteration_batch_R = cast( RolloutBatchProtocol, Batch( obs=last_obs_RO, - act=act_RA, - policy=policy_R, + act=collect_action_computation_batch_R.act, + policy=collect_action_computation_batch_R.policy_entry, obs_next=obs_next_RO, rew=rew_R, terminated=terminated_R, @@ -543,9 +667,13 @@ def _collect( if not np.isclose(render, 0): time.sleep(render) + self.run_on_step_hook( + collect_action_computation_batch_R, + current_iteration_batch_R, + ) # add data into the buffer - ptr_R, ep_rew_R, ep_len_R, ep_idx_R = self.buffer.add( - current_iteration_batch, + insertion_idx_R, ep_return_R, ep_len_R, ep_start_idx_R = self.buffer.add( + current_iteration_batch_R, buffer_ids=ready_env_ids_R, ) @@ -559,35 +687,59 @@ def _collect( # so we copy to not affect the data in the buffer last_obs_RO = copy(obs_next_RO) last_info_R = copy(info_R) - last_hidden_state_RH = copy(hidden_state_RH) + last_hidden_state_RH = copy(collect_action_computation_batch_R.hidden_state) # Preparing last_obs_RO, last_info_R, last_hidden_state_RH for the next while-loop iteration # Resetting envs that reached done, or removing some of them from the collection if needed (see below) if num_episodes_done_this_iter > 0: # TODO: adjust the whole index story, don't use np.where, just slice with boolean arrays # D - number of envs that reached done in the rollout above - env_ind_local_D = np.where(done_R)[0] - env_ind_global_D = ready_env_ids_R[env_ind_local_D] - episode_lens.extend(ep_len_R[env_ind_local_D]) - episode_returns.extend(ep_rew_R[env_ind_local_D]) - episode_start_indices.extend(ep_idx_R[env_ind_local_D]) + env_done_local_idx_D = np.where(done_R)[0] + episode_lens.extend(ep_len_R[env_done_local_idx_D]) + episode_returns.extend(ep_return_R[env_done_local_idx_D]) + episode_start_indices.extend(ep_start_idx_R[env_done_local_idx_D]) # now we copy obs_next to obs, but since there might be # finished episodes, we have to reset finished envs first. gym_reset_kwargs = gym_reset_kwargs or {} + + # The index env_done_idx_D was based on 0, ..., R + # However, each env has an index in the context of the vectorized env and buffer. So the env 0 being done means + # that some env of the corresponding "global" index was done. The mapping between "local" index in + # 0,...,R and this global index is maintained by the ready_env_ids_R array + env_done_global_idx_D = ready_env_ids_R[env_done_local_idx_D] obs_reset_DO, info_reset_D = self.env.reset( - env_id=env_ind_global_D, + env_id=env_done_global_idx_D, **gym_reset_kwargs, ) # Set the hidden state to zero or None for the envs that reached done # TODO: does it have to be so complicated? We should have a single clear type for hidden_state instead of # this complex logic - self._reset_hidden_state_based_on_type(env_ind_local_D, last_hidden_state_RH) + self._reset_hidden_state_based_on_type(env_done_local_idx_D, last_hidden_state_RH) + + # execute episode hooks for those envs which emitted 'done' + for local_done_idx in env_done_local_idx_D: + cur_ep_index_slice = slice( + ep_start_idx_R[local_done_idx], + insertion_idx_R[local_done_idx] + 1, + ) + + cur_ep_index_array, ep_rollout_batch = self._get_buffer_index_and_entries( + cur_ep_index_slice, + ) + episode_hook_additions = self.run_on_episode_done(ep_rollout_batch) + if episode_hook_additions is not None: + for key, array in episode_hook_additions.items(): + self.buffer.set_array_at_key( + array, + key, + index=cur_ep_index_array, + ) # preparing for the next iteration - last_obs_RO[env_ind_local_D] = obs_reset_DO - last_info_R[env_ind_local_D] = info_reset_D + last_obs_RO[env_done_local_idx_D] = obs_reset_DO + last_info_R[env_done_local_idx_D] = info_reset_D # Handling the case when we have more ready envs than desired and are not done yet # @@ -613,14 +765,14 @@ def _collect( # step and we still need to collect the remaining episodes to reach the breaking condition. # creating the mask - env_to_be_ignored_ind_local_S = env_ind_local_D[:surplus_env_num] + env_to_be_ignored_ind_local_S = env_done_local_idx_D[:surplus_env_num] env_should_remain_R = np.ones_like(ready_env_ids_R, dtype=bool) env_should_remain_R[env_to_be_ignored_ind_local_S] = False # stripping the "idle" indices, shortening the relevant quantities from R to R-S ready_env_ids_R = ready_env_ids_R[env_should_remain_R] last_obs_RO = last_obs_RO[env_should_remain_R] last_info_R = last_info_R[env_should_remain_R] - if hidden_state_RH is not None: + if collect_action_computation_batch_R.hidden_state is not None: last_hidden_state_RH = last_hidden_state_RH[env_should_remain_R] # type: ignore[index] if (n_step and step_count >= n_step) or ( @@ -643,6 +795,15 @@ def _collect( # reset envs and the _pre_collect fields self.reset_env(gym_reset_kwargs) # todo still necessary? + if self.buffer.hasnull(): + nan_batch = self.buffer.isnull().apply_array_func(np.sum) + + raise MalformedBufferError( + "NaN detected in the buffer. You can drop them with `buffer.dropnull()`. " + "Here an overview of the number of NaNs per field: \n" + f"{nan_batch}", + ) + return CollectStats.with_autogenerated_stats( returns=np.array(episode_returns), lens=np.array(episode_lens), @@ -652,6 +813,38 @@ def _collect( collect_speed=step_count / collect_time, ) + # TODO: move to buffer + def _get_buffer_index_and_entries( + self, + entries_slice: slice, + ) -> tuple[np.ndarray, RolloutBatchProtocol]: + """ + :param entries_slice: a slice object that selects the entries from the buffer. + `stop` can be smaller than `start`, meaning that a sub-buffer edge is to be crossed + :return: The indices of the entries in the buffer and the corresponding batch of entries. + """ + start, stop = entries_slice.start, entries_slice.stop + if stop > start: + cur_ep_index_array = np.arange( + entries_slice.start, + entries_slice.stop, + dtype=int, + ) + else: + (start, upper_edge), ( + lower_edge, + stop, + ) = self._get_start_stop_tuples_for_edge_crossing_interval( + start, + stop, + ) + cur_ep_index_array = np.concatenate( + (np.arange(start, upper_edge, dtype=int), np.arange(lower_edge, stop, dtype=int)), + ) + log.debug(f"{start=}, {upper_edge=}, {lower_edge=}, {stop=}") + ep_rollout_batch = self.buffer[cur_ep_index_array] + return cur_ep_index_array, ep_rollout_batch + @staticmethod def _reset_hidden_state_based_on_type( env_ind_local_D: np.ndarray, @@ -707,13 +900,21 @@ def __init__( self._current_action_in_all_envs_EA: np.ndarray = np.empty(self.env_num) self._current_policy_in_all_envs_E: Batch | None = None - @override def reset( self, reset_buffer: bool = True, reset_stats: bool = True, gym_reset_kwargs: dict[str, Any] | None = None, ) -> tuple[np.ndarray, np.ndarray]: + """Reset the environment, statistics, and data needed to start the collection. + + :param reset_buffer: if true, reset the replay buffer attached + to the collector. + :param reset_stats: if true, reset the statistics attached to the collector. + :param gym_reset_kwargs: extra keyword arguments to pass into the environment's + reset function. Defaults to None (extra keyword arguments) + :return: The initial observation and info from the environment. + """ # This sets the _pre_collect attrs result = super().reset( reset_buffer=reset_buffer, @@ -798,12 +999,7 @@ def _collect( ) # get the next action - ( - act_RA, - act_normalized_RA, - policy_R, - hidden_state_RH, - ) = self._compute_action_policy_hidden( + collect_batch_R = self._compute_action_policy_hidden( random=random, ready_env_ids_R=ready_env_ids_R, last_obs_RO=last_obs_RO, @@ -812,12 +1008,12 @@ def _collect( ) # save act_RA/policy_R/ hidden_state_RH before env.step - self._current_action_in_all_envs_EA[ready_env_ids_R] = act_RA + self._current_action_in_all_envs_EA[ready_env_ids_R] = collect_batch_R.act if self._current_policy_in_all_envs_E: - self._current_policy_in_all_envs_E[ready_env_ids_R] = policy_R + self._current_policy_in_all_envs_E[ready_env_ids_R] = collect_batch_R.policy_entry else: - self._current_policy_in_all_envs_E = policy_R # first iteration - if hidden_state_RH is not None: + self._current_policy_in_all_envs_E = collect_batch_R.policy_entry # first iteration + if collect_batch_R.hidden_state is not None: if self._current_hidden_state_in_all_envs_EH is not None: # Need to cast since if it's a Tensor, the assignment might in fact fail if hidden_state_RH is not # a tensor as well. This is hard to express with proper typing, even using @overload, so we cheat @@ -826,13 +1022,15 @@ def _collect( np.ndarray | Batch, self._current_hidden_state_in_all_envs_EH, ) - self._current_hidden_state_in_all_envs_EH[ready_env_ids_R] = hidden_state_RH + self._current_hidden_state_in_all_envs_EH[ + ready_env_ids_R + ] = collect_batch_R.hidden_state else: - self._current_hidden_state_in_all_envs_EH = hidden_state_RH + self._current_hidden_state_in_all_envs_EH = collect_batch_R.hidden_state # step in env obs_next_RO, rew_R, terminated_R, truncated_R, info_R = self.env.step( - act_normalized_RA, + collect_batch_R.act_normalized, ready_env_ids_R, ) done_R = np.logical_or(terminated_R, truncated_R) @@ -941,3 +1139,135 @@ def _collect( collect_time=collect_time, collect_speed=step_count / collect_time, ) + + +class StepHookProtocol(Protocol): + """A protocol for step hooks.""" + + def __call__( + self, + action_batch: CollectActionBatchProtocol, + rollout_batch: RolloutBatchProtocol, + ) -> None: + """The function to call when the hook is executed.""" + ... + + +class StepHook(StepHookProtocol, ABC): + """Marker interface for step hooks. + + All step hooks in Tianshou will inherit from it, but only the corresponding protocol will be + used in type hints. This makes it possible to discover all hooks in the codebase by looking up + the hierarchy of this class (or the protocol itself) while still allowing the user to pass + something like a lambda function as a hook. + """ + + @abstractmethod + def __call__( + self, + action_batch: CollectActionBatchProtocol, + rollout_batch: RolloutBatchProtocol, + ) -> None: + ... + + +class StepHookAddActionDistribution(StepHook): + """Adds the action distribution to the collected rollout batch under the field "action_dist". + + The field is also accessible as class variable `ACTION_DIST_KEY`. + This hook be useful for algorithms that need the previously taken actions for training, like variants of + imitation learning or DAGGER. + """ + + ACTION_DIST_KEY = "action_dist" + + def __call__( + self, + action_batch: CollectActionBatchProtocol, + rollout_batch: RolloutBatchProtocol, + ) -> None: + rollout_batch[self.ACTION_DIST_KEY] = action_batch.dist + + +class EpisodeRolloutHookProtocol(Protocol): + """A protocol for hooks (functions) that act on an entire collected episode.""" + + def __call__(self, rollout_batch: RolloutBatchProtocol) -> dict[str, np.ndarray]: + """Compute new entries for the rollout batch and return them as a dictionary. + + The new entries will be added to the episode batch inside the buffer. + """ + ... + + +class EpisodeRolloutHook(EpisodeRolloutHookProtocol, ABC): + """Marker interface for episode hooks. + + All episode hooks in Tianshou will inherit from it, but only the corresponding protocol will be + used in type hints. This makes it possible to discover all hooks in the codebase by looking up + the hierarchy of this class (or the protocol itself) while still allowing the user to pass + something like a lambda function as a hook. + """ + + @abstractmethod + def __call__(self, rollout_batch: RolloutBatchProtocol) -> dict[str, np.ndarray]: + ... + + +class EpisodeRolloutHookMCReturn(EpisodeRolloutHook): + """Adds the MC return to go as well as the full episode MC return to the transitions in the buffer. + + The latter will be constant for all transitions in the same episode and simply corresponds to + the initial MC return to go. Useful for algorithms that rely on the monte carlo returns during training. + """ + + MC_RETURN_TO_GO_KEY = "mc_return_to_go" + FULL_EPISODE_MC_RETURN_KEY = "full_episode_mc_return" + + def __init__(self, gamma: float = 0.99): + if not 0 <= gamma <= 1: + raise ValueError(f"Expected 0 <= gamma <= 1, but got {gamma=}.") + self.gamma = gamma + + def __call__(self, rollout_batch: RolloutBatchProtocol) -> dict[str, np.ndarray]: + mc_return_to_go = episode_mc_return_to_go(rollout_batch.rew, self.gamma) + full_episode_mc_return = mc_return_to_go[0] + return { + self.MC_RETURN_TO_GO_KEY: mc_return_to_go, + self.FULL_EPISODE_MC_RETURN_KEY: np.full_like( + rollout_batch.rew, + full_episode_mc_return, + ), + } + + +class EpisodeRolloutHookMerged(EpisodeRolloutHook): + """Combines multiple episode hooks into a single one.""" + + def __init__( + self, + *rollout_hooks: EpisodeRolloutHookProtocol, + check_overlapping_keys: bool = True, + ): + """:param rollout_hooks: the hooks to combine + :param check_overlapping_keys: whether to check for overlapping keys in the output of the hooks and + raise a `KeyError` if any are found. Set to `False` to disable this check (can be useful + if this becomes a performance bottleneck). + """ + self.rollout_hooks = rollout_hooks + self.check_overlapping_keys = check_overlapping_keys + + def __call__(self, rollout_batch: RolloutBatchProtocol) -> dict[str, np.ndarray]: + result: dict[str, np.ndarray] = {} + for rollout_hook in self.rollout_hooks: + new_entries_dict = rollout_hook(rollout_batch) + if self.check_overlapping_keys and ( + duplicated_entries := set(new_entries_dict).difference(result) + ): + raise KeyError( + f"Combined rollout hook {rollout_hook} leads to previously " + f"computed entries that would be overwritten: {duplicated_entries=}. " + f"Consider combining hooks which will deliver non-overlapping entries to solve this.", + ) + result.update(new_entries_dict) + return result diff --git a/tianshou/utils/array.py b/tianshou/utils/array.py new file mode 100644 index 000000000..621979666 --- /dev/null +++ b/tianshou/utils/array.py @@ -0,0 +1,15 @@ +from collections.abc import Sequence + +import numpy as np + + +def bisect_left(arr: Sequence[float], x: float) -> float: + """Assuming arr is sorted, return the largest element el of arr s.t. el < x.""" + el_index = int(np.searchsorted(arr, x, side="left")) - 1 + return arr[el_index] + + +def bisect_right(arr: Sequence[float], x: float) -> float: + """Assuming arr is sorted, return the smallest element el of arr s.t. el > x.""" + el_index = int(np.searchsorted(arr, x, side="right")) + return arr[el_index] From 404134bc6910fe7832f8d7ca09cb617361b8c166 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Thu, 8 Aug 2024 18:22:01 +0200 Subject: [PATCH 18/33] Minor, typing --- test/base/test_buffer.py | 8 ++++---- test/base/test_utils.py | 4 ++-- tianshou/data/buffer/prio.py | 4 +++- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 781b7e044..fe4301bea 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -114,10 +114,10 @@ def test_replaybuffer(size: int = 10, bufsize: int = 20) -> None: assert b.info.b.c[1] == 0 assert b.info.d.e[1] == -np.inf # test batch-style adding method, where len(batch) == 1 - batch.done = [1] - batch.terminated = [0] - batch.truncated = [1] - batch.info.e = np.zeros([1, 4]) + batch.done = np.array([True]) + batch.terminated = np.array([False]) + batch.truncated = np.array([True]) + batch.info.e = np.zeros([1, 4]) # type: ignore batch: RolloutBatchProtocol = Batch.stack([batch]) ptr, ep_rew, ep_len, ep_idx = b.add(batch, buffer_ids=[0]) assert ptr.shape == (1,) diff --git a/test/base/test_utils.py b/test/base/test_utils.py index 5d85556fe..8e44ad57b 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -212,7 +212,7 @@ def test_distribution_uniformity( distribution = create_uniform_action_dist(action_space, batch_size) # Test 7: Uniform distribution (statistical test) - large_sample = distribution.sample((10000,)) + large_sample = distribution.sample(torch.Size((10000,))) if isinstance(action_space, spaces.Box): # For Box, check if mean is close to 0 and std is close to 1/sqrt(3) assert torch.allclose(large_sample.mean(), torch.tensor(0.0), atol=0.1) @@ -227,7 +227,7 @@ def test_distribution_uniformity( def test_unsupported_space(self) -> None: # Test 6: Raises ValueError for unsupported space with pytest.raises(ValueError): - create_uniform_action_dist(spaces.MultiBinary(5)) + create_uniform_action_dist(spaces.MultiBinary(5)) # type: ignore @pytest.mark.parametrize( "space, batch_size, expected_shape, distribution_type", diff --git a/tianshou/data/buffer/prio.py b/tianshou/data/buffer/prio.py index 3936bd641..ac0220fe9 100644 --- a/tianshou/data/buffer/prio.py +++ b/tianshou/data/buffer/prio.py @@ -1,3 +1,4 @@ +from collections.abc import Sequence from typing import Any, cast import numpy as np @@ -89,6 +90,7 @@ def update_weight(self, index: np.ndarray, new_weight: np.ndarray | torch.Tensor self._min_prio = min(self._min_prio, weight.min()) def __getitem__(self, index: IndexType) -> PrioBatchProtocol: + indices: Sequence[int] | np.ndarray if isinstance(index, slice): # change slice to np array # buffer[:] will get all available data indices = ( @@ -97,7 +99,7 @@ def __getitem__(self, index: IndexType) -> PrioBatchProtocol: else self._indices[: len(self)][index] ) else: - indices = index + indices = cast(np.ndarray, index) batch = super().__getitem__(indices) weight = self.get_weight(indices) # ref: https://github.com/Kaixhin/Rainbow/blob/master/memory.py L154 From d6e3d0a7cbc010e758186d6967a106c5b175f974 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Thu, 8 Aug 2024 18:40:20 +0200 Subject: [PATCH 19/33] Imports, docs --- docs/spelling_wordlist.txt | 4 ++++ tianshou/policy/modelfree/sac.py | 6 +++--- tianshou/trainer/base.py | 5 +++-- tianshou/utils/array.py | 4 ++-- 4 files changed, 12 insertions(+), 7 deletions(-) diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index a758c4769..fa5a0066d 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -282,3 +282,7 @@ autocompletion codebase indexable sliceable +gaussian +logprob +monte +carlo diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index 2348890d3..6ec3cfe9a 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -28,8 +28,8 @@ def correct_log_prob_gaussian_tanh( ) -> torch.Tensor: """Apply correction for Tanh squashing when computing logprob from Gaussian. - You can check out the original SAC paper (arXiv 1801.01290): Eq 21. - in appendix C to get some understanding of this equation. + See the original SAC paper (arXiv 1801.01290): Equation 21. + :param log_prob: log probability of the action :param squashed_action: tanh-squashed action :param eps: epsilon for numerical stability @@ -83,7 +83,7 @@ class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]): # t can be either "clip" (for simply clipping the action) or empty string for no bounding. Only used if the action_space is continuous. This parameter is ignored in SAC, which used tanh squashing after sampling - unbounded from the gaussian policy (as in (arXiv 1801.01290): Eq 21.). + unbounded from the gaussian policy (as in (arXiv 1801.01290): Equation 21.). :param observation_space: Env's observation space. :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in optimizer in each policy.update() diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index 76805c0b2..49990f855 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -8,8 +8,6 @@ import numpy as np import tqdm -from data.buffer.base import MalformedBufferError -from utils.torch_utils import policy_within_training_step from tianshou.data import ( AsyncCollector, @@ -19,6 +17,7 @@ ReplayBuffer, SequenceSummaryStats, ) +from tianshou.data.buffer.base import MalformedBufferError from tianshou.data.collector import BaseCollector, CollectStatsBase from tianshou.policy import BasePolicy from tianshou.policy.base import TrainingStats @@ -29,6 +28,7 @@ MovAvg, ) from tianshou.utils.logging import set_numerical_fields_to_precision +from tianshou.utils.torch_utils import policy_within_training_step log = logging.getLogger(__name__) @@ -583,6 +583,7 @@ def run(self, reset_prior_to_run: bool = True, reset_buffer: bool = False) -> In See itertools - recipes. Use functions that consume iterators at C speed (feed the entire iterator into a zero-length deque). + :param reset_prior_to_run: whether to reset collectors prior to run :param reset_buffer: only has effect if `reset_prior_to_run` is True. Then it will also reset the buffer. This is usually not necessary, use diff --git a/tianshou/utils/array.py b/tianshou/utils/array.py index 621979666..ee93369b5 100644 --- a/tianshou/utils/array.py +++ b/tianshou/utils/array.py @@ -4,12 +4,12 @@ def bisect_left(arr: Sequence[float], x: float) -> float: - """Assuming arr is sorted, return the largest element el of arr s.t. el < x.""" + """Assuming arr is sorted, return the largest element `el` of arr s.t. `el < x`.""" el_index = int(np.searchsorted(arr, x, side="left")) - 1 return arr[el_index] def bisect_right(arr: Sequence[float], x: float) -> float: - """Assuming arr is sorted, return the smallest element el of arr s.t. el > x.""" + """Assuming arr is sorted, return the smallest element `el` of arr s.t. `el > x`.""" el_index = int(np.searchsorted(arr, x, side="right")) return arr[el_index] From 9bd17be35377493ead0d48ea699a2cb6ce272b90 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Fri, 9 Aug 2024 18:49:52 +0200 Subject: [PATCH 20/33] Trainer: don't rely on progress bar for terminating loops --- tianshou/trainer/base.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index 49990f855..3fdd2a9d4 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -196,7 +196,10 @@ def __init__( self.env_episode = 0 self.policy_update_time = 0.0 self.max_epoch = max_epoch - self.step_per_epoch = step_per_epoch + assert ( + step_per_epoch is not None + ), "The trainer requires step_per_epoch to be set, sorry for the wrong type hint" + self.step_per_epoch: int = step_per_epoch # either on of these two self.step_per_collect = step_per_collect @@ -311,9 +314,10 @@ def __next__(self) -> EpochStats: raise StopIteration # perform n step_per_epoch + steps_done_in_this_epoch = 0 with self._pbar(total=self.step_per_epoch, desc=f"Epoch #{self.epoch}", position=1) as t: train_stat: CollectStatsBase - while t.n < t.total and not self.stop_fn_flag: + while steps_done_in_this_epoch < self.step_per_epoch and not self.stop_fn_flag: train_stat, update_stat, self.stop_fn_flag = self.training_step() if isinstance(train_stat, CollectStats): @@ -325,7 +329,11 @@ def __next__(self) -> EpochStats: "n/ep": str(train_stat.n_collected_episodes), "n/st": str(train_stat.n_collected_steps), } + + # t might be disabled, we track the steps manually t.update(train_stat.n_collected_steps) + steps_done_in_this_epoch += train_stat.n_collected_steps + if self.stop_fn_flag: t.set_postfix(**pbar_data_dict) else: @@ -336,7 +344,10 @@ def __next__(self) -> EpochStats: train_stat = CollectStatsBase( n_collected_steps=len(self.buffer), ) + + # t might be disabled, we track the steps manually t.update() + steps_done_in_this_epoch += 1 pbar_data_dict = set_numerical_fields_to_precision(pbar_data_dict) pbar_data_dict["gradient_step"] = str(self._gradient_step) @@ -345,8 +356,10 @@ def __next__(self) -> EpochStats: if self.stop_fn_flag: break - if t.n <= t.total and not self.stop_fn_flag: + if steps_done_in_this_epoch <= self.step_per_epoch and not self.stop_fn_flag: + # t might be disabled, we track the steps manually t.update() + steps_done_in_this_epoch += 1 # for offline RL if self.train_collector is None: From 0bb8e9c527c40bb054de70e22bd067f2a51785c8 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Sat, 10 Aug 2024 15:53:39 +0200 Subject: [PATCH 21/33] Collector: improved rollout hooks interfaces and docstrings Also removed some unnecessary indirections --- tianshou/data/collector.py | 144 +++++++++++++++++++++++++------------ 1 file changed, 97 insertions(+), 47 deletions(-) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index a813d8bbc..2b439b4c5 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -5,7 +5,7 @@ from collections.abc import Sequence from copy import copy from dataclasses import dataclass -from typing import Any, Optional, Protocol, Self, TypeVar, cast +from typing import Any, Optional, Protocol, Self, TypedDict, TypeVar, cast import gymnasium as gym import numpy as np @@ -32,7 +32,6 @@ from tianshou.env import BaseVectorEnv, DummyVectorEnv from tianshou.policy import BasePolicy from tianshou.policy.base import episode_mc_return_to_go -from tianshou.utils.array import bisect_left, bisect_right from tianshou.utils.print import DataclassPPrintMixin from tianshou.utils.torch_utils import torch_train_mode @@ -41,22 +40,6 @@ _TArrLike = TypeVar("_TArrLike", bound="np.ndarray | torch.Tensor | Batch | None") -def _get_start_stop_tuples_around_edges( - edges: Sequence[int], - start: int, - stop: int, -) -> tuple[tuple[int, int], tuple[int, int]]: - """We assume that stop is smaller than start and that `edges` is a sorted array of integers. - - Then it will return the two tuples containing (start, stop) where we go from start to the next edge, - and from the previous edge to stop. - :return: (start, upper_edge), (lower_edge, stop) - """ - upper_edge = int(bisect_right(edges, start)) - lower_edge = int(bisect_left(edges, stop)) - return (start, upper_edge), (lower_edge, stop) - - @dataclass(kw_only=True) class CollectStatsBase(DataclassPPrintMixin): """The most basic stats, often used for offline learning.""" @@ -216,17 +199,51 @@ def _get_start_stop_tuples_for_edge_crossing_interval( start: int, stop: int, ) -> tuple[tuple[int, int], tuple[int, int]]: - """:return: (start, upper_edge), (lower_edge, stop)""" + """Assumes that stop < start and retrieves tuples corresponding to the two + slices that determine the interval within the buffer. + + Example: + ------- + >>> list(self._subbuffer_edges) == [0, 5, 10] + >>> start = 4 + >>> stop = 2 + >>> self._get_start_stop_tuples_for_edge_crossing_interval(start, stop) + ((4, 5), (0, 2)) + + The buffer sliced from 4 to 5 and then from 0 to 2 will contain the transitions + corresponding to the provided start and stop values. + """ log.debug( - "Received an edge-crossing episode: {start=}, {stop=}, {self._subbuffer_edges=}", + f"Received an edge-crossing episode: {start=}, {stop=}, {self._subbuffer_edges=}", ) if stop >= start: raise ValueError( f"Expected stop < start, but got {start=}, {stop=}. " - f"For stop larger than start this should never be used, and stop=start should never occur.", + f"For stop larger than start this method should never be called, " + f"and stop=start should never occur. This can occur either due to an implementation error, " + f"or due a bad configuration of the buffer that resulted in a single episode being so long that " + f"it completely filled a subbuffer (of size len(buffer)/degree_of_vectorization). " + f"Consider either shortening the episode, increasing the size of the buffer, or decreasing the " + f"degree of vectorization.", ) subbuffer_edges = cast(Sequence[int], self._subbuffer_edges) - return _get_start_stop_tuples_around_edges(subbuffer_edges, start, stop) + + edge_after_start_idx = int(np.searchsorted(subbuffer_edges, start, side="left")) + """This is the crossed edge""" + + if edge_after_start_idx == 0: + raise ValueError( + f"The start value should be larger than the first edge, but got {start=}, {subbuffer_edges[1]=}.", + ) + edge_after_start = subbuffer_edges[edge_after_start_idx] + edge_before_stop = subbuffer_edges[edge_after_start_idx - 1] + """It's the edge before the crossed edge""" + + if edge_before_stop >= stop: + raise ValueError( + f"The edge before the crossed edge should be smaller than the stop, but got {edge_before_stop=}, {stop=}.", + ) + return (start, edge_after_start), (edge_before_stop, stop) def _validate_buffer(self) -> None: buf = self.buffer @@ -425,7 +442,7 @@ def __init__( exploration noise into action. Default to False. :param on_episode_done_hook: if passed, will be executed when an episode is done. The input to the hook will be a `RolloutBatch` that contains the entire episode (and nothing else). - The dict returned by the hook will be used to add new entries to the buffer + If a dict is returned by the hook will be used to add new entries to the buffer for the episode that just ended. The hook should return arrays with floats which should be of the same length as the input rollout batch. If you have multiple hooks, you can use the `CombinedRolloutHook` class to combine them. @@ -730,9 +747,16 @@ def _collect( # noqa: C901 ) episode_hook_additions = self.run_on_episode_done(ep_rollout_batch) if episode_hook_additions is not None: - for key, array in episode_hook_additions.items(): + if n_episode is not None: + raise ValueError( + "An on_episode_done_hook with non-empty returns is not supported for n_step collection." + "Such hooks should only be used when collecting full episodes. Got a on_episode_done_hook " + f"that would add the following fields to the buffer: {list(episode_hook_additions)}.", + ) + + for key, episode_addition in episode_hook_additions.items(): self.buffer.set_array_at_key( - array, + episode_addition, key, index=cur_ep_index_array, ) @@ -988,8 +1012,8 @@ def _collect( ) if ( not len(self._current_obs_in_all_envs_EO) - == len(self._current_action_in_all_envs_EA) - == self.env_num + == len(self._current_action_in_all_envs_EA) + == self.env_num ): # major difference raise RuntimeError( f"{len(self._current_obs_in_all_envs_EO)=} and" @@ -1190,12 +1214,22 @@ def __call__( class EpisodeRolloutHookProtocol(Protocol): - """A protocol for hooks (functions) that act on an entire collected episode.""" + """A protocol for hooks (functions) that act on an entire collected episode. - def __call__(self, rollout_batch: RolloutBatchProtocol) -> dict[str, np.ndarray]: - """Compute new entries for the rollout batch and return them as a dictionary. + Can be used to add values to the buffer that are only known after the episode is finished. + A prime example is something like the MC return to go. + """ - The new entries will be added to the episode batch inside the buffer. + def __call__(self, rollout_batch: RolloutBatchProtocol) -> dict[str, np.ndarray] | None: + """Will be called by the collector when an episode is finished. + + If a dictionary is returned, the key-value pairs will be interpreted as new entries + to be added to the episode batch (inside the buffer). In that case, + the values should be arrays of the same length as the input `rollout_batch`. + + :param rollout_batch: the batch of transitions that belong to the episode. + :return: an optional dictionary containing new entries (of same len as `rollout_batch`) + to be added to the buffer. """ ... @@ -1210,7 +1244,7 @@ class EpisodeRolloutHook(EpisodeRolloutHookProtocol, ABC): """ @abstractmethod - def __call__(self, rollout_batch: RolloutBatchProtocol) -> dict[str, np.ndarray]: + def __call__(self, rollout_batch: RolloutBatchProtocol) -> dict[str, np.ndarray] | None: ... @@ -1224,32 +1258,43 @@ class EpisodeRolloutHookMCReturn(EpisodeRolloutHook): MC_RETURN_TO_GO_KEY = "mc_return_to_go" FULL_EPISODE_MC_RETURN_KEY = "full_episode_mc_return" + + class OutputDict(TypedDict): + mc_return_to_go: np.ndarray + full_episode_mc_return: np.ndarray + + def __init__(self, gamma: float = 0.99): if not 0 <= gamma <= 1: raise ValueError(f"Expected 0 <= gamma <= 1, but got {gamma=}.") self.gamma = gamma - def __call__(self, rollout_batch: RolloutBatchProtocol) -> dict[str, np.ndarray]: + def __call__( # type: ignore[override] + self, + rollout_batch: RolloutBatchProtocol, + ) -> "EpisodeRolloutHookMCReturn.OutputDict": mc_return_to_go = episode_mc_return_to_go(rollout_batch.rew, self.gamma) - full_episode_mc_return = mc_return_to_go[0] - return { - self.MC_RETURN_TO_GO_KEY: mc_return_to_go, - self.FULL_EPISODE_MC_RETURN_KEY: np.full_like( - rollout_batch.rew, - full_episode_mc_return, - ), - } + full_episode_mc_return = np.full_like(mc_return_to_go, mc_return_to_go[0]) + + return self.OutputDict( + mc_return_to_go=mc_return_to_go, + full_episode_mc_return=full_episode_mc_return, + ) class EpisodeRolloutHookMerged(EpisodeRolloutHook): - """Combines multiple episode hooks into a single one.""" + """Combines multiple episode hooks into a single one. + + If all hooks return `None`, this hook will also return `None`. + """ def __init__( self, *rollout_hooks: EpisodeRolloutHookProtocol, check_overlapping_keys: bool = True, ): - """:param rollout_hooks: the hooks to combine + """ + :param rollout_hooks: the hooks to combine :param check_overlapping_keys: whether to check for overlapping keys in the output of the hooks and raise a `KeyError` if any are found. Set to `False` to disable this check (can be useful if this becomes a performance bottleneck). @@ -1257,17 +1302,22 @@ def __init__( self.rollout_hooks = rollout_hooks self.check_overlapping_keys = check_overlapping_keys - def __call__(self, rollout_batch: RolloutBatchProtocol) -> dict[str, np.ndarray]: + def __call__(self, rollout_batch: RolloutBatchProtocol) -> dict[str, np.ndarray] | None: result: dict[str, np.ndarray] = {} for rollout_hook in self.rollout_hooks: - new_entries_dict = rollout_hook(rollout_batch) + new_entries = rollout_hook(rollout_batch) + if new_entries is None: + continue + if self.check_overlapping_keys and ( - duplicated_entries := set(new_entries_dict).difference(result) + duplicated_entries := set(new_entries).difference(result) ): raise KeyError( f"Combined rollout hook {rollout_hook} leads to previously " f"computed entries that would be overwritten: {duplicated_entries=}. " f"Consider combining hooks which will deliver non-overlapping entries to solve this.", ) - result.update(new_entries_dict) + result.update(new_entries) + if not result: + return None return result From 4d815208b8d38c48ddd2235700b0ae3ce1539e7e Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Sat, 10 Aug 2024 15:53:52 +0200 Subject: [PATCH 22/33] Removed no longer needed array module --- tianshou/utils/array.py | 15 --------------- 1 file changed, 15 deletions(-) delete mode 100644 tianshou/utils/array.py diff --git a/tianshou/utils/array.py b/tianshou/utils/array.py deleted file mode 100644 index ee93369b5..000000000 --- a/tianshou/utils/array.py +++ /dev/null @@ -1,15 +0,0 @@ -from collections.abc import Sequence - -import numpy as np - - -def bisect_left(arr: Sequence[float], x: float) -> float: - """Assuming arr is sorted, return the largest element `el` of arr s.t. `el < x`.""" - el_index = int(np.searchsorted(arr, x, side="left")) - 1 - return arr[el_index] - - -def bisect_right(arr: Sequence[float], x: float) -> float: - """Assuming arr is sorted, return the smallest element `el` of arr s.t. `el > x`.""" - el_index = int(np.searchsorted(arr, x, side="right")) - return arr[el_index] From 90d395959296c866ae5e66227366be9f0db4554c Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Sat, 10 Aug 2024 15:54:15 +0200 Subject: [PATCH 23/33] Aesthetic: docstrings, var names --- tianshou/highlevel/config.py | 25 ++++++++++++++++++------- tianshou/policy/modelfree/sac.py | 14 ++++++-------- 2 files changed, 24 insertions(+), 15 deletions(-) diff --git a/tianshou/highlevel/config.py b/tianshou/highlevel/config.py index ec702e373..287800d0c 100644 --- a/tianshou/highlevel/config.py +++ b/tianshou/highlevel/config.py @@ -74,6 +74,9 @@ class SamplingConfig(ToStringMixin): """ the number of environment steps/transitions to collect in each collection step before the network update within each training step. + + This is mutually exclusive with :attr:`episode_per_collect`, and one of the two must be set. + Note that the exact number can be reached only if this is a multiple of the number of training environments being used, as each training environment will produce the same (non-zero) number of transitions. @@ -89,13 +92,15 @@ class SamplingConfig(ToStringMixin): the number of episodes to collect in each collection step before the network update within each training step. If this is set, the number of environment steps collected in each collection step is the sum of the lengths of the episodes collected. + + This is mutually exclusive with :attr:`step_per_collect`, and one of the two must be set. """ repeat_per_collect: int | None = 1 """ controls, within one gradient update step of an on-policy algorithm, the number of times an actual gradient update is applied using the full collected dataset, i.e. if the parameter is - `n`, then the collected data shall be used five times to update the policy within the same + 5, then the collected data shall be used five times to update the policy within the same training step. The parameter is ignored and may be set to None for off-policy and offline algorithms. @@ -136,12 +141,14 @@ class SamplingConfig(ToStringMixin): """ replay_buffer_save_only_last_obs: bool = False - """if True, only the most recent frame is saved when appending to experiences rather than the - full stacked frames. This avoids duplicating observations in buffer memory. Set to False to - save stacked frames in full. + """if True, for the case where the environment outputs stacked frames (e.g. because it + is using a `FrameStack` wrapper), save only the most recent frame so as not to duplicate + observations in buffer memory. Specifically, if the environment outputs observations `obs` with + shape (N, ...), only obs[-1] of shape (...) will be stored. + Frame stacking with a fixed number of frames can then be recreated at the buffer level by setting + :attr:`replay_buffer_stack_num`. - Note: typically used together with `replay_buffer_stack_num`, see documentation there. - Currently only used in Atari examples and may be removed in the future! + Note: Currently only used in Atari examples and may be removed in the future! """ replay_buffer_stack_num: int = 1 @@ -151,7 +158,11 @@ class SamplingConfig(ToStringMixin): temporal aspects (e.g. velocities of moving objects for which only positions are observed). Note: it is recommended to do this stacking on the environment level by using something like - gymnasium's `FrameStack` instead. Currently only used in Atari examples and may be removed in the future! + gymnasium's `FrameStack` instead. Setting this to larger than one in conjunction + with :attr:`replay_buffer_save_only_last_obs` means that + stacking will be recreated at the buffer level, which is more memory-efficient. + + Currently only used in Atari examples and may be removed in the future! """ @property diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index 6ec3cfe9a..c1f19eff7 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -23,21 +23,19 @@ def correct_log_prob_gaussian_tanh( log_prob: torch.Tensor, - squashed_action: torch.Tensor, + tanh_squashed_action: torch.Tensor, eps: float = np.finfo(np.float32).eps.item(), ) -> torch.Tensor: - """Apply correction for Tanh squashing when computing logprob from Gaussian. + """Apply correction for Tanh squashing when computing `log_prob` from Gaussian. - See the original SAC paper (arXiv 1801.01290): Equation 21. + See equation 21 in the original `SAC paper `_. :param log_prob: log probability of the action - :param squashed_action: tanh-squashed action + :param tanh_squashed_action: action squashed to values in (-1, 1) range by tanh :param eps: epsilon for numerical stability """ - return log_prob - torch.log((1 - squashed_action.pow(2)) + eps).sum( - -1, - keepdim=True, - ) + log_prob_correction = torch.log(1 - tanh_squashed_action.pow(2) + eps).sum(-1, keepdim=True) + return log_prob - log_prob_correction @dataclass(kw_only=True) From f6134d25303e2fa705f4244b95839b2c811f8ad7 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Sun, 18 Aug 2024 12:00:37 +0200 Subject: [PATCH 24/33] Collector: use a larger default buffer size (previous default didn't make sense) --- tianshou/data/collector.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 9f4ec1952..d565a3e0b 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -37,6 +37,8 @@ log = logging.getLogger(__name__) +DEFAULT_BUFFER_MAXSIZE = int(1e4) + _TArrLike = TypeVar("_TArrLike", bound="np.ndarray | torch.Tensor | Batch | None") @@ -177,7 +179,7 @@ def __init__( env = DummyVectorEnv([lambda: env]) # type: ignore if buffer is None: - buffer = VectorReplayBuffer(len(env), len(env)) + buffer = VectorReplayBuffer(DEFAULT_BUFFER_MAXSIZE * len(env), len(env)) self.buffer: ReplayBuffer | ReplayBufferManager = buffer self.policy = policy From 872183e420b22f96503d1ef45e04e4a8282458cd Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Sun, 18 Aug 2024 14:40:19 +0200 Subject: [PATCH 25/33] Batch: possibility to get len of batches with dist. Added a test --- test/base/test_batch.py | 15 +++++++++++++++ tianshou/data/batch.py | 19 ++++++++++++++++--- 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 8839fe482..7e250e881 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -881,3 +881,18 @@ def test_getitem_with_slice_gives_subslice(index: IndexType) -> None: batch_sliced = batch[index] assert (batch_sliced.a == batch.a[index]).all() assert (batch_sliced.b.c == batch.b.c[index]).all() + + @staticmethod + def test_len_batch_with_dist() -> None: + batch_with_dist = Batch(a=[1, 2, 3], dist=Categorical(torch.ones((3, 3))), b=None) + batch_with_dist_sliced = batch_with_dist[:2] + assert batch_with_dist_sliced.b is None + assert len(batch_with_dist_sliced) == 2 + assert np.array_equal(batch_with_dist_sliced.a, np.array([1, 2])) + assert torch.allclose( + batch_with_dist_sliced.dist.probs, Categorical(torch.ones(2, 3)).probs + ) + + with pytest.raises(TypeError): + # scalar batches have no len + len(batch_with_dist[0]) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 38c7d0465..079da5ab0 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -278,6 +278,13 @@ def get_sliced_dist(dist: TDistribution, index: IndexType) -> TDistribution: raise NotImplementedError(f"Unsupported distribution for slicing: {dist}") +def get_len_of_dist(dist: Distribution) -> int: + """Return the length (typically batch size) of a distribution object.""" + if len(dist.batch_shape) == 0: + raise TypeError(f"scalar Distribution has no length: {dist=}") + return dist.batch_shape[0] + + # Note: This is implemented as a protocol because the interface # of Batch is always extended by adding new fields. Having a hierarchy of # protocols building off this one allows for type safety and IDE support despite @@ -1141,17 +1148,23 @@ def update(self, batch: dict | Self | None = None, **kwargs: Any) -> None: self.update(kwargs) def __len__(self) -> int: + """Raises `TypeError` if any value in the batch has no len(), typically meaning it's a batch of scalars.""" lens = [] - for obj in self.__dict__.values(): + for key, obj in self.__dict__.items(): # TODO: causes inconsistent behavior to batch with empty batches # and batch with empty sequences of other type. Remove, but only after # Buffer and Collectors have been improved to no longer rely on this if isinstance(obj, Batch) and len(obj) == 0: continue + if obj is None: + continue if hasattr(obj, "__len__") and (isinstance(obj, Batch) or obj.ndim > 0): lens.append(len(obj)) - else: - raise TypeError(f"Object {obj} in {self} has no len()") + continue + if isinstance(obj, Distribution): + lens.append(get_len_of_dist(obj)) + continue + raise TypeError(f"Entry for {key} in {self} is {obj}has no len()") if not lens: return 0 return min(lens) From 8a0f536dc3b1dfd6834e265bd778eff86b011a47 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Sun, 18 Aug 2024 14:40:45 +0200 Subject: [PATCH 26/33] Test env: aesthetic --- test/base/env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/base/env.py b/test/base/env.py index ec81554b0..618252eb0 100644 --- a/test/base/env.py +++ b/test/base/env.py @@ -168,7 +168,7 @@ def step(self, action: np.ndarray | int): # type: ignore[no-untyped-def] # cf. False, info_dict, ) - return None + raise ValueError(f"Invalid action {action}") class NXEnv(gym.Env): From 9997f8cea261db2205187ddbf4f96388eb44aa21 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Sun, 18 Aug 2024 14:42:44 +0200 Subject: [PATCH 27/33] Buffer: better names for vars (no functional change) --- tianshou/data/buffer/base.py | 12 ++++++------ tianshou/data/buffer/cached.py | 25 ++++++++++++++----------- 2 files changed, 20 insertions(+), 17 deletions(-) diff --git a/tianshou/data/buffer/base.py b/tianshou/data/buffer/base.py index e03a5d602..1ccf956e0 100644 --- a/tianshou/data/buffer/base.py +++ b/tianshou/data/buffer/base.py @@ -219,17 +219,17 @@ def update(self, buffer: "ReplayBuffer") -> np.ndarray: buffer.stack_num = stack_num if len(from_indices) == 0: return np.array([], int) - to_indices = [] + updated_indices = [] for _ in range(len(from_indices)): - to_indices.append(self._insertion_idx) + updated_indices.append(self._insertion_idx) self.last_index[0] = self._insertion_idx self._insertion_idx = (self._insertion_idx + 1) % self.maxsize self._size = min(self._size + 1, self.maxsize) - to_indices = np.array(to_indices) + updated_indices = np.array(updated_indices) if len(self._meta.get_keys()) == 0: self._meta = create_value(buffer._meta, self.maxsize, stack=False) # type: ignore - self._meta[to_indices] = buffer._meta[from_indices] - return to_indices + self._meta[updated_indices] = buffer._meta[from_indices] + return updated_indices def _update_state_pre_add( self, @@ -300,7 +300,7 @@ def add( :param buffer_ids: to make consistent with other buffer's add function; if it is not None, we assume the input batch's first dimension is always 1. - Return (current_index, episode_reward, episode_length, episode_start_index). If + Return (current_index, episode_return, episode_length, episode_start_index). If the episode is not finished, the return value of episode_length and episode_reward is 0. """ diff --git a/tianshou/data/buffer/cached.py b/tianshou/data/buffer/cached.py index 97e0a8054..6c7a6e174 100644 --- a/tianshou/data/buffer/cached.py +++ b/tianshou/data/buffer/cached.py @@ -59,24 +59,27 @@ def add( cached_buffer_ids[i]th cached buffer's corresponding episode result. """ if buffer_ids is None: - buf_arr = np.arange(1, 1 + self.cached_buffer_num) - else: # make sure it is np.ndarray - buf_arr = np.asarray(buffer_ids) + 1 - ptr, ep_rew, ep_len, ep_idx = super().add(batch, buffer_ids=buf_arr) + cached_buffer_ids = np.arange(1, 1 + self.cached_buffer_num) + else: # make sure it is np.ndarray, +1 means it's never the main buffer + cached_buffer_ids = np.asarray(buffer_ids) + 1 + insertion_idx, ep_return, ep_len, ep_start_idx = super().add( + batch, + buffer_ids=cached_buffer_ids, + ) # find the terminated episode, move data from cached buf to main buf - updated_ptr, updated_ep_idx = [], [] + updated_insertion_idx, updated_ep_start_idx = [], [] done = np.logical_or(batch.terminated, batch.truncated) - for buffer_idx in buf_arr[done]: + for buffer_idx in cached_buffer_ids[done]: index = self.main_buffer.update(self.buffers[buffer_idx]) if len(index) == 0: # unsuccessful move, replace with -1 index = [-1] - updated_ep_idx.append(index[0]) - updated_ptr.append(index[-1]) + updated_ep_start_idx.append(index[0]) + updated_insertion_idx.append(index[-1]) self.buffers[buffer_idx].reset() self._lengths[0] = len(self.main_buffer) self._lengths[buffer_idx] = 0 self.last_index[0] = index[-1] self.last_index[buffer_idx] = self._offset[buffer_idx] - ptr[done] = updated_ptr - ep_idx[done] = updated_ep_idx - return ptr, ep_rew, ep_len, ep_idx + insertion_idx[done] = updated_insertion_idx + ep_start_idx[done] = updated_ep_start_idx + return insertion_idx, ep_return, ep_len, ep_start_idx From 64cdf149a6c8d155ac9428539c78d66727d68e38 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Sun, 18 Aug 2024 17:04:30 +0200 Subject: [PATCH 28/33] Collector: major extension that allows customizing how collection stats are computed The Collector had to become generic to enable this Moved some stats computation to BaseCollector Extended the hook setting/getting and added tests for hooks Added a lot of documentation on the collect method --- test/base/test_collector.py | 137 ++++++++--- tianshou/data/collector.py | 462 +++++++++++++++++++++++++++--------- 2 files changed, 450 insertions(+), 149 deletions(-) diff --git a/test/base/test_collector.py b/test/base/test_collector.py index d03a54df7..6355d8bfc 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -12,14 +12,21 @@ Batch, CachedReplayBuffer, Collector, + CollectStats, PrioritizedReplayBuffer, ReplayBuffer, VectorReplayBuffer, ) from tianshou.data.batch import BatchProtocol +from tianshou.data.collector import ( + CollectActionBatchProtocol, + EpisodeRolloutHookMCReturn, + StepHook, +) from tianshou.data.types import ObsBatchProtocol, RolloutBatchProtocol from tianshou.env import DummyVectorEnv, SubprocVectorEnv from tianshou.policy import BasePolicy, TrainingStats +from tianshou.policy.base import episode_mc_return_to_go try: import envpool @@ -77,6 +84,16 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> Train raise NotImplementedError +@pytest.fixture() +def collector_with_single_env() -> Collector[CollectStats]: + """The env will be a MoveToRightEnv with size 5, sleep 0.""" + env = MoveToRightEnv(size=5, sleep=0) + policy = MaxActionPolicy() + collector = Collector[CollectStats](policy, env, ReplayBuffer(size=100)) + collector.reset() + return collector + + def test_collector() -> None: env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0) for i in [2, 3, 4, 5]] @@ -84,7 +101,7 @@ def test_collector() -> None: dummy_venv_4_envs = DummyVectorEnv(env_fns) policy = MaxActionPolicy() single_env = env_fns[0]() - c_single_env = Collector( + c_single_env = Collector[CollectStats]( policy, single_env, ReplayBuffer(size=100), @@ -135,7 +152,7 @@ def test_collector() -> None: assert np.allclose(c_single_env.buffer.rew[:8], [0, 1, 0, 1, 0, 1, 0, 1]) c_single_env.collect(n_step=3, random=True) - c_subproc_venv_4_envs = Collector( + c_subproc_venv_4_envs = Collector[CollectStats]( policy, subproc_venv_4_envs, VectorReplayBuffer(total_size=100, buffer_num=4), @@ -188,7 +205,7 @@ def test_collector() -> None: assert np.allclose(c_subproc_venv_4_envs.buffer.rew, rews) c_subproc_venv_4_envs.collect(n_episode=4, random=True) - c_dummy_venv_4_envs = Collector( + c_dummy_venv_4_envs = Collector[CollectStats]( policy, dummy_venv_4_envs, VectorReplayBuffer(total_size=100, buffer_num=4), @@ -219,9 +236,9 @@ def test_collector() -> None: # test corner case with pytest.raises(ValueError): - Collector(policy, dummy_venv_4_envs, ReplayBuffer(10)) + Collector[CollectStats](policy, dummy_venv_4_envs, ReplayBuffer(10)) with pytest.raises(ValueError): - Collector(policy, dummy_venv_4_envs, PrioritizedReplayBuffer(10, 0.5, 0.5)) + Collector[CollectStats](policy, dummy_venv_4_envs, PrioritizedReplayBuffer(10, 0.5, 0.5)) with pytest.raises(ValueError): c_dummy_venv_4_envs.collect() @@ -231,7 +248,11 @@ def get_env_factory(i: int, t: str) -> Callable[[], NXEnv]: # test NXEnv for obs_type in ["array", "object"]: envs = SubprocVectorEnv([get_env_factory(i=i, t=obs_type) for i in [5, 10, 15, 20]]) - c_suproc_new = Collector(policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4)) + c_suproc_new = Collector[CollectStats]( + policy, + envs, + VectorReplayBuffer(total_size=100, buffer_num=4), + ) c_suproc_new.reset() c_suproc_new.collect(n_step=6) assert c_suproc_new.buffer.obs.dtype == object @@ -373,7 +394,7 @@ def test_enough_episodes_and_correct_obs_indices_and_obs_next_iterative_collecti def test_collector_with_dict_state() -> None: env = MoveToRightEnv(size=5, sleep=0, dict_state=True) policy = MaxActionPolicy(dict_state=True) - c0 = Collector(policy, env, ReplayBuffer(size=100)) + c0 = Collector[CollectStats](policy, env, ReplayBuffer(size=100)) c0.reset() c0.collect(n_step=3) c0.collect(n_episode=2) @@ -383,7 +404,7 @@ def test_collector_with_dict_state() -> None: envs.seed(666) obs, info = envs.reset() assert not np.isclose(obs[0]["rand"], obs[1]["rand"]) - c1 = Collector( + c1 = Collector[CollectStats]( policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4), @@ -501,7 +522,7 @@ def test_collector_with_dict_state() -> None: 4, ], ), cur_obs.index[..., 0] - c2 = Collector( + c2 = Collector[CollectStats]( policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4), @@ -514,7 +535,7 @@ def test_collector_with_dict_state() -> None: def test_collector_with_multi_agent() -> None: multi_agent_env = MoveToRightEnv(size=5, sleep=0, ma_rew=4) policy = MaxActionPolicy() - c_single_env = Collector(policy, multi_agent_env, ReplayBuffer(size=100)) + c_single_env = Collector[CollectStats](policy, multi_agent_env, ReplayBuffer(size=100)) c_single_env.reset() multi_env_returns = c_single_env.collect(n_step=3).returns # c_single_env has length 3 @@ -528,7 +549,7 @@ def test_collector_with_multi_agent() -> None: env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0, ma_rew=4) for i in [2, 3, 4, 5]] envs = DummyVectorEnv(env_fns) - c_multi_env_ma = Collector( + c_multi_env_ma = Collector[CollectStats]( policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4), @@ -641,7 +662,7 @@ def test_collector_with_multi_agent() -> None: ) assert np.all(c_single_env.buffer[:].rew == [[x] * 4 for x in multi_env_returns]) assert np.all(c_single_env.buffer[:].done == multi_env_returns) - c2 = Collector( + c2 = Collector[CollectStats]( policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4), @@ -664,7 +685,7 @@ def test_collector_with_atari_setting() -> None: # atari single buffer env = MoveToRightEnv(size=5, sleep=0, array_state=True) policy = MaxActionPolicy() - c0 = Collector(policy, env, ReplayBuffer(size=100)) + c0 = Collector[CollectStats](policy, env, ReplayBuffer(size=100)) c0.reset() c0.collect(n_step=6) c0.collect(n_episode=2) @@ -675,14 +696,14 @@ def test_collector_with_atari_setting() -> None: obs[np.arange(15)] = reference_obs[np.arange(15) % 5] assert np.all(obs == c0.buffer.obs) - c1 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=True)) + c1 = Collector[CollectStats](policy, env, ReplayBuffer(size=100, ignore_obs_next=True)) c1.collect(n_episode=3, reset_before_collect=True) assert np.allclose(c0.buffer.obs, c1.buffer.obs) with pytest.raises(AttributeError): c1.buffer.obs_next # noqa: B018 assert np.all(reference_obs[[1, 2, 3, 4, 4] * 3] == c1.buffer[:].obs_next) - c2 = Collector( + c2 = Collector[CollectStats]( policy, env, ReplayBuffer(size=100, ignore_obs_next=True, save_only_last_obs=True), @@ -700,12 +721,12 @@ def test_collector_with_atari_setting() -> None: # atari multi buffer env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0, array_state=True) for i in [2, 3, 4, 5]] envs = DummyVectorEnv(env_fns) - c3 = Collector(policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4)) + c3 = Collector[CollectStats](policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4)) c3.reset() c3.collect(n_step=12) - result = c3.collect(n_episode=9) - assert result.n_collected_episodes == 9 - assert result.n_collected_steps == 23 + result_cached_buffer_collect_9_episodes = c3.collect(n_episode=9) + assert result_cached_buffer_collect_9_episodes.n_collected_episodes == 9 + assert result_cached_buffer_collect_9_episodes.n_collected_steps == 23 assert c3.buffer.obs.shape == (100, 4, 84, 84) obs = np.zeros_like(c3.buffer.obs) obs[np.arange(8)] = reference_obs[[0, 1, 0, 1, 0, 1, 0, 1]] @@ -719,7 +740,7 @@ def test_collector_with_atari_setting() -> None: obs_next[np.arange(50, 58)] = reference_obs[[1, 2, 3, 4, 1, 2, 3, 4]] obs_next[np.arange(75, 85)] = reference_obs[[1, 2, 3, 4, 5, 1, 2, 3, 4, 5]] assert np.all(obs_next == c3.buffer.obs_next) - c4 = Collector( + c4 = Collector[CollectStats]( policy, envs, VectorReplayBuffer( @@ -732,9 +753,9 @@ def test_collector_with_atari_setting() -> None: ) c4.reset() c4.collect(n_step=12) - result = c4.collect(n_episode=9) - assert result.n_collected_episodes == 9 - assert result.n_collected_steps == 23 + result_cached_buffer_collect_9_episodes = c4.collect(n_episode=9) + assert result_cached_buffer_collect_9_episodes.n_collected_episodes == 9 + assert result_cached_buffer_collect_9_episodes.n_collected_steps == 23 assert c4.buffer.obs.shape == (100, 84, 84) obs = np.zeros_like(c4.buffer.obs) slice_obs = reference_obs[:, -1] @@ -796,14 +817,14 @@ def test_collector_with_atari_setting() -> None: assert np.all(obs_next == c4.buffer[:].obs_next) buf = ReplayBuffer(100, stack_num=4, ignore_obs_next=True, save_only_last_obs=True) - c5 = Collector(policy, envs, CachedReplayBuffer(buf, 4, 10)) - c5.reset() - result_ = c5.collect(n_step=12) + collector_cached_buffer = Collector[CollectStats](policy, envs, CachedReplayBuffer(buf, 4, 10)) + collector_cached_buffer.reset() + result_cached_buffer_collect_12_steps = collector_cached_buffer.collect(n_step=12) assert len(buf) == 5 - assert len(c5.buffer) == 12 - result = c5.collect(n_episode=9) - assert result.n_collected_episodes == 9 - assert result.n_collected_steps == 23 + assert len(collector_cached_buffer.buffer) == 12 + result_cached_buffer_collect_9_episodes = collector_cached_buffer.collect(n_episode=9) + assert result_cached_buffer_collect_9_episodes.n_collected_episodes == 9 + assert result_cached_buffer_collect_9_episodes.n_collected_steps == 23 assert len(buf) == 35 assert np.all( buf.obs[: len(buf)] @@ -889,17 +910,23 @@ def test_collector_with_atari_setting() -> None: ] ], ) - assert len(buf) == len(c5.buffer) + assert len(buf) == len(collector_cached_buffer.buffer) # test buffer=None - c6 = Collector(policy, envs) - c6.reset() - result1 = c6.collect(n_step=12) + collector_default_buffer = Collector[CollectStats](policy, envs) + collector_default_buffer.reset() + result_default_buffer_collect_12_steps = collector_default_buffer.collect(n_step=12) for key in ["n_collected_episodes", "n_collected_steps", "returns", "lens"]: - assert np.allclose(getattr(result1, key), getattr(result_, key)) - result2 = c6.collect(n_episode=9) + assert np.allclose( + getattr(result_default_buffer_collect_12_steps, key), + getattr(result_cached_buffer_collect_12_steps, key), + ) + result2 = collector_default_buffer.collect(n_episode=9) for key in ["n_collected_episodes", "n_collected_steps", "returns", "lens"]: - assert np.allclose(getattr(result2, key), getattr(result, key)) + assert np.allclose( + getattr(result2, key), + getattr(result_cached_buffer_collect_9_episodes, key), + ) @pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform") @@ -907,7 +934,7 @@ def test_collector_envpool_gym_reset_return_info() -> None: envs = envpool.make_gymnasium("Pendulum-v1", num_envs=4, gym_reset_return_info=True) policy = MaxActionPolicy(action_shape=(len(envs), 1)) - c0 = Collector( + c0 = Collector[CollectStats]( policy, envs, VectorReplayBuffer(len(envs) * 10, len(envs)), @@ -926,7 +953,7 @@ def test_collector_with_vector_env() -> None: dum = DummyVectorEnv(env_fns) policy = MaxActionPolicy() - c2 = Collector( + c2 = Collector[CollectStats]( policy, dum, VectorReplayBuffer(total_size=100, buffer_num=4), @@ -959,3 +986,35 @@ def test_async_collector_with_vector_env() -> None: assert np.array_equal(np.array([1, 1, 1, 1, 1, 1, 1, 1, 8, 1, 9]), c1r.lens) c2r = c1.collect(n_step=20) assert np.array_equal(np.array([1, 10, 1, 1, 1, 1]), c2r.lens) + + +class StepHookAddFieldToBatch(StepHook): + def __call__( + self, + action_batch: CollectActionBatchProtocol, + rollout_batch: RolloutBatchProtocol, + ) -> None: + rollout_batch.set_array_at_key(np.array([1]), "added_by_hook") + + +class TestCollectStatsAndHooks: + @staticmethod + def test_on_step_hook(collector_with_single_env: Collector) -> None: + collector_with_single_env.set_on_step_hook(StepHookAddFieldToBatch()) + collect_stats = collector_with_single_env.collect(n_step=3) + assert collect_stats.n_collected_steps == 3 + # a was added by the hook + assert np.array_equal( + collector_with_single_env.buffer[:].added_by_hook, + np.array([1, 1, 1]), + ) + + @staticmethod + def test_episode_mc_hook(collector_with_single_env: Collector) -> None: + collector_with_single_env.set_on_episode_done_hook(EpisodeRolloutHookMCReturn()) + collector_with_single_env.collect(n_episode=1) + collected_batch = collector_with_single_env.buffer[:] + return_to_go = collected_batch.get(EpisodeRolloutHookMCReturn.MC_RETURN_TO_GO_KEY) + full_return = collected_batch.get(EpisodeRolloutHookMCReturn.FULL_EPISODE_MC_RETURN_KEY) + assert np.array_equal(return_to_go, episode_mc_return_to_go(collected_batch.rew)) + assert np.array_equal(full_return, np.ones(5) * return_to_go[0]) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index d565a3e0b..7615e725d 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -4,8 +4,8 @@ from abc import ABC, abstractmethod from collections.abc import Sequence from copy import copy -from dataclasses import dataclass -from typing import Any, Optional, Protocol, Self, TypedDict, TypeVar, cast +from dataclasses import dataclass, field +from typing import Any, Generic, Optional, Protocol, Self, TypedDict, TypeVar, cast import gymnasium as gym import numpy as np @@ -42,6 +42,39 @@ _TArrLike = TypeVar("_TArrLike", bound="np.ndarray | torch.Tensor | Batch | None") +class CollectActionBatchProtocol(Protocol): + """A protocol for results of computing actions from a batch of observations within a single collect step. + + All fields all have length R (the dist is a Distribution of batch size R), + where R is the number of ready envs. + """ + + act: np.ndarray | torch.Tensor + act_normalized: np.ndarray | torch.Tensor + policy_entry: Batch + dist: Distribution | None + hidden_state: np.ndarray | torch.Tensor | Batch | None + + +class CollectStepBatchProtocol(RolloutBatchProtocol): + """A batch of steps collected from a single collect step from multiple envs in parallel. + + All fields have length R (the dist is a Distribution of batch size R), where R is the number of ready envs. + This is essentially the response of the vectorized environment to making a step + with :class:`CollectActionBatchProtocol`. + """ + + dist: Distribution | None + + +class EpisodeBatchProtocol(RolloutBatchProtocol): + """Marker interface for a batch containing a single episode. + + Instances are created by retrieving an episode from the buffer when the :class:`Collector` encounters + `done=True`. + """ + + @dataclass(kw_only=True) class CollectStatsBase(DataclassPPrintMixin): """The most basic stats, often used for offline learning.""" @@ -54,19 +87,33 @@ class CollectStatsBase(DataclassPPrintMixin): @dataclass(kw_only=True) class CollectStats(CollectStatsBase): - """A data structure for storing the statistics of rollouts.""" + """A data structure for storing the statistics of rollouts. + + Custom stats collection logic can be implemented by subclassing this class and + overriding the `update_` methods. + + Ideally, it is instantiated once with correct values and then never modified. + However, during the collection process instances of modified + using the `update_` methods. Then the arrays and their corresponding `_stats` fields + may become out of sync (we don't update the stats after each update for performance reasons, + only at the end of the collection). The same for the `collect_time` and `collect_speed`. + In the `Collector`, :meth:`refresh_sequence_stats` and :meth:`set_collect_time` are + is called at the end of the collection to update the stats. But for other use cases, + the users should keep in mind to call this method manually if using the `update_` + methods. + """ collect_time: float = 0.0 """The time for collecting transitions.""" collect_speed: float = 0.0 """The speed of collecting (env_step per second).""" - returns: np.ndarray + returns: np.ndarray = field(default_factory=lambda: np.array([], dtype=float)) """The collected episode returns.""" - returns_stat: SequenceSummaryStats | None # can be None if no episode ends during the collect step + returns_stat: SequenceSummaryStats | None = None """Stats of the collected returns.""" - lens: np.ndarray + lens: np.ndarray = field(default_factory=lambda: np.array([], dtype=int)) """The collected episode lengths.""" - lens_stat: SequenceSummaryStats | None # can be None if no episode ends during the collect step + lens_stat: SequenceSummaryStats | None = None """Stats of the collected episode lengths.""" std_array: np.ndarray | None = None """The standard deviations of the predicted distributions.""" @@ -97,19 +144,82 @@ def with_autogenerated_stats( lens_stat=lens_stat, ) + def update_at_step_batch( + self, + step_batch: CollectStepBatchProtocol, + refresh_sequence_stats: bool = False, + ) -> None: + self.n_collected_steps += len(step_batch) + action_std = step_batch.dist.stddev if step_batch.dist is not None else None + if action_std is not None: + if self.std_array is None: + self.std_array = to_numpy(action_std) + else: + self.std_array = np.concatenate((self.std_array, to_numpy(action_std))) + if refresh_sequence_stats: + self.refresh_std_array_stats() -class CollectActionBatchProtocol(Protocol): - """A protocol for results of computing actions within a single collect step. + def update_at_episode_done( + self, + episode_batch: EpisodeBatchProtocol, + # NOTE: in the MARL setting this is not actually a float but rather an array or list, see todo below + episode_return: float, + refresh_sequence_stats: bool = False, + ) -> None: + self.lens = np.concatenate((self.lens, [len(episode_batch)]), dtype=int) # type: ignore + self.n_collected_episodes += 1 + if self.returns.size == 0: + # TODO: needed for non-1dim arrays returns that happen in the MARL setting + # There are multiple places that assume the returns to be 1dim, so this is a hack + # Since MARL support is currently not a priority, we should either raise an error or + # implement proper support for it. At the moment tests like `test_collector_with_multi_agent` fail + # when assuming 1d returns + self.returns = np.array([episode_return], dtype=float) + else: + self.returns = np.concatenate((self.returns, [episode_return]), dtype=float) # type: ignore + if refresh_sequence_stats: + self.refresh_return_stats() + self.refresh_len_stats() + + def set_collect_time(self, collect_time: float, update_collect_speed: bool = True) -> None: + if collect_time < 0: + raise ValueError(f"Collect time should be non-negative, but got {collect_time=}.") + + self.collect_time = collect_time + if update_collect_speed: + if collect_time == 0: + log.error( + "Collect time is 0, setting collect speed to 0. Did you make a rounding error?", + ) + self.collect_speed = 0.0 + else: + self.collect_speed = self.n_collected_steps / collect_time - All fields all have length R (the dist is a Distribution of batch size R), - where R is the number of ready envs. - """ + def refresh_return_stats(self) -> None: + if self.returns.size > 0: + self.returns_stat = SequenceSummaryStats.from_sequence(self.returns) + else: + self.returns_stat = None - act: np.ndarray | torch.Tensor - act_normalized: np.ndarray | torch.Tensor - policy_entry: Batch - dist: Distribution | None - hidden_state: np.ndarray | torch.Tensor | Batch | None + def refresh_len_stats(self) -> None: + if self.lens.size > 0: + self.lens_stat = SequenceSummaryStats.from_sequence(self.lens) + else: + self.lens_stat = None + + def refresh_std_array_stats(self) -> None: + if self.std_array is not None and self.std_array.size > 0: + self.std_array_stat = SequenceSummaryStats.from_sequence(self.std_array) + else: + self.std_array_stat = None + + def refresh_all_sequence_stats(self) -> None: + self.refresh_return_stats() + self.refresh_len_stats() + self.refresh_std_array_stats() + + +TCollectStats = TypeVar("TCollectStats", bound=CollectStats) def _nullable_slice(obj: _TArrLike, indices: np.ndarray) -> _TArrLike: @@ -152,7 +262,7 @@ def _HACKY_create_info_batch(info_array: np.ndarray) -> Batch: return result_batch_parent.info -class BaseCollector(ABC): +class BaseCollector(Generic[TCollectStats], ABC): """Used to collect data from a vector environment into a buffer using a given policy. .. note:: @@ -172,6 +282,8 @@ def __init__( env: BaseVectorEnv | gym.Env, buffer: ReplayBuffer | None = None, exploration_noise: bool = False, + # The typing is correct, there's a bug in mypy, see https://github.com/python/mypy/issues/3737 + collect_stats_class: type[TCollectStats] = CollectStats, # type: ignore[assignment] ) -> None: if isinstance(env, gym.Env) and not hasattr(env, "__len__"): warnings.warn("Single environment detected, wrap to DummyVectorEnv.") @@ -191,6 +303,7 @@ def __init__( self._is_closed = False self._validate_buffer() + self.collect_stats_class = collect_stats_class @property def _subbuffer_edges(self) -> np.ndarray: @@ -215,9 +328,6 @@ def _get_start_stop_tuples_for_edge_crossing_interval( The buffer sliced from 4 to 5 and then from 0 to 2 will contain the transitions corresponding to the provided start and stop values. """ - log.debug( - f"Received an edge-crossing episode: {start=}, {stop=}, {self._subbuffer_edges=}", - ) if stop >= start: raise ValueError( f"Expected stop < start, but got {start=}, {stop=}. " @@ -336,7 +446,7 @@ def _collect( random: bool = False, render: float | None = None, gym_reset_kwargs: dict[str, Any] | None = None, - ) -> CollectStats: + ) -> TCollectStats: pass def collect( @@ -347,26 +457,27 @@ def collect( render: float | None = None, reset_before_collect: bool = False, gym_reset_kwargs: dict[str, Any] | None = None, - ) -> CollectStats: - """Collect a specified number of steps or episodes. + ) -> TCollectStats: + """Collect the specified number of steps or episodes to the buffer. + + .. note:: + + One and only one collection specification is permitted, either + ``n_step`` or ``n_episode``. - To ensure an unbiased sampling result with the n_episode option, this function will + To ensure an unbiased sampling result with the `n_episode` option, this function will first collect ``n_episode - env_num`` episodes, then for the last ``env_num`` episodes, they will be collected evenly from each env. - :param n_step: how many steps you want to collect. - :param n_episode: how many episodes you want to collect. - :param random: whether to use random policy for collecting data. + :param n_step: how many steps to collect. + :param n_episode: how many episodes to collect. + :param random: whether to sample randomly from the action space instead of using the policy for collecting data. :param render: the sleep time between rendering consecutive frames. :param reset_before_collect: whether to reset the environment before collecting data. - (The collector needs the initial obs and info to function properly.) + (The collector needs the initial `obs` and `info` to function properly.) :param gym_reset_kwargs: extra keyword arguments to pass into the environment's reset function. Only used if reset_before_collect is True. - .. note:: - - One and only one collection number specification is permitted, either - ``n_step`` or ``n_episode``. :return: The collected stats """ @@ -376,14 +487,19 @@ def collect( if reset_before_collect: self.reset(reset_buffer=False, gym_reset_kwargs=gym_reset_kwargs) + pre_collect_time = time.time() with torch_train_mode(self.policy, enabled=False): - return self._collect( + collect_stats = self._collect( n_step=n_step, n_episode=n_episode, random=random, render=render, gym_reset_kwargs=gym_reset_kwargs, ) + collect_time = time.time() - pre_collect_time + collect_stats.set_collect_time(collect_time, update_collect_speed=True) + collect_stats.refresh_all_sequence_stats() + return collect_stats def _validate_n_step_n_episode(self, n_episode: int | None, n_step: int | None) -> None: if not n_step and not n_episode: @@ -407,7 +523,9 @@ def _validate_n_step_n_episode(self, n_episode: int | None, n_step: int | None) ) -class Collector(BaseCollector): +class Collector(BaseCollector[TCollectStats], Generic[TCollectStats]): + """Collects transitions from a vectorized env by computing and applying actions batch-wise.""" + # NAMING CONVENTION (mostly suffixes): # episode - An episode means a rollout until done (terminated or truncated). After an episode is completed, # the corresponding env is either reset or removed from the ready envs. @@ -420,8 +538,20 @@ class Collector(BaseCollector): # A - dimension(s) of actions # H - dimension(s) of hidden state # D - number of envs that reached done in the current collect iteration. Only relevant in n_episode case. - # S - number of surplus envs, i.e. envs that are ready but won't be used in the next iteration. + # S - number of surplus envs, i.e., envs that are ready but won't be used in the next iteration. # Only used in n_episode case. Then, R becomes R-S. + # local_index - selecting from the locally available environments. In more details: + # Each env is associated to an number in [0,..., N-1]. At any moment there are R ready envs, + # but they are not necessarily equal to [0, ..., R-1]. Let the R corresponding indices be + # [r_0, ..., r_(R-1)] (each r_i is in [0, ... N-1]). If the local index is + # [0, 1, 2], it means that we want to select envs [r_0, r_1, r_2]. + # We will usually select from the ready envs by slicing like `ready_env_idx_R[local_index]` + # global_index - the index in [0, ..., N-1]. Slicing the an `_R` index by a local_index produces the + # corresponding global index. In the example above: + # 1. _R index is [r_0, ..., r_(R-1)] + # 2. local_index is [0, 1, 2] + # 3. global_index is [r_0, r_1, r_2] and can be used to select from an array of length N + # def __init__( self, policy: BasePolicy, @@ -431,18 +561,26 @@ def __init__( on_episode_done_hook: Optional["EpisodeRolloutHookProtocol"] = None, on_step_hook: Optional["StepHookProtocol"] = None, raise_on_nan_in_buffer: bool = True, + collect_stats_class: type[TCollectStats] = CollectStats, # type: ignore[assignment] ) -> None: - """:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. - :param env: a ``gym.Env`` environment or an instance of the - :class:`~tianshou.env.BaseVectorEnv` class. + """ + :param policy: a tianshou policy, each :class:`BasePolocy` is capable of computing a batch + of actions from a batch of observations. + :param env: a ``gymnasium.Env`` environment or a vectorized instance of the + :class:`~tianshou.env.BaseVectorEnv` class. The latter is strongly recommended, as with + a gymnasium env the collection will not happen in parallel (a `DummyVectorEnv` + will be constructed internally from the passed env) :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class. If set to None, will instantiate a :class:`~tianshou.data.VectorReplayBuffer` + of size :data:`DEFAULT_BUFFER_MAXSIZE` * (number of envs) as the default buffer. :param exploration_noise: determine whether the action needs to be modified with the corresponding policy's exploration noise. If so, "policy. exploration_noise(act, batch)" will be called automatically to add the exploration noise into action. Default to False. - :param on_episode_done_hook: if passed, will be executed when an episode is done. + :param collect_stats_class: the class to use for collecting statistics. Allows customizing + the stats collection logic by passing a subclass of :class:`CollectStats`. + :param on_episode_done_hook: if passed will be executed when an episode is done. The input to the hook will be a `RolloutBatch` that contains the entire episode (and nothing else). If a dict is returned by the hook will be used to add new entries to the buffer for the episode that just ended. The hook should return arrays with floats @@ -454,25 +592,44 @@ def __init__( Care must be taken when using such hook, as for unfinished episodes one can easily end up with NaNs in the buffer. It is recommended to use the hooks only with the `n_episode` option in `collect`, or to strip the buffer of NaNs after the collection. - :param on_step_hook: if passed, will be executed after each step of the collection but before the - rollout batch is resulting added to the buffer. The inputs to the hook will be + :param on_step_hook: if passed will be executed after each step of the collection but before the + resulting rollout batch is added to the buffer. The inputs to the hook will be the action distributions computed from the previous observations (following the :class:`CollectActionBatchProtocol`) using the policy, and the resulting - rollout batch (following the :class:`RolloutBatchProtocol`). + rollout batch (following the :class:`RolloutBatchProtocol`). **Note** that modifying + the rollout batch with this hook also modifies the data that is collected to the buffer! :param raise_on_nan_in_buffer: whether to raise a Runtime if NaNs are found in the buffer after a collection step. Especially useful when using episode-level hooks. Consider setting to False if the NaN-check becomes a bottleneck. """ - super().__init__(policy, env, buffer, exploration_noise=exploration_noise) + super().__init__( + policy, + env, + buffer, + exploration_noise=exploration_noise, + collect_stats_class=collect_stats_class, + ) self._pre_collect_obs_RO: np.ndarray | None = None self._pre_collect_info_R: np.ndarray | None = None self._pre_collect_hidden_state_RH: np.ndarray | torch.Tensor | Batch | None = None self._is_closed = False - self.on_episode_done_hook = on_episode_done_hook - self.on_step_hook = on_step_hook + self._on_episode_done_hook = on_episode_done_hook + self._on_step_hook = on_step_hook self.collect_step, self.collect_episode, self.collect_time = 0, 0, 0.0 + def set_on_episode_done_hook(self, hook: Optional["EpisodeRolloutHookProtocol"]) -> None: + self._on_episode_done_hook = hook + + def set_on_step_hook(self, hook: Optional["StepHookProtocol"]) -> None: + self._on_step_hook = hook + + def get_on_episode_done_hook(self) -> Optional["EpisodeRolloutHookProtocol"]: + return self._on_episode_done_hook + + def get_on_step_hook(self) -> Optional["StepHookProtocol"]: + return self._on_step_hook + def close(self) -> None: super().close() self._pre_collect_obs_RO = None @@ -480,16 +637,16 @@ def close(self) -> None: def run_on_episode_done( self, - episode_batch: RolloutBatchProtocol, + episode_batch: EpisodeBatchProtocol, ) -> dict[str, np.ndarray] | None: """Executes the `on_episode_done_hook` that was passed on init. One of the main uses of this public method is to allow users to override it in custom - subclasses of the `Collector`. This way, they can override the init to no longer accept + subclasses of :class:`Collector`. This way, they can override the init to no longer accept the `on_episode_done` provider. """ - if self.on_episode_done_hook is not None: - return self.on_episode_done_hook(episode_batch) + if self._on_episode_done_hook is not None: + return self._on_episode_done_hook(episode_batch) return None def run_on_step_hook( @@ -503,8 +660,8 @@ def run_on_step_hook( subclasses of the `Collector`. This way, they can override the init to no longer accept the `on_step_hook` provider. """ - if self.on_step_hook is not None: - self.on_step_hook(action_batch, rollout_batch) + if self._on_step_hook is not None: + self._on_step_hook(action_batch, rollout_batch) def reset_env( self, @@ -592,7 +749,34 @@ def _collect( # noqa: C901 random: bool = False, render: float | None = None, gym_reset_kwargs: dict[str, Any] | None = None, - ) -> CollectStats: + ) -> TCollectStats: + """This method is currently very complex, but it's difficult to break it down into smaller chunks. + + Please read the block-comment of the class to undestand the notation + in the implementation. + + It does the collection by executing the following logic: + + 0. Keep track of n_step and n_episode for being able to stop the collection. + 1. Create a CollectStats instance to store the statistics of the collection. + 2. Compute actions (with policy or sampling from action space) for the R currently active envs. + 3. Perform a step in these R envs. + 4. Perform on-step hooks on the result + 5. Update the CollectStats (using `update_at_step_batch`) and the internal counters after the step + 6. Add the resulting R transitions to the buffer + 7. Find the D envs that reached done in the current iteration + 8. Reset the envs that reached done + 9. Extract episodes for the envs that reached done from the buffer + 10. Perform on-episode-done hooks, modify the transitions belonging to the episodes inside the buffer inplace + 11. Update the CollectStats instance with the new episodes using `update_on_episode_done` + 12. Prepare next step in while loop by saving the last observations and infos + 13. Remove surplus envs from collection mechanism, thereby reducing R, to increase performance + 14. Check whether we added NaN's to the buffer and raise error if so + 15. Update instance-level collection counters (contrary to counters with a lifetime of the method call) + 16. Prepare for the next call of collect (save last observations and info to collector state) + + You can search for Step to find the place where it happens + """ # TODO: can't do it init since AsyncCollector is currently a subclass of Collector if self.env.is_async: raise ValueError( @@ -613,13 +797,13 @@ def _collect( # noqa: C901 else: raise RuntimeError("Input validation failed, this is a bug and shouldn't have happened") - start_time = time.time() if self._pre_collect_obs_RO is None or self._pre_collect_info_R is None: raise ValueError( "Initial obs and info should not be None. " "Either reset the collector (using reset or reset_env) or pass reset_before_collect=True to collect.", ) + # Step 0 # get the first obs to be the current obs in the n_step case as # episodes as a new call to collect does not restart trajectories # (which we also really don't want) @@ -629,6 +813,9 @@ def _collect( # noqa: C901 episode_lens: list[int] = [] episode_start_indices: list[int] = [] + # Step 1 + collect_stats = self.collect_stats_class() + # in case we select fewer episodes than envs, we run only some of them last_obs_RO = _nullable_slice(self._pre_collect_obs_RO, ready_env_ids_R) last_info_R = _nullable_slice(self._pre_collect_info_R, ready_env_ids_R) @@ -646,6 +833,7 @@ def _collect( # noqa: C901 # ) # restore the state: if the last state is None, it won't store + # Step 2 # get the next action and related stats from the previous observation collect_action_computation_batch_R = self._compute_action_policy_hidden( random=random, @@ -655,6 +843,7 @@ def _collect( # noqa: C901 last_hidden_state_RH=last_hidden_state_RH, ) + # Step 3 obs_next_RO, rew_R, terminated_R, truncated_R, info_R = self.env.step( collect_action_computation_batch_R.act_normalized, ready_env_ids_R, @@ -664,10 +853,11 @@ def _collect( # noqa: C901 info_R = _dict_of_arr_to_arr_of_dicts(info_R) # type: ignore[unreachable] done_R = np.logical_or(terminated_R, truncated_R) - current_iteration_batch_R = cast( - RolloutBatchProtocol, + current_step_batch_R = cast( + CollectStepBatchProtocol, Batch( obs=last_obs_RO, + dist=collect_action_computation_batch_R.dist, act=collect_action_computation_batch_R.act, policy=collect_action_computation_batch_R.policy_entry, obs_next=obs_next_RO, @@ -686,21 +876,31 @@ def _collect( # noqa: C901 if not np.isclose(render, 0): time.sleep(render) + # Step 4 self.run_on_step_hook( collect_action_computation_batch_R, - current_iteration_batch_R, - ) - # add data into the buffer - insertion_idx_R, ep_return_R, ep_len_R, ep_start_idx_R = self.buffer.add( - current_iteration_batch_R, - buffer_ids=ready_env_ids_R, + current_step_batch_R, ) - # collect statistics + # Step 5, collect statistics + collect_stats.update_at_step_batch(current_step_batch_R) num_episodes_done_this_iter = np.sum(done_R) num_collected_episodes += num_episodes_done_this_iter step_count += len(ready_env_ids_R) + # Step 6 + # add data into the buffer. Since the buffer is essentially an array, we don't want + # to add the dist. One should not have arrays of dists but rather a single, batch-wise dist. + # Tianshou already implements slicing of dists, but we don't yet implement merging multiple + # dists into one, which would be necessary to make a buffer with dists work properly + batch_to_add_R = copy(current_step_batch_R) + batch_to_add_R.pop("dist") + batch_to_add_R = cast(RolloutBatchProtocol, batch_to_add_R) + insertion_idx_R, ep_return_R, ep_len_R, ep_start_idx_R = self.buffer.add( + batch_to_add_R, + buffer_ids=ready_env_ids_R, + ) + # preparing for the next iteration # obs_next, info and hidden_state will be modified inplace in the code below, # so we copy to not affect the data in the buffer @@ -713,19 +913,26 @@ def _collect( # noqa: C901 if num_episodes_done_this_iter > 0: # TODO: adjust the whole index story, don't use np.where, just slice with boolean arrays # D - number of envs that reached done in the rollout above + # local_idx - see block comment on class level + # Step 7 env_done_local_idx_D = np.where(done_R)[0] - episode_lens.extend(ep_len_R[env_done_local_idx_D]) - episode_returns.extend(ep_return_R[env_done_local_idx_D]) - episode_start_indices.extend(ep_start_idx_R[env_done_local_idx_D]) + episode_lens_D = ep_len_R[env_done_local_idx_D] + episode_returns_D = ep_return_R[env_done_local_idx_D] + episode_start_indices_D = ep_start_idx_R[env_done_local_idx_D] + + episode_lens.extend(episode_lens_D) + episode_returns.extend(episode_returns_D) + episode_start_indices.extend(episode_start_indices_D) + + # Step 8 # now we copy obs_next to obs, but since there might be # finished episodes, we have to reset finished envs first. - gym_reset_kwargs = gym_reset_kwargs or {} - # The index env_done_idx_D was based on 0, ..., R # However, each env has an index in the context of the vectorized env and buffer. So the env 0 being done means # that some env of the corresponding "global" index was done. The mapping between "local" index in - # 0,...,R and this global index is maintained by the ready_env_ids_R array + # 0,...,R and this global index is maintained by the ready_env_ids_R array. + # See the class block comment for more details env_done_global_idx_D = ready_env_ids_R[env_done_local_idx_D] obs_reset_DO, info_reset_D = self.env.reset( env_id=env_done_global_idx_D, @@ -737,19 +944,30 @@ def _collect( # noqa: C901 # this complex logic self._reset_hidden_state_based_on_type(env_done_local_idx_D, last_hidden_state_RH) + # Step 9 # execute episode hooks for those envs which emitted 'done' - for local_done_idx in env_done_local_idx_D: + for local_done_idx, cur_ep_return in zip( + env_done_local_idx_D, + episode_returns_D, + strict=True, + ): cur_ep_index_slice = slice( ep_start_idx_R[local_done_idx], insertion_idx_R[local_done_idx] + 1, ) - cur_ep_index_array, ep_rollout_batch = self._get_buffer_index_and_entries( + ( + cur_ep_index_array, + cur_ep_batch, + ) = self._get_buffer_index_and_entries_for_episode_from_slice( cur_ep_index_slice, ) - episode_hook_additions = self.run_on_episode_done(ep_rollout_batch) + cur_ep_batch = cast(EpisodeBatchProtocol, cur_ep_batch) + + # Step 10 + episode_hook_additions = self.run_on_episode_done(cur_ep_batch) if episode_hook_additions is not None: - if n_episode is not None: + if n_episode is None: raise ValueError( "An on_episode_done_hook with non-empty returns is not supported for n_step collection." "Such hooks should only be used when collecting full episodes. Got a on_episode_done_hook " @@ -762,11 +980,25 @@ def _collect( # noqa: C901 key, index=cur_ep_index_array, ) + # executing the same logic in the episode-batch since stats computation + # may depend on the presence of additional fields + cur_ep_batch.set_array_at_key( + episode_addition, + key, + ) + # Step 11 + # Finally, update the stats + collect_stats.update_at_episode_done( + episode_batch=cur_ep_batch, + episode_return=cur_ep_return, + ) + # Step 12 # preparing for the next iteration last_obs_RO[env_done_local_idx_D] = obs_reset_DO last_info_R[env_done_local_idx_D] = info_reset_D + # Step 13 # Handling the case when we have more ready envs than desired and are not done yet # # This can only happen if we are collecting a fixed number of episodes @@ -806,12 +1038,26 @@ def _collect( # noqa: C901 ): break - # generate statistics + # Step 14 + # Check if we screwed up somewhere + if self.buffer.hasnull(): + nan_batch = self.buffer.isnull().apply_values_transform(np.sum) + + raise MalformedBufferError( + "NaN detected in the buffer. You can drop them with `buffer.dropnull()`. " + "This error is most often caused by an incorrect use of `EpisodeRolloutHooks`" + "together with the `n_steps` (instead of `n_episodes`) option, or by " + "an incorrect implementation of `StepHook`." + "Here an overview of the number of NaNs per field: \n" + f"{nan_batch}", + ) + + # Step 15 + # update instance-lifetime counters, different from collect_stats self.collect_step += step_count self.collect_episode += num_collected_episodes - collect_time = max(time.time() - start_time, 1e-9) - self.collect_time += collect_time + # Step 16 if n_step: # persist for future collect iterations self._pre_collect_obs_RO = last_obs_RO @@ -820,27 +1066,10 @@ def _collect( # noqa: C901 elif n_episode: # reset envs and the _pre_collect fields self.reset_env(gym_reset_kwargs) # todo still necessary? - - if self.buffer.hasnull(): - nan_batch = self.buffer.isnull().apply_array_func(np.sum) - - raise MalformedBufferError( - "NaN detected in the buffer. You can drop them with `buffer.dropnull()`. " - "Here an overview of the number of NaNs per field: \n" - f"{nan_batch}", - ) - - return CollectStats.with_autogenerated_stats( - returns=np.array(episode_returns), - lens=np.array(episode_lens), - n_collected_episodes=num_collected_episodes, - n_collected_steps=step_count, - collect_time=collect_time, - collect_speed=step_count / collect_time, - ) + return collect_stats # TODO: move to buffer - def _get_buffer_index_and_entries( + def _get_buffer_index_and_entries_for_episode_from_slice( self, entries_slice: slice, ) -> tuple[np.ndarray, RolloutBatchProtocol]: @@ -850,6 +1079,17 @@ def _get_buffer_index_and_entries( :return: The indices of the entries in the buffer and the corresponding batch of entries. """ start, stop = entries_slice.start, entries_slice.stop + + # if isinstance(self.buffer, CachedReplayBuffer): + # # Accounting for the very special behavior of the CachedReplayBuffer, where once an episode is + # # finished, it is moved to the main buffer and the corresponding subbuffer is reset. + # # This means, that retrieving a slice corresponding to a finished episode should always happen + # # from the main buffer, whereas slices for unfinished episodes should always be retrieved from + # # the corresponding subbuffer + # # TODO: fix this behavior in CachedReplayBuffer, remove the special sauce here + # start = start % self.buffer.main_buffer.maxsize + # stop = stop % self.buffer.main_buffer.maxsize + if stop > start: cur_ep_index_array = np.arange( entries_slice.start, @@ -857,6 +1097,9 @@ def _get_buffer_index_and_entries( dtype=int, ) else: + # stop < start means that to retrieve the slice we have to cross an edge of the buffer + # We have to split the slice into two parts and concatenate the results + log.debug(f"Received an edge-crossing slice with {stop=} < {start=}") (start, upper_edge), ( lower_edge, stop, @@ -887,7 +1130,7 @@ def _reset_hidden_state_based_on_type( # todo is this inplace magic and just working? -class AsyncCollector(Collector): +class AsyncCollector(Collector[CollectStats]): """Async Collector handles async vector environment. Please refer to :class:`~tianshou.data.Collector` for a more detailed explanation. @@ -913,6 +1156,7 @@ def __init__( env, buffer, exploration_noise, + collect_stats_class=CollectStats, ) # E denotes the number of parallel environments: self.env_num # At init, E=R but during collection R <= E @@ -1004,7 +1248,7 @@ def _collect( ) # Each iteration of the AsyncCollector is only stepping a subset of the # envs. The last observation/ hidden state of the ones not included in - # the current iteration has to be retained. + # the current iteration has to be retained. This is done by copying the while True: # todo do we need this? # todo extend to all current attributes but some could be None at init @@ -1162,8 +1406,6 @@ def _collect( lens=np.array(episode_lens), n_collected_episodes=num_collected_episodes, n_collected_steps=step_count, - collect_time=collect_time, - collect_speed=step_count / collect_time, ) @@ -1222,14 +1464,14 @@ class EpisodeRolloutHookProtocol(Protocol): A prime example is something like the MC return to go. """ - def __call__(self, rollout_batch: RolloutBatchProtocol) -> dict[str, np.ndarray] | None: + def __call__(self, episode_batch: EpisodeBatchProtocol) -> dict[str, np.ndarray] | None: """Will be called by the collector when an episode is finished. If a dictionary is returned, the key-value pairs will be interpreted as new entries to be added to the episode batch (inside the buffer). In that case, the values should be arrays of the same length as the input `rollout_batch`. - :param rollout_batch: the batch of transitions that belong to the episode. + :param episode_batch: the batch of transitions that belong to the episode. :return: an optional dictionary containing new entries (of same len as `rollout_batch`) to be added to the buffer. """ @@ -1246,7 +1488,7 @@ class EpisodeRolloutHook(EpisodeRolloutHookProtocol, ABC): """ @abstractmethod - def __call__(self, rollout_batch: RolloutBatchProtocol) -> dict[str, np.ndarray] | None: + def __call__(self, episode_batch: EpisodeBatchProtocol) -> dict[str, np.ndarray] | None: ... @@ -1271,9 +1513,9 @@ def __init__(self, gamma: float = 0.99): def __call__( # type: ignore[override] self, - rollout_batch: RolloutBatchProtocol, + episode_batch: RolloutBatchProtocol, ) -> "EpisodeRolloutHookMCReturn.OutputDict": - mc_return_to_go = episode_mc_return_to_go(rollout_batch.rew, self.gamma) + mc_return_to_go = episode_mc_return_to_go(episode_batch.rew, self.gamma) full_episode_mc_return = np.full_like(mc_return_to_go, mc_return_to_go[0]) return self.OutputDict( @@ -1290,22 +1532,22 @@ class EpisodeRolloutHookMerged(EpisodeRolloutHook): def __init__( self, - *rollout_hooks: EpisodeRolloutHookProtocol, + *episode_rollout_hooks: EpisodeRolloutHookProtocol, check_overlapping_keys: bool = True, ): """ - :param rollout_hooks: the hooks to combine + :param episode_rollout_hooks: the hooks to combine :param check_overlapping_keys: whether to check for overlapping keys in the output of the hooks and raise a `KeyError` if any are found. Set to `False` to disable this check (can be useful if this becomes a performance bottleneck). """ - self.rollout_hooks = rollout_hooks + self.episode_rollout_hooks = episode_rollout_hooks self.check_overlapping_keys = check_overlapping_keys - def __call__(self, rollout_batch: RolloutBatchProtocol) -> dict[str, np.ndarray] | None: + def __call__(self, episode_batch: EpisodeBatchProtocol) -> dict[str, np.ndarray] | None: result: dict[str, np.ndarray] = {} - for rollout_hook in self.rollout_hooks: - new_entries = rollout_hook(rollout_batch) + for rollout_hook in self.episode_rollout_hooks: + new_entries = rollout_hook(episode_batch) if new_entries is None: continue From 58879cb5c98b3f558bd3820acb3392731641a4e8 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Sun, 18 Aug 2024 17:05:15 +0200 Subject: [PATCH 29/33] Typing: since Collector generic, CollectStats need to be passed at init for mypi --- docs/02_notebooks/L0_overview.ipynb | 19 ++++++--------- docs/02_notebooks/L5_Collector.ipynb | 6 ++--- docs/02_notebooks/L6_Trainer.ipynb | 31 +++++++++--------------- docs/02_notebooks/L7_Experiment.ipynb | 6 ++--- docs/spelling_wordlist.txt | 2 ++ examples/atari/atari_c51.py | 8 +++--- examples/atari/atari_dqn.py | 8 +++--- examples/atari/atari_fqf.py | 8 +++--- examples/atari/atari_iqn.py | 8 +++--- examples/atari/atari_ppo.py | 8 +++--- examples/atari/atari_qrdqn.py | 8 +++--- examples/atari/atari_rainbow.py | 13 +++++++--- examples/atari/atari_sac.py | 8 +++--- examples/box2d/acrobot_dualdqn.py | 6 ++--- examples/box2d/bipedal_bdq.py | 6 ++--- examples/box2d/bipedal_hardcore_sac.py | 6 ++--- examples/box2d/lunarlander_dqn.py | 6 ++--- examples/box2d/mcc_sac.py | 6 ++--- examples/discrete/discrete_dqn.py | 7 +++--- examples/inverse/irl_gail.py | 12 ++++++--- examples/mujoco/fetch_her_ddpg.py | 5 ++-- examples/mujoco/mujoco_a2c.py | 6 ++--- examples/mujoco/mujoco_ddpg.py | 6 ++--- examples/mujoco/mujoco_npg.py | 6 ++--- examples/mujoco/mujoco_ppo.py | 6 ++--- examples/mujoco/mujoco_redq.py | 6 ++--- examples/mujoco/mujoco_reinforce.py | 6 ++--- examples/mujoco/mujoco_sac.py | 6 ++--- examples/mujoco/mujoco_td3.py | 6 ++--- examples/mujoco/mujoco_trpo.py | 6 ++--- examples/offline/atari_bcq.py | 4 +-- examples/offline/atari_cql.py | 4 +-- examples/offline/atari_crr.py | 4 +-- examples/offline/atari_il.py | 4 +-- examples/offline/d4rl_bcq.py | 6 ++--- examples/offline/d4rl_cql.py | 6 ++--- examples/offline/d4rl_il.py | 6 ++--- examples/offline/d4rl_td3_bc.py | 6 ++--- examples/vizdoom/vizdoom_c51.py | 8 +++--- examples/vizdoom/vizdoom_ppo.py | 8 +++--- test/base/test_batch.py | 3 ++- test/base/test_env_finite.py | 6 ++--- test/continuous/test_ddpg.py | 6 ++--- test/continuous/test_npg.py | 6 ++--- test/continuous/test_ppo.py | 6 ++--- test/continuous/test_redq.py | 6 ++--- test/continuous/test_sac_with_il.py | 8 +++--- test/continuous/test_td3.py | 6 ++--- test/continuous/test_trpo.py | 6 ++--- test/discrete/test_a2c_with_il.py | 8 +++--- test/discrete/test_bdq.py | 6 ++--- test/discrete/test_c51.py | 5 ++-- test/discrete/test_dqn.py | 5 ++-- test/discrete/test_drqn.py | 6 ++--- test/discrete/test_fqf.py | 5 ++-- test/discrete/test_iqn.py | 5 ++-- test/discrete/test_pg.py | 6 ++--- test/discrete/test_ppo.py | 6 ++--- test/discrete/test_qrdqn.py | 11 ++++++--- test/discrete/test_rainbow.py | 11 ++++++--- test/discrete/test_sac.py | 6 ++--- test/modelbased/test_dqn_icm.py | 11 ++++++--- test/modelbased/test_ppo_icm.py | 6 ++--- test/modelbased/test_psrl.py | 6 ++--- test/offline/gather_cartpole_data.py | 13 +++++++--- test/offline/gather_pendulum_data.py | 6 ++--- test/offline/test_bcq.py | 8 +++--- test/offline/test_cql.py | 6 ++--- test/offline/test_discrete_bcq.py | 9 +++++-- test/offline/test_discrete_cql.py | 9 +++++-- test/offline/test_discrete_crr.py | 9 +++++-- test/offline/test_gail.py | 6 ++--- test/offline/test_td3_bc.py | 6 ++--- test/pettingzoo/pistonball.py | 8 +++--- test/pettingzoo/pistonball_continuous.py | 8 +++--- test/pettingzoo/tic_tac_toe.py | 8 +++--- tianshou/highlevel/agent.py | 11 ++++++--- tianshou/highlevel/experiment.py | 4 +-- 78 files changed, 305 insertions(+), 259 deletions(-) diff --git a/docs/02_notebooks/L0_overview.ipynb b/docs/02_notebooks/L0_overview.ipynb index 0ce6df154..8e3616acd 100644 --- a/docs/02_notebooks/L0_overview.ipynb +++ b/docs/02_notebooks/L0_overview.ipynb @@ -15,15 +15,6 @@ "Before we get started, we must first install Tianshou's library and Gym environment by running the commands below. This tutorials will always keep up with the latest version of Tianshou since they also serve as a test for the latest version. If you are using an older version of Tianshou, please refer to the [documentation](https://tianshou.readthedocs.io/en/latest/) of your version.\n" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# !pip install tianshou gym" - ] - }, { "cell_type": "markdown", "metadata": { @@ -67,7 +58,7 @@ "import gymnasium as gym\n", "import torch\n", "\n", - "from tianshou.data import Collector, VectorReplayBuffer\n", + "from tianshou.data import Collector, CollectStats, VectorReplayBuffer\n", "from tianshou.env import DummyVectorEnv\n", "from tianshou.policy import PPOPolicy\n", "from tianshou.trainer import OnpolicyTrainer\n", @@ -114,8 +105,12 @@ ")\n", "\n", "# collector\n", - "train_collector = Collector(policy, train_envs, VectorReplayBuffer(20000, len(train_envs)))\n", - "test_collector = Collector(policy, test_envs)\n", + "train_collector = Collector[CollectStats](\n", + " policy,\n", + " train_envs,\n", + " VectorReplayBuffer(20000, len(train_envs)),\n", + ")\n", + "test_collector = Collector[CollectStats](policy, test_envs)\n", "\n", "# trainer\n", "train_result = OnpolicyTrainer(\n", diff --git a/docs/02_notebooks/L5_Collector.ipynb b/docs/02_notebooks/L5_Collector.ipynb index 7da98a5cf..41d12c923 100644 --- a/docs/02_notebooks/L5_Collector.ipynb +++ b/docs/02_notebooks/L5_Collector.ipynb @@ -58,7 +58,7 @@ "import gymnasium as gym\n", "import torch\n", "\n", - "from tianshou.data import Collector, VectorReplayBuffer\n", + "from tianshou.data import Collector, CollectStats, VectorReplayBuffer\n", "from tianshou.env import DummyVectorEnv\n", "from tianshou.policy import PGPolicy\n", "from tianshou.utils.net.common import Net\n", @@ -94,7 +94,7 @@ " action_space=env.action_space,\n", " action_scaling=False,\n", ")\n", - "test_collector = Collector(policy, test_envs)" + "test_collector = Collector[CollectStats](policy, test_envs)" ] }, { @@ -187,7 +187,7 @@ "train_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(train_env_num)])\n", "replayBuffer = VectorReplayBuffer(buffer_size, train_env_num)\n", "\n", - "train_collector = Collector(policy, train_envs, replayBuffer)" + "train_collector = Collector[CollectStats](policy, train_envs, replayBuffer)" ] }, { diff --git a/docs/02_notebooks/L6_Trainer.ipynb b/docs/02_notebooks/L6_Trainer.ipynb index d5423bd01..d7023553f 100644 --- a/docs/02_notebooks/L6_Trainer.ipynb +++ b/docs/02_notebooks/L6_Trainer.ipynb @@ -54,6 +54,7 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { "editable": true, "id": "do-xZ-8B7nVH", @@ -63,37 +64,29 @@ "tags": [ "hide-cell", "remove-output" - ], - "ExecuteTime": { - "end_time": "2024-05-06T15:34:02.969675Z", - "start_time": "2024-05-06T15:34:00.747309Z" - } + ] }, + "outputs": [], "source": [ "%%capture\n", "\n", "import gymnasium as gym\n", "import torch\n", "\n", - "from tianshou.data import Collector, VectorReplayBuffer\n", + "from tianshou.data import Collector, CollectStats, VectorReplayBuffer\n", "from tianshou.env import DummyVectorEnv\n", "from tianshou.policy import PGPolicy\n", "from tianshou.trainer import OnpolicyTrainer\n", "from tianshou.utils.net.common import Net\n", "from tianshou.utils.net.discrete import Actor\n", "from tianshou.utils.torch_utils import policy_within_training_step, torch_train_mode" - ], - "outputs": [], - "execution_count": 1 + ] }, { "cell_type": "code", - "metadata": { - "ExecuteTime": { - "end_time": "2024-05-06T15:34:07.536452Z", - "start_time": "2024-05-06T15:34:03.636670Z" - } - }, + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "train_env_num = 4\n", "buffer_size = (\n", @@ -129,11 +122,9 @@ "\n", "# Create the replay buffer and the collector\n", "replayBuffer = VectorReplayBuffer(buffer_size, train_env_num)\n", - "test_collector = Collector(policy, test_envs)\n", - "train_collector = Collector(policy, train_envs, replayBuffer)" - ], - "outputs": [], - "execution_count": 2 + "test_collector = Collector[CollectStats](policy, test_envs)\n", + "train_collector = Collector[CollectStats](policy, train_envs, replayBuffer)" + ] }, { "cell_type": "markdown", diff --git a/docs/02_notebooks/L7_Experiment.ipynb b/docs/02_notebooks/L7_Experiment.ipynb index 0a6675cb7..47e4cb0c9 100644 --- a/docs/02_notebooks/L7_Experiment.ipynb +++ b/docs/02_notebooks/L7_Experiment.ipynb @@ -71,7 +71,7 @@ "import gymnasium as gym\n", "import torch\n", "\n", - "from tianshou.data import Collector, VectorReplayBuffer\n", + "from tianshou.data import Collector, CollectStats, VectorReplayBuffer\n", "from tianshou.env import DummyVectorEnv\n", "from tianshou.policy import PPOPolicy\n", "from tianshou.trainer import OnpolicyTrainer\n", @@ -202,12 +202,12 @@ }, "outputs": [], "source": [ - "train_collector = Collector(\n", + "train_collector = Collector[CollectStats](\n", " policy=policy,\n", " env=train_envs,\n", " buffer=VectorReplayBuffer(20000, len(train_envs)),\n", ")\n", - "test_collector = Collector(policy=policy, env=test_envs)" + "test_collector = Collector[CollectStats](policy=policy, env=test_envs)" ] }, { diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index fa5a0066d..25eaa526c 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -286,3 +286,5 @@ gaussian logprob monte carlo +subclass +subclassing diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index d611ab196..587e1b19b 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -9,7 +9,7 @@ from atari_network import C51 from atari_wrapper import make_atari_env -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import C51Policy from tianshou.policy.base import BasePolicy @@ -112,8 +112,8 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: stack_num=args.frames_stack, ) # collector - train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector(policy, test_envs, exploration_noise=True) + train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -173,7 +173,7 @@ def watch() -> None: save_only_last_obs=True, stack_num=args.frames_stack, ) - collector = Collector(policy, test_envs, buffer, exploration_noise=True) + collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True) result = collector.collect(n_step=args.buffer_size) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index eeb9bccce..8d0a2fdfe 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -9,7 +9,7 @@ from atari_network import DQN from atari_wrapper import make_atari_env -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import DQNPolicy from tianshou.policy.base import BasePolicy @@ -148,8 +148,8 @@ def main(args: argparse.Namespace = get_args()) -> None: stack_num=args.frames_stack, ) # collector - train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector(policy, test_envs, exploration_noise=True) + train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -215,7 +215,7 @@ def watch() -> None: save_only_last_obs=True, stack_num=args.frames_stack, ) - collector = Collector(policy, test_envs, buffer, exploration_noise=True) + collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True) result = collector.collect(n_step=args.buffer_size) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size diff --git a/examples/atari/atari_fqf.py b/examples/atari/atari_fqf.py index 58aff46ac..0a544eeac 100644 --- a/examples/atari/atari_fqf.py +++ b/examples/atari/atari_fqf.py @@ -9,7 +9,7 @@ from atari_network import DQN from atari_wrapper import make_atari_env -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import FQFPolicy from tianshou.policy.base import BasePolicy @@ -125,8 +125,8 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None: stack_num=args.frames_stack, ) # collector - train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector(policy, test_envs, exploration_noise=True) + train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -186,7 +186,7 @@ def watch() -> None: save_only_last_obs=True, stack_num=args.frames_stack, ) - collector = Collector(policy, test_envs, buffer, exploration_noise=True) + collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True) result = collector.collect(n_step=args.buffer_size) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size diff --git a/examples/atari/atari_iqn.py b/examples/atari/atari_iqn.py index c6090523d..5fa3638b2 100644 --- a/examples/atari/atari_iqn.py +++ b/examples/atari/atari_iqn.py @@ -9,7 +9,7 @@ from atari_network import DQN from atari_wrapper import make_atari_env -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import IQNPolicy from tianshou.policy.base import BasePolicy @@ -122,8 +122,8 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None: stack_num=args.frames_stack, ) # collector - train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector(policy, test_envs, exploration_noise=True) + train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -183,7 +183,7 @@ def watch() -> None: save_only_last_obs=True, stack_num=args.frames_stack, ) - collector = Collector(policy, test_envs, buffer, exploration_noise=True) + collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True) result = collector.collect(n_step=args.buffer_size) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size diff --git a/examples/atari/atari_ppo.py b/examples/atari/atari_ppo.py index f5e585c79..b1a5ae308 100644 --- a/examples/atari/atari_ppo.py +++ b/examples/atari/atari_ppo.py @@ -11,7 +11,7 @@ from torch.distributions import Categorical from torch.optim.lr_scheduler import LambdaLR -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import ICMPolicy, PPOPolicy from tianshou.policy.base import BasePolicy @@ -190,8 +190,8 @@ def dist(logits: torch.Tensor) -> Categorical: stack_num=args.frames_stack, ) # collector - train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector(policy, test_envs, exploration_noise=True) + train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -243,7 +243,7 @@ def watch() -> None: save_only_last_obs=True, stack_num=args.frames_stack, ) - collector = Collector(policy, test_envs, buffer, exploration_noise=True) + collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True) result = collector.collect(n_step=args.buffer_size) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index b9731316e..9a451f403 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -9,7 +9,7 @@ from atari_network import QRDQN from atari_wrapper import make_atari_env -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import QRDQNPolicy from tianshou.policy.base import BasePolicy @@ -116,8 +116,8 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None: stack_num=args.frames_stack, ) # collector - train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector(policy, test_envs, exploration_noise=True) + train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -177,7 +177,7 @@ def watch() -> None: save_only_last_obs=True, stack_num=args.frames_stack, ) - collector = Collector(policy, test_envs, buffer, exploration_noise=True) + collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True) result = collector.collect(n_step=args.buffer_size) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size diff --git a/examples/atari/atari_rainbow.py b/examples/atari/atari_rainbow.py index 952d35f07..ab5ede3cf 100644 --- a/examples/atari/atari_rainbow.py +++ b/examples/atari/atari_rainbow.py @@ -9,7 +9,12 @@ from atari_network import Rainbow from atari_wrapper import make_atari_env -from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer +from tianshou.data import ( + Collector, + CollectStats, + PrioritizedVectorReplayBuffer, + VectorReplayBuffer, +) from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import C51Policy, RainbowPolicy from tianshou.policy.base import BasePolicy @@ -142,8 +147,8 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None: weight_norm=not args.no_weight_norm, ) # collector - train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector(policy, test_envs, exploration_noise=True) + train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -213,7 +218,7 @@ def watch() -> None: alpha=args.alpha, beta=args.beta, ) - collector = Collector(policy, test_envs, buffer, exploration_noise=True) + collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True) result = collector.collect(n_step=args.buffer_size) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size diff --git a/examples/atari/atari_sac.py b/examples/atari/atari_sac.py index 4d01a88aa..cb589d83e 100644 --- a/examples/atari/atari_sac.py +++ b/examples/atari/atari_sac.py @@ -9,7 +9,7 @@ from atari_network import DQN from atari_wrapper import make_atari_env -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import DiscreteSACPolicy, ICMPolicy from tianshou.policy.base import BasePolicy @@ -173,8 +173,8 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: stack_num=args.frames_stack, ) # collector - train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector(policy, test_envs, exploration_noise=True) + train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -226,7 +226,7 @@ def watch() -> None: save_only_last_obs=True, stack_num=args.frames_stack, ) - collector = Collector(policy, test_envs, buffer, exploration_noise=True) + collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True) result = collector.collect(n_step=args.buffer_size) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index 365c073fa..b25f35c15 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -7,7 +7,7 @@ import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import DQNPolicy from tianshou.policy.base import BasePolicy @@ -84,13 +84,13 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: target_update_freq=args.target_update_freq, ) # collector - train_collector = Collector( + train_collector = Collector[CollectStats]( policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), exploration_noise=True, ) - test_collector = Collector(policy, test_envs, exploration_noise=True) + test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) # policy.set_eps(1) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) diff --git a/examples/box2d/bipedal_bdq.py b/examples/box2d/bipedal_bdq.py index c817831b1..d88379b23 100644 --- a/examples/box2d/bipedal_bdq.py +++ b/examples/box2d/bipedal_bdq.py @@ -8,7 +8,7 @@ import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import ContinuousToDiscrete, SubprocVectorEnv from tianshou.policy import BranchingDQNPolicy from tianshou.policy.base import BasePolicy @@ -109,13 +109,13 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None: target_update_freq=args.target_update_freq, ) # collector - train_collector = Collector( + train_collector = Collector[CollectStats]( policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), exploration_noise=True, ) - test_collector = Collector(policy, test_envs, exploration_noise=False) + test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=False) # policy.set_eps(1) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index 66e5f316d..b377d7bb1 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -9,7 +9,7 @@ from gymnasium.core import WrapperActType, WrapperObsType from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import SubprocVectorEnv from tianshou.policy import SACPolicy from tianshou.policy.base import BasePolicy @@ -163,13 +163,13 @@ def test_sac_bipedal(args: argparse.Namespace = get_args()) -> None: print("Loaded agent from: ", args.resume_path) # collector - train_collector = Collector( + train_collector = Collector[CollectStats]( policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), exploration_noise=True, ) - test_collector = Collector(policy, test_envs) + test_collector = Collector[CollectStats](policy, test_envs) # train_collector.collect(n_step=args.buffer_size) # log log_path = os.path.join(args.logdir, args.task, "sac") diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index f9bbd6fa6..347da2cf9 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -7,7 +7,7 @@ import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv, SubprocVectorEnv from tianshou.policy import DQNPolicy from tianshou.policy.base import BasePolicy @@ -86,13 +86,13 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: target_update_freq=args.target_update_freq, ) # collector - train_collector = Collector( + train_collector = Collector[CollectStats]( policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), exploration_noise=True, ) - test_collector = Collector(policy, test_envs, exploration_noise=True) + test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) # policy.set_eps(1) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index 7617b7b43..452eb02d6 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -7,7 +7,7 @@ import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.exploration import OUNoise from tianshou.policy import SACPolicy @@ -109,13 +109,13 @@ def test_sac(args: argparse.Namespace = get_args()) -> None: action_space=env.action_space, ) # collector - train_collector = Collector( + train_collector = Collector[CollectStats]( policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), exploration_noise=True, ) - test_collector = Collector(policy, test_envs) + test_collector = Collector[CollectStats](policy, test_envs) # train_collector.collect(n_step=args.buffer_size) # log log_path = os.path.join(args.logdir, args.task, "sac") diff --git a/examples/discrete/discrete_dqn.py b/examples/discrete/discrete_dqn.py index 3ba22a40c..4e52f4ce2 100644 --- a/examples/discrete/discrete_dqn.py +++ b/examples/discrete/discrete_dqn.py @@ -3,6 +3,7 @@ from torch.utils.tensorboard import SummaryWriter import tianshou as ts +from tianshou.data import CollectStats from tianshou.utils.space_info import SpaceInfo @@ -42,13 +43,13 @@ def main() -> None: estimation_step=n_step, target_update_freq=target_freq, ) - train_collector = ts.data.Collector( + train_collector = ts.data.Collector[CollectStats]( policy, train_envs, ts.data.VectorReplayBuffer(buffer_size, train_num), exploration_noise=True, ) - test_collector = ts.data.Collector( + test_collector = ts.data.Collector[CollectStats]( policy, test_envs, exploration_noise=True, @@ -81,7 +82,7 @@ def stop_fn(mean_rewards: float) -> bool: # watch performance policy.set_eps(eps_test) - collector = ts.data.Collector(policy, env, exploration_noise=True) + collector = ts.data.Collector[CollectStats](policy, env, exploration_noise=True) collector.collect(n_episode=100, render=1 / 35) diff --git a/examples/inverse/irl_gail.py b/examples/inverse/irl_gail.py index e327fd490..815060d1c 100644 --- a/examples/inverse/irl_gail.py +++ b/examples/inverse/irl_gail.py @@ -15,7 +15,13 @@ from torch.optim.lr_scheduler import LambdaLR from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Batch, Collector, ReplayBuffer, VectorReplayBuffer +from tianshou.data import ( + Batch, + Collector, + CollectStats, + ReplayBuffer, + VectorReplayBuffer, +) from tianshou.data.types import RolloutBatchProtocol from tianshou.env import SubprocVectorEnv, VectorEnvNormObs from tianshou.policy import GAILPolicy @@ -236,8 +242,8 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: buffer = ReplayBuffer(args.buffer_size) - train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector(policy, test_envs) + train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](policy, test_envs) # log t0 = datetime.datetime.now().strftime("%m%d_%H%M%S") log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_gail' diff --git a/examples/mujoco/fetch_her_ddpg.py b/examples/mujoco/fetch_her_ddpg.py index bbf68c2fa..ee9b76e75 100644 --- a/examples/mujoco/fetch_her_ddpg.py +++ b/examples/mujoco/fetch_her_ddpg.py @@ -13,6 +13,7 @@ from tianshou.data import ( Collector, + CollectStats, HERReplayBuffer, HERVectorReplayBuffer, ReplayBuffer, @@ -211,8 +212,8 @@ def compute_reward_fn(ag: np.ndarray, g: np.ndarray) -> np.ndarray: horizon=args.her_horizon, future_k=args.her_future_k, ) - train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector(policy, test_envs) + train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](policy, test_envs) train_collector.reset() train_collector.collect(n_step=args.start_timesteps, random=True) diff --git a/examples/mujoco/mujoco_a2c.py b/examples/mujoco/mujoco_a2c.py index 194d9b5de..817da079a 100755 --- a/examples/mujoco/mujoco_a2c.py +++ b/examples/mujoco/mujoco_a2c.py @@ -12,7 +12,7 @@ from torch.distributions import Distribution, Independent, Normal from torch.optim.lr_scheduler import LambdaLR -from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer +from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import A2CPolicy from tianshou.policy.base import BasePolicy @@ -172,8 +172,8 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: buffer = ReplayBuffer(args.buffer_size) - train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector(policy, test_envs) + train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](policy, test_envs) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") diff --git a/examples/mujoco/mujoco_ddpg.py b/examples/mujoco/mujoco_ddpg.py index db90babb0..d85a14427 100755 --- a/examples/mujoco/mujoco_ddpg.py +++ b/examples/mujoco/mujoco_ddpg.py @@ -9,7 +9,7 @@ import torch from mujoco_env import make_mujoco_env -from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer +from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.exploration import GaussianNoise from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import DDPGPolicy @@ -120,8 +120,8 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: buffer = ReplayBuffer(args.buffer_size) - train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector(policy, test_envs) + train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](policy, test_envs) train_collector.reset() train_collector.collect(n_step=args.start_timesteps, random=True) diff --git a/examples/mujoco/mujoco_npg.py b/examples/mujoco/mujoco_npg.py index 4d8530a53..9416376a1 100755 --- a/examples/mujoco/mujoco_npg.py +++ b/examples/mujoco/mujoco_npg.py @@ -12,7 +12,7 @@ from torch.distributions import Distribution, Independent, Normal from torch.optim.lr_scheduler import LambdaLR -from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer +from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import NPGPolicy from tianshou.policy.base import BasePolicy @@ -169,8 +169,8 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: buffer = ReplayBuffer(args.buffer_size) - train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector(policy, test_envs) + train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](policy, test_envs) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") diff --git a/examples/mujoco/mujoco_ppo.py b/examples/mujoco/mujoco_ppo.py index 7c3f268c8..965ec7739 100755 --- a/examples/mujoco/mujoco_ppo.py +++ b/examples/mujoco/mujoco_ppo.py @@ -12,7 +12,7 @@ from torch.distributions import Distribution, Independent, Normal from torch.optim.lr_scheduler import LambdaLR -from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer +from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import PPOPolicy from tianshou.policy.base import BasePolicy @@ -177,8 +177,8 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: buffer = ReplayBuffer(args.buffer_size) - train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector(policy, test_envs) + train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](policy, test_envs) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") diff --git a/examples/mujoco/mujoco_redq.py b/examples/mujoco/mujoco_redq.py index 8951b03ac..61d85ae1c 100755 --- a/examples/mujoco/mujoco_redq.py +++ b/examples/mujoco/mujoco_redq.py @@ -9,7 +9,7 @@ import torch from mujoco_env import make_mujoco_env -from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer +from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import REDQPolicy from tianshou.policy.base import BasePolicy @@ -148,8 +148,8 @@ def linear(x: int, y: int) -> EnsembleLinear: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: buffer = ReplayBuffer(args.buffer_size) - train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector(policy, test_envs) + train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](policy, test_envs) train_collector.reset() train_collector.collect(n_step=args.start_timesteps, random=True) diff --git a/examples/mujoco/mujoco_reinforce.py b/examples/mujoco/mujoco_reinforce.py index ff7e34099..95391e1ea 100755 --- a/examples/mujoco/mujoco_reinforce.py +++ b/examples/mujoco/mujoco_reinforce.py @@ -12,7 +12,7 @@ from torch.distributions import Distribution, Independent, Normal from torch.optim.lr_scheduler import LambdaLR -from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer +from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import PGPolicy from tianshou.policy.base import BasePolicy @@ -149,8 +149,8 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: buffer = ReplayBuffer(args.buffer_size) - train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector(policy, test_envs) + train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](policy, test_envs) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index af1398380..151237580 100755 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -9,7 +9,7 @@ import torch from mujoco_env import make_mujoco_env -from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer +from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import SACPolicy from tianshou.policy.base import BasePolicy @@ -142,8 +142,8 @@ def test_sac(args: argparse.Namespace = get_args()) -> None: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: buffer = ReplayBuffer(args.buffer_size) - train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector(policy, test_envs) + train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](policy, test_envs) train_collector.reset() train_collector.collect(n_step=args.start_timesteps, random=True) diff --git a/examples/mujoco/mujoco_td3.py b/examples/mujoco/mujoco_td3.py index 6cc8bb212..a5e3e8cf6 100755 --- a/examples/mujoco/mujoco_td3.py +++ b/examples/mujoco/mujoco_td3.py @@ -9,7 +9,7 @@ import torch from mujoco_env import make_mujoco_env -from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer +from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.exploration import GaussianNoise from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import TD3Policy @@ -140,8 +140,8 @@ def test_td3(args: argparse.Namespace = get_args()) -> None: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: buffer = ReplayBuffer(args.buffer_size) - train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector(policy, test_envs) + train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](policy, test_envs) train_collector.reset() train_collector.collect(n_step=args.start_timesteps, random=True) diff --git a/examples/mujoco/mujoco_trpo.py b/examples/mujoco/mujoco_trpo.py index eefdfcc65..9405b2440 100755 --- a/examples/mujoco/mujoco_trpo.py +++ b/examples/mujoco/mujoco_trpo.py @@ -12,7 +12,7 @@ from torch.distributions import Distribution, Independent, Normal from torch.optim.lr_scheduler import LambdaLR -from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer +from tianshou.data import Collector, CollectStats, ReplayBuffer, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import TRPOPolicy from tianshou.policy.base import BasePolicy @@ -174,8 +174,8 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: buffer = ReplayBuffer(args.buffer_size) - train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector(policy, test_envs) + train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](policy, test_envs) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") diff --git a/examples/offline/atari_bcq.py b/examples/offline/atari_bcq.py index 3af40cc7f..62c205076 100644 --- a/examples/offline/atari_bcq.py +++ b/examples/offline/atari_bcq.py @@ -14,7 +14,7 @@ from examples.atari.atari_network import DQN from examples.atari.atari_wrapper import make_atari_env from examples.offline.utils import load_buffer -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import DiscreteBCQPolicy from tianshou.policy.base import BasePolicy @@ -155,7 +155,7 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None: print("Replay buffer size:", len(buffer), flush=True) # collector - test_collector = Collector(policy, test_envs, exploration_noise=True) + test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") diff --git a/examples/offline/atari_cql.py b/examples/offline/atari_cql.py index b2c0c8705..436145c90 100644 --- a/examples/offline/atari_cql.py +++ b/examples/offline/atari_cql.py @@ -15,7 +15,7 @@ from examples.atari.atari_network import QRDQN from examples.atari.atari_wrapper import make_atari_env from examples.offline.utils import load_buffer -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import DiscreteCQLPolicy from tianshou.policy.base import BasePolicy @@ -139,7 +139,7 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None: print("Replay buffer size:", len(buffer), flush=True) # collector - test_collector = Collector(policy, test_envs, exploration_noise=True) + test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") diff --git a/examples/offline/atari_crr.py b/examples/offline/atari_crr.py index 8b6320a79..565908559 100644 --- a/examples/offline/atari_crr.py +++ b/examples/offline/atari_crr.py @@ -14,7 +14,7 @@ from examples.atari.atari_network import DQN from examples.atari.atari_wrapper import make_atari_env from examples.offline.utils import load_buffer -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import DiscreteCRRPolicy from tianshou.policy.base import BasePolicy @@ -156,7 +156,7 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None: print("Replay buffer size:", len(buffer), flush=True) # collector - test_collector = Collector(policy, test_envs, exploration_noise=True) + test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") diff --git a/examples/offline/atari_il.py b/examples/offline/atari_il.py index 712e25042..16b7cdec3 100644 --- a/examples/offline/atari_il.py +++ b/examples/offline/atari_il.py @@ -13,7 +13,7 @@ from examples.atari.atari_network import DQN from examples.atari.atari_wrapper import make_atari_env from examples.offline.utils import load_buffer -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import ImitationPolicy from tianshou.policy.base import BasePolicy @@ -113,7 +113,7 @@ def test_il(args: argparse.Namespace = get_args()) -> None: print("Replay buffer size:", len(buffer), flush=True) # collector - test_collector = Collector(policy, test_envs, exploration_noise=True) + test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") diff --git a/examples/offline/d4rl_bcq.py b/examples/offline/d4rl_bcq.py index 9ed18262a..08b38ded6 100644 --- a/examples/offline/d4rl_bcq.py +++ b/examples/offline/d4rl_bcq.py @@ -11,7 +11,7 @@ from torch.utils.tensorboard import SummaryWriter from examples.offline.utils import load_buffer_d4rl -from tianshou.data import Collector +from tianshou.data import Collector, CollectStats from tianshou.env import SubprocVectorEnv from tianshou.policy import BCQPolicy from tianshou.policy.base import BasePolicy @@ -174,7 +174,7 @@ def test_bcq() -> None: print("Loaded agent from: ", args.resume_path) # collector - test_collector = Collector(policy, test_envs) + test_collector = Collector[CollectStats](policy, test_envs) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -206,7 +206,7 @@ def watch() -> None: args.resume_path = os.path.join(log_path, "policy.pth") policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) - collector = Collector(policy, env) + collector = Collector[CollectStats](policy, env) collector.collect(n_episode=1, render=1 / 35) if not args.watch: diff --git a/examples/offline/d4rl_cql.py b/examples/offline/d4rl_cql.py index 90d6b159c..5b68edf9e 100644 --- a/examples/offline/d4rl_cql.py +++ b/examples/offline/d4rl_cql.py @@ -11,7 +11,7 @@ from torch.utils.tensorboard import SummaryWriter from examples.offline.utils import load_buffer_d4rl -from tianshou.data import Collector +from tianshou.data import Collector, CollectStats from tianshou.env import SubprocVectorEnv from tianshou.policy import CQLPolicy from tianshou.policy.base import BasePolicy @@ -312,7 +312,7 @@ def test_cql() -> None: print("Loaded agent from: ", args.resume_path) # collector - test_collector = Collector(policy, test_envs) + test_collector = Collector[CollectStats](policy, test_envs) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -344,7 +344,7 @@ def watch() -> None: args.resume_path = os.path.join(log_path, "policy.pth") policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) - collector = Collector(policy, env) + collector = Collector[CollectStats](policy, env) collector.collect(n_episode=1, render=1 / 35) if not args.watch: diff --git a/examples/offline/d4rl_il.py b/examples/offline/d4rl_il.py index e03deed80..e1e71fd82 100644 --- a/examples/offline/d4rl_il.py +++ b/examples/offline/d4rl_il.py @@ -11,7 +11,7 @@ from torch.utils.tensorboard import SummaryWriter from examples.offline.utils import load_buffer_d4rl -from tianshou.data import Collector +from tianshou.data import Collector, CollectStats from tianshou.env import SubprocVectorEnv from tianshou.policy import ImitationPolicy from tianshou.policy.base import BasePolicy @@ -110,7 +110,7 @@ def test_il() -> None: print("Loaded agent from: ", args.resume_path) # collector - test_collector = Collector(policy, test_envs) + test_collector = Collector[CollectStats](policy, test_envs) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -142,7 +142,7 @@ def watch() -> None: args.resume_path = os.path.join(log_path, "policy.pth") policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) - collector = Collector(policy, env) + collector = Collector[CollectStats](policy, env) collector.collect(n_episode=1, render=1 / 35) if not args.watch: diff --git a/examples/offline/d4rl_td3_bc.py b/examples/offline/d4rl_td3_bc.py index 6b448b320..f4e8b38c2 100644 --- a/examples/offline/d4rl_td3_bc.py +++ b/examples/offline/d4rl_td3_bc.py @@ -11,7 +11,7 @@ from torch.utils.tensorboard import SummaryWriter from examples.offline.utils import load_buffer_d4rl, normalize_all_obs_in_replay_buffer -from tianshou.data import Collector +from tianshou.data import Collector, CollectStats from tianshou.env import BaseVectorEnv, SubprocVectorEnv, VectorEnvNormObs from tianshou.exploration import GaussianNoise from tianshou.policy import TD3BCPolicy @@ -159,7 +159,7 @@ def test_td3_bc() -> None: print("Loaded agent from: ", args.resume_path) # collector - test_collector = Collector(policy, test_envs) + test_collector = Collector[CollectStats](policy, test_envs) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -191,7 +191,7 @@ def watch() -> None: args.resume_path = os.path.join(log_path, "policy.pth") policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) - collector = Collector(policy, env) + collector = Collector[CollectStats](policy, env) collector.collect(n_episode=1, render=1 / 35) if not args.watch: diff --git a/examples/vizdoom/vizdoom_c51.py b/examples/vizdoom/vizdoom_c51.py index 25ad80487..072421bc6 100644 --- a/examples/vizdoom/vizdoom_c51.py +++ b/examples/vizdoom/vizdoom_c51.py @@ -9,7 +9,7 @@ from env import make_vizdoom_env from network import C51 -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import C51Policy from tianshou.policy.base import BasePolicy @@ -120,8 +120,8 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: stack_num=args.frames_stack, ) # collector - train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector(policy, test_envs, exploration_noise=True) + train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -179,7 +179,7 @@ def watch() -> None: save_only_last_obs=True, stack_num=args.frames_stack, ) - collector = Collector(policy, test_envs, buffer, exploration_noise=True) + collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True) result = collector.collect(n_step=args.buffer_size) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size diff --git a/examples/vizdoom/vizdoom_ppo.py b/examples/vizdoom/vizdoom_ppo.py index 0d72cc619..915bb20a2 100644 --- a/examples/vizdoom/vizdoom_ppo.py +++ b/examples/vizdoom/vizdoom_ppo.py @@ -11,7 +11,7 @@ from torch.distributions import Categorical from torch.optim.lr_scheduler import LambdaLR -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import ICMPolicy, PPOPolicy from tianshou.policy.base import BasePolicy @@ -200,8 +200,8 @@ def dist(logits: torch.Tensor) -> Categorical: stack_num=args.frames_stack, ) # collector - train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector(policy, test_envs, exploration_noise=True) + train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) # log now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") @@ -245,7 +245,7 @@ def watch() -> None: save_only_last_obs=True, stack_num=args.frames_stack, ) - collector = Collector(policy, test_envs, buffer, exploration_noise=True) + collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True) result = collector.collect(n_step=args.buffer_size) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 7e250e881..9021c9d40 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -890,7 +890,8 @@ def test_len_batch_with_dist() -> None: assert len(batch_with_dist_sliced) == 2 assert np.array_equal(batch_with_dist_sliced.a, np.array([1, 2])) assert torch.allclose( - batch_with_dist_sliced.dist.probs, Categorical(torch.ones(2, 3)).probs + batch_with_dist_sliced.dist.probs, + Categorical(torch.ones(2, 3)).probs, ) with pytest.raises(TypeError): diff --git a/test/base/test_env_finite.py b/test/base/test_env_finite.py index 287e79677..39cf9d3ea 100644 --- a/test/base/test_env_finite.py +++ b/test/base/test_env_finite.py @@ -11,7 +11,7 @@ from gymnasium.spaces import Box from torch.utils.data import DataLoader, Dataset, DistributedSampler -from tianshou.data import Batch, Collector +from tianshou.data import Batch, Collector, CollectStats from tianshou.data.types import ( ActBatchProtocol, BatchProtocol, @@ -248,7 +248,7 @@ def test_finite_dummy_vector_env() -> None: dataset = DummyDataset(100) envs = FiniteSubprocVectorEnv([_finite_env_factory(dataset, 5, i) for i in range(5)]) policy = AnyPolicy() - test_collector = Collector(policy, envs, exploration_noise=True) + test_collector = Collector[CollectStats](policy, envs, exploration_noise=True) test_collector.reset() for _ in range(3): @@ -264,7 +264,7 @@ def test_finite_subproc_vector_env() -> None: dataset = DummyDataset(100) envs = FiniteSubprocVectorEnv([_finite_env_factory(dataset, 5, i) for i in range(5)]) policy = AnyPolicy() - test_collector = Collector(policy, envs, exploration_noise=True) + test_collector = Collector[CollectStats](policy, envs, exploration_noise=True) test_collector.reset() for _ in range(3): diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 1aedadabf..ce7998eff 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -6,7 +6,7 @@ import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.exploration import GaussianNoise from tianshou.policy import DDPGPolicy @@ -98,13 +98,13 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: action_space=env.action_space, ) # collector - train_collector = Collector( + train_collector = Collector[CollectStats]( policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), exploration_noise=True, ) - test_collector = Collector(policy, test_envs) + test_collector = Collector[CollectStats](policy, test_envs) # log log_path = os.path.join(args.logdir, args.task, "ddpg") writer = SummaryWriter(log_path) diff --git a/test/continuous/test_npg.py b/test/continuous/test_npg.py index 98803a9d2..d853e2186 100644 --- a/test/continuous/test_npg.py +++ b/test/continuous/test_npg.py @@ -8,7 +8,7 @@ from torch.distributions import Distribution, Independent, Normal from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import NPGPolicy from tianshou.policy.base import BasePolicy @@ -121,12 +121,12 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: deterministic_eval=True, ) # collector - train_collector = Collector( + train_collector = Collector[CollectStats]( policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), ) - test_collector = Collector(policy, test_envs) + test_collector = Collector[CollectStats](policy, test_envs) # log log_path = os.path.join(args.logdir, args.task, "npg") writer = SummaryWriter(log_path) diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index 15d834096..4b56bd630 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -7,7 +7,7 @@ from torch.distributions import Distribution, Independent, Normal from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import PPOPolicy from tianshou.policy.base import BasePolicy @@ -122,12 +122,12 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: action_space=env.action_space, ) # collector - train_collector = Collector( + train_collector = Collector[CollectStats]( policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), ) - test_collector = Collector(policy, test_envs) + test_collector = Collector[CollectStats](policy, test_envs) # log log_path = os.path.join(args.logdir, args.task, "ppo") writer = SummaryWriter(log_path) diff --git a/test/continuous/test_redq.py b/test/continuous/test_redq.py index f627f7e4f..82c8f0637 100644 --- a/test/continuous/test_redq.py +++ b/test/continuous/test_redq.py @@ -7,7 +7,7 @@ import torch.nn as nn from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import REDQPolicy from tianshou.policy.base import BasePolicy @@ -127,13 +127,13 @@ def linear(x: int, y: int) -> nn.Module: action_space=env.action_space, ) # collector - train_collector = Collector( + train_collector = Collector[CollectStats]( policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), exploration_noise=True, ) - test_collector = Collector(policy, test_envs) + test_collector = Collector[CollectStats](policy, test_envs) train_collector.reset() train_collector.collect(n_step=args.start_timesteps, random=True) # log diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 77a403359..09fc3ca45 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -6,7 +6,7 @@ import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import ImitationPolicy, SACPolicy from tianshou.policy.base import BasePolicy @@ -124,13 +124,13 @@ def test_sac_with_il(args: argparse.Namespace = get_args()) -> None: action_space=env.action_space, ) # collector - train_collector = Collector( + train_collector = Collector[CollectStats]( policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), exploration_noise=True, ) - test_collector = Collector(policy, test_envs) + test_collector = Collector[CollectStats](policy, test_envs) # train_collector.collect(n_step=args.buffer_size) # log log_path = os.path.join(args.logdir, args.task, "sac") @@ -184,7 +184,7 @@ def stop_fn(mean_rewards: float) -> bool: ) il_test_env = gym.make(args.task) il_test_env.reset(seed=args.seed + args.training_num + args.test_num) - il_test_collector = Collector( + il_test_collector = Collector[CollectStats]( il_policy, # envpool.make_gymnasium(args.task, num_envs=args.test_num, seed=args.seed), il_test_env, diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index 21a2cf40d..6c59ea25a 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -7,7 +7,7 @@ import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.exploration import GaussianNoise from tianshou.policy import TD3Policy @@ -115,13 +115,13 @@ def test_td3(args: argparse.Namespace = get_args()) -> None: action_space=env.action_space, ) # collector - train_collector = Collector( + train_collector = Collector[CollectStats]( policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), exploration_noise=True, ) - test_collector = Collector(policy, test_envs) + test_collector = Collector[CollectStats](policy, test_envs) # train_collector.collect(n_step=args.buffer_size) # log log_path = os.path.join(args.logdir, args.task, "td3") diff --git a/test/continuous/test_trpo.py b/test/continuous/test_trpo.py index 8841891bf..91e215116 100644 --- a/test/continuous/test_trpo.py +++ b/test/continuous/test_trpo.py @@ -8,7 +8,7 @@ from torch.distributions import Distribution, Independent, Normal from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import TRPOPolicy from tianshou.policy.base import BasePolicy @@ -121,12 +121,12 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: max_backtracks=args.max_backtracks, ) # collector - train_collector = Collector( + train_collector = Collector[CollectStats]( policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), ) - test_collector = Collector(policy, test_envs) + test_collector = Collector[CollectStats](policy, test_envs) # log log_path = os.path.join(args.logdir, args.task, "trpo") writer = SummaryWriter(log_path) diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index 2fd41aff8..192f24c24 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -7,7 +7,7 @@ from gymnasium.spaces import Box from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv, SubprocVectorEnv from tianshou.policy import A2CPolicy, ImitationPolicy from tianshou.policy.base import BasePolicy @@ -110,13 +110,13 @@ def test_a2c_with_il(args: argparse.Namespace = get_args()) -> None: action_space=env.action_space, ) # collector - train_collector = Collector( + train_collector = Collector[CollectStats]( policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), ) train_collector.reset() - test_collector = Collector(policy, test_envs) + test_collector = Collector[CollectStats](policy, test_envs) test_collector.reset() # log log_path = os.path.join(args.logdir, args.task, "a2c") @@ -171,7 +171,7 @@ def stop_fn(mean_rewards: float) -> bool: ) il_env.seed(args.seed) - il_test_collector = Collector( + il_test_collector = Collector[CollectStats]( il_policy, il_env, ) diff --git a/test/discrete/test_bdq.py b/test/discrete/test_bdq.py index 91c66bac0..16042f622 100644 --- a/test/discrete/test_bdq.py +++ b/test/discrete/test_bdq.py @@ -4,7 +4,7 @@ import numpy as np import torch -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import ContinuousToDiscrete, DummyVectorEnv from tianshou.policy import BranchingDQNPolicy from tianshou.trainer import OffpolicyTrainer @@ -106,13 +106,13 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None: target_update_freq=args.target_update_freq, ) # collector - train_collector = Collector( + train_collector = Collector[CollectStats]( policy, train_envs, VectorReplayBuffer(args.buffer_size, args.training_num), exploration_noise=True, ) - test_collector = Collector(policy, test_envs, exploration_noise=False) + test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=False) # policy.set_eps(1) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index 8b34ddb4b..41d6a0260 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -9,6 +9,7 @@ from tianshou.data import ( Collector, + CollectStats, PrioritizedVectorReplayBuffer, ReplayBuffer, VectorReplayBuffer, @@ -115,8 +116,8 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) # collector - train_collector = Collector(policy, train_envs, buf, exploration_noise=True) - test_collector = Collector(policy, test_envs, exploration_noise=True) + train_collector = Collector[CollectStats](policy, train_envs, buf, exploration_noise=True) + test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) # policy.set_eps(1) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 773004f2c..f82aca1f6 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -8,6 +8,7 @@ from tianshou.data import ( Collector, + CollectStats, PrioritizedVectorReplayBuffer, ReplayBuffer, VectorReplayBuffer, @@ -106,8 +107,8 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) # collector - train_collector = Collector(policy, train_envs, buf, exploration_noise=True) - test_collector = Collector(policy, test_envs, exploration_noise=True) + train_collector = Collector[CollectStats](policy, train_envs, buf, exploration_noise=True) + test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) # policy.set_eps(1) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 193179097..4cc0b6bd0 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -6,7 +6,7 @@ import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import DQNPolicy from tianshou.policy.base import BasePolicy @@ -89,9 +89,9 @@ def test_drqn(args: argparse.Namespace = get_args()) -> None: stack_num=args.stack_num, ignore_obs_next=True, ) - train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) + train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) # the stack_num is for RNN training: sample framestack obs - test_collector = Collector(policy, test_envs, exploration_noise=True) + test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) # policy.set_eps(1) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) diff --git a/test/discrete/test_fqf.py b/test/discrete/test_fqf.py index 743293be0..e0899d315 100644 --- a/test/discrete/test_fqf.py +++ b/test/discrete/test_fqf.py @@ -8,6 +8,7 @@ from tianshou.data import ( Collector, + CollectStats, PrioritizedVectorReplayBuffer, ReplayBuffer, VectorReplayBuffer, @@ -123,8 +124,8 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None: else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) # collector - train_collector = Collector(policy, train_envs, buf, exploration_noise=True) - test_collector = Collector(policy, test_envs, exploration_noise=True) + train_collector = Collector[CollectStats](policy, train_envs, buf, exploration_noise=True) + test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) # policy.set_eps(1) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) diff --git a/test/discrete/test_iqn.py b/test/discrete/test_iqn.py index f7ea67adb..08f545b11 100644 --- a/test/discrete/test_iqn.py +++ b/test/discrete/test_iqn.py @@ -8,6 +8,7 @@ from tianshou.data import ( Collector, + CollectStats, PrioritizedVectorReplayBuffer, ReplayBuffer, VectorReplayBuffer, @@ -119,8 +120,8 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None: else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) # collector - train_collector = Collector(policy, train_envs, buf, exploration_noise=True) - test_collector = Collector(policy, test_envs, exploration_noise=True) + train_collector = Collector[CollectStats](policy, train_envs, buf, exploration_noise=True) + test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) # policy.set_eps(1) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index 60d0eb469..8a681583d 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -7,7 +7,7 @@ from gymnasium.spaces import Box from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import PGPolicy from tianshou.policy.base import BasePolicy @@ -90,12 +90,12 @@ def test_pg(args: argparse.Namespace = get_args()) -> None: torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2)) torch.nn.init.zeros_(m.bias) # collector - train_collector = Collector( + train_collector = Collector[CollectStats]( policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), ) - test_collector = Collector(policy, test_envs) + test_collector = Collector[CollectStats](policy, test_envs) # log log_path = os.path.join(args.logdir, args.task, "pg") writer = SummaryWriter(log_path) diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index 27fe6f517..7e541fffb 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -8,7 +8,7 @@ from gymnasium.spaces import Box from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import PPOPolicy from tianshou.policy.base import BasePolicy @@ -117,12 +117,12 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: recompute_advantage=args.recompute_adv, ) # collector - train_collector = Collector( + train_collector = Collector[CollectStats]( policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), ) - test_collector = Collector(policy, test_envs) + test_collector = Collector[CollectStats](policy, test_envs) # log log_path = os.path.join(args.logdir, args.task, "ppo") writer = SummaryWriter(log_path) diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index 76d7d429d..5aa543fb5 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -6,7 +6,12 @@ import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer +from tianshou.data import ( + Collector, + CollectStats, + PrioritizedVectorReplayBuffer, + VectorReplayBuffer, +) from tianshou.env import DummyVectorEnv from tianshou.policy import QRDQNPolicy from tianshou.policy.base import BasePolicy @@ -108,8 +113,8 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None: else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) # collector - train_collector = Collector(policy, train_envs, buf, exploration_noise=True) - test_collector = Collector(policy, test_envs, exploration_noise=True) + train_collector = Collector[CollectStats](policy, train_envs, buf, exploration_noise=True) + test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) # policy.set_eps(1) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py index 0a73d4b77..d7d4b15b1 100644 --- a/test/discrete/test_rainbow.py +++ b/test/discrete/test_rainbow.py @@ -7,7 +7,12 @@ import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer +from tianshou.data import ( + Collector, + CollectStats, + PrioritizedVectorReplayBuffer, + VectorReplayBuffer, +) from tianshou.env import DummyVectorEnv from tianshou.policy import RainbowPolicy from tianshou.policy.base import BasePolicy @@ -123,8 +128,8 @@ def noisy_linear(x: int, y: int) -> NoisyLinear: else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) # collector - train_collector = Collector(policy, train_envs, buf, exploration_noise=True) - test_collector = Collector(policy, test_envs, exploration_noise=True) + train_collector = Collector[CollectStats](policy, train_envs, buf, exploration_noise=True) + test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) # policy.set_eps(1) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) diff --git a/test/discrete/test_sac.py b/test/discrete/test_sac.py index f16e59daf..3409dab0a 100644 --- a/test/discrete/test_sac.py +++ b/test/discrete/test_sac.py @@ -6,7 +6,7 @@ import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import DiscreteSACPolicy from tianshou.policy.base import BasePolicy @@ -106,12 +106,12 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: estimation_step=args.n_step, ) # collector - train_collector = Collector( + train_collector = Collector[CollectStats]( policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), ) - test_collector = Collector(policy, test_envs) + test_collector = Collector[CollectStats](policy, test_envs) # train_collector.collect(n_step=args.buffer_size) # log log_path = os.path.join(args.logdir, args.task, "discrete_sac") diff --git a/test/modelbased/test_dqn_icm.py b/test/modelbased/test_dqn_icm.py index 5ef0bba65..c108e7c0f 100644 --- a/test/modelbased/test_dqn_icm.py +++ b/test/modelbased/test_dqn_icm.py @@ -6,7 +6,12 @@ import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer +from tianshou.data import ( + Collector, + CollectStats, + PrioritizedVectorReplayBuffer, + VectorReplayBuffer, +) from tianshou.env import DummyVectorEnv from tianshou.policy import DQNPolicy, ICMPolicy from tianshou.policy.base import BasePolicy @@ -149,8 +154,8 @@ def test_dqn_icm(args: argparse.Namespace = get_args()) -> None: else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) # collector - train_collector = Collector(policy, train_envs, buf, exploration_noise=True) - test_collector = Collector(policy, test_envs, exploration_noise=True) + train_collector = Collector[CollectStats](policy, train_envs, buf, exploration_noise=True) + test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) # policy.set_eps(1) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) diff --git a/test/modelbased/test_ppo_icm.py b/test/modelbased/test_ppo_icm.py index 77f9a40e1..7d3780960 100644 --- a/test/modelbased/test_ppo_icm.py +++ b/test/modelbased/test_ppo_icm.py @@ -7,7 +7,7 @@ from gymnasium.spaces import Box from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import ICMPolicy, PPOPolicy from tianshou.policy.base import BasePolicy @@ -155,12 +155,12 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: forward_loss_weight=args.forward_loss_weight, ) # collector - train_collector = Collector( + train_collector = Collector[CollectStats]( policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), ) - test_collector = Collector(policy, test_envs) + test_collector = Collector[CollectStats](policy, test_envs) # log log_path = os.path.join(args.logdir, args.task, "ppo_icm") writer = SummaryWriter(log_path) diff --git a/test/modelbased/test_psrl.py b/test/modelbased/test_psrl.py index 995aef698..e79977381 100644 --- a/test/modelbased/test_psrl.py +++ b/test/modelbased/test_psrl.py @@ -6,7 +6,7 @@ import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.policy import PSRLPolicy from tianshou.trainer import OnpolicyTrainer from tianshou.utils import LazyLogger, TensorboardLogger, WandbLogger @@ -77,14 +77,14 @@ def test_psrl(args: argparse.Namespace = get_args()) -> None: add_done_loop=args.add_done_loop, ) # collector - train_collector = Collector( + train_collector = Collector[CollectStats]( policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), exploration_noise=True, ) train_collector.reset() - test_collector = Collector(policy, test_envs) + test_collector = Collector[CollectStats](policy, test_envs) test_collector.reset() # Logger log_path = os.path.join(args.logdir, args.task, "psrl") diff --git a/test/offline/gather_cartpole_data.py b/test/offline/gather_cartpole_data.py index 19ba653e5..bee9063ea 100644 --- a/test/offline/gather_cartpole_data.py +++ b/test/offline/gather_cartpole_data.py @@ -7,7 +7,12 @@ import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer +from tianshou.data import ( + Collector, + CollectStats, + PrioritizedVectorReplayBuffer, + VectorReplayBuffer, +) from tianshou.env import DummyVectorEnv from tianshou.policy import QRDQNPolicy from tianshou.policy.base import BasePolicy @@ -113,9 +118,9 @@ def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer: else: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) # collector - train_collector = Collector(policy, train_envs, buf, exploration_noise=True) + train_collector = Collector[CollectStats](policy, train_envs, buf, exploration_noise=True) train_collector.reset() - test_collector = Collector(policy, test_envs, exploration_noise=True) + test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) test_collector.reset() # policy.set_eps(1) train_collector.collect(n_step=args.batch_size * args.training_num) @@ -165,7 +170,7 @@ def test_fn(epoch: int, env_step: int | None) -> None: # save buffer in pickle format, for imitation learning unittest buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(test_envs)) policy.set_eps(0.2) - collector = Collector(policy, test_envs, buf, exploration_noise=True) + collector = Collector[CollectStats](policy, test_envs, buf, exploration_noise=True) collector.reset() collector_stats = collector.collect(n_step=args.buffer_size) if args.save_buffer_name.endswith(".hdf5"): diff --git a/test/offline/gather_pendulum_data.py b/test/offline/gather_pendulum_data.py index bc46ce4da..614ee388f 100644 --- a/test/offline/gather_pendulum_data.py +++ b/test/offline/gather_pendulum_data.py @@ -7,7 +7,7 @@ import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import SACPolicy from tianshou.policy.base import BasePolicy @@ -129,8 +129,8 @@ def gather_data() -> VectorReplayBuffer: ) # collector buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) - train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector(policy, test_envs) + train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](policy, test_envs) # train_collector.collect(n_step=args.buffer_size) # log log_path = os.path.join(args.logdir, args.task, "sac") diff --git a/test/offline/test_bcq.py b/test/offline/test_bcq.py index 8b31c1969..2ed910902 100644 --- a/test/offline/test_bcq.py +++ b/test/offline/test_bcq.py @@ -9,7 +9,7 @@ import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import BasePolicy, BCQPolicy from tianshou.policy.imitation.bcq import BCQTrainingStats @@ -164,8 +164,8 @@ def test_bcq(args: argparse.Namespace = get_args()) -> None: # collector # buffer has been gathered - # train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector(policy, test_envs) + # train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](policy, test_envs) # log t0 = datetime.datetime.now().strftime("%m%d_%H%M%S") log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_bcq' @@ -184,7 +184,7 @@ def watch() -> None: policy.load_state_dict( torch.load(os.path.join(log_path, "policy.pth"), map_location=torch.device("cpu")), ) - collector = Collector(policy, env) + collector = Collector[CollectStats](policy, env) collector.collect(n_episode=1, render=1 / 35) # trainer diff --git a/test/offline/test_cql.py b/test/offline/test_cql.py index 41d67151a..bd84098ba 100644 --- a/test/offline/test_cql.py +++ b/test/offline/test_cql.py @@ -10,7 +10,7 @@ import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import BasePolicy, CQLPolicy from tianshou.policy.imitation.cql import CQLTrainingStats @@ -165,8 +165,8 @@ def test_cql(args: argparse.Namespace = get_args()) -> None: # collector # buffer has been gathered - # train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector(policy, test_envs) + # train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](policy, test_envs) # log t0 = datetime.datetime.now().strftime("%m%d_%H%M%S") log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_cql' diff --git a/test/offline/test_discrete_bcq.py b/test/offline/test_discrete_bcq.py index 6e8e8784a..e69e0a1fa 100644 --- a/test/offline/test_discrete_bcq.py +++ b/test/offline/test_discrete_bcq.py @@ -8,7 +8,12 @@ import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer +from tianshou.data import ( + Collector, + CollectStats, + PrioritizedVectorReplayBuffer, + VectorReplayBuffer, +) from tianshou.env import DummyVectorEnv from tianshou.policy import BasePolicy, DiscreteBCQPolicy from tianshou.trainer import OfflineTrainer @@ -107,7 +112,7 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None: buffer = gather_data() # collector - test_collector = Collector(policy, test_envs, exploration_noise=True) + test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) log_path = os.path.join(args.logdir, args.task, "discrete_bcq") writer = SummaryWriter(log_path) diff --git a/test/offline/test_discrete_cql.py b/test/offline/test_discrete_cql.py index f2a60e00c..97766d494 100644 --- a/test/offline/test_discrete_cql.py +++ b/test/offline/test_discrete_cql.py @@ -8,7 +8,12 @@ import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer +from tianshou.data import ( + Collector, + CollectStats, + PrioritizedVectorReplayBuffer, + VectorReplayBuffer, +) from tianshou.env import DummyVectorEnv from tianshou.policy import BasePolicy, DiscreteCQLPolicy from tianshou.trainer import OfflineTrainer @@ -96,7 +101,7 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None: buffer = gather_data() # collector - test_collector = Collector(policy, test_envs, exploration_noise=True) + test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) log_path = os.path.join(args.logdir, args.task, "discrete_cql") writer = SummaryWriter(log_path) diff --git a/test/offline/test_discrete_crr.py b/test/offline/test_discrete_crr.py index bc54dd9d0..bf9a833a9 100644 --- a/test/offline/test_discrete_crr.py +++ b/test/offline/test_discrete_crr.py @@ -8,7 +8,12 @@ import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer +from tianshou.data import ( + Collector, + CollectStats, + PrioritizedVectorReplayBuffer, + VectorReplayBuffer, +) from tianshou.env import DummyVectorEnv from tianshou.policy import BasePolicy, DiscreteCRRPolicy from tianshou.trainer import OfflineTrainer @@ -100,7 +105,7 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None: buffer = gather_data() # collector - test_collector = Collector(policy, test_envs, exploration_noise=True) + test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) log_path = os.path.join(args.logdir, args.task, "discrete_crr") writer = SummaryWriter(log_path) diff --git a/test/offline/test_gail.py b/test/offline/test_gail.py index ea13f484c..c7f183587 100644 --- a/test/offline/test_gail.py +++ b/test/offline/test_gail.py @@ -9,7 +9,7 @@ from torch.distributions import Distribution, Independent, Normal from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import BasePolicy, GAILPolicy from tianshou.trainer import OnpolicyTrainer @@ -160,12 +160,12 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: action_space=env.action_space, ) # collector - train_collector = Collector( + train_collector = Collector[CollectStats]( policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), ) - test_collector = Collector(policy, test_envs) + test_collector = Collector[CollectStats](policy, test_envs) # log log_path = os.path.join(args.logdir, args.task, "gail") writer = SummaryWriter(log_path) diff --git a/test/offline/test_td3_bc.py b/test/offline/test_td3_bc.py index fa01444ab..17c3afb06 100644 --- a/test/offline/test_td3_bc.py +++ b/test/offline/test_td3_bc.py @@ -9,7 +9,7 @@ import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.exploration import GaussianNoise from tianshou.policy import TD3BCPolicy @@ -152,8 +152,8 @@ def test_td3_bc(args: argparse.Namespace = get_args()) -> None: # collector # buffer has been gathered - # train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector(policy, test_envs) + # train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector[CollectStats](policy, test_envs) # log t0 = datetime.datetime.now().strftime("%m%d_%H%M%S") log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_td3_bc' diff --git a/test/pettingzoo/pistonball.py b/test/pettingzoo/pistonball.py index c57522df0..0cf269d4f 100644 --- a/test/pettingzoo/pistonball.py +++ b/test/pettingzoo/pistonball.py @@ -8,7 +8,7 @@ from pettingzoo.butterfly import pistonball_v6 from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, InfoStats, VectorReplayBuffer +from tianshou.data import Collector, CollectStats, InfoStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.env.pettingzoo_env import PettingZooEnv from tianshou.policy import BasePolicy, DQNPolicy, MultiAgentPolicyManager @@ -128,13 +128,13 @@ def train_agent( policy, optim, agents = get_agents(args, agents=agents, optims=optims) # collector - train_collector = Collector( + train_collector = Collector[CollectStats]( policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), exploration_noise=True, ) - test_collector = Collector(policy, test_envs, exploration_noise=True) + test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) # log @@ -189,6 +189,6 @@ def watch(args: argparse.Namespace = get_args(), policy: BasePolicy | None = Non ) policy, _, _ = get_agents(args) [agent.set_eps(args.eps_test) for agent in policy.policies.values()] - collector = Collector(policy, env, exploration_noise=True) + collector = Collector[CollectStats](policy, env, exploration_noise=True) result = collector.collect(n_episode=1, render=args.render) result.pprint_asdict() diff --git a/test/pettingzoo/pistonball_continuous.py b/test/pettingzoo/pistonball_continuous.py index 38de81173..7beb92fde 100644 --- a/test/pettingzoo/pistonball_continuous.py +++ b/test/pettingzoo/pistonball_continuous.py @@ -11,7 +11,7 @@ from torch.distributions import Distribution, Independent, Normal from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.data.stats import InfoStats from tianshou.env import DummyVectorEnv from tianshou.env.pettingzoo_env import PettingZooEnv @@ -234,13 +234,13 @@ def train_agent( policy, optim, agents = get_agents(args, agents=agents, optims=optims) # collector - train_collector = Collector( + train_collector = Collector[CollectStats]( policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), exploration_noise=False, # True ) - test_collector = Collector(policy, test_envs) + test_collector = Collector[CollectStats](policy, test_envs) # train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True) # log log_path = os.path.join(args.logdir, "pistonball", "dqn") @@ -284,6 +284,6 @@ def watch(args: argparse.Namespace = get_args(), policy: BasePolicy | None = Non "watching random agents, as loading pre-trained policies is currently not supported", ) policy, _, _ = get_agents(args) - collector = Collector(policy, env) + collector = Collector[CollectStats](policy, env) collector_result = collector.collect(n_episode=1, render=args.render) collector_result.pprint_asdict() diff --git a/test/pettingzoo/tic_tac_toe.py b/test/pettingzoo/tic_tac_toe.py index 966c9e04c..085ad7517 100644 --- a/test/pettingzoo/tic_tac_toe.py +++ b/test/pettingzoo/tic_tac_toe.py @@ -9,7 +9,7 @@ from pettingzoo.classic import tictactoe_v3 from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data import Collector, CollectStats, VectorReplayBuffer from tianshou.data.stats import InfoStats from tianshou.env import DummyVectorEnv from tianshou.env.pettingzoo_env import PettingZooEnv @@ -163,13 +163,13 @@ def train_agent( ) # collector - train_collector = Collector( + train_collector = Collector[CollectStats]( policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), exploration_noise=True, ) - test_collector = Collector(policy, test_envs, exploration_noise=True) + test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True) # policy.set_eps(1) train_collector.reset() train_collector.collect(n_step=args.batch_size * args.training_num) @@ -229,6 +229,6 @@ def watch( env = DummyVectorEnv([partial(get_env, render_mode="human")]) policy, optim, agents = get_agents(args, agent_learn=agent_learn, agent_opponent=agent_opponent) policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_test) - collector = Collector(policy, env, exploration_noise=True) + collector = Collector[CollectStats](policy, env, exploration_noise=True) result = collector.collect(n_episode=1, render=args.render) result.pprint_asdict() diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index 81141a8a6..c9d1cf58f 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -7,7 +7,7 @@ from sensai.util.string import ToStringMixin from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer -from tianshou.data.collector import BaseCollector +from tianshou.data.collector import BaseCollector, CollectStats from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.env import Environments from tianshou.highlevel.module.actor import ( @@ -120,8 +120,13 @@ def create_train_test_collector( save_only_last_obs=self.sampling_config.replay_buffer_save_only_last_obs, ignore_obs_next=self.sampling_config.replay_buffer_ignore_obs_next, ) - train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector(policy, envs.test_envs) + train_collector = Collector[CollectStats]( + policy, + train_envs, + buffer, + exploration_noise=True, + ) + test_collector = Collector[CollectStats](policy, envs.test_envs) if reset_collectors: train_collector.reset() test_collector.reset() diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 74413997b..43698228f 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -36,7 +36,7 @@ from sensai.util.logging import datetime_tag from sensai.util.string import ToStringMixin -from tianshou.data import Collector, InfoStats +from tianshou.data import Collector, CollectStats, InfoStats from tianshou.env import BaseVectorEnv from tianshou.highlevel.agent import ( A2CAgentFactory, @@ -451,7 +451,7 @@ def _watch_agent( env: BaseVectorEnv, render: float, ) -> None: - collector = Collector(policy, env) + collector = Collector[CollectStats](policy, env) collector.reset() result = collector.collect(n_episode=num_episodes, render=render) assert result.returns_stat is not None # for mypy From 5098d321bcff349acaadf6e24e6a43b97254e5af Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Sun, 18 Aug 2024 17:34:51 +0200 Subject: [PATCH 30/33] nb-clean --- docs/02_notebooks/L6_Trainer.ipynb | 4 ---- 1 file changed, 4 deletions(-) diff --git a/docs/02_notebooks/L6_Trainer.ipynb b/docs/02_notebooks/L6_Trainer.ipynb index dfcbd0f3c..cc5b664ed 100644 --- a/docs/02_notebooks/L6_Trainer.ipynb +++ b/docs/02_notebooks/L6_Trainer.ipynb @@ -56,10 +56,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "ExecuteTime": { - "end_time": "2024-05-06T15:34:02.969675Z", - "start_time": "2024-05-06T15:34:00.747309Z" - }, "editable": true, "id": "do-xZ-8B7nVH", "slideshow": { From 419f3c52e594314e161f0f25dbb9bd534770d6a4 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Sun, 18 Aug 2024 17:51:30 +0200 Subject: [PATCH 31/33] Typos [ci skip] --- tianshou/data/collector.py | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 7615e725d..15737ce6b 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -541,12 +541,12 @@ class Collector(BaseCollector[TCollectStats], Generic[TCollectStats]): # S - number of surplus envs, i.e., envs that are ready but won't be used in the next iteration. # Only used in n_episode case. Then, R becomes R-S. # local_index - selecting from the locally available environments. In more details: - # Each env is associated to an number in [0,..., N-1]. At any moment there are R ready envs, + # Each env is associated to a number in [0,..., N-1]. At any moment there are R ready envs, # but they are not necessarily equal to [0, ..., R-1]. Let the R corresponding indices be # [r_0, ..., r_(R-1)] (each r_i is in [0, ... N-1]). If the local index is # [0, 1, 2], it means that we want to select envs [r_0, r_1, r_2]. # We will usually select from the ready envs by slicing like `ready_env_idx_R[local_index]` - # global_index - the index in [0, ..., N-1]. Slicing the an `_R` index by a local_index produces the + # global_index - the index in [0, ..., N-1]. Slicing the a `_R` index by a local_index produces the # corresponding global index. In the example above: # 1. _R index is [r_0, ..., r_(R-1)] # 2. local_index is [0, 1, 2] @@ -564,7 +564,7 @@ def __init__( collect_stats_class: type[TCollectStats] = CollectStats, # type: ignore[assignment] ) -> None: """ - :param policy: a tianshou policy, each :class:`BasePolocy` is capable of computing a batch + :param policy: a tianshou policy, each :class:`BasePolicy` is capable of computing a batch of actions from a batch of observations. :param env: a ``gymnasium.Env`` environment or a vectorized instance of the :class:`~tianshou.env.BaseVectorEnv` class. The latter is strongly recommended, as with @@ -758,24 +758,24 @@ def _collect( # noqa: C901 It does the collection by executing the following logic: 0. Keep track of n_step and n_episode for being able to stop the collection. - 1. Create a CollectStats instance to store the statistics of the collection. - 2. Compute actions (with policy or sampling from action space) for the R currently active envs. - 3. Perform a step in these R envs. - 4. Perform on-step hooks on the result - 5. Update the CollectStats (using `update_at_step_batch`) and the internal counters after the step - 6. Add the resulting R transitions to the buffer - 7. Find the D envs that reached done in the current iteration - 8. Reset the envs that reached done - 9. Extract episodes for the envs that reached done from the buffer - 10. Perform on-episode-done hooks, modify the transitions belonging to the episodes inside the buffer inplace - 11. Update the CollectStats instance with the new episodes using `update_on_episode_done` + 1. Create a CollectStats instance to store the statistics of the collection. + 2. Compute actions (with policy or sampling from action space) for the R currently active envs. + 3. Perform a step in these R envs. + 4. Perform on-step hook on the result + 5. Update the CollectStats (using `update_at_step_batch`) and the internal counters after the step + 6. Add the resulting R transitions to the buffer + 7. Find the D envs that reached done in the current iteration + 8. Reset the envs that reached done + 9. Extract episodes for the envs that reached done from the buffer + 10. Perform on-episode-done hook. If it has a return, modify the transitions belonging to the episodes inside the buffer inplace + 11. Update the CollectStats instance with the episodes from 9. by using `update_on_episode_done` 12. Prepare next step in while loop by saving the last observations and infos - 13. Remove surplus envs from collection mechanism, thereby reducing R, to increase performance + 13. Remove S surplus envs from collection mechanism, thereby reducing R to R-S, to increase performance 14. Check whether we added NaN's to the buffer and raise error if so - 15. Update instance-level collection counters (contrary to counters with a lifetime of the method call) + 15. Update instance-level collection counters (contrary to counters with a lifetime of the collect execution) 16. Prepare for the next call of collect (save last observations and info to collector state) - You can search for Step to find the place where it happens + You can search for Step to find where it happens """ # TODO: can't do it init since AsyncCollector is currently a subclass of Collector if self.env.is_async: From 0a655523aaf8024222d325535e5cd26426585ee1 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Tue, 20 Aug 2024 14:32:55 +0200 Subject: [PATCH 32/33] Collector, fixed omission: make use of `raise_on_nan_in_buffer` --- tianshou/data/collector.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 15737ce6b..49243cfee 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -284,6 +284,7 @@ def __init__( exploration_noise: bool = False, # The typing is correct, there's a bug in mypy, see https://github.com/python/mypy/issues/3737 collect_stats_class: type[TCollectStats] = CollectStats, # type: ignore[assignment] + raise_on_nan_in_buffer: bool = True, ) -> None: if isinstance(env, gym.Env) and not hasattr(env, "__len__"): warnings.warn("Single environment detected, wrap to DummyVectorEnv.") @@ -294,6 +295,7 @@ def __init__( buffer = VectorReplayBuffer(DEFAULT_BUFFER_MAXSIZE * len(env), len(env)) self.buffer: ReplayBuffer | ReplayBufferManager = buffer + self.raise_on_nan_in_buffer = raise_on_nan_in_buffer self.policy = policy self.env = cast(BaseVectorEnv, env) self.exploration_noise = exploration_noise @@ -499,6 +501,19 @@ def collect( collect_time = time.time() - pre_collect_time collect_stats.set_collect_time(collect_time, update_collect_speed=True) collect_stats.refresh_all_sequence_stats() + + if self.raise_on_nan_in_buffer and self.buffer.hasnull(): + nan_batch = self.buffer.isnull().apply_values_transform(np.sum) + + raise MalformedBufferError( + "NaN detected in the buffer. You can drop them with `buffer.dropnull()`. " + f"This error is most often caused by an incorrect use of {EpisodeRolloutHook.__name__}" + "together with the `n_steps` (instead of `n_episodes`) option, or by " + f"an incorrect implementation of {StepHook.__name__}." + "Here an overview of the number of NaNs per field: \n" + f"{nan_batch}", + ) + return collect_stats def _validate_n_step_n_episode(self, n_episode: int | None, n_step: int | None) -> None: @@ -608,7 +623,9 @@ def __init__( buffer, exploration_noise=exploration_noise, collect_stats_class=collect_stats_class, + raise_on_nan_in_buffer=raise_on_nan_in_buffer, ) + self._pre_collect_obs_RO: np.ndarray | None = None self._pre_collect_info_R: np.ndarray | None = None self._pre_collect_hidden_state_RH: np.ndarray | torch.Tensor | Batch | None = None @@ -1142,6 +1159,7 @@ def __init__( env: BaseVectorEnv, buffer: ReplayBuffer | None = None, exploration_noise: bool = False, + raise_on_nan_in_buffer: bool = True, ) -> None: if not env.is_async: # TODO: raise an exception? @@ -1157,6 +1175,7 @@ def __init__( buffer, exploration_noise, collect_stats_class=CollectStats, + raise_on_nan_in_buffer=raise_on_nan_in_buffer, ) # E denotes the number of parallel environments: self.env_num # At init, E=R but during collection R <= E From bd58581437f7637e2bd9f1c509fd0e66b7276326 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Tue, 20 Aug 2024 14:34:41 +0200 Subject: [PATCH 33/33] Block comment [ci skip] --- tianshou/data/collector.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 49243cfee..2c615923d 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -788,9 +788,8 @@ def _collect( # noqa: C901 11. Update the CollectStats instance with the episodes from 9. by using `update_on_episode_done` 12. Prepare next step in while loop by saving the last observations and infos 13. Remove S surplus envs from collection mechanism, thereby reducing R to R-S, to increase performance - 14. Check whether we added NaN's to the buffer and raise error if so - 15. Update instance-level collection counters (contrary to counters with a lifetime of the collect execution) - 16. Prepare for the next call of collect (save last observations and info to collector state) + 14. Update instance-level collection counters (contrary to counters with a lifetime of the collect execution) + 15. Prepare for the next call of collect (save last observations and info to collector state) You can search for Step to find where it happens """ @@ -1055,7 +1054,6 @@ def _collect( # noqa: C901 ): break - # Step 14 # Check if we screwed up somewhere if self.buffer.hasnull(): nan_batch = self.buffer.isnull().apply_values_transform(np.sum) @@ -1069,12 +1067,12 @@ def _collect( # noqa: C901 f"{nan_batch}", ) - # Step 15 + # Step 14 # update instance-lifetime counters, different from collect_stats self.collect_step += step_count self.collect_episode += num_collected_episodes - # Step 16 + # Step 15 if n_step: # persist for future collect iterations self._pre_collect_obs_RO = last_obs_RO