From cfde47ee43e5c3856f5afbf08fea113b7277d7c8 Mon Sep 17 00:00:00 2001 From: Adam Gleave Date: Mon, 29 Jul 2019 10:28:16 +0200 Subject: [PATCH] Add missing `n_batch` attribute to `BasePolicy` (#418) * Add missing n_batch property * [skip ci] update changelog --- docs/misc/changelog.rst | 1 + stable_baselines/common/policies.py | 1 + 2 files changed, 2 insertions(+) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index e84d1087e8..1453f6418f 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -28,6 +28,7 @@ Bug Fixes: ^^^^^^^^^^ - fixed a bug in ``traj_segment_generator`` where the ``episode_starts`` was wrongly recorded, resulting in wrong calculation of Generalized Advantage Estimation (GAE), this affects TRPO, PPO1 and GAIL (thanks to @miguelrass for spotting the bug) +- add missing property `n_batch` in `BasePolicy`. Deprecations: ^^^^^^^^^^^^^ diff --git a/stable_baselines/common/policies.py b/stable_baselines/common/policies.py index ba9c437631..b5bf4ed04b 100644 --- a/stable_baselines/common/policies.py +++ b/stable_baselines/common/policies.py @@ -111,6 +111,7 @@ def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=Fals obs_phs=None, add_action_ph=False): self.n_env = n_env self.n_steps = n_steps + self.n_batch = n_batch with tf.variable_scope("input", reuse=False): if obs_phs is None: self._obs_ph, self._processed_obs = observation_input(ob_space, n_batch, scale=scale)