Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generated samples contain only noise while training unconditional Imagen #337

Open
lucala opened this issue Mar 23, 2023 · 2 comments
Open

Comments

@lucala
Copy link

lucala commented Mar 23, 2023

I'm trying to train an unconditional model on Cifar10 but something seems to be wrong... Here is the code I'm using:

import torch
from imagen_pytorch import Unet, Imagen, ImagenTrainer
from tqdm import tqdm
import torch.optim as optim
import torchvision
from torchvision import transforms as T

class CifarDataset(Dataset):
    def __init__(self):
        image_size = 32
        super().__init__()
        self.transform = T.Compose([
            T.Resize(image_size),
            T.RandomHorizontalFlip(),
            T.CenterCrop(image_size),
            T.ToTensor()
        ])
        self.data = torchvision.datasets.CIFAR10(root='./cifar_data', train=True, 
                                                   download=True, transform=self.transform)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        img, target = self.data[index]
        return img

unet = Unet(
    dim = 128,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = 3,
    layer_attns = (False, True, True, True),
    layer_cross_attns = (False, True, True, True)
).cuda()

imagen = Imagen(
    condition_on_text = False,  # this must be set to False for unconditional Imagen
    unets = unet,
    image_sizes = 32,
    timesteps = 1000
).cuda()

dataset = CifarDataset()

# working training loop
print('starting training loop for 200k iterations')

learning_rate = 1e-4
opt = optim.Adam(unet.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08,
                     weight_decay=False)

dataloader = get_images_dataloader(batch_size = 32)
for i in range(200000):
    for image_batch in tqdm(dataloader):
        image_batch = image_batch.cuda()
        loss = imagen(image_batch, unet=unet, unet_number = 1)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0)
        opt.step()
        opt.zero_grad()
    print(f'loss: {loss}')
    images = imagen.sample(batch_size = 1, return_pil_images = True) # returns List[Image]
    images[0].save(f'./imagen_samples/sample-{i}.png')

The loss is decreasing but the samples I generate after each pass through the dataset are still only random noise. Any idea what I'm missing? The dataset contains 50k samples. How many epochs do you think are needed before one should be able to make out something in the generated samples?

@kirilllzaitsev
Copy link

@lucala, I'm facing the same problem, although I try to overfit on a single input batch. While training works as expected and reconstructs noise with low error, at inference time I get noise no matter the training. Have you found some hints toward a solution?

@lucala
Copy link
Author

lucala commented Apr 25, 2023

@lucala, I'm facing the same problem, although I try to overfit on a single input batch. While training works as expected and reconstructs noise with low error, at inference time I get noise no matter the training. Have you found some hints toward a solution?

No, unfortunately not. I pivoted to something else and haven't yet found time to come back to it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants