From 90c8e40f64bb76601d93a9416fa8723cd607ffe2 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 2 Dec 2024 16:24:13 +0000 Subject: [PATCH 01/12] [BugFix] Better account of composite distributions in PPO ghstack-source-id: 3d86f99bc5b20a53e4092d786e96a5f7e83405ac Pull Request resolved: https://github.com/pytorch/rl/pull/2622 --- torchrl/objectives/ppo.py | 53 +++++++++++++++++--------- torchrl/objectives/utils.py | 5 +++ torchrl/objectives/value/advantages.py | 22 ++++++++--- 3 files changed, 58 insertions(+), 22 deletions(-) diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 8c64c1ba539..eb9a916dfc1 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -18,6 +18,7 @@ TensorDictParams, ) from tensordict.nn import ( + CompositeDistribution, dispatch, ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential, @@ -33,6 +34,7 @@ _clip_value_loss, _GAMMA_LMBDA_DEPREC_ERROR, _reduce, + _sum_td_features, default_value_kwargs, distance_loss, ValueEstimators, @@ -462,9 +464,13 @@ def reset(self) -> None: def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor: try: - entropy = dist.entropy() + if isinstance(dist, CompositeDistribution): + kwargs = {"aggregate_probabilities": False, "include_sum": False} + else: + kwargs = {} + entropy = dist.entropy(**kwargs) if is_tensor_collection(entropy): - entropy = entropy.get(dist.entropy_key) + entropy = _sum_td_features(entropy) except NotImplementedError: x = dist.rsample((self.samples_mc_entropy,)) log_prob = dist.log_prob(x) @@ -497,13 +503,20 @@ def _log_weight( if isinstance(action, torch.Tensor): log_prob = dist.log_prob(action) else: - maybe_log_prob = dist.log_prob(tensordict) - if not isinstance(maybe_log_prob, torch.Tensor): - # In some cases (Composite distribution with aggregate_probabilities toggled off) the returned type may not - # be a tensor - log_prob = maybe_log_prob.get(self.tensor_keys.sample_log_prob) + if isinstance(dist, CompositeDistribution): + is_composite = True + kwargs = { + "inplace": False, + "aggregate_probabilities": False, + "include_sum": False, + } else: - log_prob = maybe_log_prob + is_composite = False + kwargs = {} + log_prob = dist.log_prob(tensordict, **kwargs) + if is_composite and not isinstance(prev_log_prob, TensorDict): + log_prob = _sum_td_features(log_prob) + log_prob.view_as(prev_log_prob) log_weight = (log_prob - prev_log_prob).unsqueeze(-1) kl_approx = (prev_log_prob - log_prob).unsqueeze(-1) @@ -598,6 +611,9 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: advantage = (advantage - loc) / scale log_weight, dist, kl_approx = self._log_weight(tensordict) + if is_tensor_collection(log_weight): + log_weight = _sum_td_features(log_weight) + log_weight = log_weight.view(advantage.shape) neg_loss = log_weight.exp() * advantage td_out = TensorDict({"loss_objective": -neg_loss}, batch_size=[]) if self.entropy_bonus: @@ -1149,16 +1165,19 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: kl = torch.distributions.kl.kl_divergence(previous_dist, current_dist) except NotImplementedError: x = previous_dist.sample((self.samples_mc_kl,)) - previous_log_prob = previous_dist.log_prob(x) - current_log_prob = current_dist.log_prob(x) + if isinstance(previous_dist, CompositeDistribution): + kwargs = { + "aggregate_probabilities": False, + "inplace": False, + "include_sum": False, + } + else: + kwargs = {} + previous_log_prob = previous_dist.log_prob(x, **kwargs) + current_log_prob = current_dist.log_prob(x, **kwargs) if is_tensor_collection(current_log_prob): - previous_log_prob = previous_log_prob.get( - self.tensor_keys.sample_log_prob - ) - current_log_prob = current_log_prob.get( - self.tensor_keys.sample_log_prob - ) - + previous_log_prob = _sum_td_features(previous_log_prob) + current_log_prob = _sum_td_features(current_log_prob) kl = (previous_log_prob - current_log_prob).mean(0) kl = kl.unsqueeze(-1) neg_loss = neg_loss - self.beta * kl diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index 4dfed60e5a9..9c46fc98262 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -615,3 +615,8 @@ def group_optimizers(*optimizers: torch.optim.Optimizer) -> torch.optim.Optimize raise ValueError("Cannot group optimizers of different type.") params.extend(optimizer.param_groups) return cls(params) + + +def _sum_td_features(data: TensorDictBase) -> torch.Tensor: + # Sum all features and return a tensor + return data.sum(dim="feature", reduce=True) diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index fadfe932c50..bbd6a23bfdd 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -15,11 +15,14 @@ import torch from tensordict import TensorDictBase from tensordict.nn import ( + CompositeDistribution, dispatch, + ProbabilisticTensorDictModule, set_skip_existing, TensorDictModule, TensorDictModuleBase, ) +from tensordict.nn.probabilistic import interaction_type from tensordict.utils import NestedKey from torch import Tensor @@ -74,14 +77,22 @@ def new_func(self, *args, **kwargs): def _call_actor_net( - actor_net: TensorDictModuleBase, + actor_net: ProbabilisticTensorDictModule, data: TensorDictBase, params: TensorDictBase, log_prob_key: NestedKey, ): - # TODO: extend to handle time dimension (and vmap?) - log_pi = actor_net(data.select(*actor_net.in_keys, strict=False)).get(log_prob_key) - return log_pi + dist = actor_net.get_dist(data.select(*actor_net.in_keys, strict=False)) + if isinstance(dist, CompositeDistribution): + kwargs = { + "aggregate_probabilities": True, + "inplace": False, + "include_sum": False, + } + else: + kwargs = {} + s = actor_net._dist_sample(dist, interaction_type=interaction_type()) + return dist.log_prob(s, **kwargs) class ValueEstimatorBase(TensorDictModuleBase): @@ -1771,7 +1782,8 @@ def forward( data=tensordict, params=None, log_prob_key=self.tensor_keys.sample_log_prob, - ).view_as(value) + ) + log_pi = log_pi.view_as(value) # Compute the V-Trace correction done = tensordict.get(("next", self.tensor_keys.done)) From de61e4d5eefeb41cd0e69a3821ec1b8ebf34c8c8 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 2 Dec 2024 17:49:24 +0000 Subject: [PATCH 02/12] [BugFix] skip_done_states in SAC ghstack-source-id: 39d97360e3b0e45dd8c327487eac50ddafe2254d Pull Request resolved: https://github.com/pytorch/rl/pull/2613 --- test/test_cost.py | 2 + torchrl/objectives/sac.py | 151 ++++++++++++++++++++++++-------------- 2 files changed, 99 insertions(+), 54 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index c48b4a28b99..1f191e41db6 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -4493,6 +4493,7 @@ def test_sac_terminating( actor_network=actor, qvalue_network=qvalue, value_network=value, + skip_done_states=True, ) loss.set_keys( action=action_key, @@ -5204,6 +5205,7 @@ def test_discrete_sac_terminating( qvalue_network=qvalue, num_actions=actor.spec[action_key].space.n, action_space="one-hot", + skip_done_states=True, ) loss.set_keys( action=action_key, diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index cd7039c323d..dafff17011e 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -126,6 +126,10 @@ class SACLoss(LossModule): ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied, ``"mean"``: the sum of the output will be divided by the number of elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``. + skip_done_states (bool, optional): whether the actor network used for value computation should only be run on + valid, non-terminating next states. If ``True``, it is assumed that the done state can be broadcast to the + shape of the data and that masking the data results in a valid data structure. Among other things, this may + not be true in MARL settings or when using RNNs. Defaults to ``False``. Examples: >>> import torch @@ -320,6 +324,7 @@ def __init__( priority_key: str = None, separate_losses: bool = False, reduction: str = None, + skip_done_states: bool = False, ) -> None: self._in_keys = None self._out_keys = None @@ -418,6 +423,7 @@ def __init__( raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) self._make_vmap() self.reduction = reduction + self.skip_done_states = skip_done_states def _make_vmap(self): self._vmap_qnetworkN0 = _vmap_func( @@ -712,36 +718,44 @@ def _compute_target_v2(self, tensordict) -> Tensor: ExplorationType.RANDOM ), self.actor_network_params.to_module(self.actor_network): next_tensordict = tensordict.get("next").copy() - # Check done state and avoid passing these to the actor - done = next_tensordict.get(self.tensor_keys.done) - if done is not None and done.any(): - next_tensordict_select = next_tensordict[~done.squeeze(-1)] - else: - next_tensordict_select = next_tensordict - next_dist = self.actor_network.get_dist(next_tensordict_select) - next_action = next_dist.rsample() - next_sample_log_prob = compute_log_prob( - next_dist, next_action, self.tensor_keys.log_prob - ) - if next_tensordict_select is not next_tensordict: - mask = ~done.squeeze(-1) - if mask.ndim < next_action.ndim: - mask = expand_right( - mask, (*mask.shape, *next_action.shape[mask.ndim :]) - ) - next_action = next_action.new_zeros(mask.shape).masked_scatter_( - mask, next_action + if self.skip_done_states: + # Check done state and avoid passing these to the actor + done = next_tensordict.get(self.tensor_keys.done) + if done is not None and done.any(): + next_tensordict_select = next_tensordict[~done.squeeze(-1)] + else: + next_tensordict_select = next_tensordict + next_dist = self.actor_network.get_dist(next_tensordict_select) + next_action = next_dist.rsample() + next_sample_log_prob = compute_log_prob( + next_dist, next_action, self.tensor_keys.log_prob ) - mask = ~done.squeeze(-1) - if mask.ndim < next_sample_log_prob.ndim: - mask = expand_right( - mask, - (*mask.shape, *next_sample_log_prob.shape[mask.ndim :]), + if next_tensordict_select is not next_tensordict: + mask = ~done.squeeze(-1) + if mask.ndim < next_action.ndim: + mask = expand_right( + mask, (*mask.shape, *next_action.shape[mask.ndim :]) + ) + next_action = next_action.new_zeros(mask.shape).masked_scatter_( + mask, next_action ) - next_sample_log_prob = next_sample_log_prob.new_zeros( - mask.shape - ).masked_scatter_(mask, next_sample_log_prob) - next_tensordict.set(self.tensor_keys.action, next_action) + mask = ~done.squeeze(-1) + if mask.ndim < next_sample_log_prob.ndim: + mask = expand_right( + mask, + (*mask.shape, *next_sample_log_prob.shape[mask.ndim :]), + ) + next_sample_log_prob = next_sample_log_prob.new_zeros( + mask.shape + ).masked_scatter_(mask, next_sample_log_prob) + next_tensordict.set(self.tensor_keys.action, next_action) + else: + next_dist = self.actor_network.get_dist(next_tensordict) + next_action = next_dist.rsample() + next_tensordict.set(self.tensor_keys.action, next_action) + next_sample_log_prob = compute_log_prob( + next_dist, next_action, self.tensor_keys.log_prob + ) # get q-values next_tensordict_expand = self._vmap_qnetworkN0( @@ -877,6 +891,10 @@ class DiscreteSACLoss(LossModule): ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied, ``"mean"``: the sum of the output will be divided by the number of elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``. + skip_done_states (bool, optional): whether the actor network used for value computation should only be run on + valid, non-terminating next states. If ``True``, it is assumed that the done state can be broadcast to the + shape of the data and that masking the data results in a valid data structure. Among other things, this may + not be true in MARL settings or when using RNNs. Defaults to ``False``. Examples: >>> import torch @@ -1051,6 +1069,7 @@ def __init__( priority_key: str = None, separate_losses: bool = False, reduction: str = None, + skip_done_states: bool = False, ): if reduction is None: reduction = "mean" @@ -1133,6 +1152,7 @@ def __init__( ) self._make_vmap() self.reduction = reduction + self.skip_done_states = skip_done_states def _make_vmap(self): self._vmap_qnetworkN0 = _vmap_func( @@ -1218,35 +1238,58 @@ def _compute_target(self, tensordict) -> Tensor: with torch.no_grad(): next_tensordict = tensordict.get("next").clone(False) - done = next_tensordict.get(self.tensor_keys.done) - if done is not None and done.any(): - next_tensordict_select = next_tensordict[~done.squeeze(-1)] - else: - next_tensordict_select = next_tensordict + if self.skip_done_states: + done = next_tensordict.get(self.tensor_keys.done) + if done is not None and done.any(): + next_tensordict_select = next_tensordict[~done.squeeze(-1)] + else: + next_tensordict_select = next_tensordict - # get probs and log probs for actions computed from "next" - with self.actor_network_params.to_module(self.actor_network): - next_dist = self.actor_network.get_dist(next_tensordict_select) - next_log_prob = next_dist.logits - next_prob = next_log_prob.exp() + # get probs and log probs for actions computed from "next" + with self.actor_network_params.to_module(self.actor_network): + next_dist = self.actor_network.get_dist(next_tensordict_select) + next_log_prob = next_dist.logits + next_prob = next_log_prob.exp() - # get q-values for all actions - next_tensordict_expand = self._vmap_qnetworkN0( - next_tensordict_select, self.target_qvalue_network_params - ) - next_action_value = next_tensordict_expand.get( - self.tensor_keys.action_value - ) + # get q-values for all actions + next_tensordict_expand = self._vmap_qnetworkN0( + next_tensordict_select, self.target_qvalue_network_params + ) + next_action_value = next_tensordict_expand.get( + self.tensor_keys.action_value + ) - # like in continuous SAC, we take the minimum of the value ensemble and subtract the entropy term - next_state_value = next_action_value.min(0)[0] - self._alpha * next_log_prob - # unlike in continuous SAC, we can compute the exact expectation over all discrete actions - next_state_value = (next_prob * next_state_value).sum(-1).unsqueeze(-1) - if next_tensordict_select is not next_tensordict: - mask = ~done - next_state_value = next_state_value.new_zeros( - mask.shape - ).masked_scatter_(mask, next_state_value) + # like in continuous SAC, we take the minimum of the value ensemble and subtract the entropy term + next_state_value = ( + next_action_value.min(0)[0] - self._alpha * next_log_prob + ) + # unlike in continuous SAC, we can compute the exact expectation over all discrete actions + next_state_value = (next_prob * next_state_value).sum(-1).unsqueeze(-1) + if next_tensordict_select is not next_tensordict: + mask = ~done + next_state_value = next_state_value.new_zeros( + mask.shape + ).masked_scatter_(mask, next_state_value) + else: + # get probs and log probs for actions computed from "next" + with self.actor_network_params.to_module(self.actor_network): + next_dist = self.actor_network.get_dist(next_tensordict) + next_prob = next_dist.probs + next_log_prob = torch.log(torch.where(next_prob == 0, 1e-8, next_prob)) + + # get q-values for all actions + next_tensordict_expand = self._vmap_qnetworkN0( + next_tensordict, self.target_qvalue_network_params + ) + next_action_value = next_tensordict_expand.get( + self.tensor_keys.action_value + ) + # like in continuous SAC, we take the minimum of the value ensemble and subtract the entropy term + next_state_value = ( + next_action_value.min(0)[0] - self._alpha * next_log_prob + ) + # unlike in continuous SAC, we can compute the exact expectation over all discrete actions + next_state_value = (next_prob * next_state_value).sum(-1).unsqueeze(-1) tensordict.set( ("next", self.value_estimator.tensor_keys.value), next_state_value From 830f2f26ca91ec153f63e539c423223dddd95e21 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 2 Dec 2024 17:46:24 +0000 Subject: [PATCH 03/12] [BugFix] ActionDiscretizer scalar integration ghstack-source-id: b22102f3730914b125ef0f813f4d2f22dec0b26e Pull Request resolved: https://github.com/pytorch/rl/pull/2619 --- test/mocking_classes.py | 69 +++++++++++++++++ test/test_transforms.py | 103 +++++++++++++++++++++----- torchrl/envs/transforms/transforms.py | 89 +++++++++++++++------- 3 files changed, 213 insertions(+), 48 deletions(-) diff --git a/test/mocking_classes.py b/test/mocking_classes.py index d78e2f27184..bb902f879b1 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -1927,3 +1927,72 @@ def _step( def _set_seed(self, seed: Optional[int]): self.manual_seed = seed return seed + + +class EnvWithScalarAction(EnvBase): + def __init__(self, singleton: bool = False, **kwargs): + super().__init__(**kwargs) + self.singleton = singleton + self.action_spec = Bounded( + -1, + 1, + shape=( + *self.batch_size, + 1, + ) + if self.singleton + else self.batch_size, + ) + self.observation_spec = Composite( + observation=Unbounded( + shape=( + *self.batch_size, + 3, + ) + ), + shape=self.batch_size, + ) + self.done_spec = Composite( + done=Unbounded(self.batch_size + (1,), dtype=torch.bool), + terminated=Unbounded(self.batch_size + (1,), dtype=torch.bool), + truncated=Unbounded(self.batch_size + (1,), dtype=torch.bool), + shape=self.batch_size, + ) + self.reward_spec = Unbounded( + shape=( + *self.batch_size, + 1, + ) + ) + + def _reset(self, td: TensorDict): + return TensorDict( + observation=torch.randn(*self.batch_size, 3, device=self.device), + done=torch.zeros(*self.batch_size, 1, dtype=torch.bool, device=self.device), + truncated=torch.zeros( + *self.batch_size, 1, dtype=torch.bool, device=self.device + ), + terminated=torch.zeros( + *self.batch_size, 1, dtype=torch.bool, device=self.device + ), + device=self.device, + ) + + def _step( + self, + tensordict: TensorDictBase, + ) -> TensorDictBase: + return TensorDict( + observation=torch.randn(*self.batch_size, 3, device=self.device), + reward=torch.zeros(1, device=self.device), + done=torch.zeros(*self.batch_size, 1, dtype=torch.bool, device=self.device), + truncated=torch.zeros( + *self.batch_size, 1, dtype=torch.bool, device=self.device + ), + terminated=torch.zeros( + *self.batch_size, 1, dtype=torch.bool, device=self.device + ), + ) + + def _set_seed(self, seed: Optional[int]): + ... diff --git a/test/test_transforms.py b/test/test_transforms.py index 8b2ada8c93a..ae428d35d97 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -41,6 +41,7 @@ CountingEnvCountPolicy, DiscreteActionConvMockEnv, DiscreteActionConvMockEnvNumpy, + EnvWithScalarAction, IncrementingEnv, MockBatchedLockedEnv, MockBatchedUnLockedEnv, @@ -66,6 +67,7 @@ CountingEnvCountPolicy, DiscreteActionConvMockEnv, DiscreteActionConvMockEnvNumpy, + EnvWithScalarAction, IncrementingEnv, MockBatchedLockedEnv, MockBatchedUnLockedEnv, @@ -11781,17 +11783,33 @@ def test_transform_inverse(self): class TestActionDiscretizer(TransformBase): @pytest.mark.parametrize("categorical", [True, False]) - def test_single_trans_env_check(self, categorical): - base_env = ContinuousActionVecMockEnv() + @pytest.mark.parametrize( + "env_cls", + [ + ContinuousActionVecMockEnv, + partial(EnvWithScalarAction, singleton=True), + partial(EnvWithScalarAction, singleton=False), + ], + ) + def test_single_trans_env_check(self, categorical, env_cls): + base_env = env_cls() env = base_env.append_transform( ActionDiscretizer(num_intervals=5, categorical=categorical) ) check_env_specs(env) @pytest.mark.parametrize("categorical", [True, False]) - def test_serial_trans_env_check(self, categorical): + @pytest.mark.parametrize( + "env_cls", + [ + ContinuousActionVecMockEnv, + partial(EnvWithScalarAction, singleton=True), + partial(EnvWithScalarAction, singleton=False), + ], + ) + def test_serial_trans_env_check(self, categorical, env_cls): def make_env(): - base_env = ContinuousActionVecMockEnv() + base_env = env_cls() return base_env.append_transform( ActionDiscretizer(num_intervals=5, categorical=categorical) ) @@ -11800,9 +11818,17 @@ def make_env(): check_env_specs(env) @pytest.mark.parametrize("categorical", [True, False]) - def test_parallel_trans_env_check(self, categorical): + @pytest.mark.parametrize( + "env_cls", + [ + ContinuousActionVecMockEnv, + partial(EnvWithScalarAction, singleton=True), + partial(EnvWithScalarAction, singleton=False), + ], + ) + def test_parallel_trans_env_check(self, categorical, env_cls): def make_env(): - base_env = ContinuousActionVecMockEnv() + base_env = env_cls() env = base_env.append_transform( ActionDiscretizer(num_intervals=5, categorical=categorical) ) @@ -11812,17 +11838,33 @@ def make_env(): check_env_specs(env) @pytest.mark.parametrize("categorical", [True, False]) - def test_trans_serial_env_check(self, categorical): - env = SerialEnv(2, ContinuousActionVecMockEnv).append_transform( + @pytest.mark.parametrize( + "env_cls", + [ + ContinuousActionVecMockEnv, + partial(EnvWithScalarAction, singleton=True), + partial(EnvWithScalarAction, singleton=False), + ], + ) + def test_trans_serial_env_check(self, categorical, env_cls): + env = SerialEnv(2, env_cls).append_transform( ActionDiscretizer(num_intervals=5, categorical=categorical) ) check_env_specs(env) @pytest.mark.parametrize("categorical", [True, False]) - def test_trans_parallel_env_check(self, categorical): - env = ParallelEnv( - 2, ContinuousActionVecMockEnv, mp_start_method=mp_ctx - ).append_transform(ActionDiscretizer(num_intervals=5, categorical=categorical)) + @pytest.mark.parametrize( + "env_cls", + [ + ContinuousActionVecMockEnv, + partial(EnvWithScalarAction, singleton=True), + partial(EnvWithScalarAction, singleton=False), + ], + ) + def test_trans_parallel_env_check(self, categorical, env_cls): + env = ParallelEnv(2, env_cls, mp_start_method=mp_ctx).append_transform( + ActionDiscretizer(num_intervals=5, categorical=categorical) + ) check_env_specs(env) def test_transform_no_env(self): @@ -11838,7 +11880,6 @@ def test_transform_compose(self): check_env_specs(env) @pytest.mark.skipif(not _has_gym, reason="gym required for this test") - @pytest.mark.parametrize("envname", ["cheetah", "pendulum"]) @pytest.mark.parametrize("interval_as_tensor", [False, True]) @pytest.mark.parametrize("categorical", [True, False]) @pytest.mark.parametrize( @@ -11851,15 +11892,37 @@ def test_transform_compose(self): ActionDiscretizer.SamplingStrategy.RANDOM, ], ) - def test_transform_env(self, envname, interval_as_tensor, categorical, sampling): + @pytest.mark.parametrize( + "env_cls", + [ + "cheetah", + "pendulum", + partial(EnvWithScalarAction, singleton=True), + partial(EnvWithScalarAction, singleton=False), + ], + ) + def test_transform_env(self, env_cls, interval_as_tensor, categorical, sampling): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - base_env = GymEnv( - HALFCHEETAH_VERSIONED() if envname == "cheetah" else PENDULUM_VERSIONED(), - device=device, - ) - if interval_as_tensor: - num_intervals = torch.arange(5, 11 if envname == "cheetah" else 6) + if env_cls == "cheetah": + base_env = GymEnv( + HALFCHEETAH_VERSIONED(), + device=device, + ) + num_intervals = torch.arange(5, 11) + elif env_cls == "pendulum": + base_env = GymEnv( + PENDULUM_VERSIONED(), + device=device, + ) + num_intervals = torch.arange(5, 6) else: + base_env = env_cls( + device=device, + ) + num_intervals = torch.arange(5, 6) + + if not interval_as_tensor: + # override num_intervals = 5 t = ActionDiscretizer( num_intervals=num_intervals, diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 7bdd25591cd..7ab5a2deb72 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -8585,24 +8585,32 @@ def _indent(s): def transform_input_spec(self, input_spec): try: - action_spec = input_spec["full_action_spec", self.in_keys_inv[0]] + action_spec = self.parent.full_action_spec_unbatched[self.in_keys_inv[0]] if not isinstance(action_spec, Bounded): raise TypeError( - f"action spec type {type(action_spec)} is not supported." + f"action spec type {type(action_spec)} is not supported. The action spec type must be Bounded." ) n_act = action_spec.shape if not n_act: - n_act = 1 + n_act = () + empty_shape = True else: - n_act = n_act[-1] + n_act = (n_act[-1],) + empty_shape = False self.n_act = n_act self.dtype = action_spec.dtype - interval = (action_spec.high - action_spec.low).unsqueeze(-1) + interval = action_spec.high - action_spec.low num_intervals = self.num_intervals + if not empty_shape: + interval = interval.unsqueeze(-1) + elif isinstance(num_intervals, torch.Tensor): + num_intervals = int(num_intervals.squeeze()) + self.num_intervals = torch.as_tensor(num_intervals) + def custom_arange(nint): result = torch.arange( start=0.0, @@ -8625,11 +8633,13 @@ def custom_arange(nint): if isinstance(num_intervals, int): arange = ( - custom_arange(num_intervals).expand(n_act, num_intervals) * interval - ) - self.register_buffer( - "intervals", action_spec.low.unsqueeze(-1) + arange + custom_arange(num_intervals).expand((*n_act, num_intervals)) + * interval ) + low = action_spec.low + if not empty_shape: + low = low.unsqueeze(-1) + self.register_buffer("intervals", low + arange) else: arange = [ custom_arange(_num_intervals) * interval @@ -8644,12 +8654,6 @@ def custom_arange(nint): ) ] - cls = ( - functools.partial(MultiCategorical, remove_singleton=False) - if self.categorical - else MultiOneHot - ) - if not isinstance(num_intervals, torch.Tensor): nvec = torch.as_tensor(num_intervals, device=action_spec.device) else: @@ -8657,7 +8661,10 @@ def custom_arange(nint): if nvec.ndim > 1: raise RuntimeError(f"Cannot use num_intervals with shape {nvec.shape}") if nvec.ndim == 0 or nvec.numel() == 1: - nvec = nvec.expand(action_spec.shape[-1]) + if not empty_shape: + nvec = nvec.expand(action_spec.shape[-1]) + else: + nvec = nvec.squeeze() self.register_buffer("nvec", nvec) if self.sampling == self.SamplingStrategy.RANDOM: # compute jitters @@ -8667,7 +8674,22 @@ def custom_arange(nint): if self.categorical else (*action_spec.shape[:-1], nvec.sum()) ) - action_spec = cls(nvec=nvec, shape=shape, device=action_spec.device) + + if not empty_shape: + cls = ( + functools.partial(MultiCategorical, remove_singleton=False) + if self.categorical + else MultiOneHot + ) + action_spec = cls(nvec=nvec, shape=shape, device=action_spec.device) + + else: + cls = Categorical if self.categorical else OneHot + action_spec = cls(n=int(nvec), shape=shape, device=action_spec.device) + + batch_size = self.parent.batch_size + if batch_size: + action_spec = action_spec.expand(batch_size + action_spec.shape) input_spec["full_action_spec", self.out_keys_inv[0]] = action_spec if self.out_keys_inv[0] != self.in_keys_inv[0]: @@ -8705,6 +8727,8 @@ def _inv_call(self, tensordict): if self.categorical: action = action.unsqueeze(-1) if isinstance(intervals, torch.Tensor): + shape = action.shape[: -intervals.ndim] + intervals = intervals.expand(shape + intervals.shape) action = intervals.gather(index=action, dim=-1).squeeze(-1) else: action = torch.stack( @@ -8715,17 +8739,26 @@ def _inv_call(self, tensordict): -1, ) else: - nvec = self.nvec.tolist() - action = action.split(nvec, dim=-1) - if isinstance(intervals, torch.Tensor): - intervals = intervals.unbind(-2) - action = torch.stack( - [ - intervals[action].view(action.shape[:-1]) - for (intervals, action) in zip(intervals, action) - ], - -1, - ) + nvec = self.nvec + empty_shape = not nvec.ndim + if not empty_shape: + nvec = nvec.tolist() + if isinstance(intervals, torch.Tensor): + shape = action.shape[: (-intervals.ndim + 1)] + intervals = intervals.expand(shape + intervals.shape) + intervals = intervals.unbind(-2) + action = action.split(nvec, dim=-1) + action = torch.stack( + [ + intervals[action].view(action.shape[:-1]) + for (intervals, action) in zip(intervals, action) + ], + -1, + ) + else: + shape = action.shape[: -intervals.ndim] + intervals = intervals.expand(shape + intervals.shape) + action = intervals[action].squeeze(-1) if self.sampling == self.SamplingStrategy.RANDOM: action = action + self.jitters * torch.rand_like(self.jitters) From c72583f75ab220c7ef89e9bd2505045ea5898db4 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 2 Dec 2024 17:46:25 +0000 Subject: [PATCH 04/12] [Feature, Test] Adding tests for envs that have no specs ghstack-source-id: 4c75691baa1e70f417e518df15c4208cff189950 Pull Request resolved: https://github.com/pytorch/rl/pull/2621 --- test/mocking_classes.py | 14 ++++++++++++++ test/test_env.py | 30 ++++++++++++++++++++++++++++++ torchrl/envs/common.py | 8 ++++++-- torchrl/envs/utils.py | 2 ++ 4 files changed, 52 insertions(+), 2 deletions(-) diff --git a/test/mocking_classes.py b/test/mocking_classes.py index bb902f879b1..3c30286c419 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -1996,3 +1996,17 @@ def _step( def _set_seed(self, seed: Optional[int]): ... + + +class EnvThatDoesNothing(EnvBase): + def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: + return TensorDict(batch_size=self.batch_size, device=self.device) + + def _step( + self, + tensordict: TensorDictBase, + ) -> TensorDictBase: + return TensorDict(batch_size=self.batch_size, device=self.device) + + def _set_seed(self, seed): + ... diff --git a/test/test_env.py b/test/test_env.py index 81708b0b9a6..b48b1a1cf8f 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -44,6 +44,7 @@ DiscreteActionConvMockEnvNumpy, DiscreteActionVecMockEnv, DummyModelBasedEnvBase, + EnvThatDoesNothing, EnvWithDynamicSpec, EnvWithMetadata, HeterogeneousCountingEnv, @@ -81,6 +82,7 @@ DiscreteActionConvMockEnvNumpy, DiscreteActionVecMockEnv, DummyModelBasedEnvBase, + EnvThatDoesNothing, EnvWithDynamicSpec, EnvWithMetadata, HeterogeneousCountingEnv, @@ -3554,6 +3556,34 @@ def test_auto_spec(): env.check_env_specs(tensordict=td.copy()) +def test_env_that_does_nothing(): + env = EnvThatDoesNothing() + env.check_env_specs() + r = env.rollout(3) + r.exclude( + "done", "terminated", ("next", "done"), ("next", "terminated"), inplace=True + ) + assert r.is_empty() + p_env = SerialEnv(2, EnvThatDoesNothing) + p_env.check_env_specs() + r = p_env.rollout(3) + r.exclude( + "done", "terminated", ("next", "done"), ("next", "terminated"), inplace=True + ) + assert r.is_empty() + p_env = ParallelEnv(2, EnvThatDoesNothing) + try: + p_env.check_env_specs() + r = p_env.rollout(3) + r.exclude( + "done", "terminated", ("next", "done"), ("next", "terminated"), inplace=True + ) + assert r.is_empty() + finally: + p_env.close() + del p_env + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index d5a062bc11e..bafe88b639a 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -2434,8 +2434,12 @@ def _register_gym( # noqa: F811 apply_api_compatibility=apply_api_compatibility, ) - def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - raise NotImplementedError("EnvBase.forward is not implemented") + def forward(self, *args, **kwargs): + raise NotImplementedError( + "EnvBase.forward is not implemented. If you ended here during a call to `ParallelEnv(...)`, please use " + "a constructor such as `ParallelEnv(num_env, lambda env=env: env)` instead. " + "Batched envs require constructors because environment instances may not always be serializable." + ) @abc.abstractmethod def _step( diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 7454bce99b3..209349878ec 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -287,6 +287,8 @@ def __call__(self, tensordict): if self.validate(tensordict): if self.keep_other: out = self._exclude(self.exclude_from_root, tensordict, out=None) + if out is None: + out = tensordict.empty() else: out = next_td.empty() self._grab_and_place( From b2e9f291ad2862e6b9d8d34e68d0e2607acc9295 Mon Sep 17 00:00:00 2001 From: Mana <57663038+0xMana-git@users.noreply.github.com> Date: Mon, 2 Dec 2024 23:15:18 -0800 Subject: [PATCH 05/12] [Doc] Fix typo in torchrl/modules/distributions/continuous.py (#2624) --- torchrl/modules/distributions/continuous.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/modules/distributions/continuous.py b/torchrl/modules/distributions/continuous.py index f32a3b0c6fa..eb9093dbcfe 100644 --- a/torchrl/modules/distributions/continuous.py +++ b/torchrl/modules/distributions/continuous.py @@ -554,7 +554,7 @@ def get_mode(self): def mean(self): raise NotImplementedError( f"{type(self).__name__} does not have a closed form formula for the average. " - "Am estimate of this value can be computed using dist.sample((N,)).mean(dim=0), " + "An estimate of this value can be computed using dist.sample((N,)).mean(dim=0), " "where N is a large number of samples." ) From 8257799353253ec481100b08592e65525a659690 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Valter=20Sch=C3=BCtz?= Date: Tue, 3 Dec 2024 15:12:10 +0100 Subject: [PATCH 06/12] [Doc] actor docstrings (#2626) Co-authored-by: Valter Schutz --- torchrl/modules/tensordict_module/actors.py | 10 +++++----- torchrl/modules/tensordict_module/probabilistic.py | 8 ++++---- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index 888729835b5..6175bc8bf0c 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -153,14 +153,14 @@ class ProbabilisticActor(SafeProbabilisticTensorDictSequential): issues. If this value is out of bounds, it is projected back onto the desired space using the :obj:`TensorSpec.project` method. Default is ``False``. - default_interaction_type (str, optional): keyword-only argument. + default_interaction_type (tensordict.nn.InteractionType, optional): keyword-only argument. Default method to be used to retrieve - the output value. Should be one of: 'InteractionType.MODE', 'InteractionType.DETERMINISTIC', - 'InteractionType.MEDIAN', 'InteractionType.MEAN' or - 'InteractionType.RANDOM' (in which case the value is sampled + the output value. Should be one of: ``InteractionType.MODE``, ``InteractionType.DETERMINISTIC``, + ``InteractionType.MEDIAN``, ``InteractionType.MEAN`` or + ``InteractionType.RANDOM`` (in which case the value is sampled randomly from the distribution). TorchRL's ``ExplorationType`` class is a proxy to ``InteractionType``. - Defaults to is 'InteractionType.DETERMINISTIC'. + Defaults to ``InteractionType.DETERMINISTIC``. .. note:: When a sample is drawn, the :class:`ProbabilisticActor` instance will first look for the interaction mode dictated by the diff --git a/torchrl/modules/tensordict_module/probabilistic.py b/torchrl/modules/tensordict_module/probabilistic.py index 8bd5143d20f..5ea006b8d2f 100644 --- a/torchrl/modules/tensordict_module/probabilistic.py +++ b/torchrl/modules/tensordict_module/probabilistic.py @@ -68,12 +68,12 @@ class SafeProbabilisticModule(ProbabilisticTensorDictModule): returned by the input module. If the sample is out of bounds, it is projected back onto the desired space using the `TensorSpec.project` method. Default is ``False``. - default_interaction_type (str, optional): default method to be used to retrieve - the output value. Should be one of: 'mode', 'median', 'mean' or 'random' + default_interaction_type (tensordict.nn.InteractionType, optional): default method to be used to retrieve + the output value. Should be one of: ``InteractionType.MODE``, ``InteractionType.MEDIAN``, ``InteractionType.MEAN`` or ``InteractionType.RANDOM`` (in which case the value is sampled randomly from the distribution). Default - is 'mode'. + is ``InteractionType.MODE``. Note: When a sample is drawn, the :obj:`ProbabilisticTDModule` instance will - fist look for the interaction mode dictated by the `interaction_typ()` + fist look for the interaction mode dictated by the `interaction_type()` global function. If this returns `None` (its default value), then the `default_interaction_type` of the :class:`~.ProbabilisticTDModule` instance will be used. Note that DataCollector instances will use From d22266d05d7ae10f53e3b904d847d44743beba40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Valter=20Sch=C3=BCtz?= Date: Tue, 3 Dec 2024 15:12:49 +0100 Subject: [PATCH 07/12] [Doc] Update docstring for TruncatedNormal with correct parameter names (#2625) Co-authored-by: Valter Schutz --- torchrl/modules/distributions/continuous.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchrl/modules/distributions/continuous.py b/torchrl/modules/distributions/continuous.py index eb9093dbcfe..e34f1be8ff9 100644 --- a/torchrl/modules/distributions/continuous.py +++ b/torchrl/modules/distributions/continuous.py @@ -205,8 +205,8 @@ class TruncatedNormal(D.Independent): Default is 5.0 - min (torch.Tensor or number, optional): minimum value of the distribution. Default = -1.0; - max (torch.Tensor or number, optional): maximum value of the distribution. Default = 1.0; + low (torch.Tensor or number, optional): minimum value of the distribution. Default = -1.0; + high (torch.Tensor or number, optional): maximum value of the distribution. Default = 1.0; tanh_loc (bool, optional): if ``True``, the above formula is used for the location scaling, otherwise the raw value is kept. Default is ``False``; From 607ebc52dc083290b6bcce98864881358f94fd7a Mon Sep 17 00:00:00 2001 From: Goia Rares Dan Tiago <115428237+raresdan@users.noreply.github.com> Date: Tue, 3 Dec 2024 16:35:51 +0200 Subject: [PATCH 08/12] [Refactor] Rename Recorder and LogReward (#2616) --- docs/source/reference/trainers.rst | 8 +-- sota-implementations/redq/utils.py | 10 ++-- test/test_trainer.py | 27 +++++----- torchrl/trainers/__init__.py | 2 + torchrl/trainers/helpers/trainers.py | 10 ++-- torchrl/trainers/trainers.py | 61 +++++++++++++++++++++-- tutorials/sphinx-tutorials/coding_ddpg.py | 6 +-- tutorials/sphinx-tutorials/coding_dqn.py | 8 +-- 8 files changed, 94 insertions(+), 38 deletions(-) diff --git a/docs/source/reference/trainers.rst b/docs/source/reference/trainers.rst index 11384bda0e6..8f6be633743 100644 --- a/docs/source/reference/trainers.rst +++ b/docs/source/reference/trainers.rst @@ -78,8 +78,8 @@ Hooks can be split into 3 categories: **data processing** (``"batch_process"`` a constants update), data subsampling (:class:``~torchrl.trainers.BatchSubSampler``) and such. - **Logging** hooks take a batch of data presented as a ``TensorDict`` and write in the logger - some information retrieved from that data. Examples include the ``Recorder`` hook, the reward - logger (``LogReward``) and such. Hooks should return a dictionary (or a None value) containing the + some information retrieved from that data. Examples include the ``LogValidationReward`` hook, the reward + logger (``LogScaler``) and such. Hooks should return a dictionary (or a None value) containing the data to log. The key ``"log_pbar"`` is reserved to boolean values indicating if the logged value should be displayed on the progression bar printed on the training log. @@ -174,9 +174,9 @@ Trainer and hooks BatchSubSampler ClearCudaCache CountFramesLog - LogReward + LogScaler OptimizerHook - Recorder + LogValidationReward ReplayBufferTrainer RewardNormalizer SelectKeys diff --git a/sota-implementations/redq/utils.py b/sota-implementations/redq/utils.py index 9953fcb3112..fed4922b5a7 100644 --- a/sota-implementations/redq/utils.py +++ b/sota-implementations/redq/utils.py @@ -81,8 +81,8 @@ BatchSubSampler, ClearCudaCache, CountFramesLog, - LogReward, - Recorder, + LogScalar, + LogValidationReward, ReplayBufferTrainer, RewardNormalizer, Trainer, @@ -331,7 +331,7 @@ def make_trainer( if recorder is not None: # create recorder object - recorder_obj = Recorder( + recorder_obj = LogValidationReward( record_frames=cfg.logger.record_frames, frame_skip=cfg.env.frame_skip, policy_exploration=policy_exploration, @@ -347,7 +347,7 @@ def make_trainer( # call recorder - could be removed recorder_obj(None) # create explorative recorder - could be optional - recorder_obj_explore = Recorder( + recorder_obj_explore = LogValidationReward( record_frames=cfg.logger.record_frames, frame_skip=cfg.env.frame_skip, policy_exploration=policy_exploration, @@ -369,7 +369,7 @@ def make_trainer( "post_steps", UpdateWeights(collector, update_weights_interval=1) ) - trainer.register_op("pre_steps_log", LogReward()) + trainer.register_op("pre_steps_log", LogScalar()) trainer.register_op("pre_steps_log", CountFramesLog(frame_skip=cfg.env.frame_skip)) return trainer diff --git a/test/test_trainer.py b/test/test_trainer.py index f7e4ccffdf5..caae5bbe178 100644 --- a/test/test_trainer.py +++ b/test/test_trainer.py @@ -35,14 +35,14 @@ TensorDictReplayBuffer, ) from torchrl.envs.libs.gym import _has_gym -from torchrl.trainers import Recorder, Trainer +from torchrl.trainers import LogValidationReward, Trainer from torchrl.trainers.helpers import transformed_env_constructor from torchrl.trainers.trainers import ( _has_tqdm, _has_ts, BatchSubSampler, CountFramesLog, - LogReward, + LogScalar, mask_batch, OptimizerHook, ReplayBufferTrainer, @@ -638,7 +638,7 @@ def test_log_reward(self, logname, pbar): trainer = mocking_trainer() trainer.collected_frames = 0 - log_reward = LogReward(logname, log_pbar=pbar) + log_reward = LogScalar(logname, log_pbar=pbar) trainer.register_op("pre_steps_log", log_reward) td = TensorDict({REWARD_KEY: torch.ones(3)}, [3]) trainer._pre_steps_log_hook(td) @@ -654,7 +654,7 @@ def test_log_reward_register(self, logname, pbar): trainer = mocking_trainer() trainer.collected_frames = 0 - log_reward = LogReward(logname, log_pbar=pbar) + log_reward = LogScalar(logname, log_pbar=pbar) log_reward.register(trainer) td = TensorDict({REWARD_KEY: torch.ones(3)}, [3]) trainer._pre_steps_log_hook(td) @@ -873,7 +873,7 @@ def test_recorder(self, N=8): logger=logger, )() - recorder = Recorder( + recorder = LogValidationReward( record_frames=args.record_frames, frame_skip=args.frame_skip, policy_exploration=None, @@ -919,13 +919,12 @@ def test_recorder_load(self, backend, N=8): os.environ["CKPT_BACKEND"] = backend state_dict_has_been_called = [False] load_state_dict_has_been_called = [False] - Recorder.state_dict, Recorder_state_dict = _fun_checker( - Recorder.state_dict, state_dict_has_been_called + LogValidationReward.state_dict, Recorder_state_dict = _fun_checker( + LogValidationReward.state_dict, state_dict_has_been_called + ) + (LogValidationReward.load_state_dict, Recorder_load_state_dict,) = _fun_checker( + LogValidationReward.load_state_dict, load_state_dict_has_been_called ) - ( - Recorder.load_state_dict, - Recorder_load_state_dict, - ) = _fun_checker(Recorder.load_state_dict, load_state_dict_has_been_called) args = self._get_args() @@ -948,7 +947,7 @@ def _make_recorder_and_trainer(tmpdirname): )() environment.rollout(2) - recorder = Recorder( + recorder = LogValidationReward( record_frames=args.record_frames, frame_skip=args.frame_skip, policy_exploration=None, @@ -969,8 +968,8 @@ def _make_recorder_and_trainer(tmpdirname): assert recorder2._count == 8 assert state_dict_has_been_called[0] assert load_state_dict_has_been_called[0] - Recorder.state_dict = Recorder_state_dict - Recorder.load_state_dict = Recorder_load_state_dict + LogValidationReward.state_dict = Recorder_state_dict + LogValidationReward.load_state_dict = Recorder_load_state_dict def test_updateweights(): diff --git a/torchrl/trainers/__init__.py b/torchrl/trainers/__init__.py index 364c0dec725..9d593d64f17 100644 --- a/torchrl/trainers/__init__.py +++ b/torchrl/trainers/__init__.py @@ -8,6 +8,8 @@ ClearCudaCache, CountFramesLog, LogReward, + LogScalar, + LogValidationReward, mask_batch, OptimizerHook, Recorder, diff --git a/torchrl/trainers/helpers/trainers.py b/torchrl/trainers/helpers/trainers.py index 207bcec0ffd..4819d9e07e8 100644 --- a/torchrl/trainers/helpers/trainers.py +++ b/torchrl/trainers/helpers/trainers.py @@ -25,8 +25,8 @@ BatchSubSampler, ClearCudaCache, CountFramesLog, - LogReward, - Recorder, + LogScalar, + LogValidationReward, ReplayBufferTrainer, RewardNormalizer, SelectKeys, @@ -259,7 +259,7 @@ def make_trainer( if recorder is not None: # create recorder object - recorder_obj = Recorder( + recorder_obj = LogValidationReward( record_frames=cfg.record_frames, frame_skip=cfg.frame_skip, policy_exploration=policy_exploration, @@ -275,7 +275,7 @@ def make_trainer( # call recorder - could be removed recorder_obj(None) # create explorative recorder - could be optional - recorder_obj_explore = Recorder( + recorder_obj_explore = LogValidationReward( record_frames=cfg.record_frames, frame_skip=cfg.frame_skip, policy_exploration=policy_exploration, @@ -297,7 +297,7 @@ def make_trainer( "post_steps", UpdateWeights(collector, update_weights_interval=1) ) - trainer.register_op("pre_steps_log", LogReward()) + trainer.register_op("pre_steps_log", LogScalar()) trainer.register_op("pre_steps_log", CountFramesLog(frame_skip=cfg.frame_skip)) return trainer diff --git a/torchrl/trainers/trainers.py b/torchrl/trainers/trainers.py index 7e28da45f52..83bd050ef96 100644 --- a/torchrl/trainers/trainers.py +++ b/torchrl/trainers/trainers.py @@ -822,7 +822,7 @@ def __call__(self, *args, **kwargs): torch.cuda.empty_cache() -class LogReward(TrainerHookBase): +class LogScalar(TrainerHookBase): """Reward logger hook. Args: @@ -833,7 +833,7 @@ class LogReward(TrainerHookBase): in the input batch. Defaults to ``("next", "reward")`` Examples: - >>> log_reward = LogReward(("next", "reward")) + >>> log_reward = LogScalar(("next", "reward")) >>> trainer.register_op("pre_steps_log", log_reward) """ @@ -870,6 +870,23 @@ def register(self, trainer: Trainer, name: str = "log_reward"): trainer.register_module(name, self) +class LogReward(LogScalar): + """Deprecated class. Use LogScalar instead.""" + + def __init__( + self, + logname="r_training", + log_pbar: bool = False, + reward_key: Union[str, tuple] = None, + ): + warnings.warn( + "The 'LogReward' class is deprecated and will be removed in v0.9. Please use 'LogScalar' instead.", + DeprecationWarning, + stacklevel=2, + ) + super().__init__(logname=logname, log_pbar=log_pbar, reward_key=reward_key) + + class RewardNormalizer(TrainerHookBase): """Reward normalizer hook. @@ -1127,7 +1144,7 @@ def register(self, trainer: Trainer, name: str = "batch_subsampler"): trainer.register_module(name, self) -class Recorder(TrainerHookBase): +class LogValidationReward(TrainerHookBase): """Recorder hook for :class:`~torchrl.trainers.Trainer`. Args: @@ -1264,6 +1281,44 @@ def register(self, trainer: Trainer, name: str = "recorder"): ) +class Recorder(LogValidationReward): + """Deprecated class. Use LogValidationReward instead.""" + + def __init__( + self, + *, + record_interval: int, + record_frames: int, + frame_skip: int = 1, + policy_exploration: TensorDictModule, + environment: EnvBase = None, + exploration_type: ExplorationType = ExplorationType.RANDOM, + log_keys: Optional[List[Union[str, Tuple[str]]]] = None, + out_keys: Optional[Dict[Union[str, Tuple[str]], str]] = None, + suffix: Optional[str] = None, + log_pbar: bool = False, + recorder: EnvBase = None, + ) -> None: + warnings.warn( + "The 'Recorder' class is deprecated and will be removed in v0.9. Please use 'LogValidationReward' instead.", + DeprecationWarning, + stacklevel=2, + ) + super().__init__( + record_interval=record_interval, + record_frames=record_frames, + frame_skip=frame_skip, + policy_exploration=policy_exploration, + environment=environment, + exploration_type=exploration_type, + log_keys=log_keys, + out_keys=out_keys, + suffix=suffix, + log_pbar=log_pbar, + recorder=recorder, + ) + + class UpdateWeights(TrainerHookBase): """A collector weights update hook class. diff --git a/tutorials/sphinx-tutorials/coding_ddpg.py b/tutorials/sphinx-tutorials/coding_ddpg.py index 906d162f181..70176f9de4a 100644 --- a/tutorials/sphinx-tutorials/coding_ddpg.py +++ b/tutorials/sphinx-tutorials/coding_ddpg.py @@ -883,12 +883,12 @@ def make_ddpg_actor( # # As the training data is obtained using some exploration strategy, the true # performance of our algorithm needs to be assessed in deterministic mode. We -# do this using a dedicated class, ``Recorder``, which executes the policy in +# do this using a dedicated class, ``LogValidationReward``, which executes the policy in # the environment at a given frequency and returns some statistics obtained # from these simulations. # # The following helper function builds this object: -from torchrl.trainers import Recorder +from torchrl.trainers import LogValidationReward def make_recorder(actor_model_explore, transform_state_dict, record_interval): @@ -899,7 +899,7 @@ def make_recorder(actor_model_explore, transform_state_dict, record_interval): ) # must be instantiated to load the state dict environment.transform[2].load_state_dict(transform_state_dict) - recorder_obj = Recorder( + recorder_obj = LogValidationReward( record_frames=1000, policy_exploration=actor_model_explore, environment=environment, diff --git a/tutorials/sphinx-tutorials/coding_dqn.py b/tutorials/sphinx-tutorials/coding_dqn.py index 59188ad21f6..a10e8c1169a 100644 --- a/tutorials/sphinx-tutorials/coding_dqn.py +++ b/tutorials/sphinx-tutorials/coding_dqn.py @@ -140,8 +140,8 @@ from torchrl.objectives import DQNLoss, SoftUpdate from torchrl.record.loggers.csv import CSVLogger from torchrl.trainers import ( - LogReward, - Recorder, + LogScalar, + LogValidationReward, ReplayBufferTrainer, Trainer, UpdateWeights, @@ -666,7 +666,7 @@ def get_loss_module(actor, gamma): buffer_hook.register(trainer) weight_updater = UpdateWeights(collector, update_weights_interval=1) weight_updater.register(trainer) -recorder = Recorder( +recorder = LogValidationReward( record_interval=100, # log every 100 optimization steps record_frames=1000, # maximum number of frames in the record frame_skip=1, @@ -704,7 +704,7 @@ def get_loss_module(actor, gamma): # This will be reflected by the `total_rewards` value displayed in the # progress bar. # -log_reward = LogReward(log_pbar=True) +log_reward = LogScalar(log_pbar=True) log_reward.register(trainer) ############################################################################### From 3da76f0063aac5880312832a6a449f2adb9caf91 Mon Sep 17 00:00:00 2001 From: Oliver Slumbers <40644337+oslumbers@users.noreply.github.com> Date: Tue, 3 Dec 2024 14:42:16 +0000 Subject: [PATCH 09/12] [Feature] ActionDiscretizer custom sampling (#2609) Co-authored-by: Oliver Slumbers --- torchrl/envs/transforms/transforms.py | 46 ++++++++++++++------------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 7ab5a2deb72..980273af96c 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -8583,6 +8583,26 @@ def _indent(s): f"\n{_indent(out_action_key)},\n{_indent(sampling)},\n{_indent(categorical)})" ) + def _custom_arange(self, nint, device): + result = torch.arange( + start=0.0, + end=1.0, + step=1 / nint, + dtype=self.dtype, + device=device, + ) + result_ = result + if self.sampling in ( + self.SamplingStrategy.HIGH, + self.SamplingStrategy.MEDIAN, + ): + result_ = (1 - result).flip(0) + if self.sampling == self.SamplingStrategy.MEDIAN: + result = (result + result_) / 2 + else: + result = result_ + return result + def transform_input_spec(self, input_spec): try: action_spec = self.parent.full_action_spec_unbatched[self.in_keys_inv[0]] @@ -8611,29 +8631,11 @@ def transform_input_spec(self, input_spec): num_intervals = int(num_intervals.squeeze()) self.num_intervals = torch.as_tensor(num_intervals) - def custom_arange(nint): - result = torch.arange( - start=0.0, - end=1.0, - step=1 / nint, - dtype=self.dtype, - device=action_spec.device, - ) - result_ = result - if self.sampling in ( - self.SamplingStrategy.HIGH, - self.SamplingStrategy.MEDIAN, - ): - result_ = (1 - result).flip(0) - if self.sampling == self.SamplingStrategy.MEDIAN: - result = (result + result_) / 2 - else: - result = result_ - return result - if isinstance(num_intervals, int): arange = ( - custom_arange(num_intervals).expand((*n_act, num_intervals)) + self._custom_arange(num_intervals, action_spec.device).expand( + (*n_act, num_intervals) + ) * interval ) low = action_spec.low @@ -8642,7 +8644,7 @@ def custom_arange(nint): self.register_buffer("intervals", low + arange) else: arange = [ - custom_arange(_num_intervals) * interval + self._custom_arange(_num_intervals, action_spec.device) * interval for _num_intervals, interval in zip( num_intervals.tolist(), interval.unbind(-2) ) From aed03fda451e1abebad6f7310c974b1b372c4a61 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 3 Dec 2024 14:50:26 +0000 Subject: [PATCH 10/12] [CI] Fix dreamer run in SOTA tests ghstack-source-id: dfe3ab6fe0d29fcdcaf57f31f84d04e07e36bad3 Pull Request resolved: https://github.com/pytorch/rl/pull/2627 --- .github/unittest/linux_sota/scripts/test_sota.py | 4 ++-- sota-implementations/dreamer/dreamer.py | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/.github/unittest/linux_sota/scripts/test_sota.py b/.github/unittest/linux_sota/scripts/test_sota.py index d42f96d5ee1..b7af381634c 100644 --- a/.github/unittest/linux_sota/scripts/test_sota.py +++ b/.github/unittest/linux_sota/scripts/test_sota.py @@ -190,12 +190,12 @@ logger.backend= """, "dreamer": """python sota-implementations/dreamer/dreamer.py \ - collector.total_frames=200 \ + collector.total_frames=600 \ collector.init_random_frames=10 \ collector.frames_per_batch=200 \ env.n_parallel_envs=1 \ optimization.optim_steps_per_batch=1 \ - logger.video=True \ + logger.video=False \ logger.backend=csv \ replay_buffer.buffer_size=120 \ replay_buffer.batch_size=24 \ diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index d97066b87c5..1b9823c1dd1 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -321,8 +321,10 @@ def compile_rssms(module): t_collect_init = time.time() - test_env.close() - train_env.close() + if not test_env.is_closed: + test_env.close() + if not train_env.is_closed: + train_env.close() collector.shutdown() del test_env From 1ca134cc3243b28295ce9c2e8bf363814fa8ce32 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 3 Dec 2024 15:05:44 +0000 Subject: [PATCH 11/12] [BugFix] Fix MARL PPO tutorial action_spec call ghstack-source-id: 1d9058c45b28c0f0279e4243a2a0f96c622a51d8 Pull Request resolved: https://github.com/pytorch/rl/pull/2628 --- tutorials/sphinx-tutorials/multiagent_ppo.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tutorials/sphinx-tutorials/multiagent_ppo.py b/tutorials/sphinx-tutorials/multiagent_ppo.py index e2ca3f6ecd8..0e6cc51adf6 100644 --- a/tutorials/sphinx-tutorials/multiagent_ppo.py +++ b/tutorials/sphinx-tutorials/multiagent_ppo.py @@ -450,8 +450,8 @@ out_keys=[env.action_key], distribution_class=TanhNormal, distribution_kwargs={ - "low": env.action_spec_unbatched[env.action_key].space.low, - "high": env.action_spec_unbatched[env.action_key].space.high, + "low": env.full_action_spec_unbatched[env.action_key].space.low, + "high": env.full_action_spec_unbatched[env.action_key].space.high, }, return_log_prob=True, log_prob_key=("agents", "sample_log_prob"), From 1cffffee92a37d16df3ddaf94fc29cc4b3292d5a Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 3 Dec 2024 15:05:46 +0000 Subject: [PATCH 12/12] [BugFix] Fix export aoti_compile_and_package API change ghstack-source-id: 07a0f063f8955815157c2a3eac02c6460a82f672 Pull Request resolved: https://github.com/pytorch/rl/pull/2629 --- tutorials/sphinx-tutorials/export.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tutorials/sphinx-tutorials/export.py b/tutorials/sphinx-tutorials/export.py index 48dd8723ffc..d40ef09ff8c 100644 --- a/tutorials/sphinx-tutorials/export.py +++ b/tutorials/sphinx-tutorials/export.py @@ -343,8 +343,6 @@ with torch.no_grad(): pkg_path = aoti_compile_and_package( exported_policy, - args=(), - kwargs={"pixels": pixels}, # Specify the generated shared library path package_path=path, )