diff --git a/README.md b/README.md
index c92e32144..087411176 100644
--- a/README.md
+++ b/README.md
@@ -27,7 +27,8 @@ Read the ALF documentation [here](https://alf.readthedocs.io/).
|[QRSAC](alf/algorithms/qrsac_algorithm.py)|Off-policy RL|Dabney et al. "Distributional Reinforcement Learning with Quantile Regression" [arXiv:1710.10044](https://arxiv.org/abs/1710.10044)|
|[SAC](alf/algorithms/sac_algorithm.py)|Off-policy RL|Haarnoja et al. "Soft Actor-Critic Algorithms and Applications" [arXiv:1812.05905](https://arxiv.org/abs/1812.05905)|
|[OAC](alf/algorithms/oac_algorithm.py)|Off-policy RL|Ciosek et al. "Better Exploration with Optimistic Actor-Critic" [arXiv:1910.12807](https://arxiv.org/abs/1910.12807)|
-|[HER](https://github.com/HorizonRobotics/alf/blob/911d9573866df41e9e3adf6cdd94ee03016bf5a8/alf/algorithms/data_transformer.py#L672)|Off-policy RL|Andrychowicz et al. "Hindsight Experience Replay" [arXiv:1707.01495](https://arxiv.org/abs/1707.01495)|
+|[HER](alf/algorithms/data_transformer.py) (HindsightExperienceTransformer)|Off-policy RL|Andrychowicz et al. "Hindsight Experience Replay" [arXiv:1707.01495](https://arxiv.org/abs/1707.01495)|
+|[lbVT](alf/algorithms/td_loss.py) (LowerBoundedTDLoss)|Off-policy RL|Ciosek et al. "Faster Reinforcement Learning with Value Target Lower Bounding" [link](https://openreview.net/forum?id=bgAS1ZvveZ)|
|[TAAC](alf/algorithms/taac_algorithm.py)|Off-policy RL|Yu et al. "TAAC: Temporally Abstract Actor-Critic for Continuous Control" [arXiv:2104.06521](https://arxiv.org/abs/2104.06521)|
|[DIAYN](alf/algorithms/diayn_algorithm.py)|Intrinsic motivation/Exploration|Eysenbach et al. "Diversity is All You Need: Learning Diverse Skills without a Reward Function" [arXiv:1802.06070](https://arxiv.org/abs/1802.06070)|
|[ICM](alf/algorithms/icm_algorithm.py)|Intrinsic motivation/Exploration|Pathak et al. "Curiosity-driven Exploration by Self-supervised Prediction" [arXiv:1705.05363](https://arxiv.org/abs/1705.05363)|
@@ -145,6 +146,15 @@ All the examples below are trained on a single machine Intel(R) Core(TM) i9-7960
+### lbVT
+* [DDQN with lowerbounded value target on Atari](alf/examples/dqn_breakout_conf.py). Game "Q*Bert" performance.
+
+
+
+* [SAC with lowerbounded value target on Atari](alf/examples/sac_breakout_conf.py). Game "Q*Bert" performance.
+
+
+
### DDPG
* [FetchSlide (sparse rewards)](alf/examples/ddpg_fetchslide_conf.py). Need to install the [MuJoCo](https://www.roboti.us/index.html) simulator first. This example reproduces the performance of vanilla DDPG reported in the OpenAI's Robotics environment [paper](https://arxiv.org/pdf/1802.09464.pdf). Our implementation doesn't use MPI, but obtains (evaluation) performance on par with the original implementation. (*The original MPI implementation has 19 workers, each worker containing 2 environments for rollout and sampling a minibatch of size 256 from its replay buffer for computing gradients. All the workers' gradients will be summed together for a centralized optimizer step. Our implementation simply samples a minibatch of size 5000 from a common replay buffer per optimizer step.*) The training took about 1 hour with 38 (19*2) parallel environments on a single GPU.
diff --git a/alf/algorithms/data_transformer.py b/alf/algorithms/data_transformer.py
index 2fa6c8137..45af9f0b9 100644
--- a/alf/algorithms/data_transformer.py
+++ b/alf/algorithms/data_transformer.py
@@ -736,14 +736,25 @@ class HindsightExperienceTransformer(DataTransformer):
of the current timestep.
The exact field names can be provided via arguments to the class ``__init__``.
+ NOTE: The HindsightExperienceTransformer has to happen before any transformer which changes
+ reward or achieved_goal fields, e.g. observation normalizer, reward clipper, etc..
+ See `documentation <../../docs/notes/knowledge_base.rst#datatransformers>`_ for details.
+
To use this class, add it to any existing data transformers, e.g. use this config if
``ObservationNormalizer`` is an existing data transformer:
.. code-block:: python
- ReplayBuffer.keep_episodic_info=True
- HindsightExperienceTransformer.her_proportion=0.8
- TrainerConfig.data_transformer_ctor=[@HindsightExperienceTransformer, @ObservationNormalizer]
+ alf.config('ReplayBuffer', keep_episodic_info=True)
+ alf.config(
+ 'HindsightExperienceTransformer',
+ her_proportion=0.8
+ )
+ alf.config(
+ 'TrainerConfig',
+ data_transformer_ctor=[
+ HindsightExperienceTransformer, ObservationNormalizer
+ ])
See unit test for more details on behavior.
"""
@@ -820,9 +831,10 @@ def transform_experience(self, experience: Experience):
# relabel only these sampled indices
her_cond = torch.rand(batch_size) < her_proportion
(her_indices, ) = torch.where(her_cond)
+ has_her = torch.any(her_cond)
- last_step_pos = start_pos[her_indices] + batch_length - 1
- last_env_ids = env_ids[her_indices]
+ last_step_pos = start_pos + batch_length - 1
+ last_env_ids = env_ids
# Get x, y indices of LAST steps
dist = buffer.steps_to_episode_end(last_step_pos, last_env_ids)
if alf.summary.should_record_summaries():
@@ -831,22 +843,24 @@ def transform_experience(self, experience: Experience):
torch.mean(dist.type(torch.float32)))
# get random future state
- future_idx = last_step_pos + (torch.rand(*dist.shape) *
- (dist + 1)).to(torch.int64)
+ future_dist = (torch.rand(*dist.shape) * (dist + 1)).to(
+ torch.int64)
+ future_idx = last_step_pos + future_dist
future_ag = buffer.get_field(self._achieved_goal_field,
last_env_ids, future_idx).unsqueeze(1)
# relabel desired goal
result_desired_goal = alf.nest.get_field(result,
self._desired_goal_field)
- relabed_goal = result_desired_goal.clone()
+ relabeled_goal = result_desired_goal.clone()
her_batch_index_tuple = (her_indices.unsqueeze(1),
torch.arange(batch_length).unsqueeze(0))
- relabed_goal[her_batch_index_tuple] = future_ag
+ if has_her:
+ relabeled_goal[her_batch_index_tuple] = future_ag[her_indices]
# recompute rewards
result_ag = alf.nest.get_field(result, self._achieved_goal_field)
- relabeled_rewards = self._reward_fn(result_ag, relabed_goal)
+ relabeled_rewards = self._reward_fn(result_ag, relabeled_goal)
non_her_or_fst = ~her_cond.unsqueeze(1) & (result.step_type !=
StepType.FIRST)
@@ -876,21 +890,26 @@ def transform_experience(self, experience: Experience):
alf.summary.scalar(
"replayer/" + buffer._name + ".reward_mean_before_relabel",
torch.mean(result.reward[her_indices][:-1]))
- alf.summary.scalar(
- "replayer/" + buffer._name + ".reward_mean_after_relabel",
- torch.mean(relabeled_rewards[her_indices][:-1]))
+ if has_her:
+ alf.summary.scalar(
+ "replayer/" + buffer._name + ".reward_mean_after_relabel",
+ torch.mean(relabeled_rewards[her_indices][:-1]))
+ alf.summary.scalar("replayer/" + buffer._name + ".future_distance",
+ torch.mean(future_dist.float()))
result = alf.nest.transform_nest(
- result, self._desired_goal_field, lambda _: relabed_goal)
-
+ result, self._desired_goal_field, lambda _: relabeled_goal)
result = result.update_time_step_field('reward', relabeled_rewards)
-
+ derived = {"is_her": her_cond, "future_distance": future_dist}
if alf.get_default_device() != buffer.device:
for f in accessed_fields:
result = alf.nest.transform_nest(
result, f, lambda t: convert_device(t))
- result = alf.nest.transform_nest(
- result, "batch_info.replay_buffer", lambda _: buffer)
+ info = convert_device(info)
+ derived = convert_device(derived)
+ info = info._replace(replay_buffer=buffer)
+ info = info.set_derived(derived)
+ result = alf.data_structures.add_batch_info(result, info)
return result
diff --git a/alf/algorithms/ddpg_algorithm.py b/alf/algorithms/ddpg_algorithm.py
index 7c0678998..4d1a71b17 100644
--- a/alf/algorithms/ddpg_algorithm.py
+++ b/alf/algorithms/ddpg_algorithm.py
@@ -41,8 +41,14 @@
DdpgState = namedtuple("DdpgState", ['actor', 'critics'])
DdpgInfo = namedtuple(
"DdpgInfo", [
- "reward", "step_type", "discount", "action", "action_distribution",
- "actor_loss", "critic", "discounted_return"
+ "reward",
+ "step_type",
+ "discount",
+ "action",
+ "action_distribution",
+ "actor_loss",
+ "critic",
+ "discounted_return",
],
default_value=())
DdpgLossInfo = namedtuple('DdpgLossInfo', ('actor', 'critic'))
diff --git a/alf/algorithms/her_algorithms.py b/alf/algorithms/her_algorithms.py
new file mode 100644
index 000000000..ba11b61d4
--- /dev/null
+++ b/alf/algorithms/her_algorithms.py
@@ -0,0 +1,160 @@
+# Copyright (c) 2022 Horizon Robotics and ALF Contributors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""HER Algorithms (Wrappers)."""
+"""Classes defined here are used to transfer relevant info about the
+sampled/replayed experience from HindsightDataTransformer all the way to
+algorithm.calc_loss and the loss class.
+
+Actual hindsight relabeling happens in HindsightDataTransformer.
+
+For usage, see alf/examples/her_fetchpush_conf.py.
+"""
+
+import alf
+from alf.algorithms.sac_algorithm import SacAlgorithm, SacInfo
+from alf.algorithms.ddpg_algorithm import DdpgAlgorithm, DdpgInfo
+from alf.data_structures import TimeStep
+from alf.utils import common
+
+
+def her_wrapper(alg_cls, alg_info):
+ """A helper function to construct HerAlgo based on the base (off-policy) algorithm.
+
+ We mainly do two things here:
+ 1. Create the new HerInfo namedtuple, containing a ``derived`` field together
+ with the existing fields of AlgInfo. The ``derived`` field is a dict, to be
+ populated with information derived from the Hindsight relabeling process.
+ This HerInfo structure stores training information collected from replay and
+ processed by the algorithm's train_step.
+
+ 2. Create a new HerAlgo child class of the input base algorithm.
+ The new class additionally handles passing derived fields along the pipeline
+ for the loss function (e.g. LowerboundedTDLoss) to access.
+ """
+ HerClsName = "Her" + alg_cls.__name__
+ # HerAlgo class inherits the base RL algorithm class
+ HerCls = type(HerClsName, (alg_cls, ), {})
+ HerCls.counter = 0
+
+ HerInfoName = "Her" + alg_info.__name__
+ # Unfortunately, the user has to ensure that the default_value of HerAlgInfo has to be
+ # exactly the same as the AlgInfo, otherwise there could be bugs.
+ HerInfoCls = alf.data_structures.namedtuple(
+ HerInfoName, alg_info._fields + ("derived", ), default_value=())
+ alg_info.__name__ = HerInfoName
+
+ # NOTE: replay_buffer.py has similar functions for handling BatchInfo namedtuple.
+
+ # New __new__ for AlgInfo, so every time AlgInfo is called to create an instance,
+ # an HerAlgInfo instance (with the additional ``derived`` dict) is created and
+ # returned instead. This allows us to wrap an algorithm's AlgInfo class without
+ # changing any code in the original AlgInfo class, keeping HER code separate.
+ @common.add_method(alg_info)
+ def __new__(info_cls, **kwargs):
+ assert info_cls == alg_info
+ her_info = HerInfoCls(**kwargs)
+ # Set default value, later code will check for this
+ her_info = her_info._replace(derived={})
+ return her_info
+
+ # New accessor methods for HerAlgInfo to access the ``derived`` dict.
+ @common.add_method(HerInfoCls)
+ def get_derived_field(self, field):
+ assert field in self.derived, f"field {field} not in BatchInfo.derived"
+ return self.derived[field]
+
+ @common.add_method(HerInfoCls)
+ def get_derived(self):
+ return self.derived
+
+ @common.add_method(HerInfoCls)
+ def set_derived(self, new_dict):
+ assert self.derived == {}
+ return self._replace(derived=new_dict)
+
+ # New methods for HerAlg
+ @common.add_method(HerCls)
+ def __init__(self, **kwargs):
+ """
+ Args:
+ kwargs: arguments passed to the constructor of the underlying algorithm.
+ """
+ assert HerCls.counter == 0, f"HerCls {HerCls} already defined"
+ super(HerCls, self).__init__(**kwargs)
+ HerCls.counter += 1
+
+ @common.add_method(HerCls)
+ def preprocess_experience(self, inputs: TimeStep, rollout_info: alg_info,
+ batch_info):
+ """Pass derived fields from batch_info into rollout_info"""
+ time_step, rollout_info = super(HerCls, self).preprocess_experience(
+ inputs, rollout_info, batch_info)
+ if hasattr(rollout_info, "derived") and batch_info.derived:
+ # Expand to the proper dimensions consistent with other experience fields
+ derived = alf.nest.map_structure(
+ lambda x: x.unsqueeze(1).expand(time_step.reward.shape[:2]),
+ batch_info.get_derived())
+ rollout_info = rollout_info.set_derived(derived)
+ return time_step, rollout_info
+
+ @common.add_method(HerCls)
+ def train_step(self, inputs: TimeStep, state, rollout_info: alg_info):
+ """Pass derived fields from rollout_info into alg_step.info"""
+ alg_step = super(HerCls, self).train_step(inputs, state, rollout_info)
+ return alg_step._replace(
+ info=alg_step.info.set_derived(rollout_info.get_derived()))
+
+ return HerCls # End of her_wrapper function
+
+
+# Create the actual wrapped HerAlgorithms
+HerSacAlgorithm = her_wrapper(SacAlgorithm, SacInfo)
+HerDdpgAlgorithm = her_wrapper(DdpgAlgorithm, DdpgInfo)
+"""To help understand what's going on, here is the detailed data flow:
+
+1. Replayer samples the experience with batch_info from replay_buffer.
+
+2. HindsightDataTransformer samples and relabels the experience, stores the derived info containing
+her: whether the experience has been relabeled, future_distance: the number of time steps to
+the future achieved goal used to relabel the experience.
+HindsightDataTransformer finally returns experience with experience.batch_info.derived
+containing the derived information.
+
+(NOTE: we cannot put HindsightDataTransformer into HerAlgo.preprocess_experience, as preprocessing
+happens after data_transformations, but Hindsight relabeling has to happen before other data
+transformations like observation normalization, because hindsight accesses replay_buffer data directly,
+which has not gone through the data transformers.
+Maybe we could invoke HindsightDataTransformer automatically, e.g. by preprending it to
+``TrainConfig.data_transformer_ctr`` in this file. Maybe that's too magical, and should be avoided.)
+
+3. HerAlgo.preprocess_experience copies ``batch_info.derived`` over to ``rollout_info.derived``.
+NOTE: We cannot copy from exp to rollout_info because the input to preprocess_experience is time_step,
+not exp in algorithm.py:
+
+.. code-block:: python
+
+ time_step, rollout_info = self.preprocess_experience(
+ experience.time_step, experience.rollout_info, batch_info)
+
+4. HerAlgo.train_step copies ``exp.rollout_info.derived`` over to ``policy_step.info.derived``.
+NOTE: we cannot just copy derived from exp into AlgInfo in train_step, because train_step accepts
+time_step instead of exp as input:
+
+.. code-block:: python
+
+ policy_step = self.train_step(exp.time_step, policy_state,
+ exp.rollout_info)
+
+5. BaseAlgo.calc_loss will call LowerboundedTDLoss with HerBaseAlgoInfo.
+"""
diff --git a/alf/algorithms/her_algorithms_test.py b/alf/algorithms/her_algorithms_test.py
new file mode 100644
index 000000000..9d095968a
--- /dev/null
+++ b/alf/algorithms/her_algorithms_test.py
@@ -0,0 +1,46 @@
+# Copyright (c) 2022 Horizon Robotics and ALF Contributors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from absl.testing import parameterized
+
+import alf
+from alf.algorithms.her_algorithms import HerSacAlgorithm, HerDdpgAlgorithm
+from alf.algorithms.sac_algorithm import SacAlgorithm, SacInfo
+from alf.algorithms.ddpg_algorithm import DdpgAlgorithm, DdpgInfo
+
+
+class HerAlgorithmsTest(parameterized.TestCase, alf.test.TestCase):
+ def test_her_algo_name(self):
+ self.assertEqual("HerSacAlgorithm", HerSacAlgorithm.__name__)
+ self.assertEqual("HerDdpgAlgorithm", HerDdpgAlgorithm.__name__)
+
+ @parameterized.parameters([
+ (SacInfo, ),
+ (DdpgInfo, ),
+ ])
+ def test_her_info(self, Info):
+ info = Info(reward=1)
+ self.assertEqual(1, info.reward)
+ # HerAlgInfo assumes default field value to be (), need to be consistent with AlgInfo
+ self.assertEqual((), info.action)
+ self.assertEqual({}, info.get_derived())
+ ret = info.set_derived({"a": 1, "b": 2})
+ # info is immutable
+ self.assertEqual({}, info.get_derived())
+ # ret is the new instance with field "derived" replaced
+ self.assertEqual(1, ret.get_derived_field("a"))
+ self.assertEqual(2, ret.get_derived_field("b"))
+ # get nonexistent field with and without default
+ self.assertEqual("none", ret.get_derived_field("x", default="none"))
+ self.assertRaises(AssertionError, ret.get_derived_field, "x")
diff --git a/alf/algorithms/sac_algorithm.py b/alf/algorithms/sac_algorithm.py
index 8235d0834..b8dc463e3 100644
--- a/alf/algorithms/sac_algorithm.py
+++ b/alf/algorithms/sac_algorithm.py
@@ -55,8 +55,16 @@
SacInfo = namedtuple(
"SacInfo", [
- "reward", "step_type", "discount", "action", "action_distribution",
- "actor", "critic", "alpha", "log_pi", "discounted_return"
+ "reward",
+ "step_type",
+ "discount",
+ "action",
+ "action_distribution",
+ "actor",
+ "critic",
+ "alpha",
+ "log_pi",
+ "discounted_return",
],
default_value=())
diff --git a/alf/algorithms/td_loss.py b/alf/algorithms/td_loss.py
index 80c2c0a93..6b45caa7b 100644
--- a/alf/algorithms/td_loss.py
+++ b/alf/algorithms/td_loss.py
@@ -106,59 +106,65 @@ def gamma(self):
"""
return self._gamma.clone()
- def compute_td_target(self, info: namedtuple, target_value: torch.Tensor):
+ def compute_td_target(self,
+ info: namedtuple,
+ value: torch.Tensor,
+ target_value: torch.Tensor,
+ qr: bool = False):
"""Calculate the td target.
The first dimension of all the tensors is time dimension and the second
dimesion is the batch dimension.
Args:
- info (namedtuple): experience collected from ``unroll()`` or
+ info (namedtuple): AlgInfo collected from ``unroll()`` or
a replay buffer. All tensors are time-major. ``info`` should
contain the following fields:
- reward:
- step_type:
- discount:
+ value (torch.Tensor): the time-major tensor for the value at
+ each time step. Some of its value can be overwritten and passed
+ back to the caller.
target_value (torch.Tensor): the time-major tensor for the value at
each time step. This is used to calculate return. ``target_value``
can be same as ``value``.
Returns:
- td_target
+ td_target, updated value, optional constraint_loss
"""
+ if not qr and info.reward.ndim == 3:
+ # Multi-dim reward, not quantile regression.
+ # [T, B, D] or [T, B, 1]
+ discounts = info.discount.unsqueeze(-1) * self._gamma
+ else:
+ # [T, B]
+ discounts = info.discount * self._gamma
+
if self._lambda == 1.0:
returns = value_ops.discounted_return(
rewards=info.reward,
values=target_value,
step_types=info.step_type,
- discounts=info.discount * self._gamma)
+ discounts=discounts)
elif self._lambda == 0.0:
returns = value_ops.one_step_discounted_return(
rewards=info.reward,
values=target_value,
step_types=info.step_type,
- discounts=info.discount * self._gamma)
+ discounts=discounts)
else:
advantages = value_ops.generalized_advantage_estimation(
rewards=info.reward,
values=target_value,
step_types=info.step_type,
- discounts=info.discount * self._gamma,
+ discounts=discounts,
td_lambda=self._lambda)
returns = advantages + target_value[:-1]
- disc_ret = ()
- if hasattr(info, "discounted_return"):
- disc_ret = info.discounted_return
- if disc_ret != ():
- with alf.summary.scope(self._name):
- episode_ended = disc_ret > self._default_return
- alf.summary.scalar("episodic_discounted_return_all",
- torch.mean(disc_ret[episode_ended]))
- alf.summary.scalar(
- "value_episode_ended_all",
- torch.mean(value[:-1][:, episode_ended[0, :]]))
+ returns = advantages + value[:-1]
+ returns = returns.detach()
- return returns
+ return returns, value, None
def forward(self, info: namedtuple, value: torch.Tensor,
target_value: torch.Tensor):
@@ -182,7 +188,8 @@ def forward(self, info: namedtuple, value: torch.Tensor,
Returns:
LossInfo: with the ``extra`` field same as ``loss``.
"""
- returns = self.compute_td_target(info, target_value)
+ returns, value, constraint_loss = self.compute_td_target(
+ info, value, target_value)
value = value[:-1]
if self._normalize_target:
@@ -230,6 +237,256 @@ def _summarize(v, r, td, suffix):
return LossInfo(loss=loss, extra=loss)
+@alf.configurable
+class LowerBoundedTDLoss(TDLoss):
+ """Temporal difference loss with value target lower bounding."""
+
+ def __init__(self,
+ gamma: Union[float, List[float]] = 0.99,
+ td_error_loss_fn: Callable = element_wise_squared_loss,
+ td_lambda: float = 0.95,
+ normalize_target: bool = False,
+ lb_target_q: float = 0.,
+ default_return: float = -1000.,
+ improve_w_goal_return: bool = False,
+ improve_w_nstep_bootstrap: bool = False,
+ improve_w_nstep_only: bool = False,
+ reward_multiplier: float = 1.,
+ positive_reward: bool = True,
+ debug_summaries: bool = False,
+ name: str = "LbTDLoss"):
+ r"""
+ Args:
+ gamma .. use_retrace: pass through to TDLoss.
+ lb_target_q: between 0 and 1. When not zero, use this mixing rate for the
+ lower bounded value target. Only supports batch_length == 2, one step td.
+ Suppose the original one step bootstrapped TD target is :math:`G(s)`, (which
+ equals :math:`r(s) + \gamma Q(s', a')`), the discounted accumulated return
+ to episode end is :math:`G^e(s)`, then, the new lower bounded value target is
+
+ .. math::
+
+ G^{lb}(s) \coloneqq \max(G^e(s), G(s))
+
+ default_return: Keep it the same as replay_buffer.default_return to plot to
+ tensorboard episodic_discounted_return only for the timesteps whose
+ episode already ended.
+ improve_w_goal_return: Use return calculated from the distance to hindsight
+ goals. Only supports batch_length == 2, one step td.
+ Suppose the original one step bootstrapped TD target is :math:`G(s)`, the
+ number of steps to the relabeled goal state is :math:`d`, then, for an
+ episodic task with 0/1 sparse goal reward, the new lower bounded
+ value target is
+
+ .. math::
+
+ G^{lb}(s) \coloneqq \max(\gamma^d, G(s))
+
+ improve_w_nstep_bootstrap: Look ahead 2 to n steps, and take the largest
+ bootstrapped return to lower bound the value target of the 1st step.
+ Suppose the original one step bootstrapped TD target is :math:`G(s)`, the
+ n-step bootstrapped return is :math:`G_i(s)` where :math:`i \in [1, ..., n]`,
+ then, the new lower bounded value target is
+
+ .. math::
+
+ G^{lb}(s) \coloneqq \max(\max_{i \in [1, ..., n]}(G_i(s)), G(s))
+
+ improve_w_nstep_only: Only use the n-th step bootstrapped return as
+ value target lower bound.
+ The new lower bounded value target is
+
+ .. math::
+
+ G^{lb}(s) \coloneqq \max(G_n(s), G(s))
+
+ reward_multiplier: Weight on the hindsight goal return.
+ positive_reward: If True, assumes 0/1 goal reward in an episodic task,
+ otherwise, -1/0 in a continuing task.
+ debug_summaries: True if debug summaries should be created.
+ name: The name of this loss.
+ """
+ super().__init__(
+ gamma=gamma,
+ td_error_loss_fn=td_error_loss_fn,
+ td_lambda=td_lambda,
+ normalize_target=normalize_target,
+ name=name,
+ debug_summaries=debug_summaries)
+
+ self._lb_target_q = lb_target_q
+ self._default_return = default_return
+ self._improve_w_goal_return = improve_w_goal_return
+ self._improve_w_nstep_bootstrap = improve_w_nstep_bootstrap
+ self._improve_w_nstep_only = improve_w_nstep_only
+ self._reward_multiplier = reward_multiplier
+ self._positive_reward = positive_reward
+
+ def compute_td_target(self,
+ info: namedtuple,
+ value: torch.Tensor,
+ target_value: torch.Tensor,
+ qr: bool = False):
+ """Calculate the td target.
+
+ The first dimension of all the tensors is time dimension and the second
+ dimesion is the batch dimension.
+
+ Args:
+ info (namedtuple): AlgInfo collected from ``unroll()`` or
+ a replay buffer. All tensors are time-major. ``info`` should
+ contain the following fields:
+ - reward:
+ - step_type:
+ - discount:
+ value (torch.Tensor): the time-major tensor for the value at
+ each time step. Some of its value can be overwritten and passed
+ back to the caller.
+ target_value (torch.Tensor): the time-major tensor for the value at
+ each time step. This is used to calculate return. ``target_value``
+ can be same as ``value``, except for Retrace.
+ Returns:
+ td_target, updated value, optional constraint_loss
+ """
+ returns, value, _ = super().compute_td_target(info, value,
+ target_value, qr)
+
+ constraint_loss = None
+ if self._improve_w_nstep_bootstrap:
+ assert self._lambda == 1.0, "td lambda does not work with this"
+ future_returns = value_ops.first_step_future_discounted_returns(
+ rewards=info.reward,
+ values=target_value,
+ step_types=info.step_type,
+ discounts=discounts)
+ returns = value_ops.one_step_discounted_return(
+ rewards=info.reward,
+ values=target_value,
+ step_types=info.step_type,
+ discounts=discounts)
+ assert torch.all((returns[0] == future_returns[0]) | (
+ info.step_type[0] == alf.data_structures.StepType.LAST)), \
+ str(returns[0]) + " ne\n" + str(future_returns[0]) + \
+ '\nrwd: ' + str(info.reward[0:2]) + \
+ '\nlast: ' + str(info.step_type[0:2]) + \
+ '\ndisct: ' + str(discounts[0:2]) + \
+ '\nv: ' + str(target_value[0:2])
+ if self._improve_w_nstep_only:
+ future_returns = future_returns[
+ -1] # last is the n-step return
+ else:
+ future_returns = torch.max(future_returns, dim=0)[0]
+
+ with alf.summary.scope(self._name):
+ alf.summary.scalar(
+ "max_1_to_n_future_return_gt_td",
+ torch.mean((returns[0] < future_returns).float()))
+ alf.summary.scalar("first_step_discounted_return",
+ torch.mean(returns[0]))
+
+ returns[0] = torch.max(future_returns, returns[0]).detach()
+ returns[1:] = 0
+ value = value.clone()
+ value[1:] = 0
+
+ disc_ret = ()
+ if hasattr(info, "discounted_return"):
+ disc_ret = info.discounted_return
+ if disc_ret != ():
+ with alf.summary.scope(self._name):
+ episode_ended = disc_ret > self._default_return
+ alf.summary.scalar("episodic_discounted_return_all",
+ torch.mean(disc_ret[episode_ended]))
+ alf.summary.scalar(
+ "value_episode_ended_all",
+ torch.mean(value[:-1][:, episode_ended[0, :]]))
+
+ if self._lb_target_q > 0 and disc_ret != ():
+ if hasattr(info, "get_derived_field"):
+ her_cond = info.get_derived_field("is_her")
+ else:
+ her_cond = ()
+ mask = torch.ones(returns.shape, dtype=torch.bool)
+ if her_cond != () and torch.any(~her_cond):
+ mask = ~her_cond[:-1]
+ disc_ret = disc_ret[
+ 1:] # it's expanded in Agent.preprocess_experience, need to revert back.
+ assert returns.shape == disc_ret.shape, "%s %s" % (returns.shape,
+ disc_ret.shape)
+ with alf.summary.scope(self._name):
+ alf.summary.scalar(
+ "episodic_return_gt_td",
+ torch.mean((returns < disc_ret).float()[mask]))
+ alf.summary.scalar(
+ "episodic_discounted_return",
+ torch.mean(
+ disc_ret[mask & (disc_ret > self._default_return)]))
+ returns[mask] = (1 - self._lb_target_q) * returns[mask] + \
+ self._lb_target_q * torch.max(returns, disc_ret)[mask]
+
+ if self._improve_w_goal_return:
+ batch_length, batch_size = returns.shape[:2]
+ her_cond = info.get_derived_field("is_her")
+ if her_cond != () and torch.any(her_cond):
+ dist = info.get_derived_field("future_distance")
+ if self._positive_reward:
+ goal_return = torch.pow(
+ self._gamma * torch.ones(her_cond.shape), dist)
+ else:
+ goal_return = -(1. - torch.pow(self._gamma, dist)) / (
+ 1. - self._gamma)
+ goal_return *= self._reward_multiplier
+ goal_return = goal_return[:-1]
+ returns_0 = returns
+ # Multi-dim reward:
+ if len(returns.shape) > 2:
+ returns_0 = returns[:, :, 0]
+ returns_0 = torch.where(her_cond[:-1],
+ torch.max(returns_0, goal_return),
+ returns_0)
+ with alf.summary.scope(self._name):
+ alf.summary.scalar(
+ "goal_return_gt_td",
+ torch.mean((returns_0 < goal_return).float()))
+ alf.summary.scalar("goal_return", torch.mean(goal_return))
+ if len(returns.shape) > 2:
+ returns[:, :, 0] = returns_0
+ else:
+ returns = returns_0
+
+ return returns, value, constraint_loss
+
+ def forward(self, info: namedtuple, value: torch.Tensor,
+ target_value: torch.Tensor):
+ """Calculate the loss.
+
+ The first dimension of all the tensors is time dimension and the second
+ dimesion is the batch dimension.
+
+ Args:
+ info: experience collected from ``unroll()`` or
+ a replay buffer. All tensors are time-major. ``info`` should
+ contain the following fields:
+ - reward:
+ - step_type:
+ - discount:
+ value: the time-major tensor for the value at each time
+ step. The loss is between this and the calculated return.
+ target_value: the time-major tensor for the value at
+ each time step. This is used to calculate return. ``target_value``
+ can be same as ``value``.
+ Returns:
+ LossInfo: with the ``extra`` field same as ``loss``.
+ """
+ loss_info = super().forward(info, value, target_value)
+ loss = loss_info.loss
+ if self._improve_w_nstep_bootstrap:
+ # Ignore 2nd to n-th step losses.
+ loss[1:] = 0
+
+ return LossInfo(loss=loss, extra=loss)
+
+
@alf.configurable
class TDQRLoss(TDLoss):
"""Temporal difference quantile regression loss.
@@ -301,7 +558,8 @@ def forward(self, info: namedtuple, value: torch.Tensor,
assert target_value.shape[-1] == self._num_quantiles, (
"The input target_value should have same num_quantiles as pre-defiend."
)
- returns = self.compute_td_target(info, target_value)
+ returns, value, constraint_loss = self.compute_td_target(
+ info, value, target_value, qr=True)
value = value[:-1]
# for quantile regression TD, the value and target both have shape
diff --git a/alf/algorithms/td_loss_test.py b/alf/algorithms/td_loss_test.py
new file mode 100644
index 000000000..2458fb89e
--- /dev/null
+++ b/alf/algorithms/td_loss_test.py
@@ -0,0 +1,65 @@
+# Copyright (c) 2019 Horizon Robotics. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+import torch
+
+import alf
+from alf.algorithms.td_loss import LowerBoundedTDLoss
+from alf.data_structures import TimeStep, StepType, namedtuple
+
+DataItem = namedtuple(
+ "DataItem", ["reward", "step_type", "discount"], default_value=())
+
+
+class LowerBoundedTDLossTest(unittest.TestCase):
+ """Tests for alf.algorithms.td_loss.LowerBoundedTDLoss
+ """
+
+ def _check(self, res, expected):
+ np.testing.assert_array_almost_equal(res, expected)
+
+ def test_compute_td_target_nstep_bootstrap_lowerbound(self):
+ loss = LowerBoundedTDLoss(
+ gamma=1., improve_w_nstep_bootstrap=True, td_lambda=1)
+ # Tensors are transposed to be time_major [T, B, ...]
+ step_types = torch.tensor([[StepType.MID] * 5],
+ dtype=torch.int64).transpose(0, 1)
+ rewards = torch.tensor([[2.] * 5], dtype=torch.float32).transpose(0, 1)
+ discounts = torch.tensor([[0.9] * 5], dtype=torch.float32).transpose(
+ 0, 1)
+ values = torch.tensor([[1.] * 5], dtype=torch.float32).transpose(0, 1)
+ info = DataItem(
+ reward=rewards, step_type=step_types, discount=discounts)
+ returns, value, _ = loss.compute_td_target(info, values, values)
+ expected_return = torch.tensor(
+ [[2 + 0.9 * (2 + 0.9 * (2 + 0.9 * (2 + 0.9))), 0, 0, 0]],
+ dtype=torch.float32).transpose(0, 1)
+ self._check(res=returns, expected=expected_return)
+
+ expected_value = torch.tensor([[1, 0, 0, 0, 0]],
+ dtype=torch.float32).transpose(0, 1)
+ self._check(res=value, expected=expected_value)
+
+ # n-step return is below 1-step
+ values[2:] = -10
+ expected_return[0] = 2 + 0.9
+ returns, value, _ = loss.compute_td_target(info, values, values)
+ self._check(res=returns, expected=expected_return)
+
+
+if __name__ == '__main__':
+ alf.test.main()
diff --git a/alf/examples/dqn_breakout_conf-lbtq-Qbert.png b/alf/examples/dqn_breakout_conf-lbtq-Qbert.png
new file mode 100644
index 000000000..782096eeb
Binary files /dev/null and b/alf/examples/dqn_breakout_conf-lbtq-Qbert.png differ
diff --git a/alf/examples/dqn_breakout_conf_Qbert.png b/alf/examples/dqn_breakout_conf_Qbert.png
index d6fc83c43..47dc2c10e 100644
Binary files a/alf/examples/dqn_breakout_conf_Qbert.png and b/alf/examples/dqn_breakout_conf_Qbert.png differ
diff --git a/alf/examples/her_fetchpush_conf.py b/alf/examples/her_fetchpush_conf.py
index e4a15ade0..04ac5e967 100644
--- a/alf/examples/her_fetchpush_conf.py
+++ b/alf/examples/her_fetchpush_conf.py
@@ -16,6 +16,7 @@
from alf.algorithms.data_transformer import HindsightExperienceTransformer, \
ObservationNormalizer
from alf.algorithms.ddpg_algorithm import DdpgAlgorithm
+from alf.algorithms.her_algorithms import HerDdpgAlgorithm
from alf.environments import suite_robotics
from alf.nest.utils import NestConcat
@@ -38,6 +39,8 @@
alf.config('DdpgAlgorithm', action_l2=0.05)
+alf.config('Agent', rl_algorithm_cls=HerDdpgAlgorithm)
+
# Finer grain tensorboard summaries plus local action distribution
# TrainerConfig.summarize_action_distributions=True
# TrainerConfig.summary_interval=1
diff --git a/alf/examples/sac_breakout_conf-lbtq-Qbert.png b/alf/examples/sac_breakout_conf-lbtq-Qbert.png
new file mode 100644
index 000000000..61536cebd
Binary files /dev/null and b/alf/examples/sac_breakout_conf-lbtq-Qbert.png differ
diff --git a/alf/examples/sac_breakout_conf.py b/alf/examples/sac_breakout_conf.py
index e6b163393..9a2effd41 100644
--- a/alf/examples/sac_breakout_conf.py
+++ b/alf/examples/sac_breakout_conf.py
@@ -12,10 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+# NOTE: to use this on a different atari game, add this flag:
+# --conf_param='create_environment.env_name="QbertNoFrameskip-v4"'
+
+# NOTE: for lower bound value target improvement, add these flags:
+# --conf_param='ReplayBuffer.keep_episodic_info=True'
+# --conf_param='ReplayBuffer.record_episodic_return=True'
+# --conf_param='LowerBoundedTDLoss.lb_target_q=True'
+
import functools
import alf
-from alf.algorithms.td_loss import TDLoss
+from alf.algorithms.td_loss import LowerBoundedTDLoss
from alf.environments.alf_wrappers import AtariTerminalOnLifeLossWrapper
from alf.networks import QNetwork
from alf.optimizers import AdamTF
@@ -42,7 +50,7 @@ def define_config(name, default_value):
fc_layer_params=FC_LAYER_PARAMS,
conv_layer_params=CONV_LAYER_PARAMS)
-critic_loss_ctor = functools.partial(TDLoss, td_lambda=0.95)
+critic_loss_ctor = functools.partial(LowerBoundedTDLoss, td_lambda=0)
lr = define_config('lr', 5e-4)
critic_optimizer = AdamTF(lr=lr)
@@ -61,7 +69,7 @@ def define_config(name, default_value):
target_update_period=20)
gamma = define_config('gamma', 0.99)
-alf.config('OneStepTDLoss', gamma=gamma)
+alf.config('LowerBoundedTDLoss', gamma=gamma)
alf.config('ReplayBuffer', gamma=gamma, reward_clip=(-1, 1))
# training config
@@ -82,7 +90,8 @@ def define_config(name, default_value):
num_env_steps=12000000,
evaluate=True,
num_eval_episodes=100,
- num_evals=10,
+ num_evals=50,
+ num_eval_environments=20,
num_checkpoints=5,
num_summaries=100,
debug_summaries=True,
diff --git a/alf/experience_replayers/replay_buffer.py b/alf/experience_replayers/replay_buffer.py
index 6e95fe334..eda430f67 100644
--- a/alf/experience_replayers/replay_buffer.py
+++ b/alf/experience_replayers/replay_buffer.py
@@ -29,15 +29,71 @@
from .segment_tree import SumSegmentTree, MaxSegmentTree
-BatchInfo = namedtuple(
- "BatchInfo", [
- "env_ids",
- "positions",
- "importance_weights",
- "replay_buffer",
- "discounted_return",
- ],
- default_value=())
+
+class BatchInfo(
+ namedtuple(
+ "BatchInfo", [
+ "env_ids",
+ "positions",
+ "importance_weights",
+ "replay_buffer",
+ "discounted_return",
+ "derived",
+ ],
+ default_value=())):
+ """BatchInfo stores information related to a sampled experience batch of size B.
+ - env_ids: shape [B]: environment id for each sequence.
+ - positions: shape [B]: starting position in the replay buffer for each sequence.
+ - importance_weights: shape [B]: priority divided by the average of all
+ non-zero priorities in the buffer.
+ - replay_buffer: the replay buffer object. Data transformations like FrameStacker and
+ Hindsight relabeler may need access to other data not sampled by the replayer.
+ - discounted_return: shape [B]: the accumulated future discounted return of
+ the first step of each sequence.
+ - derived: A dictionary of fields derived from the experience, e.g.
+ hindsight relabeling may return the number of steps to the future achieved goal
+ used to relabel or whether the sequence has been relabeled.
+ NOTE: ``derived`` has to be accessed through the member functions
+ ``get_derived_field()`` and ``add_derived_field()``,
+ which check for field name conflicts.
+ """
+
+ # NOTE: her_algorithms.py has similar functions for handling AlgInfo namedtuple.
+
+ def __new__(cls, *args, **kwargs):
+ info = super(BatchInfo, cls).__new__(cls, *args, **kwargs)
+ # Set default value, later code will check for this
+ info = info._replace(derived={})
+ return info
+
+ def add_derived_field(self, field, new_value):
+ """Add ``new_value`` to ``batch_info.derived[field]``.
+ Args:
+ field (str): indicate the field to be updated
+ new_value (any): the new value for the field
+ Returns:
+ BatchInfo: a structure the same as the original batch_info except
+ that the field ``field`` in the ``derived`` is set to ``new_value``.
+ """
+ assert field not in self.derived, f"field {field} already exists"
+ self.derived[field] = new_value
+ return self
+
+ def get_derived(self):
+ """Return dict ``batch_info.derived``.
+ """
+ return self.derived
+
+ def set_derived(self, new_dict):
+ """Set the ``batch_info.derived`` field to ``new_dict``.
+ Args:
+ new_dict (dict): the new value for ``batch_info.derived``
+ Returns:
+ BatchInfo: a structure the same as the original batch_info except
+ that the field ``derived`` is set to ``new_dict``.
+ """
+ assert self.derived == {}
+ return self._replace(derived=new_dict)
@alf.configurable
diff --git a/alf/utils/value_ops.py b/alf/utils/value_ops.py
index 8c36deff4..63154fb58 100644
--- a/alf/utils/value_ops.py
+++ b/alf/utils/value_ops.py
@@ -230,7 +230,6 @@ def generalized_advantage_estimation(rewards,
is that the accumulated_td is reset to 0 for is_last steps in this function.
Define abbreviations:
-
- B: batch size representing number of trajectories
- T: number of steps per trajectory
@@ -278,3 +277,66 @@ def generalized_advantage_estimation(rewards,
advs = advs.transpose(0, 1)
return advs.detach()
+
+
+def first_step_future_discounted_returns(rewards,
+ values,
+ step_types,
+ discounts,
+ time_major=True):
+ """Computes future 1 to n step discounted returns for the first step.
+
+ Define abbreviations:
+
+ - B: batch size representing number of trajectories
+ - T: number of steps per trajectory
+
+ Args:
+ rewards (Tensor): shape is [T, B] (or [T]) representing rewards.
+ values (Tensor): shape is [T,B] (or [T]) representing values.
+ step_types (Tensor): shape is [T,B] (or [T]) representing step types.
+ discounts (Tensor): shape is [T, B] (or [T]) representing discounts.
+ time_major (bool): Whether input tensors are time major.
+ False means input tensors have shape [B, T].
+
+ Returns:
+ A tensor with shape [T-1, B] (or [T-1]) representing the discounted
+ returns. Shape is [B, T-1] when time_major is false.
+ """
+ if not time_major:
+ discounts = discounts.transpose(0, 1)
+ rewards = rewards.transpose(0, 1)
+ values = values.transpose(0, 1)
+ step_types = step_types.transpose(0, 1)
+
+ assert values.shape[0] >= 2, ("The sequence length needs to be "
+ "at least 2. Got {s}".format(
+ s=values.shape[0]))
+
+ is_lasts = (step_types == StepType.LAST).to(dtype=torch.float32)
+ is_lasts = common.expand_dims_as(is_lasts, values)
+ discounts = common.expand_dims_as(discounts, values)
+
+ accw = torch.ones_like(values)
+ accw[0] = (1 - is_lasts[0]) * discounts[1]
+ rets = torch.zeros_like(values)
+ rets[0] = rewards[1] * (1 - is_lasts[0]) + accw[0] * values[1]
+ # When ith is LAST, v[i+1] shouldn't be used in computing ret[i]. When disc[i] == 0, v[i] isn't used in computing ret[i-1].
+ # when 2nd is LAST, ret[0] = r[1] + disc[1] * v[1], ret[1] = r[1] + disc[1] * (r[2] + disc[2] * v[2]), ret[2] = r[1] + disc[1] * (r[2] + disc[2] * v[2])
+ # r[t] = (1 - is_last[t]) * reward[t + 1]
+ # acc_return_to[t] = acc_return_to[t - 1] + r[t]
+ # bootstrapped_return[t] = r[t] + (1 - is_last[t + 1]) * discounts[t + 1] * v[t + 1]
+ with torch.no_grad():
+ for t in range(rewards.shape[0] - 2):
+ accw[t + 1] = accw[t] * (1 - is_lasts[t + 1]) * discounts[t + 2]
+ rets[t + 1] = (
+ rets[t] + rewards[t + 2] * (1 - is_lasts[t + 1]) * accw[t] +
+ values[t + 2] * accw[t + 1] -
+ accw[t] * values[t + 1] * (1 - is_lasts[t + 1]))
+
+ rets = rets[:-1]
+
+ if not time_major:
+ rets = rets.transpose(0, 1)
+
+ return rets.detach()
diff --git a/alf/utils/value_ops_test.py b/alf/utils/value_ops_test.py
index ebd526127..6477edbb2 100644
--- a/alf/utils/value_ops_test.py
+++ b/alf/utils/value_ops_test.py
@@ -23,23 +23,46 @@ class DiscountedReturnTest(unittest.TestCase):
"""Tests for alf.utils.value_ops.discounted_return
"""
- def _check(self, rewards, values, step_types, discounts, expected):
- np.testing.assert_array_almost_equal(
- value_ops.discounted_return(
+ def _check(self,
+ rewards,
+ values,
+ step_types,
+ discounts,
+ expected,
+ future=False):
+ if future:
+ res = value_ops.first_step_future_discounted_returns(
rewards=rewards,
values=values,
step_types=step_types,
discounts=discounts,
- time_major=False), expected)
+ time_major=False)
+ else:
+ res = value_ops.discounted_return(
+ rewards=rewards,
+ values=values,
+ step_types=step_types,
+ discounts=discounts,
+ time_major=False)
- np.testing.assert_array_almost_equal(
- value_ops.discounted_return(
+ np.testing.assert_array_almost_equal(res, expected)
+
+ if future:
+ res = value_ops.first_step_future_discounted_returns(
rewards=torch.stack([rewards, 2 * rewards], dim=2),
values=torch.stack([values, 2 * values], dim=2),
step_types=step_types,
discounts=discounts,
- time_major=False), torch.stack([expected, 2 * expected],
- dim=2))
+ time_major=False)
+ else:
+ res = value_ops.discounted_return(
+ rewards=torch.stack([rewards, 2 * rewards], dim=2),
+ values=torch.stack([values, 2 * values], dim=2),
+ step_types=step_types,
+ discounts=discounts,
+ time_major=False)
+ np.testing.assert_array_almost_equal(
+ res, torch.stack([expected, 2 * expected], dim=2))
def test_discounted_return(self):
values = torch.tensor([[1.] * 5], dtype=torch.float32)
@@ -74,7 +97,7 @@ def test_discounted_return(self):
discounts=discounts,
expected=expected)
- # tow episodes, and end normal (discount=0)
+ # two episodes, and end normal (discount=0)
step_types = torch.tensor([[
StepType.MID, StepType.MID, StepType.LAST, StepType.MID,
StepType.MID
@@ -91,6 +114,100 @@ def test_discounted_return(self):
discounts=discounts,
expected=expected)
+ def test_first_step_future_discounted_returns(self):
+ values = torch.tensor([[1.] * 5], dtype=torch.float32)
+ step_types = torch.tensor([[StepType.MID] * 5], dtype=torch.int64)
+ rewards = torch.tensor([[2.] * 5], dtype=torch.float32)
+ discounts = torch.tensor([[0.9] * 5], dtype=torch.float32)
+ expected = torch.tensor([[
+ 2 + 0.9, 2 + 0.9 * (2 + 0.9), 2 + 0.9 * (2 + 0.9 * (2 + 0.9)),
+ 2 + 0.9 * (2 + 0.9 * (2 + 0.9 * (2 + 0.9)))
+ ]],
+ dtype=torch.float32)
+ self._check(
+ rewards=rewards,
+ values=values,
+ step_types=step_types,
+ discounts=discounts,
+ expected=expected,
+ future=True)
+
+ # two episodes, and exceed by time limit (discount=1)
+ step_types = torch.tensor([[
+ StepType.MID, StepType.MID, StepType.LAST, StepType.MID,
+ StepType.MID
+ ]],
+ dtype=torch.int32)
+ expected = torch.tensor([[
+ 2 + 0.9, 2 + 0.9 * (2 + 0.9), 2 + 0.9 * (2 + 0.9),
+ 2 + 0.9 * (2 + 0.9)
+ ]],
+ dtype=torch.float32)
+ self._check(
+ rewards=rewards,
+ values=values,
+ step_types=step_types,
+ discounts=discounts,
+ expected=expected,
+ future=True)
+
+ # two episodes, and end normal (discount=0)
+ step_types = torch.tensor([[
+ StepType.MID, StepType.MID, StepType.LAST, StepType.MID,
+ StepType.MID
+ ]],
+ dtype=torch.int32)
+ discounts = torch.tensor([[0.9, 0.9, 0.0, 0.9, 0.9]])
+ expected = torch.tensor(
+ [[2 + 0.9, 2 + 0.9 * 2, 2 + 0.9 * 2, 2 + 0.9 * 2]],
+ dtype=torch.float32)
+
+ self._check(
+ rewards=rewards,
+ values=values,
+ step_types=step_types,
+ discounts=discounts,
+ expected=expected,
+ future=True)
+
+ # two episodes with discount 0 LAST.
+ values = torch.tensor([[1.] * 5], dtype=torch.float32)
+ step_types = torch.tensor([[
+ StepType.MID, StepType.LAST, StepType.LAST, StepType.MID,
+ StepType.MID
+ ]],
+ dtype=torch.int32)
+ rewards = torch.tensor([[2.] * 5], dtype=torch.float32)
+ discounts = torch.tensor([[0.9, 0.0, 0.0, 0.9, 0.9]])
+ expected = torch.tensor([[2, 2, 2, 2]], dtype=torch.float32)
+
+ self._check(
+ rewards=rewards,
+ values=values,
+ step_types=step_types,
+ discounts=discounts,
+ expected=expected,
+ future=True)
+
+ # two episodes with discount 0 LAST.
+ values = torch.tensor([[1.] * 5], dtype=torch.float32)
+ step_types = torch.tensor([[
+ StepType.LAST, StepType.LAST, StepType.LAST, StepType.MID,
+ StepType.MID
+ ]],
+ dtype=torch.int32)
+ rewards = torch.tensor([[2.] * 5], dtype=torch.float32)
+ discounts = torch.tensor([[0.0, 0.0, 0.0, 0.9, 0.9]])
+ expected = torch.tensor([[0, 0, 0, 0]], dtype=torch.float32)
+
+ self._check(
+ rewards=rewards,
+ values=values,
+ step_types=step_types,
+ discounts=discounts,
+ expected=expected,
+ future=True)
+
class GeneralizedAdvantageTest(unittest.TestCase):
"""Tests for alf.utils.value_ops.generalized_advantage_estimation