From 3d4084c4b2501496acc8f6b6d1fd12d5df866693 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Mon, 18 Nov 2024 10:43:04 -0800 Subject: [PATCH] cleanup --- pi_zero_pytorch/pi_zero.py | 19 +++++++++---------- pyproject.toml | 2 +- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/pi_zero_pytorch/pi_zero.py b/pi_zero_pytorch/pi_zero.py index 5b33e7d..63bc3d0 100644 --- a/pi_zero_pytorch/pi_zero.py +++ b/pi_zero_pytorch/pi_zero.py @@ -623,8 +623,6 @@ def ode_fn(timestep, denoised_actions): cond_scale = cond_scale, remove_parallel_component = remove_parallel_component, keep_parallel_frac = keep_parallel_frac, - return_actions_flow = True, - return_state_keys_values = True ) if cache_kv: @@ -664,19 +662,22 @@ def forward_with_reward_cfg( cond_scale = 0., remove_parallel_component = False, keep_parallel_frac = 0., - return_state_keys_values = True, - **kwargs ): - assert return_state_keys_values, 'cached key values must be turned on' + assert self.can_cfg, 'you need to train with reward token dropout' with_reward_cache, without_reward_cache = cached_state_keys_values + forward_kwargs = dict( + return_state_keys_values = True, + return_actions_flow = True, + ) + maybe_reward_out = self.forward( *args, reward_tokens = reward_tokens, cached_state_keys_values = with_reward_cache, - return_state_keys_values = return_state_keys_values, + **forward_kwargs, **kwargs ) @@ -685,15 +686,13 @@ def forward_with_reward_cfg( if not exists(reward_tokens) or cond_scale == 0.: return action_flow_with_reward, (with_reward_cache_kv, None) - no_reward_out = self.forward( + action_flow_without_reward, without_reward_cache_kv = self.forward( *args, cached_state_keys_values = without_reward_cache, - return_state_keys_values = return_state_keys_values, + **forward_kwargs, **kwargs ) - action_flow_without_reward, without_reward_cache_kv = no_reward_out - update = action_flow_with_reward - action_flow_without_reward if remove_parallel_component: diff --git a/pyproject.toml b/pyproject.toml index 0ac4772..29db90b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "pi-zero-pytorch" -version = "0.0.18" +version = "0.0.19" description = "π0 in Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }