Skip to content

Commit

Permalink
add optional adaptive rmsnorm on text embed conditioning
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 17, 2024
1 parent ac67dd5 commit b6d870c
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 4 deletions.
21 changes: 21 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
name: Pytest
on: [push, pull_request]

jobs:
build:

runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4
- name: Set up Python 3.10
uses: actions/setup-python@v5
with:
python-version: "3.10"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install -e .[test]
- name: Test with pytest
run: |
python -m pytest tests/
14 changes: 12 additions & 2 deletions meshgpt_pytorch/meshgpt_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1075,7 +1075,7 @@ def forward(
return recon_faces, total_loss, loss_breakdown

@save_load(version = __version__)
class MeshTransformer(Module,PyTorchModelHubMixin):
class MeshTransformer(Module, PyTorchModelHubMixin):
@typecheck
def __init__(
self,
Expand All @@ -1094,12 +1094,13 @@ def __init__(
cross_attn_num_mem_kv = 4, # needed for preventing nan when dropping out text condition
dropout = 0.,
coarse_pre_gateloop_depth = 2,
coarse_adaptive_rmsnorm = False,
fine_pre_gateloop_depth = 2,
gateloop_use_heinsen = False,
fine_attn_depth = 2,
fine_attn_dim_head = 32,
fine_attn_heads = 8,
fine_cross_attend_text = False,
fine_cross_attend_text = False, # additional conditioning - fine transformer cross attention to text tokens
pad_id = -1,
num_sos_tokens = None,
condition_on_text = False,
Expand Down Expand Up @@ -1177,6 +1178,8 @@ def __init__(
# main autoregressive attention network
# attending to a face token

self.coarse_adaptive_rmsnorm = coarse_adaptive_rmsnorm

self.decoder = Decoder(
dim = dim,
depth = attn_depth,
Expand All @@ -1185,6 +1188,8 @@ def __init__(
attn_flash = flash_attn,
attn_dropout = dropout,
ff_dropout = dropout,
use_adaptive_rmsnorm = coarse_adaptive_rmsnorm,
dim_condition = dim_text,
cross_attend = condition_on_text,
cross_attn_dim_context = cross_attn_dim_context,
cross_attn_num_mem_kv = cross_attn_num_mem_kv,
Expand Down Expand Up @@ -1458,6 +1463,11 @@ def forward_on_codes(
context_mask = text_mask
)

if self.coarse_adaptive_rmsnorm:
attn_context_kwargs.update(
condition = pooled_text_embed
)

# take care of codes that may be flattened

if codes.ndim > 2:
Expand Down
2 changes: 1 addition & 1 deletion meshgpt_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.2.22'
__version__ = '1.4.0'
7 changes: 7 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[aliases]
test=pytest

[tool:pytest]
addopts = --verbose -s
python_files = tests/*.py
python_paths = "."
8 changes: 7 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,13 @@
'torch_geometric',
'tqdm',
'vector-quantize-pytorch>=1.14.22',
'x-transformers>=1.30.6',
'x-transformers>=1.30.19',
],
setup_requires=[
'pytest-runner',
],
tests_require=[
'pytest'
],
classifiers=[
'Development Status :: 4 - Beta',
Expand Down
52 changes: 52 additions & 0 deletions tests/test_meshgpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import pytest
import torch

from meshgpt_pytorch import (
MeshAutoencoder,
MeshTransformer
)

@pytest.mark.parametrize('adaptive_rmsnorm', (True, False))
def test_readme(adaptive_rmsnorm):

autoencoder = MeshAutoencoder(
num_discrete_coors = 128
)

# mock inputs

vertices = torch.randn((2, 121, 3)) # (batch, num vertices, coor (3))
faces = torch.randint(0, 121, (2, 64, 3)) # (batch, num faces, vertices (3))

# forward in the faces

loss = autoencoder(
vertices = vertices,
faces = faces
)

loss.backward()

# after much training...
# you can pass in the raw face data above to train a transformer to model this sequence of face vertices

transformer = MeshTransformer(
autoencoder,
dim = 512,
max_seq_len = 768,
num_sos_tokens = 1,
fine_cross_attend_text = True,
text_cond_with_film = False,
condition_on_text = True,
coarse_adaptive_rmsnorm = adaptive_rmsnorm
)

loss = transformer(
vertices = vertices,
faces = faces,
texts = ['a high chair', 'a small teapot']
)

loss.backward()

faces_coordinates, face_mask = transformer.generate(texts = ['a small chair'], cond_scale = 3.)

0 comments on commit b6d870c

Please sign in to comment.