-
Notifications
You must be signed in to change notification settings - Fork 63
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added train by epoch for Trainer and added support for texts #12
Open
MarcusLoppe
wants to merge
150
commits into
lucidrains:main
Choose a base branch
from
MarcusLoppe:main
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 32 commits
Commits
Show all changes
150 commits
Select commit
Hold shift + click to select a range
4192b5f
added progress bar for trainer + accept texts args for encoder
MarcusLoppe 0ae54e9
added epoch checkpoint saver for trainer
MarcusLoppe a09b2e8
setup.py forgot comma
MarcusLoppe 83baaee
custom_collate can now accecpt texts
MarcusLoppe 195d14e
bug fix - save every_epoch
MarcusLoppe 1d4b705
bug fix - save every_epoch
MarcusLoppe b5d9b1e
bug fix - save every_epoch extra info
MarcusLoppe d20e3d8
bug fix - every_epoch
MarcusLoppe 4085b70
final bug fix - every_epoch
MarcusLoppe 3715e3a
Merge branch 'lucidrains:main' into main
MarcusLoppe 99cfa22
Merge branch 'lucidrains:main' into main
MarcusLoppe 57900a6
Merge branch 'lucidrains:main' into main
MarcusLoppe ee44068
Merge branch 'lucidrains:main' into main
MarcusLoppe fa515a6
removed grad_accum_every
MarcusLoppe fc49064
removed grad_accum_every
MarcusLoppe 5a68772
fix error
MarcusLoppe cdb8b52
Revert "fix error"
MarcusLoppe 3a46455
Revert "removed grad_accum_every"
MarcusLoppe 0717547
Merge branch 'lucidrains:main' into main
MarcusLoppe 7833b5b
Revert "removed grad_accum_every"
MarcusLoppe c8a82dc
Merge branch 'lucidrains:main' into main
MarcusLoppe 60ea5ec
Merge branch 'main' into main
MarcusLoppe 0b2b0ee
Merge branch 'lucidrains:main' into main
MarcusLoppe ea40170
return avg loss
MarcusLoppe 82c39b0
Merge branch 'lucidrains:main' into main
MarcusLoppe 999ffe7
Merge branch 'lucidrains:main' into main
MarcusLoppe b9204f6
forced model input
MarcusLoppe 628d756
custom_collate fixed
MarcusLoppe 7c7c61e
Merge branch 'lucidrains:main' into main
MarcusLoppe ca81558
Merge branch 'lucidrains:main' into main
MarcusLoppe abdb4c2
Merge branch 'lucidrains:main' into main
MarcusLoppe e57e6f1
added stop_at_loss
MarcusLoppe 6dd684c
merge fix
MarcusLoppe 824ba04
Merge branch 'lucidrains:main' into main
MarcusLoppe 52ed27e
Merge branch 'lucidrains:main' into main
MarcusLoppe f012e95
added notebook
MarcusLoppe ccaef95
Merge branch 'main' of https://github.com/MarcusLoppe/meshgpt-pytorch
MarcusLoppe 364b2a9
Merge branch 'lucidrains:main' into main
MarcusLoppe bd1b904
Merge branch 'main' of https://github.com/MarcusLoppe/meshgpt-pytorch
MarcusLoppe 49dcaef
added notebook & training estimation time
MarcusLoppe 0bd6c0c
notebook fixs
MarcusLoppe 31c8e78
notebook fixs 2
MarcusLoppe 68cd533
load fix
MarcusLoppe 22b2b91
trainer load fix
MarcusLoppe aab84af
trainer warmup -ix
MarcusLoppe bdc2116
Merge branch 'lucidrains:main' into main
MarcusLoppe 17a65e8
trainer revert fix
MarcusLoppe 8936d65
Merge branch 'lucidrains:main' into main
MarcusLoppe b87e051
use **forward args
MarcusLoppe 9078a53
skip data kwargs
MarcusLoppe 0163070
readded data kwargs
MarcusLoppe 673379b
implemented grad_accum_every
MarcusLoppe eb69248
implemented grad_accum_every
MarcusLoppe 5c6f832
modifyed notebook
MarcusLoppe 9766d41
added mesh_dataset
MarcusLoppe bf08333
fixed tdqm bug
MarcusLoppe 66c3e1f
Merge branch 'lucidrains:main' into main
MarcusLoppe 60face8
flash attention mispelling
MarcusLoppe f936da8
flash attention mispelling
MarcusLoppe a42052d
Merge branch 'main' of https://github.com/MarcusLoppe/meshgpt-pytorch
MarcusLoppe a0f9855
flash attention mispelling
MarcusLoppe d59a500
Merge branch 'lucidrains:main' into main
MarcusLoppe 70ac14e
reimplemented grad_accum in transformer trainer
MarcusLoppe a5d6f53
Merge branch 'main' of https://github.com/MarcusLoppe/meshgpt-pytorch
MarcusLoppe 7c86e36
cleaned up the trainer
MarcusLoppe ae057bb
cleaned up the trainer
MarcusLoppe 04ba13f
Merge branch 'lucidrains:main' into main
MarcusLoppe 45e628e
demo + update of meshdataset
MarcusLoppe f72b662
Merge branch 'main' of https://github.com/MarcusLoppe/meshgpt-pytorch
MarcusLoppe ab82a21
Merge branch 'lucidrains:main' into main
MarcusLoppe 8364938
demo, missing csv import
MarcusLoppe 4405d58
Merge branch 'main' of https://github.com/MarcusLoppe/meshgpt-pytorch
MarcusLoppe bf9d0a7
demo comments
MarcusLoppe 36a17fc
Merge branch 'lucidrains:main' into main
MarcusLoppe 7b467b4
Merge branch 'main' into main
MarcusLoppe 4927fff
Merge branch 'lucidrains:main' into main
MarcusLoppe 39a3b50
Merge branch 'lucidrains:main' into main
MarcusLoppe 66b98a7
Merge branch 'lucidrains:main' into main
MarcusLoppe b9bcac4
Merge branch 'lucidrains:main' into main
MarcusLoppe 51b81d4
Merge branch 'lucidrains:main' into main
MarcusLoppe afae738
Merge branch 'lucidrains:main' into main
MarcusLoppe 25d51be
Merge branch 'lucidrains:main' into main
MarcusLoppe c5e2c14
shuffle as arg
MarcusLoppe 5b42b8d
shuffle as arg
MarcusLoppe 6e65de4
shuffle as arg
MarcusLoppe f8c087a
sageconv fix
MarcusLoppe 60fb312
Optional Squeeze
MarcusLoppe d35228e
layernorm
MarcusLoppe 9f987b6
Merge remote-tracking branch 'upstream/main'
MarcusLoppe 1a8b0b1
block-forward fix
MarcusLoppe 738a71c
demo update
MarcusLoppe 0201134
labels
MarcusLoppe a3a2588
updated demo render functions
MarcusLoppe 4f9ddd5
training stuff
MarcusLoppe e5db7d9
checkpoint name
MarcusLoppe e9c6bdf
checkpoint name
MarcusLoppe 6cfc6f9
pad
MarcusLoppe 89010a9
pad
MarcusLoppe 6c197af
pad
MarcusLoppe 8231e7b
pad
MarcusLoppe 2940a46
Update README.md
MarcusLoppe 201ce43
Update README.md
MarcusLoppe e1f7ffb
Update README.md
MarcusLoppe cdb0d7f
tuple error
MarcusLoppe 0daebd3
Merge branch 'main' of https://github.com/MarcusLoppe/meshgpt-pytorch
MarcusLoppe e5afd9d
padding issue
MarcusLoppe fa71927
multi-gpu checkpoint fix
MarcusLoppe 780431b
gateloop_use_heinsen fix
MarcusLoppe b4fcd8d
updated notebook with better instructions
MarcusLoppe 989faa9
added fp16 training
MarcusLoppe 8288354
Merge branch 'main' of https://github.com/lucidrains/meshgpt-pytorch
MarcusLoppe 521f6f7
Merge branch 'main' of https://github.com/lucidrains/meshgpt-pytorch
MarcusLoppe 19f616d
remove unused testing things
MarcusLoppe 6e3909b
small comment
lucidrains 391704b
Merge branch 'main' of https://github.com/lucidrains/meshgpt-pytorch
MarcusLoppe 0e4a38e
Merge branch 'main' of https://github.com/lucidrains/meshgpt-pytorch
MarcusLoppe 871f418
Merge branch 'lucidrains:main' into main
MarcusLoppe dd513a6
Merge branch 'lucidrains:main' into main
MarcusLoppe 30c601e
Merge remote-tracking branch 'upstream/main'
MarcusLoppe 674fa18
Merge branch 'lucidrains:main' into main
MarcusLoppe efb420f
Merge branch 'lucidrains:main' into main
MarcusLoppe 14ad997
Merge branch 'lucidrains:main' into main
MarcusLoppe bdfeeea
Merge branch 'lucidrains:main' into main
MarcusLoppe c00e7c6
Merge branch 'lucidrains:main' into main
MarcusLoppe 0d5cb51
huggingface implementation & updated rendering
MarcusLoppe b5e7467
init file
MarcusLoppe 01ac2bb
Update mesh_dataset.py from entrys to entries
fire ecf72c7
Merge pull request #3 from fire/patch-1
MarcusLoppe f9a31d9
Merge remote-tracking branch 'upstream/main'
MarcusLoppe 11bff2b
notebook updates
MarcusLoppe 8e15172
Merge remote-tracking branch 'upstream/main'
MarcusLoppe b7a3611
bug fix
MarcusLoppe 481428d
bug fix
MarcusLoppe bdfcade
bug fix
MarcusLoppe d166f2f
dataset improvements
MarcusLoppe 2e45e7b
Merge branch 'lucidrains:main' into main
MarcusLoppe 6cad755
dataset embed_texts - output message
MarcusLoppe 1608da5
Merge branch 'main' of https://github.com/MarcusLoppe/meshgpt-pytorch
MarcusLoppe c7a1cbd
Merge branch 'lucidrains:main' into main
MarcusLoppe 5f6d670
Update setup.py
MarcusLoppe 14126a2
Merge branch 'lucidrains:main' into main
MarcusLoppe 14a4470
Merge branch 'lucidrains:main' into main
MarcusLoppe d311798
Merge branch 'lucidrains:main' into main
MarcusLoppe 3e7ed57
Update setup.py
MarcusLoppe 9e1bb0d
bug fix
MarcusLoppe ea15adb
update
MarcusLoppe b99b081
Merge branch 'lucidrains:main' into main
MarcusLoppe c3f7c9f
Merge remote-tracking branch 'upstream/main'
MarcusLoppe a8880ad
Update setup.py
MarcusLoppe 8985669
Merge branch 'lucidrains:main' into main
MarcusLoppe File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,7 +23,8 @@ | |
from meshgpt_pytorch.data import custom_collate | ||
|
||
from meshgpt_pytorch.version import __version__ | ||
|
||
import matplotlib.pyplot as plt | ||
from tqdm import tqdm | ||
from meshgpt_pytorch.meshgpt_pytorch import ( | ||
MeshAutoencoder, | ||
MeshTransformer | ||
|
@@ -126,6 +127,7 @@ def __init__( | |
accelerator_kwargs: dict = dict(), | ||
optimizer_kwargs: dict = dict(), | ||
checkpoint_every = 1000, | ||
checkpoint_every_epoch: Optional[int] = None, | ||
checkpoint_folder = './checkpoints', | ||
data_kwargs: Tuple[str, ...] = ['vertices', 'faces', 'face_edges'], | ||
warmup_steps = 1000, | ||
|
@@ -204,6 +206,7 @@ def __init__( | |
self.num_train_steps = num_train_steps | ||
self.register_buffer('step', torch.tensor(0)) | ||
|
||
self.checkpoint_every_epoch = checkpoint_every_epoch | ||
self.checkpoint_every = checkpoint_every | ||
self.checkpoint_folder = Path(checkpoint_folder) | ||
self.checkpoint_folder.mkdir(exist_ok = True, parents = True) | ||
|
@@ -388,7 +391,70 @@ def forward(self): | |
self.wait() | ||
|
||
self.print('training complete') | ||
def train(self, num_epochs, stop_at_loss = None, diplay_graph = False): | ||
epoch_losses = [] # Initialize a list to store epoch losses | ||
self.model.train() | ||
for epoch in range(num_epochs): | ||
total_loss = 0.0 | ||
num_batches = 0 | ||
|
||
progress_bar = tqdm(self.dataloader, desc=f'Epoch {epoch + 1}/{num_epochs}') | ||
|
||
for data in progress_bar: | ||
|
||
if isinstance(data, tuple): | ||
forward_kwargs = dict(zip(self.data_kwargs, data)) | ||
|
||
elif isinstance(data, dict): | ||
forward_kwargs = data | ||
|
||
|
||
with self.accelerator.autocast(): | ||
loss = self.model(vertices = forward_kwargs['vertices'], faces= forward_kwargs['faces']) | ||
self.accelerator.backward(loss) | ||
|
||
if exists(self.max_grad_norm): | ||
self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) | ||
|
||
self.optimizer.step() | ||
self.optimizer.zero_grad() | ||
|
||
if not self.accelerator.optimizer_step_was_skipped: | ||
with self.warmup.dampening(): | ||
self.scheduler.step() | ||
|
||
current_loss = loss.item() | ||
total_loss += current_loss | ||
num_batches += 1 | ||
progress_bar.set_postfix(loss=current_loss) | ||
|
||
|
||
|
||
avg_epoch_loss = total_loss / num_batches | ||
epoch_losses.append(avg_epoch_loss) | ||
self.print(f'Epoch {epoch + 1} average loss: {avg_epoch_loss}') | ||
self.wait() | ||
|
||
if self.checkpoint_every_epoch is not None and epoch != 0 and epoch % self.checkpoint_every_epoch == 0: | ||
self.save(self.checkpoint_folder / f'mesh-autoencoder.ckpt.epoch_{epoch}_avg_loss_{avg_epoch_loss:.3f}.pt') | ||
|
||
if stop_at_loss is not None and avg_epoch_loss < stop_at_loss: | ||
self.print(f'Stopping training at epoch {epoch} with average loss {avg_epoch_loss}') | ||
if self.checkpoint_every_epoch is not None: | ||
self.save(self.checkpoint_folder / f'mesh-autoencoder.ckpt.stop_at_loss_avg_loss_{avg_epoch_loss:.3f}.pt') | ||
break | ||
|
||
|
||
self.print('Training complete') | ||
if diplay_graph: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so i haven't documented this, but you can already use wandb.ai experiment tracker you just have to do trainer = Trainer(..., use_wandb_tracking = True)
with trainer.trackers('meshgpt', 'one-experiment-name'):
trainer.train() |
||
plt.figure(figsize=(10, 5)) | ||
plt.plot(range(1, num_epochs + 1), epoch_losses, marker='o') | ||
plt.title('Training Loss Over Epochs') | ||
plt.xlabel('Epoch') | ||
plt.ylabel('Average Loss') | ||
plt.grid(True) | ||
plt.show() | ||
return epoch_losses[-1] | ||
# mesh transformer trainer | ||
|
||
class MeshTransformerTrainer(Module): | ||
|
@@ -408,7 +474,9 @@ def __init__( | |
ema_kwargs: dict = dict(), | ||
accelerator_kwargs: dict = dict(), | ||
optimizer_kwargs: dict = dict(), | ||
checkpoint_every = 1000, | ||
|
||
checkpoint_every = 1000, | ||
checkpoint_every_epoch: Optional[int] = None, | ||
checkpoint_folder = './checkpoints', | ||
data_kwargs: Tuple[str, ...] = ['vertices', 'faces', 'face_edges'], | ||
warmup_steps = 1000, | ||
|
@@ -472,6 +540,7 @@ def __init__( | |
self.num_train_steps = num_train_steps | ||
self.register_buffer('step', torch.tensor(0)) | ||
|
||
self.checkpoint_every_epoch = checkpoint_every_epoch | ||
self.checkpoint_every = checkpoint_every | ||
self.checkpoint_folder = Path(checkpoint_folder) | ||
self.checkpoint_folder.mkdir(exist_ok = True, parents = True) | ||
|
@@ -597,4 +666,65 @@ def forward(self): | |
|
||
self.wait() | ||
|
||
self.print('training complete') | ||
self.print('training complete') | ||
|
||
def train(self, num_epochs, stop_at_loss = None, diplay_graph = False): | ||
epoch_losses = [] # Initialize a list to store epoch losses | ||
self.model.train() | ||
for epoch in range(num_epochs): | ||
total_loss = 0.0 | ||
num_batches = 0 | ||
|
||
progress_bar = tqdm(self.dataloader, desc=f'Epoch {epoch + 1}/{num_epochs}') | ||
|
||
for data in progress_bar: | ||
|
||
if isinstance(data, tuple): | ||
forward_kwargs = dict(zip(self.data_kwargs, data)) | ||
|
||
elif isinstance(data, dict): | ||
forward_kwargs = data | ||
|
||
|
||
with self.accelerator.autocast(): | ||
loss = self.model(**forward_kwargs) | ||
self.accelerator.backward(loss / self.grad_accum_every) | ||
|
||
if exists(self.max_grad_norm): | ||
self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) | ||
|
||
self.optimizer.step() | ||
self.optimizer.zero_grad() | ||
|
||
if not self.accelerator.optimizer_step_was_skipped: | ||
with self.warmup.dampening(): | ||
self.scheduler.step() | ||
|
||
current_loss = loss.item() | ||
total_loss += current_loss | ||
num_batches += 1 | ||
progress_bar.set_postfix(loss=current_loss) | ||
|
||
avg_epoch_loss = total_loss / num_batches | ||
epoch_losses.append(avg_epoch_loss) | ||
self.print(f'Epoch {epoch + 1} average loss: {avg_epoch_loss}') | ||
self.wait() | ||
if self.checkpoint_every_epoch is not None and epoch != 0 and epoch % self.checkpoint_every_epoch == 0: | ||
self.save(self.checkpoint_folder / f'mesh-transformer.ckpt.epoch_{epoch}_avg_loss_{avg_epoch_loss:.3f}.pt') | ||
|
||
if stop_at_loss is not None and avg_epoch_loss < stop_at_loss: | ||
self.print(f'Stopping training at epoch {epoch} with average loss {avg_epoch_loss}') | ||
if self.checkpoint_every_epoch is not None: | ||
self.save(self.checkpoint_folder / f'mesh-transformer.ckpt.stop_at_loss_avg_loss_{avg_epoch_loss:.3f}.pt') | ||
break | ||
|
||
self.print('Training complete') | ||
if diplay_graph: | ||
plt.figure(figsize=(10, 5)) | ||
plt.plot(range(1, num_epochs + 1), epoch_losses, marker='o') | ||
plt.title('Training Loss Over Epochs') | ||
plt.xlabel('Epoch') | ||
plt.ylabel('Average Loss') | ||
plt.grid(True) | ||
plt.show() | ||
return epoch_losses[-1] |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, so the text is actually only conditioned through the transformer stage through cross attention
basically the autoencoder is given the job of only compressing meshes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I know :) But if you pass it a dict with texts it will give a error since the arg doesnt exist.
So then you would need two dataset classes.
Either replace the model(**forward_args) so it uses the prarameters directly:
model(vertices = data["vertices"], faces = data["faces"])
Or just implement a dummy texts :) There is probably a better solution
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh got it! yea, i can take care of that within the trainer class (just scrub out the
text
andtext_embed
keys)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think that will work, I'm not 100% since the dataloader passes the data and maybe copies it(?).
But it won't work if you access it without copying it since the dataset is returning the data and not copying/cloning, when you do del on a key, it will remove it completely from the dataset.
So if you train the encoder and then want to train a transformer, you'll need to recreate the dataset since the texts key is removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer if the dataset returns the text with each vertices and faces.