Skip to content

Commit

Permalink
fix another error thanks to @Wonder1905 and start appreciation section
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 4, 2024
1 parent cf77236 commit aa6f5e2
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 7 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,16 @@ Implementation of <a href="https://www.physicalintelligence.company/blog/pi0">π

Summary of this work would be that it is a simplified <a href="https://github.com/lucidrains/transfusion-pytorch">Transfusion</a> (Zhou et al.) with influence from <a href="https://arxiv.org/abs/2403.03206">Stable Diffusion 3</a> (Esser et al.), mainly the adoption of flow matching instead of diffusion for policy generation, as well as the separation of parameters (<a href="https://github.com/lucidrains/mmdit/blob/main/mmdit/mmdit_pytorch.py#L43">Joint Attention</a> from mmDIT). They build on top of a pretrained vision language model, PaliGemma 2B.

### Appreciation

- [Einops](https://github.com/arogozhnikov/einops) for the amazing [pack and unpack](https://einops.rocks/4-pack-and-unpack/), used extensively here for managing various token sets

- [Flex Attention](https://pytorch.org/blog/flexattention/) for allowing for easy mixture of autoregressive and bidirectional attention

- [@Wonder1905](https://github.com/Wonder1905) for the code review and identifying issues

- You? a Phd student who want to contribute to the latest SOTA architecture for behavioral cloning?

### Install

```bash
Expand Down
8 changes: 2 additions & 6 deletions pi_zero_pytorch/pi_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -1127,10 +1127,8 @@ def forward(

actions_value_residual = default(actions_value_residual, action_values)

action_tokens = attn_ada_layerscale(action_tokens, time_cond)

state_tokens = state_tokens + state_attn_out
action_tokens = action_tokens + actions_attn_out
action_tokens = action_tokens + attn_ada_layerscale(actions_attn_out, time_cond)

state_tokens = state_ff(state_tokens) + state_tokens

Expand All @@ -1155,9 +1153,7 @@ def forward(

actions_value_residual = default(actions_value_residual, action_values)

action_tokens = attn_ada_layerscale(action_tokens, time_cond)

action_tokens = action_tokens + actions_attn_out
action_tokens = action_tokens + attn_ada_layerscale(actions_attn_out, time_cond)

action_tokens = ff_ada_rmsnorm(action_tokens, time_cond)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "pi-zero-pytorch"
version = "0.0.39"
version = "0.0.40"
description = "π0 in Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down

0 comments on commit aa6f5e2

Please sign in to comment.