Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added train by epoch for Trainer and added support for texts #12

Open
wants to merge 150 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
150 commits
Select commit Hold shift + click to select a range
4192b5f
added progress bar for trainer + accept texts args for encoder
MarcusLoppe Dec 13, 2023
0ae54e9
added epoch checkpoint saver for trainer
MarcusLoppe Dec 13, 2023
a09b2e8
setup.py forgot comma
MarcusLoppe Dec 13, 2023
83baaee
custom_collate can now accecpt texts
MarcusLoppe Dec 13, 2023
195d14e
bug fix - save every_epoch
MarcusLoppe Dec 13, 2023
1d4b705
bug fix - save every_epoch
MarcusLoppe Dec 13, 2023
b5d9b1e
bug fix - save every_epoch extra info
MarcusLoppe Dec 13, 2023
d20e3d8
bug fix - every_epoch
MarcusLoppe Dec 13, 2023
4085b70
final bug fix - every_epoch
MarcusLoppe Dec 13, 2023
3715e3a
Merge branch 'lucidrains:main' into main
MarcusLoppe Dec 13, 2023
99cfa22
Merge branch 'lucidrains:main' into main
MarcusLoppe Dec 13, 2023
57900a6
Merge branch 'lucidrains:main' into main
MarcusLoppe Dec 13, 2023
ee44068
Merge branch 'lucidrains:main' into main
MarcusLoppe Dec 14, 2023
fa515a6
removed grad_accum_every
MarcusLoppe Dec 14, 2023
fc49064
removed grad_accum_every
MarcusLoppe Dec 14, 2023
5a68772
fix error
MarcusLoppe Dec 14, 2023
cdb8b52
Revert "fix error"
MarcusLoppe Dec 14, 2023
3a46455
Revert "removed grad_accum_every"
MarcusLoppe Dec 14, 2023
0717547
Merge branch 'lucidrains:main' into main
MarcusLoppe Dec 14, 2023
7833b5b
Revert "removed grad_accum_every"
MarcusLoppe Dec 14, 2023
c8a82dc
Merge branch 'lucidrains:main' into main
MarcusLoppe Dec 14, 2023
60ea5ec
Merge branch 'main' into main
MarcusLoppe Dec 14, 2023
0b2b0ee
Merge branch 'lucidrains:main' into main
MarcusLoppe Dec 14, 2023
ea40170
return avg loss
MarcusLoppe Dec 14, 2023
82c39b0
Merge branch 'lucidrains:main' into main
MarcusLoppe Dec 14, 2023
999ffe7
Merge branch 'lucidrains:main' into main
MarcusLoppe Dec 14, 2023
b9204f6
forced model input
MarcusLoppe Dec 14, 2023
628d756
custom_collate fixed
MarcusLoppe Dec 15, 2023
7c7c61e
Merge branch 'lucidrains:main' into main
MarcusLoppe Dec 15, 2023
ca81558
Merge branch 'lucidrains:main' into main
MarcusLoppe Dec 15, 2023
abdb4c2
Merge branch 'lucidrains:main' into main
MarcusLoppe Dec 16, 2023
e57e6f1
added stop_at_loss
MarcusLoppe Dec 16, 2023
6dd684c
merge fix
MarcusLoppe Dec 17, 2023
824ba04
Merge branch 'lucidrains:main' into main
MarcusLoppe Dec 17, 2023
52ed27e
Merge branch 'lucidrains:main' into main
MarcusLoppe Dec 18, 2023
f012e95
added notebook
MarcusLoppe Dec 19, 2023
ccaef95
Merge branch 'main' of https://github.com/MarcusLoppe/meshgpt-pytorch
MarcusLoppe Dec 19, 2023
364b2a9
Merge branch 'lucidrains:main' into main
MarcusLoppe Dec 19, 2023
bd1b904
Merge branch 'main' of https://github.com/MarcusLoppe/meshgpt-pytorch
MarcusLoppe Dec 19, 2023
49dcaef
added notebook & training estimation time
MarcusLoppe Dec 19, 2023
0bd6c0c
notebook fixs
MarcusLoppe Dec 19, 2023
31c8e78
notebook fixs 2
MarcusLoppe Dec 19, 2023
68cd533
load fix
MarcusLoppe Dec 19, 2023
22b2b91
trainer load fix
MarcusLoppe Dec 19, 2023
aab84af
trainer warmup -ix
MarcusLoppe Dec 19, 2023
bdc2116
Merge branch 'lucidrains:main' into main
MarcusLoppe Dec 19, 2023
17a65e8
trainer revert fix
MarcusLoppe Dec 19, 2023
8936d65
Merge branch 'lucidrains:main' into main
MarcusLoppe Dec 23, 2023
b87e051
use **forward args
MarcusLoppe Dec 24, 2023
9078a53
skip data kwargs
MarcusLoppe Dec 24, 2023
0163070
readded data kwargs
MarcusLoppe Dec 24, 2023
673379b
implemented grad_accum_every
MarcusLoppe Dec 26, 2023
eb69248
implemented grad_accum_every
MarcusLoppe Dec 26, 2023
5c6f832
modifyed notebook
MarcusLoppe Dec 26, 2023
9766d41
added mesh_dataset
MarcusLoppe Dec 27, 2023
bf08333
fixed tdqm bug
MarcusLoppe Dec 28, 2023
66c3e1f
Merge branch 'lucidrains:main' into main
MarcusLoppe Dec 31, 2023
60face8
flash attention mispelling
MarcusLoppe Dec 31, 2023
f936da8
flash attention mispelling
MarcusLoppe Dec 31, 2023
a42052d
Merge branch 'main' of https://github.com/MarcusLoppe/meshgpt-pytorch
MarcusLoppe Dec 31, 2023
a0f9855
flash attention mispelling
MarcusLoppe Dec 31, 2023
d59a500
Merge branch 'lucidrains:main' into main
MarcusLoppe Jan 1, 2024
70ac14e
reimplemented grad_accum in transformer trainer
MarcusLoppe Jan 2, 2024
a5d6f53
Merge branch 'main' of https://github.com/MarcusLoppe/meshgpt-pytorch
MarcusLoppe Jan 2, 2024
7c86e36
cleaned up the trainer
MarcusLoppe Jan 3, 2024
ae057bb
cleaned up the trainer
MarcusLoppe Jan 3, 2024
04ba13f
Merge branch 'lucidrains:main' into main
MarcusLoppe Jan 5, 2024
45e628e
demo + update of meshdataset
MarcusLoppe Jan 5, 2024
f72b662
Merge branch 'main' of https://github.com/MarcusLoppe/meshgpt-pytorch
MarcusLoppe Jan 5, 2024
ab82a21
Merge branch 'lucidrains:main' into main
MarcusLoppe Jan 5, 2024
8364938
demo, missing csv import
MarcusLoppe Jan 5, 2024
4405d58
Merge branch 'main' of https://github.com/MarcusLoppe/meshgpt-pytorch
MarcusLoppe Jan 5, 2024
bf9d0a7
demo comments
MarcusLoppe Jan 6, 2024
36a17fc
Merge branch 'lucidrains:main' into main
MarcusLoppe Jan 9, 2024
7b467b4
Merge branch 'main' into main
MarcusLoppe Jan 11, 2024
4927fff
Merge branch 'lucidrains:main' into main
MarcusLoppe Jan 13, 2024
39a3b50
Merge branch 'lucidrains:main' into main
MarcusLoppe Jan 14, 2024
66b98a7
Merge branch 'lucidrains:main' into main
MarcusLoppe Jan 14, 2024
b9bcac4
Merge branch 'lucidrains:main' into main
MarcusLoppe Jan 19, 2024
51b81d4
Merge branch 'lucidrains:main' into main
MarcusLoppe Jan 23, 2024
afae738
Merge branch 'lucidrains:main' into main
MarcusLoppe Jan 27, 2024
25d51be
Merge branch 'lucidrains:main' into main
MarcusLoppe Feb 15, 2024
c5e2c14
shuffle as arg
MarcusLoppe Feb 15, 2024
5b42b8d
shuffle as arg
MarcusLoppe Feb 15, 2024
6e65de4
shuffle as arg
MarcusLoppe Feb 15, 2024
f8c087a
sageconv fix
MarcusLoppe Feb 17, 2024
60fb312
Optional Squeeze
MarcusLoppe Feb 19, 2024
d35228e
layernorm
MarcusLoppe Feb 29, 2024
9f987b6
Merge remote-tracking branch 'upstream/main'
MarcusLoppe Mar 8, 2024
1a8b0b1
block-forward fix
MarcusLoppe Mar 8, 2024
738a71c
demo update
MarcusLoppe Mar 8, 2024
0201134
labels
MarcusLoppe Mar 8, 2024
a3a2588
updated demo render functions
MarcusLoppe Mar 12, 2024
4f9ddd5
training stuff
MarcusLoppe Mar 12, 2024
e5db7d9
checkpoint name
MarcusLoppe Mar 12, 2024
e9c6bdf
checkpoint name
MarcusLoppe Mar 12, 2024
6cfc6f9
pad
MarcusLoppe Mar 13, 2024
89010a9
pad
MarcusLoppe Mar 13, 2024
6c197af
pad
MarcusLoppe Mar 13, 2024
8231e7b
pad
MarcusLoppe Mar 13, 2024
2940a46
Update README.md
MarcusLoppe Mar 14, 2024
201ce43
Update README.md
MarcusLoppe Mar 14, 2024
e1f7ffb
Update README.md
MarcusLoppe Mar 14, 2024
cdb0d7f
tuple error
MarcusLoppe Apr 2, 2024
0daebd3
Merge branch 'main' of https://github.com/MarcusLoppe/meshgpt-pytorch
MarcusLoppe Apr 2, 2024
e5afd9d
padding issue
MarcusLoppe Apr 15, 2024
fa71927
multi-gpu checkpoint fix
MarcusLoppe Apr 19, 2024
780431b
gateloop_use_heinsen fix
MarcusLoppe Apr 20, 2024
b4fcd8d
updated notebook with better instructions
MarcusLoppe May 7, 2024
989faa9
added fp16 training
MarcusLoppe May 7, 2024
8288354
Merge branch 'main' of https://github.com/lucidrains/meshgpt-pytorch
MarcusLoppe May 10, 2024
521f6f7
Merge branch 'main' of https://github.com/lucidrains/meshgpt-pytorch
MarcusLoppe May 10, 2024
19f616d
remove unused testing things
MarcusLoppe May 10, 2024
6e3909b
small comment
lucidrains May 10, 2024
391704b
Merge branch 'main' of https://github.com/lucidrains/meshgpt-pytorch
MarcusLoppe May 11, 2024
0e4a38e
Merge branch 'main' of https://github.com/lucidrains/meshgpt-pytorch
MarcusLoppe May 11, 2024
871f418
Merge branch 'lucidrains:main' into main
MarcusLoppe May 14, 2024
dd513a6
Merge branch 'lucidrains:main' into main
MarcusLoppe May 16, 2024
30c601e
Merge remote-tracking branch 'upstream/main'
MarcusLoppe May 20, 2024
674fa18
Merge branch 'lucidrains:main' into main
MarcusLoppe Jun 1, 2024
efb420f
Merge branch 'lucidrains:main' into main
MarcusLoppe Jun 1, 2024
14ad997
Merge branch 'lucidrains:main' into main
MarcusLoppe Jun 1, 2024
bdfeeea
Merge branch 'lucidrains:main' into main
MarcusLoppe Jun 5, 2024
c00e7c6
Merge branch 'lucidrains:main' into main
MarcusLoppe Jun 5, 2024
0d5cb51
huggingface implementation & updated rendering
MarcusLoppe Jun 7, 2024
b5e7467
init file
MarcusLoppe Jun 7, 2024
01ac2bb
Update mesh_dataset.py from entrys to entries
fire Jun 12, 2024
ecf72c7
Merge pull request #3 from fire/patch-1
MarcusLoppe Jun 12, 2024
f9a31d9
Merge remote-tracking branch 'upstream/main'
MarcusLoppe Jun 17, 2024
11bff2b
notebook updates
MarcusLoppe Jun 17, 2024
8e15172
Merge remote-tracking branch 'upstream/main'
MarcusLoppe Jun 17, 2024
b7a3611
bug fix
MarcusLoppe Jun 17, 2024
481428d
bug fix
MarcusLoppe Jun 17, 2024
bdfcade
bug fix
MarcusLoppe Jun 17, 2024
d166f2f
dataset improvements
MarcusLoppe Jun 17, 2024
2e45e7b
Merge branch 'lucidrains:main' into main
MarcusLoppe Jun 17, 2024
6cad755
dataset embed_texts - output message
MarcusLoppe Jun 18, 2024
1608da5
Merge branch 'main' of https://github.com/MarcusLoppe/meshgpt-pytorch
MarcusLoppe Jun 18, 2024
c7a1cbd
Merge branch 'lucidrains:main' into main
MarcusLoppe Jun 18, 2024
5f6d670
Update setup.py
MarcusLoppe Jun 18, 2024
14126a2
Merge branch 'lucidrains:main' into main
MarcusLoppe Jun 18, 2024
14a4470
Merge branch 'lucidrains:main' into main
MarcusLoppe Jun 20, 2024
d311798
Merge branch 'lucidrains:main' into main
MarcusLoppe Jun 20, 2024
3e7ed57
Update setup.py
MarcusLoppe Jun 21, 2024
9e1bb0d
bug fix
MarcusLoppe Jun 22, 2024
ea15adb
update
MarcusLoppe Jun 26, 2024
b99b081
Merge branch 'lucidrains:main' into main
MarcusLoppe Jul 2, 2024
c3f7c9f
Merge remote-tracking branch 'upstream/main'
MarcusLoppe Jul 2, 2024
a8880ad
Update setup.py
MarcusLoppe Jul 4, 2024
8985669
Merge branch 'lucidrains:main' into main
MarcusLoppe Jul 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion meshgpt_pytorch/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch.nn.utils.rnn import pad_sequence

from einops import rearrange, reduce
from torch import nn, Tensor

from beartype import beartype
from beartype.typing import Tuple, Union, Optional, Callable, Dict
Expand Down Expand Up @@ -114,6 +115,7 @@ def custom_collate(data, pad_id = -1):
datum = pad_sequence(datum, batch_first = True, padding_value = pad_id)
else:
datum = list(datum)
output.append(datum)

output.append(datum)

Expand All @@ -122,4 +124,4 @@ def custom_collate(data, pad_id = -1):
if is_dict:
output = dict(zip(keys, output))

return output
return output
1 change: 1 addition & 0 deletions meshgpt_pytorch/meshgpt_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,7 @@ def forward(
vertices: TensorType['b', 'nv', 3, float],
faces: TensorType['b', 'nf', 3, int],
face_edges: Optional[TensorType['b', 'e', 2, int]] = None,
texts: Optional[List[str]] = None,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, so the text is actually only conditioned through the transformer stage through cross attention

basically the autoencoder is given the job of only compressing meshes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I know :) But if you pass it a dict with texts it will give a error since the arg doesnt exist.
So then you would need two dataset classes.

Either replace the model(**forward_args) so it uses the prarameters directly:
model(vertices = data["vertices"], faces = data["faces"])

Or just implement a dummy texts :) There is probably a better solution

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh got it! yea, i can take care of that within the trainer class (just scrub out the text and text_embed keys)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that will work, I'm not 100% since the dataloader passes the data and maybe copies it(?).

