Skip to content

Commit

Permalink
e2e test with vit-pytorch
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 13, 2024
1 parent 4737924 commit 0471c3e
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 1 deletion.
24 changes: 24 additions & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
name: Tests the examples in README
on: push

env:
TYPECHECK: True

jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Install Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Install dependencies
run: |
python -m pip install uv
python -m uv pip install --upgrade pip
python -m uv pip install torch --index-url https://download.pytorch.org/whl/nightly/cpu
python -m uv pip install -e .[test]
- name: Test with pytest
run: |
python -m pytest tests/
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ Repository = "https://github.com/lucidrains/pi-zero-pytorch"
[project.optional-dependencies]
examples = []
test = [
"pytest"
"pytest",
"vit-pytorch>=1.18.7"
]

[tool.pytest.ini_options]
Expand Down
46 changes: 46 additions & 0 deletions tests/test_pi_zero.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import torch
from pi_zero_pytorch import π0

def test_pi_zero_with_vit():
from vit_pytorch import ViT
from vit_pytorch.extractor import Extractor

v = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)

v = Extractor(v, return_embeddings_only = True)

model = π0(
dim = 512,
vit = v,
vit_dim = 1024,
dim_action_input = 6,
dim_joint_state = 12,
num_tokens = 20_000
)

vision = torch.randn(1, 1024, 512)

images = torch.randn(1, 3, 2, 256, 256)

commands = torch.randint(0, 20_000, (1, 1024))
joint_state = torch.randn(1, 12)
actions = torch.randn(1, 32, 6)

loss, _ = model(images, commands, joint_state, actions)
loss.backward()

# after much training

sampled_actions = model(images, commands, joint_state, trajectory_length = 32) # (1, 32, 6)

assert sampled_actions.shape == (1, 32, 6)

0 comments on commit 0471c3e

Please sign in to comment.