From 8eb7e5c48a0006005254cae4a8b9345faf7e500a Mon Sep 17 00:00:00 2001 From: Mark Mayo Date: Thu, 17 Nov 2022 13:56:13 +1300 Subject: [PATCH] updates for python 3 Clean up of blank space Tidy of alignments Removal of unnecessary object inheritance (py 3) Incorrectly declared f-strings Import orders --- code/config.py | 22 ++-- code/dataset.py | 100 ++++++++-------- code/dc_ldm/ldm_for_fmri.py | 53 +++++---- code/dc_ldm/models/autoencoder.py | 29 +++-- code/dc_ldm/models/diffusion/classifier.py | 17 +-- code/dc_ldm/models/diffusion/ddim.py | 17 +-- code/dc_ldm/models/diffusion/ddpm.py | 111 ++++++++++-------- code/dc_ldm/models/diffusion/plms.py | 17 +-- code/dc_ldm/modules/attention.py | 18 +-- code/dc_ldm/modules/diffusionmodules/model.py | 20 ++-- .../modules/diffusionmodules/openaimodel.py | 27 ++--- code/dc_ldm/modules/diffusionmodules/util.py | 12 +- .../modules/distributions/distributions.py | 4 +- code/dc_ldm/modules/encoders/modules.py | 10 +- code/dc_ldm/modules/losses/__init__.py | 2 +- code/dc_ldm/modules/losses/contperceptual.py | 5 +- code/dc_ldm/modules/losses/vqperceptual.py | 10 +- code/dc_ldm/modules/x_transformer.py | 14 +-- code/dc_ldm/util.py | 7 +- code/eval_metrics.py | 26 ++-- code/gen_eval.py | 52 ++++---- code/sc_mbm/mae_for_fmri.py | 47 ++++---- code/sc_mbm/trainer.py | 23 ++-- code/sc_mbm/utils.py | 14 ++- code/setup.py | 4 +- code/stageA1_mbm_pretrain.py | 64 +++++----- code/stageA2_mbm_finetune.py | 59 +++++----- code/stageB_ldm_finetune.py | 73 ++++++------ 28 files changed, 445 insertions(+), 412 deletions(-) diff --git a/code/config.py b/code/config.py index b82e8e2..b92f39c 100644 --- a/code/config.py +++ b/code/config.py @@ -1,10 +1,12 @@ import os + import numpy as np + class Config_MAE_fMRI: # back compatibility pass class Config_MBM_finetune: # back compatibility - pass + pass class Config_MBM_fMRI(Config_MAE_fMRI): # configs for fmri_pretrain.py @@ -19,7 +21,7 @@ def __init__(self): self.warmup_epochs = 40 self.batch_size = 100 self.clip_grad = 0.8 - + # Model Parameters self.mask_ratio = 0.75 self.patch_size = 16 @@ -51,14 +53,14 @@ def __init__(self): class Config_MBM_finetune(Config_MBM_finetune): def __init__(self): - + # Project setting self.root_path = '.' self.output_path = self.root_path self.kam_path = os.path.join(self.root_path, 'data/Kamitani/npz') self.bold5000_path = os.path.join(self.root_path, 'data/BOLD5000') self.dataset = 'GOD' # GOD or BOLD5000 - self.pretrain_mbm_path = os.path.join(self.root_path, f'pretrains/{self.dataset}/fmri_encoder.pth') + self.pretrain_mbm_path = os.path.join(self.root_path, f'pretrains/{self.dataset}/fmri_encoder.pth') self.include_nonavg_test = True self.kam_subs = ['sbj_3'] @@ -68,8 +70,8 @@ def __init__(self): self.lr = 5.3e-5 self.weight_decay = 0.05 self.num_epoch = 15 - self.batch_size = 16 if self.dataset == 'GOD' else 4 - self.mask_ratio = 0.75 + self.batch_size = 16 if self.dataset == 'GOD' else 4 + self.mask_ratio = 0.75 self.accum_iter = 1 self.clip_grad = 0.8 self.warmup_epochs = 2 @@ -77,7 +79,7 @@ def __init__(self): # distributed training self.local_rank = 0 - + class Config_Generative_Model: def __init__(self): # project parameters @@ -92,11 +94,11 @@ def __init__(self): self.pretrain_gm_path = os.path.join(self.root_path, 'pretrains/ldm/label2img') # self.pretrain_gm_path = os.path.join(self.root_path, 'pretrains/ldm/text2img-large') # self.pretrain_gm_path = os.path.join(self.root_path, 'pretrains/ldm/layout2img') - + self.dataset = 'GOD' # GOD or BOLD5000 self.kam_subs = ['sbj_3'] self.bold5000_subs = ['CSI4'] - self.pretrain_mbm_path = os.path.join(self.root_path, f'pretrains/{self.dataset}/fmri_encoder.pth') + self.pretrain_mbm_path = os.path.join(self.root_path, f'pretrains/{self.dataset}/fmri_encoder.pth') self.img_size = 256 @@ -105,7 +107,7 @@ def __init__(self): self.batch_size = 5 if self.dataset == 'GOD' else 25 self.lr = 5.3e-5 self.num_epoch = 500 - + self.precision = 32 self.accumulate_grad = 1 self.crop_ratio = 0.2 diff --git a/code/dataset.py b/code/dataset.py index e72caf1..c783873 100644 --- a/code/dataset.py +++ b/code/dataset.py @@ -1,14 +1,16 @@ -from torch.utils.data import Dataset -import numpy as np -import os -from scipy import interpolate -from einops import rearrange -import json import csv -import torch +import json +import os from pathlib import Path -import torchvision.transforms as transforms + +import numpy as np +import torch +from torchvision import transforms +from einops import rearrange +from scipy import interpolate +from torch.utils.data import Dataset + def identity(x): return x @@ -52,7 +54,7 @@ def augmentation(data, aug_times=2, interpolation_ratio=0.5): data: num_samples, num_voxels_padded return: data_aug: num_samples*aug_times, num_voxels_padded ''' - num_to_generate = int((aug_times-1)*len(data)) + num_to_generate = int((aug_times-1)*len(data)) if num_to_generate == 0: return data pairs_idx = np.random.choice(len(data), size=(num_to_generate, 2), replace=True) @@ -85,28 +87,28 @@ def img_norm(img): return img def channel_first(img): - if img.shape[-1] == 3: - return rearrange(img, 'h w c -> c h w') - return img + if img.shape[-1] == 3: + return rearrange(img, 'h w c -> c h w') + return img class hcp_dataset(Dataset): - def __init__(self, path='../data/HCP/npz', roi='VC', patch_size=16, transform=identity, aug_times=2, + def __init__(self, path='../data/HCP/npz', roi='VC', patch_size=16, transform=identity, aug_times=2, num_sub_limit=None, include_kam=False, include_hcp=True): - super(hcp_dataset, self).__init__() + super().__init__() data = [] images = [] - + if include_hcp: for c, sub in enumerate(os.listdir(path)): if os.path.isfile(os.path.join(path,sub,'HCP_visual_voxel.npz')) == False: - continue + continue if num_sub_limit is not None and c > num_sub_limit: break npz = dict(np.load(os.path.join(path,sub,'HCP_visual_voxel.npz'))) voxels = np.concatenate([npz['V1'],npz['V2'],npz['V3'],npz['V4']], axis=-1) if roi == 'VC' else npz[roi] # 1200, num_voxels voxels = process_voxel_ts(voxels, patch_size) # num_samples, num_voxels_padded data.append(voxels) - + data = augmentation(np.concatenate(data, axis=0), aug_times) # num_samples, num_voxels_padded data = np.expand_dims(data, axis=1) # num_samples, 1, num_voxels_padded images += [None] * len(data) @@ -124,7 +126,7 @@ def __init__(self, path='../data/HCP/npz', roi='VC', patch_size=16, transform=id images += k.images assert len(data) != 0, 'No data found' - + self.roi = roi self.patch_size = patch_size self.num_voxels = data.shape[-1] @@ -133,13 +135,13 @@ def __init__(self, path='../data/HCP/npz', roi='VC', patch_size=16, transform=id self.images = images self.images_transform = transforms.Compose([ img_norm, - transforms.Resize((112, 112)), + transforms.Resize((112, 112)), channel_first ]) def __len__(self): return len(self.data) - + def __getitem__(self, index): img = self.images[index] images_transform = self.images_transform if img is not None else identity @@ -147,10 +149,10 @@ def __getitem__(self, index): return {'fmri': self.transform(self.data[index]), 'image': images_transform(img)} - + class Kamitani_pretrain_dataset(Dataset): def __init__(self, path='../data/Kamitani/npz', roi='VC', patch_size=16, transform=identity, aug_times=2): - super(Kamitani_pretrain_dataset, self).__init__() + super().__init__() k1, k2 = create_Kamitani_dataset(path, roi, patch_size, transform, include_nonavg_test=True) # data = np.concatenate([k1.fmri, k2.fmri], axis=0) # self.images = [img for img in k1.image] + [None] * len(k2.fmri) @@ -164,10 +166,10 @@ def __init__(self, path='../data/Kamitani/npz', roi='VC', patch_size=16, transfo self.patch_size = patch_size self.num_voxels = data.shape[-1] self.transform = transform - + def __len__(self): return len(self.data) - + def __getitem__(self, index): return self.transform(self.data[index]) @@ -194,7 +196,7 @@ def get_img_label(class_index:dict, img_filename:list, naive_label_set=None): return img_label, naive_label def create_Kamitani_dataset(path='../data/Kamitani/npz', roi='VC', patch_size=16, fmri_transform=identity, - image_transform=identity, subjects = ['sbj_1', 'sbj_2', 'sbj_3', 'sbj_4', 'sbj_5'], + image_transform=identity, subjects = ['sbj_1', 'sbj_2', 'sbj_3', 'sbj_4', 'sbj_5'], test_category=None, include_nonavg_test=False): img_npz = dict(np.load(os.path.join(path, 'images_256.npz'))) with open(os.path.join(path, 'imagenet_class_index.json'), 'r') as f: @@ -223,10 +225,10 @@ def create_Kamitani_dataset(path='../data/Kamitani/npz', roi='VC', patch_size=1 train_img.append(img_npz['train_images'][npz['arr_3']]) train_lb = [train_img_label[i] for i in npz['arr_3']] test_lb = test_img_label - + roi_mask = npz[roi] tr = npz['arr_0'][..., roi_mask] # train - tt = npz['arr_2'][..., roi_mask] + tt = npz['arr_2'][..., roi_mask] if include_nonavg_test: tt = np.concatenate([tt, npz['arr_1'][..., roi_mask]], axis=0) @@ -237,14 +239,14 @@ def create_Kamitani_dataset(path='../data/Kamitani/npz', roi='VC', patch_size=1 train_fmri.append(tr) test_fmri.append(tt) if test_category is not None: - train_img_, train_fmri_, test_img_, test_fmri_, train_lb, test_lb = reorganize_train_test(train_img[-1], train_fmri[-1], + train_img_, train_fmri_, test_img_, test_fmri_, train_lb, test_lb = reorganize_train_test(train_img[-1], train_fmri[-1], test_img[-1], test_fmri[-1], train_lb, test_lb, test_category, npz['arr_3']) train_img[-1] = train_img_ train_fmri[-1] = train_fmri_ test_img[-1] = test_img_ test_fmri[-1] = test_fmri_ - + train_img_label_all += train_lb test_img_label_all += test_lb @@ -267,13 +269,13 @@ def create_Kamitani_dataset(path='../data/Kamitani/npz', roi='VC', patch_size=1 # train_img = rearrange(train_img, 'n h w c -> n c h w') if isinstance(image_transform, list): - return (Kamitani_dataset(train_fmri, train_img, train_img_label_all, fmri_transform, image_transform[0], num_voxels, len(npz['arr_0'])), + return (Kamitani_dataset(train_fmri, train_img, train_img_label_all, fmri_transform, image_transform[0], num_voxels, len(npz['arr_0'])), Kamitani_dataset(test_fmri, test_img, test_img_label_all, torch.FloatTensor, image_transform[1], num_voxels, len(npz['arr_2']))) else: - return (Kamitani_dataset(train_fmri, train_img, train_img_label_all, fmri_transform, image_transform, num_voxels, len(npz['arr_0'])), + return (Kamitani_dataset(train_fmri, train_img, train_img_label_all, fmri_transform, image_transform, num_voxels, len(npz['arr_0'])), Kamitani_dataset(test_fmri, test_img, test_img_label_all, torch.FloatTensor, image_transform, num_voxels, len(npz['arr_2']))) -def reorganize_train_test(train_img, train_fmri, test_img, test_fmri, train_lb, test_lb, +def reorganize_train_test(train_img, train_fmri, test_img, test_fmri, train_lb, test_lb, test_category, train_index_lookup): test_img_ = [] test_fmri_ = [] @@ -287,7 +289,7 @@ def reorganize_train_test(train_img, train_fmri, test_img, test_fmri, train_lb, test_fmri_.append(train_fmri[train_idx]) test_lb_.append(train_lb[train_idx]) train_idx_list.append(train_idx) - + train_img_ = np.stack([img for i, img in enumerate(train_img) if i not in train_idx_list]) train_fmri_ = np.stack([fmri for i, fmri in enumerate(train_fmri) if i not in train_idx_list]) train_lb_ = [lb for i, lb in enumerate(train_lb) if i not in train_idx_list] + test_lb @@ -301,7 +303,7 @@ def reorganize_train_test(train_img, train_fmri, test_img, test_fmri, train_lb, class Kamitani_dataset(Dataset): def __init__(self, fmri, image, img_label, fmri_transform=identity, image_transform=identity, num_voxels=0, num_per_sub=50): - super(Kamitani_dataset, self).__init__() + super().__init__() self.fmri = fmri self.image = image if len(self.image) != len(self.fmri): @@ -317,7 +319,7 @@ def __init__(self, fmri, image, img_label, fmri_transform=identity, image_transf def __len__(self): return len(self.fmri) - + def __getitem__(self, index): fmri = self.fmri[index] if index >= len(self.image): @@ -336,7 +338,7 @@ def __getitem__(self, index): class base_dataset(Dataset): def __init__(self, x, y=None, transform=identity): - super(base_dataset, self).__init__() + super().__init__() self.x = x self.y = y self.transform = transform @@ -347,7 +349,7 @@ def __getitem__(self, index): return self.transform(self.x[index]) else: return self.transform(self.x[index]), self.transform(self.y[index]) - + def remove_repeats(fmri, img_lb): assert len(fmri) == len(img_lb), 'len error' fmri_dict = {} @@ -386,7 +388,7 @@ def get_stimuli_list(root, sub): def list_get_all_index(list, value): return [i for i, v in enumerate(list) if v == value] - + def create_BOLD5000_dataset(path='../data/BOLD5000', patch_size=16, fmri_transform=identity, image_transform=identity, subjects = ['CSI1', 'CSI2', 'CSI3', 'CSI4'], include_nonavg_test=False): roi_list = ['EarlyVis', 'LOC', 'OPA', 'PPA', 'RSC'] @@ -397,7 +399,7 @@ def create_BOLD5000_dataset(path='../data/BOLD5000', patch_size=16, fmri_transfo fmri_files = [f for f in os.listdir(fmri_path) if f.endswith('.npy')] fmri_files.sort() - + fmri_train_major = [] fmri_test_major = [] img_train_major = [] @@ -411,17 +413,17 @@ def create_BOLD5000_dataset(path='../data/BOLD5000', patch_size=16, fmri_transfo fmri_data_sub.append(np.load(os.path.join(fmri_path, npy))) fmri_data_sub = np.concatenate(fmri_data_sub, axis=-1) # concatenate all rois fmri_data_sub = normalize(pad_to_patch_size(fmri_data_sub, patch_size)) - + # load image img_files = get_stimuli_list(img_path, sub) img_data_sub = [imgs_dict[name] for name in img_files] - + # split train test test_idx = [list_get_all_index(img_files, img) for img in repeated_imgs_list] test_idx = [i for i in test_idx if len(i) > 0] # remove empy list for CSI4 test_fmri = np.stack([fmri_data_sub[idx].mean(axis=0) for idx in test_idx]) test_img = np.stack([img_data_sub[idx[0]] for idx in test_idx]) - + test_idx_flatten = [] for idx in test_idx: test_idx_flatten += idx # flatten @@ -444,10 +446,10 @@ def create_BOLD5000_dataset(path='../data/BOLD5000', patch_size=16, fmri_transfo num_voxels = fmri_train_major.shape[-1] if isinstance(image_transform, list): - return (BOLD5000_dataset(fmri_train_major, img_train_major, fmri_transform, image_transform[0], num_voxels), + return (BOLD5000_dataset(fmri_train_major, img_train_major, fmri_transform, image_transform[0], num_voxels), BOLD5000_dataset(fmri_test_major, img_test_major, torch.FloatTensor, image_transform[1], num_voxels)) else: - return (BOLD5000_dataset(fmri_train_major, img_train_major, fmri_transform, image_transform, num_voxels), + return (BOLD5000_dataset(fmri_train_major, img_train_major, fmri_transform, image_transform, num_voxels), BOLD5000_dataset(fmri_test_major, img_test_major, torch.FloatTensor, image_transform, num_voxels)) class BOLD5000_dataset(Dataset): @@ -457,16 +459,16 @@ def __init__(self, fmri, image, fmri_transform=identity, image_transform=identit self.fmri_transform = fmri_transform self.image_transform = image_transform self.num_voxels = num_voxels - + def __len__(self): return len(self.fmri) - + def __getitem__(self, index): fmri = self.fmri[index] img = self.image[index] / 255.0 - fmri = np.expand_dims(fmri, axis=0) + fmri = np.expand_dims(fmri, axis=0) return {'fmri': self.fmri_transform(fmri), 'image': self.image_transform(img)} - + def switch_sub_view(self, sub, subs): # Not implemented - pass \ No newline at end of file + pass diff --git a/code/dc_ldm/ldm_for_fmri.py b/code/dc_ldm/ldm_for_fmri.py index 01f1c14..2cb6d95 100644 --- a/code/dc_ldm/ldm_for_fmri.py +++ b/code/dc_ldm/ldm_for_fmri.py @@ -1,26 +1,29 @@ +import os + import numpy as np -import wandb import torch -from dc_ldm.util import instantiate_from_config -from omegaconf import OmegaConf -import torch.nn as nn -import os +from torch import nn +import torch.nn.functional as F +from torch.utils.data import DataLoader +from torchvision.utils import make_grid +import wandb from dc_ldm.models.diffusion.plms import PLMSSampler +from dc_ldm.util import instantiate_from_config from einops import rearrange, repeat -from torchvision.utils import make_grid -from torch.utils.data import DataLoader -import torch.nn.functional as F +from omegaconf import OmegaConf from sc_mbm.mae_for_fmri import fmri_encoder + + def create_model_from_config(config, num_voxels, global_pool): model = fmri_encoder(num_voxels=num_voxels, patch_size=config.patch_size, embed_dim=config.embed_dim, - depth=config.depth, num_heads=config.num_heads, mlp_ratio=config.mlp_ratio, global_pool=global_pool) + depth=config.depth, num_heads=config.num_heads, mlp_ratio=config.mlp_ratio, global_pool=global_pool) return model class cond_stage_model(nn.Module): def __init__(self, metafile, num_voxels, cond_dim=1280, global_pool=True): super().__init__() - # prepare pretrained fmri mae + # prepare pretrained fmri mae model = create_model_from_config(metafile['config'], num_voxels, global_pool) model.load_checkpoint(metafile['model']) self.mae = model @@ -49,7 +52,7 @@ def __init__(self, metafile, num_voxels, device=torch.device('cpu'), pretrain_root='../pretrains/ldm/label2img', logger=None, ddim_steps=250, global_pool=True, use_time_cond=True): self.ckp_path = os.path.join(pretrain_root, 'model.ckpt') - self.config_path = os.path.join(pretrain_root, 'config.yaml') + self.config_path = os.path.join(pretrain_root, 'config.yaml') config = OmegaConf.load(self.config_path) config.model.params.unet_config.params.use_time_cond = use_time_cond config.model.params.unet_config.params.global_pool = global_pool @@ -58,7 +61,7 @@ def __init__(self, metafile, num_voxels, device=torch.device('cpu'), model = instantiate_from_config(config.model) pl_sd = torch.load(self.ckp_path, map_location="cpu")['state_dict'] - + m, u = model.load_state_dict(pl_sd, strict=False) model.cond_stage_trainable = True model.cond_stage_model = cond_stage_model(metafile, num_voxels, self.cond_dim, global_pool=global_pool) @@ -72,7 +75,7 @@ def __init__(self, metafile, num_voxels, device=torch.device('cpu'), model.p_image_size = config.model.params.image_size model.ch_mult = config.model.params.first_stage_config.params.ddconfig.ch_mult - self.device = device + self.device = device self.model = model self.ldm_config = config self.pretrain_root = pretrain_root @@ -88,7 +91,7 @@ def finetune(self, trainers, dataset, test_dataset, bs1, lr1, # self.model.train_dataset = dataset self.model.run_full_validation_threshold = 0.15 # stage one: train the cond encoder with the pretrained one - + # # stage one: only optimize conditional encoders print('\n##### Stage One: only optimize conditional encoders #####') dataloader = DataLoader(dataset, batch_size=bs1, shuffle=True) @@ -102,7 +105,7 @@ def finetune(self, trainers, dataset, test_dataset, bs1, lr1, trainers.fit(self.model, dataloader, val_dataloaders=test_loader) self.model.unfreeze_whole_model() - + torch.save( { 'model_state_dict': self.model.state_dict(), @@ -112,14 +115,14 @@ def finetune(self, trainers, dataset, test_dataset, bs1, lr1, }, os.path.join(output_path, 'checkpoint.pth') ) - + @torch.no_grad() def generate(self, fmri_embedding, num_samples, ddim_steps, HW=None, limit=None, state=None): # fmri_embedding: n, seq_len, embed_dim all_samples = [] if HW is None: - shape = (self.ldm_config.model.params.channels, + shape = (self.ldm_config.model.params.channels, self.ldm_config.model.params.image_size, self.ldm_config.model.params.image_size) else: num_resolutions = len(self.ldm_config.model.params.first_stage_config.params.ddconfig.ch_mult) @@ -131,7 +134,7 @@ def generate(self, fmri_embedding, num_samples, ddim_steps, HW=None, limit=None, # sampler = DDIMSampler(model) if state is not None: torch.cuda.set_rng_state(state) - + with model.ema_scope(): model.eval() for count, item in enumerate(fmri_embedding): @@ -142,9 +145,9 @@ def generate(self, fmri_embedding, num_samples, ddim_steps, HW=None, limit=None, gt_image = rearrange(item['image'], 'h w c -> 1 c h w') # h w c print(f"rendering {num_samples} examples in {ddim_steps} steps.") # assert latent.shape[-1] == self.fmri_latent_dim, 'dim error' - + c = model.get_learned_conditioning(repeat(latent, 'h w -> c h w', c=num_samples).to(self.device)) - samples_ddim, _ = sampler.sample(S=ddim_steps, + samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=c, batch_size=num_samples, shape=shape, @@ -153,10 +156,10 @@ def generate(self, fmri_embedding, num_samples, ddim_steps, HW=None, limit=None, x_samples_ddim = model.decode_first_stage(samples_ddim) x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, min=0.0, max=1.0) gt_image = torch.clamp((gt_image+1.0)/2.0, min=0.0, max=1.0) - + all_samples.append(torch.cat([gt_image, x_samples_ddim.detach().cpu()], dim=0)) # put groundtruth at first - - + + # display as grid grid = torch.stack(all_samples, 0) grid = rearrange(grid, 'n b c h w -> (n b) c h w') @@ -165,7 +168,5 @@ def generate(self, fmri_embedding, num_samples, ddim_steps, HW=None, limit=None, # to image grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() model = model.to('cpu') - - return grid, (255. * torch.stack(all_samples, 0).cpu().numpy()).astype(np.uint8) - + return grid, (255. * torch.stack(all_samples, 0).cpu().numpy()).astype(np.uint8) diff --git a/code/dc_ldm/models/autoencoder.py b/code/dc_ldm/models/autoencoder.py index 56aed97..ec3b09a 100644 --- a/code/dc_ldm/models/autoencoder.py +++ b/code/dc_ldm/models/autoencoder.py @@ -1,18 +1,22 @@ -import torch -import pytorch_lightning as pl -import torch.nn.functional as F from contextlib import contextmanager + import numpy as np +import pytorch_lightning as pl +import torch # from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer -import torch.nn as nn -from dc_ldm.modules.diffusionmodules.model import Encoder, Decoder -from dc_ldm.modules.distributions.distributions import DiagonalGaussianDistribution +from torch import nn +import torch.nn.functional as F +from torch import einsum +from torch.optim.lr_scheduler import LambdaLR +from dc_ldm.modules.diffusionmodules.model import Decoder, Encoder +from dc_ldm.modules.distributions.distributions import \ + DiagonalGaussianDistribution from dc_ldm.modules.ema import LitEma from dc_ldm.util import instantiate_from_config -from packaging import version -from torch.optim.lr_scheduler import LambdaLR -from torch import einsum from einops import rearrange +from packaging import version + + class VectorQuantizer(nn.Module): """ @@ -369,7 +373,8 @@ def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs): if plot_ema: with self.ema_scope(): xrec_ema, _ = self(x) - if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema) + if x.shape[1] > 3: + xrec_ema = self.to_rgb(xrec_ema) log["reconstructions_ema"] = xrec_ema return log @@ -402,7 +407,7 @@ def decode(self, h, force_not_quantize=False): dec = self.decoder(quant) return dec - + class AutoencoderKL(pl.LightningModule): def __init__(self, ddconfig, @@ -431,7 +436,7 @@ def __init__(self, if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) self.trainable = False - + def init_from_ckpt(self, path, ignore_keys=list()): sd = torch.load(path, map_location="cpu")["state_dict"] keys = list(sd.keys()) diff --git a/code/dc_ldm/models/diffusion/classifier.py b/code/dc_ldm/models/diffusion/classifier.py index c36811d..4a041b0 100644 --- a/code/dc_ldm/models/diffusion/classifier.py +++ b/code/dc_ldm/models/diffusion/classifier.py @@ -1,17 +1,18 @@ import os -import torch +from copy import deepcopy +from glob import glob + import pytorch_lightning as pl +import torch +from dc_ldm.modules.diffusionmodules.openaimodel import (EncoderUNetModel, + UNetModel) +from dc_ldm.util import default, instantiate_from_config, ismap, log_txt_as_img +from einops import rearrange +from natsort import natsorted from omegaconf import OmegaConf from torch.nn import functional as F from torch.optim import AdamW from torch.optim.lr_scheduler import LambdaLR -from copy import deepcopy -from einops import rearrange -from glob import glob -from natsort import natsorted - -from dc_ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel -from dc_ldm.util import log_txt_as_img, default, ismap, instantiate_from_config __models__ = { 'class_label': EncoderUNetModel, diff --git a/code/dc_ldm/models/diffusion/ddim.py b/code/dc_ldm/models/diffusion/ddim.py index c8d6b8d..a8765cf 100644 --- a/code/dc_ldm/models/diffusion/ddim.py +++ b/code/dc_ldm/models/diffusion/ddim.py @@ -1,14 +1,15 @@ """SAMPLING ONLY.""" -import torch -import numpy as np -from tqdm import tqdm from functools import partial -from dc_ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like +import numpy as np +import torch +from dc_ldm.modules.diffusionmodules.util import ( + make_ddim_sampling_parameters, make_ddim_timesteps, noise_like) +from tqdm import tqdm -class DDIMSampler(object): +class DDIMSampler(): def __init__(self, model, schedule="linear", **kwargs): super().__init__() self.model = model @@ -153,8 +154,10 @@ def ddim_sampling(self, cond, shape, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning) img, pred_x0 = outs - if callback: callback(i) - if img_callback: img_callback(pred_x0, i) + if callback: + callback(i) + if img_callback: + img_callback(pred_x0, i) if index % log_every_t == 0 or index == total_steps - 1: intermediates['x_inter'].append(img) diff --git a/code/dc_ldm/models/diffusion/ddpm.py b/code/dc_ldm/models/diffusion/ddpm.py index 8df6983..c3c0c88 100644 --- a/code/dc_ldm/models/diffusion/ddpm.py +++ b/code/dc_ldm/models/diffusion/ddpm.py @@ -6,28 +6,34 @@ -- merci """ import os -import torch -import torch.nn as nn -import numpy as np -import pytorch_lightning as pl -from torch.optim.lr_scheduler import LambdaLR -from einops import rearrange, repeat from contextlib import contextmanager from functools import partial -from tqdm import tqdm -from torchvision.utils import make_grid -from pytorch_lightning.utilities.distributed import rank_zero_only -from dc_ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config -from dc_ldm.modules.ema import LitEma -from dc_ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution -from dc_ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL -from dc_ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like +import numpy as np +import pytorch_lightning as pl +import torch +from torch import nn +import torch.nn.functional as F +from torch.optim.lr_scheduler import LambdaLR +from dc_ldm.models.autoencoder import (AutoencoderKL, IdentityFirstStage, + VQModelInterface) from dc_ldm.models.diffusion.ddim import DDIMSampler from dc_ldm.models.diffusion.plms import PLMSSampler -from PIL import Image -import torch.nn.functional as F +from dc_ldm.modules.diffusionmodules.util import (extract_into_tensor, + make_beta_schedule, + noise_like) +from dc_ldm.modules.distributions.distributions import ( + DiagonalGaussianDistribution, normal_kl) +from dc_ldm.modules.ema import LitEma +from dc_ldm.util import (count_params, default, exists, + instantiate_from_config, isimage, ismap, + log_txt_as_img, mean_flat) +from einops import rearrange, repeat from eval_metrics import get_similarity_metric +from PIL import Image +from pytorch_lightning.utilities.distributed import rank_zero_only +from torchvision.utils import make_grid +from tqdm import tqdm __conditioning_keys__ = {'concat': 'c_concat', 'crossattn': 'c_crossattn', @@ -121,14 +127,14 @@ def __init__(self, self.return_cond = False self.output_path = None self.main_config = None - self.best_val = 0.0 + self.best_val = 0.0 self.run_full_validation_threshold = 0.0 self.eval_avg = True def re_init_ema(self): if self.use_ema: self.model_ema = LitEma(self.model) - print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): @@ -358,7 +364,7 @@ def shared_step(self, batch): def training_step(self, batch, batch_idx): self.train() self.cond_stage_model.train() - + loss, loss_dict = self.shared_step(batch) self.log_dict(loss_dict, prog_bar=True, @@ -370,13 +376,13 @@ def training_step(self, batch, batch_idx): return loss - + @torch.no_grad() def generate(self, data, num_samples, ddim_steps=300, HW=None, limit=None, state=None): # fmri_embedding: n, seq_len, embed_dim all_samples = [] if HW is None: - shape = (self.p_channels, + shape = (self.p_channels, self.p_image_size, self.p_image_size) else: num_resolutions = len(self.ch_mult) @@ -396,7 +402,7 @@ def generate(self, data, num_samples, ddim_steps=300, HW=None, limit=None, state # rng = torch.Generator(device=self.device).manual_seed(2022).set_state(state) - # state = torch.cuda.get_rng_state() + # state = torch.cuda.get_rng_state() with model.ema_scope(): for count, item in enumerate(zip(data['fmri'], data['image'])): if limit is not None: @@ -406,7 +412,7 @@ def generate(self, data, num_samples, ddim_steps=300, HW=None, limit=None, state gt_image = rearrange(item[1], 'h w c -> 1 c h w') # h w c print(f"rendering {num_samples} examples in {ddim_steps} steps.") c = model.get_learned_conditioning(repeat(latent, 'h w -> c h w', c=num_samples).to(self.device)) - samples_ddim, _ = sampler.sample(S=ddim_steps, + samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=c, batch_size=num_samples, shape=shape, @@ -416,9 +422,9 @@ def generate(self, data, num_samples, ddim_steps=300, HW=None, limit=None, state x_samples_ddim = model.decode_first_stage(samples_ddim) x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0,min=0.0, max=1.0) gt_image = torch.clamp((gt_image+1.0)/2.0,min=0.0, max=1.0) - + all_samples.append(torch.cat([gt_image.detach().cpu(), x_samples_ddim.detach().cpu()], dim=0)) # put groundtruth at first - + # display as grid grid = torch.stack(all_samples, 0) grid = rearrange(grid, 'n b c h w -> (n b) c h w') @@ -434,9 +440,9 @@ def save_images(self, all_samples, suffix=0): for sp_idx, imgs in enumerate(all_samples): for copy_idx, img in enumerate(imgs[1:]): img = rearrange(img, 'c h w -> h w c') - Image.fromarray(img).save(os.path.join(self.output_path, 'val', + Image.fromarray(img).save(os.path.join(self.output_path, 'val', f'{self.validation_count}_{suffix}', f'test{sp_idx}-{copy_idx}.png')) - + def full_validation(self, batch, state=None): print('###### run full validation! ######\n') grid, all_samples, state = self.generate(batch, ddim_steps=self.ddim_steps, num_samples=5, limit=None, state=state) @@ -445,7 +451,7 @@ def full_validation(self, batch, state=None): metric_dict = {f'val/{k}_full':v for k, v in zip(metric_list, metric)} self.logger.log_metrics(metric_dict) grid_imgs = Image.fromarray(grid.astype(np.uint8)) - self.logger.log_image(key=f'samples_test_full', images=[grid_imgs]) + self.logger.log_image(key='samples_test_full', images=[grid_imgs]) if metric[-1] > self.best_val: self.best_val = metric[-1] torch.save( @@ -462,14 +468,14 @@ def full_validation(self, batch, state=None): def validation_step(self, batch, batch_idx): if batch_idx != 0: return - + if self.validation_count % 15 == 0 and self.trainer.current_epoch != 0: self.full_validation(batch) else: grid, all_samples, state = self.generate(batch, ddim_steps=self.ddim_steps, num_samples=3, limit=5) metric, metric_list = self.get_eval_metric(all_samples, avg=self.eval_avg) grid_imgs = Image.fromarray(grid.astype(np.uint8)) - self.logger.log_image(key=f'samples_test', images=[grid_imgs]) + self.logger.log_image(key='samples_test', images=[grid_imgs]) metric_dict = {f'val/{k}':v for k, v in zip(metric_list, metric)} self.logger.log_metrics(metric_dict) if metric[-1] > self.run_full_validation_threshold: @@ -479,7 +485,7 @@ def validation_step(self, batch, batch_idx): def get_eval_metric(self, samples, avg=True): metric_list = ['mse', 'pcc', 'ssim', 'psm'] res_list = [] - + gt_images = [img[0] for img in samples] gt_images = rearrange(np.stack(gt_images), 'n c h w -> n h w c') samples_to_run = np.arange(1, len(samples[0])) if avg else [1] @@ -490,12 +496,12 @@ def get_eval_metric(self, samples, avg=True): pred_images = rearrange(np.stack(pred_images), 'n c h w -> n h w c') res = get_similarity_metric(pred_images, gt_images, method='pair-wise', metric_name=m) res_part.append(np.mean(res)) - res_list.append(np.mean(res_part)) + res_list.append(np.mean(res_part)) res_part = [] for s in samples_to_run: pred_images = [img[s] for img in samples] pred_images = rearrange(np.stack(pred_images), 'n c h w -> n h w c') - res = get_similarity_metric(pred_images, gt_images, 'class', None, + res = get_similarity_metric(pred_images, gt_images, 'class', None, n_way=50, num_trials=50, top_k=1, device='cuda') res_part.append(np.mean(res)) res_list.append(np.mean(res_part)) @@ -503,7 +509,7 @@ def get_eval_metric(self, samples, avg=True): metric_list.append('top-1-class') metric_list.append('top-1-class (max)') - return res_list, metric_list + return res_list, metric_list def on_train_batch_end(self, *args, **kwargs): if self.use_ema: @@ -593,7 +599,7 @@ def __init__(self, self.cond_stage_key = cond_stage_key try: self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 - except: + except Exception: self.num_downs = 0 if not scale_by_std: self.scale_factor = scale_factor @@ -601,17 +607,17 @@ def __init__(self, self.register_buffer('scale_factor', torch.tensor(scale_factor)) self.instantiate_first_stage(first_stage_config) self.instantiate_cond_stage(cond_stage_config) - + self.cond_stage_forward = cond_stage_forward self.clip_denoised = False - self.bbox_tokenizer = None + self.bbox_tokenizer = None self.restarted_from_ckpt = False if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys) self.restarted_from_ckpt = True self.train_cond_stage_only = False - + def make_cond_schedule(self, ): self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) @@ -663,7 +669,6 @@ def freeze_cond_stage(self): def unfreeze_cond_stage(self): for param in self.cond_stage_model.parameters(): param.requires_grad = True - def freeze_first_stage(self): self.first_stage_model.trainable = False @@ -684,7 +689,7 @@ def unfreeze_whole_model(self): self.first_stage_model.trainable = True for param in self.parameters(): param.requires_grad = True - + def instantiate_cond_stage(self, config): if not self.cond_stage_trainable: if config == "__is_first_stage__": @@ -967,7 +972,7 @@ def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_qua z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) # 2. apply model loop over last dim - if isinstance(self.first_stage_model, VQModelInterface): + if isinstance(self.first_stage_model, VQModelInterface): output_list = [self.first_stage_model.decode(z[:, :, :, :, i], force_not_quantize=predict_cids or force_not_quantize) for i in range(z.shape[-1])] @@ -1006,7 +1011,7 @@ def shared_step(self, batch, **kwargs): if self.return_cond: loss, cc = self(x, c) return loss, cc - else: + else: loss = self(x, c) return loss @@ -1210,8 +1215,10 @@ def progressive_denoising(self, cond, shape, verbose=True, callback=None, quanti if i % log_every_t == 0 or i == timesteps - 1: intermediates.append(x0_partial) - if callback: callback(i) - if img_callback: img_callback(img, i) + if callback: + callback(i) + if img_callback: + img_callback(img, i) return img, intermediates @torch.no_grad() @@ -1258,8 +1265,10 @@ def p_sample_loop(self, cond, shape, return_intermediates=False, if i % log_every_t == 0 or i == timesteps - 1: intermediates.append(img) - if callback: callback(i) - if img_callback: img_callback(img, i) + if callback: + callback(i) + if img_callback: + img_callback(img, i) if return_intermediates: return img, intermediates @@ -1414,14 +1423,14 @@ def configure_optimizers(self): lr = self.learning_rate if self.train_cond_stage_only: print(f"{self.__class__.__name__}: Only optimizing conditioner params!") - cond_parms = [p for n, p in self.named_parameters() + cond_parms = [p for n, p in self.named_parameters() if 'attn2' in n or 'time_embed_condtion' in n or 'norm2' in n] - # cond_parms = [p for n, p in self.named_parameters() + # cond_parms = [p for n, p in self.named_parameters() # if 'time_embed_condtion' in n] # cond_parms = [] - + params = list(self.cond_stage_model.parameters()) + cond_parms - + for p in params: p.requires_grad = True @@ -1448,7 +1457,7 @@ def configure_optimizers(self): 'frequency': 1 }] return [opt], scheduler - + return opt @torch.no_grad() diff --git a/code/dc_ldm/models/diffusion/plms.py b/code/dc_ldm/models/diffusion/plms.py index 904b70e..f55bb51 100644 --- a/code/dc_ldm/models/diffusion/plms.py +++ b/code/dc_ldm/models/diffusion/plms.py @@ -1,14 +1,15 @@ """SAMPLING ONLY.""" -import torch -import numpy as np -from tqdm import tqdm from functools import partial -from dc_ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like +import numpy as np +import torch +from dc_ldm.modules.diffusionmodules.util import ( + make_ddim_sampling_parameters, make_ddim_timesteps, noise_like) +from tqdm import tqdm -class PLMSSampler(object): +class PLMSSampler(): def __init__(self, model, schedule="linear", **kwargs): super().__init__() self.model = model @@ -161,8 +162,10 @@ def plms_sampling(self, cond, shape, old_eps.append(e_t) if len(old_eps) >= 4: old_eps.pop(0) - if callback: callback(i) - if img_callback: img_callback(pred_x0, i) + if callback: + callback(i) + if img_callback: + img_callback(pred_x0, i) if index % log_every_t == 0 or index == total_steps - 1: intermediates['x_inter'].append(img) diff --git a/code/dc_ldm/modules/attention.py b/code/dc_ldm/modules/attention.py index 27e65f7..13a1840 100644 --- a/code/dc_ldm/modules/attention.py +++ b/code/dc_ldm/modules/attention.py @@ -1,11 +1,11 @@ -from inspect import isfunction import math +from inspect import isfunction + import torch import torch.nn.functional as F -from torch import nn, einsum -from einops import rearrange, repeat - from dc_ldm.modules.diffusionmodules.util import checkpoint +from einops import rearrange, repeat +from torch import einsum, nn def exists(val): @@ -89,7 +89,7 @@ def forward(self, x): b, c, h, w = x.shape qkv = self.to_qkv(x) q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) - k = k.softmax(dim=-1) + k = k.softmax(dim=-1) context = torch.einsum('bhdn,bhen->bhde', k, v) out = torch.einsum('bhde,bhdn->bhen', context, q) out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) @@ -172,8 +172,8 @@ def forward(self, x, context=None, mask=None): q = self.to_q(x) context = default(context, x) - k = self.to_k(context) - v = self.to_v(context) + k = self.to_k(context) + v = self.to_v(context) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) @@ -188,7 +188,7 @@ def forward(self, x, context=None, mask=None): # attention, what we cannot get enough of attn = sim.softmax(dim=-1) - out = einsum('b i j, b j d -> b i d', attn, v) + out = einsum('b i j, b j d -> b i d', attn, v) out = rearrange(out, '(b h) n d -> b n (h d)', h=h) return self.to_out(out) @@ -258,4 +258,4 @@ def forward(self, x, context=None): x = block(x, context=context) x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) x = self.proj_out(x) - return x + x_in \ No newline at end of file + return x + x_in diff --git a/code/dc_ldm/modules/diffusionmodules/model.py b/code/dc_ldm/modules/diffusionmodules/model.py index aea4833..2612c85 100644 --- a/code/dc_ldm/modules/diffusionmodules/model.py +++ b/code/dc_ldm/modules/diffusionmodules/model.py @@ -1,12 +1,12 @@ # pytorch_diffusion + derived encoder decoder import math -import torch -import torch.nn as nn -import numpy as np -from einops import rearrange -from dc_ldm.util import instantiate_from_config +import numpy as np +import torch +from torch import nn from dc_ldm.modules.attention import LinearAttention +from dc_ldm.util import instantiate_from_config +from einops import rearrange def get_timestep_embedding(timesteps, embedding_dim): @@ -218,7 +218,8 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"): super().__init__() - if use_linear_attn: attn_type = "linear" + if use_linear_attn: + attn_type = "linear" self.ch = ch self.temb_ch = self.ch*4 self.num_resolutions = len(ch_mult) @@ -371,7 +372,8 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", **ignore_kwargs): super().__init__() - if use_linear_attn: attn_type = "linear" + if use_linear_attn: + attn_type = "linear" self.ch = ch self.temb_ch = 0 self.num_resolutions = len(ch_mult) @@ -465,7 +467,8 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, attn_type="vanilla", **ignorekwargs): super().__init__() - if use_linear_attn: attn_type = "linear" + if use_linear_attn: + attn_type = "linear" self.ch = ch self.temb_ch = 0 self.num_resolutions = len(ch_mult) @@ -832,4 +835,3 @@ def forward(self,x): if self.do_reshape: z = rearrange(z,'b c h w -> b (h w) c') return z - diff --git a/code/dc_ldm/modules/diffusionmodules/openaimodel.py b/code/dc_ldm/modules/diffusionmodules/openaimodel.py index 427cdd5..975c187 100644 --- a/code/dc_ldm/modules/diffusionmodules/openaimodel.py +++ b/code/dc_ldm/modules/diffusionmodules/openaimodel.py @@ -1,25 +1,19 @@ +import math from abc import abstractmethod from functools import partial -import math from typing import Iterable import numpy as np +import torch import torch as th -import torch.nn as nn +from torch import nn import torch.nn.functional as F -import torch - - -from dc_ldm.modules.diffusionmodules.util import ( - checkpoint, - conv_nd, - linear, - avg_pool_nd, - zero_module, - normalization, - timestep_embedding, -) from dc_ldm.modules.attention import SpatialTransformer +from dc_ldm.modules.diffusionmodules.util import (avg_pool_nd, checkpoint, + conv_nd, linear, + normalization, + timestep_embedding, + zero_module) # dummy replace @@ -519,7 +513,7 @@ def __init__( if self.num_classes is not None: self.label_emb = nn.Embedding(num_classes, time_embed_dim) - + # self.time_embed_condtion = nn.Linear(context_dim, time_embed_dim, bias=False) if use_time_cond: self.time_embed_condtion = nn.Sequential( @@ -527,7 +521,7 @@ def __init__( nn.Conv1d(77//2, 1, 1, bias=True), nn.Linear(context_dim, time_embed_dim, bias=True) ) if global_pool == False else nn.Linear(context_dim, time_embed_dim, bias=True) - + self.input_blocks = nn.ModuleList( [ TimestepEmbedSequential( @@ -977,4 +971,3 @@ def forward(self, x, timesteps): else: h = h.type(x.dtype) return self.out(h) - diff --git a/code/dc_ldm/modules/diffusionmodules/util.py b/code/dc_ldm/modules/diffusionmodules/util.py index 3a3a0a9..a60950e 100644 --- a/code/dc_ldm/modules/diffusionmodules/util.py +++ b/code/dc_ldm/modules/diffusionmodules/util.py @@ -8,14 +8,14 @@ # thanks! -import os import math -import torch -import torch.nn as nn -import numpy as np -from einops import repeat +import os +import numpy as np +import torch +from torch import nn from dc_ldm.util import instantiate_from_config +from einops import repeat def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): @@ -264,4 +264,4 @@ def forward(self, c_concat, c_crossattn): def noise_like(shape, device, repeat=False): repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) noise = lambda: torch.randn(shape, device=device) - return repeat_noise() if repeat else noise() \ No newline at end of file + return repeat_noise() if repeat else noise() diff --git a/code/dc_ldm/modules/distributions/distributions.py b/code/dc_ldm/modules/distributions/distributions.py index f2b8ef9..9154154 100644 --- a/code/dc_ldm/modules/distributions/distributions.py +++ b/code/dc_ldm/modules/distributions/distributions.py @@ -1,5 +1,5 @@ -import torch import numpy as np +import torch class AbstractDistribution: @@ -21,7 +21,7 @@ def mode(self): return self.value -class DiagonalGaussianDistribution(object): +class DiagonalGaussianDistribution(): def __init__(self, parameters, deterministic=False): self.parameters = parameters self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) diff --git a/code/dc_ldm/modules/encoders/modules.py b/code/dc_ldm/modules/encoders/modules.py index 9d71d1c..e97ae11 100644 --- a/code/dc_ldm/modules/encoders/modules.py +++ b/code/dc_ldm/modules/encoders/modules.py @@ -1,8 +1,9 @@ -import torch -import torch.nn as nn from functools import partial -from dc_ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test +import torch +from torch import nn +from dc_ldm.modules.x_transformer import ( # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test + Encoder, TransformerWrapper) class AbstractEncoder(nn.Module): @@ -50,7 +51,8 @@ class BERTTokenizer(AbstractEncoder): """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" def __init__(self, device="cuda", vq_interface=True, max_length=77): super().__init__() - from transformers import BertTokenizerFast # TODO: add to reuquirements + from transformers import \ + BertTokenizerFast # TODO: add to reuquirements self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") self.device = device self.vq_interface = vq_interface diff --git a/code/dc_ldm/modules/losses/__init__.py b/code/dc_ldm/modules/losses/__init__.py index f0c3e8e..30d7a89 100644 --- a/code/dc_ldm/modules/losses/__init__.py +++ b/code/dc_ldm/modules/losses/__init__.py @@ -1 +1 @@ -from dc_ldm.modules.losses.contperceptual import LPIPSWithDiscriminator \ No newline at end of file +from dc_ldm.modules.losses.contperceptual import LPIPSWithDiscriminator diff --git a/code/dc_ldm/modules/losses/contperceptual.py b/code/dc_ldm/modules/losses/contperceptual.py index 308146f..1a8d2fd 100644 --- a/code/dc_ldm/modules/losses/contperceptual.py +++ b/code/dc_ldm/modules/losses/contperceptual.py @@ -1,7 +1,7 @@ import torch -import torch.nn as nn - +from torch import nn from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? + # from vqperceptual import * # replace taming dependency to local vqperceptual.py @@ -109,4 +109,3 @@ def forward(self, inputs, reconstructions, posteriors, optimizer_idx, "{}/logits_fake".format(split): logits_fake.detach().mean() } return d_loss, log - diff --git a/code/dc_ldm/modules/losses/vqperceptual.py b/code/dc_ldm/modules/losses/vqperceptual.py index fba9d56..9e47f63 100644 --- a/code/dc_ldm/modules/losses/vqperceptual.py +++ b/code/dc_ldm/modules/losses/vqperceptual.py @@ -1,13 +1,13 @@ import torch -from torch import nn import torch.nn.functional as F +from dc_ldm.util import exists from einops import repeat - -from taming.modules.discriminator.model import NLayerDiscriminator, weights_init +from taming.modules.discriminator.model import (NLayerDiscriminator, + weights_init) from taming.modules.losses.lpips import LPIPS from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss +from torch import nn -from dc_ldm.util import exists def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] @@ -112,7 +112,7 @@ def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, nll_loss = rec_loss #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] nll_loss = torch.mean(nll_loss) - + if optimizer_idx == 2: log = { "{}/nll_loss".format(split): nll_loss.detach().mean(), diff --git a/code/dc_ldm/modules/x_transformer.py b/code/dc_ldm/modules/x_transformer.py index 5ed9623..93bceb2 100644 --- a/code/dc_ldm/modules/x_transformer.py +++ b/code/dc_ldm/modules/x_transformer.py @@ -1,11 +1,12 @@ """shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers""" -import torch -from torch import nn, einsum -import torch.nn.functional as F +from collections import namedtuple from functools import partial from inspect import isfunction -from collections import namedtuple -from einops import rearrange, repeat, reduce + +import torch +import torch.nn.functional as F +from einops import rearrange, reduce, repeat +from torch import einsum, nn # constants @@ -609,7 +610,7 @@ def forward( b = x.shape[0] device = x.device num_mem = self.num_memory_tokens - + x = self.token_emb(x) x += self.pos_emb(x) x = self.emb_dropout(x) @@ -642,4 +643,3 @@ def forward( return out, attn_maps return out - diff --git a/code/dc_ldm/util.py b/code/dc_ldm/util.py index 51839cb..8607f5b 100644 --- a/code/dc_ldm/util.py +++ b/code/dc_ldm/util.py @@ -1,9 +1,8 @@ import importlib +from inspect import isfunction -import torch import numpy as np - -from inspect import isfunction +import torch from PIL import Image, ImageDraw, ImageFont @@ -83,4 +82,4 @@ def get_obj_from_str(string, reload=False): if reload: module_imp = importlib.import_module(module) importlib.reload(module_imp) - return getattr(importlib.import_module(module, package=None), cls) \ No newline at end of file + return getattr(importlib.import_module(module, package=None), cls) diff --git a/code/eval_metrics.py b/code/eval_metrics.py index db3f1de..69bb998 100644 --- a/code/eval_metrics.py +++ b/code/eval_metrics.py @@ -1,13 +1,15 @@ from os import get_inheritable + import numpy as np -from skimage.metrics import structural_similarity as ssim -from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity -from torchmetrics.image.fid import FrechetInceptionDistance -from torchvision.models import ViT_H_14_Weights, vit_h_14 import torch from einops import rearrange -from torchmetrics.functional import accuracy from PIL import Image +from skimage.metrics import structural_similarity as ssim +from torchmetrics.functional import accuracy +from torchmetrics.image.fid import FrechetInceptionDistance +from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity +from torchvision.models import ViT_H_14_Weights, vit_h_14 + def larger_the_better(gt, comp): return gt > comp @@ -53,7 +55,7 @@ def __call__(self, pred_imgs, gt_imgs): self.fid.reset() self.fid.update(torch.tensor(rearrange(gt_imgs, 'n w h c -> n c w h')), real=True) self.fid.update(torch.tensor(rearrange(pred_imgs, 'n w h c -> n c w h')), real=False) - return self.fid.compute().item() + return self.fid.compute().item() def pair_wise_score(pred_imgs, gt_imgs, metric, is_sucess): # pred_imgs: n, w, h, 3 @@ -100,7 +102,7 @@ def n_way_scores(pred_imgs, gt_imgs, metric, is_sucess, n=2, n_trials=100): if is_sucess(gt_score, comp_score): count += 1 if count == len(n_imgs): - correct_count += 1 + correct_count += 1 corrects.append(correct_count / n_trials) return corrects @@ -116,7 +118,7 @@ def n_way_top_k_acc(pred, class_id, n_way, num_trials=40, top_k=1): for t in range(num_trials): idxs_picked = np.random.choice(pick_range, n_way-1, replace=False) pred_picked = torch.cat([pred[class_id].unsqueeze(0), pred[idxs_picked]]) - acc = accuracy(pred_picked.unsqueeze(0), torch.tensor([0], device=pred.device), + acc = accuracy(pred_picked.unsqueeze(0), torch.tensor([0], device=pred.device), top_k=top_k) acc_list.append(acc.item()) return np.mean(acc_list), np.std(acc_list) @@ -128,7 +130,7 @@ def get_n_way_top_k_acc(pred_imgs, ground_truth, n_way, num_trials, top_k, devic preprocess = weights.transforms() model = model.to(device) model = model.eval() - + acc_list = [] std_list = [] for pred, gt in zip(pred_imgs, ground_truth): @@ -140,7 +142,7 @@ def get_n_way_top_k_acc(pred_imgs, ground_truth, n_way, num_trials, top_k, devic acc, std = n_way_top_k_acc(pred_out, gt_class_id, n_way, num_trials, top_k) acc_list.append(acc) std_list.append(std) - + if return_std: return acc_list, std_list return acc_list @@ -156,7 +158,7 @@ def get_similarity_metric(img1, img2, method='pair-wise', metric_name='mse', **k img2 = rearrange(img2, 'n c w h -> n w h c') if method == 'pair-wise': - eval_procedure_func = pair_wise_score + eval_procedure_func = pair_wise_score elif method == 'n-way': eval_procedure_func = n_way_scores elif method == 'metrics-only': @@ -183,5 +185,5 @@ def get_similarity_metric(img1, img2, method='pair-wise', metric_name='mse', **k decision_func = smaller_the_better else: raise NotImplementedError - + return eval_procedure_func(img1, img2, metric_func, decision_func, **kwargs) diff --git a/code/gen_eval.py b/code/gen_eval.py index 7351672..e1aa6f2 100644 --- a/code/gen_eval.py +++ b/code/gen_eval.py @@ -1,16 +1,19 @@ -import os, sys +import argparse +import datetime +import os +import sys + import numpy as np import torch -from eval_metrics import get_similarity_metric -from dataset import create_Kamitani_dataset, create_BOLD5000_dataset +from torchvision import transforms +import wandb +from config import * +from dataset import create_BOLD5000_dataset, create_Kamitani_dataset from dc_ldm.ldm_for_fmri import fLDM from einops import rearrange +from eval_metrics import get_similarity_metric from PIL import Image -import torchvision.transforms as transforms -from config import * -import wandb -import datetime -import argparse + def to_image(img): if img.shape[-1] != 3: @@ -40,7 +43,7 @@ def wandb_init(config): def get_eval_metric(samples, avg=True): metric_list = ['mse', 'pcc', 'ssim', 'psm'] res_list = [] - + gt_images = [img[0] for img in samples] gt_images = rearrange(np.stack(gt_images), 'n c h w -> n h w c') samples_to_run = np.arange(1, len(samples[0])) if avg else [1] @@ -51,12 +54,12 @@ def get_eval_metric(samples, avg=True): pred_images = rearrange(np.stack(pred_images), 'n c h w -> n h w c') res = get_similarity_metric(pred_images, gt_images, method='pair-wise', metric_name=m) res_part.append(np.mean(res)) - res_list.append(np.mean(res_part)) + res_list.append(np.mean(res_part)) res_part = [] for s in samples_to_run: pred_images = [img[s] for img in samples] pred_images = rearrange(np.stack(pred_images), 'n c h w -> n h w c') - res = get_similarity_metric(pred_images, gt_images, 'class', None, + res = get_similarity_metric(pred_images, gt_images, 'class', None, n_way=50, num_trials=1000, top_k=1, device='cuda') res_part.append(np.mean(res)) res_list.append(np.mean(res_part)) @@ -80,7 +83,7 @@ def get_args_parser(): root = args.root target = args.dataset model_path = os.path.join(root, 'pretrains', f'{target}', 'finetuned.pth') - + sd = torch.load(model_path, map_location='cpu') config = sd['config'] # update paths @@ -91,30 +94,30 @@ def get_args_parser(): config.pretrain_gm_path = os.path.join(root, 'pretrains/ldm/label2img') print(config.__dict__) - output_path = os.path.join(config.root_path, 'results', 'eval', + output_path = os.path.join(config.root_path, 'results', 'eval', '%s'%(datetime.datetime.now().strftime("%d-%m-%Y-%H-%M-%S"))) - + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') img_transform_test = transforms.Compose([ - normalize, transforms.Resize((256, 256)), + normalize, transforms.Resize((256, 256)), channel_last ]) if target == 'GOD': - _, dataset_test = create_Kamitani_dataset(config.kam_path, config.roi, config.patch_size, - fmri_transform=torch.FloatTensor, image_transform=img_transform_test, + _, dataset_test = create_Kamitani_dataset(config.kam_path, config.roi, config.patch_size, + fmri_transform=torch.FloatTensor, image_transform=img_transform_test, subjects=config.kam_subs, test_category=config.test_category) elif target == 'BOLD5000': - _, dataset_test = create_BOLD5000_dataset(config.bold5000_path, config.patch_size, - fmri_transform=torch.FloatTensor, image_transform=img_transform_test, + _, dataset_test = create_BOLD5000_dataset(config.bold5000_path, config.patch_size, + fmri_transform=torch.FloatTensor, image_transform=img_transform_test, subjects=config.bold5000_subs) else: raise NotImplementedError - + num_voxels = dataset_test.num_voxels print(len(dataset_test)) - # prepare pretrained mae + # prepare pretrained mae pretrain_mbm_metafile = torch.load(config.pretrain_mbm_path, map_location='cpu') # create generateive model generative_model = fLDM(pretrain_mbm_metafile, num_voxels, @@ -123,19 +126,18 @@ def get_args_parser(): generative_model.model.load_state_dict(sd['model_state_dict']) print('load ldm successfully') state = sd['state'] - grid, samples = generative_model.generate(dataset_test, config.num_samples, + grid, samples = generative_model.generate(dataset_test, config.num_samples, config.ddim_steps, config.HW, limit=None, state=state) # generate 10 instances grid_imgs = Image.fromarray(grid.astype(np.uint8)) os.makedirs(output_path, exist_ok=True) - grid_imgs.save(os.path.join(output_path,f'./samples_test.png')) + grid_imgs.save(os.path.join(output_path,'./samples_test.png')) wandb_init(config) - wandb.log({f'summary/samples_test': wandb.Image(grid_imgs)}) + wandb.log({'summary/samples_test': wandb.Image(grid_imgs)}) metric, metric_list = get_eval_metric(samples, avg=True) metric_dict = {f'summary/pair-wise_{k}':v for k, v in zip(metric_list[:-2], metric[:-2])} metric_dict[f'summary/{metric_list[-2]}'] = metric[-2] metric_dict[f'summary/{metric_list[-1]}'] = metric[-1] print(metric_dict) wandb.log(metric_dict) - diff --git a/code/sc_mbm/mae_for_fmri.py b/code/sc_mbm/mae_for_fmri.py index 6972f93..a734ab6 100644 --- a/code/sc_mbm/mae_for_fmri.py +++ b/code/sc_mbm/mae_for_fmri.py @@ -1,9 +1,10 @@ +import numpy as np import sc_mbm.utils as ut import torch -import torch.nn as nn -import numpy as np -from timm.models.vision_transformer import Block import torch.nn.functional as F +from timm.models.vision_transformer import Block +from torch import nn + class PatchEmbed1D(nn.Module): """ 1 Dimensional version of data (fmri voxels) to Patch Embedding @@ -29,9 +30,9 @@ class MAEforFMRI(nn.Module): """ Masked Autoencoder with VisionTransformer backbone """ def __init__(self, num_voxels=224, patch_size=16, embed_dim=1024, in_chans=1, - depth=24, num_heads=16, decoder_embed_dim=512, + depth=24, num_heads=16, decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, - mlp_ratio=4., norm_layer=nn.LayerNorm, focus_range=None, focus_rate=None, img_recon_weight=1.0, + mlp_ratio=4., norm_layer=nn.LayerNorm, focus_range=None, focus_rate=None, img_recon_weight=1.0, use_nature_img_loss=False): super().__init__() @@ -90,7 +91,7 @@ def __init__(self, num_voxels=224, patch_size=16, embed_dim=1024, in_chans=1, self.focus_rate = focus_rate self.img_recon_weight = img_recon_weight self.use_nature_img_loss = use_nature_img_loss - + self.initialize_weights() def initialize_weights(self): @@ -131,8 +132,8 @@ def _init_weights(self, m): torch.nn.init.normal_(m.weight, std=.02) if m.bias is not None: nn.init.constant_(m.bias, 0) - - + + def patchify(self, imgs): """ imgs: (N, 1, num_voxels) @@ -152,7 +153,7 @@ def unpatchify(self, x): """ p = self.patch_embed.patch_size h = x.shape[1] - + imgs = x.reshape(shape=(x.shape[0], 1, h * p)) return imgs @@ -172,11 +173,11 @@ def random_masking(self, x, mask_ratio): ] = [self.focus_rate] * (self.focus_range[1] // self.patch_size - self.focus_range[0] // self.patch_size) weights = torch.tensor(weights).repeat(N, 1).to(x.device) ids_mask = torch.multinomial(weights, len_mask, replacement=False) - + noise = torch.rand(N, L, device=x.device) # noise in [0, 1] if self.focus_range is not None: for i in range(N): - noise[i, ids_mask[i,:]] = 1.1 # set mask portion to 1.1 + noise[i, ids_mask[i,:]] = 1.1 # set mask portion to 1.1 # sort noise for each sample ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove @@ -267,27 +268,27 @@ def forward_nature_img_decoder(self, x, ids_restore): x = x.view(x.shape[0], 512, 28, 28) return x # n, 512, 28, 28 - + def forward_nature_img_loss(self, inputs, reconstructions): loss = ((torch.tanh(inputs) - torch.tanh(reconstructions))**2).mean() if torch.isnan(reconstructions).sum(): print('nan in reconstructions') if torch.isnan(inputs).sum(): print('nan in inputs') - - return loss + + return loss def forward_loss(self, imgs, pred, mask): """ imgs: [N, 1, num_voxels] pred: [N, L, p] - mask: [N, L], 0 is keep, 1 is remove, + mask: [N, L], 0 is keep, 1 is remove, """ target = self.patchify(imgs) loss = (pred - target) ** 2 loss = loss.mean(dim=-1) # [N, L], mean loss per patch - + loss = (loss * mask).sum() / mask.sum() if mask.sum() != 0 else (loss * mask).sum() # mean loss on removed patches return loss @@ -304,7 +305,7 @@ def forward(self, imgs, img_features=None, valid_idx=None, mask_ratio=0.75): if torch.isnan(loss_nature_image_recon).sum(): print(loss_nature_image_recon) print("loss_nature_image_recon is nan") - + loss = loss + self.img_recon_weight*loss_nature_image_recon return loss, pred, mask @@ -323,7 +324,7 @@ def __init__(self, num_voxels=224, patch_size=16, embed_dim=1024, in_chans=1, Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) for i in range(depth)]) self.norm = norm_layer(embed_dim) - + self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.embed_dim = embed_dim @@ -374,22 +375,22 @@ def forward_encoder(self, x): x = x.mean(dim=1, keepdim=True) x = self.norm(x) - return x + return x def forward(self, imgs): if imgs.ndim == 2: imgs = torch.unsqueeze(imgs, dim=0) # N, n_seq, embed_dim latent = self.forward_encoder(imgs) # N, n_seq, embed_dim return latent # N, n_seq, embed_dim - + def load_checkpoint(self, state_dict): if self.global_pool: state_dict = {k: v for k, v in state_dict.items() if ('mask_token' not in k and 'norm' not in k)} else: - state_dict = {k: v for k, v in state_dict.items() if ('mask_token' not in k)} + state_dict = {k: v for k, v in state_dict.items() if 'mask_token' not in k} ut.interpolate_pos_embed(self, state_dict) - + m, u = self.load_state_dict(state_dict, strict=False) print('missing keys:', u) print('unexpected keys:', m) - return \ No newline at end of file + return diff --git a/code/sc_mbm/trainer.py b/code/sc_mbm/trainer.py index b6f29ee..dae9401 100644 --- a/code/sc_mbm/trainer.py +++ b/code/sc_mbm/trainer.py @@ -1,9 +1,12 @@ -import math, sys -import torch +import math +import sys +import time + +import numpy as np import sc_mbm.utils as ut +import torch from torch._six import inf -import numpy as np -import time + class NativeScalerWithGradNormCount: state_dict_key = "amp_scaler" @@ -49,8 +52,8 @@ def get_grad_norm_(parameters, norm_type: float = 2.0): return total_norm -def train_one_epoch(model, data_loader, optimizer, device, epoch, - loss_scaler,log_writer=None, config=None, start_time=None, model_without_ddp=None, +def train_one_epoch(model, data_loader, optimizer, device, epoch, + loss_scaler,log_writer=None, config=None, start_time=None, model_without_ddp=None, img_feature_extractor=None, preprocess=None): model.train(True) optimizer.zero_grad() @@ -58,12 +61,12 @@ def train_one_epoch(model, data_loader, optimizer, device, epoch, total_cor = [] accum_iter = config.accum_iter for data_iter_step, (data_dcit) in enumerate(data_loader): - + # we use a per iteration (instead of per epoch) lr scheduler if data_iter_step % accum_iter == 0: ut.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, config) samples = data_dcit['fmri'] - + img_features = None valid_idx = None if img_feature_extractor is not None: @@ -109,7 +112,7 @@ def train_one_epoch(model, data_loader, optimizer, device, epoch, log_writer.log('cor', np.mean(total_cor), step=epoch) if start_time is not None: log_writer.log('time (min)', (time.time() - start_time)/60.0, step=epoch) - if config.local_rank == 0: + if config.local_rank == 0: print(f'[Epoch {epoch}] loss: {np.mean(total_loss)}') - return np.mean(total_cor) \ No newline at end of file + return np.mean(total_cor) diff --git a/code/sc_mbm/utils.py b/code/sc_mbm/utils.py index b3d6a83..1a8468b 100644 --- a/code/sc_mbm/utils.py +++ b/code/sc_mbm/utils.py @@ -1,8 +1,10 @@ -import numpy as np import math -import torch import os +import numpy as np +import torch + + def get_1d_sincos_pos_embed(embed_dim, length, cls_token=False): """ grid_size: int of the grid height and width @@ -71,7 +73,7 @@ def interpolate_pos_embed(model, checkpoint_model): def adjust_learning_rate(optimizer, epoch, config): """Decay the learning rate with half-cycle cosine after warmup""" if epoch < config.warmup_epochs: - lr = config.lr * epoch / config.warmup_epochs + lr = config.lr * epoch / config.warmup_epochs else: lr = config.min_lr + (config.lr - config.min_lr) * 0.5 * \ (1. + math.cos(math.pi * (epoch - config.warmup_epochs) / (config.num_epoch - config.warmup_epochs))) @@ -93,7 +95,7 @@ def save_model(config, epoch, model, optimizer, loss_scaler, checkpoint_paths): 'config': config, } torch.save(to_save, os.path.join(checkpoint_paths, 'checkpoint.pth')) - + def load_model(config, model, checkpoint_path ): checkpoint = torch.load(checkpoint_path, map_location='cpu') @@ -119,6 +121,6 @@ def unpatchify(x, patch_size): """ p = patch_size h = x.shape[1] - + imgs = x.reshape(shape=(x.shape[0], 1, h * p)) - return imgs \ No newline at end of file + return imgs diff --git a/code/setup.py b/code/setup.py index ebf5412..6cc1b31 100644 --- a/code/setup.py +++ b/code/setup.py @@ -1,4 +1,4 @@ -from setuptools import setup, find_packages +from setuptools import find_packages, setup setup( name='mind-vis', @@ -11,4 +11,4 @@ 'tqdm', 'timm' ], -) \ No newline at end of file +) diff --git a/code/stageA1_mbm_pretrain.py b/code/stageA1_mbm_pretrain.py index 143bf35..6209f1a 100644 --- a/code/stageA1_mbm_pretrain.py +++ b/code/stageA1_mbm_pretrain.py @@ -1,22 +1,23 @@ -import os, sys -import numpy as np -import torch -from torch.utils.data import DataLoader -from torch.nn.parallel import DistributedDataParallel import argparse -import time -import timm.optim.optim_factory as optim_factory +import copy import datetime +import os +import sys +import time + import matplotlib.pyplot as plt +import numpy as np +from timm.optim import optim_factory +import torch import wandb -import copy - from config import Config_MBM_fMRI from dataset import hcp_dataset from sc_mbm.mae_for_fmri import MAEforFMRI -from sc_mbm.trainer import train_one_epoch from sc_mbm.trainer import NativeScalerWithGradNormCount as NativeScaler +from sc_mbm.trainer import train_one_epoch from sc_mbm.utils import save_model +from torch.nn.parallel import DistributedDataParallel +from torch.utils.data import DataLoader os.environ["WANDB_START_METHOD"] = "thread" os.environ['WANDB_DIR'] = "." @@ -32,14 +33,14 @@ def __init__(self, config): self.config = config self.step = None - + def log(self, name, data, step=None): if step is None: wandb.log({name: data}) else: wandb.log({name: data}, step=step) self.step = step - + def watch_model(self, *args, **kwargs): wandb.watch(*args, **kwargs) @@ -54,7 +55,7 @@ def finish(self): def get_args_parser(): parser = argparse.ArgumentParser('MBM pre-training for fMRI', add_help=False) - + # Training Parameters parser.add_argument('--lr', type=float) parser.add_argument('--weight_decay', type=float) @@ -83,10 +84,10 @@ def get_args_parser(): parser.add_argument('--use_nature_img_loss', type=bool) parser.add_argument('--img_recon_weight', type=float) - + # distributed training parameters parser.add_argument('--local_rank', type=int) - + return parser def create_readme(config, path): @@ -103,39 +104,39 @@ def fmri_transform(x, sparse_rate=0.2): def main(config): if torch.cuda.device_count() > 1: - torch.cuda.set_device(config.local_rank) + torch.cuda.set_device(config.local_rank) torch.distributed.init_process_group(backend='nccl') output_path = os.path.join(config.root_path, 'results', 'fmri_pretrain', '%s'%(datetime.datetime.now().strftime("%d-%m-%Y-%H-%M-%S"))) # output_path = os.path.join(config.root_path, 'results', 'fmri_pretrain') config.output_path = output_path logger = wandb_logger(config) if config.local_rank == 0 else None - + if config.local_rank == 0: os.makedirs(output_path, exist_ok=True) create_readme(config, output_path) - + device = torch.device(f'cuda:{config.local_rank}') if torch.cuda.is_available() else torch.device('cpu') torch.manual_seed(config.seed) np.random.seed(config.seed) # create dataset and dataloader dataset_pretrain = hcp_dataset(path=os.path.join(config.root_path, 'data/HCP/npz'), roi=config.roi, patch_size=config.patch_size, - transform=fmri_transform, aug_times=config.aug_times, num_sub_limit=config.num_sub_limit, + transform=fmri_transform, aug_times=config.aug_times, num_sub_limit=config.num_sub_limit, include_kam=config.include_kam, include_hcp=config.include_hcp) - + print(f'Dataset size: {len(dataset_pretrain)}\nNumber of voxels: {dataset_pretrain.num_voxels}') - sampler = torch.utils.data.DistributedSampler(dataset_pretrain, rank=config.local_rank) if torch.cuda.device_count() > 1 else None + sampler = torch.utils.data.DistributedSampler(dataset_pretrain, rank=config.local_rank) if torch.cuda.device_count() > 1 else None - dataloader_hcp = DataLoader(dataset_pretrain, batch_size=config.batch_size, sampler=sampler, + dataloader_hcp = DataLoader(dataset_pretrain, batch_size=config.batch_size, sampler=sampler, shuffle=(sampler is None), pin_memory=True) # create model config.num_voxels = dataset_pretrain.num_voxels model = MAEforFMRI(num_voxels=dataset_pretrain.num_voxels, patch_size=config.patch_size, embed_dim=config.embed_dim, - decoder_embed_dim=config.decoder_embed_dim, depth=config.depth, + decoder_embed_dim=config.decoder_embed_dim, depth=config.depth, num_heads=config.num_heads, decoder_num_heads=config.decoder_num_heads, mlp_ratio=config.mlp_ratio, - focus_range=config.focus_range, focus_rate=config.focus_rate, - img_recon_weight=config.img_recon_weight, use_nature_img_loss=config.use_nature_img_loss) + focus_range=config.focus_range, focus_rate=config.focus_rate, + img_recon_weight=config.img_recon_weight, use_nature_img_loss=config.use_nature_img_loss) model.to(device) model_without_ddp = model if torch.cuda.device_count() > 1: @@ -156,17 +157,18 @@ def main(config): img_feature_extractor = None preprocess = None if config.use_nature_img_loss: - from torchvision.models import resnet50, ResNet50_Weights - from torchvision.models.feature_extraction import create_feature_extractor + from torchvision.models import ResNet50_Weights, resnet50 + from torchvision.models.feature_extraction import \ + create_feature_extractor weights = ResNet50_Weights.DEFAULT preprocess = weights.transforms() - m = resnet50(weights=weights) - img_feature_extractor = create_feature_extractor(m, return_nodes={f'layer2': 'layer2'}).to(device).eval() + m = resnet50(weights=weights) + img_feature_extractor = create_feature_extractor(m, return_nodes={'layer2': 'layer2'}).to(device).eval() for param in img_feature_extractor.parameters(): param.requires_grad = False for ep in range(config.num_epoch): - if torch.cuda.device_count() > 1: + if torch.cuda.device_count() > 1: sampler.set_epoch(ep) # to shuffle the data at every epoch cor = train_one_epoch(model, dataloader_hcp, optimizer, device, ep, loss_scaler, logger, config, start_time, model_without_ddp, img_feature_extractor, preprocess) @@ -176,7 +178,7 @@ def main(config): save_model(config, ep, model_without_ddp, optimizer, loss_scaler, os.path.join(output_path,'checkpoints')) # plot figures plot_recon_figures(model, device, dataset_pretrain, output_path, 5, config, logger, model_without_ddp) - + total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str)) diff --git a/code/stageA2_mbm_finetune.py b/code/stageA2_mbm_finetune.py index c366523..0f88c54 100644 --- a/code/stageA2_mbm_finetune.py +++ b/code/stageA2_mbm_finetune.py @@ -1,24 +1,24 @@ -import os, sys -import numpy as np -import torch -from torch.utils.data import DataLoader -from torch.nn.parallel import DistributedDataParallel import argparse -import time -import timm.optim.optim_factory as optim_factory +import copy import datetime +import os +import sys +import time + import matplotlib.pyplot as plt +import numpy as np +from timm.optim import optim_factory +import torch import wandb -import copy - # own code from config import Config_MBM_finetune -from dataset import create_Kamitani_dataset, create_BOLD5000_dataset +from dataset import create_BOLD5000_dataset, create_Kamitani_dataset from sc_mbm.mae_for_fmri import MAEforFMRI -from sc_mbm.trainer import train_one_epoch from sc_mbm.trainer import NativeScalerWithGradNormCount as NativeScaler +from sc_mbm.trainer import train_one_epoch from sc_mbm.utils import save_model - +from torch.nn.parallel import DistributedDataParallel +from torch.utils.data import DataLoader os.environ["WANDB_START_METHOD"] = "thread" os.environ['WANDB_DIR'] = "." @@ -33,14 +33,14 @@ def __init__(self, config): self.config = config self.step = None - + def log(self, name, data, step=None): if step is None: wandb.log({name: data}) else: wandb.log({name: data}, step=step) self.step = step - + def watch_model(self, *args, **kwargs): wandb.watch(*args, **kwargs) @@ -67,11 +67,11 @@ def get_args_parser(): parser.add_argument('--root_path', type=str) parser.add_argument('--pretrain_mbm_path', type=str) parser.add_argument('--dataset', type=str) - parser.add_argument('--include_nonavg_test', type=bool) - + parser.add_argument('--include_nonavg_test', type=bool) + # distributed training parameters parser.add_argument('--local_rank', type=int) - + return parser def create_readme(config, path): @@ -88,20 +88,20 @@ def fmri_transform(x, sparse_rate=0.2): def main(config): if torch.cuda.device_count() > 1: - torch.cuda.set_device(config.local_rank) + torch.cuda.set_device(config.local_rank) torch.distributed.init_process_group(backend='nccl') sd = torch.load(config.pretrain_mbm_path, map_location='cpu') config_pretrain = sd['config'] - + output_path = os.path.join(config.root_path, 'results', 'fmri_finetune', '%s'%(datetime.datetime.now().strftime("%d-%m-%Y-%H-%M-%S"))) # output_path = os.path.join(config.root_path, 'results', 'fmri_finetune') config.output_path = output_path logger = wandb_logger(config) if config.local_rank == 0 else None - + if config.local_rank == 0: os.makedirs(output_path, exist_ok=True) create_readme(config, output_path) - + device = torch.device(f'cuda:{config.local_rank}') if torch.cuda.is_available() else torch.device('cpu') torch.manual_seed(config_pretrain.seed) np.random.seed(config_pretrain.seed) @@ -109,9 +109,9 @@ def main(config): # create model num_voxels = (sd['model']['pos_embed'].shape[1] - 1)* config_pretrain.patch_size model = MAEforFMRI(num_voxels=num_voxels, patch_size=config_pretrain.patch_size, embed_dim=config_pretrain.embed_dim, - decoder_embed_dim=config_pretrain.decoder_embed_dim, depth=config_pretrain.depth, - num_heads=config_pretrain.num_heads, decoder_num_heads=config_pretrain.decoder_num_heads, - mlp_ratio=config_pretrain.mlp_ratio, focus_range=None, use_nature_img_loss=False) + decoder_embed_dim=config_pretrain.decoder_embed_dim, depth=config_pretrain.depth, + num_heads=config_pretrain.num_heads, decoder_num_heads=config_pretrain.decoder_num_heads, + mlp_ratio=config_pretrain.mlp_ratio, focus_range=None, use_nature_img_loss=False) model.load_state_dict(sd['model'], strict=False) model.to(device) @@ -119,10 +119,10 @@ def main(config): # create dataset and dataloader if config.dataset == 'GOD': - _, test_set = create_Kamitani_dataset(path=config.kam_path, patch_size=config_pretrain.patch_size, + _, test_set = create_Kamitani_dataset(path=config.kam_path, patch_size=config_pretrain.patch_size, subjects=config.kam_subs, fmri_transform=torch.FloatTensor, include_nonavg_test=config.include_nonavg_test) elif config.dataset == 'BOLD5000': - _, test_set = create_BOLD5000_dataset(path=config.bold5000_path, patch_size=config_pretrain.patch_size, + _, test_set = create_BOLD5000_dataset(path=config.bold5000_path, patch_size=config_pretrain.patch_size, fmri_transform=torch.FloatTensor, subjects=config.bold5000_subs, include_nonavg_test=config.include_nonavg_test) else: raise NotImplementedError @@ -133,7 +133,7 @@ def main(config): else: test_set.fmri = test_set.fmri[:, :num_voxels] print(f'Dataset size: {len(test_set)}') - sampler = torch.utils.data.DistributedSampler(test_set) if torch.cuda.device_count() > 1 else torch.utils.data.RandomSampler(test_set) + sampler = torch.utils.data.DistributedSampler(test_set) if torch.cuda.device_count() > 1 else torch.utils.data.RandomSampler(test_set) dataloader_hcp = DataLoader(test_set, batch_size=config.batch_size, sampler=sampler) if torch.cuda.device_count() > 1: @@ -152,7 +152,7 @@ def main(config): start_time = time.time() print('Finetuning MAE on test fMRI ... ...') for ep in range(config.num_epoch): - if torch.cuda.device_count() > 1: + if torch.cuda.device_count() > 1: sampler.set_epoch(ep) # to shuffle the data at every epoch cor = train_one_epoch(model, dataloader_hcp, optimizer, device, ep, loss_scaler, logger, config, start_time, model_without_ddp) cor_list.append(cor) @@ -161,7 +161,7 @@ def main(config): save_model(config_pretrain, ep, model_without_ddp, optimizer, loss_scaler, os.path.join(output_path,'checkpoints')) # plot figures plot_recon_figures(model, device, test_set, output_path, 5, config, logger, model_without_ddp) - + total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str)) @@ -224,4 +224,3 @@ def update_config(args, config): config = Config_MBM_finetune() config = update_config(args, config) main(config) - diff --git a/code/stageB_ldm_finetune.py b/code/stageB_ldm_finetune.py index 68f980a..5c8af4f 100644 --- a/code/stageB_ldm_finetune.py +++ b/code/stageB_ldm_finetune.py @@ -1,21 +1,22 @@ -import os, sys -import numpy as np -import torch import argparse -import datetime -import wandb -import torchvision.transforms as transforms -from einops import rearrange -from PIL import Image -import pytorch_lightning as pl -from pytorch_lightning.loggers import WandbLogger import copy +import datetime +import os +import sys +import numpy as np +import pytorch_lightning as pl +import torch +from torchvision import transforms +import wandb # own code from config import Config_Generative_Model -from dataset import create_Kamitani_dataset, create_BOLD5000_dataset +from dataset import create_BOLD5000_dataset, create_Kamitani_dataset from dc_ldm.ldm_for_fmri import fLDM +from einops import rearrange from eval_metrics import get_similarity_metric +from PIL import Image +from pytorch_lightning.loggers import WandbLogger def wandb_init(config, output_path): @@ -36,14 +37,14 @@ def to_image(img): return Image.fromarray(img.astype(np.uint8)) def channel_last(img): - if img.shape[-1] == 3: - return img - return rearrange(img, 'c h w -> h w c') + if img.shape[-1] == 3: + return img + return rearrange(img, 'c h w -> h w c') def get_eval_metric(samples, avg=True): metric_list = ['mse', 'pcc', 'ssim', 'psm'] res_list = [] - + gt_images = [img[0] for img in samples] gt_images = rearrange(np.stack(gt_images), 'n c h w -> n h w c') samples_to_run = np.arange(1, len(samples[0])) if avg else [1] @@ -54,12 +55,12 @@ def get_eval_metric(samples, avg=True): pred_images = rearrange(np.stack(pred_images), 'n c h w -> n h w c') res = get_similarity_metric(pred_images, gt_images, method='pair-wise', metric_name=m) res_part.append(np.mean(res)) - res_list.append(np.mean(res_part)) + res_list.append(np.mean(res_part)) res_part = [] for s in samples_to_run: pred_images = [img[s] for img in samples] pred_images = rearrange(np.stack(pred_images), 'n c h w -> n h w c') - res = get_similarity_metric(pred_images, gt_images, 'class', None, + res = get_similarity_metric(pred_images, gt_images, 'class', None, n_way=50, num_trials=50, top_k=1, device='cuda') res_part.append(np.mean(res)) res_list.append(np.mean(res_part)) @@ -67,25 +68,25 @@ def get_eval_metric(samples, avg=True): metric_list.append('top-1-class') metric_list.append('top-1-class (max)') return res_list, metric_list - + def generate_images(generative_model, fmri_latents_dataset_train, fmri_latents_dataset_test, config): - grid, _ = generative_model.generate(fmri_latents_dataset_train, config.num_samples, + grid, _ = generative_model.generate(fmri_latents_dataset_train, config.num_samples, config.ddim_steps, config.HW, 10) # generate 10 instances grid_imgs = Image.fromarray(grid.astype(np.uint8)) grid_imgs.save(os.path.join(config.output_path, 'samples_train.png')) wandb.log({'summary/samples_train': wandb.Image(grid_imgs)}) - grid, samples = generative_model.generate(fmri_latents_dataset_test, config.num_samples, + grid, samples = generative_model.generate(fmri_latents_dataset_test, config.num_samples, config.ddim_steps, config.HW) grid_imgs = Image.fromarray(grid.astype(np.uint8)) - grid_imgs.save(os.path.join(config.output_path,f'./samples_test.png')) + grid_imgs.save(os.path.join(config.output_path,'./samples_test.png')) for sp_idx, imgs in enumerate(samples): for copy_idx, img in enumerate(imgs[1:]): img = rearrange(img, 'c h w -> h w c') - Image.fromarray(img).save(os.path.join(config.output_path, + Image.fromarray(img).save(os.path.join(config.output_path, f'./test{sp_idx}-{copy_idx}.png')) - wandb.log({f'summary/samples_test': wandb.Image(grid_imgs)}) + wandb.log({'summary/samples_test': wandb.Image(grid_imgs)}) metric, metric_list = get_eval_metric(samples, avg=config.eval_avg) metric_dict = {f'summary/pair-wise_{k}':v for k, v in zip(metric_list[:-2], metric[:-2])} @@ -126,33 +127,33 @@ def main(config): img_transform_train = transforms.Compose([ normalize, random_crop(config.img_size-crop_pix, p=0.5), - transforms.Resize((256, 256)), + transforms.Resize((256, 256)), channel_last ]) img_transform_test = transforms.Compose([ - normalize, transforms.Resize((256, 256)), + normalize, transforms.Resize((256, 256)), channel_last ]) if config.dataset == 'GOD': - fmri_latents_dataset_train, fmri_latents_dataset_test = create_Kamitani_dataset(config.kam_path, config.roi, config.patch_size, - fmri_transform=fmri_transform, image_transform=[img_transform_train, img_transform_test], + fmri_latents_dataset_train, fmri_latents_dataset_test = create_Kamitani_dataset(config.kam_path, config.roi, config.patch_size, + fmri_transform=fmri_transform, image_transform=[img_transform_train, img_transform_test], subjects=config.kam_subs) num_voxels = fmri_latents_dataset_train.num_voxels elif config.dataset == 'BOLD5000': - fmri_latents_dataset_train, fmri_latents_dataset_test = create_BOLD5000_dataset(config.bold5000_path, config.patch_size, - fmri_transform=fmri_transform, image_transform=[img_transform_train, img_transform_test], + fmri_latents_dataset_train, fmri_latents_dataset_test = create_BOLD5000_dataset(config.bold5000_path, config.patch_size, + fmri_transform=fmri_transform, image_transform=[img_transform_train, img_transform_test], subjects=config.bold5000_subs) num_voxels = fmri_latents_dataset_train.num_voxels else: raise NotImplementedError - # prepare pretrained mbm + # prepare pretrained mbm pretrain_mbm_metafile = torch.load(config.pretrain_mbm_path, map_location='cpu') # create generateive model generative_model = fLDM(pretrain_mbm_metafile, num_voxels, - device=device, pretrain_root=config.pretrain_gm_path, logger=config.logger, + device=device, pretrain_root=config.pretrain_gm_path, logger=config.logger, ddim_steps=config.ddim_steps, global_pool=config.global_pool, use_time_cond=config.use_time_cond) - + # resume training if applicable if config.checkpoint_path is not None: model_meta = torch.load(config.checkpoint_path, map_location='cpu') @@ -215,17 +216,17 @@ def create_readme(config, path): def create_trainer(num_epoch, precision=32, accumulate_grad_batches=2,logger=None,check_val_every_n_epoch=0): acc = 'gpu' if torch.cuda.is_available() else 'cpu' - return pl.Trainer(accelerator=acc, max_epochs=num_epoch, logger=logger, + return pl.Trainer(accelerator=acc, max_epochs=num_epoch, logger=logger, precision=precision, accumulate_grad_batches=accumulate_grad_batches, enable_checkpointing=False, enable_model_summary=False, gradient_clip_val=0.5, check_val_every_n_epoch=check_val_every_n_epoch) - + if __name__ == '__main__': args = get_args_parser() args = args.parse_args() config = Config_Generative_Model() config = update_config(args, config) - + if config.checkpoint_path is not None: model_meta = torch.load(config.checkpoint_path, map_location='cpu') ckp = config.checkpoint_path @@ -236,7 +237,7 @@ def create_trainer(num_epoch, precision=32, accumulate_grad_batches=2,logger=Non output_path = os.path.join(config.root_path, 'results', 'generation', '%s'%(datetime.datetime.now().strftime("%d-%m-%Y-%H-%M-%S"))) config.output_path = output_path os.makedirs(output_path, exist_ok=True) - + wandb_init(config, output_path) logger = WandbLogger()