Skip to content

Commit

Permalink
v0.7.0 changes -- no more explicit residual layers
Browse files Browse the repository at this point in the history
  • Loading branch information
tysam-code committed Nov 7, 2023
1 parent ff53cac commit ad103b4
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 61 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ Goals:
* torch- and python-idiomatic
* hackable
* few external dependencies (currently only torch and torchvision)
* ~world-record single-GPU training time (this repo holds the current world record at ~<7 (!!!) seconds on an A100, down from ~18.1 seconds originally).
* ~world-record single-GPU training time (this repo holds the current world record at ~<6.3 (!!!) seconds on an A100, down from ~18.1 seconds originally).
* <2 seconds training time in <2 years (yep!)

This is a neural network implementation of a very speedily-training network that originally started as a painstaking reproduction of [David Page's original ultra-fast CIFAR-10 implementation on a single GPU](https://myrtle.ai/learn/how-to-train-your-resnet/), but written nearly from the ground-up to be extremely rapid-experimentation-friendly. Part of the benefit of this is that we now hold the world record for single GPU training speeds on CIFAR10, for example.
Expand All @@ -39,6 +39,9 @@ What we've added:
* dirac initializations on non-depth-transitional layers (information passthrough on init)
* and more!

What we've removed:
* explicit residual layers. yep.

This code, in comparison to David's original code, is in a single file and extremely flat, but is not as durable for long-term production-level bug maintenance. You're meant to check out a fresh repo whenever you have a new idea. It is excellent for rapid idea exploring -- almost everywhere in the pipeline is exposed and built to be user-friendly. I truly enjoy personally using this code, and hope you do as well! :D Please let me know if you have any feedback. I hope to continue publishing updates to this in the future, so your support is encouraged. Share this repo with someone you know that might like it!

Feel free to check out my[Patreon](https://www.patreon.com/user/posts?u=83632131) if you like what I'm doing here and want more!. Additionally, if you want me to work up to a part-time amount of hours with you, feel free to reach out to me at [email protected]. I'd love to hear from you.
Expand Down
116 changes: 56 additions & 60 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,25 +43,24 @@
default_conv_kwargs = {'kernel_size': 3, 'padding': 'same', 'bias': False}

batchsize = 1024
bias_scaler = 56
# To replicate the ~95.78%-accuracy-in-113-seconds runs, you can change the base_depth from 64->128, train_epochs from 12.1->85, ['ema'] epochs 10->75, cutmix_size 3->9, and cutmix_epochs 6->75
bias_scaler = 64
# To replicate the ~95.79%-accuracy-in-110-seconds runs, you can change the base_depth from 64->128, train_epochs from 12.1->90, ['ema'] epochs 10->80, cutmix_size 3->10, and cutmix_epochs 6->80
hyp = {
'opt': {
'bias_lr': 1.64 * bias_scaler/512, # TODO: Is there maybe a better way to express the bias and batchnorm scaling? :'))))
'non_bias_lr': 1.64 / 512,
'bias_decay': 1.08 * 6.45e-4 * batchsize/bias_scaler,
'non_bias_decay': 1.08 * 6.45e-4 * batchsize,
'bias_lr': 1.525 * bias_scaler/512, # TODO: Is there maybe a better way to express the bias and batchnorm scaling? :'))))
'non_bias_lr': 1.525 / 512,
'bias_decay': 6.687e-4 * batchsize/bias_scaler,
'non_bias_decay': 6.687e-4 * batchsize,
'scaling_factor': 1./9,
'percent_start': .23,
'loss_scale_scaler': 1./128, # * Regularizer inside the loss summing (range: ~1/512 - 16+). FP8 should help with this somewhat too, whenever it comes out. :)
'loss_scale_scaler': 1./32, # * Regularizer inside the loss summing (range: ~1/512 - 16+). FP8 should help with this somewhat too, whenever it comes out. :)
},
'net': {
'whitening': {
'kernel_size': 2,
'num_examples': 50000,
},
'batch_norm_momentum': .5, # * Don't forget momentum is 1 - momentum here (due to a quirk in the original paper... >:( )
'conv_norm_pow': 2.6,
'batch_norm_momentum': .4, # * Don't forget momentum is 1 - momentum here (due to a quirk in the original paper... >:( )
'cutmix_size': 3,
'cutmix_epochs': 6,
'pad_amount': 2,
Expand Down Expand Up @@ -162,42 +161,34 @@ def __init__(self, num_features, eps=1e-12, momentum=hyp['net']['batch_norm_mome
# Having an outer class like this does add space and complexity but offers us
# a ton of freedom when it comes to hacking in unique functionality for each layer type
class Conv(nn.Conv2d):
def __init__(self, *args, norm=False, **kwargs):
def __init__(self, *args, **kwargs):
kwargs = {**default_conv_kwargs, **kwargs}
super().__init__(*args, **kwargs)
self.kwargs = kwargs
self.norm = norm

def forward(self, x):
if self.training and self.norm:
# TODO: Do/should we always normalize along dimension 1 of the weight vector(s), or the height x width dims too?
with torch.no_grad():
F.normalize(self.weight.data, p=self.norm)
return super().forward(x)

class Linear(nn.Linear):
def __init__(self, *args, norm=False, **kwargs):
def __init__(self, *args, temperature=None, **kwargs):
super().__init__(*args, **kwargs)
self.kwargs = kwargs
self.norm = norm
self.temperature = temperature

def forward(self, x):
if self.training and self.norm:
# TODO: Normalize on dim 1 or dim 0 for this guy?
with torch.no_grad():
F.normalize(self.weight.data, p=self.norm)
return super().forward(x)
if self.temperature is not None:
weight = self.weight * self.temperature
else:
weight = self.weight
return x @ weight.T

# can hack any changes to each residual group that you want directly in here
# can hack any changes to each convolution group that you want directly in here
class ConvGroup(nn.Module):
def __init__(self, channels_in, channels_out, norm):
def __init__(self, channels_in, channels_out):
super().__init__()
self.channels_in = channels_in
self.channels_in = channels_in
self.channels_out = channels_out

self.pool1 = nn.MaxPool2d(2)
self.conv1 = Conv(channels_in, channels_out, norm=norm)
self.conv2 = Conv(channels_out, channels_out, norm=norm)
self.conv1 = Conv(channels_in, channels_out)
self.conv2 = Conv(channels_out, channels_out)

self.norm1 = BatchNorm(channels_out)
self.norm2 = BatchNorm(channels_out)
Expand All @@ -210,20 +201,11 @@ def forward(self, x):
x = self.pool1(x)
x = self.norm1(x)
x = self.activ(x)
residual = x
x = self.conv2(x)
x = self.norm2(x)
x = self.activ(x)
x = x + residual # haiku
return x

class TemperatureScaler(nn.Module):
def __init__(self, init_val):
super().__init__()
self.scaler = torch.tensor(init_val)

def forward(self, x):
return x.mul(self.scaler)
return x

class FastGlobalMaxPooling(nn.Module):
def __init__(self):
Expand Down Expand Up @@ -275,7 +257,7 @@ def init_whitening_conv(layer, train_set=None, num_examples=None, previous_block
eigenvalue_list.append(eigenvalues)
eigenvector_list.append(eigenvectors)

eigenvalues = torch.stack(eigenvalue_list, dim=0).mean(0)
eigenvalues = torch.stack(eigenvalue_list, dim=0).mean(0)
eigenvectors = torch.stack(eigenvector_list, dim=0).mean(0)
# i believe the eigenvalues and eigenvectors come out in float32 for this because we implicitly cast it to float32 in the patches function (for numerical stability)
set_whitening_conv(layer, eigenvalues.to(dtype=layer.weight.dtype), eigenvectors.to(dtype=layer.weight.dtype), freeze=freeze)
Expand All @@ -284,7 +266,8 @@ def init_whitening_conv(layer, train_set=None, num_examples=None, previous_block

def set_whitening_conv(conv_layer, eigenvalues, eigenvectors, eps=1e-2, freeze=True):
shape = conv_layer.weight.data.shape
conv_layer.weight.data[-eigenvectors.shape[0]:, :, :, :] = (eigenvectors/torch.sqrt(eigenvalues+eps))[-shape[0]:, :, :, :] # set the first n filters of the weight data to the top n significant (sorted by importance) filters from the eigenvectors
eigenvectors_sliced = (eigenvectors/torch.sqrt(eigenvalues+eps))[-shape[0]:, :, :, :] # set the first n filters of the weight data to the top n significant (sorted by importance) filters from the eigenvectors
conv_layer.weight.data = torch.cat((eigenvectors_sliced, -eigenvectors_sliced), dim=0)
## We don't want to train this, since this is implicitly whitening over the whole dataset
## For more info, see David Page's original blogposts (link in the README.md as of this commit.)
if freeze:
Expand All @@ -304,7 +287,7 @@ def set_whitening_conv(conv_layer, eigenvalues, eigenvectors, eps=1e-2, freeze=T
'num_classes': 10
}

class SpeedyResNet(nn.Module):
class SpeedyConvNet(nn.Module):
def __init__(self, network_dict):
super().__init__()
self.net_dict = network_dict # flexible, defined in the make_net function
Expand All @@ -314,14 +297,12 @@ def forward(self, x):
if not self.training:
x = torch.cat((x, torch.flip(x, (-1,))))
x = self.net_dict['initial_block']['whiten'](x)
x = self.net_dict['initial_block']['project'](x)
x = self.net_dict['initial_block']['activation'](x)
x = self.net_dict['residual1'](x)
x = self.net_dict['residual2'](x)
x = self.net_dict['residual3'](x)
x = self.net_dict['conv_group_1'](x)
x = self.net_dict['conv_group_2'](x)
x = self.net_dict['conv_group_3'](x)
x = self.net_dict['pooling'](x)
x = self.net_dict['linear'](x)
x = self.net_dict['temperature'](x)
if not self.training:
# Average the predictions from the lr-flipped inputs during eval
orig, flipped = x.split(x.shape[0]//2, dim=0)
Expand All @@ -335,18 +316,16 @@ def make_net():
network_dict = nn.ModuleDict({
'initial_block': nn.ModuleDict({
'whiten': Conv(3, whiten_conv_depth, kernel_size=hyp['net']['whitening']['kernel_size'], padding=0),
'project': Conv(whiten_conv_depth, depths['init'], kernel_size=1, norm=2.2), # The norm argument means we renormalize the weights to be length 1 for this as the power for the norm, each step
'activation': nn.GELU(),
}),
'residual1': ConvGroup(depths['init'], depths['block1'], hyp['net']['conv_norm_pow']),
'residual2': ConvGroup(depths['block1'], depths['block2'], hyp['net']['conv_norm_pow']),
'residual3': ConvGroup(depths['block2'], depths['block3'], hyp['net']['conv_norm_pow']),
'conv_group_1': ConvGroup(2*whiten_conv_depth, depths['block1']),
'conv_group_2': ConvGroup(depths['block1'], depths['block2']),
'conv_group_3': ConvGroup(depths['block2'], depths['block3']),
'pooling': FastGlobalMaxPooling(),
'linear': Linear(depths['block3'], depths['num_classes'], bias=False, norm=5.),
'temperature': TemperatureScaler(hyp['opt']['scaling_factor'])
'linear': Linear(depths['block3'], depths['num_classes'], bias=False, temperature=hyp['opt']['scaling_factor']),
})

net = SpeedyResNet(network_dict)
net = SpeedyConvNet(network_dict)
net = net.to(hyp['misc']['device'])
net = net.to(memory_format=torch.channels_last) # to appropriately use tensor cores/avoid thrash while training
net.train()
Expand All @@ -365,18 +344,35 @@ def make_net():
## the index lookup in the dataloader may give you some trouble depending
## upon exactly how memory-limited you are

## We initialize the projections layer to return exactly the spatial inputs, this way we start
## at a nice clean place (the whitened image in feature space, directly) and can iterate directly from there.
torch.nn.init.dirac_(net.net_dict['initial_block']['project'].weight)

for layer_name in net.net_dict.keys():
if 'residual' in layer_name:
## We do the same for the second layer in each residual block, since this only
if 'conv_group' in layer_name:
# Create an implicit residual via a dirac-initialized tensor
dirac_weights_in = torch.nn.init.dirac_(torch.empty_like(net.net_dict[layer_name].conv1.weight))

# Add the implicit residual to the already-initialized convolutional transition layer.
# One can use more sophisticated initializations, but this one appeared worked best in testing.
# What this does is brings up the features from the previous residual block virtually, so not only
# do we have residual information flow within each block, we have a nearly direct connection from
# the early layers of the network to the loss function.
std_pre, mean_pre = torch.std_mean(net.net_dict[layer_name].conv1.weight.data)
net.net_dict[layer_name].conv1.weight.data = net.net_dict[layer_name].conv1.weight.data + dirac_weights_in
std_post, mean_post = torch.std_mean(net.net_dict[layer_name].conv1.weight.data)

# Renormalize the weights to match the original initialization statistics
net.net_dict[layer_name].conv1.weight.data.sub_(mean_post).div_(std_post).mul_(std_pre).add_(mean_pre)

## We do the same for the second layer in each convolution group block, since this only
## adds a simple multiplier to the inputs instead of the noise of a randomly-initialized
## convolution. This can be easily scaled down by the network, and the weights can more easily
## pivot in whichever direction they need to go now.
## The reason that I believe that this works so well is because a combination of MaxPool2d
## and the nn.GeLU function's positive bias encouraging values towards the nearly-linear
## region of the GeLU activation function at network initialization. I am not currently
## sure about this, however, it will require some more investigation. For now -- it works! D:
torch.nn.init.dirac_(net.net_dict[layer_name].conv2.weight)


return net

#############################################
Expand Down

0 comments on commit ad103b4

Please sign in to comment.