Skip to content

Universal loader

Maksim Nikolaev edited this page Aug 29, 2021 · 1 revision

This page briefly describes how the universal model loader works and how to add your own model with non-standard inputs or loss functions.

Task start.py initialize the Trainer class in which the training will take place. File trainer.py it contains 2 implemented classes for different types of GAN training, the base version requires passing some mandatory parameters:

conf: dictionary of training settings, this is the input json file
**kwargs: the dictionary must contain:
    "G"           # initialized generator
    "D"           # initialized discriminator
    "start_epoch" # epoch for continuing training
    "dataloader"  # dataset loader
    "optim_G"     # generator optimizer
    "optim_D"     # discriminator optimizer
    "gen_loss"    # generator loss function
    "disc_loss"   # discriminator loss function
    "z_dim"       # the dimension of the generator vector
    "device"      # contains a device type

For more information about what a json file should contain, see the page configuring the config.

Now let's look at how to add your model to this loader by looking at start.py:

# Loading all models
generators = {}
discriminators = {}
for name_model in get_py_modules('models'):
    model = dynamic_import(f'models.{name_model}')
    generators = {**generators, **model.generators}
    discriminators = {**discriminators, **model.discriminators}
assert conf["Generator"] in generators.keys()
assert conf["Discriminator"] in discriminators.keys()

That is, the model from the json file is loaded as follows: the script traverses all python modules in the models folder and searches inside two dictionaries, each of which is the following:

generators[NameGAN] = class NameGANgenerator
discriminators[NameGAN] = class NameGANdiscriminator

Thus, to implement this in semi-automatic mode, it is enough to move the user-defined file to the models folder with your model and add a couple of lines:

from utils import register

generators = register.ClassRegistry()
discriminators = register.ClassRegistry()

@generators.add_to_registry("NameGAN")
class Generator(nn.Module):
    ...

@discriminators.add_to_registry("NameGAN")
class Discriminator(nn.Module):
    ...

Accordingly, a decorator is written in the utils.register, which adds the necessary class to the dictionary.

Absolutely similarly, you can add loss functions, only you need to add them to the file losses.py.

Now let's analyze the case of non-standard loss functions or models. The base version of Trainer implements a universal train_loop, which refers to many functions:

def logger(self, data):                      # logs training  

def save_model(self, epoch):                 # saves model weights

def generate_images(self, cnt=1):            # returns cnt of butches of generated images

def train_disc(self, real_imgs, fake_imgs):  # accepts real and fake images for training the discriminator

def train_gen(self, fake_imgs):              # accepts the generated images for training the generator

def train_loop(self):

Accordingly, if we have a more complex loss function that requires not only real and fake images, but also, for example, additional parameters or a discriminator for input, such as WGAN-GP, then it is enough to inherit from base Trainer and redefine the necessary functions, for example, for WGAN-GP:

@trainers.add_to_registry(name="gp")            
class GpGANTrainer(BaseGANTrainer):
    def __init__(self, conf, **kwargs):
        super().__init__(conf, **kwargs) 
        
    def train_disc(self, real_imgs, fake_imgs):
        lambda_gp = self.conf["Loss_config"]["lambda_gp"]
        return self.disc_loss(self.D, real_imgs, fake_imgs, lambda_gp)