Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

updates for python 3 #2

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions code/config.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import os

import numpy as np


class Config_MAE_fMRI: # back compatibility
pass
class Config_MBM_finetune: # back compatibility
pass
pass

class Config_MBM_fMRI(Config_MAE_fMRI):
# configs for fmri_pretrain.py
Expand All @@ -19,7 +21,7 @@ def __init__(self):
self.warmup_epochs = 40
self.batch_size = 100
self.clip_grad = 0.8

# Model Parameters
self.mask_ratio = 0.75
self.patch_size = 16
Expand Down Expand Up @@ -51,14 +53,14 @@ def __init__(self):

class Config_MBM_finetune(Config_MBM_finetune):
def __init__(self):

# Project setting
self.root_path = '.'
self.output_path = self.root_path
self.kam_path = os.path.join(self.root_path, 'data/Kamitani/npz')
self.bold5000_path = os.path.join(self.root_path, 'data/BOLD5000')
self.dataset = 'GOD' # GOD or BOLD5000
self.pretrain_mbm_path = os.path.join(self.root_path, f'pretrains/{self.dataset}/fmri_encoder.pth')
self.pretrain_mbm_path = os.path.join(self.root_path, f'pretrains/{self.dataset}/fmri_encoder.pth')

self.include_nonavg_test = True
self.kam_subs = ['sbj_3']
Expand All @@ -68,16 +70,16 @@ def __init__(self):
self.lr = 5.3e-5
self.weight_decay = 0.05
self.num_epoch = 15
self.batch_size = 16 if self.dataset == 'GOD' else 4
self.mask_ratio = 0.75
self.batch_size = 16 if self.dataset == 'GOD' else 4
self.mask_ratio = 0.75
self.accum_iter = 1
self.clip_grad = 0.8
self.warmup_epochs = 2
self.min_lr = 0.

# distributed training
self.local_rank = 0

class Config_Generative_Model:
def __init__(self):
# project parameters
Expand All @@ -92,11 +94,11 @@ def __init__(self):
self.pretrain_gm_path = os.path.join(self.root_path, 'pretrains/ldm/label2img')
# self.pretrain_gm_path = os.path.join(self.root_path, 'pretrains/ldm/text2img-large')
# self.pretrain_gm_path = os.path.join(self.root_path, 'pretrains/ldm/layout2img')

self.dataset = 'GOD' # GOD or BOLD5000
self.kam_subs = ['sbj_3']
self.bold5000_subs = ['CSI4']
self.pretrain_mbm_path = os.path.join(self.root_path, f'pretrains/{self.dataset}/fmri_encoder.pth')
self.pretrain_mbm_path = os.path.join(self.root_path, f'pretrains/{self.dataset}/fmri_encoder.pth')

self.img_size = 256

Expand All @@ -105,7 +107,7 @@ def __init__(self):
self.batch_size = 5 if self.dataset == 'GOD' else 25
self.lr = 5.3e-5
self.num_epoch = 500

self.precision = 32
self.accumulate_grad = 1
self.crop_ratio = 0.2
Expand Down
100 changes: 51 additions & 49 deletions code/dataset.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@

from torch.utils.data import Dataset
import numpy as np
import os
from scipy import interpolate
from einops import rearrange
import json
import csv
import torch
import json
import os
from pathlib import Path
import torchvision.transforms as transforms

import numpy as np
import torch
from torchvision import transforms
from einops import rearrange
from scipy import interpolate
from torch.utils.data import Dataset


def identity(x):
return x
Expand Down Expand Up @@ -52,7 +54,7 @@ def augmentation(data, aug_times=2, interpolation_ratio=0.5):
data: num_samples, num_voxels_padded
return: data_aug: num_samples*aug_times, num_voxels_padded
'''
num_to_generate = int((aug_times-1)*len(data))
num_to_generate = int((aug_times-1)*len(data))
if num_to_generate == 0:
return data
pairs_idx = np.random.choice(len(data), size=(num_to_generate, 2), replace=True)
Expand Down Expand Up @@ -85,28 +87,28 @@ def img_norm(img):
return img

def channel_first(img):
if img.shape[-1] == 3:
return rearrange(img, 'h w c -> c h w')
return img
if img.shape[-1] == 3:
return rearrange(img, 'h w c -> c h w')
return img

class hcp_dataset(Dataset):
def __init__(self, path='../data/HCP/npz', roi='VC', patch_size=16, transform=identity, aug_times=2,
def __init__(self, path='../data/HCP/npz', roi='VC', patch_size=16, transform=identity, aug_times=2,
num_sub_limit=None, include_kam=False, include_hcp=True):
super(hcp_dataset, self).__init__()
super().__init__()
data = []
images = []

if include_hcp:
for c, sub in enumerate(os.listdir(path)):
if os.path.isfile(os.path.join(path,sub,'HCP_visual_voxel.npz')) == False:
continue
continue
if num_sub_limit is not None and c > num_sub_limit:
break
npz = dict(np.load(os.path.join(path,sub,'HCP_visual_voxel.npz')))
voxels = np.concatenate([npz['V1'],npz['V2'],npz['V3'],npz['V4']], axis=-1) if roi == 'VC' else npz[roi] # 1200, num_voxels
voxels = process_voxel_ts(voxels, patch_size) # num_samples, num_voxels_padded
data.append(voxels)

data = augmentation(np.concatenate(data, axis=0), aug_times) # num_samples, num_voxels_padded
data = np.expand_dims(data, axis=1) # num_samples, 1, num_voxels_padded
images += [None] * len(data)
Expand All @@ -124,7 +126,7 @@ def __init__(self, path='../data/HCP/npz', roi='VC', patch_size=16, transform=id
images += k.images

assert len(data) != 0, 'No data found'

self.roi = roi
self.patch_size = patch_size
self.num_voxels = data.shape[-1]
Expand All @@ -133,24 +135,24 @@ def __init__(self, path='../data/HCP/npz', roi='VC', patch_size=16, transform=id
self.images = images
self.images_transform = transforms.Compose([
img_norm,
transforms.Resize((112, 112)),
transforms.Resize((112, 112)),
channel_first
])

def __len__(self):
return len(self.data)

def __getitem__(self, index):
img = self.images[index]
images_transform = self.images_transform if img is not None else identity
img = img if img is not None else torch.zeros(3, 112, 112)

return {'fmri': self.transform(self.data[index]),
'image': images_transform(img)}

class Kamitani_pretrain_dataset(Dataset):
def __init__(self, path='../data/Kamitani/npz', roi='VC', patch_size=16, transform=identity, aug_times=2):
super(Kamitani_pretrain_dataset, self).__init__()
super().__init__()
k1, k2 = create_Kamitani_dataset(path, roi, patch_size, transform, include_nonavg_test=True)
# data = np.concatenate([k1.fmri, k2.fmri], axis=0)
# self.images = [img for img in k1.image] + [None] * len(k2.fmri)
Expand All @@ -164,10 +166,10 @@ def __init__(self, path='../data/Kamitani/npz', roi='VC', patch_size=16, transfo
self.patch_size = patch_size
self.num_voxels = data.shape[-1]
self.transform = transform

def __len__(self):
return len(self.data)

def __getitem__(self, index):
return self.transform(self.data[index])

Expand All @@ -194,7 +196,7 @@ def get_img_label(class_index:dict, img_filename:list, naive_label_set=None):
return img_label, naive_label

def create_Kamitani_dataset(path='../data/Kamitani/npz', roi='VC', patch_size=16, fmri_transform=identity,
image_transform=identity, subjects = ['sbj_1', 'sbj_2', 'sbj_3', 'sbj_4', 'sbj_5'],
image_transform=identity, subjects = ['sbj_1', 'sbj_2', 'sbj_3', 'sbj_4', 'sbj_5'],
test_category=None, include_nonavg_test=False):
img_npz = dict(np.load(os.path.join(path, 'images_256.npz')))
with open(os.path.join(path, 'imagenet_class_index.json'), 'r') as f:
Expand Down Expand Up @@ -223,10 +225,10 @@ def create_Kamitani_dataset(path='../data/Kamitani/npz', roi='VC', patch_size=1
train_img.append(img_npz['train_images'][npz['arr_3']])
train_lb = [train_img_label[i] for i in npz['arr_3']]
test_lb = test_img_label

roi_mask = npz[roi]
tr = npz['arr_0'][..., roi_mask] # train
tt = npz['arr_2'][..., roi_mask]
tt = npz['arr_2'][..., roi_mask]
if include_nonavg_test:
tt = np.concatenate([tt, npz['arr_1'][..., roi_mask]], axis=0)

Expand All @@ -237,14 +239,14 @@ def create_Kamitani_dataset(path='../data/Kamitani/npz', roi='VC', patch_size=1
train_fmri.append(tr)
test_fmri.append(tt)
if test_category is not None:
train_img_, train_fmri_, test_img_, test_fmri_, train_lb, test_lb = reorganize_train_test(train_img[-1], train_fmri[-1],
train_img_, train_fmri_, test_img_, test_fmri_, train_lb, test_lb = reorganize_train_test(train_img[-1], train_fmri[-1],
test_img[-1], test_fmri[-1], train_lb, test_lb,
test_category, npz['arr_3'])
train_img[-1] = train_img_
train_fmri[-1] = train_fmri_
test_img[-1] = test_img_
test_fmri[-1] = test_fmri_

train_img_label_all += train_lb
test_img_label_all += test_lb

Expand All @@ -267,13 +269,13 @@ def create_Kamitani_dataset(path='../data/Kamitani/npz', roi='VC', patch_size=1
# train_img = rearrange(train_img, 'n h w c -> n c h w')

if isinstance(image_transform, list):
return (Kamitani_dataset(train_fmri, train_img, train_img_label_all, fmri_transform, image_transform[0], num_voxels, len(npz['arr_0'])),
return (Kamitani_dataset(train_fmri, train_img, train_img_label_all, fmri_transform, image_transform[0], num_voxels, len(npz['arr_0'])),
Kamitani_dataset(test_fmri, test_img, test_img_label_all, torch.FloatTensor, image_transform[1], num_voxels, len(npz['arr_2'])))
else:
return (Kamitani_dataset(train_fmri, train_img, train_img_label_all, fmri_transform, image_transform, num_voxels, len(npz['arr_0'])),
return (Kamitani_dataset(train_fmri, train_img, train_img_label_all, fmri_transform, image_transform, num_voxels, len(npz['arr_0'])),
Kamitani_dataset(test_fmri, test_img, test_img_label_all, torch.FloatTensor, image_transform, num_voxels, len(npz['arr_2'])))

def reorganize_train_test(train_img, train_fmri, test_img, test_fmri, train_lb, test_lb,
def reorganize_train_test(train_img, train_fmri, test_img, test_fmri, train_lb, test_lb,
test_category, train_index_lookup):
test_img_ = []
test_fmri_ = []
Expand All @@ -287,7 +289,7 @@ def reorganize_train_test(train_img, train_fmri, test_img, test_fmri, train_lb,
test_fmri_.append(train_fmri[train_idx])
test_lb_.append(train_lb[train_idx])
train_idx_list.append(train_idx)

train_img_ = np.stack([img for i, img in enumerate(train_img) if i not in train_idx_list])
train_fmri_ = np.stack([fmri for i, fmri in enumerate(train_fmri) if i not in train_idx_list])
train_lb_ = [lb for i, lb in enumerate(train_lb) if i not in train_idx_list] + test_lb
Expand All @@ -301,7 +303,7 @@ def reorganize_train_test(train_img, train_fmri, test_img, test_fmri, train_lb,

class Kamitani_dataset(Dataset):
def __init__(self, fmri, image, img_label, fmri_transform=identity, image_transform=identity, num_voxels=0, num_per_sub=50):
super(Kamitani_dataset, self).__init__()
super().__init__()
self.fmri = fmri
self.image = image
if len(self.image) != len(self.fmri):
Expand All @@ -317,7 +319,7 @@ def __init__(self, fmri, image, img_label, fmri_transform=identity, image_transf

def __len__(self):
return len(self.fmri)

def __getitem__(self, index):
fmri = self.fmri[index]
if index >= len(self.image):
Expand All @@ -336,7 +338,7 @@ def __getitem__(self, index):

class base_dataset(Dataset):
def __init__(self, x, y=None, transform=identity):
super(base_dataset, self).__init__()
super().__init__()
self.x = x
self.y = y
self.transform = transform
Expand All @@ -347,7 +349,7 @@ def __getitem__(self, index):
return self.transform(self.x[index])
else:
return self.transform(self.x[index]), self.transform(self.y[index])

def remove_repeats(fmri, img_lb):
assert len(fmri) == len(img_lb), 'len error'
fmri_dict = {}
Expand Down Expand Up @@ -386,7 +388,7 @@ def get_stimuli_list(root, sub):

def list_get_all_index(list, value):
return [i for i, v in enumerate(list) if v == value]

def create_BOLD5000_dataset(path='../data/BOLD5000', patch_size=16, fmri_transform=identity,
image_transform=identity, subjects = ['CSI1', 'CSI2', 'CSI3', 'CSI4'], include_nonavg_test=False):
roi_list = ['EarlyVis', 'LOC', 'OPA', 'PPA', 'RSC']
Expand All @@ -397,7 +399,7 @@ def create_BOLD5000_dataset(path='../data/BOLD5000', patch_size=16, fmri_transfo

fmri_files = [f for f in os.listdir(fmri_path) if f.endswith('.npy')]
fmri_files.sort()

fmri_train_major = []
fmri_test_major = []
img_train_major = []
Expand All @@ -411,17 +413,17 @@ def create_BOLD5000_dataset(path='../data/BOLD5000', patch_size=16, fmri_transfo
fmri_data_sub.append(np.load(os.path.join(fmri_path, npy)))
fmri_data_sub = np.concatenate(fmri_data_sub, axis=-1) # concatenate all rois
fmri_data_sub = normalize(pad_to_patch_size(fmri_data_sub, patch_size))

# load image
img_files = get_stimuli_list(img_path, sub)
img_data_sub = [imgs_dict[name] for name in img_files]

# split train test
test_idx = [list_get_all_index(img_files, img) for img in repeated_imgs_list]
test_idx = [i for i in test_idx if len(i) > 0] # remove empy list for CSI4
test_fmri = np.stack([fmri_data_sub[idx].mean(axis=0) for idx in test_idx])
test_img = np.stack([img_data_sub[idx[0]] for idx in test_idx])

test_idx_flatten = []
for idx in test_idx:
test_idx_flatten += idx # flatten
Expand All @@ -444,10 +446,10 @@ def create_BOLD5000_dataset(path='../data/BOLD5000', patch_size=16, fmri_transfo

num_voxels = fmri_train_major.shape[-1]
if isinstance(image_transform, list):
return (BOLD5000_dataset(fmri_train_major, img_train_major, fmri_transform, image_transform[0], num_voxels),
return (BOLD5000_dataset(fmri_train_major, img_train_major, fmri_transform, image_transform[0], num_voxels),
BOLD5000_dataset(fmri_test_major, img_test_major, torch.FloatTensor, image_transform[1], num_voxels))
else:
return (BOLD5000_dataset(fmri_train_major, img_train_major, fmri_transform, image_transform, num_voxels),
return (BOLD5000_dataset(fmri_train_major, img_train_major, fmri_transform, image_transform, num_voxels),
BOLD5000_dataset(fmri_test_major, img_test_major, torch.FloatTensor, image_transform, num_voxels))

class BOLD5000_dataset(Dataset):
Expand All @@ -457,16 +459,16 @@ def __init__(self, fmri, image, fmri_transform=identity, image_transform=identit
self.fmri_transform = fmri_transform
self.image_transform = image_transform
self.num_voxels = num_voxels

def __len__(self):
return len(self.fmri)

def __getitem__(self, index):
fmri = self.fmri[index]
img = self.image[index] / 255.0
fmri = np.expand_dims(fmri, axis=0)
fmri = np.expand_dims(fmri, axis=0)
return {'fmri': self.fmri_transform(fmri), 'image': self.image_transform(img)}

def switch_sub_view(self, sub, subs):
# Not implemented
pass
pass
Loading