But it won't work if you access it without copying it since the dataset is returning the data and not copying/cloning, when you do del on a key, it will remove it completely from the dataset.
So if you train the encoder and then want to train a transformer, you'll need to recreate the dataset since the texts key is removed.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer if the dataset returns the text with each vertices and faces.

return_codes = False,
return_loss_breakdown = False,
return_recon_faces = False,
Expand Down
136 changes: 133 additions & 3 deletions meshgpt_pytorch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
from meshgpt_pytorch.data import custom_collate

from meshgpt_pytorch.version import __version__

import matplotlib.pyplot as plt
from tqdm import tqdm
from meshgpt_pytorch.meshgpt_pytorch import (
MeshAutoencoder,
MeshTransformer
Expand Down Expand Up @@ -126,6 +127,7 @@ def __init__(
accelerator_kwargs: dict = dict(),
optimizer_kwargs: dict = dict(),
checkpoint_every = 1000,
checkpoint_every_epoch: Optional[int] = None,
checkpoint_folder = './checkpoints',
data_kwargs: Tuple[str, ...] = ['vertices', 'faces', 'face_edges'],
warmup_steps = 1000,
Expand Down Expand Up @@ -204,6 +206,7 @@ def __init__(
self.num_train_steps = num_train_steps
self.register_buffer('step', torch.tensor(0))

self.checkpoint_every_epoch = checkpoint_every_epoch
self.checkpoint_every = checkpoint_every
self.checkpoint_folder = Path(checkpoint_folder)
self.checkpoint_folder.mkdir(exist_ok = True, parents = True)
Expand Down Expand Up @@ -388,7 +391,70 @@ def forward(self):
self.wait()

self.print('training complete')
def train(self, num_epochs, stop_at_loss = None, diplay_graph = False):
epoch_losses = [] # Initialize a list to store epoch losses
self.model.train()
for epoch in range(num_epochs):
total_loss = 0.0
num_batches = 0

progress_bar = tqdm(self.dataloader, desc=f'Epoch {epoch + 1}/{num_epochs}')

for data in progress_bar:

if isinstance(data, tuple):
forward_kwargs = dict(zip(self.data_kwargs, data))

elif isinstance(data, dict):
forward_kwargs = data


with self.accelerator.autocast():
loss = self.model(vertices = forward_kwargs['vertices'], faces= forward_kwargs['faces'])
self.accelerator.backward(loss)

if exists(self.max_grad_norm):
self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)

self.optimizer.step()
self.optimizer.zero_grad()

if not self.accelerator.optimizer_step_was_skipped:
with self.warmup.dampening():
self.scheduler.step()

current_loss = loss.item()
total_loss += current_loss
num_batches += 1
progress_bar.set_postfix(loss=current_loss)



avg_epoch_loss = total_loss / num_batches
epoch_losses.append(avg_epoch_loss)
self.print(f'Epoch {epoch + 1} average loss: {avg_epoch_loss}')
self.wait()

if self.checkpoint_every_epoch is not None and epoch != 0 and epoch % self.checkpoint_every_epoch == 0:
self.save(self.checkpoint_folder / f'mesh-autoencoder.ckpt.epoch_{epoch}_avg_loss_{avg_epoch_loss:.3f}.pt')

if stop_at_loss is not None and avg_epoch_loss < stop_at_loss:
self.print(f'Stopping training at epoch {epoch} with average loss {avg_epoch_loss}')
if self.checkpoint_every_epoch is not None:
self.save(self.checkpoint_folder / f'mesh-autoencoder.ckpt.stop_at_loss_avg_loss_{avg_epoch_loss:.3f}.pt')
break


self.print('Training complete')
if diplay_graph:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so i haven't documented this, but you can already use wandb.ai experiment tracker

you just have to do

trainer = Trainer(..., use_wandb_tracking = True)

with trainer.trackers('meshgpt', 'one-experiment-name'):
  trainer.train()

plt.figure(figsize=(10, 5))
plt.plot(range(1, num_epochs + 1), epoch_losses, marker='o')
plt.title('Training Loss Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Average Loss')
plt.grid(True)
plt.show()
return epoch_losses[-1]
# mesh transformer trainer

class MeshTransformerTrainer(Module):
Expand All @@ -408,7 +474,9 @@ def __init__(
ema_kwargs: dict = dict(),
accelerator_kwargs: dict = dict(),
optimizer_kwargs: dict = dict(),
checkpoint_every = 1000,

checkpoint_every = 1000,
checkpoint_every_epoch: Optional[int] = None,
checkpoint_folder = './checkpoints',
data_kwargs: Tuple[str, ...] = ['vertices', 'faces', 'face_edges'],
warmup_steps = 1000,
Expand Down Expand Up @@ -472,6 +540,7 @@ def __init__(
self.num_train_steps = num_train_steps
self.register_buffer('step', torch.tensor(0))

self.checkpoint_every_epoch = checkpoint_every_epoch
self.checkpoint_every = checkpoint_every
self.checkpoint_folder = Path(checkpoint_folder)
self.checkpoint_folder.mkdir(exist_ok = True, parents = True)
Expand Down Expand Up @@ -597,4 +666,65 @@ def forward(self):

self.wait()

self.print('training complete')
self.print('training complete')

def train(self, num_epochs, stop_at_loss = None, diplay_graph = False):
epoch_losses = [] # Initialize a list to store epoch losses
self.model.train()
for epoch in range(num_epochs):
total_loss = 0.0
num_batches = 0

progress_bar = tqdm(self.dataloader, desc=f'Epoch {epoch + 1}/{num_epochs}')

for data in progress_bar:

if isinstance(data, tuple):
forward_kwargs = dict(zip(self.data_kwargs, data))

elif isinstance(data, dict):
forward_kwargs = data


with self.accelerator.autocast():
loss = self.model(**forward_kwargs)
self.accelerator.backward(loss / self.grad_accum_every)

if exists(self.max_grad_norm):
self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)

self.optimizer.step()
self.optimizer.zero_grad()

if not self.accelerator.optimizer_step_was_skipped:
with self.warmup.dampening():
self.scheduler.step()

current_loss = loss.item()
total_loss += current_loss
num_batches += 1
progress_bar.set_postfix(loss=current_loss)

avg_epoch_loss = total_loss / num_batches
epoch_losses.append(avg_epoch_loss)
self.print(f'Epoch {epoch + 1} average loss: {avg_epoch_loss}')
self.wait()
if self.checkpoint_every_epoch is not None and epoch != 0 and epoch % self.checkpoint_every_epoch == 0:
self.save(self.checkpoint_folder / f'mesh-transformer.ckpt.epoch_{epoch}_avg_loss_{avg_epoch_loss:.3f}.pt')

if stop_at_loss is not None and avg_epoch_loss < stop_at_loss:
self.print(f'Stopping training at epoch {epoch} with average loss {avg_epoch_loss}')
if self.checkpoint_every_epoch is not None:
self.save(self.checkpoint_folder / f'mesh-transformer.ckpt.stop_at_loss_avg_loss_{avg_epoch_loss:.3f}.pt')
break

self.print('Training complete')
if diplay_graph:
plt.figure(figsize=(10, 5))
plt.plot(range(1, num_epochs + 1), epoch_losses, marker='o')
plt.title('Training Loss Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Average Loss')
plt.grid(True)
plt.show()
return epoch_losses[-1]