From 96aa21980eb7f015bcb645f5674675c590a58c65 Mon Sep 17 00:00:00 2001 From: belerico_t Date: Fri, 8 Sep 2023 15:08:30 +0200 Subject: [PATCH 01/31] Add IndependentMeanMetric --- sheeprl/utils/metric.py | 81 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 80 insertions(+), 1 deletion(-) diff --git a/sheeprl/utils/metric.py b/sheeprl/utils/metric.py index dbf98e00..63795444 100644 --- a/sheeprl/utils/metric.py +++ b/sheeprl/utils/metric.py @@ -1,8 +1,13 @@ from collections import deque -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Sequence, Union import torch +import torch.distributed as dist +from lightning.fabric.utilities.distributed import _distributed_available +from torch import Tensor +from torch.distributed.distributed_c10d import ProcessGroup from torchmetrics import Metric +from torchmetrics.aggregation import CatMetric class MetricAggregatorException(Exception): @@ -88,6 +93,80 @@ def compute(self) -> Dict[str, torch.Tensor]: return reduced_metrics +class IndependentMeanMetric: + def __init__( + self, + names: Sequence[str], + device: Union[str, torch.device] = "cpu", + process_group: Optional[ProcessGroup] = None, + ) -> None: + """Collection of N independent mean metrics, where `N` is given by the + length of the `names` parameter. This metric is useful when one wants to + maintain averages of some quantities, while still being able to broadcast them + to all the processes in a `torch.distributed` group. + + Args: + names (Sequence[str]): the names of the metrics. + device (Union[str, torch.device], optional): the device where the metrics reside. + Defaults to "cpu". + process_group (Optional[ProcessGroup], optional): the distributed process group. + Defaults to None. + """ + super().__init__() + if len(names) <= 0: + raise ValueError(f"`names` length must be greater than 0: got {len(names)}") + self._names = names + self._device = device + self._metrics: dict[str, Metric] = {} + for n in names: + m = CatMetric(sync_on_compute=False) + self._metrics[n] = m.to(self._device) + self._process_group = process_group if process_group is not None else torch.distributed.group.WORLD + + def update(self, value: float, name: str) -> None: + self._metrics[name].update(value) + + @torch.no_grad() + def compute(self) -> Dict[str, Tensor]: + """Compute the means, one for every metric. The metrics are first broadcasted + + Returns: + Dict[str, Tensor]: _description_ + """ + computed_metrics = {} + for k, v in self._metrics.items(): + computed_v = v.compute() + if not isinstance(computed_v, Tensor): + computed_metrics[k] = torch.tensor(computed_v, device=self._device) + else: + computed_metrics[k] = computed_v + if not _distributed_available(): + return computed_metrics + gathered_data = [None for _ in range(dist.get_world_size(self._process_group))] + dist.all_gather_object(gathered_data, computed_metrics, group=self._process_group) + return_data = gathered_data[0] + for rank in range(1, len(gathered_data)): + for k, rank_v in gathered_data[rank].items(): + if isinstance(rank_v, Tensor): + rank_v = torch.flatten(rank_v) + return_data[k] = torch.cat((return_data[k], rank_v)) + return {k: torch.mean(v) for k, v in return_data.items() if len(v)} + + def to(self, device: Union[str, torch.device] = "cpu") -> None: + """Move all metrics to the given device + + Args: + device (Union[str, torch.device], optional): Device to move the metrics to. Defaults to "cpu". + """ + for k, v in self._metrics.items(): + self._metrics[k] = v.to(device) + + def reset(self) -> None: + """Reset the internal state of the metrics""" + for v in self._metrics.values(): + v.reset() + + class MovingAverageMetric(Metric): """Metric for tracking moving average of a value. From 590b9f19677c9f19ff567f67b3721cfbcd0d1663 Mon Sep 17 00:00:00 2001 From: belerico Date: Sat, 9 Sep 2023 16:36:44 +0200 Subject: [PATCH 02/31] Added RankIndependentMetricAggregator and plot per env cum reward --- pyproject.toml | 2 +- sheeprl/algos/dreamer_v1/dreamer_v1.py | 4 +- sheeprl/algos/dreamer_v2/dreamer_v2.py | 4 +- sheeprl/algos/dreamer_v3/dreamer_v3.py | 4 +- sheeprl/algos/droq/droq.py | 4 +- sheeprl/algos/p2e_dv1/p2e_dv1.py | 4 +- sheeprl/algos/p2e_dv2/p2e_dv2.py | 4 +- sheeprl/algos/ppo/ppo.py | 32 +++-- sheeprl/algos/ppo/ppo_decoupled.py | 4 +- sheeprl/algos/ppo_recurrent/ppo_recurrent.py | 4 +- sheeprl/algos/sac/sac.py | 4 +- sheeprl/algos/sac/sac_decoupled.py | 4 +- sheeprl/algos/sac_ae/sac_ae.py | 4 +- sheeprl/utils/metric.py | 119 +++++++++---------- 14 files changed, 99 insertions(+), 98 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c93876c0..eaa47404 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ dependencies = [ "tensorboard>=2.10", "python-dotenv>=1.0.0", "lightning==2.0.*", - "lightning-utilities<0.9", + "lightning-utilities<=0.9", "hydra-core==1.3.0", "torchmetrics==1.1.*" ] diff --git a/sheeprl/algos/dreamer_v1/dreamer_v1.py b/sheeprl/algos/dreamer_v1/dreamer_v1.py index 213dc463..69166eb9 100644 --- a/sheeprl/algos/dreamer_v1/dreamer_v1.py +++ b/sheeprl/algos/dreamer_v1/dreamer_v1.py @@ -547,14 +547,14 @@ def main(cfg: DictConfig): # Warning for log and checkpoint every if cfg.metric.log_every % policy_steps_per_update != 0: warnings.warn( - f"The log every parameter ({cfg.metric.log_every}) is not a multiple of the " + f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the " f"policy_steps_per_update value ({policy_steps_per_update}), so " "the metrics will be logged at the nearest greater multiple of the " "policy_steps_per_update value." ) if cfg.checkpoint.every % policy_steps_per_update != 0: warnings.warn( - f"The checkpoint every parameter ({cfg.checkpoint.every}) is not a multiple of the " + f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the " f"policy_steps_per_update value ({policy_steps_per_update}), so " "the checkpoint will be saved at the nearest greater multiple of the " "policy_steps_per_update value." diff --git a/sheeprl/algos/dreamer_v2/dreamer_v2.py b/sheeprl/algos/dreamer_v2/dreamer_v2.py index c3ad542d..ae2e7867 100644 --- a/sheeprl/algos/dreamer_v2/dreamer_v2.py +++ b/sheeprl/algos/dreamer_v2/dreamer_v2.py @@ -580,14 +580,14 @@ def main(cfg: DictConfig): # Warning for log and checkpoint every if cfg.metric.log_every % policy_steps_per_update != 0: warnings.warn( - f"The log every parameter ({cfg.metric.log_every}) is not a multiple of the " + f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the " f"policy_steps_per_update value ({policy_steps_per_update}), so " "the metrics will be logged at the nearest greater multiple of the " "policy_steps_per_update value." ) if cfg.checkpoint.every % policy_steps_per_update != 0: warnings.warn( - f"The checkpoint every parameter ({cfg.checkpoint.every}) is not a multiple of the " + f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the " f"policy_steps_per_update value ({policy_steps_per_update}), so " "the checkpoint will be saved at the nearest greater multiple of the " "policy_steps_per_update value." diff --git a/sheeprl/algos/dreamer_v3/dreamer_v3.py b/sheeprl/algos/dreamer_v3/dreamer_v3.py index fb59ec85..4f19d628 100644 --- a/sheeprl/algos/dreamer_v3/dreamer_v3.py +++ b/sheeprl/algos/dreamer_v3/dreamer_v3.py @@ -530,14 +530,14 @@ def main(cfg: DictConfig): # Warning for log and checkpoint every if cfg.metric.log_every % policy_steps_per_update != 0: warnings.warn( - f"The log every parameter ({cfg.metric.log_every}) is not a multiple of the " + f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the " f"policy_steps_per_update value ({policy_steps_per_update}), so " "the metrics will be logged at the nearest greater multiple of the " "policy_steps_per_update value." ) if cfg.checkpoint.every % policy_steps_per_update != 0: warnings.warn( - f"The checkpoint every parameter ({cfg.checkpoint.every}) is not a multiple of the " + f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the " f"policy_steps_per_update value ({policy_steps_per_update}), so " "the checkpoint will be saved at the nearest greater multiple of the " "policy_steps_per_update value." diff --git a/sheeprl/algos/droq/droq.py b/sheeprl/algos/droq/droq.py index 3c7b8101..c8ba9877 100644 --- a/sheeprl/algos/droq/droq.py +++ b/sheeprl/algos/droq/droq.py @@ -229,14 +229,14 @@ def main(cfg: DictConfig): # Warning for log and checkpoint every if cfg.metric.log_every % policy_steps_per_update != 0: warnings.warn( - f"The log every parameter ({cfg.metric.log_every}) is not a multiple of the " + f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the " f"policy_steps_per_update value ({policy_steps_per_update}), so " "the metrics will be logged at the nearest greater multiple of the " "policy_steps_per_update value." ) if cfg.checkpoint.every % policy_steps_per_update != 0: warnings.warn( - f"The checkpoint every parameter ({cfg.checkpoint.every}) is not a multiple of the " + f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the " f"policy_steps_per_update value ({policy_steps_per_update}), so " "the checkpoint will be saved at the nearest greater multiple of the " "policy_steps_per_update value." diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1.py b/sheeprl/algos/p2e_dv1/p2e_dv1.py index 1e88121c..baf4bd8c 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1.py @@ -608,14 +608,14 @@ def main(cfg: DictConfig): # Warning for log and checkpoint every if cfg.metric.log_every % policy_steps_per_update != 0: warnings.warn( - f"The log every parameter ({cfg.metric.log_every}) is not a multiple of the " + f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the " f"policy_steps_per_update value ({policy_steps_per_update}), so " "the metrics will be logged at the nearest greater multiple of the " "policy_steps_per_update value." ) if cfg.checkpoint.every % policy_steps_per_update != 0: warnings.warn( - f"The checkpoint every parameter ({cfg.checkpoint.every}) is not a multiple of the " + f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the " f"policy_steps_per_update value ({policy_steps_per_update}), so " "the checkpoint will be saved at the nearest greater multiple of the " "policy_steps_per_update value." diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2.py b/sheeprl/algos/p2e_dv2/p2e_dv2.py index 79ecd46a..ea888da6 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2.py @@ -749,14 +749,14 @@ def main(cfg: DictConfig): # Warning for log and checkpoint every if cfg.metric.log_every % policy_steps_per_update != 0: warnings.warn( - f"The log every parameter ({cfg.metric.log_every}) is not a multiple of the " + f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the " f"policy_steps_per_update value ({policy_steps_per_update}), so " "the metrics will be logged at the nearest greater multiple of the " "policy_steps_per_update value." ) if cfg.checkpoint.every % policy_steps_per_update != 0: warnings.warn( - f"The checkpoint every parameter ({cfg.checkpoint.every}) is not a multiple of the " + f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the " f"policy_steps_per_update value ({policy_steps_per_update}), so " "the checkpoint will be saved at the nearest greater multiple of the " "policy_steps_per_update value." diff --git a/sheeprl/algos/ppo/ppo.py b/sheeprl/algos/ppo/ppo.py index 001a0874..7d7797df 100644 --- a/sheeprl/algos/ppo/ppo.py +++ b/sheeprl/algos/ppo/ppo.py @@ -16,7 +16,7 @@ from tensordict.tensordict import TensorDictBase from torch import nn from torch.utils.data import BatchSampler, DistributedSampler, RandomSampler -from torchmetrics import MeanMetric +from torchmetrics import MeanMetric, RunningMean from sheeprl.algos.ppo.agent import PPOAgent from sheeprl.algos.ppo.loss import entropy_loss, policy_loss, value_loss @@ -25,7 +25,7 @@ from sheeprl.utils.callback import CheckpointCallback from sheeprl.utils.env import make_dict_env from sheeprl.utils.logger import create_tensorboard_logger -from sheeprl.utils.metric import MetricAggregator +from sheeprl.utils.metric import MovingAverageMetric, RankIndependentMetricAggregator, MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.utils import gae, normalize_tensor, polynomial_decay @@ -201,6 +201,12 @@ def main(cfg: DictConfig): "Loss/entropy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), } ) + rank_aggregator = RankIndependentMetricAggregator( + { + f"Rewards/env_{rank * cfg.num_envs + i}": MovingAverageMetric(window=5, sync_on_compute=False) + for i in range(cfg.num_envs) + } + ) # Local data rb = ReplayBuffer( @@ -224,14 +230,14 @@ def main(cfg: DictConfig): # Warning for log and checkpoint every if cfg.metric.log_every % policy_steps_per_update != 0: warnings.warn( - f"The log every parameter ({cfg.metric.log_every}) is not a multiple of the " + f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the " f"policy_steps_per_update value ({policy_steps_per_update}), so " "the metrics will be logged at the nearest greater multiple of the " "policy_steps_per_update value." ) if cfg.checkpoint.every % policy_steps_per_update != 0: warnings.warn( - f"The checkpoint every parameter ({cfg.checkpoint.every}) is not a multiple of the " + f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the " f"policy_steps_per_update value ({policy_steps_per_update}), so " "the checkpoint will be saved at the nearest greater multiple of the " "policy_steps_per_update value." @@ -266,7 +272,7 @@ def main(cfg: DictConfig): normalized_obs = { k: next_obs[k] / 255 - 0.5 if k in cfg.cnn_keys.encoder else next_obs[k] for k in obs_keys } - actions, logprobs, _, value = agent(normalized_obs) + actions, logprobs, _, value = agent.module(normalized_obs) if is_continuous: real_actions = torch.cat(actions, -1).cpu().numpy() else: @@ -314,15 +320,19 @@ def main(cfg: DictConfig): fabric.print( f"Rank-0: policy_step={policy_step}, reward_env_{i}={agent_final_info['episode']['r'][0]}" ) - aggregator.update("Rewards/rew_avg", agent_final_info["episode"]["r"][0]) - aggregator.update("Game/ep_len_avg", agent_final_info["episode"]["l"][0]) + aggregator.update("Rewards/rew_avg", agent_final_info["episode"]["r"]) + aggregator.update("Game/ep_len_avg", agent_final_info["episode"]["l"]) + rank_aggregator.update( + f"Rewards/env_{fabric.global_rank * cfg.num_envs + i}", + torch.from_numpy(agent_final_info["episode"]["r"]).to(device), + ) # Estimate returns with GAE (https://arxiv.org/abs/1506.02438) with torch.no_grad(): normalized_obs = { k: next_obs[k] / 255 - 0.5 if k in cfg.cnn_keys.encoder else next_obs[k] for k in obs_keys } - next_values = agent.get_value(normalized_obs) + next_values = agent.module.get_value(normalized_obs) returns, advantages = gae( rb["rewards"], rb["values"], @@ -373,8 +383,12 @@ def main(cfg: DictConfig): last_log = policy_step metrics_dict = aggregator.compute() fabric.log_dict(metrics_dict, policy_step) - fabric.log("Time/step_per_second", int(policy_step / (time.perf_counter() - start_time)), policy_step) aggregator.reset() + rank_metrics = rank_aggregator.compute() + for rank_m in rank_metrics: + fabric.log_dict(rank_m, policy_step) + rank_aggregator.reset() + fabric.log("Time/step_per_second", int(policy_step / (time.perf_counter() - start_time)), policy_step) # Checkpoint model if ( diff --git a/sheeprl/algos/ppo/ppo_decoupled.py b/sheeprl/algos/ppo/ppo_decoupled.py index b30ff180..eb7823a8 100644 --- a/sheeprl/algos/ppo/ppo_decoupled.py +++ b/sheeprl/algos/ppo/ppo_decoupled.py @@ -149,14 +149,14 @@ def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_co # Warning for log and checkpoint every if cfg.metric.log_every % policy_steps_per_update != 0: warnings.warn( - f"The log every parameter ({cfg.metric.log_every}) is not a multiple of the " + f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the " f"policy_steps_per_update value ({policy_steps_per_update}), so " "the metrics will be logged at the nearest greater multiple of the " "policy_steps_per_update value." ) if cfg.checkpoint.every % policy_steps_per_update != 0: warnings.warn( - f"The checkpoint every parameter ({cfg.checkpoint.every}) is not a multiple of the " + f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the " f"policy_steps_per_update value ({policy_steps_per_update}), so " "the checkpoint will be saved at the nearest greater multiple of the " "policy_steps_per_update value." diff --git a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py index 21aaef61..07a925ca 100644 --- a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py +++ b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py @@ -208,14 +208,14 @@ def main(cfg: DictConfig): # Warning for log and checkpoint every if cfg.metric.log_every % policy_steps_per_update != 0: warnings.warn( - f"The log every parameter ({cfg.metric.log_every}) is not a multiple of the " + f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the " f"policy_steps_per_update value ({policy_steps_per_update}), so " "the metrics will be logged at the nearest greater multiple of the " "policy_steps_per_update value." ) if cfg.checkpoint.every % policy_steps_per_update != 0: warnings.warn( - f"The checkpoint every parameter ({cfg.checkpoint.every}) is not a multiple of the " + f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the " f"policy_steps_per_update value ({policy_steps_per_update}), so " "the checkpoint will be saved at the nearest greater multiple of the " "policy_steps_per_update value." diff --git a/sheeprl/algos/sac/sac.py b/sheeprl/algos/sac/sac.py index 883fc97a..069b6541 100644 --- a/sheeprl/algos/sac/sac.py +++ b/sheeprl/algos/sac/sac.py @@ -188,14 +188,14 @@ def main(cfg: DictConfig): # Warning for log and checkpoint every if cfg.metric.log_every % policy_steps_per_update != 0: warnings.warn( - f"The log every parameter ({cfg.metric.log_every}) is not a multiple of the " + f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the " f"policy_steps_per_update value ({policy_steps_per_update}), so " "the metrics will be logged at the nearest greater multiple of the " "policy_steps_per_update value." ) if cfg.checkpoint.every % policy_steps_per_update != 0: warnings.warn( - f"The checkpoint every parameter ({cfg.checkpoint.every}) is not a multiple of the " + f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the " f"policy_steps_per_update value ({policy_steps_per_update}), so " "the checkpoint will be saved at the nearest greater multiple of the " "policy_steps_per_update value." diff --git a/sheeprl/algos/sac/sac_decoupled.py b/sheeprl/algos/sac/sac_decoupled.py index fec93791..34d34afb 100644 --- a/sheeprl/algos/sac/sac_decoupled.py +++ b/sheeprl/algos/sac/sac_decoupled.py @@ -132,14 +132,14 @@ def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_co # Warning for log and checkpoint every if cfg.metric.log_every % policy_steps_per_update != 0: warnings.warn( - f"The log every parameter ({cfg.metric.log_every}) is not a multiple of the " + f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the " f"policy_steps_per_update value ({policy_steps_per_update}), so " "the metrics will be logged at the nearest greater multiple of the " "policy_steps_per_update value." ) if cfg.checkpoint.every % policy_steps_per_update != 0: warnings.warn( - f"The checkpoint every parameter ({cfg.checkpoint.every}) is not a multiple of the " + f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the " f"policy_steps_per_update value ({policy_steps_per_update}), so " "the checkpoint will be saved at the nearest greater multiple of the " "policy_steps_per_update value." diff --git a/sheeprl/algos/sac_ae/sac_ae.py b/sheeprl/algos/sac_ae/sac_ae.py index 5fc275a6..11a3b0ab 100644 --- a/sheeprl/algos/sac_ae/sac_ae.py +++ b/sheeprl/algos/sac_ae/sac_ae.py @@ -360,14 +360,14 @@ def main(cfg: DictConfig): # Warning for log and checkpoint every if cfg.metric.log_every % policy_steps_per_update != 0: warnings.warn( - f"The log every parameter ({cfg.metric.log_every}) is not a multiple of the " + f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the " f"policy_steps_per_update value ({policy_steps_per_update}), so " "the metrics will be logged at the nearest greater multiple of the " "policy_steps_per_update value." ) if cfg.checkpoint.every % policy_steps_per_update != 0: warnings.warn( - f"The checkpoint every parameter ({cfg.checkpoint.every}) is not a multiple of the " + f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the " f"policy_steps_per_update value ({policy_steps_per_update}), so " "the checkpoint will be saved at the nearest greater multiple of the " "policy_steps_per_update value." diff --git a/sheeprl/utils/metric.py b/sheeprl/utils/metric.py index 63795444..ec41168e 100644 --- a/sheeprl/utils/metric.py +++ b/sheeprl/utils/metric.py @@ -1,5 +1,5 @@ from collections import deque -from typing import Any, Dict, Optional, Sequence, Union +from typing import Any, Dict, List, Optional, Union import torch import torch.distributed as dist @@ -7,7 +7,7 @@ from torch import Tensor from torch.distributed.distributed_c10d import ProcessGroup from torchmetrics import Metric -from torchmetrics.aggregation import CatMetric +from torchmetrics.wrappers.running import Running class MetricAggregatorException(Exception): @@ -79,7 +79,7 @@ def to(self, device: Union[str, torch.device] = "cpu") -> None: self.metrics[k] = v.to(device) @torch.no_grad() - def compute(self) -> Dict[str, torch.Tensor]: + def compute(self) -> Dict[str, List]: """Reduce the metrics to a single value Returns: Reduced metrics @@ -87,70 +87,60 @@ def compute(self) -> Dict[str, torch.Tensor]: reduced_metrics = {} if self.metrics: for k, v in self.metrics.items(): - reduced = v.compute() - if v._update_called: + reduced: Tensor = v.compute() + if v.update_called or isinstance(v, Running): reduced_metrics[k] = reduced.tolist() return reduced_metrics -class IndependentMeanMetric: +class RankIndependentMetricAggregator: def __init__( self, - names: Sequence[str], - device: Union[str, torch.device] = "cpu", + metrics: Union[MetricAggregator, Dict[str, Metric]], process_group: Optional[ProcessGroup] = None, ) -> None: - """Collection of N independent mean metrics, where `N` is given by the - length of the `names` parameter. This metric is useful when one wants to - maintain averages of some quantities, while still being able to broadcast them - to all the processes in a `torch.distributed` group. + """Rank-independent MetricAggregator. + This metric is useful when one wants to maintain per-rank-independent metrics of some quantities, + while still being able to broadcast them to all the processes in a `torch.distributed` group. Args: - names (Sequence[str]): the names of the metrics. - device (Union[str, torch.device], optional): the device where the metrics reside. - Defaults to "cpu". + metrics (Union[MetricAggregator, Dict[str, Metric]]): the metrics to be aggregated. + If a dictionary of metrics is passed, then the aggregator is constructed from it. process_group (Optional[ProcessGroup], optional): the distributed process group. Defaults to None. """ super().__init__() - if len(names) <= 0: - raise ValueError(f"`names` length must be greater than 0: got {len(names)}") - self._names = names - self._device = device - self._metrics: dict[str, Metric] = {} - for n in names: - m = CatMetric(sync_on_compute=False) - self._metrics[n] = m.to(self._device) + if isinstance(metrics, dict): + aggregator = MetricAggregator(metrics) + self._aggregator: MetricAggregator = aggregator + for m in aggregator.metrics.values(): + m.sync_on_compute = False self._process_group = process_group if process_group is not None else torch.distributed.group.WORLD - def update(self, value: float, name: str) -> None: - self._metrics[name].update(value) + def update(self, key: str, value: Union[float, Tensor]) -> None: + """Update the metric specified by `name` with `value` + + Args: + key (str): the name of the metric to be updated. + value (Union[float, Tensor]): value to update the metric with. + """ + self._aggregator.update(key, value) @torch.no_grad() - def compute(self) -> Dict[str, Tensor]: - """Compute the means, one for every metric. The metrics are first broadcasted + def compute(self) -> List[Dict[str, List]]: + """Compute the metric independently for every rank and broadcast the result to all + the processes in the process group. Returns: - Dict[str, Tensor]: _description_ + List[Dict[str, List]]: the computed metrics, broadcasted from and to every processes. + The list of the data returned is equal to the number of processes in the process group. """ - computed_metrics = {} - for k, v in self._metrics.items(): - computed_v = v.compute() - if not isinstance(computed_v, Tensor): - computed_metrics[k] = torch.tensor(computed_v, device=self._device) - else: - computed_metrics[k] = computed_v + computed_metrics = self._aggregator.compute() if not _distributed_available(): - return computed_metrics + return [computed_metrics] gathered_data = [None for _ in range(dist.get_world_size(self._process_group))] dist.all_gather_object(gathered_data, computed_metrics, group=self._process_group) - return_data = gathered_data[0] - for rank in range(1, len(gathered_data)): - for k, rank_v in gathered_data[rank].items(): - if isinstance(rank_v, Tensor): - rank_v = torch.flatten(rank_v) - return_data[k] = torch.cat((return_data[k], rank_v)) - return {k: torch.mean(v) for k, v in return_data.items() if len(v)} + return gathered_data def to(self, device: Union[str, torch.device] = "cpu") -> None: """Move all metrics to the given device @@ -158,13 +148,11 @@ def to(self, device: Union[str, torch.device] = "cpu") -> None: Args: device (Union[str, torch.device], optional): Device to move the metrics to. Defaults to "cpu". """ - for k, v in self._metrics.items(): - self._metrics[k] = v.to(device) + self._aggregator.to(device) def reset(self) -> None: """Reset the internal state of the metrics""" - for v in self._metrics.values(): - v.reset() + self._aggregator.reset() class MovingAverageMetric(Metric): @@ -172,44 +160,43 @@ class MovingAverageMetric(Metric): Args: name (str): Name of the metric - window_size (int): Window size for computing moving average + window (int): Window size for computing moving average device (str): Device to store the metric """ - def __init__(self, name: str, window_size: int = 100, device: str = "cpu") -> None: - super().__init__(sync_on_compute=False) - self.window_size = window_size - self._values = deque(maxlen=window_size) - self._sum = torch.tensor(0.0, device=self._device) + sum_value: Tensor + + def __init__(self, window: int = 100, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._window = window + self._values = deque(maxlen=window) + self.add_state("sum_value", torch.tensor(0.0, device=self.device, dtype=torch.float32, requires_grad=False)) def update(self, value: Union[torch.Tensor, float]) -> None: """Update the moving average with a new value. Args: - value (Union[torch.Tensor, float]): New value to update the moving average + value (Union[torch.Tensor, float]): New value to update the moving average. """ if isinstance(value, torch.Tensor): value = value.item() - if len(self._values) == self.window_size: - self._sum -= self._values.popleft() - self._sum += value + if len(self._values) == self._window: + self.sum_value -= self._values.popleft() + self.sum_value += value self._values.append(value) - def compute(self) -> Dict: + def compute(self) -> Tensor: """Computes the moving average. Returns: - Dict: Dictionary with the moving average + Tensor: the moving average """ if len(self._values) == 0: - return None - average = self._sum / len(self._values) - std = torch.std(torch.tensor(self._values, device=self._device)) - torch.max(torch.tensor(self._values, device=self._device)) - torch.min(torch.tensor(self._values, device=self._device)) - return average, std.item() + return torch.nan + average = self.sum_value / len(self._values) + return average def reset(self) -> None: """Resets the moving average.""" + super().reset() self._values.clear() - self._sum = torch.tensor(0.0, device=self._device) From 4c5fd6362ccf3ec2e6ec4eeed7db2bae7fa078fa Mon Sep 17 00:00:00 2001 From: belerico_t Date: Mon, 11 Sep 2023 12:28:03 +0200 Subject: [PATCH 03/31] Add RunningMean to PPO --- sheeprl/algos/ppo/ppo.py | 57 +++++++++++++++++++++++----------------- 1 file changed, 33 insertions(+), 24 deletions(-) diff --git a/sheeprl/algos/ppo/ppo.py b/sheeprl/algos/ppo/ppo.py index 001a0874..f5a1dc01 100644 --- a/sheeprl/algos/ppo/ppo.py +++ b/sheeprl/algos/ppo/ppo.py @@ -16,7 +16,7 @@ from tensordict.tensordict import TensorDictBase from torch import nn from torch.utils.data import BatchSampler, DistributedSampler, RandomSampler -from torchmetrics import MeanMetric +from torchmetrics import MeanMetric, RunningMean from sheeprl.algos.ppo.agent import PPOAgent from sheeprl.algos.ppo.loss import entropy_loss, policy_loss, value_loss @@ -25,7 +25,7 @@ from sheeprl.utils.callback import CheckpointCallback from sheeprl.utils.env import make_dict_env from sheeprl.utils.logger import create_tensorboard_logger -from sheeprl.utils.metric import MetricAggregator +from sheeprl.utils.metric import MetricAggregator, RankIndependentMetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.utils import gae, normalize_tensor, polynomial_decay @@ -190,17 +190,21 @@ def main(cfg: DictConfig): optimizer = fabric.setup_optimizers(optimizer) # Create a metric aggregator to log the metrics - with device: - aggregator = MetricAggregator( - { - "Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Time/step_per_second": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/value_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/policy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/entropy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - } - ) + aggregator = MetricAggregator( + { + "Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Time/step_per_second": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/value_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/policy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/entropy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + } + ).to(device) + rank_metrics = {f"Rewards/rew_env_{rank * cfg.num_envs + i}": RunningMean(window=10) for i in range(cfg.num_envs)} + rank_metrics.update( + {f"Game/ep_len_env_{rank * cfg.num_envs + i}": RunningMean(window=10) for i in range(cfg.num_envs)} + ) + rank_aggregator = RankIndependentMetricAggregator(rank_metrics).to(device) # Local data rb = ReplayBuffer( @@ -266,7 +270,7 @@ def main(cfg: DictConfig): normalized_obs = { k: next_obs[k] / 255 - 0.5 if k in cfg.cnn_keys.encoder else next_obs[k] for k in obs_keys } - actions, logprobs, _, value = agent(normalized_obs) + actions, logprobs, _, value = agent.module(normalized_obs) if is_continuous: real_actions = torch.cat(actions, -1).cpu().numpy() else: @@ -309,20 +313,22 @@ def main(cfg: DictConfig): next_done = done if "final_info" in info: - for i, agent_final_info in enumerate(info["final_info"]): - if agent_final_info is not None and "episode" in agent_final_info: - fabric.print( - f"Rank-0: policy_step={policy_step}, reward_env_{i}={agent_final_info['episode']['r'][0]}" - ) - aggregator.update("Rewards/rew_avg", agent_final_info["episode"]["r"][0]) - aggregator.update("Game/ep_len_avg", agent_final_info["episode"]["l"][0]) + for i, agent_ep_info in enumerate(info["final_info"]): + if agent_ep_info is not None: + ep_rew = agent_ep_info["episode"]["r"] + ep_len = agent_ep_info["episode"]["l"] + aggregator.update("Rewards/rew_avg", ep_rew) + aggregator.update("Game/ep_len_avg", ep_len) + rank_aggregator.update(f"Rewards/rew_env_{rank * cfg.num_envs + i}", ep_rew) + rank_aggregator.update(f"Game/ep_len_env_{rank * cfg.num_envs + i}", ep_len) + fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[0]}") # Estimate returns with GAE (https://arxiv.org/abs/1506.02438) with torch.no_grad(): normalized_obs = { k: next_obs[k] / 255 - 0.5 if k in cfg.cnn_keys.encoder else next_obs[k] for k in obs_keys } - next_values = agent.get_value(normalized_obs) + next_values = agent.module.get_value(normalized_obs) returns, advantages = gae( rb["rewards"], rb["values"], @@ -369,12 +375,15 @@ def main(cfg: DictConfig): ) # Log metrics - if policy_step - last_log >= cfg.metric.log_every or cfg.dry_run: + if policy_step - last_log >= cfg.metric.log_every or update == num_updates or cfg.dry_run: last_log = policy_step metrics_dict = aggregator.compute() fabric.log_dict(metrics_dict, policy_step) - fabric.log("Time/step_per_second", int(policy_step / (time.perf_counter() - start_time)), policy_step) aggregator.reset() + rank_metrics = rank_aggregator.compute() + for m in rank_metrics: + fabric.log_dict(m, policy_step) + fabric.log("Time/step_per_second", int(policy_step / (time.perf_counter() - start_time)), policy_step) # Checkpoint model if ( From 0698a7a3c9d280e33a716d3fcb18f27083d95ba8 Mon Sep 17 00:00:00 2001 From: belerico_t Date: Mon, 11 Sep 2023 12:28:29 +0200 Subject: [PATCH 04/31] Call redue every time in MetricAggregator --- sheeprl/utils/metric.py | 139 +++++++++++++--------------------------- 1 file changed, 45 insertions(+), 94 deletions(-) diff --git a/sheeprl/utils/metric.py b/sheeprl/utils/metric.py index 63795444..5eb17de8 100644 --- a/sheeprl/utils/metric.py +++ b/sheeprl/utils/metric.py @@ -1,5 +1,5 @@ -from collections import deque -from typing import Any, Dict, Optional, Sequence, Union +import warnings +from typing import Any, Dict, List, Optional, Union import torch import torch.distributed as dist @@ -7,7 +7,6 @@ from torch import Tensor from torch.distributed.distributed_c10d import ProcessGroup from torchmetrics import Metric -from torchmetrics.aggregation import CatMetric class MetricAggregatorException(Exception): @@ -69,7 +68,7 @@ def reset(self): for metric in self.metrics.values(): metric.reset() - def to(self, device: Union[str, torch.device] = "cpu") -> None: + def to(self, device: Union[str, torch.device] = "cpu") -> "MetricAggregator": """Move all metrics to the given device Args: device (Union[str, torch.device], optional): Device to move the metrics to. Defaults to "cpu". @@ -77,6 +76,7 @@ def to(self, device: Union[str, torch.device] = "cpu") -> None: if self.metrics: for k, v in self.metrics.items(): self.metrics[k] = v.to(device) + return self @torch.no_grad() def compute(self) -> Dict[str, torch.Tensor]: @@ -88,16 +88,30 @@ def compute(self) -> Dict[str, torch.Tensor]: if self.metrics: for k, v in self.metrics.items(): reduced = v.compute() - if v._update_called: - reduced_metrics[k] = reduced.tolist() + is_tensor = torch.is_tensor(reduced) + if is_tensor and reduced.numel() == 1: + reduced_metrics[k] = reduced.item() + else: + if not is_tensor: + warnings.warn( + f"The reduced metric {k} is not a scalar tensor: type={type(reduced)}. " + "This may create problems during the logging phase.", + category=RuntimeWarning, + ) + else: + warnings.warn( + f"The reduced metric {k} is not a scalar: size={v.size()}. " + "This may create problems during the logging phase.", + category=RuntimeWarning, + ) + reduced_metrics[k] = reduced return reduced_metrics -class IndependentMeanMetric: +class RankIndependentMetricAggregator: def __init__( self, - names: Sequence[str], - device: Union[str, torch.device] = "cpu", + metrics: Union[Dict[str, Metric], MetricAggregator], process_group: Optional[ProcessGroup] = None, ) -> None: """Collection of N independent mean metrics, where `N` is given by the @@ -106,110 +120,47 @@ def __init__( to all the processes in a `torch.distributed` group. Args: - names (Sequence[str]): the names of the metrics. - device (Union[str, torch.device], optional): the device where the metrics reside. - Defaults to "cpu". + metrics (Sequence[str]): the metrics. process_group (Optional[ProcessGroup], optional): the distributed process group. Defaults to None. """ super().__init__() - if len(names) <= 0: - raise ValueError(f"`names` length must be greater than 0: got {len(names)}") - self._names = names - self._device = device - self._metrics: dict[str, Metric] = {} - for n in names: - m = CatMetric(sync_on_compute=False) - self._metrics[n] = m.to(self._device) + self._aggregator = metrics + if isinstance(metrics, dict): + self._aggregator = MetricAggregator(metrics) + for m in self._aggregator.metrics.values(): + m._to_sync = False + m.sync_on_compute = False self._process_group = process_group if process_group is not None else torch.distributed.group.WORLD + self._distributed_available = _distributed_available() + self._world_size = dist.get_world_size(self._process_group) if self._distributed_available else 1 - def update(self, value: float, name: str) -> None: - self._metrics[name].update(value) + def update(self, name: str, value: Union[float, Tensor]) -> None: + self._aggregator.update(name, value) @torch.no_grad() - def compute(self) -> Dict[str, Tensor]: + def compute(self) -> List[Dict[str, Tensor]]: """Compute the means, one for every metric. The metrics are first broadcasted Returns: Dict[str, Tensor]: _description_ """ - computed_metrics = {} - for k, v in self._metrics.items(): - computed_v = v.compute() - if not isinstance(computed_v, Tensor): - computed_metrics[k] = torch.tensor(computed_v, device=self._device) - else: - computed_metrics[k] = computed_v - if not _distributed_available(): - return computed_metrics - gathered_data = [None for _ in range(dist.get_world_size(self._process_group))] + computed_metrics = self._aggregator.compute() + if not self._distributed_available: + return [computed_metrics] + gathered_data = [None for _ in range(self._world_size)] dist.all_gather_object(gathered_data, computed_metrics, group=self._process_group) - return_data = gathered_data[0] - for rank in range(1, len(gathered_data)): - for k, rank_v in gathered_data[rank].items(): - if isinstance(rank_v, Tensor): - rank_v = torch.flatten(rank_v) - return_data[k] = torch.cat((return_data[k], rank_v)) - return {k: torch.mean(v) for k, v in return_data.items() if len(v)} - - def to(self, device: Union[str, torch.device] = "cpu") -> None: + return gathered_data + + def to(self, device: Union[str, torch.device] = "cpu") -> "RankIndependentMetricAggregator": """Move all metrics to the given device Args: device (Union[str, torch.device], optional): Device to move the metrics to. Defaults to "cpu". """ - for k, v in self._metrics.items(): - self._metrics[k] = v.to(device) + self._aggregator.to(device) + return self def reset(self) -> None: """Reset the internal state of the metrics""" - for v in self._metrics.values(): - v.reset() - - -class MovingAverageMetric(Metric): - """Metric for tracking moving average of a value. - - Args: - name (str): Name of the metric - window_size (int): Window size for computing moving average - device (str): Device to store the metric - """ - - def __init__(self, name: str, window_size: int = 100, device: str = "cpu") -> None: - super().__init__(sync_on_compute=False) - self.window_size = window_size - self._values = deque(maxlen=window_size) - self._sum = torch.tensor(0.0, device=self._device) - - def update(self, value: Union[torch.Tensor, float]) -> None: - """Update the moving average with a new value. - - Args: - value (Union[torch.Tensor, float]): New value to update the moving average - """ - if isinstance(value, torch.Tensor): - value = value.item() - if len(self._values) == self.window_size: - self._sum -= self._values.popleft() - self._sum += value - self._values.append(value) - - def compute(self) -> Dict: - """Computes the moving average. - - Returns: - Dict: Dictionary with the moving average - """ - if len(self._values) == 0: - return None - average = self._sum / len(self._values) - std = torch.std(torch.tensor(self._values, device=self._device)) - torch.max(torch.tensor(self._values, device=self._device)) - torch.min(torch.tensor(self._values, device=self._device)) - return average, std.item() - - def reset(self) -> None: - """Resets the moving average.""" - self._values.clear() - self._sum = torch.tensor(0.0, device=self._device) + self._aggregator.reset() From 26d5b4a3aaa20e3ebdefb87cea3a677a3bfd73fd Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Mon, 11 Sep 2023 15:00:13 +0200 Subject: [PATCH 05/31] docs: added howto/work_with_steps.md --- README.md | 9 +++++- howto/configs.md | 4 +-- howto/work_with_steps.md | 30 +++++++++++++++++++ sheeprl/algos/dreamer_v1/dreamer_v1.py | 4 +-- sheeprl/algos/dreamer_v2/README.md | 4 +-- sheeprl/algos/dreamer_v2/dreamer_v2.py | 16 ++++++---- sheeprl/algos/dreamer_v3/dreamer_v3.py | 14 +++++---- sheeprl/algos/dreamer_v3/utils.py | 2 +- sheeprl/algos/droq/droq.py | 4 ++- sheeprl/algos/p2e_dv1/p2e_dv1.py | 4 +-- sheeprl/algos/p2e_dv2/p2e_dv2.py | 14 +++++---- sheeprl/algos/sac/sac.py | 3 +- sheeprl/algos/sac/sac_decoupled.py | 4 +-- sheeprl/algos/sac_ae/sac_ae.py | 3 +- sheeprl/configs/algo/dreamer_v1.yaml | 2 +- sheeprl/configs/algo/dreamer_v2.yaml | 4 +-- sheeprl/configs/algo/dreamer_v3.yaml | 4 +-- sheeprl/configs/algo/droq.yaml | 2 +- sheeprl/configs/algo/sac.yaml | 2 +- sheeprl/configs/exp/dreamer_v2_ms_pacman.yaml | 4 +-- tests/test_algos/test_algos.py | 22 +++++++------- tests/test_algos/test_cli.py | 2 +- 22 files changed, 104 insertions(+), 53 deletions(-) create mode 100644 howto/work_with_steps.md diff --git a/README.md b/README.md index b1e44d83..79d37639 100644 --- a/README.md +++ b/README.md @@ -168,7 +168,14 @@ That's all it takes to train an agent with SheepRL! 🎉 > **Note** > -> You can find more information about the observation space by checking [the related howto section](./howto/select_observations.md). +> Before you start using the SheepRL framework, it is **highly recommended** that you read the following instructional documents: +> +> 1. How to [run experiments](./howto/run_experiments.md) +> 2. How to [modify the default configs](./howto/configs.md) +> 3. How to [work with steps](./howto/work_with_steps.md) +> 4. How to [select observations](./howto/select_observations.md) +> +> Moreover, there are other useful documents in the [`howto` folder](./howto/), which containes some guidance on how to properly use the framework. ### :chart_with_upwards_trend: Check your results diff --git a/howto/configs.md b/howto/configs.md index cc3c473e..ed0617f4 100644 --- a/howto/configs.md +++ b/howto/configs.md @@ -138,8 +138,8 @@ horizon: 15 # Training recipe learning_starts: 65536 -pretrain_steps: 1 -gradient_steps: 1 +per_rank_pretrain_steps: 1 +per_rank_gradient_steps: 1 train_every: 16 # Model related parameters diff --git a/howto/work_with_steps.md b/howto/work_with_steps.md new file mode 100644 index 00000000..f1821e0c --- /dev/null +++ b/howto/work_with_steps.md @@ -0,0 +1,30 @@ +# Work with steps +In this document we want to discuss about the hyper-parameters which refer to the concept of step. +There are various ways to interpret it, so it is necessary to clearly specify which iterpretation we give to that concept. + +## Policy steps +We start from the concept of *policy steps*: a policy step is the selection of an action to perform an environment step. In other words, it is called *policy step* when the actor takes in input an observation and choose an action and then the agent performs an environment step. + +> **Note** +> +> The environment step is the step performed by the environment: the environment takes in input an action and computes the next observation and the next reward. + +Now we have introduced the concept of policy step, it is necessary to clarify some aspects: +1. When there are more parallel environments, the policy step is proportional with the number of parallel environments. E.g, if there are $m$ environments, then the agent has to choose $m$ actions and each environment performs an environment step: this mean that **m policy steps** are performed. +2. When there are more parallel processes, the policy step it is proportional with the number of parallel processes. E.g, let us assume that there are $n$ processes each one containing one single environment: the $n$ actors select an action and perform a step in the environment, so, also in this case **n policy steps** are performed. + +To generalize, in a case with $n$ processes, each one with $m$ environments, the policy steps increase by $n \cdot m$ at each iteration. + +The hyper-parameters which refer to the *policy steps* are: +* `total_steps`: the total number of policy steps to perform in an experiment. +* `exploration_steps`: the number of policy steps in which the agent explores the environment in the P2E algorithms. +* `max_episode_steps`: the maximum number of policy steps an episode can last ($\text{max\_steps}$). This means that if you decide to have an action repeat greater than one ($\text{act\_repeat} > 1$), then you can perform a numeber of environment steps equal to: $\text{env\_steps} = \text{max\_steps} \cdot \text{act\_repeat}$. +* `learning_starts`: how many policy steps the agent has to perform before starting the training. +* `train_every`: how many policy steps the agent has to perform between one training and the next. + +## Gradient steps +A *gradient step* consists of an update of the parameters of the agent, i.e., a call of the *train* function. The gradient step is proportional to the number of parallel processes, indeed, if there are $n$ parallel processes, the call of the *train* method will increase by $n$ the gradient step. + +The hyper-parameters which refer to the *policy steps* are: +* `algo.per_rank_gradient_steps`: the number of gradient steps per rank to perform in a single iteration. +* `algo.per_rank_pretrain_steps`: the number of gradient steps per rank to perform in the first iteration. \ No newline at end of file diff --git a/sheeprl/algos/dreamer_v1/dreamer_v1.py b/sheeprl/algos/dreamer_v1/dreamer_v1.py index 327ae68d..d28f5ad6 100644 --- a/sheeprl/algos/dreamer_v1/dreamer_v1.py +++ b/sheeprl/algos/dreamer_v1/dreamer_v1.py @@ -535,7 +535,7 @@ def main(cfg: DictConfig): learning_starts = (cfg.algo.learning_starts // policy_steps_per_update) if not cfg.dry_run else 0 if cfg.checkpoint.resume_from and not cfg.buffer.checkpoint: learning_starts += start_step - max_step_expl_decay = cfg.algo.player.max_step_expl_decay // (cfg.algo.gradient_steps * fabric.world_size) + max_step_expl_decay = cfg.algo.player.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * fabric.world_size) if cfg.checkpoint.resume_from: player.expl_amount = polynomial_decay( expl_decay_steps, @@ -670,7 +670,7 @@ def main(cfg: DictConfig): local_data = rb.sample( cfg.per_rank_batch_size, sequence_length=cfg.per_rank_sequence_length, - n_samples=cfg.algo.gradient_steps, + n_samples=cfg.algo.per_rank_gradient_steps, ).to(device) distributed_sampler = BatchSampler(range(local_data.shape[0]), batch_size=1, drop_last=False) for i in distributed_sampler: diff --git a/sheeprl/algos/dreamer_v2/README.md b/sheeprl/algos/dreamer_v2/README.md index 12e7648e..6e25d3cc 100644 --- a/sheeprl/algos/dreamer_v2/README.md +++ b/sheeprl/algos/dreamer_v2/README.md @@ -120,7 +120,7 @@ env.action_repeat=4 \ clip_rewards=True \ total_steps=200000000 \ algo.learning_starts=200000 \ -algo.pretrain_steps=1 \ +algo.per_rank_pretrain_steps=1 \ algo.train_every=4 \ algo.gamma=0.995 \ algo.world_model.kl_regularizer=0.1 \ @@ -162,7 +162,7 @@ env.action_repeat=2 \ clip_rewards=False \ total_steps=5000000 \ algo.learning_starts=1000 \ -algo.pretrain_steps=100 \ +algo.per_rank_pretrain_steps=100 \ algo.train_every=5 \ algo.gamma=0.99 \ algo.world_model.kl_regularizer=1.0 \ diff --git a/sheeprl/algos/dreamer_v2/dreamer_v2.py b/sheeprl/algos/dreamer_v2/dreamer_v2.py index e0cd0c16..6fcafb2b 100644 --- a/sheeprl/algos/dreamer_v2/dreamer_v2.py +++ b/sheeprl/algos/dreamer_v2/dreamer_v2.py @@ -568,7 +568,7 @@ def main(cfg: DictConfig): learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 if cfg.checkpoint.resume_from and not cfg.buffer.checkpoint: learning_starts += start_step - max_step_expl_decay = cfg.algo.player.max_step_expl_decay // (cfg.algo.gradient_steps * fabric.world_size) + max_step_expl_decay = cfg.algo.player.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * fabric.world_size) if cfg.checkpoint.resume_from: player.expl_amount = polynomial_decay( expl_decay_steps, @@ -616,7 +616,7 @@ def main(cfg: DictConfig): env_ep.append(step_data[i : i + 1][None, ...]) player.init_states() - gradient_steps = 0 + per_rank_gradient_steps = 0 for update in range(start_step, num_updates + 1): # Sample an action given the observation received by the environment if ( @@ -733,17 +733,21 @@ def main(cfg: DictConfig): local_data = rb.sample( cfg.per_rank_batch_size, sequence_length=cfg.per_rank_sequence_length, - n_samples=cfg.algo.pretrain_steps if update == learning_starts else cfg.algo.gradient_steps, + n_samples=cfg.algo.per_rank_pretrain_steps + if update == learning_starts + else cfg.algo.per_rank_gradient_steps, ).to(device) else: local_data = rb.sample( cfg.per_rank_batch_size, - n_samples=cfg.algo.pretrain_steps if update == learning_starts else cfg.algo.gradient_steps, + n_samples=cfg.algo.per_rank_pretrain_steps + if update == learning_starts + else cfg.algo.per_rank_gradient_steps, prioritize_ends=cfg.buffer.prioritize_ends, ).to(device) distributed_sampler = BatchSampler(range(local_data.shape[0]), batch_size=1, drop_last=False) for i in distributed_sampler: - if gradient_steps % cfg.algo.critic.target_network_update_freq == 0: + if per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: for cp, tcp in zip(critic.module.parameters(), target_critic.parameters()): tcp.data.copy_(cp.data) train( @@ -760,7 +764,7 @@ def main(cfg: DictConfig): cfg, actions_dim, ) - gradient_steps += 1 + per_rank_gradient_steps += 1 updates_before_training = cfg.algo.train_every // policy_steps_per_update if cfg.algo.player.expl_decay: expl_decay_steps += 1 diff --git a/sheeprl/algos/dreamer_v3/dreamer_v3.py b/sheeprl/algos/dreamer_v3/dreamer_v3.py index c143b414..af3769e4 100644 --- a/sheeprl/algos/dreamer_v3/dreamer_v3.py +++ b/sheeprl/algos/dreamer_v3/dreamer_v3.py @@ -518,7 +518,7 @@ def main(cfg: DictConfig): learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 if cfg.checkpoint.resume_from and not cfg.buffer.checkpoint: learning_starts += start_step - max_step_expl_decay = cfg.algo.player.max_step_expl_decay // (cfg.algo.gradient_steps * fabric.world_size) + max_step_expl_decay = cfg.algo.player.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * fabric.world_size) if cfg.checkpoint.resume_from: player.expl_amount = polynomial_decay( expl_decay_steps, @@ -558,7 +558,7 @@ def main(cfg: DictConfig): step_data["is_first"] = torch.ones_like(step_data["dones"]).float() player.init_states() - gradient_steps = 0 + per_rank_gradient_steps = 0 for update in range(start_step, num_updates + 1): # Sample an action given the observation received by the environment if ( @@ -675,12 +675,14 @@ def main(cfg: DictConfig): local_data = rb.sample( cfg.per_rank_batch_size, sequence_length=cfg.per_rank_sequence_length, - n_samples=cfg.algo.pretrain_steps if update == learning_starts else cfg.algo.gradient_steps, + n_samples=cfg.algo.per_rank_pretrain_steps + if update == learning_starts + else cfg.algo.per_rank_gradient_steps, ).to(device) distributed_sampler = BatchSampler(range(local_data.shape[0]), batch_size=1, drop_last=False) for i in distributed_sampler: - if gradient_steps % cfg.algo.critic.target_network_update_freq == 0: - tau = 1 if gradient_steps == 0 else cfg.algo.critic.tau + if per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: + tau = 1 if per_rank_gradient_steps == 0 else cfg.algo.critic.tau for cp, tcp in zip(critic.module.parameters(), target_critic.parameters()): tcp.data.copy_(tau * cp.data + (1 - tau) * tcp.data) train( @@ -699,7 +701,7 @@ def main(cfg: DictConfig): actions_dim, moments, ) - gradient_steps += 1 + per_rank_gradient_steps += 1 updates_before_training = cfg.algo.train_every // policy_steps_per_update if cfg.algo.player.expl_decay: expl_decay_steps += 1 diff --git a/sheeprl/algos/dreamer_v3/utils.py b/sheeprl/algos/dreamer_v3/utils.py index 98789a06..a0d1ac76 100644 --- a/sheeprl/algos/dreamer_v3/utils.py +++ b/sheeprl/algos/dreamer_v3/utils.py @@ -102,7 +102,7 @@ def test( real_actions = np.array([real_act.cpu().argmax(dim=-1).numpy() for real_act in real_actions]) # Single environment step - next_obs, reward, done, truncated, _ = env.step(real_actions.reshape(env.action_space.shape)) + next_obs, reward, done, truncated, _ = env.step(int(real_actions.reshape(env.action_space.shape))) for k in next_obs.keys(): next_obs[k] = torch.from_numpy(next_obs[k]).view(1, *next_obs[k].shape).float() done = done or truncated or cfg.dry_run diff --git a/sheeprl/algos/droq/droq.py b/sheeprl/algos/droq/droq.py index a8e92de7..2eb91434 100644 --- a/sheeprl/algos/droq/droq.py +++ b/sheeprl/algos/droq/droq.py @@ -41,7 +41,9 @@ def train( ): # Sample a minibatch in a distributed way: Line 5 - Algorithm 2 # We sample one time to reduce the communications between processes - sample = rb.sample(cfg.algo.gradient_steps * cfg.per_rank_batch_size, sample_next_obs=cfg.buffer.sample_next_obs) + sample = rb.sample( + cfg.algo.per_rank_gradient_steps * cfg.per_rank_batch_size, sample_next_obs=cfg.buffer.sample_next_obs + ) gathered_data = fabric.all_gather(sample.to_dict()) gathered_data = make_tensordict(gathered_data).view(-1) if fabric.world_size > 1: diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1.py b/sheeprl/algos/p2e_dv1/p2e_dv1.py index a803bdcd..e5a4c262 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1.py @@ -596,7 +596,7 @@ def main(cfg: DictConfig): exploration_updates = min(num_updates, exploration_updates) if cfg.checkpoint.resume_from and not cfg.buffer.checkpoint: learning_starts += start_step - max_step_expl_decay = cfg.algo.player.max_step_expl_decay // (cfg.algo.gradient_steps * fabric.world_size) + max_step_expl_decay = cfg.algo.player.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * fabric.world_size) if cfg.checkpoint.resume_from: player.expl_amount = polynomial_decay( expl_decay_steps, @@ -742,7 +742,7 @@ def main(cfg: DictConfig): local_data = rb.sample( cfg.per_rank_batch_size, sequence_length=cfg.per_rank_sequence_length, - n_samples=cfg.algo.gradient_steps, + n_samples=cfg.algo.per_rank_gradient_steps, ).to(device) distributed_sampler = BatchSampler(range(local_data.shape[0]), batch_size=1, drop_last=False) for i in distributed_sampler: diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2.py b/sheeprl/algos/p2e_dv2/p2e_dv2.py index 68622663..61fec5a6 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2.py @@ -733,7 +733,7 @@ def main(cfg: DictConfig): learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 if cfg.checkpoint.resume_from and not cfg.buffer.checkpoint: learning_starts += start_step - max_step_expl_decay = cfg.algo.player.max_step_expl_decay // (cfg.algo.gradient_steps * fabric.world_size) + max_step_expl_decay = cfg.algo.player.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * fabric.world_size) if cfg.checkpoint.resume_from: player.expl_amount = polynomial_decay( expl_decay_steps, @@ -785,7 +785,7 @@ def main(cfg: DictConfig): env_ep.append(step_data[i : i + 1][None, ...]) player.init_states() - gradient_steps = 0 + per_rank_gradient_steps = 0 is_exploring = True for update in range(start_step, num_updates + 1): if update == exploration_updates: @@ -906,17 +906,21 @@ def main(cfg: DictConfig): local_data = rb.sample( cfg.per_rank_batch_size, sequence_length=cfg.per_rank_sequence_length, - n_samples=cfg.algo.pretrain_steps if update == learning_starts else cfg.algo.gradient_steps, + n_samples=cfg.algo.per_rank_pretrain_steps + if update == learning_starts + else cfg.algo.per_rank_gradient_steps, ).to(device) else: local_data = rb.sample( cfg.per_rank_batch_size, - n_samples=cfg.algo.pretrain_steps if update == learning_starts else cfg.algo.gradient_steps, + n_samples=cfg.algo.per_rank_pretrain_steps + if update == learning_starts + else cfg.algo.per_rank_gradient_steps, prioritize_ends=cfg.buffer.prioritize_ends, ).to(device) distributed_sampler = BatchSampler(range(local_data.shape[0]), batch_size=1, drop_last=False) for i in distributed_sampler: - if gradient_steps % cfg.algo.critic.target_network_update_freq == 0: + if per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: for cp, tcp in zip(critic_task.module.parameters(), target_critic_task.parameters()): tcp.data.copy_(cp.data) for cp, tcp in zip(critic_exploration.module.parameters(), target_critic_exploration.parameters()): diff --git a/sheeprl/algos/sac/sac.py b/sheeprl/algos/sac/sac.py index 4ccb62e0..04bc171b 100644 --- a/sheeprl/algos/sac/sac.py +++ b/sheeprl/algos/sac/sac.py @@ -256,7 +256,8 @@ def main(cfg: DictConfig): for _ in range(training_steps): # We sample one time to reduce the communications between processes sample = rb.sample( - cfg.algo.gradient_steps * cfg.per_rank_batch_size, sample_next_obs=cfg.buffer.sample_next_obs + cfg.algo.per_rank_gradient_steps * cfg.per_rank_batch_size, + sample_next_obs=cfg.buffer.sample_next_obs, ) # [G*B, 1] gathered_data = fabric.all_gather(sample.to_dict()) # [G*B, World, 1] gathered_data = make_tensordict(gathered_data).view(-1) # [G*B*World] diff --git a/sheeprl/algos/sac/sac_decoupled.py b/sheeprl/algos/sac/sac_decoupled.py index 05c51129..b0016b77 100644 --- a/sheeprl/algos/sac/sac_decoupled.py +++ b/sheeprl/algos/sac/sac_decoupled.py @@ -199,9 +199,9 @@ def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_co if update >= learning_starts: training_steps = learning_starts if update == learning_starts else 1 chunks = rb.sample( - training_steps * cfg.algo.gradient_steps * cfg.per_rank_batch_size * (fabric.world_size - 1), + training_steps * cfg.algo.per_rank_gradient_steps * cfg.per_rank_batch_size * (fabric.world_size - 1), sample_next_obs=cfg.buffer.sample_next_obs, - ).split(training_steps * cfg.algo.gradient_steps * cfg.per_rank_batch_size) + ).split(training_steps * cfg.algo.per_rank_gradient_steps * cfg.per_rank_batch_size) world_collective.scatter_object_list([None], [None] + chunks, src=0) # Gather metrics from the trainers to be plotted diff --git a/sheeprl/algos/sac_ae/sac_ae.py b/sheeprl/algos/sac_ae/sac_ae.py index fae6deaf..df66d59d 100644 --- a/sheeprl/algos/sac_ae/sac_ae.py +++ b/sheeprl/algos/sac_ae/sac_ae.py @@ -450,7 +450,8 @@ def main(cfg: DictConfig): for _ in range(training_steps): # We sample one time to reduce the communications between processes sample = rb.sample( - cfg.algo.gradient_steps * cfg.per_rank_batch_size, sample_next_obs=cfg.buffer.sample_next_obs + cfg.algo.per_rank_gradient_steps * cfg.per_rank_batch_size, + sample_next_obs=cfg.buffer.sample_next_obs, ) # [G*B, 1] gathered_data = fabric.all_gather(sample.to_dict()) # [G*B, World, 1] gathered_data = make_tensordict(gathered_data).view(-1) # [G*B*World] diff --git a/sheeprl/configs/algo/dreamer_v1.yaml b/sheeprl/configs/algo/dreamer_v1.yaml index 6d7ef1e8..3633fd2e 100644 --- a/sheeprl/configs/algo/dreamer_v1.yaml +++ b/sheeprl/configs/algo/dreamer_v1.yaml @@ -12,7 +12,7 @@ name: dreamer_v1 # Training recipe learning_starts: 5000 -gradient_steps: 100 +per_rank_gradient_steps: 100 train_every: 1000 # Model related parameters diff --git a/sheeprl/configs/algo/dreamer_v2.yaml b/sheeprl/configs/algo/dreamer_v2.yaml index b3035c54..80bbd82a 100644 --- a/sheeprl/configs/algo/dreamer_v2.yaml +++ b/sheeprl/configs/algo/dreamer_v2.yaml @@ -12,8 +12,8 @@ horizon: 15 # Training recipe learning_starts: 1000 -pretrain_steps: 100 -gradient_steps: 1 +per_rank_pretrain_steps: 100 +per_rank_gradient_steps: 1 train_every: 5 # Model related parameters diff --git a/sheeprl/configs/algo/dreamer_v3.yaml b/sheeprl/configs/algo/dreamer_v3.yaml index 18815b55..b2648932 100644 --- a/sheeprl/configs/algo/dreamer_v3.yaml +++ b/sheeprl/configs/algo/dreamer_v3.yaml @@ -12,8 +12,8 @@ horizon: 15 # Training recipe learning_starts: 65536 -pretrain_steps: 1 -gradient_steps: 1 +per_rank_pretrain_steps: 1 +per_rank_gradient_steps: 1 train_every: 16 # Model related parameters diff --git a/sheeprl/configs/algo/droq.yaml b/sheeprl/configs/algo/droq.yaml index 4ba945c2..82a88b2a 100644 --- a/sheeprl/configs/algo/droq.yaml +++ b/sheeprl/configs/algo/droq.yaml @@ -5,7 +5,7 @@ defaults: name: droq # Training recipe -gradient_steps: 20 +per_rank_gradient_steps: 20 # Override from `sac` config critic: diff --git a/sheeprl/configs/algo/sac.yaml b/sheeprl/configs/algo/sac.yaml index 739b04af..452f447e 100644 --- a/sheeprl/configs/algo/sac.yaml +++ b/sheeprl/configs/algo/sac.yaml @@ -12,7 +12,7 @@ hidden_size: 256 # Training recipe learning_starts: 100 -gradient_steps: 1 +per_rank_gradient_steps: 1 # Model related parameters # Actor diff --git a/sheeprl/configs/exp/dreamer_v2_ms_pacman.yaml b/sheeprl/configs/exp/dreamer_v2_ms_pacman.yaml index 1c99a32b..e784b2e3 100644 --- a/sheeprl/configs/exp/dreamer_v2_ms_pacman.yaml +++ b/sheeprl/configs/exp/dreamer_v2_ms_pacman.yaml @@ -31,8 +31,8 @@ buffer: algo: gamma: 0.995 train_every: 16 - pretrain_steps: 1 - gradient_steps: 1 + per_rank_pretrain_steps: 1 + per_rank_gradient_steps: 1 learning_starts: 200000 world_model: use_continues: True diff --git a/tests/test_algos/test_algos.py b/tests/test_algos/test_algos.py index d81ae5be..b5cb20bd 100644 --- a/tests/test_algos/test_algos.py +++ b/tests/test_algos/test_algos.py @@ -86,7 +86,7 @@ def test_droq(standard_args, checkpoint_buffer, start_time): "per_rank_batch_size=1", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", - "algo.gradient_steps=1", + "algo.per_rank_gradient_steps=1", f"root_dir={root_dir}", f"run_name={run_name}", f"buffer.checkpoint={checkpoint_buffer}", @@ -119,7 +119,7 @@ def test_sac(standard_args, checkpoint_buffer, start_time): "per_rank_batch_size=1", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", - "algo.gradient_steps=1", + "algo.per_rank_gradient_steps=1", f"root_dir={root_dir}", f"run_name={run_name}", f"buffer.checkpoint={checkpoint_buffer}", @@ -152,7 +152,7 @@ def test_sac_ae(standard_args, checkpoint_buffer, start_time): "per_rank_batch_size=1", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", - "algo.gradient_steps=1", + "algo.per_rank_gradient_steps=1", f"root_dir={root_dir}", f"run_name={run_name}", "mlp_keys.encoder=[state]", @@ -203,7 +203,7 @@ def test_sac_decoupled(standard_args, checkpoint_buffer, start_time): "exp=sac", "per_rank_batch_size=1", "algo.learning_starts=0", - "algo.gradient_steps=1", + "algo.per_rank_gradient_steps=1", f"root_dir={root_dir}", f"run_name={run_name}", f"buffer.checkpoint={checkpoint_buffer}", @@ -359,7 +359,7 @@ def test_dreamer_v1(standard_args, env_id, checkpoint_buffer, start_time): "per_rank_sequence_length=1", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", - "algo.gradient_steps=1", + "algo.per_rank_gradient_steps=1", "algo.horizon=2", f"env.id={env_id}", f"root_dir={root_dir}", @@ -413,7 +413,7 @@ def test_p2e_dv1(standard_args, env_id, checkpoint_buffer, start_time): "per_rank_sequence_length=2", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", - "algo.gradient_steps=1", + "algo.per_rank_gradient_steps=1", "algo.horizon=8", "env.id=" + env_id, f"root_dir={root_dir}", @@ -472,7 +472,7 @@ def test_p2e_dv2(standard_args, env_id, checkpoint_buffer, start_time): "per_rank_sequence_length=2", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", - "algo.gradient_steps=1", + "algo.per_rank_gradient_steps=1", "algo.horizon=2", "env.id=" + env_id, f"root_dir={root_dir}", @@ -483,7 +483,7 @@ def test_p2e_dv2(standard_args, env_id, checkpoint_buffer, start_time): "algo.world_model.representation_model.hidden_size=8", "algo.world_model.transition_model.hidden_size=8", "cnn_keys.encoder=[rgb]", - "algo.pretrain_steps=1", + "algo.per_rank_pretrain_steps=1", f"buffer.checkpoint={checkpoint_buffer}", "env.capture_video=False", ] @@ -537,7 +537,7 @@ def test_dreamer_v2(standard_args, env_id, checkpoint_buffer, start_time): "per_rank_sequence_length=1", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", - "algo.gradient_steps=1", + "algo.per_rank_gradient_steps=1", "algo.horizon=8", "env.id=" + env_id, f"root_dir={root_dir}", @@ -548,7 +548,7 @@ def test_dreamer_v2(standard_args, env_id, checkpoint_buffer, start_time): "algo.world_model.representation_model.hidden_size=8", "algo.world_model.transition_model.hidden_size=8", "cnn_keys.encoder=[rgb]", - "algo.pretrain_steps=1", + "algo.per_rank_pretrain_steps=1", "algo.layer_norm=True", f"buffer.checkpoint={checkpoint_buffer}", "env.capture_video=False", @@ -596,7 +596,7 @@ def test_dreamer_v3(standard_args, env_id, checkpoint_buffer, start_time): "per_rank_sequence_length=1", f"buffer.size={int(os.environ['LT_DEVICES'])}", "algo.learning_starts=0", - "algo.gradient_steps=1", + "algo.per_rank_gradient_steps=1", "algo.horizon=8", "env.id=" + env_id, f"root_dir={root_dir}", diff --git a/tests/test_algos/test_cli.py b/tests/test_algos/test_cli.py index 68916b7c..1931ef4c 100644 --- a/tests/test_algos/test_cli.py +++ b/tests/test_algos/test_cli.py @@ -43,7 +43,7 @@ def test_resume_from_checkpoint(): sys.executable + " sheeprl.py dreamer_v3 exp=dreamer_v3 env=dummy dry_run=True " + "env.capture_video=False algo.dense_units=8 algo.horizon=8 " - + "algo.world_model.encoder.cnn_channels_multiplier=2 algo.gradient_steps=1 " + + "algo.world_model.encoder.cnn_channels_multiplier=2 algo.per_rank_gradient_steps=1 " + "algo.world_model.recurrent_model.recurrent_state_size=8 " + "algo.world_model.representation_model.hidden_size=8 algo.learning_starts=0 " + "algo.world_model.transition_model.hidden_size=8 buffer.size=10 " From fb4372adfc30e595b07cdbb2dc6d6674e4778352 Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Mon, 11 Sep 2023 15:31:01 +0200 Subject: [PATCH 06/31] fix: env default configs --- sheeprl/configs/env/default.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sheeprl/configs/env/default.yaml b/sheeprl/configs/env/default.yaml index a63e9d36..6635a685 100644 --- a/sheeprl/configs/env/default.yaml +++ b/sheeprl/configs/env/default.yaml @@ -4,7 +4,7 @@ frame_stack: 1 screen_size: 64 action_repeat: 1 grayscale: False -clip_rewards: True +clip_rewards: False capture_video: True frame_stack_dilation: 1 max_episode_steps: null From 4cf672c697775a902284473ca10cd7fdca35e992 Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Mon, 11 Sep 2023 15:41:50 +0200 Subject: [PATCH 07/31] fix: dreamer_v3 test --- sheeprl/algos/dreamer_v3/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sheeprl/algos/dreamer_v3/utils.py b/sheeprl/algos/dreamer_v3/utils.py index a0d1ac76..98789a06 100644 --- a/sheeprl/algos/dreamer_v3/utils.py +++ b/sheeprl/algos/dreamer_v3/utils.py @@ -102,7 +102,7 @@ def test( real_actions = np.array([real_act.cpu().argmax(dim=-1).numpy() for real_act in real_actions]) # Single environment step - next_obs, reward, done, truncated, _ = env.step(int(real_actions.reshape(env.action_space.shape))) + next_obs, reward, done, truncated, _ = env.step(real_actions.reshape(env.action_space.shape)) for k in next_obs.keys(): next_obs[k] = torch.from_numpy(next_obs[k]).view(1, *next_obs[k].shape).float() done = done or truncated or cfg.dry_run From b6f75ca059e61936a8865229351e9445d816fd34 Mon Sep 17 00:00:00 2001 From: belerico_t Date: Mon, 11 Sep 2023 17:56:55 +0200 Subject: [PATCH 08/31] For PPO, add timer + change timer logging + add print config --- examples/architecture_template.py | 2 +- pyproject.toml | 3 +- sheeprl/algos/ppo/ppo.py | 79 ++++++--- sheeprl/algos/ppo/ppo_decoupled.py | 252 +++++++++++++++++------------ sheeprl/configs/hydra/default.yaml | 2 + sheeprl/utils/timer.py | 82 ++++++++++ sheeprl/utils/utils.py | 37 ++++- 7 files changed, 323 insertions(+), 134 deletions(-) create mode 100644 sheeprl/utils/timer.py diff --git a/examples/architecture_template.py b/examples/architecture_template.py index e4179dd1..bf3fdf37 100644 --- a/examples/architecture_template.py +++ b/examples/architecture_template.py @@ -139,7 +139,7 @@ def main(): if devices is None or devices in ("1", "2"): raise RuntimeError( "Please run the script with the number of devices greater than 2: " - "`lightning run model --devices=3 sheeprl.py ...`" + "`lightning run model --devices=3 examples/architecture_template.py ...`" ) world_collective = TorchCollective() diff --git a/pyproject.toml b/pyproject.toml index eaa47404..af2248b9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,8 @@ dependencies = [ "lightning==2.0.*", "lightning-utilities<=0.9", "hydra-core==1.3.0", - "torchmetrics==1.1.*" + "torchmetrics==1.1.*", + "rich==13.5.*" ] dynamic = ["version"] diff --git a/sheeprl/algos/ppo/ppo.py b/sheeprl/algos/ppo/ppo.py index ba59d0b6..f6c257e7 100644 --- a/sheeprl/algos/ppo/ppo.py +++ b/sheeprl/algos/ppo/ppo.py @@ -1,6 +1,5 @@ import copy import os -import time import warnings from typing import Union @@ -27,7 +26,8 @@ from sheeprl.utils.logger import create_tensorboard_logger from sheeprl.utils.metric import MetricAggregator, RankIndependentMetricAggregator from sheeprl.utils.registry import register_algorithm -from sheeprl.utils.utils import gae, normalize_tensor, polynomial_decay +from sheeprl.utils.timer import timer +from sheeprl.utils.utils import gae, normalize_tensor, polynomial_decay, print_config def train( @@ -108,6 +108,8 @@ def train( @register_algorithm() @hydra.main(version_base=None, config_path="../../configs", config_name="config") def main(cfg: DictConfig): + print_config(cfg) + if "minedojo" in cfg.env.env._target_.lower(): raise ValueError( "MineDojo is not currently supported by PPO agent, since it does not take " @@ -194,15 +196,16 @@ def main(cfg: DictConfig): { "Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), "Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Time/step_per_second": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), "Loss/value_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), "Loss/policy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), "Loss/entropy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), } ).to(device) - rank_metrics = {f"Rewards/rew_env_{rank * cfg.num_envs + i}": RunningMean(window=10) for i in range(cfg.num_envs)} + rank_metrics = { + f"Rewards/rew_env_{rank * cfg.env.num_envs + i}": RunningMean(window=10) for i in range(cfg.env.num_envs) + } rank_metrics.update( - {f"Game/ep_len_env_{rank * cfg.num_envs + i}": RunningMean(window=10) for i in range(cfg.num_envs)} + {f"Game/ep_len_env_{rank * cfg.env.num_envs + i}": RunningMean(window=10) for i in range(cfg.env.num_envs)} ) rank_aggregator = RankIndependentMetricAggregator(rank_metrics).to(device) @@ -224,9 +227,9 @@ def main(cfg: DictConfig): # Global variables last_log = 0 + last_train = 0 policy_step = 0 last_checkpoint = 0 - start_time = time.perf_counter() policy_steps_per_update = int(cfg.env.num_envs * cfg.algo.rollout_steps * world_size) num_updates = cfg.total_steps // policy_steps_per_update if not cfg.dry_run else 1 @@ -270,21 +273,24 @@ def main(cfg: DictConfig): for _ in range(0, cfg.algo.rollout_steps): policy_step += cfg.env.num_envs * world_size - with torch.no_grad(): - # Sample an action given the observation received by the environment - normalized_obs = { - k: next_obs[k] / 255 - 0.5 if k in cfg.cnn_keys.encoder else next_obs[k] for k in obs_keys - } - actions, logprobs, _, value = agent.module(normalized_obs) - if is_continuous: - real_actions = torch.cat(actions, -1).cpu().numpy() - else: - real_actions = np.concatenate([act.argmax(dim=-1).cpu().numpy() for act in actions], axis=-1) - actions = torch.cat(actions, -1) - - # Single environment step - o, reward, done, truncated, info = envs.step(real_actions) - done = np.logical_or(done, truncated) + # Measure environment interaction time: this considers both the model forward + # to get the action given the observation and the time taken into the environment + with timer("Time/sps_env_interaction"): + with torch.no_grad(): + # Sample an action given the observation received by the environment + normalized_obs = { + k: next_obs[k] / 255 - 0.5 if k in cfg.cnn_keys.encoder else next_obs[k] for k in obs_keys + } + actions, logprobs, _, value = agent.module(normalized_obs) + if is_continuous: + real_actions = torch.cat(actions, -1).cpu().numpy() + else: + real_actions = np.concatenate([act.argmax(dim=-1).cpu().numpy() for act in actions], axis=-1) + actions = torch.cat(actions, -1) + + # Single environment step + o, reward, done, truncated, info = envs.step(real_actions) + done = np.logical_or(done, truncated) with device: rewards = torch.tensor(reward, dtype=torch.float32).view(cfg.env.num_envs, -1) # [N_envs, 1] @@ -324,8 +330,8 @@ def main(cfg: DictConfig): ep_len = agent_ep_info["episode"]["l"] aggregator.update("Rewards/rew_avg", ep_rew) aggregator.update("Game/ep_len_avg", ep_len) - rank_aggregator.update(f"Rewards/rew_env_{rank * cfg.num_envs + i}", ep_rew) - rank_aggregator.update(f"Game/ep_len_env_{rank * cfg.num_envs + i}", ep_len) + rank_aggregator.update(f"Rewards/rew_env_{rank * cfg.env.num_envs + i}", ep_rew) + rank_aggregator.update(f"Game/ep_len_env_{rank * cfg.env.num_envs + i}", ep_len) fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[0]}") # Estimate returns with GAE (https://arxiv.org/abs/1506.02438) @@ -359,7 +365,8 @@ def main(cfg: DictConfig): else: gathered_data = local_data - train(fabric, agent, optimizer, gathered_data, aggregator, cfg) + with timer("Time/sps_train"): + train(fabric, agent, optimizer, gathered_data, aggregator, cfg) if cfg.algo.anneal_lr: fabric.log("Info/learning_rate", scheduler.get_last_lr()[0], policy_step) @@ -381,14 +388,34 @@ def main(cfg: DictConfig): # Log metrics if policy_step - last_log >= cfg.metric.log_every or update == num_updates or cfg.dry_run: - last_log = policy_step + # Sync distributed metrics metrics_dict = aggregator.compute() fabric.log_dict(metrics_dict, policy_step) aggregator.reset() + + # Sync per-rank metrics rank_metrics = rank_aggregator.compute() for m in rank_metrics: fabric.log_dict(m, policy_step) - fabric.log("Time/step_per_second", int(policy_step / (time.perf_counter() - start_time)), policy_step) + + # Sync distributed timers + timer_metrics = timer.compute() + fabric.log( + "Time/sps_train", + (update - last_train) / timer_metrics["Time/sps_train"], + policy_step, + ) + fabric.log( + "Time/sps_env_interaction", + ((policy_step - last_log) / world_size * cfg.env.action_repeat) + / timer_metrics["Time/sps_env_interaction"], + policy_step, + ) + timer.reset() + + # Reset counters + last_train = update + last_log = policy_step # Checkpoint model if ( diff --git a/sheeprl/algos/ppo/ppo_decoupled.py b/sheeprl/algos/ppo/ppo_decoupled.py index da5ffb2b..109b2ed2 100644 --- a/sheeprl/algos/ppo/ppo_decoupled.py +++ b/sheeprl/algos/ppo/ppo_decoupled.py @@ -19,7 +19,7 @@ from tensordict.tensordict import TensorDictBase, make_tensordict from torch.distributed.algorithms.join import Join from torch.utils.data import BatchSampler, RandomSampler -from torchmetrics import MeanMetric +from torchmetrics import MeanMetric, RunningMean from sheeprl.algos.ppo.agent import PPOAgent from sheeprl.algos.ppo.loss import entropy_loss, policy_loss, value_loss @@ -29,11 +29,15 @@ from sheeprl.utils.env import make_dict_env from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm -from sheeprl.utils.utils import gae, normalize_tensor, polynomial_decay +from sheeprl.utils.timer import timer +from sheeprl.utils.utils import gae, normalize_tensor, polynomial_decay, print_config @torch.no_grad() def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_collective: TorchCollective): + print_config(cfg) + + # Initialize logger root_dir = ( os.path.join("logs", "runs", cfg.root_dir) if cfg.root_dir is not None @@ -120,14 +124,17 @@ def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_co torch.nn.utils.convert_parameters.vector_to_parameters(flattened_parameters, list(agent.parameters())) # Metrics - with device: - aggregator = MetricAggregator( - { - "Rewards/rew_avg": MeanMetric(sync_on_compute=False), - "Game/ep_len_avg": MeanMetric(sync_on_compute=False), - "Time/step_per_second": MeanMetric(sync_on_compute=False), - } - ) + metrics = { + "Rewards/rew_avg": MeanMetric(sync_on_compute=False), + "Game/ep_len_avg": MeanMetric(sync_on_compute=False), + } + metrics.update( + {f"Rewards/rew_env_{i}": RunningMean(window=10, sync_on_compute=False) for i in range(cfg.env.num_envs)} + ) + metrics.update( + {f"Game/ep_len_env_{i}": RunningMean(window=10, sync_on_compute=False) for i in range(cfg.env.num_envs)} + ) + aggregator = MetricAggregator(metrics).to(device) # Local data rb = ReplayBuffer( @@ -140,10 +147,10 @@ def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_co step_data = TensorDict({}, batch_size=[cfg.env.num_envs], device=device) # Global variables - policy_step = 0 last_log = 0 + last_train = 0 + policy_step = 0 last_checkpoint = 0 - start_time = time.perf_counter() policy_steps_per_update = int(cfg.env.num_envs * cfg.algo.rollout_steps) num_updates = cfg.total_steps // policy_steps_per_update if not cfg.dry_run else 1 # Warning for log and checkpoint every @@ -195,21 +202,24 @@ def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_co for _ in range(0, cfg.algo.rollout_steps): policy_step += cfg.env.num_envs - with torch.no_grad(): - # Sample an action given the observation received by the environment - normalized_obs = { - k: next_obs[k] / 255 - 0.5 if k in cfg.cnn_keys.encoder else next_obs[k] for k in obs_keys - } - actions, logprobs, _, value = agent(normalized_obs) - if is_continuous: - real_actions = torch.cat(actions, -1).cpu().numpy() - else: - real_actions = np.concatenate([act.argmax(dim=-1).cpu().numpy() for act in actions], axis=-1) - actions = torch.cat(actions, -1) + # Measure environment interaction time: this considers both the model forward + # to get the action given the observation and the time taken into the environment + with timer("Time/sps_env_interaction", sync_on_compute=False): + with torch.no_grad(): + # Sample an action given the observation received by the environment + normalized_obs = { + k: next_obs[k] / 255 - 0.5 if k in cfg.cnn_keys.encoder else next_obs[k] for k in obs_keys + } + actions, logprobs, _, value = agent(normalized_obs) + if is_continuous: + real_actions = torch.cat(actions, -1).cpu().numpy() + else: + real_actions = np.concatenate([act.argmax(dim=-1).cpu().numpy() for act in actions], axis=-1) + actions = torch.cat(actions, -1) - # Single environment step - o, reward, done, truncated, info = envs.step(real_actions) - done = np.logical_or(done, truncated) + # Single environment step + o, reward, done, truncated, info = envs.step(real_actions) + done = np.logical_or(done, truncated) with device: rewards = torch.tensor(reward, dtype=torch.float32).view(cfg.env.num_envs, -1) # [N_envs, 1] @@ -242,13 +252,15 @@ def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_co next_done = done if "final_info" in info: - for i, agent_final_info in enumerate(info["final_info"]): - if agent_final_info is not None and "episode" in agent_final_info: - fabric.print( - f"Rank-0: policy_step={policy_step}, reward_env_{i}={agent_final_info['episode']['r'][0]}" - ) - aggregator.update("Rewards/rew_avg", agent_final_info["episode"]["r"][0]) - aggregator.update("Game/ep_len_avg", agent_final_info["episode"]["l"][0]) + for i, agent_ep_info in enumerate(info["final_info"]): + if agent_ep_info is not None: + ep_rew = agent_ep_info["episode"]["r"] + ep_len = agent_ep_info["episode"]["l"] + aggregator.update("Rewards/rew_avg", ep_rew) + aggregator.update("Game/ep_len_avg", ep_len) + aggregator.update(f"Rewards/rew_env_{i}", ep_rew) + aggregator.update(f"Game/ep_len_env_{i}", ep_len) + fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[0]}") # Estimate returns with GAE (https://arxiv.org/abs/1506.02438) normalized_obs = {k: next_obs[k] / 255 - 0.5 if k in cfg.cnn_keys.encoder else next_obs[k] for k in obs_keys} @@ -277,11 +289,6 @@ def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_co chunks = local_data[perm].split(chunks_sizes) world_collective.scatter_object_list([None], [None] + chunks, src=0) - # Gather metrics from the trainers to be plotted - if policy_step - last_log >= cfg.metric.log_every or cfg.dry_run: - metrics = [None] - player_trainer_collective.broadcast_object_list(metrics, src=1) - # Wait the trainers to finish player_trainer_collective.broadcast(flattened_parameters, src=1) @@ -289,13 +296,36 @@ def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_co torch.nn.utils.convert_parameters.vector_to_parameters(flattened_parameters, list(agent.parameters())) # Log metrics - aggregator.update("Time/step_per_second", int(policy_step / (time.perf_counter() - start_time))) if policy_step - last_log >= cfg.metric.log_every or cfg.dry_run: - last_log = policy_step - fabric.log_dict(metrics[0], policy_step) + # Gather metrics from the trainers + metrics = [None] + player_trainer_collective.broadcast_object_list(metrics, src=1) + metrics = metrics[0] + sps_train = metrics.pop("Time/sps_train") + + # Log metrics + fabric.log_dict(metrics, policy_step) fabric.log_dict(aggregator.compute(), policy_step) aggregator.reset() + # Sync timers + timer_metrics = timer.compute() + fabric.log( + "Time/sps_train", + (update - last_train) / sps_train, + policy_step, + ) + fabric.log( + "Time/sps_env_interaction", + ((policy_step - last_log) * cfg.env.action_repeat) / timer_metrics["Time/sps_env_interaction"], + policy_step, + ) + timer.reset() + + # Reset counters + last_train = update + last_log = policy_step + # Checkpoint model if (cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every) or cfg.dry_run: last_checkpoint = policy_step @@ -399,11 +429,11 @@ def trainer( # Start training update = 0 - initial_ent_coef = copy.deepcopy(cfg.algo.ent_coef) - initial_clip_coef = copy.deepcopy(cfg.algo.clip_coef) - policy_step = 0 last_log = 0 + policy_step = 0 last_checkpoint = 0 + initial_ent_coef = copy.deepcopy(cfg.algo.ent_coef) + initial_clip_coef = copy.deepcopy(cfg.algo.clip_coef) while True: # Wait for data data = [None] @@ -424,68 +454,82 @@ def trainer( update += 1 policy_step += cfg.env.num_envs * cfg.algo.rollout_steps - # Prepare sampler - indexes = list(range(data.shape[0])) - sampler = BatchSampler(RandomSampler(indexes), batch_size=cfg.per_rank_batch_size, drop_last=False) + with timer("Time/sps_train", process_group=optimization_pg): + # Prepare sampler + indexes = list(range(data.shape[0])) + sampler = BatchSampler(RandomSampler(indexes), batch_size=cfg.per_rank_batch_size, drop_last=False) + + # The Join context is needed because there can be the possibility + # that some ranks receive less data + with Join([agent._forward_module]): + for _ in range(cfg.algo.update_epochs): + for batch_idxes in sampler: + batch = data[batch_idxes] + normalized_obs = { + k: batch[k] / 255 - 0.5 if k in agent.feature_extractor.cnn_keys else batch[k] + for k in cfg.cnn_keys.encoder + cfg.mlp_keys.encoder + } + _, logprobs, entropy, new_values = agent( + normalized_obs, torch.split(batch["actions"], agent.actions_dim, dim=-1) + ) + + if cfg.algo.normalize_advantages: + batch["advantages"] = normalize_tensor(batch["advantages"]) - # The Join context is needed because there can be the possibility - # that some ranks receive less data - with Join([agent._forward_module]): - for _ in range(cfg.algo.update_epochs): - for batch_idxes in sampler: - batch = data[batch_idxes] - normalized_obs = { - k: batch[k] / 255 - 0.5 if k in agent.feature_extractor.cnn_keys else batch[k] - for k in cfg.cnn_keys.encoder + cfg.mlp_keys.encoder - } - _, logprobs, entropy, new_values = agent( - normalized_obs, torch.split(batch["actions"], agent.actions_dim, dim=-1) - ) - - if cfg.algo.normalize_advantages: - batch["advantages"] = normalize_tensor(batch["advantages"]) - - # Policy loss - pg_loss = policy_loss( - logprobs, - batch["logprobs"], - batch["advantages"], - cfg.algo.clip_coef, - cfg.algo.loss_reduction, - ) - - # Value loss - v_loss = value_loss( - new_values, - batch["values"], - batch["returns"], - cfg.algo.clip_coef, - cfg.algo.clip_vloss, - cfg.algo.loss_reduction, - ) - - # Entropy loss - ent_loss = entropy_loss(entropy, cfg.algo.loss_reduction) - - # Equation (9) in the paper - loss = pg_loss + cfg.algo.vf_coef * v_loss + cfg.algo.ent_coef * ent_loss - - optimizer.zero_grad(set_to_none=True) - fabric.backward(loss) - if cfg.algo.max_grad_norm > 0.0: - fabric.clip_gradients(agent, optimizer, max_norm=cfg.algo.max_grad_norm) - optimizer.step() - - # Update metrics - aggregator.update("Loss/policy_loss", pg_loss.detach()) - aggregator.update("Loss/value_loss", v_loss.detach()) - aggregator.update("Loss/entropy_loss", ent_loss.detach()) + # Policy loss + pg_loss = policy_loss( + logprobs, + batch["logprobs"], + batch["advantages"], + cfg.algo.clip_coef, + cfg.algo.loss_reduction, + ) + + # Value loss + v_loss = value_loss( + new_values, + batch["values"], + batch["returns"], + cfg.algo.clip_coef, + cfg.algo.clip_vloss, + cfg.algo.loss_reduction, + ) + + # Entropy loss + ent_loss = entropy_loss(entropy, cfg.algo.loss_reduction) + + # Equation (9) in the paper + loss = pg_loss + cfg.algo.vf_coef * v_loss + cfg.algo.ent_coef * ent_loss + + optimizer.zero_grad(set_to_none=True) + fabric.backward(loss) + if cfg.algo.max_grad_norm > 0.0: + fabric.clip_gradients(agent, optimizer, max_norm=cfg.algo.max_grad_norm) + optimizer.step() + + # Update metrics + aggregator.update("Loss/policy_loss", pg_loss.detach()) + aggregator.update("Loss/value_loss", v_loss.detach()) + aggregator.update("Loss/entropy_loss", ent_loss.detach()) + + if global_rank == 1: + player_trainer_collective.broadcast( + torch.nn.utils.convert_parameters.parameters_to_vector(agent.parameters()), + src=1, + ) # Send updated weights to the player if policy_step - last_log >= cfg.metric.log_every or cfg.dry_run: - last_log = policy_step + # Sync distributed metrics metrics = aggregator.compute() aggregator.reset() + + # Sync distributed timers + timers = timer.compute() + metrics.update(timers) + timer.reset() + + # Send metrics to the player if global_rank == 1: if cfg.algo.anneal_lr: metrics["Info/learning_rate"] = scheduler.get_last_lr()[0] @@ -496,11 +540,9 @@ def trainer( player_trainer_collective.broadcast_object_list( [metrics], src=1 ) # Broadcast metrics: fake send with object list between rank-0 and rank-1 - if global_rank == 1: - player_trainer_collective.broadcast( - torch.nn.utils.convert_parameters.parameters_to_vector(agent.parameters()), - src=1, - ) + + # Reset counters + last_log = policy_step if cfg.algo.anneal_lr: scheduler.step() diff --git a/sheeprl/configs/hydra/default.yaml b/sheeprl/configs/hydra/default.yaml index b0627fbe..afc6e6b4 100644 --- a/sheeprl/configs/hydra/default.yaml +++ b/sheeprl/configs/hydra/default.yaml @@ -1,2 +1,4 @@ run: dir: logs/runs/${root_dir}/${run_name} +job: + chdir: True diff --git a/sheeprl/utils/timer.py b/sheeprl/utils/timer.py new file mode 100644 index 00000000..41043178 --- /dev/null +++ b/sheeprl/utils/timer.py @@ -0,0 +1,82 @@ +# timer.py + +import time +from contextlib import ContextDecorator +from typing import Dict, Optional, Union + +import torch +from torchmetrics import MeanMetric + + +class TimerError(Exception): + """A custom exception used to report errors in use of timer class""" + + +class timer(ContextDecorator): + """A timer class to measure the time of a code block.""" + + disabled: bool = False + timers: Dict[str, MeanMetric] = {} + _start_time: Optional[float] = None + + def __init__(self, name: str, **metric_kwargs) -> None: + """Add timer to dict of timers after initialization""" + self.name = name + if not timer.disabled and self.name is not None and self.name not in self.timers.keys(): + self.timers.setdefault(self.name, MeanMetric(**metric_kwargs)) + + def start(self) -> None: + """Start a new timer""" + if self._start_time is not None: + raise TimerError("timer is running. Use .stop() to stop it") + + self._start_time = time.perf_counter() + + def stop(self) -> float: + """Stop the timer, and report the elapsed time""" + if self._start_time is None: + raise TimerError("timer is not running. Use .start() to start it") + + # Calculate elapsed time + elapsed_time = time.perf_counter() - self._start_time + self._start_time = None + + # Report elapsed time + if self.name: + self.timers[self.name].update(elapsed_time) + + return elapsed_time + + @classmethod + def to(cls, device: Union[str, torch.device] = "cpu") -> None: + """Create a new timer on a different device""" + if cls.timers: + for k, v in cls.timers.items(): + cls.timers[k] = v.to(device) + + @classmethod + def reset(cls) -> None: + """Reset all timers""" + for timer in cls.timers.values(): + timer.reset() + cls._start_time = None + + @classmethod + def compute(cls) -> Dict[str, torch.Tensor]: + """Reduce the timers to a single value""" + reduced_timers = {} + if cls.timers: + for k, v in cls.timers.items(): + reduced_timers[k] = v.compute().item() + return reduced_timers + + def __enter__(self): + """Start a new timer as a context manager""" + if not timer.disabled: + self.start() + return self + + def __exit__(self, *exc_info): + """Stop the context manager timer""" + if not timer.disabled: + self.stop() diff --git a/sheeprl/utils/utils.py b/sheeprl/utils/utils.py index 24bb1033..d808ce69 100644 --- a/sheeprl/utils/utils.py +++ b/sheeprl/utils/utils.py @@ -1,7 +1,12 @@ -from typing import Optional, Tuple +import os +from typing import Optional, Sequence, Tuple +import rich.syntax +import rich.tree import torch import torch.nn as nn +from omegaconf import DictConfig, OmegaConf +from pytorch_lightning.utilities import rank_zero_only from torch import Tensor @@ -131,3 +136,33 @@ def symlog(x: Tensor) -> Tensor: def symexp(x: Tensor) -> Tensor: return torch.sign(x) * (torch.exp(torch.abs(x)) - 1) + + +@rank_zero_only +def print_config( + config: DictConfig, + fields: Sequence[str] = ("algo", "buffer", "checkpoint", "env", "exp", "hydra", "metric", "optim"), + resolve: bool = True, +) -> None: + """Prints content of DictConfig using Rich library and its tree structure. + + Args: + config: Configuration composed by Hydra. + fields: Determines which main fields from config will + be printed and in what order. + resolve: Whether to resolve reference fields of DictConfig. + """ + style = "dim" + tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) + + for field in fields: + branch = tree.add(field, style=style, guide_style=style) + config_section = config.get(field) + branch_content = str(config_section) + if isinstance(config_section, DictConfig): + branch_content = OmegaConf.to_yaml(config_section, resolve=resolve) + branch.add(rich.syntax.Syntax(branch_content, "yaml")) + + rich.print(tree) + with open(os.path.join(os.getcwd(), "config_tree.txt"), "w") as fp: + rich.print(tree, file=fp) From 881c8777e483245a12e46047dc343ec232d36f4b Mon Sep 17 00:00:00 2001 From: belerico_t Date: Tue, 12 Sep 2023 11:05:48 +0200 Subject: [PATCH 09/31] Fix logging PPO recurrent --- sheeprl/algos/ppo/ppo.py | 20 +--- sheeprl/algos/ppo/ppo_decoupled.py | 12 +-- sheeprl/algos/ppo_recurrent/ppo_recurrent.py | 101 +++++++++++-------- sheeprl/configs/hydra/default.yaml | 2 +- sheeprl/utils/utils.py | 8 +- 5 files changed, 71 insertions(+), 72 deletions(-) diff --git a/sheeprl/algos/ppo/ppo.py b/sheeprl/algos/ppo/ppo.py index f6c257e7..6b60c58a 100644 --- a/sheeprl/algos/ppo/ppo.py +++ b/sheeprl/algos/ppo/ppo.py @@ -15,7 +15,7 @@ from tensordict.tensordict import TensorDictBase from torch import nn from torch.utils.data import BatchSampler, DistributedSampler, RandomSampler -from torchmetrics import MeanMetric, RunningMean +from torchmetrics import MeanMetric from sheeprl.algos.ppo.agent import PPOAgent from sheeprl.algos.ppo.loss import entropy_loss, policy_loss, value_loss @@ -24,7 +24,7 @@ from sheeprl.utils.callback import CheckpointCallback from sheeprl.utils.env import make_dict_env from sheeprl.utils.logger import create_tensorboard_logger -from sheeprl.utils.metric import MetricAggregator, RankIndependentMetricAggregator +from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer from sheeprl.utils.utils import gae, normalize_tensor, polynomial_decay, print_config @@ -201,13 +201,6 @@ def main(cfg: DictConfig): "Loss/entropy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), } ).to(device) - rank_metrics = { - f"Rewards/rew_env_{rank * cfg.env.num_envs + i}": RunningMean(window=10) for i in range(cfg.env.num_envs) - } - rank_metrics.update( - {f"Game/ep_len_env_{rank * cfg.env.num_envs + i}": RunningMean(window=10) for i in range(cfg.env.num_envs)} - ) - rank_aggregator = RankIndependentMetricAggregator(rank_metrics).to(device) # Local data if cfg.buffer.size < cfg.algo.rollout_steps: @@ -330,9 +323,7 @@ def main(cfg: DictConfig): ep_len = agent_ep_info["episode"]["l"] aggregator.update("Rewards/rew_avg", ep_rew) aggregator.update("Game/ep_len_avg", ep_len) - rank_aggregator.update(f"Rewards/rew_env_{rank * cfg.env.num_envs + i}", ep_rew) - rank_aggregator.update(f"Game/ep_len_env_{rank * cfg.env.num_envs + i}", ep_len) - fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[0]}") + fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") # Estimate returns with GAE (https://arxiv.org/abs/1506.02438) with torch.no_grad(): @@ -393,11 +384,6 @@ def main(cfg: DictConfig): fabric.log_dict(metrics_dict, policy_step) aggregator.reset() - # Sync per-rank metrics - rank_metrics = rank_aggregator.compute() - for m in rank_metrics: - fabric.log_dict(m, policy_step) - # Sync distributed timers timer_metrics = timer.compute() fabric.log( diff --git a/sheeprl/algos/ppo/ppo_decoupled.py b/sheeprl/algos/ppo/ppo_decoupled.py index 109b2ed2..4ac56c27 100644 --- a/sheeprl/algos/ppo/ppo_decoupled.py +++ b/sheeprl/algos/ppo/ppo_decoupled.py @@ -19,7 +19,7 @@ from tensordict.tensordict import TensorDictBase, make_tensordict from torch.distributed.algorithms.join import Join from torch.utils.data import BatchSampler, RandomSampler -from torchmetrics import MeanMetric, RunningMean +from torchmetrics import MeanMetric from sheeprl.algos.ppo.agent import PPOAgent from sheeprl.algos.ppo.loss import entropy_loss, policy_loss, value_loss @@ -128,12 +128,6 @@ def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_co "Rewards/rew_avg": MeanMetric(sync_on_compute=False), "Game/ep_len_avg": MeanMetric(sync_on_compute=False), } - metrics.update( - {f"Rewards/rew_env_{i}": RunningMean(window=10, sync_on_compute=False) for i in range(cfg.env.num_envs)} - ) - metrics.update( - {f"Game/ep_len_env_{i}": RunningMean(window=10, sync_on_compute=False) for i in range(cfg.env.num_envs)} - ) aggregator = MetricAggregator(metrics).to(device) # Local data @@ -258,9 +252,7 @@ def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_co ep_len = agent_ep_info["episode"]["l"] aggregator.update("Rewards/rew_avg", ep_rew) aggregator.update("Game/ep_len_avg", ep_len) - aggregator.update(f"Rewards/rew_env_{i}", ep_rew) - aggregator.update(f"Game/ep_len_env_{i}", ep_len) - fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[0]}") + fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") # Estimate returns with GAE (https://arxiv.org/abs/1506.02438) normalized_obs = {k: next_obs[k] / 255 - 0.5 if k in cfg.cnn_keys.encoder else next_obs[k] for k in obs_keys} diff --git a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py index 8b0b528a..b5716578 100644 --- a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py +++ b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py @@ -1,7 +1,6 @@ import copy import itertools import os -import time import warnings from contextlib import nullcontext from math import prod @@ -30,7 +29,8 @@ from sheeprl.utils.logger import create_tensorboard_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm -from sheeprl.utils.utils import gae, normalize_tensor, polynomial_decay +from sheeprl.utils.timer import timer +from sheeprl.utils.utils import gae, normalize_tensor, polynomial_decay, print_config def train( @@ -109,6 +109,7 @@ def train( @register_algorithm(decoupled=True) @hydra.main(version_base=None, config_path="../../configs", config_name="config") def main(cfg: DictConfig): + print_config(cfg) initial_ent_coef = copy.deepcopy(cfg.algo.ent_coef) initial_clip_coef = copy.deepcopy(cfg.algo.clip_coef) @@ -173,18 +174,16 @@ def main(cfg: DictConfig): ) optimizer = fabric.setup_optimizers(hydra.utils.instantiate(cfg.algo.optimizer, params=agent.parameters())) - # Metrics - with device: - aggregator = MetricAggregator( - { - "Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Time/step_per_second": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/value_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/policy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/entropy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - } - ) + # Create a metric aggregator to log the metrics + aggregator = MetricAggregator( + { + "Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/value_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/policy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/entropy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + } + ).to(device) # Local data rb = ReplayBuffer( @@ -197,13 +196,12 @@ def main(cfg: DictConfig): step_data = TensorDict({}, batch_size=[1, cfg.env.num_envs], device=device) # Global variables - policy_step = 0 last_log = 0 + last_train = 0 + policy_step = 0 last_checkpoint = 0 - start_time = time.perf_counter() policy_steps_per_update = int(cfg.env.num_envs * cfg.algo.rollout_steps * world_size) num_updates = cfg.total_steps // policy_steps_per_update if not cfg.dry_run else 1 - last_log = 0 # Warning for log and checkpoint every if cfg.metric.log_every % policy_steps_per_update != 0: @@ -229,24 +227,27 @@ def main(cfg: DictConfig): with device: # Get the first environment observation and start the optimization - next_obs = torch.tensor(envs.reset(seed=cfg.seed)[0], dtype=torch.float32).unsqueeze(0) # [1, N_envs, N_obs] - next_done = torch.zeros(1, cfg.env.num_envs, 1, dtype=torch.float32) # [1, N_envs, 1] next_state = agent.initial_states + next_done = torch.zeros(1, cfg.env.num_envs, 1, dtype=torch.float32) # [1, N_envs, 1] + next_obs = torch.tensor(envs.reset(seed=cfg.seed)[0], dtype=torch.float32).unsqueeze(0) # [1, N_envs, N_obs] for update in range(1, num_updates + 1): for _ in range(0, cfg.algo.rollout_steps): policy_step += cfg.env.num_envs * world_size - with torch.no_grad(): - # Sample an action given the observation received by the environment - action_logits, values, state = agent.module(next_obs, state=next_state) - dist = Categorical(logits=action_logits.unsqueeze(-2)) - action = dist.sample() - logprob = dist.log_prob(action) + # Measure environment interaction time: this considers both the model forward + # to get the action given the observation and the time taken into the environment + with timer("Time/sps_env_interaction"): + with torch.no_grad(): + # Sample an action given the observation received by the environment + action_logits, values, state = agent.module(next_obs, state=next_state) + dist = Categorical(logits=action_logits.unsqueeze(-2)) + action = dist.sample() + logprob = dist.log_prob(action) - # Single environment step - obs, reward, done, truncated, info = envs.step(action.cpu().numpy().reshape(envs.action_space.shape)) - done = np.logical_or(done, truncated) + # Single environment step + obs, reward, done, truncated, info = envs.step(action.cpu().numpy().reshape(envs.action_space.shape)) + done = np.logical_or(done, truncated) with device: obs = torch.tensor(obs, dtype=torch.float32).unsqueeze(0) # [1, N_envs, N_obs] @@ -279,13 +280,13 @@ def main(cfg: DictConfig): next_state = state if "final_info" in info: - for i, agent_final_info in enumerate(info["final_info"]): - if agent_final_info is not None and "episode" in agent_final_info: - fabric.print( - f"Rank-0: policy_step={policy_step}, reward_env_{i}={agent_final_info['episode']['r'][0]}" - ) - aggregator.update("Rewards/rew_avg", agent_final_info["episode"]["r"][0]) - aggregator.update("Game/ep_len_avg", agent_final_info["episode"]["l"][0]) + for i, agent_ep_info in enumerate(info["final_info"]): + if agent_ep_info is not None: + ep_rew = agent_ep_info["episode"]["r"] + ep_len = agent_ep_info["episode"]["l"] + aggregator.update("Rewards/rew_avg", ep_rew) + aggregator.update("Game/ep_len_avg", ep_len) + fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") # Estimate returns with GAE (https://arxiv.org/abs/1506.02438) with torch.no_grad(): @@ -309,8 +310,6 @@ def main(cfg: DictConfig): local_data = rb.buffer # Train the agent - - # Prepare data # 1. Split data into episodes (for every environment) episodes: List[TensorDictBase] = [] for env_id in range(cfg.env.num_envs): @@ -333,7 +332,9 @@ def main(cfg: DictConfig): else: sequences = episodes padded_sequences = pad_sequence(sequences, batch_first=False, return_mask=True) # [Seq_len, Num_seq, *] - train(fabric, agent, optimizer, padded_sequences, aggregator, cfg) + + with timer("Time/sps_train"): + train(fabric, agent, optimizer, padded_sequences, aggregator, cfg) if cfg.algo.anneal_lr: fabric.log("Info/learning_rate", scheduler.get_last_lr()[0], policy_step) @@ -354,13 +355,31 @@ def main(cfg: DictConfig): ) # Log metrics - if policy_step - last_log >= cfg.metric.log_every or cfg.dry_run: - last_log = policy_step + if policy_step - last_log >= cfg.metric.log_every or update == num_updates or cfg.dry_run: + # Sync distributed metrics metrics_dict = aggregator.compute() - fabric.log("Time/step_per_second", int(policy_step / (time.perf_counter() - start_time)), policy_step) fabric.log_dict(metrics_dict, policy_step) aggregator.reset() + # Sync distributed timers + timer_metrics = timer.compute() + fabric.log( + "Time/sps_train", + (update - last_train) / timer_metrics["Time/sps_train"], + policy_step, + ) + fabric.log( + "Time/sps_env_interaction", + ((policy_step - last_log) / world_size * cfg.env.action_repeat) + / timer_metrics["Time/sps_env_interaction"], + policy_step, + ) + timer.reset() + + # Reset counters + last_train = update + last_log = policy_step + # Checkpoint model if ( (cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every) diff --git a/sheeprl/configs/hydra/default.yaml b/sheeprl/configs/hydra/default.yaml index afc6e6b4..a7eed38a 100644 --- a/sheeprl/configs/hydra/default.yaml +++ b/sheeprl/configs/hydra/default.yaml @@ -1,4 +1,4 @@ run: dir: logs/runs/${root_dir}/${run_name} job: - chdir: True + chdir: False diff --git a/sheeprl/utils/utils.py b/sheeprl/utils/utils.py index d808ce69..aeb72d04 100644 --- a/sheeprl/utils/utils.py +++ b/sheeprl/utils/utils.py @@ -1,5 +1,5 @@ import os -from typing import Optional, Sequence, Tuple +from typing import Optional, Sequence, Tuple, Union import rich.syntax import rich.tree @@ -143,6 +143,7 @@ def print_config( config: DictConfig, fields: Sequence[str] = ("algo", "buffer", "checkpoint", "env", "exp", "hydra", "metric", "optim"), resolve: bool = True, + cfg_save_path: Optional[Union[str, os.PathLike]] = None, ) -> None: """Prints content of DictConfig using Rich library and its tree structure. @@ -164,5 +165,6 @@ def print_config( branch.add(rich.syntax.Syntax(branch_content, "yaml")) rich.print(tree) - with open(os.path.join(os.getcwd(), "config_tree.txt"), "w") as fp: - rich.print(tree, file=fp) + if cfg_save_path is not None: + with open(os.path.join(os.getcwd(), "config_tree.txt"), "w") as fp: + rich.print(tree, file=fp) From 211aef4faaf1a3dd10d843a7128fd837999c943d Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Tue, 12 Sep 2023 11:29:24 +0200 Subject: [PATCH 10/31] feat: added possibility of having a screen size that is a multiple of 2 --- sheeprl/algos/dreamer_v3/agent.py | 35 +++++--- sheeprl/algos/dreamer_v3/dreamer_v3.py | 3 +- sheeprl/configs/env/default.yaml | 2 +- .../configs/exp/dreamer_v3_100k_boxing.yaml | 45 ++++++++++ sheeprl/configs/exp/dreamer_v3_L_doapp.yaml | 1 + ..._v3_L_doapp_128px_gray_combo_discrete.yaml | 83 +++++++++++++++++++ 6 files changed, 154 insertions(+), 15 deletions(-) create mode 100644 sheeprl/configs/exp/dreamer_v3_100k_boxing.yaml create mode 100644 sheeprl/configs/exp/dreamer_v3_L_doapp_128px_gray_combo_discrete.yaml diff --git a/sheeprl/algos/dreamer_v3/agent.py b/sheeprl/algos/dreamer_v3/agent.py index e95f300a..edb45b5a 100644 --- a/sheeprl/algos/dreamer_v3/agent.py +++ b/sheeprl/algos/dreamer_v3/agent.py @@ -46,6 +46,7 @@ class CNNEncoder(nn.Module): Defaults to True. activation (ModuleType, optional): the activation function. Defaults to nn.SiLU. + stages (int, optional): how many stages for the CNN. """ def __init__( @@ -56,6 +57,7 @@ def __init__( channels_multiplier: int, layer_norm: bool = True, activation: ModuleType = nn.SiLU, + stages: int = 4, ) -> None: super().__init__() self.keys = keys @@ -63,12 +65,12 @@ def __init__( self.model = nn.Sequential( CNN( input_channels=self.input_dim[0], - hidden_channels=(torch.tensor([1, 2, 4, 8]) * channels_multiplier).tolist(), + hidden_channels=(torch.tensor([2**i for i in range(stages)]) * channels_multiplier).tolist(), cnn_layer=nn.Conv2d, layer_args={"kernel_size": 4, "stride": 2, "padding": 1, "bias": not layer_norm}, activation=activation, - norm_layer=[LayerNormChannelLast for _ in range(4)] if layer_norm else None, - norm_args=[{"normalized_shape": (2**i) * channels_multiplier, "eps": 1e-3} for i in range(4)] + norm_layer=[LayerNormChannelLast for _ in range(stages)] if layer_norm else None, + norm_args=[{"normalized_shape": (2**i) * channels_multiplier, "eps": 1e-3} for i in range(stages)] if layer_norm else None, ), @@ -156,6 +158,7 @@ class CNNDecoder(nn.Module): Defaults to nn.SiLU. layer_norm (bool, optional): whether to apply the layer normalization. Defaults to True. + stages (int): how many stages in the CNN decoder. """ def __init__( @@ -168,6 +171,7 @@ def __init__( image_size: Tuple[int, int], activation: nn.Module = nn.SiLU, layer_norm: bool = True, + stages: int = 4, ) -> None: super().__init__() self.keys = keys @@ -179,19 +183,21 @@ def __init__( nn.Linear(latent_state_size, cnn_encoder_output_dim), nn.Unflatten(1, (-1, 4, 4)), DeCNN( - input_channels=8 * channels_multiplier, - hidden_channels=(torch.tensor([4, 2, 1]) * channels_multiplier).tolist() + [self.output_dim[0]], + input_channels=(2 ** (stages - 1)) * channels_multiplier, + hidden_channels=( + torch.tensor([2**i for i in reversed(range(stages - 1))]) * channels_multiplier + ).tolist() + + [self.output_dim[0]], cnn_layer=nn.ConvTranspose2d, layer_args=[ - {"kernel_size": 4, "stride": 2, "padding": 1, "bias": not layer_norm}, - {"kernel_size": 4, "stride": 2, "padding": 1, "bias": not layer_norm}, - {"kernel_size": 4, "stride": 2, "padding": 1, "bias": not layer_norm}, - {"kernel_size": 4, "stride": 2, "padding": 1}, - ], - activation=[activation, activation, activation, None], - norm_layer=[LayerNormChannelLast for _ in range(3)] + [None] if layer_norm else None, + {"kernel_size": 4, "stride": 2, "padding": 1, "bias": not layer_norm} for _ in range(stages - 1) + ] + + [{"kernel_size": 4, "stride": 2, "padding": 1}], + activation=[activation for _ in range(stages - 1)] + [None], + norm_layer=[LayerNormChannelLast for _ in range(stages - 1)] + [None] if layer_norm else None, norm_args=[ - {"normalized_shape": (2 ** (4 - i - 2)) * channels_multiplier, "eps": 1e-3} for i in range(3) + {"normalized_shape": (2 ** (stages - i - 2)) * channels_multiplier, "eps": 1e-3} + for i in range(stages - 1) ] + [None] if layer_norm @@ -850,6 +856,7 @@ def build_models( latent_state_size = stochastic_size + recurrent_state_size # Define models + cnn_stages = int(np.log2(cfg.env.screen_size) - np.log2(4)) cnn_encoder = ( CNNEncoder( keys=cfg.cnn_keys.encoder, @@ -858,6 +865,7 @@ def build_models( channels_multiplier=world_model_cfg.encoder.cnn_channels_multiplier, layer_norm=world_model_cfg.encoder.layer_norm, activation=eval(world_model_cfg.encoder.cnn_act), + stages=cnn_stages, ) if cfg.cnn_keys.encoder is not None and len(cfg.cnn_keys.encoder) > 0 else None @@ -918,6 +926,7 @@ def build_models( image_size=obs_space[cfg.cnn_keys.decoder[0]].shape[-2:], activation=eval(world_model_cfg.observation_model.cnn_act), layer_norm=world_model_cfg.observation_model.layer_norm, + stages=cnn_stages, ) if cfg.cnn_keys.decoder is not None and len(cfg.cnn_keys.decoder) > 0 else None diff --git a/sheeprl/algos/dreamer_v3/dreamer_v3.py b/sheeprl/algos/dreamer_v3/dreamer_v3.py index c143b414..4f601c5c 100644 --- a/sheeprl/algos/dreamer_v3/dreamer_v3.py +++ b/sheeprl/algos/dreamer_v3/dreamer_v3.py @@ -333,8 +333,9 @@ def train( @hydra.main(version_base=None, config_path="../../configs", config_name="config") def main(cfg: DictConfig): # These arguments cannot be changed - cfg.env.screen_size = 64 cfg.env.frame_stack = -1 + if 2 ** int(np.log2(cfg.env.screen_size)) != cfg.env.screen_size: + raise ValueError(f"The screen size must be a power of 2, got: {cfg.env.screen_size}") # Initialize Fabric fabric = Fabric(callbacks=[CheckpointCallback()]) diff --git a/sheeprl/configs/env/default.yaml b/sheeprl/configs/env/default.yaml index a63e9d36..6635a685 100644 --- a/sheeprl/configs/env/default.yaml +++ b/sheeprl/configs/env/default.yaml @@ -4,7 +4,7 @@ frame_stack: 1 screen_size: 64 action_repeat: 1 grayscale: False -clip_rewards: True +clip_rewards: False capture_video: True frame_stack_dilation: 1 max_episode_steps: null diff --git a/sheeprl/configs/exp/dreamer_v3_100k_boxing.yaml b/sheeprl/configs/exp/dreamer_v3_100k_boxing.yaml new file mode 100644 index 00000000..b090ce42 --- /dev/null +++ b/sheeprl/configs/exp/dreamer_v3_100k_boxing.yaml @@ -0,0 +1,45 @@ +# @package _global_ + +defaults: + - dreamer_v3 + - override /env: atari + - _self_ + +# Experiment +seed: 5 +total_steps: 100000 + +# Environment +env: + num_envs: 1 + max_episode_steps: 27000 + id: BoxingNoFrameskip-v4 + +# Checkpoint +checkpoint: + every: 10000 + +# Metric +metric: + log_every: 5000 + +# Buffer +buffer: + size: 100000 + checkpoint: True + +# Algorithm +algo: + learning_starts: 1024 + train_every: 1 + dense_units: 512 + mlp_layers: 2 + world_model: + encoder: + cnn_channels_multiplier: 32 + recurrent_model: + recurrent_state_size: 512 + transition_model: + hidden_size: 512 + representation_model: + hidden_size: 512 diff --git a/sheeprl/configs/exp/dreamer_v3_L_doapp.yaml b/sheeprl/configs/exp/dreamer_v3_L_doapp.yaml index a50e8c54..f4e844db 100644 --- a/sheeprl/configs/exp/dreamer_v3_L_doapp.yaml +++ b/sheeprl/configs/exp/dreamer_v3_L_doapp.yaml @@ -13,6 +13,7 @@ total_steps: 5000000 env: id: doapp num_envs: 8 + frame_stack: 1 env: diambra_settings: characters: Kasumi diff --git a/sheeprl/configs/exp/dreamer_v3_L_doapp_128px_gray_combo_discrete.yaml b/sheeprl/configs/exp/dreamer_v3_L_doapp_128px_gray_combo_discrete.yaml new file mode 100644 index 00000000..d95e3af1 --- /dev/null +++ b/sheeprl/configs/exp/dreamer_v3_L_doapp_128px_gray_combo_discrete.yaml @@ -0,0 +1,83 @@ +# @package _global_ + +defaults: + - dreamer_v3 + - override /env: diambra + - _self_ + +# Experiment +seed: 0 +total_steps: 10000000 + +# Environment +env: + id: doapp + num_envs: 8 + grayscale: True + frame_stack: 1 + screen_size: 128 + reward_as_observation: True + env: + attack_but_combination: True + diambra_settings: + characters: Kasumi + difficulty: 4 + +# Checkpoint +checkpoint: + every: 100000 + +# Buffer +buffer: + checkpoint: True + +# The CNN and MLP keys of the decoder are the same as those of the encoder by default +cnn_keys: + encoder: + - frame +mlp_keys: + encoder: + - reward + - P1_actions_attack + - P1_actions_move + - P1_oppChar + - P1_oppHealth + - P1_oppSide + - P1_oppWins + - P1_ownChar + - P1_ownHealth + - P1_ownSide + - P1_ownWins + - stage + decoder: + - P1_actions_attack + - P1_actions_move + - P1_oppChar + - P1_oppHealth + - P1_oppSide + - P1_oppWins + - P1_ownChar + - P1_ownHealth + - P1_ownSide + - P1_ownWins + - stage + +# Algorithm +algo: + learning_starts: 65536 + train_every: 8 + dense_units: 768 + mlp_layers: 4 + world_model: + encoder: + cnn_channels_multiplier: 64 + recurrent_model: + recurrent_state_size: 2048 + transition_model: + hidden_size: 768 + representation_model: + hidden_size: 768 + +# Metric +metric: + log_every: 10000 \ No newline at end of file From e003389fe483203ec83c84d1dde8b555cf65b691 Mon Sep 17 00:00:00 2001 From: belerico_t Date: Tue, 12 Sep 2023 14:49:43 +0200 Subject: [PATCH 11/31] Chose which metric used to track time --- sheeprl/algos/ppo/ppo.py | 10 +- sheeprl/algos/ppo/ppo_decoupled.py | 25 ++- sheeprl/algos/ppo_recurrent/ppo_recurrent.py | 10 +- sheeprl/algos/sac/sac.py | 143 +++++++++------- sheeprl/algos/sac/sac_decoupled.py | 161 ++++++++++++------- sheeprl/utils/timer.py | 8 +- 6 files changed, 214 insertions(+), 143 deletions(-) diff --git a/sheeprl/algos/ppo/ppo.py b/sheeprl/algos/ppo/ppo.py index 6b60c58a..4559479e 100644 --- a/sheeprl/algos/ppo/ppo.py +++ b/sheeprl/algos/ppo/ppo.py @@ -15,7 +15,7 @@ from tensordict.tensordict import TensorDictBase from torch import nn from torch.utils.data import BatchSampler, DistributedSampler, RandomSampler -from torchmetrics import MeanMetric +from torchmetrics import MeanMetric, SumMetric from sheeprl.algos.ppo.agent import PPOAgent from sheeprl.algos.ppo.loss import entropy_loss, policy_loss, value_loss @@ -268,7 +268,7 @@ def main(cfg: DictConfig): # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment - with timer("Time/sps_env_interaction"): + with timer("Time/env_interaction_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): with torch.no_grad(): # Sample an action given the observation received by the environment normalized_obs = { @@ -356,7 +356,7 @@ def main(cfg: DictConfig): else: gathered_data = local_data - with timer("Time/sps_train"): + with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): train(fabric, agent, optimizer, gathered_data, aggregator, cfg) if cfg.algo.anneal_lr: @@ -388,13 +388,13 @@ def main(cfg: DictConfig): timer_metrics = timer.compute() fabric.log( "Time/sps_train", - (update - last_train) / timer_metrics["Time/sps_train"], + (update - last_train) / timer_metrics["Time/train_time"], policy_step, ) fabric.log( "Time/sps_env_interaction", ((policy_step - last_log) / world_size * cfg.env.action_repeat) - / timer_metrics["Time/sps_env_interaction"], + / timer_metrics["Time/env_interaction_time"], policy_step, ) timer.reset() diff --git a/sheeprl/algos/ppo/ppo_decoupled.py b/sheeprl/algos/ppo/ppo_decoupled.py index 4ac56c27..12079ac6 100644 --- a/sheeprl/algos/ppo/ppo_decoupled.py +++ b/sheeprl/algos/ppo/ppo_decoupled.py @@ -19,7 +19,7 @@ from tensordict.tensordict import TensorDictBase, make_tensordict from torch.distributed.algorithms.join import Join from torch.utils.data import BatchSampler, RandomSampler -from torchmetrics import MeanMetric +from torchmetrics import MeanMetric, SumMetric from sheeprl.algos.ppo.agent import PPOAgent from sheeprl.algos.ppo.loss import entropy_loss, policy_loss, value_loss @@ -142,11 +142,11 @@ def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_co # Global variables last_log = 0 - last_train = 0 policy_step = 0 last_checkpoint = 0 policy_steps_per_update = int(cfg.env.num_envs * cfg.algo.rollout_steps) num_updates = cfg.total_steps // policy_steps_per_update if not cfg.dry_run else 1 + # Warning for log and checkpoint every if cfg.metric.log_every % policy_steps_per_update != 0: warnings.warn( @@ -198,7 +198,7 @@ def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_co # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment - with timer("Time/sps_env_interaction", sync_on_compute=False): + with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): with torch.no_grad(): # Sample an action given the observation received by the environment normalized_obs = { @@ -287,13 +287,11 @@ def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_co # Convert back the parameters torch.nn.utils.convert_parameters.vector_to_parameters(flattened_parameters, list(agent.parameters())) - # Log metrics if policy_step - last_log >= cfg.metric.log_every or cfg.dry_run: # Gather metrics from the trainers metrics = [None] player_trainer_collective.broadcast_object_list(metrics, src=1) metrics = metrics[0] - sps_train = metrics.pop("Time/sps_train") # Log metrics fabric.log_dict(metrics, policy_step) @@ -302,20 +300,14 @@ def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_co # Sync timers timer_metrics = timer.compute() - fabric.log( - "Time/sps_train", - (update - last_train) / sps_train, - policy_step, - ) fabric.log( "Time/sps_env_interaction", - ((policy_step - last_log) * cfg.env.action_repeat) / timer_metrics["Time/sps_env_interaction"], + ((policy_step - last_log) * cfg.env.action_repeat) / timer_metrics["Time/env_interaction_time"], policy_step, ) timer.reset() # Reset counters - last_train = update last_log = policy_step # Checkpoint model @@ -422,6 +414,7 @@ def trainer( # Start training update = 0 last_log = 0 + last_train = 0 policy_step = 0 last_checkpoint = 0 initial_ent_coef = copy.deepcopy(cfg.algo.ent_coef) @@ -446,7 +439,9 @@ def trainer( update += 1 policy_step += cfg.env.num_envs * cfg.algo.rollout_steps - with timer("Time/sps_train", process_group=optimization_pg): + with timer( + "Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg) + ): # Prepare sampler indexes = list(range(data.shape[0])) sampler = BatchSampler(RandomSampler(indexes), batch_size=cfg.per_rank_batch_size, drop_last=False) @@ -510,7 +505,6 @@ def trainer( src=1, ) - # Send updated weights to the player if policy_step - last_log >= cfg.metric.log_every or cfg.dry_run: # Sync distributed metrics metrics = aggregator.compute() @@ -518,7 +512,7 @@ def trainer( # Sync distributed timers timers = timer.compute() - metrics.update(timers) + metrics.update({"Time/sps_train": (update - last_train) / timers["Time/train_time"]}) timer.reset() # Send metrics to the player @@ -534,6 +528,7 @@ def trainer( ) # Broadcast metrics: fake send with object list between rank-0 and rank-1 # Reset counters + last_train = update last_log = policy_step if cfg.algo.anneal_lr: diff --git a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py index b5716578..352e3fa4 100644 --- a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py +++ b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py @@ -18,7 +18,7 @@ from torch.distributed.algorithms.join import Join from torch.distributions import Categorical from torch.utils.data.sampler import BatchSampler, RandomSampler -from torchmetrics import MeanMetric +from torchmetrics import MeanMetric, SumMetric from sheeprl.algos.ppo.loss import entropy_loss, policy_loss, value_loss from sheeprl.algos.ppo_recurrent.agent import RecurrentPPOAgent @@ -237,7 +237,7 @@ def main(cfg: DictConfig): # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment - with timer("Time/sps_env_interaction"): + with timer("Time/env_interaction_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): with torch.no_grad(): # Sample an action given the observation received by the environment action_logits, values, state = agent.module(next_obs, state=next_state) @@ -333,7 +333,7 @@ def main(cfg: DictConfig): sequences = episodes padded_sequences = pad_sequence(sequences, batch_first=False, return_mask=True) # [Seq_len, Num_seq, *] - with timer("Time/sps_train"): + with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): train(fabric, agent, optimizer, padded_sequences, aggregator, cfg) if cfg.algo.anneal_lr: @@ -365,13 +365,13 @@ def main(cfg: DictConfig): timer_metrics = timer.compute() fabric.log( "Time/sps_train", - (update - last_train) / timer_metrics["Time/sps_train"], + (update - last_train) / timer_metrics["Time/train_time"], policy_step, ) fabric.log( "Time/sps_env_interaction", ((policy_step - last_log) / world_size * cfg.env.action_repeat) - / timer_metrics["Time/sps_env_interaction"], + / timer_metrics["Time/env_interaction_time"], policy_step, ) timer.reset() diff --git a/sheeprl/algos/sac/sac.py b/sheeprl/algos/sac/sac.py index 0b2c6d5e..6cdf1a6f 100644 --- a/sheeprl/algos/sac/sac.py +++ b/sheeprl/algos/sac/sac.py @@ -1,5 +1,4 @@ import os -import time import warnings from math import prod from typing import Optional @@ -17,7 +16,7 @@ from torch.optim import Optimizer from torch.utils.data.distributed import DistributedSampler from torch.utils.data.sampler import BatchSampler -from torchmetrics import MeanMetric +from torchmetrics import MeanMetric, SumMetric from sheeprl.algos.sac.agent import SACActor, SACAgent, SACCritic from sheeprl.algos.sac.loss import critic_loss, entropy_loss, policy_loss @@ -28,6 +27,8 @@ from sheeprl.utils.logger import create_tensorboard_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm +from sheeprl.utils.timer import timer +from sheeprl.utils.utils import print_config def train( @@ -83,12 +84,15 @@ def train( @register_algorithm() @hydra.main(version_base=None, config_path="../../configs", config_name="config") def main(cfg: DictConfig): + print_config(cfg) + # Initialize Fabric fabric = Fabric(callbacks=[CheckpointCallback()]) if not _is_using_cli(): fabric.launch() - rank = fabric.global_rank device = fabric.device + rank = fabric.global_rank + world_size = fabric.world_size fabric.seed_everything(cfg.seed) torch.backends.cudnn.deterministic = cfg.torch_deterministic @@ -152,21 +156,19 @@ def main(cfg: DictConfig): hydra.utils.instantiate(cfg.algo.alpha.optimizer, params=[agent.log_alpha]), ) - # Metrics - with device: - aggregator = MetricAggregator( - { - "Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Time/step_per_second": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/value_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/policy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/alpha_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - } - ) + # Create a metric aggregator to log the metrics + aggregator = MetricAggregator( + { + "Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/value_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/policy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/alpha_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + } + ).to(device) # Local data - buffer_size = cfg.buffer.size // int(cfg.env.num_envs * fabric.world_size) if not cfg.dry_run else 1 + buffer_size = cfg.buffer.size // int(cfg.env.num_envs * world_size) if not cfg.dry_run else 1 rb = ReplayBuffer( buffer_size, cfg.env.num_envs, @@ -177,11 +179,12 @@ def main(cfg: DictConfig): step_data = TensorDict({}, batch_size=[cfg.env.num_envs], device=device) # Global variables - policy_step = 0 last_log = 0 + last_train = 0 + train_step = 0 + policy_step = 0 last_checkpoint = 0 - start_time = time.perf_counter() - policy_steps_per_update = int(cfg.env.num_envs * fabric.world_size) + policy_steps_per_update = int(cfg.env.num_envs * world_size) num_updates = int(cfg.total_steps // policy_steps_per_update) if not cfg.dry_run else 1 learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 @@ -205,26 +208,29 @@ def main(cfg: DictConfig): obs = torch.tensor(envs.reset(seed=cfg.seed)[0], dtype=torch.float32, device=device) # [N_envs, N_obs] for update in range(1, num_updates + 1): - if update < learning_starts: - actions = envs.action_space.sample() - else: - # Sample an action given the observation received by the environment - with torch.no_grad(): - actions, _ = actor.module(obs) - actions = actions.cpu().numpy() - next_obs, rewards, dones, truncated, infos = envs.step(actions) - dones = np.logical_or(dones, truncated) - - policy_step += cfg.env.num_envs * fabric.world_size + policy_step += cfg.env.num_envs * world_size + + # Measure environment interaction time: this considers both the model forward + # to get the action given the observation and the time taken into the environment + with timer("Time/env_interaction_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + if update <= learning_starts: + actions = envs.action_space.sample() + else: + # Sample an action given the observation received by the environment + with torch.no_grad(): + actions, _ = actor.module(obs) + actions = actions.cpu().numpy() + next_obs, rewards, dones, truncated, infos = envs.step(actions) + dones = np.logical_or(dones, truncated) if "final_info" in infos: - for i, agent_final_info in enumerate(infos["final_info"]): - if agent_final_info is not None and "episode" in agent_final_info: - fabric.print( - f"Rank-0: policy_step={policy_step}, reward_env_{i}={agent_final_info['episode']['r'][0]}" - ) - aggregator.update("Rewards/rew_avg", agent_final_info["episode"]["r"][0]) - aggregator.update("Game/ep_len_avg", agent_final_info["episode"]["l"][0]) + for i, agent_ep_info in enumerate(infos["final_info"]): + if agent_ep_info is not None: + ep_rew = agent_ep_info["episode"]["r"] + ep_len = agent_ep_info["episode"]["l"] + aggregator.update("Rewards/rew_avg", ep_rew) + aggregator.update("Game/ep_len_avg", ep_len) + fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") # Save the real next observation real_next_obs = next_obs.copy() @@ -251,8 +257,8 @@ def main(cfg: DictConfig): obs = torch.tensor(next_obs, device=device) # Train the agent - if update >= learning_starts - 1: - training_steps = learning_starts if update == learning_starts - 1 else 1 + if update >= learning_starts: + training_steps = learning_starts if update == learning_starts else 1 for _ in range(training_steps): # We sample one time to reduce the communications between processes sample = rb.sample( @@ -260,10 +266,10 @@ def main(cfg: DictConfig): ) # [G*B, 1] gathered_data = fabric.all_gather(sample.to_dict()) # [G*B, World, 1] gathered_data = make_tensordict(gathered_data).view(-1) # [G*B*World] - if fabric.world_size > 1: + if world_size > 1: dist_sampler: DistributedSampler = DistributedSampler( range(len(gathered_data)), - num_replicas=fabric.world_size, + num_replicas=world_size, rank=fabric.global_rank, shuffle=True, seed=cfg.seed, @@ -277,24 +283,47 @@ def main(cfg: DictConfig): sampler=range(len(gathered_data)), batch_size=cfg.per_rank_batch_size, drop_last=False ) for batch_idxes in sampler: - train( - fabric, - agent, - actor_optimizer, - qf_optimizer, - alpha_optimizer, - gathered_data[batch_idxes], - aggregator, - update, - cfg, - policy_steps_per_update, - ) - aggregator.update("Time/step_per_second", int(policy_step / (time.perf_counter() - start_time))) - if policy_step - last_log >= cfg.metric.log_every or cfg.dry_run: - last_log = policy_step - fabric.log_dict(aggregator.compute(), policy_step) + with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + train( + fabric, + agent, + actor_optimizer, + qf_optimizer, + alpha_optimizer, + gathered_data[batch_idxes], + aggregator, + update, + cfg, + policy_steps_per_update, + ) + train_step += 1 + + # Log metrics + if policy_step - last_log >= cfg.metric.log_every or update == num_updates or cfg.dry_run: + # Sync distributed metrics + metrics_dict = aggregator.compute() + fabric.log_dict(metrics_dict, policy_step) aggregator.reset() + # Sync distributed timers + timer_metrics = timer.compute() + fabric.log( + "Time/sps_train", + (train_step - last_train) / timer_metrics["Time/train_time"], + policy_step, + ) + fabric.log( + "Time/sps_env_interaction", + ((policy_step - last_log) / world_size * cfg.env.action_repeat) + / timer_metrics["Time/env_interaction_time"], + policy_step, + ) + timer.reset() + + # Reset counters + last_log = policy_step + last_train = train_step + # Checkpoint model if ( (cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every) diff --git a/sheeprl/algos/sac/sac_decoupled.py b/sheeprl/algos/sac/sac_decoupled.py index a0faa124..a9162899 100644 --- a/sheeprl/algos/sac/sac_decoupled.py +++ b/sheeprl/algos/sac/sac_decoupled.py @@ -18,7 +18,7 @@ from tensordict import TensorDict, make_tensordict from tensordict.tensordict import TensorDictBase from torch.utils.data.sampler import BatchSampler -from torchmetrics import MeanMetric +from torchmetrics import MeanMetric, SumMetric from sheeprl.algos.sac.agent import SACActor, SACAgent, SACCritic from sheeprl.algos.sac.sac import train @@ -28,10 +28,15 @@ from sheeprl.utils.env import make_env from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm +from sheeprl.utils.timer import timer +from sheeprl.utils.utils import print_config @torch.no_grad() def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_collective: TorchCollective): + print_config(cfg) + + # Initialize logger root_dir = ( os.path.join("logs", "runs", cfg.root_dir) if cfg.root_dir is not None @@ -100,14 +105,12 @@ def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_co torch.nn.utils.convert_parameters.vector_to_parameters(flattened_parameters, actor.parameters()) # Metrics - with device: - aggregator = MetricAggregator( - { - "Rewards/rew_avg": MeanMetric(sync_on_compute=False), - "Game/ep_len_avg": MeanMetric(sync_on_compute=False), - "Time/step_per_second": MeanMetric(sync_on_compute=False), - } - ) + aggregator = MetricAggregator( + { + "Rewards/rew_avg": MeanMetric(sync_on_compute=False), + "Game/ep_len_avg": MeanMetric(sync_on_compute=False), + } + ).to(device) # Local data buffer_size = cfg.buffer.size // cfg.env.num_envs if not cfg.dry_run else 1 @@ -121,10 +124,9 @@ def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_co step_data = TensorDict({}, batch_size=[cfg.env.num_envs], device=device) # Global variables - policy_step = 0 last_log = 0 + policy_step = 0 last_checkpoint = 0 - start_time = time.perf_counter() policy_steps_per_update = int(cfg.env.num_envs) num_updates = int(cfg.total_steps // policy_steps_per_update) if not cfg.dry_run else 1 learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 @@ -150,26 +152,29 @@ def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_co obs = torch.tensor(envs.reset(seed=cfg.seed)[0], dtype=torch.float32) # [N_envs, N_obs] for update in range(1, num_updates + 1): - if update < learning_starts: - actions = envs.action_space.sample() - else: - # Sample an action given the observation received by the environment - with torch.no_grad(): - actions, _ = actor(obs) - actions = actions.cpu().numpy() - next_obs, rewards, dones, truncated, infos = envs.step(actions) - dones = np.logical_or(dones, truncated) - policy_step += cfg.env.num_envs + # Measure environment interaction time: this considers both the model forward + # to get the action given the observation and the time taken into the environment + with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): + if update <= learning_starts: + actions = envs.action_space.sample() + else: + # Sample an action given the observation received by the environment + with torch.no_grad(): + actions, _ = actor(obs) + actions = actions.cpu().numpy() + next_obs, rewards, dones, truncated, infos = envs.step(actions) + dones = np.logical_or(dones, truncated) + if "final_info" in infos: - for i, agent_final_info in enumerate(infos["final_info"]): - if agent_final_info is not None and "episode" in agent_final_info: - fabric.print( - f"Rank-0: policy_step={policy_step}, reward_env_{i}={agent_final_info['episode']['r'][0]}" - ) - aggregator.update("Rewards/rew_avg", agent_final_info["episode"]["r"][0]) - aggregator.update("Game/ep_len_avg", agent_final_info["episode"]["l"][0]) + for i, agent_ep_info in enumerate(infos["final_info"]): + if agent_ep_info is not None: + ep_rew = agent_ep_info["episode"]["r"] + ep_len = agent_ep_info["episode"]["l"] + aggregator.update("Rewards/rew_avg", ep_rew) + aggregator.update("Game/ep_len_avg", ep_len) + fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") # Save the real next observation real_next_obs = next_obs.copy() @@ -197,6 +202,13 @@ def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_co # Send data to the training agents if update >= learning_starts: + # Send local info to the trainers + if update == learning_starts: + world_collective.broadcast_object_list( + [{"update": update, "last_log": last_log, "last_checkpoint": last_checkpoint}], src=0 + ) + + # Sample data to be sent to the trainers training_steps = learning_starts if update == learning_starts else 1 chunks = rb.sample( training_steps * cfg.algo.gradient_steps * cfg.per_rank_batch_size * (fabric.world_size - 1), @@ -204,26 +216,38 @@ def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_co ).split(training_steps * cfg.algo.gradient_steps * cfg.per_rank_batch_size) world_collective.scatter_object_list([None], [None] + chunks, src=0) - # Gather metrics from the trainers to be plotted - if policy_step - last_log >= cfg.metric.log_every or cfg.dry_run: - metrics = [None] - player_trainer_collective.broadcast_object_list(metrics, src=1) - # Wait the trainers to finish player_trainer_collective.broadcast(flattened_parameters, src=1) # Convert back the parameters torch.nn.utils.convert_parameters.vector_to_parameters(flattened_parameters, actor.parameters()) + # Logs trainers-only metrics if policy_step - last_log >= cfg.metric.log_every or cfg.dry_run: + # Gather metrics from the trainers + metrics = [None] + player_trainer_collective.broadcast_object_list(metrics, src=1) + + # Log metrics fabric.log_dict(metrics[0], policy_step) - aggregator.update("Time/step_per_second", int(policy_step / (time.perf_counter() - start_time))) + # Logs player-only metrics if policy_step - last_log >= cfg.metric.log_every or cfg.dry_run: - last_log = policy_step fabric.log_dict(aggregator.compute(), policy_step) aggregator.reset() + # Sync timers + timer_metrics = timer.compute() + fabric.log( + "Time/sps_env_interaction", + ((policy_step - last_log) * cfg.env.action_repeat) / timer_metrics["Time/env_interaction_time"], + policy_step, + ) + timer.reset() + + # Reset counters + last_log = policy_step + # Checkpoint model if ( update >= learning_starts # otherwise the processes end up deadlocked @@ -343,13 +367,21 @@ def trainer( } ) + # Receive data from player reagrding the: + # * update + # * last_log + # * last_checkpoint + data = [None] + world_collective.broadcast_object_list(data, src=0) + update = data[0]["update"] + last_log = data[0]["last_log"] + last_checkpoint = data[0]["last_checkpoint"] + # Start training + train_step = 0 + last_train = 0 policy_steps_per_update = cfg.env.num_envs - learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 - update = learning_starts policy_step = update * policy_steps_per_update - last_log = (policy_step // cfg.metric.log_every - 1) * cfg.metric.log_every - last_checkpoint = 0 while True: # Wait for data data = [None] @@ -370,34 +402,47 @@ def trainer( data = make_tensordict(data, device=device) sampler = BatchSampler(range(len(data)), batch_size=cfg.per_rank_batch_size, drop_last=False) for batch_idxes in sampler: - train( - fabric, - agent, - actor_optimizer, - qf_optimizer, - alpha_optimizer, - data[batch_idxes], - aggregator, - update, - cfg, - policy_steps_per_update, - group=optimization_pg, + with timer( + "Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg) + ): + train( + fabric, + agent, + actor_optimizer, + qf_optimizer, + alpha_optimizer, + data[batch_idxes], + aggregator, + update, + cfg, + policy_steps_per_update, + group=optimization_pg, + ) + train_step += 1 + + if global_rank == 1: + player_trainer_collective.broadcast( + torch.nn.utils.convert_parameters.parameters_to_vector(actor.parameters()), src=1 ) - # Send updated weights to the player if policy_step - last_log >= cfg.metric.log_every or cfg.dry_run: - last_log = policy_step + # Sync distributed metrics metrics = aggregator.compute() aggregator.reset() + + # Sync distributed timers + timers = timer.compute() + metrics.update({"Time/sps_train": (train_step - last_train) / timers["Time/train_time"]}) + timer.reset() + if global_rank == 1: player_trainer_collective.broadcast_object_list( [metrics], src=1 ) # Broadcast metrics: fake send with object list between rank-0 and rank-1 - if global_rank == 1: - player_trainer_collective.broadcast( - torch.nn.utils.convert_parameters.parameters_to_vector(actor.parameters()), src=1 - ) + # Reset counters + last_log = policy_step + last_train = train_step # Checkpoint model on rank-0: send it everything if (cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every) or cfg.dry_run: @@ -411,6 +456,8 @@ def trainer( "update": update, } fabric.call("on_checkpoint_trainer", player_trainer_collective=player_trainer_collective, state=state) + + # Update counters update += 1 policy_step += policy_steps_per_update diff --git a/sheeprl/utils/timer.py b/sheeprl/utils/timer.py index 41043178..dcbd3d6a 100644 --- a/sheeprl/utils/timer.py +++ b/sheeprl/utils/timer.py @@ -5,7 +5,7 @@ from typing import Dict, Optional, Union import torch -from torchmetrics import MeanMetric +from torchmetrics import Metric, SumMetric class TimerError(Exception): @@ -16,14 +16,14 @@ class timer(ContextDecorator): """A timer class to measure the time of a code block.""" disabled: bool = False - timers: Dict[str, MeanMetric] = {} + timers: Dict[str, Metric] = {} _start_time: Optional[float] = None - def __init__(self, name: str, **metric_kwargs) -> None: + def __init__(self, name: str, metric: Optional[Metric] = None) -> None: """Add timer to dict of timers after initialization""" self.name = name if not timer.disabled and self.name is not None and self.name not in self.timers.keys(): - self.timers.setdefault(self.name, MeanMetric(**metric_kwargs)) + self.timers.setdefault(self.name, metric if metric is not None else SumMetric()) def start(self) -> None: """Start a new timer""" From d4072d28ebbc159b847b26e2d82914d3c9be4f6d Mon Sep 17 00:00:00 2001 From: belerico_t Date: Tue, 12 Sep 2023 15:10:52 +0200 Subject: [PATCH 12/31] Fix aggregator.to(device) + check when we measure train time --- sheeprl/algos/ppo/ppo_decoupled.py | 38 ++++++--------- sheeprl/algos/sac/sac.py | 77 ++++++++++++++++-------------- sheeprl/algos/sac/sac_decoupled.py | 36 ++++++-------- 3 files changed, 69 insertions(+), 82 deletions(-) diff --git a/sheeprl/algos/ppo/ppo_decoupled.py b/sheeprl/algos/ppo/ppo_decoupled.py index 12079ac6..f8418161 100644 --- a/sheeprl/algos/ppo/ppo_decoupled.py +++ b/sheeprl/algos/ppo/ppo_decoupled.py @@ -124,11 +124,9 @@ def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_co torch.nn.utils.convert_parameters.vector_to_parameters(flattened_parameters, list(agent.parameters())) # Metrics - metrics = { - "Rewards/rew_avg": MeanMetric(sync_on_compute=False), - "Game/ep_len_avg": MeanMetric(sync_on_compute=False), - } - aggregator = MetricAggregator(metrics).to(device) + aggregator = MetricAggregator( + {"Rewards/rew_avg": MeanMetric(sync_on_compute=False), "Game/ep_len_avg": MeanMetric(sync_on_compute=False)} + ).to(device) # Local data rb = ReplayBuffer( @@ -396,20 +394,13 @@ def trainer( scheduler = PolynomialLR(optimizer=optimizer, total_iters=num_updates, power=1.0) # Metrics - with fabric.device: - aggregator = MetricAggregator( - { - "Loss/value_loss": MeanMetric( - sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg - ), - "Loss/policy_loss": MeanMetric( - sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg - ), - "Loss/entropy_loss": MeanMetric( - sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg - ), - } - ) + aggregator = MetricAggregator( + { + "Loss/value_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg), + "Loss/policy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg), + "Loss/entropy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg), + } + ).to(device) # Start training update = 0 @@ -439,13 +430,14 @@ def trainer( update += 1 policy_step += cfg.env.num_envs * cfg.algo.rollout_steps + # Prepare sampler + indexes = list(range(data.shape[0])) + sampler = BatchSampler(RandomSampler(indexes), batch_size=cfg.per_rank_batch_size, drop_last=False) + + # Start training with timer( "Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg) ): - # Prepare sampler - indexes = list(range(data.shape[0])) - sampler = BatchSampler(RandomSampler(indexes), batch_size=cfg.per_rank_batch_size, drop_last=False) - # The Join context is needed because there can be the possibility # that some ranks receive less data with Join([agent._forward_module]): diff --git a/sheeprl/algos/sac/sac.py b/sheeprl/algos/sac/sac.py index 6cdf1a6f..2ad2c97c 100644 --- a/sheeprl/algos/sac/sac.py +++ b/sheeprl/algos/sac/sac.py @@ -259,44 +259,47 @@ def main(cfg: DictConfig): # Train the agent if update >= learning_starts: training_steps = learning_starts if update == learning_starts else 1 - for _ in range(training_steps): - # We sample one time to reduce the communications between processes - sample = rb.sample( - cfg.algo.gradient_steps * cfg.per_rank_batch_size, sample_next_obs=cfg.buffer.sample_next_obs - ) # [G*B, 1] - gathered_data = fabric.all_gather(sample.to_dict()) # [G*B, World, 1] - gathered_data = make_tensordict(gathered_data).view(-1) # [G*B*World] - if world_size > 1: - dist_sampler: DistributedSampler = DistributedSampler( - range(len(gathered_data)), - num_replicas=world_size, - rank=fabric.global_rank, - shuffle=True, - seed=cfg.seed, - drop_last=False, - ) - sampler: BatchSampler = BatchSampler( - sampler=dist_sampler, batch_size=cfg.per_rank_batch_size, drop_last=False - ) - else: - sampler = BatchSampler( - sampler=range(len(gathered_data)), batch_size=cfg.per_rank_batch_size, drop_last=False - ) + + # We sample one time to reduce the communications between processes + sample = rb.sample( + training_steps * cfg.algo.gradient_steps * cfg.per_rank_batch_size, + sample_next_obs=cfg.buffer.sample_next_obs, + ) # [G*B, 1] + gathered_data = fabric.all_gather(sample.to_dict()) # [G*B, World, 1] + gathered_data = make_tensordict(gathered_data).view(-1) # [G*B*World] + if world_size > 1: + dist_sampler: DistributedSampler = DistributedSampler( + range(len(gathered_data)), + num_replicas=world_size, + rank=fabric.global_rank, + shuffle=True, + seed=cfg.seed, + drop_last=False, + ) + sampler: BatchSampler = BatchSampler( + sampler=dist_sampler, batch_size=cfg.per_rank_batch_size, drop_last=False + ) + else: + sampler = BatchSampler( + sampler=range(len(gathered_data)), batch_size=cfg.per_rank_batch_size, drop_last=False + ) + + # Start training + with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): for batch_idxes in sampler: - with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): - train( - fabric, - agent, - actor_optimizer, - qf_optimizer, - alpha_optimizer, - gathered_data[batch_idxes], - aggregator, - update, - cfg, - policy_steps_per_update, - ) - train_step += 1 + train( + fabric, + agent, + actor_optimizer, + qf_optimizer, + alpha_optimizer, + gathered_data[batch_idxes], + aggregator, + update, + cfg, + policy_steps_per_update, + ) + train_step += 1 # Log metrics if policy_step - last_log >= cfg.metric.log_every or update == num_updates or cfg.dry_run: diff --git a/sheeprl/algos/sac/sac_decoupled.py b/sheeprl/algos/sac/sac_decoupled.py index a9162899..21a48254 100644 --- a/sheeprl/algos/sac/sac_decoupled.py +++ b/sheeprl/algos/sac/sac_decoupled.py @@ -106,10 +106,7 @@ def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_co # Metrics aggregator = MetricAggregator( - { - "Rewards/rew_avg": MeanMetric(sync_on_compute=False), - "Game/ep_len_avg": MeanMetric(sync_on_compute=False), - } + {"Rewards/rew_avg": MeanMetric(sync_on_compute=False), "Game/ep_len_avg": MeanMetric(sync_on_compute=False)} ).to(device) # Local data @@ -352,20 +349,13 @@ def trainer( ) # Metrics - with fabric.device: - aggregator = MetricAggregator( - { - "Loss/value_loss": MeanMetric( - sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg - ), - "Loss/policy_loss": MeanMetric( - sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg - ), - "Loss/alpha_loss": MeanMetric( - sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg - ), - } - ) + aggregator = MetricAggregator( + { + "Loss/value_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg), + "Loss/policy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg), + "Loss/alpha_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg), + } + ).to(device) # Receive data from player reagrding the: # * update @@ -401,10 +391,12 @@ def trainer( return data = make_tensordict(data, device=device) sampler = BatchSampler(range(len(data)), batch_size=cfg.per_rank_batch_size, drop_last=False) - for batch_idxes in sampler: - with timer( - "Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg) - ): + + # Start training + with timer( + "Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute, process_group=optimization_pg) + ): + for batch_idxes in sampler: train( fabric, agent, From 051e8e5ad81b14ac107cf52929d38660c20528e9 Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Tue, 12 Sep 2023 16:04:17 +0200 Subject: [PATCH 13/31] fix: configs --- sheeprl/configs/config.yaml | 2 +- .../exp/dreamer_v3_L_doapp_128px_gray_combo_discrete.yaml | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/sheeprl/configs/config.yaml b/sheeprl/configs/config.yaml index ec9ff788..607ca96e 100644 --- a/sheeprl/configs/config.yaml +++ b/sheeprl/configs/config.yaml @@ -7,9 +7,9 @@ defaults: - buffer: default.yaml - checkpoint: default.yaml - env: default.yaml + - metric: default.yaml - exp: null - hydra: default.yaml - - metric: default.yaml num_threads: 1 total_steps: ??? diff --git a/sheeprl/configs/exp/dreamer_v3_L_doapp_128px_gray_combo_discrete.yaml b/sheeprl/configs/exp/dreamer_v3_L_doapp_128px_gray_combo_discrete.yaml index d95e3af1..70905a27 100644 --- a/sheeprl/configs/exp/dreamer_v3_L_doapp_128px_gray_combo_discrete.yaml +++ b/sheeprl/configs/exp/dreamer_v3_L_doapp_128px_gray_combo_discrete.yaml @@ -8,11 +8,12 @@ defaults: # Experiment seed: 0 total_steps: 10000000 +per_rank_batch_size: 8 # Environment env: id: doapp - num_envs: 8 + num_envs: 4 grayscale: True frame_stack: 1 screen_size: 128 From fb1a0d9dde9fd2747c1fdea0cc9f52a23604a20c Mon Sep 17 00:00:00 2001 From: belerico_t Date: Tue, 12 Sep 2023 17:04:48 +0200 Subject: [PATCH 14/31] Fix SAC check if first info are sent to the trainers --- sheeprl/algos/sac/sac.py | 3 ++- sheeprl/algos/sac/sac_decoupled.py | 4 +++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/sheeprl/algos/sac/sac.py b/sheeprl/algos/sac/sac.py index 2ad2c97c..2aced215 100644 --- a/sheeprl/algos/sac/sac.py +++ b/sheeprl/algos/sac/sac.py @@ -240,6 +240,7 @@ def main(cfg: DictConfig): real_next_obs[idx] = final_obs with device: + next_obs = torch.tensor(next_obs, dtype=torch.float32) real_next_obs = torch.tensor(real_next_obs, dtype=torch.float32) actions = torch.tensor(actions, dtype=torch.float32).view(cfg.env.num_envs, -1) rewards = torch.tensor(rewards, dtype=torch.float32).view(cfg.env.num_envs, -1) @@ -254,7 +255,7 @@ def main(cfg: DictConfig): rb.add(step_data.unsqueeze(0)) # next_obs becomes the new obs - obs = torch.tensor(next_obs, device=device) + obs = next_obs # Train the agent if update >= learning_starts: diff --git a/sheeprl/algos/sac/sac_decoupled.py b/sheeprl/algos/sac/sac_decoupled.py index 21a48254..63247bc1 100644 --- a/sheeprl/algos/sac/sac_decoupled.py +++ b/sheeprl/algos/sac/sac_decoupled.py @@ -124,6 +124,7 @@ def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_co last_log = 0 policy_step = 0 last_checkpoint = 0 + first_info_sent = False policy_steps_per_update = int(cfg.env.num_envs) num_updates = int(cfg.total_steps // policy_steps_per_update) if not cfg.dry_run else 1 learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 @@ -200,10 +201,11 @@ def player(cfg: DictConfig, world_collective: TorchCollective, player_trainer_co # Send data to the training agents if update >= learning_starts: # Send local info to the trainers - if update == learning_starts: + if not first_info_sent: world_collective.broadcast_object_list( [{"update": update, "last_log": last_log, "last_checkpoint": last_checkpoint}], src=0 ) + first_info_sent = True # Sample data to be sent to the trainers training_steps = learning_starts if update == learning_starts else 1 From 61bcfc06094444495b27d3db94dd293628b33ae6 Mon Sep 17 00:00:00 2001 From: belerico_t Date: Tue, 12 Sep 2023 17:05:07 +0200 Subject: [PATCH 15/31] Update DroQ logging --- sheeprl/algos/droq/droq.py | 199 +++++++++++++++++++++---------------- 1 file changed, 116 insertions(+), 83 deletions(-) diff --git a/sheeprl/algos/droq/droq.py b/sheeprl/algos/droq/droq.py index 018696c3..5e682d0b 100644 --- a/sheeprl/algos/droq/droq.py +++ b/sheeprl/algos/droq/droq.py @@ -1,5 +1,4 @@ import os -import time import warnings from math import prod @@ -15,7 +14,7 @@ from torch.optim import Optimizer from torch.utils.data.distributed import DistributedSampler from torch.utils.data.sampler import BatchSampler -from torchmetrics import MeanMetric +from torchmetrics import MeanMetric, SumMetric from sheeprl.algos.droq.agent import DROQAgent, DROQCritic from sheeprl.algos.sac.agent import SACActor @@ -27,6 +26,7 @@ from sheeprl.utils.logger import create_tensorboard_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm +from sheeprl.utils.timer import timer def train( @@ -42,75 +42,83 @@ def train( # Sample a minibatch in a distributed way: Line 5 - Algorithm 2 # We sample one time to reduce the communications between processes sample = rb.sample(cfg.algo.gradient_steps * cfg.per_rank_batch_size, sample_next_obs=cfg.buffer.sample_next_obs) - gathered_data = fabric.all_gather(sample.to_dict()) - gathered_data = make_tensordict(gathered_data).view(-1) + critic_data = fabric.all_gather(sample.to_dict()) + critic_data = make_tensordict(critic_data).view(-1) if fabric.world_size > 1: dist_sampler: DistributedSampler = DistributedSampler( - range(len(gathered_data)), + range(len(critic_data)), num_replicas=fabric.world_size, rank=fabric.global_rank, shuffle=True, seed=cfg.seed, drop_last=False, ) - sampler: BatchSampler = BatchSampler(sampler=dist_sampler, batch_size=cfg.per_rank_batch_size, drop_last=False) + critic_sampler: BatchSampler = BatchSampler( + sampler=dist_sampler, batch_size=cfg.per_rank_batch_size, drop_last=False + ) else: - sampler = BatchSampler(sampler=range(len(gathered_data)), batch_size=cfg.per_rank_batch_size, drop_last=False) - - # Update the soft-critic - for batch_idxes in sampler: - data = gathered_data[batch_idxes] - next_target_qf_value = agent.get_next_target_q_values( - data["next_observations"], - data["rewards"], - data["dones"], - cfg.algo.gamma, + critic_sampler = BatchSampler( + sampler=range(len(critic_data)), batch_size=cfg.per_rank_batch_size, drop_last=False ) - for qf_value_idx in range(agent.num_critics): - # Line 8 - Algorithm 2 - qf_loss = F.mse_loss( - agent.get_ith_q_value(data["observations"], data["actions"], qf_value_idx), next_target_qf_value - ) - qf_optimizer.zero_grad(set_to_none=True) - fabric.backward(qf_loss) - qf_optimizer.step() - aggregator.update("Loss/value_loss", qf_loss) - - # Update the target networks with EMA - agent.qfs_target_ema(critic_idx=qf_value_idx) # Sample a different minibatch in a distributed way to update actor and alpha parameter sample = rb.sample(cfg.per_rank_batch_size) - data = fabric.all_gather(sample.to_dict()) - data = make_tensordict(data).view(-1) + actor_data = fabric.all_gather(sample.to_dict()) + actor_data = make_tensordict(actor_data).view(-1) if fabric.world_size > 1: - sampler: DistributedSampler = DistributedSampler( - range(len(data)), + actor_sampler: DistributedSampler = DistributedSampler( + range(len(actor_data)), num_replicas=fabric.world_size, rank=fabric.global_rank, shuffle=True, seed=cfg.seed, drop_last=False, ) - data = data[next(iter(sampler))] - - # Update the actor - actions, logprobs = agent.get_actions_and_log_probs(data["observations"]) - qf_values = agent.get_q_values(data["observations"], actions) - min_qf_values = torch.mean(qf_values, dim=-1, keepdim=True) - actor_loss = policy_loss(agent.alpha, logprobs, min_qf_values) - actor_optimizer.zero_grad(set_to_none=True) - fabric.backward(actor_loss) - actor_optimizer.step() - aggregator.update("Loss/policy_loss", actor_loss) - - # Update the entropy value - alpha_loss = entropy_loss(agent.log_alpha, logprobs.detach(), agent.target_entropy) - alpha_optimizer.zero_grad(set_to_none=True) - fabric.backward(alpha_loss) - agent.log_alpha.grad = fabric.all_reduce(agent.log_alpha.grad) - alpha_optimizer.step() - aggregator.update("Loss/alpha_loss", alpha_loss) + actor_data = actor_data[next(iter(actor_sampler))] + + with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + # Update the soft-critic + for batch_idxes in critic_sampler: + critic_batch_data = critic_data[batch_idxes] + next_target_qf_value = agent.get_next_target_q_values( + critic_batch_data["next_observations"], + critic_batch_data["rewards"], + critic_batch_data["dones"], + cfg.algo.gamma, + ) + for qf_value_idx in range(agent.num_critics): + # Line 8 - Algorithm 2 + qf_loss = F.mse_loss( + agent.get_ith_q_value( + critic_batch_data["observations"], critic_batch_data["actions"], qf_value_idx + ), + next_target_qf_value, + ) + qf_optimizer.zero_grad(set_to_none=True) + fabric.backward(qf_loss) + qf_optimizer.step() + aggregator.update("Loss/value_loss", qf_loss) + + # Update the target networks with EMA + agent.qfs_target_ema(critic_idx=qf_value_idx) + + # Update the actor + actions, logprobs = agent.get_actions_and_log_probs(actor_data["observations"]) + qf_values = agent.get_q_values(actor_data["observations"], actions) + min_qf_values = torch.mean(qf_values, dim=-1, keepdim=True) + actor_loss = policy_loss(agent.alpha, logprobs, min_qf_values) + actor_optimizer.zero_grad(set_to_none=True) + fabric.backward(actor_loss) + actor_optimizer.step() + aggregator.update("Loss/policy_loss", actor_loss) + + # Update the entropy value + alpha_loss = entropy_loss(agent.log_alpha, logprobs.detach(), agent.target_entropy) + alpha_optimizer.zero_grad(set_to_none=True) + fabric.backward(alpha_loss) + agent.log_alpha.grad = fabric.all_reduce(agent.log_alpha.grad) + alpha_optimizer.step() + aggregator.update("Loss/alpha_loss", alpha_loss) @register_algorithm() @@ -120,8 +128,9 @@ def main(cfg: DictConfig): fabric = Fabric(callbacks=[CheckpointCallback()]) if not _is_using_cli(): fabric.launch() - rank = fabric.global_rank device = fabric.device + rank = fabric.global_rank + world_size = fabric.world_size fabric.seed_everything(cfg.seed) torch.backends.cudnn.deterministic = cfg.torch_deterministic @@ -194,17 +203,16 @@ def main(cfg: DictConfig): ) # Metrics - with device: - aggregator = MetricAggregator( - { - "Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Time/step_per_second": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/value_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/policy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/alpha_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - } - ) + aggregator = MetricAggregator( + { + "Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Time/step_per_second": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/value_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/policy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/alpha_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + } + ).to(device) # Local data buffer_size = cfg.buffer.size // int(cfg.env.num_envs * fabric.world_size) if not cfg.dry_run else 1 @@ -218,10 +226,11 @@ def main(cfg: DictConfig): step_data = TensorDict({}, batch_size=[cfg.env.num_envs], device=device) # Global variables - policy_step = 0 last_log = 0 + last_train = 0 + train_step = 0 + policy_step = 0 last_checkpoint = 0 - start_time = time.perf_counter() policy_steps_per_update = int(cfg.env.num_envs * fabric.world_size) num_updates = int(cfg.total_steps // policy_steps_per_update) if not cfg.dry_run else 1 learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 @@ -247,23 +256,25 @@ def main(cfg: DictConfig): obs = torch.tensor(envs.reset(seed=cfg.seed)[0], dtype=torch.float32) # [N_envs, N_obs] for update in range(1, num_updates + 1): - # Sample an action given the observation received by the environment - with torch.no_grad(): - actions, _ = actor.module(obs) - actions = actions.cpu().numpy() - next_obs, rewards, dones, truncated, infos = envs.step(actions) - dones = np.logical_or(dones, truncated) - policy_step += cfg.env.num_envs * fabric.world_size + # Measure environment interaction time: this considers both the model forward + # to get the action given the observation and the time taken into the environment + with timer("Time/env_interaction_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + with torch.no_grad(): + actions, _ = actor.module(obs) + actions = actions.cpu().numpy() + next_obs, rewards, dones, truncated, infos = envs.step(actions) + dones = np.logical_or(dones, truncated) + if "final_info" in infos: - for i, agent_final_info in enumerate(infos["final_info"]): - if agent_final_info is not None and "episode" in agent_final_info: - fabric.print( - f"Rank-0: policy_step={policy_step}, reward_env_{i}={agent_final_info['episode']['r'][0]}" - ) - aggregator.update("Rewards/rew_avg", agent_final_info["episode"]["r"][0]) - aggregator.update("Game/ep_len_avg", agent_final_info["episode"]["l"][0]) + for i, agent_ep_info in enumerate(infos["final_info"]): + if agent_ep_info is not None: + ep_rew = agent_ep_info["episode"]["r"] + ep_len = agent_ep_info["episode"]["l"] + aggregator.update("Rewards/rew_avg", ep_rew) + aggregator.update("Game/ep_len_avg", ep_len) + fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") # Save the real next observation real_next_obs = next_obs.copy() @@ -293,12 +304,34 @@ def main(cfg: DictConfig): # Train the agent if update > learning_starts: train(fabric, agent, actor_optimizer, qf_optimizer, alpha_optimizer, rb, aggregator, cfg) - aggregator.update("Time/step_per_second", int(policy_step / (time.perf_counter() - start_time))) - if policy_step - last_log >= cfg.metric.log_every or cfg.dry_run: - last_log = policy_step - fabric.log_dict(aggregator.compute(), policy_step) + train_step += 1 + + # Log metrics + if policy_step - last_log >= cfg.metric.log_every or update == num_updates or cfg.dry_run: + # Sync distributed metrics + metrics_dict = aggregator.compute() + fabric.log_dict(metrics_dict, policy_step) aggregator.reset() + # Sync distributed timers + timer_metrics = timer.compute() + fabric.log( + "Time/sps_train", + (train_step - last_train) / timer_metrics["Time/train_time"], + policy_step, + ) + fabric.log( + "Time/sps_env_interaction", + ((policy_step - last_log) / world_size * cfg.env.action_repeat) + / timer_metrics["Time/env_interaction_time"], + policy_step, + ) + timer.reset() + + # Reset counters + last_log = policy_step + last_train = train_step + # Checkpoint model if ( (cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every) From 1c935326496e60a41ed38136d11c4182f6ccff54 Mon Sep 17 00:00:00 2001 From: belerico_t Date: Tue, 12 Sep 2023 17:33:44 +0200 Subject: [PATCH 16/31] [skip ci] Updated howto --- howto/work_with_steps.md | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/howto/work_with_steps.md b/howto/work_with_steps.md index f1821e0c..b17b9d0e 100644 --- a/howto/work_with_steps.md +++ b/howto/work_with_steps.md @@ -1,30 +1,32 @@ # Work with steps In this document we want to discuss about the hyper-parameters which refer to the concept of step. -There are various ways to interpret it, so it is necessary to clearly specify which iterpretation we give to that concept. +There are various ways to interpret it, so it is necessary to clearly specify how to interpret it. ## Policy steps -We start from the concept of *policy steps*: a policy step is the selection of an action to perform an environment step. In other words, it is called *policy step* when the actor takes in input an observation and choose an action and then the agent performs an environment step. +We start from the concept of *policy step*: a policy step is the particular step in which the policy selects the action to perform in the environment, given an observation received by it. > **Note** > > The environment step is the step performed by the environment: the environment takes in input an action and computes the next observation and the next reward. -Now we have introduced the concept of policy step, it is necessary to clarify some aspects: -1. When there are more parallel environments, the policy step is proportional with the number of parallel environments. E.g, if there are $m$ environments, then the agent has to choose $m$ actions and each environment performs an environment step: this mean that **m policy steps** are performed. -2. When there are more parallel processes, the policy step it is proportional with the number of parallel processes. E.g, let us assume that there are $n$ processes each one containing one single environment: the $n$ actors select an action and perform a step in the environment, so, also in this case **n policy steps** are performed. +Now that we have introduced the concept of policy step, it is necessary to clarify some aspects: -To generalize, in a case with $n$ processes, each one with $m$ environments, the policy steps increase by $n \cdot m$ at each iteration. +1. When there are multiple parallel environments, the policy step is proportional to the number of parallel environments. E.g., if there are $m$ environments, then the actor has to choose $m$ actions and each environment performs an environment step: this means that $\bold{m}$ **policy steps** are performed. +2. When there are multiple parallel processes (i.e. the script has been run with `lightning run model --devices>=2 ...`), the policy step it is proportional to the number of parallel processes. E.g., let us assume that there are $n$ processes each one containing one single environment: the $n$ actors select an action and a (per-process) step in the environment is performed. In this case $\bold{n}$ **policy steps** are performed. + +In general, if we have $n$ parallel processes, each one with $m$ independent environments, the policy step increases **globally** by $n \cdot m$ at each iteration. The hyper-parameters which refer to the *policy steps* are: -* `total_steps`: the total number of policy steps to perform in an experiment. + +* `total_steps`: the total number of policy steps to perform in an experiment. Effectively, this number will be divided in each process by $n \cdot m$ to obtain the number of training steps to be performed by each of them. * `exploration_steps`: the number of policy steps in which the agent explores the environment in the P2E algorithms. -* `max_episode_steps`: the maximum number of policy steps an episode can last ($\text{max\_steps}$). This means that if you decide to have an action repeat greater than one ($\text{act\_repeat} > 1$), then you can perform a numeber of environment steps equal to: $\text{env\_steps} = \text{max\_steps} \cdot \text{act\_repeat}$. +* `max_episode_steps`: the maximum number of policy steps an episode can last ($\text{max\_steps}$); when this number is reached a `terminated=True` is returned by the environment. This means that if you decide to have an action repeat greater than one ($\text{action\_repeat} > 1$), then the environment performs a maximum number of steps equal to: $\text{env\_steps} = \text{max\_steps} \cdot \text{action\_repeat}$. * `learning_starts`: how many policy steps the agent has to perform before starting the training. * `train_every`: how many policy steps the agent has to perform between one training and the next. ## Gradient steps -A *gradient step* consists of an update of the parameters of the agent, i.e., a call of the *train* function. The gradient step is proportional to the number of parallel processes, indeed, if there are $n$ parallel processes, the call of the *train* method will increase by $n$ the gradient step. +A *gradient step* consists of an update of the parameters of the agent, i.e., a call of the *train* function. The gradient step is proportional to the number of parallel processes, indeed, if there are $n$ parallel processes, $n \cdot \text{gradient\_steps}$ calls to the *train* method will be executed. -The hyper-parameters which refer to the *policy steps* are: +The hyper-parameters which refer to the *gradient steps* are: * `algo.per_rank_gradient_steps`: the number of gradient steps per rank to perform in a single iteration. * `algo.per_rank_pretrain_steps`: the number of gradient steps per rank to perform in the first iteration. \ No newline at end of file From 94bc646583a04d673eaeb889a32746c7b76f928c Mon Sep 17 00:00:00 2001 From: belerico_t Date: Tue, 12 Sep 2023 18:32:01 +0200 Subject: [PATCH 17/31] Update logging SACAE --- sheeprl/algos/droq/droq.py | 3 + sheeprl/algos/sac_ae/sac_ae.py | 155 ++++++++++++++++++++------------- 2 files changed, 96 insertions(+), 62 deletions(-) diff --git a/sheeprl/algos/droq/droq.py b/sheeprl/algos/droq/droq.py index a9b6c4fe..346eaa12 100644 --- a/sheeprl/algos/droq/droq.py +++ b/sheeprl/algos/droq/droq.py @@ -27,6 +27,7 @@ from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer +from sheeprl.utils.utils import print_config def train( @@ -126,6 +127,8 @@ def train( @register_algorithm() @hydra.main(version_base=None, config_path="../../configs", config_name="config") def main(cfg: DictConfig): + print_config(cfg) + # Initialize Fabric fabric = Fabric(callbacks=[CheckpointCallback()]) if not _is_using_cli(): diff --git a/sheeprl/algos/sac_ae/sac_ae.py b/sheeprl/algos/sac_ae/sac_ae.py index ddcc7c05..1333dea7 100644 --- a/sheeprl/algos/sac_ae/sac_ae.py +++ b/sheeprl/algos/sac_ae/sac_ae.py @@ -1,6 +1,5 @@ import copy import os -import time import warnings from math import prod from typing import Optional, Union @@ -22,7 +21,7 @@ from torch.optim import Optimizer from torch.utils.data.distributed import DistributedSampler from torch.utils.data.sampler import BatchSampler -from torchmetrics import MeanMetric +from torchmetrics import MeanMetric, SumMetric from sheeprl.algos.sac.loss import critic_loss, entropy_loss, policy_loss from sheeprl.algos.sac_ae.agent import ( @@ -43,6 +42,8 @@ from sheeprl.utils.logger import create_tensorboard_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm +from sheeprl.utils.timer import timer +from sheeprl.utils.utils import print_config def train( @@ -133,6 +134,8 @@ def train( @register_algorithm() @hydra.main(version_base=None, config_path="../../configs", config_name="config") def main(cfg: DictConfig): + print_config(cfg) + if "minedojo" in cfg.env.env._target_.lower(): raise ValueError( "MineDojo is not currently supported by SAC-AE agent, since it does not take " @@ -164,8 +167,9 @@ def main(cfg: DictConfig): fabric = Fabric(strategy=strategy, callbacks=[CheckpointCallback()]) if not _is_using_cli(): fabric.launch() - rank = fabric.global_rank device = fabric.device + rank = fabric.global_rank + world_size = fabric.world_size fabric.seed_everything(cfg.seed) torch.backends.cudnn.deterministic = cfg.torch_deterministic @@ -323,18 +327,16 @@ def main(cfg: DictConfig): ) # Metrics - with device: - aggregator = MetricAggregator( - { - "Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Time/step_per_second": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/value_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/policy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/alpha_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/reconstruction_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - } - ) + aggregator = MetricAggregator( + { + "Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/value_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/policy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/alpha_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/reconstruction_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + } + ).to(device) # Local data buffer_size = cfg.buffer.size // int(cfg.env.num_envs * fabric.world_size) if not cfg.dry_run else 1 @@ -349,10 +351,11 @@ def main(cfg: DictConfig): step_data = TensorDict({}, batch_size=[cfg.env.num_envs], device=fabric.device if cfg.buffer.memmap else "cpu") # Global variables - policy_step = 0 last_log = 0 + last_train = 0 + train_step = 0 + policy_step = 0 last_checkpoint = 0 - start_time = time.time() policy_steps_per_update = int(cfg.env.num_envs * fabric.world_size) num_updates = int(cfg.total_steps // policy_steps_per_update) if not cfg.dry_run else 1 learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 @@ -386,26 +389,29 @@ def main(cfg: DictConfig): obs[k] = torch_obs for update in range(1, num_updates + 1): - if update < learning_starts: - actions = envs.action_space.sample() - else: - with torch.no_grad(): - normalized_obs = {k: v / 255 if k in cfg.cnn_keys.encoder else v for k, v in obs.items()} - actions, _ = actor.module(normalized_obs) - actions = actions.cpu().numpy() - o, rewards, dones, truncated, infos = envs.step(actions) - dones = np.logical_or(dones, truncated) - policy_step += cfg.env.num_envs * fabric.world_size + # Measure environment interaction time: this considers both the model forward + # to get the action given the observation and the time taken into the environment + with timer("Time/env_interaction_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + if update < learning_starts: + actions = envs.action_space.sample() + else: + with torch.no_grad(): + normalized_obs = {k: v / 255 if k in cfg.cnn_keys.encoder else v for k, v in obs.items()} + actions, _ = actor.module(normalized_obs) + actions = actions.cpu().numpy() + o, rewards, dones, truncated, infos = envs.step(actions) + dones = np.logical_or(dones, truncated) + if "final_info" in infos: - for i, agent_final_info in enumerate(infos["final_info"]): - if agent_final_info is not None and "episode" in agent_final_info: - fabric.print( - f"Rank-0: policy_step={policy_step}, reward_env_{i}={agent_final_info['episode']['r'][0]}" - ) - aggregator.update("Rewards/rew_avg", agent_final_info["episode"]["r"][0]) - aggregator.update("Game/ep_len_avg", agent_final_info["episode"]["l"][0]) + for i, agent_ep_info in enumerate(infos["final_info"]): + if agent_ep_info is not None: + ep_rew = agent_ep_info["episode"]["r"] + ep_len = agent_ep_info["episode"]["l"] + aggregator.update("Rewards/rew_avg", ep_rew) + aggregator.update("Game/ep_len_avg", ep_len) + fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") # Save the real next observation real_next_obs = copy.deepcopy(o) @@ -447,30 +453,33 @@ def main(cfg: DictConfig): # Train the agent if update >= learning_starts - 1: training_steps = learning_starts if update == learning_starts - 1 else 1 - for _ in range(training_steps): - # We sample one time to reduce the communications between processes - sample = rb.sample( - cfg.algo.per_rank_gradient_steps * cfg.per_rank_batch_size, - sample_next_obs=cfg.buffer.sample_next_obs, - ) # [G*B, 1] - gathered_data = fabric.all_gather(sample.to_dict()) # [G*B, World, 1] - gathered_data = make_tensordict(gathered_data).view(-1) # [G*B*World] - if fabric.world_size > 1: - dist_sampler: DistributedSampler = DistributedSampler( - range(len(gathered_data)), - num_replicas=fabric.world_size, - rank=fabric.global_rank, - shuffle=True, - seed=cfg.seed, - drop_last=False, - ) - sampler: BatchSampler = BatchSampler( - sampler=dist_sampler, batch_size=cfg.per_rank_batch_size, drop_last=False - ) - else: - sampler = BatchSampler( - sampler=range(len(gathered_data)), batch_size=cfg.per_rank_batch_size, drop_last=False - ) + + # We sample one time to reduce the communications between processes + sample = rb.sample( + training_steps * cfg.algo.per_rank_gradient_steps * cfg.per_rank_batch_size, + sample_next_obs=cfg.buffer.sample_next_obs, + ) # [G*B, 1] + gathered_data = fabric.all_gather(sample.to_dict()) # [G*B, World, 1] + gathered_data = make_tensordict(gathered_data).view(-1) # [G*B*World] + if fabric.world_size > 1: + dist_sampler: DistributedSampler = DistributedSampler( + range(len(gathered_data)), + num_replicas=fabric.world_size, + rank=fabric.global_rank, + shuffle=True, + seed=cfg.seed, + drop_last=False, + ) + sampler: BatchSampler = BatchSampler( + sampler=dist_sampler, batch_size=cfg.per_rank_batch_size, drop_last=False + ) + else: + sampler = BatchSampler( + sampler=range(len(gathered_data)), batch_size=cfg.per_rank_batch_size, drop_last=False + ) + + # Start training + with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): for batch_idxes in sampler: train( fabric, @@ -488,12 +497,34 @@ def main(cfg: DictConfig): cfg, policy_steps_per_update, ) - aggregator.update("Time/step_per_second", int(policy_step / (time.time() - start_time))) - if policy_step - last_log >= cfg.metric.log_every or cfg.dry_run: - last_log = policy_step - fabric.log_dict(aggregator.compute(), policy_step) + train_step += 1 + + # Log metrics + if policy_step - last_log >= cfg.metric.log_every or update == num_updates or cfg.dry_run: + # Sync distributed metrics + metrics_dict = aggregator.compute() + fabric.log_dict(metrics_dict, policy_step) aggregator.reset() + # Sync distributed timers + timer_metrics = timer.compute() + fabric.log( + "Time/sps_train", + (train_step - last_train) / timer_metrics["Time/train_time"], + policy_step, + ) + fabric.log( + "Time/sps_env_interaction", + ((policy_step - last_log) / world_size * cfg.env.action_repeat) + / timer_metrics["Time/env_interaction_time"], + policy_step, + ) + timer.reset() + + # Reset counters + last_log = policy_step + last_train = train_step + # Checkpoint model if (cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every) or cfg.dry_run: last_checkpoint >= policy_step From 14587afef0074569fce8ccfc8ba3fc7a16cd1b95 Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Wed, 13 Sep 2023 09:10:15 +0000 Subject: [PATCH 18/31] config: added dreamer_v3 configs --- .../exp/dreamer_v3_dmc_walker_walk.yaml | 55 +++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 sheeprl/configs/exp/dreamer_v3_dmc_walker_walk.yaml diff --git a/sheeprl/configs/exp/dreamer_v3_dmc_walker_walk.yaml b/sheeprl/configs/exp/dreamer_v3_dmc_walker_walk.yaml new file mode 100644 index 00000000..a683a28e --- /dev/null +++ b/sheeprl/configs/exp/dreamer_v3_dmc_walker_walk.yaml @@ -0,0 +1,55 @@ +# @package _global_ + +defaults: + - dreamer_v3 + - override /env: dmc + - _self_ + +# Experiment +seed: 5 +total_steps: 1000000 +cnn_keys: + encoder: + - rgb +mlp_keys: + encoder: + - state + +# Environment +env: + num_envs: 1 + max_episode_steps: 1000 + id: walker_walk + env: + from_vectors: True + from_pixels: True + +# Checkpoint +checkpoint: + every: 10000 + +# Buffer +buffer: + size: 100000 + checkpoint: True + memmap: True + +# Algorithm +algo: + learning_starts: 8000 + train_every: 2 + dense_units: 512 + mlp_layers: 2 + world_model: + encoder: + cnn_channels_multiplier: 32 + recurrent_model: + recurrent_state_size: 512 + transition_model: + hidden_size: 512 + representation_model: + hidden_size: 512 + +# Metric +metric: + log_every: 5000 \ No newline at end of file From b1f309cfce7bb6c196411e1792a5f616149b0efc Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Wed, 13 Sep 2023 09:14:48 +0000 Subject: [PATCH 19/31] fix: dependencies --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index c93876c0..8d116d81 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,8 @@ dependencies = [ "lightning==2.0.*", "lightning-utilities<0.9", "hydra-core==1.3.0", - "torchmetrics==1.1.*" + "torchmetrics==1.1.*", + "opencv-python==4.8.0.*" ] dynamic = ["version"] From 247dcbd1a2519415279781da281e5d63520a6197 Mon Sep 17 00:00:00 2001 From: belerico_t Date: Wed, 13 Sep 2023 11:30:01 +0200 Subject: [PATCH 20/31] Add sps for Dreamer-V1 --- sheeprl/algos/dreamer_v1/dreamer_v1.py | 209 ++++++++++++++----------- 1 file changed, 118 insertions(+), 91 deletions(-) diff --git a/sheeprl/algos/dreamer_v1/dreamer_v1.py b/sheeprl/algos/dreamer_v1/dreamer_v1.py index 3a30d809..94d275f0 100644 --- a/sheeprl/algos/dreamer_v1/dreamer_v1.py +++ b/sheeprl/algos/dreamer_v1/dreamer_v1.py @@ -1,7 +1,6 @@ import copy import os import pathlib -import time import warnings from typing import Dict @@ -18,7 +17,7 @@ from tensordict.tensordict import TensorDictBase from torch.distributions import Bernoulli, Independent, Normal from torch.utils.data import BatchSampler -from torchmetrics import MeanMetric +from torchmetrics import MeanMetric, SumMetric from sheeprl.algos.dreamer_v1.agent import PlayerDV1, WorldModel, build_models from sheeprl.algos.dreamer_v1.loss import actor_loss, critic_loss, reconstruction_loss @@ -29,6 +28,7 @@ from sheeprl.utils.logger import create_tensorboard_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm +from sheeprl.utils.timer import timer from sheeprl.utils.utils import compute_lambda_values, polynomial_decay # Decomment the following two lines if you cannot start an experiment with DMC environments @@ -367,8 +367,9 @@ def main(cfg: DictConfig): fabric = Fabric(callbacks=[CheckpointCallback()]) if not _is_using_cli(): fabric.launch() - rank = fabric.global_rank device = fabric.device + rank = fabric.global_rank + world_size = fabric.world_size fabric.seed_everything(cfg.seed) torch.backends.cudnn.deterministic = cfg.torch_deterministic @@ -379,7 +380,7 @@ def main(cfg: DictConfig): ckpt_path = pathlib.Path(cfg.checkpoint.resume_from) cfg = OmegaConf.load(ckpt_path.parent.parent.parent / ".hydra" / "config.yaml") cfg.checkpoint.resume_from = str(ckpt_path) - cfg.per_rank_batch_size = state["batch_size"] // fabric.world_size + cfg.per_rank_batch_size = state["batch_size"] // world_size cfg.root_dir = root_dir cfg.run_name = f"resume_from_checkpoint_{run_name}" @@ -479,32 +480,30 @@ def main(cfg: DictConfig): ) # Metrics - with device: - aggregator = MetricAggregator( - { - "Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Time/step_per_second": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/reconstruction_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/value_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/policy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/observation_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/reward_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/state_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/continue_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "State/post_entropy": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "State/prior_entropy": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "State/kl": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Params/exploration_amout": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Grads/world_model": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Grads/actor": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Grads/critic": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - } - ) - aggregator.to(fabric.device) + aggregator = MetricAggregator( + { + "Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Time/step_per_second": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/reconstruction_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/value_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/policy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/observation_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/reward_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/state_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/continue_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "State/post_entropy": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "State/prior_entropy": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "State/kl": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Params/exploration_amout": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Grads/world_model": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Grads/actor": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Grads/critic": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + } + ).to(fabric.device) # Local data - buffer_size = cfg.buffer.size // int(cfg.env.num_envs * fabric.world_size) if not cfg.dry_run else 2 + buffer_size = cfg.buffer.size // int(cfg.env.num_envs * world_size) if not cfg.dry_run else 2 rb = AsyncReplayBuffer( buffer_size, cfg.env.num_envs, @@ -514,28 +513,29 @@ def main(cfg: DictConfig): sequential=True, ) if cfg.checkpoint.resume_from and cfg.buffer.checkpoint: - if isinstance(state["rb"], list) and fabric.world_size == len(state["rb"]): + if isinstance(state["rb"], list) and world_size == len(state["rb"]): rb = state["rb"][fabric.global_rank] elif isinstance(state["rb"], AsyncReplayBuffer): rb = state["rb"] else: - raise RuntimeError(f"Given {len(state['rb'])}, but {fabric.world_size} processes are instantiated") + raise RuntimeError(f"Given {len(state['rb'])}, but {world_size} processes are instantiated") step_data = TensorDict({}, batch_size=[cfg.env.num_envs], device=fabric.device if cfg.buffer.memmap else "cpu") expl_decay_steps = state["expl_decay_steps"] if cfg.checkpoint.resume_from else 0 # Global variables - start_step = state["update"] // fabric.world_size if cfg.checkpoint.resume_from else 1 + train_step = 0 + last_train = 0 + start_step = state["update"] // world_size if cfg.checkpoint.resume_from else 1 policy_step = state["update"] * cfg.env.num_envs if cfg.checkpoint.resume_from else 0 last_log = state["last_log"] if cfg.checkpoint.resume_from else 0 last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 - start_time = time.perf_counter() - policy_steps_per_update = int(cfg.env.num_envs * fabric.world_size) + policy_steps_per_update = int(cfg.env.num_envs * world_size) updates_before_training = cfg.algo.train_every // policy_steps_per_update if not cfg.dry_run else 0 num_updates = int(cfg.total_steps // policy_steps_per_update) if not cfg.dry_run else 1 learning_starts = (cfg.algo.learning_starts // policy_steps_per_update) if not cfg.dry_run else 0 if cfg.checkpoint.resume_from and not cfg.buffer.checkpoint: learning_starts += start_step - max_step_expl_decay = cfg.algo.player.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * fabric.world_size) + max_step_expl_decay = cfg.algo.player.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * world_size) if cfg.checkpoint.resume_from: player.expl_amount = polynomial_decay( expl_decay_steps, @@ -576,47 +576,50 @@ def main(cfg: DictConfig): player.init_states() for update in range(start_step, num_updates + 1): - # Sample an action given the observation received by the environment - if update <= learning_starts and cfg.checkpoint.resume_from is None and "minedojo" not in cfg.env.id: - real_actions = actions = np.array(envs.action_space.sample()) - if not is_continuous: - actions = np.concatenate( - [ - F.one_hot(torch.tensor(act), act_dim).numpy() - for act, act_dim in zip(actions.reshape(len(actions_dim), -1), actions_dim) - ], - axis=-1, - ) - else: - with torch.no_grad(): - preprocessed_obs = {} - for k, v in obs.items(): - if k in cfg.cnn_keys.encoder: - preprocessed_obs[k] = v[None, ...].to(device) / 255 - 0.5 + policy_step += cfg.env.num_envs * world_size + + # Measure environment interaction time: this considers both the model forward + # to get the action given the observation and the time taken into the environment + with timer("Time/env_interaction_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + # Sample an action given the observation received by the environment + if update <= learning_starts and cfg.checkpoint.resume_from is None and "minedojo" not in cfg.env.id: + real_actions = actions = np.array(envs.action_space.sample()) + if not is_continuous: + actions = np.concatenate( + [ + F.one_hot(torch.tensor(act), act_dim).numpy() + for act, act_dim in zip(actions.reshape(len(actions_dim), -1), actions_dim) + ], + axis=-1, + ) + else: + with torch.no_grad(): + preprocessed_obs = {} + for k, v in obs.items(): + if k in cfg.cnn_keys.encoder: + preprocessed_obs[k] = v[None, ...].to(device) / 255 - 0.5 + else: + preprocessed_obs[k] = v[None, ...].to(device) + mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} + if len(mask) == 0: + mask = None + real_actions = actions = player.get_exploration_action(preprocessed_obs, is_continuous, mask) + actions = torch.cat(actions, -1).cpu().numpy() + if is_continuous: + real_actions = torch.cat(real_actions, -1).cpu().numpy() else: - preprocessed_obs[k] = v[None, ...].to(device) - mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} - if len(mask) == 0: - mask = None - real_actions = actions = player.get_exploration_action(preprocessed_obs, is_continuous, mask) - actions = torch.cat(actions, -1).cpu().numpy() - if is_continuous: - real_actions = torch.cat(real_actions, -1).cpu().numpy() - else: - real_actions = np.array([real_act.cpu().argmax(dim=-1).numpy() for real_act in real_actions]) - o, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) - dones = np.logical_or(dones, truncated) - - policy_step += cfg.env.num_envs * fabric.world_size + real_actions = np.array([real_act.cpu().argmax(dim=-1).numpy() for real_act in real_actions]) + o, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) + dones = np.logical_or(dones, truncated) if "final_info" in infos: - for i, agent_final_info in enumerate(infos["final_info"]): - if agent_final_info is not None and "episode" in agent_final_info: - fabric.print( - f"Rank-0: policy_step={policy_step}, reward_env_{i}={agent_final_info['episode']['r'][0]}" - ) - aggregator.update("Rewards/rew_avg", agent_final_info["episode"]["r"][0]) - aggregator.update("Game/ep_len_avg", agent_final_info["episode"]["l"][0]) + for i, agent_ep_info in enumerate(infos["final_info"]): + if agent_ep_info is not None: + ep_rew = agent_ep_info["episode"]["r"] + ep_len = agent_ep_info["episode"]["l"] + aggregator.update("Rewards/rew_avg", ep_rew) + aggregator.update("Game/ep_len_avg", ep_len) + fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") # Save the real next observation real_next_obs = copy.deepcopy(o) @@ -673,19 +676,22 @@ def main(cfg: DictConfig): n_samples=cfg.algo.per_rank_gradient_steps, ).to(device) distributed_sampler = BatchSampler(range(local_data.shape[0]), batch_size=1, drop_last=False) - for i in distributed_sampler: - train( - fabric, - world_model, - actor, - critic, - world_optimizer, - actor_optimizer, - critic_optimizer, - local_data[i].view(cfg.per_rank_sequence_length, cfg.per_rank_batch_size), - aggregator, - cfg, - ) + # Start training + with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + for i in distributed_sampler: + train( + fabric, + world_model, + actor, + critic, + world_optimizer, + actor_optimizer, + critic_optimizer, + local_data[i].view(cfg.per_rank_sequence_length, cfg.per_rank_batch_size), + aggregator, + cfg, + ) + train_step += 1 updates_before_training = cfg.algo.train_every // policy_steps_per_update if cfg.algo.player.expl_decay: expl_decay_steps += 1 @@ -696,12 +702,33 @@ def main(cfg: DictConfig): max_decay_steps=max_step_expl_decay, ) aggregator.update("Params/exploration_amout", player.expl_amount) - aggregator.update("Time/step_per_second", int(policy_step / (time.perf_counter() - start_time))) - if policy_step - last_log >= cfg.metric.log_every or cfg.dry_run: - last_log = policy_step - fabric.log_dict(aggregator.compute(), policy_step) + + # Log metrics + if policy_step - last_log >= cfg.metric.log_every or update == num_updates or cfg.dry_run: + # Sync distributed metrics + metrics_dict = aggregator.compute() + fabric.log_dict(metrics_dict, policy_step) aggregator.reset() + # Sync distributed timers + timer_metrics = timer.compute() + fabric.log( + "Time/sps_train", + (train_step - last_train) / timer_metrics["Time/train_time"], + policy_step, + ) + fabric.log( + "Time/sps_env_interaction", + ((policy_step - last_log) / world_size * cfg.env.action_repeat) + / timer_metrics["Time/env_interaction_time"], + policy_step, + ) + timer.reset() + + # Reset counters + last_log = policy_step + last_train = train_step + # Checkpoint Model if ( (cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every) @@ -717,8 +744,8 @@ def main(cfg: DictConfig): "actor_optimizer": actor_optimizer.state_dict(), "critic_optimizer": critic_optimizer.state_dict(), "expl_decay_steps": expl_decay_steps, - "update": update * fabric.world_size, - "batch_size": cfg.per_rank_batch_size * fabric.world_size, + "update": update * world_size, + "batch_size": cfg.per_rank_batch_size * world_size, "last_log": last_log, "last_checkpoint": last_checkpoint, } From 0eb05bbc56c5c84ffe3ede37c04647cd830caa37 Mon Sep 17 00:00:00 2001 From: belerico_t Date: Wed, 13 Sep 2023 11:40:36 +0200 Subject: [PATCH 21/31] From `reconstruction_loss` to `world_model_loss` in all Dreamers + add sps to Dreamer-V2 --- sheeprl/algos/dreamer_v1/dreamer_v1.py | 5 +- sheeprl/algos/dreamer_v2/dreamer_v2.py | 223 ++++++++++++++----------- sheeprl/algos/dreamer_v3/dreamer_v3.py | 4 +- sheeprl/algos/p2e_dv1/p2e_dv1.py | 4 +- sheeprl/algos/p2e_dv2/p2e_dv2.py | 4 +- 5 files changed, 133 insertions(+), 107 deletions(-) diff --git a/sheeprl/algos/dreamer_v1/dreamer_v1.py b/sheeprl/algos/dreamer_v1/dreamer_v1.py index 94d275f0..9a75200b 100644 --- a/sheeprl/algos/dreamer_v1/dreamer_v1.py +++ b/sheeprl/algos/dreamer_v1/dreamer_v1.py @@ -212,7 +212,7 @@ def train( ) world_optimizer.step() aggregator.update("Grads/world_model", world_model_grads.mean().detach()) - aggregator.update("Loss/reconstruction_loss", rec_loss.detach()) + aggregator.update("Loss/world_model_loss", rec_loss.detach()) aggregator.update("Loss/observation_loss", observation_loss.detach()) aggregator.update("Loss/reward_loss", reward_loss.detach()) aggregator.update("Loss/state_loss", state_loss.detach()) @@ -484,8 +484,7 @@ def main(cfg: DictConfig): { "Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), "Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Time/step_per_second": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/reconstruction_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/world_model_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), "Loss/value_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), "Loss/policy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), "Loss/observation_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), diff --git a/sheeprl/algos/dreamer_v2/dreamer_v2.py b/sheeprl/algos/dreamer_v2/dreamer_v2.py index 81f51902..b72f57d3 100644 --- a/sheeprl/algos/dreamer_v2/dreamer_v2.py +++ b/sheeprl/algos/dreamer_v2/dreamer_v2.py @@ -24,7 +24,7 @@ from torch.distributions import Bernoulli, Distribution, Independent, Normal, OneHotCategorical from torch.optim import Optimizer from torch.utils.data import BatchSampler -from torchmetrics import MeanMetric +from torchmetrics import MeanMetric, SumMetric from sheeprl.algos.dreamer_v2.agent import PlayerDV2, WorldModel, build_models from sheeprl.algos.dreamer_v2.loss import reconstruction_loss @@ -35,6 +35,7 @@ from sheeprl.utils.logger import create_tensorboard_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm +from sheeprl.utils.timer import timer from sheeprl.utils.utils import polynomial_decay # Decomment the following two lines if you cannot start an experiment with DMC environments @@ -207,7 +208,7 @@ def train( ) world_optimizer.step() aggregator.update("Grads/world_model", world_model_grads.mean().detach()) - aggregator.update("Loss/reconstruction_loss", rec_loss.detach()) + aggregator.update("Loss/world_model_loss", rec_loss.detach()) aggregator.update("Loss/observation_loss", observation_loss.detach()) aggregator.update("Loss/reward_loss", reward_loss.detach()) aggregator.update("Loss/state_loss", state_loss.detach()) @@ -387,8 +388,9 @@ def main(cfg: DictConfig): fabric = Fabric(callbacks=[CheckpointCallback()]) if not _is_using_cli(): fabric.launch() - rank = fabric.global_rank device = fabric.device + rank = fabric.global_rank + world_size = fabric.world_size fabric.seed_everything(cfg.seed) torch.backends.cudnn.deterministic = cfg.torch_deterministic @@ -399,7 +401,7 @@ def main(cfg: DictConfig): ckpt_path = pathlib.Path(cfg.checkpoint.resume_from) cfg = OmegaConf.load(ckpt_path.parent.parent.parent / ".hydra" / "config.yaml") cfg.checkpoint.resume_from = str(ckpt_path) - cfg.per_rank_batch_size = state["batch_size"] // fabric.world_size + cfg.per_rank_batch_size = state["batch_size"] // world_size cfg.root_dir = root_dir cfg.run_name = f"resume_from_checkpoint_{run_name}" @@ -500,32 +502,29 @@ def main(cfg: DictConfig): ) # Metrics - with device: - aggregator = MetricAggregator( - { - "Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Time/step_per_second": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/reconstruction_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/value_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/policy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/observation_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/reward_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/state_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/continue_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "State/post_entropy": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "State/prior_entropy": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "State/kl": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Params/exploration_amout": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Grads/world_model": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Grads/actor": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Grads/critic": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - } - ) - aggregator.to(fabric.device) + aggregator = MetricAggregator( + { + "Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/world_model_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/value_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/policy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/observation_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/reward_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/state_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/continue_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "State/post_entropy": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "State/prior_entropy": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "State/kl": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Params/exploration_amout": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Grads/world_model": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Grads/actor": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Grads/critic": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + } + ).to(fabric.device) # Local data - buffer_size = cfg.buffer.size // int(cfg.env.num_envs * fabric.world_size) if not cfg.dry_run else 2 + buffer_size = cfg.buffer.size // int(cfg.env.num_envs * world_size) if not cfg.dry_run else 2 buffer_type = cfg.buffer.type.lower() if buffer_type == "sequential": rb = AsyncReplayBuffer( @@ -547,28 +546,30 @@ def main(cfg: DictConfig): else: raise ValueError(f"Unrecognized buffer type: must be one of `sequential` or `episode`, received: {buffer_type}") if cfg.checkpoint.resume_from and cfg.buffer.checkpoint: - if isinstance(state["rb"], list) and fabric.world_size == len(state["rb"]): + if isinstance(state["rb"], list) and world_size == len(state["rb"]): rb = state["rb"][fabric.global_rank] elif isinstance(state["rb"], (AsyncReplayBuffer, EpisodeBuffer)): rb = state["rb"] else: - raise RuntimeError(f"Given {len(state['rb'])}, but {fabric.world_size} processes are instantiated") + raise RuntimeError(f"Given {len(state['rb'])}, but {world_size} processes are instantiated") step_data = TensorDict({}, batch_size=[cfg.env.num_envs], device="cpu") expl_decay_steps = state["expl_decay_steps"] if cfg.checkpoint.resume_from else 0 # Global variables - start_step = state["update"] // fabric.world_size if cfg.checkpoint.resume_from else 1 + train_step = 0 + last_train = 0 + start_step = state["update"] // world_size if cfg.checkpoint.resume_from else 1 policy_step = state["update"] * cfg.env.num_envs if cfg.checkpoint.resume_from else 0 last_log = state["last_log"] if cfg.checkpoint.resume_from else 0 last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 - start_time = time.perf_counter() - policy_steps_per_update = int(cfg.env.num_envs * fabric.world_size) + time.perf_counter() + policy_steps_per_update = int(cfg.env.num_envs * world_size) updates_before_training = cfg.algo.train_every // policy_steps_per_update if not cfg.dry_run else 0 num_updates = cfg.total_steps // policy_steps_per_update if not cfg.dry_run else 1 learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 if cfg.checkpoint.resume_from and not cfg.buffer.checkpoint: learning_starts += start_step - max_step_expl_decay = cfg.algo.player.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * fabric.world_size) + max_step_expl_decay = cfg.algo.player.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * world_size) if cfg.checkpoint.resume_from: player.expl_amount = polynomial_decay( expl_decay_steps, @@ -618,46 +619,49 @@ def main(cfg: DictConfig): per_rank_gradient_steps = 0 for update in range(start_step, num_updates + 1): - # Sample an action given the observation received by the environment - if ( - update <= learning_starts - and cfg.checkpoint.resume_from is None - and "minedojo" not in cfg.algo.actor.cls.lower() - ): - real_actions = actions = np.array(envs.action_space.sample()) - if not is_continuous: - actions = np.concatenate( - [ - F.one_hot(torch.tensor(act), act_dim).numpy() - for act, act_dim in zip(actions.reshape(len(actions_dim), -1), actions_dim) - ], - axis=-1, - ) - else: - with torch.no_grad(): - preprocessed_obs = {} - for k, v in obs.items(): - if k in cfg.cnn_keys.encoder: - preprocessed_obs[k] = v[None, ...].to(device) / 255 - 0.5 + policy_step += cfg.env.num_envs * world_size + + # Measure environment interaction time: this considers both the model forward + # to get the action given the observation and the time taken into the environment + with timer("Time/env_interaction_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + # Sample an action given the observation received by the environment + if ( + update <= learning_starts + and cfg.checkpoint.resume_from is None + and "minedojo" not in cfg.algo.actor.cls.lower() + ): + real_actions = actions = np.array(envs.action_space.sample()) + if not is_continuous: + actions = np.concatenate( + [ + F.one_hot(torch.tensor(act), act_dim).numpy() + for act, act_dim in zip(actions.reshape(len(actions_dim), -1), actions_dim) + ], + axis=-1, + ) + else: + with torch.no_grad(): + preprocessed_obs = {} + for k, v in obs.items(): + if k in cfg.cnn_keys.encoder: + preprocessed_obs[k] = v[None, ...].to(device) / 255 - 0.5 + else: + preprocessed_obs[k] = v[None, ...].to(device) + mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} + if len(mask) == 0: + mask = None + real_actions = actions = player.get_exploration_action(preprocessed_obs, is_continuous, mask) + actions = torch.cat(actions, -1).cpu().numpy() + if is_continuous: + real_actions = torch.cat(real_actions, -1).cpu().numpy() else: - preprocessed_obs[k] = v[None, ...].to(device) - mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} - if len(mask) == 0: - mask = None - real_actions = actions = player.get_exploration_action(preprocessed_obs, is_continuous, mask) - actions = torch.cat(actions, -1).cpu().numpy() - if is_continuous: - real_actions = torch.cat(real_actions, -1).cpu().numpy() - else: - real_actions = np.array([real_act.cpu().argmax(dim=-1).numpy() for real_act in real_actions]) - - step_data["is_first"] = copy.deepcopy(step_data["dones"]) - o, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) - dones = np.logical_or(dones, truncated) - if cfg.dry_run and buffer_type == "episode": - dones = np.ones_like(dones) - - policy_step += cfg.env.num_envs * fabric.world_size + real_actions = np.array([real_act.cpu().argmax(dim=-1).numpy() for real_act in real_actions]) + + step_data["is_first"] = copy.deepcopy(step_data["dones"]) + o, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) + dones = np.logical_or(dones, truncated) + if cfg.dry_run and buffer_type == "episode": + dones = np.ones_like(dones) if "final_info" in infos: for i, agent_final_info in enumerate(infos["final_info"]): @@ -746,25 +750,27 @@ def main(cfg: DictConfig): prioritize_ends=cfg.buffer.prioritize_ends, ).to(device) distributed_sampler = BatchSampler(range(local_data.shape[0]), batch_size=1, drop_last=False) - for i in distributed_sampler: - if per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: - for cp, tcp in zip(critic.module.parameters(), target_critic.parameters()): - tcp.data.copy_(cp.data) - train( - fabric, - world_model, - actor, - critic, - target_critic, - world_optimizer, - actor_optimizer, - critic_optimizer, - local_data[i].view(cfg.per_rank_sequence_length, cfg.per_rank_batch_size), - aggregator, - cfg, - actions_dim, - ) - per_rank_gradient_steps += 1 + with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + for i in distributed_sampler: + if per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: + for cp, tcp in zip(critic.module.parameters(), target_critic.parameters()): + tcp.data.copy_(cp.data) + train( + fabric, + world_model, + actor, + critic, + target_critic, + world_optimizer, + actor_optimizer, + critic_optimizer, + local_data[i].view(cfg.per_rank_sequence_length, cfg.per_rank_batch_size), + aggregator, + cfg, + actions_dim, + ) + per_rank_gradient_steps += 1 + train_step += 1 updates_before_training = cfg.algo.train_every // policy_steps_per_update if cfg.algo.player.expl_decay: expl_decay_steps += 1 @@ -775,12 +781,33 @@ def main(cfg: DictConfig): max_decay_steps=max_step_expl_decay, ) aggregator.update("Params/exploration_amout", player.expl_amount) - aggregator.update("Time/step_per_second", int(policy_step / (time.perf_counter() - start_time))) - if policy_step - last_log >= cfg.metric.log_every or cfg.dry_run: - last_log = policy_step - fabric.log_dict(aggregator.compute(), policy_step) + + # Log metrics + if policy_step - last_log >= cfg.metric.log_every or update == num_updates or cfg.dry_run: + # Sync distributed metrics + metrics_dict = aggregator.compute() + fabric.log_dict(metrics_dict, policy_step) aggregator.reset() + # Sync distributed timers + timer_metrics = timer.compute() + fabric.log( + "Time/sps_train", + (train_step - last_train) / timer_metrics["Time/train_time"], + policy_step, + ) + fabric.log( + "Time/sps_env_interaction", + ((policy_step - last_log) / world_size * cfg.env.action_repeat) + / timer_metrics["Time/env_interaction_time"], + policy_step, + ) + timer.reset() + + # Reset counters + last_log = policy_step + last_train = train_step + # Checkpoint Model if ( (cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every) @@ -797,8 +824,8 @@ def main(cfg: DictConfig): "actor_optimizer": actor_optimizer.state_dict(), "critic_optimizer": critic_optimizer.state_dict(), "expl_decay_steps": expl_decay_steps, - "update": update * fabric.world_size, - "batch_size": cfg.per_rank_batch_size * fabric.world_size, + "update": update * world_size, + "batch_size": cfg.per_rank_batch_size * world_size, "last_log": last_log, "last_checkpoint": last_checkpoint, } diff --git a/sheeprl/algos/dreamer_v3/dreamer_v3.py b/sheeprl/algos/dreamer_v3/dreamer_v3.py index eec48cd5..839bd66a 100644 --- a/sheeprl/algos/dreamer_v3/dreamer_v3.py +++ b/sheeprl/algos/dreamer_v3/dreamer_v3.py @@ -177,7 +177,7 @@ def train( ) world_optimizer.step() aggregator.update("Grads/world_model", world_model_grads.mean().detach()) - aggregator.update("Loss/reconstruction_loss", rec_loss.detach()) + aggregator.update("Loss/world_model_loss", rec_loss.detach()) aggregator.update("Loss/observation_loss", observation_loss.detach()) aggregator.update("Loss/reward_loss", reward_loss.detach()) aggregator.update("Loss/state_loss", state_loss.detach()) @@ -468,7 +468,7 @@ def main(cfg: DictConfig): "Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), "Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), "Time/step_per_second": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/reconstruction_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/world_model_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), "Loss/value_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), "Loss/policy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), "Loss/observation_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1.py b/sheeprl/algos/p2e_dv1/p2e_dv1.py index ced0ec92..0e301844 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1.py @@ -170,7 +170,7 @@ def train( ) aggregator.update("Grads/world_model", world_grad.detach()) world_optimizer.step() - aggregator.update("Loss/reconstruction_loss", rec_loss.detach()) + aggregator.update("Loss/world_model_loss", rec_loss.detach()) aggregator.update("Loss/observation_loss", observation_loss.detach()) aggregator.update("Loss/reward_loss", reward_loss.detach()) aggregator.update("Loss/state_loss", state_loss.detach()) @@ -535,7 +535,7 @@ def main(cfg: DictConfig): "Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), "Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), "Time/step_per_second": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/reconstruction_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/world_model_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), "Loss/value_loss_task": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), "Loss/policy_loss_task": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), "Loss/value_loss_exploration": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2.py b/sheeprl/algos/p2e_dv2/p2e_dv2.py index 4746d8f1..8feeaa34 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2.py @@ -192,7 +192,7 @@ def train( error_if_nonfinite=False, ) world_optimizer.step() - aggregator.update("Loss/reconstruction_loss", rec_loss.detach()) + aggregator.update("Loss/world_model_loss", rec_loss.detach()) aggregator.update("Loss/observation_loss", observation_loss.detach()) aggregator.update("Loss/reward_loss", reward_loss.detach()) aggregator.update("Loss/state_loss", state_loss.detach()) @@ -662,7 +662,7 @@ def main(cfg: DictConfig): "Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), "Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), "Time/step_per_second": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/reconstruction_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/world_model_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), "Loss/value_loss_task": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), "Loss/policy_loss_task": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), "Loss/value_loss_exploration": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), From 17ad8d9682c587bb8188eebfecff059eafa96693 Mon Sep 17 00:00:00 2001 From: belerico_t Date: Wed, 13 Sep 2023 11:45:20 +0200 Subject: [PATCH 22/31] Add sps to Draemer-V3 algo --- sheeprl/algos/dreamer_v2/dreamer_v2.py | 2 - sheeprl/algos/dreamer_v3/dreamer_v3.py | 205 ++++++++++++++----------- 2 files changed, 115 insertions(+), 92 deletions(-) diff --git a/sheeprl/algos/dreamer_v2/dreamer_v2.py b/sheeprl/algos/dreamer_v2/dreamer_v2.py index b72f57d3..d4fa78c4 100644 --- a/sheeprl/algos/dreamer_v2/dreamer_v2.py +++ b/sheeprl/algos/dreamer_v2/dreamer_v2.py @@ -5,7 +5,6 @@ import copy import os import pathlib -import time import warnings from typing import Dict, Sequence @@ -562,7 +561,6 @@ def main(cfg: DictConfig): policy_step = state["update"] * cfg.env.num_envs if cfg.checkpoint.resume_from else 0 last_log = state["last_log"] if cfg.checkpoint.resume_from else 0 last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 - time.perf_counter() policy_steps_per_update = int(cfg.env.num_envs * world_size) updates_before_training = cfg.algo.train_every // policy_steps_per_update if not cfg.dry_run else 0 num_updates = cfg.total_steps // policy_steps_per_update if not cfg.dry_run else 1 diff --git a/sheeprl/algos/dreamer_v3/dreamer_v3.py b/sheeprl/algos/dreamer_v3/dreamer_v3.py index 839bd66a..8633a849 100644 --- a/sheeprl/algos/dreamer_v3/dreamer_v3.py +++ b/sheeprl/algos/dreamer_v3/dreamer_v3.py @@ -5,7 +5,6 @@ import copy import os import pathlib -import time import warnings from functools import partial from typing import Dict, Sequence @@ -25,7 +24,7 @@ from torch.distributions import Bernoulli, Distribution, Independent, OneHotCategorical from torch.optim import Optimizer from torch.utils.data import BatchSampler -from torchmetrics import MeanMetric +from torchmetrics import MeanMetric, SumMetric from sheeprl.algos.dreamer_v3.agent import PlayerDV3, WorldModel, build_models from sheeprl.algos.dreamer_v3.loss import reconstruction_loss @@ -38,6 +37,7 @@ from sheeprl.utils.logger import create_tensorboard_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm +from sheeprl.utils.timer import timer from sheeprl.utils.utils import polynomial_decay # Decomment the following two lines if you cannot start an experiment with DMC environments @@ -340,8 +340,9 @@ def main(cfg: DictConfig): fabric = Fabric(callbacks=[CheckpointCallback()]) if not _is_using_cli(): fabric.launch() - rank = fabric.global_rank device = fabric.device + rank = fabric.global_rank + world_size = fabric.world_size fabric.seed_everything(cfg.seed) torch.backends.cudnn.deterministic = cfg.torch_deterministic @@ -462,29 +463,26 @@ def main(cfg: DictConfig): ) # Metrics - with device: - aggregator = MetricAggregator( - { - "Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Time/step_per_second": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/world_model_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/value_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/policy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/observation_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/reward_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/state_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/continue_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "State/kl": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "State/post_entropy": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "State/prior_entropy": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Params/exploration_amout": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Grads/world_model": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Grads/actor": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Grads/critic": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - } - ) - aggregator.to(fabric.device) + aggregator = MetricAggregator( + { + "Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/world_model_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/value_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/policy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/observation_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/reward_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/state_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/continue_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "State/kl": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "State/post_entropy": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "State/prior_entropy": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Params/exploration_amout": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Grads/world_model": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Grads/actor": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Grads/critic": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + } + ).to(fabric.device) # Local data buffer_size = cfg.buffer.size // int(cfg.env.num_envs * fabric.world_size) if not cfg.dry_run else 2 @@ -507,11 +505,12 @@ def main(cfg: DictConfig): expl_decay_steps = state["expl_decay_steps"] if cfg.checkpoint.resume_from else 0 # Global variables + train_step = 0 + last_train = 0 start_step = state["update"] // fabric.world_size if cfg.checkpoint.resume_from else 1 policy_step = state["update"] * cfg.env.num_envs if cfg.checkpoint.resume_from else 0 last_log = state["last_log"] if cfg.checkpoint.resume_from else 0 last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 - start_time = time.perf_counter() policy_steps_per_update = int(cfg.env.num_envs * fabric.world_size) updates_before_training = cfg.algo.train_every // policy_steps_per_update num_updates = int(cfg.total_steps // policy_steps_per_update) if not cfg.dry_run else 1 @@ -560,46 +559,49 @@ def main(cfg: DictConfig): per_rank_gradient_steps = 0 for update in range(start_step, num_updates + 1): - # Sample an action given the observation received by the environment - if ( - update <= learning_starts - and cfg.checkpoint.resume_from is None - and "minedojo" not in cfg.algo.actor.cls.lower() - ): - real_actions = actions = np.array(envs.action_space.sample()) - if not is_continuous: - actions = np.concatenate( - [ - F.one_hot(torch.tensor(act), act_dim).numpy() - for act, act_dim in zip(actions.reshape(len(actions_dim), -1), actions_dim) - ], - axis=-1, - ) - else: - with torch.no_grad(): - preprocessed_obs = {} - for k, v in obs.items(): - if k in cfg.cnn_keys.encoder: - preprocessed_obs[k] = v[None, ...].to(device) / 255.0 + policy_step += cfg.env.num_envs * world_size + + # Measure environment interaction time: this considers both the model forward + # to get the action given the observation and the time taken into the environment + with timer("Time/env_interaction_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + # Sample an action given the observation received by the environment + if ( + update <= learning_starts + and cfg.checkpoint.resume_from is None + and "minedojo" not in cfg.algo.actor.cls.lower() + ): + real_actions = actions = np.array(envs.action_space.sample()) + if not is_continuous: + actions = np.concatenate( + [ + F.one_hot(torch.tensor(act), act_dim).numpy() + for act, act_dim in zip(actions.reshape(len(actions_dim), -1), actions_dim) + ], + axis=-1, + ) + else: + with torch.no_grad(): + preprocessed_obs = {} + for k, v in obs.items(): + if k in cfg.cnn_keys.encoder: + preprocessed_obs[k] = v[None, ...].to(device) / 255.0 + else: + preprocessed_obs[k] = v[None, ...].to(device) + mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} + if len(mask) == 0: + mask = None + real_actions = actions = player.get_exploration_action(preprocessed_obs, is_continuous, mask) + actions = torch.cat(actions, -1).cpu().numpy() + if is_continuous: + real_actions = torch.cat(real_actions, dim=-1).cpu().numpy() else: - preprocessed_obs[k] = v[None, ...].to(device) - mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} - if len(mask) == 0: - mask = None - real_actions = actions = player.get_exploration_action(preprocessed_obs, is_continuous, mask) - actions = torch.cat(actions, -1).cpu().numpy() - if is_continuous: - real_actions = torch.cat(real_actions, dim=-1).cpu().numpy() - else: - real_actions = np.array([real_act.cpu().argmax(dim=-1).numpy() for real_act in real_actions]) + real_actions = np.array([real_act.cpu().argmax(dim=-1).numpy() for real_act in real_actions]) - step_data["actions"] = torch.from_numpy(actions).view(cfg.env.num_envs, -1).float() - rb.add(step_data[None, ...]) + step_data["actions"] = torch.from_numpy(actions).view(cfg.env.num_envs, -1).float() + rb.add(step_data[None, ...]) - o, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) - dones = np.logical_or(dones, truncated) - - policy_step += cfg.env.num_envs * fabric.world_size + o, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) + dones = np.logical_or(dones, truncated) step_data["is_first"] = torch.zeros_like(step_data["dones"]) if "restart_on_exception" in infos: @@ -680,28 +682,30 @@ def main(cfg: DictConfig): else cfg.algo.per_rank_gradient_steps, ).to(device) distributed_sampler = BatchSampler(range(local_data.shape[0]), batch_size=1, drop_last=False) - for i in distributed_sampler: - if per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: - tau = 1 if per_rank_gradient_steps == 0 else cfg.algo.critic.tau - for cp, tcp in zip(critic.module.parameters(), target_critic.parameters()): - tcp.data.copy_(tau * cp.data + (1 - tau) * tcp.data) - train( - fabric, - world_model, - actor, - critic, - target_critic, - world_optimizer, - actor_optimizer, - critic_optimizer, - local_data[i].view(cfg.per_rank_sequence_length, cfg.per_rank_batch_size), - aggregator, - cfg, - is_continuous, - actions_dim, - moments, - ) - per_rank_gradient_steps += 1 + with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + for i in distributed_sampler: + if per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: + tau = 1 if per_rank_gradient_steps == 0 else cfg.algo.critic.tau + for cp, tcp in zip(critic.module.parameters(), target_critic.parameters()): + tcp.data.copy_(tau * cp.data + (1 - tau) * tcp.data) + train( + fabric, + world_model, + actor, + critic, + target_critic, + world_optimizer, + actor_optimizer, + critic_optimizer, + local_data[i].view(cfg.per_rank_sequence_length, cfg.per_rank_batch_size), + aggregator, + cfg, + is_continuous, + actions_dim, + moments, + ) + per_rank_gradient_steps += 1 + train_step += 1 updates_before_training = cfg.algo.train_every // policy_steps_per_update if cfg.algo.player.expl_decay: expl_decay_steps += 1 @@ -712,12 +716,33 @@ def main(cfg: DictConfig): max_decay_steps=max_step_expl_decay, ) aggregator.update("Params/exploration_amout", player.expl_amount) - aggregator.update("Time/step_per_second", int(policy_step / (time.perf_counter() - start_time))) - if policy_step - last_log >= cfg.metric.log_every or cfg.dry_run: - last_log = policy_step - fabric.log_dict(aggregator.compute(), policy_step) + + # Log metrics + if policy_step - last_log >= cfg.metric.log_every or update == num_updates or cfg.dry_run: + # Sync distributed metrics + metrics_dict = aggregator.compute() + fabric.log_dict(metrics_dict, policy_step) aggregator.reset() + # Sync distributed timers + timer_metrics = timer.compute() + fabric.log( + "Time/sps_train", + (train_step - last_train) / timer_metrics["Time/train_time"], + policy_step, + ) + fabric.log( + "Time/sps_env_interaction", + ((policy_step - last_log) / world_size * cfg.env.action_repeat) + / timer_metrics["Time/env_interaction_time"], + policy_step, + ) + timer.reset() + + # Reset counters + last_log = policy_step + last_train = train_step + # Checkpoint Model if ( (cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every) From bb55fa1e716691ad4a85a79e04f8fb762d563ab0 Mon Sep 17 00:00:00 2001 From: belerico_t Date: Wed, 13 Sep 2023 12:04:55 +0200 Subject: [PATCH 23/31] Add print_config + Add sps Dreamer-V3 --- sheeprl/algos/dreamer_v1/dreamer_v1.py | 4 +- sheeprl/algos/dreamer_v2/dreamer_v2.py | 4 +- sheeprl/algos/dreamer_v3/dreamer_v3.py | 4 +- sheeprl/algos/p2e_dv1/p2e_dv1.py | 232 ++++++++++++++----------- sheeprl/algos/p2e_dv2/p2e_dv2.py | 198 ++++++++++++--------- 5 files changed, 254 insertions(+), 188 deletions(-) diff --git a/sheeprl/algos/dreamer_v1/dreamer_v1.py b/sheeprl/algos/dreamer_v1/dreamer_v1.py index 9a75200b..1938354c 100644 --- a/sheeprl/algos/dreamer_v1/dreamer_v1.py +++ b/sheeprl/algos/dreamer_v1/dreamer_v1.py @@ -29,7 +29,7 @@ from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import compute_lambda_values, polynomial_decay +from sheeprl.utils.utils import compute_lambda_values, polynomial_decay, print_config # Decomment the following two lines if you cannot start an experiment with DMC environments # os.environ["PYOPENGL_PLATFORM"] = "" @@ -359,6 +359,8 @@ def train( @register_algorithm() @hydra.main(version_base=None, config_path="../../configs", config_name="config") def main(cfg: DictConfig): + print_config(cfg) + # These arguments cannot be changed cfg.env.screen_size = 64 cfg.env.frame_stack = 1 diff --git a/sheeprl/algos/dreamer_v2/dreamer_v2.py b/sheeprl/algos/dreamer_v2/dreamer_v2.py index d4fa78c4..fc3f54f7 100644 --- a/sheeprl/algos/dreamer_v2/dreamer_v2.py +++ b/sheeprl/algos/dreamer_v2/dreamer_v2.py @@ -35,7 +35,7 @@ from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import polynomial_decay +from sheeprl.utils.utils import polynomial_decay, print_config # Decomment the following two lines if you cannot start an experiment with DMC environments # os.environ["PYOPENGL_PLATFORM"] = "" @@ -379,6 +379,8 @@ def train( @register_algorithm() @hydra.main(version_base=None, config_path="../../configs", config_name="config") def main(cfg: DictConfig): + print_config(cfg) + # These arguments cannot be changed cfg.env.screen_size = 64 cfg.env.frame_stack = 1 diff --git a/sheeprl/algos/dreamer_v3/dreamer_v3.py b/sheeprl/algos/dreamer_v3/dreamer_v3.py index 8633a849..6108903c 100644 --- a/sheeprl/algos/dreamer_v3/dreamer_v3.py +++ b/sheeprl/algos/dreamer_v3/dreamer_v3.py @@ -38,7 +38,7 @@ from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import polynomial_decay +from sheeprl.utils.utils import polynomial_decay, print_config # Decomment the following two lines if you cannot start an experiment with DMC environments # os.environ["PYOPENGL_PLATFORM"] = "" @@ -332,6 +332,8 @@ def train( @register_algorithm() @hydra.main(version_base=None, config_path="../../configs", config_name="config") def main(cfg: DictConfig): + print_config(cfg) + # These arguments cannot be changed cfg.env.screen_size = 64 cfg.env.frame_stack = -1 diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1.py b/sheeprl/algos/p2e_dv1/p2e_dv1.py index 0e301844..3d1a0149 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1.py @@ -1,7 +1,6 @@ import copy import os import pathlib -import time import warnings from typing import Dict @@ -20,7 +19,7 @@ from torch import nn from torch.distributions import Bernoulli, Independent, Normal from torch.utils.data import BatchSampler -from torchmetrics import MeanMetric +from torchmetrics import MeanMetric, SumMetric from sheeprl.algos.dreamer_v1.agent import PlayerDV1, WorldModel from sheeprl.algos.dreamer_v1.loss import actor_loss, critic_loss, reconstruction_loss @@ -33,7 +32,8 @@ from sheeprl.utils.logger import create_tensorboard_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm -from sheeprl.utils.utils import compute_lambda_values, init_weights, polynomial_decay +from sheeprl.utils.timer import timer +from sheeprl.utils.utils import compute_lambda_values, init_weights, polynomial_decay, print_config # Decomment the following line if you are using MineDojo on an headless machine # os.environ["MINEDOJO_HEADLESS"] = "1" @@ -364,6 +364,8 @@ def train( @register_algorithm() @hydra.main(version_base=None, config_path="../../configs", config_name="config") def main(cfg: DictConfig): + print_config(cfg) + # These arguments cannot be changed cfg.env.screen_size = 64 cfg.env.frame_stack = 1 @@ -372,8 +374,9 @@ def main(cfg: DictConfig): fabric = Fabric(callbacks=[CheckpointCallback()]) if not _is_using_cli(): fabric.launch() - rank = fabric.global_rank device = fabric.device + rank = fabric.global_rank + world_size = fabric.world_size fabric.seed_everything(cfg.seed) torch.backends.cudnn.deterministic = cfg.torch_deterministic @@ -384,7 +387,7 @@ def main(cfg: DictConfig): ckpt_path = pathlib.Path(cfg.checkpoint.resume_from) cfg = OmegaConf.load(ckpt_path.parent.parent.parent / ".hydra" / "config.yaml") cfg.checkpoint.resume_from = str(ckpt_path) - cfg.per_rank_batch_size = state["batch_size"] // fabric.world_size + cfg.per_rank_batch_size = state["batch_size"] // world_size cfg.root_dir = root_dir cfg.run_name = f"resume_from_checkpoint_{run_name}" @@ -529,41 +532,38 @@ def main(cfg: DictConfig): ) # Metrics - with device: - aggregator = MetricAggregator( - { - "Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Time/step_per_second": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/world_model_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/value_loss_task": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/policy_loss_task": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/value_loss_exploration": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/policy_loss_exploration": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/observation_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/reward_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/state_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/continue_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Loss/ensemble_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "State/kl": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "State/p_entropy": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "State/q_entropy": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Params/exploration_amout": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Rewards/intrinsic": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Values_exploration/predicted_values": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Values_exploration/lambda_values": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Grads/world_model": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Grads/actor_task": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Grads/critic_task": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Grads/actor_exploration": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Grads/critic_exploration": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Grads/ensemble": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - } - ) - aggregator.to(device) + aggregator = MetricAggregator( + { + "Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/world_model_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/value_loss_task": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/policy_loss_task": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/value_loss_exploration": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/policy_loss_exploration": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/observation_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/reward_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/state_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/continue_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Loss/ensemble_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "State/kl": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "State/p_entropy": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "State/q_entropy": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Params/exploration_amout": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Rewards/intrinsic": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Values_exploration/predicted_values": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Values_exploration/lambda_values": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Grads/world_model": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Grads/actor_task": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Grads/critic_task": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Grads/actor_exploration": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Grads/critic_exploration": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + "Grads/ensemble": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), + } + ).to(device) # Local data - buffer_size = cfg.buffer.size // int(cfg.env.num_envs * fabric.world_size) if not cfg.dry_run else 4 + buffer_size = cfg.buffer.size // int(cfg.env.num_envs * world_size) if not cfg.dry_run else 4 rb = AsyncReplayBuffer( buffer_size, cfg.env.num_envs, @@ -573,22 +573,23 @@ def main(cfg: DictConfig): sequential=True, ) if cfg.checkpoint.resume_from and cfg.buffer.checkpoint: - if isinstance(state["rb"], list) and fabric.world_size == len(state["rb"]): + if isinstance(state["rb"], list) and world_size == len(state["rb"]): rb = state["rb"][fabric.global_rank] elif isinstance(state["rb"], AsyncReplayBuffer): rb = state["rb"] else: - raise RuntimeError(f"Given {len(state['rb'])}, but {fabric.world_size} processes are instantiated") + raise RuntimeError(f"Given {len(state['rb'])}, but {world_size} processes are instantiated") step_data = TensorDict({}, batch_size=[cfg.env.num_envs], device="cpu") expl_decay_steps = state["expl_decay_steps"] if cfg.checkpoint.resume_from else 0 # Global variables - start_step = state["update"] // fabric.world_size if cfg.checkpoint.resume_from else 1 + train_step = 0 + last_train = 0 + start_step = state["update"] // world_size if cfg.checkpoint.resume_from else 1 policy_step = state["update"] * cfg.env.num_envs if cfg.checkpoint.resume_from else 0 last_log = state["last_log"] if cfg.checkpoint.resume_from else 0 last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 - start_time = time.perf_counter() - policy_steps_per_update = int(cfg.env.num_envs * fabric.world_size) + policy_steps_per_update = int(cfg.env.num_envs * world_size) updates_before_training = cfg.algo.train_every // policy_steps_per_update if not cfg.dry_run else 0 num_updates = int(cfg.total_steps // policy_steps_per_update) if not cfg.dry_run else 1 learning_starts = (cfg.algo.learning_starts // policy_steps_per_update) if not cfg.dry_run else 0 @@ -596,7 +597,7 @@ def main(cfg: DictConfig): exploration_updates = min(num_updates, exploration_updates) if cfg.checkpoint.resume_from and not cfg.buffer.checkpoint: learning_starts += start_step - max_step_expl_decay = cfg.algo.player.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * fabric.world_size) + max_step_expl_decay = cfg.algo.player.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * world_size) if cfg.checkpoint.resume_from: player.expl_amount = polynomial_decay( expl_decay_steps, @@ -638,6 +639,8 @@ def main(cfg: DictConfig): is_exploring = True for update in range(start_step, num_updates + 1): + policy_step += cfg.env.num_envs * world_size + if update == exploration_updates: is_exploring = False player.actor = actor_task.module @@ -645,39 +648,40 @@ def main(cfg: DictConfig): if fabric.is_global_zero: test(copy.deepcopy(player), fabric, cfg, "zero-shot") - # Sample an action given the observation received by the environment - if update <= learning_starts and cfg.checkpoint.resume_from is None and "minedojo" not in cfg.env.id: - real_actions = actions = np.array(envs.action_space.sample()) - if not is_continuous: - actions = np.concatenate( - [ - F.one_hot(torch.tensor(act), act_dim).numpy() - for act, act_dim in zip(actions.reshape(len(actions_dim), -1), actions_dim) - ], - axis=-1, - ) - else: - with torch.no_grad(): - preprocessed_obs = {} - for k, v in obs.items(): - if k in cfg.cnn_keys.encoder: - preprocessed_obs[k] = v[None, ...].to(device) / 255 - 0.5 + # Measure environment interaction time: this considers both the model forward + # to get the action given the observation and the time taken into the environment + with timer("Time/env_interaction_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + # Sample an action given the observation received by the environment + if update <= learning_starts and cfg.checkpoint.resume_from is None and "minedojo" not in cfg.env.id: + real_actions = actions = np.array(envs.action_space.sample()) + if not is_continuous: + actions = np.concatenate( + [ + F.one_hot(torch.tensor(act), act_dim).numpy() + for act, act_dim in zip(actions.reshape(len(actions_dim), -1), actions_dim) + ], + axis=-1, + ) + else: + with torch.no_grad(): + preprocessed_obs = {} + for k, v in obs.items(): + if k in cfg.cnn_keys.encoder: + preprocessed_obs[k] = v[None, ...].to(device) / 255 - 0.5 + else: + preprocessed_obs[k] = v[None, ...].to(device) + mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} + if len(mask) == 0: + mask = None + real_actions = actions = player.get_exploration_action(preprocessed_obs, is_continuous, mask) + actions = torch.cat(actions, -1).cpu().numpy() + if is_continuous: + real_actions = torch.cat(real_actions, -1).cpu().numpy() else: - preprocessed_obs[k] = v[None, ...].to(device) - mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} - if len(mask) == 0: - mask = None - real_actions = actions = player.get_exploration_action(preprocessed_obs, is_continuous, mask) - actions = torch.cat(actions, -1).cpu().numpy() - if is_continuous: - real_actions = torch.cat(real_actions, -1).cpu().numpy() - else: - real_actions = np.array([real_act.cpu().argmax(dim=-1).numpy() for real_act in real_actions]) - - o, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) - dones = np.logical_or(dones, truncated) - - policy_step += cfg.env.num_envs * fabric.world_size + real_actions = np.array([real_act.cpu().argmax(dim=-1).numpy() for real_act in real_actions]) + + o, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) + dones = np.logical_or(dones, truncated) if "final_info" in infos: for i, agent_final_info in enumerate(infos["final_info"]): @@ -745,26 +749,29 @@ def main(cfg: DictConfig): n_samples=cfg.algo.per_rank_gradient_steps, ).to(device) distributed_sampler = BatchSampler(range(local_data.shape[0]), batch_size=1, drop_last=False) - for i in distributed_sampler: - train( - fabric, - world_model, - actor_task, - critic_task, - world_optimizer, - actor_task_optimizer, - critic_task_optimizer, - local_data[i].view(cfg.per_rank_sequence_length, cfg.per_rank_batch_size), - aggregator, - cfg, - ensembles=ensembles, - ensemble_optimizer=ensemble_optimizer, - actor_exploration=actor_exploration, - critic_exploration=critic_exploration, - actor_exploration_optimizer=actor_exploration_optimizer, - critic_exploration_optimizer=critic_exploration_optimizer, - is_exploring=is_exploring, - ) + # Start training + with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + for i in distributed_sampler: + train( + fabric, + world_model, + actor_task, + critic_task, + world_optimizer, + actor_task_optimizer, + critic_task_optimizer, + local_data[i].view(cfg.per_rank_sequence_length, cfg.per_rank_batch_size), + aggregator, + cfg, + ensembles=ensembles, + ensemble_optimizer=ensemble_optimizer, + actor_exploration=actor_exploration, + critic_exploration=critic_exploration, + actor_exploration_optimizer=actor_exploration_optimizer, + critic_exploration_optimizer=critic_exploration_optimizer, + is_exploring=is_exploring, + ) + train_step += 1 updates_before_training = cfg.algo.train_every // policy_steps_per_update if cfg.algo.player.expl_decay: expl_decay_steps += 1 @@ -775,12 +782,33 @@ def main(cfg: DictConfig): max_decay_steps=max_step_expl_decay, ) aggregator.update("Params/exploration_amout", player.expl_amount) - aggregator.update("Time/step_per_second", int(policy_step / (time.perf_counter() - start_time))) - if policy_step - last_log >= cfg.metric.log_every or cfg.dry_run: - last_log = policy_step - fabric.log_dict(aggregator.compute(), policy_step) + + # Log metrics + if policy_step - last_log >= cfg.metric.log_every or update == num_updates or cfg.dry_run: + # Sync distributed metrics + metrics_dict = aggregator.compute() + fabric.log_dict(metrics_dict, policy_step) aggregator.reset() + # Sync distributed timers + timer_metrics = timer.compute() + fabric.log( + "Time/sps_train", + (train_step - last_train) / timer_metrics["Time/train_time"], + policy_step, + ) + fabric.log( + "Time/sps_env_interaction", + ((policy_step - last_log) / world_size * cfg.env.action_repeat) + / timer_metrics["Time/env_interaction_time"], + policy_step, + ) + timer.reset() + + # Reset counters + last_log = policy_step + last_train = train_step + # Checkpoint Model if ( (cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every) @@ -798,8 +826,8 @@ def main(cfg: DictConfig): "critic_task_optimizer": critic_task_optimizer.state_dict(), "ensemble_optimizer": ensemble_optimizer.state_dict(), "expl_decay_steps": expl_decay_steps, - "update": update * fabric.world_size, - "batch_size": cfg.per_rank_batch_size * fabric.world_size, + "update": update * world_size, + "batch_size": cfg.per_rank_batch_size * world_size, "actor_exploration": actor_exploration.state_dict(), "critic_exploration": critic_exploration.state_dict(), "actor_exploration_optimizer": actor_exploration_optimizer.state_dict(), diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2.py b/sheeprl/algos/p2e_dv2/p2e_dv2.py index 8feeaa34..1bdd7059 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2.py @@ -1,7 +1,6 @@ import copy import os import pathlib -import time import warnings from typing import Dict, Sequence @@ -20,7 +19,7 @@ from torch import Tensor, nn from torch.distributions import Bernoulli, Distribution, Independent, Normal, OneHotCategorical from torch.utils.data import BatchSampler -from torchmetrics import MeanMetric +from torchmetrics import MeanMetric, SumMetric from sheeprl.algos.dreamer_v2.agent import PlayerDV2, WorldModel from sheeprl.algos.dreamer_v2.loss import reconstruction_loss @@ -33,7 +32,8 @@ from sheeprl.utils.logger import create_tensorboard_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm -from sheeprl.utils.utils import polynomial_decay +from sheeprl.utils.timer import timer +from sheeprl.utils.utils import polynomial_decay, print_config # Decomment the following line if you are using MineDojo on an headless machine # os.environ["MINEDOJO_HEADLESS"] = "1" @@ -466,6 +466,8 @@ def train( @register_algorithm() @hydra.main(version_base=None, config_path="../../configs", config_name="config") def main(cfg: DictConfig): + print_config(cfg) + # These arguments cannot be changed cfg.env.screen_size = 64 cfg.env.frame_stack = 1 @@ -474,8 +476,9 @@ def main(cfg: DictConfig): fabric = Fabric(callbacks=[CheckpointCallback()]) if not _is_using_cli(): fabric.launch() - rank = fabric.global_rank device = fabric.device + rank = fabric.global_rank + world_size = fabric.world_size fabric.seed_everything(cfg.seed) torch.backends.cudnn.deterministic = cfg.torch_deterministic @@ -486,7 +489,7 @@ def main(cfg: DictConfig): ckpt_path = pathlib.Path(cfg.checkpoint.resume_from) cfg = OmegaConf.load(ckpt_path.parent.parent.parent / ".hydra" / "config.yaml") cfg.checkpoint.resume_from = str(ckpt_path) - cfg.per_rank_batch_size = state["batch_size"] // fabric.world_size + cfg.per_rank_batch_size = state["batch_size"] // world_size cfg.root_dir = root_dir cfg.run_name = f"resume_from_checkpoint_{run_name}" @@ -661,7 +664,6 @@ def main(cfg: DictConfig): { "Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), "Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Time/step_per_second": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), "Loss/world_model_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), "Loss/value_loss_task": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), "Loss/policy_loss_task": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), @@ -690,7 +692,7 @@ def main(cfg: DictConfig): aggregator.to(device) # Local data - buffer_size = cfg.buffer.size // int(cfg.env.num_envs * fabric.world_size) if not cfg.dry_run else 4 + buffer_size = cfg.buffer.size // int(cfg.env.num_envs * world_size) if not cfg.dry_run else 4 buffer_type = cfg.buffer.type.lower() if buffer_type == "sequential": rb = AsyncReplayBuffer( @@ -712,28 +714,29 @@ def main(cfg: DictConfig): else: raise ValueError(f"Unrecognized buffer type: must be one of `sequential` or `episode`, received: {buffer_type}") if cfg.checkpoint.resume_from and cfg.buffer.checkpoint: - if isinstance(state["rb"], list) and fabric.world_size == len(state["rb"]): + if isinstance(state["rb"], list) and world_size == len(state["rb"]): rb = state["rb"][fabric.global_rank] elif isinstance(state["rb"], AsyncReplayBuffer): rb = state["rb"] else: - raise RuntimeError(f"Given {len(state['rb'])}, but {fabric.world_size} processes are instantiated") + raise RuntimeError(f"Given {len(state['rb'])}, but {world_size} processes are instantiated") step_data = TensorDict({}, batch_size=[cfg.env.num_envs], device="cpu") expl_decay_steps = state["expl_decay_steps"] if cfg.checkpoint.resume_from else 0 # Global variables - start_step = state["update"] // fabric.world_size if cfg.checkpoint.resume_from else 1 + train_step = 0 + last_train = 0 + start_step = state["update"] // world_size if cfg.checkpoint.resume_from else 1 policy_step = state["update"] * cfg.env.num_envs if cfg.checkpoint.resume_from else 0 last_log = state["last_log"] if cfg.checkpoint.resume_from else 0 last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 - start_time = time.perf_counter() - policy_steps_per_update = int(cfg.env.num_envs * fabric.world_size) + policy_steps_per_update = int(cfg.env.num_envs * world_size) updates_before_training = cfg.algo.train_every // policy_steps_per_update if not cfg.dry_run else 0 num_updates = cfg.total_steps // policy_steps_per_update if not cfg.dry_run else 1 learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 if cfg.checkpoint.resume_from and not cfg.buffer.checkpoint: learning_starts += start_step - max_step_expl_decay = cfg.algo.player.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * fabric.world_size) + max_step_expl_decay = cfg.algo.player.max_step_expl_decay // (cfg.algo.per_rank_gradient_steps * world_size) if cfg.checkpoint.resume_from: player.expl_amount = polynomial_decay( expl_decay_steps, @@ -788,6 +791,8 @@ def main(cfg: DictConfig): per_rank_gradient_steps = 0 is_exploring = True for update in range(start_step, num_updates + 1): + policy_step += cfg.env.num_envs * world_size + if update == exploration_updates: is_exploring = False player.actor = actor_task.module @@ -795,42 +800,43 @@ def main(cfg: DictConfig): if fabric.is_global_zero: test(copy.deepcopy(player), fabric, cfg, "zero-shot") - # Sample an action given the observation received by the environment - if update <= learning_starts and cfg.checkpoint.resume_from is None and "minedojo" not in cfg.env.id: - real_actions = actions = np.array(envs.action_space.sample()) - if not is_continuous: - actions = np.concatenate( - [ - F.one_hot(torch.tensor(act), act_dim).numpy() - for act, act_dim in zip(actions.reshape(len(actions_dim), -1), actions_dim) - ], - axis=-1, - ) - else: - with torch.no_grad(): - preprocessed_obs = {} - for k, v in obs.items(): - if k in cfg.cnn_keys.encoder: - preprocessed_obs[k] = v[None, ...].to(device) / 255 - 0.5 + # Measure environment interaction time: this considers both the model forward + # to get the action given the observation and the time taken into the environment + with timer("Time/env_interaction_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + # Sample an action given the observation received by the environment + if update <= learning_starts and cfg.checkpoint.resume_from is None and "minedojo" not in cfg.env.id: + real_actions = actions = np.array(envs.action_space.sample()) + if not is_continuous: + actions = np.concatenate( + [ + F.one_hot(torch.tensor(act), act_dim).numpy() + for act, act_dim in zip(actions.reshape(len(actions_dim), -1), actions_dim) + ], + axis=-1, + ) + else: + with torch.no_grad(): + preprocessed_obs = {} + for k, v in obs.items(): + if k in cfg.cnn_keys.encoder: + preprocessed_obs[k] = v[None, ...].to(device) / 255 - 0.5 + else: + preprocessed_obs[k] = v[None, ...].to(device) + mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} + if len(mask) == 0: + mask = None + real_actions = actions = player.get_exploration_action(preprocessed_obs, is_continuous, mask) + actions = torch.cat(actions, -1).cpu().numpy() + if is_continuous: + real_actions = torch.cat(real_actions, -1).cpu().numpy() else: - preprocessed_obs[k] = v[None, ...].to(device) - mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} - if len(mask) == 0: - mask = None - real_actions = actions = player.get_exploration_action(preprocessed_obs, is_continuous, mask) - actions = torch.cat(actions, -1).cpu().numpy() - if is_continuous: - real_actions = torch.cat(real_actions, -1).cpu().numpy() - else: - real_actions = np.array([real_act.cpu().argmax(dim=-1).numpy() for real_act in real_actions]) - - step_data["is_first"] = copy.deepcopy(step_data["dones"]) - o, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) - dones = np.logical_or(dones, truncated) - if cfg.dry_run and buffer_type == "episode": - dones = np.ones_like(dones) - - policy_step += cfg.env.num_envs * fabric.world_size + real_actions = np.array([real_act.cpu().argmax(dim=-1).numpy() for real_act in real_actions]) + + step_data["is_first"] = copy.deepcopy(step_data["dones"]) + o, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) + dones = np.logical_or(dones, truncated) + if cfg.dry_run and buffer_type == "episode": + dones = np.ones_like(dones) if "final_info" in infos: for i, agent_final_info in enumerate(infos["final_info"]): @@ -919,35 +925,40 @@ def main(cfg: DictConfig): prioritize_ends=cfg.buffer.prioritize_ends, ).to(device) distributed_sampler = BatchSampler(range(local_data.shape[0]), batch_size=1, drop_last=False) - for i in distributed_sampler: - if per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: - for cp, tcp in zip(critic_task.module.parameters(), target_critic_task.parameters()): - tcp.data.copy_(cp.data) - for cp, tcp in zip(critic_exploration.module.parameters(), target_critic_exploration.parameters()): - tcp.data.copy_(cp.data) - train( - fabric, - world_model, - actor_task, - critic_task, - target_critic_task, - world_optimizer, - actor_task_optimizer, - critic_task_optimizer, - local_data[i].view(cfg.per_rank_sequence_length, cfg.per_rank_batch_size), - aggregator, - cfg, - ensembles=ensembles, - ensemble_optimizer=ensemble_optimizer, - actor_exploration=actor_exploration, - critic_exploration=critic_exploration, - target_critic_exploration=target_critic_exploration, - actor_exploration_optimizer=actor_exploration_optimizer, - critic_exploration_optimizer=critic_exploration_optimizer, - is_continuous=is_continuous, - actions_dim=actions_dim, - is_exploring=is_exploring, - ) + # Start training + with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + for i in distributed_sampler: + if per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: + for cp, tcp in zip(critic_task.module.parameters(), target_critic_task.parameters()): + tcp.data.copy_(cp.data) + for cp, tcp in zip( + critic_exploration.module.parameters(), target_critic_exploration.parameters() + ): + tcp.data.copy_(cp.data) + train( + fabric, + world_model, + actor_task, + critic_task, + target_critic_task, + world_optimizer, + actor_task_optimizer, + critic_task_optimizer, + local_data[i].view(cfg.per_rank_sequence_length, cfg.per_rank_batch_size), + aggregator, + cfg, + ensembles=ensembles, + ensemble_optimizer=ensemble_optimizer, + actor_exploration=actor_exploration, + critic_exploration=critic_exploration, + target_critic_exploration=target_critic_exploration, + actor_exploration_optimizer=actor_exploration_optimizer, + critic_exploration_optimizer=critic_exploration_optimizer, + is_continuous=is_continuous, + actions_dim=actions_dim, + is_exploring=is_exploring, + ) + train_step += 1 updates_before_training = cfg.algo.train_every // policy_steps_per_update if cfg.algo.player.expl_decay: expl_decay_steps += 1 @@ -958,12 +969,33 @@ def main(cfg: DictConfig): max_decay_steps=max_step_expl_decay, ) aggregator.update("Params/exploration_amout", player.expl_amount) - aggregator.update("Time/step_per_second", int(policy_step / (time.perf_counter() - start_time))) - if policy_step - last_log >= cfg.metric.log_every or cfg.dry_run: - last_log = policy_step - fabric.log_dict(aggregator.compute(), policy_step) + + # Log metrics + if policy_step - last_log >= cfg.metric.log_every or update == num_updates or cfg.dry_run: + # Sync distributed metrics + metrics_dict = aggregator.compute() + fabric.log_dict(metrics_dict, policy_step) aggregator.reset() + # Sync distributed timers + timer_metrics = timer.compute() + fabric.log( + "Time/sps_train", + (train_step - last_train) / timer_metrics["Time/train_time"], + policy_step, + ) + fabric.log( + "Time/sps_env_interaction", + ((policy_step - last_log) / world_size * cfg.env.action_repeat) + / timer_metrics["Time/env_interaction_time"], + policy_step, + ) + timer.reset() + + # Reset counters + last_log = policy_step + last_train = train_step + # Checkpoint Model if ( (cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every) @@ -982,8 +1014,8 @@ def main(cfg: DictConfig): "critic_task_optimizer": critic_task_optimizer.state_dict(), "ensemble_optimizer": ensemble_optimizer.state_dict(), "expl_decay_steps": expl_decay_steps, - "update": update * fabric.world_size, - "batch_size": cfg.per_rank_batch_size * fabric.world_size, + "update": update * world_size, + "batch_size": cfg.per_rank_batch_size * world_size, "actor_exploration": actor_exploration.state_dict(), "critic_exploration": critic_exploration.state_dict(), "target_critic_exploration": target_critic_exploration.state_dict(), From 70fd28b042f1df9f397d07e7d5dd02639f0e9474 Mon Sep 17 00:00:00 2001 From: belerico_t Date: Wed, 13 Sep 2023 12:41:29 +0200 Subject: [PATCH 24/31] [skip ci] Removed unused metric --- howto/register_new_algorithm.md | 2 -- sheeprl/algos/droq/droq.py | 1 - 2 files changed, 3 deletions(-) diff --git a/howto/register_new_algorithm.md b/howto/register_new_algorithm.md index 3ee1534c..d8f9e4ef 100644 --- a/howto/register_new_algorithm.md +++ b/howto/register_new_algorithm.md @@ -129,7 +129,6 @@ def main(cfg: DictConfig): { "Rewards/rew_avg": MeanMetric(), "Game/ep_len_avg": MeanMetric(), - "Time/step_per_second": MeanMetric(), "Loss/value_loss": MeanMetric(), "Loss/policy_loss": MeanMetric(), "Loss/entropy_loss": MeanMetric(), @@ -222,7 +221,6 @@ def main(cfg: DictConfig): # Log metrics metrics_dict = aggregator.compute() - fabric.log("Time/step_per_second", int(global_step / (time.perf_counter() - start_time)), global_step) fabric.log_dict(metrics_dict, global_step) aggregator.reset() diff --git a/sheeprl/algos/droq/droq.py b/sheeprl/algos/droq/droq.py index 346eaa12..3ad3e221 100644 --- a/sheeprl/algos/droq/droq.py +++ b/sheeprl/algos/droq/droq.py @@ -212,7 +212,6 @@ def main(cfg: DictConfig): { "Rewards/rew_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), "Game/ep_len_avg": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), - "Time/step_per_second": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), "Loss/value_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), "Loss/policy_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), "Loss/alpha_loss": MeanMetric(sync_on_compute=cfg.metric.sync_on_compute), From c8db8fb716f7578e33497555e3fd932f5143699a Mon Sep 17 00:00:00 2001 From: belerico_t Date: Wed, 13 Sep 2023 17:26:26 +0200 Subject: [PATCH 25/31] Add env_done check on DiambraWrapper to fix evaluation --- sheeprl/configs/env/diambra.yaml | 2 +- sheeprl/envs/{diambra_wrapper.py => diambra.py} | 2 +- sheeprl/utils/env.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) rename sheeprl/envs/{diambra_wrapper.py => diambra.py} (97%) diff --git a/sheeprl/configs/env/diambra.yaml b/sheeprl/configs/env/diambra.yaml index 242d11e9..d13d87ef 100644 --- a/sheeprl/configs/env/diambra.yaml +++ b/sheeprl/configs/env/diambra.yaml @@ -8,7 +8,7 @@ frame_stack: 4 sync_env: True env: - _target_: sheeprl.envs.diambra_wrapper.DiambraWrapper + _target_: sheeprl.envs.diambra.DiambraWrapper id: ${env.id} action_space: discrete screen_size: ${env.screen_size} diff --git a/sheeprl/envs/diambra_wrapper.py b/sheeprl/envs/diambra.py similarity index 97% rename from sheeprl/envs/diambra_wrapper.py rename to sheeprl/envs/diambra.py index e35026a5..7aab7d6d 100644 --- a/sheeprl/envs/diambra_wrapper.py +++ b/sheeprl/envs/diambra.py @@ -105,7 +105,7 @@ def _convert_obs(self, obs: Dict[str, Union[int, np.ndarray]]) -> Dict[str, np.n def step(self, action: Any) -> Tuple[Any, SupportsFloat, bool, bool, Dict[str, Any]]: obs, reward, done, infos = self._env.step(action) infos["env_domain"] = "DIAMBRA" - return self._convert_obs(obs), reward, done, False, infos + return self._convert_obs(obs), reward, done or infos.get("env_done", False), False, infos def render(self, mode: str = "rgb_array", **kwargs) -> Optional[Union[RenderFrame, List[RenderFrame]]]: return self._env.render("rgb_array") diff --git a/sheeprl/utils/env.py b/sheeprl/utils/env.py index 58e6434a..8cd07e5c 100644 --- a/sheeprl/utils/env.py +++ b/sheeprl/utils/env.py @@ -12,7 +12,7 @@ from sheeprl.utils.imports import _IS_DIAMBRA_ARENA_AVAILABLE, _IS_DIAMBRA_AVAILABLE, _IS_DMC_AVAILABLE if _IS_DIAMBRA_ARENA_AVAILABLE and _IS_DIAMBRA_AVAILABLE: - from sheeprl.envs.diambra_wrapper import DiambraWrapper + from sheeprl.envs.diambra import DiambraWrapper if _IS_DMC_AVAILABLE: from sheeprl.envs.dmc import DMCWrapper From 5a50c5b995a71f3772b39568bb596bd07bc24bad Mon Sep 17 00:00:00 2001 From: belerico_t Date: Thu, 14 Sep 2023 17:25:57 +0200 Subject: [PATCH 26/31] Update log cumulative rew on rank-0 --- sheeprl/algos/dreamer_v2/dreamer_v2.py | 14 +++++++------- sheeprl/algos/dreamer_v3/dreamer_v3.py | 14 +++++++------- sheeprl/algos/p2e_dv1/p2e_dv1.py | 14 +++++++------- sheeprl/algos/p2e_dv2/p2e_dv2.py | 14 +++++++------- 4 files changed, 28 insertions(+), 28 deletions(-) diff --git a/sheeprl/algos/dreamer_v2/dreamer_v2.py b/sheeprl/algos/dreamer_v2/dreamer_v2.py index fc3f54f7..026307f6 100644 --- a/sheeprl/algos/dreamer_v2/dreamer_v2.py +++ b/sheeprl/algos/dreamer_v2/dreamer_v2.py @@ -664,13 +664,13 @@ def main(cfg: DictConfig): dones = np.ones_like(dones) if "final_info" in infos: - for i, agent_final_info in enumerate(infos["final_info"]): - if agent_final_info is not None and "episode" in agent_final_info: - fabric.print( - f"Rank-0: policy_step={policy_step}, reward_env_{i}={agent_final_info['episode']['r'][0]}" - ) - aggregator.update("Rewards/rew_avg", agent_final_info["episode"]["r"][0]) - aggregator.update("Game/ep_len_avg", agent_final_info["episode"]["l"][0]) + for i, agent_ep_info in enumerate(infos["final_info"]): + if agent_ep_info is not None: + ep_rew = agent_ep_info["episode"]["r"] + ep_len = agent_ep_info["episode"]["l"] + aggregator.update("Rewards/rew_avg", ep_rew) + aggregator.update("Game/ep_len_avg", ep_len) + fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") # Save the real next observation real_next_obs = copy.deepcopy(o) diff --git a/sheeprl/algos/dreamer_v3/dreamer_v3.py b/sheeprl/algos/dreamer_v3/dreamer_v3.py index 977396aa..f3c85e3f 100644 --- a/sheeprl/algos/dreamer_v3/dreamer_v3.py +++ b/sheeprl/algos/dreamer_v3/dreamer_v3.py @@ -618,13 +618,13 @@ def main(cfg: DictConfig): step_data["is_first"][i] = torch.ones_like(step_data["is_first"][i]) if "final_info" in infos: - for i, agent_final_info in enumerate(infos["final_info"]): - if agent_final_info is not None and "episode" in agent_final_info: - fabric.print( - f"Rank-0: policy_step={policy_step}, reward_env_{i}={agent_final_info['episode']['r'][0]}" - ) - aggregator.update("Rewards/rew_avg", agent_final_info["episode"]["r"][0]) - aggregator.update("Game/ep_len_avg", agent_final_info["episode"]["l"][0]) + for i, agent_ep_info in enumerate(infos["final_info"]): + if agent_ep_info is not None: + ep_rew = agent_ep_info["episode"]["r"] + ep_len = agent_ep_info["episode"]["l"] + aggregator.update("Rewards/rew_avg", ep_rew) + aggregator.update("Game/ep_len_avg", ep_len) + fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") # Save the real next observation real_next_obs = copy.deepcopy(o) diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1.py b/sheeprl/algos/p2e_dv1/p2e_dv1.py index 3d1a0149..2ab82df0 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1.py @@ -684,13 +684,13 @@ def main(cfg: DictConfig): dones = np.logical_or(dones, truncated) if "final_info" in infos: - for i, agent_final_info in enumerate(infos["final_info"]): - if agent_final_info is not None and "episode" in agent_final_info: - fabric.print( - f"Rank-0: policy_step={policy_step}, reward_env_{i}={agent_final_info['episode']['r'][0]}" - ) - aggregator.update("Rewards/rew_avg", agent_final_info["episode"]["r"][0]) - aggregator.update("Game/ep_len_avg", agent_final_info["episode"]["l"][0]) + for i, agent_ep_info in enumerate(infos["final_info"]): + if agent_ep_info is not None: + ep_rew = agent_ep_info["episode"]["r"] + ep_len = agent_ep_info["episode"]["l"] + aggregator.update("Rewards/rew_avg", ep_rew) + aggregator.update("Game/ep_len_avg", ep_len) + fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") # Save the real next observation real_next_obs = copy.deepcopy(o) diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2.py b/sheeprl/algos/p2e_dv2/p2e_dv2.py index 1bdd7059..e8a76b61 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2.py @@ -839,13 +839,13 @@ def main(cfg: DictConfig): dones = np.ones_like(dones) if "final_info" in infos: - for i, agent_final_info in enumerate(infos["final_info"]): - if agent_final_info is not None and "episode" in agent_final_info: - fabric.print( - f"Rank-0: policy_step={policy_step}, reward_env_{i}={agent_final_info['episode']['r'][0]}" - ) - aggregator.update("Rewards/rew_avg", agent_final_info["episode"]["r"][0]) - aggregator.update("Game/ep_len_avg", agent_final_info["episode"]["l"][0]) + for i, agent_ep_info in enumerate(infos["final_info"]): + if agent_ep_info is not None: + ep_rew = agent_ep_info["episode"]["r"] + ep_len = agent_ep_info["episode"]["l"] + aggregator.update("Rewards/rew_avg", ep_rew) + aggregator.update("Game/ep_len_avg", ep_len) + fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}") # Save the real next observation real_next_obs = copy.deepcopy(o) From 7ab278784d93c421f866bff6b8e1dd0be486e028 Mon Sep 17 00:00:00 2001 From: belerico_t Date: Fri, 15 Sep 2023 09:53:26 +0200 Subject: [PATCH 27/31] Fix sps: sps_train is computed globally while sps_env is computed locally --- sheeprl/algos/dreamer_v1/dreamer_v1.py | 4 ++-- sheeprl/algos/dreamer_v2/dreamer_v2.py | 4 ++-- sheeprl/algos/dreamer_v3/dreamer_v3.py | 4 ++-- sheeprl/algos/droq/droq.py | 4 ++-- sheeprl/algos/p2e_dv1/p2e_dv1.py | 4 ++-- sheeprl/algos/p2e_dv2/p2e_dv2.py | 4 ++-- sheeprl/algos/ppo/ppo.py | 8 +++++--- sheeprl/algos/ppo/ppo_decoupled.py | 7 +++++-- sheeprl/algos/ppo_recurrent/ppo_recurrent.py | 8 +++++--- sheeprl/algos/sac/sac.py | 4 ++-- sheeprl/algos/sac/sac_decoupled.py | 4 ++-- sheeprl/algos/sac_ae/sac_ae.py | 4 ++-- 12 files changed, 33 insertions(+), 26 deletions(-) diff --git a/sheeprl/algos/dreamer_v1/dreamer_v1.py b/sheeprl/algos/dreamer_v1/dreamer_v1.py index 1938354c..1eea1737 100644 --- a/sheeprl/algos/dreamer_v1/dreamer_v1.py +++ b/sheeprl/algos/dreamer_v1/dreamer_v1.py @@ -581,7 +581,7 @@ def main(cfg: DictConfig): # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment - with timer("Time/env_interaction_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): # Sample an action given the observation received by the environment if update <= learning_starts and cfg.checkpoint.resume_from is None and "minedojo" not in cfg.env.id: real_actions = actions = np.array(envs.action_space.sample()) @@ -692,7 +692,7 @@ def main(cfg: DictConfig): aggregator, cfg, ) - train_step += 1 + train_step += world_size updates_before_training = cfg.algo.train_every // policy_steps_per_update if cfg.algo.player.expl_decay: expl_decay_steps += 1 diff --git a/sheeprl/algos/dreamer_v2/dreamer_v2.py b/sheeprl/algos/dreamer_v2/dreamer_v2.py index 026307f6..0d7217eb 100644 --- a/sheeprl/algos/dreamer_v2/dreamer_v2.py +++ b/sheeprl/algos/dreamer_v2/dreamer_v2.py @@ -623,7 +623,7 @@ def main(cfg: DictConfig): # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment - with timer("Time/env_interaction_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): # Sample an action given the observation received by the environment if ( update <= learning_starts @@ -770,7 +770,7 @@ def main(cfg: DictConfig): actions_dim, ) per_rank_gradient_steps += 1 - train_step += 1 + train_step += world_size updates_before_training = cfg.algo.train_every // policy_steps_per_update if cfg.algo.player.expl_decay: expl_decay_steps += 1 diff --git a/sheeprl/algos/dreamer_v3/dreamer_v3.py b/sheeprl/algos/dreamer_v3/dreamer_v3.py index f3c85e3f..e1abe5c4 100644 --- a/sheeprl/algos/dreamer_v3/dreamer_v3.py +++ b/sheeprl/algos/dreamer_v3/dreamer_v3.py @@ -566,7 +566,7 @@ def main(cfg: DictConfig): # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment - with timer("Time/env_interaction_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): # Sample an action given the observation received by the environment if ( update <= learning_starts @@ -708,7 +708,7 @@ def main(cfg: DictConfig): moments, ) per_rank_gradient_steps += 1 - train_step += 1 + train_step += world_size updates_before_training = cfg.algo.train_every // policy_steps_per_update if cfg.algo.player.expl_decay: expl_decay_steps += 1 diff --git a/sheeprl/algos/droq/droq.py b/sheeprl/algos/droq/droq.py index 3ad3e221..fdd263cd 100644 --- a/sheeprl/algos/droq/droq.py +++ b/sheeprl/algos/droq/droq.py @@ -264,7 +264,7 @@ def main(cfg: DictConfig): # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment - with timer("Time/env_interaction_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): with torch.no_grad(): actions, _ = actor.module(obs) actions = actions.cpu().numpy() @@ -308,7 +308,7 @@ def main(cfg: DictConfig): # Train the agent if update > learning_starts: train(fabric, agent, actor_optimizer, qf_optimizer, alpha_optimizer, rb, aggregator, cfg) - train_step += 1 + train_step += world_size # Log metrics if policy_step - last_log >= cfg.metric.log_every or update == num_updates or cfg.dry_run: diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1.py b/sheeprl/algos/p2e_dv1/p2e_dv1.py index 2ab82df0..f987d558 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1.py @@ -650,7 +650,7 @@ def main(cfg: DictConfig): # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment - with timer("Time/env_interaction_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): # Sample an action given the observation received by the environment if update <= learning_starts and cfg.checkpoint.resume_from is None and "minedojo" not in cfg.env.id: real_actions = actions = np.array(envs.action_space.sample()) @@ -771,7 +771,7 @@ def main(cfg: DictConfig): critic_exploration_optimizer=critic_exploration_optimizer, is_exploring=is_exploring, ) - train_step += 1 + train_step += world_size updates_before_training = cfg.algo.train_every // policy_steps_per_update if cfg.algo.player.expl_decay: expl_decay_steps += 1 diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2.py b/sheeprl/algos/p2e_dv2/p2e_dv2.py index e8a76b61..d1d48a0b 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2.py @@ -802,7 +802,7 @@ def main(cfg: DictConfig): # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment - with timer("Time/env_interaction_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): # Sample an action given the observation received by the environment if update <= learning_starts and cfg.checkpoint.resume_from is None and "minedojo" not in cfg.env.id: real_actions = actions = np.array(envs.action_space.sample()) @@ -958,7 +958,7 @@ def main(cfg: DictConfig): actions_dim=actions_dim, is_exploring=is_exploring, ) - train_step += 1 + train_step += world_size updates_before_training = cfg.algo.train_every // policy_steps_per_update if cfg.algo.player.expl_decay: expl_decay_steps += 1 diff --git a/sheeprl/algos/ppo/ppo.py b/sheeprl/algos/ppo/ppo.py index 4559479e..632e8ba3 100644 --- a/sheeprl/algos/ppo/ppo.py +++ b/sheeprl/algos/ppo/ppo.py @@ -221,6 +221,7 @@ def main(cfg: DictConfig): # Global variables last_log = 0 last_train = 0 + train_step = 0 policy_step = 0 last_checkpoint = 0 policy_steps_per_update = int(cfg.env.num_envs * cfg.algo.rollout_steps * world_size) @@ -268,7 +269,7 @@ def main(cfg: DictConfig): # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment - with timer("Time/env_interaction_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): with torch.no_grad(): # Sample an action given the observation received by the environment normalized_obs = { @@ -358,6 +359,7 @@ def main(cfg: DictConfig): with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): train(fabric, agent, optimizer, gathered_data, aggregator, cfg) + train_step += world_size if cfg.algo.anneal_lr: fabric.log("Info/learning_rate", scheduler.get_last_lr()[0], policy_step) @@ -388,7 +390,7 @@ def main(cfg: DictConfig): timer_metrics = timer.compute() fabric.log( "Time/sps_train", - (update - last_train) / timer_metrics["Time/train_time"], + (train_step - last_train) / timer_metrics["Time/train_time"], policy_step, ) fabric.log( @@ -400,8 +402,8 @@ def main(cfg: DictConfig): timer.reset() # Reset counters - last_train = update last_log = policy_step + last_train = train_step # Checkpoint model if ( diff --git a/sheeprl/algos/ppo/ppo_decoupled.py b/sheeprl/algos/ppo/ppo_decoupled.py index f8418161..862ce9dc 100644 --- a/sheeprl/algos/ppo/ppo_decoupled.py +++ b/sheeprl/algos/ppo/ppo_decoupled.py @@ -349,6 +349,7 @@ def trainer( optimization_pg: CollectibleGroup, ): global_rank = world_collective.rank + group_world_size = world_collective.world_size - 1 # Receive (possibly updated, by the make_dict_env method for example) cfg from the player data = [None] @@ -406,6 +407,7 @@ def trainer( update = 0 last_log = 0 last_train = 0 + train_step = 0 policy_step = 0 last_checkpoint = 0 initial_ent_coef = copy.deepcopy(cfg.algo.ent_coef) @@ -428,6 +430,7 @@ def trainer( return data = make_tensordict(data, device=device) update += 1 + train_step += group_world_size policy_step += cfg.env.num_envs * cfg.algo.rollout_steps # Prepare sampler @@ -504,7 +507,7 @@ def trainer( # Sync distributed timers timers = timer.compute() - metrics.update({"Time/sps_train": (update - last_train) / timers["Time/train_time"]}) + metrics.update({"Time/sps_train": (train_step - last_train) / timers["Time/train_time"]}) timer.reset() # Send metrics to the player @@ -520,8 +523,8 @@ def trainer( ) # Broadcast metrics: fake send with object list between rank-0 and rank-1 # Reset counters - last_train = update last_log = policy_step + last_train = train_step if cfg.algo.anneal_lr: scheduler.step() diff --git a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py index 352e3fa4..a37fd05b 100644 --- a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py +++ b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py @@ -198,6 +198,7 @@ def main(cfg: DictConfig): # Global variables last_log = 0 last_train = 0 + train_step = 0 policy_step = 0 last_checkpoint = 0 policy_steps_per_update = int(cfg.env.num_envs * cfg.algo.rollout_steps * world_size) @@ -237,7 +238,7 @@ def main(cfg: DictConfig): # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment - with timer("Time/env_interaction_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): with torch.no_grad(): # Sample an action given the observation received by the environment action_logits, values, state = agent.module(next_obs, state=next_state) @@ -335,6 +336,7 @@ def main(cfg: DictConfig): with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): train(fabric, agent, optimizer, padded_sequences, aggregator, cfg) + train_step += world_size if cfg.algo.anneal_lr: fabric.log("Info/learning_rate", scheduler.get_last_lr()[0], policy_step) @@ -365,7 +367,7 @@ def main(cfg: DictConfig): timer_metrics = timer.compute() fabric.log( "Time/sps_train", - (update - last_train) / timer_metrics["Time/train_time"], + (train_step - last_train) / timer_metrics["Time/train_time"], policy_step, ) fabric.log( @@ -377,8 +379,8 @@ def main(cfg: DictConfig): timer.reset() # Reset counters - last_train = update last_log = policy_step + last_train = train_step # Checkpoint model if ( diff --git a/sheeprl/algos/sac/sac.py b/sheeprl/algos/sac/sac.py index 98ace69c..c7f10aea 100644 --- a/sheeprl/algos/sac/sac.py +++ b/sheeprl/algos/sac/sac.py @@ -212,7 +212,7 @@ def main(cfg: DictConfig): # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment - with timer("Time/env_interaction_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): if update <= learning_starts: actions = envs.action_space.sample() else: @@ -300,7 +300,7 @@ def main(cfg: DictConfig): cfg, policy_steps_per_update, ) - train_step += 1 + train_step += world_size # Log metrics if policy_step - last_log >= cfg.metric.log_every or update == num_updates or cfg.dry_run: diff --git a/sheeprl/algos/sac/sac_decoupled.py b/sheeprl/algos/sac/sac_decoupled.py index 2aeb4520..b4e21d29 100644 --- a/sheeprl/algos/sac/sac_decoupled.py +++ b/sheeprl/algos/sac/sac_decoupled.py @@ -296,7 +296,7 @@ def trainer( optimization_pg: CollectibleGroup, ): global_rank = world_collective.rank - global_rank - 1 + group_world_size = world_collective.world_size - 1 # Receive (possibly updated, by the make_dict_env method for example) cfg from the player data = [None] @@ -412,7 +412,7 @@ def trainer( policy_steps_per_update, group=optimization_pg, ) - train_step += 1 + train_step += group_world_size if global_rank == 1: player_trainer_collective.broadcast( diff --git a/sheeprl/algos/sac_ae/sac_ae.py b/sheeprl/algos/sac_ae/sac_ae.py index 1333dea7..88a6cba2 100644 --- a/sheeprl/algos/sac_ae/sac_ae.py +++ b/sheeprl/algos/sac_ae/sac_ae.py @@ -393,7 +393,7 @@ def main(cfg: DictConfig): # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment - with timer("Time/env_interaction_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): if update < learning_starts: actions = envs.action_space.sample() else: @@ -497,7 +497,7 @@ def main(cfg: DictConfig): cfg, policy_steps_per_update, ) - train_step += 1 + train_step += world_size # Log metrics if policy_step - last_log >= cfg.metric.log_every or update == num_updates or cfg.dry_run: From 8793ae4b48e6f77996b6d4a56fe7ff8dd8dc956a Mon Sep 17 00:00:00 2001 From: Federico Belotti Date: Fri, 15 Sep 2023 10:17:44 +0200 Subject: [PATCH 28/31] Bump to v0.3.2 --- sheeprl/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sheeprl/__init__.py b/sheeprl/__init__.py index a53b76a4..9d81d0a6 100644 --- a/sheeprl/__init__.py +++ b/sheeprl/__init__.py @@ -31,4 +31,4 @@ np.int = np.int64 np.bool = bool -__version__ = "0.3.0" +__version__ = "0.3.2" From a6f3ff70df77f3441b57065e6715a4aab076a6b6 Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Fri, 15 Sep 2023 11:18:30 +0200 Subject: [PATCH 29/31] fix: added close method to diambra wrapper --- sheeprl/envs/diambra.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sheeprl/envs/diambra.py b/sheeprl/envs/diambra.py index 7aab7d6d..96697a99 100644 --- a/sheeprl/envs/diambra.py +++ b/sheeprl/envs/diambra.py @@ -114,3 +114,6 @@ def reset( self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None ) -> Tuple[Any, Dict[str, Any]]: return self._convert_obs(self._env.reset()), {"env_domain": "DIAMBRA"} + + def close(self) -> None: + self._env.close() From 0d1c80c3bba62426ff3b74c27f888b79607584a5 Mon Sep 17 00:00:00 2001 From: belerico_t Date: Sat, 16 Sep 2023 12:52:53 +0200 Subject: [PATCH 30/31] Removed mujoco frameskip = 0 + removed frame_skip from dmc --- sheeprl/configs/env/dmc.yaml | 3 +-- sheeprl/envs/dmc.py | 49 +++++++++++++++++------------------- sheeprl/utils/env.py | 32 ++++++++++------------- 3 files changed, 37 insertions(+), 47 deletions(-) diff --git a/sheeprl/configs/env/dmc.yaml b/sheeprl/configs/env/dmc.yaml index 723ca036..2dcf3eb8 100644 --- a/sheeprl/configs/env/dmc.yaml +++ b/sheeprl/configs/env/dmc.yaml @@ -9,9 +9,8 @@ max_episode_steps: 1000 env: _target_: sheeprl.envs.dmc.DMCWrapper id: ${env.id} - height: ${env.screen_size} width: ${env.screen_size} - frame_skip: ${env.action_repeat} + height: ${env.screen_size} seed: null from_pixels: True from_vectors: False diff --git a/sheeprl/envs/dmc.py b/sheeprl/envs/dmc.py index 0d8b5150..5d6c910f 100644 --- a/sheeprl/envs/dmc.py +++ b/sheeprl/envs/dmc.py @@ -5,7 +5,7 @@ if not _IS_DMC_AVAILABLE: raise ModuleNotFoundError(_IS_DMC_AVAILABLE) -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional, SupportsFloat, Tuple, Union import numpy as np from dm_control import suite @@ -13,7 +13,7 @@ from gymnasium import core, spaces -def _spec_to_box(spec, dtype) -> spaces.Space: +def _spec_to_box(spec, dtype) -> spaces.Box: def extract_min_max(s): assert s.dtype == np.float64 or s.dtype == np.float32 dim = int(np.prod(s.shape)) @@ -54,7 +54,6 @@ def __init__( height: int = 84, width: int = 84, camera_id: int = 0, - frame_skip: int = 1, task_kwargs: Optional[Dict[Any, Any]] = None, environment_kwargs: Optional[Dict[Any, Any]] = None, channels_first: bool = True, @@ -93,9 +92,6 @@ def __init__( Defaults to 84. camera_id (int, optional): the id of the camera from where to take the image observation. Defaults to 0. - frame_skip (int, optional): action repeat value. Given an action, `frame_skip` steps will be performed - in the environment with the given action. - Defaults to 1. task_kwargs (Optional[Dict[Any, Any]], optional): Optional dict of keyword arguments for the task. Defaults to None. environment_kwargs (Optional[Dict[Any, Any]], optional): Optional dict specifying @@ -121,7 +117,6 @@ def __init__( self._height = height self._width = width self._camera_id = camera_id - self._frame_skip = frame_skip self._channels_first = channels_first # create task @@ -137,6 +132,10 @@ def __init__( self._true_action_space = _spec_to_box([self._env.action_spec()], np.float32) self._norm_action_space = spaces.Box(low=-1.0, high=1.0, shape=self._true_action_space.shape, dtype=np.float32) + # set the reward range + reward_space = _spec_to_box([self._env.reward_spec()], np.float32) + self._reward_range = (reward_space.low.item(), reward_space.high.item()) + # create observation space if from_pixels: shape = (3, height, width) if channels_first else (height, width, 3) @@ -163,7 +162,7 @@ def __init__( def __getattr__(self, name): return getattr(self._env, name) - def _get_obs(self, time_step): + def _get_obs(self, time_step) -> Union[Dict[str, np.ndarray], np.ndarray]: if self._from_pixels: rgb_obs = self.render(camera_id=self._camera_id) if self._channels_first: @@ -177,7 +176,7 @@ def _get_obs(self, time_step): else: return rgb_obs - def _convert_action(self, action): + def _convert_action(self, action) -> np.ndarray: action = action.astype(np.float64) true_delta = self._true_action_space.high - self._true_action_space.low norm_delta = self._norm_action_space.high - self._norm_action_space.low @@ -187,20 +186,20 @@ def _convert_action(self, action): return action @property - def observation_space(self): + def observation_space(self) -> Union[spaces.Dict, spaces.Box]: return self._observation_space @property - def state_space(self): + def state_space(self) -> spaces.Box: return self._state_space @property - def action_space(self): + def action_space(self) -> spaces.Box: return self._norm_action_space @property - def reward_range(self): - return 0, self._frame_skip + def reward_range(self) -> Tuple[float, float]: + return self._reward_range @property def render_mode(self) -> str: @@ -211,33 +210,31 @@ def seed(self, seed: Optional[int] = None): self._norm_action_space.seed(seed) self._observation_space.seed(seed) - def step(self, action): + def step( + self, action: Any + ) -> Tuple[Union[Dict[str, np.ndarray], np.ndarray], SupportsFloat, bool, bool, Dict[str, Any]]: assert self._norm_action_space.contains(action) action = self._convert_action(action) assert self._true_action_space.contains(action) - reward = 0 - extra = {"internal_state": self._env.physics.get_state().copy()} - - for _ in range(self._frame_skip): - time_step = self._env.step(action) - reward += time_step.reward or 0 - done = time_step.last() - if done: - break + time_step = self._env.step(action) + reward = time_step.reward or 0.0 + done = time_step.last() obs = self._get_obs(time_step) self.current_state = _flatten_obs(time_step.observation) + extra = {} extra["discount"] = time_step.discount + extra["internal_state"] = self._env.physics.get_state().copy() return obs, reward, done, False, extra def reset( self, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None - ) -> Tuple[np.ndarray, Dict[str, Any]]: + ) -> Tuple[Union[Dict[str, np.ndarray], np.ndarray], Dict[str, Any]]: time_step = self._env.reset() self.current_state = _flatten_obs(time_step.observation) obs = self._get_obs(time_step) return obs, {} - def render(self, camera_id: Optional[int] = None): + def render(self, camera_id: Optional[int] = None) -> np.ndarray: return self._env.physics.render(height=self._height, width=self._width, camera_id=camera_id or self._camera_id) def close(self): diff --git a/sheeprl/utils/env.py b/sheeprl/utils/env.py index 58e6434a..8a13cd67 100644 --- a/sheeprl/utils/env.py +++ b/sheeprl/utils/env.py @@ -14,7 +14,7 @@ if _IS_DIAMBRA_ARENA_AVAILABLE and _IS_DIAMBRA_AVAILABLE: from sheeprl.envs.diambra_wrapper import DiambraWrapper if _IS_DMC_AVAILABLE: - from sheeprl.envs.dmc import DMCWrapper + pass def make_env( @@ -86,14 +86,11 @@ def thunk() -> gym.Env: if "rank" in cfg.env.env: instantiate_kwargs["rank"] = rank + vector_env_idx env = hydra.utils.instantiate(cfg.env.env, **instantiate_kwargs) - if "mujoco" in env_spec: - env.frame_skip = 0 # action repeat if ( cfg.env.action_repeat > 1 and "atari" not in env_spec - and (not _IS_DMC_AVAILABLE or not isinstance(env, DMCWrapper)) and (not (_IS_DIAMBRA_ARENA_AVAILABLE and _IS_DIAMBRA_AVAILABLE) or not isinstance(env, DiambraWrapper)) ): env = ActionRepeat(env, cfg.env.action_repeat) @@ -154,41 +151,38 @@ def thunk() -> gym.Env: def transform_obs(obs: Dict[str, Any]): for k in cnn_keys: - shape = obs[k].shape + current_obs = obs[k] + shape = current_obs.shape is_3d = len(shape) == 3 is_grayscale = not is_3d or shape[0] == 1 or shape[-1] == 1 channel_first = not is_3d or shape[0] in (1, 3) # to 3D image if not is_3d: - obs.update({k: np.expand_dims(obs[k], axis=0)}) + current_obs = np.expand_dims(current_obs, axis=0) # channel last (opencv needs it) if channel_first: - obs.update({k: obs[k].transpose(1, 2, 0)}) + current_obs = np.transpose(current_obs, (1, 2, 0)) # resize - if obs[k].shape[:-1] != (cfg.env.screen_size, cfg.env.screen_size): - obs.update( - { - k: cv2.resize( - obs[k], (cfg.env.screen_size, cfg.env.screen_size), interpolation=cv2.INTER_AREA - ) - } + if current_obs.shape[:-1] != (cfg.env.screen_size, cfg.env.screen_size): + current_obs = cv2.resize( + current_obs, (cfg.env.screen_size, cfg.env.screen_size), interpolation=cv2.INTER_AREA ) # to grayscale if cfg.env.grayscale and not is_grayscale: - obs.update({k: cv2.cvtColor(obs[k], cv2.COLOR_RGB2GRAY)}) + current_obs = cv2.cvtColor(current_obs, cv2.COLOR_RGB2GRAY) # back to 3D - if len(obs[k].shape) == 2: - obs.update({k: np.expand_dims(obs[k], axis=-1)}) + if len(current_obs.shape) == 2: + current_obs = np.expand_dims(current_obs, axis=-1) if not cfg.env.grayscale: - obs.update({k: np.repeat(obs[k], 3, axis=-1)}) + current_obs = np.repeat(current_obs, 3, axis=-1) # channel first (PyTorch default) - obs.update({k: obs[k].transpose(2, 0, 1)}) + obs[k] = current_obs.transpose(2, 0, 1) return obs From bc16310b91064075e427c4c385e3b37ff4f0ad65 Mon Sep 17 00:00:00 2001 From: belerico_t Date: Mon, 18 Sep 2023 09:29:19 +0200 Subject: [PATCH 31/31] Call super().close --- sheeprl/envs/diambra.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sheeprl/envs/diambra.py b/sheeprl/envs/diambra.py index 96697a99..f9c2f2ce 100644 --- a/sheeprl/envs/diambra.py +++ b/sheeprl/envs/diambra.py @@ -117,3 +117,4 @@ def reset( def close(self) -> None: self._env.close() + super().close()