From 01603a8a9605552d1766311375122966136f338f Mon Sep 17 00:00:00 2001 From: TySam& B Date: Sun, 15 Jan 2023 00:29:49 -0500 Subject: [PATCH] Committing changes -- memory format and max pooling for a change from ~18.1 to ~12.31-12.38s, a new world record! :fireworks: :fireworks: :penguin: :fireworks: --- main.py | 106 ++++++++++++++++++++--------------------------- requirements.txt | 3 -- 2 files changed, 44 insertions(+), 65 deletions(-) diff --git a/main.py b/main.py index dfa5a1e..20423c8 100644 --- a/main.py +++ b/main.py @@ -41,10 +41,10 @@ bias_scaler = 64 hyp = { 'opt': { - 'bias_lr': 1. * bias_scaler/batchsize, # TODO: How we're expressing this information feels somewhat clunky, is there maybe a better way to do this? :')))) - 'non_bias_lr': 1. / batchsize, - 'bias_decay': 5e-4 * batchsize/bias_scaler, - 'non_bias_decay': 5e-4 * batchsize, + 'bias_lr': 1.35 * 1. * bias_scaler/batchsize, # TODO: How we're expressing this information feels somewhat clunky, is there maybe a better way to do this? :')))) + 'non_bias_lr': 1.35 * 1. / batchsize, + 'bias_decay': 4.8e-4 * batchsize/bias_scaler, + 'non_bias_decay': 4.8e-4 * batchsize, 'scaling_factor': 1./16, 'percent_start': .2, }, @@ -53,15 +53,15 @@ 'kernel_size': 3, 'num_examples': 10000, }, - 'ghost_norm_group_size': 64, ## Regularization + 'batch_norm_momentum': .4, 'cutout_size': 0, 'pad_amount': 4, }, 'misc': { 'ema': { - 'epochs': 2, - 'decay_base': .99, - 'every_n_steps': 5, + 'epochs': 3, + 'decay_base': .987, + 'every_n_steps': 2, }, 'train_epochs': 10, 'device': 'cuda', @@ -124,6 +124,7 @@ def batch_normalize_images(input_images, mean, std): ## hyp dictionary, then we should be good. :) data = torch.load(hyp['misc']['data_location']) + ## As you'll note above and below, one difference is that we don't count loading the raw data to GPU since it's such a variable operation, and can sort of get in the way ## of measuring other things. That said, measuring the preprocessing (outside of the padding) is still important to us. @@ -136,40 +137,15 @@ def batch_normalize_images(input_images, mean, std): # Network Components # ############################################# -# We might be able to fuse this weight and save some memory/runtime/etc, since the fast version of the network doesn't need it I thinks... +# We might be able to fuse this weight and save some memory/runtime/etc, since the fast version of the network might be able to do without somehow.... class BatchNorm(nn.BatchNorm2d): - def __init__(self, num_features, eps=1e-05, momentum=0.1, weight=False, bias=True): + def __init__(self, num_features, eps=1e-12, momentum=hyp['net']['batch_norm_momentum'], weight=False, bias=True): super().__init__(num_features, eps=eps, momentum=momentum) self.weight.data.fill_(1.0) self.bias.data.fill_(0.0) self.weight.requires_grad = weight self.bias.requires_grad = bias -class GhostNorm(BatchNorm): - def __init__(self, num_features, num_splits=batchsize//hyp['net']['ghost_norm_group_size'], **kw): - super().__init__(num_features, **kw) - self.num_splits = num_splits - self.register_buffer('running_mean', torch.zeros(num_features*self.num_splits)) - self.register_buffer('running_var', torch.ones(num_features*self.num_splits)) - - def train(self, mode=True): - if (self.training is True) and (mode is False): #lazily collate stats when we are going to use them, i.e., when we switch from the 'train' to 'eval' modes - self.running_mean = torch.mean(self.running_mean.view(self.num_splits, self.num_features), dim=0).repeat(self.num_splits) - self.running_var = torch.mean(self.running_var.view(self.num_splits, self.num_features), dim=0).repeat(self.num_splits) - return super().train(mode) - - def forward(self, input): - N, C, H, W = input.shape - if self.training or not self.track_running_stats: - return torch.nn.functional.batch_norm( - input.view(-1, C*self.num_splits, H, W), self.running_mean, self.running_var, - self.weight.repeat(self.num_splits), self.bias.repeat(self.num_splits), - True, self.momentum, self.eps).view(N, C, H, W) - else: - return torch.nn.functional.batch_norm( - input, self.running_mean[:self.num_features], self.running_var[:self.num_features], - self.weight, self.bias, False, self.momentum, self.eps) - # Allows us to set default arguments for the whole convolution itself. class Conv(nn.Conv2d): def __init__(self, *args, **kwargs): @@ -177,7 +153,6 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.kwargs = kwargs - # can hack any changes to each residual group that you want directly in here class ConvGroup(nn.Module): def __init__(self, channels_in, channels_out, residual, short, pool): @@ -191,16 +166,15 @@ def __init__(self, channels_in, channels_out, residual, short, pool): self.conv1 = Conv(channels_in, channels_out) self.pool1 = nn.MaxPool2d(2) - self.norm1 = GhostNorm(channels_out) + self.norm1 = BatchNorm(channels_out) self.activ = nn.CELU(alpha=.3) # note: this has to be flat if we're jitting things.... we just might burn a bit of extra GPU mem if so if not short: self.conv2 = Conv(channels_out, channels_out) self.conv3 = Conv(channels_out, channels_out) - self.norm2 = GhostNorm(channels_out) - self.norm3 = GhostNorm(channels_out) - + self.norm2 = BatchNorm(channels_out) + self.norm3 = BatchNorm(channels_out) def forward(self, x): x = self.conv1(x) @@ -234,6 +208,15 @@ def forward(self, x): ## my implementation, and David's implementation... return x.mul(self.scaler) +class FastGlobalMaxPooling(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + # Previously was chained torch.max calls. + # requires less time than AdaptiveMax2dPooling -- about ~.3s for the entire run, in fact (which is pretty significant! :O :D :O :O <3 <3 <3 <3) + return torch.amax(x, dim=(2,3)) # Global maximum pooling + ############################################# # Init Helper Functions # ############################################# @@ -288,7 +271,7 @@ def set_whitening_conv(conv_layer, eigenvalues, eigenvectors, eps=1e-2): class SpeedyResNet(nn.Module): def __init__(self, network_dict): super().__init__() - self.net_dict = network_dict # flexible, defined in the make_network function + self.net_dict = network_dict # flexible, defined in the make_net function # This allows you to customize/change the execution order of the network as needed. def forward(self, x): @@ -302,7 +285,6 @@ def forward(self, x): x = self.net_dict['residual2'](x) x = self.net_dict['residual3'](x) x = self.net_dict['pooling'](x) - x = self.net_dict['reshape'](x) x = self.net_dict['linear'](x) x = self.net_dict['temperature'](x) if not self.training: @@ -319,26 +301,26 @@ def make_net(): 'initial_block': nn.ModuleDict({ 'whiten': Conv(3, whiten_conv_depth, kernel_size=hyp['net']['whitening']['kernel_size']), 'project': Conv(whiten_conv_depth, depths['init'], kernel_size=1), - 'norm': GhostNorm(depths['init'], weight=False), + 'norm': BatchNorm(depths['init'], weight=False), 'activation': nn.CELU(alpha=.3), }), 'residual1': ConvGroup(depths['init'], depths['block1'], residual=True, short=False, pool=True), 'residual2': ConvGroup(depths['block1'], depths['block2'], residual=True, short=True, pool=True), 'residual3': ConvGroup(depths['block2'], depths['block3'], residual=True, short=False, pool=True), - 'pooling': nn.AdaptiveMaxPool2d((1, 1)), - 'reshape': nn.Flatten(), + 'pooling': FastGlobalMaxPooling(), 'linear': nn.Linear(depths['block3'], depths['num_classes'], bias=False), 'temperature': TemperatureScaler(hyp['opt']['scaling_factor']) }) net = SpeedyResNet(network_dict) net = net.to(hyp['misc']['device']) + net = net.to(memory_format=torch.channels_last) # to appropriately use tensor cores/avoid thrash while training net.train() net.half() # Convert network to half before initializing the initial whitening layer. ## Initialize the whitening convolution with torch.no_grad(): - # Initialize the first layer to be fixed weights that whiten the expected input values of the network be on the unit hypersphere. (i.e. their vector length is 1., IIRC) + # Initialize the first layer to be fixed weights that whiten the expected input values of the network be on the unit hypersphere. (i.e. their...average vector length is 1.?, IIRC) init_whitening_conv(net.net_dict['initial_block']['whiten'], data['train']['images'].index_select(0, torch.randperm(data['train']['images'].shape[0], device=data['train']['images'].device)), num_examples=hyp['net']['whitening']['num_examples'], @@ -392,7 +374,7 @@ def batch_crop(inputs, crop_size): def batch_flip_lr(batch_images, flip_chance=.5): with torch.no_grad(): - # TODO: More elegant way to do this? :') :'(((( + # TODO: Is there a more elegant way to do this? :') :'(((( return torch.where(torch.rand_like(batch_images[:, 0, 0, 0].view(-1, 1, 1, 1)) < flip_chance, torch.flip(batch_images, (-1,)), batch_images) @@ -416,7 +398,8 @@ def forward(self, inputs): with torch.no_grad(): return self.net_ema(inputs) -# TODO: Can we jit this in the (more distant) future? :) +# TODO: Could we jit this in the (more distant) future? :) +@torch.no_grad() def get_batches(data_dict, key, batchsize): num_epoch_examples = len(data_dict[key]['images']) shuffled = torch.randperm(num_epoch_examples, device='cuda') @@ -431,6 +414,8 @@ def get_batches(data_dict, key, batchsize): else: images = data_dict[key]['images'] + # Send the images to an (in beta) channels_last to help improve tensor core occupancy (and reduce NCHW <-> NHWC thrash) during training + images = images.to(memory_format=torch.channels_last) for idx in range(num_epoch_examples // batchsize): if not (idx+1)*batchsize > num_epoch_examples: ## Use the shuffled randperm to assemble individual items into a minibatch yield images.index_select(0, shuffled[idx*batchsize:(idx+1)*batchsize]), \ @@ -472,13 +457,12 @@ def print_training_details(columns_list, separator_left='| ', separator_right=' if is_final_entry: print('-'*(len(print_string))) # print the final output bar -print_training_details(logging_columns_list, column_heads_only=True) # print out the training column heads. +print_training_details(logging_columns_list, column_heads_only=True) ## print out the training column heads before we print the actual content for each run. ######################################## # Train and Eval # ######################################## -# to do cast to fp16 precision for training def main(): # Initializing constants for the whole run. net_ema = None ## Reset any existing network emas, we want to have _something_ to check for existence so we can initialize the EMA right from where the network is during training @@ -488,13 +472,14 @@ def main(): current_steps = 0. # TODO: Doesn't currently account for partial epochs really (since we're not doing "real" epochs across the whole batchsize).... - num_steps_per_epoch = len(data['train']['images']) // batchsize # todo: a bit of a tad of cleanup here. ::::))) :>>> + num_steps_per_epoch = len(data['train']['images']) // batchsize total_train_steps = num_steps_per_epoch * hyp['misc']['train_epochs'] ema_epoch_start = hyp['misc']['train_epochs'] - hyp['misc']['ema']['epochs'] num_low_lr_steps_for_ema = hyp['misc']['ema']['epochs'] * num_steps_per_epoch - ## TODO: (check? <# :)))) ) I believe this wasn't logged, but the EMA update power is adjusted by being raised to the power of the number of "every n" steps + + ## I believe this wasn't logged, but the EMA update power is adjusted by being raised to the power of the number of "every n" steps ## to somewhat accomodate for whatever the expected information intake rate is. The tradeoff I believe, though, is that this is to some degree noisier as we - ## are intaking fewer samples of our distribution-over-time, with a higher individual weight each. + ## are intaking fewer samples of our distribution-over-time, with a higher individual weight each. This can be good or bad depending upon what we want. projected_ema_decay_val = hyp['misc']['ema']['decay_base'] ** hyp['misc']['ema']['every_n_steps'] # Adjust pct_start based upon how many epochs we need to finetune the ema at a low lr for @@ -540,7 +525,7 @@ def main(): for epoch_step, (inputs, targets) in enumerate(get_batches(data, key='train', batchsize=batchsize)): ## Run everything through the network outputs = net(inputs) - + ## If you want to add other losses or hack around with the loss, you can do that here. loss = loss_fn(outputs, targets).sum() ## Note, as noted in the original blog posts, the summing here does a kind of loss scaling @@ -561,11 +546,10 @@ def main(): # We only want to step the lr_schedulers while we have training steps to consume. Otherwise we get a not-so-friendly error from PyTorch lr_sched.step() lr_sched_bias.step() - + ## Using 'set_to_none' I believe is slightly faster (albeit riskier w/ funky gradient update workflows) than under the default 'set to zero' method opt.zero_grad(set_to_none=True) opt_bias.zero_grad(set_to_none=True) - current_steps += 1 if epoch >= ema_epoch_start and current_steps % hyp['misc']['ema']['every_n_steps'] == 0: @@ -573,7 +557,6 @@ def main(): if net_ema is None: net_ema = NetworkEMA(net, decay=projected_ema_decay_val) net_ema.update(net) - ender.record() torch.cuda.synchronize() total_time_seconds += 1e-3 * starter.elapsed_time(ender) @@ -588,11 +571,10 @@ def main(): loss_list_val, acc_list, acc_list_ema = [], [], [] with torch.no_grad(): - # TODO: Copy is probably slow, we can def avoid this somehow, I think.... for inputs, targets in get_batches(data, key='eval', batchsize=eval_batchsize): if epoch >= ema_epoch_start: - outputs = net_ema(inputs) - acc_list_ema.append((outputs.argmax(-1) == targets).float().mean()) + outputs = net_ema(inputs) + acc_list_ema.append((outputs.argmax(-1) == targets).float().mean()) outputs = net(inputs) loss_list_val.append(loss_fn(outputs, targets).float().mean()) acc_list.append((outputs.argmax(-1) == targets).float().mean()) @@ -610,12 +592,12 @@ def main(): format_for_table = lambda x, locals: (f"{locals[x]}".rjust(len(x))) \ if type(locals[x]) == int else "{:0.4f}".format(locals[x]).rjust(len(x)) \ if locals[x] is not None \ - else " "*len(x) + else " "*len(x) # Print out our training details (sorry for the complexity, the whole logging business here is a bit of a hot mess once the columns need to be aligned and such....) ## We also check to see if we're in our final epoch so we can print the 'bottom' of the table for each round. print_training_details(list(map(partial(format_for_table, locals=locals()), logging_columns_list)), is_final_entry=(epoch == hyp['misc']['train_epochs'] - 1)) if __name__ == "__main__": - for run_num in range(5): + for run_num in range(25): main() diff --git a/requirements.txt b/requirements.txt index 17f2add..ac988bd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,2 @@ torch torchvision -numpy -ipython -rich