From a7dfb46a45105f1cf986ca57b6219cf0eb37521d Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 7 Nov 2024 13:48:36 -0800 Subject: [PATCH] can condition on high rewards during inference --- pi_zero_pytorch/pi_zero.py | 2 ++ pyproject.toml | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/pi_zero_pytorch/pi_zero.py b/pi_zero_pytorch/pi_zero.py index 5a404a6..7982dc3 100644 --- a/pi_zero_pytorch/pi_zero.py +++ b/pi_zero_pytorch/pi_zero.py @@ -420,6 +420,7 @@ def sample( token_ids, joint_states, trajectory_length: int, + reward_tokens = None, steps = 18, batch_size = 1, show_pbar = True @@ -442,6 +443,7 @@ def ode_fn(timestep, denoised_actions): joint_states, denoised_actions, times = timestep, + reward_tokens = reward_tokens, cached_state_keys_values = cached_state_kv, return_actions_flow = True, return_state_keys_values = True diff --git a/pyproject.toml b/pyproject.toml index eece4aa..658413c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "pi-zero-pytorch" -version = "0.0.6" +version = "0.0.8" description = "π0 in Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }