Skip to content

Commit

Permalink
Merge pull request #1 from fmi-basel/dev_release
Browse files Browse the repository at this point in the history
Updated notebooks, added spiking code and 3dshapes
  • Loading branch information
fzenke authored Aug 16, 2023
2 parents be3dddc + 226465f commit dd486d0
Show file tree
Hide file tree
Showing 58 changed files with 11,352 additions and 20,742 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,6 @@ archives/

# Built Visual Studio Code Extensions
*.vsix

# DS_Store files
**/.DS_Store
Empty file modified LICENSE
100644 → 100755
Empty file.
16 changes: 10 additions & 6 deletions README.md
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,29 @@ To train a deep net with layer-local LPL, simply run
python lpl_main.py
```

in the virtual environment you just created. Several useful command-line arguments are provided in `lpl_main.py` and `models\modules.py`. A few are listed below:
in the virtual environment you just created. Several useful command-line arguments are provided in `lpl_main.py` and `models/modules.py`. A few are listed below to assist you in playing around with the code on your own:
- `--train_with_supervision` trains the same network with supervision.
- `--use_negative_samples` trains the network with a cosine-distance-based contrastive loss.
**Note:** this needs to be combined with setting the decorrelation loss coefficient to 0, and enabling the additional projection MLP.
- `--train_end_to_end` is a flag for training the network with backpropagation to optimize the specified loss (lpl, supervised or neg. samples) at the output layer only.
- `--no_pooling` optimizes the specified loss on the unpooled feature maps.
- `--use_projector_mlp` adds additional projection dense layers at every layer in the network where the specified loss is to be optimized.
- `--use_projector_mlp` adds additional dense projection heads at every layer in the network where the loss is optimized.
- `--pull_coeff`, `--push_coeff`, and `--decorr_coeff` are the different loss coefficients. Default values are 1.0, 1.0, and 10.0 respectively.
- `--dataset` specifies the dataset to be used. The Shapes3D dataset needs a separate preprocessing step to make the required sequence of images (detailed in `notebooks/E5 - 3D shapes.ipynb`).
- `--gpus` is an integer that specifies the number of GPUs to use. If you have multiple GPUs, you can use this to speed up training or support larger datasets/batches by setting it to a number greater than 1, for example `--gpus 2` to use 2 GPUs.

By default, the code trains a VGG-11 network on the CIFAR-10 dataset with LPL. The default hyperparameters are the same as those used in the paper. The code will automatically download the dataset if it is not already present in the `~/data/datasets/` directory. Model performance and other metrics are logged under `~/data/lpl/$DATASET/`, where `$DATASET` is the name of the dataset used. You can monitor training progress by running `tensorboard --logdir ~/data/lpl/$DATASET/` in a separate terminal window.

### Note on network architectures
Only a VGG-11 architecture is provided here, but the framework can easily be extended to other architectures. You can simply configure another encoder in `models\encoders.py`, and add it to `models\network.py`. In principle, everything should work with residual architectures as well, but layer-local learning in this case is not well-defined.
Only a VGG-11 architecture has been extensively tested, but the framework easily extends to other architectures. You can simply configure another encoder in `models/encoders.py`, and add it to `models/network.py`. In principle, everything should work with residual architectures (provided as an option) as well, but layer-local learning in this case is not well-defined because of the non-plastic skip connections.

## Analysis
## Analysis and reproduction of figures

Jupyter notebooks (3a, 3b and 3c) provided in `notebooks` contain instructions on extracting and visualising several metrics for the quality of learned representations. Also provided there are notebooks for generating all figures from the paper.
Jupyter notebooks under `notebooks` contain instructions on extracting and visualizing several metrics for the quality of learned representations. Also provided in the notebooks is the code for generating figures from the paper.

## Spiking network code

The spiking network code will be released in the near future, at the latest, upon manuscript acceptance.
You find the spiking network code and instructions for reproducing the corresponding results under `spiking_simulations/`.

## Citation

Expand Down
Empty file modified callbacks/__init__.py
100644 → 100755
Empty file.
Empty file modified callbacks/ssl_callbacks.py
100644 → 100755
Empty file.
Empty file added datasets/__init__.py
Empty file.
133 changes: 133 additions & 0 deletions datasets/sequence_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
from matplotlib import pyplot as plt
import numpy as np
import h5py
import os
import time
from tqdm import tqdm

_FACTORS_IN_ORDER = ['floor_hue', 'wall_hue', 'object_hue', 'scale', 'shape', 'orientation']
_NUM_VALUES_PER_FACTOR = {'floor_hue': 10, 'wall_hue': 10, 'object_hue': 10, 'scale': 8, 'shape': 4, 'orientation': 15}
_PERMUTATIONS_PER_FACTOR = {}

# generate num_vectors random vectors of dim num_dims and length at sqrt(num_dims)
def sample_random_vectors(num_vectors, num_dims):
""" Samples many vectors of dimension num_dims and length sqrt(num_dims)
Args:
num_vectors: number of vectors to sample.
num_dims: dimension of vectors to sample.
Returns:
batch: vectors shape [batch_size,num_dims]
"""
vectors = np.random.normal(size=[num_vectors, num_dims])
lengths = np.sqrt(np.sum(vectors**2, axis=1))
vectors = np.sqrt(num_dims) * vectors / lengths[:, np.newaxis]
# discretize vectors to the closest integer (negative values rounded down, positive values rounded up)
vectors = np.round(vectors)
# if any vector is all zeros, then set random coordinate to 1
vectors[np.sum(vectors**2, axis=1) == 0, np.random.randint(num_dims)] = 1
return vectors

def sample_one_hot_vectors(num_vectors, num_dims):
""" Samples many one-hot vectors of dimension num_dims
Args:
num_vectors: number of vectors to sample.
num_dims: dimension of vectors to sample.
Returns:
batch: vectors shape [batch_size,num_dims]
"""
non_zero_indices = np.random.randint(num_dims, size=num_vectors)
vectors = np.eye(num_dims)[non_zero_indices]
# multiply random sign to each vector
signs = np.random.choice([-1, 1], size=num_vectors)
# select one-quarter of the vectors and multiply by 2
signs[np.random.choice(num_vectors, size=num_vectors//4)] *= 2
vectors = vectors * signs[:, np.newaxis]
return vectors


def get_index(factors):
""" Converts factors to indices in range(num_data)
Args:
factors: np array shape [6,batch_size].
factors[i]=factors[i,:] takes integer values in
range(_NUM_VALUES_PER_FACTOR[_FACTORS_IN_ORDER[i]]).
Returns:
indices: np array shape [batch_size].
"""
indices = 0
base = 1
for factor, name in reversed(list(enumerate(_FACTORS_IN_ORDER))):
indices += factors[factor] * base
base *= _NUM_VALUES_PER_FACTOR[name]
return indices


def generate_trajectories(num_sequences=64000, batch_size=16):

current_state = np.zeros([num_sequences, len(_FACTORS_IN_ORDER)], dtype=np.int8)
trajectories = np.zeros([num_sequences, batch_size+1, len(_FACTORS_IN_ORDER)], dtype=np.int8)

for k, factor in enumerate(_FACTORS_IN_ORDER):
current_state[:, k] = np.random.choice(_NUM_VALUES_PER_FACTOR[factor], num_sequences)
trajectories[:, 0] = current_state

directions = sample_one_hot_vectors(num_sequences, len(_FACTORS_IN_ORDER)-1)
directions = np.insert(directions, _FACTORS_IN_ORDER.index('shape'), 0, axis=1)

for i in tqdm(range(1, batch_size+1)):
for factor in _FACTORS_IN_ORDER:
if factor == 'floor_hue' or factor == 'wall_hue' or factor == 'object_hue':
current_state[:, _FACTORS_IN_ORDER.index(factor)] = (current_state[:, _FACTORS_IN_ORDER.index(factor)] + directions[:, _FACTORS_IN_ORDER.index(factor)]) % _NUM_VALUES_PER_FACTOR[factor]
if factor == 'scale' or factor == 'orientation':
current_state[:, _FACTORS_IN_ORDER.index(factor)] = np.clip(current_state[:, _FACTORS_IN_ORDER.index(factor)] + directions[:, _FACTORS_IN_ORDER.index(factor)], 0, _NUM_VALUES_PER_FACTOR[factor]-1)
directions[:, _FACTORS_IN_ORDER.index(factor)] = np.where(current_state[:, _FACTORS_IN_ORDER.index(factor)] == 0, -directions[:, _FACTORS_IN_ORDER.index(factor)], directions[:, _FACTORS_IN_ORDER.index(factor)])
directions[:, _FACTORS_IN_ORDER.index(factor)] = np.where(current_state[:, _FACTORS_IN_ORDER.index(factor)] == _NUM_VALUES_PER_FACTOR[factor]-1, -directions[:, _FACTORS_IN_ORDER.index(factor)], directions[:, _FACTORS_IN_ORDER.index(factor)])
trajectories[:, i] = current_state

flattened_trajectories = trajectories.reshape(-1, len(_FACTORS_IN_ORDER)).astype(np.int32)
# exchange values of each factor according to _PERMUTATIONS_PER_FACTOR
for i, factor in enumerate(_FACTORS_IN_ORDER):
flattened_trajectories[:, i] = _PERMUTATIONS_PER_FACTOR[factor][flattened_trajectories[:, i]]
indexed_trajectories = get_index(flattened_trajectories.T)

return indexed_trajectories

if __name__ == '__main__':
data_dir = os.path.expanduser("~/data/datasets/shapes3d")
if not os.path.exists(data_dir):
os.makedirs(data_dir)

shuffle_colors = True

permutations_file_path = os.path.join(data_dir, 'permutations.npy')
if shuffle_colors:
trajectory_file_path = os.path.join(data_dir, 'indexed_trajectories.npy')
if not os.path.exists(permutations_file_path):
print('did not find permutations file, creating new shuffling of colors ...')
for factor in _FACTORS_IN_ORDER:
if factor == 'floor_hue' or factor == 'wall_hue' or factor == 'object_hue':
_PERMUTATIONS_PER_FACTOR[factor] = np.random.permutation(_NUM_VALUES_PER_FACTOR[factor])
else:
_PERMUTATIONS_PER_FACTOR[factor] = np.arange(_NUM_VALUES_PER_FACTOR[factor])
np.save(permutations_file_path, _PERMUTATIONS_PER_FACTOR)
else:
print('found permutations file, loading ...')
_PERMUTATIONS_PER_FACTOR = np.load(permutations_file_path, allow_pickle=True).item()
else:
trajectory_file_path = os.path.join(data_dir, 'indexed_trajectories_no_shuffle.npy')
print('not shuffling colors')
_PERMUTATIONS_PER_FACTOR = {factor: np.arange(_NUM_VALUES_PER_FACTOR[factor]) for factor in _FACTORS_IN_ORDER}

if os.path.exists(trajectory_file_path):
print('trajectories file {} already exists, please delete it first if you want to regenerate it'.format(trajectory_file_path))
else:
print("generating training data...")
print("destination directory: {}".format(data_dir))
indexed_trajectories = generate_trajectories(num_sequences=64000, batch_size=16)

print("saving training data...")
np.save(trajectory_file_path, indexed_trajectories)
print('saved trajectories to {}'.format(trajectory_file_path))
138 changes: 138 additions & 0 deletions datasets/shapes3d_datamodule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import os
import h5py

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.transforms import ToTensor
from torch.utils.data import random_split
from pytorch_lightning import LightningDataModule

_FACTORS_IN_ORDER = ['floor_hue', 'wall_hue', 'object_hue', 'scale', 'shape',
'orientation']
_NUM_VALUES_PER_FACTOR = {'floor_hue': 10, 'wall_hue': 10, 'object_hue': 10,
'scale': 8, 'shape': 4, 'orientation': 15}

class Shapes3DSequences(Dataset):

def __init__(self, data_dir, images, labels, batch_size=1024, set='train', transform=None):
self.data_dir = data_dir
self.transform = transform

self.images = images
self.labels = labels

self.image_shape = self.images.shape[1:]
self.n_samples = self.labels.shape[0]

shuffled_colors = True
if shuffled_colors:
indexed_trajectories = np.load(os.path.join(data_dir, 'indexed_trajectories.npy'))
else:
indexed_trajectories = np.load(os.path.join(data_dir, 'indexed_trajectories_no_shuffle.npy'))
self.batch_size = batch_size

# split indexed_trajectories into blocks of batch_size + 1 dropping any remainder
n_blocks = len(indexed_trajectories) // (self.batch_size + 1)
self.trajectories = indexed_trajectories[:n_blocks * (self.batch_size + 1)].reshape([-1, self.batch_size + 1])
self.trajectories = torch.from_numpy(self.trajectories)

if set == 'train':
self.trajectories = self.trajectories[:-30]
if set == 'val':
self.trajectories = self.trajectories[-30:-10]
if set == 'test':
self.trajectories = self.trajectories[-10:]

self.num_iter = 0
self.train = set == 'train'

def __len__(self):
return (1000 - 1) if self.train else (len(self.trajectories) - 1)

def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()

if self.train:
# sample a random trajectory
idx = np.random.randint(len(self.trajectories) - 1)
trajectory = self.trajectories[idx]
ims = torch.zeros([self.batch_size + 1] + list(self.image_shape))
labels = torch.zeros([self.batch_size + 1], dtype=torch.long)
for i, index in enumerate(trajectory):
ims[i] = self.images[index]
labels[i] = self.labels[index]

if self.transform:
ims = self.transform(ims)

return (ims[:-1], ims[1:], ims[:-1]), (labels[:-1])

class Shapes3DDataModule(LightningDataModule):

def __init__(self, data_dir, batch_size=1024, factor_to_use_as_label='shape'):
super().__init__()
self.data_path = data_dir
self.batch_size = batch_size
self.dims = (3, 64, 64)
self.output_dims = (6,)

self.size = (3, 64, 64)

dataset = h5py.File(os.path.join(data_dir, '3dshapes.h5'), 'r')
images = np.array(dataset['images']) # (480000, 64, 64, 3)
factors = np.array(dataset['labels']) # (480000, 6)

# convert from hdf5 to torch tensors
self.images = torch.from_numpy(images) / 255.0
self.factors = torch.from_numpy(factors)

# per-channel zero-mean unit-variance normalization of the images
self.images = (self.images - self.images.mean(dim=(0, 1, 2))) / self.images.std(dim=(0, 1, 2))

# change to channel first
self.images = self.images.permute(0, 3, 1, 2) # shape (batch_size, 3, 64, 64)
self._create_labels(factor_to_use_as_label)

def _create_labels(self, factor_to_use_as_label):
self.factor_to_use_as_label = factor_to_use_as_label
self.num_classes = _NUM_VALUES_PER_FACTOR[self.factor_to_use_as_label]

factors = self.factors[:, _FACTORS_IN_ORDER.index(self.factor_to_use_as_label)]
self.labels = torch.zeros([factors.shape[0]], dtype=torch.long)
if self.factor_to_use_as_label == 'floor_hue' or self.factor_to_use_as_label == 'wall_hue' or self.factor_to_use_as_label == 'object_hue':
self.labels = (factors * 10).long()
elif self.factor_to_use_as_label == 'scale':
# data values do not match the description given for this factor, hence the alternative discretization
self.labels = np.digitize(factors, np.linspace(0.75, 1.25, 8)) - 1
elif self.factor_to_use_as_label == 'shape':
self.labels = (factors).long()
elif self.factor_to_use_as_label == 'orientation':
self.labels = np.digitize(factors, np.linspace(-30, 30, 15)) - 1

return

def prepare_data(self):
# download, split, etc...
# only called on 1 GPU
pass

def setup(self, stage=None):
# transforms
# transform = transforms.Compose([ToTensor()])

# data
self.shapes3d_full = Shapes3DSequences(self.data_path, self.images, self.labels, batch_size=self.batch_size, transform=None, set='train')
self.shapes3d_val = Shapes3DSequences(self.data_path, self.images, self.labels, batch_size=self.batch_size, transform=None, set='val')
self.shapes3d_test = Shapes3DSequences(self.data_path, self.images, self.labels, batch_size=self.batch_size, transform=None, set='test')

def train_dataloader(self):
return DataLoader(self.shapes3d_full, batch_size=None, batch_sampler=None, num_workers=8, pin_memory=True, shuffle=True)

def val_dataloader(self):
return DataLoader(self.shapes3d_val, batch_size=None, batch_sampler=None, num_workers=8, pin_memory=True)

def test_dataloader(self):
return DataLoader(self.shapes3d_test, batch_size=None, batch_sampler=None, num_workers=8)
Loading

0 comments on commit dd486d0

Please sign in to comment.