diff --git a/pi_zero_pytorch/pi_zero.py b/pi_zero_pytorch/pi_zero.py index 5154651..d46c0e8 100644 --- a/pi_zero_pytorch/pi_zero.py +++ b/pi_zero_pytorch/pi_zero.py @@ -17,6 +17,8 @@ from scipy.optimize import linear_sum_assignment +from ema_pytorch import EMA + from rotary_embedding_torch import ( RotaryEmbedding, apply_rotary_emb @@ -587,6 +589,24 @@ def load_pretrained_vlm_weights_( ): raise NotImplementedError + def create_ema( + self, + beta = 0.99, + **ema_kwargs + ) -> EMA: + + ema_pi_zero = EMA( + self, + beta = beta, + include_online_model = False, + forward_method_names = ( + 'sample_actions', + ), + **ema_kwargs + ) + + return ema_pi_zero + @torch.inference_mode() def sample_actions( self, diff --git a/pyproject.toml b/pyproject.toml index 1f1ea8e..e62736f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "pi-zero-pytorch" -version = "0.0.24" +version = "0.0.25" description = "π0 in Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } @@ -28,6 +28,7 @@ dependencies = [ "beartype", "einx>=0.3.0", "einops>=0.8.0", + "ema-pytorch>=0.7.3", "jaxtyping", "rotary-embedding-torch>=0.8.5", 'scipy',