-
Notifications
You must be signed in to change notification settings - Fork 88
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
plateau during training #8
Comments
I've been toying with another dataset with this method (Omniglot) and it does seem very sensitive to hyperparameters. I am not sure if what you're describing is the same as what I'm getting, but for me most experiments just end up having a reconstruction loss plateauing around ~0.16 (n.b: using Gaussian/L2 error here, not logistic mixture). It seems like one has to maybe carefully balance the weightings of the reconstruction loss and the KL loss to get things to work. I'm actually surprised that the unweighted ELBO seemingly just works for the experiments in this paper. |
I guess that, because VDVAE has almost no normalization, some hyperparameters (like batch_size) are more relevant than what we expect. Out of curiosity: I guess the model you are training on Omniglot is much smaller than the one for CIFAR10. |
Yeah I think we're talking about the same thing now, posterior collapse. Yes, it seems like you have to upweight lambda or downweight beta, and it's clear the posterior collapse actually happens most of the time because if you look at the Nope I haven't tried a Bernoulli likelihood, though it would make sense. IIRC it has roughly the same number of examples, but there are way more classes. If it helps any, when I first experimented with this code back when the paper was in review, it seemed like a lot of that instability was actually coming from the DMoL layer (see my post here: https://openreview.net/forum?id=RLRXCV6DbEJ¬eId=cra1CWLY3U_). It is also still not clear to me what the difference is between this and a Gaussian distribution (i.e. L2 error) when it comes to the metrics we care about (such as the ELBO). |
Hi @georgosgeorgos, The model is certainly sensitive to batch size/LR, and typically (as in any ML model) this should be chosen from a hyperparameter search over several options. Since you're decreasing the batch size by a factor of 4, it's possible that a lower LR than 1/2 the original is required (as you seem to have found). In my experiments I did not find that adjusting the KL was actually useful--the default unweighting typically led to higher performance--so I'm surprised that you're finding it's necessary for good performance. My assumption (without seeing your specific experiments) would be that the unweighted loss with better optimizer hyperparameters would do better. |
Another note, which may be good for folks to know (even if you used the code as-is, without changing anything): I found that training broke entirely if I replaced the output distribution from DMoL (the default) to Gaussian (i.e. mean squared error). In this case the reconstruction error just simply plateaued and I had to disable gradient clipping entirely, as well as introduce instance norm to the |
Yes, I think that's reasonable. Another thing to note is that the gradient skipping threshold depends on LR/batchsize -- @georgosgeorgos it could be that your config has much higher gradient norms, and thus they might be getting skipped (which would result in training stopping). You should be able to see in your logs the number of skipped updates. Ideally you should set it to a value that sees very few skips. |
Hi Rewon. Cool work!
I tested the model with your checkpoint and everything works perfectly.
Now I am training VDVAE on CIFAR10 from scratch using one GPU (reducing the batch size 32 -->16 and the lr 2e-4 --> 1e-4).
The model starts training without problems and then gets stuck in a plateau around ~4.7 nats/dim for a long time.
I found a similar plateau with other configurations (smaller lr, smaller model).
Did you experience this plateau during training?
Thanks!
The text was updated successfully, but these errors were encountered: