From 0471c3ee3b0b2ec28c5dbf1c9c7ac19f0531ed80 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 13 Nov 2024 06:38:08 -0800 Subject: [PATCH] e2e test with vit-pytorch --- .github/workflows/test.yaml | 24 +++++++++++++++++++ pyproject.toml | 3 ++- tests/test_pi_zero.py | 46 +++++++++++++++++++++++++++++++++++++ 3 files changed, 72 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/test.yaml create mode 100644 tests/test_pi_zero.py diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml new file mode 100644 index 0000000..593e558 --- /dev/null +++ b/.github/workflows/test.yaml @@ -0,0 +1,24 @@ +name: Tests the examples in README +on: push + +env: + TYPECHECK: True + +jobs: + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Install Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + - name: Install dependencies + run: | + python -m pip install uv + python -m uv pip install --upgrade pip + python -m uv pip install torch --index-url https://download.pytorch.org/whl/nightly/cpu + python -m uv pip install -e .[test] + - name: Test with pytest + run: | + python -m pytest tests/ diff --git a/pyproject.toml b/pyproject.toml index 66978f8..408da96 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,8 @@ Repository = "https://github.com/lucidrains/pi-zero-pytorch" [project.optional-dependencies] examples = [] test = [ - "pytest" + "pytest", + "vit-pytorch>=1.18.7" ] [tool.pytest.ini_options] diff --git a/tests/test_pi_zero.py b/tests/test_pi_zero.py new file mode 100644 index 0000000..8e103a4 --- /dev/null +++ b/tests/test_pi_zero.py @@ -0,0 +1,46 @@ +import torch +from pi_zero_pytorch import π0 + +def test_pi_zero_with_vit(): + from vit_pytorch import ViT + from vit_pytorch.extractor import Extractor + + v = ViT( + image_size = 256, + patch_size = 32, + num_classes = 1000, + dim = 1024, + depth = 6, + heads = 16, + mlp_dim = 2048, + dropout = 0.1, + emb_dropout = 0.1 + ) + + v = Extractor(v, return_embeddings_only = True) + + model = π0( + dim = 512, + vit = v, + vit_dim = 1024, + dim_action_input = 6, + dim_joint_state = 12, + num_tokens = 20_000 + ) + + vision = torch.randn(1, 1024, 512) + + images = torch.randn(1, 3, 2, 256, 256) + + commands = torch.randint(0, 20_000, (1, 1024)) + joint_state = torch.randn(1, 12) + actions = torch.randn(1, 32, 6) + + loss, _ = model(images, commands, joint_state, actions) + loss.backward() + + # after much training + + sampled_actions = model(images, commands, joint_state, trajectory_length = 32) # (1, 32, 6) + + assert sampled_actions.shape == (1, 32, 6)