From 4385102c49dabcf18df12cce1410f92ac4009153 Mon Sep 17 00:00:00 2001 From: Kaixhin Date: Wed, 15 May 2019 21:12:45 +0100 Subject: [PATCH] Update hyperparameters to camera ready version --- main.py | 8 ++++---- models.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/main.py b/main.py index 1fb2de6..db4fa1a 100644 --- a/main.py +++ b/main.py @@ -36,10 +36,10 @@ parser.add_argument('--batch-size', type=int, default=50, metavar='B', help='Batch size') parser.add_argument('--chunk-size', type=int, default=50, metavar='L', help='Chunk size') parser.add_argument('--overshooting-distance', type=int, default=50, metavar='D', help='Latent overshooting distance/latent overshooting weight for t = 1') -parser.add_argument('--overshooting-kl-beta', type=float, default=1, metavar='β>1', help='Latent overshooting KL weight for t > 1 (0 to disable)') -parser.add_argument('--overshooting-reward-scale', type=float, default=1, metavar='R>1', help='Latent overshooting reward prediction weight for t > 1 (0 to disable)') -parser.add_argument('--global-kl-beta', type=float, default=0.1, metavar='βg', help='Global KL weight (0 to disable)') -parser.add_argument('--free-nats', type=float, default=2, metavar='F', help='Free nats') +parser.add_argument('--overshooting-kl-beta', type=float, default=0, metavar='β>1', help='Latent overshooting KL weight for t > 1 (0 to disable)') +parser.add_argument('--overshooting-reward-scale', type=float, default=0, metavar='R>1', help='Latent overshooting reward prediction weight for t > 1 (0 to disable)') +parser.add_argument('--global-kl-beta', type=float, default=0, metavar='βg', help='Global KL weight (0 to disable)') +parser.add_argument('--free-nats', type=float, default=3, metavar='F', help='Free nats') parser.add_argument('--learning-rate', type=float, default=1e-3, metavar='α', help='Learning rate') # TODO: Original has a linear learning rate decay, but it seems unlikely that this makes a significant difference parser.add_argument('--grad-clip-norm', type=float, default=1000, metavar='C', help='Gradient clipping norm') parser.add_argument('--planning-horizon', type=int, default=12, metavar='H', help='Planning horizon distance') diff --git a/models.py b/models.py index 189c07b..26454ad 100644 --- a/models.py +++ b/models.py @@ -15,7 +15,7 @@ def bottle(f, x_tuple): class TransitionModel(jit.ScriptModule): __constants__ = ['min_std_dev'] - def __init__(self, belief_size, state_size, action_size, hidden_size, embedding_size, activation_function='relu', min_std_dev=1e-5): + def __init__(self, belief_size, state_size, action_size, hidden_size, embedding_size, activation_function='relu', min_std_dev=0.1): super().__init__() self.act_fn = getattr(F, activation_function) self.min_std_dev = min_std_dev