Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 2, 2024
2 parents 58c4d6a + 583f553 commit ca35b99
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 30 deletions.
53 changes: 36 additions & 17 deletions torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
TensorDictParams,
)
from tensordict.nn import (
CompositeDistribution,
dispatch,
ProbabilisticTensorDictModule,
ProbabilisticTensorDictSequential,
Expand All @@ -33,6 +34,7 @@
_clip_value_loss,
_GAMMA_LMBDA_DEPREC_ERROR,
_reduce,
_sum_td_features,
default_value_kwargs,
distance_loss,
ValueEstimators,
Expand Down Expand Up @@ -462,9 +464,13 @@ def reset(self) -> None:

def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor:
try:
entropy = dist.entropy()
if isinstance(dist, CompositeDistribution):
kwargs = {"aggregate_probabilities": False, "include_sum": False}
else:
kwargs = {}
entropy = dist.entropy(**kwargs)
if is_tensor_collection(entropy):
entropy = entropy.get(dist.entropy_key)
entropy = _sum_td_features(entropy)
except NotImplementedError:
x = dist.rsample((self.samples_mc_entropy,))
log_prob = dist.log_prob(x)
Expand Down Expand Up @@ -497,13 +503,20 @@ def _log_weight(
if isinstance(action, torch.Tensor):
log_prob = dist.log_prob(action)
else:
maybe_log_prob = dist.log_prob(tensordict)
if not isinstance(maybe_log_prob, torch.Tensor):
# In some cases (Composite distribution with aggregate_probabilities toggled off) the returned type may not
# be a tensor
log_prob = maybe_log_prob.get(self.tensor_keys.sample_log_prob)
if isinstance(dist, CompositeDistribution):
is_composite = True
kwargs = {
"inplace": False,
"aggregate_probabilities": False,
"include_sum": False,
}
else:
log_prob = maybe_log_prob
is_composite = False
kwargs = {}
log_prob = dist.log_prob(tensordict, **kwargs)
if is_composite and not isinstance(prev_log_prob, TensorDict):
log_prob = _sum_td_features(log_prob)
log_prob.view_as(prev_log_prob)

log_weight = (log_prob - prev_log_prob).unsqueeze(-1)
kl_approx = (prev_log_prob - log_prob).unsqueeze(-1)
Expand Down Expand Up @@ -598,6 +611,9 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
advantage = (advantage - loc) / scale

log_weight, dist, kl_approx = self._log_weight(tensordict)
if is_tensor_collection(log_weight):
log_weight = _sum_td_features(log_weight)
log_weight = log_weight.view(advantage.shape)
neg_loss = log_weight.exp() * advantage
td_out = TensorDict({"loss_objective": -neg_loss}, batch_size=[])
if self.entropy_bonus:
Expand Down Expand Up @@ -1149,16 +1165,19 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
kl = torch.distributions.kl.kl_divergence(previous_dist, current_dist)
except NotImplementedError:
x = previous_dist.sample((self.samples_mc_kl,))
previous_log_prob = previous_dist.log_prob(x)
current_log_prob = current_dist.log_prob(x)
if isinstance(previous_dist, CompositeDistribution):
kwargs = {
"aggregate_probabilities": False,
"inplace": False,
"include_sum": False,
}
else:
kwargs = {}
previous_log_prob = previous_dist.log_prob(x, **kwargs)
current_log_prob = current_dist.log_prob(x, **kwargs)
if is_tensor_collection(current_log_prob):
previous_log_prob = previous_log_prob.get(
self.tensor_keys.sample_log_prob
)
current_log_prob = current_log_prob.get(
self.tensor_keys.sample_log_prob
)

previous_log_prob = _sum_td_features(previous_log_prob)
current_log_prob = _sum_td_features(current_log_prob)
kl = (previous_log_prob - current_log_prob).mean(0)
kl = kl.unsqueeze(-1)
neg_loss = neg_loss - self.beta * kl
Expand Down
16 changes: 8 additions & 8 deletions torchrl/objectives/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,10 @@ class SACLoss(LossModule):
``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
``"mean"``: the sum of the output will be divided by the number of
elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
skip_done_states (bool, optional): whether the actor network should only be run on valid, non-terminating
next states. If ``True``, it is assumed that the done state can be broadcast to the shape of the
data and that masking the data results in a valid data structure. Among other things, this may not
be true in MARL settings or when using RNNs. Defaults to ``False``.
skip_done_states (bool, optional): whether the actor network used for value computation should only be run on
valid, non-terminating next states. If ``True``, it is assumed that the done state can be broadcast to the
shape of the data and that masking the data results in a valid data structure. Among other things, this may
not be true in MARL settings or when using RNNs. Defaults to ``False``.
Examples:
>>> import torch
Expand Down Expand Up @@ -891,10 +891,10 @@ class DiscreteSACLoss(LossModule):
``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied,
``"mean"``: the sum of the output will be divided by the number of
elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``.
skip_done_states (bool, optional): whether the actor network should only be run on valid, non-terminating
next states. If ``True``, it is assumed that the done state can be broadcast to the shape of the
data and that masking the data results in a valid data structure. Among other things, this may not
be true in MARL settings or when using RNNs. Defaults to ``False``.
skip_done_states (bool, optional): whether the actor network used for value computation should only be run on
valid, non-terminating next states. If ``True``, it is assumed that the done state can be broadcast to the
shape of the data and that masking the data results in a valid data structure. Among other things, this may
not be true in MARL settings or when using RNNs. Defaults to ``False``.
Examples:
>>> import torch
Expand Down
5 changes: 5 additions & 0 deletions torchrl/objectives/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,3 +615,8 @@ def group_optimizers(*optimizers: torch.optim.Optimizer) -> torch.optim.Optimize
raise ValueError("Cannot group optimizers of different type.")
params.extend(optimizer.param_groups)
return cls(params)


def _sum_td_features(data: TensorDictBase) -> torch.Tensor:
# Sum all features and return a tensor
return data.sum(dim="feature", reduce=True)
22 changes: 17 additions & 5 deletions torchrl/objectives/value/advantages.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@
import torch
from tensordict import TensorDictBase
from tensordict.nn import (
CompositeDistribution,
dispatch,
ProbabilisticTensorDictModule,
set_skip_existing,
TensorDictModule,
TensorDictModuleBase,
)
from tensordict.nn.probabilistic import interaction_type
from tensordict.utils import NestedKey
from torch import Tensor

Expand Down Expand Up @@ -74,14 +77,22 @@ def new_func(self, *args, **kwargs):


def _call_actor_net(
actor_net: TensorDictModuleBase,
actor_net: ProbabilisticTensorDictModule,
data: TensorDictBase,
params: TensorDictBase,
log_prob_key: NestedKey,
):
# TODO: extend to handle time dimension (and vmap?)
log_pi = actor_net(data.select(*actor_net.in_keys, strict=False)).get(log_prob_key)
return log_pi
dist = actor_net.get_dist(data.select(*actor_net.in_keys, strict=False))
if isinstance(dist, CompositeDistribution):
kwargs = {
"aggregate_probabilities": True,
"inplace": False,
"include_sum": False,
}
else:
kwargs = {}
s = actor_net._dist_sample(dist, interaction_type=interaction_type())
return dist.log_prob(s, **kwargs)


class ValueEstimatorBase(TensorDictModuleBase):
Expand Down Expand Up @@ -1771,7 +1782,8 @@ def forward(
data=tensordict,
params=None,
log_prob_key=self.tensor_keys.sample_log_prob,
).view_as(value)
)
log_pi = log_pi.view_as(value)

# Compute the V-Trace correction
done = tensordict.get(("next", self.tensor_keys.done))
Expand Down

0 comments on commit ca35b99

Please sign in to comment.