You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm trying to train an unconditional model on Cifar10 but something seems to be wrong... Here is the code I'm using:
import torch
from imagen_pytorch import Unet, Imagen, ImagenTrainer
from tqdm import tqdm
import torch.optim as optim
import torchvision
from torchvision import transforms as T
class CifarDataset(Dataset):
def __init__(self):
image_size = 32
super().__init__()
self.transform = T.Compose([
T.Resize(image_size),
T.RandomHorizontalFlip(),
T.CenterCrop(image_size),
T.ToTensor()
])
self.data = torchvision.datasets.CIFAR10(root='./cifar_data', train=True,
download=True, transform=self.transform)
def __len__(self):
return len(self.data)
def __getitem__(self, index):
img, target = self.data[index]
return img
unet = Unet(
dim = 128,
dim_mults = (1, 2, 4, 8),
num_resnet_blocks = 3,
layer_attns = (False, True, True, True),
layer_cross_attns = (False, True, True, True)
).cuda()
imagen = Imagen(
condition_on_text = False, # this must be set to False for unconditional Imagen
unets = unet,
image_sizes = 32,
timesteps = 1000
).cuda()
dataset = CifarDataset()
# working training loop
print('starting training loop for 200k iterations')
learning_rate = 1e-4
opt = optim.Adam(unet.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08,
weight_decay=False)
dataloader = get_images_dataloader(batch_size = 32)
for i in range(200000):
for image_batch in tqdm(dataloader):
image_batch = image_batch.cuda()
loss = imagen(image_batch, unet=unet, unet_number = 1)
loss.backward()
torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0)
opt.step()
opt.zero_grad()
print(f'loss: {loss}')
images = imagen.sample(batch_size = 1, return_pil_images = True) # returns List[Image]
images[0].save(f'./imagen_samples/sample-{i}.png')
The loss is decreasing but the samples I generate after each pass through the dataset are still only random noise. Any idea what I'm missing? The dataset contains 50k samples. How many epochs do you think are needed before one should be able to make out something in the generated samples?
The text was updated successfully, but these errors were encountered:
@lucala, I'm facing the same problem, although I try to overfit on a single input batch. While training works as expected and reconstructs noise with low error, at inference time I get noise no matter the training. Have you found some hints toward a solution?
@lucala, I'm facing the same problem, although I try to overfit on a single input batch. While training works as expected and reconstructs noise with low error, at inference time I get noise no matter the training. Have you found some hints toward a solution?
No, unfortunately not. I pivoted to something else and haven't yet found time to come back to it.
I'm trying to train an unconditional model on Cifar10 but something seems to be wrong... Here is the code I'm using:
The loss is decreasing but the samples I generate after each pass through the dataset are still only random noise. Any idea what I'm missing? The dataset contains 50k samples. How many epochs do you think are needed before one should be able to make out something in the generated samples?
The text was updated successfully, but these errors were encountered: