Skip to content

Commit

Permalink
add wip
Browse files Browse the repository at this point in the history
  • Loading branch information
QuantuMope committed Sep 5, 2024
1 parent f95bbf3 commit ca336d9
Showing 1 changed file with 94 additions and 7 deletions.
101 changes: 94 additions & 7 deletions alf/algorithms/sac_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,67 @@ def _set_target_entropy(name, target_entropy, flat_action_spec):
return target_entropy


def extract_action_distribution_at_slice(
orig_dist: td.TransformedDistribution,
action_slice: slice) -> td.TransformedDistribution:
"""
Helper function for splitting an action distribution at a certain index.
This is only really necessary the agent's action spec differs with the expert's
action spec. This requires us to split the distribution into two parts, where one
matches the expert's action distribution for computing the KLD loss.
Args:
orig_dist: The action distribution to split.
action_slice: The slice of the action distribution for the new one. The
distribution will contain indices [..., action-slice].
Returns:
A tuple containing the left and right distributions split at the index.
"""
action_dim = orig_dist.event_shape[0]
start, stop, step = action_slice.indices(action_dim)

assert 0 <= start < action_dim and 0 < stop <= action_dim, \
f"Invalid slice {action_slice} for action_dim of {action_dim}."

base_dist = orig_dist.base_dist
distribution_type = type(base_dist)
assert isinstance(base_dist, (dist_utils.DiagMultivariateNormal, dist_utils.DiagMultivariateBeta)), \
f"split_action_distribution is not implemented for {distribution_type}."

params = dist_utils.distributions_to_params(base_dist)
new_params = alf.nest.map_structure(lambda x: x[..., action_slice], params)

new_base_dist = distribution_type(**new_params)

# Currently the transforms are only either a StableTanh or an AffineTransform.
# If new transforms are used in the future, we'll have to modify this function.
# NOTE: we have to rather hardcode this as alf does not support returning params
# of transforms yet.
extract_params = lambda x: x[..., action_slice]
new_transforms = []
for tf in orig_dist.transforms:
if isinstance(tf, dist_utils.StableTanh):
# For StableTanh transforms, we can just add them to each sides' transform list.
new_transforms.append(tf)
elif isinstance(tf, dist_utils.AffineTransform):
# For AffineTransforms, we need to split them at the index as well.
new_aft_loc = extract_params(tf.loc)
new_aft_scale = extract_params(tf.scale)
new_transforms.append(
dist_utils.AffineTransform(
loc=new_aft_loc, scale=new_aft_scale))
else:
raise NotImplementedError(
f"split_action_distribution is not implemented for {type(tf)}."
)

new_act_dist = td.TransformedDistribution(new_base_dist, new_transforms)
return new_act_dist


@alf.configurable
class SacAlgorithm(OffPolicyAlgorithm):
r"""Soft Actor Critic algorithm, described in:
Expand Down Expand Up @@ -789,8 +850,14 @@ def _compute_critics(self,
# continuous: critics shape [B, replicas, reward_dim]
return critics, critics_state

def _actor_train_step(self, observation, state, action, critics, log_pi,
action_distribution):
def _actor_train_step(self,
observation,
state,
action,
critics,
log_pi,
action_distribution,
action_slice: Optional[slice] = None):
neg_entropy = sum(nest.flatten(log_pi))

if self._act_type == ActionType.Discrete:
Expand All @@ -814,6 +881,10 @@ def _actor_train_step(self, observation, state, action, critics, log_pi,
# This sum() will reduce all dims so q_value can be any rank
dqda = nest_utils.grad(action, q_value.sum())

if action_slice is not None:
dqda = dqda[..., action_slice]
action = action[..., action_slice]

def actor_loss_fn(dqda, action):
if self._dqda_clipping:
dqda = torch.clamp(dqda, -self._dqda_clipping,
Expand Down Expand Up @@ -899,8 +970,11 @@ def _alpha_train_step(self, log_pi):
self._target_entropy)
return sum(nest.flatten(alpha_loss))

def train_step(self, inputs: TimeStep, state: SacState,
rollout_info: SacInfo):
def train_step(self,
inputs: TimeStep,
state: SacState,
rollout_info: SacInfo,
action_slice: Optional[slice] = None):
assert not self._is_eval
self._training_started = True
if self._target_repr_alg is not None:
Expand All @@ -922,8 +996,16 @@ def train_step(self, inputs: TimeStep, state: SacState,
action_state) = self._predict_action(
observation, state=state.action)

if action_slice is not None:
new_act_dist = extract_action_distribution_at_slice(
action_distribution, action_slice=action_slice)
new_action = action[..., action_slice]
else:
new_act_dist = action_distribution
new_action = action

log_pi = nest.map_structure(lambda dist, a: dist.log_prob(a),
action_distribution, action)
new_act_dist, new_action)

if self._act_type == ActionType.Mixed:
# For mixed type, add log_pi separately
Expand All @@ -939,8 +1021,13 @@ def train_step(self, inputs: TimeStep, state: SacState,
log_pi = log_pi - log_prior

actor_state, actor_loss = self._actor_train_step(
observation, state.actor, action, critics, log_pi,
action_distribution)
observation,
state.actor,
action,
critics,
log_pi,
action_distribution,
action_slice=action_slice)
critic_state, critic_info = self._critic_train_step(
observation, target_observation, state.critic, rollout_info,
action, action_distribution)
Expand Down

0 comments on commit ca336d9

Please sign in to comment.